Merge branch 'master' into refactor-clerk-w-ninja

This commit is contained in:
Emile Rolley 2022-02-15 20:36:15 +01:00 committed by GitHub
commit e58c6c52b6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
26 changed files with 435 additions and 104 deletions

38
.nix/cmdliner.nix Normal file
View File

@ -0,0 +1,38 @@
{ lib, stdenv, fetchurl, ocaml, findlib, ocamlbuild, topkg, result }:
let
pname = "cmdliner";
in
assert lib.versionAtLeast ocaml.version "4.01.0";
let param =
{
version = "1.1.0";
hash = "sha256-irWd4HTlJSYuz3HMgi1de2GVL2qus0QjeCe1WdsSs8Q=";
}
; in
stdenv.mkDerivation rec {
name = "ocaml${ocaml.version}-${pname}-${version}";
inherit (param) version;
src = fetchurl {
url = "https://erratique.ch/software/${pname}/releases/${pname}-${version}.tbz";
inherit (param) hash;
};
nativeBuildInputs = [ ocaml ocamlbuild findlib ];
buildInputs = [ topkg ];
propagatedBuildInputs = [ result ];
inherit (topkg) buildPhase installPhase;
meta = with lib; {
homepage = "https://erratique.ch/software/cmdliner";
description = "An OCaml module for the declarative definition of command line interfaces";
license = licenses.bsd3;
platforms = ocaml.meta.platforms or [];
maintainers = [ ];
};
}

View File

@ -114,6 +114,8 @@ let catala_backend_to_string (backend : Cli.backend_option) : string =
| Cli.Html -> "Html"
| Cli.Python -> "Python"
| Cli.Typecheck -> "Typecheck"
| Cli.Scalc -> "Scalc"
| Cli.Lcalc -> "Lcalc"
type expected_output_descr = {
base_filename : string;

View File

@ -21,7 +21,7 @@ depends: [
"menhirLib" {>= "20200211"}
"unionFind" {>= "20200320"}
"bindlib" {>= "5.0.1"}
"cmdliner" {>= "1.0.4"}
"cmdliner" {>= "1.1.0"}
"re" {>= "1.9.0"}
"zarith" {>= "1.12"}
"zarith_stubs_js" {>= "v0.14.1"}

View File

@ -107,7 +107,7 @@ let rec format_typ (ctx : Ast.decl_ctx) (fmt : Format.formatter) (typ : typ Pos.
Format.fprintf fmt "@[<hov 2>%a %a@ %a@]" format_typ_with_parens t1 format_operator ""
format_typ t2
| TArray t1 -> Format.fprintf fmt "@[<hov 2>%a@ %a@]" format_base_type "array" format_typ t1
| TAny -> Format.fprintf fmt "any"
| TAny -> format_base_type fmt "any"
(* (EmileRolley) NOTE: seems to be factorizable with Lcalc.Print.format_lit. *)
let format_lit (fmt : Format.formatter) (l : lit Pos.marked) : unit =
@ -229,13 +229,13 @@ let rec format_expr ?(debug : bool = false) (ctx : Ast.decl_ctx) (fmt : Format.f
(fst (List.nth (Ast.EnumMap.find en ctx.ctx_enums) n))
format_expr e
| EMatch (e, es, e_name) ->
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@ @[<hov 2>%a@]@]" format_keyword "match" format_expr e
Format.fprintf fmt "@[<hov 0>%a@ @[<hov 2>%a@]@ %a@ %a@]" format_keyword "match" format_expr e
format_keyword "with"
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n| ")
~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n")
(fun fmt (e, c) ->
Format.fprintf fmt "@[<hov 2>%a%a@ %a@]" format_enum_constructor c format_punctuation
":" format_expr e))
Format.fprintf fmt "@[<hov 2>%a %a%a@ %a@]" format_punctuation "|"
format_enum_constructor c format_punctuation ":" format_expr e))
(List.combine es (List.map fst (Ast.EnumMap.find e_name ctx.ctx_enums)))
| ELit l -> format_lit fmt (Pos.same_pos_as l e)
| EApp ((EAbs ((binder, _), taus), _), args) ->

View File

@ -32,10 +32,14 @@ val format_punctuation : Format.formatter -> string -> unit
val format_operator : Format.formatter -> string -> unit
val format_lit_style : Format.formatter -> string -> unit
(** {1 Formatters} *)
val format_uid_list : Format.formatter -> Uid.MarkedString.info list -> unit
val format_enum_constructor : Format.formatter -> Ast.EnumConstructor.t -> unit
val format_tlit : Format.formatter -> Ast.typ_lit -> unit
val format_typ : Ast.decl_ctx -> Format.formatter -> Ast.typ Pos.marked -> unit

View File

