|
2 | 2 |
|
3 | 3 |
|
4 | 4 | def brent_method( |
5 | | - f: Callable[[float], float], |
6 | | - a: float, |
7 | | - b: float, |
| 5 | + func: Callable[[float], float], |
| 6 | + left: float, |
| 7 | + right: float, |
8 | 8 | tol: float = 1e-8, |
9 | 9 | max_iter: int = 100, |
10 | 10 | ) -> float: |
11 | 11 | """ |
12 | | - Find the root of function f in the interval [a, b] using Brent's Method. |
| 12 | + Find the root of function func in the interval [left, right] using Brent's Method. |
13 | 13 |
|
14 | 14 | Brent's Method combines bisection, secant, and inverse quadratic interpolation. |
15 | 15 |
|
| 16 | +
|
16 | 17 | Parameters |
17 | 18 | ---------- |
18 | | - f : Callable[[float], float] |
| 19 | + func : Callable[[float], float] |
19 | 20 | Function for which to find the root. |
20 | | - a : float |
| 21 | + left : float |
21 | 22 | Left endpoint of interval. |
22 | | - b : float |
| 23 | + right : float |
23 | 24 | Right endpoint of interval. |
24 | 25 | tol : float |
25 | 26 | Tolerance for convergence (default 1e-8). |
26 | 27 | max_iter : int |
27 | 28 | Maximum number of iterations (default 100). |
28 | 29 |
|
| 30 | +
|
29 | 31 | Returns |
30 | 32 | ------- |
31 | 33 | float |
32 | | - Approximate root of f in [a, b]. |
| 34 | + Approximate root of func in [left, right]. |
33 | 35 |
|
34 | 36 | Raises |
35 | 37 | ------ |
36 | 38 | ValueError |
37 | | - If f(a) and f(b) do not have opposite signs. |
| 39 | + If func(left) and func(right) do not have opposite signs. |
38 | 40 |
|
39 | 41 | Examples |
40 | 42 | -------- |
41 | | - >>> def func(x): return x**3 - x - 2 |
42 | | - >>> round(brent_method(func, 1, 2), 5) |
| 43 | + >>> def f(x): return x**3 - x - 2 |
| 44 | + >>> round(brent_method(f, 1, 2), 5) |
43 | 45 | 1.52138 |
44 | 46 |
|
45 | | - >>> def func2(x): return x**2 + 1 |
46 | | - >>> brent_method(func2, 0, 1) |
| 47 | + >>> def f2(x): return x**2 + 1 |
| 48 | + >>> brent_method(f2, 0, 1) |
47 | 49 | Traceback (most recent call last): |
48 | 50 | ... |
49 | | - ValueError: f(a) and f(b) must have opposite signs |
| 51 | + ValueError: func(left) and func(right) must have opposite signs |
50 | 52 | """ |
51 | | - fa = f(a) |
52 | | - fb = f(b) |
| 53 | + fl = func(left) |
| 54 | + fr = func(right) |
53 | 55 |
|
54 | | - if fa * fb >= 0: |
55 | | - raise ValueError("f(a) and f(b) must have opposite signs") |
| 56 | + if fl * fr >= 0: |
| 57 | + raise ValueError("func(left) and func(right) must have opposite signs") |
56 | 58 |
|
57 | | - if abs(fa) < abs(fb): |
58 | | - a, b = b, a |
59 | | - fa, fb = fb, fa |
| 59 | + if abs(fl) < abs(fr): |
| 60 | + left, right = right, left |
| 61 | + fl, fr = fr, fl |
60 | 62 |
|
61 | | - c = a |
62 | | - fc = fa |
63 | | - d = e = b - a |
| 63 | + c = left |
| 64 | + fc = fl |
| 65 | + d = right - left |
64 | 66 |
|
65 | 67 | for iteration in range(max_iter): |
66 | | - if fb == 0: |
67 | | - return b |
| 68 | + if fr == 0: |
| 69 | + return right |
68 | 70 |
|
69 | | - if fa != fc and fb != fc: |
| 71 | + if fc not in (fl, fr): |
70 | 72 | # Inverse quadratic interpolation |
71 | 73 | s = ( |
72 | | - a * fb * fc / ((fa - fb) * (fa - fc)) |
73 | | - + b * fa * fc / ((fb - fa) * (fb - fc)) |
74 | | - + c * fa * fb / ((fc - fa) * (fc - fb)) |
| 74 | + left * fr * fc / ((fl - fr) * (fl - fc)) |
| 75 | + + right * fl * fc / ((fr - fl) * (fr - fc)) |
| 76 | + + c * fl * fr / ((fc - fl) * (fc - fr)) |
75 | 77 | ) |
76 | 78 | else: |
77 | 79 | # Secant method |
78 | | - s = b - fb * (b - a) / (fb - fa) |
| 80 | + s = right - fr * (right - left) / (fr - fl) |
79 | 81 |
|
80 | 82 | conditions = [ |
81 | | - not ((3 * a + b) / 4 < s < b) if b > a else not (b < s < (3 * a + b) / 4), |
82 | | - iteration > 1 and abs(s - b) >= abs(b - c) / 2, |
83 | | - iteration <= 1 and abs(s - b) >= abs(c - d) / 2, |
84 | | - iteration > 1 and abs(b - c) < tol, |
| 83 | + not ((3 * left + right) / 4 < s < right) if right > left else not (right < s < (3 * left + right) / 4), |
| 84 | + iteration > 1 and abs(s - right) >= abs(right - c) / 2, |
| 85 | + iteration <= 1 and abs(s - right) >= abs(c - d) / 2, |
| 86 | + iteration > 1 and abs(right - c) < tol, |
85 | 87 | iteration <= 1 and abs(c - d) < tol, |
86 | 88 | ] |
87 | 89 |
|
88 | 90 | if any(conditions): |
89 | 91 | # Bisection fallback |
90 | | - s = (a + b) / 2 |
91 | | - d = e = b - a |
| 92 | + s = (left + right) / 2 |
| 93 | + d = right - left |
92 | 94 |
|
93 | | - fs = f(s) |
94 | | - d, c = c, b |
95 | | - fc = fb |
| 95 | + fs = func(s) |
| 96 | + d, c = c, right |
| 97 | + fc = fr |
96 | 98 |
|
97 | | - if fa * fs < 0: |
98 | | - b = s |
99 | | - fb = fs |
| 99 | + if fl * fs < 0: |
| 100 | + right = s |
| 101 | + fr = fs |
100 | 102 | else: |
101 | | - a = s |
102 | | - fa = fs |
| 103 | + left = s |
| 104 | + fl = fs |
103 | 105 |
|
104 | | - if abs(fa) < abs(fb): |
105 | | - a, b = b, a |
106 | | - fa, fb = fb, fa |
| 106 | + if abs(fl) < abs(fr): |
| 107 | + left, right = right, left |
| 108 | + fl, fr = fr, fl |
107 | 109 |
|
108 | | - if abs(b - a) < tol: |
109 | | - return b |
| 110 | + if abs(right - left) < tol: |
| 111 | + return right |
110 | 112 |
|
111 | | - # If we reach max iterations |
112 | | - return b |
| 113 | + return right |
113 | 114 |
|
114 | 115 |
|
115 | 116 | if __name__ == "__main__": |
|
0 commit comments