diff --git a/src/endpoint.zig b/src/endpoint.zig index 7fcb305..4c55e3f 100644 --- a/src/endpoint.zig +++ b/src/endpoint.zig @@ -77,117 +77,14 @@ const Request = zap.Request; const ListenerSettings = zap.HttpListenerSettings; const HttpListener = zap.HttpListener; -const ImplementedMethods = struct { - get: bool = false, - head: bool = false, - post: bool = false, - put: bool = false, - delete: bool = false, - patch: bool = false, - options: bool = false, - custom_method: bool = false, -}; - -pub fn checkEndpointType(T: type) ImplementedMethods { - var implemented_methods: ImplementedMethods = .{}; - - if (@hasField(T, "path")) { - if (@FieldType(T, "path") != []const u8) { - @compileError(@typeName(@FieldType(T, "path")) ++ " has wrong type, expected: []const u8"); - } - } else { - @compileError(@typeName(T) ++ " has no path field"); - } - - if (@hasField(T, "error_strategy")) { - if (@FieldType(T, "error_strategy") != ErrorStrategy) { - @compileError(@typeName(@FieldType(T, "error_strategy")) ++ " has wrong type, expected: zap.Endpoint.ErrorStrategy"); - } - } else { - @compileError(@typeName(T) ++ " has no error_strategy field"); - } - - // TODO: use field names of ImplementedMethods - const methods_to_check = [_][]const u8{ - "get", - "head", - "post", - "put", - "delete", - "patch", - "options", - "custom_method", - }; - - const params_to_check = [_]type{ - *T, - Request, - }; - - inline for (methods_to_check) |method| { - if (@hasDecl(T, method)) { - const Method = @TypeOf(@field(T, method)); - const method_info = @typeInfo(Method); - if (method_info != .@"fn") { - @compileError("Expected `" ++ @typeName(T) ++ "." ++ method ++ "` to be a request handler method, got: " ++ @typeName(Method)); - } - - // now check parameters - const params = method_info.@"fn".params; - if (params.len != params_to_check.len) { - @compileError(std.fmt.comptimePrint( - "Expected method `{s}.{s}` to have {d} parameters, got {d}", - .{ - @typeName(T), - method, - params_to_check.len, - params.len, - }, - )); - } - - inline for (params_to_check, 0..) |param_type_expected, i| { - if (params[i].type.? != param_type_expected) { - @compileError(std.fmt.comptimePrint( - "Expected parameter {d} of method {s}.{s} to be {s}, got {s}", - .{ - i + 1, - @typeName(T), - method, - @typeName(param_type_expected), - @typeName(params[i].type.?), - }, - )); - } - } - - // check return type - const ret_type = method_info.@"fn".return_type.?; - const ret_info = @typeInfo(ret_type); - if (ret_info != .error_union) { - @compileError("Expected return type of method `" ++ @typeName(T) ++ "." ++ method ++ "` to be !void, got: " ++ @typeName(ret_type)); - } - if (ret_info.error_union.payload != void) { - @compileError("Expected return type of method `" ++ @typeName(T) ++ "." ++ method ++ "` to be !void, got: !" ++ @typeName(ret_info.error_union.payload)); - } - @field(implemented_methods, method) = true; - } else { - @compileError(@typeName(T) ++ " has no method named `" ++ method ++ "`"); - // TODO: shall we warn? - // No, we should provide a default implementation that calls - // "unhandled request" callback, and if that's not defined, log the - // request as being unhandled. - } - } - return implemented_methods; -} - pub const Binder = struct { pub const Interface = struct { - call: *const fn (*Interface, zap.Request) anyerror!void = undefined, + pub const CallBack = *const fn (*Interface, zap.Request) anyerror!void; + pub const MethodCallbacks = [std.meta.fields(zap.http.Method).len]CallBack; + + methods: MethodCallbacks, path: []const u8, destroy: *const fn (*Interface, std.mem.Allocator) void = undefined, - implemented_methods: ImplementedMethods = undefined, }; pub fn Bind(ArbitraryEndpoint: type) type { return struct { @@ -211,17 +108,13 @@ pub const Binder = struct { try self.onRequest(r); } + pub fn onUnimplementedInterface(_: *Interface, _: zap.Request) !void { + return error.NotImplemented; + } + pub fn onRequest(self: *Bound, r: zap.Request) !void { - const ret = switch (r.methodAsEnum()) { - .GET => self.endpoint.*.get(r), - .HEAD => self.endpoint.*.head(r), - .POST => self.endpoint.*.post(r), - .PUT => self.endpoint.*.put(r), - .DELETE => self.endpoint.*.delete(r), - .PATCH => self.endpoint.*.patch(r), - .OPTIONS => self.endpoint.*.options(r), - .UNKNOWN => self.endpoint.*.custom_method(r), - }; + const method_index = @intFromEnum(r.methodAsEnum()); + const ret = self.interface.methods[method_index](self, r); if (ret) { // handled without error } else |err| { @@ -232,19 +125,105 @@ pub const Binder = struct { } } } + + fn setupEndpoint() Interface.MethodCallbacks { + const T = ArbitraryEndpoint; + if (@hasField(T, "path")) { + if (@FieldType(T, "path") != []const u8) { + @compileError(@typeName(@FieldType(T, "path")) ++ " has wrong type, expected: []const u8"); + } + } else { + @compileError(@typeName(T) ++ " has no path field"); + } + + if (@hasField(T, "error_strategy")) { + if (@FieldType(T, "error_strategy") != ErrorStrategy) { + @compileError(@typeName(@FieldType(T, "error_strategy")) ++ " has wrong type, expected: zap.Endpoint.ErrorStrategy"); + } + } else { + @compileError(@typeName(T) ++ " has no error_strategy field"); + } + + const params_to_check = [_]type{ + *T, + Request, + }; + + comptime { + var method_callbacks: Interface.MethodCallbacks = undefined; + for (std.meta.tags(zap.http.Method)) |http_method| { + const method: []const u8 = if (http_method != .unknown) @tagName(http_method) else "custom_method"; + if (@hasDecl(T, method)) { + const Method = @TypeOf(@field(T, method)); + const method_info = @typeInfo(Method); + if (method_info != .@"fn") { + @compileError("Expected `" ++ @typeName(T) ++ "." ++ method ++ "` to be a request handler method, got: " ++ @typeName(Method)); + } + + // now check parameters + const params = method_info.@"fn".params; + if (params.len != params_to_check.len) { + @compileError(std.fmt.comptimePrint( + "Expected method `{s}.{s}` to have {d} parameters, got {d}", + .{ + @typeName(T), + method, + params_to_check.len, + params.len, + }, + )); + } + + for (params_to_check, 0..) |param_type_expected, i| { + if (params[i].type.? != param_type_expected) { + @compileError(std.fmt.comptimePrint( + "Expected parameter {d} of method {s}.{s} to be {s}, got {s}", + .{ + i + 1, + @typeName(T), + method, + @typeName(param_type_expected), + @typeName(params[i].type.?), + }, + )); + } + } + + // check return type + const ret_type = method_info.@"fn".return_type.?; + const ret_info = @typeInfo(ret_type); + if (ret_info != .error_union) { + @compileError("Expected return type of method `" ++ @typeName(T) ++ "." ++ method ++ "` to be !void, got: " ++ @typeName(ret_type)); + } + if (ret_info.error_union.payload != void) { + @compileError("Expected return type of method `" ++ @typeName(T) ++ "." ++ method ++ "` to be !void, got: !" ++ @typeName(ret_info.error_union.payload)); + } + + method_callbacks[@intFromEnum(http_method)] = Bound.onRequestInterface; + } else { + @compileError(@typeName(T) ++ " has no method named `" ++ method ++ "`"); + // TODO: shall we warn? + // No, we should provide a default implementation that calls + // "unhandled request" callback, and if that's not defined, log the + // request as being unhandled. + } + } + return method_callbacks; + } + } }; } pub fn init(ArbitraryEndpoint: type, value: *ArbitraryEndpoint) Binder.Bind(ArbitraryEndpoint) { - const implemented_methods = checkEndpointType(ArbitraryEndpoint); const BoundEp = Binder.Bind(ArbitraryEndpoint); + const methods = BoundEp.setupEndpoint(); return .{ .endpoint = value, .interface = .{ .path = value.path, .call = BoundEp.onRequestInterface, .destroy = BoundEp.destroy, - .implemented_methods = implemented_methods, + .implemented_methods = methods, }, }; } diff --git a/src/http.zig b/src/http.zig index f07ad2b..76af9a3 100644 --- a/src/http.zig +++ b/src/http.zig @@ -131,21 +131,21 @@ pub const StatusCode = enum(u16) { }; pub const Method = enum { - GET, - HEAD, - POST, - PUT, - DELETE, - PATCH, - OPTIONS, - UNKNOWN, + get, + head, + post, + put, + delete, + patch, + options, + unknown, }; pub fn methodToEnum(method: ?[]const u8) Method { { if (method) |m| { - return std.meta.stringToEnum(Method, m) orelse .UNKNOWN; + return std.meta.stringToEnum(Method, m) orelse .unknown; } - return Method.UNKNOWN; + return .unknown; } }