diff options
| author | Zuhaitz <zuhaitz.zechhub@gmail.com> | 2026-01-25 23:23:43 +0000 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2026-01-25 23:23:43 +0000 |
| commit | b568f67d75553bbecd2cadc4d61b330b8aea2ad2 (patch) | |
| tree | 93de523967424146ba2b4ccb0f728c47cdbe2251 /src | |
| parent | 18b0932249b0df8ddea159ba187cb9c3587197da (diff) | |
| parent | 59951529ba67d3316a01afd45808c1b20b20c1e1 (diff) | |
Merge pull request #109 from iryuken/main
Fix generic struct pointer instantiation bug #105
Diffstat (limited to 'src')
| -rw-r--r-- | src/parser/parser_core.c | 28 | ||||
| -rw-r--r-- | src/parser/parser_expr.c | 127 | ||||
| -rw-r--r-- | src/parser/parser_struct.c | 14 | ||||
| -rw-r--r-- | src/parser/parser_utils.c | 165 |
4 files changed, 315 insertions, 19 deletions
diff --git a/src/parser/parser_core.c b/src/parser/parser_core.c index 5e685d7..1245a55 100644 --- a/src/parser/parser_core.c +++ b/src/parser/parser_core.c @@ -591,22 +591,34 @@ static ASTNode *generate_derive_impls(ParserContext *ctx, ASTNode *strct, char * } char cmp[256]; - ASTNode *fdef = find_struct_def(ctx, ft); - if (fdef && fdef->type == NODE_ENUM) + // Detect pointer using type_info OR string check (fallback) + int is_ptr = 0; + if (f->type_info && f->type_info->kind == TYPE_POINTER) { - // Enum field: compare tags (pointer access via auto-deref) + is_ptr = 1; + } + // Fallback: check if type string ends with '*' + if (!is_ptr && ft && strchr(ft, '*')) + { + is_ptr = 1; + } + + // Only look up struct def for non-pointer types + ASTNode *fdef = is_ptr ? NULL : find_struct_def(ctx, ft); + + if (!is_ptr && fdef && fdef->type == NODE_ENUM) + { + // Enum field: compare tags sprintf(cmp, "self.%s.tag == other.%s.tag", fn, fn); } - else if (fdef && fdef->type == NODE_STRUCT) + else if (!is_ptr && fdef && fdef->type == NODE_STRUCT) { - // Struct field: use _eq function, pass addresses - // self.field is L-value, other.field is L-value (auto-deref from - // pointer) We need addresses of them: &self.field, &other.field + // Struct field: use __eq function sprintf(cmp, "%s__eq(&self.%s, &other.%s)", ft, fn, fn); } else { - // Primitive or unknown: use == (auto-deref) + // Primitive, POINTER, or unknown: use == sprintf(cmp, "self.%s == other.%s", fn, fn); } strcat(body, cmp); diff --git a/src/parser/parser_expr.c b/src/parser/parser_expr.c index 638d668..8005e8e 100644 --- a/src/parser/parser_expr.c +++ b/src/parser/parser_expr.c @@ -5195,6 +5195,32 @@ ASTNode *parse_expr_prec(ParserContext *ctx, Lexer *l, Precedence min_prec) char *struct_name = resolve_struct_name_from_type(ctx, lt, &is_lhs_ptr, &allocated_name); + // If we are comparing pointers with == or !=, do NOT rewrite to .eq() + // We want pointer equality, not value equality (which requires dereferencing) + // But strict check: Only if BOTH are pointers. If one is value, we might need rewrite. + if (is_lhs_ptr && struct_name && + (strcmp(bin->binary.op, "==") == 0 || strcmp(bin->binary.op, "!=") == 0)) + { + int is_rhs_ptr = 0; + char *r_alloc = NULL; + char *r_name = + resolve_struct_name_from_type(ctx, rhs->type_info, &is_rhs_ptr, &r_alloc); + if (r_alloc) + { + free(r_alloc); + } + + if (is_rhs_ptr) + { + // Both are pointers: Skip rewrite to allow pointer comparison + if (allocated_name) + { + free(allocated_name); + } + struct_name = NULL; + } + } + if (struct_name) { char mangled[256]; @@ -5423,8 +5449,15 @@ ASTNode *parse_expr_prec(ParserContext *ctx, Lexer *l, Precedence min_prec) } } + int lhs_is_num = + is_integer_type(lhs->type_info) || lhs->type_info->kind == TYPE_F32 || + lhs->type_info->kind == TYPE_F64 || lhs->type_info->kind == TYPE_FLOAT; + int rhs_is_num = + is_integer_type(rhs->type_info) || rhs->type_info->kind == TYPE_F32 || + rhs->type_info->kind == TYPE_F64 || rhs->type_info->kind == TYPE_FLOAT; + if (!skip_check && !type_eq(lhs->type_info, rhs->type_info) && - !(is_integer_type(lhs->type_info) && is_integer_type(rhs->type_info))) + !(lhs_is_num && rhs_is_num)) { char msg[256]; sprintf(msg, "Type mismatch in comparison: cannot compare '%s' and '%s'", t1, @@ -5554,16 +5587,92 @@ ASTNode *parse_expr_prec(ParserContext *ctx, Lexer *l, Precedence min_prec) } } - char msg[256]; - sprintf(msg, "Type mismatch in binary operation '%s'", bin->binary.op); + // Allow assigning 0 to pointer (NULL) + int is_null_assign = 0; + if (strcmp(bin->binary.op, "=") == 0) + { + int lhs_is_ptr = (lhs->type_info->kind == TYPE_POINTER || + lhs->type_info->kind == TYPE_STRING || + (t1 && strstr(t1, "*") != NULL)); + if (lhs_is_ptr && rhs->type == NODE_EXPR_LITERAL && + rhs->literal.int_val == 0) + { + is_null_assign = 1; + } + } - char suggestion[512]; - sprintf(suggestion, - "Left operand has type '%s', right operand has type '%s'\n = " - "note: Consider casting one operand to match the other", - t1, t2); + if (!is_null_assign) + { + // Check for arithmetic promotion (Int * Float, etc) + int lhs_is_num = is_integer_type(lhs->type_info) || + lhs->type_info->kind == TYPE_F32 || + lhs->type_info->kind == TYPE_F64 || + lhs->type_info->kind == TYPE_FLOAT; + int rhs_is_num = is_integer_type(rhs->type_info) || + rhs->type_info->kind == TYPE_F32 || + rhs->type_info->kind == TYPE_F64 || + rhs->type_info->kind == TYPE_FLOAT; + + int valid_arith = 0; + if (lhs_is_num && rhs_is_num) + { + if (strcmp(bin->binary.op, "+") == 0 || + strcmp(bin->binary.op, "-") == 0 || + strcmp(bin->binary.op, "*") == 0 || + strcmp(bin->binary.op, "/") == 0) + { + valid_arith = 1; + // Result is the float type if one is float + if (lhs->type_info->kind == TYPE_F64 || + rhs->type_info->kind == TYPE_F64) + { + bin->type_info = lhs->type_info->kind == TYPE_F64 + ? lhs->type_info + : rhs->type_info; + } + else if (lhs->type_info->kind == TYPE_F32 || + rhs->type_info->kind == TYPE_F32 || + lhs->type_info->kind == TYPE_FLOAT || + rhs->type_info->kind == TYPE_FLOAT) + { + // Pick the float type. If both float, pick lhs. + if (lhs->type_info->kind == TYPE_F32 || + lhs->type_info->kind == TYPE_FLOAT) + { + bin->type_info = lhs->type_info; + } + else + { + bin->type_info = rhs->type_info; + } + } + else + { + // Both int (but failed equality check previously? - rare + // but possible if diff int types) If diff int types, we + // usually allow it in C (promotion). For now, assume LHS + // dominates or standard promotion. + bin->type_info = lhs->type_info; + } + } + } - zpanic_with_suggestion(op, msg, suggestion); + if (!valid_arith) + { + char msg[256]; + sprintf(msg, "Type mismatch in binary operation '%s'", + bin->binary.op); + + char suggestion[512]; + sprintf( + suggestion, + "Left operand has type '%s', right operand has type '%s'\n = " + "note: Consider casting one operand to match the other", + t1, t2); + + zpanic_with_suggestion(op, msg, suggestion); + } + } bin_inference_success:; } diff --git a/src/parser/parser_struct.c b/src/parser/parser_struct.c index e776bb8..17d8e93 100644 --- a/src/parser/parser_struct.c +++ b/src/parser/parser_struct.c @@ -740,12 +740,14 @@ ASTNode *parse_struct(ParserContext *ctx, Lexer *l, int is_union) // Named use -> Composition (Add field, don't flatten) Token field_name = lexer_next(l); lexer_next(l); // eat : - char *field_type_str = parse_type(ctx, l); + Type *ft = parse_type_formal(ctx, l); + char *field_type_str = type_to_string(ft); expect(l, TOK_SEMICOLON, "Expected ;"); ASTNode *nf = ast_create(NODE_FIELD); nf->field.name = token_strdup(field_name); nf->field.type = field_type_str; + nf->type_info = ft; if (!h) { @@ -792,6 +794,12 @@ ASTNode *parse_struct(ParserContext *ctx, Lexer *l, int is_union) ASTNode *nf = ast_create(NODE_FIELD); nf->field.name = xstrdup(f->field.name); nf->field.type = xstrdup(f->field.type); + // Copy type info? Ideally deep copy or ref + // For now, we leave it NULL or shallow copy if needed, but mixins usually + // aren't generic params themselves in the same way. + // Let's shallow copy for safety if it exists. + nf->type_info = f->type_info; + if (!h) { h = nf; @@ -821,11 +829,13 @@ ASTNode *parse_struct(ParserContext *ctx, Lexer *l, int is_union) { Token f_name = lexer_next(l); expect(l, TOK_COLON, "Expected :"); - char *f_type = parse_type(ctx, l); + Type *ft = parse_type_formal(ctx, l); + char *f_type = type_to_string(ft); ASTNode *f = ast_create(NODE_FIELD); f->field.name = token_strdup(f_name); f->field.type = f_type; + f->type_info = ft; f->field.bit_width = 0; // Optional bit width: name: type : 3 diff --git a/src/parser/parser_utils.c b/src/parser/parser_utils.c index 385d36c..86e1b50 100644 --- a/src/parser/parser_utils.c +++ b/src/parser/parser_utils.c @@ -1882,6 +1882,129 @@ FuncSig *find_func(ParserContext *ctx, const char *name) return NULL; } +// Helper function to recursively scan AST for sizeof types and trigger instantiation of generic +// structs +static void trigger_sizeof_instantiations(ParserContext *ctx, ASTNode *node) +{ + if (!node) + { + return; + } + + // Process current node + if (node->type == NODE_EXPR_SIZEOF && node->size_of.target_type) + { + const char *type_str = node->size_of.target_type; + if (strchr(type_str, '_')) + { + // Remove trailing '*' or 'Ptr' if present + char *type_copy = xstrdup(type_str); + char *star = strchr(type_copy, '*'); + if (star) + { + *star = '\0'; + } + else + { + // Check for "Ptr" suffix and remove it + size_t len = strlen(type_copy); + if (len > 3 && strcmp(type_copy + len - 3, "Ptr") == 0) + { + type_copy[len - 3] = '\0'; + } + } + + char *underscore = strrchr(type_copy, '_'); + if (underscore && underscore > type_copy) + { + *underscore = '\0'; + char *template_name = type_copy; + char *concrete_arg = underscore + 1; + + // Check if this is a known generic template + GenericTemplate *gt = ctx->templates; + int found = 0; + while (gt) + { + if (strcmp(gt->name, template_name) == 0) + { + found = 1; + break; + } + gt = gt->next; + } + + if (found) + { + char *unmangled = unmangle_ptr_suffix(concrete_arg); + Token dummy_tok = {0}; + instantiate_generic(ctx, template_name, concrete_arg, unmangled, dummy_tok); + free(unmangled); + } + } + free(type_copy); + } + } + + // Recursively visit children based on node type + switch (node->type) + { + case NODE_FUNCTION: + trigger_sizeof_instantiations(ctx, node->func.body); + break; + case NODE_BLOCK: + trigger_sizeof_instantiations(ctx, node->block.statements); + break; + case NODE_VAR_DECL: + trigger_sizeof_instantiations(ctx, node->var_decl.init_expr); + break; + case NODE_RETURN: + trigger_sizeof_instantiations(ctx, node->ret.value); + break; + case NODE_EXPR_BINARY: + trigger_sizeof_instantiations(ctx, node->binary.left); + trigger_sizeof_instantiations(ctx, node->binary.right); + break; + case NODE_EXPR_UNARY: + trigger_sizeof_instantiations(ctx, node->unary.operand); + break; + case NODE_EXPR_CALL: + trigger_sizeof_instantiations(ctx, node->call.callee); + trigger_sizeof_instantiations(ctx, node->call.args); + break; + case NODE_EXPR_MEMBER: + trigger_sizeof_instantiations(ctx, node->member.target); + break; + case NODE_EXPR_INDEX: + trigger_sizeof_instantiations(ctx, node->index.array); + trigger_sizeof_instantiations(ctx, node->index.index); + break; + case NODE_EXPR_CAST: + trigger_sizeof_instantiations(ctx, node->cast.expr); + break; + case NODE_IF: + trigger_sizeof_instantiations(ctx, node->if_stmt.condition); + trigger_sizeof_instantiations(ctx, node->if_stmt.then_body); + trigger_sizeof_instantiations(ctx, node->if_stmt.else_body); + break; + case NODE_WHILE: + trigger_sizeof_instantiations(ctx, node->while_stmt.condition); + trigger_sizeof_instantiations(ctx, node->while_stmt.body); + break; + case NODE_FOR: + trigger_sizeof_instantiations(ctx, node->for_stmt.init); + trigger_sizeof_instantiations(ctx, node->for_stmt.condition); + trigger_sizeof_instantiations(ctx, node->for_stmt.step); + trigger_sizeof_instantiations(ctx, node->for_stmt.body); + break; + default: + break; + } + + // Visit next sibling + trigger_sizeof_instantiations(ctx, node->next); +} + char *instantiate_function_template(ParserContext *ctx, const char *name, const char *concrete_type, const char *unmangled_type) { @@ -2020,6 +2143,11 @@ char *instantiate_function_template(ParserContext *ctx, const char *name, const { return NULL; } + + // Scan the function body for sizeof expressions and trigger instantiation + // of any generic structs referenced there (e.g., sizeof(RcInner_int32_t)) + trigger_sizeof_instantiations(ctx, new_fn->func.body); + free(new_fn->func.name); new_fn->func.name = xstrdup(mangled); new_fn->func.generic_params = NULL; @@ -2265,6 +2393,43 @@ ASTNode *copy_fields_replacing(ParserContext *ctx, ASTNode *fields, const char * } } + // Additional check: if type_info is a pointer to a struct with a mangled name, + // instantiate that struct as well (fixes cases like RcInner<T>* where the + // string check above might not catch it) + if (n->type_info && n->type_info->kind == TYPE_POINTER && n->type_info->inner) + { + Type *inner = n->type_info->inner; + if (inner->kind == TYPE_STRUCT && inner->name && strchr(inner->name, '_')) + { + // Extract template name by checking against known templates + // We can't use strrchr because types like "Inner_int32_t" have multiple underscores + char *template_name = NULL; + char *concrete_arg = NULL; + + // Try each known template to see if the type name starts with it + GenericTemplate *gt = ctx->templates; + while (gt) + { + size_t tlen = strlen(gt->name); + // Check if name starts with template name followed by underscore + if (strncmp(inner->name, gt->name, tlen) == 0 && inner->name[tlen] == '_') + { + template_name = gt->name; + concrete_arg = inner->name + tlen + 1; // Skip template name and underscore + break; + } + gt = gt->next; + } + + if (template_name && concrete_arg) + { + char *unmangled = unmangle_ptr_suffix(concrete_arg); + instantiate_generic(ctx, template_name, concrete_arg, unmangled, fields->token); + free(unmangled); + } + } + } + n->next = copy_fields_replacing(ctx, fields->next, param, concrete); return n; } |
