diff --git a/compiler/lcalc/closure_conversion.ml b/compiler/lcalc/closure_conversion.ml index 2e8ba1c5..09ce1387 100644 --- a/compiler/lcalc/closure_conversion.ml +++ b/compiler/lcalc/closure_conversion.ml @@ -24,7 +24,31 @@ type 'm ctx = { globally_bound_vars : ('m expr, typ) Var.Map.t; } -let tys_as_tanys tys = List.map (fun x -> Mark.map (fun _ -> TAny) x) tys +(** Function types will be transformed in this way throughout, including in + [decl_ctx] *) +let rec translate_type t = + let pos = Mark.get t in + match Mark.remove t with + | TArrow (t1, t2) -> + ( TTuple + [ + ( TArrow + ( (TClosureEnv, Pos.no_pos) :: List.map translate_type t1, + translate_type t2 ), + Pos.no_pos ); + TClosureEnv, Pos.no_pos; + ], + pos ) + | TDefault t' -> TDefault (translate_type t'), pos + | TOption t' -> TOption (translate_type t'), pos + | TAny | TClosureEnv | TLit _ | TEnum _ | TStruct _ -> t + | TArray ts -> TArray (translate_type ts), pos + | TTuple ts -> TTuple (List.map translate_type ts), pos + +let translate_mark e = Mark.map_mark (Expr.map_ty translate_type) e + +let join_vars : ('a, 'x) Var.Map.t -> ('a, 'x) Var.Map.t -> ('a, 'x) Var.Map.t = + fun m1 m2 -> Var.Map.union (fun _ a _ -> Some a) m1 m2 (** {1 Transforming closures}*) @@ -33,19 +57,20 @@ let tys_as_tanys tys = List.map (fun x -> Mark.map (fun _ -> TAny) x) tys http://gallium.inria.fr/~fpottier/mpri/cours04.pdf#page=10 (environment-passing closure conversion). *) let rec transform_closures_expr : - type m. m ctx -> m expr -> m expr Var.Set.t * m expr boxed = + type m. m ctx -> m expr -> (m expr, m mark) Var.Map.t * m expr boxed = fun ctx e -> + let e = translate_mark e in let m = Mark.get e in match Mark.remove e with | EStruct _ | EStructAccess _ | ETuple _ | ETupleAccess _ | EInj _ | EArray _ | ELit _ | EExternal _ | EAssert _ | EFatalError _ | EIfThenElse _ | ERaiseEmpty | ECatchEmpty _ -> - Expr.map_gather ~acc:Var.Set.empty ~join:Var.Set.union + Expr.map_gather ~acc:Var.Map.empty ~join:join_vars ~f:(transform_closures_expr ctx) e | EVar v -> ( match Var.Map.find_opt v ctx.globally_bound_vars with - | None -> Var.Set.singleton v, (Bindlib.box_var v, m) + | None -> Var.Map.singleton v m, (Bindlib.box_var v, m) | Some (TArrow (targs, tret), _) -> (* Here we eta-expand the argument to make sure function pointers are correctly casted as closures *) @@ -69,13 +94,13 @@ let rec transform_closures_expr : { ctx with globally_bound_vars = - Var.Map.add v (TAny, Pos.no_pos) ctx.globally_bound_vars; + Var.Map.add v (Expr.maybe_ty m) ctx.globally_bound_vars; } in Bindlib.box_apply (transform_closures_expr ctx) (Expr.Box.lift e) in Bindlib.unbox boxed - | Some _ -> Var.Set.empty, (Bindlib.box_var v, m)) + | Some _ -> Var.Map.empty, (Bindlib.box_var v, m)) | EMatch { e; cases; name } -> let free_vars, new_e = (transform_closures_expr ctx) e in (* We do not close the clotures inside the arms of the match expression, @@ -89,13 +114,11 @@ let rec transform_closures_expr : let new_free_vars, new_body = (transform_closures_expr ctx) body in let new_free_vars = Array.fold_left - (fun acc v -> Var.Set.remove v acc) + (fun acc v -> Var.Map.remove v acc) new_free_vars vars in let new_binder = Expr.bind vars new_body in - ( Var.Set.union free_vars - (Var.Set.diff new_free_vars - (Var.Set.of_list (Array.to_list vars))), + ( join_vars free_vars new_free_vars, EnumConstructor.Map.add cons (Expr.eabs new_binder tys (Mark.get e1)) new_cases ) @@ -109,54 +132,58 @@ let rec transform_closures_expr : let vars, body = Bindlib.unmbind binder in let free_vars, new_body = (transform_closures_expr ctx) body in let free_vars = - Array.fold_left (fun acc v -> Var.Set.remove v acc) free_vars vars + Array.fold_left (fun acc v -> Var.Map.remove v acc) free_vars vars in let new_binder = Expr.bind vars new_body in let free_vars, new_args = List.fold_right (fun arg (free_vars, new_args) -> let new_free_vars, new_arg = (transform_closures_expr ctx) arg in - Var.Set.union free_vars new_free_vars, new_arg :: new_args) + join_vars free_vars new_free_vars, new_arg :: new_args) args (free_vars, []) in ( free_vars, Expr.eapp - ~f:(Expr.eabs new_binder (tys_as_tanys tys) e1_pos) + ~f:(Expr.eabs new_binder (List.map translate_type tys) e1_pos) ~args:new_args ~tys m ) | EAbs { binder; tys } -> (* λ x.t *) - let binder_mark = Expr.with_ty m (TAny, Expr.mark_pos m) in - let binder_pos = Expr.mark_pos binder_mark in + let binder_pos = Expr.mark_pos m in + let mark_ty ty = Expr.with_ty m ty in (* Converting the closure. *) let vars, body = Bindlib.unmbind binder in (* t *) let body_vars, new_body = (transform_closures_expr ctx) body in (* [[t]] *) let extra_vars = - Var.Set.diff body_vars (Var.Set.of_list (Array.to_list vars)) + Array.fold_left (fun m v -> Var.Map.remove v m) body_vars vars + in + let extra_vars_list = Var.Map.bindings extra_vars in + let extra_vars_types = + List.map (fun (_, m) -> Expr.maybe_ty m) extra_vars_list in - let extra_vars_list = Var.Set.elements extra_vars in (* x1, ..., xn *) let code_var = Var.make ctx.name_context in (* code *) let closure_env_arg_var = Var.make "env" in let closure_env_var = Var.make "env" in - let any_ty = TAny, binder_pos in + let env_ty = TTuple extra_vars_types, binder_pos in (* let env = from_closure_env env in let arg0 = env.0 in ... *) let new_closure_body = - Expr.make_let_in closure_env_var any_ty + Expr.make_let_in closure_env_var env_ty (Expr.eappop ~op:(Operator.FromClosureEnv, binder_pos) ~tys:[TClosureEnv, binder_pos] - ~args:[Expr.evar closure_env_arg_var binder_mark] - binder_mark) + ~args: + [Expr.evar closure_env_arg_var (mark_ty (TClosureEnv, binder_pos))] + (mark_ty env_ty)) (Expr.make_multiple_let_in - (Array.of_list extra_vars_list) - (List.map (fun _ -> any_ty) extra_vars_list) + (Array.of_list (List.map fst extra_vars_list)) + extra_vars_types (List.mapi (fun i _ -> Expr.make_tupleaccess - (Expr.evar closure_env_var binder_mark) + (Expr.evar closure_env_var (mark_ty env_ty)) i (List.length extra_vars_list) binder_pos) @@ -167,33 +194,39 @@ let rec transform_closures_expr : (* fun env arg0 ... -> new_closure_body *) let new_closure = Expr.make_abs - (Array.concat [Array.make 1 closure_env_arg_var; vars]) + (Array.append [| closure_env_arg_var |] vars) new_closure_body ((TClosureEnv, binder_pos) :: tys) (Expr.pos e) in + let new_closure_ty = Expr.maybe_ty (Mark.get new_closure) in ( extra_vars, - Expr.make_let_in code_var - (TAny, Expr.pos e) - new_closure + Expr.make_let_in code_var new_closure_ty new_closure (Expr.make_tuple - ((Bindlib.box_var code_var, binder_mark) + ((Bindlib.box_var code_var, mark_ty new_closure_ty) :: [ Expr.eappop ~op:(Operator.ToClosureEnv, binder_pos) - ~tys:[TAny, Expr.pos e] + ~tys: + [ + ( (if extra_vars_list = [] then TLit TUnit + else TTuple extra_vars_types), + binder_pos ); + ] ~args: [ - (if extra_vars_list = [] then Expr.elit LUnit binder_mark + (if extra_vars_list = [] then + Expr.elit LUnit (mark_ty (TLit TUnit, binder_pos)) else Expr.etuple (List.map - (fun extra_var -> - Bindlib.box_var extra_var, binder_mark) + (fun (extra_var, m) -> + ( Bindlib.box_var extra_var, + Expr.with_pos binder_pos m )) extra_vars_list) - m); + (mark_ty (TTuple extra_vars_types, binder_pos))); ] - (Mark.get e); + (mark_ty (TClosureEnv, binder_pos)); ]) m) (Expr.pos e) ) @@ -219,16 +252,16 @@ let rec transform_closures_expr : let new_arg = Expr.make_abs vars new_arg tys (Expr.mark_pos m_arg) in - Var.Set.union free_vars new_free_vars, new_arg :: new_args + join_vars free_vars new_free_vars, new_arg :: new_args | _ -> let new_free_vars, new_arg = transform_closures_expr ctx arg in - Var.Set.union free_vars new_free_vars, new_arg :: new_args) - args (Var.Set.empty, []) + join_vars free_vars new_free_vars, new_arg :: new_args) + args (Var.Map.empty, []) in free_vars, Expr.eappop ~op ~tys ~args:new_args (Mark.get e) | EAppOp _ -> (* This corresponds to an operator call, which we don't want to transform *) - Expr.map_gather ~acc:Var.Set.empty ~join:Var.Set.union + Expr.map_gather ~acc:Var.Map.empty ~join:join_vars ~f:(transform_closures_expr ctx) e | EApp { f = EVar v, f_m; args; tys } @@ -239,12 +272,16 @@ let rec transform_closures_expr : List.fold_right (fun arg (free_vars, new_args) -> let new_free_vars, new_arg = (transform_closures_expr ctx) arg in - Var.Set.union free_vars new_free_vars, new_arg :: new_args) - args (Var.Set.empty, []) + join_vars free_vars new_free_vars, new_arg :: new_args) + args (Var.Map.empty, []) in free_vars, Expr.eapp ~f:(Expr.evar v f_m) ~args:new_args ~tys m | EApp { f = e1; args; tys } -> let free_vars, new_e1 = (transform_closures_expr ctx) e1 in + let tys = List.map translate_type tys in + let pos = Expr.mark_pos m in + let env_arg_ty = TClosureEnv, Expr.pos new_e1 in + let fun_ty = TArrow (env_arg_ty :: tys, Expr.maybe_ty m), pos in let code_env_var = Var.make "code_and_env" in let code_env_expr = let pos = Expr.pos e1 in @@ -252,8 +289,7 @@ let rec transform_closures_expr : (Expr.with_ty (Mark.get e1) ( TTuple [ - ( TArrow ((TClosureEnv, pos) :: tys, (TAny, Expr.pos e)), - Expr.pos e ); + TArrow ((TClosureEnv, pos) :: tys, Expr.maybe_ty m), Expr.pos e; TClosureEnv, pos; ], pos )) @@ -264,24 +300,23 @@ let rec transform_closures_expr : List.fold_right (fun arg (free_vars, new_args) -> let new_free_vars, new_arg = (transform_closures_expr ctx) arg in - Var.Set.union free_vars new_free_vars, new_arg :: new_args) + join_vars free_vars new_free_vars, new_arg :: new_args) args (free_vars, []) in let call_expr = - let m1 = Mark.get e1 in - let pos = Expr.mark_pos m in - let env_arg_ty = TClosureEnv, Expr.pos e1 in - let fun_ty = TArrow (env_arg_ty :: tys, (TAny, Expr.pos e)), Expr.pos e in + let m1 = Mark.get new_e1 in Expr.make_multiple_let_in [| code_var; env_var |] [fun_ty; env_arg_ty] [ Expr.make_tupleaccess code_env_expr 0 2 pos; Expr.make_tupleaccess code_env_expr 1 2 pos; ] - (Expr.eapp - ~f:(Bindlib.box_var code_var, m1) - ~args:((Bindlib.box_var env_var, m1) :: new_args) - ~tys:(env_arg_ty :: tys) m) - (Expr.pos e) + (Expr.make_app + (Bindlib.box_var code_var, Expr.with_ty m1 fun_ty) + ((Bindlib.box_var env_var, Expr.with_ty m1 env_arg_ty) :: new_args) + (env_arg_ty + :: (* List.map (fun (_, m) -> Expr.maybe_ty m) new_args *) tys) + pos) + pos in ( free_vars, Expr.make_let_in code_env_var @@ -393,33 +428,15 @@ let transform_closures_program (p : 'm program) : 'm program Bindlib.box = capture footprint. See [tests/tests_func/good/scope_call_func_struct_closure.catala_en]. *) let new_decl_ctx = - let rec replace_fun_typs t = - match Mark.remove t with - | TArrow (t1, t2) -> - ( TTuple - [ - ( TArrow - ( (TClosureEnv, Pos.no_pos) :: List.map replace_fun_typs t1, - replace_fun_typs t2 ), - Pos.no_pos ); - TClosureEnv, Pos.no_pos; - ], - Mark.get t ) - | TDefault t' -> TDefault (replace_fun_typs t'), Mark.get t - | TOption t' -> TOption (replace_fun_typs t'), Mark.get t - | TAny | TClosureEnv | TLit _ | TEnum _ | TStruct _ -> t - | TArray ts -> TArray (replace_fun_typs ts), Mark.get t - | TTuple ts -> TTuple (List.map replace_fun_typs ts), Mark.get t - in { p.decl_ctx with ctx_structs = StructName.Map.map - (StructField.Map.map replace_fun_typs) + (StructField.Map.map translate_type) p.decl_ctx.ctx_structs; ctx_enums = EnumName.Map.map - (EnumConstructor.Map.map replace_fun_typs) + (EnumConstructor.Map.map translate_type) p.decl_ctx.ctx_enums; (* Toplevel definitions may not contain scope calls or take functions as arguments at the moment, which ensures that their interfaces aren't @@ -489,9 +506,7 @@ let rec hoist_closures_expr : args (collected_closures, []) in ( collected_closures, - Expr.eapp - ~f:(Expr.eabs new_binder (tys_as_tanys tys) e1_pos) - ~args:new_args ~tys m ) + Expr.eapp ~f:(Expr.eabs new_binder tys e1_pos) ~args:new_args ~tys m ) | EAppOp { op = ((HandleDefaultOpt | Fold | Map | Filter | Reduce), _) as op; @@ -525,20 +540,16 @@ let rec hoist_closures_expr : in collected_closures, Expr.eappop ~op ~args:new_args ~tys (Mark.get e) | EAbs { tys; _ } -> - (* this is the closure we want to hoist*) + (* this is the closure we want to hoist *) let closure_var = Var.make ("closure_" ^ name_context) in (* TODO: This will end up as a toplevel name. However for now we assume toplevel names are unique, but this breaks this assertions and can lead to name wrangling in the backends. We need to have a better system for name disambiguation when for instance printing to Dcalc/Lcalc/Scalc but also OCaml, Python, etc. *) - ( [ - { - name = closure_var; - ty = TArrow (tys, (TAny, Expr.mark_pos m)), Expr.mark_pos m; - closure = Expr.rebox e; - }; - ], + let pos = Expr.mark_pos m in + let ty = Expr.maybe_ty ~typ:(TArrow (tys, (TAny, pos))) m in + ( [{ name = closure_var; ty; closure = Expr.rebox e }], Expr.make_var closure_var m ) | EApp _ | EStruct _ | EStructAccess _ | ETuple _ | ETupleAccess _ | EInj _ | EArray _ | ELit _ | EAssert _ | EFatalError _ | EAppOp _ | EIfThenElse _ @@ -660,9 +671,9 @@ let hoist_closures_program (p : 'm program) : 'm program Bindlib.box = (** {1 Closure conversion}*) -let closure_conversion (p : 'm program) : untyped program = +let closure_conversion (p : 'm program) : 'm program = let new_p = transform_closures_program p in let new_p = hoist_closures_program (Bindlib.unbox new_p) in (* FIXME: either fix the types of the marks, or remove the types annotations during the main processing (rather than requiring a new traversal) *) - Program.untype (Bindlib.unbox new_p) + Bindlib.unbox new_p diff --git a/compiler/lcalc/closure_conversion.mli b/compiler/lcalc/closure_conversion.mli index 415f4681..8dbad9fb 100644 --- a/compiler/lcalc/closure_conversion.mli +++ b/compiler/lcalc/closure_conversion.mli @@ -21,4 +21,4 @@ After closure conversion, closure hoisting is perform and all closures end up as toplevel definitions. *) -val closure_conversion : 'm Ast.program -> Shared_ast.untyped Ast.program +val closure_conversion : 'm Ast.program -> 'm Ast.program diff --git a/compiler/shared_ast/var.ml b/compiler/shared_ast/var.ml index d74be626..12f64c83 100644 --- a/compiler/shared_ast/var.ml +++ b/compiler/shared_ast/var.ml @@ -100,6 +100,7 @@ module Map = struct let empty = empty let singleton v x = singleton (t v) x let add v x m = add (t v) x m + let remove v m = remove (t v) m let update v f m = update (t v) f m let find v m = find (t v) m let find_opt v m = find_opt (t v) m diff --git a/compiler/shared_ast/var.mli b/compiler/shared_ast/var.mli index 0aa92bda..0e741dca 100644 --- a/compiler/shared_ast/var.mli +++ b/compiler/shared_ast/var.mli @@ -64,6 +64,7 @@ module Map : sig val empty : ('e, 'x) t val singleton : 'e var -> 'x -> ('e, 'x) t val add : 'e var -> 'x -> ('e, 'x) t -> ('e, 'x) t + val remove : 'e var -> ('e, 'x) t -> ('e, 'x) t val update : 'e var -> ('x option -> 'x option) -> ('e, 'x) t -> ('e, 'x) t val find : 'e var -> ('e, 'x) t -> 'x val find_opt : 'e var -> ('e, 'x) t -> 'x option