summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/codegen/codegen.c56
-rw-r--r--src/codegen/codegen_decl.c104
-rw-r--r--src/parser/parser_expr.c204
-rw-r--r--src/parser/parser_stmt.c12
4 files changed, 356 insertions, 20 deletions
diff --git a/src/codegen/codegen.c b/src/codegen/codegen.c
index 32fbdcf..c6211ad 100644
--- a/src/codegen/codegen.c
+++ b/src/codegen/codegen.c
@@ -703,21 +703,57 @@ void codegen_expression(ParserContext *ctx, ASTNode *node, FILE *out)
char mixin_func_name[128];
sprintf(mixin_func_name, "%s__%s", base, method);
+ char *resolved_method_suffix = NULL;
+
if (!find_func(ctx, mixin_func_name))
{
- // Method not found on primary struct, check mixins
- ASTNode *def = find_struct_def(ctx, base);
- if (def && def->type == NODE_STRUCT && def->strct.used_structs)
+ // Try resolving as a trait method: Struct__Trait_Method
+ StructRef *ref = ctx->parsed_impls_list;
+ while (ref)
+ {
+ if (ref->node && ref->node->type == NODE_IMPL_TRAIT)
+ {
+ if (strcmp(ref->node->impl_trait.target_type, base) == 0)
+ {
+ char trait_mangled[256];
+ sprintf(trait_mangled, "%s__%s_%s", base,
+ ref->node->impl_trait.trait_name, method);
+ if (find_func(ctx, trait_mangled))
+ {
+ char *suffix =
+ xmalloc(strlen(ref->node->impl_trait.trait_name) +
+ strlen(method) + 2);
+ sprintf(suffix, "%s_%s", ref->node->impl_trait.trait_name,
+ method);
+ resolved_method_suffix = suffix;
+ break;
+ }
+ }
+ }
+ ref = ref->next;
+ }
+
+ if (resolved_method_suffix)
+ {
+ method = resolved_method_suffix;
+ }
+ else
{
- for (int k = 0; k < def->strct.used_struct_count; k++)
+ // Method not found on primary struct, check mixins
+ ASTNode *def = find_struct_def(ctx, base);
+ if (def && def->type == NODE_STRUCT && def->strct.used_structs)
{
- char mixin_check[128];
- sprintf(mixin_check, "%s__%s", def->strct.used_structs[k], method);
- if (find_func(ctx, mixin_check))
+ for (int k = 0; k < def->strct.used_struct_count; k++)
{
- call_base = def->strct.used_structs[k];
- need_cast = 1;
- break;
+ char mixin_check[128];
+ sprintf(mixin_check, "%s__%s", def->strct.used_structs[k],
+ method);
+ if (find_func(ctx, mixin_check))
+ {
+ call_base = def->strct.used_structs[k];
+ need_cast = 1;
+ break;
+ }
}
}
}
diff --git a/src/codegen/codegen_decl.c b/src/codegen/codegen_decl.c
index 932420f..e7bd3f1 100644
--- a/src/codegen/codegen_decl.c
+++ b/src/codegen/codegen_decl.c
@@ -424,6 +424,28 @@ void emit_struct_defs(ParserContext *ctx, ASTNode *node, FILE *out)
}
}
+// Helper to substitute 'Self' with replacement string
+static char *substitute_proto_self(const char *type_str, const char *replacement)
+{
+ if (!type_str)
+ {
+ return NULL;
+ }
+ if (strcmp(type_str, "Self") == 0)
+ {
+ return xstrdup(replacement);
+ }
+ // Handle pointers (Self* -> replacement*)
+ if (strncmp(type_str, "Self", 4) == 0)
+ {
+ char *rest = (char *)type_str + 4;
+ char *buf = xmalloc(strlen(replacement) + strlen(rest) + 1);
+ sprintf(buf, "%s%s", replacement, rest);
+ return buf;
+ }
+ return xstrdup(type_str);
+}
+
// Emit trait definitions.
void emit_trait_defs(ASTNode *node, FILE *out)
{
@@ -440,8 +462,10 @@ void emit_trait_defs(ASTNode *node, FILE *out)
ASTNode *m = node->trait.methods;
while (m)
{
- fprintf(out, " %s (*%s)(", m->func.ret_type,
- parse_original_method_name(m->func.name));
+ char *ret_safe = substitute_proto_self(m->func.ret_type, "void*");
+ fprintf(out, " %s (*%s)(", ret_safe, parse_original_method_name(m->func.name));
+ free(ret_safe);
+
int has_self = (m->func.args && strstr(m->func.args, "self"));
if (!has_self)
{
@@ -454,7 +478,32 @@ void emit_trait_defs(ASTNode *node, FILE *out)
{
fprintf(out, ", ");
}
- fprintf(out, "%s", m->func.args);
+ char *args_safe = xstrdup(m->func.args);
+ // TODO: better replace, but for now this works.
+ char *p = strstr(args_safe, "Self");
+ while (p)
+ {
+ // Check word boundary
+ if ((p == args_safe || !isalnum(p[-1])) && !isalnum(p[4]))
+ {
+ int off = p - args_safe;
+ char *new_s = xmalloc(strlen(args_safe) + 10);
+ strncpy(new_s, args_safe, off);
+ new_s[off] = 0;
+ strcat(new_s, "void*");
+ strcat(new_s, p + 4);
+ free(args_safe);
+ args_safe = new_s;
+ p = strstr(args_safe + off + 5, "Self");
+ }
+ else
+ {
+ p = strstr(p + 1, "Self");
+ }
+ }
+
+ fprintf(out, "%s", args_safe);
+ free(args_safe);
}
fprintf(out, ");\n");
m = m->next;
@@ -467,7 +516,9 @@ void emit_trait_defs(ASTNode *node, FILE *out)
while (m)
{
const char *orig = parse_original_method_name(m->func.name);
- fprintf(out, "%s %s__%s(%s* self", m->func.ret_type, node->trait.name, orig,
+ char *ret_sub = substitute_proto_self(m->func.ret_type, node->trait.name);
+
+ fprintf(out, "%s %s__%s(%s* self", ret_sub, node->trait.name, orig,
node->trait.name);
int has_self = (m->func.args && strstr(m->func.args, "self"));
@@ -478,17 +529,45 @@ void emit_trait_defs(ASTNode *node, FILE *out)
char *comma = strchr(m->func.args, ',');
if (comma)
{
- fprintf(out, ", %s", comma + 1);
+ // Substitute Self -> TraitName in wrapper args
+ char *args_sub = xstrdup(comma + 1);
+ char *p = strstr(args_sub, "Self");
+ while (p)
+ {
+ int off = p - args_sub;
+ char *new_s =
+ xmalloc(strlen(args_sub) + strlen(node->trait.name) + 5);
+ strncpy(new_s, args_sub, off);
+ new_s[off] = 0;
+ strcat(new_s, node->trait.name);
+ strcat(new_s, p + 4);
+ free(args_sub);
+ args_sub = new_s;
+ p = strstr(args_sub + off + strlen(node->trait.name), "Self");
+ }
+
+ fprintf(out, ", %s", args_sub);
+ free(args_sub);
}
}
else
{
- fprintf(out, ", %s", m->func.args);
+ fprintf(out, ", %s", m->func.args); // TODO: recursive subst
}
}
fprintf(out, ") {\n");
- fprintf(out, " return self->vtable->%s(self->self", orig);
+ int ret_is_self = (strcmp(m->func.ret_type, "Self") == 0);
+
+ if (ret_is_self)
+ {
+ // Special handling: return (Trait){.self = call(), .vtable = self->vtable}
+ fprintf(out, " void* ret = self->vtable->%s(self->self", orig);
+ }
+ else
+ {
+ fprintf(out, " return self->vtable->%s(self->self", orig);
+ }
if (m->func.args)
{
@@ -510,7 +589,16 @@ void emit_trait_defs(ASTNode *node, FILE *out)
}
free(call_args);
}
- fprintf(out, ");\n}\n");
+ fprintf(out, ");\n");
+
+ if (ret_is_self)
+ {
+ fprintf(out, " return (%s){.self = ret, .vtable = self->vtable};\n",
+ node->trait.name);
+ }
+
+ fprintf(out, "}\n\n");
+ free(ret_sub);
m = m->next;
}
diff --git a/src/parser/parser_expr.c b/src/parser/parser_expr.c
index a79bb21..092b86b 100644
--- a/src/parser/parser_expr.c
+++ b/src/parser/parser_expr.c
@@ -3308,7 +3308,122 @@ ASTNode *parse_expr_prec(ParserContext *ctx, Lexer *l, Precedence min_prec)
if (op.type == TOK_LPAREN)
{
ASTNode *call = ast_create(NODE_EXPR_CALL);
- call->call.callee = lhs;
+
+ // Method Resolution Logic (Struct Method -> Trait Method)
+ ASTNode *self_arg = NULL;
+ FuncSig *resolved_sig = NULL;
+ char *resolved_name = NULL;
+
+ if (lhs->type == NODE_EXPR_MEMBER)
+ {
+ Type *lt = lhs->member.target->type_info;
+ int is_lhs_ptr = 0;
+ char *alloc_name = NULL;
+ char *struct_name =
+ resolve_struct_name_from_type(ctx, lt, &is_lhs_ptr, &alloc_name);
+
+ if (struct_name)
+ {
+ char mangled[256];
+ sprintf(mangled, "%s__%s", struct_name, lhs->member.field);
+ FuncSig *sig = find_func(ctx, mangled);
+
+ if (!sig)
+ {
+ // Trait method lookup: Struct__Trait_Method
+ StructRef *ref = ctx->parsed_impls_list;
+ while (ref)
+ {
+ if (ref->node && ref->node->type == NODE_IMPL_TRAIT)
+ {
+ if (ref->node->impl_trait.target_type &&
+ strcmp(ref->node->impl_trait.target_type, struct_name) == 0)
+ {
+ char trait_mangled[512];
+ snprintf(trait_mangled, 512, "%s__%s_%s", struct_name,
+ ref->node->impl_trait.trait_name, lhs->member.field);
+ if (find_func(ctx, trait_mangled))
+ {
+ sig = find_func(ctx, trait_mangled);
+ strcpy(mangled, trait_mangled);
+ break;
+ }
+ }
+ }
+ ref = ref->next;
+ }
+ }
+
+ if (sig)
+ {
+ resolved_name = xstrdup(mangled);
+ resolved_sig = sig;
+
+ // Create 'self' argument
+ ASTNode *obj = lhs->member.target;
+
+ // Handle Reference/Pointer adjustment based on signature
+ if (sig->total_args > 0 && sig->arg_types[0] &&
+ sig->arg_types[0]->kind == TYPE_POINTER)
+ {
+ if (!is_lhs_ptr)
+ {
+ // Function expects ptr, have value -> &obj
+ int is_rvalue =
+ (obj->type == NODE_EXPR_CALL || obj->type == NODE_EXPR_BINARY ||
+ obj->type == NODE_EXPR_STRUCT_INIT ||
+ obj->type == NODE_EXPR_CAST || obj->type == NODE_MATCH);
+
+ ASTNode *addr = ast_create(NODE_EXPR_UNARY);
+ addr->unary.op = is_rvalue ? xstrdup("&_rval") : xstrdup("&");
+ addr->unary.operand = obj;
+ addr->type_info = type_new_ptr(lt);
+ self_arg = addr;
+ }
+ else
+ {
+ self_arg = obj;
+ }
+ }
+ else
+ {
+ // Function expects value
+ if (is_lhs_ptr)
+ {
+ // Have ptr, need value -> *obj
+ ASTNode *deref = ast_create(NODE_EXPR_UNARY);
+ deref->unary.op = xstrdup("*");
+ deref->unary.operand = obj;
+ if (lt && lt->kind == TYPE_POINTER && lt->inner)
+ {
+ deref->type_info = lt->inner;
+ }
+ self_arg = deref;
+ }
+ else
+ {
+ self_arg = obj;
+ }
+ }
+ }
+ }
+ if (alloc_name)
+ {
+ free(alloc_name);
+ }
+ }
+
+ if (resolved_name)
+ {
+ ASTNode *callee = ast_create(NODE_EXPR_VAR);
+ callee->var_ref.name = resolved_name;
+ call->call.callee = callee;
+ }
+ else
+ {
+ call->call.callee = lhs;
+ }
+
ASTNode *head = NULL, *tail = NULL;
char **arg_names = NULL;
int arg_count = 0;
@@ -3390,12 +3505,44 @@ ASTNode *parse_expr_prec(ParserContext *ctx, Lexer *l, Precedence min_prec)
{
zpanic_at(lexer_peek(l), "Expected )");
}
+
+ // Prepend 'self' argument if resolved
+ if (self_arg)
+ {
+ self_arg->next = head;
+ head = self_arg;
+ arg_count++;
+
+ if (has_named)
+ {
+ // Prepend NULL to arg_names for self
+ char **new_names = xmalloc(sizeof(char *) * arg_count);
+ new_names[0] = NULL;
+ for (int i = 0; i < arg_count - 1; i++)
+ {
+ new_names[i + 1] = arg_names[i];
+ }
+ free(arg_names);
+ arg_names = new_names;
+ }
+ }
+
call->call.args = head;
call->call.arg_names = has_named ? arg_names : NULL;
call->call.arg_count = arg_count;
call->resolved_type = xstrdup("unknown");
- if (lhs->type_info && lhs->type_info->kind == TYPE_FUNCTION && lhs->type_info->inner)
+
+ if (resolved_sig)
+ {
+ call->type_info = resolved_sig->ret_type;
+ if (call->type_info)
+ {
+ call->resolved_type = type_to_string(call->type_info);
+ }
+ }
+ else if (lhs->type_info && lhs->type_info->kind == TYPE_FUNCTION &&
+ lhs->type_info->inner)
{
call->type_info = lhs->type_info->inner;
}
@@ -3695,6 +3842,33 @@ ASTNode *parse_expr_prec(ParserContext *ctx, Lexer *l, Precedence min_prec)
sprintf(mangled, "%s__%s", struct_name, node->member.field);
FuncSig *sig = find_func(ctx, mangled);
+
+ if (!sig)
+ {
+ // Try resolving as a trait method: Struct__Trait__Method
+ StructRef *ref = ctx->parsed_impls_list;
+ while (ref)
+ {
+ if (ref->node && ref->node->type == NODE_IMPL_TRAIT)
+ {
+ const char *t_struct = ref->node->impl_trait.target_type;
+ if (t_struct && strcmp(t_struct, struct_name) == 0)
+ {
+ char trait_mangled[512];
+ snprintf(trait_mangled, 512, "%s__%s_%s", struct_name,
+ ref->node->impl_trait.trait_name, node->member.field);
+ if (find_func(ctx, trait_mangled))
+ {
+ strcpy(mangled, trait_mangled); // Update mangled name
+ sig = find_func(ctx, mangled);
+ break;
+ }
+ }
+ }
+ ref = ref->next;
+ }
+ }
+
if (sig)
{
// It is a method! Create a Function Type Info to carry the return
@@ -4070,6 +4244,32 @@ ASTNode *parse_expr_prec(ParserContext *ctx, Lexer *l, Precedence min_prec)
FuncSig *sig = find_func(ctx, mangled);
+ if (!sig)
+ {
+ // Try resolving as a trait method: Struct__Trait__Method
+ StructRef *ref = ctx->parsed_impls_list;
+ while (ref)
+ {
+ if (ref->node && ref->node->type == NODE_IMPL_TRAIT)
+ {
+ const char *t_struct = ref->node->impl_trait.target_type;
+ if (t_struct && strcmp(t_struct, struct_name) == 0)
+ {
+ char trait_mangled[512];
+ snprintf(trait_mangled, 512, "%s__%s_%s", struct_name,
+ ref->node->impl_trait.trait_name, method);
+ if (find_func(ctx, trait_mangled))
+ {
+ strcpy(mangled, trait_mangled); // Update mangled name
+ sig = find_func(ctx, mangled);
+ break;
+ }
+ }
+ }
+ ref = ref->next;
+ }
+ }
+
if (sig)
{
ASTNode *call = ast_create(NODE_EXPR_CALL);
diff --git a/src/parser/parser_stmt.c b/src/parser/parser_stmt.c
index a8c8df5..7873d51 100644
--- a/src/parser/parser_stmt.c
+++ b/src/parser/parser_stmt.c
@@ -3398,6 +3398,12 @@ ASTNode *parse_impl(ParserContext *ctx, Lexer *l)
char *na = patch_self_args(f->func.args, name2);
free(f->func.args);
f->func.args = na;
+
+ // Register function for lookup
+ register_func(ctx, mangled, f->func.arg_count, f->func.defaults, f->func.arg_types,
+ f->func.ret_type_info, f->func.is_varargs, f->func.is_async,
+ f->token);
+
if (!h)
{
h = f;
@@ -3424,6 +3430,12 @@ ASTNode *parse_impl(ParserContext *ctx, Lexer *l)
char *na = patch_self_args(f->func.args, name2);
free(f->func.args);
f->func.args = na;
+
+ // Register function for lookup
+ register_func(ctx, mangled, f->func.arg_count, f->func.defaults,
+ f->func.arg_types, f->func.ret_type_info, f->func.is_varargs,
+ f->func.is_async, f->token);
+
if (!h)
{
h = f;