Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 27 additions & 2 deletions self-host/checker.pith
Original file line number Diff line number Diff line change
Expand Up @@ -1208,6 +1208,22 @@ fn check_while_statement(node: Node, scope_id: Int):
# for statement
# ---------------------------------------------------------------

# If `iter_type` is an iterator (a struct with a `next() -> T?` method), return
# the element type T; otherwise return -1. Kept small and separate so the map
# reads compile reliably.
fn iterator_element_type(iter_type: Int) -> Int:
info := get_type_info(iter_type)
if info.kind != "struct":
return 0 - 1
next_key := get_type_name(iter_type) + ".next"
if method_type_map.contains_key(next_key) == false:
return 0 - 1
next_fn_info := get_type_info(method_type_map[next_key])
ret_info := get_type_info(next_fn_info.return_type)
if ret_info.kind != "optional":
return 0 - 1
return ret_info.inner

fn check_for_statement(node: Node, scope_id: Int):
# node.value: "x" or "x, i" (binding names)
# children[0]: iterable expression
Expand Down Expand Up @@ -1244,14 +1260,23 @@ fn check_for_statement(node: Node, scope_id: Int):
elif iter_info.kind == "set":
elem_type = iter_info.inner
else:
report_error("E217", "cannot iterate over " + get_type_name(iter_type))
return
# iterator protocol: a struct with a `next() -> T?` method is iterable,
# yielding T until next() returns none
it_elem := iterator_element_type(iter_type)
if it_elem >= 0:
elem_type = it_elem
else:
report_error("E217", "cannot iterate over " + get_type_name(iter_type))
return
# create loop scope
loop_scope := create_loop_scope(scope_id)
# bind the loop variable
if node.children.len() >= 3:
elem_name := get_node(node.children[2]).value
define_binding(loop_scope, elem_name, elem_type, false)
# cache the element type on the loop-var node so the ir emitter can
# read it back (used to type the value extracted from an iterator)
expression_type_cache.insert(node.children[2], elem_type)
if node.children.len() >= 4:
idx_name := get_node(node.children[3]).value
define_binding(loop_scope, idx_name, tid_int, false)
Expand Down
Binary file modified self-host/ir_driver
Binary file not shown.
64 changes: 64 additions & 0 deletions self-host/ir_emitter_core.pith
Original file line number Diff line number Diff line change
Expand Up @@ -4067,6 +4067,13 @@ fn ir_emit_for_stmt(node: Node):
if get_node(node.children[0]).kind == "range":
ir_emit_for_range_stmt(node, names)
return
# iterator protocol: a struct with a `next() -> T?` method drives the
# loop via next() instead of index-based access
iter_struct := ir_infer_type(node.children[0])
next_name := ir_lookup_impl_method_name(iter_struct, "next")
if next_name.len() > 0:
ir_emit_for_iterator_stmt(node, names, next_name)
return
iter_type := ir_infer_type(node.children[0])
mut iter_r := ir_expr(node.children[0])
item_kind := ir_for_iter_item_kind(iter_type)
Expand Down Expand Up @@ -4186,6 +4193,63 @@ fn ir_emit_for_range_stmt(node: Node, names: IrForBindingNames):
ir_emit("jmp " + head_l)
ir_emit("label " + end_l)

