From 13a414d48d7b49f9c84e97de654213ad5b2e858b Mon Sep 17 00:00:00 2001 From: harshgupta2125 Date: Thu, 2 Oct 2025 11:48:42 +0530 Subject: [PATCH 1/2] =?UTF-8?q?Add=20Brent=E2=80=99s=20Method=20for=20root?= =?UTF-8?q?=20finding?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- maths/numerical_analysis/brent_method.py | 100 +++++++++++++++++++++++ 1 file changed, 100 insertions(+) create mode 100644 maths/numerical_analysis/brent_method.py diff --git a/maths/numerical_analysis/brent_method.py b/maths/numerical_analysis/brent_method.py new file mode 100644 index 000000000000..1aea2a9a45d1 --- /dev/null +++ b/maths/numerical_analysis/brent_method.py @@ -0,0 +1,100 @@ +""" +Brent's Method for root finding. + +This function implements Brent's Method, an efficient algorithm for finding the +root of a function. It combines the bisection method, the secant method, and +inverse quadratic interpolation. + +Reference: +- https://en.wikipedia.org/wiki/Brent%27s_method +- https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.brentq.html + + +>>> def f(x): return x**3 - x - 2 +>>> round(brent_method(f, 1, 2), 6) +1.52138 +>>> brent_method(f, 1, 1.5) # No sign change, should raise an error +Traceback (most recent call last): + ... +ValueError: f(a) and f(b) must have different signs +""" + +from collections.abc import Callable + + +def brent_method( + f: Callable[[float], float], + a: float, + b: float, + tol: float = 1e-7, + max_iter: int = 100, +) -> float: + """ + Find a root of the function f in the interval [a, b] using Brent's method. + + Args: + f: The function for which we are trying to find a root. + a: The start of the interval. + b: The end of the interval. + tol: The allowed error of the result. + max_iter: Maximum number of iterations. + + Returns: + A root of f in [a, b], accurate to within tol. + + Raises: + ValueError: If f(a) and f(b) do not have opposite signs. + RuntimeError: If the root is not found within max_iter iterations. + """ + fa = f(a) + fb = f(b) + if fa * fb >= 0: + raise ValueError("f(a) and f(b) must have different signs") + + if abs(fa) < abs(fb): + a, b = b, a + fa, fb = fb, fa + + c, fc = a, fa + d = e = b - a + + for _ in range(max_iter): + if fb == 0: + return b + if fc not in (fa, fb): + # Inverse quadratic interpolation + s = ( + a * fb * fc / ((fa - fb) * (fa - fc)) + + b * fa * fc / ((fb - fa) * (fb - fc)) + + c * fa * fb / ((fc - fa) * (fc - fb)) + ) + else: + # Secant Method + s = b - fb * (b - a) / (fb - fa) + + conditions = [ + not ((3 * a + b) / 4 < s < b) if b > a else not (b < s < (3 * a + b) / 4), + (e is not None and abs(s - b) >= abs(e / 2)), + (d is not None and abs(d) >= abs(e / 2)), + abs(b - a) < tol, + ] + if any(conditions): + s = (a + b) / 2 # Bisection method + e = d = b - a + else: + d = e + e = b - s + + fs = f(s) + c, fc = b, fb + if fa * fs < 0: + b, fb = s, fs + else: + a, fa = s, fs + if abs(fa) < abs(fb): + a, b = b, a + fa, fb = fb, fa + if abs(b - a) < tol: + return b + + raise RuntimeError("Maximum number of iterations reached without convergence") From ead646c3f1b36e829d1636ce482fbb027de2f2db Mon Sep 17 00:00:00 2001 From: harshgupta2125 Date: Thu, 2 Oct 2025 12:06:25 +0530 Subject: [PATCH 2/2] =?UTF-8?q?Add=20Brent=E2=80=99s=20Method=20for=20root?= =?UTF-8?q?=20finding=20with=20descriptive=20parameters?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- maths/numerical_analysis/brent_method.py | 58 ++++++++++++------------ 1 file changed, 30 insertions(+), 28 deletions(-) diff --git a/maths/numerical_analysis/brent_method.py b/maths/numerical_analysis/brent_method.py index 1aea2a9a45d1..c66e72131c57 100644 --- a/maths/numerical_analysis/brent_method.py +++ b/maths/numerical_analysis/brent_method.py @@ -23,9 +23,9 @@ def brent_method( - f: Callable[[float], float], - a: float, - b: float, + func: Callable[[float], float], + left_bound: float, + right_bound: float, tol: float = 1e-7, max_iter: int = 100, ) -> float: @@ -33,9 +33,9 @@ def brent_method( Find a root of the function f in the interval [a, b] using Brent's method. Args: - f: The function for which we are trying to find a root. - a: The start of the interval. - b: The end of the interval. + func: The function for which we are trying to find a root. + left_bound: The start of the interval. + right_bound: The end of the interval. tol: The allowed error of the result. max_iter: Maximum number of iterations. @@ -46,55 +46,57 @@ def brent_method( ValueError: If f(a) and f(b) do not have opposite signs. RuntimeError: If the root is not found within max_iter iterations. """ - fa = f(a) - fb = f(b) + fa = func(left_bound) + fb = func(right_bound) if fa * fb >= 0: raise ValueError("f(a) and f(b) must have different signs") if abs(fa) < abs(fb): - a, b = b, a + left_bound, right_bound = right_bound, left_bound fa, fb = fb, fa - c, fc = a, fa - d = e = b - a + c, fc = left_bound, fa + d = e = right_bound - left_bound for _ in range(max_iter): if fb == 0: - return b + return right_bound if fc not in (fa, fb): # Inverse quadratic interpolation s = ( - a * fb * fc / ((fa - fb) * (fa - fc)) - + b * fa * fc / ((fb - fa) * (fb - fc)) + left_bound * fb * fc / ((fa - fb) * (fa - fc)) + + right_bound * fa * fc / ((fb - fa) * (fb - fc)) + c * fa * fb / ((fc - fa) * (fc - fb)) ) else: # Secant Method - s = b - fb * (b - a) / (fb - fa) + s = right_bound - fb * (right_bound - left_bound) / (fb - fa) conditions = [ - not ((3 * a + b) / 4 < s < b) if b > a else not (b < s < (3 * a + b) / 4), - (e is not None and abs(s - b) >= abs(e / 2)), + not ((3 * left_bound + right_bound) / 4 < s < right_bound) + if right_bound > left_bound + else not (right_bound < s < (3 * left_bound + right_bound) / 4), + (e is not None and abs(s - right_bound) >= abs(e / 2)), (d is not None and abs(d) >= abs(e / 2)), - abs(b - a) < tol, + abs(right_bound - left_bound) < tol, ] if any(conditions): - s = (a + b) / 2 # Bisection method - e = d = b - a + s = (left_bound + right_bound) / 2 # Bisection method + e = d = right_bound - left_bound else: d = e - e = b - s + e = right_bound - s - fs = f(s) - c, fc = b, fb + fs = func(s) + c, fc = right_bound, fb if fa * fs < 0: - b, fb = s, fs + right_bound, fb = s, fs else: - a, fa = s, fs + left_bound, fa = s, fs if abs(fa) < abs(fb): - a, b = b, a + left_bound, right_bound = right_bound, left_bound fa, fb = fb, fa - if abs(b - a) < tol: - return b + if abs(right_bound - left_bound) < tol: + return right_bound raise RuntimeError("Maximum number of iterations reached without convergence")