Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 146 additions & 0 deletions maths/numerical_analysis/brents_method.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
"""
Brent's Method for Root Finding

Find a root of a function in a bracketing interval using Brent's method.

Brent's method combines bisection, secant, and inverse quadratic interpolation to efficiently and robustly find a root of a continuous function. It is guaranteed to converge as long as the root is bracketed.

Check failure on line 6 in maths/numerical_analysis/brents_method.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (E501)

maths/numerical_analysis/brents_method.py:6:89: E501 Line too long (207 > 88)

See:
https://en.wikipedia.org/wiki/Brent%27s_method

Author: Aryan Singh (2nd year LNMIIT)
"""

from collections.abc import Callable


def brent_method(
function: Callable[[float], float],
lower_bound: float,
upper_bound: float,
tolerance: float = 1e-14,
max_iterations: int = 100,
) -> float:
"""
Find a root of a function in a bracketing interval using Brent's method.

Args:
function: A continuous function for which the root is sought.
lower_bound: One end of the bracketing interval.
upper_bound: The other end of the bracketing interval.
tolerance: The tolerance for convergence (default 1e-14).
max_iterations: Maximum number of iterations (default 100).

Returns:
A float value approximating the root.

Raises:
ValueError: If the root is not bracketed in [lower_bound, upper_bound].

Examples:
>>> brent_method(lambda x: x**3 - 1, -5, 5)
1.0
>>> brent_method(lambda x: x**2 - 4*x + 3, 0, 2)
1.0
>>> brent_method(lambda x: x**2 - 4*x + 3, 2, 4)
3.0
>>> brent_method(lambda x: x**2 - 4*x + 3, 4, 1000)
Traceback (most recent call last):
...
ValueError: Root is not bracketed in the interval [4, 1000].
"""
function_lower = function(lower_bound)
function_upper = function(upper_bound)
if function_lower * function_upper >= 0:
error_message = (
f"Root is not bracketed in the interval [{lower_bound}, {upper_bound}]."
)
raise ValueError(error_message)

if abs(function_lower) < abs(function_upper):
lower_bound, upper_bound = upper_bound, lower_bound
function_lower, function_upper = function_upper, function_lower

previous_bound = lower_bound
function_previous = function_lower
previous_step = upper_bound - lower_bound
bisect_flag = True

for _ in range(max_iterations):
if function_upper == 0:
return upper_bound
if function_previous not in {function_lower, function_upper}:
# Inverse quadratic interpolation
s = (
lower_bound
* function_upper
* function_previous
/ (
(function_lower - function_upper)
* (function_lower - function_previous)
)
+ upper_bound
* function_lower
* function_previous
/ (
(function_upper - function_lower)
* (function_upper - function_previous)
)
+ previous_bound
* function_lower
* function_upper
/ (
(function_previous - function_lower)
* (function_previous - function_upper)
)
)
else:
# Secant method
s = upper_bound - function_upper * (upper_bound - lower_bound) / (
function_upper - function_lower
)

conditions = [
not (
(3 * lower_bound + upper_bound) / 4 < s < upper_bound
if upper_bound > lower_bound
else upper_bound < s < (3 * lower_bound + upper_bound) / 4
),
bisect_flag
and abs(s - upper_bound) >= abs(upper_bound - previous_bound) / 2,
not bisect_flag
and abs(s - upper_bound) >= abs(previous_bound - previous_step) / 2,
bisect_flag and abs(upper_bound - previous_bound) < tolerance,
not bisect_flag and abs(previous_bound - previous_step) < tolerance,
]
if any(conditions):
s = (lower_bound + upper_bound) / 2
bisect_flag = True
else:
bisect_flag = False

function_s = function(s)
previous_step, previous_bound = previous_bound, upper_bound
function_previous = function_upper

if function_lower * function_s < 0:
upper_bound = s
function_upper = function_s
else:
lower_bound = s
function_lower = function_s

if abs(function_lower) < abs(function_upper):
lower_bound, upper_bound = upper_bound, lower_bound
function_lower, function_upper = function_upper, function_lower

if abs(upper_bound - lower_bound) < tolerance or function_upper == 0:
return upper_bound

return upper_bound


if __name__ == "__main__":
import doctest

doctest.testmod()
Loading