Add simple download progress tracker

This commit is contained in:
Ching Pei Yang 2023-11-26 13:55:35 +01:00
parent 69b72866db
commit b55336f093
No known key found for this signature in database
GPG Key ID: 062FBBCE1D0C5DD9

View File

@ -1,5 +1,5 @@
use std::{
io::{self, Read},
io::{self, Read, Write},
path::Path,
};
@ -275,9 +275,17 @@ pub fn download_and_hash(
Encoding::new(content_encoding, url)?
};
// Use .take to prevent a malicious server from sending back bytes
// until system resources are exhausted!
decompress_into(dest_dir, encoding, resp.take(max_download_bytes))
if let Some(content_len) = resp.content_length() {
// Print download progress to stdout if we know the content length
//
// Use .take to prevent a malicious server from sending back bytes
// until system resources are exhausted!
let resp = ProgressReporter::new(resp.take(max_download_bytes), content_len as usize);
decompress_into(dest_dir, encoding, resp)
} else {
decompress_into(dest_dir, encoding, resp.take(max_download_bytes))
}
}
/// The content encodings we support
@ -405,3 +413,44 @@ impl<R: Read> Read for HashReader<R> {
Ok(bytes_read)
}
}
/// Prints download progress to stdout
struct ProgressReporter<R: Read> {
read: usize,
total: usize,
last_reported: usize,
reader: R,
}
impl<R: Read> ProgressReporter<R> {
fn new(reader: R, total: usize) -> Self {
ProgressReporter {
read: 0,
last_reported: 0,
total,
reader,
}
}
}
impl<R: Read> Read for ProgressReporter<R> {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
let size = self.reader.read(buf)?;
self.read += size;
print!(
"\u{001b}[2K\u{001b}[G[{:.1} / {:.1} MB]",
self.read as f32 / 1_000_000.0,
self.total as f32 / 1_000_000.0,
);
std::io::stdout().flush()?;
self.last_reported = self.read;
if self.read >= self.total {
println!("");
}
Ok(size)
}
}