Skip to content

Commit f0bc509

Browse files
committed
misc: Fix FindWithin (again)
1 parent 48373f5 commit f0bc509

1 file changed

Lines changed: 43 additions & 16 deletions

File tree

devito/ir/iet/visitors.py

Lines changed: 43 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1131,36 +1131,63 @@ class FindWithin(FindNodes):
11311131
collecting matching nodes after `stop` is found.
11321132
"""
11331133

1134-
# Dummy object to signal the end of the search
1135-
STOP = object()
1134+
# Sentinel values to signal the start/end of a matching window
1135+
SET_FLAG = object()
1136+
UNSET_FLAG = object()
11361137

11371138
def __init__(self, match: type, start: Node, stop: Node | None = None) -> None:
11381139
super().__init__(match)
11391140
self.start = start
11401141
self.stop = stop
11411142

1142-
def _post_visit(self, ret: Iterator[Node]) -> list[Node]:
1143-
ret = super()._post_visit(ret)
1144-
if ret and ret[-1] is self.STOP:
1145-
ret.pop()
1146-
return ret
1143+
def _post_visit(self, ret: Iterator[Node | object]) -> list[Node]:
1144+
return super()._post_visit(i for i in ret
1145+
if i not in (self.SET_FLAG, self.UNSET_FLAG))
1146+
1147+
def visit_object(self, o: object, flag: bool = False) -> Iterator[Node | object]:
1148+
yield self.SET_FLAG if flag else self.UNSET_FLAG
1149+
1150+
def visit_tuple(self, o: Sequence[Any], flag: bool = False) -> Iterator[Node | object]:
1151+
for el in o:
1152+
for i in self._visit(el, flag=flag):
1153+
# New flag state is yielded at the end of child results
1154+
if i is self.SET_FLAG:
1155+
flag = True
1156+
continue
1157+
if i is self.UNSET_FLAG:
1158+
flag = False
1159+
continue
1160+
1161+
# Regular object
1162+
yield i
11471163

1148-
def visit_Node(self, o: Node, flag: bool = False) -> Iterator[Node]:
1149-
if o is self.start:
1150-
flag = True
1164+
yield self.SET_FLAG if flag else self.UNSET_FLAG
1165+
1166+
visit_list = visit_tuple
1167+
1168+
def visit_Node(self, o: Node, flag: bool = False) -> Iterator[Node | object]:
1169+
flag = flag or (o is self.start)
11511170

11521171
if flag and self.rule(self.match, o):
11531172
yield o
1173+
11541174
for child in o.children:
11551175
for i in self._visit(child, flag=flag):
1156-
if flag and i is self.STOP:
1157-
yield self.STOP
1158-
return
1159-
1176+
# New flag state is yielded at the end of child results
1177+
if i is self.SET_FLAG:
1178+
flag = True
1179+
continue
1180+
if i is self.UNSET_FLAG:
1181+
if flag:
1182+
yield self.UNSET_FLAG
1183+
return
1184+
continue
1185+
1186+
# Regular object
11601187
yield i
11611188

1162-
if flag and o is self.stop:
1163-
yield self.STOP
1189+
flag &= (o is not self.stop)
1190+
yield self.SET_FLAG if flag else self.UNSET_FLAG
11641191

11651192

11661193
ApplicationType = TypeVar('ApplicationType')

0 commit comments

Comments
 (0)