mirror of
https://github.com/CatalaLang/catala.git
synced 2024-11-08 07:51:43 +03:00
implementation of fold, reduce, map & filter in without exceptions
Work in progress: not working for filter & map
This commit is contained in:
parent
61830bc348
commit
69ac8ca929
@ -203,8 +203,9 @@ let _ = monad_handle_default
|
||||
|
||||
let trans_var ctx (x : 'm D.expr Var.t) : 'm Ast.expr Var.t =
|
||||
let new_ = (Var.Map.find x ctx).var in
|
||||
Cli.debug_format "before: %a after: %a" Print.var_debug x Print.var_debug new_;
|
||||
|
||||
(* Cli.debug_format "before: %a after: %a" Print.var_debug x Print.var_debug
|
||||
new_; *)
|
||||
new_
|
||||
|
||||
let trans_op : type m. (dcalc, m) Op.t -> (lcalc, m) Op.t =
|
||||
@ -300,12 +301,125 @@ let rec trans ctx (e : 'm D.expr) : (lcalc, 'm mark) boxed_gexpr =
|
||||
monad_bind_var (trans ctx' body) var' (trans ctx arg) ~mark
|
||||
| EApp { f = EApp { f = EOp { op = Op.Log _; _ }, _; args = _ }, _; _ } ->
|
||||
assert false
|
||||
(* | EApp { f = EOp { op = Op.Fold; tys }, opmark; args = [f; init; l] } -> (*
|
||||
*) let x1 = Var.make "x1" in let x2 = var.make "x2" in let body = monad let
|
||||
f' = assert false in monad_mbind (Expr.eop (trans_op Op.Fold) tys opmark)
|
||||
[f'; trans ctx init; trans ctx l] ~mark | EApp { f = EOp { op = Op.Fold; _
|
||||
}, _; _ } -> (* Cannot happend: folds must be fully determined *) assert
|
||||
false *)
|
||||
| EApp { f = EOp { op = Op.Fold; tys }, opmark; args = [f; init; l] } ->
|
||||
(* The function f should have type b -> a -> a. Hence, its translation has
|
||||
type [b] -> [a] -> option [a]. But we need a function of type option [b]
|
||||
-> option [a] -> option [a] for the type checking of fold. Hence, we
|
||||
"iota-expand" the function as follows: [λ x y. bindm x y. [f] x y] *)
|
||||
let x1 = Var.make "x1" in
|
||||
let x2 = Var.make "x2" in
|
||||
let f' =
|
||||
monad_bind_cont ~mark
|
||||
(fun f ->
|
||||
monad_return ~mark
|
||||
(Expr.eabs
|
||||
(Expr.bind [| x1; x2 |]
|
||||
(monad_mbind_cont ~mark
|
||||
(fun vars ->
|
||||
Expr.eapp (Expr.evar f m)
|
||||
(ListLabels.map vars ~f:(fun v -> Expr.evar v m))
|
||||
m)
|
||||
[Expr.evar x1 m; Expr.evar x2 m]))
|
||||
[TAny, pos; TAny, pos]
|
||||
m))
|
||||
(trans ctx f)
|
||||
in
|
||||
monad_mbind
|
||||
(Expr.eop (trans_op Op.Fold) tys opmark)
|
||||
[f'; monad_return ~mark (trans ctx init); trans ctx l]
|
||||
~mark
|
||||
| EApp { f = EOp { op = Op.Fold; _ }, _; _ } ->
|
||||
(* Cannot happend: folds must be fully determined *) assert false
|
||||
| EApp { f = EOp { op = Op.Reduce; tys }, opmark; args = [f; init; l] } ->
|
||||
(* The function f should have type b -> a -> a. Hence, its translation has
|
||||
type [b] -> [a] -> option [a]. But we need a function of type option [b]
|
||||
-> option [a] -> option [a] for the type checking of fold. Hence, we
|
||||
"iota-expand" the function as follows: [λ x y. bindm x y. [f] x y] *)
|
||||
let x1 = Var.make "x1" in
|
||||
let x2 = Var.make "x2" in
|
||||
let f' =
|
||||
monad_bind_cont ~mark
|
||||
(fun f ->
|
||||
monad_return ~mark
|
||||
(Expr.eabs
|
||||
(Expr.bind [| x1; x2 |]
|
||||
(monad_mbind_cont ~mark
|
||||
(fun vars ->
|
||||
Expr.eapp (Expr.evar f m)
|
||||
(ListLabels.map vars ~f:(fun v -> Expr.evar v m))
|
||||
m)
|
||||
[Expr.evar x1 m; Expr.evar x2 m]))
|
||||
[TAny, pos; TAny, pos]
|
||||
m))
|
||||
(trans ctx f)
|
||||
in
|
||||
monad_mbind
|
||||
(Expr.eop (trans_op Op.Reduce) tys opmark)
|
||||
[f'; monad_return ~mark (trans ctx init); trans ctx l]
|
||||
~mark
|
||||
| EApp { f = EOp { op = Op.Reduce; _ }, _; _ } ->
|
||||
(* Cannot happend: folds must be fully determined *) assert false
|
||||
| EApp { f = EOp { op = Op.Map; tys }, opmark; args = [f; l] } ->
|
||||
(* The function f should have type b -> a -> a. Hence, its translation has
|
||||
type [b] -> [a] -> option [a]. But we need a function of type option [b]
|
||||
-> option [a] -> option [a] for the type checking of fold. Hence, we
|
||||
"iota-expand" the function as follows: [λ x y. bindm x y. [f] x y] *)
|
||||
let x1 = Var.make "x1" in
|
||||
let f' =
|
||||
monad_bind_cont ~mark
|
||||
(fun f ->
|
||||
monad_return ~mark
|
||||
(Expr.eabs
|
||||
(Expr.bind [| x1 |]
|
||||
(monad_mbind_cont ~mark
|
||||
(fun vars ->
|
||||
Expr.eapp (Expr.evar f m)
|
||||
(ListLabels.map vars ~f:(fun v -> Expr.evar v m))
|
||||
m)
|
||||
[Expr.evar x1 m]))
|
||||
[TAny, pos]
|
||||
m))
|
||||
(trans ctx f)
|
||||
in
|
||||
monad_mbind_cont
|
||||
(fun vars ->
|
||||
monad_return ~mark
|
||||
(Expr.eapp
|
||||
(Expr.eop (trans_op Op.Map) tys opmark)
|
||||
(ListLabels.map vars ~f:(fun v -> Expr.evar v m))
|
||||
mark))
|
||||
[f'; trans ctx l]
|
||||
~mark
|
||||
| EApp { f = EOp { op = Op.Map; _ }, _; _ } ->
|
||||
(* Cannot happend: folds must be fully determined *) assert false
|
||||
| EApp { f = EOp { op = Op.Filter; tys }, opmark; args = [f; l] } ->
|
||||
(* The function f should have type b -> a -> a. Hence, its translation has
|
||||
type [b] -> [a] -> option [a]. But we need a function of type option [b]
|
||||
-> option [a] -> option [a] for the type checking of fold. Hence, we
|
||||
"iota-expand" the function as follows: [λ x y. bindm x y. [f] x y] *)
|
||||
let x1 = Var.make "x1" in
|
||||
let f' =
|
||||
monad_bind_cont ~mark
|
||||
(fun f ->
|
||||
monad_return ~mark
|
||||
(Expr.eabs
|
||||
(Expr.bind [| x1 |]
|
||||
(monad_mbind_cont ~mark
|
||||
(fun vars ->
|
||||
Expr.eapp (Expr.evar f m)
|
||||
(ListLabels.map vars ~f:(fun v -> Expr.evar v m))
|
||||
m)
|
||||
[Expr.evar x1 m]))
|
||||
[TAny, pos]
|
||||
m))
|
||||
(trans ctx f)
|
||||
in
|
||||
monad_mbind
|
||||
(Expr.eop (trans_op Op.Filter) tys opmark)
|
||||
[f'; trans ctx l]
|
||||
~mark
|
||||
| EApp { f = EOp { op = Op.Filter; _ }, _; _ } ->
|
||||
(* Cannot happend: folds must be fully determined *) assert false
|
||||
| EApp { f = EOp { op; tys }, opmark; args } ->
|
||||
let res =
|
||||
monad_mmap
|
||||
@ -495,7 +609,7 @@ let rec trans_scope_let ctx s =
|
||||
let next_var' = Var.translate next_var in
|
||||
let ctx' =
|
||||
Var.Map.add next_var
|
||||
{ info_pure = false; is_scope = false; var = next_var' }
|
||||
{ info_pure = true; is_scope = false; var = next_var' }
|
||||
ctx
|
||||
in
|
||||
|
||||
|
@ -18,5 +18,4 @@ let translate_program_with_exceptions =
|
||||
Compile_with_exceptions.translate_program
|
||||
|
||||
let translate_program_without_exceptions p =
|
||||
Shared_ast.Typing.program ~leave_unresolved:true
|
||||
(Compile_without_exceptions.translate_program p)
|
||||
Compile_without_exceptions.translate_program p
|
||||
|
@ -19,7 +19,7 @@ val translate_program_with_exceptions : 'm Dcalc.Ast.program -> 'm Ast.program
|
||||
translation uses exceptions to handle empty default terms. *)
|
||||
|
||||
val translate_program_without_exceptions :
|
||||
Shared_ast.typed Dcalc.Ast.program -> Shared_ast.typed Ast.program
|
||||
Shared_ast.typed Dcalc.Ast.program -> Shared_ast.untyped Ast.program
|
||||
(** Translation from the default calculus to the lambda calculus. This
|
||||
translation uses an option monad to handle empty defaults terms. This
|
||||
transformation is one piece to permit to compile toward legacy languages
|
||||
|
@ -94,7 +94,9 @@ let rec ast_to_typ (ty : A.typ) : unionfind_typ =
|
||||
|
||||
let typ_needs_parens (t : unionfind_typ) : bool =
|
||||
let t = UnionFind.get (UnionFind.find t) in
|
||||
match Marked.unmark t with TArrow _ | TArray _ -> true | _ -> false
|
||||
match Marked.unmark t with
|
||||
| TArrow _ | TArray _ | TOption _ -> true
|
||||
| _ -> false
|
||||
|
||||
let rec format_typ
|
||||
(ctx : A.decl_ctx)
|
||||
@ -117,7 +119,8 @@ let rec format_typ
|
||||
| TStruct s -> Format.fprintf fmt "%a" A.StructName.format_t s
|
||||
| TEnum e -> Format.fprintf fmt "%a" A.EnumName.format_t e
|
||||
| TOption t ->
|
||||
Format.fprintf fmt "@[<hov 2>%a@ %s@]" format_typ_with_parens t "eoption"
|
||||
Format.fprintf fmt "@[<hov 2>(%a)@ %s@]" format_typ_with_parens t
|
||||
"ffeoption"
|
||||
| TArrow ([t1], t2) ->
|
||||
Format.fprintf fmt "@[<hov 2>%a@ →@ %a@]" format_typ_with_parens t1
|
||||
format_typ t2
|
||||
@ -379,13 +382,13 @@ and typecheck_expr_top_down :
|
||||
| A.Typed { A.ty; _ } -> unify ctx e tau (ast_to_typ ty)
|
||||
in
|
||||
let context_mark = { uf = tau; pos = pos_e } in
|
||||
let uf_mark uf =
|
||||
let mark_with_tau_and_unify uf =
|
||||
(* Unify with the supplied type first, and return the mark *)
|
||||
unify ctx e uf tau;
|
||||
{ uf; pos = pos_e }
|
||||
in
|
||||
let unionfind ?(pos = e) t = UnionFind.make (add_pos pos t) in
|
||||
let ty_mark ty = uf_mark (unionfind ty) in
|
||||
let ty_mark ty = mark_with_tau_and_unify (unionfind ty) in
|
||||
match Marked.unmark e with
|
||||
| A.ELocation loc ->
|
||||
let ty_opt =
|
||||
@ -403,7 +406,7 @@ and typecheck_expr_top_down :
|
||||
Errors.raise_spanned_error pos_e "Reference to %a not found"
|
||||
(Expr.format ctx) e
|
||||
in
|
||||
Expr.elocation loc (uf_mark (ast_to_typ ty))
|
||||
Expr.elocation loc (mark_with_tau_and_unify (ast_to_typ ty))
|
||||
| A.EStruct { name; fields } ->
|
||||
let mark = ty_mark (TStruct name) in
|
||||
let str_ast = A.StructName.Map.find name ctx.A.ctx_structs in
|
||||
@ -496,7 +499,7 @@ and typecheck_expr_top_down :
|
||||
in
|
||||
A.StructField.Map.find field str
|
||||
in
|
||||
let mark = uf_mark fld_ty in
|
||||
let mark = mark_with_tau_and_unify fld_ty in
|
||||
Expr.edstructaccess e_struct' field (Some name) mark
|
||||
| A.EStructAccess { e = e_struct; name; field } ->
|
||||
let fld_ty =
|
||||
@ -517,48 +520,53 @@ and typecheck_expr_top_down :
|
||||
"Structure %a doesn't define a field %a" A.StructName.format_t name
|
||||
A.StructField.format_t field
|
||||
in
|
||||
let mark = uf_mark fld_ty in
|
||||
let mark = mark_with_tau_and_unify fld_ty in
|
||||
let e_struct' =
|
||||
typecheck_expr_top_down ~leave_unresolved ctx env
|
||||
(unionfind (TStruct name)) e_struct
|
||||
in
|
||||
Expr.estructaccess e_struct' field name mark
|
||||
| A.EInj { name; cons; e = e_enum }
|
||||
when name = Definitions.option_enum && cons = Definitions.some_constr ->
|
||||
when Definitions.EnumName.compare name Definitions.option_enum = 0
|
||||
&& Definitions.EnumConstructor.compare cons Definitions.some_constr = 0
|
||||
->
|
||||
let cell_type = unionfind (TAny (Any.fresh ())) in
|
||||
let mark = uf_mark (unionfind (TOption cell_type)) in
|
||||
let mark = mark_with_tau_and_unify (unionfind (TOption cell_type)) in
|
||||
let e_enum' =
|
||||
typecheck_expr_top_down ~leave_unresolved ctx env cell_type e_enum
|
||||
in
|
||||
Expr.einj e_enum' cons name mark
|
||||
| A.EInj { name; cons; e = e_enum }
|
||||
when name = Definitions.option_enum && cons = Definitions.none_constr ->
|
||||
when Definitions.EnumName.compare name Definitions.option_enum = 0
|
||||
&& Definitions.EnumConstructor.compare cons Definitions.none_constr = 0
|
||||
->
|
||||
let cell_type = unionfind (TAny (Any.fresh ())) in
|
||||
let mark = uf_mark (unionfind (TOption cell_type)) in
|
||||
let mark = mark_with_tau_and_unify (unionfind (TOption cell_type)) in
|
||||
let e_enum' =
|
||||
typecheck_expr_top_down ~leave_unresolved ctx env (unionfind (TLit TUnit))
|
||||
e_enum
|
||||
in
|
||||
Expr.einj e_enum' cons name mark
|
||||
| A.EInj { name; cons; e = e_enum } ->
|
||||
let mark = uf_mark (unionfind (TEnum name)) in
|
||||
let mark = mark_with_tau_and_unify (unionfind (TEnum name)) in
|
||||
let e_enum' =
|
||||
typecheck_expr_top_down ~leave_unresolved ctx env
|
||||
(A.EnumConstructor.Map.find cons (A.EnumName.Map.find name env.enums))
|
||||
e_enum
|
||||
in
|
||||
Expr.einj e_enum' cons name mark
|
||||
| A.EMatch { e = e1; name; cases } when name = Definitions.option_enum ->
|
||||
let cell_type = TAny (Any.fresh ()) in
|
||||
let t_arg = unionfind ~pos:e1 (TOption (unionfind ~pos:e1 cell_type)) in
|
||||
| A.EMatch { e = e1; name; cases }
|
||||
when Definitions.EnumName.compare name Definitions.option_enum = 0 ->
|
||||
let cell_type = unionfind ~pos:e1 (TAny (Any.fresh ())) in
|
||||
let t_arg = unionfind ~pos:e1 (TOption cell_type) in
|
||||
let cases_ty =
|
||||
ListLabels.fold_right2
|
||||
[A.none_constr; A.some_constr]
|
||||
[TLit TUnit; cell_type] ~f:A.EnumConstructor.Map.add
|
||||
~init:A.EnumConstructor.Map.empty
|
||||
[unionfind ~pos:e1 (TLit TUnit); cell_type]
|
||||
~f:A.EnumConstructor.Map.add ~init:A.EnumConstructor.Map.empty
|
||||
in
|
||||
let t_ret = TAny (Any.fresh ()) in
|
||||
let mark = uf_mark (unionfind ~pos:e t_ret) in
|
||||
let t_ret = unionfind ~pos:e (TAny (Any.fresh ())) in
|
||||
let mark = mark_with_tau_and_unify t_ret in
|
||||
let e1' = typecheck_expr_top_down ~leave_unresolved ctx env t_arg e1 in
|
||||
let cases' =
|
||||
A.EnumConstructor.MapLabels.merge cases cases_ty ~f:(fun _ e e_ty ->
|
||||
@ -566,8 +574,7 @@ and typecheck_expr_top_down :
|
||||
| Some e, Some e_ty ->
|
||||
Some
|
||||
(typecheck_expr_top_down ~leave_unresolved ctx env
|
||||
(unionfind ~pos:e
|
||||
(TArrow ([unionfind ~pos:e e_ty], unionfind ~pos:e t_ret)))
|
||||
(unionfind ~pos:e (TArrow ([e_ty], t_ret)))
|
||||
e)
|
||||
| _ -> assert false)
|
||||
in
|
||||
@ -576,7 +583,7 @@ and typecheck_expr_top_down :
|
||||
| A.EMatch { e = e1; name; cases } ->
|
||||
let cases_ty = A.EnumName.Map.find name ctx.A.ctx_enums in
|
||||
let t_ret = unionfind ~pos:e1 (TAny (Any.fresh ())) in
|
||||
let mark = uf_mark t_ret in
|
||||
let mark = mark_with_tau_and_unify t_ret in
|
||||
let e1' =
|
||||
typecheck_expr_top_down ~leave_unresolved ctx env (unionfind (TEnum name))
|
||||
e1
|
||||
@ -597,7 +604,7 @@ and typecheck_expr_top_down :
|
||||
let scope_out_struct =
|
||||
(A.ScopeName.Map.find scope ctx.ctx_scopes).out_struct_name
|
||||
in
|
||||
let mark = uf_mark (unionfind (TStruct scope_out_struct)) in
|
||||
let mark = mark_with_tau_and_unify (unionfind (TStruct scope_out_struct)) in
|
||||
let vars = A.ScopeName.Map.find scope env.scopes in
|
||||
let args' =
|
||||
A.ScopeVar.Map.mapi
|
||||
@ -622,11 +629,11 @@ and typecheck_expr_top_down :
|
||||
Errors.raise_spanned_error pos_e
|
||||
"Variable %s not found in the current context" (Bindlib.name_of v)
|
||||
in
|
||||
Expr.evar (Var.translate v) (uf_mark tau')
|
||||
Expr.evar (Var.translate v) (mark_with_tau_and_unify tau')
|
||||
| A.ELit lit -> Expr.elit lit (ty_mark (lit_type lit))
|
||||
| A.ETuple es ->
|
||||
let tys = List.map (fun _ -> unionfind (TAny (Any.fresh ()))) es in
|
||||
let mark = uf_mark (unionfind (TTuple tys)) in
|
||||
let mark = mark_with_tau_and_unify (unionfind (TTuple tys)) in
|
||||
let es' =
|
||||
List.map2 (typecheck_expr_top_down ~leave_unresolved ctx env) tys es
|
||||
in
|
||||
@ -656,8 +663,7 @@ and typecheck_expr_top_down :
|
||||
let tau_args = List.map ast_to_typ t_args in
|
||||
let t_ret = unionfind (TAny (Any.fresh ())) in
|
||||
let t_func = unionfind (TArrow (tau_args, t_ret)) in
|
||||
let mark = uf_mark t_func in
|
||||
if not leave_unresolved then assert (List.for_all all_resolved tau_args);
|
||||
let mark = mark_with_tau_and_unify t_func in
|
||||
let xs, body = Bindlib.unmbind binder in
|
||||
let xs' = Array.map Var.translate xs in
|
||||
let env =
|
||||
@ -742,10 +748,12 @@ and typecheck_expr_top_down :
|
||||
let tys, mark =
|
||||
Operator.kind_dispatch op
|
||||
~polymorphic:(fun op ->
|
||||
tys, uf_mark (polymorphic_op_type (Marked.mark pos_e op)))
|
||||
( tys,
|
||||
mark_with_tau_and_unify (polymorphic_op_type (Marked.mark pos_e op))
|
||||
))
|
||||
~monomorphic:(fun op ->
|
||||
let mark =
|
||||
uf_mark
|
||||
mark_with_tau_and_unify
|
||||
(ast_to_typ (Operator.monomorphic_type (Marked.mark pos_e op)))
|
||||
in
|
||||
List.map (typ_to_ast ~leave_unresolved) tys', mark)
|
||||
@ -756,7 +764,8 @@ and typecheck_expr_top_down :
|
||||
{ uf = t_func; pos = pos_e } ))
|
||||
~resolved:(fun op ->
|
||||
let mark =
|
||||
uf_mark (ast_to_typ (Operator.resolved_type (Marked.mark pos_e op)))
|
||||
mark_with_tau_and_unify
|
||||
(ast_to_typ (Operator.resolved_type (Marked.mark pos_e op)))
|
||||
in
|
||||
List.map (typ_to_ast ~leave_unresolved) tys', mark)
|
||||
in
|
||||
@ -782,7 +791,7 @@ and typecheck_expr_top_down :
|
||||
in
|
||||
Expr.eifthenelse cond' et' ef' context_mark
|
||||
| A.EAssert e1 ->
|
||||
let mark = uf_mark (unionfind (TLit TUnit)) in
|
||||
let mark = mark_with_tau_and_unify (unionfind (TLit TUnit)) in
|
||||
let e1' =
|
||||
typecheck_expr_top_down ~leave_unresolved ctx env
|
||||
(unionfind ~pos:e1 (TLit TBool))
|
||||
@ -794,7 +803,7 @@ and typecheck_expr_top_down :
|
||||
Expr.eerroronempty e1' context_mark
|
||||
| A.EArray es ->
|
||||
let cell_type = unionfind (TAny (Any.fresh ())) in
|
||||
let mark = uf_mark (unionfind (TArray cell_type)) in
|
||||
let mark = mark_with_tau_and_unify (unionfind (TArray cell_type)) in
|
||||
let es' =
|
||||
List.map (typecheck_expr_top_down ~leave_unresolved ctx env cell_type) es
|
||||
in
|
||||
|
Loading…
Reference in New Issue
Block a user