Merge pull request #25034 from ziglang/lzma

std.compress: update lzma, lzma2, and xz to new I/O API
This commit is contained in:
Andrew Kelley 2025-08-27 06:49:45 -07:00 committed by GitHub
commit 50edad37ba
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 1470 additions and 1612 deletions

View file

@ -2,89 +2,759 @@ const std = @import("../std.zig");
const math = std.math; const math = std.math;
const mem = std.mem; const mem = std.mem;
const Allocator = std.mem.Allocator; const Allocator = std.mem.Allocator;
const assert = std.debug.assert;
const ArrayList = std.ArrayList;
const Writer = std.Io.Writer;
const Reader = std.Io.Reader;
pub const decode = @import("lzma/decode.zig"); pub const RangeDecoder = struct {
range: u32,
code: u32,
pub fn decompress( pub fn init(reader: *Reader) !RangeDecoder {
allocator: Allocator, var counter: u64 = 0;
reader: anytype, return initCounting(reader, &counter);
) !Decompress(@TypeOf(reader)) { }
return decompressWithOptions(allocator, reader, .{});
}
pub fn decompressWithOptions( pub fn initCounting(reader: *Reader, n_read: *u64) !RangeDecoder {
allocator: Allocator, const reserved = try reader.takeByte();
reader: anytype, n_read.* += 1;
options: decode.Options, if (reserved != 0) return error.InvalidRangeCode;
) !Decompress(@TypeOf(reader)) { const code = try reader.takeInt(u32, .big);
const params = try decode.Params.readHeader(reader, options); n_read.* += 4;
return Decompress(@TypeOf(reader)).init(allocator, reader, params, options.memlimit); return .{
} .range = 0xFFFF_FFFF,
.code = code,
pub fn Decompress(comptime ReaderType: type) type {
return struct {
const Self = @This();
pub const Error =
ReaderType.Error ||
Allocator.Error ||
error{ CorruptInput, EndOfStream, Overflow };
pub const Reader = std.io.GenericReader(*Self, Error, read);
allocator: Allocator,
in_reader: ReaderType,
to_read: std.ArrayListUnmanaged(u8),
buffer: decode.lzbuffer.LzCircularBuffer,
decoder: decode.rangecoder.RangeDecoder,
state: decode.DecoderState,
pub fn init(allocator: Allocator, source: ReaderType, params: decode.Params, memlimit: ?usize) !Self {
return Self{
.allocator = allocator,
.in_reader = source,
.to_read = .{},
.buffer = decode.lzbuffer.LzCircularBuffer.init(params.dict_size, memlimit orelse math.maxInt(usize)),
.decoder = try decode.rangecoder.RangeDecoder.init(source),
.state = try decode.DecoderState.init(allocator, params.properties, params.unpacked_size),
}; };
} }
pub fn reader(self: *Self) Reader { pub fn isFinished(self: RangeDecoder) bool {
return .{ .context = self }; return self.code == 0;
} }
pub fn deinit(self: *Self) void { fn normalize(self: *RangeDecoder, reader: *Reader, n_read: *u64) !void {
self.to_read.deinit(self.allocator); if (self.range < 0x0100_0000) {
self.buffer.deinit(self.allocator); self.range <<= 8;
self.state.deinit(self.allocator); self.code = (self.code << 8) ^ @as(u32, try reader.takeByte());
n_read.* += 1;
}
}
fn getBit(self: *RangeDecoder, reader: *Reader, n_read: *u64) !bool {
self.range >>= 1;
const bit = self.code >= self.range;
if (bit) self.code -= self.range;
try self.normalize(reader, n_read);
return bit;
}
pub fn get(self: *RangeDecoder, reader: *Reader, count: usize, n_read: *u64) !u32 {
var result: u32 = 0;
for (0..count) |_| {
result = (result << 1) ^ @intFromBool(try self.getBit(reader, n_read));
}
return result;
}
pub fn decodeBit(self: *RangeDecoder, reader: *Reader, prob: *u16, n_read: *u64) !bool {
const bound = (self.range >> 11) * prob.*;
if (self.code < bound) {
prob.* += (0x800 - prob.*) >> 5;
self.range = bound;
try self.normalize(reader, n_read);
return false;
} else {
prob.* -= prob.* >> 5;
self.code -= bound;
self.range -= bound;
try self.normalize(reader, n_read);
return true;
}
}
fn parseBitTree(
self: *RangeDecoder,
reader: *Reader,
num_bits: u5,
probs: []u16,
n_read: *u64,
) !u32 {
var tmp: u32 = 1;
var i: @TypeOf(num_bits) = 0;
while (i < num_bits) : (i += 1) {
const bit = try self.decodeBit(reader, &probs[tmp], n_read);
tmp = (tmp << 1) ^ @intFromBool(bit);
}
return tmp - (@as(u32, 1) << num_bits);
}
pub fn parseReverseBitTree(
self: *RangeDecoder,
reader: *Reader,
num_bits: u5,
probs: []u16,
offset: usize,
n_read: *u64,
) !u32 {
var result: u32 = 0;
var tmp: usize = 1;
var i: @TypeOf(num_bits) = 0;
while (i < num_bits) : (i += 1) {
const bit = @intFromBool(try self.decodeBit(reader, &probs[offset + tmp], n_read));
tmp = (tmp << 1) ^ bit;
result ^= @as(u32, bit) << i;
}
return result;
}
};
pub const Decode = struct {
properties: Properties,
literal_probs: Vec2d,
pos_slot_decoder: [4]BitTree(6),
align_decoder: BitTree(4),
pos_decoders: [115]u16,
is_match: [192]u16,
is_rep: [12]u16,
is_rep_g0: [12]u16,
is_rep_g1: [12]u16,
is_rep_g2: [12]u16,
is_rep_0long: [192]u16,
state: usize,
rep: [4]usize,
len_decoder: LenDecoder,
rep_len_decoder: LenDecoder,
pub fn init(gpa: Allocator, properties: Properties) !Decode {
return .{
.properties = properties,
.literal_probs = try Vec2d.init(gpa, 0x400, @as(usize, 1) << (properties.lc + properties.lp), 0x300),
.pos_slot_decoder = @splat(.{}),
.align_decoder = .{},
.pos_decoders = @splat(0x400),
.is_match = @splat(0x400),
.is_rep = @splat(0x400),
.is_rep_g0 = @splat(0x400),
.is_rep_g1 = @splat(0x400),
.is_rep_g2 = @splat(0x400),
.is_rep_0long = @splat(0x400),
.state = 0,
.rep = @splat(0),
.len_decoder = .{},
.rep_len_decoder = .{},
};
}
pub fn deinit(self: *Decode, gpa: Allocator) void {
self.literal_probs.deinit(gpa);
self.* = undefined; self.* = undefined;
} }
pub fn read(self: *Self, output: []u8) Error!usize { pub fn resetState(self: *Decode, gpa: Allocator, new_props: Properties) !void {
const writer = self.to_read.writer(self.allocator); new_props.validate();
while (self.to_read.items.len < output.len) { if (self.properties.lc + self.properties.lp == new_props.lc + new_props.lp) {
switch (try self.state.process(self.allocator, self.in_reader, writer, &self.buffer, &self.decoder)) { self.literal_probs.fill(0x400);
.continue_ => {}, } else {
.finished => { self.literal_probs.deinit(gpa);
try self.buffer.finish(writer); self.literal_probs = try Vec2d.init(gpa, 0x400, @as(usize, 1) << (new_props.lc + new_props.lp), 0x300);
}
self.properties = new_props;
for (&self.pos_slot_decoder) |*t| t.reset();
self.align_decoder.reset();
self.pos_decoders = @splat(0x400);
self.is_match = @splat(0x400);
self.is_rep = @splat(0x400);
self.is_rep_g0 = @splat(0x400);
self.is_rep_g1 = @splat(0x400);
self.is_rep_g2 = @splat(0x400);
self.is_rep_0long = @splat(0x400);
self.state = 0;
self.rep = @splat(0);
self.len_decoder.reset();
self.rep_len_decoder.reset();
}
pub fn process(
self: *Decode,
reader: *Reader,
allocating: *Writer.Allocating,
/// `CircularBuffer` or `std.compress.lzma2.AccumBuffer`.
buffer: anytype,
decoder: *RangeDecoder,
n_read: *u64,
) !ProcessingStatus {
const gpa = allocating.allocator;
const writer = &allocating.writer;
const pos_state = buffer.len & ((@as(usize, 1) << self.properties.pb) - 1);
if (!try decoder.decodeBit(reader, &self.is_match[(self.state << 4) + pos_state], n_read)) {
const byte: u8 = try self.decodeLiteral(reader, buffer, decoder, n_read);
try buffer.appendLiteral(gpa, byte, writer);
self.state = if (self.state < 4)
0
else if (self.state < 10)
self.state - 3
else
self.state - 6;
return .more;
}
var len: usize = undefined;
if (try decoder.decodeBit(reader, &self.is_rep[self.state], n_read)) {
if (!try decoder.decodeBit(reader, &self.is_rep_g0[self.state], n_read)) {
if (!try decoder.decodeBit(reader, &self.is_rep_0long[(self.state << 4) + pos_state], n_read)) {
self.state = if (self.state < 7) 9 else 11;
const dist = self.rep[0] + 1;
try buffer.appendLz(gpa, 1, dist, writer);
return .more;
}
} else {
const idx: usize = if (!try decoder.decodeBit(reader, &self.is_rep_g1[self.state], n_read))
1
else if (!try decoder.decodeBit(reader, &self.is_rep_g2[self.state], n_read))
2
else
3;
const dist = self.rep[idx];
var i = idx;
while (i > 0) : (i -= 1) {
self.rep[i] = self.rep[i - 1];
}
self.rep[0] = dist;
}
len = try self.rep_len_decoder.decode(reader, decoder, pos_state, n_read);
self.state = if (self.state < 7) 8 else 11;
} else {
self.rep[3] = self.rep[2];
self.rep[2] = self.rep[1];
self.rep[1] = self.rep[0];
len = try self.len_decoder.decode(reader, decoder, pos_state, n_read);
self.state = if (self.state < 7) 7 else 10;
const rep_0 = try self.decodeDistance(reader, decoder, len, n_read);
self.rep[0] = rep_0;
if (self.rep[0] == 0xFFFF_FFFF) {
if (decoder.isFinished()) {
return .finished;
}
return error.CorruptInput;
}
}
len += 2;
const dist = self.rep[0] + 1;
try buffer.appendLz(gpa, len, dist, writer);
return .more;
}
fn decodeLiteral(
self: *Decode,
reader: *Reader,
/// `CircularBuffer` or `std.compress.lzma2.AccumBuffer`.
buffer: anytype,
decoder: *RangeDecoder,
n_read: *u64,
) !u8 {
const def_prev_byte = 0;
const prev_byte = @as(usize, buffer.lastOr(def_prev_byte));
var result: usize = 1;
const lit_state = ((buffer.len & ((@as(usize, 1) << self.properties.lp) - 1)) << self.properties.lc) +
(prev_byte >> (8 - self.properties.lc));
const probs = try self.literal_probs.get(lit_state);
if (self.state >= 7) {
var match_byte = @as(usize, try buffer.lastN(self.rep[0] + 1));
while (result < 0x100) {
const match_bit = (match_byte >> 7) & 1;
match_byte <<= 1;
const bit = @intFromBool(try decoder.decodeBit(
reader,
&probs[((@as(usize, 1) + match_bit) << 8) + result],
n_read,
));
result = (result << 1) ^ bit;
if (match_bit != bit) {
break; break;
},
} }
} }
const input = self.to_read.items; }
const n = @min(input.len, output.len);
@memcpy(output[0..n], input[0..n]); while (result < 0x100) {
std.mem.copyForwards(u8, input[0 .. input.len - n], input[n..]); result = (result << 1) ^ @intFromBool(try decoder.decodeBit(reader, &probs[result], n_read));
self.to_read.shrinkRetainingCapacity(input.len - n); }
return n;
return @truncate(result - 0x100);
}
fn decodeDistance(
self: *Decode,
reader: *Reader,
decoder: *RangeDecoder,
length: usize,
n_read: *u64,
) !usize {
const len_state = if (length > 3) 3 else length;
const pos_slot: usize = try self.pos_slot_decoder[len_state].parse(reader, decoder, n_read);
if (pos_slot < 4) return pos_slot;
const num_direct_bits = @as(u5, @intCast((pos_slot >> 1) - 1));
var result = (2 ^ (pos_slot & 1)) << num_direct_bits;
if (pos_slot < 14) {
result += try decoder.parseReverseBitTree(
reader,
num_direct_bits,
&self.pos_decoders,
result - pos_slot,
n_read,
);
} else {
result += @as(usize, try decoder.get(reader, num_direct_bits - 4, n_read)) << 4;
result += try self.align_decoder.parseReverse(reader, decoder, n_read);
}
return result;
}
/// A circular buffer for LZ sequences
pub const CircularBuffer = struct {
/// Circular buffer
buf: ArrayList(u8),
/// Length of the buffer
dict_size: usize,
/// Buffer memory limit
mem_limit: usize,
/// Current position
cursor: usize,
/// Total number of bytes sent through the buffer
len: usize,
pub fn init(dict_size: usize, mem_limit: usize) CircularBuffer {
return .{
.buf = .{},
.dict_size = dict_size,
.mem_limit = mem_limit,
.cursor = 0,
.len = 0,
};
}
pub fn get(self: CircularBuffer, index: usize) u8 {
return if (0 <= index and index < self.buf.items.len) self.buf.items[index] else 0;
}
pub fn set(self: *CircularBuffer, gpa: Allocator, index: usize, value: u8) !void {
if (index >= self.mem_limit) {
return error.CorruptInput;
}
try self.buf.ensureTotalCapacity(gpa, index + 1);
while (self.buf.items.len < index) {
self.buf.appendAssumeCapacity(0);
}
self.buf.appendAssumeCapacity(value);
}
/// Retrieve the last byte or return a default
pub fn lastOr(self: CircularBuffer, lit: u8) u8 {
return if (self.len == 0)
lit
else
self.get((self.dict_size + self.cursor - 1) % self.dict_size);
}
/// Retrieve the n-th last byte
pub fn lastN(self: CircularBuffer, dist: usize) !u8 {
if (dist > self.dict_size or dist > self.len) {
return error.CorruptInput;
}
const offset = (self.dict_size + self.cursor - dist) % self.dict_size;
return self.get(offset);
}
/// Append a literal
pub fn appendLiteral(
self: *CircularBuffer,
gpa: Allocator,
lit: u8,
writer: *Writer,
) !void {
try self.set(gpa, self.cursor, lit);
self.cursor += 1;
self.len += 1;
// Flush the circular buffer to the output
if (self.cursor == self.dict_size) {
try writer.writeAll(self.buf.items);
self.cursor = 0;
}
}
/// Fetch an LZ sequence (length, distance) from inside the buffer
pub fn appendLz(
self: *CircularBuffer,
gpa: Allocator,
len: usize,
dist: usize,
writer: *Writer,
) !void {
if (dist > self.dict_size or dist > self.len) {
return error.CorruptInput;
}
var offset = (self.dict_size + self.cursor - dist) % self.dict_size;
var i: usize = 0;
while (i < len) : (i += 1) {
const x = self.get(offset);
try self.appendLiteral(gpa, x, writer);
offset += 1;
if (offset == self.dict_size) {
offset = 0;
}
}
}
pub fn finish(self: *CircularBuffer, writer: *Writer) !void {
if (self.cursor > 0) {
try writer.writeAll(self.buf.items[0..self.cursor]);
self.cursor = 0;
}
}
pub fn deinit(self: *CircularBuffer, gpa: Allocator) void {
self.buf.deinit(gpa);
self.* = undefined;
} }
}; };
}
pub fn BitTree(comptime num_bits: usize) type {
return struct {
probs: [1 << num_bits]u16 = @splat(0x400),
pub fn parse(self: *@This(), reader: *Reader, decoder: *RangeDecoder, n_read: *u64) !u32 {
return decoder.parseBitTree(reader, num_bits, &self.probs, n_read);
}
pub fn parseReverse(
self: *@This(),
reader: *Reader,
decoder: *RangeDecoder,
n_read: *u64,
) !u32 {
return decoder.parseReverseBitTree(reader, num_bits, &self.probs, 0, n_read);
}
pub fn reset(self: *@This()) void {
@memset(&self.probs, 0x400);
}
};
}
pub const LenDecoder = struct {
choice: u16 = 0x400,
choice2: u16 = 0x400,
low_coder: [16]BitTree(3) = @splat(.{}),
mid_coder: [16]BitTree(3) = @splat(.{}),
high_coder: BitTree(8) = .{},
pub fn decode(
self: *LenDecoder,
reader: *Reader,
decoder: *RangeDecoder,
pos_state: usize,
n_read: *u64,
) !usize {
if (!try decoder.decodeBit(reader, &self.choice, n_read)) {
return @as(usize, try self.low_coder[pos_state].parse(reader, decoder, n_read));
} else if (!try decoder.decodeBit(reader, &self.choice2, n_read)) {
return @as(usize, try self.mid_coder[pos_state].parse(reader, decoder, n_read)) + 8;
} else {
return @as(usize, try self.high_coder.parse(reader, decoder, n_read)) + 16;
}
}
pub fn reset(self: *LenDecoder) void {
self.choice = 0x400;
self.choice2 = 0x400;
for (&self.low_coder) |*t| t.reset();
for (&self.mid_coder) |*t| t.reset();
self.high_coder.reset();
}
};
pub const Vec2d = struct {
data: []u16,
cols: usize,
pub fn init(gpa: Allocator, value: u16, w: usize, h: usize) !Vec2d {
const len = try math.mul(usize, w, h);
const data = try gpa.alloc(u16, len);
@memset(data, value);
return .{
.data = data,
.cols = h,
};
}
pub fn deinit(v: *Vec2d, gpa: Allocator) void {
gpa.free(v.data);
v.* = undefined;
}
pub fn fill(v: *Vec2d, value: u16) void {
@memset(v.data, value);
}
fn get(v: Vec2d, row: usize) ![]u16 {
const start_row = try math.mul(usize, row, v.cols);
const end_row = try math.add(usize, start_row, v.cols);
return v.data[start_row..end_row];
}
};
pub const Options = struct {
unpacked_size: UnpackedSize = .read_from_header,
mem_limit: ?usize = null,
allow_incomplete: bool = false,
};
pub const UnpackedSize = union(enum) {
read_from_header,
read_header_but_use_provided: ?u64,
use_provided: ?u64,
};
const ProcessingStatus = enum {
more,
finished,
};
pub const Properties = struct {
lc: u4,
lp: u3,
pb: u3,
fn validate(self: Properties) void {
assert(self.lc <= 8);
assert(self.lp <= 4);
assert(self.pb <= 4);
}
};
pub const Params = struct {
properties: Properties,
dict_size: u32,
unpacked_size: ?u64,
pub fn readHeader(reader: *Reader, options: Options) !Params {
var props = try reader.takeByte();
if (props >= 225) return error.CorruptInput;
const lc: u4 = @intCast(props % 9);
props /= 9;
const lp: u3 = @intCast(props % 5);
props /= 5;
const pb: u3 = @intCast(props);
const dict_size_provided = try reader.takeInt(u32, .little);
const dict_size = @max(0x1000, dict_size_provided);
const unpacked_size = switch (options.unpacked_size) {
.read_from_header => blk: {
const unpacked_size_provided = try reader.takeInt(u64, .little);
const marker_mandatory = unpacked_size_provided == 0xFFFF_FFFF_FFFF_FFFF;
break :blk if (marker_mandatory) null else unpacked_size_provided;
},
.read_header_but_use_provided => |x| blk: {
_ = try reader.takeInt(u64, .little);
break :blk x;
},
.use_provided => |x| x,
};
return .{
.properties = .{ .lc = lc, .lp = lp, .pb = pb },
.dict_size = dict_size,
.unpacked_size = unpacked_size,
};
}
};
};
pub const Decompress = struct {
gpa: Allocator,
input: *Reader,
reader: Reader,
buffer: Decode.CircularBuffer,
range_decoder: RangeDecoder,
decode: Decode,
err: ?Error,
unpacked_size: ?u64,
pub const Error = error{
OutOfMemory,
ReadFailed,
CorruptInput,
DecompressedSizeMismatch,
EndOfStream,
Overflow,
};
/// Takes ownership of `buffer` which may be resized with `gpa`.
///
/// LZMA was explicitly designed to take advantage of large heap memory
/// being available, with a dictionary size anywhere from 4K to 4G. Thus,
/// this API dynamically allocates the dictionary as-needed.
pub fn initParams(
input: *Reader,
gpa: Allocator,
buffer: []u8,
params: Decode.Params,
mem_limit: usize,
) !Decompress {
return .{
.gpa = gpa,
.input = input,
.buffer = Decode.CircularBuffer.init(params.dict_size, mem_limit),
.range_decoder = try RangeDecoder.init(input),
.decode = try Decode.init(gpa, params.properties),
.reader = .{
.buffer = buffer,
.vtable = &.{
.readVec = readVec,
.stream = stream,
.discard = discard,
},
.seek = 0,
.end = 0,
},
.err = null,
.unpacked_size = params.unpacked_size,
};
}
/// Takes ownership of `buffer` which may be resized with `gpa`.
///
/// LZMA was explicitly designed to take advantage of large heap memory
/// being available, with a dictionary size anywhere from 4K to 4G. Thus,
/// this API dynamically allocates the dictionary as-needed.
pub fn initOptions(
input: *Reader,
gpa: Allocator,
buffer: []u8,
options: Decode.Options,
mem_limit: usize,
) !Decompress {
const params = try Decode.Params.readHeader(input, options);
return initParams(input, gpa, buffer, params, mem_limit);
}
/// Reclaim ownership of the buffer passed to `init`.
pub fn takeBuffer(d: *Decompress) []u8 {
const buffer = d.reader.buffer;
d.reader.buffer = &.{};
return buffer;
}
pub fn deinit(d: *Decompress) void {
const gpa = d.gpa;
gpa.free(d.reader.buffer);
d.buffer.deinit(gpa);
d.decode.deinit(gpa);
d.* = undefined;
}
fn readVec(r: *Reader, data: [][]u8) Reader.Error!usize {
_ = data;
return readIndirect(r);
}
fn stream(r: *Reader, w: *Writer, limit: std.Io.Limit) Reader.StreamError!usize {
_ = w;
_ = limit;
return readIndirect(r);
}
fn discard(r: *Reader, limit: std.Io.Limit) Reader.Error!usize {
const d: *Decompress = @alignCast(@fieldParentPtr("reader", r));
_ = d;
_ = limit;
@panic("TODO");
}
fn readIndirect(r: *Reader) Reader.Error!usize {
const d: *Decompress = @alignCast(@fieldParentPtr("reader", r));
const gpa = d.gpa;
var allocating = Writer.Allocating.initOwnedSlice(gpa, r.buffer);
allocating.writer.end = r.end;
defer {
r.buffer = allocating.writer.buffer;
r.end = allocating.writer.end;
}
if (d.decode.state == math.maxInt(usize)) return error.EndOfStream;
process_next: {
if (d.unpacked_size) |unpacked_size| {
if (d.buffer.len >= unpacked_size) break :process_next;
} else if (d.range_decoder.isFinished()) {
break :process_next;
}
var n_read: u64 = 0;
switch (d.decode.process(d.input, &allocating, &d.buffer, &d.range_decoder, &n_read) catch |err| switch (err) {
error.WriteFailed => {
d.err = error.OutOfMemory;
return error.ReadFailed;
},
error.EndOfStream => {
d.err = error.EndOfStream;
return error.ReadFailed;
},
else => |e| {
d.err = e;
return error.ReadFailed;
},
}) {
.more => return 0,
.finished => break :process_next,
}
}
if (d.unpacked_size) |unpacked_size| {
if (d.buffer.len != unpacked_size) {
d.err = error.DecompressedSizeMismatch;
return error.ReadFailed;
}
}
d.buffer.finish(&allocating.writer) catch |err| switch (err) {
error.WriteFailed => {
d.err = error.OutOfMemory;
return error.ReadFailed;
},
};
d.decode.state = math.maxInt(usize);
return 0;
}
};
test { test {
_ = @import("lzma/test.zig"); _ = @import("lzma/test.zig");
_ = @import("lzma/vec2d.zig");
} }

