Skip to content

Commit 80ae77a

Browse files
committed
add segment_tree_node.py
1 parent d9d56b1 commit 80ae77a

1 file changed

Lines changed: 186 additions & 0 deletions

File tree

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
2+
class Node:
3+
def __init__(self, start: int, end: int) -> None:
4+
# Initializes a segment tree node with start and end indices
5+
self.start = start
6+
self.end = end
7+
self.value: int = 0
8+
self.left: Node = self
9+
self.right: Node = self
10+
11+
12+
class SegmentTree:
13+
def __init__(self, nums: list[int], mode: str = "max") -> None:
14+
"""
15+
Initializes the Segment Tree.
16+
:param nums: List of integers to build the tree from.
17+
:param mode: Operation mode of the tree ('max' or 'sum').
18+
"""
19+
self.size = len(nums)
20+
self.mode = mode
21+
if mode not in {"max", "sum"}:
22+
self.mode = "max" # Default to max if invalid mode is given
23+
24+
# Build the tree from the input list
25+
self.root: Node = self.build(0, self.size - 1, nums)
26+
27+
def build(self, start: int, end: int, nums: list[int]) -> Node:
28+
"""
29+
Recursively builds the segment tree.
30+
:param start: Start index of the segment.
31+
:param end: End index of the segment.
32+
:param nums: Original input array.
33+
:return: Root node of the constructed subtree.
34+
35+
>>> tree = SegmentTree([1, 2, 3, 4, 5], mode="max")
36+
>>> tree.root.value
37+
5
38+
"""
39+
if start > end:
40+
return Node(0, 0)
41+
42+
if start == end:
43+
# Leaf node
44+
n = Node(start, end)
45+
n.value = nums[start]
46+
return n
47+
48+
mid = (start + end) // 2
49+
root = Node(start, end)
50+
root.left = self.build(start, mid, nums)
51+
root.right = self.build(mid + 1, end, nums)
52+
53+
# Set the value according to the mode
54+
if self.mode == "max":
55+
root.value = max(root.left.value, root.right.value)
56+
else:
57+
root.value = root.left.value + root.right.value
58+
59+
return root
60+
61+
def max_in_range(self, start_index: int, end_index: int) -> int:
62+
"""
63+
Queries the maximum value in a given range.
64+
Only works in 'max' mode.
65+
66+
>>> tree = SegmentTree([1, 2, 3, 4, 5], mode="max")
67+
>>> tree.max_in_range(1, 3)
68+
4
69+
"""
70+
if self.mode == "sum":
71+
raise Exception("Current Segment Tree doesn't support finding max")
72+
73+
if start_index > end_index or start_index < 0 or end_index >= self.size:
74+
raise Exception("Invalid index")
75+
76+
if self.root is None:
77+
raise ValueError("Tree not initialized")
78+
79+
return self.query(self.root, start_index, end_index, 0, self.size - 1)
80+
81+
def sum_in_range(self, start_index: int, end_index: int) -> int:
82+
"""
83+
Queries the sum of values in a given range.
84+
Only works in 'sum' mode.
85+
86+
>>> tree = SegmentTree([1, 2, 3, 4, 5], mode="sum")
87+
>>> tree.sum_in_range(1, 3)
88+
9
89+
"""
90+
if self.mode == "max":
91+
raise Exception("Current Segment Tree doesn't support summing")
92+
93+
if start_index > end_index or start_index < 0 or end_index >= self.size:
94+
raise Exception("Invalid index")
95+
96+
if self.root is None:
97+
raise ValueError("Tree not initialized")
98+
99+
return self.query(self.root, start_index, end_index, 0, self.size - 1)
100+
101+
def query(
102+
self, node: Node, start_index: int, end_index: int, start: int, end: int
103+
) -> int:
104+
"""
105+
Recursively queries a value (max or sum) in a given range.
106+
:param node: Current node in the tree.
107+
:param start_index: Query start index.
108+
:param end_index: Query end index.
109+
:param start: Node's segment start.
110+
:param end: Node's segment end.
111+
:return: Result of query in the range.
112+
113+
>>> tree = SegmentTree([1, 2, 3, 4, 5], mode="max")
114+
>>> tree.query(tree.root, 1, 3, 0, 4)
115+
4
116+
"""
117+
# Complete overlap
118+
if start_index <= start and end <= end_index:
119+
return node.value
120+
121+
mid = (start + end) // 2
122+
123+
if end_index <= mid:
124+
# Entire range is in the left child
125+
return self.query(node.left, start_index, end_index, start, mid)
126+
elif start_index > mid:
127+
# Entire range is in the right child
128+
return self.query(node.right, start_index, end_index, mid + 1, end)
129+
elif self.mode == "max":
130+
return max(
131+
self.query(node.left, start_index, end_index, start, mid),
132+
self.query(node.right, start_index, end_index, mid + 1, end),
133+
)
134+
else:
135+
return self.query(
136+
node.left, start_index, end_index, start, mid
137+
) + self.query(node.right, start_index, end_index, mid + 1, end)
138+
139+
def update(self, index: int, new_value: int) -> None:
140+
"""
141+
Updates a value at a specific index in the segment tree.
142+
:param index: Index to update.
143+
:param new_value: New value to set.
144+
145+
>>> tree = SegmentTree([1, 2, 3, 4, 5], mode="max")
146+
>>> tree.update(2, 6)
147+
>>> tree.max_in_range(1, 3)
148+
6
149+
"""
150+
if index < 0 or index >= self.size:
151+
raise Exception("Invalid index")
152+
153+
self.modify(self.root, index, new_value, 0, self.size - 1)
154+
155+
def modify(
156+
self, node: Node, index: int, new_value: int, start: int, end: int
157+
) -> None:
158+
"""
159+
Recursively updates the tree to reflect a change at a specific index.
160+
:param node: Current node being processed.
161+
:param index: Index to update.
162+
:param new_value: New value to assign.
163+
:param start: Start index of node's segment.
164+
:param end: End index of node's segment.
165+
166+
>>> tree = SegmentTree([1, 2, 3, 4, 5], mode="max")
167+
>>> tree.modify(tree.root, 2, 6, 0, 4)
168+
>>> tree.max_in_range(0, 4)
169+
6
170+
"""
171+
if start == end:
172+
node.value = new_value
173+
return
174+
175+
mid = (start + end) // 2
176+
177+
if index <= mid:
178+
self.modify(node.left, index, new_value, start, mid)
179+
else:
180+
self.modify(node.right, index, new_value, mid + 1, end)
181+
182+
# Recompute current node's value after update
183+
if self.mode == "max":
184+
node.value = max(node.left.value, node.right.value)
185+
else:
186+
node.value = node.left.value + node.right.value

0 commit comments

Comments
 (0)