mirror of
https://github.com/CatalaLang/catala.git
synced 2024-11-08 07:51:43 +03:00
Fixed functions
This commit is contained in:
parent
1030e4bc8d
commit
0443221e8b
@ -14,6 +14,7 @@
|
||||
|
||||
module Pos = Utils.Pos
|
||||
module Errors = Utils.Errors
|
||||
module Cli = Utils.Cli
|
||||
|
||||
(** The optional argument subdef allows to choose between differents uids in case the expression is
|
||||
a redefinition of a subvariable *)
|
||||
@ -37,22 +38,26 @@ let translate_unop (op : Ast.unop) : Dcalc.Ast.unop = match op with Not -> Not |
|
||||
|
||||
let rec translate_expr (scope : Scopelang.Ast.ScopeName.t)
|
||||
(def_key : Desugared.Ast.ScopeDef.t option) (ctxt : Name_resolution.context)
|
||||
((expr, pos) : Ast.expression Pos.marked) : Scopelang.Ast.expr Pos.marked =
|
||||
((expr, pos) : Ast.expression Pos.marked) : Scopelang.Ast.expr Pos.marked Bindlib.box =
|
||||
let scope_ctxt = Scopelang.Ast.ScopeMap.find scope ctxt.scopes in
|
||||
let rec_helper = translate_expr scope def_key ctxt in
|
||||
match expr with
|
||||
| IfThenElse (e_if, e_then, e_else) ->
|
||||
(EIfThenElse (rec_helper e_if, rec_helper e_then, rec_helper e_else), pos)
|
||||
Bindlib.box_apply3
|
||||
(fun e_if e_then e_else -> (Scopelang.Ast.EIfThenElse (e_if, e_then, e_else), pos))
|
||||
(rec_helper e_if) (rec_helper e_then) (rec_helper e_else)
|
||||
| Binop (op, e1, e2) ->
|
||||
let op_term =
|
||||
Pos.same_pos_as (Scopelang.Ast.EOp (Dcalc.Ast.Binop (translate_binop (Pos.unmark op)))) op
|
||||
in
|
||||
(EApp (op_term, [ rec_helper e1; rec_helper e2 ]), pos)
|
||||
Bindlib.box_apply2
|
||||
(fun e1 e2 -> (Scopelang.Ast.EApp (op_term, [ e1; e2 ]), pos))
|
||||
(rec_helper e1) (rec_helper e2)
|
||||
| Unop (op, e) ->
|
||||
let op_term =
|
||||
Pos.same_pos_as (Scopelang.Ast.EOp (Dcalc.Ast.Unop (translate_unop (Pos.unmark op)))) op
|
||||
in
|
||||
(EApp (op_term, [ rec_helper e ]), pos)
|
||||
Bindlib.box_apply (fun e -> (Scopelang.Ast.EApp (op_term, [ e ]), pos)) (rec_helper e)
|
||||
| Literal l ->
|
||||
let untyped_term =
|
||||
match l with
|
||||
@ -61,7 +66,7 @@ let rec translate_expr (scope : Scopelang.Ast.ScopeName.t)
|
||||
| Bool b -> Scopelang.Ast.ELit (Dcalc.Ast.LBool b)
|
||||
| _ -> Name_resolution.raise_unsupported_feature "literal" pos
|
||||
in
|
||||
(untyped_term, pos)
|
||||
Bindlib.box (untyped_term, pos)
|
||||
| Ident x -> (
|
||||
(* first we check whether this is a local var, then we resort to scope-wide variables *)
|
||||
match def_key with
|
||||
@ -70,14 +75,15 @@ let rec translate_expr (scope : Scopelang.Ast.ScopeName.t)
|
||||
match Desugared.Ast.IdentMap.find_opt x def_ctxt.var_idmap with
|
||||
| None -> (
|
||||
match Desugared.Ast.IdentMap.find_opt x scope_ctxt.var_idmap with
|
||||
| Some uid -> (Scopelang.Ast.ELocation (ScopeVar (uid, pos)), pos)
|
||||
| Some uid -> Bindlib.box (Scopelang.Ast.ELocation (ScopeVar (uid, pos)), pos)
|
||||
| None ->
|
||||
Name_resolution.raise_unknown_identifier "for a\n local or scope-wide variable"
|
||||
(x, pos) )
|
||||
| Some uid -> (Scopelang.Ast.EVar uid, pos) )
|
||||
| Some uid -> Bindlib.box_var uid
|
||||
(* the whole box thing is to accomodate for this case *) )
|
||||
| None -> (
|
||||
match Desugared.Ast.IdentMap.find_opt x scope_ctxt.var_idmap with
|
||||
| Some uid -> (Scopelang.Ast.ELocation (ScopeVar (uid, pos)), pos)
|
||||
| Some uid -> Bindlib.box (Scopelang.Ast.ELocation (ScopeVar (uid, pos)), pos)
|
||||
| None -> Name_resolution.raise_unknown_identifier "for a scope-wide variable" (x, pos) )
|
||||
)
|
||||
| Dotted (e, x) -> (
|
||||
@ -91,31 +97,41 @@ let rec translate_expr (scope : Scopelang.Ast.ScopeName.t)
|
||||
Scopelang.Ast.SubScopeMap.find subscope_uid scope_ctxt.sub_scopes
|
||||
in
|
||||
let subscope_var_uid = Name_resolution.get_var_uid subscope_real_uid ctxt x in
|
||||
( Scopelang.Ast.ELocation
|
||||
(SubScopeVar (subscope_real_uid, (subscope_uid, pos), (subscope_var_uid, pos))),
|
||||
pos )
|
||||
Bindlib.box
|
||||
( Scopelang.Ast.ELocation
|
||||
(SubScopeVar (subscope_real_uid, (subscope_uid, pos), (subscope_var_uid, pos))),
|
||||
pos )
|
||||
| _ ->
|
||||
Name_resolution.raise_unsupported_feature
|
||||
"left hand side of a dotted expression should be an\n\n identifier" pos )
|
||||
| FunCall (f, arg) -> (EApp (rec_helper f, [ rec_helper arg ]), pos)
|
||||
| FunCall (f, arg) ->
|
||||
Bindlib.box_apply2
|
||||
(fun f arg -> (Scopelang.Ast.EApp (f, [ arg ]), pos))
|
||||
(rec_helper f) (rec_helper arg)
|
||||
| _ -> Name_resolution.raise_unsupported_feature "unsupported expression" pos
|
||||
|
||||
(* Translation from the parsed ast to the scope language *)
|
||||
|
||||
let merge_conditions (precond : Scopelang.Ast.expr Pos.marked option)
|
||||
(cond : Scopelang.Ast.expr Pos.marked option) (default_pos : Pos.t) :
|
||||
Scopelang.Ast.expr Pos.marked =
|
||||
let merge_conditions (precond : Scopelang.Ast.expr Pos.marked Bindlib.box option)
|
||||
(cond : Scopelang.Ast.expr Pos.marked Bindlib.box option) (default_pos : Pos.t) :
|
||||
Scopelang.Ast.expr Pos.marked Bindlib.box =
|
||||
match (precond, cond) with
|
||||
| Some precond, Some cond ->
|
||||
let op_term = (Scopelang.Ast.EOp (Dcalc.Ast.Binop Dcalc.Ast.And), Pos.get_position precond) in
|
||||
(Scopelang.Ast.EApp (op_term, [ precond; cond ]), Pos.get_position precond)
|
||||
let op_term =
|
||||
(Scopelang.Ast.EOp (Dcalc.Ast.Binop Dcalc.Ast.And), Pos.get_position (Bindlib.unbox precond))
|
||||
in
|
||||
Bindlib.box_apply2
|
||||
(fun precond cond ->
|
||||
(Scopelang.Ast.EApp (op_term, [ precond; cond ]), Pos.get_position precond))
|
||||
precond cond
|
||||
| Some cond, None | None, Some cond -> cond
|
||||
| None, None -> (Scopelang.Ast.ELit (Dcalc.Ast.LBool true), default_pos)
|
||||
| None, None -> Bindlib.box (Scopelang.Ast.ELit (Dcalc.Ast.LBool true), default_pos)
|
||||
|
||||
let process_default (ctxt : Name_resolution.context) (scope : Scopelang.Ast.ScopeName.t)
|
||||
(def_key : Desugared.Ast.ScopeDef.t) (param_uid : Scopelang.Ast.Var.t option)
|
||||
(precond : Scopelang.Ast.expr Pos.marked option) (just : Ast.expression Pos.marked option)
|
||||
(cons : Ast.expression Pos.marked) : Desugared.Ast.rule =
|
||||
(precond : Scopelang.Ast.expr Pos.marked Bindlib.box option)
|
||||
(just : Ast.expression Pos.marked option) (cons : Ast.expression Pos.marked) :
|
||||
Desugared.Ast.rule =
|
||||
let just =
|
||||
match just with
|
||||
| Some just -> Some (translate_expr scope (Some def_key) ctxt just)
|
||||
@ -133,30 +149,55 @@ let process_default (ctxt : Name_resolution.context) (scope : Scopelang.Ast.Scop
|
||||
| Dcalc.Ast.TArrow _, None ->
|
||||
Errors.raise_spanned_error
|
||||
"this definition has a function type but the parameter is missing"
|
||||
(Pos.get_position cons)
|
||||
(Pos.get_position (Bindlib.unbox cons))
|
||||
| _, Some _ ->
|
||||
Errors.raise_spanned_error
|
||||
"this definition has a parameter but its type is not a function"
|
||||
(Pos.get_position cons)
|
||||
(Pos.get_position (Bindlib.unbox cons))
|
||||
| _ -> None);
|
||||
parent_rule =
|
||||
None (* for now we don't have a priority mechanism in the syntax but it will happen soon *);
|
||||
}
|
||||
|
||||
let add_var_to_def_idmap (ctxt : Name_resolution.context) (scope_uid : Scopelang.Ast.ScopeName.t)
|
||||
(def_key : Desugared.Ast.ScopeDef.t) (name : string Pos.marked) (var : Scopelang.Ast.Var.t) :
|
||||
Name_resolution.context =
|
||||
{
|
||||
ctxt with
|
||||
scopes =
|
||||
Scopelang.Ast.ScopeMap.update scope_uid
|
||||
(fun scope_ctxt ->
|
||||
match scope_ctxt with
|
||||
| Some scope_ctxt ->
|
||||
Some
|
||||
{
|
||||
scope_ctxt with
|
||||
Name_resolution.definitions =
|
||||
Desugared.Ast.ScopeDefMap.update def_key
|
||||
(fun def_ctxt ->
|
||||
match def_ctxt with
|
||||
| None -> assert false (* should not happen *)
|
||||
| Some (def_ctxt : Name_resolution.def_context) ->
|
||||
Some
|
||||
{
|
||||
Name_resolution.var_idmap =
|
||||
Desugared.Ast.IdentMap.add (Pos.unmark name) var
|
||||
def_ctxt.Name_resolution.var_idmap;
|
||||
})
|
||||
scope_ctxt.Name_resolution.definitions;
|
||||
}
|
||||
| None -> assert false
|
||||
(* should not happen *))
|
||||
ctxt.scopes;
|
||||
}
|
||||
|
||||
(* Process a definition *)
|
||||
let process_def (precond : Scopelang.Ast.expr Pos.marked option)
|
||||
let process_def (precond : Scopelang.Ast.expr Pos.marked Bindlib.box option)
|
||||
(scope_uid : Scopelang.Ast.ScopeName.t) (ctxt : Name_resolution.context)
|
||||
(prgm : Desugared.Ast.program) (def : Ast.definition) : Desugared.Ast.program =
|
||||
let scope : Desugared.Ast.scope = Scopelang.Ast.ScopeMap.find scope_uid prgm in
|
||||
let scope_ctxt = Scopelang.Ast.ScopeMap.find scope_uid ctxt.scopes in
|
||||
let default_pos = Pos.get_position def.definition_expr in
|
||||
let param_uid (def_uid : Desugared.Ast.ScopeDef.t) : Scopelang.Ast.Var.t option =
|
||||
match def.definition_parameter with
|
||||
| None -> None
|
||||
| Some param ->
|
||||
let def_ctxt = Desugared.Ast.ScopeDefMap.find def_uid scope_ctxt.definitions in
|
||||
Some (Desugared.Ast.IdentMap.find (Pos.unmark param) def_ctxt.var_idmap)
|
||||
in
|
||||
let def_key =
|
||||
match Pos.unmark def.definition_name with
|
||||
| [ x ] ->
|
||||
@ -173,6 +214,14 @@ let process_def (precond : Scopelang.Ast.expr Pos.marked option)
|
||||
Desugared.Ast.ScopeDef.SubScopeVar (subscope_uid, x_uid)
|
||||
| _ -> Errors.raise_spanned_error "Structs are not handled yet" default_pos
|
||||
in
|
||||
(* We add to the name resolution context the name of the parameter variable *)
|
||||
let param_uid, new_ctxt =
|
||||
match def.definition_parameter with
|
||||
| None -> (None, ctxt)
|
||||
| Some param ->
|
||||
let param_var = Scopelang.Ast.Var.make param in
|
||||
(Some param_var, add_var_to_def_idmap ctxt scope_uid def_key param param_var)
|
||||
in
|
||||
let scope_updated =
|
||||
let x_def, x_type =
|
||||
match Desugared.Ast.ScopeDefMap.find_opt def_key scope.scope_defs with
|
||||
@ -187,7 +236,7 @@ let process_def (precond : Scopelang.Ast.expr Pos.marked option)
|
||||
in
|
||||
let x_def =
|
||||
Desugared.Ast.RuleMap.add rule_name
|
||||
(process_default ctxt scope_uid def_key (param_uid def_key) precond def.definition_condition
|
||||
(process_default new_ctxt scope_uid def_key param_uid precond def.definition_condition
|
||||
def.definition_expr)
|
||||
x_def
|
||||
in
|
||||
@ -199,7 +248,7 @@ let process_def (precond : Scopelang.Ast.expr Pos.marked option)
|
||||
Scopelang.Ast.ScopeMap.add scope_uid scope_updated prgm
|
||||
|
||||
(** Process a rule from the surface language *)
|
||||
let process_rule (precond : Scopelang.Ast.expr Pos.marked option)
|
||||
let process_rule (precond : Scopelang.Ast.expr Pos.marked Bindlib.box option)
|
||||
(scope : Scopelang.Ast.ScopeName.t) (ctxt : Name_resolution.context)
|
||||
(prgm : Desugared.Ast.program) (rule : Ast.rule) : Desugared.Ast.program =
|
||||
let consequence_expr = Ast.Literal (Ast.Bool (Pos.unmark rule.rule_consequence)) in
|
||||
|
@ -217,16 +217,7 @@ let process_scope_use (ctxt : context) (use : Ast.scope_use) : context =
|
||||
| Ast.Definition def ->
|
||||
let scope_ctxt = Scopelang.Ast.ScopeMap.find scope_uid ctxt.scopes in
|
||||
let def_uid = qident_to_scope_def ctxt scope_uid def.definition_name in
|
||||
let def_ctxt =
|
||||
{
|
||||
var_idmap =
|
||||
( match def.definition_parameter with
|
||||
| None -> Desugared.Ast.IdentMap.empty
|
||||
| Some param ->
|
||||
Desugared.Ast.IdentMap.singleton (Pos.unmark param)
|
||||
(Scopelang.Ast.Var.make param) );
|
||||
}
|
||||
in
|
||||
let def_ctxt = { var_idmap = Desugared.Ast.IdentMap.empty } in
|
||||
let scope_ctxt =
|
||||
{
|
||||
scope_ctxt with
|
||||
|
@ -31,7 +31,8 @@ let rec format_typ (fmt : Format.formatter) (typ : typ Pos.marked) : unit =
|
||||
Format.fprintf fmt "(%a)"
|
||||
(Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt " *@ ") format_typ)
|
||||
ts
|
||||
| TArrow (t1, t2) -> Format.fprintf fmt "%a →@ %a" format_typ_with_parens t1 format_typ t2
|
||||
| TArrow (t1, t2) ->
|
||||
Format.fprintf fmt "@[<hov 2>%a →@ %a@]" format_typ_with_parens t1 format_typ t2
|
||||
|
||||
let format_lit (fmt : Format.formatter) (l : lit Pos.marked) : unit =
|
||||
match Pos.unmark l with
|
||||
|
@ -37,9 +37,9 @@ let rec format_typ (fmt : Format.formatter) (ty : typ Pos.marked UnionFind.elem)
|
||||
| TAny -> Format.fprintf fmt "α"
|
||||
| TTuple ts ->
|
||||
Format.fprintf fmt "(%a)"
|
||||
(Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt " *@ ") format_typ)
|
||||
(Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt " * ") format_typ)
|
||||
ts
|
||||
| TArrow (t1, t2) -> Format.fprintf fmt "%a →@ %a" format_typ t1 format_typ t2
|
||||
| TArrow (t1, t2) -> Format.fprintf fmt "%a → %a" format_typ t1 format_typ t2
|
||||
|
||||
let rec unify (t1 : typ Pos.marked UnionFind.elem) (t2 : typ Pos.marked UnionFind.elem) : unit =
|
||||
(* Cli.debug_print (Format.asprintf "Unifying %a and %a" format_typ t1 format_typ t2); *)
|
||||
|
@ -53,16 +53,16 @@ module ScopeDefSet = Set.Make (ScopeDef)
|
||||
(* Scopes *)
|
||||
|
||||
type rule = {
|
||||
just : Scopelang.Ast.expr Pos.marked;
|
||||
cons : Scopelang.Ast.expr Pos.marked;
|
||||
just : Scopelang.Ast.expr Pos.marked Bindlib.box;
|
||||
cons : Scopelang.Ast.expr Pos.marked Bindlib.box;
|
||||
parameter : (Scopelang.Ast.Var.t * Dcalc.Ast.typ Pos.marked) option;
|
||||
parent_rule : RuleName.t option;
|
||||
}
|
||||
|
||||
let empty_rule (pos : Pos.t) (have_parameter : Dcalc.Ast.typ Pos.marked option) : rule =
|
||||
{
|
||||
just = (Scopelang.Ast.ELit (Dcalc.Ast.LBool false), pos);
|
||||
cons = (Scopelang.Ast.ELit Dcalc.Ast.LEmptyError, pos);
|
||||
just = Bindlib.box (Scopelang.Ast.ELit (Dcalc.Ast.LBool false), pos);
|
||||
cons = Bindlib.box (Scopelang.Ast.ELit Dcalc.Ast.LEmptyError, pos);
|
||||
parameter =
|
||||
( match have_parameter with
|
||||
| Some typ -> Some (Scopelang.Ast.Var.make ("dummy", pos), typ)
|
||||
@ -117,6 +117,9 @@ let free_variables (def : rule RuleMap.t) : Pos.t ScopeDefMap.t =
|
||||
in
|
||||
RuleMap.fold
|
||||
(fun _ rule acc ->
|
||||
let locs = Scopelang.Ast.locations_used rule.just @ Scopelang.Ast.locations_used rule.cons in
|
||||
let locs =
|
||||
Scopelang.Ast.locations_used (Bindlib.unbox rule.just)
|
||||
@ Scopelang.Ast.locations_used (Bindlib.unbox rule.cons)
|
||||
in
|
||||
add_locs acc locs)
|
||||
def ScopeDefMap.empty
|
||||
|
@ -14,6 +14,7 @@
|
||||
|
||||
module Pos = Utils.Pos
|
||||
module Errors = Utils.Errors
|
||||
module Cli = Utils.Cli
|
||||
|
||||
type rule_tree = Leaf of Ast.rule | Node of Ast.rule * rule_tree list
|
||||
|
||||
@ -46,31 +47,37 @@ let rec def_map_to_tree (def : Ast.rule Ast.RuleMap.t) : rule_tree =
|
||||
in
|
||||
Node (no_parent, tree_children)
|
||||
|
||||
let rec rule_tree_to_expr (is_func : Scopelang.Ast.Var.t option) (tree : rule_tree) :
|
||||
Scopelang.Ast.expr Pos.marked =
|
||||
let rec rule_tree_to_expr ~(toplevel : bool) (is_func : Scopelang.Ast.Var.t option)
|
||||
(tree : rule_tree) : Scopelang.Ast.expr Pos.marked Bindlib.box =
|
||||
let rule, children = match tree with Leaf r -> (r, []) | Node (r, child) -> (r, child) in
|
||||
(* because each rule has its own variable parameter and we want to convert the whole rule tree
|
||||
into a function, we need to perform some alpha-renaming of all the expressions *)
|
||||
let substitute_parameter (e : Scopelang.Ast.expr Pos.marked) : Scopelang.Ast.expr Pos.marked =
|
||||
let substitute_parameter (e : Scopelang.Ast.expr Pos.marked Bindlib.box) :
|
||||
Scopelang.Ast.expr Pos.marked Bindlib.box =
|
||||
match (is_func, rule.parameter) with
|
||||
| Some new_param, Some (old_param, _) ->
|
||||
let binder = Bindlib.bind_var old_param (Bindlib.box e) in
|
||||
Bindlib.subst (Bindlib.unbox binder) (Scopelang.Ast.EVar new_param, Pos.no_pos)
|
||||
let binder = Bindlib.bind_var old_param e in
|
||||
Bindlib.box_apply2
|
||||
(fun binder new_param -> Bindlib.subst binder new_param)
|
||||
binder (Bindlib.box_var new_param)
|
||||
| None, None -> e
|
||||
| _ -> assert false
|
||||
(* should not happen *)
|
||||
in
|
||||
let just = substitute_parameter rule.Ast.just in
|
||||
let cons = substitute_parameter rule.Ast.cons in
|
||||
let children = List.map (rule_tree_to_expr is_func) children in
|
||||
let default = (Scopelang.Ast.EDefault (just, cons, children), Pos.no_pos) in
|
||||
let children = Bindlib.box_list (List.map (rule_tree_to_expr ~toplevel:false is_func) children) in
|
||||
let default =
|
||||
Bindlib.box_apply3
|
||||
(fun just cons children -> (Scopelang.Ast.EDefault (just, cons, children), Pos.no_pos))
|
||||
just cons children
|
||||
in
|
||||
match (is_func, rule.parameter) with
|
||||
| None, None -> default
|
||||
| Some new_param, Some (_, typ) ->
|
||||
Bindlib.unbox
|
||||
(Scopelang.Ast.make_abs
|
||||
(Array.of_list [ new_param ])
|
||||
(Bindlib.box default) Pos.no_pos [ typ ] Pos.no_pos)
|
||||
if toplevel then
|
||||
Scopelang.Ast.make_abs (Array.of_list [ new_param ]) default Pos.no_pos [ typ ] Pos.no_pos
|
||||
else default
|
||||
| _ -> assert false
|
||||
|
||||
(* should not happen *)
|
||||
@ -97,7 +104,7 @@ let translate_def (def : Ast.rule Ast.RuleMap.t) : Scopelang.Ast.expr Pos.marked
|
||||
( Some
|
||||
(Format.asprintf "The type of the parameter of this expression is %a"
|
||||
Dcalc.Print.format_typ typ),
|
||||
Pos.get_position r.Ast.cons ))
|
||||
Pos.get_position (Bindlib.unbox r.Ast.cons) ))
|
||||
(Ast.RuleMap.bindings (Ast.RuleMap.filter (fun n r -> not (is_typ n r)) def)))
|
||||
| None -> assert false (* should not happen *)
|
||||
else if all_rules_not_func then None
|
||||
@ -105,10 +112,13 @@ let translate_def (def : Ast.rule Ast.RuleMap.t) : Scopelang.Ast.expr Pos.marked
|
||||
Errors.raise_multispanned_error
|
||||
"some definitions of the same variable are functions while others aren't"
|
||||
( List.map
|
||||
(fun (_, r) -> (Some "This definition is a function:", Pos.get_position r.Ast.cons))
|
||||
(fun (_, r) ->
|
||||
(Some "This definition is a function:", Pos.get_position (Bindlib.unbox r.Ast.cons)))
|
||||
(Ast.RuleMap.bindings (Ast.RuleMap.filter is_func def))
|
||||
@ List.map
|
||||
(fun (_, r) -> (Some "This definition is not a function:", Pos.get_position r.Ast.cons))
|
||||
(fun (_, r) ->
|
||||
( Some "This definition is not a function:",
|
||||
Pos.get_position (Bindlib.unbox r.Ast.cons) ))
|
||||
(Ast.RuleMap.bindings (Ast.RuleMap.filter (fun n r -> not (is_func n r)) def)) )
|
||||
in
|
||||
let dummy_rule = Ast.empty_rule Pos.no_pos is_def_func in
|
||||
@ -123,9 +133,10 @@ let translate_def (def : Ast.rule Ast.RuleMap.t) : Scopelang.Ast.expr Pos.marked
|
||||
def)
|
||||
in
|
||||
let def_tree = def_map_to_tree def in
|
||||
rule_tree_to_expr
|
||||
(Option.map (fun _ -> Scopelang.Ast.Var.make ("param", Pos.no_pos)) is_def_func)
|
||||
def_tree
|
||||
Bindlib.unbox
|
||||
(rule_tree_to_expr ~toplevel:true
|
||||
(Option.map (fun _ -> Scopelang.Ast.Var.make ("ρ", Pos.no_pos)) is_def_func)
|
||||
def_tree)
|
||||
|
||||
let translate_scope (scope : Ast.scope) : Scopelang.Ast.scope_decl =
|
||||
let scope_dependencies = Dependency.build_scope_dependencies scope in
|
||||
|
81
src/catala/scope_language/print.ml
Normal file
81
src/catala/scope_language/print.ml
Normal file
@ -0,0 +1,81 @@
|
||||
(* This file is part of the Catala compiler, a specification language for tax and social benefits
|
||||
computation rules. Copyright (C) 2020 Inria, contributor: Denis Merigoux
|
||||
<denis.merigoux@inria.fr>
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||
in compliance with the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
or implied. See the License for the specific language governing permissions and limitations under
|
||||
the License. *)
|
||||
|
||||
module Pos = Utils.Pos
|
||||
open Ast
|
||||
|
||||
let needs_parens (e : expr Pos.marked) : bool =
|
||||
match Pos.unmark e with EAbs _ -> true | _ -> false
|
||||
|
||||
let format_var (fmt : Format.formatter) (v : Var.t) : unit =
|
||||
Format.fprintf fmt "%s_%d" (Bindlib.name_of v) (Bindlib.uid_of v)
|
||||
|
||||
let format_location (fmt : Format.formatter) (l : location) : unit =
|
||||
match l with
|
||||
| ScopeVar v -> Format.fprintf fmt "%a" ScopeVar.format_t (Pos.unmark v)
|
||||
| SubScopeVar (_, subindex, subvar) ->
|
||||
Format.fprintf fmt "%a.%a" SubScopeName.format_t (Pos.unmark subindex) ScopeVar.format_t
|
||||
(Pos.unmark subvar)
|
||||
|
||||
let rec format_expr (fmt : Format.formatter) (e : expr Pos.marked) : unit =
|
||||
let format_with_parens (fmt : Format.formatter) (e : expr Pos.marked) =
|
||||
if needs_parens e then Format.fprintf fmt "(%a)" format_expr e
|
||||
else Format.fprintf fmt "%a" format_expr e
|
||||
in
|
||||
match Pos.unmark e with
|
||||
| ELocation l -> Format.fprintf fmt "%a" format_location l
|
||||
| EVar v -> Format.fprintf fmt "%a" format_var v
|
||||
| ELit l -> Format.fprintf fmt "%a" Dcalc.Print.format_lit (Pos.same_pos_as l e)
|
||||
| EApp ((EAbs (_, binder, taus), _), args) ->
|
||||
let xs, body = Bindlib.unmbind binder in
|
||||
let xs_tau = List.map2 (fun x tau -> (x, tau)) (Array.to_list xs) taus in
|
||||
let xs_tau_arg = List.map2 (fun (x, tau) arg -> (x, tau, arg)) xs_tau args in
|
||||
Format.fprintf fmt "@[%a%a@]"
|
||||
(Format.pp_print_list
|
||||
~pp_sep:(fun fmt () -> Format.fprintf fmt " ")
|
||||
(fun fmt (x, tau, arg) ->
|
||||
Format.fprintf fmt "@[@[<hov 2>let@ %a@ :@ %a@ =@ %a@]@ in@\n@]" format_var x
|
||||
Dcalc.Print.format_typ tau format_expr arg))
|
||||
xs_tau_arg format_expr body
|
||||
| EAbs (_, binder, taus) ->
|
||||
let xs, body = Bindlib.unmbind binder in
|
||||
let xs_tau = List.map2 (fun x tau -> (x, tau)) (Array.to_list xs) taus in
|
||||
Format.fprintf fmt "@[<hov 2>λ@ %a@ →@ %a@]"
|
||||
(Format.pp_print_list
|
||||
~pp_sep:(fun fmt () -> Format.fprintf fmt " ")
|
||||
(fun fmt (x, tau) ->
|
||||
Format.fprintf fmt "@[(%a:@ %a)@]" format_var x Dcalc.Print.format_typ tau))
|
||||
xs_tau format_expr body
|
||||
| EApp ((EOp (Binop op), _), [ arg1; arg2 ]) ->
|
||||
Format.fprintf fmt "@[%a@ %a@ %a@]" format_with_parens arg1 Dcalc.Print.format_binop
|
||||
(op, Pos.no_pos) format_with_parens arg2
|
||||
| EApp ((EOp (Unop op), _), [ arg1 ]) ->
|
||||
Format.fprintf fmt "@[%a@ %a@]" Dcalc.Print.format_unop (op, Pos.no_pos) format_with_parens
|
||||
arg1
|
||||
| EApp (f, args) ->
|
||||
Format.fprintf fmt "@[%a@ %a@]" format_expr f
|
||||
(Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ") format_with_parens)
|
||||
args
|
||||
| EIfThenElse (e1, e2, e3) ->
|
||||
Format.fprintf fmt "if@ @[<hov 2>%a@]@ then@ @[<hov 2>%a@]@ else@ @[<hov 2>%a@]" format_expr
|
||||
e1 format_expr e2 format_expr e3
|
||||
| EOp (Binop op) -> Format.fprintf fmt "%a" Dcalc.Print.format_binop (op, Pos.no_pos)
|
||||
| EOp (Unop op) -> Format.fprintf fmt "%a" Dcalc.Print.format_unop (op, Pos.no_pos)
|
||||
| EDefault (just, cons, subs) ->
|
||||
if List.length subs = 0 then
|
||||
Format.fprintf fmt "@[⟨%a ⊢ %a⟩@]" format_expr just format_expr cons
|
||||
else
|
||||
Format.fprintf fmt "@[<hov 2>⟨%a ⊢ %a |@ %a⟩@]" format_expr just format_expr cons
|
||||
(Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ") format_expr)
|
||||
subs
|
@ -56,6 +56,7 @@ let merge_defaults (caller : Dcalc.Ast.expr Pos.marked Bindlib.box)
|
||||
|
||||
let rec translate_expr (ctx : ctx) (e : Ast.expr Pos.marked) : Dcalc.Ast.expr Pos.marked Bindlib.box
|
||||
=
|
||||
(* Cli.debug_print (Format.asprintf "Translating: %a" Print.format_expr e); *)
|
||||
Bindlib.box_apply
|
||||
(fun (x : Dcalc.Ast.expr) -> Pos.same_pos_as x e)
|
||||
( match Pos.unmark e with
|
||||
|
@ -7,8 +7,8 @@ new scope S:
|
||||
param b type bool
|
||||
|
||||
scope S:
|
||||
def f of x [ (x >= x) ] := x + x
|
||||
def f of x [ not b ] := x * x
|
||||
def f of x1 [ (x1 >= x1) ] := x1 + x1
|
||||
def f of x2 [ not b ] := x2 * x2
|
||||
|
||||
def b := false
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user