1
0
Fork 0
mirror of https://github.com/zigzap/zap.git synced 2025-10-21 07:34:08 +00:00

AuthResult is now an enum (was: bool)

This commit is contained in:
Rene Schallner 2023-05-08 00:06:57 +02:00
parent ed156ba654
commit 645db5c8d0
3 changed files with 100 additions and 79 deletions

View file

@ -89,69 +89,77 @@ pub fn AuthenticatingEndpoint(comptime Authenticator: type) type {
/// here, the auth_endpoint will be passed in /// here, the auth_endpoint will be passed in
pub fn get(e: *SimpleEndpoint, r: zap.SimpleRequest) void { pub fn get(e: *SimpleEndpoint, r: zap.SimpleRequest) void {
const authEp: *Self = @fieldParentPtr(Self, "auth_endpoint", e); const authEp: *Self = @fieldParentPtr(Self, "auth_endpoint", e);
if (authEp.authenticator.authenticateRequest(&r) == false) { switch (authEp.authenticator.authenticateRequest(&r)) {
if (e.settings.unauthorized) |foo| { .AuthFailed => {
foo(authEp.endpoint, r); if (e.settings.unauthorized) |unauthorized| {
unauthorized(authEp.endpoint, r);
return; return;
} else { } else {
r.setStatus(.unauthorized); r.setStatus(.unauthorized);
r.sendBody("UNAUTHORIZED") catch return; r.sendBody("UNAUTHORIZED") catch return;
return; return;
} }
},
.AuthOK => authEp.endpoint.settings.get.?(authEp.endpoint, r),
.Handled => {},
} }
// auth successful
authEp.endpoint.settings.get.?(authEp.endpoint, r);
} }
/// here, the auth_endpoint will be passed in /// here, the auth_endpoint will be passed in
pub fn post(e: *SimpleEndpoint, r: zap.SimpleRequest) void { pub fn post(e: *SimpleEndpoint, r: zap.SimpleRequest) void {
const authEp: *Self = @fieldParentPtr(Self, "auth_endpoint", e); const authEp: *Self = @fieldParentPtr(Self, "auth_endpoint", e);
if (authEp.authenticator.authenticateRequest(&r) == false) { switch (authEp.authenticator.authenticateRequest(&r)) {
if (e.settings.unauthorized) |foo| { .AuthFailed => {
foo(authEp.endpoint, r); if (e.settings.unauthorized) |unauthorized| {
unauthorized(authEp.endpoint, r);
return; return;
} else { } else {
r.setStatus(.unauthorized); r.setStatus(.unauthorized);
r.sendBody("UNAUTHORIZED") catch return; r.sendBody("UNAUTHORIZED") catch return;
return; return;
} }
},
.AuthOK => authEp.endpoint.settings.post.?(authEp.endpoint, r),
.Handled => {},
} }
// auth successful
authEp.endpoint.settings.post.?(authEp.endpoint, r);
} }
/// here, the auth_endpoint will be passed in /// here, the auth_endpoint will be passed in
pub fn put(e: *SimpleEndpoint, r: zap.SimpleRequest) void { pub fn put(e: *SimpleEndpoint, r: zap.SimpleRequest) void {
const authEp: *Self = @fieldParentPtr(Self, "auth_endpoint", e); const authEp: *Self = @fieldParentPtr(Self, "auth_endpoint", e);
if (authEp.authenticator.authenticateRequest(&r) == false) { switch (authEp.authenticator.authenticateRequest(&r)) {
if (e.settings.unauthorized) |foo| { .AuthFailed => {
foo(authEp.endpoint, r); if (e.settings.unauthorized) |unauthorized| {
unauthorized(authEp.endpoint, r);
return; return;
} else { } else {
r.setStatus(.unauthorized); r.setStatus(.unauthorized);
r.sendBody("UNAUTHORIZED") catch return; r.sendBody("UNAUTHORIZED") catch return;
return; return;
} }
},
.AuthOK => authEp.endpoint.settings.put.?(authEp.endpoint, r),
.Handled => {},
} }
// auth successful
authEp.endpoint.settings.put.?(authEp.endpoint, r);
} }
/// here, the auth_endpoint will be passed in /// here, the auth_endpoint will be passed in
pub fn delete(e: *SimpleEndpoint, r: zap.SimpleRequest) void { pub fn delete(e: *SimpleEndpoint, r: zap.SimpleRequest) void {
const authEp: *Self = @fieldParentPtr(Self, "auth_endpoint", e); const authEp: *Self = @fieldParentPtr(Self, "auth_endpoint", e);
if (authEp.authenticator.authenticateRequest(&r) == false) { switch (authEp.authenticator.authenticateRequest(&r)) {
if (e.settings.unauthorized) |foo| { .AuthFailed => {
foo(authEp.endpoint, r); if (e.settings.unauthorized) |unauthorized| {
unauthorized(authEp.endpoint, r);
return; return;
} else { } else {
r.setStatus(.unauthorized); r.setStatus(.unauthorized);
r.sendBody("UNAUTHORIZED") catch return; r.sendBody("UNAUTHORIZED") catch return;
return; return;
} }
},
.AuthOK => authEp.endpoint.settings.delete.?(authEp.endpoint, r),
.Handled => {},
} }
// auth successful
authEp.endpoint.settings.delete.?(authEp.endpoint, r);
} }
}; };
} }

