Skip to content

Commit 9181005

Browse files
committed
Lazy FindApplications
1 parent 508d41c commit 9181005

1 file changed

Lines changed: 18 additions & 28 deletions

File tree

devito/ir/iet/visitors.py

Lines changed: 18 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1128,53 +1128,43 @@ def visit_Node(self, o: Node) -> Iterator[Any]:
11281128
yield from self._visit(i)
11291129

11301130

1131-
class FindApplications(Visitor):
1131+
TApp = TypeVar('TApp')
1132+
class FindApplications(LazyVisitor[set[TApp]]):
11321133

11331134
"""
11341135
Find all SymPy applied functions (aka, `Application`s). The user may refine
11351136
the search by supplying a different target class.
11361137
"""
11371138

1138-
def __init__(self, cls=Application):
1139+
def __init__(self, cls: type[TApp] = Application):
11391140
super().__init__()
11401141
self.match = lambda i: isinstance(i, cls) and not isinstance(i, Basic)
11411142

1142-
@classmethod
1143-
def default_retval(cls):
1144-
return set()
1145-
1146-
def visit_object(self, o, **kwargs):
1147-
return self.default_retval()
1143+
def _post_visit(self, ret):
1144+
return set(ret)
11481145

1149-
def visit_tuple(self, o, ret=None):
1150-
ret = ret or self.default_retval()
1146+
def visit_tuple(self, o: Sequence[Any]) -> Iterator[TApp]:
11511147
for i in o:
1152-
ret.update(self._visit(i, ret=ret))
1153-
return ret
1148+
yield from self._visit(i)
11541149

1155-
def visit_Node(self, o, ret=None):
1156-
ret = ret or self.default_retval()
1150+
def visit_Node(self, o: Node) -> Iterator[TApp]:
11571151
for i in o.children:
1158-
ret.update(self._visit(i, ret=ret))
1159-
return ret
1152+
yield from self._visit(i)
11601153

1161-
def visit_Expression(self, o, **kwargs):
1162-
return o.expr.find(self.match)
1154+
def visit_Expression(self, o: Expression, **kwargs) -> Iterator[TApp]:
1155+
yield from set(o.expr.find(self.match))
11631156

1164-
def visit_Iteration(self, o, **kwargs):
1165-
ret = self._visit(o.children) or self.default_retval()
1166-
ret.update(o.symbolic_min.find(self.match))
1167-
ret.update(o.symbolic_max.find(self.match))
1168-
return ret
1157+
def visit_Iteration(self, o: Iteration, **kwargs) -> Iterator[TApp]:
1158+
yield from self._visit(o.children)
1159+
yield from set(o.symbolic_min.find(self.match))
1160+
yield from set(o.symbolic_max.find(self.match))
11691161

1170-
def visit_Call(self, o, **kwargs):
1171-
ret = self.default_retval()
1162+
def visit_Call(self, o: Call, **kwargs) -> Iterator[TApp]:
11721163
for i in o.arguments:
11731164
try:
1174-
ret.update(i.find(self.match))
1165+
yield from set(i.find(self.match))
11751166
except (AttributeError, TypeError):
1176-
ret.update(self._visit(i, ret=ret))
1177-
return ret
1167+
yield from self._visit(i)
11781168

11791169

11801170
class IsPerfectIteration(Visitor):

0 commit comments

Comments
 (0)