diff --git a/asyncgit/src/sync/logwalker.rs b/asyncgit/src/sync/logwalker.rs index cc73e6de..56186f6f 100644 --- a/asyncgit/src/sync/logwalker.rs +++ b/asyncgit/src/sync/logwalker.rs @@ -28,11 +28,16 @@ impl<'a> Ord for TimeOrderedCommit<'a> { } } +type LogWalkerFilter = + Box Result>; + /// pub struct LogWalker<'a> { commits: BinaryHeap>, visited: HashSet, limit: usize, + repo: &'a Repository, + filter: Option, } impl<'a> LogWalker<'a> { @@ -47,9 +52,19 @@ impl<'a> LogWalker<'a> { commits, limit, visited: HashSet::with_capacity(1000), + repo, + filter: None, }) } + /// + pub fn filter(self, filter: LogWalkerFilter) -> Self { + Self { + filter: Some(filter), + ..self + } + } + /// pub fn read(&mut self, out: &mut Vec) -> Result { let mut count = 0_usize; @@ -59,7 +74,17 @@ impl<'a> LogWalker<'a> { self.visit(p); } - out.push(c.0.id().into()); + let id: CommitId = c.0.id().into(); + let commit_should_be_included = + if let Some(ref filter) = self.filter { + filter(self.repo, &id)? + } else { + true + }; + + if commit_should_be_included { + out.push(id); + } count += 1; if count == self.limit { @@ -82,9 +107,10 @@ impl<'a> LogWalker<'a> { #[cfg(test)] mod tests { use super::*; + use crate::error::Result; use crate::sync::{ - commit, get_commits_info, stage_add_file, - tests::repo_init_empty, + commit, commit_files::get_commit_diff, get_commits_info, + stage_add_file, tests::repo_init_empty, }; use pretty_assertions::assert_eq; use std::{fs::File, io::Write, path::Path}; @@ -144,4 +170,79 @@ mod tests { Ok(()) } + + #[test] + fn test_logwalker_with_filter() -> Result<()> { + let file_path = Path::new("foo"); + let second_file_path = Path::new("baz"); + let (_td, repo) = repo_init_empty().unwrap(); + let root = repo.path().parent().unwrap(); + let repo_path = root.as_os_str().to_str().unwrap(); + + File::create(&root.join(file_path))?.write_all(b"a")?; + stage_add_file(repo_path, file_path).unwrap(); + + let _first_commit_id = commit(repo_path, "commit1").unwrap(); + + File::create(&root.join(second_file_path))? + .write_all(b"a")?; + stage_add_file(repo_path, second_file_path).unwrap(); + + let second_commit_id = commit(repo_path, "commit2").unwrap(); + + File::create(&root.join(file_path))?.write_all(b"b")?; + stage_add_file(repo_path, file_path).unwrap(); + + let _third_commit_id = commit(repo_path, "commit3").unwrap(); + + let diff_contains_baz = |repo: &Repository, + commit_id: &CommitId| + -> Result { + let diff = get_commit_diff( + &repo, + *commit_id, + Some("baz".into()), + )?; + + let contains_file = diff.deltas().len() > 0; + + Ok(contains_file) + }; + + let mut items = Vec::new(); + let mut walker = LogWalker::new(&repo, 100)? + .filter(Box::new(diff_contains_baz)); + walker.read(&mut items).unwrap(); + + assert_eq!(items.len(), 1); + assert_eq!(items[0], second_commit_id.into()); + + let mut items = Vec::new(); + walker.read(&mut items).unwrap(); + + assert_eq!(items.len(), 0); + + let diff_contains_bar = |repo: &Repository, + commit_id: &CommitId| + -> Result { + let diff = get_commit_diff( + &repo, + *commit_id, + Some("bar".into()), + )?; + + let contains_file = diff.deltas().len() > 0; + + Ok(contains_file) + }; + + let mut items = Vec::new(); + let mut walker = LogWalker::new(&repo, 100)? + .filter(Box::new(diff_contains_bar)); + walker.read(&mut items).unwrap(); + + assert_eq!(items.len(), 0); + + Ok(()) + } }