summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorZuhaitz Méndez Fernández de Aránguiz <zuhaitz@debian>2026-01-18 19:54:32 +0000
committerZuhaitz Méndez Fernández de Aránguiz <zuhaitz@debian>2026-01-18 19:54:32 +0000
commit8401c970ece366592dcbfaa0affe1a1a1bc18ac8 (patch)
treea83b2864e51782ae9a186b27fa6124eae8422357 /src
parentfc4abb77ecab8fe3c497e13d7f7d6d8f832514b2 (diff)
CUDA Interop, baby.
Diffstat (limited to 'src')
-rw-r--r--src/ast/ast.h16
-rw-r--r--src/codegen/codegen.c64
-rw-r--r--src/codegen/codegen_utils.c17
-rw-r--r--src/main.c7
-rw-r--r--src/parser/parser_core.c24
-rw-r--r--src/parser/parser_stmt.c111
-rw-r--r--src/zprep.h1
7 files changed, 238 insertions, 2 deletions
diff --git a/src/ast/ast.h b/src/ast/ast.h
index 2288860..cef68c6 100644
--- a/src/ast/ast.h
+++ b/src/ast/ast.h
@@ -122,7 +122,8 @@ typedef enum
NODE_TRY,
NODE_REFLECTION,
NODE_AWAIT,
- NODE_REPL_PRINT
+ NODE_REPL_PRINT,
+ NODE_CUDA_LAUNCH
} NodeType;
// ** AST Node Structure **
@@ -176,6 +177,10 @@ struct ASTNode
char *section; // @section("name")
int is_async; // async function
int is_comptime; // @comptime function
+ // CUDA qualifiers
+ int cuda_global; // @global -> __global__
+ int cuda_device; // @device -> __device__
+ int cuda_host; // @host -> __host__
} func;
struct
@@ -539,6 +544,15 @@ struct ASTNode
{
ASTNode *expr;
} repl_print;
+
+ struct
+ {
+ ASTNode *call; // The kernel call (NODE_EXPR_CALL)
+ ASTNode *grid; // Grid dimensions expression
+ ASTNode *block; // Block dimensions expression
+ ASTNode *shared_mem; // Optional shared memory size (NULL = default)
+ ASTNode *stream; // Optional CUDA stream (NULL = default)
+ } cuda_launch;
};
};
diff --git a/src/codegen/codegen.c b/src/codegen/codegen.c
index 644e9f5..a371548 100644
--- a/src/codegen/codegen.c
+++ b/src/codegen/codegen.c
@@ -2252,6 +2252,70 @@ void codegen_node_single(ParserContext *ctx, ASTNode *node, FILE *out)
fprintf(out, ";\n");
}
break;
+ case NODE_CUDA_LAUNCH:
+ {
+ // Emit CUDA kernel launch: kernel<<<grid, block, shared, stream>>>(args);
+ ASTNode *call = node->cuda_launch.call;
+
+ // Get kernel name from callee
+ if (call->call.callee->type == NODE_EXPR_VAR)
+ {
+ fprintf(out, " %s<<<", call->call.callee->var_ref.name);
+ }
+ else
+ {
+ fprintf(out, " ");
+ codegen_expression(ctx, call->call.callee, out);
+ fprintf(out, "<<<");
+ }
+
+ // Grid dimension
+ codegen_expression(ctx, node->cuda_launch.grid, out);
+ fprintf(out, ", ");
+
+ // Block dimension
+ codegen_expression(ctx, node->cuda_launch.block, out);
+
+ // Optional shared memory size
+ if (node->cuda_launch.shared_mem || node->cuda_launch.stream)
+ {
+ fprintf(out, ", ");
+ if (node->cuda_launch.shared_mem)
+ {
+ codegen_expression(ctx, node->cuda_launch.shared_mem, out);
+ }
+ else
+ {
+ fprintf(out, "0");
+ }
+ }
+
+ // Optional CUDA stream
+ if (node->cuda_launch.stream)
+ {
+ fprintf(out, ", ");
+ codegen_expression(ctx, node->cuda_launch.stream, out);
+ }
+
+ fprintf(out, ">>>(");
+
+ // Arguments
+ ASTNode *arg = call->call.args;
+ int first = 1;
+ while (arg)
+ {
+ if (!first)
+ {
+ fprintf(out, ", ");
+ }
+ codegen_expression(ctx, arg, out);
+ first = 0;
+ arg = arg->next;
+ }
+
+ fprintf(out, ");\n");
+ break;
+ }
default:
codegen_expression(ctx, node, out);
fprintf(out, ";\n");
diff --git a/src/codegen/codegen_utils.c b/src/codegen/codegen_utils.c
index b1fcf4c..af1c862 100644
--- a/src/codegen/codegen_utils.c
+++ b/src/codegen/codegen_utils.c
@@ -535,6 +535,23 @@ void emit_func_signature(FILE *out, ASTNode *func, const char *name_override)
return;
}
+ // Emit CUDA qualifiers (for both forward declarations and definitions)
+ if (g_config.use_cuda)
+ {
+ if (func->func.cuda_global)
+ {
+ fprintf(out, "__global__ ");
+ }
+ if (func->func.cuda_device)
+ {
+ fprintf(out, "__device__ ");
+ }
+ if (func->func.cuda_host)
+ {
+ fprintf(out, "__host__ ");
+ }
+ }
+
// Return type
char *ret_str;
if (func->func.ret_type_info)
diff --git a/src/main.c b/src/main.c
index edc3723..2b7dcf0 100644
--- a/src/main.c
+++ b/src/main.c
@@ -48,6 +48,7 @@ void print_usage()
printf(" -q, --quiet Quiet output\n");
printf(" -c Compile only (produce .o)\n");
printf(" --cpp Use C++ mode.\n");
+ printf(" --cuda Use CUDA mode (requires nvcc).\n");
}
int main(int argc, char **argv)
@@ -147,6 +148,12 @@ int main(int argc, char **argv)
strcpy(g_config.cc, "g++");
g_config.use_cpp = 1;
}
+ else if (strcmp(arg, "--cuda") == 0)
+ {
+ strcpy(g_config.cc, "nvcc");
+ g_config.use_cuda = 1;
+ g_config.use_cpp = 1; // CUDA implies C++ mode
+ }
else if (strcmp(arg, "--check") == 0)
{
g_config.mode_check = 1;
diff --git a/src/parser/parser_core.c b/src/parser/parser_core.c
index 3e683fb..c3c91fe 100644
--- a/src/parser/parser_core.c
+++ b/src/parser/parser_core.c
@@ -63,6 +63,9 @@ ASTNode *parse_program_nodes(ParserContext *ctx, Lexer *l)
int attr_weak = 0;
int attr_export = 0;
int attr_comptime = 0;
+ int attr_cuda_global = 0; // @global -> __global__
+ int attr_cuda_device = 0; // @device -> __device__
+ int attr_cuda_host = 0; // @host -> __host__
char *deprecated_msg = NULL;
char *attr_section = NULL;
@@ -232,7 +235,23 @@ ASTNode *parse_program_nodes(ParserContext *ctx, Lexer *l)
}
else
{
- zwarn_at(attr, "Unknown attribute: %.*s", attr.len, attr.start);
+ // Checking for CUDA attributes...
+ if (0 == strncmp(attr.start, "global", 6) && 6 == attr.len)
+ {
+ attr_cuda_global = 1;
+ }
+ else if (0 == strncmp(attr.start, "device", 6) && 6 == attr.len)
+ {
+ attr_cuda_device = 1;
+ }
+ else if (0 == strncmp(attr.start, "host", 4) && 4 == attr.len)
+ {
+ attr_cuda_host = 1;
+ }
+ else
+ {
+ zwarn_at(attr, "Unknown attribute: %.*s", attr.len, attr.start);
+ }
}
t = lexer_peek(l);
@@ -469,6 +488,9 @@ ASTNode *parse_program_nodes(ParserContext *ctx, Lexer *l)
s->func.pure = attr_pure;
s->func.section = attr_section;
s->func.is_comptime = attr_comptime;
+ s->func.cuda_global = attr_cuda_global;
+ s->func.cuda_device = attr_cuda_device;
+ s->func.cuda_host = attr_cuda_host;
if (attr_deprecated && s->func.name)
{
diff --git a/src/parser/parser_stmt.c b/src/parser/parser_stmt.c
index 5307768..daf3f72 100644
--- a/src/parser/parser_stmt.c
+++ b/src/parser/parser_stmt.c
@@ -2503,6 +2503,117 @@ ASTNode *parse_statement(ParserContext *ctx, Lexer *l)
return parse_guard(ctx, l);
}
+ // CUDA launch: launch kernel(args) with { grid: X, block: Y };
+ if (strncmp(tk.start, "launch", 6) == 0 && tk.len == 6)
+ {
+ Token launch_tok = lexer_next(l); // eat 'launch'
+
+ // Parse the kernel call expression
+ ASTNode *call = parse_expression(ctx, l);
+ if (!call || call->type != NODE_EXPR_CALL)
+ {
+ zpanic_at(launch_tok, "Expected kernel call after 'launch'");
+ }
+
+ // Expect 'with'
+ Token with_tok = lexer_peek(l);
+ if (with_tok.type != TOK_IDENT || strncmp(with_tok.start, "with", 4) != 0 ||
+ with_tok.len != 4)
+ {
+ zpanic_at(with_tok, "Expected 'with' after kernel call in launch statement");
+ }
+ lexer_next(l); // eat 'with'
+
+ // Expect '{' for configuration block
+ if (lexer_peek(l).type != TOK_LBRACE)
+ {
+ zpanic_at(lexer_peek(l), "Expected '{' after 'with' in launch statement");
+ }
+ lexer_next(l); // eat '{'
+
+ ASTNode *grid = NULL;
+ ASTNode *block = NULL;
+ ASTNode *shared_mem = NULL;
+ ASTNode *stream = NULL;
+
+ // Parse configuration fields
+ while (lexer_peek(l).type != TOK_RBRACE && lexer_peek(l).type != TOK_EOF)
+ {
+ Token field_name = lexer_next(l);
+ if (field_name.type != TOK_IDENT)
+ {
+ zpanic_at(field_name, "Expected field name in launch configuration");
+ }
+
+ // Expect ':'
+ if (lexer_peek(l).type != TOK_COLON)
+ {
+ zpanic_at(lexer_peek(l), "Expected ':' after field name");
+ }
+ lexer_next(l); // eat ':'
+
+ // Parse value expression
+ ASTNode *value = parse_expression(ctx, l);
+
+ // Assign to appropriate field
+ if (strncmp(field_name.start, "grid", 4) == 0 && field_name.len == 4)
+ {
+ grid = value;
+ }
+ else if (strncmp(field_name.start, "block", 5) == 0 && field_name.len == 5)
+ {
+ block = value;
+ }
+ else if (strncmp(field_name.start, "shared_mem", 10) == 0 && field_name.len == 10)
+ {
+ shared_mem = value;
+ }
+ else if (strncmp(field_name.start, "stream", 6) == 0 && field_name.len == 6)
+ {
+ stream = value;
+ }
+ else
+ {
+ zpanic_at(field_name, "Unknown launch configuration field (expected: grid, "
+ "block, shared_mem, stream)");
+ }
+
+ // Optional comma
+ if (lexer_peek(l).type == TOK_COMMA)
+ {
+ lexer_next(l);
+ }
+ }
+
+ // Expect '}'
+ if (lexer_peek(l).type != TOK_RBRACE)
+ {
+ zpanic_at(lexer_peek(l), "Expected '}' to close launch configuration");
+ }
+ lexer_next(l); // eat '}'
+
+ // Expect ';'
+ if (lexer_peek(l).type == TOK_SEMICOLON)
+ {
+ lexer_next(l);
+ }
+
+ // Require at least grid and block
+ if (!grid || !block)
+ {
+ zpanic_at(launch_tok, "Launch configuration requires at least 'grid' and 'block'");
+ }
+
+ ASTNode *n = ast_create(NODE_CUDA_LAUNCH);
+ n->cuda_launch.call = call;
+ n->cuda_launch.grid = grid;
+ n->cuda_launch.block = block;
+ n->cuda_launch.shared_mem = shared_mem;
+ n->cuda_launch.stream = stream;
+ n->token = launch_tok;
+ return n;
+ }
+
// Do-while loop: do { body } while condition;
if (strncmp(tk.start, "do", 2) == 0 && tk.len == 2)
{
diff --git a/src/zprep.h b/src/zprep.h
index f9bb6b6..18c4c51 100644
--- a/src/zprep.h
+++ b/src/zprep.h
@@ -182,6 +182,7 @@ typedef struct
int is_freestanding; // 1 if --freestanding.
int mode_transpile; // 1 if 'transpile' command.
int use_cpp; // 1 if --cpp (emit C++ compatible code).
+ int use_cuda; // 1 if --cuda (emit CUDA-compatible code).
// GCC Flags accumulator.
char gcc_flags[4096];