Skip to content

Commit 13a414d

Browse files
Add Brent’s Method for root finding
1 parent a71618f commit 13a414d

1 file changed

Lines changed: 100 additions & 0 deletions

File tree

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
"""
2+
Brent's Method for root finding.
3+
4+
This function implements Brent's Method, an efficient algorithm for finding the
5+
root of a function. It combines the bisection method, the secant method, and
6+
inverse quadratic interpolation.
7+
8+
Reference:
9+
- https://en.wikipedia.org/wiki/Brent%27s_method
10+
- https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.brentq.html
11+
12+
13+
>>> def f(x): return x**3 - x - 2
14+
>>> round(brent_method(f, 1, 2), 6)
15+
1.52138
16+
>>> brent_method(f, 1, 1.5) # No sign change, should raise an error
17+
Traceback (most recent call last):
18+
...
19+
ValueError: f(a) and f(b) must have different signs
20+
"""
21+
22+
from collections.abc import Callable
23+
24+
25+
def brent_method(
26+
f: Callable[[float], float],
27+
a: float,
28+
b: float,
29+
tol: float = 1e-7,
30+
max_iter: int = 100,
31+
) -> float:
32+
"""
33+
Find a root of the function f in the interval [a, b] using Brent's method.
34+
35+
Args:
36+
f: The function for which we are trying to find a root.
37+
a: The start of the interval.
38+
b: The end of the interval.
39+
tol: The allowed error of the result.
40+
max_iter: Maximum number of iterations.
41+
42+
Returns:
43+
A root of f in [a, b], accurate to within tol.
44+
45+
Raises:
46+
ValueError: If f(a) and f(b) do not have opposite signs.
47+
RuntimeError: If the root is not found within max_iter iterations.
48+
"""
49+
fa = f(a)
50+
fb = f(b)
51+
if fa * fb >= 0:
52+
raise ValueError("f(a) and f(b) must have different signs")
53+
54+
if abs(fa) < abs(fb):
55+
a, b = b, a
56+
fa, fb = fb, fa
57+
58+
c, fc = a, fa
59+
d = e = b - a
60+
61+
for _ in range(max_iter):
62+
if fb == 0:
63+
return b
64+
if fc not in (fa, fb):
65+
# Inverse quadratic interpolation
66+
s = (
67+
a * fb * fc / ((fa - fb) * (fa - fc))
68+
+ b * fa * fc / ((fb - fa) * (fb - fc))
69+
+ c * fa * fb / ((fc - fa) * (fc - fb))
70+
)
71+
else:
72+
# Secant Method
73+
s = b - fb * (b - a) / (fb - fa)
74+
75+
conditions = [
76+
not ((3 * a + b) / 4 < s < b) if b > a else not (b < s < (3 * a + b) / 4),
77+
(e is not None and abs(s - b) >= abs(e / 2)),
78+
(d is not None and abs(d) >= abs(e / 2)),
79+
abs(b - a) < tol,
80+
]
81+
if any(conditions):
82+
s = (a + b) / 2 # Bisection method
83+
e = d = b - a
84+
else:
85+
d = e
86+
e = b - s
87+
88+
fs = f(s)
89+
c, fc = b, fb
90+
if fa * fs < 0:
91+
b, fb = s, fs
92+
else:
93+
a, fa = s, fs
94+
if abs(fa) < abs(fb):
95+
a, b = b, a
96+
fa, fb = fb, fa
97+
if abs(b - a) < tol:
98+
return b
99+
100+
raise RuntimeError("Maximum number of iterations reached without convergence")

0 commit comments

Comments
 (0)