From 645db5c8d06ad20e0310329579484463e32e6028 Mon Sep 17 00:00:00 2001 From: Rene Schallner Date: Mon, 8 May 2023 00:06:57 +0200 Subject: [PATCH] AuthResult is now an enum (was: bool) --- src/endpoint.zig | 96 ++++++++++++++++++++++------------------- src/http_auth.zig | 59 +++++++++++++++---------- src/tests/test_auth.zig | 24 +++++------ 3 files changed, 100 insertions(+), 79 deletions(-) diff --git a/src/endpoint.zig b/src/endpoint.zig index 3a24431..a53c6c3 100644 --- a/src/endpoint.zig +++ b/src/endpoint.zig @@ -89,69 +89,77 @@ pub fn AuthenticatingEndpoint(comptime Authenticator: type) type { /// here, the auth_endpoint will be passed in pub fn get(e: *SimpleEndpoint, r: zap.SimpleRequest) void { const authEp: *Self = @fieldParentPtr(Self, "auth_endpoint", e); - if (authEp.authenticator.authenticateRequest(&r) == false) { - if (e.settings.unauthorized) |foo| { - foo(authEp.endpoint, r); - return; - } else { - r.setStatus(.unauthorized); - r.sendBody("UNAUTHORIZED") catch return; - return; - } + switch (authEp.authenticator.authenticateRequest(&r)) { + .AuthFailed => { + if (e.settings.unauthorized) |unauthorized| { + unauthorized(authEp.endpoint, r); + return; + } else { + r.setStatus(.unauthorized); + r.sendBody("UNAUTHORIZED") catch return; + return; + } + }, + .AuthOK => authEp.endpoint.settings.get.?(authEp.endpoint, r), + .Handled => {}, } - // auth successful - authEp.endpoint.settings.get.?(authEp.endpoint, r); } /// here, the auth_endpoint will be passed in pub fn post(e: *SimpleEndpoint, r: zap.SimpleRequest) void { const authEp: *Self = @fieldParentPtr(Self, "auth_endpoint", e); - if (authEp.authenticator.authenticateRequest(&r) == false) { - if (e.settings.unauthorized) |foo| { - foo(authEp.endpoint, r); - return; - } else { - r.setStatus(.unauthorized); - r.sendBody("UNAUTHORIZED") catch return; - return; - } + switch (authEp.authenticator.authenticateRequest(&r)) { + .AuthFailed => { + if (e.settings.unauthorized) |unauthorized| { + unauthorized(authEp.endpoint, r); + return; + } else { + r.setStatus(.unauthorized); + r.sendBody("UNAUTHORIZED") catch return; + return; + } + }, + .AuthOK => authEp.endpoint.settings.post.?(authEp.endpoint, r), + .Handled => {}, } - // auth successful - authEp.endpoint.settings.post.?(authEp.endpoint, r); } /// here, the auth_endpoint will be passed in pub fn put(e: *SimpleEndpoint, r: zap.SimpleRequest) void { const authEp: *Self = @fieldParentPtr(Self, "auth_endpoint", e); - if (authEp.authenticator.authenticateRequest(&r) == false) { - if (e.settings.unauthorized) |foo| { - foo(authEp.endpoint, r); - return; - } else { - r.setStatus(.unauthorized); - r.sendBody("UNAUTHORIZED") catch return; - return; - } + switch (authEp.authenticator.authenticateRequest(&r)) { + .AuthFailed => { + if (e.settings.unauthorized) |unauthorized| { + unauthorized(authEp.endpoint, r); + return; + } else { + r.setStatus(.unauthorized); + r.sendBody("UNAUTHORIZED") catch return; + return; + } + }, + .AuthOK => authEp.endpoint.settings.put.?(authEp.endpoint, r), + .Handled => {}, } - // auth successful - authEp.endpoint.settings.put.?(authEp.endpoint, r); } /// here, the auth_endpoint will be passed in pub fn delete(e: *SimpleEndpoint, r: zap.SimpleRequest) void { const authEp: *Self = @fieldParentPtr(Self, "auth_endpoint", e); - if (authEp.authenticator.authenticateRequest(&r) == false) { - if (e.settings.unauthorized) |foo| { - foo(authEp.endpoint, r); - return; - } else { - r.setStatus(.unauthorized); - r.sendBody("UNAUTHORIZED") catch return; - return; - } + switch (authEp.authenticator.authenticateRequest(&r)) { + .AuthFailed => { + if (e.settings.unauthorized) |unauthorized| { + unauthorized(authEp.endpoint, r); + return; + } else { + r.setStatus(.unauthorized); + r.sendBody("UNAUTHORIZED") catch return; + return; + } + }, + .AuthOK => authEp.endpoint.settings.delete.?(authEp.endpoint, r), + .Handled => {}, } - // auth successful - authEp.endpoint.settings.delete.?(authEp.endpoint, r); } }; } diff --git a/src/http_auth.zig b/src/http_auth.zig index 88574e1..be96f21 100644 --- a/src/http_auth.zig +++ b/src/http_auth.zig @@ -48,6 +48,19 @@ const BasicAuthStrategy = enum { Token68, }; +pub const AuthResult = enum { + /// authentication / authorization was successful + AuthOK, + /// authentication / authorization failed + AuthFailed, + /// the authenticator handled the request that didn't pass authentication / + /// authorization . + /// this is used to implement authenticators that redirect to a login + /// page. An AuthenticatingEndpoint will not do the default, which is trying + /// to call the `unauthorized` callback or. + Handled, +}; + /// HTTP Basic Authentication RFC 7617 /// "Authorization: Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==" /// user-pass strings: "$username:$password" -> base64 @@ -91,7 +104,7 @@ pub fn BasicAuth(comptime Lookup: type, comptime kind: BasicAuthStrategy) type { } /// Use this to decode the auth_header into user:pass, lookup pass in lookup - pub fn authenticateUserPass(self: *Self, auth_header: []const u8) bool { + pub fn authenticateUserPass(self: *Self, auth_header: []const u8) AuthResult { zap.debug("AuthenticateUserPass\n", .{}); const encoded = auth_header[AuthScheme.Basic.str().len..]; const decoder = std.base64.standard.Decoder; @@ -102,7 +115,7 @@ pub fn BasicAuth(comptime Lookup: type, comptime kind: BasicAuthStrategy) type { "ERROR: UserPassAuth: decoded_size {d} >= buffer.len {d}\n", .{ decoded_size, buffer.len }, ); - return false; + return .AuthFailed; } var decoded = buffer[0..decoded_size]; decoder.decode(decoded, encoded) catch |err| { @@ -110,7 +123,7 @@ pub fn BasicAuth(comptime Lookup: type, comptime kind: BasicAuthStrategy) type { "ERROR: UserPassAuth: unable to decode `{s}`: {any}\n", .{ encoded, err }, ); - return false; + return .AuthFailed; }; // we have decoded // we can split @@ -122,7 +135,7 @@ pub fn BasicAuth(comptime Lookup: type, comptime kind: BasicAuthStrategy) type { "ERROR: UserPassAuth: user {any} or pass {any} is null\n", .{ user, pass }, ); - return false; + return .AuthFailed; } // now, do the lookup const actual_pw = self.lookup.*.get(user.?); @@ -132,13 +145,13 @@ pub fn BasicAuth(comptime Lookup: type, comptime kind: BasicAuthStrategy) type { "INFO: UserPassAuth for user `{s}`: `{s}` == pass `{s}` = {}\n", .{ user.?, pw, pass.?, ret }, ); - return ret; + return if (ret) .AuthOK else .AuthFailed; } else { zap.debug( "ERROR: UserPassAuth: user `{s}` not found in map of size {d}!\n", .{ user.?, self.lookup.*.count() }, ); - return false; + return .AuthFailed; } } else |err| { // can't calc slice size --> fallthrough to return false @@ -146,20 +159,20 @@ pub fn BasicAuth(comptime Lookup: type, comptime kind: BasicAuthStrategy) type { "ERROR: UserPassAuth: cannot calc slize size for encoded `{s}`: {any} \n", .{ encoded, err }, ); - return false; + return .AuthFailed; } zap.debug("UNREACHABLE\n", .{}); - return false; + return .AuthFailed; } /// Use this to just look up if the base64-encoded auth_header exists in lookup - pub fn authenticateToken68(self: *Self, auth_header: []const u8) bool { + pub fn authenticateToken68(self: *Self, auth_header: []const u8) AuthResult { const token = auth_header[AuthScheme.Basic.str().len..]; - return self.lookup.*.contains(token); + return if (self.lookup.*.contains(token)) .AuthOK else .AuthFailed; } // dispatch based on kind - pub fn authenticate(self: *Self, auth_header: []const u8) bool { + pub fn authenticate(self: *Self, auth_header: []const u8) AuthResult { zap.debug("AUTHENTICATE\n", .{}); // switch (self.kind) { switch (kind) { @@ -167,7 +180,7 @@ pub fn BasicAuth(comptime Lookup: type, comptime kind: BasicAuthStrategy) type { .Token68 => return self.authenticateToken68(auth_header), } } - pub fn authenticateRequest(self: *Self, r: *const zap.SimpleRequest) bool { + pub fn authenticateRequest(self: *Self, r: *const zap.SimpleRequest) AuthResult { zap.debug("AUTHENTICATE REQUEST\n", .{}); if (extractAuthHeader(.Basic, r)) |auth_header| { zap.debug("Authentication Header found!\n", .{}); @@ -180,7 +193,7 @@ pub fn BasicAuth(comptime Lookup: type, comptime kind: BasicAuthStrategy) type { } } zap.debug("NO fitting Auth Header found!\n", .{}); - return false; + return .AuthFailed; } }; } @@ -213,19 +226,19 @@ pub const BearerAuthSingle = struct { .realm = if (realm) |the_realm| try allocator.dupe(u8, the_realm) else null, }; } - pub fn authenticate(self: *Self, auth_header: []const u8) bool { + pub fn authenticate(self: *Self, auth_header: []const u8) AuthResult { if (checkAuthHeader(.Bearer, auth_header) == false) { - return false; + return .AuthFailed; } const token = auth_header[AuthScheme.Bearer.str().len..]; - return std.mem.eql(u8, token, self.token); + return if (std.mem.eql(u8, token, self.token)) .AuthOK else .AuthFailed; } - pub fn authenticateRequest(self: *Self, r: *const zap.SimpleRequest) bool { + pub fn authenticateRequest(self: *Self, r: *const zap.SimpleRequest) AuthResult { if (extractAuthHeader(.Bearer, r)) |auth_header| { return self.authenticate(auth_header); } - return false; + return .AuthFailed; } pub fn deinit(self: *Self) void { @@ -271,19 +284,19 @@ pub fn BearerAuthMulti(comptime Lookup: type) type { } } - pub fn authenticate(self: *Self, auth_header: []const u8) bool { + pub fn authenticate(self: *Self, auth_header: []const u8) AuthResult { if (checkAuthHeader(.Bearer, auth_header) == false) { - return false; + return .AuthFailed; } const token = auth_header[AuthScheme.Bearer.str().len..]; - return self.lookup.*.contains(token); + return if (self.lookup.*.contains(token)) .AuthOK else .AuthFailed; } - pub fn authenticateRequest(self: *Self, r: *const zap.SimpleRequest) bool { + pub fn authenticateRequest(self: *Self, r: *const zap.SimpleRequest) AuthResult { if (extractAuthHeader(.Bearer, r)) |auth_header| { return self.authenticate(auth_header); } - return false; + return .AuthFailed; } }; } diff --git a/src/tests/test_auth.zig b/src/tests/test_auth.zig index 6fc56bb..2b5e246 100644 --- a/src/tests/test_auth.zig +++ b/src/tests/test_auth.zig @@ -17,10 +17,10 @@ test "BearerAuthSingle authenticate" { defer auth.deinit(); // invalid auth header - try std.testing.expectEqual(auth.authenticate("wrong header"), false); - try std.testing.expectEqual(auth.authenticate("Bearer wrong-token"), false); + try std.testing.expectEqual(auth.authenticate("wrong header"), .AuthFailed); + try std.testing.expectEqual(auth.authenticate("Bearer wrong-token"), .AuthFailed); // ok - try std.testing.expectEqual(auth.authenticate("Bearer " ++ token), true); + try std.testing.expectEqual(auth.authenticate("Bearer " ++ token), .AuthOK); } test "BearerAuthMulti authenticate" { @@ -37,10 +37,10 @@ test "BearerAuthMulti authenticate" { defer auth.deinit(); // invalid auth header - try std.testing.expectEqual(auth.authenticate("wrong header"), false); - try std.testing.expectEqual(auth.authenticate("Bearer wrong-token"), false); + try std.testing.expectEqual(auth.authenticate("wrong header"), .AuthFailed); + try std.testing.expectEqual(auth.authenticate("Bearer wrong-token"), .AuthFailed); // ok - try std.testing.expectEqual(auth.authenticate("Bearer " ++ token), true); + try std.testing.expectEqual(auth.authenticate("Bearer " ++ token), .AuthOK); } test "BasicAuth Token68" { @@ -59,10 +59,10 @@ test "BasicAuth Token68" { defer auth.deinit(); // invalid auth header - try std.testing.expectEqual(auth.authenticate("wrong header"), false); - try std.testing.expectEqual(auth.authenticate("Basic wrong-token"), false); + try std.testing.expectEqual(auth.authenticate("wrong header"), .AuthFailed); + try std.testing.expectEqual(auth.authenticate("Basic wrong-token"), .AuthFailed); // ok - try std.testing.expectEqual(auth.authenticate("Basic " ++ token), true); + try std.testing.expectEqual(auth.authenticate("Basic " ++ token), .AuthOK); } test "BasicAuth UserPass" { @@ -90,13 +90,13 @@ test "BasicAuth UserPass" { defer auth.deinit(); // invalid auth header - try std.testing.expectEqual(auth.authenticate("wrong header"), false); - try std.testing.expectEqual(auth.authenticate("Basic wrong-token"), false); + try std.testing.expectEqual(auth.authenticate("wrong header"), .AuthFailed); + try std.testing.expectEqual(auth.authenticate("Basic wrong-token"), .AuthFailed); // ok const expected = try std.fmt.allocPrint(a, "Basic {s}", .{encoded}); defer a.free(expected); - try std.testing.expectEqual(auth.authenticate(expected), true); + try std.testing.expectEqual(auth.authenticate(expected), .AuthOK); } const HTTP_RESPONSE: []const u8 =