Skip to content

Commit d69462b

Browse files
committed
Slightly faster search
1 parent 8e1bc1b commit d69462b

1 file changed

Lines changed: 9 additions & 11 deletions

File tree

devito/symbolics/search.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from collections.abc import Callable, Iterable, Iterator
2+
from itertools import chain
23
from typing import Any, Literal
34

45
import numpy as np
@@ -54,11 +55,12 @@ def __init__(self, query: Callable[[Expression], bool], deep: bool = False) -> N
5455
self.query = query
5556
self.deep = deep
5657

57-
def _next(self, expr: Expression) -> Iterator[Expression]:
58+
def _next(self, expr: Expression) -> Iterable[Expression]:
5859
if self.deep and expr.is_Indexed:
59-
yield from expr.indices
60-
elif not q_leaf(expr):
61-
yield from expr.args
60+
return expr.indices
61+
elif q_leaf(expr):
62+
return ()
63+
return expr.args
6264

6365
def visit_postorder(self, expr: Expression) -> Iterator[Expression]:
6466
"""
@@ -116,13 +118,9 @@ def search(exprs: Expression | Iterable[Expression],
116118
_search = searcher.visit_preorder_first_hit
117119
else:
118120
raise ValueError(f"Unknown visit mode '{visit}'")
119-
120-
found = modes[mode]()
121-
for e in as_tuple(exprs):
122-
if not isinstance(e, sympy.Basic):
123-
continue
124-
125-
found.update(_search(e))
121+
122+
exprs = filter(lambda e: isinstance(e, sympy.Basic), as_tuple(exprs))
123+
found = modes[mode](chain(*map(_search, exprs)))
126124

127125
return found
128126

0 commit comments

Comments
 (0)