diff --git a/compiler/lcalc/monomorphize.ml b/compiler/lcalc/monomorphize.ml index bf89104c..8a13233f 100644 --- a/compiler/lcalc/monomorphize.ml +++ b/compiler/lcalc/monomorphize.ml @@ -212,112 +212,96 @@ let is_some c = (assert (EnumConstructor.equal Expr.none_constr c); false) -(* We output a typed expr but the types in the output are wrong, it should be - untyped and re-typed later. *) let rec monomorphize_expr (monomorphized_instances : monomorphized_instances) - (e : typed expr) : typed expr boxed = - let typ = Expr.ty e in - match Mark.remove e with + (e0 : typed expr) : typed expr boxed = + let ty0 = Expr.ty e0 in + (* Keys in [monomorphized_instances] are before monomorphization, so collect + this top-down *) + let f_expr = monomorphize_expr monomorphized_instances in + let f_ty = monomorphize_typ monomorphized_instances in + (* Proceed bottom-up: apply first to the sub-terms *) + let e = Expr.map ~f:f_expr ~typ:f_ty ~op:Fun.id e0 in + let m = Mark.get e in + let map_box f = Expr.Box.app1 e (fun e -> f (Mark.remove e)) m in + map_box + @@ function | ETuple args -> - let new_args = List.map (monomorphize_expr monomorphized_instances) args in - let tuple_instance = Type.Map.find typ monomorphized_instances.tuples in - let fields = - StructField.Map.of_list - @@ List.map2 - (fun new_arg (tuple_field, _) -> tuple_field, new_arg) - new_args tuple_instance.fields + let tuple_instance = Type.Map.find ty0 monomorphized_instances.tuples in + EStruct + { + name = tuple_instance.name; + fields = + StructField.Map.of_list + @@ List.map2 + (fun (tuple_field, _) arg -> tuple_field, arg) + tuple_instance.fields args; + } + | ETupleAccess { e; index; _ } -> + (* The type of the tuple needs to be recovered from the untransformed + expr *) + let tup_ty = + match e0 with ETupleAccess { e; _ }, _ -> Expr.ty e | _ -> assert false in - Expr.estruct ~name:tuple_instance.name ~fields (Mark.get e) - | ETupleAccess { e = e1; index; _ } -> let tuple_instance = - Type.Map.find (Expr.ty e1) monomorphized_instances.tuples - in - let new_e1 = monomorphize_expr monomorphized_instances e1 in - Expr.estructaccess ~name:tuple_instance.name - ~field:(fst (List.nth tuple_instance.fields index)) - ~e:new_e1 (Mark.get e) - | EMatch { name; e = e1; cases } when EnumName.equal name Expr.option_enum -> - let new_e1 = monomorphize_expr monomorphized_instances e1 in - let new_cases = - EnumConstructor.Map.bindings - (EnumConstructor.Map.map - (monomorphize_expr monomorphized_instances) - cases) + Type.Map.find tup_ty monomorphized_instances.tuples in + EStructAccess + { + name = tuple_instance.name; + e; + field = fst (List.nth tuple_instance.fields index); + } + | EMatch { name; e; cases } when EnumName.equal name Expr.option_enum -> + let option_instance = Type.Map.find ty0 monomorphized_instances.options in + EMatch + { + name = option_instance.name; + e; + cases = + EnumConstructor.Map.fold + (fun c -> + EnumConstructor.Map.add + (if is_some c then option_instance.some_cons + else option_instance.none_cons)) + cases EnumConstructor.Map.empty; + } + | EInj { name; e; cons } when EnumName.equal name Expr.option_enum -> let option_instance = Type.Map.find - (match Mark.remove (Expr.ty e1) with - | TOption t -> t - | _ -> failwith "should not happen") + (match Mark.remove ty0 with TOption t -> t | _ -> assert false) monomorphized_instances.options in - let new_cases = - match new_cases with - | [(n1, e1); (n2, e2)] -> ( - match is_some n1, is_some n2 with - | true, false -> - [option_instance.some_cons, e1; option_instance.none_cons, e2] - | false, true -> - [option_instance.some_cons, e2; option_instance.none_cons, e1] - | _ -> failwith "should not happen") - | _ -> failwith "should not happen" + EInj + { + name = option_instance.name; + e; + cons = + (if is_some cons then option_instance.some_cons + else option_instance.none_cons); + } + | EArray elts as e -> + let elt_ty = + match Mark.remove ty0 with TArray t -> t | _ -> assert false in - let new_cases = EnumConstructor.Map.of_list new_cases in - Expr.ematch ~name:option_instance.name ~e:new_e1 ~cases:new_cases - (Mark.get e) - | EInj { name; e = e1; cons } when EnumName.equal name Expr.option_enum -> - let option_instance = - Type.Map.find - (match Mark.remove (Expr.ty e) with - | TOption t -> t - | _ -> failwith "should not happen") - monomorphized_instances.options - in - let new_e1 = monomorphize_expr monomorphized_instances e1 in - let new_cons = - if is_some cons then option_instance.some_cons - else option_instance.none_cons - in - Expr.einj ~name:option_instance.name ~e:new_e1 ~cons:new_cons (Mark.get e) - (* We do not forget to tweak types stored directly in the AST in the nodes - of kind [EAbs], [EApp] and [EAppOp]. *) - | EAbs { binder; tys } -> - let new_tys = List.map (monomorphize_typ monomorphized_instances) tys in - let vars, body = Bindlib.unmbind binder in - let new_body = monomorphize_expr monomorphized_instances body in - Expr.make_abs vars new_body new_tys (Expr.pos e) - | EApp { f; args; tys } -> - let new_f = monomorphize_expr monomorphized_instances f in - let new_args = List.map (monomorphize_expr monomorphized_instances) args in - let new_tys = List.map (monomorphize_typ monomorphized_instances) tys in - Expr.eapp ~f:new_f ~args:new_args ~tys:new_tys (Mark.get e) - | EAppOp { op; args; tys } -> - let new_args = List.map (monomorphize_expr monomorphized_instances) args in - let new_tys = List.map (monomorphize_typ monomorphized_instances) tys in - Expr.eappop ~op ~args:new_args ~tys:new_tys (Mark.get e) - | EArray args -> - let new_args = List.map (monomorphize_expr monomorphized_instances) args in - let array_instance = - Type.Map.find - (match Mark.remove (Expr.ty e) with - | TArray t -> t - | _ -> failwith "should not happen") - monomorphized_instances.arrays - in - Expr.estruct ~name:array_instance.name - ~fields: - (StructField.Map.add array_instance.content_field - (Expr.earray new_args (Mark.get e)) - (StructField.Map.singleton array_instance.len_field - (Expr.elit - (LInt (Runtime.integer_of_int (List.length args))) - (Mark.get e)))) - (Mark.get e) - | _ -> Expr.map ~f:(monomorphize_expr monomorphized_instances) e + let array_instance = Type.Map.find elt_ty monomorphized_instances.arrays in + EStruct + { + name = array_instance.name; + fields = + StructField.Map.of_list + [ + ( array_instance.len_field, + ( ELit (LInt (Runtime.integer_of_int (List.length elts))), + Expr.with_ty m (TLit TInt, Expr.mark_pos m) ) ); + ( array_instance.content_field, + (e, Expr.with_ty m (TArray (f_ty elt_ty), Expr.mark_pos m)) ); + ]; + } + | e -> e let program (prg : typed program) : - untyped program * Scopelang.Dependency.TVertex.t list = + typed program * Scopelang.Dependency.TVertex.t list = let monomorphized_instances = collect_monomorphized_instances prg in (* First we remove the polymorphic option type *) let prg = @@ -417,7 +401,6 @@ let program (prg : typed program) : scope_body))) ~varf:Fun.id prg.code_items in - let prg = Program.untype { prg with code_items } in - ( prg, + ( { prg with code_items }, Scopelang.Dependency.check_type_cycles prg.decl_ctx.ctx_structs prg.decl_ctx.ctx_enums ) diff --git a/compiler/lcalc/monomorphize.mli b/compiler/lcalc/monomorphize.mli index ab523bef..f154fed2 100644 --- a/compiler/lcalc/monomorphize.mli +++ b/compiler/lcalc/monomorphize.mli @@ -18,7 +18,7 @@ open Shared_ast open Ast val program : - typed program -> untyped program * Scopelang.Dependency.TVertex.t list + typed program -> typed program * Scopelang.Dependency.TVertex.t list (** This function performs type monomorphization in a Catala program with two main actions: {ul {- transforms tuples into named structs.} diff --git a/compiler/shared_ast/print.ml b/compiler/shared_ast/print.ml index da811645..e3bcfa97 100644 --- a/compiler/shared_ast/print.ml +++ b/compiler/shared_ast/print.ml @@ -474,13 +474,20 @@ end module ExprGen (C : EXPR_PARAM) = struct let rec expr_aux : - type a. + type a t. Bindlib.ctxt -> Ocolor_types.color4 list -> Format.formatter -> - (a, 't) gexpr -> + (a, t) gexpr -> unit = fun bnd_ctx colors fmt e -> + (* (* Uncomment for type annotations everywhere *) + * (fun f -> + * Format.fprintf fmt "@[(%a:@ %a)@]" + * f e + * typ_debug + * (match Mark.get e with Typed {ty; _} -> ty | _ -> TAny,Pos.no_pos)) + * @@ fun fmt e -> *) let exprb bnd_ctx colors e = expr_aux bnd_ctx colors e in let exprc colors e = exprb bnd_ctx colors e in let expr e = exprc colors e in