revset: flatten union nodes in AST to save recursion stack

Maybe it'll also be good to keep RevsetExpression::Union(_) flattened, but
that's not needed to get around stack overflow. The constructed expression
tree is balanced.

test_expand_symbol_alias() is slightly adjusted since there are more than
one representation for "a|b|c" now.

Fixes #4031
This commit is contained in:
Yuya Nishihara 2024-07-07 16:23:20 +09:00
parent f90b061808
commit 415c831e30
3 changed files with 63 additions and 7 deletions

View File

@ -824,13 +824,19 @@ pub fn lower_expression(
let lhs = lower_expression(lhs_node, context)?;
let rhs = lower_expression(rhs_node, context)?;
match op {
BinaryOp::Union => Ok(lhs.union(&rhs)),
BinaryOp::Intersection => Ok(lhs.intersection(&rhs)),
BinaryOp::Difference => Ok(lhs.minus(&rhs)),
BinaryOp::DagRange => Ok(lhs.dag_range_to(&rhs)),
BinaryOp::Range => Ok(lhs.range(&rhs)),
}
}
ExpressionKind::UnionAll(nodes) => {
let expressions: Vec<_> = nodes
.iter()
.map(|node| lower_expression(node, context))
.try_collect()?;
Ok(RevsetExpression::union_all(&expressions))
}
ExpressionKind::FunctionCall(function) => lower_function_call(function, context),
ExpressionKind::Modifier(modifier) => {
let name = modifier.name;

View File

@ -311,6 +311,8 @@ pub enum ExpressionKind<'i> {
RangeAll,
Unary(UnaryOp, Box<ExpressionNode<'i>>),
Binary(BinaryOp, Box<ExpressionNode<'i>>, Box<ExpressionNode<'i>>),
/// `x | y | ..`
UnionAll(Vec<ExpressionNode<'i>>),
FunctionCall(Box<FunctionCallNode<'i>>),
/// `name: body`
Modifier(Box<ModifierNode<'i>>),
@ -341,6 +343,10 @@ impl<'i> FoldableExpression<'i> for ExpressionKind<'i> {
let rhs = Box::new(folder.fold_expression(*rhs)?);
Ok(ExpressionKind::Binary(op, lhs, rhs))
}
ExpressionKind::UnionAll(nodes) => {
let nodes = dsl_util::fold_expression_nodes(folder, nodes)?;
Ok(ExpressionKind::UnionAll(nodes))
}
ExpressionKind::FunctionCall(function) => folder.fold_function_call(function, span),
ExpressionKind::Modifier(modifier) => {
let modifier = Box::new(ModifierNode {
@ -392,8 +398,6 @@ pub enum UnaryOp {
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
pub enum BinaryOp {
/// `|`
Union,
/// `&`
Intersection,
/// `~`
@ -418,6 +422,20 @@ pub struct ModifierNode<'i> {
pub body: ExpressionNode<'i>,
}
fn union_nodes<'i>(lhs: ExpressionNode<'i>, rhs: ExpressionNode<'i>) -> ExpressionNode<'i> {
let span = lhs.span.start_pos().span(&rhs.span.end_pos());
let expr = match lhs.kind {
// Flatten "x | y | z" to save recursion stack. Machine-generated query
// might have long chain of unions.
ExpressionKind::UnionAll(mut nodes) => {
nodes.push(rhs);
ExpressionKind::UnionAll(nodes)
}
_ => ExpressionKind::UnionAll(vec![lhs, rhs]),
};
ExpressionNode::new(expr, span)
}
pub(super) fn parse_program(revset_str: &str) -> Result<ExpressionNode, RevsetParseError> {
let mut pairs = RevsetParser::parse(Rule::program, revset_str)?;
let first = pairs.next().unwrap();
@ -551,7 +569,7 @@ fn parse_expression_node(pairs: Pairs<Rule>) -> Result<ExpressionNode, RevsetPar
})
.map_infix(|lhs, op, rhs| {
let op_kind = match op.as_rule() {
Rule::union_op => BinaryOp::Union,
Rule::union_op => return Ok(union_nodes(lhs?, rhs?)),
Rule::compat_add_op => Err(not_infix_op(&op, "|", "union"))?,
Rule::intersection_op => BinaryOp::Intersection,
Rule::difference_op => BinaryOp::Difference,
@ -883,6 +901,10 @@ mod tests {
let rhs = Box::new(normalize_tree(*rhs));
ExpressionKind::Binary(op, lhs, rhs)
}
ExpressionKind::UnionAll(nodes) => {
let nodes = normalize_list(nodes);
ExpressionKind::UnionAll(nodes)
}
ExpressionKind::FunctionCall(function) => {
let function = Box::new(normalize_function_call(*function));
ExpressionKind::FunctionCall(function)
@ -1067,7 +1089,11 @@ mod tests {
// Parse the "union" operator
assert_matches!(
parse_into_kind("foo | bar"),
Ok(ExpressionKind::Binary(BinaryOp::Union, _, _))
Ok(ExpressionKind::UnionAll(nodes)) if nodes.len() == 2
);
assert_matches!(
parse_into_kind("foo | bar | baz"),
Ok(ExpressionKind::UnionAll(nodes)) if nodes.len() == 3
);
// Parse the "difference" operator
assert_matches!(
@ -1479,8 +1505,8 @@ mod tests {
#[test]
fn test_expand_symbol_alias() {
assert_eq!(
with_aliases([("AB", "a|b")]).parse_normalized("AB|c"),
parse_normalized("(a|b)|c")
with_aliases([("AB", "a&b")]).parse_normalized("AB|c"),
parse_normalized("(a&b)|c")
);
assert_eq!(
with_aliases([("AB", "a|b")]).parse_normalized("AB::heads(AB)"),

View File

@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::iter;
use std::path::Path;
use assert_matches::assert_matches;
@ -2641,6 +2642,29 @@ fn test_evaluate_expression_union() {
);
}
#[test]
fn test_evaluate_expression_machine_generated_union() {
let settings = testutils::user_settings();
let test_repo = TestRepo::init();
let repo = &test_repo.repo;
let mut tx = repo.start_transaction(&settings);
let mut_repo = tx.mut_repo();
let mut graph_builder = CommitGraphBuilder::new(&settings, mut_repo);
let commit1 = graph_builder.initial_commit();
let commit2 = graph_builder.commit_with_parents(&[&commit1]);
// This query shouldn't trigger stack overflow. Here we use "x::y" in case
// we had optimization path for trivial "commit_id|.." expression.
let revset_str = iter::repeat(format!("({}::{})", commit1.id().hex(), commit2.id().hex()))
.take(5000)
.join("|");
assert_eq!(
resolve_commit_ids(mut_repo, &revset_str),
vec![commit2.id().clone(), commit1.id().clone()]
);
}
#[test]
fn test_evaluate_expression_intersection() {
let settings = testutils::user_settings();