Skip to content

Commit a1a379c

Browse files
authored
Update minimum_spanning_tree_prims2.py
1 parent 1273319 commit a1a379c

1 file changed

Lines changed: 63 additions & 82 deletions

File tree

graphs/minimum_spanning_tree_prims2.py

Lines changed: 63 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,10 @@
1010
from __future__ import annotations
1111

1212
from sys import maxsize
13-
from typing import Generic, TypeVar
14-
15-
T = TypeVar("T")
16-
1713

1814
def get_parent_position(position: int) -> int:
1915
"""
20-
Heap helper function to get the position of the parent of the current node
16+
heap helper function get the position of the parent of the current node
2117
2218
>>> get_parent_position(1)
2319
0
@@ -29,7 +25,7 @@ def get_parent_position(position: int) -> int:
2925

3026
def get_child_left_position(position: int) -> int:
3127
"""
32-
Heap helper function to get the position of the left child of the current node
28+
heap helper function get the position of the left child of the current node
3329
3430
>>> get_child_left_position(0)
3531
1
@@ -39,26 +35,29 @@ def get_child_left_position(position: int) -> int:
3935

4036
def get_child_right_position(position: int) -> int:
4137
"""
42-
Heap helper function to get the position of the right child of the current node
38+
heap helper function get the position of the right child of the current node
4339
4440
>>> get_child_right_position(0)
4541
2
4642
"""
4743
return (2 * position) + 2
4844

4945