View file

@ -48,6 +48,19 @@ const BasicAuthStrategy = enum {
Token68, Token68,
}; };
pub const AuthResult = enum {
/// authentication / authorization was successful
AuthOK,
/// authentication / authorization failed
AuthFailed,
/// the authenticator handled the request that didn't pass authentication /
/// authorization .
/// this is used to implement authenticators that redirect to a login
/// page. An AuthenticatingEndpoint will not do the default, which is trying
/// to call the `unauthorized` callback or.
Handled,
};
/// HTTP Basic Authentication RFC 7617 /// HTTP Basic Authentication RFC 7617
/// "Authorization: Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==" /// "Authorization: Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ=="
/// user-pass strings: "$username:$password" -> base64 /// user-pass strings: "$username:$password" -> base64
@ -91,7 +104,7 @@ pub fn BasicAuth(comptime Lookup: type, comptime kind: BasicAuthStrategy) type {
} }
/// Use this to decode the auth_header into user:pass, lookup pass in lookup /// Use this to decode the auth_header into user:pass, lookup pass in lookup
pub fn authenticateUserPass(self: *Self, auth_header: []const u8) bool { pub fn authenticateUserPass(self: *Self, auth_header: []const u8) AuthResult {
zap.debug("AuthenticateUserPass\n", .{}); zap.debug("AuthenticateUserPass\n", .{});
const encoded = auth_header[AuthScheme.Basic.str().len..]; const encoded = auth_header[AuthScheme.Basic.str().len..];
const decoder = std.base64.standard.Decoder; const decoder = std.base64.standard.Decoder;
@ -102,7 +115,7 @@ pub fn BasicAuth(comptime Lookup: type, comptime kind: BasicAuthStrategy) type {
"ERROR: UserPassAuth: decoded_size {d} >= buffer.len {d}\n", "ERROR: UserPassAuth: decoded_size {d} >= buffer.len {d}\n",
.{ decoded_size, buffer.len }, .{ decoded_size, buffer.len },
); );
return false; return .AuthFailed;
} }
var decoded = buffer[0..decoded_size]; var decoded = buffer[0..decoded_size];
decoder.decode(decoded, encoded) catch |err| { decoder.decode(decoded, encoded) catch |err| {
@ -110,7 +123,7 @@ pub fn BasicAuth(comptime Lookup: type, comptime kind: BasicAuthStrategy) type {
"ERROR: UserPassAuth: unable to decode `{s}`: {any}\n", "ERROR: UserPassAuth: unable to decode `{s}`: {any}\n",
.{ encoded, err }, .{ encoded, err },
); );
return false; return .AuthFailed;
}; };
// we have decoded // we have decoded
// we can split // we can split
@ -122,7 +135,7 @@ pub fn BasicAuth(comptime Lookup: type, comptime kind: BasicAuthStrategy) type {
"ERROR: UserPassAuth: user {any} or pass {any} is null\n", "ERROR: UserPassAuth: user {any} or pass {any} is null\n",
.{ user, pass }, .{ user, pass },
); );
return false; return .AuthFailed;
} }
// now, do the lookup // now, do the lookup
const actual_pw = self.lookup.*.get(user.?); const actual_pw = self.lookup.*.get(user.?);
@ -132,13 +145,13 @@ pub fn BasicAuth(comptime Lookup: type, comptime kind: BasicAuthStrategy) type {
"INFO: UserPassAuth for user `{s}`: `{s}` == pass `{s}` = {}\n", "INFO: UserPassAuth for user `{s}`: `{s}` == pass `{s}` = {}\n",
.{ user.?, pw, pass.?, ret }, .{ user.?, pw, pass.?, ret },
); );
return ret; return if (ret) .AuthOK else .AuthFailed;
} else { } else {
zap.debug( zap.debug(
"ERROR: UserPassAuth: user `{s}` not found in map of size {d}!\n", "ERROR: UserPassAuth: user `{s}` not found in map of size {d}!\n",
.{ user.?, self.lookup.*.count() }, .{ user.?, self.lookup.*.count() },
); );
return false; return .AuthFailed;
} }
} else |err| { } else |err| {
// can't calc slice size --> fallthrough to return false // can't calc slice size --> fallthrough to return false
@ -146,20 +159,20 @@ pub fn BasicAuth(comptime Lookup: type, comptime kind: BasicAuthStrategy) type {
"ERROR: UserPassAuth: cannot calc slize size for encoded `{s}`: {any} \n", "ERROR: UserPassAuth: cannot calc slize size for encoded `{s}`: {any} \n",
.{ encoded, err }, .{ encoded, err },
); );
return false; return .AuthFailed;
} }
zap.debug("UNREACHABLE\n", .{}); zap.debug("UNREACHABLE\n", .{});
return false; return .AuthFailed;
} }
/// Use this to just look up if the base64-encoded auth_header exists in lookup /// Use this to just look up if the base64-encoded auth_header exists in lookup
pub fn authenticateToken68(self: *Self, auth_header: []const u8) bool { pub fn authenticateToken68(self: *Self, auth_header: []const u8) AuthResult {
const token = auth_header[AuthScheme.Basic.str().len..]; const token = auth_header[AuthScheme.Basic.str().len..];
return self.lookup.*.contains(token); return if (self.lookup.*.contains(token)) .AuthOK else .AuthFailed;
} }
// dispatch based on kind // dispatch based on kind
pub fn authenticate(self: *Self, auth_header: []const u8) bool { pub fn authenticate(self: *Self, auth_header: []const u8) AuthResult {
zap.debug("AUTHENTICATE\n", .{}); zap.debug("AUTHENTICATE\n", .{});
// switch (self.kind) { // switch (self.kind) {
switch (kind) { switch (kind) {
@ -167,7 +180,7 @@ pub fn BasicAuth(comptime Lookup: type, comptime kind: BasicAuthStrategy) type {
.Token68 => return self.authenticateToken68(auth_header), .Token68 => return self.authenticateToken68(auth_header),
} }
} }
pub fn authenticateRequest(self: *Self, r: *const zap.SimpleRequest) bool { pub fn authenticateRequest(self: *Self, r: *const zap.SimpleRequest) AuthResult {
zap.debug("AUTHENTICATE REQUEST\n", .{}); zap.debug("AUTHENTICATE REQUEST\n", .{});
if (extractAuthHeader(.Basic, r)) |auth_header| { if (extractAuthHeader(.Basic, r)) |auth_header| {
zap.debug("Authentication Header found!\n", .{}); zap.debug("Authentication Header found!\n", .{});
@ -180,7 +193,7 @@ pub fn BasicAuth(comptime Lookup: type, comptime kind: BasicAuthStrategy) type {
} }
} }
zap.debug("NO fitting Auth Header found!\n", .{}); zap.debug("NO fitting Auth Header found!\n", .{});
return false; return .AuthFailed;
} }
}; };
} }
@ -213,19 +226,19 @@ pub const BearerAuthSingle = struct {
.realm = if (realm) |the_realm| try allocator.dupe(u8, the_realm) else null, .realm = if (realm) |the_realm| try allocator.dupe(u8, the_realm) else null,
}; };
} }
pub fn authenticate(self: *Self, auth_header: []const u8) bool { pub fn authenticate(self: *Self, auth_header: []const u8) AuthResult {
if (checkAuthHeader(.Bearer, auth_header) == false) { if (checkAuthHeader(.Bearer, auth_header) == false) {
return false; return .AuthFailed;
} }
const token = auth_header[AuthScheme.Bearer.str().len..]; const token = auth_header[AuthScheme.Bearer.str().len..];
return std.mem.eql(u8, token, self.token); return if (std.mem.eql(u8, token, self.token)) .AuthOK else .AuthFailed;
} }
pub fn authenticateRequest(self: *Self, r: *const zap.SimpleRequest) bool { pub fn authenticateRequest(self: *Self, r: *const zap.SimpleRequest) AuthResult {
if (extractAuthHeader(.Bearer, r)) |auth_header| { if (extractAuthHeader(.Bearer, r)) |auth_header| {
return self.authenticate(auth_header); return self.authenticate(auth_header);
} }
return false; return .AuthFailed;
} }
pub fn deinit(self: *Self) void { pub fn deinit(self: *Self) void {
@ -271,19 +284,19 @@ pub fn BearerAuthMulti(comptime Lookup: type) type {
} }
} }
pub fn authenticate(self: *Self, auth_header: []const u8) bool { pub fn authenticate(self: *Self, auth_header: []const u8) AuthResult {
if (checkAuthHeader(.Bearer, auth_header) == false) { if (checkAuthHeader(.Bearer, auth_header) == false) {
return false; return .AuthFailed;
} }
const token = auth_header[AuthScheme.Bearer.str().len..]; const token = auth_header[AuthScheme.Bearer.str().len..];
return self.lookup.*.contains(token); return if (self.lookup.*.contains(token)) .AuthOK else .AuthFailed;
} }
pub fn authenticateRequest(self: *Self, r: *const zap.SimpleRequest) bool { pub fn authenticateRequest(self: *Self, r: *const zap.SimpleRequest) AuthResult {
if (extractAuthHeader(.Bearer, r)) |auth_header| { if (extractAuthHeader(.Bearer, r)) |auth_header| {
return self.authenticate(auth_header); return self.authenticate(auth_header);
} }
return false; return .AuthFailed;
} }
}; };
} }