View file

@ -1,379 +0,0 @@
const std = @import("../../std.zig");
const assert = std.debug.assert;
const math = std.math;
const Allocator = std.mem.Allocator;
pub const lzbuffer = @import("decode/lzbuffer.zig");
pub const rangecoder = @import("decode/rangecoder.zig");
const LzCircularBuffer = lzbuffer.LzCircularBuffer;
const BitTree = rangecoder.BitTree;
const LenDecoder = rangecoder.LenDecoder;
const RangeDecoder = rangecoder.RangeDecoder;
const Vec2D = @import("vec2d.zig").Vec2D;
pub const Options = struct {
unpacked_size: UnpackedSize = .read_from_header,
memlimit: ?usize = null,
allow_incomplete: bool = false,
};
pub const UnpackedSize = union(enum) {
read_from_header,
read_header_but_use_provided: ?u64,
use_provided: ?u64,
};
const ProcessingStatus = enum {
continue_,
finished,
};
pub const Properties = struct {
lc: u4,
lp: u3,
pb: u3,
fn validate(self: Properties) void {
assert(self.lc <= 8);
assert(self.lp <= 4);
assert(self.pb <= 4);
}
};
pub const Params = struct {
properties: Properties,
dict_size: u32,
unpacked_size: ?u64,
pub fn readHeader(reader: anytype, options: Options) !Params {
var props = try reader.readByte();
if (props >= 225) {
return error.CorruptInput;
}
const lc = @as(u4, @intCast(props % 9));
props /= 9;
const lp = @as(u3, @intCast(props % 5));
props /= 5;
const pb = @as(u3, @intCast(props));
const dict_size_provided = try reader.readInt(u32, .little);
const dict_size = @max(0x1000, dict_size_provided);
const unpacked_size = switch (options.unpacked_size) {
.read_from_header => blk: {
const unpacked_size_provided = try reader.readInt(u64, .little);
const marker_mandatory = unpacked_size_provided == 0xFFFF_FFFF_FFFF_FFFF;
break :blk if (marker_mandatory)
null
else
unpacked_size_provided;
},
.read_header_but_use_provided => |x| blk: {
_ = try reader.readInt(u64, .little);
break :blk x;
},
.use_provided => |x| x,
};
return Params{
.properties = Properties{ .lc = lc, .lp = lp, .pb = pb },
.dict_size = dict_size,
.unpacked_size = unpacked_size,
};
}
};
pub const DecoderState = struct {
lzma_props: Properties,
unpacked_size: ?u64,
literal_probs: Vec2D(u16),
pos_slot_decoder: [4]BitTree(6),
align_decoder: BitTree(4),
pos_decoders: [115]u16,
is_match: [192]u16,
is_rep: [12]u16,
is_rep_g0: [12]u16,
is_rep_g1: [12]u16,
is_rep_g2: [12]u16,
is_rep_0long: [192]u16,
state: usize,
rep: [4]usize,
len_decoder: LenDecoder,
rep_len_decoder: LenDecoder,
pub fn init(
allocator: Allocator,
lzma_props: Properties,
unpacked_size: ?u64,
) !DecoderState {
return .{
.lzma_props = lzma_props,
.unpacked_size = unpacked_size,
.literal_probs = try Vec2D(u16).init(allocator, 0x400, .{ @as(usize, 1) << (lzma_props.lc + lzma_props.lp), 0x300 }),
.pos_slot_decoder = @splat(.{}),
.align_decoder = .{},
.pos_decoders = @splat(0x400),
.is_match = @splat(0x400),
.is_rep = @splat(0x400),
.is_rep_g0 = @splat(0x400),
.is_rep_g1 = @splat(0x400),
.is_rep_g2 = @splat(0x400),
.is_rep_0long = @splat(0x400),
.state = 0,
.rep = @splat(0),
.len_decoder = .{},
.rep_len_decoder = .{},
};
}
pub fn deinit(self: *DecoderState, allocator: Allocator) void {
self.literal_probs.deinit(allocator);
self.* = undefined;
}
pub fn resetState(self: *DecoderState, allocator: Allocator, new_props: Properties) !void {
new_props.validate();
if (self.lzma_props.lc + self.lzma_props.lp == new_props.lc + new_props.lp) {
self.literal_probs.fill(0x400);
} else {
self.literal_probs.deinit(allocator);
self.literal_probs = try Vec2D(u16).init(allocator, 0x400, .{ @as(usize, 1) << (new_props.lc + new_props.lp), 0x300 });
}
self.lzma_props = new_props;
for (&self.pos_slot_decoder) |*t| t.reset();
self.align_decoder.reset();
self.pos_decoders = @splat(0x400);
self.is_match = @splat(0x400);
self.is_rep = @splat(0x400);
self.is_rep_g0 = @splat(0x400);
self.is_rep_g1 = @splat(0x400);
self.is_rep_g2 = @splat(0x400);
self.is_rep_0long = @splat(0x400);
self.state = 0;
self.rep = @splat(0);
self.len_decoder.reset();
self.rep_len_decoder.reset();
}
fn processNextInner(
self: *DecoderState,
allocator: Allocator,
reader: anytype,
writer: anytype,
buffer: anytype,
decoder: *RangeDecoder,
update: bool,
) !ProcessingStatus {
const pos_state = buffer.len & ((@as(usize, 1) << self.lzma_props.pb) - 1);
if (!try decoder.decodeBit(
reader,
&self.is_match[(self.state << 4) + pos_state],
update,
)) {
const byte: u8 = try self.decodeLiteral(reader, buffer, decoder, update);
if (update) {
try buffer.appendLiteral(allocator, byte, writer);
self.state = if (self.state < 4)
0
else if (self.state < 10)
self.state - 3
else
self.state - 6;
}
return .continue_;
}
var len: usize = undefined;
if (try decoder.decodeBit(reader, &self.is_rep[self.state], update)) {
if (!try decoder.decodeBit(reader, &self.is_rep_g0[self.state], update)) {
if (!try decoder.decodeBit(
reader,
&self.is_rep_0long[(self.state << 4) + pos_state],
update,
)) {
if (update) {
self.state = if (self.state < 7) 9 else 11;
const dist = self.rep[0] + 1;
try buffer.appendLz(allocator, 1, dist, writer);
}
return .continue_;
}
} else {
const idx: usize = if (!try decoder.decodeBit(reader, &self.is_rep_g1[self.state], update))
1
else if (!try decoder.decodeBit(reader, &self.is_rep_g2[self.state], update))
2
else
3;
if (update) {
const dist = self.rep[idx];
var i = idx;
while (i > 0) : (i -= 1) {
self.rep[i] = self.rep[i - 1];
}
self.rep[0] = dist;
}
}
len = try self.rep_len_decoder.decode(reader, decoder, pos_state, update);
if (update) {
self.state = if (self.state < 7) 8 else 11;
}
} else {
if (update) {
self.rep[3] = self.rep[2];
self.rep[2] = self.rep[1];
self.rep[1] = self.rep[0];
}
len = try self.len_decoder.decode(reader, decoder, pos_state, update);
if (update) {
self.state = if (self.state < 7) 7 else 10;
}
const rep_0 = try self.decodeDistance(reader, decoder, len, update);
if (update) {
self.rep[0] = rep_0;
if (self.rep[0] == 0xFFFF_FFFF) {
if (decoder.isFinished()) {
return .finished;
}
return error.CorruptInput;
}
}
}
if (update) {
len += 2;
const dist = self.rep[0] + 1;
try buffer.appendLz(allocator, len, dist, writer);
}
return .continue_;
}
fn processNext(
self: *DecoderState,
allocator: Allocator,
reader: anytype,
writer: anytype,
buffer: anytype,
decoder: *RangeDecoder,
) !ProcessingStatus {
return self.processNextInner(allocator, reader, writer, buffer, decoder, true);
}
pub fn process(
self: *DecoderState,
allocator: Allocator,
reader: anytype,
writer: anytype,
buffer: anytype,
decoder: *RangeDecoder,
) !ProcessingStatus {
process_next: {
if (self.unpacked_size) |unpacked_size| {
if (buffer.len >= unpacked_size) {
break :process_next;
}
} else if (decoder.isFinished()) {
break :process_next;
}
switch (try self.processNext(allocator, reader, writer, buffer, decoder)) {
.continue_ => return .continue_,
.finished => break :process_next,
}
}
if (self.unpacked_size) |unpacked_size| {
if (buffer.len != unpacked_size) {
return error.CorruptInput;
}
}
return .finished;
}
fn decodeLiteral(
self: *DecoderState,
reader: anytype,
buffer: anytype,
decoder: *RangeDecoder,
update: bool,
) !u8 {
const def_prev_byte = 0;
const prev_byte = @as(usize, buffer.lastOr(def_prev_byte));
var result: usize = 1;
const lit_state = ((buffer.len & ((@as(usize, 1) << self.lzma_props.lp) - 1)) << self.lzma_props.lc) +
(prev_byte >> (8 - self.lzma_props.lc));
const probs = try self.literal_probs.getMut(lit_state);
if (self.state >= 7) {
var match_byte = @as(usize, try buffer.lastN(self.rep[0] + 1));
while (result < 0x100) {
const match_bit = (match_byte >> 7) & 1;
match_byte <<= 1;
const bit = @intFromBool(try decoder.decodeBit(
reader,
&probs[((@as(usize, 1) + match_bit) << 8) + result],
update,
));
result = (result << 1) ^ bit;
if (match_bit != bit) {
break;
}
}
}
while (result < 0x100) {
result = (result << 1) ^ @intFromBool(try decoder.decodeBit(reader, &probs[result], update));
}
return @as(u8, @truncate(result - 0x100));
}
fn decodeDistance(
self: *DecoderState,
reader: anytype,
decoder: *RangeDecoder,
length: usize,
update: bool,
) !usize {
const len_state = if (length > 3) 3 else length;
const pos_slot = @as(usize, try self.pos_slot_decoder[len_state].parse(reader, decoder, update));
if (pos_slot < 4)
return pos_slot;
const num_direct_bits = @as(u5, @intCast((pos_slot >> 1) - 1));
var result = (2 ^ (pos_slot & 1)) << num_direct_bits;
if (pos_slot < 14) {
result += try decoder.parseReverseBitTree(
reader,
num_direct_bits,
&self.pos_decoders,
result - pos_slot,
update,
);
} else {
result += @as(usize, try decoder.get(reader, num_direct_bits - 4)) << 4;
result += try self.align_decoder.parseReverse(reader, decoder, update);
}
return result;
}
};

