Skip to content

Commit 4283c33

Browse files
committed
brent7
1 parent 8636195 commit 4283c33

1 file changed

Lines changed: 57 additions & 49 deletions

File tree

maths/numerical_analysis/brents_method.py

Lines changed: 57 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,18 @@
55
66
Brent's method combines bisection, secant, and inverse quadratic interpolation to efficiently and robustly find a root of a continuous function. It is guaranteed to converge as long as the root is bracketed.
77
8-
See: https://en.wikipedia.org/wiki/Brent%27s_method
8+
See:
9+
https://en.wikipedia.org/wiki/Brent%27s_method
910
10-
Author: [Aryan Singh (2nd year LNMIIT)]
11+
Author: Aryan Singh (2nd year LNMIIT)
1112
"""
1213

1314
from collections.abc import Callable
1415

1516
def brent_method(
1617
function: Callable[[float], float],
17-
lower: float,
18-
upper: float,
18+
lower_bound: float,
19+
upper_bound: float,
1920
tolerance: float = 1e-14,
2021
max_iterations: int = 100,
2122
) -> float:
@@ -24,16 +25,16 @@ def brent_method(
2425
2526
Args:
2627
function: A continuous function for which the root is sought.
27-
lower: One end of the bracketing interval.
28-
upper: The other end of the bracketing interval.
28+
lower_bound: One end of the bracketing interval.
29+
upper_bound: The other end of the bracketing interval.
2930
tolerance: The tolerance for convergence (default 1e-14).
3031
max_iterations: Maximum number of iterations (default 100).
3132
3233
Returns:
3334
A float value approximating the root.
3435
3536
Raises:
36-
ValueError: If the root is not bracketed in [lower, upper].
37+
ValueError: If the root is not bracketed in [lower_bound, upper_bound].
3738
3839
Examples:
3940
>>> brent_method(lambda x: x**3 - 1, -5, 5)
@@ -47,66 +48,73 @@ def brent_method(
4748
...
4849
ValueError: Root is not bracketed in the interval [4, 1000].
4950
"""
50-
fa = function(lower)
51-
fb = function(upper)
52-
if fa * fb >= 0:
53-
raise ValueError(f"Root is not bracketed in the interval [{lower}, {upper}].")
54-
55-
if abs(fa) < abs(fb):
56-
lower, upper = upper, lower
57-
fa, fb = fb, fa
58-
59-
c = lower
60-
fc = fa
61-
d = upper - lower
62-
mflag = True
51+
function_lower = function(lower_bound)
52+
function_upper = function(upper_bound)
53+
if function_lower * function_upper >= 0:
54+
error_message = (
55+
"Root is not bracketed in the interval "
56+
f"[{lower_bound}, {upper_bound}]."
57+
)
58+
raise ValueError(error_message)
59+
60+
if abs(function_lower) < abs(function_upper):
61+
lower_bound, upper_bound = upper_bound, lower_bound
62+
function_lower, function_upper = function_upper, function_lower
63+
64+
previous_bound = lower_bound
65+
function_previous = function_lower
66+
previous_step = upper_bound - lower_bound
67+
bisect_flag = True
6368

6469
for _ in range(max_iterations):
65-
if fb == 0:
66-
return upper
67-
if fc not in {fa, fb}:
70+
if function_upper == 0:
71+
return upper_bound
72+
if function_previous not in {function_lower, function_upper}:
6873
# Inverse quadratic interpolation
6974
s = (
70-
lower * fb * fc / ((fa - fb) * (fa - fc))
71-
+ upper * fa * fc / ((fb - fa) * (fb - fc))
72-
+ c * fa * fb / ((fc - fa) * (fc - fb))
75+
lower_bound * function_upper * function_previous
76+
/ ((function_lower - function_upper) * (function_lower - function_previous))
77+
+ upper_bound * function_lower * function_previous
78+
/ ((function_upper - function_lower) * (function_upper - function_previous))
79+
+ previous_bound * function_lower * function_upper
80+
/ ((function_previous - function_lower) * (function_previous - function_upper))
7381
)
7482
else:
7583
# Secant method
76-
s = upper - fb * (upper - lower) / (fb - fa)
84+
s = upper_bound - function_upper * (upper_bound - lower_bound) / (function_upper - function_lower)
7785

7886
conditions = [
79-
not ((3 * lower + upper) / 4 < s < upper if upper > lower else upper < s < (3 * lower + upper) / 4),
80-
mflag and abs(s - upper) >= abs(upper - c) / 2,
81-
not mflag and abs(s - upper) >= abs(c - d) / 2,
82-
mflag and abs(upper - c) < tolerance,
83-
not mflag and abs(c - d) < tolerance,
87+
not ((3 * lower_bound + upper_bound) / 4 < s < upper_bound if upper_bound > lower_bound else upper_bound < s < (3 * lower_bound + upper_bound) / 4),
88+
bisect_flag and abs(s - upper_bound) >= abs(upper_bound - previous_bound) / 2,
89+
not bisect_flag and abs(s - upper_bound) >= abs(previous_bound - previous_step) / 2,
90+
bisect_flag and abs(upper_bound - previous_bound) < tolerance,
91+
not bisect_flag and abs(previous_bound - previous_step) < tolerance,
8492
]
8593
if any(conditions):
86-
s = (lower + upper) / 2
87-
mflag = True
94+
s = (lower_bound + upper_bound) / 2
95+
bisect_flag = True
8896
else:
89-
mflag = False
97+
bisect_flag = False
9098

91-
fs = function(s)
92-
d, c = c, upper
93-
fc = fb
99+
function_s = function(s)
100+
previous_step, previous_bound = previous_bound, upper_bound
101+
function_previous = function_upper
94102

95-
if fa * fs < 0:
96-
upper = s
97-
fb = fs
103+
if function_lower * function_s < 0:
104+
upper_bound = s
105+
function_upper = function_s
98106
else:
99-
lower = s
100-
fa = fs
107+
lower_bound = s
108+
function_lower = function_s
101109

102-
if abs(fa) < abs(fb):
103-
lower, upper = upper, lower
104-
fa, fb = fb, fa
110+
if abs(function_lower) < abs(function_upper):
111+
lower_bound, upper_bound = upper_bound, lower_bound
112+
function_lower, function_upper = function_upper, function_lower
105113

106-
if abs(upper - lower) < tolerance or fb == 0:
107-
return upper
114+
if abs(upper_bound - lower_bound) < tolerance or function_upper == 0:
115+
return upper_bound
108116

109-
return upper
117+
return upper_bound
110118

111119

112120
if __name__ == "__main__":

0 commit comments

Comments
 (0)