50-
class MinPriorityQueue(Generic[T]):
46+
class MinPriorityQueue[T]:
5147
"""
5248
Minimum Priority Queue Class
5349
5450
Functions:
55-
is_empty: Check if the priority queue is empty
56-
push: Add an element with given priority to the queue
57-
extract_min: Remove and return the element with lowest weight (highest priority)
58-
update_key: Update the weight of the given key
59-
_bubble_up: Place a node at proper position (upward movement)
60-
_bubble_down: Place a node at proper position (downward movement)
61-
_swap_nodes: Swap nodes at given positions
51+
is_empty: function to check if the priority queue is empty
52+
push: function to add an element with given priority to the queue
53+
extract_min: function to remove and return the element with lowest weight (highest
54+
priority)
55+
update_key: function to update the weight of the given key
56+
_bubble_up: helper function to place a node at the proper position (upward
57+
movement)
58+
_bubble_down: helper function to place a node at the proper position (downward
59+
movement)
60+
_swap_nodes: helper function to swap the nodes at the given positions
6261
6362
>>> queue = MinPriorityQueue()
6463
@@ -92,18 +91,18 @@ def __repr__(self) -> str:
9291
return str(self.heap)
9392

9493
def is_empty(self) -> bool:
95-
"""Check if the priority queue is empty"""
94+
# Check if the priority queue is empty
9695
return self.elements == 0
9796

9897
def push(self, elem: T, weight: int) -> None:
99-
"""Add an element with given priority to the queue"""
98+
# Add an element with given priority to the queue
10099
self.heap.append((elem, weight))
101100
self.position_map[elem] = self.elements
102101
self.elements += 1
103102
self._bubble_up(elem)
104103

105104
def extract_min(self) -> T:
106-
"""Remove and return the element with lowest weight (highest priority)"""
105+
# Remove and return the element with lowest weight (highest priority)
107106
if self.elements > 1:
108107
self._swap_nodes(0, self.elements - 1)
109108
elem, _ = self.heap.pop()
@@ -115,7 +114,7 @@ def extract_min(self) -> T:
115114
return elem
116115

117116
def update_key(self, elem: T, weight: int) -> None:
118-
"""Update the weight of the given key"""
117+
# Update the weight of the given key
119118
position = self.position_map[elem]
120119
self.heap[position] = (elem, weight)
121120
if position > 0:
@@ -129,50 +128,48 @@ def update_key(self, elem: T, weight: int) -> None:
129128
self._bubble_down(elem)
130129

131130
def _bubble_up(self, elem: T) -> None:
132-
"""Place node at proper position (upward movement) - internal use only"""
131+
# Place a node at the proper position (upward movement) [to be used internally
132+
# only]
133133
curr_pos = self.position_map[elem]
134134
if curr_pos == 0:
135-
return
135+
return None
136136
parent_position = get_parent_position(curr_pos)
137137
_, weight = self.heap[curr_pos]
138138
_, parent_weight = self.heap[parent_position]
139139
if parent_weight > weight:
140140
self._swap_nodes(parent_position, curr_pos)
141-
self._bubble_up(elem)
141+
return self._bubble_up(elem)
142+
return None
142143

143144
def _bubble_down(self, elem: T) -> None:
144-
"""Place node at proper position (downward movement) - internal use only"""
145+
# Place a node at the proper position (downward movement) [to be used
146+
# internally only]
145147
curr_pos = self.position_map[elem]
146148
_, weight = self.heap[curr_pos]
147149
child_left_position = get_child_left_position(curr_pos)
148150
child_right_position = get_child_right_position(curr_pos)
149-
150-
# Check if both children exist
151151
if child_left_position < self.elements and child_right_position < self.elements:
152152
_, child_left_weight = self.heap[child_left_position]
153153
_, child_right_weight = self.heap[child_right_position]
154154
if child_right_weight < child_left_weight and child_right_weight < weight:
155155
self._swap_nodes(child_right_position, curr_pos)
156-
self._bubble_down(elem)
157-
return
158-
159-
# Check left child
156+
return self._bubble_down(elem)
160157
if child_left_position < self.elements:
161158
_, child_left_weight = self.heap[child_left_position]
162159
if child_left_weight < weight:
163160
self._swap_nodes(child_left_position, curr_pos)
164-
self._bubble_down(elem)
165-
return
166-
167-
# Check right child
161+
return self._bubble_down(elem)
162+
else:
163+
return None
168164
if child_right_position < self.elements:
169165
_, child_right_weight = self.heap[child_right_position]
170166
if child_right_weight < weight:
171167
self._swap_nodes(child_right_position, curr_pos)
172-
self._bubble_down(elem)
168+
return self._bubble_down(elem)
169+
return None
173170

174171
def _swap_nodes(self, node1_pos: int, node2_pos: int) -> None:
175-
"""Swap nodes at given positions"""
172+
# Swap the nodes at the given positions
176173
node1_elem = self.heap[node1_pos][0]
177174
node2_elem = self.heap[node2_pos][0]
178175
self.heap[node1_pos], self.heap[node2_pos] = (
@@ -183,13 +180,13 @@ def _swap_nodes(self, node1_pos: int, node2_pos: int) -> None:
183180
self.position_map[node2_elem] = node1_pos
184181

185182

186-
class GraphUndirectedWeighted(Generic[T]):
183+
class GraphUndirectedWeighted[T]:
187184
"""
188185
Graph Undirected Weighted Class
189186
190187
Functions:
191-
add_node: Add a node to the graph
192-
add_edge: Add an edge between two nodes with given weight
188+
add_node: function to add a node in the graph
189+
add_edge: function to add an edge between 2 nodes in the graph
193190
"""
194191

195192
def __init__(self) -> None:
@@ -203,26 +200,25 @@ def __len__(self) -> int:
203200
return self.nodes
204201

205202
def add_node(self, node: T) -> None:
206-
"""Add a node to the graph if not already present"""
203+
# Add a node in the graph if it is not in the graph
207204
if node not in self.connections:
208205
self.connections[node] = {}
209206
self.nodes += 1
210207

211208
def add_edge(self, node1: T, node2: T, weight: int) -> None:
212-
"""Add an edge between two nodes with given weight"""
209+
# Add an edge between 2 nodes in the graph
213210
self.add_node(node1)
214211
self.add_node(node2)
215212
self.connections[node1][node2] = weight
216213
self.connections[node2][node1] = weight
217214

218215

219-
def prims_algo(
216+
def prims_algo[T](
220217
graph: GraphUndirectedWeighted[T],
221218
) -> tuple[dict[T, int], dict[T, T | None]]:
222219
"""
223-
Prim's algorithm for minimum spanning tree
224-
225220
>>> graph = GraphUndirectedWeighted()
221+
226222
>>> graph.add_edge("a", "b", 3)
227223
>>> graph.add_edge("b", "c", 10)
228224
>>> graph.add_edge("c", "d", 5)
@@ -231,53 +227,38 @@ def prims_algo(
231227
232228
>>> dist, parent = prims_algo(graph)
233229
234-
>>> dist["b"]
230+
>>> abs(dist["a"] - dist["b"])
235231
3
236-
>>> dist["c"]
237-
10
238-
>>> dist["d"]
239-
5
240-
>>> parent["b"]
241-
'a'
242-
>>> parent["c"]
243-
'b'
244-
>>> parent["d"]
245-
'c'
232+
>>> abs(dist["d"] - dist["b"])
233+
15
234+
>>> abs(dist["a"] - dist["c"])
235+
13
246236
"""
247-
# Initialize distance and parent dictionaries using dict.fromkeys
237+
# prim's algorithm for minimum spanning tree
248238
dist: dict[T, int] = dict.fromkeys(graph.connections, maxsize)
249-
parent: dict[T, T | None] = dict.fromkeys(graph.connections, None)
239+
parent: dict[T, T | None] = dict.fromkeys(graph.connections)
250240

251-
# Create priority queue and add all nodes
252241
priority_queue: MinPriorityQueue[T] = MinPriorityQueue()
253-
for node in graph.connections:
254-
priority_queue.push(node, dist[node])
242+
for node, weight in dist.items():
243+
priority_queue.push(node, weight)
255244

256-
# Return if graph is empty
257245
if priority_queue.is_empty():
258246
return dist, parent
259-
260-
# Start with first node
261-
start_node = priority_queue.extract_min()
262-
dist[start_node] = 0
263-
264-
# Update neighbors of start node
265-
for neighbor, weight in graph.connections[start_node].items():
266-
if dist[neighbor] > weight:
267-
dist[neighbor] = weight
268-
priority_queue.update_key(neighbor, weight)
269-
parent[neighbor] = start_node
270-
271-
# Main algorithm loop
247+
# initialization
248+
node = priority_queue.extract_min()
249+
dist[node] = 0
250+
for neighbour in graph.connections[node]:
251+
if dist[neighbour] > dist[node] + graph.connections[node][neighbour]:
252+
dist[neighbour] = dist[node] + graph.connections[node][neighbour]
253+
priority_queue.update_key(neighbour, dist[neighbour])
254+
parent[neighbour] = node
255+
256+
# running prim's algorithm
272257
while not priority_queue.is_empty():
273258
node = priority_queue.extract_min()
274-
275-
# Explore neighbors of current node
276-
for neighbor, weight in graph.connections[node].items():
277-
# Update if found better connection to tree
278-
if dist[neighbor] > weight:
279-
dist[neighbor] = weight
280-
priority_queue.update_key(neighbor, weight)
281-
parent[neighbor] = node
282-
259+
for neighbour in graph.connections[node]:
260+
if dist[neighbour] > dist[node] + graph.connections[node][neighbour]:
261+
dist[neighbour] = dist[node] + graph.connections[node][neighbour]
262+
priority_queue.update_key(neighbour, dist[neighbour])
263+
parent[neighbour] = node
283264
return dist, parent

0 commit comments

Comments
 (0)