mirror of
https://github.com/CatalaLang/catala.git
synced 2024-09-19 16:28:12 +03:00
Desugared to scope complete but untested [skip ci]
This commit is contained in:
parent
cf8c6233d9
commit
5ef7e45e11
@ -157,16 +157,15 @@ let def_map_to_tree (def_info : Ast.ScopeDef.t) (def : Ast.rule Ast.RuleMap.t) :
|
||||
(** From the {!type: rule_tree}, builds an {!constructor: Dcalc.Ast.EDefault} expression in the
|
||||
scope language. The [~toplevel] parameter is used to know when to place the toplevel binding in
|
||||
the case of functions. *)
|
||||
let rec rule_tree_to_expr ~(toplevel : bool) (def_pos : Pos.t)
|
||||
(is_func : Scopelang.Ast.Var.t option) (tree : rule_tree) :
|
||||
Scopelang.Ast.expr Pos.marked Bindlib.box =
|
||||
let rec rule_tree_to_expr ~(toplevel : bool) (ctx : ctx) (def_pos : Pos.t)
|
||||
(is_func : Ast.Var.t option) (tree : rule_tree) : Scopelang.Ast.expr Pos.marked Bindlib.box =
|
||||
let exceptions, base_rules =
|
||||
match tree with Leaf r -> ([], r) | Node (exceptions, r) -> (exceptions, r)
|
||||
in
|
||||
(* because each rule has its own variable parameter and we want to convert the whole rule tree
|
||||
into a function, we need to perform some alpha-renaming of all the expressions *)
|
||||
let substitute_parameter (e : Scopelang.Ast.expr Pos.marked Bindlib.box) (rule : Ast.rule) :
|
||||
Scopelang.Ast.expr Pos.marked Bindlib.box =
|
||||
let substitute_parameter (e : Ast.expr Pos.marked Bindlib.box) (rule : Ast.rule) :
|
||||
Ast.expr Pos.marked Bindlib.box =
|
||||
match (is_func, rule.Ast.rule_parameter) with
|
||||
| Some new_param, Some (old_param, _) ->
|
||||
let binder = Bindlib.bind_var old_param e in
|
||||
@ -177,12 +176,36 @@ let rec rule_tree_to_expr ~(toplevel : bool) (def_pos : Pos.t)
|
||||
| _ -> assert false
|
||||
(* should not happen *)
|
||||
in
|
||||
let ctx =
|
||||
match is_func with
|
||||
| None -> ctx
|
||||
| Some new_param -> (
|
||||
match Ast.VarMap.find_opt new_param ctx.var_mapping with
|
||||
| None ->
|
||||
let new_param_scope = Scopelang.Ast.Var.make (Bindlib.name_of new_param, def_pos) in
|
||||
{ ctx with var_mapping = Ast.VarMap.add new_param new_param_scope ctx.var_mapping }
|
||||
| Some _ ->
|
||||
(* We only create a mapping if none exists because [rule_tree_to_expr] is called
|
||||
recursively on the exceptions of the tree and we don't want to create a new Scopelang
|
||||
variable for the parameter at each tree level. *)
|
||||
ctx)
|
||||
in
|
||||
let base_just_list =
|
||||
List.map (fun rule -> substitute_parameter rule.Ast.rule_just rule) base_rules
|
||||
in
|
||||
let base_cons_list =
|
||||
List.map (fun rule -> substitute_parameter rule.Ast.rule_cons rule) base_rules
|
||||
in
|
||||
let translate_and_unbox_list (list : Ast.expr Pos.marked Bindlib.box list) :
|
||||
Scopelang.Ast.expr Pos.marked Bindlib.box list =
|
||||
List.map
|
||||
(fun e ->
|
||||
(* There are two levels of boxing here, the outermost is introduced by the [translate_expr]
|
||||
function for which all of the bindings should have been closed by now, so we can safely
|
||||
unbox. *)
|
||||
Bindlib.unbox (Bindlib.box_apply (translate_expr ctx) e))
|
||||
list
|
||||
in
|
||||
let default_containing_base_cases =
|
||||
Bindlib.box_apply2
|
||||
(fun base_just_list base_cons_list ->
|
||||
@ -194,10 +217,11 @@ let rec rule_tree_to_expr ~(toplevel : bool) (def_pos : Pos.t)
|
||||
(Scopelang.Ast.ELit (Dcalc.Ast.LBool false), def_pos),
|
||||
(Scopelang.Ast.ELit Dcalc.Ast.LEmptyError, def_pos) ),
|
||||
def_pos ))
|
||||
(Bindlib.box_list base_just_list) (Bindlib.box_list base_cons_list)
|
||||
(Bindlib.box_list (translate_and_unbox_list base_just_list))
|
||||
(Bindlib.box_list (translate_and_unbox_list base_cons_list))
|
||||
in
|
||||
let exceptions =
|
||||
Bindlib.box_list (List.map (rule_tree_to_expr ~toplevel:false def_pos is_func) exceptions)
|
||||
Bindlib.box_list (List.map (rule_tree_to_expr ~toplevel:false ctx def_pos is_func) exceptions)
|
||||
in
|
||||
let default =
|
||||
Bindlib.box_apply2
|
||||
@ -221,7 +245,9 @@ let rec rule_tree_to_expr ~(toplevel : bool) (def_pos : Pos.t)
|
||||
(Scopelang.Ast.ErrorOnEmpty default, def_pos))
|
||||
default
|
||||
in
|
||||
Scopelang.Ast.make_abs (Array.of_list [ new_param ]) default def_pos [ typ ] def_pos
|
||||
Scopelang.Ast.make_abs
|
||||
(Array.of_list [ Ast.VarMap.find new_param ctx.var_mapping ])
|
||||
default def_pos [ typ ] def_pos
|
||||
else default
|
||||
| _ -> (* should not happen *) assert false
|
||||
|
||||
@ -229,7 +255,7 @@ let rec rule_tree_to_expr ~(toplevel : bool) (def_pos : Pos.t)
|
||||
|
||||
(** Translates a definition inside a scope, the resulting expression should be an {!constructor:
|
||||
Dcalc.Ast.EDefault} *)
|
||||
let translate_def (def_info : Ast.ScopeDef.t) (def : Ast.rule Ast.RuleMap.t)
|
||||
let translate_def (ctx : ctx) (def_info : Ast.ScopeDef.t) (def : Ast.rule Ast.RuleMap.t)
|
||||
(typ : Scopelang.Ast.typ Pos.marked) (io : Scopelang.Ast.io) ~(is_cond : bool)
|
||||
~(is_subscope_var : bool) : Scopelang.Ast.expr Pos.marked =
|
||||
(* Here, we have to transform this list of rules into a default tree. *)
|
||||
@ -292,9 +318,11 @@ let translate_def (def_info : Ast.ScopeDef.t) (def : Ast.rule Ast.RuleMap.t)
|
||||
then (ELit LEmptyError, Pos.no_pos)
|
||||
else
|
||||
Bindlib.unbox
|
||||
(rule_tree_to_expr ~toplevel:true
|
||||
(rule_tree_to_expr ~toplevel:true ctx
|
||||
(Ast.ScopeDef.get_position def_info)
|
||||
(Option.map (fun _ -> Scopelang.Ast.Var.make ("param", Pos.no_pos)) is_def_func_param_typ)
|
||||
(Option.map
|
||||
(fun _ -> Ast.Var.make ("param", Ast.ScopeDef.get_position def_info))
|
||||
is_def_func_param_typ)
|
||||
(match top_list with
|
||||
| [] ->
|
||||
(* In this case, there are no rules to define the expression *)
|
||||
@ -302,7 +330,7 @@ let translate_def (def_info : Ast.ScopeDef.t) (def : Ast.rule Ast.RuleMap.t)
|
||||
| _ -> Node (top_list, [ top_value ])))
|
||||
|
||||
(** Translates a scope *)
|
||||
let translate_scope (scope : Ast.scope) : Scopelang.Ast.scope_decl =
|
||||
let translate_scope (ctx : ctx) (scope : Ast.scope) : Scopelang.Ast.scope_decl =
|
||||
let scope_dependencies = Dependency.build_scope_dependencies scope in
|
||||
Dependency.check_for_cycle scope scope_dependencies;
|
||||
let scope_ordering = Dependency.correct_computation_ordering scope_dependencies in
|
||||
@ -311,8 +339,10 @@ let translate_scope (scope : Ast.scope) : Scopelang.Ast.scope_decl =
|
||||
(List.map
|
||||
(fun vertex ->
|
||||
match vertex with
|
||||
| Dependency.Vertex.Var (var : Scopelang.Ast.ScopeVar.t) -> (
|
||||
let scope_def = Ast.ScopeDefMap.find (Ast.ScopeDef.Var var) scope.scope_defs in
|
||||
| Dependency.Vertex.Var (var, state) -> (
|
||||
let scope_def =
|
||||
Ast.ScopeDefMap.find (Ast.ScopeDef.Var (var, state)) scope.scope_defs
|
||||
in
|
||||
let var_def = scope_def.scope_def_rules in
|
||||
let var_typ = scope_def.scope_def_typ in
|
||||
let is_cond = scope_def.scope_def_is_condition in
|
||||
@ -321,8 +351,7 @@ let translate_scope (scope : Ast.scope) : Scopelang.Ast.scope_decl =
|
||||
(* If the variable is tagged as input, then it shall not be redefined. *)
|
||||
Errors.raise_multispanned_error
|
||||
"It is impossible to give a definition to a scope variable tagged as input."
|
||||
(( Some "Incriminated variable:",
|
||||
Pos.get_position (Scopelang.Ast.ScopeVar.get_info var) )
|
||||
((Some "Incriminated variable:", Pos.get_position (Ast.ScopeVar.get_info var))
|
||||
:: List.map
|
||||
(fun (rule, _) ->
|
||||
( Some "Incriminated variable definition:",
|
||||
@ -331,14 +360,22 @@ let translate_scope (scope : Ast.scope) : Scopelang.Ast.scope_decl =
|
||||
| OnlyInput -> [] (* we do not provide any definition for an input-only variable *)
|
||||
| _ ->
|
||||
let expr_def =
|
||||
translate_def (Ast.ScopeDef.Var var) var_def var_typ scope_def.Ast.scope_def_io
|
||||
~is_cond ~is_subscope_var:false
|
||||
translate_def ctx
|
||||
(Ast.ScopeDef.Var (var, state))
|
||||
var_def var_typ scope_def.Ast.scope_def_io ~is_cond ~is_subscope_var:false
|
||||
in
|
||||
let scope_var =
|
||||
match (Ast.ScopeVarMap.find var ctx.scope_var_mapping, state) with
|
||||
| WholeVar v, None -> v
|
||||
| States states, Some state -> List.assoc state states
|
||||
| _ -> failwith "should not happen"
|
||||
in
|
||||
[
|
||||
Scopelang.Ast.Definition
|
||||
( ( Scopelang.Ast.ScopeVar
|
||||
(var, Pos.get_position (Scopelang.Ast.ScopeVar.get_info var)),
|
||||
Pos.get_position (Scopelang.Ast.ScopeVar.get_info var) ),
|
||||
( scope_var,
|
||||
Pos.get_position (Scopelang.Ast.ScopeVar.get_info scope_var) ),
|
||||
Pos.get_position (Scopelang.Ast.ScopeVar.get_info scope_var) ),
|
||||
var_typ,
|
||||
scope_def.Ast.scope_def_io,
|
||||
expr_def );
|
||||
@ -384,8 +421,7 @@ let translate_scope (scope : Ast.scope) : Scopelang.Ast.scope_decl =
|
||||
tagged as input or context."
|
||||
((Some "Incriminated subscope:", Ast.ScopeDef.get_position def_key)
|
||||
:: ( Some "Incriminated variable:",
|
||||
Pos.get_position (Scopelang.Ast.ScopeVar.get_info sub_scope_var)
|
||||
)
|
||||
Pos.get_position (Ast.ScopeVar.get_info sub_scope_var) )
|
||||
:: List.map
|
||||
(fun (rule, _) ->
|
||||
( Some "Incriminated subscope variable definition:",
|
||||
@ -400,27 +436,31 @@ let translate_scope (scope : Ast.scope) : Scopelang.Ast.scope_decl =
|
||||
[
|
||||
(Some "Incriminated subscope:", Ast.ScopeDef.get_position def_key);
|
||||
( Some "Incriminated variable:",
|
||||
Pos.get_position (Scopelang.Ast.ScopeVar.get_info sub_scope_var)
|
||||
);
|
||||
Pos.get_position (Ast.ScopeVar.get_info sub_scope_var) );
|
||||
]
|
||||
| _ -> ());
|
||||
(* Now that all is good, we can proceed with translating this redefinition
|
||||
to a proper Scopelang term. *)
|
||||
let expr_def =
|
||||
translate_def def_key def def_typ scope_def.Ast.scope_def_io ~is_cond
|
||||
translate_def ctx def_key def def_typ scope_def.Ast.scope_def_io ~is_cond
|
||||
~is_subscope_var:true
|
||||
in
|
||||
let subscop_real_name =
|
||||
Scopelang.Ast.SubScopeMap.find sub_scope_index scope.scope_sub_scopes
|
||||
in
|
||||
let var_pos =
|
||||
Pos.get_position (Scopelang.Ast.ScopeVar.get_info sub_scope_var)
|
||||
in
|
||||
let var_pos = Pos.get_position (Ast.ScopeVar.get_info sub_scope_var) in
|
||||
Scopelang.Ast.Definition
|
||||
( ( Scopelang.Ast.SubScopeVar
|
||||
( subscop_real_name,
|
||||
(sub_scope_index, var_pos),
|
||||
(sub_scope_var, var_pos) ),
|
||||
match
|
||||
Ast.ScopeVarMap.find sub_scope_var ctx.scope_var_mapping
|
||||
with
|
||||
| WholeVar v -> (v, var_pos)
|
||||
| States states ->
|
||||
(* When defining a sub-scope variable, we always define its
|
||||
first state in the sub-scope. *)
|
||||
(snd (List.hd states), var_pos) ),
|
||||
var_pos ),
|
||||
def_typ,
|
||||
scope_def.Ast.scope_def_io,
|
||||
@ -433,17 +473,43 @@ let translate_scope (scope : Ast.scope) : Scopelang.Ast.scope_decl =
|
||||
sub_scope_vars_redefs @ [ Scopelang.Ast.Call (sub_scope, sub_scope_index) ])
|
||||
scope_ordering)
|
||||
in
|
||||
(* Then, after having computed all the scopes variables, we add the assertions *)
|
||||
(* Then, after having computed all the scopes variables, we add the assertions. TODO: the
|
||||
assertions should be interleaved with the definitions! *)
|
||||
let scope_decl_rules =
|
||||
scope_decl_rules
|
||||
@ List.map (fun e -> Scopelang.Ast.Assertion (Bindlib.unbox e)) scope.Ast.scope_assertions
|
||||
@ List.map
|
||||
(fun e ->
|
||||
let scope_e = translate_expr ctx e in
|
||||
Bindlib.unbox (Bindlib.box_apply (fun scope_e -> Scopelang.Ast.Assertion scope_e) scope_e))
|
||||
(Bindlib.unbox (Bindlib.box_list scope.Ast.scope_assertions))
|
||||
in
|
||||
let scope_sig =
|
||||
Scopelang.Ast.ScopeVarSet.fold
|
||||
(fun var acc ->
|
||||
let scope_def = Ast.ScopeDefMap.find (Ast.ScopeDef.Var var) scope.scope_defs in
|
||||
let typ = scope_def.scope_def_typ in
|
||||
Scopelang.Ast.ScopeVarMap.add var (typ, scope_def.scope_def_io) acc)
|
||||
Ast.ScopeVarMap.fold
|
||||
(fun var (states : Ast.var_or_states) acc ->
|
||||
match states with
|
||||
| WholeVar ->
|
||||
let scope_def = Ast.ScopeDefMap.find (Ast.ScopeDef.Var (var, None)) scope.scope_defs in
|
||||
let typ = scope_def.scope_def_typ in
|
||||
Scopelang.Ast.ScopeVarMap.add
|
||||
(match Ast.ScopeVarMap.find var ctx.scope_var_mapping with
|
||||
| WholeVar v -> v
|
||||
| States _ -> failwith "should not happen")
|
||||
(typ, scope_def.scope_def_io) acc
|
||||
| States states ->
|
||||
(* What happens in the case of variables with multiple states is interesting. We need to
|
||||
create as many Scopelang.Var entries in the scope signature as there are states. *)
|
||||
List.fold_left
|
||||
(fun acc (state : Ast.StateName.t) ->
|
||||
let scope_def =
|
||||
Ast.ScopeDefMap.find (Ast.ScopeDef.Var (var, Some state)) scope.scope_defs
|
||||
in
|
||||
Scopelang.Ast.ScopeVarMap.add
|
||||
(match Ast.ScopeVarMap.find var ctx.scope_var_mapping with
|
||||
| WholeVar _ -> failwith "should not happen"
|
||||
| States states' -> List.assoc state states')
|
||||
(scope_def.scope_def_typ, scope_def.scope_def_io)
|
||||
acc)
|
||||
acc states)
|
||||
scope.scope_vars Scopelang.Ast.ScopeVarMap.empty
|
||||
in
|
||||
{
|
||||
@ -455,8 +521,45 @@ let translate_scope (scope : Ast.scope) : Scopelang.Ast.scope_decl =
|
||||
(** {1 API} *)
|
||||
|
||||
let translate_program (pgrm : Ast.program) : Scopelang.Ast.program =
|
||||
(* First we give mappings to all the locations between Desugared and Scopelang. This involves
|
||||
creating a new Scopelang scope variable for every state of a Desugared variable. *)
|
||||
let ctx =
|
||||
Scopelang.Ast.ScopeMap.fold
|
||||
(fun _scope scope_decl ctx ->
|
||||
Ast.ScopeVarMap.fold
|
||||
(fun scope_var (states : Ast.var_or_states) ctx ->
|
||||
match states with
|
||||
| Ast.WholeVar ->
|
||||
{
|
||||
ctx with
|
||||
scope_var_mapping =
|
||||
Ast.ScopeVarMap.add scope_var
|
||||
(WholeVar (Scopelang.Ast.ScopeVar.fresh (Ast.ScopeVar.get_info scope_var)))
|
||||
ctx.scope_var_mapping;
|
||||
}
|
||||
| States states ->
|
||||
{
|
||||
ctx with
|
||||
scope_var_mapping =
|
||||
Ast.ScopeVarMap.add scope_var
|
||||
(States
|
||||
(List.map
|
||||
(fun state ->
|
||||
( state,
|
||||
Scopelang.Ast.ScopeVar.fresh
|
||||
(let state_name, state_pos = Ast.StateName.get_info state in
|
||||
( Pos.unmark (Ast.ScopeVar.get_info scope_var) ^ "_" ^ state_name,
|
||||
state_pos )) ))
|
||||
states))
|
||||
ctx.scope_var_mapping;
|
||||
})
|
||||
scope_decl.Ast.scope_vars ctx)
|
||||
pgrm.Ast.program_scopes
|
||||
{ scope_var_mapping = Ast.ScopeVarMap.empty; var_mapping = Ast.VarMap.empty }
|
||||
in
|
||||
{
|
||||
Scopelang.Ast.program_scopes = Scopelang.Ast.ScopeMap.map translate_scope pgrm.program_scopes;
|
||||
Scopelang.Ast.program_scopes =
|
||||
Scopelang.Ast.ScopeMap.map (translate_scope ctx) pgrm.program_scopes;
|
||||
Scopelang.Ast.program_structs = pgrm.program_structs;
|
||||
Scopelang.Ast.program_enums = pgrm.program_enums;
|
||||
}
|
||||
|
@ -1171,20 +1171,44 @@ let desugar_program (ctxt : Name_resolution.context) (prgm : Ast.program) : Desu
|
||||
}
|
||||
acc
|
||||
| states ->
|
||||
List.fold_left
|
||||
(fun acc state ->
|
||||
let def_key = Desugared.Ast.ScopeDef.Var (v, Some state) in
|
||||
Desugared.Ast.ScopeDefMap.add def_key
|
||||
{
|
||||
Desugared.Ast.scope_def_rules = Desugared.Ast.RuleMap.empty;
|
||||
Desugared.Ast.scope_def_typ = v_sig.var_sig_typ;
|
||||
Desugared.Ast.scope_def_label_groups =
|
||||
Name_resolution.label_groups ctxt s_uid def_key;
|
||||
Desugared.Ast.scope_def_is_condition = v_sig.var_sig_is_condition;
|
||||
Desugared.Ast.scope_def_io = attribute_to_io v_sig.var_sig_io;
|
||||
}
|
||||
acc)
|
||||
acc states)
|
||||
fst
|
||||
(List.fold_left
|
||||
(fun (acc, i) state ->
|
||||
let def_key = Desugared.Ast.ScopeDef.Var (v, Some state) in
|
||||
( Desugared.Ast.ScopeDefMap.add def_key
|
||||
{
|
||||
Desugared.Ast.scope_def_rules = Desugared.Ast.RuleMap.empty;
|
||||
Desugared.Ast.scope_def_typ = v_sig.var_sig_typ;
|
||||
Desugared.Ast.scope_def_label_groups =
|
||||
Name_resolution.label_groups ctxt s_uid def_key;
|
||||
Desugared.Ast.scope_def_is_condition =
|
||||
v_sig.var_sig_is_condition;
|
||||
Desugared.Ast.scope_def_io =
|
||||
(* The first state should have the input I/O of the
|
||||
original variable, and the last state should have the
|
||||
output I/O of the original variable. All intermediate
|
||||
states shall have "internal" I/O.*)
|
||||
(let original_io = attribute_to_io v_sig.var_sig_io in
|
||||
let io_input =
|
||||
if i = 0 then original_io.io_input
|
||||
else
|
||||
( Scopelang.Ast.NoInput,
|
||||
Pos.get_position
|
||||
(Desugared.Ast.StateName.get_info state) )
|
||||
in
|
||||
let io_output =
|
||||
if i = List.length states - 1 then
|
||||
original_io.io_output
|
||||
else
|
||||
( false,
|
||||
Pos.get_position
|
||||
(Desugared.Ast.StateName.get_info state) )
|
||||
in
|
||||
{ io_input; io_output });
|
||||
}
|
||||
acc,
|
||||
i + 1 ))
|
||||
(acc, 0) states))
|
||||
s_context.Name_resolution.var_idmap Desugared.Ast.ScopeDefMap.empty
|
||||
in
|
||||
let scope_and_subscope_vars_defs =
|
||||
|
Loading…
Reference in New Issue
Block a user