@ -72,6 +72,8 @@ let driver (source_file : Pos.input_file) (debug : bool) (unstyled : bool)
else if backend = "python" then Cli.Python
else if backend = "proof" then Cli.Proof
else if backend = "typecheck" then Cli.Typecheck
else if backend = "lcalc" then Cli.Lcalc
else if backend = "scalc" then Cli.Scalc
else
Errors.raise_error
(Printf.sprintf "The selected backend (%s) is not supported by Catala" backend)
@ -251,7 +253,7 @@ let driver (source_file : Pos.input_file) (debug : bool) (unstyled : bool)
result))
results;
0
| Cli.OCaml | Cli.Python ->
| Cli.OCaml | Cli.Python | Cli.Lcalc | Cli.Scalc ->
Cli.debug_print "Compiling program into lambda calculus...";
let prgm = Lcalc.Compile_with_exceptions.translate_program prgm in
let prgm =
@ -261,29 +263,79 @@ let driver (source_file : Pos.input_file) (debug : bool) (unstyled : bool)
end
else prgm
in
if backend = Cli.Lcalc then begin
let fmt, at_end =
match output_file with
| Some f ->
let oc = open_out f in
(Format.formatter_of_out_channel oc, fun _ -> close_out oc)
| None -> (Format.std_formatter, fun _ -> ())
in
if Option.is_some ex_scope then
Format.fprintf fmt "%a\n"
(Lcalc.Print.format_scope ~debug prgm.decl_ctx)
(let body =
List.find (fun body -> body.Lcalc.Ast.scope_body_name = scope_uid) prgm.scopes
in
body)
else
Format.fprintf fmt "%a\n"
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "\n\n")
(fun fmt scope -> (Lcalc.Print.format_scope prgm.decl_ctx) fmt scope))
prgm.scopes;
at_end ();
exit 0
end;
let source_file =
match source_file with
| FileName f -> f
| Contents _ ->
Errors.raise_error "This backend does not work if the input is not a file"
in
let output_file (extension : string) : string =
let new_output_file (extension : string) : string =
match output_file with
| Some f -> f
| None -> Filename.remove_extension source_file ^ extension
in
(match backend with
| Cli.OCaml ->
let output_file = output_file ".ml" in
let output_file = new_output_file ".ml" in
Cli.debug_print (Printf.sprintf "Writing to %s..." output_file);
let oc = open_out output_file in
let fmt = Format.formatter_of_out_channel oc in
Cli.debug_print "Compiling program into OCaml...";
Lcalc.To_ocaml.format_program fmt prgm type_ordering;
close_out oc
| Cli.Python ->
| Cli.Python | Cli.Scalc ->
let prgm = Scalc.Compile_from_lambda.translate_program prgm in
let output_file = output_file ".py" in
if backend = Cli.Scalc then begin
let fmt, at_end =
match output_file with
| Some f ->
let oc = open_out f in
(Format.formatter_of_out_channel oc, fun _ -> close_out oc)
| None -> (Format.std_formatter, fun _ -> ())
in
if Option.is_some ex_scope then
Format.fprintf fmt "%a\n"
(Scalc.Print.format_scope ~debug prgm.decl_ctx)
(let body =
List.find
(fun body -> body.Scalc.Ast.scope_body_name = scope_uid)
prgm.scopes
in
body)
else
Format.fprintf fmt "%a\n"
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "\n\n")
(fun fmt scope -> (Scalc.Print.format_scope prgm.decl_ctx) fmt scope))
prgm.scopes;
at_end ();
exit 0
end;
let output_file = new_output_file ".py" in
Cli.debug_print "Compiling program into Python...";
Cli.debug_print (Printf.sprintf "Writing to %s..." output_file);
let oc = open_out output_file in

View File

@ -88,4 +88,10 @@ let handle_default = Var.make ("handle_default", Pos.no_pos)
type binder = (expr, expr Pos.marked) Bindlib.binder
type program = { decl_ctx : D.decl_ctx; scopes : (Var.t * expr Pos.marked) list }
type scope_body = {
scope_body_name : Dcalc.Ast.ScopeName.t;
scope_body_var : Var.t;
scope_body_expr : expr Pos.marked;
}
type program = { decl_ctx : D.decl_ctx; scopes : scope_body list }

View File

@ -94,4 +94,10 @@ val handle_default : Var.t
type binder = (expr, expr Pos.marked) Bindlib.binder
type program = { decl_ctx : Dcalc.Ast.decl_ctx; scopes : (Var.t * expr Pos.marked) list }
type scope_body = {
scope_body_name : Dcalc.Ast.ScopeName.t;
scope_body_var : Var.t;
scope_body_expr : expr Pos.marked;
}
type program = { decl_ctx : Dcalc.Ast.decl_ctx; scopes : scope_body list }

View File

