55"""
66
77from collections import OrderedDict
8- from collections .abc import Iterable
8+ from collections .abc import Callable , Iterable , Iterator , Sequence
99from itertools import chain , groupby
10+ from typing import Any , Generic , TypeVar
1011import ctypes
1112
1213import cgen as c
@@ -58,6 +59,47 @@ def always_rebuild(self, o, *args, **kwargs):
5859 return o ._rebuild (* new_ops , ** okwargs )
5960
6061
62+ ResultType = TypeVar ('ResultType' )
63+
64+
65+ class LazyVisitor (GenericVisitor , Generic [ResultType ]):
66+
67+ """
68+ A generic visitor that lazily yields results instead of flattening results
69+ from children at every step.
70+
71+ Subclass-defined visit methods (and default_retval) should be generators.
72+ """
73+
74+ @classmethod
75+ def default_retval (cls ) -> Iterator [Any ]:
76+ yield from ()
77+
78+ def lookup_method (self , instance ) -> Callable [..., Iterator [Any ]]:
79+ return super ().lookup_method (instance )
80+
81+ def _visit (self , o , * args , ** kwargs ) -> Iterator [Any ]:
82+ """Visit `o`."""
83+ meth = self .lookup_method (o )
84+ yield from meth (o , * args , ** kwargs )
85+
86+ def _post_visit (self , ret : Iterator [Any ]) -> ResultType :
87+ """Postprocess the visitor output before returning it to the caller."""
88+ return list (ret )
89+
90+ def visit_object (self , o : object , ** kwargs ) -> Iterator [Any ]:
91+ yield from self .default_retval ()
92+
93+ def visit_Node (self , o : Node , ** kwargs ) -> Iterator [Any ]:
94+ yield from self ._visit (o .children , ** kwargs )
95+
96+ def visit_tuple (self , o : Sequence [Any ]) -> Iterator [Any ]:
97+ for i in o :
98+ yield from self ._visit (i )
99+
100+ visit_list = visit_tuple
101+
102+
61103class PrintAST (Visitor ):
62104
63105 _depth = 0
@@ -978,16 +1020,7 @@ def visit_Node(self, o, ret=None, parents=None, in_parent=False):
9781020 return ret
9791021
9801022
981- class FindSymbols (Visitor ):
982-
983- class Retval (list ):
984- def __init__ (self , * retvals ):
985- elements = filter_ordered (flatten (retvals ), key = id )
986- super ().__init__ (elements )
987-
988- @classmethod
989- def default_retval (cls ):
990- return cls .Retval ()
1023+ class FindSymbols (LazyVisitor [list [Any ]]):
9911024
9921025 """
9931026 Find symbols in an Iteration/Expression tree.
@@ -1007,31 +1040,30 @@ def default_retval(cls):
10071040 """
10081041
10091042 def _defines_aliases (n ):
1010- retval = []
10111043 for i in n .defines :
10121044 f = i .function
10131045 if f .is_ArrayBasic :
1014- retval . extend ([ f , f .indexed ] )
1046+ yield from ( f , f .indexed )
10151047 else :
1016- retval .append (i )
1017- return tuple (retval )
1048+ yield i
10181049
1019- rules = {
1050+ RulesDict = dict [str , Callable [[Node ], Iterator [Any ]]]
1051+ rules : RulesDict = {
10201052 'symbolics' : lambda n : n .functions ,
1021- 'basics' : lambda n : [ i for i in n .expr_symbols if isinstance (i , Basic )] ,
1022- 'symbols' : lambda n : [ i for i in n .expr_symbols
1023- if isinstance (i , AbstractSymbol )] ,
1024- 'dimensions' : lambda n : [ i for i in n .expr_symbols if isinstance (i , Dimension )] ,
1025- 'indexeds' : lambda n : [ i for i in n .expr_symbols if i .is_Indexed ] ,
1026- 'indexedbases' : lambda n : [ i for i in n .expr_symbols
1027- if isinstance (i , IndexedBase )] ,
1053+ 'basics' : lambda n : ( i for i in n .expr_symbols if isinstance (i , Basic )) ,
1054+ 'symbols' : lambda n : ( i for i in n .expr_symbols
1055+ if isinstance (i , AbstractSymbol )) ,
1056+ 'dimensions' : lambda n : ( i for i in n .expr_symbols if isinstance (i , Dimension )) ,
1057+ 'indexeds' : lambda n : ( i for i in n .expr_symbols if i .is_Indexed ) ,
1058+ 'indexedbases' : lambda n : ( i for i in n .expr_symbols
1059+ if isinstance (i , IndexedBase )) ,
10281060 'writes' : lambda n : as_tuple (n .writes ),
10291061 'defines' : lambda n : as_tuple (n .defines ),
1030- 'globals' : lambda n : [ f .base for f in n .functions if f ._mem_global ] ,
1062+ 'globals' : lambda n : ( f .base for f in n .functions if f ._mem_global ) ,
10311063 'defines-aliases' : _defines_aliases
10321064 }
10331065
1034- def __init__ (self , mode = 'symbolics' ):
1066+ def __init__ (self , mode : str = 'symbolics' ) -> None :
10351067 super ().__init__ ()
10361068
10371069 modes = mode .split ('|' )
@@ -1041,33 +1073,27 @@ def __init__(self, mode='symbolics'):
10411073 self .rule = lambda n : chain (* [self .rules [mode ](n ) for mode in modes ])
10421074
10431075 def _post_visit (self , ret ):
1044- return sorted (ret , key = lambda i : str ( i ) )
1076+ return sorted (filter_ordered ( ret , key = id ), key = str )
10451077
1046- def visit_tuple (self , o ):
1047- return self .Retval (* [self ._visit (i ) for i in o ])
1048-
1049- visit_list = visit_tuple
1078+ def visit_Node (self , o : Node ) -> Iterator [Any ]:
1079+ yield from self ._visit (o .children )
1080+ yield from self .rule (o )
10501081
1051- def visit_Node (self , o ):
1052- return self .Retval (self ._visit (o .children ), self .rule (o ))
1053-
1054- def visit_ThreadedProdder (self , o ):
1082+ def visit_ThreadedProdder (self , o ) -> Iterator [Any ]:
10551083 # TODO: this handle required because ThreadedProdder suffers from the
10561084 # long-standing issue affecting all Node subclasses which rely on
10571085 # multiple inheritance
1058- return self .Retval (self ._visit (o .then_body ), self .rule (o ))
1059-
1060- def visit_Operator (self , o ):
1061- ret = self ._visit (o .body )
1062- ret .extend (flatten (self ._visit (v ) for v in o ._func_table .values ()))
1063- return self .Retval (ret , self .rule (o ))
1086+ yield from self ._visit (o .then_body )
1087+ yield from self .rule (o )
10641088
1089+ def visit_Operator (self , o ) -> Iterator [Any ]:
1090+ yield from self ._visit (o .body )
1091+ yield from self .rule (o )
1092+ for i in o ._func_table .values ():
1093+ yield from self ._visit (i )
10651094
1066- class FindNodes (Visitor ):
10671095
1068- @classmethod
1069- def default_retval (cls ):
1070- return []
1096+ class FindNodes (LazyVisitor [list [Node ]]):
10711097
10721098 """
10731099 Find all instances of given type.
@@ -1083,34 +1109,22 @@ def default_retval(cls):
10831109 appears.
10841110 """
10851111
1086- rules = {
1112+ RulesDict = dict [str , Callable [[type , Node ], bool ]]
1113+ rules : RulesDict = {
10871114 'type' : lambda match , o : isinstance (o , match ),
10881115 'scope' : lambda match , o : match in flatten (o .children )
10891116 }
10901117
1091- def __init__ (self , match , mode = 'type' ):
1118+ def __init__ (self , match : type , mode : str = 'type' ):
10921119 super ().__init__ ()
10931120 self .match = match
10941121 self .rule = self .rules [mode ]
10951122
1096- def visit_object (self , o , ret = None ):
1097- return ret
1098-
1099- def visit_tuple (self , o , ret = None ):
1100- for i in o :
1101- ret = self ._visit (i , ret = ret )
1102- return ret
1103-
1104- visit_list = visit_tuple
1105-
1106- def visit_Node (self , o , ret = None ):
1107- if ret is None :
1108- ret = self .default_retval ()
1123+ def visit_Node (self , o : Node ) -> Iterator [Any ]:
11091124 if self .rule (self .match , o ):
1110- ret . append ( o )
1125+ yield o
11111126 for i in o .children :
1112- ret = self ._visit (i , ret = ret )
1113- return ret
1127+ yield from self ._visit (i )
11141128
11151129
11161130class FindWithin (FindNodes ):
@@ -1156,53 +1170,36 @@ def visit_Node(self, o, ret=None):
11561170 return found , flag
11571171
11581172
1159- class FindApplications (Visitor ):
1173+ ApplicationType = TypeVar ('ApplicationType' )
1174+
11601175
1176+ class FindApplications (LazyVisitor [set [ApplicationType ]]):
11611177 """
11621178 Find all SymPy applied functions (aka, `Application`s). The user may refine
11631179 the search by supplying a different target class.
11641180 """
11651181
1166- def __init__ (self , cls = Application ):
1182+ def __init__ (self , cls : type [ ApplicationType ] = Application ):
11671183 super ().__init__ ()
11681184 self .match = lambda i : isinstance (i , cls ) and not isinstance (i , Basic )
11691185
1170- @classmethod
1171- def default_retval (cls ):
1172- return set ()
1173-
1174- def visit_object (self , o , ** kwargs ):
1175- return self .default_retval ()
1176-
1177- def visit_tuple (self , o , ret = None ):
1178- ret = ret or self .default_retval ()
1179- for i in o :
1180- ret .update (self ._visit (i , ret = ret ))
1181- return ret
1182-
1183- def visit_Node (self , o , ret = None ):
1184- ret = ret or self .default_retval ()
1185- for i in o .children :
1186- ret .update (self ._visit (i , ret = ret ))
1187- return ret
1186+ def _post_visit (self , ret ):
1187+ return set (ret )
11881188
1189- def visit_Expression (self , o , ** kwargs ):
1190- return o .expr .find (self .match )
1189+ def visit_Expression (self , o : Expression , ** kwargs ) -> Iterator [ ApplicationType ] :
1190+ yield from o .expr .find (self .match )
11911191
1192- def visit_Iteration (self , o , ** kwargs ):
1193- ret = self ._visit (o .children ) or self .default_retval ()
1194- ret .update (o .symbolic_min .find (self .match ))
1195- ret .update (o .symbolic_max .find (self .match ))
1196- return ret
1192+ def visit_Iteration (self , o : Iteration , ** kwargs ) -> Iterator [ApplicationType ]:
1193+ yield from self ._visit (o .children )
1194+ yield from o .symbolic_min .find (self .match )
1195+ yield from o .symbolic_max .find (self .match )
11971196
1198- def visit_Call (self , o , ** kwargs ):
1199- ret = self .default_retval ()
1197+ def visit_Call (self , o : Call , ** kwargs ) -> Iterator [ApplicationType ]:
12001198 for i in o .arguments :
12011199 try :
1202- ret . update ( i .find (self .match ) )
1200+ yield from i .find (self .match )
12031201 except (AttributeError , TypeError ):
1204- ret .update (self ._visit (i , ret = ret ))
1205- return ret
1202+ yield from self ._visit (i )
12061203
12071204
12081205class IsPerfectIteration (Visitor ):
0 commit comments