Skip to content

Commit 8489f90

Browse files
author
debesh
committed
Add Brent's Method for root finding (numerical analysis)
1 parent a71618f commit 8489f90

1 file changed

Lines changed: 115 additions & 0 deletions

File tree

maths/brent_method.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
from typing import Callable
2+
3+
def brent_method(
4+
f: Callable[[float], float],
5+
a: float,
6+
b: float,
7+
tol: float = 1e-8,
8+
max_iter: int = 100
9+
) -> float:
10+
"""
11+
Find the root of function f in the interval [a, b] using Brent's Method.
12+
13+
Brent's Method combines bisection, secant, and inverse quadratic interpolation.
14+
15+
Parameters
16+
----------
17+
f : Callable[[float], float]
18+
Function for which to find the root.
19+
a : float
20+
Left endpoint of interval.
21+
b : float
22+
Right endpoint of interval.
23+
tol : float
24+
Tolerance for convergence (default 1e-8).
25+
max_iter : int
26+
Maximum number of iterations (default 100).
27+
28+
Returns
29+
-------
30+
float
31+
Approximate root of f in [a, b].
32+
33+
Raises
34+
------
35+
ValueError
36+
If f(a) and f(b) do not have opposite signs.
37+
38+
Examples
39+
--------
40+
>>> def func(x): return x**3 - x - 2
41+
>>> round(brent_method(func, 1, 2), 5)
42+
1.52138
43+
44+
>>> def func2(x): return x**2 + 1
45+
>>> brent_method(func2, 0, 1)
46+
Traceback (most recent call last):
47+
...
48+
ValueError: f(a) and f(b) must have opposite signs
49+
"""
50+
fa = f(a)
51+
fb = f(b)
52+
53+
if fa * fb >= 0:
54+
raise ValueError("f(a) and f(b) must have opposite signs")
55+
56+
if abs(fa) < abs(fb):
57+
a, b = b, a
58+
fa, fb = fb, fa
59+
60+
c = a
61+
fc = fa
62+
d = e = b - a
63+
64+
for iteration in range(max_iter):
65+
if fb == 0:
66+
return b
67+
68+
if fa != fc and fb != fc:
69+
# Inverse quadratic interpolation
70+
s = (
71+
a * fb * fc / ((fa - fb) * (fa - fc))
72+
+ b * fa * fc / ((fb - fa) * (fb - fc))
73+
+ c * fa * fb / ((fc - fa) * (fc - fb))
74+
)
75+
else:
76+
# Secant method
77+
s = b - fb * (b - a) / (fb - fa)
78+
79+
conditions = [
80+
not ((3 * a + b) / 4 < s < b) if b > a else not (b < s < (3 * a + b) / 4),
81+
iteration > 1 and abs(s - b) >= abs(b - c) / 2,
82+
iteration <= 1 and abs(s - b) >= abs(c - d) / 2,
83+
iteration > 1 and abs(b - c) < tol,
84+
iteration <= 1 and abs(c - d) < tol,
85+
]
86+
87+
if any(conditions):
88+
# Bisection fallback
89+
s = (a + b) / 2
90+
d = e = b - a
91+
92+
fs = f(s)
93+
d, c = c, b
94+
fc = fb
95+
96+
if fa * fs < 0:
97+
b = s
98+
fb = fs
99+
else:
100+
a = s
101+
fa = fs
102+
103+
if abs(fa) < abs(fb):
104+
a, b = b, a
105+
fa, fb = fb, fa
106+
107+
if abs(b - a) < tol:
108+
return b
109+
110+
# If we reach max iterations
111+
return b
112+
if __name__ == "__main__":
113+
import doctest
114+
doctest.testmod(verbose=True)
115+

0 commit comments

Comments
 (0)