Nothing to see here, move along meow
1#![allow(clippy::single_match)]
2
3use crate::error::FsError;
4
5const MIN_MATCH: usize = 4;
6const HASH_LOG: usize = 12;
7const HASH_SIZE: usize = 1 << HASH_LOG;
8const ML_BITS: u32 = 4;
9const RUN_BITS: u32 = 4;
10const RUN_MASK: usize = (1 << RUN_BITS) - 1;
11const ML_MASK: usize = (1 << ML_BITS) - 1;
12
13fn hash4(data: &[u8], pos: usize) -> usize {
14 let val = u32::from_le_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
15 ((val.wrapping_mul(2654435761)) >> (32 - HASH_LOG)) as usize
16}
17
18fn write_length(dst: &mut [u8], pos: &mut usize, length: usize) -> bool {
19 let mut remaining = length;
20 core::iter::from_fn(|| match remaining >= 255 {
21 true => {
22 remaining -= 255;
23 Some(255u8)
24 }
25 false => None,
26 })
27 .for_each(|byte| {
28 if *pos < dst.len() {
29 dst[*pos] = byte;
30 *pos += 1;
31 }
32 });
33 match *pos < dst.len() {
34 true => {
35 dst[*pos] = remaining as u8;
36 *pos += 1;
37 true
38 }
39 false => false,
40 }
41}
42
43pub fn compress(src: &[u8], dst: &mut [u8], table: &mut [u16; HASH_SIZE]) -> Option<usize> {
44 let src_len = src.len();
45 match src_len < MIN_MATCH + 1 {
46 true => return None,
47 false => {}
48 }
49
50 table.iter_mut().for_each(|entry| *entry = 0);
51 let mut dst_pos = 0usize;
52 let mut anchor = 0usize;
53 let mut src_pos = 0usize;
54 let src_limit = src_len.saturating_sub(5);
55
56 core::iter::from_fn(|| match src_pos >= src_limit {
57 true => None,
58 false => {
59 let h = hash4(src, src_pos);
60 let match_pos = table[h] as usize;
61 table[h] = src_pos as u16;
62
63 let matched = match_pos < src_pos
64 && src_pos - match_pos <= u16::MAX as usize
65 && src[match_pos..match_pos + 4] == src[src_pos..src_pos + 4];
66
67 match matched {
68 false => {
69 src_pos += 1;
70 Some(())
71 }
72 true => {
73 let lit_len = src_pos - anchor;
74 let match_len =
75 count_match(src, src_pos + MIN_MATCH, match_pos + MIN_MATCH) + MIN_MATCH;
76 let offset = (src_pos - match_pos) as u16;
77
78 let token_pos = dst_pos;
79 match token_pos < dst.len() {
80 true => dst[token_pos] = 0,
81 false => return None,
82 }
83 dst_pos += 1;
84
85 let token_lit = match lit_len >= RUN_MASK {
86 true => RUN_MASK,
87 false => lit_len,
88 };
89 dst[token_pos] = (token_lit << RUN_BITS) as u8;
90
91 match lit_len >= RUN_MASK {
92 true => {
93 let extra = lit_len - RUN_MASK;
94 match write_length(dst, &mut dst_pos, extra) {
95 true => {}
96 false => return None,
97 }
98 }
99 false => {}
100 }
101
102 match dst_pos + lit_len + 2 <= dst.len() {
103 true => {
104 dst[dst_pos..dst_pos + lit_len]
105 .copy_from_slice(&src[anchor..anchor + lit_len]);
106 dst_pos += lit_len;
107 }
108 false => return None,
109 }
110
111 let offset_bytes = offset.to_le_bytes();
112 dst[dst_pos] = offset_bytes[0];
113 dst[dst_pos + 1] = offset_bytes[1];
114 dst_pos += 2;
115
116 let ml = match_len - MIN_MATCH;
117 let token_ml = match ml >= ML_MASK {
118 true => ML_MASK,
119 false => ml,
120 };
121 dst[token_pos] |= token_ml as u8;
122
123 match ml >= ML_MASK {
124 true => {
125 let extra = ml - ML_MASK;
126 match write_length(dst, &mut dst_pos, extra) {
127 true => {}
128 false => return None,
129 }
130 }
131 false => {}
132 }
133
134 src_pos += match_len;
135 anchor = src_pos;
136
137 match src_pos < src_limit {
138 true => {
139 table[hash4(src, src_pos)] = src_pos as u16;
140 }
141 false => {}
142 }
143
144 Some(())
145 }
146 }
147 }
148 })
149 .last();
150
151 let remaining = src_len - anchor;
152 let token_pos = dst_pos;
153 match token_pos < dst.len() {
154 true => {}
155 false => return None,
156 }
157 dst_pos += 1;
158
159 let token_lit = match remaining >= RUN_MASK {
160 true => RUN_MASK,
161 false => remaining,
162 };
163 dst[token_pos] = (token_lit << RUN_BITS) as u8;
164
165 match remaining >= RUN_MASK {
166 true => {
167 let extra = remaining - RUN_MASK;
168 match write_length(dst, &mut dst_pos, extra) {
169 true => {}
170 false => return None,
171 }
172 }
173 false => {}
174 }
175
176 match dst_pos + remaining <= dst.len() {
177 true => {
178 dst[dst_pos..dst_pos + remaining].copy_from_slice(&src[anchor..anchor + remaining]);
179 dst_pos += remaining;
180 Some(dst_pos)
181 }
182 false => None,
183 }
184}
185
186fn count_match(src: &[u8], mut a: usize, mut b: usize) -> usize {
187 let start = a;
188 let limit = src.len();
189 core::iter::from_fn(|| match a < limit && b < limit && src[a] == src[b] {
190 true => {
191 a += 1;
192 b += 1;
193 Some(())
194 }
195 false => None,
196 })
197 .count();
198 a - start
199}
200
201struct DecompState {
202 src_pos: usize,
203 dst_pos: usize,
204}
205
206enum DecompStep {
207 Continue(DecompState),
208 Done(usize),
209 Corrupt,
210}
211
212pub fn decompress(src: &[u8], dst: &mut [u8]) -> Result<usize, FsError> {
213 let initial = DecompState {
214 src_pos: 0,
215 dst_pos: 0,
216 };
217
218 let result = (0..src.len()).try_fold(initial, |state, _| match state.src_pos >= src.len() {
219 true => Err(Ok(state.dst_pos)),
220 false => match decompress_token(src, dst, state) {
221 DecompStep::Continue(s) => Ok(s),
222 DecompStep::Done(n) => Err(Ok(n)),
223 DecompStep::Corrupt => Err(Err(FsError::DecompressError)),
224 },
225 });
226
227 match result {
228 Ok(s) => Ok(s.dst_pos),
229 Err(Ok(n)) => Ok(n),
230 Err(Err(e)) => Err(e),
231 }
232}
233
234#[cfg(test)]
235mod tests {
236 use super::*;
237
238 #[test]
239 fn roundtrip_zeros() {
240 let src = [0u8; 4096];
241 let mut compressed = [0u8; 4096];
242 let mut table = [0u16; HASH_SIZE];
243 let clen = compress(&src, &mut compressed, &mut table).expect("should compress zeros");
244 assert!(clen < src.len());
245
246 let mut decompressed = [0u8; 4096];
247 let dlen = decompress(&compressed[..clen], &mut decompressed).unwrap();
248 assert_eq!(&decompressed[..dlen], &src[..]);
249 }
250
251 #[test]
252 fn roundtrip_repeating_pattern() {
253 let src: Vec<u8> = (0..4096).map(|i| (i % 251) as u8).collect();
254 let mut compressed = [0u8; 8192];
255 let mut table = [0u16; HASH_SIZE];
256 let clen = compress(&src, &mut compressed, &mut table).expect("should compress pattern");
257
258 let mut decompressed = [0u8; 4096];
259 let dlen = decompress(&compressed[..clen], &mut decompressed).unwrap();
260 assert_eq!(&decompressed[..dlen], &src[..]);
261 }
262
263 #[test]
264 fn roundtrip_all_same_byte() {
265 let src = [0xAA_u8; 4096];
266 let mut compressed = [0u8; 4096];
267 let mut table = [0u16; HASH_SIZE];
268 let clen = compress(&src, &mut compressed, &mut table).expect("should compress");
269 assert!(clen < 100);
270
271 let mut decompressed = [0u8; 4096];
272 let dlen = decompress(&compressed[..clen], &mut decompressed).unwrap();
273 assert_eq!(&decompressed[..dlen], &src[..]);
274 }
275
276 #[test]
277 fn roundtrip_sequential_bytes() {
278 let src: Vec<u8> = (0..4096).map(|i| (i & 0xFF) as u8).collect();
279 let mut compressed = [0u8; 8192];
280 let mut table = [0u16; HASH_SIZE];
281 match compress(&src, &mut compressed, &mut table) {
282 Some(clen) => {
283 let mut decompressed = [0u8; 4096];
284 let dlen = decompress(&compressed[..clen], &mut decompressed).unwrap();
285 assert_eq!(&decompressed[..dlen], &src[..]);
286 }
287 None => {}
288 }
289 }
290
291 #[test]
292 fn roundtrip_small_data() {
293 let src = b"hello world, this is a test of lz4 compression";
294 let mut compressed = [0u8; 256];
295 let mut table = [0u16; HASH_SIZE];
296 match compress(src, &mut compressed, &mut table) {
297 Some(clen) => {
298 let mut decompressed = [0u8; 256];
299 let dlen = decompress(&compressed[..clen], &mut decompressed).unwrap();
300 assert_eq!(&decompressed[..dlen], &src[..]);
301 }
302 None => {}
303 }
304 }
305
306 #[test]
307 fn compress_too_short_returns_none() {
308 let src = [0u8; 4];
309 let mut compressed = [0u8; 64];
310 let mut table = [0u16; HASH_SIZE];
311 assert!(compress(&src, &mut compressed, &mut table).is_none());
312 }
313
314 #[test]
315 fn decompress_empty_input() {
316 let mut dst = [0u8; 64];
317 let result = decompress(&[], &mut dst);
318 assert_eq!(result.unwrap(), 0);
319 }
320
321 #[test]
322 fn decompress_corrupt_offset_zero() {
323 let corrupt = [0x00, 0x00, 0x00];
324 let mut dst = [0u8; 64];
325 assert!(decompress(&corrupt, &mut dst).is_err());
326 }
327
328 #[test]
329 fn roundtrip_block_sized_structured_data() {
330 let mut src = [0u8; 4096];
331 (0..64).for_each(|i| {
332 let start = i * 64;
333 src[start..start + 8].copy_from_slice(&(i as u64).to_le_bytes());
334 src[start + 8..start + 64].fill(0xAB);
335 });
336 let mut compressed = [0u8; 8192];
337 let mut table = [0u16; HASH_SIZE];
338 let clen = compress(&src, &mut compressed, &mut table).expect("structured data compresses");
339
340 let mut decompressed = [0u8; 4096];
341 let dlen = decompress(&compressed[..clen], &mut decompressed).unwrap();
342 assert_eq!(&decompressed[..dlen], &src[..]);
343 }
344}
345
346fn decompress_token(src: &[u8], dst: &mut [u8], state: DecompState) -> DecompStep {
347 let DecompState {
348 mut src_pos,
349 mut dst_pos,
350 } = state;
351
352 let src_len = src.len();
353 let dst_len = dst.len();
354
355 match src_pos >= src_len {
356 true => return DecompStep::Done(dst_pos),
357 false => {}
358 }
359
360 let token = src[src_pos];
361 src_pos += 1;
362
363 let mut lit_len = (token >> 4) as usize;
364 match lit_len == RUN_MASK {
365 true => {
366 core::iter::from_fn(|| match src_pos < src_len {
367 true => {
368 let extra = src[src_pos] as usize;
369 src_pos += 1;
370 lit_len += extra;
371 match extra == 255 {
372 true => Some(()),
373 false => None,
374 }
375 }
376 false => None,
377 })
378 .last();
379 }
380 false => {}
381 }
382
383 match src_pos + lit_len <= src_len && dst_pos + lit_len <= dst_len {
384 true => {
385 dst[dst_pos..dst_pos + lit_len].copy_from_slice(&src[src_pos..src_pos + lit_len]);
386 src_pos += lit_len;
387 dst_pos += lit_len;
388 }
389 false => return DecompStep::Corrupt,
390 }
391
392 match src_pos >= src_len {
393 true => return DecompStep::Done(dst_pos),
394 false => {}
395 }
396
397 match src_pos + 2 <= src_len {
398 true => {}
399 false => return DecompStep::Corrupt,
400 }
401 let offset = u16::from_le_bytes([src[src_pos], src[src_pos + 1]]) as usize;
402 src_pos += 2;
403
404 match offset == 0 || offset > dst_pos {
405 true => return DecompStep::Corrupt,
406 false => {}
407 }
408
409 let mut match_len = (token & 0x0F) as usize + MIN_MATCH;
410 match (token & 0x0F) as usize == ML_MASK {
411 true => {
412 core::iter::from_fn(|| match src_pos < src_len {
413 true => {
414 let extra = src[src_pos] as usize;
415 src_pos += 1;
416 match_len += extra;
417 match extra == 255 {
418 true => Some(()),
419 false => None,
420 }
421 }
422 false => None,
423 })
424 .last();
425 }
426 false => {}
427 }
428
429 match dst_pos + match_len <= dst_len {
430 true => {
431 let match_start = dst_pos - offset;
432 (0..match_len).for_each(|i| {
433 dst[dst_pos + i] = dst[match_start + i];
434 });
435 dst_pos += match_len;
436 DecompStep::Continue(DecompState { src_pos, dst_pos })
437 }
438 false => DecompStep::Corrupt,
439 }
440}