Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions examples/functional.pith
Original file line number Diff line number Diff line change
@@ -1,34 +1,34 @@
# functional list operations — map, filter, reduce with lambdas
# functional list operations — map, filter, reduce as list methods

import std.fmt as fmt

fn main():
nums := [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]

# map: transform each element
doubled := map(nums, fn(x: Int) => x * 2)
doubled := nums.map(fn(x: Int) => x * 2)
print("doubled: " + fmt.ints(doubled))

squared := map(nums, fn(x: Int) => x * x)
squared := nums.map(fn(x: Int) => x * x)
print("squared: " + fmt.ints(squared))

# filter: keep elements matching predicate
evens := filter(nums, fn(x: Int) => 1 - (x % 2))
evens := nums.filter(fn(x: Int) => x % 2 == 0)
print("evens: " + fmt.ints(evens))

big := filter(nums, fn(x: Int) => x / 6)
big := nums.filter(fn(x: Int) => x > 5)
print("big (>5): " + fmt.ints(big))

# reduce: fold into single value
total := reduce(nums, 0, fn(acc: Int, x: Int) => acc + x)
total := nums.reduce(0, fn(acc: Int, x: Int) => acc + x)
print("sum: {total}")

small := [1, 2, 3, 4, 5]
product := reduce(small, 1, fn(acc: Int, x: Int) => acc * x)
product := small.reduce(1, fn(acc: Int, x: Int) => acc * x)
print("product 1-5: {product}")

# chaining: filter then map
big_doubled := map(filter(nums, fn(x: Int) => x / 6), fn(x: Int) => x * 2)
# chaining reads left to right: filter then map
big_doubled := nums.filter(fn(x: Int) => x > 5).map(fn(x: Int) => x * 2)
print("big doubled: " + fmt.ints(big_doubled))

# map.get_default for counters
Expand Down
80 changes: 80 additions & 0 deletions self-host/checker.pith
Original file line number Diff line number Diff line change
Expand Up @@ -2308,6 +2308,80 @@ fn check_list_reduce_builtin(arg_indices: List[Int], scope_id: Int) -> Int:
return TID_ERR
return init_tid

# Method-call forms of the higher-order list builtins. The receiver is the
# list (its element type is info.inner), so these take one fewer argument than
# the free-function forms above: list.map(fn), list.filter(fn), list.reduce(init, fn).

fn check_list_map_method(info: TypeInfo, args: List[Int], scope_id: Int) -> Int:
if args.len() != 1:
report_error("E207", "map expects 1 argument, got " + args.len().to_string())
return TID_ERR
fn_tid := c_check_expr(call_arg_expr_index(args[0]), scope_id)
if is_error_type(fn_tid):
return TID_ERR
fn_info := get_type_info(fn_tid)
if fn_info.kind != "function":
report_error("E208", "map expects a function argument")
return TID_ERR
if fn_info.param_types.len() != 1:
report_error("E207", "map callback expects 1 parameter, got " + fn_info.param_types.len().to_string())
return TID_ERR
if types_structurally_equal(fn_info.param_types[0], info.inner) == false:
report_error("E219", "map callback parameter mismatch: expected " + get_type_name(info.inner) + ", got " + get_type_name(fn_info.param_types[0]))
return TID_ERR
return intern_list_type(fn_info.return_type)

fn check_list_filter_method(info: TypeInfo, args: List[Int], scope_id: Int) -> Int:
if args.len() != 1:
report_error("E207", "filter expects 1 argument, got " + args.len().to_string())
return TID_ERR
fn_tid := c_check_expr(call_arg_expr_index(args[0]), scope_id)
if is_error_type(fn_tid):
return TID_ERR
fn_info := get_type_info(fn_tid)
if fn_info.kind != "function":
report_error("E208", "filter expects a function argument")
return TID_ERR
if fn_info.param_types.len() != 1:
report_error("E207", "filter callback expects 1 parameter, got " + fn_info.param_types.len().to_string())
return TID_ERR
if types_structurally_equal(fn_info.param_types[0], info.inner) == false:
report_error("E219", "filter callback parameter mismatch: expected " + get_type_name(info.inner) + ", got " + get_type_name(fn_info.param_types[0]))
return TID_ERR
if fn_info.return_type != lookup_type_id("Bool"):
if is_integer_type(fn_info.return_type) == false:
report_error("E219", "filter callback must return Bool or Int-like truthy value, got " + get_type_name(fn_info.return_type))
return TID_ERR
return intern_list_type(info.inner)

