1+ from collections .abc import Callable , Iterable , Iterator
2+ from typing import Literal
3+
14import sympy
25
36from devito .symbolics .queries import (q_indexed , q_function , q_terminal , q_leaf ,
811 'retrieve_terminals' , 'retrieve_symbols' , 'retrieve_dimensions' ,
912 'retrieve_derivatives' , 'search' ]
1013
14+ class Set (set [sympy .Basic ]):
15+
16+ @staticmethod
17+ def wrap (obj : sympy .Basic ) -> set [sympy .Basic ]:
18+ return {obj }
1119
12- class Search :
1320
14- class Set ( set ):
21+ class List ( list [ sympy . Basic ] ):
1522
16- @staticmethod
17- def wrap (obj ) :
18- return { obj }
23+ @staticmethod
24+ def wrap (obj : sympy . Basic ) -> list [ sympy . Basic ] :
25+ return [ obj ]
1926
20- class List (list ):
27+ def update (self , obj : sympy .Basic ) -> None :
28+ self .extend (obj )
29+
2130
22- @staticmethod
23- def wrap (obj ):
24- return [obj ]
31+ modes : dict [Literal ['all' , 'unique' ], type [List ] | type [Set ]] = {
32+ 'all' : List ,
33+ 'unique' : Set
34+ }
2535
26- def update (self , obj ):
27- return self .extend (obj )
2836
29- modes = {
30- 'unique' : Set ,
31- 'all' : List
32- }
37+ class Search :
3338
34- def __init__ (self , query , mode , deep = False ):
39+ def __init__ (self , query : Callable [[sympy .Basic ], bool ],
40+ order : Literal ['postorder' , 'preorder' ], deep : bool = False ) -> None :
3541 """
3642 Search objects in an expression. This is much quicker than the more
3743 general SymPy's find.
@@ -40,98 +46,67 @@ def __init__(self, query, mode, deep=False):
4046 ----------
4147 query
4248 Any query from :mod:`queries`.
43- mode : str
44- Either 'unique' or 'all' (catch all instances) .
49+ order : str
50+ Either `preorder` or `postorder`, for the search order .
4551 deep : bool, optional
4652 If True, propagate the search within an Indexed's indices. Defaults to False.
4753 """
4854 self .query = query
49- self .collection = self . modes [ mode ]
55+ self .order = order
5056 self .deep = deep
5157
52- def _next (self , expr ):
58+ def _next (self , expr ) -> Iterator [ sympy . Basic ] :
5359 if self .deep and expr .is_Indexed :
54- return expr .indices
55- elif q_leaf (expr ):
56- return ()
60+ yield from expr .indices
61+ elif not q_leaf (expr ):
62+ yield from expr .args
63+
64+ def visit (self , expr : sympy .Basic ) -> Iterator [sympy .Basic ]:
65+ """Visit the expression in the specified order."""
66+ if self .order == 'preorder' :
67+ if self .query (expr ):
68+ yield expr
69+ for child in self ._next (expr ):
70+ yield from self .visit (child )
5771 else :
58- return expr .args
72+ for child in self ._next (expr ):
73+ yield from self .visit (child )
74+ if self .query (expr ):
75+ yield expr
5976
60- def dfs (self , expr ):
61- """
62- Perform a DFS search.
6377
64- Parameters
65- ----------
66- expr : expr-like
67- The searched expression.
68- """
69- found = self .collection ()
70- for a in self ._next (expr ):
71- found .update (self .dfs (a ))
72- if self .query (expr ):
73- found .update (self .collection .wrap (expr ))
74- return found
75-
76- def bfs (self , expr ):
77- """
78- Perform a BFS search.
79-
80- Parameters
81- ----------
82- expr : expr-like
83- The searched expression.
84- """
85- found = self .collection ()
86- if self .query (expr ):
87- found .update (self .collection .wrap (expr ))
88- for a in self ._next (expr ):
89- found .update (self .bfs (a ))
90- return found
91-
92- def bfs_first_hit (self , expr ):
93- """
94- Perform a BFS search, returning immediately when a node matches the query.
95-
96- Parameters
97- ----------
98- expr : expr-like
99- The searched expression.
100- """
101- found = self .collection ()
102- if self .query (expr ):
103- found .update (self .collection .wrap (expr ))
104- return found
105- for a in self ._next (expr ):
106- found .update (self .bfs_first_hit (a ))
107- return found
108-
109-
110- def search (exprs , query , mode = 'unique' , visit = 'dfs' , deep = False ):
78+ def search (exprs : sympy .Basic | Iterable [sympy .Basic ],
79+ query : type | Callable [[sympy .Basic ], bool ],
80+ mode : Literal ['all' , 'unique' ] = 'unique' ,
81+ visit : Literal ['dfs' , 'bfs' , 'bfs_first_hit' ] = 'dfs' ,
82+ deep : bool = False ) -> List | Set :
11183 """Interface to Search."""
11284
113- assert mode in Search .modes , "Unknown mode"
85+ assert mode in ('all' , 'unique' ), "Unknown mode"
86+ assert visit in ('dfs' , 'bfs' , 'bfs_first_hit' ), "Unknown visit type"
11487
11588 if isinstance (query , type ):
11689 Q = lambda obj : isinstance (obj , query )
11790 else :
11891 Q = query
11992
120- searcher = Search (Q , mode , deep )
93+ # Search doesn't actually use a BFS (rather, a preorder DFS), but the terminology
94+ # is retained in this function's parameters for backwards compatibility
95+ order = 'postorder' if visit == 'dfs' else 'preorder'
96+ searcher = Search (Q , order , deep )
12197
122- found = Search .modes [mode ]()
98+ Collection = modes [mode ]
99+ found = Collection ()
123100 for e in as_tuple (exprs ):
124101 if not isinstance (e , sympy .Basic ):
125102 continue
126103
127- if visit == 'dfs' :
128- found .update (searcher .dfs (e ))
129- elif visit == 'bfs' :
130- found .update (searcher .bfs (e ))
131- elif visit == "bfs_first_hit" :
132- found .update (searcher .bfs_first_hit (e ))
133- else :
134- raise ValueError ("Unknown visit type `%s`" % visit )
104+ for i in searcher .visit (e ):
105+ found .update (Collection .wrap (i ))
106+
107+ if visit == 'bfs_first_hit' :
108+ # Stop at the first hit for this outer expression
109+ break
135110
136111 return found
137112
0 commit comments