View file

@ -1,228 +0,0 @@
const std = @import("../../../std.zig");
const math = std.math;
const mem = std.mem;
const Allocator = std.mem.Allocator;
const ArrayListUnmanaged = std.ArrayListUnmanaged;
/// An accumulating buffer for LZ sequences
pub const LzAccumBuffer = struct {
/// Buffer
buf: ArrayListUnmanaged(u8),
/// Buffer memory limit
memlimit: usize,
/// Total number of bytes sent through the buffer
len: usize,
const Self = @This();
pub fn init(memlimit: usize) Self {
return Self{
.buf = .{},
.memlimit = memlimit,
.len = 0,
};
}
pub fn appendByte(self: *Self, allocator: Allocator, byte: u8) !void {
try self.buf.append(allocator, byte);
self.len += 1;
}
/// Reset the internal dictionary
pub fn reset(self: *Self, writer: anytype) !void {
try writer.writeAll(self.buf.items);
self.buf.clearRetainingCapacity();
self.len = 0;
}
/// Retrieve the last byte or return a default
pub fn lastOr(self: Self, lit: u8) u8 {
const buf_len = self.buf.items.len;
return if (buf_len == 0)
lit
else
self.buf.items[buf_len - 1];
}
/// Retrieve the n-th last byte
pub fn lastN(self: Self, dist: usize) !u8 {
const buf_len = self.buf.items.len;
if (dist > buf_len) {
return error.CorruptInput;
}
return self.buf.items[buf_len - dist];
}
/// Append a literal
pub fn appendLiteral(
self: *Self,
allocator: Allocator,
lit: u8,
writer: anytype,
) !void {
_ = writer;
if (self.len >= self.memlimit) {
return error.CorruptInput;
}
try self.buf.append(allocator, lit);
self.len += 1;
}
/// Fetch an LZ sequence (length, distance) from inside the buffer
pub fn appendLz(
self: *Self,
allocator: Allocator,
len: usize,
dist: usize,
writer: anytype,
) !void {
_ = writer;
const buf_len = self.buf.items.len;
if (dist > buf_len) {
return error.CorruptInput;
}
var offset = buf_len - dist;
var i: usize = 0;
while (i < len) : (i += 1) {
const x = self.buf.items[offset];
try self.buf.append(allocator, x);
offset += 1;
}
self.len += len;
}
pub fn finish(self: *Self, writer: anytype) !void {
try writer.writeAll(self.buf.items);
self.buf.clearRetainingCapacity();
}
pub fn deinit(self: *Self, allocator: Allocator) void {
self.buf.deinit(allocator);
self.* = undefined;
}
};
/// A circular buffer for LZ sequences
pub const LzCircularBuffer = struct {
/// Circular buffer
buf: ArrayListUnmanaged(u8),
/// Length of the buffer
dict_size: usize,
/// Buffer memory limit
memlimit: usize,
/// Current position
cursor: usize,
/// Total number of bytes sent through the buffer
len: usize,
const Self = @This();
pub fn init(dict_size: usize, memlimit: usize) Self {
return Self{
.buf = .{},
.dict_size = dict_size,
.memlimit = memlimit,
.cursor = 0,
.len = 0,
};
}
pub fn get(self: Self, index: usize) u8 {
return if (0 <= index and index < self.buf.items.len)
self.buf.items[index]
else
0;
}
pub fn set(self: *Self, allocator: Allocator, index: usize, value: u8) !void {
if (index >= self.memlimit) {
return error.CorruptInput;
}
try self.buf.ensureTotalCapacity(allocator, index + 1);
while (self.buf.items.len < index) {
self.buf.appendAssumeCapacity(0);
}
self.buf.appendAssumeCapacity(value);
}
/// Retrieve the last byte or return a default
pub fn lastOr(self: Self, lit: u8) u8 {
return if (self.len == 0)
lit
else
self.get((self.dict_size + self.cursor - 1) % self.dict_size);
}
/// Retrieve the n-th last byte
pub fn lastN(self: Self, dist: usize) !u8 {
if (dist > self.dict_size or dist > self.len) {
return error.CorruptInput;
}
const offset = (self.dict_size + self.cursor - dist) % self.dict_size;
return self.get(offset);
}
/// Append a literal
pub fn appendLiteral(
self: *Self,
allocator: Allocator,
lit: u8,
writer: anytype,
) !void {
try self.set(allocator, self.cursor, lit);
self.cursor += 1;
self.len += 1;
// Flush the circular buffer to the output
if (self.cursor == self.dict_size) {
try writer.writeAll(self.buf.items);
self.cursor = 0;
}
}
/// Fetch an LZ sequence (length, distance) from inside the buffer
pub fn appendLz(
self: *Self,
allocator: Allocator,
len: usize,
dist: usize,
writer: anytype,
) !void {
if (dist > self.dict_size or dist > self.len) {
return error.CorruptInput;
}
var offset = (self.dict_size + self.cursor - dist) % self.dict_size;
var i: usize = 0;
while (i < len) : (i += 1) {
const x = self.get(offset);
try self.appendLiteral(allocator, x, writer);
offset += 1;
if (offset == self.dict_size) {
offset = 0;
}
}
}
pub fn finish(self: *Self, writer: anytype) !void {
if (self.cursor > 0) {
try writer.writeAll(self.buf.items[0..self.cursor]);
self.cursor = 0;
}
}
pub fn deinit(self: *Self, allocator: Allocator) void {
self.buf.deinit(allocator);
self.* = undefined;
}
};

