mirror of
https://github.com/CatalaLang/catala.git
synced 2024-09-19 16:28:12 +03:00
Generated OCaml has valid syntax
This commit is contained in:
parent
27b6303982
commit
50bccd8d13
1
examples/tutorial_en/.gitignore
vendored
1
examples/tutorial_en/.gitignore
vendored
@ -11,3 +11,4 @@ _minted*
|
||||
*.toc
|
||||
*.pyg
|
||||
*.d
|
||||
*.ml
|
@ -14,6 +14,8 @@
|
||||
|
||||
open Utils
|
||||
|
||||
module ScopeName : Uid.Id with type info = Uid.MarkedString.info = Uid.Make (Uid.MarkedString) ()
|
||||
|
||||
module StructName : Uid.Id with type info = Uid.MarkedString.info = Uid.Make (Uid.MarkedString) ()
|
||||
|
||||
module StructFieldName : Uid.Id with type info = Uid.MarkedString.info =
|
||||
@ -164,3 +166,5 @@ let make_let_in (x : Var.t) (tau : typ Pos.marked) (e1 : expr Pos.marked Bindlib
|
||||
(Bindlib.box_list [ e1 ])
|
||||
|
||||
type binder = (expr, expr Pos.marked) Bindlib.binder
|
||||
|
||||
type program = { decl_ctx : decl_ctx; scopes : (Var.t * expr Pos.marked) list }
|
||||
|
@ -134,7 +134,7 @@ let translate_def (def_info : Ast.ScopeDef.t) (def : Ast.rule Ast.RuleMap.t)
|
||||
Bindlib.unbox
|
||||
(rule_tree_to_expr ~toplevel:true
|
||||
(Ast.ScopeDef.get_position def_info)
|
||||
(Option.map (fun _ -> Scopelang.Ast.Var.make ("ρ", Pos.no_pos)) is_def_func)
|
||||
(Option.map (fun _ -> Scopelang.Ast.Var.make ("param", Pos.no_pos)) is_def_func)
|
||||
( match top_list with
|
||||
| [] ->
|
||||
(* In this case, there are no rules to define the expression *)
|
||||
|
@ -140,17 +140,17 @@ let driver (source_file : Pos.input_file) (debug : bool) (unstyled : bool)
|
||||
Cli.debug_print "Collecting rules...";
|
||||
let prgm = Desugared.Desugared_to_scope.translate_program prgm in
|
||||
Cli.debug_print "Translating to default calculus...";
|
||||
let prgm, ctx = Scopelang.Scope_to_dcalc.translate_program prgm scope_uid in
|
||||
let prgm, prgm_expr = Scopelang.Scope_to_dcalc.translate_program prgm scope_uid in
|
||||
(* Cli.debug_print (Format.asprintf "Output program:@\n%a" (Dcalc.Print.format_expr ctx)
|
||||
prgm); *)
|
||||
Cli.debug_print "Typechecking...";
|
||||
let _typ = Dcalc.Typing.infer_type ctx prgm in
|
||||
let _typ = Dcalc.Typing.infer_type prgm.decl_ctx prgm_expr in
|
||||
(* Cli.debug_print (Format.asprintf "Typechecking results :@\n%a" Dcalc.Print.format_typ
|
||||
typ); *)
|
||||
match backend with
|
||||
| Cli.Run ->
|
||||
Cli.debug_print "Starting interpretation...";
|
||||
let results = Dcalc.Interpreter.interpret_program ctx prgm in
|
||||
let results = Dcalc.Interpreter.interpret_program prgm.decl_ctx prgm_expr in
|
||||
let results =
|
||||
List.sort
|
||||
(fun (v1, _) (v2, _) -> String.compare (Bindlib.name_of v1) (Bindlib.name_of v2))
|
||||
@ -163,12 +163,13 @@ let driver (source_file : Pos.input_file) (debug : bool) (unstyled : bool)
|
||||
(fun (var, result) ->
|
||||
Cli.result_print
|
||||
(Format.asprintf "@[<hov 2>%s@ =@ %a@]" (Bindlib.name_of var)
|
||||
(Dcalc.Print.format_expr ctx) result))
|
||||
(Dcalc.Print.format_expr prgm.decl_ctx)
|
||||
result))
|
||||
results;
|
||||
0
|
||||
| Cli.OCaml ->
|
||||
Cli.debug_print "Compiling program into OCaml...";
|
||||
let prgm, ctx = Lcalc.Compile_with_exceptions.translate_expr prgm ctx in
|
||||
let prgm = Lcalc.Compile_with_exceptions.translate_program prgm in
|
||||
let source_file =
|
||||
match source_file with
|
||||
| FileName f -> f
|
||||
@ -183,7 +184,7 @@ let driver (source_file : Pos.input_file) (debug : bool) (unstyled : bool)
|
||||
Cli.debug_print (Printf.sprintf "Writing to %s..." output_file);
|
||||
let oc = open_out output_file in
|
||||
let fmt = Format.formatter_of_out_channel oc in
|
||||
Lcalc.To_ocaml.format_program ctx fmt prgm;
|
||||
Lcalc.To_ocaml.format_program fmt prgm;
|
||||
close_out oc;
|
||||
0
|
||||
| _ -> assert false
|
||||
|
@ -94,3 +94,5 @@ let make_let_in (x : Var.t) (tau : D.typ Pos.marked) (e1 : expr Pos.marked Bindl
|
||||
(Bindlib.box_list [ e1 ])
|
||||
|
||||
type binder = (expr, expr Pos.marked) Bindlib.binder
|
||||
|
||||
type program = { decl_ctx : D.decl_ctx; scopes : (Var.t * expr Pos.marked) list }
|
||||
|
@ -27,6 +27,8 @@ let option_ctx =
|
||||
[ D.EnumConstructor.fresh ("Some", Pos.no_pos); D.EnumConstructor.fresh ("None", Pos.no_pos) ]
|
||||
option_sig
|
||||
|
||||
let handle_default pos = A.make_var (A.Var.make ("fold_exceptions", pos), pos)
|
||||
|
||||
let translate_lit (l : D.lit) : A.expr =
|
||||
match l with
|
||||
| D.LBool l -> A.ELit (A.LBool l)
|
||||
@ -39,69 +41,17 @@ let translate_lit (l : D.lit) : A.expr =
|
||||
| D.LEmptyError -> A.ERaise A.EmptyError
|
||||
|
||||
let rec translate_default (ctx : ctx) (exceptions : D.expr Pos.marked list)
|
||||
(_just : D.expr Pos.marked) (_cons : D.expr Pos.marked) (pos_default : Pos.t) :
|
||||
(just : D.expr Pos.marked) (cons : D.expr Pos.marked) (pos_default : Pos.t) :
|
||||
A.expr Pos.marked Bindlib.box =
|
||||
let none_expr =
|
||||
Bindlib.box (A.EInj ((A.ELit A.LUnit, pos_default), 1, option_enum, option_sig), pos_default)
|
||||
in
|
||||
let some_expr e =
|
||||
Bindlib.box_apply (fun e -> (A.EInj (e, 0, option_enum, option_sig), pos_default)) e
|
||||
in
|
||||
let acc_var = A.Var.make ("acc", pos_default) in
|
||||
let exc_var = A.Var.make ("exc", pos_default) in
|
||||
let acc_some_var = A.Var.make ("acc", pos_default) in
|
||||
let acc_none_var = A.Var.make ("_", pos_default) in
|
||||
let exc_some_var = A.Var.make ("_", pos_default) in
|
||||
let exc_none_var = A.Var.make ("_", pos_default) in
|
||||
let exc_none_case_body = some_expr (A.make_var (acc_some_var, pos_default)) in
|
||||
let exc_some_case_body = Bindlib.box (A.ERaise A.ConflictError, pos_default) in
|
||||
let acc_some_case_body =
|
||||
Bindlib.box_apply4
|
||||
(fun some_exc_var exc_none_case exc_some_case none_expr ->
|
||||
( A.EMatch
|
||||
( (A.ECatch (some_exc_var, A.EmptyError, none_expr), pos_default),
|
||||
[ exc_some_case; exc_none_case ],
|
||||
option_enum ),
|
||||
pos_default ))
|
||||
(some_expr (A.make_var (exc_var, pos_default)))
|
||||
(A.make_abs [| exc_none_var |] exc_none_case_body pos_default
|
||||
[ (D.TLit D.TUnit, pos_default) ] pos_default)
|
||||
(A.make_abs [| exc_some_var |] exc_some_case_body pos_default [ (D.TAny, pos_default) ]
|
||||
pos_default)
|
||||
none_expr
|
||||
in
|
||||
let acc_none_case_body =
|
||||
Bindlib.box_apply2
|
||||
(fun some_exc_var none_expr ->
|
||||
(A.ECatch (some_exc_var, A.EmptyError, none_expr), pos_default))
|
||||
(some_expr (A.make_var (exc_var, pos_default)))
|
||||
none_expr
|
||||
in
|
||||
let fold_body =
|
||||
Bindlib.box_apply3
|
||||
(fun acc_var acc_none_case acc_some_case ->
|
||||
(A.EMatch (acc_var, [ acc_some_case; acc_none_case ], option_enum), pos_default))
|
||||
(A.make_var (acc_var, pos_default))
|
||||
(A.make_abs [| acc_none_var |] acc_none_case_body pos_default
|
||||
[ (D.TLit D.TUnit, pos_default) ] pos_default)
|
||||
(A.make_abs [| acc_some_var |] acc_some_case_body pos_default [ (D.TAny, pos_default) ]
|
||||
pos_default)
|
||||
in
|
||||
let fold_func =
|
||||
A.make_abs [| acc_var; exc_var |] fold_body pos_default
|
||||
[ (D.TAny, pos_default); (D.TAny, pos_default) ]
|
||||
pos_default
|
||||
in
|
||||
let exceptions = List.map (translate_expr ctx) exceptions in
|
||||
let exceptions =
|
||||
A.make_app
|
||||
(Bindlib.box (A.EOp (D.Ternop D.Fold), pos_default))
|
||||
A.make_app (handle_default pos_default)
|
||||
[
|
||||
fold_func;
|
||||
none_expr;
|
||||
Bindlib.box_apply
|
||||
(fun exceptions -> (A.EArray exceptions, pos_default))
|
||||
(Bindlib.box_list exceptions);
|
||||
translate_expr ctx just;
|
||||
translate_expr ctx cons;
|
||||
]
|
||||
pos_default
|
||||
in
|
||||
@ -163,6 +113,27 @@ and translate_expr (ctx : ctx) (e : D.expr Pos.marked) : A.expr Pos.marked Bindl
|
||||
| D.EDefault (exceptions, just, cons) ->
|
||||
translate_default ctx exceptions just cons (Pos.get_position e)
|
||||
|
||||
let translate_expr (e : D.expr Pos.marked) (ctx : D.decl_ctx) : A.expr Pos.marked * D.decl_ctx =
|
||||
( Bindlib.unbox (translate_expr D.VarMap.empty e),
|
||||
{ ctx with D.ctx_enums = D.EnumMap.add option_enum option_ctx ctx.D.ctx_enums } )
|
||||
let translate_program (prgm : D.program) : A.program =
|
||||
{
|
||||
scopes =
|
||||
(let acc, _ =
|
||||
List.fold_left
|
||||
(fun ((acc, ctx) : 'a * A.Var.t D.VarMap.t) (n, e) ->
|
||||
let new_n = A.Var.make (Bindlib.name_of n, Pos.no_pos) in
|
||||
let new_acc =
|
||||
( new_n,
|
||||
Bindlib.unbox
|
||||
(translate_expr (D.VarMap.map (fun v -> A.make_var (v, Pos.no_pos)) ctx) e) )
|
||||
:: acc
|
||||
in
|
||||
let new_ctx = D.VarMap.add n new_n ctx in
|
||||
(new_acc, new_ctx))
|
||||
([], D.VarMap.empty) prgm.scopes
|
||||
in
|
||||
List.rev acc);
|
||||
decl_ctx =
|
||||
{
|
||||
prgm.decl_ctx with
|
||||
D.ctx_enums = D.EnumMap.add option_enum option_ctx prgm.decl_ctx.D.ctx_enums;
|
||||
};
|
||||
}
|
||||
|
@ -16,17 +16,14 @@ open Utils
|
||||
open Ast
|
||||
|
||||
let format_lit (fmt : Format.formatter) (l : lit Pos.marked) : unit =
|
||||
Dcalc.Print.format_lit fmt
|
||||
(Pos.same_pos_as
|
||||
( match Pos.unmark l with
|
||||
| LBool b -> Dcalc.Ast.LBool b
|
||||
| LInt i -> Dcalc.Ast.LInt i
|
||||
| LUnit -> Dcalc.Ast.LUnit
|
||||
| LRat i -> Dcalc.Ast.LRat i
|
||||
| LMoney e -> Dcalc.Ast.LMoney e
|
||||
| LDate d -> Dcalc.Ast.LDate d
|
||||
| LDuration d -> Dcalc.Ast.LDuration d )
|
||||
l)
|
||||
match Pos.unmark l with
|
||||
| LBool b -> Dcalc.Print.format_lit fmt (Pos.same_pos_as (Dcalc.Ast.LBool b) l)
|
||||
| LInt i -> Dcalc.Print.format_lit fmt (Pos.same_pos_as (Dcalc.Ast.LInt i) l)
|
||||
| LUnit -> Dcalc.Print.format_lit fmt (Pos.same_pos_as Dcalc.Ast.LUnit l)
|
||||
| LRat i -> Dcalc.Print.format_lit fmt (Pos.same_pos_as (Dcalc.Ast.LRat i) l)
|
||||
| LMoney e -> Format.fprintf fmt "mk_money@ %.2f" Q.(to_float (of_bigint e / of_int 100))
|
||||
| LDate d -> Dcalc.Print.format_lit fmt (Pos.same_pos_as (Dcalc.Ast.LDate d) l)
|
||||
| LDuration d -> Dcalc.Print.format_lit fmt (Pos.same_pos_as (Dcalc.Ast.LDuration d) l)
|
||||
|
||||
let format_op_kind (fmt : Format.formatter) (k : Dcalc.Ast.op_kind) =
|
||||
Format.fprintf fmt "%s"
|
||||
@ -107,7 +104,11 @@ let rec format_typ (fmt : Format.formatter) (typ : Dcalc.Ast.typ Pos.marked) : u
|
||||
| TAny -> Format.fprintf fmt "_"
|
||||
|
||||
let format_var (fmt : Format.formatter) (v : Var.t) : unit =
|
||||
Format.fprintf fmt "%s" (String.lowercase_ascii (Bindlib.name_of v))
|
||||
let lowercase_name = String.lowercase_ascii (Bindlib.name_of v) in
|
||||
let lowercase_name =
|
||||
Re.Pcre.substitute ~rex:(Re.Pcre.regexp "\\.") ~subst:(fun _ -> "_dot_") lowercase_name
|
||||
in
|
||||
Format.fprintf fmt "%s" lowercase_name
|
||||
|
||||
let needs_parens (_e : expr Pos.marked) : bool = true
|
||||
|
||||
@ -134,11 +135,11 @@ let rec format_expr (ctx : Dcalc.Ast.decl_ctx) (fmt : Format.formatter) (e : exp
|
||||
(fun fmt e -> Format.fprintf fmt "%a" format_expr e))
|
||||
es
|
||||
| ETuple (es, Some s) ->
|
||||
Format.fprintf fmt "%a {@[<hov 2>%a@]}" Dcalc.Ast.StructName.format_t s
|
||||
Format.fprintf fmt "{@[<hov 2>%a@]}"
|
||||
(Format.pp_print_list
|
||||
~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ")
|
||||
~pp_sep:(fun fmt () -> Format.fprintf fmt ";@ ")
|
||||
(fun fmt (e, struct_field) ->
|
||||
Format.fprintf fmt "\"%a\":@ %a" Dcalc.Ast.StructFieldName.format_t struct_field
|
||||
Format.fprintf fmt "%a=@ %a" Dcalc.Ast.StructFieldName.format_t struct_field
|
||||
format_expr e))
|
||||
(List.combine es (List.map fst (Dcalc.Ast.StructMap.find s ctx.ctx_structs)))
|
||||
| EArray es ->
|
||||
@ -147,11 +148,17 @@ let rec format_expr (ctx : Dcalc.Ast.decl_ctx) (fmt : Format.formatter) (e : exp
|
||||
~pp_sep:(fun fmt () -> Format.fprintf fmt ";@ ")
|
||||
(fun fmt e -> Format.fprintf fmt "%a" format_expr e))
|
||||
es
|
||||
| ETupleAccess (e1, n, s, _ts) -> (
|
||||
| ETupleAccess (e1, n, s, ts) -> (
|
||||
match s with
|
||||
| None -> Format.fprintf fmt "%a.%d" format_expr e1 n
|
||||
| None ->
|
||||
Format.fprintf fmt "let@ %a@ = %a@ in@ x"
|
||||
(Format.pp_print_list
|
||||
~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ")
|
||||
(fun fmt i -> Format.fprintf fmt "%s" (if i = n then "x" else "_")))
|
||||
(List.mapi (fun i _ -> i) ts)
|
||||
format_expr e1
|
||||
| Some s ->
|
||||
Format.fprintf fmt "%a.\"%a\"" format_expr e1 Dcalc.Ast.StructFieldName.format_t
|
||||
Format.fprintf fmt "%a.%a" format_expr e1 Dcalc.Ast.StructFieldName.format_t
|
||||
(fst (List.nth (Dcalc.Ast.StructMap.find s ctx.ctx_structs) n)) )
|
||||
| EInj (e, n, en, _ts) ->
|
||||
Format.fprintf fmt "@[<hov 2>%a@ %a@]" Dcalc.Ast.EnumConstructor.format_t
|
||||
@ -218,7 +225,7 @@ let rec format_expr (ctx : Dcalc.Ast.decl_ctx) (fmt : Format.formatter) (e : exp
|
||||
| EAssert e' -> Format.fprintf fmt "@[<hov 2>assert@ (%a)@]" format_expr e'
|
||||
| ERaise exc -> Format.fprintf fmt "raise@ %a" format_exception exc
|
||||
| ECatch (e1, exc, e2) ->
|
||||
Format.fprintf fmt "@[<hov 2>try@ %a@ with %a -> %a@]" format_expr e1 format_exception exc
|
||||
Format.fprintf fmt "@[<hov 2>try@ %a@ with@ %a@ ->@ %a@]" format_expr e1 format_exception exc
|
||||
format_expr e2
|
||||
|
||||
let format_ctx (fmt : Format.formatter) (ctx : D.decl_ctx) : unit =
|
||||
@ -235,7 +242,7 @@ let format_ctx (fmt : Format.formatter) (ctx : D.decl_ctx) : unit =
|
||||
struct_fields))
|
||||
(Dcalc.Ast.StructMap.bindings ctx.Dcalc.Ast.ctx_structs)
|
||||
(Format.pp_print_list
|
||||
~pp_sep:(fun fmt () -> Format.fprintf fmt "\n")
|
||||
~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n")
|
||||
(fun _fmt (enum_name, enum_cons) ->
|
||||
Format.fprintf fmt "type %a =@\n@[<hov 2> %a@]@\n" format_enum_name enum_name
|
||||
(Format.pp_print_list
|
||||
@ -244,7 +251,15 @@ let format_ctx (fmt : Format.formatter) (ctx : D.decl_ctx) : unit =
|
||||
Format.fprintf fmt "| %a@ of@ %a" Dcalc.Ast.EnumConstructor.format_t enum_cons
|
||||
format_typ enum_cons_type))
|
||||
enum_cons))
|
||||
(Dcalc.Ast.EnumMap.bindings ctx.Dcalc.Ast.ctx_enums)
|
||||
(List.filter
|
||||
(* option is a polymorphic type which we don't handle well... *)
|
||||
(fun (e, _) -> e <> Compile_with_exceptions.option_enum)
|
||||
(Dcalc.Ast.EnumMap.bindings ctx.Dcalc.Ast.ctx_enums))
|
||||
|
||||
let format_program (ctx : D.decl_ctx) (fmt : Format.formatter) (e : Ast.expr Pos.marked) : unit =
|
||||
Format.fprintf fmt "%a\n\n%a" format_ctx ctx (format_expr ctx) e
|
||||
let format_program (fmt : Format.formatter) (p : Ast.program) : unit =
|
||||
Format.fprintf fmt "%a@\n@\n%a" format_ctx p.decl_ctx
|
||||
(Format.pp_print_list
|
||||
~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n@\n")
|
||||
(fun fmt (name, e) ->
|
||||
Format.fprintf fmt "@[<hov 2>let@ %a@ =@ %a@]" format_var name (format_expr p.decl_ctx) e))
|
||||
p.scopes
|
||||
|
@ -18,7 +18,7 @@ open Utils
|
||||
|
||||
(** {1 Identifiers} *)
|
||||
|
||||
module ScopeName : Uid.Id with type info = Uid.MarkedString.info = Uid.Make (Uid.MarkedString) ()
|
||||
module ScopeName = Dcalc.Ast.ScopeName
|
||||
|
||||
module ScopeNameSet : Set.S with type elt = ScopeName.t = Set.Make (ScopeName)
|
||||
|
||||
|
@ -518,7 +518,7 @@ let build_scope_typ_from_sig (scope_sig : (Ast.ScopeVar.t * Dcalc.Ast.typ) list)
|
||||
scope_sig result_typ
|
||||
|
||||
let translate_program (prgm : Ast.program) (top_level_scope_name : Ast.ScopeName.t) :
|
||||
Dcalc.Ast.expr Pos.marked * Dcalc.Ast.decl_ctx =
|
||||
Dcalc.Ast.program * Dcalc.Ast.expr Pos.marked =
|
||||
let scope_dependencies = Dependency.build_program_dep_graph prgm in
|
||||
Dependency.check_for_cycle_in_scope scope_dependencies;
|
||||
Dependency.check_type_cycles prgm.program_structs prgm.program_enums;
|
||||
@ -558,17 +558,16 @@ let translate_program (prgm : Ast.program) (top_level_scope_name : Ast.ScopeName
|
||||
let acc = Dcalc.Ast.make_var (snd (Ast.ScopeMap.find top_level_scope_name sctx), Pos.no_pos) in
|
||||
(* the resulting expression is the list of definitions of all the scopes, ending with the
|
||||
top-level scope. *)
|
||||
( Bindlib.unbox
|
||||
(let acc =
|
||||
List.fold_right
|
||||
(fun scope_name (acc : Dcalc.Ast.expr Pos.marked Bindlib.box) ->
|
||||
let scope = Ast.ScopeMap.find scope_name prgm.program_scopes in
|
||||
let pos_scope = Pos.get_position (Ast.ScopeName.get_info scope.scope_decl_name) in
|
||||
let scope_expr = translate_scope_decl struct_ctx enum_ctx sctx scope_name scope in
|
||||
let scope_sig, dvar = Ast.ScopeMap.find scope_name sctx in
|
||||
let scope_typ = build_scope_typ_from_sig scope_sig pos_scope in
|
||||
Dcalc.Ast.make_let_in dvar scope_typ scope_expr acc)
|
||||
scope_ordering acc
|
||||
in
|
||||
acc),
|
||||
decl_ctx )
|
||||
let whole_program_expr, scopes =
|
||||
List.fold_right
|
||||
(fun scope_name (acc, scopes) ->
|
||||
let scope = Ast.ScopeMap.find scope_name prgm.program_scopes in
|
||||
let pos_scope = Pos.get_position (Ast.ScopeName.get_info scope.scope_decl_name) in
|
||||
let scope_expr = translate_scope_decl struct_ctx enum_ctx sctx scope_name scope in
|
||||
let scope_sig, dvar = Ast.ScopeMap.find scope_name sctx in
|
||||
let scope_typ = build_scope_typ_from_sig scope_sig pos_scope in
|
||||
( Dcalc.Ast.make_let_in dvar scope_typ scope_expr acc,
|
||||
(dvar, Bindlib.unbox scope_expr) :: scopes ))
|
||||
scope_ordering (acc, [])
|
||||
in
|
||||
({ scopes; decl_ctx }, Bindlib.unbox whole_program_expr)
|
||||
|
Loading…
Reference in New Issue
Block a user