summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/parser/parser_expr.c148
-rw-r--r--tests/functions/test_lambda_arrow.zc8
2 files changed, 149 insertions, 7 deletions
diff --git a/src/parser/parser_expr.c b/src/parser/parser_expr.c
index e21c983..28c19ad 100644
--- a/src/parser/parser_expr.c
+++ b/src/parser/parser_expr.c
@@ -5250,6 +5250,52 @@ ASTNode *parse_expr_prec(ParserContext *ctx, Lexer *l, Precedence min_prec)
if (!skip_check && !type_eq(lhs->type_info, rhs->type_info) &&
!(is_integer_type(lhs->type_info) && is_integer_type(rhs->type_info)))
{
+ // Backward Inference for Lambda Params
+ // LHS is Unknown Var, RHS is Known
+ if (lhs->type == NODE_EXPR_VAR && lhs->type_info &&
+ lhs->type_info->kind == TYPE_UNKNOWN && rhs->type_info &&
+ rhs->type_info->kind != TYPE_UNKNOWN)
+ {
+ // Infer LHS type from RHS
+ Symbol *sym = find_symbol_entry(ctx, lhs->var_ref.name);
+ if (sym)
+ {
+ // Update Symbol
+ sym->type_info = rhs->type_info;
+ sym->type_name = type_to_string(rhs->type_info);
+
+ // Update AST Node
+ lhs->type_info = rhs->type_info;
+ lhs->resolved_type = xstrdup(sym->type_name);
+
+ // Re-check validity (optional, but good)
+ bin->type_info = rhs->type_info;
+ goto inference_success;
+ }
+ }
+
+ // RHS is Unknown Var, LHS is Known
+ if (rhs->type == NODE_EXPR_VAR && rhs->type_info &&
+ rhs->type_info->kind == TYPE_UNKNOWN && lhs->type_info &&
+ lhs->type_info->kind != TYPE_UNKNOWN)
+ {
+ // Infer RHS type from LHS
+ Symbol *sym = find_symbol_entry(ctx, rhs->var_ref.name);
+ if (sym)
+ {
+ // Update Symbol
+ sym->type_info = lhs->type_info;
+ sym->type_name = type_to_string(lhs->type_info);
+
+ // Update AST Node
+ rhs->type_info = lhs->type_info;
+ rhs->resolved_type = xstrdup(sym->type_name);
+
+ bin->type_info = lhs->type_info;
+ goto inference_success;
+ }
+ }
+
char msg[256];
sprintf(msg, "Type mismatch in comparison: cannot compare '%s' and '%s'", t1,
t2);
@@ -5258,6 +5304,8 @@ ASTNode *parse_expr_prec(ParserContext *ctx, Lexer *l, Precedence min_prec)
sprintf(suggestion, "Both operands must have compatible types for comparison");
zpanic_with_suggestion(op, msg, suggestion);
+
+ inference_success:;
}
}
else
@@ -5333,6 +5381,51 @@ ASTNode *parse_expr_prec(ParserContext *ctx, Lexer *l, Precedence min_prec)
if (!is_ptr_arith && !alias_match)
{
+ // ** Backward Inference for Binary Ops **
+ // Case 1: LHS is Unknown Var, RHS is Known
+ if (lhs->type == NODE_EXPR_VAR && lhs->type_info &&
+ lhs->type_info->kind == TYPE_UNKNOWN && rhs->type_info &&
+ rhs->type_info->kind != TYPE_UNKNOWN)
+ {
+ // Infer LHS type from RHS
+ Symbol *sym = find_symbol_entry(ctx, lhs->var_ref.name);
+ if (sym)
+ {
+ // Update Symbol
+ sym->type_info = rhs->type_info;
+ sym->type_name = type_to_string(rhs->type_info);
+
+ // Update AST Node
+ lhs->type_info = rhs->type_info;
+ lhs->resolved_type = xstrdup(sym->type_name);
+
+ bin->type_info = rhs->type_info;
+ goto bin_inference_success;
+ }
+ }
+
+ // Case 2: RHS is Unknown Var, LHS is Known
+ if (rhs->type == NODE_EXPR_VAR && rhs->type_info &&
+ rhs->type_info->kind == TYPE_UNKNOWN && lhs->type_info &&
+ lhs->type_info->kind != TYPE_UNKNOWN)
+ {
+ // Infer RHS type from LHS
+ Symbol *sym = find_symbol_entry(ctx, rhs->var_ref.name);
+ if (sym)
+ {
+ // Update Symbol
+ sym->type_info = lhs->type_info;
+ sym->type_name = type_to_string(lhs->type_info);
+
+ // Update AST Node
+ rhs->type_info = lhs->type_info;
+ rhs->resolved_type = xstrdup(sym->type_name);
+
+ bin->type_info = lhs->type_info;
+ goto bin_inference_success;
+ }
+ }
+
char msg[256];
sprintf(msg, "Type mismatch in binary operation '%s'", bin->binary.op);
@@ -5343,6 +5436,8 @@ ASTNode *parse_expr_prec(ParserContext *ctx, Lexer *l, Precedence min_prec)
t1, t2);
zpanic_with_suggestion(op, msg, suggestion);
+
+ bin_inference_success:;
}
}
}
@@ -5360,21 +5455,21 @@ ASTNode *parse_arrow_lambda_single(ParserContext *ctx, Lexer *l, char *param_nam
lambda->lambda.param_names[0] = param_name;
lambda->lambda.num_params = 1;
- // Default param type: int
+ // Default param type: unknown (to be inferred)
lambda->lambda.param_types = xmalloc(sizeof(char *));
- lambda->lambda.param_types[0] = xstrdup("int");
+ lambda->lambda.param_types[0] = xstrdup("unknown");
- // Create Type Info: int -> int
+ // Create Type Info: unknown -> unknown
Type *t = type_new(TYPE_FUNCTION);
- t->inner = type_new(TYPE_INT); // Return
+ t->inner = type_new(TYPE_UNKNOWN); // Return
t->args = xmalloc(sizeof(Type *));
- t->args[0] = type_new(TYPE_INT); // Arg
+ t->args[0] = type_new(TYPE_UNKNOWN); // Arg
t->arg_count = 1;
lambda->type_info = t;
// Register parameter in scope for body parsing
enter_scope(ctx);
- add_symbol(ctx, param_name, "int", type_new(TYPE_INT));
+ add_symbol(ctx, param_name, "unknown", t->args[0]);
// Body parsing...
ASTNode *body_block = NULL;
@@ -5391,7 +5486,46 @@ ASTNode *parse_arrow_lambda_single(ParserContext *ctx, Lexer *l, char *param_nam
body_block->block.statements = ret;
}
lambda->lambda.body = body_block;
- lambda->lambda.return_type = xstrdup("int");
+
+ // Attempt to infer return type from body if it's a simple return
+ if (lambda->lambda.body->block.statements &&
+ lambda->lambda.body->block.statements->type == NODE_RETURN &&
+ !lambda->lambda.body->block.statements->next)
+ {
+ ASTNode *ret_val = lambda->lambda.body->block.statements->ret.value;
+ if (ret_val->type_info && ret_val->type_info->kind != TYPE_UNKNOWN)
+ {
+ // Update return type
+ if (t->inner)
+ {
+ free(t->inner);
+ }
+ t->inner = ret_val->type_info;
+ }
+ }
+
+ // Update parameter types from symbol table (in case inference happened)
+ Symbol *sym = find_symbol_entry(ctx, param_name);
+ if (sym && sym->type_info && sym->type_info->kind != TYPE_UNKNOWN)
+ {
+ free(lambda->lambda.param_types[0]);
+ lambda->lambda.param_types[0] = type_to_string(sym->type_info);
+ t->args[0] = sym->type_info;
+ }
+ else
+ {
+ // Fallback to int if still unknown
+ free(lambda->lambda.param_types[0]);
+ lambda->lambda.param_types[0] = xstrdup("int");
+ // Update symbol to match fallback
+ if (sym)
+ {
+ sym->type_name = xstrdup("int");
+ sym->type_info = type_new(TYPE_INT);
+ }
+ }
+
+ lambda->lambda.return_type = type_to_string(t->inner);
lambda->lambda.lambda_id = ctx->lambda_counter++;
lambda->lambda.is_expression = 1;
register_lambda(ctx, lambda);
diff --git a/tests/functions/test_lambda_arrow.zc b/tests/functions/test_lambda_arrow.zc
index c976ecf..111487d 100644
--- a/tests/functions/test_lambda_arrow.zc
+++ b/tests/functions/test_lambda_arrow.zc
@@ -18,3 +18,11 @@ test "test_lambda_arrow" {
"compute((a, b) -> a * b, 3, 4) = {res3}";
if res3 != 12 { exit(1); }
}
+
+test "lambda_inference_repro" {
+ var dble = x -> x * 2.0;
+ var res = dble(9.0);
+ if res != 18.0 {
+ exit(1);
+ }
+}