enable Gpu address spaces (#10884)

This commit is contained in:
gwenzek 2022-02-21 20:05:27 +01:00 committed by GitHub
parent d8da9a01fc
commit 628e9e6d04
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 78 additions and 1 deletions

View file

@ -157,6 +157,12 @@ pub const AddressSpace = enum {
gs,
fs,
ss,
// GPU address spaces
global,
constant,
param,
shared,
local,
};
/// This data structure is used by the Zig language code generation and

View file

@ -18006,10 +18006,14 @@ pub fn analyzeAddrspace(
const address_space = addrspace_tv.val.toEnum(std.builtin.AddressSpace);
const target = sema.mod.getTarget();
const arch = target.cpu.arch;
const is_gpu = arch == .nvptx or arch == .nvptx64;
const supported = switch (address_space) {
.generic => true,
.gs, .fs, .ss => (arch == .i386 or arch == .x86_64) and ctx == .pointer,
// TODO: check that .shared and .local are left uninitialized
.global, .param, .shared, .local => is_gpu,
.constant => is_gpu and (ctx == .constant),
};
if (!supported) {
@ -18020,7 +18024,6 @@ pub fn analyzeAddrspace(
.constant => "constant values",
.pointer => "pointers",
};
return sema.fail(
block,
src,

View file

@ -801,6 +801,16 @@ pub const DeclGen = struct {
.gs => llvm.address_space.x86.gs,
.fs => llvm.address_space.x86.fs,
.ss => llvm.address_space.x86.ss,
else => unreachable,
},
.nvptx, .nvptx64 => switch (address_space) {
.generic => llvm.address_space.default,
.global => llvm.address_space.nvptx.global,
.constant => llvm.address_space.nvptx.constant,
.param => llvm.address_space.nvptx.param,
.shared => llvm.address_space.nvptx.shared,
.local => llvm.address_space.nvptx.local,
else => unreachable,
},
else => switch (address_space) {
.generic => llvm.address_space.default,

View file

@ -16,4 +16,5 @@ pub fn addCases(ctx: *TestContext) !void {
try @import("stage2/riscv64.zig").addCases(ctx);
try @import("stage2/plan9.zig").addCases(ctx);
try @import("stage2/x86_64.zig").addCases(ctx);
try @import("stage2/nvptx.zig").addCases(ctx);
}

57
test/stage2/nvptx.zig Normal file
View file

@ -0,0 +1,57 @@
const std = @import("std");
const TestContext = @import("../../src/test.zig").TestContext;
const nvptx = std.zig.CrossTarget{
.cpu_arch = .nvptx64,
.os_tag = .cuda,
};
pub fn addCases(ctx: *TestContext) !void {
{
var case = ctx.exeUsingLlvmBackend("simple addition and subtraction", nvptx);
case.compiles(
\\fn add(a: i32, b: i32) i32 {
\\ return a + b;
\\}
\\
\\pub export fn main(a: i32, out: *i32) callconv(.PtxKernel) void {
\\ const x = add(a, 7);
\\ var y = add(2, 0);
\\ y -= x;
\\ out.* = y;
\\}
);
}
{
var case = ctx.exeUsingLlvmBackend("read special registers", nvptx);
case.compiles(
\\fn tid() usize {
\\ var tid = asm volatile ("mov.u32 \t$0, %tid.x;"
\\ : [ret] "=r" (-> u32),
\\ );
\\ return @as(usize, tid);
\\}
\\
\\pub export fn main(a: []const i32, out: []i32) callconv(.PtxKernel) void {
\\ const i = tid();
\\ out[i] = a[i] + 7;
\\}
);
}
{
var case = ctx.exeUsingLlvmBackend("address spaces", nvptx);
case.compiles(
\\var x: u32 addrspace(.global) = 0;
\\
\\pub export fn increment(out: *i32) callconv(.PtxKernel) void {
\\ x += 1;
\\ out.* = x;
\\}
);
}
}