1212import cgen as c
1313from sympy import IndexedBase
1414from sympy .core .function import Application
15- from typing import Any
15+ from typing import Any , Generic , TypeVar
1616
1717from devito .exceptions import CompilationError
1818from devito .ir .iet .nodes import (Node , Iteration , Expression , ExpressionBundle ,
@@ -58,7 +58,8 @@ def always_rebuild(self, o, *args, **kwargs):
5858 return o ._rebuild (* new_ops , ** okwargs )
5959
6060
61- class LazyVisitor (GenericVisitor ):
61+ TResult = TypeVar ('TResult' )
62+ class LazyVisitor (GenericVisitor , Generic [TResult ]):
6263
6364 """
6465 A generic visitor that lazily yields results instead of flattening results
@@ -79,7 +80,7 @@ def _visit(self, o, *args, **kwargs) -> Iterator[Any]:
7980 meth = self .lookup_method (o )
8081 yield from meth (o , * args , ** kwargs )
8182
82- def _post_visit (self , ret : Iterator [Any ]) -> list [ Any ] :
83+ def _post_visit (self , ret : Iterator [Any ]) -> TResult :
8384 """Postprocess the visitor output before returning it to the caller."""
8485 return list (ret )
8586
@@ -1010,7 +1011,7 @@ def visit_Node(self, o, ret=None, parents=None, in_parent=False):
10101011 return ret
10111012
10121013
1013- class FindSymbols (LazyVisitor ):
1014+ class FindSymbols (LazyVisitor [ list [ Any ]] ):
10141015
10151016 """
10161017 Find symbols in an Iteration/Expression tree.
@@ -1061,7 +1062,7 @@ def __init__(self, mode: str = 'symbolics') -> None:
10611062 else :
10621063 self .rule = lambda n : chain (* [self .rules [mode ](n ) for mode in modes ])
10631064
1064- def _post_visit (self , ret : Iterable [ Any ]) -> Iterable [ Any ] :
1065+ def _post_visit (self , ret ) :
10651066 return sorted (filter_ordered (ret , key = id ), key = str )
10661067
10671068 def visit_tuple (self , o : Sequence [Any ]) -> Iterator [Any ]:
@@ -1088,7 +1089,7 @@ def visit_Operator(self, o) -> Iterator[Any]:
10881089 yield from self ._visit (i )
10891090
10901091
1091- class FindNodes (LazyVisitor ):
1092+ class FindNodes (LazyVisitor [ list [ Node ]] ):
10921093
10931094 """
10941095 Find all instances of given type.
@@ -1104,12 +1105,12 @@ class FindNodes(LazyVisitor):
11041105 appears.
11051106 """
11061107
1107- rules = {
1108+ rules : dict [ str , Callable [[ type , Node ], bool ]] = {
11081109 'type' : lambda match , o : isinstance (o , match ),
11091110 'scope' : lambda match , o : match in flatten (o .children )
11101111 }
11111112
1112- def __init__ (self , match , mode = 'type' ):
1113+ def __init__ (self , match : type , mode : str = 'type' ):
11131114 super ().__init__ ()
11141115 self .match = match
11151116 self .rule = self .rules [mode ]
0 commit comments