diff --git a/atuin-client/src/import/bash.rs b/atuin-client/src/import/bash.rs index d5fbef46..1a171625 100644 --- a/atuin-client/src/import/bash.rs +++ b/atuin-client/src/import/bash.rs @@ -1,63 +1,77 @@ -use std::io::{BufRead, BufReader}; -use std::{fs::File, path::Path}; +use std::{ + fs::File, + io::{BufRead, BufReader, Read, Seek}, + path::{Path, PathBuf}, +}; +use directories::UserDirs; use eyre::{eyre, Result}; -use super::count_lines; +use super::{count_lines, Importer}; use crate::history::History; #[derive(Debug)] -pub struct Bash { - file: BufReader, - - pub loc: u64, - pub counter: i64, +pub struct Bash { + file: BufReader, + strbuf: String, + loc: usize, + counter: i64, } -impl Bash { - pub fn new(path: impl AsRef) -> Result { - let file = File::open(path)?; - let mut buf = BufReader::new(file); +impl Bash { + fn new(r: R) -> Result { + let mut buf = BufReader::new(r); let loc = count_lines(&mut buf)?; Ok(Self { file: buf, - loc: loc as u64, + strbuf: String::new(), + loc, counter: 0, }) } +} - fn read_line(&mut self) -> Option> { - let mut line = String::new(); +impl Importer for Bash { + const NAME: &'static str = "bash"; - match self.file.read_line(&mut line) { - Ok(0) => None, - Ok(_) => Some(Ok(line)), - Err(e) => Some(Err(eyre!("failed to read line: {}", e))), // we can skip past things like invalid utf8 - } + fn histpath() -> Result { + let user_dirs = UserDirs::new().unwrap(); + let home_dir = user_dirs.home_dir(); + + Ok(home_dir.join(".bash_history")) + } + + fn parse(path: impl AsRef) -> Result { + Self::new(File::open(path)?) } } -impl Iterator for Bash { +impl Iterator for Bash { type Item = Result; fn next(&mut self) -> Option { - let line = self.read_line()?; - - if let Err(e) = line { - return Some(Err(e)); // :( + self.strbuf.clear(); + match self.file.read_line(&mut self.strbuf) { + Ok(0) => return None, + Ok(_) => (), + Err(e) => return Some(Err(eyre!("failed to read line: {}", e))), // we can skip past things like invalid utf8 } - let mut line = line.unwrap(); + self.loc -= 1; - while line.ends_with("\\\n") { - let next_line = self.read_line()?; - - if next_line.is_err() { + while self.strbuf.ends_with("\\\n") { + if self.file.read_line(&mut self.strbuf).is_err() { + // There's a chance that the last line of a command has invalid + // characters, the only safe thing to do is break :/ + // usually just invalid utf8 or smth + // however, we really need to avoid missing history, so it's + // better to have some items that should have been part of + // something else, than to miss things. So break. break; - } + }; - line.push_str(next_line.unwrap().as_str()); + self.loc -= 1; } let time = chrono::Utc::now(); @@ -68,7 +82,7 @@ impl Iterator for Bash { Some(Ok(History::new( time, - line.trim_end().to_string(), + self.strbuf.trim_end().to_string(), String::from("unknown"), -1, -1, @@ -76,4 +90,45 @@ impl Iterator for Bash { None, ))) } + + fn size_hint(&self) -> (usize, Option) { + (0, Some(self.loc)) + } +} + +#[cfg(test)] +mod tests { + use std::io::Cursor; + + use super::Bash; + + #[test] + fn test_parse_file() { + let input = r"cargo install atuin +cargo install atuin; \ +cargo update +cargo :b̷i̶t̴r̵o̴t̴ ̵i̷s̴ ̷r̶e̵a̸l̷ +"; + + let cursor = Cursor::new(input); + let mut bash = Bash::new(cursor).unwrap(); + assert_eq!(bash.loc, 4); + assert_eq!(bash.size_hint(), (0, Some(4))); + + assert_eq!( + &bash.next().unwrap().unwrap().command, + "cargo install atuin" + ); + assert_eq!( + &bash.next().unwrap().unwrap().command, + "cargo install atuin; \\\ncargo update" + ); + assert_eq!( + &bash.next().unwrap().unwrap().command, + "cargo :b̷i̶t̴r̵o̴t̴ ̵i̷s̴ ̷r̶e̵a̸l̷" + ); + assert!(bash.next().is_none()); + + assert_eq!(bash.size_hint(), (0, Some(0))); + } } diff --git a/atuin-client/src/import/mod.rs b/atuin-client/src/import/mod.rs index 0b21d605..d73d3857 100644 --- a/atuin-client/src/import/mod.rs +++ b/atuin-client/src/import/mod.rs @@ -1,16 +1,24 @@ -use std::fs::File; -use std::io::{BufRead, BufReader, Seek, SeekFrom}; +use std::io::{BufRead, BufReader, Read, Seek, SeekFrom}; +use std::path::{Path, PathBuf}; use eyre::Result; +use crate::history::History; + pub mod bash; pub mod resh; pub mod zsh; // this could probably be sped up -fn count_lines(buf: &mut BufReader) -> Result { +fn count_lines(buf: &mut BufReader) -> Result { let lines = buf.lines().count(); buf.seek(SeekFrom::Start(0))?; Ok(lines) } + +pub trait Importer: IntoIterator> + Sized { + const NAME: &'static str; + fn histpath() -> Result; + fn parse(path: impl AsRef) -> Result; +} diff --git a/atuin-client/src/import/resh.rs b/atuin-client/src/import/resh.rs index 55c9da7f..a0378c36 100644 --- a/atuin-client/src/import/resh.rs +++ b/atuin-client/src/import/resh.rs @@ -1,91 +1,156 @@ +use std::{ + fs::File, + io::{BufRead, BufReader}, + path::{Path, PathBuf}, +}; + +use atuin_common::utils::uuid_v4; +use chrono::{TimeZone, Utc}; +use directories::UserDirs; +use eyre::{eyre, Result}; use serde::Deserialize; +use super::{count_lines, Importer}; +use crate::history::History; + #[derive(Deserialize, Debug)] +#[serde(rename_all = "camelCase")] pub struct ReshEntry { - #[serde(rename = "cmdLine")] pub cmd_line: String, - #[serde(rename = "exitCode")] pub exit_code: i64, pub shell: String, pub uname: String, - #[serde(rename = "sessionId")] pub session_id: String, pub home: String, pub lang: String, - #[serde(rename = "lcAll")] pub lc_all: String, pub login: String, pub pwd: String, - #[serde(rename = "pwdAfter")] pub pwd_after: String, - #[serde(rename = "shellEnv")] pub shell_env: String, pub term: String, - #[serde(rename = "realPwd")] pub real_pwd: String, - #[serde(rename = "realPwdAfter")] pub real_pwd_after: String, pub pid: i64, - #[serde(rename = "sessionPid")] pub session_pid: i64, pub host: String, pub hosttype: String, pub ostype: String, pub machtype: String, pub shlvl: i64, - #[serde(rename = "timezoneBefore")] pub timezone_before: String, - #[serde(rename = "timezoneAfter")] pub timezone_after: String, - #[serde(rename = "realtimeBefore")] pub realtime_before: f64, - #[serde(rename = "realtimeAfter")] pub realtime_after: f64, - #[serde(rename = "realtimeBeforeLocal")] pub realtime_before_local: f64, - #[serde(rename = "realtimeAfterLocal")] pub realtime_after_local: f64, - #[serde(rename = "realtimeDuration")] pub realtime_duration: f64, - #[serde(rename = "realtimeSinceSessionStart")] pub realtime_since_session_start: f64, - #[serde(rename = "realtimeSinceBoot")] pub realtime_since_boot: f64, - #[serde(rename = "gitDir")] pub git_dir: String, - #[serde(rename = "gitRealDir")] pub git_real_dir: String, - #[serde(rename = "gitOriginRemote")] pub git_origin_remote: String, - #[serde(rename = "gitDirAfter")] pub git_dir_after: String, - #[serde(rename = "gitRealDirAfter")] pub git_real_dir_after: String, - #[serde(rename = "gitOriginRemoteAfter")] pub git_origin_remote_after: String, - #[serde(rename = "machineId")] pub machine_id: String, - #[serde(rename = "osReleaseId")] pub os_release_id: String, - #[serde(rename = "osReleaseVersionId")] pub os_release_version_id: String, - #[serde(rename = "osReleaseIdLike")] pub os_release_id_like: String, - #[serde(rename = "osReleaseName")] pub os_release_name: String, - #[serde(rename = "osReleasePrettyName")] pub os_release_pretty_name: String, - #[serde(rename = "reshUuid")] pub resh_uuid: String, - #[serde(rename = "reshVersion")] pub resh_version: String, - #[serde(rename = "reshRevision")] pub resh_revision: String, - #[serde(rename = "partsMerged")] pub parts_merged: bool, pub recalled: bool, - #[serde(rename = "recallLastCmdLine")] pub recall_last_cmd_line: String, pub cols: String, pub lines: String, } + +#[derive(Debug)] +pub struct Resh { + file: BufReader, + strbuf: String, + loc: usize, + counter: i64, +} + +impl Importer for Resh { + const NAME: &'static str = "resh"; + + fn histpath() -> Result { + let user_dirs = UserDirs::new().unwrap(); + let home_dir = user_dirs.home_dir(); + + Ok(home_dir.join(".resh_history.json")) + } + + fn parse(path: impl AsRef) -> Result { + let file = File::open(path)?; + let mut buf = BufReader::new(file); + let loc = count_lines(&mut buf)?; + + Ok(Self { + file: buf, + strbuf: String::new(), + loc, + counter: 0, + }) + } +} + +impl Iterator for Resh { + type Item = Result; + + fn next(&mut self) -> Option { + self.strbuf.clear(); + match self.file.read_line(&mut self.strbuf) { + Ok(0) => return None, + Ok(_) => (), + Err(e) => return Some(Err(eyre!("failed to read line: {}", e))), // we can skip past things like invalid utf8 + } + + let entry = match serde_json::from_str::(&self.strbuf) { + Ok(e) => e, + Err(e) => { + return Some(Err(eyre!( + "Invalid entry found in resh_history file: {}", + e + ))) + } + }; + + #[allow(clippy::cast_possible_truncation)] + #[allow(clippy::cast_sign_loss)] + let timestamp = { + let secs = entry.realtime_before.floor() as i64; + let nanosecs = (entry.realtime_before.fract() * 1_000_000_000_f64).round() as u32; + Utc.timestamp(secs, nanosecs) + }; + #[allow(clippy::cast_possible_truncation)] + #[allow(clippy::cast_sign_loss)] + let duration = { + let secs = entry.realtime_after.floor() as i64; + let nanosecs = (entry.realtime_after.fract() * 1_000_000_000_f64).round() as u32; + let difference = Utc.timestamp(secs, nanosecs) - timestamp; + difference.num_nanoseconds().unwrap_or(0) + }; + + Some(Ok(History { + id: uuid_v4(), + timestamp, + duration, + exit: entry.exit_code, + command: entry.cmd_line, + cwd: entry.pwd, + session: uuid_v4(), + hostname: entry.host, + })) + } + + fn size_hint(&self) -> (usize, Option) { + (self.loc, Some(self.loc)) + } +} diff --git a/atuin-client/src/import/zsh.rs b/atuin-client/src/import/zsh.rs index 46e9af63..b3db36b6 100644 --- a/atuin-client/src/import/zsh.rs +++ b/atuin-client/src/import/zsh.rs @@ -1,50 +1,73 @@ // import old shell history! // automatically hoover up all that we can find -use std::io::{BufRead, BufReader}; -use std::{fs::File, path::Path}; +use std::{ + fs::File, + io::{BufRead, BufReader, Read, Seek}, + path::{Path, PathBuf}, +}; use chrono::prelude::*; use chrono::Utc; +use directories::UserDirs; use eyre::{eyre, Result}; use itertools::Itertools; -use super::count_lines; +use super::{count_lines, Importer}; use crate::history::History; #[derive(Debug)] -pub struct Zsh { - file: BufReader, - - pub loc: u64, - pub counter: i64, +pub struct Zsh { + file: BufReader, + strbuf: String, + loc: usize, + counter: i64, } -impl Zsh { - pub fn new(path: impl AsRef) -> Result { - let file = File::open(path)?; - let mut buf = BufReader::new(file); +impl Zsh { + fn new(r: R) -> Result { + let mut buf = BufReader::new(r); let loc = count_lines(&mut buf)?; Ok(Self { file: buf, - loc: loc as u64, + strbuf: String::new(), + loc, counter: 0, }) } +} - fn read_line(&mut self) -> Option> { - let mut line = String::new(); +impl Importer for Zsh { + const NAME: &'static str = "zsh"; - match self.file.read_line(&mut line) { - Ok(0) => None, - Ok(_) => Some(Ok(line)), - Err(e) => Some(Err(eyre!("failed to read line: {}", e))), // we can skip past things like invalid utf8 + fn histpath() -> Result { + // oh-my-zsh sets HISTFILE=~/.zhistory + // zsh has no default value for this var, but uses ~/.zhistory. + // we could maybe be smarter about this in the future :) + let user_dirs = UserDirs::new().unwrap(); + let home_dir = user_dirs.home_dir(); + + let mut candidates = [".zhistory", ".zsh_history"].iter(); + loop { + match candidates.next() { + Some(candidate) => { + let histpath = home_dir.join(candidate); + if histpath.exists() { + break Ok(histpath); + } + } + None => break Err(eyre!("Could not find history file. Try setting $HISTFILE")), + } } } + + fn parse(path: impl AsRef) -> Result { + Self::new(File::open(path)?) + } } -impl Iterator for Zsh { +impl Iterator for Zsh { type Item = Result; fn next(&mut self) -> Option { @@ -52,18 +75,17 @@ impl Iterator for Zsh { // These lines begin with : // So, if the line begins with :, parse it. Otherwise it's just // the command - let line = self.read_line()?; - - if let Err(e) = line { - return Some(Err(e)); // :( + self.strbuf.clear(); + match self.file.read_line(&mut self.strbuf) { + Ok(0) => return None, + Ok(_) => (), + Err(e) => return Some(Err(eyre!("failed to read line: {}", e))), // we can skip past things like invalid utf8 } - let mut line = line.unwrap(); + self.loc -= 1; - while line.ends_with("\\\n") { - let next_line = self.read_line()?; - - if next_line.is_err() { + while self.strbuf.ends_with("\\\n") { + if self.file.read_line(&mut self.strbuf).is_err() { // There's a chance that the last line of a command has invalid // characters, the only safe thing to do is break :/ // usually just invalid utf8 or smth @@ -71,19 +93,19 @@ impl Iterator for Zsh { // better to have some items that should have been part of // something else, than to miss things. So break. break; - } + }; - line.push_str(next_line.unwrap().as_str()); + self.loc -= 1; } // We have to handle the case where a line has escaped newlines. // Keep reading until we have a non-escaped newline - let extended = line.starts_with(':'); + let extended = self.strbuf.starts_with(':'); if extended { self.counter += 1; - Some(Ok(parse_extended(line.as_str(), self.counter))) + Some(Ok(parse_extended(&self.strbuf, self.counter))) } else { let time = chrono::Utc::now(); let offset = chrono::Duration::seconds(self.counter); @@ -93,7 +115,7 @@ impl Iterator for Zsh { Some(Ok(History::new( time, - line.trim_end().to_string(), + self.strbuf.trim_end().to_string(), String::from("unknown"), -1, -1, @@ -102,6 +124,10 @@ impl Iterator for Zsh { ))) } } + + fn size_hint(&self) -> (usize, Option) { + (0, Some(self.loc)) + } } fn parse_extended(line: &str, counter: i64) -> History { @@ -133,10 +159,12 @@ fn parse_extended(line: &str, counter: i64) -> History { #[cfg(test)] mod test { + use std::io::Cursor; + use chrono::prelude::*; use chrono::Utc; - use super::parse_extended; + use super::*; #[test] fn test_parse_extended_simple() { @@ -164,4 +192,31 @@ mod test { assert_eq!(parsed.duration, 10_000_000_000); assert_eq!(parsed.timestamp, Utc.timestamp(1_613_322_469, 0)); } + + #[test] + fn test_parse_file() { + let input = r": 1613322469:0;cargo install atuin +: 1613322469:10;cargo install atuin; \ +cargo update +: 1613322469:10;cargo :b̷i̶t̴r̵o̴t̴ ̵i̷s̴ ̷r̶e̵a̸l̷ +"; + + let cursor = Cursor::new(input); + let mut zsh = Zsh::new(cursor).unwrap(); + assert_eq!(zsh.loc, 4); + assert_eq!(zsh.size_hint(), (0, Some(4))); + + assert_eq!(&zsh.next().unwrap().unwrap().command, "cargo install atuin"); + assert_eq!( + &zsh.next().unwrap().unwrap().command, + "cargo install atuin; \\\ncargo update" + ); + assert_eq!( + &zsh.next().unwrap().unwrap().command, + "cargo :b̷i̶t̴r̵o̴t̴ ̵i̷s̴ ̷r̶e̵a̸l̷" + ); + assert!(zsh.next().is_none()); + + assert_eq!(zsh.size_hint(), (0, Some(0))); + } } diff --git a/src/command/import.rs b/src/command/import.rs index 9a0364da..53940abb 100644 --- a/src/command/import.rs +++ b/src/command/import.rs @@ -1,15 +1,11 @@ -use std::env; -use std::path::PathBuf; +use std::{env, path::PathBuf}; -use atuin_common::utils::uuid_v4; -use chrono::{TimeZone, Utc}; -use directories::UserDirs; use eyre::{eyre, Result}; use structopt::StructOpt; -use atuin_client::history::History; use atuin_client::import::{bash::Bash, zsh::Zsh}; -use atuin_client::{database::Database, import::resh::ReshEntry}; +use atuin_client::{database::Database, import::Importer}; +use atuin_client::{history::History, import::resh::Resh}; use indicatif::ProgressBar; #[derive(StructOpt)] @@ -39,6 +35,8 @@ pub enum Cmd { Resh, } +const BATCH_SIZE: usize = 100; + impl Cmd { pub async fn run(&self, db: &mut (impl Database + Send + Sync)) -> Result<()> { println!(" Atuin "); @@ -55,216 +53,117 @@ impl Cmd { if shell.ends_with("/zsh") { println!("Detected ZSH"); - import_zsh(db).await + import::, _>(db, BATCH_SIZE).await } else { println!("cannot import {} history", shell); Ok(()) } } - Self::Zsh => import_zsh(db).await, - Self::Bash => import_bash(db).await, - Self::Resh => import_resh(db).await, + Self::Zsh => import::, _>(db, BATCH_SIZE).await, + Self::Bash => import::, _>(db, BATCH_SIZE).await, + Self::Resh => import::(db, BATCH_SIZE).await, } } } -async fn import_resh(db: &mut (impl Database + Send + Sync)) -> Result<()> { - let histpath = std::path::Path::new(std::env::var("HOME")?.as_str()).join(".resh_history.json"); +async fn import( + db: &mut DB, + buf_size: usize, +) -> Result<()> +where + I::IntoIter: Send, +{ + println!("Importing history from {}", I::NAME); - println!("Parsing .resh_history.json..."); - #[allow(clippy::filter_map)] - let history = std::fs::read_to_string(histpath)? - .split('\n') - .map(str::trim) - .map(|x| serde_json::from_str::(x)) - .filter_map(|x| match x { - Ok(x) => Some(x), - Err(e) => { - if e.is_eof() { - None - } else { - warn!("Invalid entry found in resh_history file: {}", e); - None - } - } - }) - .map(|x| { - #[allow(clippy::cast_possible_truncation)] - #[allow(clippy::cast_sign_loss)] - let timestamp = { - let secs = x.realtime_before.floor() as i64; - let nanosecs = (x.realtime_before.fract() * 1_000_000_000_f64).round() as u32; - Utc.timestamp(secs, nanosecs) - }; - #[allow(clippy::cast_possible_truncation)] - #[allow(clippy::cast_sign_loss)] - let duration = { - let secs = x.realtime_after.floor() as i64; - let nanosecs = (x.realtime_after.fract() * 1_000_000_000_f64).round() as u32; - let difference = Utc.timestamp(secs, nanosecs) - timestamp; - difference.num_nanoseconds().unwrap_or(0) - }; + let histpath = get_histpath::()?; + let contents = I::parse(histpath)?; - History { - id: uuid_v4(), - timestamp, - duration, - exit: x.exit_code, - command: x.cmd_line, - cwd: x.pwd, - session: uuid_v4(), - hostname: x.host, - } - }) - .collect::>(); - println!("Updating database..."); - - let progress = ProgressBar::new(history.len() as u64); - - let buf_size = 100; - let mut buf = Vec::<_>::with_capacity(buf_size); - - for i in history { - buf.push(i); - - if buf.len() == buf_size { - db.save_bulk(&buf).await?; - progress.inc(buf.len() as u64); - - buf.clear(); - } - } - - if !buf.is_empty() { - db.save_bulk(&buf).await?; - progress.inc(buf.len() as u64); - } - Ok(()) -} - -async fn import_zsh(db: &mut (impl Database + Send + Sync)) -> Result<()> { - // oh-my-zsh sets HISTFILE=~/.zhistory - // zsh has no default value for this var, but uses ~/.zhistory. - // we could maybe be smarter about this in the future :) - - let histpath = env::var("HISTFILE"); - - let histpath = if let Ok(p) = histpath { - let histpath = PathBuf::from(p); - - if !histpath.exists() { - return Err(eyre!( - "Could not find history file {:?}. try updating $HISTFILE", - histpath - )); - } - - histpath + let iter = contents.into_iter(); + let progress = if let (_, Some(upper_bound)) = iter.size_hint() { + ProgressBar::new(upper_bound as u64) } else { - let user_dirs = UserDirs::new().unwrap(); - let home_dir = user_dirs.home_dir(); - - let mut candidates = [".zhistory", ".zsh_history"].iter(); - loop { - match candidates.next() { - Some(candidate) => { - let histpath = home_dir.join(candidate); - if histpath.exists() { - break histpath; - } - } - None => return Err(eyre!("Could not find history file. try setting $HISTFILE")), - } - } + ProgressBar::new_spinner() }; - let zsh = Zsh::new(histpath)?; - - let progress = ProgressBar::new(zsh.loc); - - let buf_size = 100; let mut buf = Vec::::with_capacity(buf_size); + let mut iter = progress.wrap_iter(iter); + loop { + // fill until either no more entries + // or until the buffer is full + let done = fill_buf(&mut buf, &mut iter); - for i in zsh - .filter_map(Result::ok) - .filter(|x| !x.command.trim().is_empty()) - { - buf.push(i); + // flush + db.save_bulk(&buf).await?; - if buf.len() == buf_size { - db.save_bulk(&buf).await?; - progress.inc(buf.len() as u64); - - buf.clear(); + if done { + break; } } - if !buf.is_empty() { - db.save_bulk(&buf).await?; - progress.inc(buf.len() as u64); - } - - progress.finish(); println!("Import complete!"); Ok(()) } -// TODO: don't just copy paste this lol -async fn import_bash(db: &mut (impl Database + Send + Sync)) -> Result<()> { - // oh-my-zsh sets HISTFILE=~/.zhistory - // zsh has no default value for this var, but uses ~/.zhistory. - // we could maybe be smarter about this in the future :) - - let histpath = env::var("HISTFILE"); - - let histpath = if let Ok(p) = histpath { - let histpath = PathBuf::from(p); - - if !histpath.exists() { - return Err(eyre!( - "Could not find history file {:?}. try updating $HISTFILE", - histpath - )); - } - - histpath +fn get_histpath() -> Result { + if let Ok(p) = env::var("HISTFILE") { + is_file(PathBuf::from(p)) } else { - let user_dirs = UserDirs::new().unwrap(); - let home_dir = user_dirs.home_dir(); + is_file(I::histpath()?) + } +} - home_dir.join(".bash_history") - }; +fn is_file(p: PathBuf) -> Result { + if p.is_file() { + Ok(p) + } else { + Err(eyre!( + "Could not find history file {:?}. Try setting $HISTFILE", + p + )) + } +} - let bash = Bash::new(histpath)?; +fn fill_buf(buf: &mut Vec, iter: &mut impl Iterator>) -> bool { + buf.clear(); + loop { + match iter.next() { + Some(Ok(t)) => buf.push(t), + Some(Err(_)) => (), + None => break true, + } - let progress = ProgressBar::new(bash.loc); - - let buf_size = 100; - let mut buf = Vec::::with_capacity(buf_size); - - for i in bash - .filter_map(Result::ok) - .filter(|x| !x.command.trim().is_empty()) - { - buf.push(i); - - if buf.len() == buf_size { - db.save_bulk(&buf).await?; - progress.inc(buf.len() as u64); - - buf.clear(); + if buf.len() == buf.capacity() { + break false; } } - - if !buf.is_empty() { - db.save_bulk(&buf).await?; - progress.inc(buf.len() as u64); - } - - progress.finish(); - println!("Import complete!"); - - Ok(()) +} + +#[cfg(test)] +mod tests { + use super::fill_buf; + + #[test] + fn test_fill_buf() { + let mut buf = Vec::with_capacity(4); + let mut iter = vec![ + Ok(1), + Err(2), + Ok(3), + Ok(4), + Err(5), + Ok(6), + Ok(7), + Err(8), + Ok(9), + ] + .into_iter(); + + assert!(!fill_buf(&mut buf, &mut iter)); + assert_eq!(buf, vec![1, 3, 4, 6]); + + assert!(fill_buf(&mut buf, &mut iter)); + assert_eq!(buf, vec![7, 9]); + } }