1010from __future__ import annotations
1111
1212from sys import maxsize
13- from typing import Generic , TypeVar
14-
15- T = TypeVar ("T" )
16-
1713
1814def get_parent_position (position : int ) -> int :
1915 """
20- Heap helper function to get the position of the parent of the current node
16+ heap helper function get the position of the parent of the current node
2117
2218 >>> get_parent_position(1)
2319 0
@@ -29,7 +25,7 @@ def get_parent_position(position: int) -> int:
2925
3026def get_child_left_position (position : int ) -> int :
3127 """
32- Heap helper function to get the position of the left child of the current node
28+ heap helper function get the position of the left child of the current node
3329
3430 >>> get_child_left_position(0)
3531 1
@@ -39,26 +35,29 @@ def get_child_left_position(position: int) -> int:
3935
4036def get_child_right_position (position : int ) -> int :
4137 """
42- Heap helper function to get the position of the right child of the current node
38+ heap helper function get the position of the right child of the current node
4339
4440 >>> get_child_right_position(0)
4541 2
4642 """
4743 return (2 * position ) + 2
4844
4945
50- class MinPriorityQueue ( Generic [T ]) :
46+ class MinPriorityQueue [T ]:
5147 """
5248 Minimum Priority Queue Class
5349
5450 Functions:
55- is_empty: Check if the priority queue is empty
56- push: Add an element with given priority to the queue
57- extract_min: Remove and return the element with lowest weight (highest priority)
58- update_key: Update the weight of the given key
59- _bubble_up: Place a node at proper position (upward movement)
60- _bubble_down: Place a node at proper position (downward movement)
61- _swap_nodes: Swap nodes at given positions
51+ is_empty: function to check if the priority queue is empty
52+ push: function to add an element with given priority to the queue
53+ extract_min: function to remove and return the element with lowest weight (highest
54+ priority)
55+ update_key: function to update the weight of the given key
56+ _bubble_up: helper function to place a node at the proper position (upward
57+ movement)
58+ _bubble_down: helper function to place a node at the proper position (downward
59+ movement)
60+ _swap_nodes: helper function to swap the nodes at the given positions
6261
6362 >>> queue = MinPriorityQueue()
6463
@@ -92,18 +91,18 @@ def __repr__(self) -> str:
9291 return str (self .heap )
9392
9493 def is_empty (self ) -> bool :
95- """ Check if the priority queue is empty"""
94+ # Check if the priority queue is empty
9695 return self .elements == 0
9796
9897 def push (self , elem : T , weight : int ) -> None :
99- """ Add an element with given priority to the queue"""
98+ # Add an element with given priority to the queue
10099 self .heap .append ((elem , weight ))
101100 self .position_map [elem ] = self .elements
102101 self .elements += 1
103102 self ._bubble_up (elem )
104103
105104 def extract_min (self ) -> T :
106- """ Remove and return the element with lowest weight (highest priority)"""
105+ # Remove and return the element with lowest weight (highest priority)
107106 if self .elements > 1 :
108107 self ._swap_nodes (0 , self .elements - 1 )
109108 elem , _ = self .heap .pop ()
@@ -115,7 +114,7 @@ def extract_min(self) -> T:
115114 return elem
116115
117116 def update_key (self , elem : T , weight : int ) -> None :
118- """ Update the weight of the given key"""
117+ # Update the weight of the given key
119118 position = self .position_map [elem ]
120119 self .heap [position ] = (elem , weight )
121120 if position > 0 :
@@ -129,50 +128,48 @@ def update_key(self, elem: T, weight: int) -> None:
129128 self ._bubble_down (elem )
130129
131130 def _bubble_up (self , elem : T ) -> None :
132- """Place node at proper position (upward movement) - internal use only"""
131+ # Place a node at the proper position (upward movement) [to be used internally
132+ # only]
133133 curr_pos = self .position_map [elem ]
134134 if curr_pos == 0 :
135- return
135+ return None
136136 parent_position = get_parent_position (curr_pos )
137137 _ , weight = self .heap [curr_pos ]
138138 _ , parent_weight = self .heap [parent_position ]
139139 if parent_weight > weight :
140140 self ._swap_nodes (parent_position , curr_pos )
141- self ._bubble_up (elem )
141+ return self ._bubble_up (elem )
142+ return None
142143
143144 def _bubble_down (self , elem : T ) -> None :
144- """Place node at proper position (downward movement) - internal use only"""
145+ # Place a node at the proper position (downward movement) [to be used
146+ # internally only]
145147 curr_pos = self .position_map [elem ]
146148 _ , weight = self .heap [curr_pos ]
147149 child_left_position = get_child_left_position (curr_pos )
148150 child_right_position = get_child_right_position (curr_pos )
149-
150- # Check if both children exist
151151 if child_left_position < self .elements and child_right_position < self .elements :
152152 _ , child_left_weight = self .heap [child_left_position ]
153153 _ , child_right_weight = self .heap [child_right_position ]
154154 if child_right_weight < child_left_weight and child_right_weight < weight :
155155 self ._swap_nodes (child_right_position , curr_pos )
156- self ._bubble_down (elem )
157- return
158-
159- # Check left child
156+ return self ._bubble_down (elem )
160157 if child_left_position < self .elements :
161158 _ , child_left_weight = self .heap [child_left_position ]
162159 if child_left_weight < weight :
163160 self ._swap_nodes (child_left_position , curr_pos )
164- self ._bubble_down (elem )
165- return
166-
167- # Check right child
161+ return self ._bubble_down (elem )
162+ else :
163+ return None
168164 if child_right_position < self .elements :
169165 _ , child_right_weight = self .heap [child_right_position ]
170166 if child_right_weight < weight :
171167 self ._swap_nodes (child_right_position , curr_pos )
172- self ._bubble_down (elem )
168+ return self ._bubble_down (elem )
169+ return None
173170
174171 def _swap_nodes (self , node1_pos : int , node2_pos : int ) -> None :
175- """ Swap nodes at given positions"""
172+ # Swap the nodes at the given positions
176173 node1_elem = self .heap [node1_pos ][0 ]
177174 node2_elem = self .heap [node2_pos ][0 ]
178175 self .heap [node1_pos ], self .heap [node2_pos ] = (
@@ -183,13 +180,13 @@ def _swap_nodes(self, node1_pos: int, node2_pos: int) -> None:
183180 self .position_map [node2_elem ] = node1_pos
184181
185182
186- class GraphUndirectedWeighted ( Generic [T ]) :
183+ class GraphUndirectedWeighted [T ]:
187184 """
188185 Graph Undirected Weighted Class
189186
190187 Functions:
191- add_node: Add a node to the graph
192- add_edge: Add an edge between two nodes with given weight
188+ add_node: function to add a node in the graph
189+ add_edge: function to add an edge between 2 nodes in the graph
193190 """
194191
195192 def __init__ (self ) -> None :
@@ -203,26 +200,25 @@ def __len__(self) -> int:
203200 return self .nodes
204201
205202 def add_node (self , node : T ) -> None :
206- """ Add a node to the graph if not already present"""
203+ # Add a node in the graph if it is not in the graph
207204 if node not in self .connections :
208205 self .connections [node ] = {}
209206 self .nodes += 1
210207
211208 def add_edge (self , node1 : T , node2 : T , weight : int ) -> None :
212- """ Add an edge between two nodes with given weight"""
209+ # Add an edge between 2 nodes in the graph
213210 self .add_node (node1 )
214211 self .add_node (node2 )
215212 self .connections [node1 ][node2 ] = weight
216213 self .connections [node2 ][node1 ] = weight
217214
218215
219- def prims_algo (
216+ def prims_algo [ T ] (
220217 graph : GraphUndirectedWeighted [T ],
221218) -> tuple [dict [T , int ], dict [T , T | None ]]:
222219 """
223- Prim's algorithm for minimum spanning tree
224-
225220 >>> graph = GraphUndirectedWeighted()
221+
226222 >>> graph.add_edge("a", "b", 3)
227223 >>> graph.add_edge("b", "c", 10)
228224 >>> graph.add_edge("c", "d", 5)
@@ -231,53 +227,38 @@ def prims_algo(
231227
232228 >>> dist, parent = prims_algo(graph)
233229
234- >>> dist["b"]
230+ >>> abs( dist["a"] - dist[" b"])
235231 3
236- >>> dist["c"]
237- 10
238- >>> dist["d"]
239- 5
240- >>> parent["b"]
241- 'a'
242- >>> parent["c"]
243- 'b'
244- >>> parent["d"]
245- 'c'
232+ >>> abs(dist["d"] - dist["b"])
233+ 15
234+ >>> abs(dist["a"] - dist["c"])
235+ 13
246236 """
247- # Initialize distance and parent dictionaries using dict.fromkeys
237+ # prim's algorithm for minimum spanning tree
248238 dist : dict [T , int ] = dict .fromkeys (graph .connections , maxsize )
249- parent : dict [T , T | None ] = dict .fromkeys (graph .connections , None )
239+ parent : dict [T , T | None ] = dict .fromkeys (graph .connections )
250240
251- # Create priority queue and add all nodes
252241 priority_queue : MinPriorityQueue [T ] = MinPriorityQueue ()
253- for node in graph . connections :
254- priority_queue .push (node , dist [ node ] )
242+ for node , weight in dist . items () :
243+ priority_queue .push (node , weight )
255244
256- # Return if graph is empty
257245 if priority_queue .is_empty ():
258246 return dist , parent
259-
260- # Start with first node
261- start_node = priority_queue .extract_min ()
262- dist [start_node ] = 0
263-
264- # Update neighbors of start node
265- for neighbor , weight in graph .connections [start_node ].items ():
266- if dist [neighbor ] > weight :
267- dist [neighbor ] = weight
268- priority_queue .update_key (neighbor , weight )
269- parent [neighbor ] = start_node
270-
271- # Main algorithm loop
247+ # initialization
248+ node = priority_queue .extract_min ()
249+ dist [node ] = 0
250+ for neighbour in graph .connections [node ]:
251+ if dist [neighbour ] > dist [node ] + graph .connections [node ][neighbour ]:
252+ dist [neighbour ] = dist [node ] + graph .connections [node ][neighbour ]
253+ priority_queue .update_key (neighbour , dist [neighbour ])
254+ parent [neighbour ] = node
255+
256+ # running prim's algorithm
272257 while not priority_queue .is_empty ():
273258 node = priority_queue .extract_min ()
274-
275- # Explore neighbors of current node
276- for neighbor , weight in graph .connections [node ].items ():
277- # Update if found better connection to tree
278- if dist [neighbor ] > weight :
279- dist [neighbor ] = weight
280- priority_queue .update_key (neighbor , weight )
281- parent [neighbor ] = node
282-
259+ for neighbour in graph .connections [node ]:
260+ if dist [neighbour ] > dist [node ] + graph .connections [node ][neighbour ]:
261+ dist [neighbour ] = dist [node ] + graph .connections [node ][neighbour ]
262+ priority_queue .update_key (neighbour , dist [neighbour ])
263+ parent [neighbour ] = node
283264 return dist , parent
0 commit comments