Skip to content

Commit 130efc1

Browse files
committed
Lazy search
1 parent d6a722f commit 130efc1

1 file changed

Lines changed: 60 additions & 85 deletions

File tree

devito/symbolics/search.py

Lines changed: 60 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
from collections.abc import Callable, Iterable, Iterator
2+
from typing import Literal
3+
14
import sympy
25

36
from devito.symbolics.queries import (q_indexed, q_function, q_terminal, q_leaf,
@@ -8,30 +11,33 @@
811
'retrieve_terminals', 'retrieve_symbols', 'retrieve_dimensions',
912
'retrieve_derivatives', 'search']
1013

14+
class Set(set[sympy.Basic]):
15+
16+
@staticmethod
17+
def wrap(obj: sympy.Basic) -> set[sympy.Basic]:
18+
return {obj}
1119

12-
class Search:
1320

14-
class Set(set):
21+
class List(list[sympy.Basic]):
1522

16-
@staticmethod
17-
def wrap(obj):
18-
return {obj}
23+
@staticmethod
24+
def wrap(obj: sympy.Basic) -> list[sympy.Basic]:
25+
return [obj]
1926

20-
class List(list):
27+
def update(self, obj: sympy.Basic) -> None:
28+
self.extend(obj)
29+
2130

22-
@staticmethod
23-
def wrap(obj):
24-
return [obj]
31+
modes: dict[Literal['all', 'unique'], type[List] | type[Set]] = {
32+
'all': List,
33+
'unique': Set
34+
}
2535

26-
def update(self, obj):
27-
return self.extend(obj)
2836

29-
modes = {
30-
'unique': Set,
31-
'all': List
32-
}
37+
class Search:
3338

34-
def __init__(self, query, mode, deep=False):
39+
def __init__(self, query: Callable[[sympy.Basic], bool],
40+
order: Literal['postorder', 'preorder'], deep: bool = False) -> None:
3541
"""
3642
Search objects in an expression. This is much quicker than the more
3743
general SymPy's find.
@@ -40,98 +46,67 @@ def __init__(self, query, mode, deep=False):
4046
----------
4147
query
4248
Any query from :mod:`queries`.
43-
mode : str
44-
Either 'unique' or 'all' (catch all instances).
49+
order : str
50+
Either `preorder` or `postorder`, for the search order.
4551
deep : bool, optional
4652
If True, propagate the search within an Indexed's indices. Defaults to False.
4753
"""
4854
self.query = query
49-
self.collection = self.modes[mode]
55+
self.order = order
5056
self.deep = deep
5157

52-
def _next(self, expr):
58+
def _next(self, expr) -> Iterator[sympy.Basic]:
5359
if self.deep and expr.is_Indexed:
54-
return expr.indices
55-
elif q_leaf(expr):
56-
return ()
60+
yield from expr.indices
61+
elif not q_leaf(expr):
62+
yield from expr.args
63+
64+
def visit(self, expr: sympy.Basic) -> Iterator[sympy.Basic]:
65+
"""Visit the expression in the specified order."""
66+
if self.order == 'preorder':
67+
if self.query(expr):
68+
yield expr
69+
for child in self._next(expr):
70+
yield from self.visit(child)
5771
else:
58-
return expr.args
72+
for child in self._next(expr):
73+
yield from self.visit(child)
74+
if self.query(expr):
75+
yield expr
5976

60-
def dfs(self, expr):
61-
"""
62-
Perform a DFS search.
6377

64-
Parameters
65-
----------
66-
expr : expr-like
67-
The searched expression.
68-
"""
69-
found = self.collection()
70-
for a in self._next(expr):
71-
found.update(self.dfs(a))
72-
if self.query(expr):
73-
found.update(self.collection.wrap(expr))
74-
return found
75-
76-
def bfs(self, expr):
77-
"""
78-
Perform a BFS search.
79-
80-
Parameters
81-
----------
82-
expr : expr-like
83-
The searched expression.
84-
"""
85-
found = self.collection()
86-
if self.query(expr):
87-
found.update(self.collection.wrap(expr))
88-
for a in self._next(expr):
89-
found.update(self.bfs(a))
90-
return found
91-
92-
def bfs_first_hit(self, expr):
93-
"""
94-
Perform a BFS search, returning immediately when a node matches the query.
95-
96-
Parameters
97-
----------
98-
expr : expr-like
99-
The searched expression.
100-
"""
101-
found = self.collection()
102-
if self.query(expr):
103-
found.update(self.collection.wrap(expr))
104-
return found
105-
for a in self._next(expr):
106-
found.update(self.bfs_first_hit(a))
107-
return found
108-
109-
110-
def search(exprs, query, mode='unique', visit='dfs', deep=False):
78+
def search(exprs: sympy.Basic | Iterable[sympy.Basic],
79+
query: type | Callable[[sympy.Basic], bool],
80+
mode: Literal['all', 'unique'] = 'unique',
81+
visit: Literal['dfs', 'bfs', 'bfs_first_hit'] = 'dfs',
82+
deep: bool = False) -> List | Set:
11183
"""Interface to Search."""
11284

113-
assert mode in Search.modes, "Unknown mode"
85+
assert mode in ('all', 'unique'), "Unknown mode"
86+
assert visit in ('dfs', 'bfs', 'bfs_first_hit'), "Unknown visit type"
11487

11588
if isinstance(query, type):
11689
Q = lambda obj: isinstance(obj, query)
11790
else:
11891
Q = query
11992

120-
searcher = Search(Q, mode, deep)
93+
# Search doesn't actually use a BFS (rather, a preorder DFS), but the terminology
94+
# is retained in this function's parameters for backwards compatibility
95+
order = 'postorder' if visit == 'dfs' else 'preorder'
96+
searcher = Search(Q, order, deep)
12197

122-
found = Search.modes[mode]()
98+
Collection = modes[mode]
99+
found = Collection()
123100
for e in as_tuple(exprs):
124101
if not isinstance(e, sympy.Basic):
125102
continue
126103

127-
if visit == 'dfs':
128-
found.update(searcher.dfs(e))
129-
elif visit == 'bfs':
130-
found.update(searcher.bfs(e))
131-
elif visit == "bfs_first_hit":
132-
found.update(searcher.bfs_first_hit(e))
133-
else:
134-
raise ValueError("Unknown visit type `%s`" % visit)
104+
for i in searcher.visit(e):
105+
found.update(Collection.wrap(i))
106+
107+
if visit == 'bfs_first_hit':
108+
# Stop at the first hit for this outer expression
109+
break
135110

136111
return found
137112

0 commit comments

Comments
 (0)