77from collections import OrderedDict
88from collections .abc import Callable , Iterable , Iterator , Sequence
99from itertools import chain , groupby
10+ from typing import Any , Generic , TypeAlias , TypeVar
1011import ctypes
1112
1213import cgen as c
1314from sympy import IndexedBase
1415from sympy .core .function import Application
15- from typing import Any , Generic , TypeVar
1616
1717from devito .exceptions import CompilationError
1818from devito .ir .iet .nodes import (Node , Iteration , Expression , ExpressionBundle ,
@@ -1046,7 +1046,8 @@ def _defines_aliases(n):
10461046 else :
10471047 yield i
10481048
1049- rules = {
1049+ RulesDict : TypeAlias = dict [str , Callable [[Node ], Iterator [Any ]]]
1050+ rules : RulesDict = {
10501051 'symbolics' : lambda n : n .functions ,
10511052 'basics' : lambda n : (i for i in n .expr_symbols if isinstance (i , Basic )),
10521053 'symbols' : lambda n : (i for i in n .expr_symbols
@@ -1068,7 +1069,7 @@ def __init__(self, mode: str = 'symbolics') -> None:
10681069 if len (modes ) == 1 :
10691070 self .rule = self .rules [mode ]
10701071 else :
1071- self .rule = lambda n : chain (* [ self .rules [mode ](n ) for mode in modes ] )
1072+ self .rule = lambda n : chain (self .rules [mode ](n ) for mode in modes )
10721073
10731074 def _post_visit (self , ret ):
10741075 return sorted (filter_ordered (ret , key = id ), key = str )
@@ -1107,7 +1108,8 @@ class FindNodes(LazyVisitor[list[Node]]):
11071108 appears.
11081109 """
11091110
1110- rules : dict [str , Callable [[type , Node ], bool ]] = {
1111+ RulesDict : TypeAlias = dict [str , Callable [[type , Node ], bool ]]
1112+ rules : RulesDict = {
11111113 'type' : lambda match , o : isinstance (o , match ),
11121114 'scope' : lambda match , o : match in flatten (o .children )
11131115 }
0 commit comments