@ -138,13 +138,17 @@ let translate_program (prgm : D.program) : A.program =
(fun ((acc, ctx) : _ * A.Var.t D.VarMap.t) (scope_name, n, e) ->
let new_n = A.Var.make (Bindlib.name_of n, Pos.no_pos) in
let new_acc =
( new_n,
Bindlib.unbox
(translate_expr
(D.VarMap.map (fun v -> A.make_var (v, Pos.no_pos)) ctx)
(Bindlib.unbox
(D.build_whole_scope_expr prgm.decl_ctx e
(Pos.get_position (Dcalc.Ast.ScopeName.get_info scope_name))))) )
{
Ast.scope_body_name = scope_name;
scope_body_var = new_n;
scope_body_expr =
Bindlib.unbox
(translate_expr
(D.VarMap.map (fun v -> A.make_var (v, Pos.no_pos)) ctx)
(Bindlib.unbox
(D.build_whole_scope_expr prgm.decl_ctx e
(Pos.get_position (Dcalc.Ast.ScopeName.get_info scope_name)))));
}
:: acc
in
let new_ctx = D.VarMap.add n new_n ctx in

View File

@ -7,7 +7,7 @@ default term, which has been eliminated through diverse compilation schemes.
The module describing the abstract syntax tree is:
{!modules: Lcalc.Ast}
{!modules: Lcalc.Ast Lcalc.Print}
This intermediate representation corresponds to the lambda calculus
presented in the {{: https://arxiv.org/abs/2103.03198} Catala formalization}.

View File

@ -67,6 +67,16 @@ let rec peephole_expr (e : expr Pos.marked) : expr Pos.marked Bindlib.box =
| ERaise _ | ELit _ | EOp _ -> Bindlib.box e
let peephole_optimizations (p : program) : program =
{ p with scopes = List.map (fun (var, e) -> (var, Bindlib.unbox (peephole_expr e))) p.scopes }
{
p with
scopes =
List.map
(fun scope_body ->
{
scope_body with
scope_body_expr = Bindlib.unbox (peephole_expr scope_body.scope_body_expr);
})
p.scopes;
}
let optimize_program (p : program) : program = peephole_optimizations p

View File

@ -30,33 +30,23 @@ let begins_with_uppercase (s : string) : bool =
(** @note: (EmileRolley) seems to be factorizable with Dcalc.Print.format_lit. *)
let format_lit (fmt : Format.formatter) (l : lit Pos.marked) : unit =
match Pos.unmark l with
| LBool b -> Format.fprintf fmt "%b" b
| LInt i -> Format.fprintf fmt "%s" (Runtime.integer_to_string i)
| LUnit -> Format.fprintf fmt "()"
| LBool b -> Dcalc.Print.format_lit_style fmt (string_of_bool b)
| LInt i -> Dcalc.Print.format_lit_style fmt (Runtime.integer_to_string i)
| LUnit -> Dcalc.Print.format_lit_style fmt "()"
| LRat i ->
Format.fprintf fmt "%s"
Dcalc.Print.format_lit_style fmt
(Runtime.decimal_to_string ~max_prec_digits:!Utils.Cli.max_prec_digits i)
| LMoney e -> (
match !Utils.Cli.locale_lang with
| En -> Format.fprintf fmt "$%s" (Runtime.money_to_string e)
| Fr -> Format.fprintf fmt "%s €" (Runtime.money_to_string e)
| Pl -> Format.fprintf fmt "%s PLN" (Runtime.money_to_string e))
| LDate d -> Format.fprintf fmt "%s" (Runtime.date_to_string d)
| LDuration d -> Format.fprintf fmt "%s" (Runtime.duration_to_string d)
let format_uid_list (fmt : Format.formatter) (infos : Uid.MarkedString.info list) : unit =
Format.fprintf fmt "%a"
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ".")
(fun fmt info ->
Format.fprintf fmt "%a"
(Utils.Cli.format_with_style
(if begins_with_uppercase (Pos.unmark info) then [ ANSITerminal.red ] else []))
(Format.asprintf "%a" Utils.Uid.MarkedString.format_info info)))
infos
| En -> Dcalc.Print.format_lit_style fmt (Format.asprintf "$%s" (Runtime.money_to_string e))
| Fr -> Dcalc.Print.format_lit_style fmt (Format.asprintf "%s €" (Runtime.money_to_string e))
| Pl ->
Dcalc.Print.format_lit_style fmt (Format.asprintf "%s PLN" (Runtime.money_to_string e)))
| LDate d -> Dcalc.Print.format_lit_style fmt (Runtime.date_to_string d)
| LDuration d -> Dcalc.Print.format_lit_style fmt (Runtime.duration_to_string d)
let format_exception (fmt : Format.formatter) (exn : except) : unit =
Format.fprintf fmt
Dcalc.Print.format_operator fmt
(match exn with
| EmptyError -> "EmptyError"
| ConflictError -> "ConflictError"
@ -75,9 +65,9 @@ let needs_parens (e : expr Pos.marked) : bool =
let format_var (fmt : Format.formatter) (v : Var.t) : unit =
Format.fprintf fmt "%s" (Bindlib.name_of v)
let rec format_expr (ctx : Dcalc.Ast.decl_ctx) (fmt : Format.formatter) (e : expr Pos.marked) : unit
=
let format_expr = format_expr ctx in
let rec format_expr (ctx : Dcalc.Ast.decl_ctx) ?(debug : bool = false) (fmt : Format.formatter)
(e : expr Pos.marked) : unit =
let format_expr = format_expr ctx ~debug in
let format_with_parens (fmt : Format.formatter) (e : expr Pos.marked) =
if needs_parens e then
Format.fprintf fmt "%a%a%a" format_punctuation "(" format_expr e format_punctuation ")"
@ -92,8 +82,8 @@ let rec format_expr (ctx : Dcalc.Ast.decl_ctx) (fmt : Format.formatter) (e : exp
(fun fmt e -> Format.fprintf fmt "%a" format_expr e))
es format_punctuation ")"
| ETuple (es, Some s) ->
Format.fprintf fmt "@[<hov 2>%a@ @[<hov 2>%a%a%a@]@]" Dcalc.Ast.StructName.format_t s
format_punctuation "{"
Format.fprintf fmt "@[<hov 2>%a@ %a%a%a@]" Dcalc.Ast.StructName.format_t s format_punctuation
"{"
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ")
(fun fmt (e, struct_field) ->
@ -117,17 +107,17 @@ let rec format_expr (ctx : Dcalc.Ast.decl_ctx) (fmt : Format.formatter) (e : exp
(fst (List.nth (Dcalc.Ast.StructMap.find s ctx.ctx_structs) n))
format_punctuation "\"")
| EInj (e, n, en, _ts) ->
Format.fprintf fmt "@[<hov 2>%a@ %a@]" Dcalc.Ast.EnumConstructor.format_t
Format.fprintf fmt "@[<hov 2>%a@ %a@]" Dcalc.Print.format_enum_constructor
(fst (List.nth (Dcalc.Ast.EnumMap.find en ctx.ctx_enums) n))
format_expr e
| EMatch (e, es, e_name) ->
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@ @[<hov 2>%a@]@]" format_keyword "match" format_expr e
Format.fprintf fmt "@[<hov 0>%a@ @[<hov 2>%a@]@ %a@ %a@]" format_keyword "match" format_expr e
format_keyword "with"
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n| ")
~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n")
(fun fmt (e, c) ->
Format.fprintf fmt "@[<hov 2>%a%a@ %a@]" Dcalc.Ast.EnumConstructor.format_t c
format_punctuation ":" format_expr e))
Format.fprintf fmt "@[<hov 2>%a %a%a@ %a@]" format_punctuation "|"
Dcalc.Print.format_enum_constructor c format_punctuation ":" format_expr e))
(List.combine es (List.map fst (Dcalc.Ast.EnumMap.find e_name ctx.ctx_enums)))
| ELit l -> Format.fprintf fmt "%a" format_lit (Pos.same_pos_as l e)
| EApp ((EAbs ((binder, _), taus), _), args) ->
@ -138,14 +128,14 @@ let rec format_expr (ctx : Dcalc.Ast.decl_ctx) (fmt : Format.formatter) (e : exp
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "")
(fun fmt (x, tau, arg) ->
Format.fprintf fmt "@[<hov 2>%a@ @[<hov 2>%a@ %a@ %a@]@ %a@ %a@]@ %a@\n" format_keyword
"let" format_var x format_punctuation ":" (Dcalc.Print.format_typ ctx) tau
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@ %a@ %a@ %a@ %a@]@\n" format_keyword "let"
format_var x format_punctuation ":" (Dcalc.Print.format_typ ctx) tau
format_punctuation "=" format_expr arg format_keyword "in"))
xs_tau_arg format_expr body
| EAbs ((binder, _), taus) ->
let xs, body = Bindlib.unmbind binder in
let xs_tau = List.map2 (fun x tau -> (x, tau)) (Array.to_list xs) taus in
Format.fprintf fmt "@[<hov 2>%a @[<hov 2>%a@] %a@ %a@]" format_punctuation "λ"
Format.fprintf fmt "@[<hov 2>%a %a %a@ %a@]" format_punctuation "λ"
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ")
(fun fmt (x, tau) ->
@ -158,7 +148,7 @@ let rec format_expr (ctx : Dcalc.Ast.decl_ctx) (fmt : Format.formatter) (e : exp
| EApp ((EOp (Binop op), _), [ arg1; arg2 ]) ->
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@]" format_with_parens arg1 Dcalc.Print.format_binop
(op, Pos.no_pos) format_with_parens arg2
| EApp ((EOp (Unop (Log _)), _), [ arg1 ]) when not !Cli.debug_flag ->
| EApp ((EOp (Unop (Log _)), _), [ arg1 ]) when not debug ->
Format.fprintf fmt "%a" format_with_parens arg1
| EApp ((EOp (Unop op), _), [ arg1 ]) ->
Format.fprintf fmt "@[<hov 2>%a@ %a@]" Dcalc.Print.format_unop (op, Pos.no_pos)
@ -174,9 +164,14 @@ let rec format_expr (ctx : Dcalc.Ast.decl_ctx) (fmt : Format.formatter) (e : exp
| EOp (Binop op) -> Format.fprintf fmt "%a" Dcalc.Print.format_binop (op, Pos.no_pos)
| EOp (Unop op) -> Format.fprintf fmt "%a" Dcalc.Print.format_unop (op, Pos.no_pos)
| ECatch (e1, exn, e2) ->
Format.fprintf fmt "@[<hov 2>try@ %a@ with@ %a ->@ %a@]" format_with_parens e1
format_exception exn format_with_parens e2
| ERaise exn -> Format.fprintf fmt "@[<hov 2>raise@ %a@]" format_exception exn
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@ %a ->@ %a@]" format_keyword "try" format_with_parens
e1 format_keyword "with" format_exception exn format_with_parens e2
| ERaise exn -> Format.fprintf fmt "@[<hov 2>%a@ %a@]" format_keyword "raise" format_exception exn
| EAssert e' ->
Format.fprintf fmt "@[<hov 2>%a@ %a%a%a@]" format_keyword "assert" format_punctuation "("
format_expr e' format_punctuation ")"
let format_scope (decl_ctx : Dcalc.Ast.decl_ctx) ?(debug : bool = false) (fmt : Format.formatter)
(body : scope_body) : unit =
Format.fprintf fmt "@[<hov 2>%a %a %a@ %a@]" format_keyword "let" format_var body.scope_body_var
format_punctuation "=" (format_expr decl_ctx ~debug) body.scope_body_expr

View File

@ -22,10 +22,13 @@ val begins_with_uppercase : string -> bool
(** {1 Formatters} *)
val format_uid_list : Format.formatter -> Uid.MarkedString.info list -> unit
val format_lit : Format.formatter -> Ast.lit Pos.marked -> unit
val format_var : Format.formatter -> Ast.Var.t -> unit
val format_expr : Dcalc.Ast.decl_ctx -> Format.formatter -> Ast.expr Pos.marked -> unit
val format_exception : Format.formatter -> Ast.except -> unit
val format_expr :
Dcalc.Ast.decl_ctx -> ?debug:bool -> Format.formatter -> Ast.expr Pos.marked -> unit
val format_scope : Dcalc.Ast.decl_ctx -> ?debug:bool -> Format.formatter -> Ast.scope_body -> unit

View File

@ -433,6 +433,7 @@ let format_program (fmt : Format.formatter) (p : Ast.program)
(format_ctx type_ordering) p.decl_ctx
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n@\n")
(fun fmt (name, e) ->
Format.fprintf fmt "@[<hov 2>let@ %a@ =@ %a@]" format_var name (format_expr p.decl_ctx) e))
(fun fmt body ->
Format.fprintf fmt "@[<hov 2>let@ %a@ =@ %a@]" format_var body.scope_body_var
(format_expr p.decl_ctx) body.scope_body_expr))
p.scopes

