Skip to content

Commit 508d41c

Browse files
committed
LazyVisitor generic typing
1 parent 82a11c9 commit 508d41c

1 file changed

Lines changed: 9 additions & 8 deletions

File tree

devito/ir/iet/visitors.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import cgen as c
1313
from sympy import IndexedBase
1414
from sympy.core.function import Application
15-
from typing import Any
15+
from typing import Any, Generic, TypeVar
1616

1717
from devito.exceptions import CompilationError
1818
from devito.ir.iet.nodes import (Node, Iteration, Expression, ExpressionBundle,
@@ -58,7 +58,8 @@ def always_rebuild(self, o, *args, **kwargs):
5858
return o._rebuild(*new_ops, **okwargs)
5959

6060

61-
class LazyVisitor(GenericVisitor):
61+
TResult = TypeVar('TResult')
62+
class LazyVisitor(GenericVisitor, Generic[TResult]):
6263

6364
"""
6465
A generic visitor that lazily yields results instead of flattening results
@@ -79,7 +80,7 @@ def _visit(self, o, *args, **kwargs) -> Iterator[Any]:
7980
meth = self.lookup_method(o)
8081
yield from meth(o, *args, **kwargs)
8182

82-
def _post_visit(self, ret: Iterator[Any]) -> list[Any]:
83+
def _post_visit(self, ret: Iterator[Any]) -> TResult:
8384
"""Postprocess the visitor output before returning it to the caller."""
8485
return list(ret)
8586

@@ -1010,7 +1011,7 @@ def visit_Node(self, o, ret=None, parents=None, in_parent=False):
10101011
return ret
10111012

10121013

1013-
class FindSymbols(LazyVisitor):
1014+
class FindSymbols(LazyVisitor[list[Any]]):
10141015

10151016
"""
10161017
Find symbols in an Iteration/Expression tree.
@@ -1061,7 +1062,7 @@ def __init__(self, mode: str = 'symbolics') -> None:
10611062
else:
10621063
self.rule = lambda n: chain(*[self.rules[mode](n) for mode in modes])
10631064

1064-
def _post_visit(self, ret: Iterable[Any]) -> Iterable[Any]:
1065+
def _post_visit(self, ret):
10651066
return sorted(filter_ordered(ret, key = id), key=str)
10661067

10671068
def visit_tuple(self, o: Sequence[Any]) -> Iterator[Any]:
@@ -1088,7 +1089,7 @@ def visit_Operator(self, o) -> Iterator[Any]:
10881089
yield from self._visit(i)
10891090

10901091

1091-
class FindNodes(LazyVisitor):
1092+
class FindNodes(LazyVisitor[list[Node]]):
10921093

10931094
"""
10941095
Find all instances of given type.
@@ -1104,12 +1105,12 @@ class FindNodes(LazyVisitor):
11041105
appears.
11051106
"""
11061107

1107-
rules = {
1108+
rules: dict[str, Callable[[type, Node], bool]] = {
11081109
'type': lambda match, o: isinstance(o, match),
11091110
'scope': lambda match, o: match in flatten(o.children)
11101111
}
11111112

1112-
def __init__(self, match, mode='type'):
1113+
def __init__(self, match: type, mode: str = 'type'):
11131114
super().__init__()
11141115
self.match = match
11151116
self.rule = self.rules[mode]

0 commit comments

Comments
 (0)