Simplify a few mark operations

This commit is contained in:
Louis Gesbert 2022-09-12 17:23:44 +02:00
parent d93b699a4c
commit 0bb9cce341
4 changed files with 21 additions and 34 deletions

View File

@ -491,8 +491,7 @@ let interpret_program :
(fun (_, ty) -> (fun (_, ty) ->
match Marked.unmark ty with match Marked.unmark ty with
| TArrow ((TLit TUnit, _), ty_in) -> | TArrow ((TLit TUnit, _), ty_in) ->
Expr.empty_thunked_term Expr.empty_thunked_term (Expr.with_ty mark_e ty_in)
(Expr.map_mark (fun pos -> pos) (fun _ -> ty_in) mark_e)
| _ -> | _ ->
Errors.raise_spanned_error (Marked.get_mark ty) Errors.raise_spanned_error (Marked.get_mark ty)
"This scope needs input arguments to be executed. But the Catala \ "This scope needs input arguments to be executed. But the Catala \
@ -512,7 +511,7 @@ let interpret_program :
| a :: _ -> Expr.pos a | a :: _ -> Expr.pos a
| [] -> Pos.no_pos | [] -> Pos.no_pos
in in
Expr.map_mark (fun _ -> pos) (fun _ -> targs) mark_e ); Expr.with_ty mark_e ~pos targs );
] ), ] ),
Expr.map_mark Expr.map_mark
(fun pos -> pos) (fun pos -> pos)

View File

@ -39,9 +39,7 @@ let make_none m =
Bindlib.box Bindlib.box
@@ mark @@ mark
@@ EInj @@ EInj
( Marked.mark ( Marked.mark (Expr.with_ty m tunit) (ELit LUnit),
(Expr.map_mark (fun pos -> pos) (fun _ -> tunit) m)
(ELit LUnit),
0, 0,
option_enum, option_enum,
[TLit TUnit, Pos.no_pos; TAny, Pos.no_pos] ) [TLit TUnit, Pos.no_pos; TAny, Pos.no_pos] )

View File

