Skip to content

Commit cdaeabe

Browse files
author
debesh
committed
Fix parameter names and lint issues in Brent's Method
1 parent a14bceb commit cdaeabe

1 file changed

Lines changed: 53 additions & 52 deletions

File tree

maths/brent_method.py

Lines changed: 53 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -2,114 +2,115 @@
22

33

44
def brent_method(
5-
f: Callable[[float], float],
6-
a: float,
7-
b: float,
5+
func: Callable[[float], float],
6+
left: float,
7+
right: float,
88
tol: float = 1e-8,
99
max_iter: int = 100,
1010
) -> float:
1111
"""
12-
Find the root of function f in the interval [a, b] using Brent's Method.
12+
Find the root of function func in the interval [left, right] using Brent's Method.
1313
1414
Brent's Method combines bisection, secant, and inverse quadratic interpolation.
1515
16+
1617
Parameters
1718
----------
18-
f : Callable[[float], float]
19+
func : Callable[[float], float]
1920
Function for which to find the root.
20-
a : float
21+
left : float
2122
Left endpoint of interval.
22-
b : float
23+
right : float
2324
Right endpoint of interval.
2425
tol : float
2526
Tolerance for convergence (default 1e-8).
2627
max_iter : int
2728
Maximum number of iterations (default 100).
2829
30+
2931
Returns
3032
-------
3133
float
32-
Approximate root of f in [a, b].
34+
Approximate root of func in [left, right].
3335
3436
Raises
3537
------
3638
ValueError
37-
If f(a) and f(b) do not have opposite signs.
39+
If func(left) and func(right) do not have opposite signs.
3840
3941
Examples
4042
--------
41-
>>> def func(x): return x**3 - x - 2
42-
>>> round(brent_method(func, 1, 2), 5)
43+
>>> def f(x): return x**3 - x - 2
44+
>>> round(brent_method(f, 1, 2), 5)
4345
1.52138
4446
45-
>>> def func2(x): return x**2 + 1
46-
>>> brent_method(func2, 0, 1)
47+
>>> def f2(x): return x**2 + 1
48+
>>> brent_method(f2, 0, 1)
4749
Traceback (most recent call last):
4850
...
49-
ValueError: f(a) and f(b) must have opposite signs
51+
ValueError: func(left) and func(right) must have opposite signs
5052
"""
51-
fa = f(a)
52-
fb = f(b)
53+
fl = func(left)
54+
fr = func(right)
5355

54-
if fa * fb >= 0:
55-
raise ValueError("f(a) and f(b) must have opposite signs")
56+
if fl * fr >= 0:
57+
raise ValueError("func(left) and func(right) must have opposite signs")
5658

57-
if abs(fa) < abs(fb):
58-
a, b = b, a
59-
fa, fb = fb, fa
59+
if abs(fl) < abs(fr):
60+
left, right = right, left
61+
fl, fr = fr, fl
6062

61-
c = a
62-
fc = fa
63-
d = e = b - a
63+
c = left
64+
fc = fl
65+
d = right - left
6466

6567
for iteration in range(max_iter):
66-
if fb == 0:
67-
return b
68+
if fr == 0:
69+
return right
6870

69-
if fa != fc and fb != fc:
71+
if fc not in (fl, fr):
7072
# Inverse quadratic interpolation
7173
s = (
72-
a * fb * fc / ((fa - fb) * (fa - fc))
73-
+ b * fa * fc / ((fb - fa) * (fb - fc))
74-
+ c * fa * fb / ((fc - fa) * (fc - fb))
74+
left * fr * fc / ((fl - fr) * (fl - fc))
75+
+ right * fl * fc / ((fr - fl) * (fr - fc))
76+
+ c * fl * fr / ((fc - fl) * (fc - fr))
7577
)
7678
else:
7779
# Secant method
78-
s = b - fb * (b - a) / (fb - fa)
80+
s = right - fr * (right - left) / (fr - fl)
7981

8082
conditions = [
81-
not ((3 * a + b) / 4 < s < b) if b > a else not (b < s < (3 * a + b) / 4),
82-
iteration > 1 and abs(s - b) >= abs(b - c) / 2,
83-
iteration <= 1 and abs(s - b) >= abs(c - d) / 2,
84-
iteration > 1 and abs(b - c) < tol,
83+
not ((3 * left + right) / 4 < s < right) if right > left else not (right < s < (3 * left + right) / 4),
84+
iteration > 1 and abs(s - right) >= abs(right - c) / 2,
85+
iteration <= 1 and abs(s - right) >= abs(c - d) / 2,
86+
iteration > 1 and abs(right - c) < tol,
8587
iteration <= 1 and abs(c - d) < tol,
8688
]
8789

8890
if any(conditions):
8991
# Bisection fallback
90-
s = (a + b) / 2
91-
d = e = b - a
92+
s = (left + right) / 2
93+
d = right - left
9294

93-
fs = f(s)
94-
d, c = c, b
95-
fc = fb
95+
fs = func(s)
96+
d, c = c, right
97+
fc = fr
9698

97-
if fa * fs < 0:
98-
b = s
99-
fb = fs
99+
if fl * fs < 0:
100+
right = s
101+
fr = fs
100102
else:
101-
a = s
102-
fa = fs
103+
left = s
104+
fl = fs
103105

104-
if abs(fa) < abs(fb):
105-
a, b = b, a
106-
fa, fb = fb, fa
106+
if abs(fl) < abs(fr):
107+
left, right = right, left
108+
fl, fr = fr, fl
107109

108-
if abs(b - a) < tol:
109-
return b
110+
if abs(right - left) < tol:
111+
return right
110112

111-
# If we reach max iterations
112-
return b
113+
return right
113114

114115

115116
if __name__ == "__main__":

0 commit comments

Comments
 (0)