Switched scope functions input to struct instead of many arguments

This commit is contained in:
Denis Merigoux 2021-02-01 15:57:19 +01:00
parent 1faa35900b
commit d88ccc38f6
8 changed files with 161 additions and 63 deletions

View File

@ -355,16 +355,16 @@ and evaluate_expr (ctx : Ast.decl_ctx) (e : A.expr Pos.marked) : A.expr Pos.mark
| None ->
Errors.raise_spanned_error
(Format.asprintf
"the tuple has %d components but the %i-th element was requested (should not \
"The tuple has %d components but the %i-th element was requested (should not \
happen if the term was well-type)"
(List.length es) n)
(Pos.get_position e1) )
| _ ->
Errors.raise_spanned_error
(Format.asprintf
"the expression should be a tuple with %d components but is not (should not happen \
if the term was well-typed)"
n)
"The expression %a should be a tuple with %d components but is not (should not \
happen if the term was well-typed)"
(Print.format_expr ctx) e n)
(Pos.get_position e1) )
| EInj (e1, n, en, ts) ->
let e1' = evaluate_expr ctx e1 in
@ -453,15 +453,21 @@ and evaluate_expr (ctx : Ast.decl_ctx) (e : A.expr Pos.marked) : A.expr Pos.mark
(** Interpret a program. This function expects an expression typed as a function whose argument are
all thunked. The function is executed by providing for each argument a thunked empty default. *)
let interpret_program (ctx : Ast.decl_ctx) (e : Ast.expr Pos.marked) :
(Ast.Var.t * Ast.expr Pos.marked) list =
(Uid.MarkedString.info * Ast.expr Pos.marked) list =
match Pos.unmark (evaluate_expr ctx e) with
| Ast.EAbs (_, binder, taus) -> (
| Ast.EAbs (_, _, [ (Ast.TTuple (taus, Some s_in), _) ]) -> (
let application_term = List.map (fun _ -> empty_thunked_term) taus in
let to_interpret = (Ast.EApp (e, application_term), Pos.no_pos) in
let to_interpret =
(Ast.EApp (e, [ (Ast.ETuple (application_term, Some s_in), Pos.no_pos) ]), Pos.no_pos)
in
match Pos.unmark (evaluate_expr ctx to_interpret) with
| Ast.ETuple (args, Some _) ->
let vars, _ = Bindlib.unmbind binder in
List.map2 (fun arg var -> (var, arg)) args (Array.to_list vars)
| Ast.ETuple (args, Some s_out) ->
let s_out_fields =
List.map
(fun (f, _) -> Ast.StructFieldName.get_info f)
(Ast.StructMap.find s_out ctx.ctx_structs)
in
List.map2 (fun arg var -> (var, arg)) args s_out_fields
| _ ->
Errors.raise_spanned_error
"The interpretation of a program should always yield a struct corresponding to the \

View File

@ -82,7 +82,7 @@ let rec format_typ (ctx : Ast.decl_ctx) (fmt : Format.formatter)
let rec unify (ctx : Ast.decl_ctx) (t1 : typ Pos.marked UnionFind.elem)
(t2 : typ Pos.marked UnionFind.elem) : unit =
let unify = unify ctx in
(* Cli.debug_print (Format.asprintf "Unifying %a and %a" format_typ t1 format_typ t2); *)
(* Cli.debug_print (Format.asprintf "Unifying %a and %a" (format_typ ctx) t1 (format_typ ctx) t2); *)
let t1_repr = UnionFind.get (UnionFind.find t1) in
let t2_repr = UnionFind.get (UnionFind.find t2) in
let raise_type_error (t1_pos : Pos.t) (t2_pos : Pos.t) : 'a =
@ -225,6 +225,7 @@ type env = typ Pos.marked UnionFind.elem A.VarMap.t
(** Infers the most permissive type from an expression *)
let rec typecheck_expr_bottom_up (ctx : Ast.decl_ctx) (env : env) (e : A.expr Pos.marked) :
typ Pos.marked UnionFind.elem =
(* Cli.debug_print (Format.asprintf "Looking for type of %a" (Print.format_expr ctx) e); *)
try
let out =
match Pos.unmark e with
@ -339,7 +340,8 @@ let rec typecheck_expr_bottom_up (ctx : Ast.decl_ctx) (env : env) (e : A.expr Po
es;
UnionFind.make (Pos.same_pos_as (TArray cell_type) e)
in
(* Cli.debug_print (Format.asprintf "Found type of %a: %a" Print.format_expr e format_typ out); *)
(* Cli.debug_print (Format.asprintf "Found type of %a: %a" (Print.format_expr ctx) e (format_typ
ctx) out); *)
out
with Errors.StructuredError (msg, err_pos) when List.length err_pos = 2 ->
raise
@ -351,7 +353,8 @@ let rec typecheck_expr_bottom_up (ctx : Ast.decl_ctx) (env : env) (e : A.expr Po
(** Checks whether the expression can be typed with the provided type *)
and typecheck_expr_top_down (ctx : Ast.decl_ctx) (env : env) (e : A.expr Pos.marked)
(tau : typ Pos.marked UnionFind.elem) : unit =
(* Cli.debug_print (Format.asprintf "Typechecking %a : %a" Print.format_expr e format_typ tau); *)
(* Cli.debug_print (Format.asprintf "Typechecking %a : %a" (Print.format_expr ctx) e (format_typ
ctx) tau); *)
try
match Pos.unmark e with
| EVar v -> (

View File

@ -144,8 +144,8 @@ let driver (source_file : Pos.input_file) (debug : bool) (unstyled : bool)
let prgm, prgm_expr, type_ordering =
Scopelang.Scope_to_dcalc.translate_program prgm scope_uid
in
(* Cli.debug_print (Format.asprintf "Output program:@\n%a" (Dcalc.Print.format_expr ctx)
prgm); *)
(* Cli.debug_print (Format.asprintf "Output program:@\n%a" (Dcalc.Print.format_expr
prgm.decl_ctx) prgm_expr); *)
Cli.debug_print "Typechecking...";
let _typ = Dcalc.Typing.infer_type prgm.decl_ctx prgm_expr in
(* Cli.debug_print (Format.asprintf "Typechecking results :@\n%a" Dcalc.Print.format_typ
@ -154,18 +154,25 @@ let driver (source_file : Pos.input_file) (debug : bool) (unstyled : bool)
| Cli.Run ->
Cli.debug_print "Starting interpretation...";
let results = Dcalc.Interpreter.interpret_program prgm.decl_ctx prgm_expr in
let out_regex = Re.Pcre.regexp "\\_out$" in
let results =
List.sort
(fun (v1, _) (v2, _) -> String.compare (Bindlib.name_of v1) (Bindlib.name_of v2))
List.map
(fun ((v1, v1_pos), e1) ->
let v1 = Re.Pcre.substitute ~rex:out_regex ~subst:(fun _ -> "") v1 in
((v1, v1_pos), e1))
results
in
let results =
List.sort (fun ((v1, _), _) ((v2, _), _) -> String.compare v1 v2) results
in
Cli.result_print
(Format.asprintf "Computation successful!%s"
(if List.length results > 0 then " Results:" else ""));
List.iter
(fun (var, result) ->
(fun ((var, _), result) ->
Cli.result_print
(Format.asprintf "@[<hov 2>%s@ =@ %a@]" (Bindlib.name_of var)
(Format.asprintf "@[<hov 2>%s@ =@ %a@]" var
(Dcalc.Print.format_expr prgm.decl_ctx)
result))
results;

View File

@ -15,7 +15,13 @@
open Utils
type scope_sigs_ctx =
((Ast.ScopeVar.t * Dcalc.Ast.typ) list * Dcalc.Ast.Var.t * Ast.StructName.t) Ast.ScopeMap.t
(* list of scope variables with their types *)
( (Ast.ScopeVar.t * Dcalc.Ast.typ) list
* (* var representing the scope *) Dcalc.Ast.Var.t
* (* var representing the scope input inside the scope func *) Dcalc.Ast.Var.t
* (* scope input *) Ast.StructName.t
* (* scope output *) Ast.StructName.t )
Ast.ScopeMap.t
type ctx = {
structs : Ast.struct_ctx;
@ -370,7 +376,11 @@ let rec translate_rule (ctx : ctx) (rule : Ast.rule) (rest : Ast.rule list)
let out_e = Dcalc.Ast.make_app intermediate_e [ thunked_new_e ] (Pos.get_position e) in
(out_e, new_ctx)
| Call (subname, subindex) ->
let all_subscope_vars, scope_dcalc_var, called_scope_return_struct =
let ( all_subscope_vars,
scope_dcalc_var,
_,
called_scope_input_struct,
called_scope_return_struct ) =
Ast.ScopeMap.find subname ctx.scopes_parameters
in
let subscope_vars_defined =
@ -380,6 +390,7 @@ let rec translate_rule (ctx : ctx) (rule : Ast.rule) (rest : Ast.rule list)
let subscope_var_not_yet_defined subvar =
not (Ast.ScopeVarMap.mem subvar subscope_vars_defined)
in
let pos_call = Pos.get_position (Ast.SubScopeName.get_info subindex) in
let subscope_args =
List.map
(fun (subvar, _) ->
@ -387,9 +398,15 @@ let rec translate_rule (ctx : ctx) (rule : Ast.rule) (rest : Ast.rule list)
Bindlib.box Dcalc.Interpreter.empty_thunked_term
else
let a_var, _ = Ast.ScopeVarMap.find subvar subscope_vars_defined in
Dcalc.Ast.make_var (a_var, Pos.get_position (Ast.SubScopeName.get_info subindex)))
Dcalc.Ast.make_var (a_var, pos_call))
all_subscope_vars
in
let subscope_struct_arg =
Bindlib.box_apply
(fun subscope_args ->
(Dcalc.Ast.ETuple (subscope_args, Some called_scope_input_struct), pos_call))
(Bindlib.box_list subscope_args)
in
let all_subscope_vars_dcalc =
List.map
(fun (subvar, tau) ->
@ -427,8 +444,8 @@ let rec translate_rule (ctx : ctx) (rule : Ast.rule) (rest : Ast.rule list)
let call_expr =
tag_with_log_entry
(Bindlib.box_apply2
(fun e u -> (Dcalc.Ast.EApp (e, u), Pos.no_pos))
subscope_func (Bindlib.box_list subscope_args))
(fun e u -> (Dcalc.Ast.EApp (e, [ u ]), Pos.no_pos))
subscope_func subscope_struct_arg)
Dcalc.Ast.EndCall
[
(sigma_name, pos_sigma);
@ -499,7 +516,10 @@ let translate_scope_decl (struct_ctx : Ast.struct_ctx) (enum_ctx : Ast.enum_ctx)
Dcalc.Ast.expr Pos.marked Bindlib.box * Dcalc.Ast.struct_ctx =
let ctx = empty_ctx struct_ctx enum_ctx sctx scope_name in
let sigma_info = Ast.ScopeName.get_info sigma.scope_decl_name in
let scope_variables, _, scope_return_struct_name = Ast.ScopeMap.find sigma.scope_decl_name sctx in
let scope_variables, _, scope_input_var, scope_input_struct_name, scope_return_struct_name =
Ast.ScopeMap.find sigma.scope_decl_name sctx
in
let pos_sigma = Pos.get_position sigma_info in
let rules, ctx = translate_rules ctx sigma.scope_decl_rules sigma_info scope_return_struct_name in
let scope_variables =
List.map
@ -508,36 +528,81 @@ let translate_scope_decl (struct_ctx : Ast.struct_ctx) (enum_ctx : Ast.enum_ctx)
(x, tau, dcalc_x))
scope_variables
in
let pos_sigma = Pos.get_position sigma_info in
(* first we create variables from the fields of the input struct *)
let rules, _ =
List.fold_right
(fun (_, tau, dvar) (acc, i) ->
let result_access =
Bindlib.box_apply
(fun r ->
( Dcalc.Ast.ETupleAccess
( r,
i,
Some scope_input_struct_name,
List.map
(fun (_, t, _) ->
( Dcalc.Ast.TArrow ((Dcalc.Ast.TLit TUnit, pos_sigma), (t, pos_sigma)),
pos_sigma ))
scope_variables ),
pos_sigma ))
(Dcalc.Ast.make_var (scope_input_var, pos_sigma))
in
( Dcalc.Ast.make_let_in dvar
(Dcalc.Ast.TArrow ((Dcalc.Ast.TLit TUnit, pos_sigma), (tau, pos_sigma)), pos_sigma)
result_access acc,
i - 1 ))
scope_variables
(rules, List.length scope_variables - 1)
in
let scope_return_struct_fields =
List.map
(fun (_, tau, dvar) ->
let struct_field_name = Ast.StructFieldName.fresh (Bindlib.name_of dvar, pos_sigma) in
let struct_field_name =
Ast.StructFieldName.fresh (Bindlib.name_of dvar ^ "_out", pos_sigma)
in
(struct_field_name, (tau, pos_sigma)))
scope_variables
in
let new_struct_ctx =
Ast.StructMap.singleton scope_return_struct_name scope_return_struct_fields
let scope_input_struct_fields =
List.map
(fun (_, tau, dvar) ->
let struct_field_name =
Ast.StructFieldName.fresh (Bindlib.name_of dvar ^ "_in", pos_sigma)
in
( struct_field_name,
(Dcalc.Ast.TArrow ((Dcalc.Ast.TLit TUnit, pos_sigma), (tau, pos_sigma)), pos_sigma) ))
scope_variables
in
( Dcalc.Ast.make_abs
(Array.of_list (List.map (fun (_, _, x) -> x) scope_variables))
rules pos_sigma
(List.map
(fun (_, tau, _) ->
(Dcalc.Ast.TArrow ((Dcalc.Ast.TLit TUnit, pos_sigma), (tau, pos_sigma)), pos_sigma))
scope_variables)
let new_struct_ctx =
Ast.StructMap.add scope_input_struct_name scope_input_struct_fields
(Ast.StructMap.singleton scope_return_struct_name scope_return_struct_fields)
in
( Dcalc.Ast.make_abs [| scope_input_var |] rules pos_sigma
[
( Dcalc.Ast.TTuple (List.map snd scope_input_struct_fields, Some scope_input_struct_name),
pos_sigma );
]
pos_sigma,
new_struct_ctx )
let build_scope_typ_from_sig (scope_sig : (Ast.ScopeVar.t * Dcalc.Ast.typ) list)
(scope_struct_name : Ast.StructName.t) (pos : Pos.t) : Dcalc.Ast.typ Pos.marked =
(scope_input_struct_name : Ast.StructName.t) (scope_return_struct_name : Ast.StructName.t)
(pos : Pos.t) : Dcalc.Ast.typ Pos.marked =
let result_typ =
(Dcalc.Ast.TTuple (List.map (fun (_, tau) -> (tau, pos)) scope_sig, Some scope_struct_name), pos)
( Dcalc.Ast.TTuple
(List.map (fun (_, tau) -> (tau, pos)) scope_sig, Some scope_return_struct_name),
pos )
in
List.fold_right
(fun (_, arg_t) acc ->
(Dcalc.Ast.TArrow ((Dcalc.Ast.TArrow ((TLit TUnit, pos), (arg_t, pos)), pos), acc), pos))
scope_sig result_typ
let input_typ =
( Dcalc.Ast.TTuple
( List.map
(fun (_, tau) -> (Dcalc.Ast.TArrow ((TLit TUnit, pos), (tau, pos)), pos))
scope_sig,
Some scope_input_struct_name ),
pos )
in
(Dcalc.Ast.TArrow (input_typ, result_typ), pos)
let translate_program (prgm : Ast.program) (top_level_scope_name : Ast.ScopeName.t) :
Dcalc.Ast.program * Dcalc.Ast.expr Pos.marked * Dependency.TVertex.t list =
@ -571,6 +636,14 @@ let translate_program (prgm : Ast.program) (top_level_scope_name : Ast.ScopeName
Ast.StructName.fresh
(Pos.map_under_mark (fun s -> s ^ "_out") (Ast.ScopeName.get_info scope_name))
in
let scope_input_var =
Dcalc.Ast.Var.make
(Pos.map_under_mark (fun s -> s ^ "_in") (Ast.ScopeName.get_info scope_name))
in
let scope_input_struct_name =
Ast.StructName.fresh
(Pos.map_under_mark (fun s -> s ^ "_in") (Ast.ScopeName.get_info scope_name))
in
( List.map
(fun (scope_var, tau) ->
let tau = translate_typ (ctx_for_typ_translation scope_name) tau in
@ -578,6 +651,8 @@ let translate_program (prgm : Ast.program) (top_level_scope_name : Ast.ScopeName
(scope_var, Pos.unmark tau))
(Ast.ScopeVarMap.bindings scope.scope_sig),
scope_dvar,
scope_input_var,
scope_input_struct_name,
scope_return_struct_name ))
prgm.program_scopes
in
@ -585,7 +660,7 @@ let translate_program (prgm : Ast.program) (top_level_scope_name : Ast.ScopeName
returning *)
let acc =
Dcalc.Ast.make_var
(let _, x, _ = Ast.ScopeMap.find top_level_scope_name sctx in
(let _, x, _, _, _ = Ast.ScopeMap.find top_level_scope_name sctx in
(x, Pos.no_pos))
in
(* the resulting expression is the list of definitions of all the scopes, ending with the
@ -595,18 +670,23 @@ let translate_program (prgm : Ast.program) (top_level_scope_name : Ast.ScopeName
(fun scope_name (acc, scopes, decl_ctx) ->
let scope = Ast.ScopeMap.find scope_name prgm.program_scopes in
let pos_scope = Pos.get_position (Ast.ScopeName.get_info scope.scope_decl_name) in
let scope_expr, scope_struct =
let scope_expr, scope_out_struct =
translate_scope_decl struct_ctx enum_ctx sctx scope_name scope
in
let scope_sig, dvar, scope_struct_name = Ast.ScopeMap.find scope_name sctx in
let scope_typ = build_scope_typ_from_sig scope_sig scope_struct_name pos_scope in
let scope_sig, dvar, _, scope_input_struct_name, scope_return_struct_name =
Ast.ScopeMap.find scope_name sctx
in
let scope_typ =
build_scope_typ_from_sig scope_sig scope_input_struct_name scope_return_struct_name
pos_scope
in
let decl_ctx =
{
decl_ctx with
Dcalc.Ast.ctx_structs =
Ast.StructMap.union
(fun _ _ -> assert false (* should not happen *))
decl_ctx.Dcalc.Ast.ctx_structs scope_struct;
decl_ctx.Dcalc.Ast.ctx_structs scope_out_struct;
}
in
( Dcalc.Ast.make_let_in dvar scope_typ scope_expr acc,

View File

@ -1,16 +1,18 @@
module Allocations_familiales = Law_source.Allocations_familiales
module AF = Allocations_familiales
open Catala.Runtime
let compute_allocations_familiales ~(current_date : CalendarLib.Date.t)
~(children : Allocations_familiales.enfant_entree array) ~(income : int)
~(residence : Allocations_familiales.collectivite) : float =
~(children : AF.enfant_entree array) ~(income : int) ~(residence : AF.collectivite) : float =
let result =
Allocations_familiales.interface_allocations_familiales
(fun _ -> date_of_calendar_date current_date)
(fun _ -> children)
no_input
(fun _ -> money_of_units_integers income)
(fun _ -> residence)
no_input
AF.interface_allocations_familiales
{
AF.date_courante_in = (fun _ -> date_of_calendar_date current_date);
AF.enfants_in = (fun _ -> children);
AF.enfants_a_charge_in = no_input;
AF.ressources_menage_in = (fun _ -> money_of_units_integers income);
AF.residence_in = (fun _ -> residence);
AF.montant_verse_in = no_input;
}
in
money_to_float result.Allocations_familiales.montant_verse
money_to_float result.AF.montant_verse_out

View File

@ -1,5 +1,5 @@
[ERROR] Error during typechecking, incompatible types:
[ERROR] --> F [Case3: any[71]]
[ERROR] --> F [Case3: any[77]]
[ERROR] --> E [Case1: integer | Case2: unit]
[ERROR]
[ERROR] Error coming from typechecking the following expression:
@ -9,7 +9,7 @@
[ERROR] | ^
[ERROR] + Article
[ERROR]
[ERROR] Type F [Case3: any[71]] coming from expression:
[ERROR] Type F [Case3: any[77]] coming from expression:
[ERROR] --> test_enum/bad/quick_pattern_2.catala
[ERROR] |
[ERROR] 28 | def y := x with Case3

View File

@ -1,5 +1,5 @@
[ERROR] Error during typechecking, incompatible types:
[ERROR] --> F [Case3: any[18] | Case4: any[19]]
[ERROR] --> F [Case3: any[20] | Case4: any[21]]
[ERROR] --> E [Case1: unit | Case2: unit]
[ERROR]
[ERROR] Error coming from typechecking the following expression:
@ -9,7 +9,7 @@
[ERROR] | ^
[ERROR] + Article
[ERROR]
[ERROR] Type F [Case3: any[18] | Case4: any[19]] coming from expression:
[ERROR] Type F [Case3: any[20] | Case4: any[21]] coming from expression:
[ERROR] --> test_enum/bad/quick_pattern_3.catala
[ERROR] |
[ERROR] 18 | def y := x with Case3

View File

@ -1,5 +1,5 @@
[ERROR] Error during typechecking, incompatible types:
[ERROR] --> F [Case3: any[18]]
[ERROR] --> F [Case3: any[20]]
[ERROR] --> E [Case1: unit | Case2: unit]
[ERROR]
[ERROR] Error coming from typechecking the following expression:
@ -9,7 +9,7 @@
[ERROR] | ^
[ERROR] + Test
[ERROR]
[ERROR] Type F [Case3: any[18]] coming from expression:
[ERROR] Type F [Case3: any[20]] coming from expression:
[ERROR] --> test_enum/bad/quick_pattern_4.catala
[ERROR] |
[ERROR] 17 | def y := x with Case3