mirror of
				https://github.com/zigzap/zap.git
				synced 2025-10-20 15:14:08 +00:00 
			
		
		
		
	AuthResult is now an enum (was: bool)
This commit is contained in:
		
							parent
							
								
									ed156ba654
								
							
						
					
					
						commit
						645db5c8d0
					
				
					 3 changed files with 100 additions and 79 deletions
				
			
		|  | @ -89,69 +89,77 @@ pub fn AuthenticatingEndpoint(comptime Authenticator: type) type { | |||
|         /// here, the auth_endpoint will be passed in | ||||
|         pub fn get(e: *SimpleEndpoint, r: zap.SimpleRequest) void { | ||||
|             const authEp: *Self = @fieldParentPtr(Self, "auth_endpoint", e); | ||||
|             if (authEp.authenticator.authenticateRequest(&r) == false) { | ||||
|                 if (e.settings.unauthorized) |foo| { | ||||
|                     foo(authEp.endpoint, r); | ||||
|                     return; | ||||
|                 } else { | ||||
|                     r.setStatus(.unauthorized); | ||||
|                     r.sendBody("UNAUTHORIZED") catch return; | ||||
|                     return; | ||||
|                 } | ||||
|             switch (authEp.authenticator.authenticateRequest(&r)) { | ||||
|                 .AuthFailed => { | ||||
|                     if (e.settings.unauthorized) |unauthorized| { | ||||
|                         unauthorized(authEp.endpoint, r); | ||||
|                         return; | ||||
|                     } else { | ||||
|                         r.setStatus(.unauthorized); | ||||
|                         r.sendBody("UNAUTHORIZED") catch return; | ||||
|                         return; | ||||
|                     } | ||||
|                 }, | ||||
|                 .AuthOK => authEp.endpoint.settings.get.?(authEp.endpoint, r), | ||||
|                 .Handled => {}, | ||||
|             } | ||||
|             // auth successful | ||||
|             authEp.endpoint.settings.get.?(authEp.endpoint, r); | ||||
|         } | ||||
| 
 | ||||
|         /// here, the auth_endpoint will be passed in | ||||
|         pub fn post(e: *SimpleEndpoint, r: zap.SimpleRequest) void { | ||||
|             const authEp: *Self = @fieldParentPtr(Self, "auth_endpoint", e); | ||||
|             if (authEp.authenticator.authenticateRequest(&r) == false) { | ||||
|                 if (e.settings.unauthorized) |foo| { | ||||
|                     foo(authEp.endpoint, r); | ||||
|                     return; | ||||
|                 } else { | ||||
|                     r.setStatus(.unauthorized); | ||||
|                     r.sendBody("UNAUTHORIZED") catch return; | ||||
|                     return; | ||||
|                 } | ||||
|             switch (authEp.authenticator.authenticateRequest(&r)) { | ||||
|                 .AuthFailed => { | ||||
|                     if (e.settings.unauthorized) |unauthorized| { | ||||
|                         unauthorized(authEp.endpoint, r); | ||||
|                         return; | ||||
|                     } else { | ||||
|                         r.setStatus(.unauthorized); | ||||
|                         r.sendBody("UNAUTHORIZED") catch return; | ||||
|                         return; | ||||
|                     } | ||||
|                 }, | ||||
|                 .AuthOK => authEp.endpoint.settings.post.?(authEp.endpoint, r), | ||||
|                 .Handled => {}, | ||||
|             } | ||||
|             // auth successful | ||||
|             authEp.endpoint.settings.post.?(authEp.endpoint, r); | ||||
|         } | ||||
| 
 | ||||
|         /// here, the auth_endpoint will be passed in | ||||
|         pub fn put(e: *SimpleEndpoint, r: zap.SimpleRequest) void { | ||||
|             const authEp: *Self = @fieldParentPtr(Self, "auth_endpoint", e); | ||||
|             if (authEp.authenticator.authenticateRequest(&r) == false) { | ||||
|                 if (e.settings.unauthorized) |foo| { | ||||
|                     foo(authEp.endpoint, r); | ||||
|                     return; | ||||
|                 } else { | ||||
|                     r.setStatus(.unauthorized); | ||||
|                     r.sendBody("UNAUTHORIZED") catch return; | ||||
|                     return; | ||||
|                 } | ||||
|             switch (authEp.authenticator.authenticateRequest(&r)) { | ||||
|                 .AuthFailed => { | ||||
|                     if (e.settings.unauthorized) |unauthorized| { | ||||
|                         unauthorized(authEp.endpoint, r); | ||||
|                         return; | ||||
|                     } else { | ||||
|                         r.setStatus(.unauthorized); | ||||
|                         r.sendBody("UNAUTHORIZED") catch return; | ||||
|                         return; | ||||
|                     } | ||||
|                 }, | ||||
|                 .AuthOK => authEp.endpoint.settings.put.?(authEp.endpoint, r), | ||||
|                 .Handled => {}, | ||||
|             } | ||||
|             // auth successful | ||||
|             authEp.endpoint.settings.put.?(authEp.endpoint, r); | ||||
|         } | ||||
| 
 | ||||
|         /// here, the auth_endpoint will be passed in | ||||
|         pub fn delete(e: *SimpleEndpoint, r: zap.SimpleRequest) void { | ||||
|             const authEp: *Self = @fieldParentPtr(Self, "auth_endpoint", e); | ||||
|             if (authEp.authenticator.authenticateRequest(&r) == false) { | ||||
|                 if (e.settings.unauthorized) |foo| { | ||||
|                     foo(authEp.endpoint, r); | ||||
|                     return; | ||||
|                 } else { | ||||
|                     r.setStatus(.unauthorized); | ||||
|                     r.sendBody("UNAUTHORIZED") catch return; | ||||
|                     return; | ||||
|                 } | ||||
|             switch (authEp.authenticator.authenticateRequest(&r)) { | ||||
|                 .AuthFailed => { | ||||
|                     if (e.settings.unauthorized) |unauthorized| { | ||||
|                         unauthorized(authEp.endpoint, r); | ||||
|                         return; | ||||
|                     } else { | ||||
|                         r.setStatus(.unauthorized); | ||||
|                         r.sendBody("UNAUTHORIZED") catch return; | ||||
|                         return; | ||||
|                     } | ||||
|                 }, | ||||
|                 .AuthOK => authEp.endpoint.settings.delete.?(authEp.endpoint, r), | ||||
|                 .Handled => {}, | ||||
|             } | ||||
|             // auth successful | ||||
|             authEp.endpoint.settings.delete.?(authEp.endpoint, r); | ||||
|         } | ||||
|     }; | ||||
| } | ||||
|  |  | |||
|  | @ -48,6 +48,19 @@ const BasicAuthStrategy = enum { | |||
|     Token68, | ||||
| }; | ||||
| 
 | ||||
| pub const AuthResult = enum { | ||||
|     /// authentication / authorization was successful | ||||
|     AuthOK, | ||||
|     /// authentication / authorization failed | ||||
|     AuthFailed, | ||||
|     /// the authenticator handled the request that didn't pass authentication / | ||||
|     /// authorization . | ||||
|     /// this is used to implement authenticators that redirect to a login | ||||
|     /// page. An AuthenticatingEndpoint will not do the default, which is trying | ||||
|     /// to call the `unauthorized` callback or. | ||||
|     Handled, | ||||
| }; | ||||
| 
 | ||||
| /// HTTP Basic Authentication RFC 7617 | ||||
| /// "Authorization: Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==" | ||||
| /// user-pass strings: "$username:$password" -> base64 | ||||
|  | @ -91,7 +104,7 @@ pub fn BasicAuth(comptime Lookup: type, comptime kind: BasicAuthStrategy) type { | |||
|         } | ||||
| 
 | ||||
|         /// Use this to decode the auth_header into user:pass, lookup pass in lookup | ||||
|         pub fn authenticateUserPass(self: *Self, auth_header: []const u8) bool { | ||||
|         pub fn authenticateUserPass(self: *Self, auth_header: []const u8) AuthResult { | ||||
|             zap.debug("AuthenticateUserPass\n", .{}); | ||||
|             const encoded = auth_header[AuthScheme.Basic.str().len..]; | ||||
|             const decoder = std.base64.standard.Decoder; | ||||
|  | @ -102,7 +115,7 @@ pub fn BasicAuth(comptime Lookup: type, comptime kind: BasicAuthStrategy) type { | |||
|                         "ERROR: UserPassAuth: decoded_size {d} >= buffer.len {d}\n", | ||||
|                         .{ decoded_size, buffer.len }, | ||||
|                     ); | ||||
|                     return false; | ||||
|                     return .AuthFailed; | ||||
|                 } | ||||
|                 var decoded = buffer[0..decoded_size]; | ||||
|                 decoder.decode(decoded, encoded) catch |err| { | ||||
|  | @ -110,7 +123,7 @@ pub fn BasicAuth(comptime Lookup: type, comptime kind: BasicAuthStrategy) type { | |||
|                         "ERROR: UserPassAuth: unable to decode `{s}`: {any}\n", | ||||
|                         .{ encoded, err }, | ||||
|                     ); | ||||
|                     return false; | ||||
|                     return .AuthFailed; | ||||
|                 }; | ||||
|                 // we have decoded | ||||
|                 // we can split | ||||
|  | @ -122,7 +135,7 @@ pub fn BasicAuth(comptime Lookup: type, comptime kind: BasicAuthStrategy) type { | |||
|                         "ERROR: UserPassAuth: user {any} or pass {any} is null\n", | ||||
|                         .{ user, pass }, | ||||
|                     ); | ||||
|                     return false; | ||||
|                     return .AuthFailed; | ||||
|                 } | ||||
|                 // now, do the lookup | ||||
|                 const actual_pw = self.lookup.*.get(user.?); | ||||
|  | @ -132,13 +145,13 @@ pub fn BasicAuth(comptime Lookup: type, comptime kind: BasicAuthStrategy) type { | |||
|                         "INFO: UserPassAuth for user `{s}`: `{s}` == pass `{s}` = {}\n", | ||||
|                         .{ user.?, pw, pass.?, ret }, | ||||
|                     ); | ||||
|                     return ret; | ||||
|                     return if (ret) .AuthOK else .AuthFailed; | ||||
|                 } else { | ||||
|                     zap.debug( | ||||
|                         "ERROR: UserPassAuth: user `{s}` not found in map of size {d}!\n", | ||||
|                         .{ user.?, self.lookup.*.count() }, | ||||
|                     ); | ||||
|                     return false; | ||||
|                     return .AuthFailed; | ||||
|                 } | ||||
|             } else |err| { | ||||
|                 // can't calc slice size --> fallthrough to return false | ||||
|  | @ -146,20 +159,20 @@ pub fn BasicAuth(comptime Lookup: type, comptime kind: BasicAuthStrategy) type { | |||
|                     "ERROR: UserPassAuth: cannot calc slize size for encoded `{s}`: {any} \n", | ||||
|                     .{ encoded, err }, | ||||
|                 ); | ||||
|                 return false; | ||||
|                 return .AuthFailed; | ||||
|             } | ||||
|             zap.debug("UNREACHABLE\n", .{}); | ||||
|             return false; | ||||
|             return .AuthFailed; | ||||
|         } | ||||
| 
 | ||||
