diff --git a/compiler/passes/src/common/graph/mod.rs b/compiler/passes/src/common/graph/mod.rs index 1880317e2e..18164ee896 100644 --- a/compiler/passes/src/common/graph/mod.rs +++ b/compiler/passes/src/common/graph/mod.rs @@ -79,26 +79,19 @@ impl DiGraph { // The set of nodes that are on the path to the current node in the search. let mut discovered: IndexSet = IndexSet::new(); // Check if there is a cycle in the graph starting from `node`. - if self.contains_cycle_from(*node, &mut discovered, &mut finished) { - let path = match discovered.pop() { - // TODO: Should this error more silently? - None => unreachable!("If `contains_cycle_from` returns `true`, `discovered` is not empty."), - Some(node) => { - let mut path = vec![node]; - // Backtrack through the discovered nodes to find the cycle. - while let Some(next) = discovered.pop() { - // Add the node to the path. - path.push(next); - // If the node is the same as the first node in the path, we have found the cycle. - if next == node { - break; - } - } - // Reverse the path to get the cycle in the correct order. - path.reverse(); - path + if let Some(node) = self.contains_cycle_from(*node, &mut discovered, &mut finished) { + let mut path = vec![node]; + // Backtrack through the discovered nodes to find the cycle. + while let Some(next) = discovered.pop() { + // Add the node to the path. + path.push(next); + // If the node is the same as the first node in the path, we have found the cycle. + if next == node { + break; } - }; + } + // Reverse the path to get the cycle in the correct order. + path.reverse(); // A cycle was detected. Return the path of the cycle. return Err(GraphError::CycleDetected(path)); } @@ -109,8 +102,13 @@ impl DiGraph { } // Detects if there is a cycle in the graph starting from the given node, via a recursive depth-first search. + // If there is no cycle, returns `None`. + // If there is a cycle, returns the node that was most recently discovered. // Nodes are added to to `finished` in topological order. - fn contains_cycle_from(&self, node: N, discovered: &mut IndexSet, finished: &mut IndexSet) -> bool { + fn contains_cycle_from(&self, node: N, discovered: &mut IndexSet, finished: &mut IndexSet) -> Option{ + println!("discovered: {:?}", discovered); + println!("finished: {:?}", finished); + println!("node: {:?}\n", node); // Add the node to the set of discovered nodes. discovered.insert(node); @@ -121,12 +119,13 @@ impl DiGraph { if discovered.contains(child) { // Insert the child node into the set of discovered nodes; this is used to reconstruct the cycle. // Note that this case is always hit when there is a cycle. - discovered.insert(*child); - return true; + return Some(*child); } // If the node has not been explored, explore it. - if !finished.contains(child) && self.contains_cycle_from(*child, discovered, finished) { - return true; + if !finished.contains(child) { + if let Some(child) = self.contains_cycle_from(*child, discovered, finished) { + return Some(child); + } } } } @@ -136,6 +135,67 @@ impl DiGraph { // Add the node to the set of finished nodes. finished.insert(node); - false + None } } + +#[cfg(test)] +mod test { + use super::*; + + impl Node for u32 {} + + #[test] + fn test_toposort() { + let mut graph = DiGraph::::new(IndexSet::new()); + + graph.add_edge(1, 2); + graph.add_edge(1, 3); + graph.add_edge(2, 4); + graph.add_edge(3, 4); + graph.add_edge(4, 5); + + // At this point, the graph looks like: + // 1 + // / \ + // 2 3 + // \ / + // 4 + // | + // 5 + + let result = graph.topological_sort(); + assert!(result.is_ok()); + + let order: Vec = result.unwrap().into_iter().collect(); + let expected = Vec::from([5u32, 4, 2, 3, 1]); + assert_eq!(order, expected); + } + + #[test] + fn test_toposort_cycle() { + let mut graph = DiGraph::::new(IndexSet::new()); + + graph.add_edge(1, 2); + graph.add_edge(2, 3); + graph.add_edge(2, 4); + graph.add_edge(4, 1); + + // At this point, the graph looks like: + // 1 + // | + // 2 + // / \ + // 3 4 + // | + // 1 + + let result = graph.topological_sort(); + assert!(result.is_err()); + + let GraphError::CycleDetected(cycle) = result.unwrap_err(); + let expected = Vec::from([1u32, 2, 4, 1]); + assert_eq!(cycle, expected); + } + +}