View File

@ -51,4 +51,10 @@ and block = stmt Pos.marked list
and func = { func_params : (LocalName.t Pos.marked * D.typ Pos.marked) list; func_body : block }
type program = { decl_ctx : D.decl_ctx; scopes : (TopLevelName.t * func) list }
type scope_body = {
scope_body_name : Dcalc.Ast.ScopeName.t;
scope_body_var : TopLevelName.t;
scope_body_func : func;
}
type program = { decl_ctx : D.decl_ctx; scopes : scope_body list }

View File

@ -244,14 +244,21 @@ let translate_program (p : L.program) : A.program =
scopes =
(let _, new_scopes =
List.fold_left
(fun (func_dict, new_scopes) (scope_name, scope_expr) ->
(fun (func_dict, new_scopes) body ->
let new_scope_params, new_scope_body =
translate_scope p.decl_ctx func_dict scope_expr
translate_scope p.decl_ctx func_dict body.Lcalc.Ast.scope_body_expr
in
let func_id = A.TopLevelName.fresh (Bindlib.name_of scope_name, Pos.no_pos) in
let func_dict = L.VarMap.add scope_name func_id func_dict in
let func_id =
A.TopLevelName.fresh (Bindlib.name_of body.Lcalc.Ast.scope_body_var, Pos.no_pos)
in
let func_dict = L.VarMap.add body.Lcalc.Ast.scope_body_var func_id func_dict in
( func_dict,
(func_id, { A.func_params = new_scope_params; A.func_body = new_scope_body })
{
Ast.scope_body_name = body.Lcalc.Ast.scope_body_name;
Ast.scope_body_var = func_id;
scope_body_func =
{ A.func_params = new_scope_params; A.func_body = new_scope_body };
}
:: new_scopes ))
( L.VarMap.singleton L.handle_default
(A.TopLevelName.fresh ("handle_default", Pos.no_pos)),

162
compiler/scalc/print.ml Normal file
View File

@ -0,0 +1,162 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits
computation rules. Copyright (C) 2022 Inria, contributor: Denis Merigoux
<denis.merigoux@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
in compliance with the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
or implied. See the License for the specific language governing permissions and limitations under
the License. *)
open Utils
open Ast
let needs_parens (_e : expr Pos.marked) : bool = false
let format_local_name (fmt : Format.formatter) (v : LocalName.t) : unit =
Format.fprintf fmt "%a_%s" LocalName.format_t v (string_of_int (LocalName.hash v))
let rec format_expr (decl_ctx : Dcalc.Ast.decl_ctx) ?(debug : bool = false) (fmt : Format.formatter)
(e : expr Pos.marked) : unit =
let format_expr = format_expr decl_ctx ~debug in
let format_with_parens (fmt : Format.formatter) (e : expr Pos.marked) =
if needs_parens e then
Format.fprintf fmt "%a%a%a" Dcalc.Print.format_punctuation "(" format_expr e
Dcalc.Print.format_punctuation ")"
else Format.fprintf fmt "%a" format_expr e
in
match Pos.unmark e with
| EVar v -> Format.fprintf fmt "%a" format_local_name v
| EFunc v -> Format.fprintf fmt "%a" TopLevelName.format_t v
| EStruct (es, s) ->
Format.fprintf fmt "@[<hov 2>%a@ %a%a%a@]" Dcalc.Ast.StructName.format_t s
Dcalc.Print.format_punctuation "{"
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ")
(fun fmt (e, struct_field) ->
Format.fprintf fmt "%a%a%a%a %a" Dcalc.Print.format_punctuation "\""
Dcalc.Ast.StructFieldName.format_t struct_field Dcalc.Print.format_punctuation "\""
Dcalc.Print.format_punctuation ":" format_expr e))
(List.combine es (List.map fst (Dcalc.Ast.StructMap.find s decl_ctx.ctx_structs)))
Dcalc.Print.format_punctuation "}"
| EArray es ->
Format.fprintf fmt "@[<hov 2>%a%a%a@]" Dcalc.Print.format_punctuation "["
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ";@ ")
(fun fmt e -> Format.fprintf fmt "%a" format_expr e))
es Dcalc.Print.format_punctuation "]"
| EStructFieldAccess (e1, field, s) ->
Format.fprintf fmt "%a%a%a%a%a" format_expr e1 Dcalc.Print.format_punctuation "."
Dcalc.Print.format_punctuation "\"" Dcalc.Ast.StructFieldName.format_t
(fst
(List.find
(fun (field', _) -> Dcalc.Ast.StructFieldName.compare field' field = 0)
(Dcalc.Ast.StructMap.find s decl_ctx.ctx_structs)))
Dcalc.Print.format_punctuation "\""
| EInj (e, case, enum) ->
Format.fprintf fmt "@[<hov 2>%a@ %a@]" Dcalc.Print.format_enum_constructor
(fst
(List.find
(fun (case', _) -> Dcalc.Ast.EnumConstructor.compare case' case = 0)
(Dcalc.Ast.EnumMap.find enum decl_ctx.ctx_enums)))
format_expr e
| ELit l -> Format.fprintf fmt "%a" Lcalc.Print.format_lit (Pos.same_pos_as l e)
| EApp ((EOp (Binop ((Dcalc.Ast.Map | Dcalc.Ast.Filter) as op)), _), [ arg1; arg2 ]) ->
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@]" Dcalc.Print.format_binop (op, Pos.no_pos)
format_with_parens arg1 format_with_parens arg2
| EApp ((EOp (Binop op), _), [ arg1; arg2 ]) ->
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@]" format_with_parens arg1 Dcalc.Print.format_binop
(op, Pos.no_pos) format_with_parens arg2
| EApp ((EOp (Unop (Log _)), _), [ arg1 ]) when not debug ->
Format.fprintf fmt "%a" format_with_parens arg1
| EApp ((EOp (Unop op), _), [ arg1 ]) ->
Format.fprintf fmt "@[<hov 2>%a@ %a@]" Dcalc.Print.format_unop (op, Pos.no_pos)
format_with_parens arg1
| EApp (f, args) ->
Format.fprintf fmt "@[<hov 2>%a@ %a@]" format_expr f
(Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ") format_with_parens)
args
| EOp (Ternop op) -> Format.fprintf fmt "%a" Dcalc.Print.format_ternop (op, Pos.no_pos)
| EOp (Binop op) -> Format.fprintf fmt "%a" Dcalc.Print.format_binop (op, Pos.no_pos)
| EOp (Unop op) -> Format.fprintf fmt "%a" Dcalc.Print.format_unop (op, Pos.no_pos)
let rec format_statement (decl_ctx : Dcalc.Ast.decl_ctx) ?(debug : bool = false)
(fmt : Format.formatter) (stmt : stmt Pos.marked) : unit =
if debug then () else ();
match Pos.unmark stmt with
| SInnerFuncDef (name, func) ->
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@ %a@]@\n@[<v 2> %a@]" Dcalc.Print.format_keyword
"let" LocalName.format_t (Pos.unmark name)
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ")
(fun fmt ((name, _), typ) ->
Format.fprintf fmt "%a%a %a@ %a%a" Dcalc.Print.format_punctuation "("
LocalName.format_t name Dcalc.Print.format_punctuation ":"
(Dcalc.Print.format_typ decl_ctx) typ Dcalc.Print.format_punctuation ")"))
func.func_params Dcalc.Print.format_punctuation "=" (format_block decl_ctx ~debug)
func.func_body
| SLocalDecl (name, typ) ->
Format.fprintf fmt "@[<hov 2>%a %a %a@ %a@]" Dcalc.Print.format_keyword "decl"
LocalName.format_t (Pos.unmark name) Dcalc.Print.format_punctuation ":"
(Dcalc.Print.format_typ decl_ctx) typ
| SLocalDef (name, expr) ->
Format.fprintf fmt "@[<hov 2>%a %a@ %a@]" LocalName.format_t (Pos.unmark name)
Dcalc.Print.format_punctuation "=" (format_expr decl_ctx ~debug) expr
| STryExcept (b_try, except, b_with) ->
Format.fprintf fmt "@[<v 2>%a%a@ %a@]@\n@[<v 2>%a %a%a@ %a@]" Dcalc.Print.format_keyword "try"
Dcalc.Print.format_punctuation ":" (format_block decl_ctx ~debug) b_try
Dcalc.Print.format_keyword "with" Lcalc.Print.format_exception except
Dcalc.Print.format_punctuation ":" (format_block decl_ctx ~debug) b_with
| SRaise except ->
Format.fprintf fmt "@[<hov 2>%a %a@]" Dcalc.Print.format_keyword "raise"
Lcalc.Print.format_exception except
| SIfThenElse (e_if, b_true, b_false) ->
Format.fprintf fmt "@[<v 2>%a @[<hov 2>%a@]%a@ %a@ @]@[<v 2>%a%a@ %a@]"
Dcalc.Print.format_keyword "if" (format_expr decl_ctx ~debug) e_if
Dcalc.Print.format_punctuation ":" (format_block decl_ctx ~debug) b_true
Dcalc.Print.format_keyword "else" Dcalc.Print.format_punctuation ":"
(format_block decl_ctx ~debug) b_false
| SReturn ret ->
Format.fprintf fmt "@[<hov 2>%a %a@]" Dcalc.Print.format_keyword "return"
(format_expr decl_ctx ~debug)
(ret, Pos.get_position stmt)
| SAssert expr ->
Format.fprintf fmt "@[<hov 2>%a %a@]" Dcalc.Print.format_keyword "assert"
(format_expr decl_ctx ~debug)
(expr, Pos.get_position stmt)
| SSwitch (e_switch, enum, arms) ->
Format.fprintf fmt "@[<v 0>%a @[<hov 2>%a@]%a@]%a" Dcalc.Print.format_keyword "switch"
(format_expr decl_ctx ~debug) e_switch Dcalc.Print.format_punctuation ":"
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n")
(fun fmt ((case, _), (arm_block, payload_name)) ->
Format.fprintf fmt "%a %a%a@ %a @[<v 2>%a@ %a@]" Dcalc.Print.format_punctuation "|"
Dcalc.Print.format_enum_constructor case Dcalc.Print.format_punctuation ":"
LocalName.format_t payload_name Dcalc.Print.format_punctuation ""
(format_block decl_ctx ~debug) arm_block))
(List.combine (Dcalc.Ast.EnumMap.find enum decl_ctx.ctx_enums) arms)
and format_block (decl_ctx : Dcalc.Ast.decl_ctx) ?(debug : bool = false) (fmt : Format.formatter)
(block : block) : unit =
Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "%a@ " Dcalc.Print.format_punctuation ";")
(format_statement decl_ctx ~debug)
fmt block
let format_scope (decl_ctx : Dcalc.Ast.decl_ctx) ?(debug : bool = false) (fmt : Format.formatter)
(body : scope_body) : unit =
if debug then () else ();
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@ %a@]@\n@[<v 2> %a@]" Dcalc.Print.format_keyword "let"
TopLevelName.format_t body.scope_body_var
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ")
(fun fmt ((name, _), typ) ->
Format.fprintf fmt "%a%a %a@ %a%a" Dcalc.Print.format_punctuation "(" LocalName.format_t
name Dcalc.Print.format_punctuation ":" (Dcalc.Print.format_typ decl_ctx) typ
Dcalc.Print.format_punctuation ")"))
body.scope_body_func.func_params Dcalc.Print.format_punctuation "="
(format_block decl_ctx ~debug) body.scope_body_func.func_body

