Skip to content

Commit 504c6a6

Browse files
author
Kcstring
committed
fix: address topological sort review
1 parent a0a8209 commit 504c6a6

1 file changed

Lines changed: 43 additions & 17 deletions

File tree

sorts/topological_sort.py

Lines changed: 43 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -15,27 +15,53 @@
1515
vertices: list[str] = ["a", "b", "c", "d", "e"]
1616

1717

18-
def topological_sort(start: str, visited: list[str], sort: list[str]) -> list[str]:
19-
"""Perform topological sort on a directed acyclic graph."""
20-
current = start
21-
# add current to visited
18+
def _visit(
19+
current: str,
20+
visited: list[str],
21+
post_order: list[str],
22+
graph: dict[str, list[str]],
23+
) -> None:
24+
"""Visit all descendants of the current vertex using DFS."""
2225
visited.append(current)
23-
neighbors = edges[current]
24-
for neighbor in neighbors:
25-
# if neighbor not in visited, visit
26+
for neighbor in graph[current]:
2627
if neighbor not in visited:
27-
sort = topological_sort(neighbor, visited, sort)
28-
# if all neighbors visited add current before its descendants
29-
sort.insert(0, current)
30-
# if all vertices haven't been visited select a new one to visit
31-
if len(visited) != len(vertices):
32-
for vertice in vertices:
28+
_visit(neighbor, visited, post_order, graph)
29+
post_order.append(current)
30+
31+
32+
def topological_sort(
33+
start: str,
34+
visited: list[str],
35+
sort: list[str],
36+
graph: dict[str, list[str]] | None = None,
37+
vertices_list: list[str] | None = None,
38+
) -> list[str]:
39+
"""
40+
Perform topological sort on a directed acyclic graph.
41+
42+
>>> result = topological_sort("a", [], [], edges, vertices)
43+
>>> all(
44+
... result.index(parent) < result.index(child)
45+
... for parent, children in edges.items()
46+
... for child in children
47+
... )
48+
True
49+
"""
50+
if graph is None:
51+
graph = edges
52+
if vertices_list is None:
53+
vertices_list = list(graph)
54+
55+
_visit(start, visited, sort, graph)
56+
if len(visited) != len(vertices_list):
57+
for vertice in vertices_list:
3358
if vertice not in visited:
34-
sort = topological_sort(vertice, visited, sort)
35-
# return sort
59+
_visit(vertice, visited, sort, graph)
60+
sort.reverse()
3661
return sort
3762

3863

3964
if __name__ == "__main__":
40-
sort = topological_sort("a", [], [])
41-
print(sort)
65+
result = topological_sort("a", [], [], edges, vertices)
66+
assert result == ["a", "b", "e", "d", "c"]
67+
print(result)

0 commit comments

Comments
 (0)