|
15 | 15 | vertices: list[str] = ["a", "b", "c", "d", "e"] |
16 | 16 |
|
17 | 17 |
|
18 | | -def topological_sort(start: str, visited: list[str], sort: list[str]) -> list[str]: |
19 | | - """Perform topological sort on a directed acyclic graph.""" |
20 | | - current = start |
21 | | - # add current to visited |
| 18 | +def _visit( |
| 19 | + current: str, |
| 20 | + visited: list[str], |
| 21 | + post_order: list[str], |
| 22 | + graph: dict[str, list[str]], |
| 23 | +) -> None: |
| 24 | + """Visit all descendants of the current vertex using DFS.""" |
22 | 25 | visited.append(current) |
23 | | - neighbors = edges[current] |
24 | | - for neighbor in neighbors: |
25 | | - # if neighbor not in visited, visit |
| 26 | + for neighbor in graph[current]: |
26 | 27 | if neighbor not in visited: |
27 | | - sort = topological_sort(neighbor, visited, sort) |
28 | | - # if all neighbors visited add current before its descendants |
29 | | - sort.insert(0, current) |
30 | | - # if all vertices haven't been visited select a new one to visit |
31 | | - if len(visited) != len(vertices): |
32 | | - for vertice in vertices: |
| 28 | + _visit(neighbor, visited, post_order, graph) |
| 29 | + post_order.append(current) |
| 30 | + |
| 31 | + |
| 32 | +def topological_sort( |
| 33 | + start: str, |
| 34 | + visited: list[str], |
| 35 | + sort: list[str], |
| 36 | + graph: dict[str, list[str]] | None = None, |
| 37 | + vertices_list: list[str] | None = None, |
| 38 | +) -> list[str]: |
| 39 | + """ |
| 40 | + Perform topological sort on a directed acyclic graph. |
| 41 | +
|
| 42 | + >>> result = topological_sort("a", [], [], edges, vertices) |
| 43 | + >>> all( |
| 44 | + ... result.index(parent) < result.index(child) |
| 45 | + ... for parent, children in edges.items() |
| 46 | + ... for child in children |
| 47 | + ... ) |
| 48 | + True |
| 49 | + """ |
| 50 | + if graph is None: |
| 51 | + graph = edges |
| 52 | + if vertices_list is None: |
| 53 | + vertices_list = list(graph) |
| 54 | + |
| 55 | + _visit(start, visited, sort, graph) |
| 56 | + if len(visited) != len(vertices_list): |
| 57 | + for vertice in vertices_list: |
33 | 58 | if vertice not in visited: |
34 | | - sort = topological_sort(vertice, visited, sort) |
35 | | - # return sort |
| 59 | + _visit(vertice, visited, sort, graph) |
| 60 | + sort.reverse() |
36 | 61 | return sort |
37 | 62 |
|
38 | 63 |
|
39 | 64 | if __name__ == "__main__": |
40 | | - sort = topological_sort("a", [], []) |
41 | | - print(sort) |
| 65 | + result = topological_sort("a", [], [], edges, vertices) |
| 66 | + assert result == ["a", "b", "e", "d", "c"] |
| 67 | + print(result) |
0 commit comments