Skip to content

Commit d6a722f

Browse files
committed
Squashed commit of the following:
commit 2079dda Author: enwask <enwask@ucf.edu> Date: Fri Jun 6 10:34:58 2025 +0200 Remove FindApplications.visit_Node override for coverage commit 60aaf00 Author: enwask <enwask@ucf.edu> Date: Thu Jun 5 00:17:10 2025 +0200 compiler: Lazy IET visitors commit 167d26d Author: enwask <enwask@ucf.edu> Date: Mon Jun 2 12:09:46 2025 +0200 Fix FindSymbols commit 871e42a Author: enwask <enwask@ucf.edu> Date: Sun Jun 1 23:59:20 2025 +0200 Remove TypeAlias for Python 3.9 commit 38230cf Author: enwask <enwask@ucf.edu> Date: Sat May 31 19:22:04 2025 +0100 Fix all the tests failing commit f7bdfbd Author: enwask <enwask@ucf.edu> Date: Fri May 30 16:27:12 2025 +0100 Remove redundant key=id, rename type parameters commit 2b13898 Author: enwask <enwask@ucf.edu> Date: Fri May 30 15:56:10 2025 +0100 Move typing import, add RulesDict type aliases commit 82cc017 Author: enwask <enwask@ucf.edu> Date: Fri May 30 12:52:22 2025 +0100 Move visit_tuple up + remove redundant set casts commit 4d1b64b Author: enwask <enwask@ucf.edu> Date: Fri May 30 11:38:02 2025 +0100 Formatting fixes commit 9181005 Author: enwask <enwask@ucf.edu> Date: Thu May 29 11:25:35 2025 +0100 Lazy FindApplications commit 508d41c Author: enwask <enwask@ucf.edu> Date: Thu May 29 11:23:53 2025 +0100 LazyVisitor generic typing commit 82a11c9 Author: enwask <enwask@ucf.edu> Date: Wed May 28 18:30:40 2025 +0100 Cleanup + faster FindSymbols commit 3d9fe74 Author: enwask <enwask@ucf.edu> Date: Wed May 28 17:25:51 2025 +0100 FindSymbols tweaks commit e56e768 Author: enwask <enwask@ucf.edu> Date: Wed May 28 12:03:16 2025 +0100 Lazy FindNodes commit de94ac3 Author: enwask <enwask@ucf.edu> Date: Tue May 27 18:57:18 2025 +0100 Lazy FindSymbols commit 921125b Author: enwask <enwask@ucf.edu> Date: Tue May 27 18:55:49 2025 +0100 Add LazyVisitor
1 parent 14ff62a commit d6a722f

1 file changed

Lines changed: 91 additions & 94 deletions

File tree

devito/ir/iet/visitors.py

