diff --git a/src/catala/default_calculus/ast.ml b/src/catala/default_calculus/ast.ml index a0cd9dc6..abd7e06f 100644 --- a/src/catala/default_calculus/ast.ml +++ b/src/catala/default_calculus/ast.ml @@ -34,9 +34,22 @@ type expr = module Var = struct type t = expr Pos.marked Bindlib.var + let make (s : string Pos.marked) = + Bindlib.new_var (fun x -> (EVar x, Pos.get_position s)) (Pos.unmark s) + let compare x y = Bindlib.compare_vars x y end module VarMap = Map.Make (Var) +let make_var (x : Var.t) : expr Pos.marked Bindlib.box = Bindlib.box_var x + +let make_abs (x : Var.t) (e : expr Pos.marked Bindlib.box) (pos_binder : Pos.t) (tau : typ) + (pos : Pos.t) : expr Pos.marked Bindlib.box = + Bindlib.box_apply (fun b -> (EAbs (pos_binder, b, tau), pos)) (Bindlib.bind_var x e) + +let make_app (e : expr Pos.marked Bindlib.box) (u : expr Pos.marked Bindlib.box) (pos : Pos.t) : + expr Pos.marked Bindlib.box = + Bindlib.box_apply2 (fun e u -> (EApp (e, u), pos)) e u + type binder = (expr, expr Pos.marked) Bindlib.binder diff --git a/src/catala/scope_language/ast.ml b/src/catala/scope_language/ast.ml index 7e8c7438..903afce8 100644 --- a/src/catala/scope_language/ast.ml +++ b/src/catala/scope_language/ast.ml @@ -24,7 +24,9 @@ module ScopeVar = Uid.Make (Uid.MarkedString) module ScopeVarSet = Set.Make (ScopeVar) module ScopeVarMap = Map.Make (ScopeVar) -type location = ScopeVar of ScopeVar.t | SubScopeVar of ScopeName.t * SubScopeName.t * ScopeVar.t +type location = + | ScopeVar of ScopeVar.t Pos.marked + | SubScopeVar of ScopeName.t * SubScopeName.t Pos.marked * ScopeVar.t Pos.marked type expr = | ELocation of location @@ -34,7 +36,20 @@ type expr = | EApp of expr Pos.marked * expr Pos.marked | EDefault of expr Pos.marked * expr Pos.marked * expr Pos.marked list -type rule = Definition of location * Dcalc.Ast.typ * expr | Call of ScopeName.t * SubScopeName.t +module Var = struct + type t = expr Pos.marked Bindlib.var + + let make (s : string Pos.marked) = + Bindlib.new_var (fun x -> (EVar x, Pos.get_position s)) (Pos.unmark s) + + let compare x y = Bindlib.compare_vars x y +end + +module VarMap = Map.Make (Var) + +type rule = + | Definition of location * Dcalc.Ast.typ * expr Pos.marked + | Call of ScopeName.t * SubScopeName.t type scope_decl = { scope_decl_name : ScopeName.t; scope_decl_rules : rule list } diff --git a/src/catala/scope_language/scope_to_dcalc.ml b/src/catala/scope_language/scope_to_dcalc.ml index 9fb8a0ff..9daaf970 100644 --- a/src/catala/scope_language/scope_to_dcalc.ml +++ b/src/catala/scope_language/scope_to_dcalc.ml @@ -13,20 +13,117 @@ the License. *) module Pos = Utils.Pos +module Errors = Utils.Errors -type ctx = Ast.location list +type ctx = { + scope_vars : (Dcalc.Ast.Var.t * Dcalc.Ast.typ) Ast.ScopeVarMap.t; + subscope_vars : (Dcalc.Ast.Var.t * Dcalc.Ast.typ) Ast.ScopeVarMap.t Ast.SubScopeMap.t; + local_vars : Dcalc.Ast.Var.t Ast.VarMap.t; +} + +let empty_ctx = + { + scope_vars = Ast.ScopeVarMap.empty; + subscope_vars = Ast.SubScopeMap.empty; + local_vars = Ast.VarMap.empty; + } type scope_ctx = Dcalc.Ast.Var.t Ast.ScopeMap.t -let hole_var : Dcalc.Ast.Var.t = Bindlib.new_var (fun x -> (Dcalc.Ast.EVar x, Pos.no_pos)) "hole" +let hole_var : Dcalc.Ast.Var.t = Dcalc.Ast.Var.make ("hole", Pos.no_pos) -let translate_rules (_p : scope_ctx) (_ctx : ctx) (_rules : Ast.rule list) : Dcalc.Ast.expr * ctx = - assert false +let merge_operator_var : Dcalc.Ast.Var.t = Dcalc.Ast.Var.make ("merge", Pos.no_pos) -let translate_scope_decl (p : scope_ctx) (sigma : Ast.scope_decl) : Dcalc.Ast.expr = - let ctx = [] in - let _defs, ctx = translate_rules p ctx sigma.scope_decl_rules in - let _scope_variables = - List.filter_map (fun l -> match l with Ast.ScopeVar v -> Some v | _ -> None) ctx +let merge_operator_expr : Dcalc.Ast.expr Pos.marked Bindlib.box = Bindlib.box_var merge_operator_var + +let hole_expr : Dcalc.Ast.expr Pos.marked Bindlib.box = Bindlib.box_var hole_var + +let rec translate_expr (ctx : ctx) (e : Ast.expr Pos.marked) : Dcalc.Ast.expr Pos.marked = + Pos.same_pos_as + ( match Pos.unmark e with + | EVar v -> Dcalc.Ast.EVar (Ast.VarMap.find v ctx.local_vars) + | ELit l -> Dcalc.Ast.ELit l + | EApp (e1, e2) -> Dcalc.Ast.EApp (translate_expr ctx e1, translate_expr ctx e2) + | EAbs (pos_binder, binder, typ) -> + let x, body = Bindlib.unbind binder in + let new_x = Dcalc.Ast.Var.make (Bindlib.name_of x, pos_binder) in + let body = + translate_expr { ctx with local_vars = Ast.VarMap.add x new_x ctx.local_vars } body + in + let binder = Bindlib.unbox (Bindlib.bind_var new_x (Bindlib.box body)) in + Dcalc.Ast.EAbs (pos_binder, binder, typ) + | EDefault (just, cons, subs) -> + Dcalc.Ast.EDefault + (translate_expr ctx just, translate_expr ctx cons, List.map (translate_expr ctx) subs) + | ELocation (ScopeVar a) -> + Dcalc.Ast.EVar (fst (Ast.ScopeVarMap.find (Pos.unmark a) ctx.scope_vars)) + | ELocation (SubScopeVar (_, s, a)) -> + Dcalc.Ast.EVar + (fst + (Ast.ScopeVarMap.find (Pos.unmark a) + (Ast.SubScopeMap.find (Pos.unmark s) ctx.subscope_vars))) ) + e + +let translate_rule (_p : scope_ctx) (ctx : ctx) (rule : Ast.rule) : + Dcalc.Ast.expr Pos.marked Bindlib.box * ctx = + match rule with + | Definition (ScopeVar a, tau, e) -> + let a_name = Ast.ScopeVar.get_info (Pos.unmark a) in + let a_var = Dcalc.Ast.Var.make a_name in + let next_e = + Dcalc.Ast.make_abs a_var hole_expr (Pos.get_position a) tau (Pos.get_position e) + in + let silent1 = Dcalc.Ast.Var.make ("silent", Pos.get_position e) in + let silent2 = Dcalc.Ast.Var.make ("silent", Pos.get_position e) in + let wrapped_e = + Dcalc.Ast.make_abs silent1 + (Bindlib.box (translate_expr ctx e)) + (Pos.get_position e) Dcalc.Ast.TUnit (Pos.get_position e) + in + let a_expr = Dcalc.Ast.make_var a_var in + let merged_expr = Dcalc.Ast.make_app merge_operator_expr a_expr (Pos.get_position e) in + let merged_expr = Dcalc.Ast.make_app merged_expr wrapped_e (Pos.get_position e) in + let merged_thunked = + Dcalc.Ast.make_abs silent2 merged_expr (Pos.get_position e) Dcalc.Ast.TUnit + (Pos.get_position e) + in + let final_e = Dcalc.Ast.make_app merged_thunked next_e (Pos.get_position e) in + let new_ctx = + { ctx with scope_vars = Ast.ScopeVarMap.add (Pos.unmark a) (a_var, tau) ctx.scope_vars } + in + (final_e, new_ctx) + | Definition (SubScopeVar _, _tau, _e) -> + Errors.raise_error "translation of subscope vars definitions unimplemented" + | Call _ -> Errors.raise_error "translation of subscope calls unimplemented" + +let translate_rules (p : scope_ctx) (ctx : ctx) (rules : Ast.rule list) : + Dcalc.Ast.expr Pos.marked Bindlib.box * ctx = + let acc = hole_expr in + List.fold_left + (fun (acc, ctx) rule -> + let new_e, ctx = translate_rule p ctx rule in + let acc = Bindlib.unbox (Bindlib.bind_var hole_var acc) in + let new_acc = Bindlib.subst acc (Bindlib.unbox new_e) in + (Bindlib.box new_acc, ctx)) + (acc, ctx) rules + +let translate_scope_decl (p : scope_ctx) (sigma : Ast.scope_decl) : Dcalc.Ast.expr Pos.marked = + let ctx = empty_ctx in + let rules, ctx = translate_rules p ctx sigma.scope_decl_rules in + let scope_variables = Ast.ScopeVarMap.bindings ctx.scope_vars in + let pos_sigma = Pos.get_position (Ast.ScopeName.get_info sigma.scope_decl_name) in + let return_exp = + Dcalc.Ast.ETuple + (List.map (fun (_, (dcalc_var, _)) -> (Dcalc.Ast.EVar dcalc_var, pos_sigma)) scope_variables) in - assert false + let func_acc = rules in + let func_acc = + List.fold_right + (fun (_, (dcalc_var, tau)) func_acc -> + Dcalc.Ast.make_abs dcalc_var func_acc pos_sigma + (Dcalc.Ast.TArrow ((Dcalc.Ast.TUnit, pos_sigma), (tau, pos_sigma))) + pos_sigma) + scope_variables func_acc + in + let func_acc = Bindlib.unbox (Bindlib.bind_var hole_var func_acc) in + Bindlib.subst func_acc (return_exp, pos_sigma)