diff --git a/sorts/topological_sort.py b/sorts/topological_sort.py index efce8165fcac..bf66558ad376 100644 --- a/sorts/topological_sort.py +++ b/sorts/topological_sort.py @@ -15,27 +15,53 @@ vertices: list[str] = ["a", "b", "c", "d", "e"] -def topological_sort(start: str, visited: list[str], sort: list[str]) -> list[str]: - """Perform topological sort on a directed acyclic graph.""" - current = start - # add current to visited +def _visit( + current: str, + visited: list[str], + post_order: list[str], + graph: dict[str, list[str]], +) -> None: + """Visit all descendants of the current vertex using DFS.""" visited.append(current) - neighbors = edges[current] - for neighbor in neighbors: - # if neighbor not in visited, visit + for neighbor in graph[current]: if neighbor not in visited: - sort = topological_sort(neighbor, visited, sort) - # if all neighbors visited add current to sort - sort.append(current) - # if all vertices haven't been visited select a new one to visit - if len(visited) != len(vertices): - for vertice in vertices: + _visit(neighbor, visited, post_order, graph) + post_order.append(current) + + +def topological_sort( + start: str, + visited: list[str], + sort: list[str], + graph: dict[str, list[str]] | None = None, + vertices_list: list[str] | None = None, +) -> list[str]: + """ + Perform topological sort on a directed acyclic graph. + + >>> result = topological_sort("a", [], [], edges, vertices) + >>> all( + ... result.index(parent) < result.index(child) + ... for parent, children in edges.items() + ... for child in children + ... ) + True + """ + if graph is None: + graph = edges + if vertices_list is None: + vertices_list = list(graph) + + _visit(start, visited, sort, graph) + if len(visited) != len(vertices_list): + for vertice in vertices_list: if vertice not in visited: - sort = topological_sort(vertice, visited, sort) - # return sort + _visit(vertice, visited, sort, graph) + sort.reverse() return sort if __name__ == "__main__": - sort = topological_sort("a", [], []) - print(sort) + result = topological_sort("a", [], [], edges, vertices) + assert result == ["a", "b", "e", "d", "c"] + print(result)