Simplify monomorphisation, and preserve type annotations

This commit is contained in:
Louis Gesbert 2024-02-06 15:33:40 +01:00
parent 22674cd15d
commit df70c5dd57
3 changed files with 88 additions and 98 deletions

View File

@ -212,112 +212,96 @@ let is_some c =
(assert (EnumConstructor.equal Expr.none_constr c);
false)
(* We output a typed expr but the types in the output are wrong, it should be
untyped and re-typed later. *)
let rec monomorphize_expr
(monomorphized_instances : monomorphized_instances)
(e : typed expr) : typed expr boxed =
let typ = Expr.ty e in
match Mark.remove e with
(e0 : typed expr) : typed expr boxed =
let ty0 = Expr.ty e0 in
(* Keys in [monomorphized_instances] are before monomorphization, so collect
this top-down *)
let f_expr = monomorphize_expr monomorphized_instances in
let f_ty = monomorphize_typ monomorphized_instances in
(* Proceed bottom-up: apply first to the sub-terms *)
let e = Expr.map ~f:f_expr ~typ:f_ty ~op:Fun.id e0 in
let m = Mark.get e in
let map_box f = Expr.Box.app1 e (fun e -> f (Mark.remove e)) m in
map_box
@@ function
| ETuple args ->
let new_args = List.map (monomorphize_expr monomorphized_instances) args in
let tuple_instance = Type.Map.find typ monomorphized_instances.tuples in
let fields =
StructField.Map.of_list
@@ List.map2
(fun new_arg (tuple_field, _) -> tuple_field, new_arg)
new_args tuple_instance.fields
let tuple_instance = Type.Map.find ty0 monomorphized_instances.tuples in
EStruct
{
name = tuple_instance.name;
fields =
StructField.Map.of_list
@@ List.map2
(fun (tuple_field, _) arg -> tuple_field, arg)
tuple_instance.fields args;
}
| ETupleAccess { e; index; _ } ->
(* The type of the tuple needs to be recovered from the untransformed
expr *)
let tup_ty =
match e0 with ETupleAccess { e; _ }, _ -> Expr.ty e | _ -> assert false
in
Expr.estruct ~name:tuple_instance.name ~fields (Mark.get e)
| ETupleAccess { e = e1; index; _ } ->
let tuple_instance =
Type.Map.find (Expr.ty e1) monomorphized_instances.tuples
in
let new_e1 = monomorphize_expr monomorphized_instances e1 in
Expr.estructaccess ~name:tuple_instance.name
~field:(fst (List.nth tuple_instance.fields index))
~e:new_e1 (Mark.get e)
| EMatch { name; e = e1; cases } when EnumName.equal name Expr.option_enum ->
let new_e1 = monomorphize_expr monomorphized_instances e1 in
let new_cases =
EnumConstructor.Map.bindings
(EnumConstructor.Map.map
(monomorphize_expr monomorphized_instances)
cases)
Type.Map.find tup_ty monomorphized_instances.tuples
in
EStructAccess
{
name = tuple_instance.name;
e;
field = fst (List.nth tuple_instance.fields index);
}
| EMatch { name; e; cases } when EnumName.equal name Expr.option_enum ->
let option_instance = Type.Map.find ty0 monomorphized_instances.options in
EMatch
{
name = option_instance.name;
e;
cases =
EnumConstructor.Map.fold
(fun c ->
EnumConstructor.Map.add
(if is_some c then option_instance.some_cons
else option_instance.none_cons))
cases EnumConstructor.Map.empty;
}
| EInj { name; e; cons } when EnumName.equal name Expr.option_enum ->
let option_instance =
Type.Map.find
(match Mark.remove (Expr.ty e1) with
| TOption t -> t
| _ -> failwith "should not happen")
(match Mark.remove ty0 with TOption t -> t | _ -> assert false)
monomorphized_instances.options
in
let new_cases =
match new_cases with
| [(n1, e1); (n2, e2)] -> (
match is_some n1, is_some n2 with
| true, false ->
[option_instance.some_cons, e1; option_instance.none_cons, e2]
| false, true ->
[option_instance.some_cons, e2; option_instance.none_cons, e1]
| _ -> failwith "should not happen")
| _ -> failwith "should not happen"
EInj
{
name = option_instance.name;
e;
cons =
(if is_some cons then option_instance.some_cons
else option_instance.none_cons);
}
| EArray elts as e ->
let elt_ty =
match Mark.remove ty0 with TArray t -> t | _ -> assert false
in
let new_cases = EnumConstructor.Map.of_list new_cases in
Expr.ematch ~name:option_instance.name ~e:new_e1 ~cases:new_cases
(Mark.get e)
| EInj { name; e = e1; cons } when EnumName.equal name Expr.option_enum ->
let option_instance =
Type.Map.find
(match Mark.remove (Expr.ty e) with
| TOption t -> t
| _ -> failwith "should not happen")
monomorphized_instances.options
in
let new_e1 = monomorphize_expr monomorphized_instances e1 in
let new_cons =
if is_some cons then option_instance.some_cons
else option_instance.none_cons
in
Expr.einj ~name:option_instance.name ~e:new_e1 ~cons:new_cons (Mark.get e)
(* We do not forget to tweak types stored directly in the AST in the nodes
of kind [EAbs], [EApp] and [EAppOp]. *)
| EAbs { binder; tys } ->
let new_tys = List.map (monomorphize_typ monomorphized_instances) tys in
let vars, body = Bindlib.unmbind binder in
let new_body = monomorphize_expr monomorphized_instances body in
Expr.make_abs vars new_body new_tys (Expr.pos e)
| EApp { f; args; tys } ->
let new_f = monomorphize_expr monomorphized_instances f in
let new_args = List.map (monomorphize_expr monomorphized_instances) args in
let new_tys = List.map (monomorphize_typ monomorphized_instances) tys in
Expr.eapp ~f:new_f ~args:new_args ~tys:new_tys (Mark.get e)
| EAppOp { op; args; tys } ->
let new_args = List.map (monomorphize_expr monomorphized_instances) args in
let new_tys = List.map (monomorphize_typ monomorphized_instances) tys in
Expr.eappop ~op ~args:new_args ~tys:new_tys (Mark.get e)
| EArray args ->
let new_args = List.map (monomorphize_expr monomorphized_instances) args in
let array_instance =
Type.Map.find
(match Mark.remove (Expr.ty e) with
| TArray t -> t
| _ -> failwith "should not happen")
monomorphized_instances.arrays
in
Expr.estruct ~name:array_instance.name
~fields:
(StructField.Map.add array_instance.content_field
(Expr.earray new_args (Mark.get e))
(StructField.Map.singleton array_instance.len_field
(Expr.elit
(LInt (Runtime.integer_of_int (List.length args)))
(Mark.get e))))
(Mark.get e)
| _ -> Expr.map ~f:(monomorphize_expr monomorphized_instances) e
let array_instance = Type.Map.find elt_ty monomorphized_instances.arrays in
EStruct
{
name = array_instance.name;
fields =
StructField.Map.of_list
[
( array_instance.len_field,
( ELit (LInt (Runtime.integer_of_int (List.length elts))),
Expr.with_ty m (TLit TInt, Expr.mark_pos m) ) );
( array_instance.content_field,
(e, Expr.with_ty m (TArray (f_ty elt_ty), Expr.mark_pos m)) );
];
}
| e -> e
let program (prg : typed program) :
untyped program * Scopelang.Dependency.TVertex.t list =
typed program * Scopelang.Dependency.TVertex.t list =
let monomorphized_instances = collect_monomorphized_instances prg in
(* First we remove the polymorphic option type *)
let prg =
@ -417,7 +401,6 @@ let program (prg : typed program) :
scope_body)))
~varf:Fun.id prg.code_items
in
let prg = Program.untype { prg with code_items } in
( prg,
( { prg with code_items },
Scopelang.Dependency.check_type_cycles prg.decl_ctx.ctx_structs
prg.decl_ctx.ctx_enums )

View File

@ -18,7 +18,7 @@ open Shared_ast
open Ast
val program :
typed program -> untyped program * Scopelang.Dependency.TVertex.t list
typed program -> typed program * Scopelang.Dependency.TVertex.t list
(** This function performs type monomorphization in a Catala program with two
main actions: {ul
{- transforms tuples into named structs.}

View File

@ -474,13 +474,20 @@ end
module ExprGen (C : EXPR_PARAM) = struct
let rec expr_aux :
type a.
type a t.
Bindlib.ctxt ->
Ocolor_types.color4 list ->
Format.formatter ->
(a, 't) gexpr ->
(a, t) gexpr ->
unit =
fun bnd_ctx colors fmt e ->
(* (* Uncomment for type annotations everywhere *)
* (fun f ->
* Format.fprintf fmt "@[<hv 1>(%a:@ %a)@]"
* f e
* typ_debug
* (match Mark.get e with Typed {ty; _} -> ty | _ -> TAny,Pos.no_pos))
* @@ fun fmt e -> *)
let exprb bnd_ctx colors e = expr_aux bnd_ctx colors e in
let exprc colors e = exprb bnd_ctx colors e in
let expr e = exprc colors e in