15
compiler/scalc/print.mli Normal file
View File

@ -0,0 +1,15 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits
computation rules. Copyright (C) 2022 Inria, contributor: Denis Merigoux
<denis.merigoux@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
in compliance with the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
or implied. See the License for the specific language governing permissions and limitations under
the License. *)
val format_scope : Dcalc.Ast.decl_ctx -> ?debug:bool -> Format.formatter -> Ast.scope_body -> unit

View File

@ -8,7 +8,7 @@ rules in the language, every local variable has a unique id.
The module describing the abstract syntax tree is:
{!modules: Scalc.Ast}
{!modules: Scalc.Ast Scalc.Print}
{1 Compilation from lambda calculus }

View File

@ -420,8 +420,9 @@ let format_program (fmt : Format.formatter) (p : Ast.program)
(format_ctx type_ordering) p.decl_ctx
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n@\n")
(fun fmt (name, { Ast.func_params; Ast.func_body }) ->
Format.fprintf fmt "@[<hov 4>def %a(%a):@\n%a@]" format_toplevel_name name
(fun fmt body ->
let { Ast.func_params; Ast.func_body } = body.scope_body_func in
Format.fprintf fmt "@[<hov 4>def %a(%a):@\n%a@]" format_toplevel_name body.scope_body_var
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ", ")
(fun fmt (var, typ) ->

View File

@ -77,13 +77,14 @@ let rec format_expr (fmt : Format.formatter) (e : expr Pos.marked) : unit =
| EEnumInj (e1, cons, _) ->
Format.fprintf fmt "%a@ %a" Ast.EnumConstructor.format_t cons format_expr e1
| EMatch (e1, _, cases) ->
Format.fprintf fmt "@[<hov 2>@[%a@ %a@ %a@]@ %a@]" Dcalc.Print.format_keyword "match"
Format.fprintf fmt "@[<hov 0>%a@ @[<hov 2>%a@]@ %a@ %a@]" Dcalc.Print.format_keyword "match"
format_expr e1 Dcalc.Print.format_keyword "with"
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@ %a@ " Dcalc.Print.format_punctuation "|")
~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n")
(fun fmt (cons_name, case_expr) ->
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@]" Ast.EnumConstructor.format_t cons_name
Dcalc.Print.format_punctuation "" format_expr case_expr))
Format.fprintf fmt "@[<hov 2>%a %a@ %a@ %a@]" Dcalc.Print.format_punctuation "|"
Dcalc.Print.format_enum_constructor cons_name Dcalc.Print.format_punctuation ""
format_expr case_expr))
(Ast.EnumConstructorMap.bindings cases)
| EApp ((EAbs ((binder, _), taus), _), args) ->
let xs, body = Bindlib.unmbind binder in
@ -104,7 +105,9 @@ let rec format_expr (fmt : Format.formatter) (e : expr Pos.marked) : unit =
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@ %a@]" Dcalc.Print.format_punctuation "λ"
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt " ")
(fun fmt (x, tau) -> Format.fprintf fmt "@[(%a:@ %a)@]" format_var x format_typ tau))
(fun fmt (x, tau) ->
Format.fprintf fmt "@[%a%a%a@ %a%a@]" Dcalc.Print.format_punctuation "(" format_var x
Dcalc.Print.format_punctuation ":" format_typ tau Dcalc.Print.format_punctuation ")"))
xs_tau Dcalc.Print.format_punctuation "" format_expr body
| EApp ((EOp (Binop op), _), [ arg1; arg2 ]) ->
Format.fprintf fmt "@[%a@ %a@ %a@]" format_with_parens arg1 Dcalc.Print.format_binop
@ -163,32 +166,27 @@ let format_enum (fmt : Format.formatter)
cases
let format_scope (fmt : Format.formatter) ((name, decl) : ScopeName.t * scope_decl) : unit =
Format.fprintf fmt "@[<hov 2>%a %a@ %a@ %a@ %a@]@\n@[<hov 2> %a@]" Dcalc.Print.format_keyword
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@ %a@ %a@]@\n@[<v 2> %a@]" Dcalc.Print.format_keyword
"let" Dcalc.Print.format_keyword "scope" ScopeName.format_t name
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ")
(fun fmt (scope_var, (typ, vis)) ->
Format.fprintf fmt "%a%a%a %a%s%s%a" Dcalc.Print.format_punctuation "(" ScopeVar.format_t
Format.fprintf fmt "%a%a%a %a%a%a%a%a" Dcalc.Print.format_punctuation "(" ScopeVar.format_t
scope_var Dcalc.Print.format_punctuation ":" format_typ typ
Dcalc.Print.format_punctuation "|" Dcalc.Print.format_keyword
(match Pos.unmark vis.io_input with
| NoInput ->
Format.asprintf "%a%a" Dcalc.Print.format_punctuation "|" Dcalc.Print.format_keyword
"internal"
| OnlyInput ->
Format.asprintf "%a%a" Dcalc.Print.format_punctuation "|" Dcalc.Print.format_keyword
"input"
| Reentrant ->
Format.asprintf "%a%a" Dcalc.Print.format_punctuation "|" Dcalc.Print.format_keyword
"context")
(if Pos.unmark vis.io_output then
Format.asprintf "%a%a" Dcalc.Print.format_punctuation "|" Dcalc.Print.format_keyword
"output"
else "")
Dcalc.Print.format_punctuation ")"))
| NoInput -> "internal"
| OnlyInput -> "input"
| Reentrant -> "context")
(if Pos.unmark vis.io_output then fun fmt () ->
Format.fprintf fmt "%a@,%a" Dcalc.Print.format_punctuation "|"
Dcalc.Print.format_keyword "output"
else fun fmt () -> Format.fprintf fmt "@<0>")
() Dcalc.Print.format_punctuation ")"))
(ScopeVarMap.bindings decl.scope_sig)
Dcalc.Print.format_punctuation "="
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "%a@\n" Dcalc.Print.format_punctuation ";")
~pp_sep:(fun fmt () -> Format.fprintf fmt "%a@ " Dcalc.Print.format_punctuation ";")
(fun fmt rule ->
match rule with
| Definition (loc, typ, _, e) ->
@ -216,12 +214,17 @@ let format_scope (fmt : Format.formatter) ((name, decl) : ScopeName.t * scope_de
decl.scope_decl_rules
let format_program (fmt : Format.formatter) (p : program) : unit =
Format.fprintf fmt "%a%s%a%s%a"
(Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n@\n") format_struct)
Format.fprintf fmt "%a%a%a%a%a"
(Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt "\n\n") format_struct)
(StructMap.bindings p.program_structs)
(if StructMap.is_empty p.program_structs then "" else "\n\n")
(Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n@\n") format_enum)
(fun fmt () ->
if StructMap.is_empty p.program_structs then Format.fprintf fmt ""
else Format.fprintf fmt "\n\n")
()
(Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt "\n\n") format_enum)
(EnumMap.bindings p.program_enums)
(if EnumMap.is_empty p.program_enums then "" else "\n\n")
(Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n@\n") format_scope)
(fun fmt () ->
if EnumMap.is_empty p.program_enums then Format.fprintf fmt "" else Format.fprintf fmt "\n\n")
()
(Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt "\n\n") format_scope)
(ScopeMap.bindings p.program_scopes)

View File

@ -81,6 +81,8 @@ type backend_option =
| Typecheck
| OCaml
| Python
| Scalc
| Lcalc
| Dcalc
| Scopelang
| Proof
@ -162,8 +164,17 @@ let info =
Catala program. Use the $(b,-s) option to restrict the output to a particular scope." );
`I
( "$(b,Dcalc)",
"Prints a debugging verbatim of the scope language intermediate representation of the \
"Prints a debugging verbatim of the default calculus intermediate representation of the \
Catala program. Use the $(b,-s) option to restrict the output to a particular scope." );
`I
( "$(b,Lcalc)",
"Prints a debugging verbatim of the lambda calculus intermediate representation of the \
Catala program. Use the $(b,-s) option to restrict the output to a particular scope." );
`I
( "$(b,Scalc)",
"Prints a debugging verbatim of the statement calculus intermediate representation of \
the Catala program. Use the $(b,-s) option to restrict the output to a particular \
scope." );
`S Manpage.s_authors;
`P "The authors are listed by alphabetical order.";
`P "Nicolas Chataing <nicolas.chataing@ens.fr>";

View File

@ -60,6 +60,8 @@ type backend_option =
| Typecheck
| OCaml
| Python
| Scalc
| Lcalc
| Dcalc
| Scopelang
| Proof

View File

@ -88,7 +88,8 @@ let match_and_ignore_outer_reentrant_default (ctx : ctx) (e : expr Pos.marked) :
(Print.format_expr ~debug:true ctx.decl)
e)
(Pos.get_position e))
| EApp ((EOp (Unop (Log _)), _), [ ((ErrorOnEmpty (EDefault (_, _, _), _), _) as d) ]) ->
| ErrorOnEmpty (EApp ((EOp (Unop (Log _)), _), [ d ]), _)
| EApp ((EOp (Unop (Log _)), _), [ (ErrorOnEmpty d, _) ]) ->
d (* input subscope variables and non-input scope variable *)
| _ ->
Errors.raise_spanned_error

View File

@ -4,4 +4,5 @@ with pkgs;
ocamlPackages.callPackage ./. {
bindlib = ocamlPackages.callPackage ./.nix/bindlib.nix { };
unionfind = ocamlPackages.callPackage ./.nix/unionfind.nix { };
cmdliner = ocamlPackages.callPackage ./.nix/cmdliner.nix { };
}

View File

@ -5,6 +5,7 @@ let
pkg = ocamlPackages.callPackage ./. {
bindlib = ocamlPackages.callPackage ./.nix/bindlib.nix { };
unionfind = ocamlPackages.callPackage ./.nix/unionfind.nix { };
cmdliner = ocamlPackages.callPackage ./.nix/cmdliner.nix { };
};
in mkShell {
inputsFrom = [ pkg ];