View file

@ -1,181 +0,0 @@
const std = @import("../../../std.zig");
const mem = std.mem;
pub const RangeDecoder = struct {
range: u32,
code: u32,
pub fn init(reader: anytype) !RangeDecoder {
const reserved = try reader.readByte();
if (reserved != 0) {
return error.CorruptInput;
}
return RangeDecoder{
.range = 0xFFFF_FFFF,
.code = try reader.readInt(u32, .big),
};
}
pub fn fromParts(
range: u32,
code: u32,
) RangeDecoder {
return .{
.range = range,
.code = code,
};
}
pub fn set(self: *RangeDecoder, range: u32, code: u32) void {
self.range = range;
self.code = code;
}
pub inline fn isFinished(self: RangeDecoder) bool {
return self.code == 0;
}
inline fn normalize(self: *RangeDecoder, reader: anytype) !void {
if (self.range < 0x0100_0000) {
self.range <<= 8;
self.code = (self.code << 8) ^ @as(u32, try reader.readByte());
}
}
inline fn getBit(self: *RangeDecoder, reader: anytype) !bool {
self.range >>= 1;
const bit = self.code >= self.range;
if (bit)
self.code -= self.range;
try self.normalize(reader);
return bit;
}
pub fn get(self: *RangeDecoder, reader: anytype, count: usize) !u32 {
var result: u32 = 0;
var i: usize = 0;
while (i < count) : (i += 1)
result = (result << 1) ^ @intFromBool(try self.getBit(reader));
return result;
}
pub inline fn decodeBit(self: *RangeDecoder, reader: anytype, prob: *u16, update: bool) !bool {
const bound = (self.range >> 11) * prob.*;
if (self.code < bound) {
if (update)
prob.* += (0x800 - prob.*) >> 5;
self.range = bound;
try self.normalize(reader);
return false;
} else {
if (update)
prob.* -= prob.* >> 5;
self.code -= bound;
self.range -= bound;
try self.normalize(reader);
return true;
}
}
fn parseBitTree(
self: *RangeDecoder,
reader: anytype,
num_bits: u5,
probs: []u16,
update: bool,
) !u32 {
var tmp: u32 = 1;
var i: @TypeOf(num_bits) = 0;
while (i < num_bits) : (i += 1) {
const bit = try self.decodeBit(reader, &probs[tmp], update);
tmp = (tmp << 1) ^ @intFromBool(bit);
}
return tmp - (@as(u32, 1) << num_bits);
}
pub fn parseReverseBitTree(
self: *RangeDecoder,
reader: anytype,
num_bits: u5,
probs: []u16,
offset: usize,
update: bool,
) !u32 {
var result: u32 = 0;
var tmp: usize = 1;
var i: @TypeOf(num_bits) = 0;
while (i < num_bits) : (i += 1) {
const bit = @intFromBool(try self.decodeBit(reader, &probs[offset + tmp], update));
tmp = (tmp << 1) ^ bit;
result ^= @as(u32, bit) << i;
}
return result;
}
};
pub fn BitTree(comptime num_bits: usize) type {
return struct {
probs: [1 << num_bits]u16 = @splat(0x400),
const Self = @This();
pub fn parse(
self: *Self,
reader: anytype,
decoder: *RangeDecoder,
update: bool,
) !u32 {
return decoder.parseBitTree(reader, num_bits, &self.probs, update);
}
pub fn parseReverse(
self: *Self,
reader: anytype,
decoder: *RangeDecoder,
update: bool,
) !u32 {
return decoder.parseReverseBitTree(reader, num_bits, &self.probs, 0, update);
}
pub fn reset(self: *Self) void {
@memset(&self.probs, 0x400);
}
};
}
pub const LenDecoder = struct {
choice: u16 = 0x400,
choice2: u16 = 0x400,
low_coder: [16]BitTree(3) = @splat(.{}),
mid_coder: [16]BitTree(3) = @splat(.{}),
high_coder: BitTree(8) = .{},
pub fn decode(
self: *LenDecoder,
reader: anytype,
decoder: *RangeDecoder,
pos_state: usize,
update: bool,
) !usize {
if (!try decoder.decodeBit(reader, &self.choice, update)) {
return @as(usize, try self.low_coder[pos_state].parse(reader, decoder, update));
} else if (!try decoder.decodeBit(reader, &self.choice2, update)) {
return @as(usize, try self.mid_coder[pos_state].parse(reader, decoder, update)) + 8;
} else {
return @as(usize, try self.high_coder.parse(reader, decoder, update)) + 16;
}
}
pub fn reset(self: *LenDecoder) void {
self.choice = 0x400;
self.choice2 = 0x400;
for (&self.low_coder) |*t| t.reset();
for (&self.mid_coder) |*t| t.reset();
self.high_coder.reset();
}
};

View file

