diff --git a/src/http_auth.zig b/src/http_auth.zig index d6c4bce..6a384de 100644 --- a/src/http_auth.zig +++ b/src/http_auth.zig @@ -92,9 +92,31 @@ pub fn BasicAuth(comptime Lookup: type, comptime kind: BasicAuthStrategy) type { /// Use this to decode the auth_header into user:pass, lookup pass in lookup pub fn authenticateUserPass(self: *Self, auth_header: []const u8) bool { - _ = auth_header; - _ = self; - // TODO + const encoded = auth_header[AuthScheme.Basic.str().len..]; + const decoder = std.base64.standard.Decoder; + var buffer: [0x100]u8 = undefined; + if (decoder.calcSizeForSlice(encoded)) |decoded_size| { + if (decoded_size >= buffer.len) { + return false; + } + var decoded = buffer[0..decoded_size]; + decoder.decode(decoded, encoded) catch return false; + // we have decoded + // we can split + var it = std.mem.split(u8, decoded, ":"); + const user = it.next(); + const pass = it.next(); + if (user == null or pass == null) { + return false; + } + // now, do the lookup + const actual_pw = self.lookup.*.get(user.?); + if (actual_pw) |pw| { + return std.mem.eql(u8, pass.?, pw); + } + } else |_| { + // can't calc slice size --> fallthrough to return false + } return false; } diff --git a/src/http_client_testrunner.zig b/src/http_client_testrunner.zig index c4d02b9..20b91ab 100644 --- a/src/http_client_testrunner.zig +++ b/src/http_client_testrunner.zig @@ -40,5 +40,25 @@ pub fn main() !void { }, a); _ = try p.spawnAndWait(); + std.time.sleep(3 * std.time.ns_per_s); + + p = std.ChildProcess.init(&.{ + "./zig-out/bin/http_client", + "http://127.0.0.1:3000/test", + "Basic", + "QWxsYWRkaW46b3BlbnNlc2FtZQ==", + }, a); + _ = try p.spawnAndWait(); + + std.time.sleep(3 * std.time.ns_per_s); + + p = std.ChildProcess.init(&.{ + "./zig-out/bin/http_client", + "http://127.0.0.1:3000/test", + "Basic", + "QWxsYWRkaW46b3BlbnNlc2FtZQ==-invalid", + }, a); + _ = try p.spawnAndWait(); + // std.time.sleep(3 * std.time.ns_per_s); } diff --git a/src/test_auth.zig b/src/test_auth.zig index 650b1d0..0e97fac 100644 --- a/src/test_auth.zig +++ b/src/test_auth.zig @@ -15,6 +15,7 @@ test "BearerAuthSingle authenticate" { // invalid auth header try std.testing.expectEqual(auth.authenticate("wrong header"), false); try std.testing.expectEqual(auth.authenticate("Bearer wrong-token"), false); + // ok try std.testing.expectEqual(auth.authenticate("Bearer " ++ token), true); } @@ -33,9 +34,68 @@ test "BearerAuthMulti authenticate" { // invalid auth header try std.testing.expectEqual(auth.authenticate("wrong header"), false); try std.testing.expectEqual(auth.authenticate("Bearer wrong-token"), false); + // ok try std.testing.expectEqual(auth.authenticate("Bearer " ++ token), true); } +test "BasicAuth Token68" { + const a = std.testing.allocator; + const token = "QWxhZGRpbjpvcGVuIHNlc2FtZQ=="; + + // create a set of Token68 entries + const Set = std.StringHashMap(void); + var set = Set.init(a); // set + defer set.deinit(); + try set.put(token, {}); + + // create authenticator + const Authenticator = Authenticators.BasicAuth(Set, .Token68); + var auth = try Authenticator.init(a, &set, null); + defer auth.deinit(); + + // invalid auth header + try std.testing.expectEqual(auth.authenticate("wrong header"), false); + try std.testing.expectEqual(auth.authenticate("Basic wrong-token"), false); + // ok + try std.testing.expectEqual(auth.authenticate("Basic " ++ token), true); +} + +test "BasicAuth UserPass" { + const a = std.testing.allocator; + + // create a set of User -> Pass entries + const Map = std.StringHashMap([]const u8); + var map = Map.init(a); // set + defer map.deinit(); + + // create user / pass entry + const user = "Alladdin"; + const pass = "opensesame"; + try map.put(user, pass); + + // now, do the encoding for the Basic auth + const token = user ++ ":" ++ pass; + var encoder = std.base64.url_safe.Encoder; + var buffer: [256]u8 = undefined; + const encoded = encoder.encode(&buffer, token); + std.debug.print("\nEncoded: {s}\n", .{encoded}); + + // create authenticator + const Authenticator = Authenticators.BasicAuth(Map, .UserPass); + var auth = try Authenticator.init(a, &map, null); + defer auth.deinit(); + + // invalid auth header + try std.testing.expectEqual(auth.authenticate("wrong header"), false); + try std.testing.expectEqual(auth.authenticate("Basic wrong-token"), false); + // ok + const expected = try std.fmt.allocPrint(a, "Basic {s}", .{encoded}); + std.debug.print("Expected: {s}\n", .{expected}); + + defer a.free(expected); + try std.testing.expectEqual(auth.authenticate(expected), true); +} + const HTTP_RESPONSE: []const u8 = \\ \\ Hello from ZAP!!! @@ -252,7 +312,7 @@ test "BearerAuthSingle authenticateRequest test-unauthorized" { try std.testing.expectEqualStrings("UNAUTHORIZED", received_response); } -test "BasicAuth authenticateRequest" { +test "BasicAuth Token68 authenticateRequest" { const a = std.testing.allocator; const token = "QWxhZGRpbjpvcGVuIHNlc2FtZQ=="; @@ -306,7 +366,7 @@ test "BasicAuth authenticateRequest" { try std.testing.expectEqualStrings(HTTP_RESPONSE, received_response); } -test "BasicAuth authenticateRequest test-unauthorized" { +test "BasicAuth Token68 authenticateRequest test-unauthorized" { const a = std.testing.allocator; const token = "QWxhZGRpbjpvcGVuIHNlc2FtZQ=="; @@ -359,3 +419,133 @@ test "BasicAuth authenticateRequest test-unauthorized" { try std.testing.expectEqualStrings("UNAUTHORIZED", received_response); } + +test "BasicAuth UserPass authenticateRequest" { + const a = std.testing.allocator; + + // setup listener + var listener = zap.SimpleEndpointListener.init( + a, + .{ + .port = 3000, + .on_request = null, + .log = false, + .max_clients = 10, + .max_body_size = 1 * 1024, + }, + ); + defer listener.deinit(); + + // create mini endpoint + var ep = Endpoints.SimpleEndpoint.init(.{ + .path = "/test", + .get = endpoint_http_get, + .unauthorized = endpoint_http_unauthorized, + }); + + // create a set of User -> Pass entries + const Map = std.StringHashMap([]const u8); + var map = Map.init(a); // set + defer map.deinit(); + + // create user / pass entry + const user = "Alladdin"; + const pass = "opensesame"; + try map.put(user, pass); + + // now, do the encoding for the Basic auth + const token = user ++ ":" ++ pass; + var encoder = std.base64.url_safe.Encoder; + var buffer: [256]u8 = undefined; + const encoded = encoder.encode(&buffer, token); + std.debug.print("\nEncoded: {s}\n", .{encoded}); + + // create authenticator + const Authenticator = Authenticators.BasicAuth(Map, .UserPass); + var authenticator = try Authenticator.init(a, &map, null); + defer authenticator.deinit(); + + // create authenticating endpoint + const BearerAuthEndpoint = Endpoints.AuthenticatingEndpoint(Authenticator); + var auth_ep = BearerAuthEndpoint.init(&ep, &authenticator); + + try listener.addEndpoint(auth_ep.getEndpoint()); + + listener.listen() catch {}; + std.debug.print("Listening on 0.0.0.0:3000\n", .{}); + std.debug.print("Please run the following:\n", .{}); + std.debug.print("./zig-out/bin/http_client http://127.0.0.1:3000/test Basic {s}\n", .{encoded}); + + // start worker threads + zap.start(.{ + .threads = 1, + .workers = 0, + }); + + try std.testing.expectEqualStrings(HTTP_RESPONSE, received_response); +} + +test "BasicAuth UserPass authenticateRequest test-unauthorized" { + const a = std.testing.allocator; + + // setup listener + var listener = zap.SimpleEndpointListener.init( + a, + .{ + .port = 3000, + .on_request = null, + .log = false, + .max_clients = 10, + .max_body_size = 1 * 1024, + }, + ); + defer listener.deinit(); + + // create mini endpoint + var ep = Endpoints.SimpleEndpoint.init(.{ + .path = "/test", + .get = endpoint_http_get, + .unauthorized = endpoint_http_unauthorized, + }); + + // create a set of User -> Pass entries + const Map = std.StringHashMap([]const u8); + var map = Map.init(a); // set + defer map.deinit(); + + // create user / pass entry + const user = "Alladdin"; + const pass = "opensesame"; + try map.put(user, pass); + + // now, do the encoding for the Basic auth + const token = user ++ ":" ++ pass; + var encoder = std.base64.url_safe.Encoder; + var buffer: [256]u8 = undefined; + const encoded = encoder.encode(&buffer, token); + std.debug.print("\nEncoded: {s}\n", .{encoded}); + + // create authenticator + const Authenticator = Authenticators.BasicAuth(Map, .UserPass); + var authenticator = try Authenticator.init(a, &map, null); + defer authenticator.deinit(); + + // create authenticating endpoint + const BearerAuthEndpoint = Endpoints.AuthenticatingEndpoint(Authenticator); + var auth_ep = BearerAuthEndpoint.init(&ep, &authenticator); + + try listener.addEndpoint(auth_ep.getEndpoint()); + + listener.listen() catch {}; + std.debug.print("Listening on 0.0.0.0:3000\n", .{}); + std.debug.print("Please run the following:\n", .{}); + std.debug.print("./zig-out/bin/http_client http://127.0.0.1:3000/test Basic {s}-invalid\n", .{encoded}); + + // start worker threads + zap.start(.{ + .threads = 1, + .workers = 0, + }); + + try std.testing.expectEqualStrings("UNAUTHORIZED", received_response); +}