11from collections .abc import Callable , Iterable , Iterator
2- from typing import Literal
2+ from typing import Any , Literal
33
44import sympy
55
1111 'retrieve_terminals' , 'retrieve_symbols' , 'retrieve_dimensions' ,
1212 'retrieve_derivatives' , 'search' ]
1313
14- class Set (set [sympy .Basic ]):
14+
15+ class Set (set ):
1516
1617 @staticmethod
17- def wrap (obj : sympy . Basic ) -> set [ sympy . Basic ] :
18+ def wrap (obj ) -> set :
1819 return {obj }
1920
2021
21- class List (list [ sympy . Basic ] ):
22+ class List (list ):
2223
2324 @staticmethod
24- def wrap (obj : sympy . Basic ) -> list [ sympy . Basic ] :
25+ def wrap (obj ) -> list :
2526 return [obj ]
2627
27- def update (self , obj : sympy . Basic ) -> None :
28- self .extend (obj )
28+ def update (self , obj : Iterable [ Any ] ) -> None :
29+ return self .extend (obj )
2930
3031
3132modes : dict [Literal ['all' , 'unique' ], type [List ] | type [Set ]] = {
@@ -35,55 +36,59 @@ def update(self, obj: sympy.Basic) -> None:
3536
3637
3738class Search :
38-
39- def __init__ (self , query : Callable [[sympy .Basic ], bool ],
40- order : Literal ['postorder' , 'preorder' ], deep : bool = False ) -> None :
39+ def __init__ (self , query : Callable [[Any ], bool ], deep : bool = False ) -> None :
4140 """
42- Search objects in an expression. This is much quicker than the more
43- general SymPy's find.
41+ Search objects in an expression. This is much quicker than the more general
42+ SymPy's find.
4443
4544 Parameters
4645 ----------
4746 query
4847 Any query from :mod:`queries`.
49- order : str
50- Either `preorder` or `postorder`, for the search order.
5148 deep : bool, optional
5249 If True, propagate the search within an Indexed's indices. Defaults to False.
5350 """
5451 self .query = query
55- self .order = order
5652 self .deep = deep
5753
58- def _next (self , expr ) -> Iterator [sympy . Basic ]:
54+ def _next (self , expr ) -> Iterator [Any ]:
5955 if self .deep and expr .is_Indexed :
6056 yield from expr .indices
6157 elif not q_leaf (expr ):
6258 yield from expr .args
6359
64- def visit (self , expr : sympy .Basic ) -> Iterator [sympy .Basic ]:
65- """Visit the expression in the specified order."""
66- if self .order == 'preorder' :
67- if self .query (expr ):
68- yield expr
69- for child in self ._next (expr ):
70- yield from self .visit (child )
71- else :
72- for child in self ._next (expr ):
73- yield from self .visit (child )
74- if self .query (expr ):
75- yield expr
76-
77-
78- def search (exprs : sympy .Basic | Iterable [sympy .Basic ],
79- query : type | Callable [[sympy .Basic ], bool ],
60+ def visit_postorder (self , expr ) -> Iterator [Any ]:
61+ for i in self ._next (expr ):
62+ yield from self .visit_postorder (i )
63+ if self .query (expr ):
64+ yield expr
65+
66+ def visit_preorder (self , expr ) -> Iterator [Any ]:
67+ if self .query (expr ):
68+ yield expr
69+ for i in self ._next (expr ):
70+ yield from self .visit_preorder (i )
71+
72+ def visit_preorder_first_hit (self , expr ) -> tuple [Any , ...]:
73+ """Visit the expression in preorder and return the first hit."""
74+ if self .query (expr ):
75+ return (expr ,)
76+ for i in self ._next (expr ):
77+ result = self .visit_preorder_first_hit (i )
78+ if result :
79+ return result
80+ return ()
81+
82+
83+
84+ def search (exprs ,
85+ query : type | Callable [[Any ], bool ],
8086 mode : Literal ['all' , 'unique' ] = 'unique' ,
8187 visit : Literal ['dfs' , 'bfs' , 'bfs_first_hit' ] = 'dfs' ,
8288 deep : bool = False ) -> List | Set :
8389 """Interface to Search."""
8490
8591 assert mode in ('all' , 'unique' ), "Unknown mode"
86- assert visit in ('dfs' , 'bfs' , 'bfs_first_hit' ), "Unknown visit type"
8792
8893 if isinstance (query , type ):
8994 Q = lambda obj : isinstance (obj , query )
@@ -92,21 +97,23 @@ def search(exprs: sympy.Basic | Iterable[sympy.Basic],
9297
9398 # Search doesn't actually use a BFS (rather, a preorder DFS), but the terminology
9499 # is retained in this function's parameters for backwards compatibility
95- order = 'postorder' if visit == 'dfs' else 'preorder'
96- searcher = Search (Q , order , deep )
100+ searcher = Search (Q , deep )
101+
102+ if visit == 'dfs' :
103+ _visit = searcher .visit_postorder
104+ elif visit == 'bfs' :
105+ _visit = searcher .visit_preorder
106+ elif visit == 'bfs_first_hit' :
107+ _visit = searcher .visit_preorder_first_hit
108+ else :
109+ raise ValueError (f"Unknown visit mode '{ visit } '" )
97110
98- Collection = modes [mode ]
99- found = Collection ()
111+ found = modes [mode ]()
100112 for e in as_tuple (exprs ):
101113 if not isinstance (e , sympy .Basic ):
102114 continue
103115
104- for i in searcher .visit (e ):
105- found .update (Collection .wrap (i ))
106-
107- if visit == 'bfs_first_hit' :
108- # Stop at the first hit for this outer expression
109- break
116+ found .update (_visit (e ))
110117
111118 return found
112119
0 commit comments