@ -1,24 +1,31 @@
const std = @import("../../std.zig"); const std = @import("../../std.zig");
const lzma = @import("../lzma.zig"); const lzma = std.compress.lzma;
fn testDecompress(compressed: []const u8) ![]u8 { fn testDecompress(compressed: []const u8) ![]u8 {
const allocator = std.testing.allocator; const gpa = std.testing.allocator;
var stream = std.io.fixedBufferStream(compressed); var stream: std.Io.Reader = .fixed(compressed);
var decompressor = try lzma.decompress(allocator, stream.reader());
var decompressor = try lzma.Decompress.initOptions(&stream, gpa, &.{}, .{}, std.math.maxInt(u32));
defer decompressor.deinit(); defer decompressor.deinit();
const reader = decompressor.reader(); return decompressor.reader.allocRemaining(gpa, .unlimited);
return reader.readAllAlloc(allocator, std.math.maxInt(usize));
} }
fn testDecompressEqual(expected: []const u8, compressed: []const u8) !void { fn testDecompressEqual(expected: []const u8, compressed: []const u8) !void {
const allocator = std.testing.allocator; const gpa = std.testing.allocator;
const decomp = try testDecompress(compressed); const decomp = try testDecompress(compressed);
defer allocator.free(decomp); defer gpa.free(decomp);
try std.testing.expectEqualSlices(u8, expected, decomp); try std.testing.expectEqualSlices(u8, expected, decomp);
} }
fn testDecompressError(expected: anyerror, compressed: []const u8) !void { fn testDecompressError(expected: anyerror, compressed: []const u8) !void {
return std.testing.expectError(expected, testDecompress(compressed)); const gpa = std.testing.allocator;
var stream: std.Io.Reader = .fixed(compressed);
var decompressor = try lzma.Decompress.initOptions(&stream, gpa, &.{}, .{}, std.math.maxInt(u32));
defer decompressor.deinit();
try std.testing.expectError(error.ReadFailed, decompressor.reader.allocRemaining(gpa, .unlimited));
try std.testing.expectEqual(expected, decompressor.err orelse return error.TestFailed);
} }
test "decompress empty world" { test "decompress empty world" {
@ -76,24 +83,26 @@ test "known size with end of payload marker" {
test "too big uncompressed size in header" { test "too big uncompressed size in header" {
try testDecompressError( try testDecompressError(
error.CorruptInput, error.DecompressedSizeMismatch,
@embedFile("testdata/bad-too_big_size-with_eopm.lzma"), @embedFile("testdata/bad-too_big_size-with_eopm.lzma"),
); );
} }
test "too small uncompressed size in header" { test "too small uncompressed size in header" {
try testDecompressError( try testDecompressError(
error.CorruptInput, error.DecompressedSizeMismatch,
@embedFile("testdata/bad-too_small_size-without_eopm-3.lzma"), @embedFile("testdata/bad-too_small_size-without_eopm-3.lzma"),
); );
} }
test "reading one byte" { test "reading one byte" {
const gpa = std.testing.allocator;
const compressed = @embedFile("testdata/good-known_size-with_eopm.lzma"); const compressed = @embedFile("testdata/good-known_size-with_eopm.lzma");
var stream = std.io.fixedBufferStream(compressed); var stream: std.Io.Reader = .fixed(compressed);
var decompressor = try lzma.decompress(std.testing.allocator, stream.reader()); var decompressor = try lzma.Decompress.initOptions(&stream, gpa, &.{}, .{}, std.math.maxInt(u32));
defer decompressor.deinit(); defer decompressor.deinit();
var buffer = [1]u8{0}; var buffer: [1]u8 = undefined;
_ = try decompressor.read(buffer[0..]); try decompressor.reader.readSliceAll(&buffer);
try std.testing.expectEqual(72, buffer[0]);
} }

View file

@ -1,128 +0,0 @@
const std = @import("../../std.zig");
const math = std.math;
const mem = std.mem;
const Allocator = std.mem.Allocator;
pub fn Vec2D(comptime T: type) type {
return struct {
data: []T,
cols: usize,
const Self = @This();
pub fn init(allocator: Allocator, value: T, size: struct { usize, usize }) !Self {
const len = try math.mul(usize, size[0], size[1]);
const data = try allocator.alloc(T, len);
@memset(data, value);
return Self{
.data = data,
.cols = size[1],
};
}
pub fn deinit(self: *Self, allocator: Allocator) void {
allocator.free(self.data);
self.* = undefined;
}
pub fn fill(self: *Self, value: T) void {
@memset(self.data, value);
}
inline fn _get(self: Self, row: usize) ![]T {
const start_row = try math.mul(usize, row, self.cols);
const end_row = try math.add(usize, start_row, self.cols);
return self.data[start_row..end_row];
}
pub fn get(self: Self, row: usize) ![]const T {
return self._get(row);
}
pub fn getMut(self: *Self, row: usize) ![]T {
return self._get(row);
}
};
}
const testing = std.testing;
const expectEqualSlices = std.testing.expectEqualSlices;
const expectError = std.testing.expectError;
test "init" {
const allocator = testing.allocator;
var vec2d = try Vec2D(i32).init(allocator, 1, .{ 2, 3 });
defer vec2d.deinit(allocator);
try expectEqualSlices(i32, &.{ 1, 1, 1 }, try vec2d.get(0));
try expectEqualSlices(i32, &.{ 1, 1, 1 }, try vec2d.get(1));
}
test "init overflow" {
const allocator = testing.allocator;
try expectError(
error.Overflow,
Vec2D(i32).init(allocator, 1, .{ math.maxInt(usize), math.maxInt(usize) }),
);
}
test "fill" {
const allocator = testing.allocator;
var vec2d = try Vec2D(i32).init(allocator, 0, .{ 2, 3 });
defer vec2d.deinit(allocator);
vec2d.fill(7);
try expectEqualSlices(i32, &.{ 7, 7, 7 }, try vec2d.get(0));
try expectEqualSlices(i32, &.{ 7, 7, 7 }, try vec2d.get(1));
}
test "get" {
var data = [_]i32{ 0, 1, 2, 3, 4, 5, 6, 7 };
const vec2d = Vec2D(i32){
.data = &data,
.cols = 2,
};
try expectEqualSlices(i32, &.{ 0, 1 }, try vec2d.get(0));
try expectEqualSlices(i32, &.{ 2, 3 }, try vec2d.get(1));
try expectEqualSlices(i32, &.{ 4, 5 }, try vec2d.get(2));
try expectEqualSlices(i32, &.{ 6, 7 }, try vec2d.get(3));
}
test "getMut" {
var data = [_]i32{ 0, 1, 2, 3, 4, 5, 6, 7 };
var vec2d = Vec2D(i32){
.data = &data,
.cols = 2,
};
const row = try vec2d.getMut(1);
row[1] = 9;
try expectEqualSlices(i32, &.{ 0, 1 }, try vec2d.get(0));
// (1, 1) should be 9.
try expectEqualSlices(i32, &.{ 2, 9 }, try vec2d.get(1));
try expectEqualSlices(i32, &.{ 4, 5 }, try vec2d.get(2));
try expectEqualSlices(i32, &.{ 6, 7 }, try vec2d.get(3));
}
test "get multiplication overflow" {
const allocator = testing.allocator;
var matrix = try Vec2D(i32).init(allocator, 0, .{ 3, 4 });
defer matrix.deinit(allocator);
const row = (math.maxInt(usize) / 4) + 1;
try expectError(error.Overflow, matrix.get(row));
try expectError(error.Overflow, matrix.getMut(row));
}
test "get addition overflow" {
const allocator = testing.allocator;
var matrix = try Vec2D(i32).init(allocator, 0, .{ 3, 5 });
defer matrix.deinit(allocator);
const row = math.maxInt(usize) / 5;
try expectError(error.Overflow, matrix.get(row));
try expectError(error.Overflow, matrix.getMut(row));
}

View file

@ -1,26 +1,282 @@
const std = @import("../std.zig"); const std = @import("../std.zig");
const Allocator = std.mem.Allocator; const Allocator = std.mem.Allocator;
const ArrayList = std.ArrayList;
const lzma = std.compress.lzma;
const Writer = std.Io.Writer;
const Reader = std.Io.Reader;
pub const decode = @import("lzma2/decode.zig"); /// An accumulating buffer for LZ sequences
pub const AccumBuffer = struct {
/// Buffer
buf: ArrayList(u8),
/// Buffer memory limit
memlimit: usize,
/// Total number of bytes sent through the buffer
len: usize,
pub fn decompress( pub fn init(memlimit: usize) AccumBuffer {
return .{
.buf = .{},
.memlimit = memlimit,
.len = 0,
};
}
pub fn appendByte(self: *AccumBuffer, allocator: Allocator, byte: u8) !void {
try self.buf.append(allocator, byte);
self.len += 1;
}
/// Reset the internal dictionary
pub fn reset(self: *AccumBuffer, writer: *Writer) !void {
try writer.writeAll(self.buf.items);
self.buf.clearRetainingCapacity();
self.len = 0;
}
/// Retrieve the last byte or return a default
pub fn lastOr(self: AccumBuffer, lit: u8) u8 {
const buf_len = self.buf.items.len;
return if (buf_len == 0)
lit
else
self.buf.items[buf_len - 1];
}
/// Retrieve the n-th last byte
pub fn lastN(self: AccumBuffer, dist: usize) !u8 {
const buf_len = self.buf.items.len;
if (dist > buf_len) {
return error.CorruptInput;
}
return self.buf.items[buf_len - dist];
}
/// Append a literal
pub fn appendLiteral(
self: *AccumBuffer,
allocator: Allocator, allocator: Allocator,
reader: anytype, lit: u8,
writer: anytype, writer: *Writer,
) !void { ) !void {
var decoder = try decode.Decoder.init(allocator); _ = writer;
defer decoder.deinit(allocator); if (self.len >= self.memlimit) {
return decoder.decompress(allocator, reader, writer); return error.CorruptInput;
} }
try self.buf.append(allocator, lit);
self.len += 1;
}
test { /// Fetch an LZ sequence (length, distance) from inside the buffer
pub fn appendLz(
self: *AccumBuffer,
allocator: Allocator,
len: usize,
dist: usize,
writer: *Writer,
) !void {
_ = writer;
const buf_len = self.buf.items.len;
if (dist > buf_len) {
return error.CorruptInput;
}
var offset = buf_len - dist;
var i: usize = 0;
while (i < len) : (i += 1) {
const x = self.buf.items[offset];
try self.buf.append(allocator, x);
offset += 1;
}
self.len += len;
}
pub fn finish(self: *AccumBuffer, writer: *Writer) !void {
try writer.writeAll(self.buf.items);
self.buf.clearRetainingCapacity();
}
pub fn deinit(self: *AccumBuffer, allocator: Allocator) void {
self.buf.deinit(allocator);
self.* = undefined;
}
};
pub const Decode = struct {
lzma_decode: lzma.Decode,
pub fn init(gpa: Allocator) !Decode {
return .{ .lzma_decode = try lzma.Decode.init(gpa, .{ .lc = 0, .lp = 0, .pb = 0 }) };
}
pub fn deinit(self: *Decode, gpa: Allocator) void {
self.lzma_decode.deinit(gpa);
self.* = undefined;
}
/// Returns how many compressed bytes were consumed.
pub fn decompress(d: *Decode, reader: *Reader, allocating: *Writer.Allocating) !u64 {
const gpa = allocating.allocator;
var accum = AccumBuffer.init(std.math.maxInt(usize));
defer accum.deinit(gpa);
var n_read: u64 = 0;
while (true) {
const status = try reader.takeByte();
n_read += 1;
switch (status) {
0 => break,
1 => n_read += try parseUncompressed(reader, allocating, &accum, true),
2 => n_read += try parseUncompressed(reader, allocating, &accum, false),
else => n_read += try d.parseLzma(reader, allocating, &accum, status),
}
}
try accum.finish(&allocating.writer);
return n_read;
}
fn parseLzma(
d: *Decode,
reader: *Reader,
allocating: *Writer.Allocating,
accum: *AccumBuffer,
status: u8,
) !u64 {
if (status & 0x80 == 0) return error.CorruptInput;
const Reset = struct {
dict: bool,
state: bool,
props: bool,
};
const reset: Reset = switch ((status >> 5) & 0x3) {
0 => .{
.dict = false,
.state = false,
.props = false,
},
1 => .{
.dict = false,
.state = true,
.props = false,
},
2 => .{
.dict = false,
.state = true,
.props = true,
},
3 => .{
.dict = true,
.state = true,
.props = true,
},
else => unreachable,
};
var n_read: u64 = 0;
const unpacked_size = blk: {
var tmp: u64 = status & 0x1F;
tmp <<= 16;
tmp |= try reader.takeInt(u16, .big);
n_read += 2;
break :blk tmp + 1;
};
const packed_size = blk: {
const tmp: u17 = try reader.takeInt(u16, .big);
n_read += 2;
break :blk tmp + 1;
};
if (reset.dict) try accum.reset(&allocating.writer);
const ld = &d.lzma_decode;
if (reset.state) {
var new_props = ld.properties;
if (reset.props) {
var props = try reader.takeByte();
n_read += 1;
if (props >= 225) {
return error.CorruptInput;
}
const lc = @as(u4, @intCast(props % 9));
props /= 9;
const lp = @as(u3, @intCast(props % 5));
props /= 5;
const pb = @as(u3, @intCast(props));
if (lc + lp > 4) {
return error.CorruptInput;
}
new_props = .{ .lc = lc, .lp = lp, .pb = pb };
}
try ld.resetState(allocating.allocator, new_props);
}
const expected_unpacked_size = accum.len + unpacked_size;
const start_count = n_read;
var range_decoder = try lzma.RangeDecoder.initCounting(reader, &n_read);
while (true) {
if (accum.len >= expected_unpacked_size) break;
if (range_decoder.isFinished()) break;
switch (try ld.process(reader, allocating, accum, &range_decoder, &n_read)) {
.more => continue,
.finished => break,
}
}
if (accum.len != expected_unpacked_size) return error.DecompressedSizeMismatch;
if (n_read - start_count != packed_size) return error.CompressedSizeMismatch;
return n_read;
}
fn parseUncompressed(
reader: *Reader,
allocating: *Writer.Allocating,
accum: *AccumBuffer,
reset_dict: bool,
) !usize {
const unpacked_size = @as(u17, try reader.takeInt(u16, .big)) + 1;
if (reset_dict) try accum.reset(&allocating.writer);
const gpa = allocating.allocator;
for (0..unpacked_size) |_| {
try accum.appendByte(gpa, try reader.takeByte());
}
return 2 + unpacked_size;
}
};
test "decompress hello world stream" {
const expected = "Hello\nWorld!\n"; const expected = "Hello\nWorld!\n";
const compressed = &[_]u8{ 0x01, 0x00, 0x05, 0x48, 0x65, 0x6C, 0x6C, 0x6F, 0x0A, 0x02, 0x00, 0x06, 0x57, 0x6F, 0x72, 0x6C, 0x64, 0x21, 0x0A, 0x00 }; const compressed = &[_]u8{ 0x01, 0x00, 0x05, 0x48, 0x65, 0x6C, 0x6C, 0x6F, 0x0A, 0x02, 0x00, 0x06, 0x57, 0x6F, 0x72, 0x6C, 0x64, 0x21, 0x0A, 0x00 };
const allocator = std.testing.allocator; const gpa = std.testing.allocator;
var decomp = std.array_list.Managed(u8).init(allocator);
defer decomp.deinit(); var decode = try Decode.init(gpa);
var stream = std.io.fixedBufferStream(compressed); defer decode.deinit(gpa);
try decompress(allocator, stream.reader(), decomp.writer());
try std.testing.expectEqualSlices(u8, expected, decomp.items); var stream: std.Io.Reader = .fixed(compressed);
var result: std.Io.Writer.Allocating = .init(gpa);
defer result.deinit();
const n_read = try decode.decompress(&stream, &result);
try std.testing.expectEqual(compressed.len, n_read);
try std.testing.expectEqualStrings(expected, result.written());
} }

View file

@ -1,169 +0,0 @@
const std = @import("../../std.zig");
const Allocator = std.mem.Allocator;
const lzma = @import("../lzma.zig");
const DecoderState = lzma.decode.DecoderState;
const LzAccumBuffer = lzma.decode.lzbuffer.LzAccumBuffer;
const Properties = lzma.decode.Properties;
const RangeDecoder = lzma.decode.rangecoder.RangeDecoder;
pub const Decoder = struct {
lzma_state: DecoderState,
pub fn init(allocator: Allocator) !Decoder {
return Decoder{
.lzma_state = try DecoderState.init(
allocator,
Properties{
.lc = 0,
.lp = 0,
.pb = 0,
},
null,
),
};
}
pub fn deinit(self: *Decoder, allocator: Allocator) void {
self.lzma_state.deinit(allocator);
self.* = undefined;
}
pub fn decompress(
self: *Decoder,
allocator: Allocator,
reader: anytype,
writer: anytype,
) !void {
var accum = LzAccumBuffer.init(std.math.maxInt(usize));
defer accum.deinit(allocator);
while (true) {
const status = try reader.readByte();
switch (status) {
0 => break,
1 => try parseUncompressed(allocator, reader, writer, &accum, true),
2 => try parseUncompressed(allocator, reader, writer, &accum, false),
else => try self.parseLzma(allocator, reader, writer, &accum, status),
}
}
try accum.finish(writer);
}
fn parseLzma(
self: *Decoder,
allocator: Allocator,
reader: anytype,
writer: anytype,
accum: *LzAccumBuffer,
status: u8,
) !void {
if (status & 0x80 == 0) {
return error.CorruptInput;
}
const Reset = struct {
dict: bool,
state: bool,
props: bool,
};
const reset = switch ((status >> 5) & 0x3) {
0 => Reset{
.dict = false,
.state = false,
.props = false,
},
1 => Reset{
.dict = false,
.state = true,
.props = false,
},
2 => Reset{
.dict = false,
.state = true,
.props = true,
},
3 => Reset{
.dict = true,
.state = true,
.props = true,
},
else => unreachable,
};
const unpacked_size = blk: {
var tmp: u64 = status & 0x1F;
tmp <<= 16;
tmp |= try reader.readInt(u16, .big);
break :blk tmp + 1;
};
const packed_size = blk: {
const tmp: u17 = try reader.readInt(u16, .big);
break :blk tmp + 1;
};
if (reset.dict) {
try accum.reset(writer);
}
if (reset.state) {
var new_props = self.lzma_state.lzma_props;
if (reset.props) {
var props = try reader.readByte();
if (props >= 225) {
return error.CorruptInput;
}
const lc = @as(u4, @intCast(props % 9));
props /= 9;
const lp = @as(u3, @intCast(props % 5));
props /= 5;
const pb = @as(u3, @intCast(props));
if (lc + lp > 4) {
return error.CorruptInput;
}
new_props = Properties{ .lc = lc, .lp = lp, .pb = pb };
}
try self.lzma_state.resetState(allocator, new_props);
}
self.lzma_state.unpacked_size = unpacked_size + accum.len;
var counter = std.io.countingReader(reader);
const counter_reader = counter.reader();
var rangecoder = try RangeDecoder.init(counter_reader);
while (try self.lzma_state.process(allocator, counter_reader, writer, accum, &rangecoder) == .continue_) {}
if (counter.bytes_read != packed_size) {
return error.CorruptInput;
}
}
fn parseUncompressed(
allocator: Allocator,
reader: anytype,
writer: anytype,
accum: *LzAccumBuffer,
reset_dict: bool,
) !void {
const unpacked_size = @as(u17, try reader.readInt(u16, .big)) + 1;
if (reset_dict) {
try accum.reset(writer);
}
var i: @TypeOf(unpacked_size) = 0;
while (i < unpacked_size) : (i += 1) {
try accum.appendByte(allocator, try reader.readByte());
}
}
};

View file

@ -1,165 +1,4 @@
const std = @import("std"); pub const Decompress = @import("xz/Decompress.zig");
const block = @import("xz/block.zig");
const Allocator = std.mem.Allocator;
const Crc32 = std.hash.Crc32;
pub const Check = enum(u4) {
none = 0x00,
crc32 = 0x01,
crc64 = 0x04,
sha256 = 0x0A,
_,
};
fn readStreamFlags(reader: anytype, check: *Check) !void {
const reserved1 = try reader.readByte();
if (reserved1 != 0) return error.CorruptInput;
const byte = try reader.readByte();
if ((byte >> 4) != 0) return error.CorruptInput;
check.* = @enumFromInt(@as(u4, @truncate(byte)));
}
pub fn decompress(allocator: Allocator, reader: anytype) !Decompress(@TypeOf(reader)) {
return Decompress(@TypeOf(reader)).init(allocator, reader);
}
pub fn Decompress(comptime ReaderType: type) type {
return struct {
const Self = @This();
pub const Error = ReaderType.Error || block.Decoder(ReaderType).Error;
pub const Reader = std.io.GenericReader(*Self, Error, read);
allocator: Allocator,
block_decoder: block.Decoder(ReaderType),
in_reader: ReaderType,
fn init(allocator: Allocator, source: ReaderType) !Self {
const magic = try source.readBytesNoEof(6);
if (!std.mem.eql(u8, &magic, &.{ 0xFD, '7', 'z', 'X', 'Z', 0x00 }))
return error.BadHeader;
var check: Check = undefined;
const hash_a = blk: {
var hasher = hashedReader(source, Crc32.init());
try readStreamFlags(hasher.reader(), &check);
break :blk hasher.hasher.final();
};
const hash_b = try source.readInt(u32, .little);
if (hash_a != hash_b)
return error.WrongChecksum;
return Self{
.allocator = allocator,
.block_decoder = try block.decoder(allocator, source, check),
.in_reader = source,
};
}
pub fn deinit(self: *Self) void {
self.block_decoder.deinit();
}
pub fn reader(self: *Self) Reader {
return .{ .context = self };
}
pub fn read(self: *Self, buffer: []u8) Error!usize {
if (buffer.len == 0)
return 0;
const r = try self.block_decoder.read(buffer);
if (r != 0)
return r;
const index_size = blk: {
var hasher = hashedReader(self.in_reader, Crc32.init());
hasher.hasher.update(&[1]u8{0x00});
var counter = std.io.countingReader(hasher.reader());
counter.bytes_read += 1;
const counting_reader = counter.reader();
const record_count = try std.leb.readUleb128(u64, counting_reader);
if (record_count != self.block_decoder.block_count)
return error.CorruptInput;
var i: usize = 0;
while (i < record_count) : (i += 1) {
// TODO: validate records
_ = try std.leb.readUleb128(u64, counting_reader);
_ = try std.leb.readUleb128(u64, counting_reader);
}
while (counter.bytes_read % 4 != 0) {
if (try counting_reader.readByte() != 0)
return error.CorruptInput;
}
const hash_a = hasher.hasher.final();
const hash_b = try counting_reader.readInt(u32, .little);
if (hash_a != hash_b)
return error.WrongChecksum;
break :blk counter.bytes_read;
};
const hash_a = try self.in_reader.readInt(u32, .little);
const hash_b = blk: {
var hasher = hashedReader(self.in_reader, Crc32.init());
const hashed_reader = hasher.reader();
const backward_size = (@as(u64, try hashed_reader.readInt(u32, .little)) + 1) * 4;
if (backward_size != index_size)
return error.CorruptInput;
var check: Check = undefined;
try readStreamFlags(hashed_reader, &check);
break :blk hasher.hasher.final();
};
if (hash_a != hash_b)
return error.WrongChecksum;
const magic = try self.in_reader.readBytesNoEof(2);
if (!std.mem.eql(u8, &magic, &.{ 'Y', 'Z' }))
return error.CorruptInput;
return 0;
}
};
}
pub fn HashedReader(ReaderType: type, HasherType: type) type {
return struct {
child_reader: ReaderType,
hasher: HasherType,
pub const Error = ReaderType.Error;
pub const Reader = std.io.GenericReader(*@This(), Error, read);
pub fn read(self: *@This(), buf: []u8) Error!usize {
const amt = try self.child_reader.read(buf);
self.hasher.update(buf[0..amt]);
return amt;
}
pub fn reader(self: *@This()) Reader {
return .{ .context = self };
}
};
}
pub fn hashedReader(
reader: anytype,
hasher: anytype,
) HashedReader(@TypeOf(reader), @TypeOf(hasher)) {
return .{ .child_reader = reader, .hasher = hasher };
}
test { test {
_ = @import("xz/test.zig"); _ = @import("xz/test.zig");

View file

@ -0,0 +1,319 @@
const Decompress = @This();
const std = @import("../../std.zig");
const Allocator = std.mem.Allocator;
const ArrayList = std.ArrayList;
const Crc32 = std.hash.Crc32;
const Crc64 = std.hash.crc.Crc64Xz;
const Sha256 = std.crypto.hash.sha2.Sha256;
const lzma2 = std.compress.lzma2;
const Writer = std.Io.Writer;
const Reader = std.Io.Reader;
const assert = std.debug.assert;
/// Underlying compressed data stream to pull bytes from.
input: *Reader,
/// Uncompressed bytes output by this stream implementation.
reader: Reader,
gpa: Allocator,
check: Check,
block_count: usize,
err: ?Error,
pub const Error = error{
ReadFailed,
OutOfMemory,
CorruptInput,
EndOfStream,
WrongChecksum,
Unsupported,
Overflow,
InvalidRangeCode,
DecompressedSizeMismatch,
CompressedSizeMismatch,
};
pub const Check = enum(u4) {
none = 0x00,
crc32 = 0x01,
crc64 = 0x04,
sha256 = 0x0A,
_,
};
pub const StreamFlags = packed struct(u16) {
null: u8 = 0,
check: Check,
reserved: u4 = 0,
};
pub const InitError = error{
NotXzStream,
WrongChecksum,
};
/// XZ uses a series of LZMA2 blocks which each specify a dictionary size
/// anywhere from 4K to 4G. Thus, this API dynamically allocates the dictionary
/// as-needed.
pub fn init(
input: *Reader,
gpa: Allocator,
/// Decompress takes ownership of this buffer and resizes it with `gpa`.
buffer: []u8,
) !Decompress {
const magic = try input.takeArray(6);
if (!std.mem.eql(u8, magic, &.{ 0xFD, '7', 'z', 'X', 'Z', 0x00 }))
return error.NotXzStream;
const computed_checksum = Crc32.hash(try input.peek(@sizeOf(StreamFlags)));
const stream_flags = input.takeStruct(StreamFlags, .little) catch unreachable;
const stored_hash = try input.takeInt(u32, .little);
if (computed_checksum != stored_hash) return error.WrongChecksum;
return .{
.input = input,
.reader = .{
.vtable = &.{
.stream = stream,
.readVec = readVec,
.discard = discard,
},
.buffer = buffer,
.seek = 0,
.end = 0,
},
.gpa = gpa,
.check = stream_flags.check,
.block_count = 0,
.err = null,
};
}
/// Reclaim ownership of the buffer passed to `init`.
pub fn takeBuffer(d: *Decompress) []u8 {
const buffer = d.reader.buffer;
d.reader.buffer = &.{};
return buffer;
}
pub fn deinit(d: *Decompress) void {
const gpa = d.gpa;
gpa.free(d.reader.buffer);
d.* = undefined;
}
fn readVec(r: *Reader, data: [][]u8) Reader.Error!usize {
_ = data;
return readIndirect(r);
}
fn stream(r: *Reader, w: *Writer, limit: std.Io.Limit) Reader.StreamError!usize {
_ = w;
_ = limit;
return readIndirect(r);
}
fn discard(r: *Reader, limit: std.Io.Limit) Reader.Error!usize {
const d: *Decompress = @alignCast(@fieldParentPtr("reader", r));
_ = d;
_ = limit;
@panic("TODO");
}
fn readIndirect(r: *Reader) Reader.Error!usize {
const d: *Decompress = @alignCast(@fieldParentPtr("reader", r));
const gpa = d.gpa;
const input = d.input;
var allocating = Writer.Allocating.initOwnedSlice(gpa, r.buffer);
allocating.writer.end = r.end;
defer {
r.buffer = allocating.writer.buffer;
r.end = allocating.writer.end;
}
if (d.err != null) return error.ReadFailed;
if (d.block_count == std.math.maxInt(usize)) return error.EndOfStream;
readBlock(input, &allocating) catch |err| switch (err) {
error.WriteFailed => {
d.err = error.OutOfMemory;
return error.ReadFailed;
},
error.SuccessfulEndOfStream => {
finish(d) catch |finish_err| {
d.err = finish_err;
return error.ReadFailed;
};
d.block_count = std.math.maxInt(usize);
return error.EndOfStream;
},
else => |e| {
d.err = e;
return error.ReadFailed;
},
};
switch (d.check) {
.none => {},
.crc32 => {
const declared_checksum = try input.takeInt(u32, .little);
// TODO
//const hash_a = Crc32.hash(unpacked_bytes);
//if (hash_a != hash_b) return error.WrongChecksum;
_ = declared_checksum;
},
.crc64 => {
const declared_checksum = try input.takeInt(u64, .little);
// TODO
//const hash_a = Crc64.hash(unpacked_bytes);
//if (hash_a != hash_b) return error.WrongChecksum;
_ = declared_checksum;
},
.sha256 => {
const declared_hash = try input.take(Sha256.digest_length);
// TODO
//var hash_a: [Sha256.digest_length]u8 = undefined;
//Sha256.hash(unpacked_bytes, &hash_a, .{});
//if (!std.mem.eql(u8, &hash_a, &hash_b))
// return error.WrongChecksum;
_ = declared_hash;
},
else => {
d.err = error.Unsupported;
return error.ReadFailed;
},
}
d.block_count += 1;
return 0;
}
fn readBlock(input: *Reader, allocating: *Writer.Allocating) !void {
var packed_size: ?u64 = null;
var unpacked_size: ?u64 = null;
const header_size = h: {
// Read the block header via peeking so that we can hash the whole thing too.
const first_byte: usize = try input.peekByte();
if (first_byte == 0) return error.SuccessfulEndOfStream;
const declared_header_size = first_byte * 4;
try input.fill(declared_header_size);
const header_seek_start = input.seek;
input.toss(1);
const Flags = packed struct(u8) {
last_filter_index: u2,
reserved: u4,
has_packed_size: bool,
has_unpacked_size: bool,
};
const flags = try input.takeStruct(Flags, .little);
const filter_count = @as(u3, flags.last_filter_index) + 1;
if (filter_count > 1) return error.Unsupported;
if (flags.has_packed_size) packed_size = try input.takeLeb128(u64);
if (flags.has_unpacked_size) unpacked_size = try input.takeLeb128(u64);
const FilterId = enum(u64) {
lzma2 = 0x21,
_,
};
const filter_id: FilterId = @enumFromInt(try input.takeLeb128(u64));
if (filter_id != .lzma2) return error.Unsupported;
const properties_size = try input.takeLeb128(u64);
if (properties_size != 1) return error.CorruptInput;
// TODO: use filter properties
_ = try input.takeByte();
const actual_header_size = input.seek - header_seek_start;
if (actual_header_size > declared_header_size) return error.CorruptInput;
const remaining_bytes = declared_header_size - actual_header_size;
for (0..remaining_bytes) |_| {
if (try input.takeByte() != 0) return error.CorruptInput;
}
const header_slice = input.buffer[header_seek_start..][0..declared_header_size];
const computed_checksum = Crc32.hash(header_slice);
const declared_checksum = try input.takeInt(u32, .little);
if (computed_checksum != declared_checksum) return error.WrongChecksum;
break :h declared_header_size;
};
// Compressed Data
var lzma2_decode = try lzma2.Decode.init(allocating.allocator);
defer lzma2_decode.deinit(allocating.allocator);
const before_size = allocating.writer.end;
const packed_bytes_read = try lzma2_decode.decompress(input, allocating);
const unpacked_bytes = allocating.writer.end - before_size;
if (packed_size) |s| {
if (s != packed_bytes_read) return error.CorruptInput;
}
if (unpacked_size) |s| {
if (s != unpacked_bytes) return error.CorruptInput;
}
// Block Padding
const block_counter = header_size + packed_bytes_read;
const padding = try input.take(@intCast((4 - (block_counter % 4)) % 4));
for (padding) |byte| {
if (byte != 0) return error.CorruptInput;
}
}
fn finish(d: *Decompress) !void {
const input = d.input;
const index_size = blk: {
// Assume that we already peeked a zero in readBlock().
assert(input.buffered()[0] == 0);
var input_counter: u64 = 1;
var checksum: Crc32 = .init();
checksum.update(&.{0});
input.toss(1);
const record_count = try countLeb128(input, u64, &input_counter, &checksum);
if (record_count != d.block_count)
return error.CorruptInput;
for (0..@intCast(record_count)) |_| {
// TODO: validate records
_ = try countLeb128(input, u64, &input_counter, &checksum);
_ = try countLeb128(input, u64, &input_counter, &checksum);
}
const padding = try input.take(@intCast((4 - (input_counter % 4)) % 4));
for (padding) |byte| {
if (byte != 0) return error.CorruptInput;
}
checksum.update(padding);
const declared_checksum = try input.takeInt(u32, .little);
const computed_checksum = checksum.final();
if (computed_checksum != declared_checksum) return error.WrongChecksum;
break :blk input_counter + padding.len + 4;
};
const declared_checksum = try input.takeInt(u32, .little);
const computed_checksum = Crc32.hash(try input.peek(4 + @sizeOf(StreamFlags)));
if (declared_checksum != computed_checksum) return error.WrongChecksum;
const backward_size = (@as(u64, try input.takeInt(u32, .little)) + 1) * 4;
if (backward_size != index_size) return error.CorruptInput;
input.toss(@sizeOf(StreamFlags));
if (!std.mem.eql(u8, try input.takeArray(2), &.{ 'Y', 'Z' }))
return error.CorruptInput;
}
fn countLeb128(reader: *Reader, comptime T: type, counter: *u64, hasher: *Crc32) !T {
try reader.fill(8);
const start = reader.seek;
const result = try reader.takeLeb128(T);
const read_slice = reader.buffer[start..reader.seek];
hasher.update(read_slice);
counter.* += read_slice.len;
return result;
}

View file

@ -1,208 +0,0 @@
const std = @import("../../std.zig");
const lzma2 = std.compress.lzma2;
const Allocator = std.mem.Allocator;
const ArrayListUnmanaged = std.ArrayListUnmanaged;
const Crc32 = std.hash.Crc32;
const Crc64 = std.hash.crc.Crc64Xz;
const Sha256 = std.crypto.hash.sha2.Sha256;
const xz = std.compress.xz;
const DecodeError = error{
CorruptInput,
EndOfStream,
EndOfStreamWithNoError,
WrongChecksum,
Unsupported,
Overflow,
};
pub fn decoder(allocator: Allocator, reader: anytype, check: xz.Check) !Decoder(@TypeOf(reader)) {
return Decoder(@TypeOf(reader)).init(allocator, reader, check);
}
pub fn Decoder(comptime ReaderType: type) type {
return struct {
const Self = @This();
pub const Error =
ReaderType.Error ||
DecodeError ||
Allocator.Error;
pub const Reader = std.io.GenericReader(*Self, Error, read);
allocator: Allocator,
inner_reader: ReaderType,
check: xz.Check,
err: ?Error,
to_read: ArrayListUnmanaged(u8),
read_pos: usize,
block_count: usize,
fn init(allocator: Allocator, in_reader: ReaderType, check: xz.Check) !Self {
return Self{
.allocator = allocator,
.inner_reader = in_reader,
.check = check,
.err = null,
.to_read = .{},
.read_pos = 0,
.block_count = 0,
};
}
pub fn deinit(self: *Self) void {
self.to_read.deinit(self.allocator);
}
pub fn reader(self: *Self) Reader {
return .{ .context = self };
}
pub fn read(self: *Self, output: []u8) Error!usize {
while (true) {
const unread_len = self.to_read.items.len - self.read_pos;
if (unread_len > 0) {
const n = @min(unread_len, output.len);
@memcpy(output[0..n], self.to_read.items[self.read_pos..][0..n]);
self.read_pos += n;
return n;
}
if (self.err) |e| {
if (e == DecodeError.EndOfStreamWithNoError) {
return 0;
}
return e;
}
if (self.read_pos > 0) {
self.to_read.shrinkRetainingCapacity(0);
self.read_pos = 0;
}
self.readBlock() catch |e| {
self.err = e;
};
}
}
fn readBlock(self: *Self) Error!void {
var block_counter = std.io.countingReader(self.inner_reader);
const block_reader = block_counter.reader();
var packed_size: ?u64 = null;
var unpacked_size: ?u64 = null;
// Block Header
{
var header_hasher = xz.hashedReader(block_reader, Crc32.init());
const header_reader = header_hasher.reader();
const header_size = @as(u64, try header_reader.readByte()) * 4;
if (header_size == 0)
return error.EndOfStreamWithNoError;
const Flags = packed struct(u8) {
last_filter_index: u2,
reserved: u4,
has_packed_size: bool,
has_unpacked_size: bool,
};
const flags = @as(Flags, @bitCast(try header_reader.readByte()));
const filter_count = @as(u3, flags.last_filter_index) + 1;
if (filter_count > 1)
return error.Unsupported;
if (flags.has_packed_size)
packed_size = try std.leb.readUleb128(u64, header_reader);
if (flags.has_unpacked_size)
unpacked_size = try std.leb.readUleb128(u64, header_reader);
const FilterId = enum(u64) {
lzma2 = 0x21,
_,
};
const filter_id = @as(
FilterId,
@enumFromInt(try std.leb.readUleb128(u64, header_reader)),
);
if (@intFromEnum(filter_id) >= 0x4000_0000_0000_0000)
return error.CorruptInput;
if (filter_id != .lzma2)
return error.Unsupported;
const properties_size = try std.leb.readUleb128(u64, header_reader);
if (properties_size != 1)
return error.CorruptInput;
// TODO: use filter properties
_ = try header_reader.readByte();
while (block_counter.bytes_read != header_size) {
if (try header_reader.readByte() != 0)
return error.CorruptInput;
}
const hash_a = header_hasher.hasher.final();
const hash_b = try header_reader.readInt(u32, .little);
if (hash_a != hash_b)
return error.WrongChecksum;
}
// Compressed Data
var packed_counter = std.io.countingReader(block_reader);
try lzma2.decompress(
self.allocator,
packed_counter.reader(),
self.to_read.writer(self.allocator),
);
if (packed_size) |s| {
if (s != packed_counter.bytes_read)
return error.CorruptInput;
}
const unpacked_bytes = self.to_read.items;
if (unpacked_size) |s| {
if (s != unpacked_bytes.len)
return error.CorruptInput;
}
// Block Padding
while (block_counter.bytes_read % 4 != 0) {
if (try block_reader.readByte() != 0)
return error.CorruptInput;
}
switch (self.check) {
.none => {},
.crc32 => {
const hash_a = Crc32.hash(unpacked_bytes);
const hash_b = try self.inner_reader.readInt(u32, .little);
if (hash_a != hash_b)
return error.WrongChecksum;
},
.crc64 => {
const hash_a = Crc64.hash(unpacked_bytes);
const hash_b = try self.inner_reader.readInt(u64, .little);
if (hash_a != hash_b)
return error.WrongChecksum;
},
.sha256 => {
var hash_a: [Sha256.digest_length]u8 = undefined;
Sha256.hash(unpacked_bytes, &hash_a, .{});
var hash_b: [Sha256.digest_length]u8 = undefined;
try self.inner_reader.readNoEof(&hash_b);
if (!std.mem.eql(u8, &hash_a, &hash_b))
return error.WrongChecksum;
},
else => return error.Unsupported,
}
self.block_count += 1;
}
};
}

View file

@ -3,48 +3,79 @@ const testing = std.testing;
const xz = std.compress.xz; const xz = std.compress.xz;
fn decompress(data: []const u8) ![]u8 { fn decompress(data: []const u8) ![]u8 {
var in_stream = std.io.fixedBufferStream(data); const gpa = testing.allocator;
var xz_stream = try xz.decompress(testing.allocator, in_stream.reader()); var in_stream: std.Io.Reader = .fixed(data);
var xz_stream = try xz.Decompress.init(&in_stream, gpa, &.{});
defer xz_stream.deinit(); defer xz_stream.deinit();
return xz_stream.reader().readAllAlloc(testing.allocator, std.math.maxInt(usize)); return xz_stream.reader.allocRemaining(gpa, .unlimited);
} }
fn testReader(data: []const u8, comptime expected: []const u8) !void { fn testReader(data: []const u8, expected: []const u8) !void {
const buf = try decompress(data); const gpa = testing.allocator;
defer testing.allocator.free(buf);
try testing.expectEqualSlices(u8, expected, buf); const result = try decompress(data);
defer gpa.free(result);
try testing.expectEqualSlices(u8, expected, result);
} }
test "compressed data" { fn testDecompressError(expected: anyerror, compressed: []const u8) !void {
const gpa = std.testing.allocator;
var stream: std.Io.Reader = .fixed(compressed);
var decompressor = try xz.Decompress.init(&stream, gpa, &.{});
defer decompressor.deinit();
try std.testing.expectError(error.ReadFailed, decompressor.reader.allocRemaining(gpa, .unlimited));
try std.testing.expectEqual(expected, decompressor.err orelse return error.TestFailed);
}
test "fixture good-0-empty.xz" {
try testReader(@embedFile("testdata/good-0-empty.xz"), ""); try testReader(@embedFile("testdata/good-0-empty.xz"), "");
}
inline for ([_][]const u8{ const hello_world_text =
"good-1-check-none.xz",
"good-1-check-crc32.xz",
"good-1-check-crc64.xz",
"good-1-check-sha256.xz",
"good-2-lzma2.xz",
"good-1-block_header-1.xz",
"good-1-block_header-2.xz",
"good-1-block_header-3.xz",
}) |filename| {
try testReader(@embedFile("testdata/" ++ filename),
\\Hello \\Hello
\\World! \\World!
\\ \\
); ;
}
inline for ([_][]const u8{ test "fixture good-1-check-none.xz" {
"good-1-lzma2-1.xz", try testReader(@embedFile("testdata/good-1-check-none.xz"), hello_world_text);
"good-1-lzma2-2.xz", }
"good-1-lzma2-3.xz",
"good-1-lzma2-4.xz", test "fixture good-1-check-crc32.xz" {
}) |filename| { try testReader(@embedFile("testdata/good-1-check-crc32.xz"), hello_world_text);
try testReader(@embedFile("testdata/" ++ filename), }
test "fixture good-1-check-crc64.xz" {
try testReader(@embedFile("testdata/good-1-check-crc64.xz"), hello_world_text);
}
test "fixture good-1-check-sha256.xz" {
try testReader(@embedFile("testdata/good-1-check-sha256.xz"), hello_world_text);
}
test "fixture good-2-lzma2.xz" {
try testReader(@embedFile("testdata/good-2-lzma2.xz"), hello_world_text);
}
test "fixture good-1-block_header-1.xz" {
try testReader(@embedFile("testdata/good-1-block_header-1.xz"), hello_world_text);
}
test "fixture good-1-block_header-2.xz" {
try testReader(@embedFile("testdata/good-1-block_header-2.xz"), hello_world_text);
}
test "fixture good-1-block_header-3.xz" {
try testReader(@embedFile("testdata/good-1-block_header-3.xz"), hello_world_text);
}
const lorem_ipsum_text =
\\Lorem ipsum dolor sit amet, consectetur adipisicing \\Lorem ipsum dolor sit amet, consectetur adipisicing
\\elit, sed do eiusmod tempor incididunt ut \\elit, sed do eiusmod tempor incididunt ut
\\labore et dolore magna aliqua. Ut enim \\labore et dolore magna aliqua. Ut enim
@ -56,27 +87,54 @@ test "compressed data" {
\\non proident, sunt in culpa qui officia \\non proident, sunt in culpa qui officia
\\deserunt mollit anim id est laborum. \\deserunt mollit anim id est laborum.
\\ \\
); ;
}
test "fixture good-1-lzma2-1.xz" {
try testReader(@embedFile("testdata/good-1-lzma2-1.xz"), lorem_ipsum_text);
}
test "fixture good-1-lzma2-2.xz" {
try testReader(@embedFile("testdata/good-1-lzma2-2.xz"), lorem_ipsum_text);
}
test "fixture good-1-lzma2-3.xz" {
try testReader(@embedFile("testdata/good-1-lzma2-3.xz"), lorem_ipsum_text);
}
test "fixture good-1-lzma2-4.xz" {
try testReader(@embedFile("testdata/good-1-lzma2-4.xz"), lorem_ipsum_text);
}
test "fixture good-1-lzma2-5.xz" {
try testReader(@embedFile("testdata/good-1-lzma2-5.xz"), ""); try testReader(@embedFile("testdata/good-1-lzma2-5.xz"), "");
} }
test "unsupported" { test "fixture good-1-delta-lzma2.tiff.xz" {
inline for ([_][]const u8{ try testDecompressError(error.Unsupported, @embedFile("testdata/good-1-delta-lzma2.tiff.xz"));
"good-1-delta-lzma2.tiff.xz", }
"good-1-x86-lzma2.xz",
"good-1-sparc-lzma2.xz", test "fixture good-1-x86-lzma2.xz" {
"good-1-arm64-lzma2-1.xz", try testDecompressError(error.Unsupported, @embedFile("testdata/good-1-x86-lzma2.xz"));
"good-1-arm64-lzma2-2.xz", }
"good-1-3delta-lzma2.xz",
"good-1-empty-bcj-lzma2.xz", test "fixture good-1-sparc-lzma2.xz" {
}) |filename| { try testDecompressError(error.Unsupported, @embedFile("testdata/good-1-sparc-lzma2.xz"));
try testing.expectError( }
error.Unsupported,
decompress(@embedFile("testdata/" ++ filename)), test "fixture good-1-arm64-lzma2-1.xz" {
); try testDecompressError(error.Unsupported, @embedFile("testdata/good-1-arm64-lzma2-1.xz"));
} }
test "fixture good-1-arm64-lzma2-2.xz" {
try testDecompressError(error.Unsupported, @embedFile("testdata/good-1-arm64-lzma2-2.xz"));
}
test "fixture good-1-3delta-lzma2.xz" {
try testDecompressError(error.Unsupported, @embedFile("testdata/good-1-3delta-lzma2.xz"));
}
test "fixture good-1-empty-bcj-lzma2.xz" {
try testDecompressError(error.Unsupported, @embedFile("testdata/good-1-empty-bcj-lzma2.xz"));
} }
fn testDontPanic(data: []const u8) !void { fn testDontPanic(data: []const u8) !void {
@ -91,6 +149,8 @@ test "size fields: integer overflow avoidance" {
// These cases were found via fuzz testing and each previously caused // These cases were found via fuzz testing and each previously caused
// an integer overflow when decoding. We just want to ensure they no longer // an integer overflow when decoding. We just want to ensure they no longer
// cause a panic // cause a panic
// TODO this not a sufficient way to test. tests should always check the result,
// not merely ensure that the code does not crash.
const header_size_overflow = "\xfd7zXZ\x00\x00\x01i\"\xde6z"; const header_size_overflow = "\xfd7zXZ\x00\x00\x01i\"\xde6z";
try testDontPanic(header_size_overflow); try testDontPanic(header_size_overflow);
const lzma2_chunk_size_overflow = "\xfd7zXZ\x00\x00\x01i\"\xde6\x02\x00!\x01\x08\x00\x00\x00\xd8\x0f#\x13\x01\xff\xff"; const lzma2_chunk_size_overflow = "\xfd7zXZ\x00\x00\x01i\"\xde6\x02\x00!\x01\x08\x00\x00\x00\xd8\x0f#\x13\x01\xff\xff";

View file

@ -1204,12 +1204,10 @@ fn unpackResource(
}, },
.@"tar.xz" => { .@"tar.xz" => {
const gpa = f.arena.child_allocator; const gpa = f.arena.child_allocator;
var dcp = std.compress.xz.decompress(gpa, resource.reader().adaptToOldInterface()) catch |err| var decompress = std.compress.xz.Decompress.init(resource.reader(), gpa, &.{}) catch |err|
return f.fail(f.location_tok, try eb.printString("unable to decompress tarball: {t}", .{err})); return f.fail(f.location_tok, try eb.printString("unable to decompress tarball: {t}", .{err}));
defer dcp.deinit(); defer decompress.deinit();
var adapter_buffer: [1024]u8 = undefined; return try unpackTarball(f, tmp_directory.handle, &decompress.reader);
var adapter = dcp.reader().adaptToNewApi(&adapter_buffer);
return try unpackTarball(f, tmp_directory.handle, &adapter.new_interface);
}, },
.@"tar.zst" => { .@"tar.zst" => {
const window_len = std.compress.zstd.default_window_len; const window_len = std.compress.zstd.default_window_len;