mirror of
https://github.com/CatalaLang/catala.git
synced 2024-11-08 07:51:43 +03:00
Arrow List (#404)
This commit is contained in:
commit
5bd140ae5f
2
.gitignore
vendored
2
.gitignore
vendored
@ -13,3 +13,5 @@ legifrance_oauth*
|
||||
node_modules/
|
||||
build.ninja
|
||||
|
||||
.envrc
|
||||
.direnv
|
@ -352,33 +352,39 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm Scopelang.Ast.expr) :
|
||||
field sc_sig.scope_sig_output_struct (Expr.with_ty m typ)
|
||||
in
|
||||
match Marked.unmark typ with
|
||||
| TArrow (t_in, t_out) ->
|
||||
| TArrow (ts_in, t_out) ->
|
||||
(* Here the output scope struct field is a function so we
|
||||
eta-expand it and insert logging instructions. Invariant:
|
||||
works because user-defined functions in scope have only one
|
||||
argument. *)
|
||||
let param_var = Var.make "param" in
|
||||
works because there is no partial evaluation. *)
|
||||
let params_vars =
|
||||
ListLabels.mapi ts_in ~f:(fun i _ ->
|
||||
Var.make ("param" ^ string_of_int i))
|
||||
in
|
||||
let f_markings =
|
||||
[ScopeName.get_info scope; StructField.get_info field]
|
||||
in
|
||||
Expr.make_abs
|
||||
(Array.of_list [param_var])
|
||||
(Array.of_list params_vars)
|
||||
(tag_with_log_entry
|
||||
(tag_with_log_entry
|
||||
(Expr.eapp
|
||||
(tag_with_log_entry original_field_expr BeginCall
|
||||
f_markings)
|
||||
[
|
||||
tag_with_log_entry
|
||||
(Expr.make_var param_var (Expr.with_ty m t_in))
|
||||
(VarDef (Marked.unmark t_in))
|
||||
(f_markings @ [Marked.mark (Expr.pos e) "input"]);
|
||||
]
|
||||
(ListLabels.mapi (List.combine params_vars ts_in)
|
||||
~f:(fun i (param_var, t_in) ->
|
||||
tag_with_log_entry
|
||||
(Expr.make_var param_var (Expr.with_ty m t_in))
|
||||
(VarDef (Marked.unmark t_in))
|
||||
(f_markings
|
||||
@ [
|
||||
Marked.mark (Expr.pos e)
|
||||
("input" ^ string_of_int i);
|
||||
])))
|
||||
(Expr.with_ty m t_out))
|
||||
(VarDef (Marked.unmark t_out))
|
||||
(f_markings @ [Marked.mark (Expr.pos e) "output"]))
|
||||
EndCall f_markings)
|
||||
[t_in] (Expr.pos e)
|
||||
ts_in (Expr.pos e)
|
||||
| _ -> original_field_expr)
|
||||
(StructName.Map.find sc_sig.scope_sig_output_struct ctx.structs))
|
||||
(Expr.with_ty m (TStruct sc_sig.scope_sig_output_struct, Expr.pos e))
|
||||
@ -443,7 +449,7 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm Scopelang.Ast.expr) :
|
||||
| m -> tag_with_log_entry e1_func BeginCall m
|
||||
in
|
||||
let new_args = List.map (translate_expr ctx) args in
|
||||
let input_typ, output_typ =
|
||||
let input_typs, output_typ =
|
||||
(* NOTE: this is a temporary solution, it works because it's assume that
|
||||
all function calls are from scope variable. However, this will change
|
||||
-- for more information see
|
||||
@ -452,8 +458,9 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm Scopelang.Ast.expr) :
|
||||
let _, typ, _ = ScopeVar.Map.find (Marked.unmark var) vars in
|
||||
match typ with
|
||||
| TArrow (marked_input_typ, marked_output_typ) ->
|
||||
Marked.unmark marked_input_typ, Marked.unmark marked_output_typ
|
||||
| _ -> TAny, TAny
|
||||
( List.map Marked.unmark marked_input_typ,
|
||||
Marked.unmark marked_output_typ )
|
||||
| _ -> ListLabels.map new_args ~f:(fun _ -> TAny), TAny
|
||||
in
|
||||
match Marked.unmark f with
|
||||
| ELocation (ScopelangScopeVar var) ->
|
||||
@ -467,21 +474,26 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm Scopelang.Ast.expr) :
|
||||
TopdefName.Map.find (Marked.unmark tvar) ctx.toplevel_vars
|
||||
in
|
||||
match typ with
|
||||
| TArrow ((tin, _), (tout, _)) -> tin, tout
|
||||
| TArrow (tin, (tout, _)) -> List.map Marked.unmark tin, tout
|
||||
| _ ->
|
||||
Errors.raise_spanned_error (Expr.pos e)
|
||||
"Application of non-function toplevel variable")
|
||||
| _ -> TAny, TAny
|
||||
| _ -> ListLabels.map new_args ~f:(fun _ -> TAny), TAny
|
||||
in
|
||||
|
||||
(* Cli.debug_format "new_args %d, input_typs: %d, input_typs %a"
|
||||
(List.length new_args) (List.length input_typs) (Format.pp_print_list
|
||||
Print.typ_debug) (List.map (Marked.mark Pos.no_pos) input_typs); *)
|
||||
let new_args =
|
||||
match markings, new_args with
|
||||
| (_ :: _ as m), [new_arg] ->
|
||||
[
|
||||
tag_with_log_entry new_arg (VarDef input_typ)
|
||||
(m @ [Marked.mark (Expr.pos e) "input"]);
|
||||
]
|
||||
| _ -> new_args
|
||||
ListLabels.mapi (List.combine new_args input_typs)
|
||||
~f:(fun i (new_arg, input_typ) ->
|
||||
match markings with
|
||||
| _ :: _ as m ->
|
||||
tag_with_log_entry new_arg (VarDef input_typ)
|
||||
(m @ [Marked.mark (Expr.pos e) ("input" ^ string_of_int i)])
|
||||
| _ -> new_arg)
|
||||
in
|
||||
|
||||
let new_e = Expr.eapp e1_func new_args m in
|
||||
let new_e =
|
||||
match markings with
|
||||
@ -640,7 +652,7 @@ let translate_rule
|
||||
| OnlyInput -> tau
|
||||
| Reentrant ->
|
||||
if is_func then tau
|
||||
else TArrow ((TLit TUnit, var_def_pos), tau), var_def_pos);
|
||||
else TArrow ([TLit TUnit, var_def_pos], tau), var_def_pos);
|
||||
scope_let_expr = thunked_or_nonempty_new_e;
|
||||
scope_let_kind = SubScopeVarDefinition;
|
||||
})
|
||||
@ -935,7 +947,7 @@ let translate_scope_decl
|
||||
match var_ctx.scope_var_typ with
|
||||
| TArrow _ -> var_ctx.scope_var_typ, pos_sigma
|
||||
| _ ->
|
||||
( TArrow ((TLit TUnit, pos_sigma), (var_ctx.scope_var_typ, pos_sigma)),
|
||||
( TArrow ([TLit TUnit, pos_sigma], (var_ctx.scope_var_typ, pos_sigma)),
|
||||
pos_sigma ))
|
||||
| NoInput -> failwith "should not happen"
|
||||
in
|
||||
|
@ -522,9 +522,9 @@ let interpret_program :
|
||||
match Marked.unmark ty with
|
||||
| TArrow (ty_in, ty_out) ->
|
||||
Expr.make_abs
|
||||
[| Var.make "_" |]
|
||||
(Array.of_list @@ List.map (fun _ -> Var.make "_") ty_in)
|
||||
(Bindlib.box (ELit LEmptyError), Expr.with_ty mark_e ty_out)
|
||||
[ty_in] (Expr.mark_pos mark_e)
|
||||
ty_in (Expr.mark_pos mark_e)
|
||||
| _ ->
|
||||
Errors.raise_spanned_error (Marked.get_mark ty)
|
||||
"This scope needs input arguments to be executed. But the Catala \
|
||||
|
@ -103,7 +103,7 @@ type rule = {
|
||||
rule_id : RuleName.t;
|
||||
rule_just : expr boxed;
|
||||
rule_cons : expr boxed;
|
||||
rule_parameter : (expr Var.t * typ) option;
|
||||
rule_parameter : (expr Var.t * typ) list option;
|
||||
rule_exception : exception_situation;
|
||||
rule_label : label_situation;
|
||||
}
|
||||
@ -124,45 +124,46 @@ module Rule = struct
|
||||
let c2 = Expr.unbox r2.rule_cons in
|
||||
Expr.compare c1 c2
|
||||
| n -> n)
|
||||
| Some (v1, t1), Some (v2, t2) -> (
|
||||
match Type.compare t1 t2 with
|
||||
| 0 -> (
|
||||
let open Bindlib in
|
||||
let b1 = unbox (bind_var v1 (Expr.Box.lift r1.rule_just)) in
|
||||
let b2 = unbox (bind_var v2 (Expr.Box.lift r2.rule_just)) in
|
||||
let _, j1, j2 = unbind2 b1 b2 in
|
||||
match Expr.compare j1 j2 with
|
||||
| 0 ->
|
||||
let b1 = unbox (bind_var v1 (Expr.Box.lift r1.rule_cons)) in
|
||||
let b2 = unbox (bind_var v2 (Expr.Box.lift r2.rule_cons)) in
|
||||
let _, c1, c2 = unbind2 b1 b2 in
|
||||
Expr.compare c1 c2
|
||||
| n -> n)
|
||||
| n -> n)
|
||||
| Some l1, Some l2 ->
|
||||
ListLabels.compare l1 l2 ~cmp:(fun (v1, t1) (v2, t2) ->
|
||||
match Type.compare t1 t2 with
|
||||
| 0 -> (
|
||||
let open Bindlib in
|
||||
let b1 = unbox (bind_var v1 (Expr.Box.lift r1.rule_just)) in
|
||||
let b2 = unbox (bind_var v2 (Expr.Box.lift r2.rule_just)) in
|
||||
let _, j1, j2 = unbind2 b1 b2 in
|
||||
match Expr.compare j1 j2 with
|
||||
| 0 ->
|
||||
let b1 = unbox (bind_var v1 (Expr.Box.lift r1.rule_cons)) in
|
||||
let b2 = unbox (bind_var v2 (Expr.Box.lift r2.rule_cons)) in
|
||||
let _, c1, c2 = unbind2 b1 b2 in
|
||||
Expr.compare c1 c2
|
||||
| n -> n)
|
||||
| n -> n)
|
||||
| None, Some _ -> -1
|
||||
| Some _, None -> 1
|
||||
end
|
||||
|
||||
let empty_rule (pos : Pos.t) (have_parameter : typ option) : rule =
|
||||
let empty_rule (pos : Pos.t) (have_parameter : typ list option) : rule =
|
||||
{
|
||||
rule_just = Expr.box (ELit (LBool false), Untyped { pos });
|
||||
rule_cons = Expr.box (ELit LEmptyError, Untyped { pos });
|
||||
rule_parameter =
|
||||
(match have_parameter with
|
||||
| Some typ -> Some (Var.make "dummy", typ)
|
||||
| Some typs -> Some (List.map (fun typ -> Var.make "dummy", typ) typs)
|
||||
| None -> None);
|
||||
rule_exception = BaseCase;
|
||||
rule_id = RuleName.fresh ("empty", pos);
|
||||
rule_label = Unlabeled;
|
||||
}
|
||||
|
||||
let always_false_rule (pos : Pos.t) (have_parameter : typ option) : rule =
|
||||
let always_false_rule (pos : Pos.t) (have_parameter : typ list option) : rule =
|
||||
{
|
||||
rule_just = Expr.box (ELit (LBool true), Untyped { pos });
|
||||
rule_cons = Expr.box (ELit (LBool false), Untyped { pos });
|
||||
rule_parameter =
|
||||
(match have_parameter with
|
||||
| Some typ -> Some (Var.make "dummy", typ)
|
||||
| Some typs -> Some (List.map (fun typ -> Var.make "dummy", typ) typs)
|
||||
| None -> None);
|
||||
rule_exception = BaseCase;
|
||||
rule_id = RuleName.fresh ("always_false", pos);
|
||||
|
@ -60,15 +60,15 @@ type rule = {
|
||||
rule_id : RuleName.t;
|
||||
rule_just : expr boxed;
|
||||
rule_cons : expr boxed;
|
||||
rule_parameter : (expr Var.t * typ) option;
|
||||
rule_parameter : (expr Var.t * typ) list option;
|
||||
rule_exception : exception_situation;
|
||||
rule_label : label_situation;
|
||||
}
|
||||
|
||||
module Rule : Set.OrderedType with type t = rule
|
||||
|
||||
val empty_rule : Pos.t -> typ option -> rule
|
||||
val always_false_rule : Pos.t -> typ option -> rule
|
||||
val empty_rule : Pos.t -> typ list option -> rule
|
||||
val always_false_rule : Pos.t -> typ list option -> rule
|
||||
|
||||
type assertion = expr boxed
|
||||
type variation_typ = Increasing | Decreasing
|
||||
|
@ -29,7 +29,9 @@ let rule ctx env rule =
|
||||
let env =
|
||||
match rule.rule_parameter with
|
||||
| None -> env
|
||||
| Some (v, ty) -> Typing.Env.add_var v ty env
|
||||
| Some vars_and_types ->
|
||||
ListLabels.fold_right vars_and_types ~init:env ~f:(fun (v, t) ->
|
||||
Typing.Env.add_var v t)
|
||||
in
|
||||
(* Note: we could use the known rule type here to direct typing. We choose not
|
||||
to because it shouldn't be needed for disambiguation, and we prefer to
|
||||
|
@ -938,7 +938,12 @@ let process_default
|
||||
Name_resolution.get_def_typ ctxt (Marked.unmark def_key)
|
||||
in
|
||||
match Marked.unmark def_key_typ, param_uid with
|
||||
| TArrow (t_in, _), Some param_uid -> Some (Marked.unmark param_uid, t_in)
|
||||
| TArrow ([t_in], _), Some param_uid ->
|
||||
Some [Marked.unmark param_uid, t_in]
|
||||
| TArrow _, Some _ ->
|
||||
Errors.raise_spanned_error (Expr.pos cons)
|
||||
"This definition has a function type but there is multiple \
|
||||
arguments."
|
||||
| TArrow _, None ->
|
||||
Errors.raise_spanned_error (Expr.pos cons)
|
||||
"This definition has a function type but the parameter is missing"
|
||||
@ -1204,9 +1209,9 @@ let process_topdef
|
||||
(Marked.get_mark def.S.topdef_name)
|
||||
in
|
||||
let typ =
|
||||
List.fold_right
|
||||
(fun argty retty -> TArrow (argty, retty), ty_pos)
|
||||
arg_types body_type
|
||||
match arg_types with
|
||||
| [] -> body_type
|
||||
| _ -> TArrow (arg_types, body_type), ty_pos
|
||||
in
|
||||
{
|
||||
prgm with
|
||||
|
@ -319,7 +319,9 @@ let process_type (ctxt : context) ((naked_typ, typ_pos) : Surface.Ast.typ) : typ
|
||||
match naked_typ with
|
||||
| Surface.Ast.Base base_typ -> process_base_typ ctxt (base_typ, typ_pos)
|
||||
| Surface.Ast.Func { arg_typ; return_typ } ->
|
||||
( TArrow (process_base_typ ctxt arg_typ, process_base_typ ctxt return_typ),
|
||||
(* TODO Louis: /!\ There is only one argument in the surface syntax for
|
||||
function now. *)
|
||||
( TArrow ([process_base_typ ctxt arg_typ], process_base_typ ctxt return_typ),
|
||||
typ_pos )
|
||||
|
||||
(** Process data declaration *)
|
||||
|
@ -129,8 +129,8 @@ let rec translate_typ (tau : typ) : typ =
|
||||
| TAny -> TAny
|
||||
| TArray ts -> TArray (translate_typ ts)
|
||||
(* catala is not polymorphic *)
|
||||
| TArrow ((TLit TUnit, _), t2) -> TOption (translate_typ t2)
|
||||
| TArrow (t1, t2) -> TArrow (translate_typ t1, translate_typ t2)
|
||||
| TArrow ([(TLit TUnit, _)], t2) -> TOption (translate_typ t2)
|
||||
| TArrow (t1, t2) -> TArrow (List.map translate_typ t1, translate_typ t2)
|
||||
end
|
||||
|
||||
(** [c = disjoint_union_maps cs] Compute the disjoint union of multiple maps.
|
||||
@ -458,7 +458,7 @@ let rec translate_scope_let (ctx : 'm ctx) (lets : 'm D.expr scope_body_expr) :
|
||||
thunked, then the variable is context. If it's not thunked, it's a
|
||||
regular input. *)
|
||||
match Marked.unmark typ with
|
||||
| TArrow ((TLit TUnit, _), _) -> false
|
||||
| TArrow ([(TLit TUnit, _)], _) -> false
|
||||
| _ -> true)
|
||||
| ScopeVarDefinition | SubScopeVarDefinition | CallingSubScope
|
||||
| DestructuringSubScopeResults | Assertion ->
|
||||
|
@ -167,8 +167,11 @@ let rec format_typ (fmt : Format.formatter) (typ : typ) : unit =
|
||||
format_enum_name Ast.option_enum
|
||||
| TEnum e -> Format.fprintf fmt "%a.t" format_to_module_name (`Ename e)
|
||||
| TArrow (t1, t2) ->
|
||||
Format.fprintf fmt "@[<hov 2>%a ->@ %a@]" format_typ_with_parens t1
|
||||
format_typ_with_parens t2
|
||||
Format.fprintf fmt "@[<hov 2>%a@]"
|
||||
(Format.pp_print_list
|
||||
~pp_sep:(fun fmt () -> Format.fprintf fmt " ->@ ")
|
||||
format_typ_with_parens)
|
||||
(t1 @ [t2])
|
||||
| TArray t1 -> Format.fprintf fmt "@[%a@ array@]" format_typ_with_parens t1
|
||||
| TAny -> Format.fprintf fmt "_"
|
||||
|
||||
|
@ -78,8 +78,11 @@ module To_jsoo = struct
|
||||
Format.fprintf fmt "@[%a@ Js.js_array Js.t@]" format_typ_with_parens t1
|
||||
| TAny -> Format.fprintf fmt "Js.Unsafe.any Js.t"
|
||||
| TArrow (t1, t2) ->
|
||||
Format.fprintf fmt "(@[<hov 2>%a, @ %a@]) Js.meth_callback"
|
||||
format_typ_with_parens t1 format_typ_with_parens t2
|
||||
Format.fprintf fmt "(@[<hov 2>unit, @ %a -> %a@]) Js.meth_callback"
|
||||
(Format.pp_print_list
|
||||
~pp_sep:(fun fmt () -> Format.pp_print_string fmt " -> ")
|
||||
format_typ_with_parens)
|
||||
t1 format_typ_with_parens t2
|
||||
|
||||
let rec format_typ_to_jsoo fmt typ =
|
||||
match Marked.unmark typ with
|
||||
@ -153,13 +156,23 @@ module To_jsoo = struct
|
||||
(fun fmt (struct_field, struct_field_type) ->
|
||||
match Marked.unmark struct_field_type with
|
||||
| TArrow (t1, t2) ->
|
||||
let args_names =
|
||||
ListLabels.mapi t1 ~f:(fun i _ ->
|
||||
"function_input" ^ string_of_int i)
|
||||
in
|
||||
Format.fprintf fmt
|
||||
"@[<hov 2>method %a =@ Js.wrap_meth_callback@ @[<hv 2>(@,\
|
||||
fun input ->@ %a (%a.%a (%a input)))@]@]"
|
||||
fun _ %a ->@ %a (%a.%a %a))@]@]"
|
||||
format_struct_field_name_camel_case struct_field
|
||||
(Format.pp_print_list (fun fmt (arg_i, ti) ->
|
||||
Format.fprintf fmt "(%s: %a)" arg_i format_typ ti))
|
||||
(List.combine args_names t1)
|
||||
format_typ_to_jsoo t2 fmt_struct_name ()
|
||||
format_struct_field_name (None, struct_field)
|
||||
format_typ_of_jsoo t1
|
||||
(Format.pp_print_list (fun fmt (i, ti) ->
|
||||
Format.fprintf fmt "@[<hv 2>(%a@ %a)@]"
|
||||
format_typ_of_jsoo ti Format.pp_print_string i))
|
||||
(List.combine args_names t1)
|
||||
| _ ->
|
||||
Format.fprintf fmt "@[<hov 2>val %a =@ %a %a.%a@]"
|
||||
format_struct_field_name_camel_case struct_field
|
||||
|
@ -186,8 +186,11 @@ let rec format_typ (fmt : Format.formatter) (typ : typ) : unit =
|
||||
Format.fprintf fmt "Optional[%a]" format_typ some_typ
|
||||
| TEnum e -> Format.fprintf fmt "%a" format_enum_name e
|
||||
| TArrow (t1, t2) ->
|
||||
Format.fprintf fmt "Callable[[%a], %a]" format_typ_with_parens t1
|
||||
format_typ_with_parens t2
|
||||
Format.fprintf fmt "Callable[[%a], %a]"
|
||||
(Format.pp_print_list
|
||||
~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ")
|
||||
format_typ_with_parens)
|
||||
t1 format_typ_with_parens t2
|
||||
| TArray t1 -> Format.fprintf fmt "List[%a]" format_typ_with_parens t1
|
||||
| TAny -> Format.fprintf fmt "Any"
|
||||
|
||||
|
@ -239,7 +239,9 @@ let rec get_structs_or_enums_in_type (t : typ) : TVertexSet.t =
|
||||
| TEnum e -> TVertexSet.singleton (TVertex.Enum e)
|
||||
| TArrow (t1, t2) ->
|
||||
TVertexSet.union
|
||||
(get_structs_or_enums_in_type t1)
|
||||
(t1
|
||||
|> List.map get_structs_or_enums_in_type
|
||||
|> List.fold_left TVertexSet.union TVertexSet.empty)
|
||||
(get_structs_or_enums_in_type t2)
|
||||
| TLit _ | TAny -> TVertexSet.empty
|
||||
| TOption t1 | TArray t1 -> get_structs_or_enums_in_type t1
|
||||
|
@ -212,26 +212,29 @@ let rec rule_tree_to_expr
|
||||
~(is_reentrant_var : bool)
|
||||
(ctx : ctx)
|
||||
(def_pos : Pos.t)
|
||||
(is_func : Desugared.Ast.expr Var.t option)
|
||||
(is_func : Desugared.Ast.expr Var.t list option)
|
||||
(tree : rule_tree) : untyped Ast.expr boxed =
|
||||
let emark = Untyped { pos = def_pos } in
|
||||
let exceptions, base_rules =
|
||||
match tree with Leaf r -> [], r | Node (exceptions, r) -> exceptions, r
|
||||
in
|
||||
(* because each rule has its own variable parameter and we want to convert the
|
||||
whole rule tree into a function, we need to perform some alpha-renaming of
|
||||
all the expressions *)
|
||||
(* because each rule has its own variables parameters and we want to convert
|
||||
the whole rule tree into a function, we need to perform some alpha-renaming
|
||||
of all the expressions *)
|
||||
let substitute_parameter
|
||||
(e : Desugared.Ast.expr boxed)
|
||||
(rule : Desugared.Ast.rule) : Desugared.Ast.expr boxed =
|
||||
match is_func, rule.Desugared.Ast.rule_parameter with
|
||||
| Some new_param, Some (old_param, _) ->
|
||||
let binder = Bindlib.bind_var old_param (Marked.unmark e) in
|
||||
| Some new_params, Some old_params_with_types ->
|
||||
let old_params, _ = List.split old_params_with_types in
|
||||
let old_params = Array.of_list old_params in
|
||||
let new_params = Array.of_list new_params in
|
||||
let binder = Bindlib.bind_mvar old_params (Marked.unmark e) in
|
||||
Marked.mark (Marked.get_mark e)
|
||||
@@ Bindlib.box_apply2
|
||||
(fun binder new_param -> Bindlib.subst binder new_param)
|
||||
(fun binder new_param -> Bindlib.msubst binder new_param)
|
||||
binder
|
||||
(Bindlib.box_var new_param)
|
||||
(new_params |> Array.map Bindlib.box_var |> Bindlib.box_array)
|
||||
| None, None -> e
|
||||
| _ -> assert false
|
||||
(* should not happen *)
|
||||
@ -239,20 +242,22 @@ let rec rule_tree_to_expr
|
||||
let ctx =
|
||||
match is_func with
|
||||
| None -> ctx
|
||||
| Some new_param -> (
|
||||
match Var.Map.find_opt new_param ctx.var_mapping with
|
||||
| None ->
|
||||
let new_param_scope = Var.make (Bindlib.name_of new_param) in
|
||||
{
|
||||
ctx with
|
||||
var_mapping = Var.Map.add new_param new_param_scope ctx.var_mapping;
|
||||
}
|
||||
| Some _ ->
|
||||
(* We only create a mapping if none exists because [rule_tree_to_expr]
|
||||
is called recursively on the exceptions of the tree and we don't want
|
||||
to create a new Scopelang variable for the parameter at each tree
|
||||
level. *)
|
||||
ctx)
|
||||
| Some new_params ->
|
||||
ListLabels.fold_left new_params ~init:ctx ~f:(fun ctx new_param ->
|
||||
match Var.Map.find_opt new_param ctx.var_mapping with
|
||||
| None ->
|
||||
let new_param_scope = Var.make (Bindlib.name_of new_param) in
|
||||
{
|
||||
ctx with
|
||||
var_mapping =
|
||||
Var.Map.add new_param new_param_scope ctx.var_mapping;
|
||||
}
|
||||
| Some _ ->
|
||||
(* We only create a mapping if none exists because
|
||||
[rule_tree_to_expr] is called recursively on the exceptions of
|
||||
the tree and we don't want to create a new Scopelang variable for
|
||||
the parameter at each tree level. *)
|
||||
ctx)
|
||||
in
|
||||
let base_just_list =
|
||||
List.map
|
||||
@ -301,7 +306,8 @@ let rec rule_tree_to_expr
|
||||
in
|
||||
match is_func, (List.hd base_rules).Desugared.Ast.rule_parameter with
|
||||
| None, None -> default
|
||||
| Some new_param, Some (_, typ) ->
|
||||
| Some new_params, Some ls ->
|
||||
let _, tys = List.split ls in
|
||||
if toplevel then
|
||||
(* When we're creating a function from multiple defaults, we must check
|
||||
that the result returned by the function is not empty, unless we're
|
||||
@ -311,9 +317,12 @@ let rec rule_tree_to_expr
|
||||
let default =
|
||||
if is_reentrant_var then default else Expr.eerroronempty default emark
|
||||
in
|
||||
|
||||
Expr.make_abs
|
||||
[| Var.Map.find new_param ctx.var_mapping |]
|
||||
default [typ] def_pos
|
||||
(new_params
|
||||
|> List.map (fun x -> Var.Map.find x ctx.var_mapping)
|
||||
|> Array.of_list)
|
||||
default tys def_pos
|
||||
else default
|
||||
| _ -> (* should not happen *) assert false
|
||||
|
||||
@ -340,7 +349,7 @@ let translate_def
|
||||
let all_rules_not_func =
|
||||
RuleName.Map.for_all (fun n r -> not (is_rule_func n r)) def
|
||||
in
|
||||
let is_def_func_param_typ : typ option =
|
||||
let is_def_func_param_typ : typ list option =
|
||||
if is_def_func && all_rules_func then
|
||||
match Marked.unmark typ with
|
||||
| TArrow (t_param, _) -> Some t_param
|
||||
@ -379,7 +388,7 @@ let translate_def
|
||||
| Reentrant -> true
|
||||
| _ -> false
|
||||
in
|
||||
let top_value =
|
||||
let top_value : Desugared.Ast.rule option =
|
||||
if is_cond && ((not is_subscope_var) || (is_subscope_var && is_input)) then
|
||||
(* We add the bottom [false] value for conditions, only for the scope
|
||||
where the condition is declared. Except when the variable is an input,
|
||||
@ -419,13 +428,19 @@ let translate_def
|
||||
let m = Untyped { pos = Desugared.Ast.ScopeDef.get_position def_info } in
|
||||
let empty_error = Expr.elit LEmptyError m in
|
||||
match is_def_func_param_typ with
|
||||
| Some ty ->
|
||||
Expr.make_abs [| Var.make "_" |] empty_error [ty] (Expr.mark_pos m)
|
||||
| Some tys ->
|
||||
Expr.make_abs
|
||||
(Array.init (List.length tys) (fun _ -> Var.make "_"))
|
||||
empty_error tys (Expr.mark_pos m)
|
||||
| _ -> empty_error
|
||||
else
|
||||
rule_tree_to_expr ~toplevel:true ~is_reentrant_var:is_reentrant ctx
|
||||
(Desugared.Ast.ScopeDef.get_position def_info)
|
||||
(Option.map (fun _ -> Var.make "param") is_def_func_param_typ)
|
||||
(Option.map
|
||||
(fun l ->
|
||||
ListLabels.mapi l ~f:(fun i _ ->
|
||||
Var.make ("param" ^ string_of_int i)))
|
||||
is_def_func_param_typ)
|
||||
(match top_list, top_value with
|
||||
| [], None ->
|
||||
(* In this case, there are no rules to define the expression and no
|
||||
|
@ -74,7 +74,7 @@ and naked_typ =
|
||||
| TStruct of StructName.t
|
||||
| TEnum of EnumName.t
|
||||
| TOption of typ
|
||||
| TArrow of typ * typ
|
||||
| TArrow of typ list * typ
|
||||
| TArray of typ
|
||||
| TAny
|
||||
|
||||
|
@ -182,7 +182,7 @@ let fold_marks
|
||||
| [] -> invalid_arg "Dcalc.Ast.fold_mark"
|
||||
| Untyped _ :: _ as ms ->
|
||||
Untyped { pos = pos_f (List.map (function Untyped { pos } -> pos) ms) }
|
||||
| Typed _ :: _ ->
|
||||
| Typed _ :: _ as ms ->
|
||||
Typed
|
||||
{
|
||||
pos = pos_f (List.map (function Typed { pos; _ } -> pos) ms);
|
||||
@ -713,34 +713,27 @@ let make_abs xs e taus pos =
|
||||
let mark =
|
||||
map_mark
|
||||
(fun _ -> pos)
|
||||
(fun ety ->
|
||||
List.fold_right
|
||||
(fun tx acc -> Marked.mark pos (TArrow (tx, acc)))
|
||||
taus ety)
|
||||
(fun ety -> Marked.mark pos (TArrow (taus, ety)))
|
||||
(Marked.get_mark e)
|
||||
in
|
||||
eabs (bind xs e) taus mark
|
||||
|
||||
let make_app e u pos =
|
||||
let make_app e args pos =
|
||||
let mark =
|
||||
fold_marks
|
||||
(fun _ -> pos)
|
||||
(function
|
||||
| [] -> assert false
|
||||
| fty :: argtys ->
|
||||
List.fold_left
|
||||
(fun tf tx ->
|
||||
match Marked.unmark tf with
|
||||
| TArrow (tx', tr) ->
|
||||
assert (Type.unifiable tx.ty tx');
|
||||
(* wrong arg type *)
|
||||
tr
|
||||
| TAny -> tf
|
||||
| _ -> assert false)
|
||||
fty.ty argtys)
|
||||
(List.map Marked.get_mark (e :: u))
|
||||
| fty :: argtys -> (
|
||||
match Marked.unmark fty.ty with
|
||||
| TArrow (tx', tr) ->
|
||||
assert (Type.unifiable_list tx' (List.map (fun x -> x.ty) argtys));
|
||||
tr
|
||||
| TAny -> fty.ty
|
||||
| _ -> assert false))
|
||||
(List.map Marked.get_mark (e :: args))
|
||||
in
|
||||
eapp e u mark
|
||||
eapp e args mark
|
||||
|
||||
let empty_thunked_term mark =
|
||||
let silent = Var.make "_" in
|
||||
|
@ -404,18 +404,19 @@ let translate :
|
||||
| Eq_dur_dur -> Eq_dur_dur
|
||||
|
||||
let monomorphic_type (op, pos) =
|
||||
let ( @- ) a b = TArrow ((TLit a, pos), b), pos in
|
||||
let ( @-> ) a b = TArrow ((TLit a, pos), (TLit b, pos)), pos in
|
||||
match op with
|
||||
| Not -> TBool @-> TBool
|
||||
| GetDay -> TDate @-> TInt
|
||||
| GetMonth -> TDate @-> TInt
|
||||
| GetYear -> TDate @-> TInt
|
||||
| FirstDayOfMonth -> TDate @-> TDate
|
||||
| LastDayOfMonth -> TDate @-> TDate
|
||||
| And -> TBool @- TBool @-> TBool
|
||||
| Or -> TBool @- TBool @-> TBool
|
||||
| Xor -> TBool @- TBool @-> TBool
|
||||
let args, ret =
|
||||
match op with
|
||||
| Not -> [TBool], TBool
|
||||
| GetDay -> [TDate], TInt
|
||||
| GetMonth -> [TDate], TInt
|
||||
| GetYear -> [TDate], TInt
|
||||
| FirstDayOfMonth -> [TDate], TDate
|
||||
| LastDayOfMonth -> [TDate], TDate
|
||||
| And -> [TBool; TBool], TBool
|
||||
| Or -> [TBool; TBool], TBool
|
||||
| Xor -> [TBool; TBool], TBool
|
||||
in
|
||||
TArrow (List.map (fun tau -> TLit tau, pos) args, (TLit ret, pos)), pos
|
||||
|
||||
(** Rules for overloads definitions:
|
||||
|
||||
@ -431,62 +432,63 @@ let monomorphic_type (op, pos) =
|
||||
['a], ['b] and ['c], there should be a unique solution for the third. *)
|
||||
|
||||
let resolved_type (op, pos) =
|
||||
let ( @- ) a b = TArrow ((TLit a, pos), b), pos in
|
||||
let ( @-> ) a b = TArrow ((TLit a, pos), (TLit b, pos)), pos in
|
||||
match op with
|
||||
| Minus_int -> TInt @-> TInt
|
||||
| Minus_rat -> TRat @-> TRat
|
||||
| Minus_mon -> TMoney @-> TMoney
|
||||
| Minus_dur -> TDuration @-> TDuration
|
||||
| ToRat_int -> TInt @-> TRat
|
||||
| ToRat_mon -> TMoney @-> TRat
|
||||
| ToMoney_rat -> TRat @-> TMoney
|
||||
| Round_rat -> TRat @-> TRat
|
||||
| Round_mon -> TMoney @-> TMoney
|
||||
| Add_int_int -> TInt @- TInt @-> TInt
|
||||
| Add_rat_rat -> TRat @- TRat @-> TRat
|
||||
| Add_mon_mon -> TMoney @- TMoney @-> TMoney
|
||||
| Add_dat_dur -> TDate @- TDuration @-> TDate
|
||||
| Add_dur_dur -> TDuration @- TDuration @-> TDuration
|
||||
| Sub_int_int -> TInt @- TInt @-> TInt
|
||||
| Sub_rat_rat -> TRat @- TRat @-> TRat
|
||||
| Sub_mon_mon -> TMoney @- TMoney @-> TMoney
|
||||
| Sub_dat_dat -> TDate @- TDate @-> TDuration
|
||||
| Sub_dat_dur -> TDate @- TDuration @-> TDuration
|
||||
| Sub_dur_dur -> TDuration @- TDuration @-> TDuration
|
||||
| Mult_int_int -> TInt @- TInt @-> TInt
|
||||
| Mult_rat_rat -> TRat @- TRat @-> TRat
|
||||
| Mult_mon_rat -> TMoney @- TRat @-> TMoney
|
||||
| Mult_dur_int -> TDuration @- TInt @-> TDuration
|
||||
| Div_int_int -> TInt @- TInt @-> TRat
|
||||
| Div_rat_rat -> TRat @- TRat @-> TRat
|
||||
| Div_mon_mon -> TMoney @- TMoney @-> TRat
|
||||
| Div_mon_rat -> TMoney @- TRat @-> TMoney
|
||||
| Lt_int_int -> TInt @- TInt @-> TBool
|
||||
| Lt_rat_rat -> TRat @- TRat @-> TBool
|
||||
| Lt_mon_mon -> TMoney @- TMoney @-> TBool
|
||||
| Lt_dat_dat -> TDate @- TDate @-> TBool
|
||||
| Lt_dur_dur -> TDuration @- TDuration @-> TBool
|
||||
| Lte_int_int -> TInt @- TInt @-> TBool
|
||||
| Lte_rat_rat -> TRat @- TRat @-> TBool
|
||||
| Lte_mon_mon -> TMoney @- TMoney @-> TBool
|
||||
| Lte_dat_dat -> TDate @- TDate @-> TBool
|
||||
| Lte_dur_dur -> TDuration @- TDuration @-> TBool
|
||||
| Gt_int_int -> TInt @- TInt @-> TBool
|
||||
| Gt_rat_rat -> TRat @- TRat @-> TBool
|
||||
| Gt_mon_mon -> TMoney @- TMoney @-> TBool
|
||||
| Gt_dat_dat -> TDate @- TDate @-> TBool
|
||||
| Gt_dur_dur -> TDuration @- TDuration @-> TBool
|
||||
| Gte_int_int -> TInt @- TInt @-> TBool
|
||||
| Gte_rat_rat -> TRat @- TRat @-> TBool
|
||||
| Gte_mon_mon -> TMoney @- TMoney @-> TBool
|
||||
| Gte_dat_dat -> TDate @- TDate @-> TBool
|
||||
| Gte_dur_dur -> TDuration @- TDuration @-> TBool
|
||||
| Eq_int_int -> TInt @- TInt @-> TBool
|
||||
| Eq_rat_rat -> TRat @- TRat @-> TBool
|
||||
| Eq_mon_mon -> TMoney @- TMoney @-> TBool
|
||||
| Eq_dat_dat -> TDate @- TDate @-> TBool
|
||||
| Eq_dur_dur -> TDuration @- TDuration @-> TBool
|
||||
let args, ret =
|
||||
match op with
|
||||
| Minus_int -> [TInt], TInt
|
||||
| Minus_rat -> [TRat], TRat
|
||||
| Minus_mon -> [TMoney], TMoney
|
||||
| Minus_dur -> [TDuration], TDuration
|
||||
| ToRat_int -> [TInt], TRat
|
||||
| ToRat_mon -> [TMoney], TRat
|
||||
| ToMoney_rat -> [TRat], TMoney
|
||||
| Round_rat -> [TRat], TRat
|
||||
| Round_mon -> [TMoney], TMoney
|
||||
| Add_int_int -> [TInt; TInt], TInt
|
||||
| Add_rat_rat -> [TRat; TRat], TRat
|
||||
| Add_mon_mon -> [TMoney; TMoney], TMoney
|
||||
| Add_dat_dur -> [TDate; TDuration], TDate
|
||||
| Add_dur_dur -> [TDuration; TDuration], TDuration
|
||||
| Sub_int_int -> [TInt; TInt], TInt
|
||||
| Sub_rat_rat -> [TRat; TRat], TRat
|
||||
| Sub_mon_mon -> [TMoney; TMoney], TMoney
|
||||
| Sub_dat_dat -> [TDate; TDate], TDuration
|
||||
| Sub_dat_dur -> [TDate; TDuration], TDuration
|
||||
| Sub_dur_dur -> [TDuration; TDuration], TDuration
|
||||
| Mult_int_int -> [TInt; TInt], TInt
|
||||
| Mult_rat_rat -> [TRat; TRat], TRat
|
||||
| Mult_mon_rat -> [TMoney; TRat], TMoney
|
||||
| Mult_dur_int -> [TDuration; TInt], TDuration
|
||||
| Div_int_int -> [TInt; TInt], TRat
|
||||
| Div_rat_rat -> [TRat; TRat], TRat
|
||||
| Div_mon_mon -> [TMoney; TMoney], TRat
|
||||
| Div_mon_rat -> [TMoney; TRat], TMoney
|
||||
| Lt_int_int -> [TInt; TInt], TBool
|
||||
| Lt_rat_rat -> [TRat; TRat], TBool
|
||||
| Lt_mon_mon -> [TMoney; TMoney], TBool
|
||||
| Lt_dat_dat -> [TDate; TDate], TBool
|
||||
| Lt_dur_dur -> [TDuration; TDuration], TBool
|
||||
| Lte_int_int -> [TInt; TInt], TBool
|
||||
| Lte_rat_rat -> [TRat; TRat], TBool
|
||||
| Lte_mon_mon -> [TMoney; TMoney], TBool
|
||||
| Lte_dat_dat -> [TDate; TDate], TBool
|
||||
| Lte_dur_dur -> [TDuration; TDuration], TBool
|
||||
| Gt_int_int -> [TInt; TInt], TBool
|
||||
| Gt_rat_rat -> [TRat; TRat], TBool
|
||||
| Gt_mon_mon -> [TMoney; TMoney], TBool
|
||||
| Gt_dat_dat -> [TDate; TDate], TBool
|
||||
| Gt_dur_dur -> [TDuration; TDuration], TBool
|
||||
| Gte_int_int -> [TInt; TInt], TBool
|
||||
| Gte_rat_rat -> [TRat; TRat], TBool
|
||||
| Gte_mon_mon -> [TMoney; TMoney], TBool
|
||||
| Gte_dat_dat -> [TDate; TDate], TBool
|
||||
| Gte_dur_dur -> [TDuration; TDuration], TBool
|
||||
| Eq_int_int -> [TInt; TInt], TBool
|
||||
| Eq_rat_rat -> [TRat; TRat], TBool
|
||||
| Eq_mon_mon -> [TMoney; TMoney], TBool
|
||||
| Eq_dat_dat -> [TDate; TDate], TBool
|
||||
| Eq_dur_dur -> [TDuration; TDuration], TBool
|
||||
in
|
||||
TArrow (List.map (fun tau -> TLit tau, pos) args, (TLit ret, pos)), pos
|
||||
|
||||
let resolve_overload_aux (op : ('a, overloaded) t) (operands : typ_lit list) :
|
||||
('b, resolved) t * [ `Straight | `Reversed ] =
|
||||
|
@ -113,9 +113,15 @@ let rec typ (ctx : decl_ctx option) (fmt : Format.formatter) (ty : typ) : unit =
|
||||
(EnumConstructor.Map.bindings (EnumName.Map.find e ctx.ctx_enums))
|
||||
punctuation "]")
|
||||
| TOption t -> Format.fprintf fmt "@[<hov 2>%a@ %a@]" base_type "option" typ t
|
||||
| TArrow (t1, t2) ->
|
||||
Format.fprintf fmt "@[<hov 2>%a %a@ %a@]" typ_with_parens t1 op_style "→"
|
||||
| TArrow ([t1], t2) ->
|
||||
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@]" typ_with_parens t1 op_style "→"
|
||||
typ t2
|
||||
| TArrow (t1, t2) ->
|
||||
Format.fprintf fmt "@[<hov 2>%a%a%a@ %a@ %a@]" op_style "("
|
||||
(Format.pp_print_list
|
||||
~pp_sep:(fun fmt () -> Format.fprintf fmt "%a@ " op_style ",")
|
||||
typ_with_parens)
|
||||
t1 op_style ")" op_style "→" typ t2
|
||||
| TArray t1 ->
|
||||
Format.fprintf fmt "@[<hov 2>%a@ %a@]" base_type "collection" typ t1
|
||||
| TAny -> base_type fmt "any"
|
||||
|
@ -155,7 +155,7 @@ let build_typ_from_sig
|
||||
(pos : Pos.t) : typ =
|
||||
let input_typ = Marked.mark pos (TStruct scope_input_struct_name) in
|
||||
let result_typ = Marked.mark pos (TStruct scope_return_struct_name) in
|
||||
Marked.mark pos (TArrow (input_typ, result_typ))
|
||||
Marked.mark pos (TArrow ([input_typ], result_typ))
|
||||
|
||||
type 'e scope_name_or_var = ScopeName of ScopeName.t | ScopeVar of 'e Var.t
|
||||
|
||||
|
@ -29,7 +29,7 @@ let rec equal ty1 ty2 =
|
||||
| TStruct n1, TStruct n2 -> StructName.equal n1 n2
|
||||
| TEnum n1, TEnum n2 -> EnumName.equal n1 n2
|
||||
| TOption t1, TOption t2 -> equal t1 t2
|
||||
| TArrow (t1, t1'), TArrow (t2, t2') -> equal t1 t2 && equal t1' t2'
|
||||
| TArrow (t1, t1'), TArrow (t2, t2') -> equal_list t1 t2 && equal t1' t2'
|
||||
| TArray t1, TArray t2 -> equal t1 t2
|
||||
| TAny, TAny -> true
|
||||
| ( ( TLit _ | TTuple _ | TStruct _ | TEnum _ | TOption _ | TArrow _
|
||||
@ -49,7 +49,8 @@ let rec unifiable ty1 ty2 =
|
||||
| TStruct n1, TStruct n2 -> StructName.equal n1 n2
|
||||
| TEnum n1, TEnum n2 -> EnumName.equal n1 n2
|
||||
| TOption t1, TOption t2 -> unifiable t1 t2
|
||||
| TArrow (t1, t1'), TArrow (t2, t2') -> unifiable t1 t2 && unifiable t1' t2'
|
||||
| TArrow (t1, t1'), TArrow (t2, t2') ->
|
||||
unifiable_list t1 t2 && unifiable t1' t2'
|
||||
| TArray t1, TArray t2 -> unifiable t1 t2
|
||||
| ( (TLit _ | TTuple _ | TStruct _ | TEnum _ | TOption _ | TArrow _ | TArray _),
|
||||
_ ) ->
|
||||
@ -66,7 +67,7 @@ let rec compare ty1 ty2 =
|
||||
| TEnum en1, TEnum en2 -> EnumName.compare en1 en2
|
||||
| TOption t1, TOption t2 -> compare t1 t2
|
||||
| TArrow (a1, b1), TArrow (a2, b2) -> (
|
||||
match compare a1 a2 with 0 -> compare b1 b2 | n -> n)
|
||||
match List.compare compare a1 a2 with 0 -> compare b1 b2 | n -> n)
|
||||
| TArray t1, TArray t2 -> compare t1 t2
|
||||
| TAny, TAny -> 0
|
||||
| TLit _, _ -> -1
|
||||
|
@ -19,8 +19,9 @@ type t = Definitions.typ
|
||||
val equal : t -> t -> bool
|
||||
val equal_list : t list -> t list -> bool
|
||||
val compare : t -> t -> int
|
||||
|
||||
val unifiable : t -> t -> bool
|
||||
|
||||
val unifiable_list : t list -> t list -> bool
|
||||
(** Similar to [equal], but allows TAny holes *)
|
||||
|
||||
val arrow_return : t -> t
|
||||
|
@ -39,7 +39,7 @@ type unionfind_typ = naked_typ Marked.pos UnionFind.elem
|
||||
|
||||
and naked_typ =
|
||||
| TLit of A.typ_lit
|
||||
| TArrow of unionfind_typ * unionfind_typ
|
||||
| TArrow of unionfind_typ list * unionfind_typ
|
||||
| TTuple of unionfind_typ list
|
||||
| TStruct of A.StructName.t
|
||||
| TEnum of A.EnumName.t
|
||||
@ -56,7 +56,7 @@ let rec typ_to_ast ?(unsafe = false) (ty : unionfind_typ) : A.typ =
|
||||
| TStruct s -> A.TStruct s, pos
|
||||
| TEnum e -> A.TEnum e, pos
|
||||
| TOption t -> A.TOption (typ_to_ast t), pos
|
||||
| TArrow (t1, t2) -> A.TArrow (typ_to_ast t1, typ_to_ast t2), pos
|
||||
| TArrow (t1, t2) -> A.TArrow (List.map typ_to_ast t1, typ_to_ast t2), pos
|
||||
| TArray t1 -> A.TArray (typ_to_ast t1), pos
|
||||
| TAny _ ->
|
||||
if unsafe then A.TAny, pos
|
||||
@ -73,14 +73,14 @@ let rec all_resolved ty =
|
||||
| TAny _ -> false
|
||||
| TLit _ | TStruct _ | TEnum _ -> true
|
||||
| TOption t1 | TArray t1 -> all_resolved t1
|
||||
| TArrow (t1, t2) -> all_resolved t1 && all_resolved t2
|
||||
| TArrow (t1, t2) -> List.for_all all_resolved t1 && all_resolved t2
|
||||
| TTuple ts -> List.for_all all_resolved ts
|
||||
|
||||
let rec ast_to_typ (ty : A.typ) : unionfind_typ =
|
||||
let ty' =
|
||||
match Marked.unmark ty with
|
||||
| A.TLit l -> TLit l
|
||||
| A.TArrow (t1, t2) -> TArrow (ast_to_typ t1, ast_to_typ t2)
|
||||
| A.TArrow (t1, t2) -> TArrow (List.map ast_to_typ t1, ast_to_typ t2)
|
||||
| A.TTuple ts -> TTuple (List.map ast_to_typ ts)
|
||||
| A.TStruct s -> TStruct s
|
||||
| A.TEnum e -> TEnum e
|
||||
@ -118,9 +118,15 @@ let rec format_typ
|
||||
| TEnum e -> Format.fprintf fmt "%a" A.EnumName.format_t e
|
||||
| TOption t ->
|
||||
Format.fprintf fmt "@[<hov 2>%a@ %s@]" format_typ_with_parens t "eoption"
|
||||
| TArrow (t1, t2) ->
|
||||
Format.fprintf fmt "@[<hov 2>%a →@ %a@]" format_typ_with_parens t1
|
||||
| TArrow ([t1], t2) ->
|
||||
Format.fprintf fmt "@[<hov 2>%a@ →@ %a@]" format_typ_with_parens t1
|
||||
format_typ t2
|
||||
| TArrow (t1, t2) ->
|
||||
Format.fprintf fmt "@[<hov 2>(%a)@ →@ %a@]"
|
||||
(Format.pp_print_list
|
||||
~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ")
|
||||
format_typ_with_parens)
|
||||
t1 format_typ t2
|
||||
| TArray t1 -> (
|
||||
match Marked.unmark (UnionFind.get (UnionFind.find t1)) with
|
||||
| TAny _ when not !Cli.debug_flag -> Format.pp_print_string fmt "collection"
|
||||
@ -149,12 +155,13 @@ let rec unify
|
||||
let () =
|
||||
match Marked.unmark t1_repr, Marked.unmark t2_repr with
|
||||
| TLit tl1, TLit tl2 -> if tl1 <> tl2 then raise_type_error ()
|
||||
| TArrow (t11, t12), TArrow (t21, t22) ->
|
||||
| TArrow (t11, t12), TArrow (t21, t22) -> (
|
||||
unify e t12 t22;
|
||||
unify e t11 t21
|
||||
| TTuple ts1, TTuple ts2 ->
|
||||
if List.length ts1 = List.length ts2 then List.iter2 (unify e) ts1 ts2
|
||||
else raise_type_error ()
|
||||
try List.iter2 (unify e) t11 t21
|
||||
with Invalid_argument _ -> raise_type_error ())
|
||||
| TTuple ts1, TTuple ts2 -> (
|
||||
try List.iter2 (unify e) ts1 ts2
|
||||
with Invalid_argument _ -> raise_type_error ())
|
||||
| TStruct s1, TStruct s2 ->
|
||||
if not (A.StructName.equal s1 s2) then raise_type_error ()
|
||||
| TEnum e1, TEnum e2 ->
|
||||
@ -240,19 +247,19 @@ let polymorphic_op_type (op : ('a, Operator.polymorphic) A.operator Marked.pos)
|
||||
let it = lazy (UnionFind.make (TLit TInt, pos)) in
|
||||
let array a = lazy (UnionFind.make (TArray (Lazy.force a), pos)) in
|
||||
let ( @-> ) x y =
|
||||
lazy (UnionFind.make (TArrow (Lazy.force x, Lazy.force y), pos))
|
||||
lazy (UnionFind.make (TArrow (List.map Lazy.force x, Lazy.force y), pos))
|
||||
in
|
||||
let ty =
|
||||
match Marked.unmark op with
|
||||
| Fold -> (any2 @-> any @-> any2) @-> any2 @-> array any @-> any2
|
||||
| Eq -> any @-> any @-> bt
|
||||
| Map -> (any @-> any2) @-> array any @-> array any2
|
||||
| Filter -> (any @-> bt) @-> array any @-> array any
|
||||
| Reduce -> (any @-> any @-> any) @-> any @-> array any @-> any
|
||||
| Concat -> array any @-> array any @-> array any
|
||||
| Log (PosRecordIfTrueBool, _) -> bt @-> bt
|
||||
| Log _ -> any @-> any
|
||||
| Length -> array any @-> it
|
||||
| Fold -> [[any2; any] @-> any2; any2; array any] @-> any2
|
||||
| Eq -> [any; any] @-> bt
|
||||
| Map -> [[any] @-> any2; array any] @-> array any2
|
||||
| Filter -> [[any] @-> bt; array any] @-> array any
|
||||
| Reduce -> [[any; any] @-> any; any; array any] @-> any
|
||||
| Concat -> [array any; array any] @-> array any
|
||||
| Log (PosRecordIfTrueBool, _) -> [bt] @-> bt
|
||||
| Log _ -> [any] @-> any
|
||||
| Length -> [array any] @-> it
|
||||
in
|
||||
Lazy.force ty
|
||||
|
||||
@ -512,7 +519,10 @@ and typecheck_expr_top_down :
|
||||
A.EnumConstructor.Map.mapi
|
||||
(fun c_name e ->
|
||||
let c_ty = A.EnumConstructor.Map.find c_name cases_ty in
|
||||
let e_ty = unionfind ~pos:e (TArrow (ast_to_typ c_ty, t_ret)) in
|
||||
(* For now our constructors are limited to zero or one argument. If
|
||||
there is a change to allow for multiple arguments, it might be
|
||||
easier to use tuples directly. *)
|
||||
let e_ty = unionfind ~pos:e (TArrow ([ast_to_typ c_ty], t_ret)) in
|
||||
typecheck_expr_top_down ctx env e_ty e)
|
||||
cases
|
||||
in
|
||||
@ -571,11 +581,7 @@ and typecheck_expr_top_down :
|
||||
else
|
||||
let tau_args = List.map ast_to_typ t_args in
|
||||
let t_ret = unionfind (TAny (Any.fresh ())) in
|
||||
let t_func =
|
||||
List.fold_right
|
||||
(fun t_arg acc -> unionfind (TArrow (t_arg, acc)))
|
||||
tau_args t_ret
|
||||
in
|
||||
let t_func = unionfind (TArrow (tau_args, t_ret)) in
|
||||
let mark = uf_mark t_func in
|
||||
assert (List.for_all all_resolved tau_args);
|
||||
let xs, body = Bindlib.unmbind binder in
|
||||
@ -590,11 +596,7 @@ and typecheck_expr_top_down :
|
||||
Expr.eabs binder' (List.map typ_to_ast tau_args) mark
|
||||
| A.EApp { f = (EOp { op; tys }, _) as e1; args } ->
|
||||
let t_args = List.map ast_to_typ tys in
|
||||
let t_func =
|
||||
List.fold_right
|
||||
(fun t_arg acc -> unionfind (TArrow (t_arg, acc)))
|
||||
t_args tau
|
||||
in
|
||||
let t_func = unionfind (TArrow (t_args, tau)) in
|
||||
let e1', args' =
|
||||
Operator.kind_dispatch op
|
||||
~polymorphic:(fun _ ->
|
||||
@ -630,22 +632,14 @@ and typecheck_expr_top_down :
|
||||
of the arguments if [f] is [EAbs] before disambiguation. This is also the
|
||||
right order for the [let-in] form. *)
|
||||
let t_args = List.map (fun _ -> unionfind (TAny (Any.fresh ()))) args in
|
||||
let t_func =
|
||||
List.fold_right
|
||||
(fun t_arg acc -> unionfind (TArrow (t_arg, acc)))
|
||||
t_args tau
|
||||
in
|
||||
let t_func = unionfind (TArrow (t_args, tau)) in
|
||||
let args' = List.map2 (typecheck_expr_top_down ctx env) t_args args in
|
||||
let e1' = typecheck_expr_top_down ctx env t_func e1 in
|
||||
Expr.eapp e1' args' context_mark
|
||||
| A.EOp { op; tys } ->
|
||||
let tys' = List.map ast_to_typ tys in
|
||||
let t_ret = unionfind (TAny (Any.fresh ())) in
|
||||
let t_func =
|
||||
List.fold_right
|
||||
(fun t_arg acc -> unionfind (TArrow (t_arg, acc)))
|
||||
tys' t_ret
|
||||
in
|
||||
let t_func = unionfind (TArrow (tys', t_ret)) in
|
||||
unify ctx e t_func tau;
|
||||
let tys, mark =
|
||||
Operator.kind_dispatch op
|
||||
@ -790,7 +784,7 @@ let scope_body ctx env body =
|
||||
UnionFind.make
|
||||
(Marked.mark
|
||||
(get_pos body.A.scope_body_output_struct)
|
||||
(TArrow (ty_in, ty_out))) )
|
||||
(TArrow ([ty_in], ty_out))) )
|
||||
|
||||
let rec scopes ctx env = function
|
||||
| A.Nil -> Bindlib.box A.Nil
|
||||
|
@ -405,10 +405,12 @@ let find_or_create_funcdecl (ctx : context) (v : typed expr Var.t) (ty : typ) :
|
||||
| None -> (
|
||||
match Marked.unmark ty with
|
||||
| TArrow (t1, t2) ->
|
||||
let ctx, z3_t1 = translate_typ ctx (Marked.unmark t1) in
|
||||
let ctx, z3_t1 =
|
||||
List.fold_left_map translate_typ ctx (List.map Marked.unmark t1)
|
||||
in
|
||||
let ctx, z3_t2 = translate_typ ctx (Marked.unmark t2) in
|
||||
let name = unique_name v in
|
||||
let fd = FuncDecl.mk_func_decl_s ctx.ctx_z3 name [z3_t1] z3_t2 in
|
||||
let fd = FuncDecl.mk_func_decl_s ctx.ctx_z3 name z3_t1 z3_t2 in
|
||||
let ctx = add_funcdecl v fd ctx in
|
||||
let ctx = add_z3var name v ty ctx in
|
||||
ctx, fd
|
||||
|
12
flake.nix
12
flake.nix
@ -1,7 +1,7 @@
|
||||
{
|
||||
inputs = {
|
||||
flake-utils.url = github:numtide/flake-utils;
|
||||
nixpkgs.url = github:NixOS/nixpkgs/nixos-unstable;
|
||||
flake-utils.url = "github:numtide/flake-utils";
|
||||
nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable";
|
||||
};
|
||||
|
||||
outputs = {nixpkgs, flake-utils, ...}:
|
||||
@ -19,14 +19,18 @@
|
||||
};
|
||||
defaultPackage = packages.catala;
|
||||
devShell = pkgs.mkShell {
|
||||
inputsFrom = [packages.catala];
|
||||
inputsFrom = [ packages.catala ];
|
||||
buildInputs = [
|
||||
pkgs.inotify-tools
|
||||
ocamlPackages.merlin
|
||||
pkgs.ocamlformat
|
||||
pkgs.ocamlformat_0_21_0
|
||||
ocamlPackages.ocp-indent
|
||||
ocamlPackages.utop
|
||||
ocamlPackages.odoc
|
||||
ocamlPackages.ocaml-lsp
|
||||
pkgs.groff
|
||||
pkgs.obelisk
|
||||
pkgs.ninja
|
||||
];
|
||||
};
|
||||
}
|
||||
|
@ -19,7 +19,7 @@ scope B:
|
||||
$ catala Scopelang -s B
|
||||
let scope B (b: bool|input) =
|
||||
let a.f : integer → integer =
|
||||
λ (param: integer) → ⟨b && param >! 0 ⊢ param -! 1⟩;
|
||||
λ (param0: integer) → ⟨b && param0 >! 0 ⊢ param0 -! 1⟩;
|
||||
call A[a]
|
||||
```
|
||||
|
||||
@ -29,8 +29,8 @@ let A =
|
||||
λ (A_in: A_in {"f_in": integer → integer}) →
|
||||
let f : integer → integer = A_in."f_in" in
|
||||
let f1 : integer → integer =
|
||||
λ (param: integer) → error_empty
|
||||
⟨f param | true ⊢ ⟨true ⊢ param +! 1⟩⟩ in
|
||||
λ (param0: integer) → error_empty
|
||||
⟨f param0 | true ⊢ ⟨true ⊢ param0 +! 1⟩⟩ in
|
||||
A { }
|
||||
```
|
||||
|
||||
@ -40,7 +40,7 @@ let B =
|
||||
λ (B_in: B_in {"b_in": bool}) →
|
||||
let b : bool = B_in."b_in" in
|
||||
let a.f : integer → integer =
|
||||
λ (param: integer) → ⟨b && param >! 0 ⊢ param -! 1⟩ in
|
||||
λ (param0: integer) → ⟨b && param0 >! 0 ⊢ param0 -! 1⟩ in
|
||||
let result : A {} = A (A_in { "f_in"= a.f }) in
|
||||
B { }
|
||||
```
|
||||
|
@ -0,0 +1,17 @@
|
||||
## Test basic toplevel values defs
|
||||
|
||||
```catala
|
||||
declaration glob1 content decimal equals 44.12
|
||||
|
||||
declaration scope S:
|
||||
output a content boolean
|
||||
|
||||
scope S:
|
||||
definition a equals glob1 >= 30.
|
||||
```
|
||||
|
||||
```catala-test-inline
|
||||
$ catala Interpret -s S
|
||||
[RESULT] Computation successful! Results:
|
||||
[RESULT] a = true
|
||||
```
|
@ -26,7 +26,7 @@ $ catala Interpret -t -s HousingComputation
|
||||
│ ‾‾‾‾‾‾
|
||||
|
||||
[LOG] → HousingComputation.f
|
||||
[LOG] ≔ HousingComputation.f.input: 1
|
||||
[LOG] ≔ HousingComputation.f.input0: 1
|
||||
[LOG] ☛ Definition applied:
|
||||
┌─⯈ tests/test_scope/good/scope_call3.catala_en:7.13-14:
|
||||
└─┐
|
||||
@ -43,10 +43,10 @@ $ catala Interpret -t -s HousingComputation
|
||||
7 │ definition f of x equals (output of RentComputation).f of x
|
||||
│ ‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾
|
||||
|
||||
[LOG] ≔ RentComputation.direct.output: RentComputation { "f"= λ (param: integer) → RentComputation { "f"= λ (param1: integer) → error_empty ⟨true ⊢ λ (param2: integer) → error_empty ⟨true ⊢ param2 +! 1⟩ param1 +! 1⟩ }."f" param }
|
||||
[LOG] ≔ RentComputation.direct.output: RentComputation { "f"= λ (param0: integer) → RentComputation { "f"= λ (param01: integer) → error_empty ⟨true ⊢ λ (param02: integer) → error_empty ⟨true ⊢ param02 +! 1⟩ param01 +! 1⟩ }."f" param0 }
|
||||
[LOG] ← RentComputation.direct
|
||||
[LOG] → RentComputation.f
|
||||
[LOG] ≔ RentComputation.f.input: 1
|
||||
[LOG] ≔ RentComputation.f.input0: 1
|
||||
[LOG] ☛ Definition applied:
|
||||
┌─⯈ tests/test_scope/good/scope_call3.catala_en:16.13-14:
|
||||
└──┐
|
||||
@ -54,7 +54,7 @@ $ catala Interpret -t -s HousingComputation
|
||||
│ ‾
|
||||
|
||||
[LOG] → RentComputation.g
|
||||
[LOG] ≔ RentComputation.g.input: 2
|
||||
[LOG] ≔ RentComputation.g.input0: 2
|
||||
[LOG] ☛ Definition applied:
|
||||
┌─⯈ tests/test_scope/good/scope_call3.catala_en:15.13-14:
|
||||
└──┐
|
||||
@ -70,20 +70,20 @@ $ catala Interpret -t -s HousingComputation
|
||||
[LOG] ≔ HousingComputation.result: 3
|
||||
[RESULT] Computation successful! Results:
|
||||
[RESULT] f =
|
||||
λ (param: integer) → error_empty
|
||||
λ (param0: integer) → error_empty
|
||||
⟨true ⊢
|
||||
let result : RentComputation {"f": integer → integer} =
|
||||
λ (RentComputation_in: RentComputation_in {}) →
|
||||
let g : integer → integer = error_empty
|
||||
(λ (param1: integer) → error_empty
|
||||
⟨true ⊢ param1 +! 1⟩) in
|
||||
(λ (param01: integer) → error_empty
|
||||
⟨true ⊢ param01 +! 1⟩) in
|
||||
let f : integer → integer = error_empty
|
||||
(λ (param1: integer) → error_empty
|
||||
⟨true ⊢ g param1 +! 1⟩) in
|
||||
(λ (param01: integer) → error_empty
|
||||
⟨true ⊢ g param01 +! 1⟩) in
|
||||
RentComputation { "f"= f } RentComputation_in { } in
|
||||
let result1 : RentComputation {"f": integer → integer} =
|
||||
RentComputation { "f"=
|
||||
λ (param1: integer) → result."f" param1 } in
|
||||
if true then result1 else ∅ ."f" param⟩
|
||||
λ (param01: integer) → result."f" param01 } in
|
||||
if true then result1 else ∅ ."f" param0⟩
|
||||
[RESULT] result = 3
|
||||
```
|
||||
|
Loading…
Reference in New Issue
Block a user