Skip to content

Commit 8e1bc1b

Browse files
committed
Fix search regression + typing
1 parent a3b4048 commit 8e1bc1b

1 file changed

Lines changed: 37 additions & 28 deletions

File tree

devito/symbolics/search.py

Lines changed: 37 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from collections.abc import Callable, Iterable, Iterator
22
from typing import Any, Literal
33

4+
import numpy as np
45
import sympy
56

67
from devito.symbolics.queries import (q_indexed, q_function, q_terminal, q_leaf,
@@ -12,31 +13,33 @@
1213
'retrieve_derivatives', 'search']
1314

1415

15-
class Set(set):
16+
Expression = sympy.Basic | np.number | int | float
1617

18+
19+
class Set(set[Expression]):
1720
@staticmethod
18-
def wrap(obj) -> set:
21+
def wrap(obj: Expression) -> set[Expression]:
1922
return {obj}
2023

2124

22-
class List(list):
23-
25+
class List(list[Expression]):
2426
@staticmethod
25-
def wrap(obj) -> list:
27+
def wrap(obj: Expression) -> list[Expression]:
2628
return [obj]
2729

28-
def update(self, obj: Iterable[Any]) -> None:
29-
return self.extend(obj)
30-
30+
def update(self, obj: Iterable[Expression]) -> None:
31+
self.extend(obj)
32+
3133

32-
modes: dict[Literal['all', 'unique'], type[List] | type[Set]] = {
34+
Mode = Literal['all', 'unique']
35+
modes: dict[Mode, type[List] | type[Set]] = {
3336
'all': List,
3437
'unique': Set
3538
}
3639

3740

3841
class Search:
39-
def __init__(self, query: Callable[[Any], bool], deep: bool = False) -> None:
42+
def __init__(self, query: Callable[[Expression], bool], deep: bool = False) -> None:
4043
"""
4144
Search objects in an expression. This is much quicker than the more general
4245
SymPy's find.
@@ -51,39 +54,46 @@ def __init__(self, query: Callable[[Any], bool], deep: bool = False) -> None:
5154
self.query = query
5255
self.deep = deep
5356

54-
def _next(self, expr) -> Iterator[Any]:
57+
def _next(self, expr: Expression) -> Iterator[Expression]:
5558
if self.deep and expr.is_Indexed:
5659
yield from expr.indices
5760
elif not q_leaf(expr):
5861
yield from expr.args
5962

60-
def visit_postorder(self, expr) -> Iterator[Any]:
63+
def visit_postorder(self, expr: Expression) -> Iterator[Expression]:
64+
"""
65+
Visit the expression with a postorder traversal, yielding all hits.
66+
"""
6167
for i in self._next(expr):
6268
yield from self.visit_postorder(i)
6369
if self.query(expr):
6470
yield expr
6571

66-
def visit_preorder(self, expr) -> Iterator[Any]:
72+
def visit_preorder(self, expr: Expression) -> Iterator[Expression]:
73+
"""
74+
Visit the expression with a preorder traversal, yielding all hits.
75+
"""
6776
if self.query(expr):
6877
yield expr
6978
for i in self._next(expr):
7079
yield from self.visit_preorder(i)
7180

72-
def visit_preorder_first_hit(self, expr) -> tuple[Any, ...]:
73-
"""Visit the expression in preorder and return the first hit."""
81+
def visit_preorder_first_hit(self, expr: Expression) -> Iterator[Expression]:
82+
"""
83+
Visit the expression in preorder and return a tuple containing the first hit,
84+
if any. This can return more than a single result, as it looks for the first
85+
hit from any branch but may find a hit in multiple branches.
86+
"""
7487
if self.query(expr):
75-
return (expr,)
88+
yield expr
89+
return
7690
for i in self._next(expr):
77-
result = self.visit_preorder_first_hit(i)
78-
if result:
79-
return result
80-
return ()
81-
91+
yield from self.visit_preorder_first_hit(i)
8292

8393

84-
def search(exprs,
94+
def search(exprs: Expression | Iterable[Expression],
8595
query: type | Callable[[Any], bool],
86-
mode: Literal['all', 'unique'] = 'unique',
96+
mode: Mode = 'unique',
8797
visit: Literal['dfs', 'bfs', 'bfs_first_hit'] = 'dfs',
8898
deep: bool = False) -> List | Set:
8999
"""Interface to Search."""
@@ -98,13 +108,12 @@ def search(exprs,
98108
# Search doesn't actually use a BFS (rather, a preorder DFS), but the terminology
99109
# is retained in this function's parameters for backwards compatibility
100110
searcher = Search(Q, deep)
101-
102111
if visit == 'dfs':
103-
_visit = searcher.visit_postorder
112+
_search = searcher.visit_postorder
104113
elif visit == 'bfs':
105-
_visit = searcher.visit_preorder
114+
_search = searcher.visit_preorder
106115
elif visit == 'bfs_first_hit':
107-
_visit = searcher.visit_preorder_first_hit
116+
_search = searcher.visit_preorder_first_hit
108117
else:
109118
raise ValueError(f"Unknown visit mode '{visit}'")
110119

@@ -113,7 +122,7 @@ def search(exprs,
113122
if not isinstance(e, sympy.Basic):
114123
continue
115124

116-
found.update(_visit(e))
125+
found.update(_search(e))
117126

118127
return found
119128

0 commit comments

Comments
 (0)