Lines changed: 91 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55
"""
66

77
from collections import OrderedDict
8-
from collections.abc import Iterable
8+
from collections.abc import Callable, Iterable, Iterator, Sequence
99
from itertools import chain, groupby
10+
from typing import Any, Generic, TypeVar
1011
import ctypes
1112

1213
import cgen as c
@@ -58,6 +59,47 @@ def always_rebuild(self, o, *args, **kwargs):
5859
return o._rebuild(*new_ops, **okwargs)
5960

6061

62+
ResultType = TypeVar('ResultType')
63+
64+
65+
class LazyVisitor(GenericVisitor, Generic[ResultType]):
66+
67+
"""
68+
A generic visitor that lazily yields results instead of flattening results
69+
from children at every step.
70+
71+
Subclass-defined visit methods (and default_retval) should be generators.
72+
"""
73+
74+
@classmethod
75+
def default_retval(cls) -> Iterator[Any]:
76+
yield from ()
77+
78+
def lookup_method(self, instance) -> Callable[..., Iterator[Any]]:
79+
return super().lookup_method(instance)
80+
81+
def _visit(self, o, *args, **kwargs) -> Iterator[Any]:
82+
"""Visit `o`."""
83+
meth = self.lookup_method(o)
84+
yield from meth(o, *args, **kwargs)
85+
86+
def _post_visit(self, ret: Iterator[Any]) -> ResultType:
87+
"""Postprocess the visitor output before returning it to the caller."""
88+
return list(ret)
89+
90+
def visit_object(self, o: object, **kwargs) -> Iterator[Any]:
91+
yield from self.default_retval()
92+
93+
def visit_Node(self, o: Node, **kwargs) -> Iterator[Any]:
94+
yield from self._visit(o.children, **kwargs)
95+
96+
def visit_tuple(self, o: Sequence[Any]) -> Iterator[Any]:
97+
for i in o:
98+
yield from self._visit(i)
99+
100+
visit_list = visit_tuple
101+
102+
61103
class PrintAST(Visitor):
62104

63105
_depth = 0
@@ -978,16 +1020,7 @@ def visit_Node(self, o, ret=None, parents=None, in_parent=False):
9781020
return ret
9791021

9801022

981-
class FindSymbols(Visitor):
982-
983-
class Retval(list):
984-
def __init__(self, *retvals):
985-
elements = filter_ordered(flatten(retvals), key=id)
986-
super().__init__(elements)
987-
988-
@classmethod
989-
def default_retval(cls):
990-
return cls.Retval()
1023+
class FindSymbols(LazyVisitor[list[Any]]):
9911024

9921025
"""
9931026
Find symbols in an Iteration/Expression tree.
@@ -1007,31 +1040,30 @@ def default_retval(cls):
10071040
"""
10081041

10091042
def _defines_aliases(n):
1010-
retval = []
10111043
for i in n.defines:
10121044
f = i.function
10131045
if f.is_ArrayBasic:
1014-
retval.extend([f, f.indexed])
1046+
yield from (f, f.indexed)
10151047
else:
1016-
retval.append(i)
1017-
return tuple(retval)
1048+
yield i
10181049

1019-
rules = {
1050+
RulesDict = dict[str, Callable[[Node], Iterator[Any]]]
1051+
rules: RulesDict = {
10201052
'symbolics': lambda n: n.functions,
1021-
'basics': lambda n: [i for i in n.expr_symbols if isinstance(i, Basic)],
1022-
'symbols': lambda n: [i for i in n.expr_symbols
1023-
if isinstance(i, AbstractSymbol)],
1024-
'dimensions': lambda n: [i for i in n.expr_symbols if isinstance(i, Dimension)],
1025-
'indexeds': lambda n: [i for i in n.expr_symbols if i.is_Indexed],
1026-
'indexedbases': lambda n: [i for i in n.expr_symbols
1027-
if isinstance(i, IndexedBase)],
1053+
'basics': lambda n: (i for i in n.expr_symbols if isinstance(i, Basic)),
1054+
'symbols': lambda n: (i for i in n.expr_symbols
1055+
if isinstance(i, AbstractSymbol)),
1056+
'dimensions': lambda n: (i for i in n.expr_symbols if isinstance(i, Dimension)),
1057+
'indexeds': lambda n: (i for i in n.expr_symbols if i.is_Indexed),
1058+
'indexedbases': lambda n: (i for i in n.expr_symbols
1059+
if isinstance(i, IndexedBase)),
10281060
'writes': lambda n: as_tuple(n.writes),
10291061
'defines': lambda n: as_tuple(n.defines),
1030-
'globals': lambda n: [f.base for f in n.functions if f._mem_global],
1062+
'globals': lambda n: (f.base for f in n.functions if f._mem_global),
10311063
'defines-aliases': _defines_aliases
10321064
}
10331065

1034-
def __init__(self, mode='symbolics'):
1066+
def __init__(self, mode: str = 'symbolics') -> None:
10351067
super().__init__()
10361068

10371069
modes = mode.split('|')
@@ -1041,33 +1073,27 @@ def __init__(self, mode='symbolics'):
10411073
self.rule = lambda n: chain(*[self.rules[mode](n) for mode in modes])
10421074

10431075
def _post_visit(self, ret):
1044-
return sorted(ret, key=lambda i: str(i))
1076+
return sorted(filter_ordered(ret, key=id), key=str)
10451077

1046-
def visit_tuple(self, o):
1047-
return self.Retval(*[self._visit(i) for i in o])
1048-
1049-
visit_list = visit_tuple
1078+
def visit_Node(self, o: Node) -> Iterator[Any]:
1079+
yield from self._visit(o.children)
1080+
yield from self.rule(o)
10501081

1051-
def visit_Node(self, o):
1052-
return self.Retval(self._visit(o.children), self.rule(o))
1053-
1054-
def visit_ThreadedProdder(self, o):
1082+
def visit_ThreadedProdder(self, o) -> Iterator[Any]:
10551083
# TODO: this handle required because ThreadedProdder suffers from the
10561084
# long-standing issue affecting all Node subclasses which rely on
10571085
# multiple inheritance
1058-
return self.Retval(self._visit(o.then_body), self.rule(o))
1059-
1060-
def visit_Operator(self, o):
1061-
ret = self._visit(o.body)
1062-
ret.extend(flatten(self._visit(v) for v in o._func_table.values()))
1063-
return self.Retval(ret, self.rule(o))
1086+
yield from self._visit(o.then_body)
1087+
yield from self.rule(o)
10641088

1089+
def visit_Operator(self, o) -> Iterator[Any]:
1090+
yield from self._visit(o.body)
1091+
yield from self.rule(o)
1092+
for i in o._func_table.values():
1093+
yield from self._visit(i)
10651094

1066-
class FindNodes(Visitor):
10671095

1068-
@classmethod
1069-
def default_retval(cls):
1070-
return []
1096+
class FindNodes(LazyVisitor[list[Node]]):
10711097

10721098
"""
10731099
Find all instances of given type.
@@ -1083,34 +1109,22 @@ def default_retval(cls):
10831109
appears.
10841110
"""
10851111

1086-
rules = {
1112+
RulesDict = dict[str, Callable[[type, Node], bool]]
1113+
rules: RulesDict = {
10871114
'type': lambda match, o: isinstance(o, match),
10881115
'scope': lambda match, o: match in flatten(o.children)
10891116
}
10901117

1091-
def __init__(self, match, mode='type'):
1118+
def __init__(self, match: type, mode: str = 'type'):
10921119
super().__init__()
10931120
self.match = match
10941121
self.rule = self.rules[mode]
10951122

1096-
def visit_object(self, o, ret=None):
1097-
return ret
1098-
1099-
def visit_tuple(self, o, ret=None):
1100-
for i in o:
1101-
ret = self._visit(i, ret=ret)
1102-
return ret
1103-
1104-
visit_list = visit_tuple
1105-
1106-
def visit_Node(self, o, ret=None):
1107-
if ret is None:
1108-
ret = self.default_retval()
1123+
def visit_Node(self, o: Node) -> Iterator[Any]:
11091124
if self.rule(self.match, o):
1110-
ret.append(o)
1125+
yield o
11111126
for i in o.children:
1112-
ret = self._visit(i, ret=ret)
1113-
return ret
1127+
yield from self._visit(i)
11141128

11151129

11161130
class FindWithin(FindNodes):
@@ -1156,53 +1170,36 @@ def visit_Node(self, o, ret=None):
11561170
return found, flag
11571171

11581172

1159-
class FindApplications(Visitor):
1173+
ApplicationType = TypeVar('ApplicationType')
1174+
11601175

1176+
class FindApplications(LazyVisitor[set[ApplicationType]]):
11611177
"""
11621178
Find all SymPy applied functions (aka, `Application`s). The user may refine
11631179
the search by supplying a different target class.
11641180
"""
11651181

1166-
def __init__(self, cls=Application):
1182+
def __init__(self, cls: type[ApplicationType] = Application):
11671183
super().__init__()
11681184
self.match = lambda i: isinstance(i, cls) and not isinstance(i, Basic)
11691185

1170-
@classmethod
1171-
def default_retval(cls):
1172-
return set()
1173-
1174-
def visit_object(self, o, **kwargs):
1175-
return self.default_retval()
1176-
1177-
def visit_tuple(self, o, ret=None):
1178-
ret = ret or self.default_retval()
1179-
for i in o:
1180-
ret.update(self._visit(i, ret=ret))
1181-
return ret
1182-
1183-
def visit_Node(self, o, ret=None):
1184-
ret = ret or self.default_retval()
1185-
for i in o.children:
1186-
ret.update(self._visit(i, ret=ret))
1187-
return ret
1186+
def _post_visit(self, ret):
1187+
return set(ret)
11881188

1189-
def visit_Expression(self, o, **kwargs):
1190-
return o.expr.find(self.match)
1189+
def visit_Expression(self, o: Expression, **kwargs) -> Iterator[ApplicationType]:
1190+
yield from o.expr.find(self.match)
11911191

1192-
def visit_Iteration(self, o, **kwargs):
1193-
ret = self._visit(o.children) or self.default_retval()
1194-
ret.update(o.symbolic_min.find(self.match))
1195-
ret.update(o.symbolic_max.find(self.match))
1196-
return ret
1192+
def visit_Iteration(self, o: Iteration, **kwargs) -> Iterator[ApplicationType]:
1193+
yield from self._visit(o.children)
1194+
yield from o.symbolic_min.find(self.match)
1195+
yield from o.symbolic_max.find(self.match)
11971196

1198-
def visit_Call(self, o, **kwargs):
1199-
ret = self.default_retval()
1197+
def visit_Call(self, o: Call, **kwargs) -> Iterator[ApplicationType]:
12001198
for i in o.arguments:
12011199
try:
1202-
ret.update(i.find(self.match))
1200+
yield from i.find(self.match)
12031201
except (AttributeError, TypeError):
1204-
ret.update(self._visit(i, ret=ret))
1205-
return ret
1202+
yield from self._visit(i)
12061203

12071204

12081205
class IsPerfectIteration(Visitor):

0 commit comments

Comments
 (0)