|         /// Use this to just look up if the base64-encoded auth_header exists in lookup | ||||
|         pub fn authenticateToken68(self: *Self, auth_header: []const u8) bool { | ||||
|         pub fn authenticateToken68(self: *Self, auth_header: []const u8) AuthResult { | ||||
|             const token = auth_header[AuthScheme.Basic.str().len..]; | ||||
|             return self.lookup.*.contains(token); | ||||
|             return if (self.lookup.*.contains(token)) .AuthOK else .AuthFailed; | ||||
|         } | ||||
| 
 | ||||
|         // dispatch based on kind | ||||
|         pub fn authenticate(self: *Self, auth_header: []const u8) bool { | ||||
|         pub fn authenticate(self: *Self, auth_header: []const u8) AuthResult { | ||||
|             zap.debug("AUTHENTICATE\n", .{}); | ||||
|             // switch (self.kind) { | ||||
|             switch (kind) { | ||||
|  | @ -167,7 +180,7 @@ pub fn BasicAuth(comptime Lookup: type, comptime kind: BasicAuthStrategy) type { | |||
|                 .Token68 => return self.authenticateToken68(auth_header), | ||||
|             } | ||||
|         } | ||||
|         pub fn authenticateRequest(self: *Self, r: *const zap.SimpleRequest) bool { | ||||
|         pub fn authenticateRequest(self: *Self, r: *const zap.SimpleRequest) AuthResult { | ||||
|             zap.debug("AUTHENTICATE REQUEST\n", .{}); | ||||
|             if (extractAuthHeader(.Basic, r)) |auth_header| { | ||||
|                 zap.debug("Authentication Header found!\n", .{}); | ||||
|  | @ -180,7 +193,7 @@ pub fn BasicAuth(comptime Lookup: type, comptime kind: BasicAuthStrategy) type { | |||
|                 } | ||||
|             } | ||||
|             zap.debug("NO fitting Auth Header found!\n", .{}); | ||||
|             return false; | ||||
|             return .AuthFailed; | ||||
|         } | ||||
|     }; | ||||
| } | ||||
|  | @ -213,19 +226,19 @@ pub const BearerAuthSingle = struct { | |||
|             .realm = if (realm) |the_realm| try allocator.dupe(u8, the_realm) else null, | ||||
|         }; | ||||
|     } | ||||
|     pub fn authenticate(self: *Self, auth_header: []const u8) bool { | ||||
|     pub fn authenticate(self: *Self, auth_header: []const u8) AuthResult { | ||||
|         if (checkAuthHeader(.Bearer, auth_header) == false) { | ||||
|             return false; | ||||
|             return .AuthFailed; | ||||
|         } | ||||
|         const token = auth_header[AuthScheme.Bearer.str().len..]; | ||||
|         return std.mem.eql(u8, token, self.token); | ||||
|         return if (std.mem.eql(u8, token, self.token)) .AuthOK else .AuthFailed; | ||||
|     } | ||||
| 
 | ||||
|     pub fn authenticateRequest(self: *Self, r: *const zap.SimpleRequest) bool { | ||||
|     pub fn authenticateRequest(self: *Self, r: *const zap.SimpleRequest) AuthResult { | ||||
|         if (extractAuthHeader(.Bearer, r)) |auth_header| { | ||||
|             return self.authenticate(auth_header); | ||||
|         } | ||||
|         return false; | ||||
|         return .AuthFailed; | ||||
|     } | ||||
| 
 | ||||
|     pub fn deinit(self: *Self) void { | ||||
|  | @ -271,19 +284,19 @@ pub fn BearerAuthMulti(comptime Lookup: type) type { | |||
|             } | ||||
|         } | ||||
| 
 | ||||
|         pub fn authenticate(self: *Self, auth_header: []const u8) bool { | ||||
|         pub fn authenticate(self: *Self, auth_header: []const u8) AuthResult { | ||||
|             if (checkAuthHeader(.Bearer, auth_header) == false) { | ||||
|                 return false; | ||||
|                 return .AuthFailed; | ||||
|             } | ||||
|             const token = auth_header[AuthScheme.Bearer.str().len..]; | ||||
|             return self.lookup.*.contains(token); | ||||
|             return if (self.lookup.*.contains(token)) .AuthOK else .AuthFailed; | ||||
|         } | ||||
| 
 | ||||
|         pub fn authenticateRequest(self: *Self, r: *const zap.SimpleRequest) bool { | ||||
|         pub fn authenticateRequest(self: *Self, r: *const zap.SimpleRequest) AuthResult { | ||||
|             if (extractAuthHeader(.Bearer, r)) |auth_header| { | ||||
|                 return self.authenticate(auth_header); | ||||
|             } | ||||
|             return false; | ||||
|             return .AuthFailed; | ||||
|         } | ||||
|     }; | ||||
| } | ||||
|  |  | |||
|  | @ -17,10 +17,10 @@ test "BearerAuthSingle authenticate" { | |||
|     defer auth.deinit(); | ||||
| 
 | ||||
|     // invalid auth header | ||||
|     try std.testing.expectEqual(auth.authenticate("wrong header"), false); | ||||
|     try std.testing.expectEqual(auth.authenticate("Bearer wrong-token"), false); | ||||
|     try std.testing.expectEqual(auth.authenticate("wrong header"), .AuthFailed); | ||||
|     try std.testing.expectEqual(auth.authenticate("Bearer wrong-token"), .AuthFailed); | ||||
|     // ok | ||||
|     try std.testing.expectEqual(auth.authenticate("Bearer " ++ token), true); | ||||
|     try std.testing.expectEqual(auth.authenticate("Bearer " ++ token), .AuthOK); | ||||
| } | ||||
| 
 | ||||
| test "BearerAuthMulti authenticate" { | ||||
|  | @ -37,10 +37,10 @@ test "BearerAuthMulti authenticate" { | |||
|     defer auth.deinit(); | ||||
| 
 | ||||
|     // invalid auth header | ||||
|     try std.testing.expectEqual(auth.authenticate("wrong header"), false); | ||||
|     try std.testing.expectEqual(auth.authenticate("Bearer wrong-token"), false); | ||||
|     try std.testing.expectEqual(auth.authenticate("wrong header"), .AuthFailed); | ||||
|     try std.testing.expectEqual(auth.authenticate("Bearer wrong-token"), .AuthFailed); | ||||
|     // ok | ||||
|     try std.testing.expectEqual(auth.authenticate("Bearer " ++ token), true); | ||||
|     try std.testing.expectEqual(auth.authenticate("Bearer " ++ token), .AuthOK); | ||||
| } | ||||
| 
 | ||||
| test "BasicAuth Token68" { | ||||
|  | @ -59,10 +59,10 @@ test "BasicAuth Token68" { | |||
|     defer auth.deinit(); | ||||
| 
 | ||||
|     // invalid auth header | ||||
|     try std.testing.expectEqual(auth.authenticate("wrong header"), false); | ||||
|     try std.testing.expectEqual(auth.authenticate("Basic wrong-token"), false); | ||||
|     try std.testing.expectEqual(auth.authenticate("wrong header"), .AuthFailed); | ||||
|     try std.testing.expectEqual(auth.authenticate("Basic wrong-token"), .AuthFailed); | ||||
|     // ok | ||||
|     try std.testing.expectEqual(auth.authenticate("Basic " ++ token), true); | ||||
|     try std.testing.expectEqual(auth.authenticate("Basic " ++ token), .AuthOK); | ||||
| } | ||||
| 
 | ||||
| test "BasicAuth UserPass" { | ||||
|  | @ -90,13 +90,13 @@ test "BasicAuth UserPass" { | |||
|     defer auth.deinit(); | ||||
| 
 | ||||
|     // invalid auth header | ||||
|     try std.testing.expectEqual(auth.authenticate("wrong header"), false); | ||||
|     try std.testing.expectEqual(auth.authenticate("Basic wrong-token"), false); | ||||
|     try std.testing.expectEqual(auth.authenticate("wrong header"), .AuthFailed); | ||||
|     try std.testing.expectEqual(auth.authenticate("Basic wrong-token"), .AuthFailed); | ||||
|     // ok | ||||
|     const expected = try std.fmt.allocPrint(a, "Basic {s}", .{encoded}); | ||||
| 
 | ||||
|     defer a.free(expected); | ||||
|     try std.testing.expectEqual(auth.authenticate(expected), true); | ||||
|     try std.testing.expectEqual(auth.authenticate(expected), .AuthOK); | ||||
| } | ||||
| 
 | ||||
| const HTTP_RESPONSE: []const u8 = | ||||
|  |  | |||
		Loading…
	
	Add table
		
		Reference in a new issue
	
	 Rene Schallner
						Rene Schallner