View file

@ -17,10 +17,10 @@ test "BearerAuthSingle authenticate" {
defer auth.deinit(); defer auth.deinit();
// invalid auth header // invalid auth header
try std.testing.expectEqual(auth.authenticate("wrong header"), false); try std.testing.expectEqual(auth.authenticate("wrong header"), .AuthFailed);
try std.testing.expectEqual(auth.authenticate("Bearer wrong-token"), false); try std.testing.expectEqual(auth.authenticate("Bearer wrong-token"), .AuthFailed);
// ok // ok
try std.testing.expectEqual(auth.authenticate("Bearer " ++ token), true); try std.testing.expectEqual(auth.authenticate("Bearer " ++ token), .AuthOK);
} }
test "BearerAuthMulti authenticate" { test "BearerAuthMulti authenticate" {
@ -37,10 +37,10 @@ test "BearerAuthMulti authenticate" {
defer auth.deinit(); defer auth.deinit();
// invalid auth header // invalid auth header
try std.testing.expectEqual(auth.authenticate("wrong header"), false); try std.testing.expectEqual(auth.authenticate("wrong header"), .AuthFailed);
try std.testing.expectEqual(auth.authenticate("Bearer wrong-token"), false); try std.testing.expectEqual(auth.authenticate("Bearer wrong-token"), .AuthFailed);
// ok // ok
try std.testing.expectEqual(auth.authenticate("Bearer " ++ token), true); try std.testing.expectEqual(auth.authenticate("Bearer " ++ token), .AuthOK);
} }
test "BasicAuth Token68" { test "BasicAuth Token68" {
@ -59,10 +59,10 @@ test "BasicAuth Token68" {
defer auth.deinit(); defer auth.deinit();
// invalid auth header // invalid auth header
try std.testing.expectEqual(auth.authenticate("wrong header"), false); try std.testing.expectEqual(auth.authenticate("wrong header"), .AuthFailed);
try std.testing.expectEqual(auth.authenticate("Basic wrong-token"), false); try std.testing.expectEqual(auth.authenticate("Basic wrong-token"), .AuthFailed);
// ok // ok
try std.testing.expectEqual(auth.authenticate("Basic " ++ token), true); try std.testing.expectEqual(auth.authenticate("Basic " ++ token), .AuthOK);
} }
test "BasicAuth UserPass" { test "BasicAuth UserPass" {
@ -90,13 +90,13 @@ test "BasicAuth UserPass" {
defer auth.deinit(); defer auth.deinit();
// invalid auth header // invalid auth header
try std.testing.expectEqual(auth.authenticate("wrong header"), false); try std.testing.expectEqual(auth.authenticate("wrong header"), .AuthFailed);
try std.testing.expectEqual(auth.authenticate("Basic wrong-token"), false); try std.testing.expectEqual(auth.authenticate("Basic wrong-token"), .AuthFailed);
// ok // ok
const expected = try std.fmt.allocPrint(a, "Basic {s}", .{encoded}); const expected = try std.fmt.allocPrint(a, "Basic {s}", .{encoded});
defer a.free(expected); defer a.free(expected);
try std.testing.expectEqual(auth.authenticate(expected), true); try std.testing.expectEqual(auth.authenticate(expected), .AuthOK);
} }
const HTTP_RESPONSE: []const u8 = const HTTP_RESPONSE: []const u8 =