Skip to content

Commit a3b4048

Browse files
committed
Better lazy search
1 parent 130efc1 commit a3b4048

1 file changed

Lines changed: 50 additions & 43 deletions

File tree

devito/symbolics/search.py

Lines changed: 50 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from collections.abc import Callable, Iterable, Iterator
2-
from typing import Literal
2+
from typing import Any, Literal
33

44
import sympy
55

@@ -11,21 +11,22 @@
1111
'retrieve_terminals', 'retrieve_symbols', 'retrieve_dimensions',
1212
'retrieve_derivatives', 'search']
1313

14-
class Set(set[sympy.Basic]):
14+
15+
class Set(set):
1516

1617
@staticmethod
17-
def wrap(obj: sympy.Basic) -> set[sympy.Basic]:
18+
def wrap(obj) -> set:
1819
return {obj}
1920

2021

21-
class List(list[sympy.Basic]):
22+
class List(list):
2223

2324
@staticmethod
24-
def wrap(obj: sympy.Basic) -> list[sympy.Basic]:
25+
def wrap(obj) -> list:
2526
return [obj]
2627

27-
def update(self, obj: sympy.Basic) -> None:
28-
self.extend(obj)
28+
def update(self, obj: Iterable[Any]) -> None:
29+
return self.extend(obj)
2930

3031

3132
modes: dict[Literal['all', 'unique'], type[List] | type[Set]] = {
@@ -35,55 +36,59 @@ def update(self, obj: sympy.Basic) -> None:
3536

3637

3738
class Search:
38-
39-
def __init__(self, query: Callable[[sympy.Basic], bool],
40-
order: Literal['postorder', 'preorder'], deep: bool = False) -> None:
39+
def __init__(self, query: Callable[[Any], bool], deep: bool = False) -> None:
4140
"""
42-
Search objects in an expression. This is much quicker than the more
43-
general SymPy's find.
41+
Search objects in an expression. This is much quicker than the more general
42+
SymPy's find.
4443
4544
Parameters
4645
----------
4746
query
4847
Any query from :mod:`queries`.
49-
order : str
50-
Either `preorder` or `postorder`, for the search order.
5148
deep : bool, optional
5249
If True, propagate the search within an Indexed's indices. Defaults to False.
5350
"""
5451
self.query = query
55-
self.order = order
5652
self.deep = deep
5753

58-
def _next(self, expr) -> Iterator[sympy.Basic]:
54+
def _next(self, expr) -> Iterator[Any]:
5955
if self.deep and expr.is_Indexed:
6056
yield from expr.indices
6157
elif not q_leaf(expr):
6258
yield from expr.args
6359

64-
def visit(self, expr: sympy.Basic) -> Iterator[sympy.Basic]:
65-
"""Visit the expression in the specified order."""
66-
if self.order == 'preorder':
67-
if self.query(expr):
68-
yield expr
69-
for child in self._next(expr):
70-
yield from self.visit(child)
71-
else:
72-
for child in self._next(expr):
73-
yield from self.visit(child)
74-
if self.query(expr):
75-
yield expr
76-
77-
78-
def search(exprs: sympy.Basic | Iterable[sympy.Basic],
79-
query: type | Callable[[sympy.Basic], bool],
60+
def visit_postorder(self, expr) -> Iterator[Any]:
61+
for i in self._next(expr):
62+
yield from self.visit_postorder(i)
63+
if self.query(expr):
64+
yield expr
65+
66+
def visit_preorder(self, expr) -> Iterator[Any]:
67+
if self.query(expr):
68+
yield expr
69+
for i in self._next(expr):
70+
yield from self.visit_preorder(i)
71+
72+
def visit_preorder_first_hit(self, expr) -> tuple[Any, ...]:
73+
"""Visit the expression in preorder and return the first hit."""
74+
if self.query(expr):
75+
return (expr,)
76+
for i in self._next(expr):
77+
result = self.visit_preorder_first_hit(i)
78+
if result:
79+
return result
80+
return ()
81+
82+
83+
84+
def search(exprs,
85+
query: type | Callable[[Any], bool],
8086
mode: Literal['all', 'unique'] = 'unique',
8187
visit: Literal['dfs', 'bfs', 'bfs_first_hit'] = 'dfs',
8288
deep: bool = False) -> List | Set:
8389
"""Interface to Search."""
8490

8591
assert mode in ('all', 'unique'), "Unknown mode"
86-
assert visit in ('dfs', 'bfs', 'bfs_first_hit'), "Unknown visit type"
8792

8893
if isinstance(query, type):
8994
Q = lambda obj: isinstance(obj, query)
@@ -92,21 +97,23 @@ def search(exprs: sympy.Basic | Iterable[sympy.Basic],
9297

9398
# Search doesn't actually use a BFS (rather, a preorder DFS), but the terminology
9499
# is retained in this function's parameters for backwards compatibility
95-
order = 'postorder' if visit == 'dfs' else 'preorder'
96-
searcher = Search(Q, order, deep)
100+
searcher = Search(Q, deep)
101+
102+
if visit == 'dfs':
103+
_visit = searcher.visit_postorder
104+
elif visit == 'bfs':
105+
_visit = searcher.visit_preorder
106+
elif visit == 'bfs_first_hit':
107+
_visit = searcher.visit_preorder_first_hit
108+
else:
109+
raise ValueError(f"Unknown visit mode '{visit}'")
97110

98-
Collection = modes[mode]
99-
found = Collection()
111+
found = modes[mode]()
100112
for e in as_tuple(exprs):
101113
if not isinstance(e, sympy.Basic):
102114
continue
103115

104-
for i in searcher.visit(e):
105-
found.update(Collection.wrap(i))
106-
107-
if visit == 'bfs_first_hit':
108-
# Stop at the first hit for this outer expression
109-
break
116+
found.update(_visit(e))
110117

111118
return found
112119

0 commit comments

Comments
 (0)