Skip to content

Commit 8636195

Browse files
committed
brent4
2 parents e3f1eae + 0aeb7f7 commit 8636195

1 file changed

Lines changed: 72 additions & 46 deletions

File tree

Lines changed: 72 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,88 +1,114 @@
1+
"""
2+
Brent's Method for Root Finding
3+
4+
Find a root of a function in a bracketing interval using Brent's method.
5+
6+
Brent's method combines bisection, secant, and inverse quadratic interpolation to efficiently and robustly find a root of a continuous function. It is guaranteed to converge as long as the root is bracketed.
7+
8+
See: https://en.wikipedia.org/wiki/Brent%27s_method
9+
10+
Author: [Aryan Singh (2nd year LNMIIT)]
11+
"""
12+
113
from collections.abc import Callable
214

315
def brent_method(
4-
f: Callable[[float], float],
5-
a: float,
6-
b: float,
7-
tol: float = 1e-14,
8-
max_iter: int = 100
16+
function: Callable[[float], float],
17+
lower: float,
18+
upper: float,
19+
tolerance: float = 1e-14,
20+
max_iterations: int = 100,
921
) -> float:
1022
"""
11-
Root finding using Brent's method.
12-
13-
>>> brent_method(lambda x: x**3 - 1, -5, 5)
14-
1.0
15-
>>> brent_method(lambda x: x**2 - 4*x + 3, 0, 2)
16-
1.0
17-
>>> brent_method(lambda x: x**2 - 4*x + 3, 2, 4)
18-
3.0
19-
>>> brent_method(lambda x: x**2 - 4*x + 3, 4, 1000)
20-
Traceback (most recent call last):
21-
...
22-
ValueError: Root is not bracketed.
23-
"""
23+
Find a root of a function in a bracketing interval using Brent's method.
2424
25-
fa = f(a)
26-
fb = f(b)
25+
Args:
26+
function: A continuous function for which the root is sought.
27+
lower: One end of the bracketing interval.
28+
upper: The other end of the bracketing interval.
29+
tolerance: The tolerance for convergence (default 1e-14).
30+
max_iterations: Maximum number of iterations (default 100).
31+
32+
Returns:
33+
A float value approximating the root.
34+
35+
Raises:
36+
ValueError: If the root is not bracketed in [lower, upper].
37+
38+
Examples:
39+
>>> brent_method(lambda x: x**3 - 1, -5, 5)
40+
1.0
41+
>>> brent_method(lambda x: x**2 - 4*x + 3, 0, 2)
42+
1.0
43+
>>> brent_method(lambda x: x**2 - 4*x + 3, 2, 4)
44+
3.0
45+
>>> brent_method(lambda x: x**2 - 4*x + 3, 4, 1000)
46+
Traceback (most recent call last):
47+
...
48+
ValueError: Root is not bracketed in the interval [4, 1000].
49+
"""
50+
fa = function(lower)
51+
fb = function(upper)
2752
if fa * fb >= 0:
28-
raise ValueError("Root is not bracketed.")
53+
raise ValueError(f"Root is not bracketed in the interval [{lower}, {upper}].")
2954

3055
if abs(fa) < abs(fb):
31-
a, b = b, a
56+
lower, upper = upper, lower
3257
fa, fb = fb, fa
3358

34-
c = a
59+
c = lower
3560
fc = fa
36-
d = b - a # Only d is used, e removed
61+
d = upper - lower
3762
mflag = True
3863

39-
for _ in range(max_iter):
64+
for _ in range(max_iterations):
4065
if fb == 0:
41-
return b
66+
return upper
4267
if fc not in {fa, fb}:
4368
# Inverse quadratic interpolation
4469
s = (
45-
a * fb * fc / ((fa - fb) * (fa - fc)) +
46-
b * fa * fc / ((fb - fa) * (fb - fc)) +
47-
c * fa * fb / ((fc - fa) * (fc - fb))
70+
lower * fb * fc / ((fa - fb) * (fa - fc))
71+
+ upper * fa * fc / ((fb - fa) * (fb - fc))
72+
+ c * fa * fb / ((fc - fa) * (fc - fb))
4873
)
4974
else:
5075
# Secant method
51-
s = b - fb * (b - a) / (fb - fa)
76+
s = upper - fb * (upper - lower) / (fb - fa)
5277

5378
conditions = [
54-
not ((3 * a + b) / 4 < s < b if b > a else b < s < (3 * a + b) / 4),
55-
mflag and abs(s - b) >= abs(b - c) / 2,
56-
not mflag and abs(s - b) >= abs(c - d) / 2,
57-
mflag and abs(b - c) < tol,
58-
not mflag and abs(c - d) < tol,
79+
not ((3 * lower + upper) / 4 < s < upper if upper > lower else upper < s < (3 * lower + upper) / 4),
80+
mflag and abs(s - upper) >= abs(upper - c) / 2,
81+
not mflag and abs(s - upper) >= abs(c - d) / 2,
82+
mflag and abs(upper - c) < tolerance,
83+
not mflag and abs(c - d) < tolerance,
5984
]
6085
if any(conditions):
61-
s = (a + b) / 2
86+
s = (lower + upper) / 2
6287
mflag = True
6388
else:
6489
mflag = False
6590

66-
fs = f(s)
67-
d, c = c, b
91+
fs = function(s)
92+
d, c = c, upper
6893
fc = fb
6994

7095
if fa * fs < 0:
71-
b = s
96+
upper = s
7297
fb = fs
7398
else:
74-
a = s
99+
lower = s
75100
fa = fs
76101

77102
if abs(fa) < abs(fb):
78-
a, b = b, a
103+
lower, upper = upper, lower
79104
fa, fb = fb, fa
80105

81-
if abs(b - a) < tol or fb == 0:
82-
return b
106+
if abs(upper - lower) < tolerance or fb == 0:
107+
return upper
108+
109+
return upper
83110

84-
return b
85111

86112
if __name__ == "__main__":
87-
from doctest import testmod
88-
testmod()
113+
import doctest
114+
doctest.testmod()

0 commit comments

Comments
 (0)