diff --git a/examples/endpoint/stopendpoint.zig b/examples/endpoint/stopendpoint.zig index f96ec37..0e5d067 100644 --- a/examples/endpoint/stopendpoint.zig +++ b/examples/endpoint/stopendpoint.zig @@ -14,14 +14,14 @@ pub fn init(path: []const u8) StopEndpoint { }; } -pub fn get(e: *StopEndpoint, r: zap.Request) anyerror!void { +pub fn get(e: *StopEndpoint, r: zap.Request) !void { _ = e; _ = r; zap.stop(); } -pub fn post(_: *StopEndpoint, _: zap.Request) anyerror!void {} -pub fn put(_: *StopEndpoint, _: zap.Request) anyerror!void {} -pub fn delete(_: *StopEndpoint, _: zap.Request) anyerror!void {} -pub fn patch(_: *StopEndpoint, _: zap.Request) anyerror!void {} -pub fn options(_: *StopEndpoint, _: zap.Request) anyerror!void {} +pub fn post(_: *StopEndpoint, _: zap.Request) !void {} +pub fn put(_: *StopEndpoint, _: zap.Request) !void {} +pub fn delete(_: *StopEndpoint, _: zap.Request) !void {} +pub fn patch(_: *StopEndpoint, _: zap.Request) !void {} +pub fn options(_: *StopEndpoint, _: zap.Request) !void {} diff --git a/examples/endpoint/userweb.zig b/examples/endpoint/userweb.zig index d45370c..2d13bf9 100644 --- a/examples/endpoint/userweb.zig +++ b/examples/endpoint/userweb.zig @@ -43,8 +43,9 @@ fn userIdFromPath(self: *UserWeb, path: []const u8) ?usize { return null; } -pub fn put(_: *UserWeb, _: zap.Request) anyerror!void {} -pub fn get(self: *UserWeb, r: zap.Request) anyerror!void { +pub fn put(_: *UserWeb, _: zap.Request) !void {} + +pub fn get(self: *UserWeb, r: zap.Request) !void { if (r.path) |path| { // /users if (path.len == self.path.len) { @@ -69,7 +70,7 @@ fn listUsers(self: *UserWeb, r: zap.Request) !void { } } -pub fn post(self: *UserWeb, r: zap.Request) anyerror!void { +pub fn post(self: *UserWeb, r: zap.Request) !void { if (r.body) |body| { const maybe_user: ?std.json.Parsed(User) = std.json.parseFromSlice(User, self.alloc, body, .{}) catch null; if (maybe_user) |u| { @@ -86,7 +87,7 @@ pub fn post(self: *UserWeb, r: zap.Request) anyerror!void { } } -pub fn patch(self: *UserWeb, r: zap.Request) anyerror!void { +pub fn patch(self: *UserWeb, r: zap.Request) !void { if (r.path) |path| { if (self.userIdFromPath(path)) |id| { if (self._users.get(id)) |_| { @@ -109,7 +110,7 @@ pub fn patch(self: *UserWeb, r: zap.Request) anyerror!void { } } -pub fn delete(self: *UserWeb, r: zap.Request) anyerror!void { +pub fn delete(self: *UserWeb, r: zap.Request) !void { if (r.path) |path| { if (self.userIdFromPath(path)) |id| { var jsonbuf: [128]u8 = undefined; @@ -124,7 +125,7 @@ pub fn delete(self: *UserWeb, r: zap.Request) anyerror!void { } } -pub fn options(_: *UserWeb, r: zap.Request) anyerror!void { +pub fn options(_: *UserWeb, r: zap.Request) !void { try r.setHeader("Access-Control-Allow-Origin", "*"); try r.setHeader("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS"); r.setStatus(zap.http.StatusCode.no_content); diff --git a/examples/endpoint_auth/endpoint_auth.zig b/examples/endpoint_auth/endpoint_auth.zig index 57b2a37..0cf4879 100644 --- a/examples/endpoint_auth/endpoint_auth.zig +++ b/examples/endpoint_auth/endpoint_auth.zig @@ -1,3 +1,9 @@ +//! +//! Part of the Zap examples. +//! +//! Build me with `zig build endpoint_auth`. +//! Run me with `zig build run-endpoint_auth`. +//! const std = @import("std"); const zap = @import("zap"); diff --git a/src/App.zig b/src/App.zig index 4699b77..d53cba3 100644 --- a/src/App.zig +++ b/src/App.zig @@ -153,12 +153,56 @@ pub fn Create(comptime Context: type) type { "patch", "options", }; + const params_to_check = [_]type{ + *T, + Allocator, + *Context, + Request, + }; inline for (methods_to_check) |method| { if (@hasDecl(T, method)) { const Method = @TypeOf(@field(T, method)); - const Expected = fn (_: *T, _: Allocator, _: *Context, _: Request) anyerror!void; - if (Method != Expected) { - @compileError(method ++ " method of " ++ @typeName(T) ++ " has wrong type:\n" ++ @typeName(Method) ++ "\nexpected:\n" ++ @typeName(Expected)); + const method_info = @typeInfo(Method); + if (method_info != .@"fn") { + @compileError("Expected `" ++ @typeName(T) ++ "." ++ method ++ "` to be a request handler method, got: " ++ @typeName(Method)); + } + + // now check parameters + const params = method_info.@"fn".params; + if (params.len != params_to_check.len) { + @compileError(std.fmt.comptimePrint( + "Expected method `{s}.{s}` to have {d} parameters, got {d}", + .{ + @typeName(T), + method, + params_to_check.len, + params.len, + }, + )); + } + + inline for (params_to_check, 0..) |param_type_expected, i| { + if (params[i].type.? != param_type_expected) { + @compileError(std.fmt.comptimePrint( + "Expected parameter {d} of method {s}.{s} to be {s}, got {s}", + .{ + i + 1, + @typeName(T), + method, + @typeName(param_type_expected), + @typeName(params[i].type.?), + }, + )); + } + } + + const ret_type = method_info.@"fn".return_type.?; + const ret_info = @typeInfo(ret_type); + if (ret_info != .error_union) { + @compileError("Expected return type of method `" ++ @typeName(T) ++ "." ++ method ++ "` to be !void, got: " ++ @typeName(ret_type)); + } + if (ret_info.error_union.payload != void) { + @compileError("Expected return type of method `" ++ @typeName(T) ++ "." ++ method ++ "` to be !void, got: !" ++ @typeName(ret_info.error_union.payload)); } } else { @compileError(@typeName(T) ++ " has no method named `" ++ method ++ "`"); diff --git a/src/endpoint.zig b/src/endpoint.zig index c32a489..ae67fbc 100644 --- a/src/endpoint.zig +++ b/src/endpoint.zig @@ -95,10 +95,57 @@ pub fn checkEndpointType(T: type) void { "patch", "options", }; + + const params_to_check = [_]type{ + *T, + Request, + }; + inline for (methods_to_check) |method| { if (@hasDecl(T, method)) { - if (@TypeOf(@field(T, method)) != fn (_: *T, _: Request) anyerror!void) { - @compileError(method ++ " method of " ++ @typeName(T) ++ " has wrong type:\n" ++ @typeName(@TypeOf(T.get)) ++ "\nexpected:\n" ++ @typeName(fn (_: *T, _: Request) anyerror!void)); + const Method = @TypeOf(@field(T, method)); + const method_info = @typeInfo(Method); + if (method_info != .@"fn") { + @compileError("Expected `" ++ @typeName(T) ++ "." ++ method ++ "` to be a request handler method, got: " ++ @typeName(Method)); + } + + // now check parameters + const params = method_info.@"fn".params; + if (params.len != params_to_check.len) { + @compileError(std.fmt.comptimePrint( + "Expected method `{s}.{s}` to have {d} parameters, got {d}", + .{ + @typeName(T), + method, + params_to_check.len, + params.len, + }, + )); + } + + inline for (params_to_check, 0..) |param_type_expected, i| { + if (params[i].type.? != param_type_expected) { + @compileError(std.fmt.comptimePrint( + "Expected parameter {d} of method {s}.{s} to be {s}, got {s}", + .{ + i + 1, + @typeName(T), + method, + @typeName(param_type_expected), + @typeName(params[i].type.?), + }, + )); + } + } + + // check return type + const ret_type = method_info.@"fn".return_type.?; + const ret_info = @typeInfo(ret_type); + if (ret_info != .error_union) { + @compileError("Expected return type of method `" ++ @typeName(T) ++ "." ++ method ++ "` to be !void, got: " ++ @typeName(ret_type)); + } + if (ret_info.error_union.payload != void) { + @compileError("Expected return type of method `" ++ @typeName(T) ++ "." ++ method ++ "` to be !void, got: !" ++ @typeName(ret_info.error_union.payload)); } } else { @compileError(@typeName(T) ++ " has no method named `" ++ method ++ "`");