11from collections .abc import Callable , Iterable , Iterator
22from typing import Any , Literal
33
4+ import numpy as np
45import sympy
56
67from devito .symbolics .queries import (q_indexed , q_function , q_terminal , q_leaf ,
1213 'retrieve_derivatives' , 'search' ]
1314
1415
15- class Set ( set ):
16+ Expression = sympy . Basic | np . number | int | float
1617
18+
19+ class Set (set [Expression ]):
1720 @staticmethod
18- def wrap (obj ) -> set :
21+ def wrap (obj : Expression ) -> set [ Expression ] :
1922 return {obj }
2023
2124
22- class List (list ):
23-
25+ class List (list [Expression ]):
2426 @staticmethod
25- def wrap (obj ) -> list :
27+ def wrap (obj : Expression ) -> list [ Expression ] :
2628 return [obj ]
2729
28- def update (self , obj : Iterable [Any ]) -> None :
29- return self .extend (obj )
30-
30+ def update (self , obj : Iterable [Expression ]) -> None :
31+ self .extend (obj )
32+
3133
32- modes : dict [Literal ['all' , 'unique' ], type [List ] | type [Set ]] = {
34+ Mode = Literal ['all' , 'unique' ]
35+ modes : dict [Mode , type [List ] | type [Set ]] = {
3336 'all' : List ,
3437 'unique' : Set
3538}
3639
3740
3841class Search :
39- def __init__ (self , query : Callable [[Any ], bool ], deep : bool = False ) -> None :
42+ def __init__ (self , query : Callable [[Expression ], bool ], deep : bool = False ) -> None :
4043 """
4144 Search objects in an expression. This is much quicker than the more general
4245 SymPy's find.
@@ -51,39 +54,46 @@ def __init__(self, query: Callable[[Any], bool], deep: bool = False) -> None:
5154 self .query = query
5255 self .deep = deep
5356
54- def _next (self , expr ) -> Iterator [Any ]:
57+ def _next (self , expr : Expression ) -> Iterator [Expression ]:
5558 if self .deep and expr .is_Indexed :
5659 yield from expr .indices
5760 elif not q_leaf (expr ):
5861 yield from expr .args
5962
60- def visit_postorder (self , expr ) -> Iterator [Any ]:
63+ def visit_postorder (self , expr : Expression ) -> Iterator [Expression ]:
64+ """
65+ Visit the expression with a postorder traversal, yielding all hits.
66+ """
6167 for i in self ._next (expr ):
6268 yield from self .visit_postorder (i )
6369 if self .query (expr ):
6470 yield expr
6571
66- def visit_preorder (self , expr ) -> Iterator [Any ]:
72+ def visit_preorder (self , expr : Expression ) -> Iterator [Expression ]:
73+ """
74+ Visit the expression with a preorder traversal, yielding all hits.
75+ """
6776 if self .query (expr ):
6877 yield expr
6978 for i in self ._next (expr ):
7079 yield from self .visit_preorder (i )
7180
72- def visit_preorder_first_hit (self , expr ) -> tuple [Any , ...]:
73- """Visit the expression in preorder and return the first hit."""
81+ def visit_preorder_first_hit (self , expr : Expression ) -> Iterator [Expression ]:
82+ """
83+ Visit the expression in preorder and return a tuple containing the first hit,
84+ if any. This can return more than a single result, as it looks for the first
85+ hit from any branch but may find a hit in multiple branches.
86+ """
7487 if self .query (expr ):
75- return (expr ,)
88+ yield expr
89+ return
7690 for i in self ._next (expr ):
77- result = self .visit_preorder_first_hit (i )
78- if result :
79- return result
80- return ()
81-
91+ yield from self .visit_preorder_first_hit (i )
8292
8393
84- def search (exprs ,
94+ def search (exprs : Expression | Iterable [ Expression ] ,
8595 query : type | Callable [[Any ], bool ],
86- mode : Literal [ 'all' , 'unique' ] = 'unique' ,
96+ mode : Mode = 'unique' ,
8797 visit : Literal ['dfs' , 'bfs' , 'bfs_first_hit' ] = 'dfs' ,
8898 deep : bool = False ) -> List | Set :
8999 """Interface to Search."""
@@ -98,13 +108,12 @@ def search(exprs,
98108 # Search doesn't actually use a BFS (rather, a preorder DFS), but the terminology
99109 # is retained in this function's parameters for backwards compatibility
100110 searcher = Search (Q , deep )
101-
102111 if visit == 'dfs' :
103- _visit = searcher .visit_postorder
112+ _search = searcher .visit_postorder
104113 elif visit == 'bfs' :
105- _visit = searcher .visit_preorder
114+ _search = searcher .visit_preorder
106115 elif visit == 'bfs_first_hit' :
107- _visit = searcher .visit_preorder_first_hit
116+ _search = searcher .visit_preorder_first_hit
108117 else :
109118 raise ValueError (f"Unknown visit mode '{ visit } '" )
110119
@@ -113,7 +122,7 @@ def search(exprs,
113122 if not isinstance (e , sympy .Basic ):
114123 continue
115124
116- found .update (_visit (e ))
125+ found .update (_search (e ))
117126
118127 return found
119128
0 commit comments