@@ -59,10 +59,11 @@ def always_rebuild(self, o, *args, **kwargs):
5959 return o ._rebuild (* new_ops , ** okwargs )
6060
6161
62- ResultType = TypeVar ('ResultType' )
62+ YieldType = TypeVar ('YieldType' , covariant = True )
63+ ResultType = TypeVar ('ResultType' , covariant = True )
6364
6465
65- class LazyVisitor (GenericVisitor , Generic [ResultType ]):
66+ class LazyVisitor (GenericVisitor , Generic [YieldType , ResultType ]):
6667
6768 """
6869 A generic visitor that lazily yields results instead of flattening results
@@ -71,25 +72,25 @@ class LazyVisitor(GenericVisitor, Generic[ResultType]):
7172 Subclass-defined visit methods should be generators.
7273 """
7374
74- def lookup_method (self , instance ) -> Callable [..., Iterator [Any ]]:
75+ def lookup_method (self , instance ) -> Callable [..., Iterator [YieldType ]]:
7576 return super ().lookup_method (instance )
7677
77- def _visit (self , o , * args , ** kwargs ) -> Iterator [Any ]:
78+ def _visit (self , o , * args , ** kwargs ) -> Iterator [YieldType ]:
7879 meth = self .lookup_method (o )
7980 yield from meth (o , * args , ** kwargs )
8081
81- def _post_visit (self , ret : Iterator [Any ]) -> ResultType :
82+ def _post_visit (self , ret : Iterator [YieldType ]) -> ResultType :
8283 return list (ret )
8384
84- def visit_object (self , o : object , ** kwargs ) -> Iterator [Any ]:
85+ def visit_object (self , o : object , ** kwargs ) -> Iterator [YieldType ]:
8586 yield from ()
8687
87- def visit_Node (self , o : Node , ** kwargs ) -> Iterator [Any ]:
88+ def visit_Node (self , o : Node , ** kwargs ) -> Iterator [YieldType ]:
8889 yield from self ._visit (o .children , ** kwargs )
8990
90- def visit_tuple (self , o : Sequence [Any ]) -> Iterator [Any ]:
91+ def visit_tuple (self , o : Sequence [Any ], ** kwargs ) -> Iterator [YieldType ]:
9192 for i in o :
92- yield from self ._visit (i )
93+ yield from self ._visit (i , ** kwargs )
9394
9495 visit_list = visit_tuple
9596
@@ -1014,7 +1015,7 @@ def visit_Node(self, o, ret=None, parents=None, in_parent=False):
10141015 return ret
10151016
10161017
1017- class FindSymbols (LazyVisitor [list [Any ]]):
1018+ class FindSymbols (LazyVisitor [Any , list [Any ]]):
10181019
10191020 """
10201021 Find symbols in an Iteration/Expression tree.
@@ -1088,7 +1089,7 @@ def visit_Operator(self, o) -> Iterator[Any]:
10881089 yield from self ._visit (i )
10891090
10901091
1091- class FindNodes (LazyVisitor [list [Node ]]):
1092+ class FindNodes (LazyVisitor [Node , list [Node ]]):
10921093
10931094 """
10941095 Find all instances of given type.
@@ -1110,65 +1111,61 @@ class FindNodes(LazyVisitor[list[Node]]):
11101111 'scope' : lambda match , o : match in flatten (o .children )
11111112 }
11121113
1113- def __init__ (self , match : type , mode : str = 'type' ):
1114+ def __init__ (self , match : type , mode : str = 'type' ) -> None :
11141115 super ().__init__ ()
11151116 self .match = match
11161117 self .rule = self .rules [mode ]
11171118
1118- def visit_Node (self , o : Node ) -> Iterator [Any ]:
1119+ def visit_Node (self , o : Node , ** kwargs ) -> Iterator [Node ]:
11191120 if self .rule (self .match , o ):
11201121 yield o
11211122 for i in o .children :
1122- yield from self ._visit (i )
1123+ yield from self ._visit (i , ** kwargs )
11231124
11241125
11251126class FindWithin (FindNodes ):
11261127
1127- @classmethod
1128- def default_retval (cls ):
1129- return [], False
1130-
11311128 """
11321129 Like FindNodes, but given an additional parameter `within=(start, stop)`,
11331130 it starts collecting matching nodes only after `start` is found, and stops
11341131 collecting matching nodes after `stop` is found.
11351132 """
11361133
1137- def __init__ (self , match , start , stop = None ):
1134+ # Dummy object to signal the end of the search
1135+ STOP = object ()
1136+
1137+ def __init__ (self , match : type , start : Node , stop : Node | None = None ) -> None :
11381138 super ().__init__ (match )
11391139 self .start = start
11401140 self .stop = stop
11411141
1142- def visit (self , o , ret = None ):
1143- found , _ = self ._visit (o , ret = ret )
1144- return found
1145-
1146- def visit_Node (self , o , ret = None ):
1147- if ret is None :
1148- ret = self .default_retval ()
1149- found , flag = ret
1142+ def _post_visit (self , ret : Iterator [Node ]) -> list [Node ]:
1143+ ret = super ()._post_visit (ret )
1144+ if ret [- 1 ] is self .STOP :
1145+ ret .pop ()
1146+ return ret
11501147
1148+ def visit_Node (self , o : Node , flag : bool = False ) -> Iterator [Node ]:
11511149 if o is self .start :
11521150 flag = True
11531151
11541152 if flag and self .rule (self .match , o ):
1155- found . append ( o )
1153+ yield o
11561154 for i in o .children :
1157- found , newflag = self ._visit (i , ret = ( found , flag ))
1158- if flag and not newflag :
1159- return found , newflag
1160- flag = newflag
1155+ for r in self ._visit (i , flag = flag ):
1156+ yield r
1157+ if r is self . STOP :
1158+ return
11611159
11621160 if o is self .stop :
1163- flag = False
1164-
1165- return found , flag
1161+ yield self .STOP
11661162
11671163
11681164ApplicationType = TypeVar ('ApplicationType' )
11691165
11701166
1171- class FindApplications (LazyVisitor [set [ApplicationType ]]):
1167+ class FindApplications (LazyVisitor [ApplicationType , set [ApplicationType ]]):
1168+
11721169 """
11731170 Find all SymPy applied functions (aka, `Application`s). The user may refine
11741171 the search by supplying a different target class.
0 commit comments