fn check_list_reduce_method(info: TypeInfo, args: List[Int], scope_id: Int) -> Int:
if args.len() != 2:
report_error("E207", "reduce expects 2 arguments, got " + args.len().to_string())
return TID_ERR
init_tid := c_check_expr(call_arg_expr_index(args[0]), scope_id)
fn_tid := c_check_expr(call_arg_expr_index(args[1]), scope_id)
if is_error_type(init_tid):
return TID_ERR
if is_error_type(fn_tid):
return TID_ERR
fn_info := get_type_info(fn_tid)
if fn_info.kind != "function":
report_error("E208", "reduce expects a function as the second argument")
return TID_ERR
if fn_info.param_types.len() != 2:
report_error("E207", "reduce callback expects 2 parameters, got " + fn_info.param_types.len().to_string())
return TID_ERR
if types_structurally_equal(fn_info.param_types[0], init_tid) == false:
report_error("E219", "reduce accumulator parameter mismatch: expected " + get_type_name(init_tid) + ", got " + get_type_name(fn_info.param_types[0]))
return TID_ERR
if types_structurally_equal(fn_info.param_types[1], info.inner) == false:
report_error("E219", "reduce element parameter mismatch: expected " + get_type_name(info.inner) + ", got " + get_type_name(fn_info.param_types[1]))
return TID_ERR
if types_structurally_equal(fn_info.return_type, init_tid) == false:
report_error("E219", "reduce callback return mismatch: expected " + get_type_name(init_tid) + ", got " + get_type_name(fn_info.return_type))
return TID_ERR
return init_tid

