55
66Brent's method combines bisection, secant, and inverse quadratic interpolation to efficiently and robustly find a root of a continuous function. It is guaranteed to converge as long as the root is bracketed.
77
8- See: https://en.wikipedia.org/wiki/Brent%27s_method
8+ See:
9+ https://en.wikipedia.org/wiki/Brent%27s_method
910
10- Author: [ Aryan Singh (2nd year LNMIIT)]
11+ Author: Aryan Singh (2nd year LNMIIT)
1112"""
1213
1314from collections .abc import Callable
1415
1516def brent_method (
1617 function : Callable [[float ], float ],
17- lower : float ,
18- upper : float ,
18+ lower_bound : float ,
19+ upper_bound : float ,
1920 tolerance : float = 1e-14 ,
2021 max_iterations : int = 100 ,
2122) -> float :
@@ -24,16 +25,16 @@ def brent_method(
2425
2526 Args:
2627 function: A continuous function for which the root is sought.
27- lower : One end of the bracketing interval.
28- upper : The other end of the bracketing interval.
28+ lower_bound : One end of the bracketing interval.
29+ upper_bound : The other end of the bracketing interval.
2930 tolerance: The tolerance for convergence (default 1e-14).
3031 max_iterations: Maximum number of iterations (default 100).
3132
3233 Returns:
3334 A float value approximating the root.
3435
3536 Raises:
36- ValueError: If the root is not bracketed in [lower, upper ].
37+ ValueError: If the root is not bracketed in [lower_bound, upper_bound ].
3738
3839 Examples:
3940 >>> brent_method(lambda x: x**3 - 1, -5, 5)
@@ -47,66 +48,73 @@ def brent_method(
4748 ...
4849 ValueError: Root is not bracketed in the interval [4, 1000].
4950 """
50- fa = function (lower )
51- fb = function (upper )
52- if fa * fb >= 0 :
53- raise ValueError (f"Root is not bracketed in the interval [{ lower } , { upper } ]." )
54-
55- if abs (fa ) < abs (fb ):
56- lower , upper = upper , lower
57- fa , fb = fb , fa
58-
59- c = lower
60- fc = fa
61- d = upper - lower
62- mflag = True
51+ function_lower = function (lower_bound )
52+ function_upper = function (upper_bound )
53+ if function_lower * function_upper >= 0 :
54+ error_message = (
55+ "Root is not bracketed in the interval "
56+ f"[{ lower_bound } , { upper_bound } ]."
57+ )
58+ raise ValueError (error_message )
59+
60+ if abs (function_lower ) < abs (function_upper ):
61+ lower_bound , upper_bound = upper_bound , lower_bound
62+ function_lower , function_upper = function_upper , function_lower
63+
64+ previous_bound = lower_bound
65+ function_previous = function_lower
66+ previous_step = upper_bound - lower_bound
67+ bisect_flag = True
6368
6469 for _ in range (max_iterations ):
65- if fb == 0 :
66- return upper
67- if fc not in {fa , fb }:
70+ if function_upper == 0 :
71+ return upper_bound
72+ if function_previous not in {function_lower , function_upper }:
6873 # Inverse quadratic interpolation
6974 s = (
70- lower * fb * fc / ((fa - fb ) * (fa - fc ))
71- + upper * fa * fc / ((fb - fa ) * (fb - fc ))
72- + c * fa * fb / ((fc - fa ) * (fc - fb ))
75+ lower_bound * function_upper * function_previous
76+ / ((function_lower - function_upper ) * (function_lower - function_previous ))
77+ + upper_bound * function_lower * function_previous
78+ / ((function_upper - function_lower ) * (function_upper - function_previous ))
79+ + previous_bound * function_lower * function_upper
80+ / ((function_previous - function_lower ) * (function_previous - function_upper ))
7381 )
7482 else :
7583 # Secant method
76- s = upper - fb * (upper - lower ) / (fb - fa )
84+ s = upper_bound - function_upper * (upper_bound - lower_bound ) / (function_upper - function_lower )
7785
7886 conditions = [
79- not ((3 * lower + upper ) / 4 < s < upper if upper > lower else upper < s < (3 * lower + upper ) / 4 ),
80- mflag and abs (s - upper ) >= abs (upper - c ) / 2 ,
81- not mflag and abs (s - upper ) >= abs (c - d ) / 2 ,
82- mflag and abs (upper - c ) < tolerance ,
83- not mflag and abs (c - d ) < tolerance ,
87+ not ((3 * lower_bound + upper_bound ) / 4 < s < upper_bound if upper_bound > lower_bound else upper_bound < s < (3 * lower_bound + upper_bound ) / 4 ),
88+ bisect_flag and abs (s - upper_bound ) >= abs (upper_bound - previous_bound ) / 2 ,
89+ not bisect_flag and abs (s - upper_bound ) >= abs (previous_bound - previous_step ) / 2 ,
90+ bisect_flag and abs (upper_bound - previous_bound ) < tolerance ,
91+ not bisect_flag and abs (previous_bound - previous_step ) < tolerance ,
8492 ]
8593 if any (conditions ):
86- s = (lower + upper ) / 2
87- mflag = True
94+ s = (lower_bound + upper_bound ) / 2
95+ bisect_flag = True
8896 else :
89- mflag = False
97+ bisect_flag = False
9098
91- fs = function (s )
92- d , c = c , upper
93- fc = fb
99+ function_s = function (s )
100+ previous_step , previous_bound = previous_bound , upper_bound
101+ function_previous = function_upper
94102
95- if fa * fs < 0 :
96- upper = s
97- fb = fs
103+ if function_lower * function_s < 0 :
104+ upper_bound = s
105+ function_upper = function_s
98106 else :
99- lower = s
100- fa = fs
107+ lower_bound = s
108+ function_lower = function_s
101109
102- if abs (fa ) < abs (fb ):
103- lower , upper = upper , lower
104- fa , fb = fb , fa
110+ if abs (function_lower ) < abs (function_upper ):
111+ lower_bound , upper_bound = upper_bound , lower_bound
112+ function_lower , function_upper = function_upper , function_lower
105113
106- if abs (upper - lower ) < tolerance or fb == 0 :
107- return upper
114+ if abs (upper_bound - lower_bound ) < tolerance or function_upper == 0 :
115+ return upper_bound
108116
109- return upper
117+ return upper_bound
110118
111119
112120if __name__ == "__main__" :
0 commit comments