# `for x in it` where `it` is an iterator (a struct with `next() -> T?`).
# Drives the loop by calling next() each iteration: the result is an optional
# tuple (is_some flag at field 0, value at field 8); stop when the flag is 0.
fn ir_emit_for_iterator_stmt(node: Node, names: IrForBindingNames, next_name: String):
iter_r := ir_expr(node.children[0])
fid := ir_for_count
ir_for_count = ir_for_count + 1
it_var := "__for_it_" + fid.to_string()
fp := "__for_pos_" + fid.to_string()
has_index := names.index_name.len() > 0
ir_emit("store " + it_var + " " + iter_r.to_string())
if has_index:
zero_r := ir_reg()
ir_emit("iconst " + zero_r.to_string() + " 0")
ir_emit("store " + fp + " " + zero_r.to_string())
# the element type was cached on the loop-var node by the checker
mut elem_kind := ir_retkind(ir_type_from_tid(c_get_expr_type(node.children[2])))
if elem_kind.len() == 0:
elem_kind = "unknown"
head_l := ir_label()
body_l := ir_label()
step_l := ir_label()
end_l := ir_label()
ir_emit("label " + head_l)
it_cur := ir_reg()
ir_emit("load " + it_cur.to_string() + " " + it_var)
opt_r := ir_reg()
ir_emit("call " + opt_r.to_string() + " " + next_name + " tuple 1 " + it_cur.to_string())
flag_r := ir_optional_flag_field(opt_r)
ir_emit("brif " + flag_r.to_string() + " " + body_l + " " + end_l)
ir_emit("label " + body_l)
val_r := ir_reg()
ir_emit("field " + val_r.to_string() + " " + opt_r.to_string() + " 8 " + elem_kind + " value")
ir_var_types.insert(names.value_name, elem_kind)
ir_emit("store " + names.value_name + " " + val_r.to_string())
if has_index:
pos_r := ir_reg()
ir_emit("load " + pos_r.to_string() + " " + fp)
ir_var_types.insert(names.index_name, "int")
ir_emit("store " + names.index_name + " " + pos_r.to_string())
ir_break_stack = ir_push_string(ir_break_stack, end_l)
ir_continue_stack = ir_push_string(ir_continue_stack, step_l)
ir_block(node.children[1])
ir_break_stack.remove(ir_break_stack.len() - 1)
ir_continue_stack.remove(ir_continue_stack.len() - 1)
ir_emit("label " + step_l)
if has_index:
p3 := ir_reg()
ir_emit("load " + p3.to_string() + " " + fp)
pone_r := ir_reg()
ir_emit("iconst " + pone_r.to_string() + " 1")
pinc_r := ir_reg()
ir_emit("add " + pinc_r.to_string() + " " + p3.to_string() + " " + pone_r.to_string())
ir_emit("store " + fp + " " + pinc_r.to_string())
ir_emit("jmp " + head_l)
ir_emit("label " + end_l)

fn ir_stmt_control(idx: Int, node: Node):
if node.kind == "if" or node.kind == "if_stmt":
ir_emit_if_stmt(node)
Expand Down
54 changes: 54 additions & 0 deletions tests/cases/test_for_iterator.pith
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# a `for` loop drives any struct that has a `next() -> T?` method (the iterator
# protocol), calling next() until it returns none

struct Counter:
cur: Int
hi: Int

impl Counter:
fn next() -> Int?:
if self.cur >= self.hi:
return none
v := self.cur
self.cur = self.cur + 1
return v

struct Words:
items: List[String]
pos: Int

impl Words:
fn next() -> String?:
if self.pos >= self.items.len():
return none
w := self.items[self.pos]
self.pos = self.pos + 1
return w

fn main():
mut total := 0
for x in Counter(0, 5):
total = total + x
print("sum: {total}")

# break / continue
mut parts := ""
for x in Counter(0, 10):
if x == 2:
continue
if x == 5:
break
parts = parts + x.to_string()
print("bc: {parts}")

# value + index binding (index is the 0-based position)
mut vi := ""
for v, i in Counter(10, 13):
vi = vi + "(" + v.to_string() + "@" + i.to_string() + ")"
print("vi: {vi}")

# string-element iterator
mut joined := ""
for w in Words(["red", "green", "blue"], 0):
joined = joined + w + " "
print("words: {joined}")
4 changes: 4 additions & 0 deletions tests/expected/test_for_iterator.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
sum: 10
bc: 0134
vi: (10@0)(11@1)(12@2)
words: red green blue
Loading