Skip to content

Instantly share code, notes, and snippets.

@bczhc
Last active December 29, 2025 14:40
Show Gist options
  • Select an option

  • Save bczhc/fc15a28d5bafb0aedfab79cec29ba328 to your computer and use it in GitHub Desktop.

Select an option

Save bczhc/fc15a28d5bafb0aedfab79cec29ba328 to your computer and use it in GitHub Desktop.
Huffman编码文件解压缩器 #huffman
/// 没有支持流式压缩/解压缩
///
/// 手动实现,未使用bitvec bitstream-io 等库,练习作品,还有很多优化空间、、
#![feature(file_buffered)]
use byteorder::{ReadBytesExt, WriteBytesExt, LE};
use clap::Parser;
use std::cmp::{Ordering, Reverse};
use std::collections::{BinaryHeap, HashMap};
use std::fmt::Debug;
use std::fs::File;
use std::io::{BufRead, Read, Seek, SeekFrom, Write};
use std::path::{Path, PathBuf};
use std::{io, mem};
type Bit = bool;
#[derive(Parser)]
struct Args {
/// 输入文件路径
input: PathBuf,
/// 输出文件路径
output: PathBuf,
/// 是否执行解压 (默认为压缩)
#[arg(short, long)]
decompress: bool,
/// 压缩再解压缩
#[arg(short, long)]
test: bool,
}
fn main() -> anyhow::Result<()> {
let args = Args::parse();
if args.test {
compress_file(&args.input, &args.output)?;
decompress_file(&args.output, "/tmp/decompressed")?;
let original_size = read_file_size(&args.input)?;
let compressed_size = read_file_size(&args.output)?;
println!(
"{} -> {}, compression factor: {:.2}%",
original_size,
compressed_size,
compressed_size as f64 / original_size as f64 * 100.0,
);
return Ok(());
}
if args.decompress {
println!("正在解压: {:?} -> {:?}", args.input, args.output);
decompress_file(&args.input, &args.output)?;
} else {
println!("正在压缩: {:?} -> {:?}", args.input, args.output);
compress_file(&args.input, &args.output)?;
}
Ok(())
}
fn read_file_size(path: impl AsRef<Path>) -> io::Result<u64> {
Ok(path.as_ref().metadata()?.len())
}
fn pack_bits(bits: &[Bit]) -> (usize, Vec<u8>) {
// unwrap is used in this function - won't fail for Vec<u8>
let buf = Vec::new();
let mut writer = BitStream::new(buf);
for &b in bits {
writer.write(b).unwrap();
}
let len = writer.len();
let buf = writer.finish().unwrap();
(len, buf)
}
fn decompress_file(input: impl AsRef<Path>, output: impl AsRef<Path>) -> io::Result<()> {
let mut reader = File::open_buffered(input)?;
let mut writer = File::create_buffered(output)?;
let original_size = reader.read_u64::<LE>()?;
let huffman_entry_num = reader.read_u8()? + 1;
let mut huffman_entries = Vec::new();
for _ in 0..huffman_entry_num {
huffman_entries.push(HuffmanEntry::<u8>::read_from(&mut reader)?);
}
let lookup_map = huffman_entries
.into_iter()
.map(|x| {
let mut bits = vec![false; x.bits_len as usize];
for (i, b) in bits.iter_mut().enumerate() {
*b = get_bit(x.packed_bits[i / 8], (i % 8) as u8);
}
(bits, x.symbol)
})
.collect::<HashMap<_, _>>();
let mut bits: Vec<bool> = Vec::with_capacity(32);
let mut bytes_written = 0_u64;
'a: for b in reader.bytes() {
let byte = b?;
for i in 0..8_u8 {
let bit = get_bit(byte, i);
bits.push(bit);
let lookup = lookup_map.get(&bits);
if let Some(&symbol) = lookup {
writer.write_u8(symbol)?;
bytes_written += 1;
if bytes_written == original_size {
break 'a;
}
bits.clear();
}
}
}
println!("Done");
Ok(())
}
fn compress_file(input: impl AsRef<Path>, output: impl AsRef<Path>) -> io::Result<()> {
let mut reader = File::open_buffered(input)?;
let mut writer = File::create_buffered(output)?;
reader.seek(SeekFrom::End(0))?;
let original_size = reader.stream_position()?;
reader.seek(SeekFrom::Start(0))?;
writer.write_u64::<LE>(original_size)?;
println!("Scanning byte frequencies...");
let freq = scan_bytes_freq(&mut reader)?;
println!("Building Huffman table...");
let table = build_huffman_tree(freq).collect_encoding_table();
let huffman_map = table
.into_iter()
.map(|x| (x.symbol, x))
.collect::<HashMap<_, _>>();
let huffman_entry_num = huffman_map.len();
writer.write_u8(
(huffman_entry_num - 1/* won't be 0; store 1 as underlying 0 */)
.try_into()
.unwrap(), /* for [u8] compression, this won't exceed u8 range */
)?;
for entry in huffman_map.values() {
entry.write_to(&mut writer)?;
}
let mut bit_writer = BitStream::new(&mut writer);
println!("Reset file position and start compressing now...");
reader.seek(SeekFrom::Start(0))?;
for b in reader.bytes() {
let byte = b?;
let entry = &huffman_map[&byte];
for idx in 0..(entry.bits_len as usize) {
bit_writer.write(get_bit(entry.packed_bits[idx / 8], (idx % 8) as u8))?;
}
}
bit_writer.finish()?;
println!("Done");
Ok(())
}
/// MSB is index 0
#[inline(always)]
fn set_bit(x: u8, i: u8, bit: bool) -> u8 {
let n = 7 - i;
let bit: u8 = bit.into();
(x & !(1 << n)) | (bit << n)
}
#[inline(always)]
fn get_bit(x: u8, i: u8) -> bool {
let n = 7 - i;
((x & (1 << n)) >> n) == 1
}
struct BitStream<W: Write> {
writer: W,
len: usize,
byte: u8,
}
impl<W: Write> BitStream<W> {
fn new(writer: W) -> Self {
Self {
writer,
len: 0,
byte: 0_u8,
}
}
fn write(&mut self, bit: bool) -> io::Result<()> {
let n = (self.len % 8) as u8;
self.byte = set_bit(self.byte, n, bit);
self.len += 1;
if self.len % 8 == 0 {
self.writer.write_u8(self.byte)?;
}
Ok(())
}
fn flush_underlying(&mut self) -> io::Result<()> {
self.writer.flush()
}
/// Pad the current working byte and write to the writer.
fn pad_current(&mut self) -> io::Result<()> {
if self.len % 8 == 0 {
return Ok(());
}
for _ in 0..(8 - self.len % 8) {
self.write(false)?;
}
Ok(())
}
/// Finish the writer. Unfinished working byte will be padded.
fn finish(mut self) -> io::Result<W> {
self.pad_current()?;
self.flush_underlying()?;
Ok(self.writer)
}
fn len(&self) -> usize {
self.len
}
}
fn scan_bytes_freq<R: Read + BufRead>(mut reader: R) -> io::Result<HashMap<u8, u64>> {
let mut dict = HashMap::new();
for byte in reader.bytes() {
*dict.entry(byte?).or_insert(0) += 1;
}
Ok(dict)
}
#[derive(Debug)]
enum Node<T> {
Internal {
freq: u64,
left: Box<Node<T>>,
right: Box<Node<T>>,
},
Leaf {
freq: u64,
item: T,
},
}
struct HuffmanEntry<T> {
symbol: T,
bits_len: u8,
packed_bits: Vec<u8>,
}
trait WriteTo {
fn write_to<W: Write>(&self, writer: W) -> io::Result<()>;
}
trait ReadFrom
where
Self: Sized,
{
fn read_from<R: Read>(reader: R) -> io::Result<Self>;
}
/// Format: \[symbol (u8) | bits_len (u8) | bits (var)\]
impl WriteTo for HuffmanEntry<u8> {
fn write_to<W: Write>(&self, mut writer: W) -> io::Result<()> {
writer.write_u8(self.symbol)?;
writer.write_u8(self.bits_len)?;
writer.write_all(&self.packed_bits)?;
Ok(())
}
}
impl ReadFrom for HuffmanEntry<u8> {
fn read_from<R: Read>(mut reader: R) -> io::Result<Self> {
let symbol = reader.read_u8()?;
let bits_len = reader.read_u8()?;
let mut bits = vec![0_u8; bits_len.div_ceil(8) as usize];
reader.read_exact(&mut bits)?;
Ok(Self {
symbol,
bits_len,
packed_bits: bits,
})
}
}
impl<T> Node<T> {
fn freq(&self) -> u64 {
match self {
Node::Internal { freq, .. } => *freq,
Node::Leaf { freq, .. } => *freq,
}
}
fn collect_encoding_table(self: Box<Node<T>>) -> Vec<HuffmanEntry<T>> {
if let Node::Leaf { freq, item } = *self {
return vec![HuffmanEntry {
symbol: item,
bits_len: 1,
packed_bits: vec![0b1000_0000 /* MSB-first bit '1' */],
}];
}
fn sub<T>(root: Box<Node<T>>, path: &mut Vec<Bit>, result: &mut Vec<HuffmanEntry<T>>) {
match *root {
Node::Leaf { freq, item } => {
let packed = pack_bits(&path);
result.push(HuffmanEntry {
symbol: item,
packed_bits: packed.1,
bits_len: packed.0.try_into().expect("Unexpected too long path"),
});
}
Node::Internal { freq, left, right } => {
path.push(false);
sub(left, path, result);
path.pop();
path.push(true);
sub(right, path, result);
path.pop();
}
}
}
let mut result = Vec::new();
let mut path = Vec::new();
sub(self, &mut path, &mut result);
result
}
}
impl<T> Eq for Node<T> {}
impl<T> PartialEq<Self> for Node<T> {
fn eq(&self, other: &Self) -> bool {
self.freq().eq(&other.freq())
}
}
impl<T> PartialOrd<Self> for Node<T> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
self.freq().partial_cmp(&other.freq())
}
}
impl<T> Ord for Node<T> {
fn cmp(&self, other: &Self) -> Ordering {
self.freq().cmp(&other.freq())
}
}
fn build_huffman_tree<T>(input: impl IntoIterator<Item = (T, u64)>) -> Box<Node<T>> {
let mut heap = BinaryHeap::new();
for x in input {
let node = Node::Leaf {
freq: x.1,
item: x.0,
};
heap.push(Reverse(Box::new(node)));
}
while heap.len() > 1 {
// `heap.len() > 1` ensures the two unwraps will never fail
let mut take1 = heap.pop().unwrap();
let mut take2 = heap.pop().unwrap();
// for a consistency, always put the smaller frequency one on the left
if take1.0.freq() > take2.0.freq() {
mem::swap(&mut take1, &mut take2);
}
let new_node = Node::Internal {
freq: take1.0.freq() + take2.0.freq(),
left: take1.0,
right: take2.0,
};
heap.push(Reverse(Box::new(new_node)));
}
// the only one remaining node is the tree root
heap.pop().unwrap().0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment