std.compress.zstandard: check FSE bitstreams are fully consumed

This commit is contained in:
dweiller 2023-02-12 04:33:20 +11:00
parent 1530e73648
commit 373d8ef26e
3 changed files with 32 additions and 16 deletions

View file

@ -391,15 +391,21 @@ pub const DecodeState = struct {
try self.literal_stream_reader.init(bytes); try self.literal_stream_reader.init(bytes);
} }
fn isLiteralStreamEmpty(self: *DecodeState) bool {
switch (self.literal_streams) {
.one => return self.literal_stream_reader.isEmpty(),
.four => return self.literal_stream_index == 3 and self.literal_stream_reader.isEmpty(),
}
}
const LiteralBitsError = error{ const LiteralBitsError = error{
BitStreamHasNoStartBit, BitStreamHasNoStartBit,
UnexpectedEndOfLiteralStream, UnexpectedEndOfLiteralStream,
}; };
fn readLiteralsBits( fn readLiteralsBits(
self: *DecodeState, self: *DecodeState,
comptime T: type,
bit_count_to_read: usize, bit_count_to_read: usize,
) LiteralBitsError!T { ) LiteralBitsError!u16 {
return self.literal_stream_reader.readBitsNoEof(u16, bit_count_to_read) catch bits: { return self.literal_stream_reader.readBitsNoEof(u16, bit_count_to_read) catch bits: {
if (self.literal_streams == .four and self.literal_stream_index < 3) { if (self.literal_streams == .four and self.literal_stream_index < 3) {
try self.nextLiteralMultiStream(); try self.nextLiteralMultiStream();
@ -461,7 +467,7 @@ pub const DecodeState = struct {
while (i < len) : (i += 1) { while (i < len) : (i += 1) {
var prefix: u16 = 0; var prefix: u16 = 0;
while (true) { while (true) {
const new_bits = self.readLiteralsBits(u16, bit_count_to_read) catch |err| { const new_bits = self.readLiteralsBits(bit_count_to_read) catch |err| {
return err; return err;
}; };
prefix <<= bit_count_to_read; prefix <<= bit_count_to_read;
@ -533,7 +539,7 @@ pub const DecodeState = struct {
while (i < len) : (i += 1) { while (i < len) : (i += 1) {
var prefix: u16 = 0; var prefix: u16 = 0;
while (true) { while (true) {
const new_bits = try self.readLiteralsBits(u16, bit_count_to_read); const new_bits = try self.readLiteralsBits(bit_count_to_read);
prefix <<= bit_count_to_read; prefix <<= bit_count_to_read;
prefix |= new_bits; prefix |= new_bits;
bits_read += bit_count_to_read; bits_read += bit_count_to_read;
@ -659,13 +665,10 @@ pub fn decodeBlock(
sequence_size_limit -= decompressed_size; sequence_size_limit -= decompressed_size;
} }
if (bit_stream.bit_reader.bit_count != 0) { if (!bit_stream.isEmpty()) {
return error.MalformedCompressedBlock; return error.MalformedCompressedBlock;
} }
bytes_read += bit_stream_bytes.len;
} }
if (bytes_read != block_size) return error.MalformedCompressedBlock;
if (decode_state.literal_written_count < literals.header.regenerated_size) { if (decode_state.literal_written_count < literals.header.regenerated_size) {
const len = literals.header.regenerated_size - decode_state.literal_written_count; const len = literals.header.regenerated_size - decode_state.literal_written_count;
@ -675,7 +678,9 @@ pub fn decodeBlock(
bytes_written += len; bytes_written += len;
} }
consumed_count.* += bytes_read; if (!decode_state.isLiteralStreamEmpty()) return error.MalformedCompressedBlock;
consumed_count.* += block_size;
return bytes_written; return bytes_written;
}, },
.reserved => return error.ReservedBlock, .reserved => return error.ReservedBlock,
@ -749,13 +754,10 @@ pub fn decodeBlockRingBuffer(
sequence_size_limit -= decompressed_size; sequence_size_limit -= decompressed_size;
} }
if (bit_stream.bit_reader.bit_count != 0) { if (!bit_stream.isEmpty()) {
return error.MalformedCompressedBlock; return error.MalformedCompressedBlock;
} }
bytes_read += bit_stream_bytes.len;
} }
if (bytes_read != block_size) return error.MalformedCompressedBlock;
if (decode_state.literal_written_count < literals.header.regenerated_size) { if (decode_state.literal_written_count < literals.header.regenerated_size) {
const len = literals.header.regenerated_size - decode_state.literal_written_count; const len = literals.header.regenerated_size - decode_state.literal_written_count;
@ -764,7 +766,9 @@ pub fn decodeBlockRingBuffer(
bytes_written += len; bytes_written += len;
} }
consumed_count.* += bytes_read; if (!decode_state.isLiteralStreamEmpty()) return error.MalformedCompressedBlock;
consumed_count.* += block_size;
if (bytes_written > block_size_max) return error.BlockSizeOverMaximum; if (bytes_written > block_size_max) return error.BlockSizeOverMaximum;
return bytes_written; return bytes_written;
}, },
@ -837,7 +841,7 @@ pub fn decodeBlockReader(
sequence_size_limit -= decompressed_size; sequence_size_limit -= decompressed_size;
bytes_written += decompressed_size; bytes_written += decompressed_size;
} }
if (bit_stream.bit_reader.bit_count != 0) { if (!bit_stream.isEmpty()) {
return error.MalformedCompressedBlock; return error.MalformedCompressedBlock;
} }
} }
@ -849,6 +853,8 @@ pub fn decodeBlockReader(
bytes_written += len; bytes_written += len;
} }
if (!decode_state.isLiteralStreamEmpty()) return error.MalformedCompressedBlock;
if (bytes_written > block_size_max) return error.BlockSizeOverMaximum; if (bytes_written > block_size_max) return error.BlockSizeOverMaximum;
if (block_reader_limited.bytes_left != 0) return error.MalformedCompressedBlock; if (block_reader_limited.bytes_left != 0) return error.MalformedCompressedBlock;
decode_state.literal_written_count = 0; decode_state.literal_written_count = 0;

View file

@ -86,6 +86,10 @@ fn assignWeights(huff_bits: *readers.ReverseBitReader, accuracy_log: usize, entr
odd_state = odd_data.baseline + odd_bits; odd_state = odd_data.baseline + odd_bits;
} else return error.MalformedHuffmanTree; } else return error.MalformedHuffmanTree;
if (!huff_bits.isEmpty()) {
return error.MalformedHuffmanTree;
}
return i + 1; // stream contains all but the last symbol return i + 1; // stream contains all but the last symbol
} }

View file

@ -36,7 +36,9 @@ pub const ReverseBitReader = struct {
pub fn init(self: *ReverseBitReader, bytes: []const u8) error{BitStreamHasNoStartBit}!void { pub fn init(self: *ReverseBitReader, bytes: []const u8) error{BitStreamHasNoStartBit}!void {
self.byte_reader = ReversedByteReader.init(bytes); self.byte_reader = ReversedByteReader.init(bytes);
self.bit_reader = std.io.bitReader(.Big, self.byte_reader.reader()); self.bit_reader = std.io.bitReader(.Big, self.byte_reader.reader());
while (0 == self.readBitsNoEof(u1, 1) catch return error.BitStreamHasNoStartBit) {} var i: usize = 0;
while (i < 8 and 0 == self.readBitsNoEof(u1, 1) catch return error.BitStreamHasNoStartBit) : (i += 1) {}
if (i == 8) return error.BitStreamHasNoStartBit;
} }
pub fn readBitsNoEof(self: *@This(), comptime U: type, num_bits: usize) error{EndOfStream}!U { pub fn readBitsNoEof(self: *@This(), comptime U: type, num_bits: usize) error{EndOfStream}!U {
@ -50,6 +52,10 @@ pub const ReverseBitReader = struct {
pub fn alignToByte(self: *@This()) void { pub fn alignToByte(self: *@This()) void {
self.bit_reader.alignToByte(); self.bit_reader.alignToByte();
} }
pub fn isEmpty(self: ReverseBitReader) bool {
return self.byte_reader.remaining_bytes == 0 and self.bit_reader.bit_count == 0;
}
}; };
pub fn BitReader(comptime Reader: type) type { pub fn BitReader(comptime Reader: type) type {