amdgpu,nvptx: unify kernel calling conventions

AmdgpuKernel and NvptxKernel are unified into a Kernel calling convention.
There is really no reason for these to be separate; no backend is allowed to
emit the calling convention of the other. This is in the same spirit as the
.Interrupt calling convention lowering to different LLVM calling conventions,
and opens the way for SPIR-V kernels to be exported using the Kernel calling
convention.
This commit is contained in:
Robin Voetter 2023-04-08 14:17:34 +02:00
parent 3f2025f59e
commit f12beb857a
No known key found for this signature in database
GPG key ID: E755662F227CB468
6 changed files with 27 additions and 30 deletions

View file

@ -160,8 +160,7 @@ pub const CallingConvention = enum {
AAPCSVFP,
SysV,
Win64,
PtxKernel,
AmdgpuKernel,
Kernel,
};
/// This data structure is used by the Zig language code generation and

View file

@ -8883,7 +8883,7 @@ fn funcCommon(
};
return sema.failWithOwnedErrorMsg(msg);
}
if (!ret_poison and !Type.fnCallingConventionAllowsZigTypes(cc_resolved) and !try sema.validateExternType(return_type, .ret_ty)) {
if (!ret_poison and !Type.fnCallingConventionAllowsZigTypes(target, cc_resolved) and !try sema.validateExternType(return_type, .ret_ty)) {
const msg = msg: {
const msg = try sema.errMsg(block, ret_ty_src, "return type '{}' not allowed in function with calling convention '{s}'", .{
return_type.fmt(sema.mod), @tagName(cc_resolved),
@ -8961,13 +8961,9 @@ fn funcCommon(
.x86_64 => null,
else => @as([]const u8, "x86_64"),
},
.PtxKernel => switch (arch) {
.nvptx, .nvptx64 => null,
else => @as([]const u8, "nvptx and nvptx64"),
},
.AmdgpuKernel => switch (arch) {
.amdgcn => null,
else => @as([]const u8, "amdgcn"),
.Kernel => switch (arch) {
.nvptx, .nvptx64, .amdgcn, .spirv32, .spirv64 => null,
else => @as([]const u8, "nvptx, amdgcn and SPIR-V"),
},
}) |allowed_platform| {
return sema.fail(block, cc_src, "callconv '{s}' is only available on {s}, not {s}", .{
@ -9093,10 +9089,11 @@ fn analyzeParameter(
comptime_params[i] = param.is_comptime or requires_comptime;
const this_generic = param.ty.tag() == .generic_poison;
is_generic.* = is_generic.* or this_generic;
if (param.is_comptime and !Type.fnCallingConventionAllowsZigTypes(cc)) {
const target = sema.mod.getTarget();
if (param.is_comptime and !Type.fnCallingConventionAllowsZigTypes(target, cc)) {
return sema.fail(block, param_src, "comptime parameters not allowed in function with calling convention '{s}'", .{@tagName(cc)});
}
if (this_generic and !sema.no_partial_func_ty and !Type.fnCallingConventionAllowsZigTypes(cc)) {
if (this_generic and !sema.no_partial_func_ty and !Type.fnCallingConventionAllowsZigTypes(target, cc)) {
return sema.fail(block, param_src, "generic parameters not allowed in function with calling convention '{s}'", .{@tagName(cc)});
}
if (!param.ty.isValidParamType()) {
@ -9112,7 +9109,7 @@ fn analyzeParameter(
};
return sema.failWithOwnedErrorMsg(msg);
}
if (!this_generic and !Type.fnCallingConventionAllowsZigTypes(cc) and !try sema.validateExternType(param.ty, .param_ty)) {
if (!this_generic and !Type.fnCallingConventionAllowsZigTypes(target, cc) and !try sema.validateExternType(param.ty, .param_ty)) {
const msg = msg: {
const msg = try sema.errMsg(block, param_src, "parameter of type '{}' not allowed in function with calling convention '{s}'", .{
param.ty.fmt(sema.mod), @tagName(cc),
@ -22786,12 +22783,13 @@ fn validateExternType(
},
.Fn => {
if (position != .other) return false;
return switch (ty.fnCallingConvention()) {
// For now we want to authorize PTX kernel to use zig objects, even if we end up exposing the ABI.
// The goal is to experiment with more integrated CPU/GPU code.
.PtxKernel => true,
else => !Type.fnCallingConventionAllowsZigTypes(ty.fnCallingConvention()),
};
const target = sema.mod.getTarget();
// For now we want to authorize PTX kernel to use zig objects, even if we end up exposing the ABI.
// The goal is to experiment with more integrated CPU/GPU code.
if (ty.fnCallingConvention() == .Kernel and (target.cpu.arch == .nvptx or target.cpu.arch == .nvptx64)) {
return true;
}
return !Type.fnCallingConventionAllowsZigTypes(target, ty.fnCallingConvention());
},
.Enum => {
var buf: Type.Payload.Bits = undefined;

View file

@ -10350,11 +10350,8 @@ fn toLlvmCallConv(cc: std.builtin.CallingConvention, target: std.Target) llvm.Ca
.Signal => .AVR_SIGNAL,
.SysV => .X86_64_SysV,
.Win64 => .Win64,
.PtxKernel => return switch (target.cpu.arch) {
.Kernel => return switch (target.cpu.arch) {
.nvptx, .nvptx64 => .PTX_Kernel,
else => unreachable,
},
.AmdgpuKernel => return switch (target.cpu.arch) {
.amdgcn => .AMDGPU_KERNEL,
else => unreachable,
},

View file

@ -1,7 +1,7 @@
//! NVidia PTX (Paralle Thread Execution)
//! https://docs.nvidia.com/cuda/parallel-thread-execution/index.html
//! For this we rely on the nvptx backend of LLVM
//! Kernel functions need to be marked both as "export" and "callconv(.PtxKernel)"
//! Kernel functions need to be marked both as "export" and "callconv(.Kernel)"
const NvPtx = @This();

View file

@ -4796,9 +4796,12 @@ pub const Type = extern union {
}
/// Asserts the type is a function.
pub fn fnCallingConventionAllowsZigTypes(cc: std.builtin.CallingConvention) bool {
pub fn fnCallingConventionAllowsZigTypes(target: Target, cc: std.builtin.CallingConvention) bool {
return switch (cc) {
.Unspecified, .Async, .Inline, .PtxKernel => true,
.Unspecified, .Async, .Inline => true,
// For now we want to authorize PTX kernel to use zig objects, even if we end up exposing the ABI.
// The goal is to experiment with more integrated CPU/GPU code.
.Kernel => target.cpu.arch == .nvptx or target.cpu.arch == .nvptx64,
else => false,
};
}

View file

@ -10,7 +10,7 @@ pub fn addCases(ctx: *Cases) !void {
\\ return a + b;
\\}
\\
\\pub export fn add_and_substract(a: i32, out: *i32) callconv(.PtxKernel) void {
\\pub export fn add_and_substract(a: i32, out: *i32) callconv(.Kernel) void {
\\ const x = add(a, 7);
\\ var y = add(2, 0);
\\ y -= x;
@ -29,7 +29,7 @@ pub fn addCases(ctx: *Cases) !void {
\\ );
\\}
\\
\\pub export fn special_reg(a: []const i32, out: []i32) callconv(.PtxKernel) void {
\\pub export fn special_reg(a: []const i32, out: []i32) callconv(.Kernel) void {
\\ const i = threadIdX();
\\ out[i] = a[i] + 7;
\\}
@ -42,7 +42,7 @@ pub fn addCases(ctx: *Cases) !void {
case.addCompile(
\\var x: i32 addrspace(.global) = 0;
\\
\\pub export fn increment(out: *i32) callconv(.PtxKernel) void {
\\pub export fn increment(out: *i32) callconv(.Kernel) void {
\\ x += 1;
\\ out.* = x;
\\}
@ -59,7 +59,7 @@ pub fn addCases(ctx: *Cases) !void {
\\}
\\
\\ var _sdata: [1024]f32 addrspace(.shared) = undefined;
\\ pub export fn reduceSum(d_x: []const f32, out: *f32) callconv(.PtxKernel) void {
\\ pub export fn reduceSum(d_x: []const f32, out: *f32) callconv(.Kernel) void {
\\ var sdata = @addrSpaceCast(.generic, &_sdata);
\\ const tid: u32 = threadIdX();
\\ var sum = d_x[tid];