diff --git a/examples/functional.pith b/examples/functional.pith index 3b4baae..adc0cc3 100644 --- a/examples/functional.pith +++ b/examples/functional.pith @@ -1,4 +1,4 @@ -# functional list operations — map, filter, reduce with lambdas +# functional list operations — map, filter, reduce as list methods import std.fmt as fmt @@ -6,29 +6,29 @@ fn main(): nums := [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] # map: transform each element - doubled := map(nums, fn(x: Int) => x * 2) + doubled := nums.map(fn(x: Int) => x * 2) print("doubled: " + fmt.ints(doubled)) - squared := map(nums, fn(x: Int) => x * x) + squared := nums.map(fn(x: Int) => x * x) print("squared: " + fmt.ints(squared)) # filter: keep elements matching predicate - evens := filter(nums, fn(x: Int) => 1 - (x % 2)) + evens := nums.filter(fn(x: Int) => x % 2 == 0) print("evens: " + fmt.ints(evens)) - big := filter(nums, fn(x: Int) => x / 6) + big := nums.filter(fn(x: Int) => x > 5) print("big (>5): " + fmt.ints(big)) # reduce: fold into single value - total := reduce(nums, 0, fn(acc: Int, x: Int) => acc + x) + total := nums.reduce(0, fn(acc: Int, x: Int) => acc + x) print("sum: {total}") small := [1, 2, 3, 4, 5] - product := reduce(small, 1, fn(acc: Int, x: Int) => acc * x) + product := small.reduce(1, fn(acc: Int, x: Int) => acc * x) print("product 1-5: {product}") - # chaining: filter then map - big_doubled := map(filter(nums, fn(x: Int) => x / 6), fn(x: Int) => x * 2) + # chaining reads left to right: filter then map + big_doubled := nums.filter(fn(x: Int) => x > 5).map(fn(x: Int) => x * 2) print("big doubled: " + fmt.ints(big_doubled)) # map.get_default for counters diff --git a/self-host/checker.pith b/self-host/checker.pith index 94efc21..65a3924 100644 --- a/self-host/checker.pith +++ b/self-host/checker.pith @@ -2308,6 +2308,80 @@ fn check_list_reduce_builtin(arg_indices: List[Int], scope_id: Int) -> Int: return TID_ERR return init_tid +# Method-call forms of the higher-order list builtins. The receiver is the +# list (its element type is info.inner), so these take one fewer argument than +# the free-function forms above: list.map(fn), list.filter(fn), list.reduce(init, fn). + +fn check_list_map_method(info: TypeInfo, args: List[Int], scope_id: Int) -> Int: + if args.len() != 1: + report_error("E207", "map expects 1 argument, got " + args.len().to_string()) + return TID_ERR + fn_tid := c_check_expr(call_arg_expr_index(args[0]), scope_id) + if is_error_type(fn_tid): + return TID_ERR + fn_info := get_type_info(fn_tid) + if fn_info.kind != "function": + report_error("E208", "map expects a function argument") + return TID_ERR + if fn_info.param_types.len() != 1: + report_error("E207", "map callback expects 1 parameter, got " + fn_info.param_types.len().to_string()) + return TID_ERR + if types_structurally_equal(fn_info.param_types[0], info.inner) == false: + report_error("E219", "map callback parameter mismatch: expected " + get_type_name(info.inner) + ", got " + get_type_name(fn_info.param_types[0])) + return TID_ERR + return intern_list_type(fn_info.return_type) + +fn check_list_filter_method(info: TypeInfo, args: List[Int], scope_id: Int) -> Int: + if args.len() != 1: + report_error("E207", "filter expects 1 argument, got " + args.len().to_string()) + return TID_ERR + fn_tid := c_check_expr(call_arg_expr_index(args[0]), scope_id) + if is_error_type(fn_tid): + return TID_ERR + fn_info := get_type_info(fn_tid) + if fn_info.kind != "function": + report_error("E208", "filter expects a function argument") + return TID_ERR + if fn_info.param_types.len() != 1: + report_error("E207", "filter callback expects 1 parameter, got " + fn_info.param_types.len().to_string()) + return TID_ERR + if types_structurally_equal(fn_info.param_types[0], info.inner) == false: + report_error("E219", "filter callback parameter mismatch: expected " + get_type_name(info.inner) + ", got " + get_type_name(fn_info.param_types[0])) + return TID_ERR + if fn_info.return_type != lookup_type_id("Bool"): + if is_integer_type(fn_info.return_type) == false: + report_error("E219", "filter callback must return Bool or Int-like truthy value, got " + get_type_name(fn_info.return_type)) + return TID_ERR + return intern_list_type(info.inner) + +fn check_list_reduce_method(info: TypeInfo, args: List[Int], scope_id: Int) -> Int: + if args.len() != 2: + report_error("E207", "reduce expects 2 arguments, got " + args.len().to_string()) + return TID_ERR + init_tid := c_check_expr(call_arg_expr_index(args[0]), scope_id) + fn_tid := c_check_expr(call_arg_expr_index(args[1]), scope_id) + if is_error_type(init_tid): + return TID_ERR + if is_error_type(fn_tid): + return TID_ERR + fn_info := get_type_info(fn_tid) + if fn_info.kind != "function": + report_error("E208", "reduce expects a function as the second argument") + return TID_ERR + if fn_info.param_types.len() != 2: + report_error("E207", "reduce callback expects 2 parameters, got " + fn_info.param_types.len().to_string()) + return TID_ERR + if types_structurally_equal(fn_info.param_types[0], init_tid) == false: + report_error("E219", "reduce accumulator parameter mismatch: expected " + get_type_name(init_tid) + ", got " + get_type_name(fn_info.param_types[0])) + return TID_ERR + if types_structurally_equal(fn_info.param_types[1], info.inner) == false: + report_error("E219", "reduce element parameter mismatch: expected " + get_type_name(info.inner) + ", got " + get_type_name(fn_info.param_types[1])) + return TID_ERR + if types_structurally_equal(fn_info.return_type, init_tid) == false: + report_error("E219", "reduce callback return mismatch: expected " + get_type_name(init_tid) + ", got " + get_type_name(fn_info.return_type)) + return TID_ERR + return init_tid + # --------------------------------------------------------------- # interface bounds checking # --------------------------------------------------------------- @@ -2925,6 +2999,12 @@ fn check_list_method_call(method: String, info: TypeInfo, args: List[Int], scope if method == "sort": expect_no_arguments(method, args) return intern_list_type(elem) + if method == "map": + return check_list_map_method(info, args, scope_id) + if method == "filter": + return check_list_filter_method(info, args, scope_id) + if method == "reduce": + return check_list_reduce_method(info, args, scope_id) # not a built-in list method return - 1 diff --git a/self-host/ir_driver b/self-host/ir_driver index 809a681..d1f3789 100755 Binary files a/self-host/ir_driver and b/self-host/ir_driver differ diff --git a/self-host/ir_emitter_core.pith b/self-host/ir_emitter_core.pith index 2d0f373..06aadf0 100644 --- a/self-host/ir_emitter_core.pith +++ b/self-host/ir_emitter_core.pith @@ -2515,8 +2515,9 @@ fn ir_emit_json_decode_file_call(idx: Int, node: Node, path_input: Bool) -> Int: fn ir_emit_list_map_call(idx: Int, node: Node) -> Int: if node.children.len() != 3: return - 1 - list_idx := ir_call_arg_expr_index(node.children[1]) - mapper_idx := ir_call_arg_expr_index(node.children[2]) + return ir_emit_list_map_lowering(idx, ir_call_arg_expr_index(node.children[1]), ir_call_arg_expr_index(node.children[2])) + +fn ir_emit_list_map_lowering(idx: Int, list_idx: Int, mapper_idx: Int) -> Int: list_r := ir_expr(list_idx) mapper_r := ir_expr(mapper_idx) mapper_name := ir_temp_name("__map_fn") @@ -2533,8 +2534,9 @@ fn ir_emit_list_map_call(idx: Int, node: Node) -> Int: fn ir_emit_list_filter_call(idx: Int, node: Node) -> Int: if node.children.len() != 3: return - 1 - list_idx := ir_call_arg_expr_index(node.children[1]) - keep_idx := ir_call_arg_expr_index(node.children[2]) + return ir_emit_list_filter_lowering(idx, ir_call_arg_expr_index(node.children[1]), ir_call_arg_expr_index(node.children[2])) + +fn ir_emit_list_filter_lowering(idx: Int, list_idx: Int, keep_idx: Int) -> Int: list_r := ir_expr(list_idx) keep_r := ir_expr(keep_idx) keep_name := ir_temp_name("__filter_fn") @@ -2552,9 +2554,9 @@ fn ir_emit_list_filter_call(idx: Int, node: Node) -> Int: fn ir_emit_list_reduce_call(idx: Int, node: Node) -> Int: if node.children.len() != 4: return - 1 - list_idx := ir_call_arg_expr_index(node.children[1]) - init_idx := ir_call_arg_expr_index(node.children[2]) - reducer_idx := ir_call_arg_expr_index(node.children[3]) + return ir_emit_list_reduce_lowering(idx, ir_call_arg_expr_index(node.children[1]), ir_call_arg_expr_index(node.children[2]), ir_call_arg_expr_index(node.children[3])) + +fn ir_emit_list_reduce_lowering(idx: Int, list_idx: Int, init_idx: Int, reducer_idx: Int) -> Int: list_r := ir_expr(list_idx) init_r := ir_expr(init_idx) reducer_r := ir_expr(reducer_idx) @@ -2602,7 +2604,7 @@ fn ir_emit_list_contains_method(node: Node, obj_type: String) -> Int: index_r := ir_emit_list_index_search(node, obj_type) return ir_emit_list_contains_from_index(index_r) -fn ir_emit_self_hosted_list_method(node: Node, mname: String, obj_type: String) -> Int: +fn ir_emit_self_hosted_list_method(idx: Int, node: Node, mname: String, obj_type: String) -> Int: if not ir_is_list_runtime_type(obj_type): return - 1 if mname == "is_empty": @@ -2611,6 +2613,14 @@ fn ir_emit_self_hosted_list_method(node: Node, mname: String, obj_type: String) return ir_emit_list_index_search(node, obj_type) if mname == "contains": return ir_emit_list_contains_method(node, obj_type) + if mname == "map": + # list.map(fn): receiver is the list, child 0; the mapper is arg 0 + return ir_emit_list_map_lowering(idx, node.children[0], ir_method_call_arg_expr(node, 0)) + if mname == "filter": + return ir_emit_list_filter_lowering(idx, node.children[0], ir_method_call_arg_expr(node, 0)) + if mname == "reduce": + # list.reduce(init, fn): init is arg 0, reducer is arg 1 + return ir_emit_list_reduce_lowering(idx, node.children[0], ir_method_call_arg_expr(node, 0), ir_method_call_arg_expr(node, 1)) return - 1 fn ir_emit_index_json_decode_call(idx: Int, node: Node, callee_idx: Int) -> Int: @@ -3274,7 +3284,7 @@ fn ir_emit_method_call(idx: Int) -> Int: return r obj_type := ir_method_receiver_type(node) - self_hosted_list_r := ir_emit_self_hosted_list_method(node, mname, obj_type) + self_hosted_list_r := ir_emit_self_hosted_list_method(idx, node, mname, obj_type) if self_hosted_list_r >= 0: return self_hosted_list_r mut emit_name := ir_method_emit_name(mname, obj_type) diff --git a/tests/cases/test_list_method_syntax.pith b/tests/cases/test_list_method_syntax.pith new file mode 100644 index 0000000..5e22706 --- /dev/null +++ b/tests/cases/test_list_method_syntax.pith @@ -0,0 +1,28 @@ +# map, filter, reduce called as methods on a list, including chaining + +fn ints_to_string(xs: List[Int]) -> String: + mut s := "" + for x in xs: + s = s + x.to_string() + " " + return s + +fn main(): + nums := [1, 2, 3, 4, 5, 6] + + doubled := nums.map(fn(x: Int) => x * 2) + print("map: " + ints_to_string(doubled)) + + evens := nums.filter(fn(x: Int) => x % 2 == 0) + print("filter: " + ints_to_string(evens)) + + total := nums.reduce(0, fn(acc: Int, x: Int) => acc + x) + print("reduce: {total}") + + # chaining reads left to right + chained := nums.map(fn(x: Int) => x * 10).filter(fn(x: Int) => x > 25) + print("chained: " + ints_to_string(chained)) + + # string elements + words := ["hi", "hello", "yo", "world"] + longish := words.filter(fn(w: String) => w.len() > 2) + print("long words: {longish.len()}") diff --git a/tests/expected/test_list_method_syntax.txt b/tests/expected/test_list_method_syntax.txt new file mode 100644 index 0000000..234fa69 --- /dev/null +++ b/tests/expected/test_list_method_syntax.txt @@ -0,0 +1,5 @@ +map: 2 4 6 8 10 12 +filter: 2 4 6 +reduce: 21 +chained: 30 40 50 60 +long words: 2