@ -43,11 +43,10 @@ module A = Ast
open Shared_ast open Shared_ast
type 'm hoists = ('m A.expr, 'm D.expr) Var.Map.t type 'm hoists = ('m A.expr, 'm D.expr) Var.Map.t
(** Hoists definition. It represent bindings between [A.Var.t] and (** Hoists definition. It represent bindings between [A.Var.t] and [D.expr]. *)
[D.naked_expr]. *)
type 'm info = { type 'm info = {
naked_expr : 'm A.expr Bindlib.box; expr : 'm A.expr Bindlib.box;
var : 'm A.expr Var.t; var : 'm A.expr Var.t;
is_pure : bool; is_pure : bool;
} }
@ -104,7 +103,7 @@ let add_var
(is_pure : bool) (is_pure : bool)
(ctx : 'm ctx) : 'm ctx = (ctx : 'm ctx) : 'm ctx =
let new_var = Var.make (Bindlib.name_of var) in let new_var = Var.make (Bindlib.name_of var) in
let naked_expr = Expr.make_var (new_var, mark) in let expr = Expr.make_var (new_var, mark) in
(* Cli.debug_print @@ Format.asprintf "D.%a |-> A.%a" Print.var var Print.var (* Cli.debug_print @@ Format.asprintf "D.%a |-> A.%a" Print.var var Print.var
new_var; *) new_var; *)
@ -112,7 +111,7 @@ let add_var
ctx with ctx with
vars = vars =
Var.Map.update var Var.Map.update var
(fun _ -> Some { naked_expr; var = new_var; is_pure }) (fun _ -> Some { expr; var = new_var; is_pure })
ctx.vars; ctx.vars;
} }
@ -174,7 +173,7 @@ let rec translate_and_hoist (ctx : 'm ctx) (e : 'm D.expr) :
(* Cli.debug_print @@ Format.asprintf "Found an unpure variable %a, (* Cli.debug_print @@ Format.asprintf "Found an unpure variable %a,
created a variable %a to replace it" Print.var v Print.var v'; *) created a variable %a to replace it" Print.var v Print.var v'; *)
Expr.make_var (v', mark), Var.Map.singleton v' e Expr.make_var (v', mark), Var.Map.singleton v' e
else (find ~info:"should never happen" v ctx).naked_expr, Var.Map.empty else (find ~info:"should never happen" v ctx).expr, Var.Map.empty
| EApp ((EVar v, p), [(ELit LUnit, _)]) -> | EApp ((EVar v, p), [(ELit LUnit, _)]) ->
if not (find ~info:"search for a variable" v ctx).is_pure then if not (find ~info:"search for a variable" v ctx).is_pure then
let v' = Var.make (Bindlib.name_of v) in let v' = Var.make (Bindlib.name_of v) in
@ -309,7 +308,7 @@ and translate_expr ?(append_esome = true) (ctx : 'm ctx) (e : 'm D.expr) :
match hoist with match hoist with
(* Here we have to handle only the cases appearing in hoists, as defined (* Here we have to handle only the cases appearing in hoists, as defined
the [translate_and_hoist] function. *) the [translate_and_hoist] function. *)
| EVar v -> (find ~info:"should never happen" v ctx).naked_expr | EVar v -> (find ~info:"should never happen" v ctx).expr
| EDefault (excep, just, cons) -> | EDefault (excep, just, cons) ->
let excep' = List.map (translate_expr ctx) excep in let excep' = List.map (translate_expr ctx) excep in
let just' = translate_expr ctx just in let just' = translate_expr ctx just in
@ -376,12 +375,12 @@ let rec translate_scope_let (ctx : 'm ctx) (lets : 'm D.expr scope_body_expr) :
} -> } ->
(* special case : the subscope variable is thunked (context i/o). We remove (* special case : the subscope variable is thunked (context i/o). We remove
this thunking. *) this thunking. *)
let _, naked_expr = Bindlib.unmbind binder in let _, expr = Bindlib.unmbind binder in
let var_is_pure = true in let var_is_pure = true in
let var, next = Bindlib.unbind next in let var, next = Bindlib.unbind next in
(* Cli.debug_print @@ Format.asprintf "unbinding %a" Print.var var; *) (* Cli.debug_print @@ Format.asprintf "unbinding %a" Print.var var; *)
let vmark = Expr.map_mark (fun _ -> pos) (fun _ -> typ) emark in let vmark = Expr.with_ty emark ~pos typ in
let ctx' = add_var vmark var var_is_pure ctx in let ctx' = add_var vmark var var_is_pure ctx in
let new_var = (find ~info:"variable that was just created" var ctx').var in let new_var = (find ~info:"variable that was just created" var ctx').var in
let new_next = translate_scope_let ctx' next in let new_next = translate_scope_let ctx' next in
@ -395,13 +394,13 @@ let rec translate_scope_let (ctx : 'm ctx) (lets : 'm D.expr scope_body_expr) :
scope_let_next = new_next; scope_let_next = new_next;
scope_let_pos = pos; scope_let_pos = pos;
}) })
(translate_expr ctx ~append_esome:false naked_expr) (translate_expr ctx ~append_esome:false expr)
(Bindlib.bind_var new_var new_next) (Bindlib.bind_var new_var new_next)
| ScopeLet | ScopeLet
{ {
scope_let_kind = SubScopeVarDefinition; scope_let_kind = SubScopeVarDefinition;
scope_let_typ = typ; scope_let_typ = typ;
scope_let_expr = (ErrorOnEmpty _, emark) as naked_expr; scope_let_expr = (ErrorOnEmpty _, emark) as expr;
scope_let_next = next; scope_let_next = next;
scope_let_pos = pos; scope_let_pos = pos;
} -> } ->
@ -409,7 +408,7 @@ let rec translate_scope_let (ctx : 'm ctx) (lets : 'm D.expr scope_body_expr) :
let var_is_pure = true in let var_is_pure = true in
let var, next = Bindlib.unbind next in let var, next = Bindlib.unbind next in
(* Cli.debug_print @@ Format.asprintf "unbinding %a" Print.var var; *) (* Cli.debug_print @@ Format.asprintf "unbinding %a" Print.var var; *)
let vmark = Expr.map_mark (fun _ -> pos) (fun _ -> typ) emark in let vmark = Expr.with_ty emark ~pos typ in
let ctx' = add_var vmark var var_is_pure ctx in let ctx' = add_var vmark var var_is_pure ctx in
let new_var = (find ~info:"variable that was just created" var ctx').var in let new_var = (find ~info:"variable that was just created" var ctx').var in
Bindlib.box_apply2 Bindlib.box_apply2
@ -422,25 +421,25 @@ let rec translate_scope_let (ctx : 'm ctx) (lets : 'm D.expr scope_body_expr) :
scope_let_next = new_next; scope_let_next = new_next;
scope_let_pos = pos; scope_let_pos = pos;
}) })
(translate_expr ctx ~append_esome:false naked_expr) (translate_expr ctx ~append_esome:false expr)
(Bindlib.bind_var new_var (translate_scope_let ctx' next)) (Bindlib.bind_var new_var (translate_scope_let ctx' next))
| ScopeLet | ScopeLet
{ {
scope_let_kind = SubScopeVarDefinition; scope_let_kind = SubScopeVarDefinition;
scope_let_pos = pos; scope_let_pos = pos;
scope_let_expr = naked_expr; scope_let_expr = expr;
_; _;
} -> } ->
Errors.raise_spanned_error pos Errors.raise_spanned_error pos
"Internal Error: found an SubScopeVarDefinition that does not satisfy \ "Internal Error: found an SubScopeVarDefinition that does not satisfy \
the invariants when translating Dcalc to Lcalc without exceptions: \ the invariants when translating Dcalc to Lcalc without exceptions: \
@[<hov 2>%a@]" @[<hov 2>%a@]"
(Expr.format ctx.decl_ctx) naked_expr (Expr.format ctx.decl_ctx) expr
| ScopeLet | ScopeLet
{ {
scope_let_kind = kind; scope_let_kind = kind;
scope_let_typ = typ; scope_let_typ = typ;
scope_let_expr = naked_expr; scope_let_expr = expr;
scope_let_next = next; scope_let_next = next;
scope_let_pos = pos; scope_let_pos = pos;
} -> } ->
@ -460,9 +459,7 @@ let rec translate_scope_let (ctx : 'm ctx) (lets : 'm D.expr scope_body_expr) :
in in
let var, next = Bindlib.unbind next in let var, next = Bindlib.unbind next in
(* Cli.debug_print @@ Format.asprintf "unbinding %a" Print.var var; *) (* Cli.debug_print @@ Format.asprintf "unbinding %a" Print.var var; *)
let vmark = let vmark = Expr.with_ty (Marked.get_mark expr) ~pos typ in
Expr.map_mark (fun _ -> pos) (fun _ -> typ) (Marked.get_mark naked_expr)
in
let ctx' = add_var vmark var var_is_pure ctx in let ctx' = add_var vmark var var_is_pure ctx in
let new_var = (find ~info:"variable that was just created" var ctx').var in let new_var = (find ~info:"variable that was just created" var ctx').var in
Bindlib.box_apply2 Bindlib.box_apply2
@ -475,7 +472,7 @@ let rec translate_scope_let (ctx : 'm ctx) (lets : 'm D.expr scope_body_expr) :
scope_let_next = new_next; scope_let_next = new_next;
scope_let_pos = pos; scope_let_pos = pos;
}) })
(translate_expr ctx ~append_esome:false naked_expr) (translate_expr ctx ~append_esome:false expr)
(Bindlib.bind_var new_var (translate_scope_let ctx' next)) (Bindlib.bind_var new_var (translate_scope_let ctx' next))
let translate_scope_body let translate_scope_body

View File

@ -73,14 +73,7 @@ let merge_defaults
let m = Marked.get_mark (Bindlib.unbox caller) in let m = Marked.get_mark (Bindlib.unbox caller) in
let pos = Expr.mark_pos m in let pos = Expr.mark_pos m in
Expr.make_app caller Expr.make_app caller
[ [Bindlib.box (ELit LUnit, Expr.with_ty m (Marked.mark pos (TLit TUnit)))]
Bindlib.box
( ELit LUnit,
Expr.map_mark
(fun _ -> pos)
(fun _ -> Marked.mark pos (TLit TUnit))
m );
]
pos pos
in in
let body = let body =