1+ """
2+ Brent's Method for Root Finding
3+
4+ Find a root of a function in a bracketing interval using Brent's method.
5+
6+ Brent's method combines bisection, secant, and inverse quadratic interpolation to efficiently and robustly find a root of a continuous function. It is guaranteed to converge as long as the root is bracketed.
7+
8+ See: https://en.wikipedia.org/wiki/Brent%27s_method
9+
10+ Author: [Aryan Singh (2nd year LNMIIT)]
11+ """
12+
113from collections .abc import Callable
214
315def brent_method (
4- f : Callable [[float ], float ],
5- a : float ,
6- b : float ,
7- tol : float = 1e-14 ,
8- max_iter : int = 100
16+ function : Callable [[float ], float ],
17+ lower : float ,
18+ upper : float ,
19+ tolerance : float = 1e-14 ,
20+ max_iterations : int = 100 ,
921) -> float :
1022 """
11- Root finding using Brent's method.
12-
13- >>> brent_method(lambda x: x**3 - 1, -5, 5)
14- 1.0
15- >>> brent_method(lambda x: x**2 - 4*x + 3, 0, 2)
16- 1.0
17- >>> brent_method(lambda x: x**2 - 4*x + 3, 2, 4)
18- 3.0
19- >>> brent_method(lambda x: x**2 - 4*x + 3, 4, 1000)
20- Traceback (most recent call last):
21- ...
22- ValueError: Root is not bracketed.
23- """
23+ Find a root of a function in a bracketing interval using Brent's method.
2424
25- fa = f (a )
26- fb = f (b )
25+ Args:
26+ function: A continuous function for which the root is sought.
27+ lower: One end of the bracketing interval.
28+ upper: The other end of the bracketing interval.
29+ tolerance: The tolerance for convergence (default 1e-14).
30+ max_iterations: Maximum number of iterations (default 100).
31+
32+ Returns:
33+ A float value approximating the root.
34+
35+ Raises:
36+ ValueError: If the root is not bracketed in [lower, upper].
37+
38+ Examples:
39+ >>> brent_method(lambda x: x**3 - 1, -5, 5)
40+ 1.0
41+ >>> brent_method(lambda x: x**2 - 4*x + 3, 0, 2)
42+ 1.0
43+ >>> brent_method(lambda x: x**2 - 4*x + 3, 2, 4)
44+ 3.0
45+ >>> brent_method(lambda x: x**2 - 4*x + 3, 4, 1000)
46+ Traceback (most recent call last):
47+ ...
48+ ValueError: Root is not bracketed in the interval [4, 1000].
49+ """
50+ fa = function (lower )
51+ fb = function (upper )
2752 if fa * fb >= 0 :
28- raise ValueError ("Root is not bracketed." )
53+ raise ValueError (f "Root is not bracketed in the interval [ { lower } , { upper } ] ." )
2954
3055 if abs (fa ) < abs (fb ):
31- a , b = b , a
56+ lower , upper = upper , lower
3257 fa , fb = fb , fa
3358
34- c = a
59+ c = lower
3560 fc = fa
36- d = b - a # Only d is used, e removed
61+ d = upper - lower
3762 mflag = True
3863
39- for _ in range (max_iter ):
64+ for _ in range (max_iterations ):
4065 if fb == 0 :
41- return b
66+ return upper
4267 if fc not in {fa , fb }:
4368 # Inverse quadratic interpolation
4469 s = (
45- a * fb * fc / ((fa - fb ) * (fa - fc )) +
46- b * fa * fc / ((fb - fa ) * (fb - fc )) +
47- c * fa * fb / ((fc - fa ) * (fc - fb ))
70+ lower * fb * fc / ((fa - fb ) * (fa - fc ))
71+ + upper * fa * fc / ((fb - fa ) * (fb - fc ))
72+ + c * fa * fb / ((fc - fa ) * (fc - fb ))
4873 )
4974 else :
5075 # Secant method
51- s = b - fb * (b - a ) / (fb - fa )
76+ s = upper - fb * (upper - lower ) / (fb - fa )
5277
5378 conditions = [
54- not ((3 * a + b ) / 4 < s < b if b > a else b < s < (3 * a + b ) / 4 ),
55- mflag and abs (s - b ) >= abs (b - c ) / 2 ,
56- not mflag and abs (s - b ) >= abs (c - d ) / 2 ,
57- mflag and abs (b - c ) < tol ,
58- not mflag and abs (c - d ) < tol ,
79+ not ((3 * lower + upper ) / 4 < s < upper if upper > lower else upper < s < (3 * lower + upper ) / 4 ),
80+ mflag and abs (s - upper ) >= abs (upper - c ) / 2 ,
81+ not mflag and abs (s - upper ) >= abs (c - d ) / 2 ,
82+ mflag and abs (upper - c ) < tolerance ,
83+ not mflag and abs (c - d ) < tolerance ,
5984 ]
6085 if any (conditions ):
61- s = (a + b ) / 2
86+ s = (lower + upper ) / 2
6287 mflag = True
6388 else :
6489 mflag = False
6590
66- fs = f (s )
67- d , c = c , b
91+ fs = function (s )
92+ d , c = c , upper
6893 fc = fb
6994
7095 if fa * fs < 0 :
71- b = s
96+ upper = s
7297 fb = fs
7398 else :
74- a = s
99+ lower = s
75100 fa = fs
76101
77102 if abs (fa ) < abs (fb ):
78- a , b = b , a
103+ lower , upper = upper , lower
79104 fa , fb = fb , fa
80105
81- if abs (b - a ) < tol or fb == 0 :
82- return b
106+ if abs (upper - lower ) < tolerance or fb == 0 :
107+ return upper
108+
109+ return upper
83110
84- return b
85111
86112if __name__ == "__main__" :
87- from doctest import testmod
88- testmod ()
113+ import doctest
114+ doctest . testmod ()
0 commit comments