Visit the whole Expr tree while lowering params

This commit is contained in:
Agus Zubiaga 2024-08-17 15:10:54 -03:00
parent bcd8e7e28a
commit 6588a32195
No known key found for this signature in database

View File

@ -44,118 +44,306 @@ pub fn lower(
var_store,
};
let mut index = 0;
while index < decls.len() {
let tag = decls.declarations[index];
match tag {
Value => {
env.lower_expr(&mut decls.expressions[index].value);
if let Some(new_arg) = env.home_params_argument() {
decls.convert_value_to_function(index, vec![new_arg], env.var_store);
}
}
Function(fn_def_index) => {
if let Some((_, mark, pattern)) = env.home_params_argument() {
let var = env.var_store.fresh();
decls.function_bodies[fn_def_index.index()]
.value
.arguments
.push((var, mark, pattern));
if let Some(ann) = &mut decls.annotations[index] {
if let Type::Function(args, _, _) = &mut ann.signature {
args.push(Type::Variable(var))
}
}
}
env.lower_expr(&mut decls.expressions[index].value)
}
Recursive(_) => { /* todo */ }
MutualRecursion { length, .. } => {
/* todo */
index += length as usize;
}
TailRecursive(_) => { /* todo */ }
Destructure(_) => { /* todo */ }
Expectation => { /* todo */ }
ExpectationFx => { /* todo */ }
}
index += 1;
}
env.lower_decls(decls);
}
impl<'a> LowerParams<'a> {
fn lower_expr(&mut self, expr: &mut Expr) {
match expr {
ParamsVar {
symbol,
var,
params_symbol,
params_var,
} => {
// A referece to a top-level value def in an imported module with params
*expr = self.call_params_var(*symbol, *var, *params_symbol, *params_var);
}
Var(symbol, var) => {
if self.is_params_extended_home_symbol(symbol) {
// A reference to a top-level value def in the home module with params
let params = self.home_params.as_ref().unwrap();
*expr =
self.call_params_var(*symbol, *var, params.whole_symbol, params.whole_var);
}
}
Call(fun, args, _called_via) => {
for arg in args.iter_mut() {
// todo: params var in arg
self.lower_expr(&mut arg.1.value);
}
fn lower_decls(&mut self, decls: &mut Declarations) {
let mut index = 0;
match fun.1.value {
// A call to a function in an imported module with params
ParamsVar {
symbol,
var,
params_var,
params_symbol,
} => {
args.push((params_var, Loc::at_zero(Var(params_symbol, params_var))));
fun.1.value = Var(symbol, var);
while index < decls.len() {
let tag = decls.declarations[index];
match tag {
Value => {
self.lower_expr(&mut decls.expressions[index].value);
if let Some(new_arg) = self.home_params_argument() {
decls.convert_value_to_function(index, vec![new_arg], self.var_store);
}
Var(symbol, _var) => {
if self.is_params_extended_home_symbol(&symbol) {
// A call to a top-level function in the home module with params
let params = self.home_params.as_ref().unwrap();
args.push((
params.whole_var,
Loc::at_zero(Var(params.whole_symbol, params.whole_var)),
));
}
Function(fn_def_index) => {
if let Some((_, mark, pattern)) = self.home_params_argument() {
let var = self.var_store.fresh();
decls.function_bodies[fn_def_index.index()]
.value
.arguments
.push((var, mark, pattern));
if let Some(ann) = &mut decls.annotations[index] {
if let Type::Function(args, _, _) = &mut ann.signature {
args.push(Type::Variable(var))
}
}
}
_ => self.lower_expr(&mut fun.1.value),
}
}
Closure(ClosureData {
function_type: _,
closure_type: _,
return_type: _,
name: _,
captured_symbols: _,
recursive: _,
arguments: _,
loc_body,
}) => {
// todo: capture params?
self.lower_expr(&mut loc_body.value);
self.lower_expr(&mut decls.expressions[index].value)
}
Recursive(_) => { /* todo */ }
MutualRecursion { length, .. } => {
/* todo */
index += length as usize;
}
TailRecursive(_) => { /* todo */ }
Destructure(_) => { /* todo */ }
Expectation => { /* todo */ }
ExpectationFx => { /* todo */ }
}
index += 1;
}
}
fn lower_expr(&mut self, expr: &mut Expr) {
let mut expr_stack = vec![expr];
while let Some(expr) = expr_stack.pop() {
match expr {
// Nodes to lower
ParamsVar {
symbol,
var,
params_symbol,
params_var,
} => {
// A referece to a top-level value def in an imported module with params
*expr = self.call_params_var(*symbol, *var, *params_symbol, *params_var);
}
Var(symbol, var) => {
if self.is_params_extended_home_symbol(symbol) {
// A reference to a top-level value def in the home module with params
let params = self.home_params.as_ref().unwrap();
*expr = self.call_params_var(
*symbol,
*var,
params.whole_symbol,
params.whole_var,
);
}
}
Call(fun, args, _called_via) => {
expr_stack.reserve(args.len() + 1);
match fun.1.value {
// A call to a function in an imported module with params
ParamsVar {
symbol,
var,
params_var,
params_symbol,
} => {
args.push((params_var, Loc::at_zero(Var(params_symbol, params_var))));
fun.1.value = Var(symbol, var);
}
Var(symbol, _var) => {
if self.is_params_extended_home_symbol(&symbol) {
// A call to a top-level function in the home module with params
let params = self.home_params.as_ref().unwrap();
args.push((
params.whole_var,
Loc::at_zero(Var(params.whole_symbol, params.whole_var)),
));
}
}
_ => expr_stack.push(&mut fun.1.value),
}
for (_, arg) in args.iter_mut() {
expr_stack.push(&mut arg.value);
}
}
Closure(ClosureData {
function_type: _,
closure_type: _,
return_type: _,
name: _,
captured_symbols: _,
recursive: _,
arguments: _,
loc_body,
}) => {
// todo: capture params?
expr_stack.push(&mut loc_body.value);
}
// Nodes to walk
LetNonRec(def, cont) => {
expr_stack.reserve(2);
expr_stack.push(&mut def.loc_expr.value);
expr_stack.push(&mut cont.value);
}
LetRec(defs, cont, _cycle_mark) => {
expr_stack.reserve(defs.len() + 1);
for def in defs {
expr_stack.push(&mut def.loc_expr.value);
}
expr_stack.push(&mut cont.value);
}
When {
loc_cond,
branches,
cond_var: _,
expr_var: _,
region: _,
branches_cond_var: _,
exhaustive: _,
} => {
expr_stack.reserve(branches.len() + 1);
expr_stack.push(&mut loc_cond.value);
for branch in branches.iter_mut() {
expr_stack.push(&mut branch.value.value);
}
}
If {
branches,
final_else,
cond_var: _,
branch_var: _,
} => {
expr_stack.reserve(branches.len() * 2 + 1);
for (cond, ret) in branches.iter_mut() {
expr_stack.push(&mut cond.value);
expr_stack.push(&mut ret.value);
}
expr_stack.push(&mut final_else.value);
}
RunLowLevel {
args,
op: _,
ret_var: _,
}
| ForeignCall {
foreign_symbol: _,
args,
ret_var: _,
} => {
expr_stack.extend(args.iter_mut().map(|(_, arg)| arg));
}
List {
elem_var: _,
loc_elems,
} => {
expr_stack.extend(loc_elems.iter_mut().map(|loc_elem| &mut loc_elem.value));
}
Record {
record_var: _,
fields,
} => {
expr_stack.extend(
fields
.iter_mut()
.map(|(_, field)| &mut field.loc_expr.value),
);
}
Tuple {
tuple_var: _,
elems,
} => {
expr_stack.extend(elems.iter_mut().map(|(_, elem)| &mut elem.value));
}
ImportParams(_, _, Some((_, params_expr))) => {
expr_stack.push(params_expr);
}
Crash { msg, ret_var: _ } => {
expr_stack.push(&mut msg.value);
}
RecordAccess {
loc_expr,
record_var: _,
ext_var: _,
field_var: _,
field: _,
} => expr_stack.push(&mut loc_expr.value),
TupleAccess {
loc_expr,
tuple_var: _,
ext_var: _,
elem_var: _,
index: _,
} => expr_stack.push(&mut loc_expr.value),
RecordUpdate {
updates,
record_var: _,
ext_var: _,
symbol: _,
} => expr_stack.extend(
updates
.iter_mut()
.map(|(_, field)| &mut field.loc_expr.value),
),
Tag {
arguments,
tag_union_var: _,
ext_var: _,
name: _,
} => expr_stack.extend(arguments.iter_mut().map(|(_, arg)| &mut arg.value)),
OpaqueRef {
argument,
opaque_var: _,
name: _,
specialized_def_type: _,
type_arguments: _,
lambda_set_variables: _,
} => expr_stack.push(&mut argument.1.value),
Expect {
loc_condition,
loc_continuation,
lookups_in_cond: _,
} => {
expr_stack.reserve(2);
expr_stack.push(&mut loc_condition.value);
expr_stack.push(&mut loc_continuation.value);
}
ExpectFx {
loc_condition,
loc_continuation,
lookups_in_cond: _,
} => {
expr_stack.reserve(2);
expr_stack.push(&mut loc_condition.value);
expr_stack.push(&mut loc_continuation.value);
}
Dbg {
loc_message,
loc_continuation,
source_location: _,
source: _,
variable: _,
symbol: _,
} => {
expr_stack.reserve(2);
expr_stack.push(&mut loc_message.value);
expr_stack.push(&mut loc_continuation.value);
}
RecordAccessor(_)
| ImportParams(_, _, None)
| ZeroArgumentTag {
closure_name: _,
variant_var: _,
ext_var: _,
name: _,
}
| OpaqueWrapFunction(_)
| EmptyRecord
| TypedHole(_)
| RuntimeError(_)
| Num(_, _, _, _)
| Int(_, _, _, _, _)
| Float(_, _, _, _, _)
| Str(_)
| SingleQuote(_, _, _, _)
| IngestedFile(_, _, _)
| AbilityMember(_, _, _) => { /* terminal */ }
}
_ => { /* todo */ }
}
}