# ---------------------------------------------------------------
# interface bounds checking
# ---------------------------------------------------------------
Expand Down Expand Up @@ -2925,6 +2999,12 @@ fn check_list_method_call(method: String, info: TypeInfo, args: List[Int], scope
if method == "sort":
expect_no_arguments(method, args)
return intern_list_type(elem)
if method == "map":
return check_list_map_method(info, args, scope_id)
if method == "filter":
return check_list_filter_method(info, args, scope_id)
if method == "reduce":
return check_list_reduce_method(info, args, scope_id)
# not a built-in list method
return - 1

Expand Down
Binary file modified self-host/ir_driver
Binary file not shown.
28 changes: 19 additions & 9 deletions self-host/ir_emitter_core.pith
Original file line number Diff line number Diff line change
Expand Up @@ -2515,8 +2515,9 @@ fn ir_emit_json_decode_file_call(idx: Int, node: Node, path_input: Bool) -> Int:
fn ir_emit_list_map_call(idx: Int, node: Node) -> Int:
if node.children.len() != 3:
return - 1
list_idx := ir_call_arg_expr_index(node.children[1])
mapper_idx := ir_call_arg_expr_index(node.children[2])
return ir_emit_list_map_lowering(idx, ir_call_arg_expr_index(node.children[1]), ir_call_arg_expr_index(node.children[2]))

fn ir_emit_list_map_lowering(idx: Int, list_idx: Int, mapper_idx: Int) -> Int:
list_r := ir_expr(list_idx)
mapper_r := ir_expr(mapper_idx)
mapper_name := ir_temp_name("__map_fn")
Expand All @@ -2533,8 +2534,9 @@ fn ir_emit_list_map_call(idx: Int, node: Node) -> Int:
fn ir_emit_list_filter_call(idx: Int, node: Node) -> Int:
if node.children.len() != 3:
return - 1
list_idx := ir_call_arg_expr_index(node.children[1])
keep_idx := ir_call_arg_expr_index(node.children[2])
return ir_emit_list_filter_lowering(idx, ir_call_arg_expr_index(node.children[1]), ir_call_arg_expr_index(node.children[2]))

fn ir_emit_list_filter_lowering(idx: Int, list_idx: Int, keep_idx: Int) -> Int:
list_r := ir_expr(list_idx)
keep_r := ir_expr(keep_idx)
keep_name := ir_temp_name("__filter_fn")
Expand All @@ -2552,9 +2554,9 @@ fn ir_emit_list_filter_call(idx: Int, node: Node) -> Int:
fn ir_emit_list_reduce_call(idx: Int, node: Node) -> Int:
if node.children.len() != 4:
return - 1
list_idx := ir_call_arg_expr_index(node.children[1])
init_idx := ir_call_arg_expr_index(node.children[2])
reducer_idx := ir_call_arg_expr_index(node.children[3])
return ir_emit_list_reduce_lowering(idx, ir_call_arg_expr_index(node.children[1]), ir_call_arg_expr_index(node.children[2]), ir_call_arg_expr_index(node.children[3]))

fn ir_emit_list_reduce_lowering(idx: Int, list_idx: Int, init_idx: Int, reducer_idx: Int) -> Int:
list_r := ir_expr(list_idx)
init_r := ir_expr(init_idx)
reducer_r := ir_expr(reducer_idx)
Expand Down Expand Up @@ -2602,7 +2604,7 @@ fn ir_emit_list_contains_method(node: Node, obj_type: String) -> Int:
index_r := ir_emit_list_index_search(node, obj_type)
return ir_emit_list_contains_from_index(index_r)

fn ir_emit_self_hosted_list_method(node: Node, mname: String, obj_type: String) -> Int:
fn ir_emit_self_hosted_list_method(idx: Int, node: Node, mname: String, obj_type: String) -> Int:
if not ir_is_list_runtime_type(obj_type):
return - 1
if mname == "is_empty":
Expand All @@ -2611,6 +2613,14 @@ fn ir_emit_self_hosted_list_method(node: Node, mname: String, obj_type: String)
return ir_emit_list_index_search(node, obj_type)
if mname == "contains":
return ir_emit_list_contains_method(node, obj_type)
if mname == "map":
# list.map(fn): receiver is the list, child 0; the mapper is arg 0
return ir_emit_list_map_lowering(idx, node.children[0], ir_method_call_arg_expr(node, 0))
if mname == "filter":
return ir_emit_list_filter_lowering(idx, node.children[0], ir_method_call_arg_expr(node, 0))
if mname == "reduce":
# list.reduce(init, fn): init is arg 0, reducer is arg 1
return ir_emit_list_reduce_lowering(idx, node.children[0], ir_method_call_arg_expr(node, 0), ir_method_call_arg_expr(node, 1))
return - 1

fn ir_emit_index_json_decode_call(idx: Int, node: Node, callee_idx: Int) -> Int:
Expand Down Expand Up @@ -3274,7 +3284,7 @@ fn ir_emit_method_call(idx: Int) -> Int:
return r

obj_type := ir_method_receiver_type(node)
self_hosted_list_r := ir_emit_self_hosted_list_method(node, mname, obj_type)
self_hosted_list_r := ir_emit_self_hosted_list_method(idx, node, mname, obj_type)
if self_hosted_list_r >= 0:
return self_hosted_list_r
mut emit_name := ir_method_emit_name(mname, obj_type)
Expand Down
28 changes: 28 additions & 0 deletions tests/cases/test_list_method_syntax.pith
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# map, filter, reduce called as methods on a list, including chaining

fn ints_to_string(xs: List[Int]) -> String:
mut s := ""
for x in xs:
s = s + x.to_string() + " "
return s

fn main():
nums := [1, 2, 3, 4, 5, 6]

doubled := nums.map(fn(x: Int) => x * 2)
print("map: " + ints_to_string(doubled))

evens := nums.filter(fn(x: Int) => x % 2 == 0)
print("filter: " + ints_to_string(evens))

total := nums.reduce(0, fn(acc: Int, x: Int) => acc + x)
print("reduce: {total}")

# chaining reads left to right
chained := nums.map(fn(x: Int) => x * 10).filter(fn(x: Int) => x > 25)
print("chained: " + ints_to_string(chained))

# string elements
words := ["hi", "hello", "yo", "world"]
longish := words.filter(fn(w: String) => w.len() > 2)
print("long words: {longish.len()}")
5 changes: 5 additions & 0 deletions tests/expected/test_list_method_syntax.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
map: 2 4 6 8 10 12
filter: 2 4 6
reduce: 21
chained: 30 40 50 60
long words: 2
Loading