Merge branch 'master' into afromher_334

This commit is contained in:
Denis Merigoux 2023-01-20 14:05:38 -05:00
commit 7cffc53169
No known key found for this signature in database
GPG Key ID: EE99DCFA365C3EE3
331 changed files with 33415 additions and 30203 deletions

View File

@ -1,28 +0,0 @@
{ lib, fetchFromGitHub, buildDunePackage }:
# We need the very last version "bleeding edge" since previous versions don't use dune.
buildDunePackage rec {
pname = "bindlib";
version = "5.0.1a";
minimumOCamlVersion = "4.0.8";
useDune2 = true;
src = fetchFromGitHub {
owner = "rlepigre";
repo = "ocaml-${pname}";
rev = "317f195d22c75f556053039cd94b52bd0c423709";
name = pname;
hash = "sha256-uO/Ko9PmQ+wE0d9jfEngd4G014B4nxGgfQyEvB52Pz8=";
};
meta = with lib; {
homepage = "https://rlepigre.github.io/ocaml-bindlib/";
description =
"Bindlib is a library allowing the manipulation of data structures with bound variables";
license = licenses.lgpl3;
maintainers = [ ];
};
}

View File

@ -5,7 +5,7 @@
, bindlib , bindlib
, buildDunePackage , buildDunePackage
, calendar , calendar
, cmdliner_1_1_0 , cmdliner
, cppo , cppo
, dates_calc , dates_calc
, fetchFromGitHub , fetchFromGitHub
@ -42,7 +42,7 @@ buildDunePackage rec {
ansiterminal ansiterminal
benchmark benchmark
bindlib bindlib
cmdliner_1_1_0 cmdliner
cppo cppo
dates_calc dates_calc
js_of_ocaml js_of_ocaml

View File

@ -1,32 +1,13 @@
{ ocamlPackages, fetchurl }: { ocamlPackages, fetchurl }:
ocamlPackages.overrideScope' (self: super: { ocamlPackages.overrideScope' (self: super: {
cmdliner_1_1_0 = super.cmdliner.overrideAttrs (o: rec { alcotest = (super.alcotest.override {}).overrideAttrs (_: {
version = "1.1.0";
src = fetchurl {
url = "https://erratique.ch/software/${o.pname}/releases/${o.pname }-${version}.tbz";
sha256 = "sha256-irWd4HTlJSYuz3HMgi1de2GVL2qus0QjeCe1WdsSs8Q=";
};
});
alcotest = (super.alcotest.override {
cmdliner = self.cmdliner_1_1_0;
}).overrideAttrs (_: {
doCheck = false; doCheck = false;
}); });
# Use a more recent version of `re` than the one packaged in nixpkgs
re = super.re.overrideAttrs (o: rec {
version = "1.10.4";
src = fetchurl {
url = "https://github.com/ocaml/ocaml-${o.pname}/releases/download/${version}/${o.pname}-${version}.tbz";
sha256 = "sha256-g+s+QwCqmx3HggdJAQ9DYuqDUkdCEwUk14wgzpnKdHw=";
};
});
catala = self.callPackage ./catala.nix { }; catala = self.callPackage ./catala.nix { };
bindlib = self.callPackage ./bindlib.nix { };
unionfind = self.callPackage ./unionfind.nix { }; unionfind = self.callPackage ./unionfind.nix { };
ninja_utils = self.callPackage ./ninja_utils.nix { }; ninja_utils = self.callPackage ./ninja_utils.nix { };
clerk = self.callPackage ./clerk.nix { }; clerk = self.callPackage ./clerk.nix { };
ppx_yojson_conv = self.callPackage ./ppx_yojson_conv.nix { };
ubase = self.callPackage ./ubase.nix { }; ubase = self.callPackage ./ubase.nix { };
dates_calc = self.callPackage ./dates_calc.nix { }; dates_calc = self.callPackage ./dates_calc.nix { };
}) })

View File

@ -1,20 +0,0 @@
{ lib, fetchurl, buildDunePackage, ppxlib, ppx_yojson_conv_lib, ppx_js_style }:
buildDunePackage rec {
pname = "ppx_yojson_conv";
version = "0.14.0";
minimumOCamlVersion = "4.0.8";
useDune2 = true;
propagatedBuildInputs = [
ppxlib ppx_yojson_conv_lib ppx_js_style
];
src = fetchurl
{
url = "https://ocaml.janestreet.com/ocaml-core/v0.14/files/ppx_yojson_conv-v0.14.0.tar.gz";
sha256 = "0ls6vzj7k0wrjliifqczs78anbc8b88as5w7a3wixfcs1gjfsp2w";
};
}

View File

@ -104,19 +104,20 @@ need more, here is how one can be added:
- Choose a name wisely. Be ready to patch any code that already used the name - Choose a name wisely. Be ready to patch any code that already used the name
for scope parameters, variables or structure fields, since it won't compile for scope parameters, variables or structure fields, since it won't compile
anymore. anymore.
- Add an element to the `builtin_expression` type in `surface/ast.ml(i)` - Add an element to the `builtin_expression` type in `surface/ast.ml`
- Add your builtin in the `builtins` list in `surface/lexer.cppo.ml`, and with - Add your builtin in the `builtins` list in `surface/lexer.cppo.ml`, and with
proper translations in all of the language-specific modules proper translations in all of the language-specific modules
`surface/lexer_en.cppo.ml`, `surface/lexer_fr.cppo.ml`, etc. Don't forget the `surface/lexer_en.cppo.ml`, `surface/lexer_fr.cppo.ml`, etc. Don't forget the
macro at the beginning of `lexer.cppo.ml`. macro at the beginning of `lexer.cppo.ml`.
- The rest can all be done by following the type errors downstream: - The rest can all be done by following the type errors downstream:
- Add a corresponding element to the lower-level AST in `dcalc/ast.ml(i)`, type `unop` - Add a corresponding element to the lower-level AST in `shared_ast/definitions.ml`, type `Op.t`
- Extend the translation accordingly in `surface/desugaring.ml` - Extend the generic operations on operators in `shared_ast/operators.ml` as well as the type information for the operator
- Extend the printer (`dcalc/print.ml`) and the typer with correct type - Extend the translation accordingly in `desugared/from_surface.ml`
information (`dcalc/typing.ml`) - Extend the printer (`shared_ast/print.ml`)
- Finally, provide the implementations: - Finally, provide the implementations:
- in `lcalc/to_ocaml.ml`, function `format_unop`
- in `dcalc/interpreter.ml`, function `evaluate_operator` - in `dcalc/interpreter.ml`, function `evaluate_operator`
- in `../runtimes/ocaml/runtime.ml`
- in `../runtimes/python/catala/src/catala/runtime.py`
- Update the syntax guide in `doc/syntax/syntax.tex` with your new builtin - Update the syntax guide in `doc/syntax/syntax.tex` with your new builtin
### Internationalization of the Catala syntax ### Internationalization of the Catala syntax

View File

@ -3,7 +3,7 @@
FROM ocamlpro/ocaml:4.14-2022-07-17 AS dev-build-context FROM ocamlpro/ocaml:4.14-2022-07-17 AS dev-build-context
# pandoc is not in alpine stable yet, install it manually with an explicit repository # pandoc is not in alpine stable yet, install it manually with an explicit repository
RUN sudo apk add pandoc --repository=http://dl-cdn.alpinelinux.org/alpine/edge/testing/ RUN sudo apk add pandoc --repository=http://dl-cdn.alpinelinux.org/alpine/edge/community/
RUN mkdir catala RUN mkdir catala
WORKDIR catala WORKDIR catala

View File

@ -22,18 +22,31 @@ Finally, start a shell inside a new container created from the newly built image
The repository provides nix files to build or develop the catala compiler. The repository provides nix files to build or develop the catala compiler.
Once [nix is installed](https://nixos.org/manual/nix/stable/#ch-installing-binary), Once [nix is installed](https://nixos.org/manual/nix/stable/#ch-installing-binary),
it is possible to enter a development shell: with flakes enabled it is possible to enter a development shell:
nix-shell nix develop
or to build the Catala compiler, documentation and runtime library: or to build the Catala compiler, documentation and runtime library:
nix-build release.nix nix build
Dependencies not yet in nixpkgs (`bindlib` and `unionFind` at the moment of writing) Dependencies not yet in nixpkgs (`ubase` and `unionFind` at the moment of writing)
are hardcoded inside the `.nix` directory. The `default.nix` should be compatible with are hardcoded inside the `.nix` directory. The `.nix/catala.nix` should be compatible with
nixpkgs, if it finds a maintainer. nixpkgs, if it finds a maintainer.
To develop catala's compiler using vscode using ocaml's [lsp](https://microsoft.github.io/language-server-protocol/), you can use the [ocaml-platform extension](https://marketplace.visualstudio.com/items?itemName=ocamllabs.ocaml-platform) with the following settings (inside the file `.vscode/settings.json`).
```json
{
"ocaml.sandbox": {
"kind": "custom",
"template": "nix develop --command $prog $args"
},
}
```
The nix build is updated weekly by an automatic github action.
### With opam ### With opam
The Catala compiler is written using OCaml. First, you have to install `opam`, The Catala compiler is written using OCaml. First, you have to install `opam`,

View File

@ -299,8 +299,9 @@ run_french_law_library_benchmark_python: $(PY_VIRTUALENV) \
CATALA_OPTS?= CATALA_OPTS?=
CLERK_OPTS?=--makeflags="$(MAKEFLAGS)" CLERK_OPTS?=--makeflags="$(MAKEFLAGS)"
CATALA_BIN=_build/default/compiler/catala.exe CATALA_BIN=_build/default/$(COMPILER_DIR)/catala.exe
CLERK_BIN=_build/default/build_system/clerk.exe CLERK_BIN=_build/default/$(BUILD_SYSTEM_DIR)/clerk.exe
CATALA_LEGIFRANCE_BIN=_build/default/$(CATALA_LEGIFRANCE_DIR)/catala_legifrance.exe
CLERK=$(CLERK_BIN) --exe $(CATALA_BIN) \ CLERK=$(CLERK_BIN) --exe $(CATALA_BIN) \
$(CLERK_OPTS) $(if $(CATALA_OPTS),--catala-opts=$(CATALA_OPTS),) $(CLERK_OPTS) $(if $(CATALA_OPTS),--catala-opts=$(CATALA_OPTS),)
@ -336,7 +337,7 @@ tests/%: .FORCE
# Website assets # Website assets
########################################## ##########################################
WEBSITE_ASSETS = grammar.html catala.html WEBSITE_ASSETS = grammar.html catala.html clerk.html catala_legifrance.html
$(addprefix _build/default/,$(WEBSITE_ASSETS)): $(addprefix _build/default/,$(WEBSITE_ASSETS)):
dune build $@ dune build $@
@ -386,6 +387,11 @@ help_clerk:
help_catala: help_catala:
$(CATALA_BIN) --help $(CATALA_BIN) --help
#> help_catala_legifrance : Display the catala_legifrance man page
help_catala_legifrance:
$(CATALA_LEGIFRANCE_BIN) --help
########################################## ##########################################
# Special targets # Special targets
########################################## ##########################################

View File

@ -16,7 +16,7 @@
the License. *) the License. *)
open Cmdliner open Cmdliner
open Utils open Catala_utils
open Ninja_utils open Ninja_utils
module Nj = Ninja_utils module Nj = Ninja_utils
@ -524,7 +524,7 @@ let collect_all_ninja_build
(tested_file : string) (tested_file : string)
(reset_test_outputs : bool) : (string * ninja) option = (reset_test_outputs : bool) : (string * ninja) option =
let expected_outputs = search_for_expected_outputs tested_file in let expected_outputs = search_for_expected_outputs tested_file in
if List.length expected_outputs = 0 then ( if expected_outputs = [] then (
Cli.debug_print "No expected outputs were found for test file %s" Cli.debug_print "No expected outputs were found for test file %s"
tested_file; tested_file;
None) None)
@ -890,10 +890,18 @@ let driver
let files_or_folders = List.sort_uniq String.compare files_or_folders let files_or_folders = List.sort_uniq String.compare files_or_folders
and catala_exe = Option.fold ~none:"catala" ~some:Fun.id catala_exe and catala_exe = Option.fold ~none:"catala" ~some:Fun.id catala_exe
and catala_opts = Option.fold ~none:"" ~some:Fun.id catala_opts and catala_opts = Option.fold ~none:"" ~some:Fun.id catala_opts
and ninja_output = and with_ninja_output k =
Option.fold match ninja_output with
~none:(Filename.temp_file "clerk_build_" ".ninja") | Some f -> k f
~some:Fun.id ninja_output | None -> (
let f = Filename.temp_file "clerk_build_" ".ninja" in
match k f with
| exception e ->
if not debug then Sys.remove f;
raise e
| r ->
Sys.remove f;
r)
in in
match String.lowercase_ascii command with match String.lowercase_ascii command with
| "test" -> ( | "test" -> (
@ -919,20 +927,22 @@ let driver
if 0 = List.compare_lengths ctx.all_failed_names files_or_folders then if 0 = List.compare_lengths ctx.all_failed_names files_or_folders then
return_ok return_ok
else else
try with_ninja_output
File.with_formatter_of_file ninja_output (fun fmt -> @@ fun nin ->
Cli.debug_print "writing %s..." ninja_output; match
File.with_formatter_of_file nin (fun fmt ->
Cli.debug_print "writing %s..." nin;
Nj.format fmt Nj.format fmt
(add_root_test_build ninja ctx.all_file_names (add_root_test_build ninja ctx.all_file_names
ctx.all_test_builds)); ctx.all_test_builds))
with
| () ->
let ninja_cmd = let ninja_cmd =
"ninja -k 0 -f " ^ ninja_output ^ " " ^ ninja_flags ^ " test" "ninja -k 0 -f " ^ nin ^ " " ^ ninja_flags ^ " test"
in in
Cli.debug_print "executing '%s'..." ninja_cmd; Cli.debug_print "executing '%s'..." ninja_cmd;
let return = Sys.command ninja_cmd in Sys.command ninja_cmd
if not debug then Sys.remove ninja_output; | exception Sys_error e ->
return
with Sys_error e ->
Cli.error_print "can not write in %s" e; Cli.error_print "can not write in %s" e;
return_err) return_err)
| "run" -> ( | "run" -> (

View File

@ -9,7 +9,7 @@
(public_name clerk.driver) (public_name clerk.driver)
(libraries (libraries
catala.runtime_ocaml catala.runtime_ocaml
catala.utils catala.catala_utils
ninja_utils ninja_utils
cmdliner cmdliner
re re

View File

@ -34,6 +34,7 @@ depends: [
"ppx_yojson_conv" {>= "0.14.0"} "ppx_yojson_conv" {>= "0.14.0"}
"re" {>= "1.9.0"} "re" {>= "1.9.0"}
"sedlex" {>= "2.4"} "sedlex" {>= "2.4"}
"uutf" {>= "1.0.3"}
"ubase" {>= "0.05"} "ubase" {>= "0.05"}
"unionFind" {>= "20200320"} "unionFind" {>= "20200320"}
"visitors" {>= "20200210"} "visitors" {>= "20200210"}
@ -45,6 +46,7 @@ depends: [
"obelisk" {cataladevmode} "obelisk" {cataladevmode}
"conf-npm" {cataladevmode} "conf-npm" {cataladevmode}
"conf-python-3-dev" {cataladevmode} "conf-python-3-dev" {cataladevmode}
"cpdf" {cataladevmode}
"z3" {catalaz3mode} "z3" {catalaz3mode}
] ]
depopts: ["z3"] depopts: ["z3"]

View File

@ -7,12 +7,12 @@ In {{: desugared.html} the desugared representation} or in the
global identifiers. These identifiers use OCaml's type system to statically global identifiers. These identifiers use OCaml's type system to statically
distinguish e.g. a scope identifier from a struct identifier. distinguish e.g. a scope identifier from a struct identifier.
The {!module: Utils.Uid} module provides a generative functor whose output is The {!module: Uid} module provides a generative functor whose output is
a fresh sort of global identifiers. a fresh sort of global identifiers.
Related modules: Related modules:
{!modules: Utils.Uid} {!modules: Uid}
{1 Source code positions} {1 Source code positions}
@ -22,7 +22,7 @@ code. These annotations are critical to produce readable error messages.
Related modules: Related modules:
{!modules: Utils.Pos} {!modules: Pos}
{1 Error messages} {1 Error messages}

View File

@ -172,7 +172,7 @@ let plugins_dirs =
let default = let default =
let ( / ) = Filename.concat in let ( / ) = Filename.concat in
[ [
Sys.executable_name Filename.dirname Sys.executable_name
/ Filename.parent_dir_name / Filename.parent_dir_name
/ "lib" / "lib"
/ "catala" / "catala"

View File

@ -1,8 +1,8 @@
(library (library
(name utils) (name catala_utils)
(public_name catala.utils) (public_name catala.catala_utils)
(libraries cmdliner ubase ANSITerminal re bindlib catala.runtime_ocaml)) (libraries cmdliner ubase ANSITerminal re bindlib catala.runtime_ocaml))
(documentation (documentation
(package catala) (package catala)
(mld_files utils)) (mld_files catala_utils))

View File

@ -26,7 +26,7 @@ exception StructuredError of (string * (string option * Pos.t) list)
let print_structured_error (msg : string) (pos : (string option * Pos.t) list) : let print_structured_error (msg : string) (pos : (string option * Pos.t) list) :
string = string =
Printf.sprintf "%s%s%s" msg Printf.sprintf "%s%s%s" msg
(if List.length pos = 0 then "" else "\n\n") (if pos = [] then "" else "\n\n")
(String.concat "\n\n" (String.concat "\n\n"
(List.map (List.map
(fun (msg, pos) -> (fun (msg, pos) ->

View File

@ -79,11 +79,11 @@ let to_string (pos : t) : string =
let to_string_short (pos : t) : string = let to_string_short (pos : t) : string =
let s, e = pos.code_pos in let s, e = pos.code_pos in
if e.Lexing.pos_lnum = s.Lexing.pos_lnum then if e.Lexing.pos_lnum = s.Lexing.pos_lnum then
Printf.sprintf "%s:%d.%d-%d" s.Lexing.pos_fname s.Lexing.pos_lnum Printf.sprintf "%s:%d.%d-%d:" s.Lexing.pos_fname s.Lexing.pos_lnum
(s.Lexing.pos_cnum - s.Lexing.pos_bol) (s.Lexing.pos_cnum - s.Lexing.pos_bol)
(e.Lexing.pos_cnum - e.Lexing.pos_bol) (e.Lexing.pos_cnum - e.Lexing.pos_bol)
else else
Printf.sprintf "%s:%d.%d-%d.%d" s.Lexing.pos_fname s.Lexing.pos_lnum Printf.sprintf "%s:%d.%d-%d.%d:" s.Lexing.pos_fname s.Lexing.pos_lnum
(s.Lexing.pos_cnum - s.Lexing.pos_bol) (s.Lexing.pos_cnum - s.Lexing.pos_bol)
e.Lexing.pos_lnum e.Lexing.pos_lnum
(e.Lexing.pos_cnum - e.Lexing.pos_bol) (e.Lexing.pos_cnum - e.Lexing.pos_bol)
@ -102,6 +102,27 @@ let string_repeat n s =
done; done;
Bytes.to_string buf Bytes.to_string buf
(* Note: this should do, but remains incorrect for combined unicode characters
that display as one (e.g. `e` + postfix `'`). We should switch to Uuseg at
some poing *)
let string_columns s =
let len = String.length s in
let rec aux ncols i =
if i >= len then ncols
else if s.[i] = '\t' then aux (ncols + 8) (i + 1)
else
aux (ncols + 1) (i + Uchar.utf_decode_length (String.get_utf_8_uchar s i))
in
aux 0 0
let utf8_byte_index s ui0 =
let rec aux bi ui =
if ui >= ui0 then bi
else
aux (bi + Uchar.utf_decode_length (String.get_utf_8_uchar s bi)) (ui + 1)
in
aux 0 0
let retrieve_loc_text (pos : t) : string = let retrieve_loc_text (pos : t) : string =
try try
let filename = get_file pos in let filename = get_file pos in
@ -132,34 +153,32 @@ let retrieve_loc_text (pos : t) : string =
let print_matched_line (line : string) (line_no : int) : string = let print_matched_line (line : string) (line_no : int) : string =
let line_indent = indent_number line in let line_indent = indent_number line in
let error_indicator_style = [ANSITerminal.red; ANSITerminal.Bold] in let error_indicator_style = [ANSITerminal.red; ANSITerminal.Bold] in
line let match_start_index =
^ utf8_byte_index line
if line_no >= sline && line_no <= eline then (if line_no = sline then get_start_column pos - 1 else line_indent)
"\n" in
^ let match_end_index =
if line_no = sline && line_no = eline then if line_no = eline then utf8_byte_index line (get_end_column pos - 1)
Cli.with_style error_indicator_style "%*s%s" else String.length line
(get_start_column pos - 1) in
"" let unmatched_prefix = String.sub line 0 match_start_index in
(string_repeat let matched_substring =
(max (get_end_column pos - get_start_column pos) 0) String.sub line match_start_index
"") (max 0 (match_end_index - match_start_index))
else if line_no = sline && line_no <> eline then in
Cli.with_style error_indicator_style "%*s%s" let match_start_col = string_columns unmatched_prefix in
(get_start_column pos - 1) let match_num_cols = string_columns matched_substring in
"" String.concat ""
(string_repeat (line
(max (String.length line - get_start_column pos) 0) :: "\n"
"") ::
else if line_no <> sline && line_no <> eline then (if line_no >= sline && line_no <= eline then
Cli.with_style error_indicator_style "%*s%s" line_indent "" [
(string_repeat (max (String.length line - line_indent) 0) "") string_repeat match_start_col " ";
else if line_no <> sline && line_no = eline then Cli.with_style error_indicator_style "%s"
Cli.with_style error_indicator_style "%*s%*s" line_indent "" (string_repeat match_num_cols "");
(get_end_column pos - 1 - line_indent) ]
(string_repeat (max (get_end_column pos - line_indent) 0) "") else []))
else assert false (* should not happen *)
else ""
in in
let include_extra_count = 0 in let include_extra_count = 0 in
let rec get_lines (n : int) : string list = let rec get_lines (n : int) : string list =
@ -193,10 +212,8 @@ let retrieve_loc_text (pos : t) : string =
(Cli.with_style blue_style "└%s┐" (string_repeat spaces "")); (Cli.with_style blue_style "└%s┐" (string_repeat spaces ""));
Buffer.add_char buf '\n'; Buffer.add_char buf '\n';
Buffer.add_string buf Buffer.add_string buf
(Cli.add_prefix_to_each_line (Cli.add_prefix_to_each_line (String.concat "\n" pos_lines) (fun i ->
(String.concat "\n" ("" :: pos_lines)) let cur_line = sline - include_extra_count + i in
(fun i ->
let cur_line = sline - include_extra_count + i - 1 in
if if
cur_line >= sline cur_line >= sline
&& cur_line <= sline + (2 * (eline - sline)) && cur_line <= sline + (2 * (eline - sline))

View File

@ -14,39 +14,47 @@
License for the specific language governing permissions and limitations under License for the specific language governing permissions and limitations under
the License. *) the License. *)
let to_ascii : string -> string = Ubase.from_utf8 include Stdlib.String
let is_uppercase_ascii (c : char) : bool = let to_ascii : string -> string = Ubase.from_utf8
let c = Char.code c in let is_uppercase_ascii = function 'A' .. 'Z' -> true | _ -> false
(* 'A' <= c && c <= 'Z' *)
0x41 <= c && c <= 0x5b
let begins_with_uppercase (s : string) : bool = let begins_with_uppercase (s : string) : bool =
if "" = s then false else is_uppercase_ascii (to_ascii s).[0] "" <> s && is_uppercase_ascii (get (to_ascii s) 0)
let to_snake_case (s : string) : string = let to_snake_case (s : string) : string =
let out = ref "" in let out = ref "" in
to_ascii s to_ascii s
|> String.iteri (fun i c -> |> iteri (fun i c ->
out := out :=
!out !out
^ (if is_uppercase_ascii c && 0 <> i then "_" else "") ^ (if is_uppercase_ascii c && 0 <> i then "_" else "")
^ String.lowercase_ascii (String.make 1 c)); ^ lowercase_ascii (make 1 c));
!out !out
let to_camel_case (s : string) : string = let to_camel_case (s : string) : string =
let last_was_underscore = ref false in let last_was_underscore = ref false in
let out = ref "" in let out = ref "" in
to_ascii s to_ascii s
|> String.iteri (fun i c -> |> iteri (fun i c ->
let is_underscore = c = '_' in let is_underscore = c = '_' in
let c_string = String.make 1 c in let c_string = make 1 c in
out := out :=
!out !out
^ ^
if is_underscore then "" if is_underscore then ""
else if !last_was_underscore || 0 = i then else if !last_was_underscore || 0 = i then uppercase_ascii c_string
String.uppercase_ascii c_string
else c_string; else c_string;
last_was_underscore := is_underscore); last_was_underscore := is_underscore);
!out !out
let remove_prefix ~prefix s =
if starts_with ~prefix s then
let plen = length prefix in
sub s plen (length s - plen)
else s
let format_t = Format.pp_print_string
module Set = Set.Make (Stdlib.String)
module Map = Map.Make (Stdlib.String)

View File

@ -14,6 +14,10 @@
License for the specific language governing permissions and limitations under License for the specific language governing permissions and limitations under
the License. *) the License. *)
include module type of Stdlib.String
module Set : Set.S with type elt = string
module Map : Map.S with type key = string
(** Helper functions used for string manipulation. *) (** Helper functions used for string manipulation. *)
val to_ascii : string -> string val to_ascii : string -> string
@ -34,3 +38,11 @@ val to_snake_case : string -> string
val to_camel_case : string -> string val to_camel_case : string -> string
(** Converts snake_case into CamlCase after removing Remove all diacritics on (** Converts snake_case into CamlCase after removing Remove all diacritics on
Latin letters. *) Latin letters. *)
val remove_prefix : prefix:string -> string -> string
(** [remove_prefix ~prefix str] returns
- if [str] starts with [prefix], a string [s] such that [prefix ^ s = str]
- otherwise, [str] unchanged *)
val format_t : Format.formatter -> string -> unit

View File

@ -18,7 +18,7 @@ module type Info = sig
type info type info
val to_string : info -> string val to_string : info -> string
val format_info : Format.formatter -> info -> unit val format : Format.formatter -> info -> unit
val equal : info -> info -> bool val equal : info -> info -> bool
val compare : info -> info -> int val compare : info -> info -> int
end end
@ -33,10 +33,21 @@ module type Id = sig
val equal : t -> t -> bool val equal : t -> t -> bool
val format_t : Format.formatter -> t -> unit val format_t : Format.formatter -> t -> unit
val hash : t -> int val hash : t -> int
module Set : Set.S with type elt = t
module Map : Map.S with type key = t
end end
module Make (X : Info) () : Id with type info = X.info = struct module Make (X : Info) () : Id with type info = X.info = struct
type t = { id : int; info : X.info } module Ordering = struct
type t = { id : int; info : X.info }
let compare (x : t) (y : t) : int = compare x.id y.id
let equal x y = Int.equal x.id y.id
end
include Ordering
type info = X.info type info = X.info
let counter = ref 0 let counter = ref 0
@ -46,20 +57,20 @@ module Make (X : Info) () : Id with type info = X.info = struct
{ id = !counter; info } { id = !counter; info }
let get_info (uid : t) : X.info = uid.info let get_info (uid : t) : X.info = uid.info
let compare (x : t) (y : t) : int = compare x.id y.id let format_t (fmt : Format.formatter) (x : t) : unit = X.format fmt x.info
let equal x y = Int.equal x.id y.id
let format_t (fmt : Format.formatter) (x : t) : unit =
X.format_info fmt x.info
let hash (x : t) : int = x.id let hash (x : t) : int = x.id
module Set = Set.Make (Ordering)
module Map = Map.Make (Ordering)
end end
module MarkedString = struct module MarkedString = struct
type info = string Marked.pos type info = string Marked.pos
let to_string (s, _) = s let to_string (s, _) = s
let format_info fmt i = Format.pp_print_string fmt (to_string i) let format fmt i = Format.pp_print_string fmt (to_string i)
let equal i1 i2 = String.equal (Marked.unmark i1) (Marked.unmark i2) let equal i1 i2 = String.equal (Marked.unmark i1) (Marked.unmark i2)
let compare i1 i2 = String.compare (Marked.unmark i1) (Marked.unmark i2) let compare i1 i2 = String.compare (Marked.unmark i1) (Marked.unmark i2)
end end
module Gen () = Make (MarkedString) ()

View File

@ -21,7 +21,7 @@ module type Info = sig
type info type info
val to_string : info -> string val to_string : info -> string
val format_info : Format.formatter -> info -> unit val format : Format.formatter -> info -> unit
val equal : info -> info -> bool val equal : info -> info -> bool
(** Equality disregards position *) (** Equality disregards position *)
@ -48,9 +48,15 @@ module type Id = sig
val equal : t -> t -> bool val equal : t -> t -> bool
val format_t : Format.formatter -> t -> unit val format_t : Format.formatter -> t -> unit
val hash : t -> int val hash : t -> int
module Set : Set.S with type elt = t
module Map : Map.S with type key = t
end end
(** This is the generative functor that ensures that two modules resulting from (** This is the generative functor that ensures that two modules resulting from
two different calls to [Make] will be viewed as different types [t] by the two different calls to [Make] will be viewed as different types [t] by the
OCaml typechecker. Prevents mixing up different sorts of identifiers. *) OCaml typechecker. Prevents mixing up different sorts of identifiers. *)
module Make (X : Info) () : Id with type info = X.info module Make (X : Info) () : Id with type info = X.info
module Gen () : Id with type info = MarkedString.info
(** Shortcut for creating a kind of uids over marked strings *)

View File

@ -1,3 +1,4 @@
open Catala_utils
open Driver open Driver
open Js_of_ocaml open Js_of_ocaml
@ -12,7 +13,7 @@ let _ =
driver driver
(Contents (Js.to_string contents)) (Contents (Js.to_string contents))
{ {
Utils.Cli.debug = false; Cli.debug = false;
color = Never; color = Never;
wrap_weaved_output = false; wrap_weaved_output = false;
avoid_exceptions = false; avoid_exceptions = false;

View File

@ -1,7 +1,15 @@
(library (library
(name dcalc) (name dcalc)
(public_name catala.dcalc) (public_name catala.dcalc)
(libraries bindlib unionFind utils re ubase catala.runtime_ocaml shared_ast) (libraries
bindlib
unionFind
catala_utils
re
ubase
catala.runtime_ocaml
shared_ast
scopelang)
(preprocess (preprocess
(pps visitors.ppx))) (pps visitors.ppx)))

View File

@ -16,4 +16,4 @@
(** Scope language to default calculus translator *) (** Scope language to default calculus translator *)
val translate_program : 'm Ast.program -> 'm Dcalc.Ast.program val translate_program : 'm Scopelang.Ast.program -> 'm Ast.program

View File

@ -16,7 +16,7 @@
(** Reference interpreter for the default calculus *) (** Reference interpreter for the default calculus *)
open Utils open Catala_utils
open Shared_ast open Shared_ast
module Runtime = Runtime_ocaml.Runtime module Runtime = Runtime_ocaml.Runtime
@ -29,272 +29,117 @@ let log_indent = ref 0
(** {1 Evaluation} *) (** {1 Evaluation} *)
let rec evaluate_operator let print_log ctx entry infos pos e =
(ctx : decl_ctx) if !Cli.trace_flag then
(op : operator) match entry with
(pos : Pos.t) | VarDef _ ->
(args : 'm Ast.expr list) : 'm Ast.naked_expr = (* TODO: this usage of Format is broken, Formatting requires that all is
(* Try to apply [div] and if a [Division_by_zero] exceptions is catched, use formatted in one pass, without going through intermediate "%s" *)
[op] to raise multispanned errors. *) Cli.log_format "%*s%a %a: %s" (!log_indent * 2) "" Print.log_entry entry
let apply_div_or_raise_err (div : unit -> 'm Ast.naked_expr) : Print.uid_list infos
'm Ast.naked_expr = (match Marked.unmark e with
try div () | EAbs _ -> Cli.with_style [ANSITerminal.green] "<function>"
with Division_by_zero -> | _ ->
let expr_str =
Format.asprintf "%a" (Expr.format ctx ~debug:false) e
in
let expr_str =
Re.Pcre.substitute ~rex:(Re.Pcre.regexp "\n\\s*")
~subst:(fun _ -> " ")
expr_str
in
Cli.with_style [ANSITerminal.green] "%s" expr_str)
| PosRecordIfTrueBool -> (
match pos <> Pos.no_pos, Marked.unmark e with
| true, ELit (LBool true) ->
Cli.log_format "%*s%a%s:\n%s" (!log_indent * 2) "" Print.log_entry entry
(Cli.with_style [ANSITerminal.green] "Definition applied")
(Cli.add_prefix_to_each_line (Pos.retrieve_loc_text pos) (fun _ ->
Format.asprintf "%*s" (!log_indent * 2) ""))
| _ -> ())
| BeginCall ->
Cli.log_format "%*s%a %a" (!log_indent * 2) "" Print.log_entry entry
Print.uid_list infos;
log_indent := !log_indent + 1
| EndCall ->
log_indent := !log_indent - 1;
Cli.log_format "%*s%a %a" (!log_indent * 2) "" Print.log_entry entry
Print.uid_list infos
(* Todo: this should be handled early when resolving overloads. Here we have
proper structural equality, but the OCaml backend for example uses the
builtin equality function instead of this. *)
let rec handle_eq ctx pos e1 e2 =
let open Runtime.Oper in
match e1, e2 with
| ELit LUnit, ELit LUnit -> true
| ELit (LBool b1), ELit (LBool b2) -> not (o_xor b1 b2)
| ELit (LInt x1), ELit (LInt x2) -> o_eq_int_int x1 x2
| ELit (LRat x1), ELit (LRat x2) -> o_eq_rat_rat x1 x2
| ELit (LMoney x1), ELit (LMoney x2) -> o_eq_mon_mon x1 x2
| ELit (LDuration x1), ELit (LDuration x2) -> o_eq_dur_dur x1 x2
| ELit (LDate x1), ELit (LDate x2) -> o_eq_dat_dat x1 x2
| EArray es1, EArray es2 -> (
try
List.for_all2
(fun e1 e2 ->
match evaluate_operator ctx Eq pos [e1; e2] with
| ELit (LBool b) -> b
| _ -> assert false
(* should not happen *))
es1 es2
with Invalid_argument _ -> false)
| EStruct { fields = es1; name = s1 }, EStruct { fields = es2; name = s2 } ->
StructName.equal s1 s2
&& StructField.Map.equal
(fun e1 e2 ->
match evaluate_operator ctx Eq pos [e1; e2] with
| ELit (LBool b) -> b
| _ -> assert false
(* should not happen *))
es1 es2
| ( EInj { e = e1; cons = i1; name = en1 },
EInj { e = e2; cons = i2; name = en2 } ) -> (
try
EnumName.equal en1 en2
&& EnumConstructor.equal i1 i2
&&
match evaluate_operator ctx Eq pos [e1; e2] with
| ELit (LBool b) -> b
| _ -> assert false
(* should not happen *)
with Invalid_argument _ -> false)
| _, _ -> false (* comparing anything else return false *)
(* Call-by-value: the arguments are expected to be already evaluated here *)
and evaluate_operator :
type k.
decl_ctx ->
(dcalc, k) operator ->
Pos.t ->
'm Ast.expr list ->
'm Ast.naked_expr =
fun ctx op pos args ->
let protect f x y =
let get_binop_args_pos = function
| (arg0 :: arg1 :: _ : 'm Ast.expr list) ->
[None, Expr.pos arg0; None, Expr.pos arg1]
| _ -> assert false
in
try f x y with
| Division_by_zero ->
Errors.raise_multispanned_error Errors.raise_multispanned_error
[ [
Some "The division operator:", pos; Some "The division operator:", pos;
Some "The null denominator:", Expr.pos (List.nth args 1); Some "The null denominator:", Expr.pos (List.nth args 1);
] ]
"division by zero at runtime" "division by zero at runtime"
in | Runtime.UncomparableDurations ->
let get_binop_args_pos = function
| (arg0 :: arg1 :: _ : 'm Ast.expr list) ->
[None, Expr.pos arg0; None, Expr.pos arg1]
| _ -> assert false
in
(* Try to apply [cmp] and if a [UncomparableDurations] exceptions is catched,
use [args] to raise multispanned errors. *)
let apply_cmp_or_raise_err
(cmp : unit -> 'm Ast.naked_expr)
(args : 'm Ast.expr list) : 'm Ast.naked_expr =
try cmp ()
with Runtime.UncomparableDurations ->
Errors.raise_multispanned_error (get_binop_args_pos args) Errors.raise_multispanned_error (get_binop_args_pos args)
"Cannot compare together durations that cannot be converted to a \ "Cannot compare together durations that cannot be converted to a \
precise number of days" precise number of days"
in in
match op, List.map Marked.unmark args with let err () =
| Ternop Fold, [_f; _init; EArray es] ->
Marked.unmark
(List.fold_left
(fun acc e' ->
evaluate_expr ctx
(Marked.same_mark_as (EApp (List.nth args 0, [acc; e'])) e'))
(List.nth args 1) es)
| Binop And, [ELit (LBool b1); ELit (LBool b2)] -> ELit (LBool (b1 && b2))
| Binop Or, [ELit (LBool b1); ELit (LBool b2)] -> ELit (LBool (b1 || b2))
| Binop Xor, [ELit (LBool b1); ELit (LBool b2)] -> ELit (LBool (b1 <> b2))
| Binop (Add KInt), [ELit (LInt i1); ELit (LInt i2)] ->
ELit (LInt Runtime.(i1 +! i2))
| Binop (Sub KInt), [ELit (LInt i1); ELit (LInt i2)] ->
ELit (LInt Runtime.(i1 -! i2))
| Binop (Mult KInt), [ELit (LInt i1); ELit (LInt i2)] ->
ELit (LInt Runtime.(i1 *! i2))
| Binop (Div KInt), [ELit (LInt i1); ELit (LInt i2)] ->
apply_div_or_raise_err (fun _ -> ELit (LInt Runtime.(i1 /! i2)))
| Binop (Add KRat), [ELit (LRat i1); ELit (LRat i2)] ->
ELit (LRat Runtime.(i1 +& i2))
| Binop (Sub KRat), [ELit (LRat i1); ELit (LRat i2)] ->
ELit (LRat Runtime.(i1 -& i2))
| Binop (Mult KRat), [ELit (LRat i1); ELit (LRat i2)] ->
ELit (LRat Runtime.(i1 *& i2))
| Binop (Div KRat), [ELit (LRat i1); ELit (LRat i2)] ->
apply_div_or_raise_err (fun _ -> ELit (LRat Runtime.(i1 /& i2)))
| Binop (Add KMoney), [ELit (LMoney m1); ELit (LMoney m2)] ->
ELit (LMoney Runtime.(m1 +$ m2))
| Binop (Sub KMoney), [ELit (LMoney m1); ELit (LMoney m2)] ->
ELit (LMoney Runtime.(m1 -$ m2))
| Binop (Mult KMoney), [ELit (LMoney m1); ELit (LRat m2)] ->
ELit (LMoney Runtime.(m1 *$ m2))
| Binop (Div KMoney), [ELit (LMoney m1); ELit (LMoney m2)] ->
apply_div_or_raise_err (fun _ -> ELit (LRat Runtime.(m1 /$ m2)))
| Binop (Add KDuration), [ELit (LDuration d1); ELit (LDuration d2)] ->
ELit (LDuration Runtime.(d1 +^ d2))
| Binop (Sub KDuration), [ELit (LDuration d1); ELit (LDuration d2)] ->
ELit (LDuration Runtime.(d1 -^ d2))
| Binop (Sub KDate), [ELit (LDate d1); ELit (LDate d2)] ->
ELit (LDuration Runtime.(d1 -@ d2))
| Binop (Add KDate), [ELit (LDate d1); ELit (LDuration d2)] ->
ELit (LDate Runtime.(d1 +@ d2))
| Binop (Mult KDuration), [ELit (LDuration d1); ELit (LInt i1)] ->
ELit (LDuration Runtime.(d1 *^ i1))
| Binop (Lt KInt), [ELit (LInt i1); ELit (LInt i2)] ->
ELit (LBool Runtime.(i1 <! i2))
| Binop (Lte KInt), [ELit (LInt i1); ELit (LInt i2)] ->
ELit (LBool Runtime.(i1 <=! i2))
| Binop (Gt KInt), [ELit (LInt i1); ELit (LInt i2)] ->
ELit (LBool Runtime.(i1 >! i2))
| Binop (Gte KInt), [ELit (LInt i1); ELit (LInt i2)] ->
ELit (LBool Runtime.(i1 >=! i2))
| Binop (Lt KRat), [ELit (LRat i1); ELit (LRat i2)] ->
ELit (LBool Runtime.(i1 <& i2))
| Binop (Lte KRat), [ELit (LRat i1); ELit (LRat i2)] ->
ELit (LBool Runtime.(i1 <=& i2))
| Binop (Gt KRat), [ELit (LRat i1); ELit (LRat i2)] ->
ELit (LBool Runtime.(i1 >& i2))
| Binop (Gte KRat), [ELit (LRat i1); ELit (LRat i2)] ->
ELit (LBool Runtime.(i1 >=& i2))
| Binop (Lt KMoney), [ELit (LMoney m1); ELit (LMoney m2)] ->
ELit (LBool Runtime.(m1 <$ m2))
| Binop (Lte KMoney), [ELit (LMoney m1); ELit (LMoney m2)] ->
ELit (LBool Runtime.(m1 <=$ m2))
| Binop (Gt KMoney), [ELit (LMoney m1); ELit (LMoney m2)] ->
ELit (LBool Runtime.(m1 >$ m2))
| Binop (Gte KMoney), [ELit (LMoney m1); ELit (LMoney m2)] ->
ELit (LBool Runtime.(m1 >=$ m2))
| Binop (Lt KDuration), [ELit (LDuration d1); ELit (LDuration d2)] ->
apply_cmp_or_raise_err (fun _ -> ELit (LBool Runtime.(d1 <^ d2))) args
| Binop (Lte KDuration), [ELit (LDuration d1); ELit (LDuration d2)] ->
apply_cmp_or_raise_err (fun _ -> ELit (LBool Runtime.(d1 <=^ d2))) args
| Binop (Gt KDuration), [ELit (LDuration d1); ELit (LDuration d2)] ->
apply_cmp_or_raise_err (fun _ -> ELit (LBool Runtime.(d1 >^ d2))) args
| Binop (Gte KDuration), [ELit (LDuration d1); ELit (LDuration d2)] ->
apply_cmp_or_raise_err (fun _ -> ELit (LBool Runtime.(d1 >=^ d2))) args
| Binop (Lt KDate), [ELit (LDate d1); ELit (LDate d2)] ->
ELit (LBool Runtime.(d1 <@ d2))
| Binop (Lte KDate), [ELit (LDate d1); ELit (LDate d2)] ->
ELit (LBool Runtime.(d1 <=@ d2))
| Binop (Gt KDate), [ELit (LDate d1); ELit (LDate d2)] ->
ELit (LBool Runtime.(d1 >@ d2))
| Binop (Gte KDate), [ELit (LDate d1); ELit (LDate d2)] ->
ELit (LBool Runtime.(d1 >=@ d2))
| Binop Eq, [ELit LUnit; ELit LUnit] -> ELit (LBool true)
| Binop Eq, [ELit (LDuration d1); ELit (LDuration d2)] ->
ELit (LBool Runtime.(d1 =^ d2))
| Binop Eq, [ELit (LDate d1); ELit (LDate d2)] ->
ELit (LBool Runtime.(d1 =@ d2))
| Binop Eq, [ELit (LMoney m1); ELit (LMoney m2)] ->
ELit (LBool Runtime.(m1 =$ m2))
| Binop Eq, [ELit (LRat i1); ELit (LRat i2)] ->
ELit (LBool Runtime.(i1 =& i2))
| Binop Eq, [ELit (LInt i1); ELit (LInt i2)] ->
ELit (LBool Runtime.(i1 =! i2))
| Binop Eq, [ELit (LBool b1); ELit (LBool b2)] -> ELit (LBool (b1 = b2))
| Binop Eq, [EArray es1; EArray es2] ->
ELit
(LBool
(try
List.for_all2
(fun e1 e2 ->
match evaluate_operator ctx op pos [e1; e2] with
| ELit (LBool b) -> b
| _ -> assert false
(* should not happen *))
es1 es2
with Invalid_argument _ -> false))
| Binop Eq, [ETuple (es1, s1); ETuple (es2, s2)] ->
ELit
(LBool
(try
s1 = s2
&& List.for_all2
(fun e1 e2 ->
match evaluate_operator ctx op pos [e1; e2] with
| ELit (LBool b) -> b
| _ -> assert false
(* should not happen *))
es1 es2
with Invalid_argument _ -> false))
| Binop Eq, [EInj (e1, i1, en1, _ts1); EInj (e2, i2, en2, _ts2)] ->
ELit
(LBool
(try
en1 = en2
&& i1 = i2
&&
match evaluate_operator ctx op pos [e1; e2] with
| ELit (LBool b) -> b
| _ -> assert false
(* should not happen *)
with Invalid_argument _ -> false))
| Binop Eq, [_; _] ->
ELit (LBool false) (* comparing anything else return false *)
| Binop Neq, [_; _] -> (
match evaluate_operator ctx (Binop Eq) pos args with
| ELit (LBool b) -> ELit (LBool (not b))
| _ -> assert false (*should not happen *))
| Binop Concat, [EArray es1; EArray es2] -> EArray (es1 @ es2)
| Binop Map, [_; EArray es] ->
EArray
(List.map
(fun e' ->
evaluate_expr ctx
(Marked.same_mark_as (EApp (List.nth args 0, [e'])) e'))
es)
| Binop Filter, [_; EArray es] ->
EArray
(List.filter
(fun e' ->
match
evaluate_expr ctx
(Marked.same_mark_as (EApp (List.nth args 0, [e'])) e')
with
| ELit (LBool b), _ -> b
| _ ->
Errors.raise_spanned_error
(Expr.pos (List.nth args 0))
"This predicate evaluated to something else than a boolean \
(should not happen if the term was well-typed)")
es)
| Binop _, ([ELit LEmptyError; _] | [_; ELit LEmptyError]) -> ELit LEmptyError
| Unop (Minus KInt), [ELit (LInt i)] ->
ELit (LInt Runtime.(integer_of_int 0 -! i))
| Unop (Minus KRat), [ELit (LRat i)] ->
ELit (LRat Runtime.(decimal_of_string "0" -& i))
| Unop (Minus KMoney), [ELit (LMoney i)] ->
ELit (LMoney Runtime.(money_of_units_int 0 -$ i))
| Unop (Minus KDuration), [ELit (LDuration i)] ->
ELit (LDuration Runtime.(~-^i))
| Unop Not, [ELit (LBool b)] -> ELit (LBool (not b))
| Unop Length, [EArray es] ->
ELit (LInt (Runtime.integer_of_int (List.length es)))
| Unop GetDay, [ELit (LDate d)] ->
ELit (LInt Runtime.(day_of_month_of_date d))
| Unop GetMonth, [ELit (LDate d)] ->
ELit (LInt Runtime.(month_number_of_date d))
| Unop GetYear, [ELit (LDate d)] -> ELit (LInt Runtime.(year_of_date d))
| Unop FirstDayOfMonth, [ELit (LDate d)] ->
ELit (LDate Runtime.(first_day_of_month d))
| Unop LastDayOfMonth, [ELit (LDate d)] ->
ELit (LDate Runtime.(first_day_of_month d))
| Unop IntToRat, [ELit (LInt i)] -> ELit (LRat Runtime.(decimal_of_integer i))
| Unop MoneyToRat, [ELit (LMoney i)] ->
ELit (LRat Runtime.(decimal_of_money i))
| Unop RatToMoney, [ELit (LRat i)] ->
ELit (LMoney Runtime.(money_of_decimal i))
| Unop RoundMoney, [ELit (LMoney m)] -> ELit (LMoney Runtime.(money_round m))
| Unop RoundDecimal, [ELit (LRat m)] -> ELit (LRat Runtime.(decimal_round m))
| Unop (Log (entry, infos)), [e'] ->
if !Cli.trace_flag then (
match entry with
| VarDef _ ->
(* TODO: this usage of Format is broken, Formatting requires that all is
formatted in one pass, without going through intermediate "%s" *)
Cli.log_format "%*s%a %a: %s" (!log_indent * 2) "" Print.log_entry entry
Print.uid_list infos
(match e' with
| EAbs _ -> Cli.with_style [ANSITerminal.green] "<function>"
| _ ->
let expr_str =
Format.asprintf "%a" (Expr.format ctx ~debug:false) (List.hd args)
in
let expr_str =
Re.Pcre.substitute ~rex:(Re.Pcre.regexp "\n\\s*")
~subst:(fun _ -> " ")
expr_str
in
Cli.with_style [ANSITerminal.green] "%s" expr_str)
| PosRecordIfTrueBool -> (
match pos <> Pos.no_pos, e' with
| true, ELit (LBool true) ->
Cli.log_format "%*s%a%s:\n%s" (!log_indent * 2) "" Print.log_entry
entry
(Cli.with_style [ANSITerminal.green] "Definition applied")
(Cli.add_prefix_to_each_line (Pos.retrieve_loc_text pos) (fun _ ->
Format.asprintf "%*s" (!log_indent * 2) ""))
| _ -> ())
| BeginCall ->
Cli.log_format "%*s%a %a" (!log_indent * 2) "" Print.log_entry entry
Print.uid_list infos;
log_indent := !log_indent + 1
| EndCall ->
log_indent := !log_indent - 1;
Cli.log_format "%*s%a %a" (!log_indent * 2) "" Print.log_entry entry
Print.uid_list infos)
else ();
e'
| Unop _, [ELit LEmptyError] -> ELit LEmptyError
| _ ->
Errors.raise_multispanned_error Errors.raise_multispanned_error
([Some "Operator:", pos] ([Some "Operator:", pos]
@ List.mapi @ List.mapi
@ -307,6 +152,162 @@ let rec evaluate_operator
args) args)
"Operator applied to the wrong arguments\n\ "Operator applied to the wrong arguments\n\
(should not happen if the term was well-typed)" (should not happen if the term was well-typed)"
in
let open Runtime.Oper in
if List.exists (function ELit LEmptyError, _ -> true | _ -> false) args then
ELit LEmptyError
else
Operator.kind_dispatch op
~polymorphic:(fun op ->
match op, args with
| Length, [(EArray es, _)] ->
ELit (LInt (Runtime.integer_of_int (List.length es)))
| Log (entry, infos), [e'] ->
print_log ctx entry infos pos e';
Marked.unmark e'
| Eq, [(e1, _); (e2, _)] -> ELit (LBool (handle_eq ctx pos e1 e2))
| Map, [f; (EArray es, _)] ->
EArray
(List.map
(fun e' ->
evaluate_expr ctx
(Marked.same_mark_as (EApp { f; args = [e'] }) e'))
es)
| Reduce, [_; default; (EArray [], _)] -> Marked.unmark default
| Reduce, [f; _; (EArray (x0 :: xn), _)] ->
Marked.unmark
(List.fold_left
(fun acc x ->
evaluate_expr ctx
(Marked.same_mark_as (EApp { f; args = [acc; x] }) f))
x0 xn)
| Concat, [(EArray es1, _); (EArray es2, _)] -> EArray (es1 @ es2)
| Filter, [f; (EArray es, _)] ->
EArray
(List.filter
(fun e' ->
match
evaluate_expr ctx
(Marked.same_mark_as (EApp { f; args = [e'] }) e')
with
| ELit (LBool b), _ -> b
| _ ->
Errors.raise_spanned_error
(Expr.pos (List.nth args 0))
"This predicate evaluated to something else than a \
boolean (should not happen if the term was well-typed)")
es)
| Fold, [f; init; (EArray es, _)] ->
Marked.unmark
(List.fold_left
(fun acc e' ->
evaluate_expr ctx
(Marked.same_mark_as (EApp { f; args = [acc; e'] }) e'))
init es)
| (Length | Log _ | Eq | Map | Concat | Filter | Fold | Reduce), _ ->
err ())
~monomorphic:(fun op ->
let rlit =
match op, List.map (function ELit l, _ -> l | _ -> err ()) args with
| Not, [LBool b] -> LBool (o_not b)
| GetDay, [LDate d] -> LInt (o_getDay d)
| GetMonth, [LDate d] -> LInt (o_getMonth d)
| GetYear, [LDate d] -> LInt (o_getYear d)
| FirstDayOfMonth, [LDate d] -> LDate (o_firstDayOfMonth d)
| LastDayOfMonth, [LDate d] -> LDate (o_lastDayOfMonth d)
| And, [LBool b1; LBool b2] -> LBool (o_and b1 b2)
| Or, [LBool b1; LBool b2] -> LBool (o_or b1 b2)
| Xor, [LBool b1; LBool b2] -> LBool (o_xor b1 b2)
| ( ( Not | GetDay | GetMonth | GetYear | FirstDayOfMonth
| LastDayOfMonth | And | Or | Xor ),
_ ) ->
err ()
in
ELit rlit)
~resolved:(fun op ->
let rlit =
match op, List.map (function ELit l, _ -> l | _ -> err ()) args with
| Minus_int, [LInt x] -> LInt (o_minus_int x)
| Minus_rat, [LRat x] -> LRat (o_minus_rat x)
| Minus_mon, [LMoney x] -> LMoney (o_minus_mon x)
| Minus_dur, [LDuration x] -> LDuration (o_minus_dur x)
| ToRat_int, [LInt i] -> LRat (o_torat_int i)
| ToRat_mon, [LMoney i] -> LRat (o_torat_mon i)
| ToMoney_rat, [LRat i] -> LMoney (o_tomoney_rat i)
| Round_mon, [LMoney m] -> LMoney (o_round_mon m)
| Round_rat, [LRat m] -> LRat (o_round_rat m)
| Add_int_int, [LInt x; LInt y] -> LInt (o_add_int_int x y)
| Add_rat_rat, [LRat x; LRat y] -> LRat (o_add_rat_rat x y)
| Add_mon_mon, [LMoney x; LMoney y] -> LMoney (o_add_mon_mon x y)
| Add_dat_dur, [LDate x; LDuration y] -> LDate (o_add_dat_dur x y)
| Add_dur_dur, [LDuration x; LDuration y] ->
LDuration (o_add_dur_dur x y)
| Sub_int_int, [LInt x; LInt y] -> LInt (o_sub_int_int x y)
| Sub_rat_rat, [LRat x; LRat y] -> LRat (o_sub_rat_rat x y)
| Sub_mon_mon, [LMoney x; LMoney y] -> LMoney (o_sub_mon_mon x y)
| Sub_dat_dat, [LDate x; LDate y] -> LDuration (o_sub_dat_dat x y)
| Sub_dat_dur, [LDate x; LDuration y] -> LDate (o_sub_dat_dur x y)
| Sub_dur_dur, [LDuration x; LDuration y] ->
LDuration (o_sub_dur_dur x y)
| Mult_int_int, [LInt x; LInt y] -> LInt (o_mult_int_int x y)
| Mult_rat_rat, [LRat x; LRat y] -> LRat (o_mult_rat_rat x y)
| Mult_mon_rat, [LMoney x; LRat y] -> LMoney (o_mult_mon_rat x y)
| Mult_dur_int, [LDuration x; LInt y] ->
LDuration (o_mult_dur_int x y)
| Div_int_int, [LInt x; LInt y] -> LRat (protect o_div_int_int x y)
| Div_rat_rat, [LRat x; LRat y] -> LRat (protect o_div_rat_rat x y)
| Div_mon_mon, [LMoney x; LMoney y] ->
LRat (protect o_div_mon_mon x y)
| Div_mon_rat, [LMoney x; LRat y] ->
LMoney (protect o_div_mon_rat x y)
| Lt_int_int, [LInt x; LInt y] -> LBool (o_lt_int_int x y)
| Lt_rat_rat, [LRat x; LRat y] -> LBool (o_lt_rat_rat x y)
| Lt_mon_mon, [LMoney x; LMoney y] -> LBool (o_lt_mon_mon x y)
| Lt_dat_dat, [LDate x; LDate y] -> LBool (o_lt_dat_dat x y)
| Lt_dur_dur, [LDuration x; LDuration y] ->
LBool (protect o_lt_dur_dur x y)
| Lte_int_int, [LInt x; LInt y] -> LBool (o_lte_int_int x y)
| Lte_rat_rat, [LRat x; LRat y] -> LBool (o_lte_rat_rat x y)
| Lte_mon_mon, [LMoney x; LMoney y] -> LBool (o_lte_mon_mon x y)
| Lte_dat_dat, [LDate x; LDate y] -> LBool (o_lte_dat_dat x y)
| Lte_dur_dur, [LDuration x; LDuration y] ->
LBool (protect o_lte_dur_dur x y)
| Gt_int_int, [LInt x; LInt y] -> LBool (o_gt_int_int x y)
| Gt_rat_rat, [LRat x; LRat y] -> LBool (o_gt_rat_rat x y)
| Gt_mon_mon, [LMoney x; LMoney y] -> LBool (o_gt_mon_mon x y)
| Gt_dat_dat, [LDate x; LDate y] -> LBool (o_gt_dat_dat x y)
| Gt_dur_dur, [LDuration x; LDuration y] ->
LBool (protect o_gt_dur_dur x y)
| Gte_int_int, [LInt x; LInt y] -> LBool (o_gte_int_int x y)
| Gte_rat_rat, [LRat x; LRat y] -> LBool (o_gte_rat_rat x y)
| Gte_mon_mon, [LMoney x; LMoney y] -> LBool (o_gte_mon_mon x y)
| Gte_dat_dat, [LDate x; LDate y] -> LBool (o_gte_dat_dat x y)
| Gte_dur_dur, [LDuration x; LDuration y] ->
LBool (protect o_gte_dur_dur x y)
| Eq_int_int, [LInt x; LInt y] -> LBool (o_eq_int_int x y)
| Eq_rat_rat, [LRat x; LRat y] -> LBool (o_eq_rat_rat x y)
| Eq_mon_mon, [LMoney x; LMoney y] -> LBool (o_eq_mon_mon x y)
| Eq_dat_dat, [LDate x; LDate y] -> LBool (o_eq_dat_dat x y)
| Eq_dur_dur, [LDuration x; LDuration y] ->
LBool (protect o_eq_dur_dur x y)
| ( ( Minus_int | Minus_rat | Minus_mon | Minus_dur | ToRat_int
| ToRat_mon | ToMoney_rat | Round_rat | Round_mon | Add_int_int
| Add_rat_rat | Add_mon_mon | Add_dat_dur | Add_dur_dur
| Sub_int_int | Sub_rat_rat | Sub_mon_mon | Sub_dat_dat
| Sub_dat_dur | Sub_dur_dur | Mult_int_int | Mult_rat_rat
| Mult_mon_rat | Mult_dur_int | Div_int_int | Div_rat_rat
| Div_mon_mon | Div_mon_rat | Lt_int_int | Lt_rat_rat | Lt_mon_mon
| Lt_dat_dat | Lt_dur_dur | Lte_int_int | Lte_rat_rat
| Lte_mon_mon | Lte_dat_dat | Lte_dur_dur | Gt_int_int
| Gt_rat_rat | Gt_mon_mon | Gt_dat_dat | Gt_dur_dur | Gte_int_int
| Gte_rat_rat | Gte_mon_mon | Gte_dat_dat | Gte_dur_dur
| Eq_int_int | Eq_rat_rat | Eq_mon_mon | Eq_dat_dat | Eq_dur_dur
),
_ ) ->
err ()
in
ELit rlit)
~overloaded:(fun _ -> assert false)
and evaluate_expr (ctx : decl_ctx) (e : 'm Ast.expr) : 'm Ast.expr = and evaluate_expr (ctx : decl_ctx) (e : 'm Ast.expr) : 'm Ast.expr =
match Marked.unmark e with match Marked.unmark e with
@ -314,11 +315,11 @@ and evaluate_expr (ctx : decl_ctx) (e : 'm Ast.expr) : 'm Ast.expr =
Errors.raise_spanned_error (Expr.pos e) Errors.raise_spanned_error (Expr.pos e)
"free variable found at evaluation (should not happen if term was \ "free variable found at evaluation (should not happen if term was \
well-typed" well-typed"
| EApp (e1, args) -> ( | EApp { f = e1; args } -> (
let e1 = evaluate_expr ctx e1 in let e1 = evaluate_expr ctx e1 in
let args = List.map (evaluate_expr ctx) args in let args = List.map (evaluate_expr ctx) args in
match Marked.unmark e1 with match Marked.unmark e1 with
| EAbs (binder, _) -> | EAbs { binder; _ } ->
if Bindlib.mbinder_arity binder = List.length args then if Bindlib.mbinder_arity binder = List.length args then
evaluate_expr ctx evaluate_expr ctx
(Bindlib.msubst binder (Array.of_list (List.map Marked.unmark args))) (Bindlib.msubst binder (Array.of_list (List.map Marked.unmark args)))
@ -327,7 +328,7 @@ and evaluate_expr (ctx : decl_ctx) (e : 'm Ast.expr) : 'm Ast.expr =
"wrong function call, expected %d arguments, got %d" "wrong function call, expected %d arguments, got %d"
(Bindlib.mbinder_arity binder) (Bindlib.mbinder_arity binder)
(List.length args) (List.length args)
| EOp op -> | EOp { op; _ } ->
Marked.same_mark_as (evaluate_operator ctx op (Expr.pos e) args) e Marked.same_mark_as (evaluate_operator ctx op (Expr.pos e) args) e
| ELit LEmptyError -> Marked.same_mark_as (ELit LEmptyError) e | ELit LEmptyError -> Marked.same_mark_as (ELit LEmptyError) e
| _ -> | _ ->
@ -335,69 +336,66 @@ and evaluate_expr (ctx : decl_ctx) (e : 'm Ast.expr) : 'm Ast.expr =
"function has not been reduced to a lambda at evaluation (should not \ "function has not been reduced to a lambda at evaluation (should not \
happen if the term was well-typed") happen if the term was well-typed")
| EAbs _ | ELit _ | EOp _ -> e (* these are values *) | EAbs _ | ELit _ | EOp _ -> e (* these are values *)
| ETuple (es, s) -> | EStruct { fields = es; name } ->
let new_es = List.map (evaluate_expr ctx) es in let new_es = StructField.Map.map (evaluate_expr ctx) es in
if List.exists is_empty_error new_es then if StructField.Map.exists (fun _ e -> is_empty_error e) new_es then
Marked.same_mark_as (ELit LEmptyError) e Marked.same_mark_as (ELit LEmptyError) e
else Marked.same_mark_as (ETuple (new_es, s)) e else Marked.same_mark_as (EStruct { fields = new_es; name }) e
| ETupleAccess (e1, n, s, _) -> ( | EStructAccess { e = e1; name = s; field } -> (
let e1 = evaluate_expr ctx e1 in let e1 = evaluate_expr ctx e1 in
match Marked.unmark e1 with match Marked.unmark e1 with
| ETuple (es, s') -> ( | EStruct { fields = es; name = s' } -> (
(match s, s' with if not (StructName.equal s s') then
| None, None -> ()
| Some s, Some s' when s = s' -> ()
| _ ->
Errors.raise_multispanned_error Errors.raise_multispanned_error
[None, Expr.pos e; None, Expr.pos e1] [None, Expr.pos e; None, Expr.pos e1]
"Error during tuple access: not the same structs (should not happen \ "Error during struct access: not the same structs (should not happen \
if the term was well-typed)"); if the term was well-typed)";
match List.nth_opt es n with match StructField.Map.find_opt field es with
| Some e' -> e' | Some e' -> e'
| None -> | None ->
Errors.raise_spanned_error (Expr.pos e1) Errors.raise_spanned_error (Expr.pos e1)
"The tuple has %d components but the %i-th element was requested \ "Invalid field access %a in struct %a (should not happen if the term \
(should not happen if the term was well-type)" was well-typed)"
(List.length es) n) StructField.format_t field StructName.format_t s)
| ELit LEmptyError -> Marked.same_mark_as (ELit LEmptyError) e | ELit LEmptyError -> Marked.same_mark_as (ELit LEmptyError) e
| _ -> | _ ->
Errors.raise_spanned_error (Expr.pos e1) Errors.raise_spanned_error (Expr.pos e1)
"The expression %a should be a tuple with %d components but is not \ "The expression %a should be a struct %a but is not (should not happen \
(should not happen if the term was well-typed)" if the term was well-typed)"
(Expr.format ctx ~debug:true) (Expr.format ctx ~debug:true)
e n) e StructName.format_t s)
| EInj (e1, n, en, ts) -> | EInj { e = e1; name; cons } ->
let e1' = evaluate_expr ctx e1 in let e1' = evaluate_expr ctx e1 in
if is_empty_error e1' then Marked.same_mark_as (ELit LEmptyError) e if is_empty_error e then Marked.same_mark_as (ELit LEmptyError) e
else Marked.same_mark_as (EInj (e1', n, en, ts)) e else Marked.same_mark_as (EInj { e = e1'; name; cons }) e
| EMatch (e1, es, e_name) -> ( | EMatch { e = e1; cases = es; name } -> (
let e1 = evaluate_expr ctx e1 in let e1 = evaluate_expr ctx e1 in
match Marked.unmark e1 with match Marked.unmark e1 with
| EInj (e1, n, e_name', _) -> | EInj { e = e1; cons; name = name' } ->
if e_name <> e_name' then if not (EnumName.equal name name') then
Errors.raise_multispanned_error Errors.raise_multispanned_error
[None, Expr.pos e; None, Expr.pos e1] [None, Expr.pos e; None, Expr.pos e1]
"Error during match: two different enums found (should not happen if \ "Error during match: two different enums found (should not happen if \
the term was well-typed)"; the term was well-typed)";
let es_n = let es_n =
match List.nth_opt es n with match EnumConstructor.Map.find_opt cons es with
| Some es_n -> es_n | Some es_n -> es_n
| None -> | None ->
Errors.raise_spanned_error (Expr.pos e) Errors.raise_spanned_error (Expr.pos e)
"sum type index error (should not happen if the term was \ "sum type index error (should not happen if the term was \
well-typed)" well-typed)"
in in
let new_e = Marked.same_mark_as (EApp (es_n, [e1])) e in let new_e = Marked.same_mark_as (EApp { f = es_n; args = [e1] }) e in
evaluate_expr ctx new_e evaluate_expr ctx new_e
| ELit LEmptyError -> Marked.same_mark_as (ELit LEmptyError) e | ELit LEmptyError -> Marked.same_mark_as (ELit LEmptyError) e
| _ -> | _ ->
Errors.raise_spanned_error (Expr.pos e1) Errors.raise_spanned_error (Expr.pos e1)
"Expected a term having a sum type as an argument to a match (should \ "Expected a term having a sum type as an argument to a match (should \
not happen if the term was well-typed") not happen if the term was well-typed")
| EDefault (exceptions, just, cons) -> ( | EDefault { excepts; just; cons } -> (
let exceptions = List.map (evaluate_expr ctx) exceptions in let excepts = List.map (evaluate_expr ctx) excepts in
let empty_count = List.length (List.filter is_empty_error exceptions) in let empty_count = List.length (List.filter is_empty_error excepts) in
match List.length exceptions - empty_count with match List.length excepts - empty_count with
| 0 -> ( | 0 -> (
let just = evaluate_expr ctx just in let just = evaluate_expr ctx just in
match Marked.unmark just with match Marked.unmark just with
@ -408,19 +406,19 @@ and evaluate_expr (ctx : decl_ctx) (e : 'm Ast.expr) : 'm Ast.expr =
Errors.raise_spanned_error (Expr.pos e) Errors.raise_spanned_error (Expr.pos e)
"Default justification has not been reduced to a boolean at \ "Default justification has not been reduced to a boolean at \
evaluation (should not happen if the term was well-typed") evaluation (should not happen if the term was well-typed")
| 1 -> List.find (fun sub -> not (is_empty_error sub)) exceptions | 1 -> List.find (fun sub -> not (is_empty_error sub)) excepts
| _ -> | _ ->
Errors.raise_multispanned_error Errors.raise_multispanned_error
(List.map (List.map
(fun except -> (fun except ->
Some "This consequence has a valid justification:", Expr.pos except) Some "This consequence has a valid justification:", Expr.pos except)
(List.filter (fun sub -> not (is_empty_error sub)) exceptions)) (List.filter (fun sub -> not (is_empty_error sub)) excepts))
"There is a conflict between multiple valid consequences for assigning \ "There is a conflict between multiple valid consequences for assigning \
the same variable.") the same variable.")
| EIfThenElse (cond, et, ef) -> ( | EIfThenElse { cond; etrue; efalse } -> (
match Marked.unmark (evaluate_expr ctx cond) with match Marked.unmark (evaluate_expr ctx cond) with
| ELit (LBool true) -> evaluate_expr ctx et | ELit (LBool true) -> evaluate_expr ctx etrue
| ELit (LBool false) -> evaluate_expr ctx ef | ELit (LBool false) -> evaluate_expr ctx efalse
| ELit LEmptyError -> Marked.same_mark_as (ELit LEmptyError) e | ELit LEmptyError -> Marked.same_mark_as (ELit LEmptyError) e
| _ -> | _ ->
Errors.raise_spanned_error (Expr.pos cond) Errors.raise_spanned_error (Expr.pos cond)
@ -431,7 +429,7 @@ and evaluate_expr (ctx : decl_ctx) (e : 'm Ast.expr) : 'm Ast.expr =
if List.exists is_empty_error new_es then if List.exists is_empty_error new_es then
Marked.same_mark_as (ELit LEmptyError) e Marked.same_mark_as (ELit LEmptyError) e
else Marked.same_mark_as (EArray new_es) e else Marked.same_mark_as (EArray new_es) e
| ErrorOnEmpty e' -> | EErrorOnEmpty e' ->
let e' = evaluate_expr ctx e' in let e' = evaluate_expr ctx e' in
if Marked.unmark e' = ELit LEmptyError then if Marked.unmark e' = ELit LEmptyError then
Errors.raise_spanned_error (Expr.pos e') Errors.raise_spanned_error (Expr.pos e')
@ -443,23 +441,44 @@ and evaluate_expr (ctx : decl_ctx) (e : 'm Ast.expr) : 'm Ast.expr =
| ELit (LBool true) -> Marked.same_mark_as (ELit LUnit) e' | ELit (LBool true) -> Marked.same_mark_as (ELit LUnit) e'
| ELit (LBool false) -> ( | ELit (LBool false) -> (
match Marked.unmark e' with match Marked.unmark e' with
| ErrorOnEmpty | EErrorOnEmpty
( EApp ( EApp
((EOp (Binop op), _), [((ELit _, _) as e1); ((ELit _, _) as e2)]), {
_ ) f = EOp { op; _ }, _;
| EApp args = [((ELit _, _) as e1); ((ELit _, _) as e2)];
( (EOp (Unop (Log _)), _), },
[ _ ) ->
( EApp
( (EOp (Binop op), _),
[((ELit _, _) as e1); ((ELit _, _) as e2)] ),
_ );
] )
| EApp ((EOp (Binop op), _), [((ELit _, _) as e1); ((ELit _, _) as e2)])
->
Errors.raise_spanned_error (Expr.pos e') "Assertion failed: %a %a %a" Errors.raise_spanned_error (Expr.pos e') "Assertion failed: %a %a %a"
(Expr.format ctx ~debug:false) (Expr.format ctx ~debug:false)
e1 Print.binop op e1 Print.operator op
(Expr.format ctx ~debug:false)
e2
| EApp
{
f = EOp { op = Log _; _ }, _;
args =
[
( EApp
{
f = EOp { op; _ }, _;
args = [((ELit _, _) as e1); ((ELit _, _) as e2)];
},
_ );
];
} ->
Errors.raise_spanned_error (Expr.pos e') "Assertion failed: %a %a %a"
(Expr.format ctx ~debug:false)
e1 Print.operator op
(Expr.format ctx ~debug:false)
e2
| EApp
{
f = EOp { op; _ }, _;
args = [((ELit _, _) as e1); ((ELit _, _) as e2)];
} ->
Errors.raise_spanned_error (Expr.pos e') "Assertion failed: %a %a %a"
(Expr.format ctx ~debug:false)
e1 Print.operator op
(Expr.format ctx ~debug:false) (Expr.format ctx ~debug:false)
e2 e2
| _ -> | _ ->
@ -479,19 +498,22 @@ let interpret_program :
fun (ctx : decl_ctx) (e : 'm Ast.expr) : fun (ctx : decl_ctx) (e : 'm Ast.expr) :
(Uid.MarkedString.info * 'm Ast.expr) list -> (Uid.MarkedString.info * 'm Ast.expr) list ->
match evaluate_expr ctx e with match evaluate_expr ctx e with
| (EAbs (_, [((TStruct s_in, _) as _targs)]), mark_e) as e -> begin | (EAbs { tys = [((TStruct s_in, _) as _targs)]; _ }, mark_e) as e -> begin
(* At this point, the interpreter seeks to execute the scope but does not (* At this point, the interpreter seeks to execute the scope but does not
have a way to retrieve input values from the command line. [taus] contain have a way to retrieve input values from the command line. [taus] contain
the types of the scope arguments. For [context] arguments, we can provide the types of the scope arguments. For [context] arguments, we can provide
an empty thunked term. But for [input] arguments of another type, we an empty thunked term. But for [input] arguments of another type, we
cannot provide anything so we have to fail. *) cannot provide anything so we have to fail. *)
let taus = StructMap.find s_in ctx.ctx_structs in let taus = StructName.Map.find s_in ctx.ctx_structs in
let application_term = let application_term =
List.map StructField.Map.map
(fun (_, ty) -> (fun ty ->
match Marked.unmark ty with match Marked.unmark ty with
| TArrow ((TLit TUnit, _), ty_in) -> | TArrow (ty_in, ty_out) ->
Expr.empty_thunked_term (Expr.with_ty mark_e ty_in) Expr.make_abs
[| Var.make "_" |]
(Bindlib.box (ELit LEmptyError), Expr.with_ty mark_e ty_out)
[ty_in] (Expr.mark_pos mark_e)
| _ -> | _ ->
Errors.raise_spanned_error (Marked.get_mark ty) Errors.raise_spanned_error (Marked.get_mark ty)
"This scope needs input arguments to be executed. But the Catala \ "This scope needs input arguments to be executed. But the Catala \
@ -503,17 +525,14 @@ let interpret_program :
in in
let to_interpret = let to_interpret =
Expr.make_app (Expr.box e) Expr.make_app (Expr.box e)
[Expr.make_tuple application_term (Some s_in) mark_e] [Expr.estruct s_in application_term mark_e]
(Expr.pos e) (Expr.pos e)
in in
match Marked.unmark (evaluate_expr ctx (Expr.unbox to_interpret)) with match Marked.unmark (evaluate_expr ctx (Expr.unbox to_interpret)) with
| ETuple (args, Some s_out) -> | EStruct { fields; _ } ->
let s_out_fields = List.map
List.map (fun (fld, e) -> StructField.get_info fld, e)
(fun (f, _) -> StructFieldName.get_info f) (StructField.Map.bindings fields)
(StructMap.find s_out ctx.ctx_structs)
in
List.map2 (fun arg var -> var, arg) args s_out_fields
| _ -> | _ ->
Errors.raise_spanned_error (Expr.pos e) Errors.raise_spanned_error (Expr.pos e)
"The interpretation of a program should always yield a struct \ "The interpretation of a program should always yield a struct \

View File

@ -16,7 +16,7 @@
(** Reference interpreter for the default calculus *) (** Reference interpreter for the default calculus *)
open Utils open Catala_utils
open Shared_ast open Shared_ast
val evaluate_expr : decl_ctx -> 'm Ast.expr -> 'm Ast.expr val evaluate_expr : decl_ctx -> 'm Ast.expr -> 'm Ast.expr

View File

@ -14,7 +14,7 @@
WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under License for the specific language governing permissions and limitations under
the License. *) the License. *)
open Utils open Catala_utils
open Shared_ast open Shared_ast
open Ast open Ast
@ -24,179 +24,206 @@ type partial_evaluation_ctx = {
} }
let rec partial_evaluation (ctx : partial_evaluation_ctx) (e : 'm expr) : let rec partial_evaluation (ctx : partial_evaluation_ctx) (e : 'm expr) :
'm expr Bindlib.box = (dcalc, 'm mark) boxed_gexpr =
(* We proceed bottom-up, first apply on the subterms *)
let e = Expr.map ~f:(partial_evaluation ctx) e in
let mark = Marked.get_mark e in let mark = Marked.get_mark e in
let rec_helper = partial_evaluation ctx in (* Then reduce the parent node *)
match Marked.unmark e with let reduce e =
| EApp (* Todo: improve the handling of eapp(log,elit) cases here, it obfuscates
( (( EOp (Unop Not), _ the matches and the log calls are not preserved, which would be a good
| EApp ((EOp (Unop (Log _)), _), [(EOp (Unop Not), _)]), _ ) as op), property *)
[e1] ) -> match Marked.unmark e with
(* reduction of logical not *) | EApp
(Bindlib.box_apply (fun e1 -> {
match e1 with f =
| ELit (LBool false), _ -> ELit (LBool true), mark ( EOp { op = Not; _ }, _
| ELit (LBool true), _ -> ELit (LBool false), mark | ( EApp
| _ -> EApp (op, [e1]), mark)) {
(rec_helper e1) f = EOp { op = Log _; _ }, _;
| EApp args = [(EOp { op = Not; _ }, _)];
( (( EOp (Binop Or), _ },
| EApp ((EOp (Unop (Log _)), _), [(EOp (Binop Or), _)]), _ ) as op), _ ) ) as op;
[e1; e2] ) -> args = [e1];
(* reduction of logical or *) } -> (
(Bindlib.box_apply2 (fun e1 e2 -> (* reduction of logical not *)
match e1, e2 with match e1 with
| (ELit (LBool false), _), new_e | new_e, (ELit (LBool false), _) -> | ELit (LBool false), _ -> ELit (LBool true)
new_e | ELit (LBool true), _ -> ELit (LBool false)
| (ELit (LBool true), _), _ | _, (ELit (LBool true), _) -> | e1 -> EApp { f = op; args = [e1] })
ELit (LBool true), mark | EApp
| _ -> EApp (op, [e1; e2]), mark)) {
(rec_helper e1) (rec_helper e2) f =
| EApp ( EOp { op = Or; _ }, _
( (( EOp (Binop And), _ | ( EApp
| EApp ((EOp (Unop (Log _)), _), [(EOp (Binop And), _)]), _ ) as op), {
[e1; e2] ) -> f = EOp { op = Log _; _ }, _;
(* reduction of logical and *) args = [(EOp { op = Or; _ }, _)];
(Bindlib.box_apply2 (fun e1 e2 -> },
match e1, e2 with _ ) ) as op;
| (ELit (LBool true), _), new_e | new_e, (ELit (LBool true), _) -> args = [e1; e2];
new_e } -> (
| (ELit (LBool false), _), _ | _, (ELit (LBool false), _) -> (* reduction of logical or *)
ELit (LBool false), mark match e1, e2 with
| _ -> EApp (op, [e1; e2]), mark)) | (ELit (LBool false), _), new_e | new_e, (ELit (LBool false), _) ->
(rec_helper e1) (rec_helper e2) Marked.unmark new_e
| EVar x -> Bindlib.box_apply (fun x -> x, mark) (Bindlib.box_var x) | (ELit (LBool true), _), _ | _, (ELit (LBool true), _) ->
| ETuple (args, s_name) -> ELit (LBool true)
Bindlib.box_apply | _ -> EApp { f = op; args = [e1; e2] })
(fun args -> ETuple (args, s_name), mark) | EApp
(List.map rec_helper args |> Bindlib.box_list) {
| ETupleAccess (arg, i, s_name, typs) -> f =
Bindlib.box_apply ( EOp { op = And; _ }, _
(fun arg -> ETupleAccess (arg, i, s_name, typs), mark) | ( EApp
(rec_helper arg) {
| EInj (arg, i, e_name, typs) -> f = EOp { op = Log _; _ }, _;
Bindlib.box_apply args = [(EOp { op = And; _ }, _)];
(fun arg -> EInj (arg, i, e_name, typs), mark) },
(rec_helper arg) _ ) ) as op;
| EMatch (arg, arms, e_name) -> args = [e1; e2];
Bindlib.box_apply2 } -> (
(fun arg arms -> (* reduction of logical and *)
match arg, arms with match e1, e2 with
| (EInj (e1, i, e_name', _ts), _), _ | (ELit (LBool true), _), new_e | new_e, (ELit (LBool true), _) ->
when EnumName.compare e_name e_name' = 0 -> Marked.unmark new_e
(* iota reduction *) | (ELit (LBool false), _), _ | _, (ELit (LBool false), _) ->
EApp (List.nth arms i, [e1]), mark ELit (LBool false)
| _ -> EMatch (arg, arms, e_name), mark) | _ -> EApp { f = op; args = [e1; e2] })
(rec_helper arg) | EMatch { e = EInj { e; name = name1; cons }, _; cases; name }
(List.map rec_helper arms |> Bindlib.box_list) when EnumName.equal name name1 ->
| EArray args -> (* iota reduction *)
Bindlib.box_apply EApp { f = EnumConstructor.Map.find cons cases; args = [e] }
(fun args -> EArray args, mark) | EApp { f = EAbs { binder; _ }, _; args } ->
(List.map rec_helper args |> Bindlib.box_list) (* beta reduction *)
| ELit l -> Bindlib.box (ELit l, mark) Marked.unmark (Bindlib.msubst binder (List.map fst args |> Array.of_list))
| EAbs (binder, typs) -> | EDefault { excepts; just; cons } -> (
let vars, body = Bindlib.unmbind binder in (* TODO: mechanically prove each of these optimizations correct :) *)
let new_body = rec_helper body in let excepts =
let new_binder = Bindlib.bind_mvar vars new_body in List.filter
Bindlib.box_apply (fun binder -> EAbs (binder, typs), mark) new_binder (fun except -> Marked.unmark except <> ELit LEmptyError)
| EApp (f, args) -> excepts
Bindlib.box_apply2 (* we can discard the exceptions that are always empty error *)
(fun f args -> in
match Marked.unmark f with let value_except_count =
| EAbs (binder, _ts) -> List.fold_left
(* beta reduction *) (fun nb except -> if Expr.is_value except then nb + 1 else nb)
Bindlib.msubst binder (List.map fst args |> Array.of_list) 0 excepts
| _ -> EApp (f, args), mark) in
(rec_helper f) if value_except_count > 1 then
(List.map rec_helper args |> Bindlib.box_list) (* at this point we know a conflict error will be triggered so we just
| EAssert e1 -> Bindlib.box_apply (fun e1 -> EAssert e1, mark) (rec_helper e1) feed the expression to the interpreter that will print the beautiful
| EOp op -> Bindlib.box (EOp op, mark) right error message *)
| EDefault (exceptions, just, cons) -> Marked.unmark (Interpreter.evaluate_expr ctx.decl_ctx e)
Bindlib.box_apply3 else
(fun exceptions just cons -> match excepts, just with
(* TODO: mechanically prove each of these optimizations correct :) *) | [except], _ when Expr.is_value except ->
match
( List.filter
(fun except ->
match Marked.unmark except with
| ELit LEmptyError -> false
| _ -> true)
exceptions
(* we can discard the exceptions that are always empty error *),
just,
cons )
with
| exceptions, just, cons
when List.fold_left
(fun nb except -> if Expr.is_value except then nb + 1 else nb)
0 exceptions
> 1 ->
(* at this point we know a conflict error will be triggered so we just
feed the expression to the interpreter that will print the
beautiful right error message *)
Interpreter.evaluate_expr ctx.decl_ctx
(EDefault (exceptions, just, cons), mark)
| [except], _, _ when Expr.is_value except ->
(* if there is only one exception and it is a non-empty value it is (* if there is only one exception and it is a non-empty value it is
always chosen *) always chosen *)
except Marked.unmark except
| ( [], | ( [],
( ( ELit (LBool true) ( ( ELit (LBool true)
| EApp ((EOp (Unop (Log _)), _), [(ELit (LBool true), _)]) ), | EApp
_ ), {
cons ) -> f = EOp { op = Log _; _ }, _;
cons args = [(ELit (LBool true), _)];
} ),
_ ) ) ->
Marked.unmark cons
| ( [], | ( [],
( ( ELit (LBool false) ( ( ELit (LBool false)
| EApp ((EOp (Unop (Log _)), _), [(ELit (LBool false), _)]) ), | EApp
_ ), {
_ ) -> f = EOp { op = Log _; _ }, _;
ELit LEmptyError, mark args = [(ELit (LBool false), _)];
| [], just, cons when not !Cli.avoid_exceptions_flag -> } ),
_ ) ) ->
ELit LEmptyError
| [], just when not !Cli.avoid_exceptions_flag ->
(* without exceptions, a default is just an [if then else] raising an (* without exceptions, a default is just an [if then else] raising an
error in the else case. This exception is only valid in the context error in the else case. This exception is only valid in the context
of compilation_with_exceptions, so we desactivate with a global of compilation_with_exceptions, so we desactivate with a global
flag to know if we will be compiling using exceptions or the option flag to know if we will be compiling using exceptions or the option
monad. *) monad. FIXME: move this optimisation somewhere else to avoid this
EIfThenElse (just, cons, (ELit LEmptyError, mark)), mark check *)
| exceptions, just, cons -> EDefault (exceptions, just, cons), mark) EIfThenElse
(List.map rec_helper exceptions |> Bindlib.box_list) { cond = just; etrue = cons; efalse = ELit LEmptyError, mark }
(rec_helper just) (rec_helper cons) | excepts, just -> EDefault { excepts; just; cons })
| EIfThenElse (e1, e2, e3) -> | EIfThenElse
Bindlib.box_apply3 {
(fun e1 e2 e3 -> cond =
match Marked.unmark e1, Marked.unmark e2, Marked.unmark e3 with ( ELit (LBool true), _
| ELit (LBool true), _, _ | ( EApp
| EApp ((EOp (Unop (Log _)), _), [(ELit (LBool true), _)]), _, _ -> {
e2 f = EOp { op = Log _; _ }, _;
| ELit (LBool false), _, _ args = [(ELit (LBool true), _)];
| EApp ((EOp (Unop (Log _)), _), [(ELit (LBool false), _)]), _, _ -> },
e3 _ ) );
| ( _, etrue;
( ELit (LBool true) _;
| EApp ((EOp (Unop (Log _)), _), [(ELit (LBool true), _)]) ), } ->
( ELit (LBool false) Marked.unmark etrue
| EApp ((EOp (Unop (Log _)), _), [(ELit (LBool false), _)]) ) ) -> | EIfThenElse
e1 {
| _ when Expr.equal e2 e3 -> e2 cond =
| _ -> EIfThenElse (e1, e2, e3), mark) ( ( ELit (LBool false)
(rec_helper e1) (rec_helper e2) (rec_helper e3) | EApp
| ErrorOnEmpty e1 -> {
Bindlib.box_apply (fun e1 -> ErrorOnEmpty e1, mark) (rec_helper e1) f = EOp { op = Log _; _ }, _;
args = [(ELit (LBool false), _)];
} ),
_ );
efalse;
_;
} ->
Marked.unmark efalse
| EIfThenElse
{
cond;
etrue =
( ( ELit (LBool btrue)
| EApp
{
f = EOp { op = Log _; _ }, _;
args = [(ELit (LBool btrue), _)];
} ),
_ );
efalse =
( ( ELit (LBool bfalse)
| EApp
{
f = EOp { op = Log _; _ }, _;
args = [(ELit (LBool bfalse), _)];
} ),
_ );
} ->
if btrue && not bfalse then Marked.unmark cond
else if (not btrue) && bfalse then
EApp
{
f = EOp { op = Not; tys = [TLit TBool, Expr.mark_pos mark] }, mark;
args = [cond];
}
(* note: this last call eliminates the condition & might skip log calls
as well *)
else (* btrue = bfalse *) ELit (LBool btrue)
| e -> e
in
Expr.Box.app1 e reduce mark
let optimize_expr (decl_ctx : decl_ctx) (e : 'm expr) = let optimize_expr (decl_ctx : decl_ctx) (e : 'm expr) =
partial_evaluation { var_values = Var.Map.empty; decl_ctx } e partial_evaluation { var_values = Var.Map.empty; decl_ctx } e
let rec scope_lets_map let rec scope_lets_map
(t : 'a -> 'm expr -> 'm expr Bindlib.box) (t : 'a -> 'm expr -> (dcalc, 'm mark) boxed_gexpr)
(ctx : 'a) (ctx : 'a)
(scope_body_expr : 'm expr scope_body_expr) : (scope_body_expr : 'm expr scope_body_expr) :
'm expr scope_body_expr Bindlib.box = 'm expr scope_body_expr Bindlib.box =
match scope_body_expr with match scope_body_expr with
| Result e -> Bindlib.box_apply (fun e' -> Result e') (t ctx e) | Result e ->
Bindlib.box_apply (fun e' -> Result e') (Expr.Box.lift (t ctx e))
| ScopeLet scope_let -> | ScopeLet scope_let ->
let var, next = Bindlib.unbind scope_let.scope_let_next in let var, next = Bindlib.unbind scope_let.scope_let_next in
let new_scope_let_expr = t ctx scope_let.scope_let_expr in let new_scope_let_expr = Expr.Box.lift (t ctx scope_let.scope_let_expr) in
let new_next = scope_lets_map t ctx next in let new_next = scope_lets_map t ctx next in
let new_next = Bindlib.bind_var var new_next in let new_next = Bindlib.bind_var var new_next in
Bindlib.box_apply2 Bindlib.box_apply2
@ -210,7 +237,7 @@ let rec scope_lets_map
new_scope_let_expr new_next new_scope_let_expr new_next
let rec scopes_map let rec scopes_map
(t : 'a -> 'm expr -> 'm expr Bindlib.box) (t : 'a -> 'm expr -> (dcalc, 'm mark) boxed_gexpr)
(ctx : 'a) (ctx : 'a)
(scopes : 'm expr scopes) : 'm expr scopes Bindlib.box = (scopes : 'm expr scopes) : 'm expr scopes Bindlib.box =
match scopes with match scopes with
@ -241,7 +268,7 @@ let rec scopes_map
new_scope_body_expr new_scope_next new_scope_body_expr new_scope_next
let program_map let program_map
(t : 'a -> 'm expr -> 'm expr Bindlib.box) (t : 'a -> 'm expr -> (dcalc, 'm mark) boxed_gexpr)
(ctx : 'a) (ctx : 'a)
(p : 'm program) : 'm program Bindlib.box = (p : 'm program) : 'm program Bindlib.box =
Bindlib.box_apply Bindlib.box_apply

View File

@ -20,5 +20,5 @@
open Shared_ast open Shared_ast
open Ast open Ast
val optimize_expr : decl_ctx -> 'm expr -> 'm expr Bindlib.box val optimize_expr : decl_ctx -> 'm expr -> (dcalc, 'm mark) boxed_gexpr
val optimize_program : 'm program -> 'm program val optimize_program : 'm program -> 'm program

View File

@ -16,25 +16,11 @@
(** Abstract syntax tree of the desugared representation *) (** Abstract syntax tree of the desugared representation *)
open Utils open Catala_utils
open Shared_ast open Shared_ast
(** {1 Names, Maps and Keys} *) (** {1 Names, Maps and Keys} *)
module IdentMap : Map.S with type key = String.t = Map.Make (String)
module RuleName : Uid.Id with type info = Uid.MarkedString.info =
Uid.Make (Uid.MarkedString) ()
module RuleMap : Map.S with type key = RuleName.t = Map.Make (RuleName)
module RuleSet : Set.S with type elt = RuleName.t = Set.Make (RuleName)
module LabelName : Uid.Id with type info = Uid.MarkedString.info =
Uid.Make (Uid.MarkedString) ()
module LabelMap : Map.S with type key = LabelName.t = Map.Make (LabelName)
module LabelSet : Set.S with type elt = LabelName.t = Set.Make (LabelName)
(** Inside a scope, a definition can refer either to a scope def, or a subscope (** Inside a scope, a definition can refer either to a scope def, or a subscope
def *) def *)
module ScopeDef = struct module ScopeDef = struct
@ -103,6 +89,9 @@ module ExprMap = Map.Make (struct
let compare = Expr.compare let compare = Expr.compare
end) end)
type io_input = NoInput | OnlyInput | Reentrant
type io = { io_output : bool Marked.pos; io_input : io_input Marked.pos }
type exception_situation = type exception_situation =
| BaseCase | BaseCase
| ExceptionToLabel of LabelName.t Marked.pos | ExceptionToLabel of LabelName.t Marked.pos
@ -136,7 +125,7 @@ module Rule = struct
Expr.compare c1 c2 Expr.compare c1 c2
| n -> n) | n -> n)
| Some (v1, t1), Some (v2, t2) -> ( | Some (v1, t1), Some (v2, t2) -> (
match Shared_ast.Expr.compare_typ t1 t2 with match Type.compare t1 t2 with
| 0 -> ( | 0 -> (
let open Bindlib in let open Bindlib in
let b1 = unbox (bind_var v1 (Expr.Box.lift r1.rule_just)) in let b1 = unbox (bind_var v1 (Expr.Box.lift r1.rule_just)) in
@ -189,29 +178,32 @@ type meta_assertion =
| VariesWith of unit * variation_typ Marked.pos option | VariesWith of unit * variation_typ Marked.pos option
type scope_def = { type scope_def = {
scope_def_rules : rule RuleMap.t; scope_def_rules : rule RuleName.Map.t;
scope_def_typ : typ; scope_def_typ : typ;
scope_def_is_condition : bool; scope_def_is_condition : bool;
scope_def_io : Scopelang.Ast.io; scope_def_io : io;
} }
type var_or_states = WholeVar | States of StateName.t list type var_or_states = WholeVar | States of StateName.t list
type scope = { type scope = {
scope_vars : var_or_states ScopeVarMap.t; scope_vars : var_or_states ScopeVar.Map.t;
scope_sub_scopes : ScopeName.t SubScopeMap.t; scope_sub_scopes : ScopeName.t SubScopeName.Map.t;
scope_uid : ScopeName.t; scope_uid : ScopeName.t;
scope_defs : scope_def ScopeDefMap.t; scope_defs : scope_def ScopeDefMap.t;
scope_assertions : assertion list; scope_assertions : assertion list;
scope_meta_assertions : meta_assertion list; scope_meta_assertions : meta_assertion list;
} }
type program = { program_scopes : scope ScopeMap.t; program_ctx : decl_ctx } type program = {
program_scopes : scope ScopeName.Map.t;
program_ctx : decl_ctx;
}
let rec locations_used e : LocationSet.t = let rec locations_used e : LocationSet.t =
match e with match e with
| ELocation l, m -> LocationSet.singleton (l, Expr.mark_pos m) | ELocation l, m -> LocationSet.singleton (l, Expr.mark_pos m)
| EAbs (binder, _), _ -> | EAbs { binder; _ }, _ ->
let _, body = Bindlib.unmbind binder in let _, body = Bindlib.unmbind binder in
locations_used body locations_used body
| e -> | e ->
@ -219,7 +211,7 @@ let rec locations_used e : LocationSet.t =
(fun e -> LocationSet.union (locations_used e)) (fun e -> LocationSet.union (locations_used e))
e LocationSet.empty e LocationSet.empty
let free_variables (def : rule RuleMap.t) : Pos.t ScopeDefMap.t = let free_variables (def : rule RuleName.Map.t) : Pos.t ScopeDefMap.t =
let add_locs (acc : Pos.t ScopeDefMap.t) (locs : LocationSet.t) : let add_locs (acc : Pos.t ScopeDefMap.t) (locs : LocationSet.t) :
Pos.t ScopeDefMap.t = Pos.t ScopeDefMap.t =
LocationSet.fold LocationSet.fold
@ -235,7 +227,7 @@ let free_variables (def : rule RuleMap.t) : Pos.t ScopeDefMap.t =
loc_pos acc) loc_pos acc)
locs acc locs acc
in in
RuleMap.fold RuleName.Map.fold
(fun _ rule acc -> (fun _ rule acc ->
let locs = let locs =
LocationSet.union LocationSet.union

View File

@ -16,19 +16,9 @@
(** Abstract syntax tree of the desugared representation *) (** Abstract syntax tree of the desugared representation *)
open Utils open Catala_utils
open Shared_ast open Shared_ast
(** {1 Names, Maps and Keys} *)
module IdentMap : Map.S with type key = String.t
module RuleName : Uid.Id with type info = Uid.MarkedString.info
module RuleMap : Map.S with type key = RuleName.t
module RuleSet : Set.S with type elt = RuleName.t
module LabelName : Uid.Id with type info = Uid.MarkedString.info
module LabelMap : Map.S with type key = LabelName.t
module LabelSet : Set.S with type elt = LabelName.t
(** Inside a scope, a definition can refer either to a scope def, or a subscope (** Inside a scope, a definition can refer either to a scope def, or a subscope
def *) def *)
module ScopeDef : sig module ScopeDef : sig
@ -88,27 +78,51 @@ type meta_assertion =
| FixedBy of reference_typ Marked.pos | FixedBy of reference_typ Marked.pos
| VariesWith of unit * variation_typ Marked.pos option | VariesWith of unit * variation_typ Marked.pos option
(** This type characterizes the three levels of visibility for a given scope
variable with regards to the scope's input and possible redefinitions inside
the scope.. *)
type io_input =
| NoInput
(** For an internal variable defined only in the scope, and does not
appear in the input. *)
| OnlyInput
(** For variables that should not be redefined in the scope, because they
appear in the input. *)
| Reentrant
(** For variables defined in the scope that can also be redefined by the
caller as they appear in the input. *)
type io = {
io_output : bool Marked.pos;
(** [true] is present in the output of the scope. *)
io_input : io_input Marked.pos;
}
(** Characterization of the input/output status of a scope variable. *)
type scope_def = { type scope_def = {
scope_def_rules : rule RuleMap.t; scope_def_rules : rule RuleName.Map.t;
scope_def_typ : typ; scope_def_typ : typ;
scope_def_is_condition : bool; scope_def_is_condition : bool;
scope_def_io : Scopelang.Ast.io; scope_def_io : io;
} }
type var_or_states = WholeVar | States of StateName.t list type var_or_states = WholeVar | States of StateName.t list
type scope = { type scope = {
scope_vars : var_or_states ScopeVarMap.t; scope_vars : var_or_states ScopeVar.Map.t;
scope_sub_scopes : ScopeName.t SubScopeMap.t; scope_sub_scopes : ScopeName.t SubScopeName.Map.t;
scope_uid : ScopeName.t; scope_uid : ScopeName.t;
scope_defs : scope_def ScopeDefMap.t; scope_defs : scope_def ScopeDefMap.t;
scope_assertions : assertion list; scope_assertions : assertion list;
scope_meta_assertions : meta_assertion list; scope_meta_assertions : meta_assertion list;
} }
type program = { program_scopes : scope ScopeMap.t; program_ctx : decl_ctx } type program = {
program_scopes : scope ScopeName.Map.t;
program_ctx : decl_ctx;
}
(** {1 Helpers} *) (** {1 Helpers} *)
val locations_used : expr -> LocationSet.t val locations_used : expr -> LocationSet.t
val free_variables : rule RuleMap.t -> Pos.t ScopeDefMap.t val free_variables : rule RuleName.Map.t -> Pos.t ScopeDefMap.t

View File

@ -17,7 +17,7 @@
(** Scope dependencies computations using {{:http://ocamlgraph.lri.fr/} (** Scope dependencies computations using {{:http://ocamlgraph.lri.fr/}
OCamlgraph} *) OCamlgraph} *)
open Utils open Catala_utils
open Shared_ast open Shared_ast
(** {1 Scope variables dependency graph} *) (** {1 Scope variables dependency graph} *)
@ -143,7 +143,7 @@ let build_scope_dependencies (scope : Ast.scope) : ScopeDependencies.t =
let g = ScopeDependencies.empty in let g = ScopeDependencies.empty in
(* Add all the vertices to the graph *) (* Add all the vertices to the graph *)
let g = let g =
ScopeVarMap.fold ScopeVar.Map.fold
(fun (v : ScopeVar.t) var_or_state g -> (fun (v : ScopeVar.t) var_or_state g ->
match var_or_state with match var_or_state with
| Ast.WholeVar -> ScopeDependencies.add_vertex g (Vertex.Var (v, None)) | Ast.WholeVar -> ScopeDependencies.add_vertex g (Vertex.Var (v, None))
@ -155,7 +155,7 @@ let build_scope_dependencies (scope : Ast.scope) : ScopeDependencies.t =
scope.scope_vars g scope.scope_vars g
in in
let g = let g =
SubScopeMap.fold SubScopeName.Map.fold
(fun (v : SubScopeName.t) _ g -> (fun (v : SubScopeName.t) _ g ->
ScopeDependencies.add_vertex g (Vertex.SubScope v)) ScopeDependencies.add_vertex g (Vertex.SubScope v))
scope.scope_sub_scopes g scope.scope_sub_scopes g
@ -229,10 +229,10 @@ let build_scope_dependencies (scope : Ast.scope) : ScopeDependencies.t =
(** {2 Graph declaration} *) (** {2 Graph declaration} *)
module ExceptionVertex = struct module ExceptionVertex = struct
include Ast.RuleSet include RuleName.Set
let hash (x : t) : int = let hash (x : t) : int =
Ast.RuleSet.fold (fun r acc -> Int.logxor (Ast.RuleName.hash r) acc) x 0 RuleName.Set.fold (fun r acc -> Int.logxor (RuleName.hash r) acc) x 0
let equal x y = compare x y = 0 let equal x y = compare x y = 0
end end
@ -257,13 +257,13 @@ module ExceptionsSCC = Graph.Components.Make (ExceptionsDependencies)
(** {2 Graph computations} *) (** {2 Graph computations} *)
type exception_edge = { type exception_edge = {
label_from : Ast.LabelName.t; label_from : LabelName.t;
label_to : Ast.LabelName.t; label_to : LabelName.t;
edge_positions : Pos.t list; edge_positions : Pos.t list;
} }
let build_exceptions_graph let build_exceptions_graph
(def : Ast.rule Ast.RuleMap.t) (def : Ast.rule RuleName.Map.t)
(def_info : Ast.ScopeDef.t) : ExceptionsDependencies.t = (def_info : Ast.ScopeDef.t) : ExceptionsDependencies.t =
(* First we partition the definitions into groups bearing the same label. To (* First we partition the definitions into groups bearing the same label. To
handle the rules that were not labeled by the user, we create implicit handle the rules that were not labeled by the user, we create implicit
@ -271,63 +271,59 @@ let build_exceptions_graph
(* All the rules of the form [definition x ...] are base case with no explicit (* All the rules of the form [definition x ...] are base case with no explicit
label, so they should share this implicit label. *) label, so they should share this implicit label. *)
let base_case_implicit_label = let base_case_implicit_label = LabelName.fresh ("base_case", Pos.no_pos) in
Ast.LabelName.fresh ("base_case", Pos.no_pos)
in
(* When declaring [exception definition x ...], it means there is a unique (* When declaring [exception definition x ...], it means there is a unique
rule [R] to which this can be an exception to. So we give a unique label to rule [R] to which this can be an exception to. So we give a unique label to
all the rules that are implicitly exceptions to rule [R]. *) all the rules that are implicitly exceptions to rule [R]. *)
let exception_to_rule_implicit_labels : Ast.LabelName.t Ast.RuleMap.t = let exception_to_rule_implicit_labels : LabelName.t RuleName.Map.t =
Ast.RuleMap.fold RuleName.Map.fold
(fun _ rule_from exception_to_rule_implicit_labels -> (fun _ rule_from exception_to_rule_implicit_labels ->
match rule_from.Ast.rule_exception with match rule_from.Ast.rule_exception with
| Ast.ExceptionToRule (rule_to, _) -> ( | Ast.ExceptionToRule (rule_to, _) -> (
match match
Ast.RuleMap.find_opt rule_to exception_to_rule_implicit_labels RuleName.Map.find_opt rule_to exception_to_rule_implicit_labels
with with
| Some _ -> | Some _ ->
(* we already created the label *) exception_to_rule_implicit_labels (* we already created the label *) exception_to_rule_implicit_labels
| None -> | None ->
Ast.RuleMap.add rule_to RuleName.Map.add rule_to
(Ast.LabelName.fresh (LabelName.fresh
( "exception_to_" ( "exception_to_" ^ Marked.unmark (RuleName.get_info rule_to),
^ Marked.unmark (Ast.RuleName.get_info rule_to),
Pos.no_pos )) Pos.no_pos ))
exception_to_rule_implicit_labels) exception_to_rule_implicit_labels)
| _ -> exception_to_rule_implicit_labels) | _ -> exception_to_rule_implicit_labels)
def Ast.RuleMap.empty def RuleName.Map.empty
in in
(* When declaring [exception foo_l definition x ...], the rule is exception to (* When declaring [exception foo_l definition x ...], the rule is exception to
all the rules sharing label [foo_l]. So we give a unique label to all the all the rules sharing label [foo_l]. So we give a unique label to all the
rules that are implicitly exceptions to rule [foo_l]. *) rules that are implicitly exceptions to rule [foo_l]. *)
let exception_to_label_implicit_labels : Ast.LabelName.t Ast.LabelMap.t = let exception_to_label_implicit_labels : LabelName.t LabelName.Map.t =
Ast.RuleMap.fold RuleName.Map.fold
(fun _ rule_from (fun _ rule_from
(exception_to_label_implicit_labels : Ast.LabelName.t Ast.LabelMap.t) -> (exception_to_label_implicit_labels : LabelName.t LabelName.Map.t) ->
match rule_from.Ast.rule_exception with match rule_from.Ast.rule_exception with
| Ast.ExceptionToLabel (label_to, _) -> ( | Ast.ExceptionToLabel (label_to, _) -> (
match match
Ast.LabelMap.find_opt label_to exception_to_label_implicit_labels LabelName.Map.find_opt label_to exception_to_label_implicit_labels
with with
| Some _ -> | Some _ ->
(* we already created the label *) (* we already created the label *)
exception_to_label_implicit_labels exception_to_label_implicit_labels
| None -> | None ->
Ast.LabelMap.add label_to LabelName.Map.add label_to
(Ast.LabelName.fresh (LabelName.fresh
( "exception_to_" ( "exception_to_" ^ Marked.unmark (LabelName.get_info label_to),
^ Marked.unmark (Ast.LabelName.get_info label_to),
Pos.no_pos )) Pos.no_pos ))
exception_to_label_implicit_labels) exception_to_label_implicit_labels)
| _ -> exception_to_label_implicit_labels) | _ -> exception_to_label_implicit_labels)
def Ast.LabelMap.empty def LabelName.Map.empty
in in
(* Now we have all the labels necessary to partition our rules into sets, each (* Now we have all the labels necessary to partition our rules into sets, each
one corresponding to a label relating to the structure of the exception one corresponding to a label relating to the structure of the exception
DAG. *) DAG. *)
let label_to_rule_sets = let label_to_rule_sets =
Ast.RuleMap.fold RuleName.Map.fold
(fun rule_name rule rule_sets -> (fun rule_name rule rule_sets ->
let label_of_rule = let label_of_rule =
match rule.Ast.rule_label with match rule.Ast.rule_label with
@ -336,23 +332,23 @@ let build_exceptions_graph
match rule.Ast.rule_exception with match rule.Ast.rule_exception with
| BaseCase -> base_case_implicit_label | BaseCase -> base_case_implicit_label
| ExceptionToRule (r, _) -> | ExceptionToRule (r, _) ->
Ast.RuleMap.find r exception_to_rule_implicit_labels RuleName.Map.find r exception_to_rule_implicit_labels
| ExceptionToLabel (l', _) -> | ExceptionToLabel (l', _) ->
Ast.LabelMap.find l' exception_to_label_implicit_labels) LabelName.Map.find l' exception_to_label_implicit_labels)
in in
Ast.LabelMap.update label_of_rule LabelName.Map.update label_of_rule
(fun rule_set -> (fun rule_set ->
match rule_set with match rule_set with
| None -> Some (Ast.RuleSet.singleton rule_name) | None -> Some (RuleName.Set.singleton rule_name)
| Some rule_set -> Some (Ast.RuleSet.add rule_name rule_set)) | Some rule_set -> Some (RuleName.Set.add rule_name rule_set))
rule_sets) rule_sets)
def Ast.LabelMap.empty def LabelName.Map.empty
in in
let find_label_of_rule (r : Ast.RuleName.t) : Ast.LabelName.t = let find_label_of_rule (r : RuleName.t) : LabelName.t =
fst fst
(Ast.LabelMap.choose (LabelName.Map.choose
(Ast.LabelMap.filter (LabelName.Map.filter
(fun _ rule_set -> Ast.RuleSet.mem r rule_set) (fun _ rule_set -> RuleName.Set.mem r rule_set)
label_to_rule_sets)) label_to_rule_sets))
in in
(* Next, we collect the exception edges between those groups of rules referred (* Next, we collect the exception edges between those groups of rules referred
@ -360,7 +356,7 @@ let build_exceptions_graph
edges as they are declared at each rule but should be the same for all the edges as they are declared at each rule but should be the same for all the
rules of the same group. *) rules of the same group. *)
let exception_edges : exception_edge list = let exception_edges : exception_edge list =
Ast.RuleMap.fold RuleName.Map.fold
(fun rule_name rule exception_edges -> (fun rule_name rule exception_edges ->
let label_from = find_label_of_rule rule_name in let label_from = find_label_of_rule rule_name in
let label_to_and_pos = let label_to_and_pos =
@ -374,16 +370,16 @@ let build_exceptions_graph
| Some (label_to, edge_pos) -> ( | Some (label_to, edge_pos) -> (
let other_edges_originating_from_same_label = let other_edges_originating_from_same_label =
List.filter List.filter
(fun edge -> Ast.LabelName.compare edge.label_from label_from = 0) (fun edge -> LabelName.compare edge.label_from label_from = 0)
exception_edges exception_edges
in in
(* We check the consistency*) (* We check the consistency*)
if Ast.LabelName.compare label_from label_to = 0 then if LabelName.compare label_from label_to = 0 then
Errors.raise_spanned_error edge_pos Errors.raise_spanned_error edge_pos
"Cannot define rule as an exception to itself"; "Cannot define rule as an exception to itself";
List.iter List.iter
(fun edge -> (fun edge ->
if Ast.LabelName.compare edge.label_to label_to <> 0 then if LabelName.compare edge.label_to label_to <> 0 then
Errors.raise_multispanned_error Errors.raise_multispanned_error
(( Some (( Some
"This declaration contradicts another exception \ "This declaration contradicts another exception \
@ -401,8 +397,8 @@ let build_exceptions_graph
let existing_edge = let existing_edge =
List.find_opt List.find_opt
(fun edge -> (fun edge ->
Ast.LabelName.compare edge.label_from label_from = 0 LabelName.compare edge.label_from label_from = 0
&& Ast.LabelName.compare edge.label_to label_to = 0) && LabelName.compare edge.label_to label_to = 0)
exception_edges exception_edges
in in
match existing_edge with match existing_edge with
@ -420,7 +416,7 @@ let build_exceptions_graph
in in
(* We've got the vertices and the edges, let's build the graph! *) (* We've got the vertices and the edges, let's build the graph! *)
let g = let g =
Ast.LabelMap.fold LabelName.Map.fold
(fun _label rule_set g -> ExceptionsDependencies.add_vertex g rule_set) (fun _label rule_set g -> ExceptionsDependencies.add_vertex g rule_set)
label_to_rule_sets ExceptionsDependencies.empty label_to_rule_sets ExceptionsDependencies.empty
in in
@ -429,10 +425,10 @@ let build_exceptions_graph
List.fold_left List.fold_left
(fun g edge -> (fun g edge ->
let rule_group_from = let rule_group_from =
Ast.LabelMap.find edge.label_from label_to_rule_sets LabelName.Map.find edge.label_from label_to_rule_sets
in in
let rule_group_to = let rule_group_to =
Ast.LabelMap.find edge.label_to label_to_rule_sets LabelName.Map.find edge.label_to label_to_rule_sets
in in
let edge = let edge =
ExceptionsDependencies.E.create rule_group_from edge.edge_positions ExceptionsDependencies.E.create rule_group_from edge.edge_positions
@ -453,11 +449,10 @@ let check_for_exception_cycle (g : ExceptionsDependencies.t) : unit =
let spans = let spans =
List.flatten List.flatten
(List.map (List.map
(fun (vs : Ast.RuleSet.t) -> (fun (vs : RuleName.Set.t) ->
let v = Ast.RuleSet.choose vs in let v = RuleName.Set.choose vs in
let var_str, var_info = let var_str, var_info =
( Format.asprintf "%a" Ast.RuleName.format_t v, Format.asprintf "%a" RuleName.format_t v, RuleName.get_info v
Ast.RuleName.get_info v )
in in
let succs = ExceptionsDependencies.succ_e g vs in let succs = ExceptionsDependencies.succ_e g vs in
let _, edge_pos, _ = let _, edge_pos, _ =

View File

@ -17,7 +17,8 @@
(** Scope dependencies computations using {{:http://ocamlgraph.lri.fr/} (** Scope dependencies computations using {{:http://ocamlgraph.lri.fr/}
OCamlgraph} *) OCamlgraph} *)
open Utils open Catala_utils
open Shared_ast
(** {1 Scope variables dependency graph} *) (** {1 Scope variables dependency graph} *)
@ -71,9 +72,9 @@ val build_scope_dependencies : Ast.scope -> ScopeDependencies.t
module EdgeExceptions : Graph.Sig.ORDERED_TYPE_DFT with type t = Pos.t list module EdgeExceptions : Graph.Sig.ORDERED_TYPE_DFT with type t = Pos.t list
module ExceptionsDependencies : module ExceptionsDependencies :
Graph.Sig.P with type V.t = Ast.RuleSet.t and type E.label = EdgeExceptions.t Graph.Sig.P with type V.t = RuleName.Set.t and type E.label = EdgeExceptions.t
val build_exceptions_graph : val build_exceptions_graph :
Ast.rule Ast.RuleMap.t -> Ast.ScopeDef.t -> ExceptionsDependencies.t Ast.rule RuleName.Map.t -> Ast.ScopeDef.t -> ExceptionsDependencies.t
val check_for_exception_cycle : ExceptionsDependencies.t -> unit val check_for_exception_cycle : ExceptionsDependencies.t -> unit

View File

@ -1,670 +0,0 @@
(* This file is part of the Catala compiler, a specification language for tax
and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
Denis Merigoux <denis.merigoux@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not
use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *)
(** Translation from {!module: Desugared.Ast} to {!module: Scopelang.Ast} *)
open Utils
open Shared_ast
(** {1 Expression translation}*)
type target_scope_vars =
| WholeVar of ScopeVar.t
| States of (StateName.t * ScopeVar.t) list
type ctx = {
scope_var_mapping : target_scope_vars ScopeVarMap.t;
var_mapping : (Ast.expr, untyped Scopelang.Ast.expr Var.t) Var.Map.t;
}
let tag_with_log_entry
(e : untyped Scopelang.Ast.expr boxed)
(l : log_entry)
(markings : Utils.Uid.MarkedString.info list) :
untyped Scopelang.Ast.expr boxed =
Expr.eapp
(Expr.eop (Unop (Log (l, markings))) (Marked.get_mark e))
[e] (Marked.get_mark e)
let rec translate_expr (ctx : ctx) (e : Ast.expr) :
untyped Scopelang.Ast.expr boxed =
let m = Marked.get_mark e in
match Marked.unmark e with
| ELocation (SubScopeVar (s_name, ss_name, s_var)) ->
(* When referring to a subscope variable in an expression, we are referring
to the output, hence we take the last state. *)
let new_s_var =
match ScopeVarMap.find (Marked.unmark s_var) ctx.scope_var_mapping with
| WholeVar new_s_var -> Marked.same_mark_as new_s_var s_var
| States states ->
Marked.same_mark_as (snd (List.hd (List.rev states))) s_var
in
Expr.elocation (SubScopeVar (s_name, ss_name, new_s_var)) m
| ELocation (DesugaredScopeVar (s_var, None)) ->
Expr.elocation
(ScopelangScopeVar
(match
ScopeVarMap.find (Marked.unmark s_var) ctx.scope_var_mapping
with
| WholeVar new_s_var -> Marked.same_mark_as new_s_var s_var
| States _ -> failwith "should not happen"))
m
| ELocation (DesugaredScopeVar (s_var, Some state)) ->
Expr.elocation
(ScopelangScopeVar
(match
ScopeVarMap.find (Marked.unmark s_var) ctx.scope_var_mapping
with
| WholeVar _ -> failwith "should not happen"
| States states -> Marked.same_mark_as (List.assoc state states) s_var))
m
| EVar v -> Expr.evar (Var.Map.find v ctx.var_mapping) m
| EStruct (s_name, fields) ->
Expr.estruct s_name (StructFieldMap.map (translate_expr ctx) fields) m
| EStructAccess (e1, f_name, s_name) ->
Expr.estructaccess (translate_expr ctx e1) f_name s_name m
| EEnumInj (e1, cons, e_name) ->
Expr.eenuminj (translate_expr ctx e1) cons e_name m
| EMatchS (e1, e_name, arms) ->
Expr.ematchs (translate_expr ctx e1) e_name
(EnumConstructorMap.map (translate_expr ctx) arms)
m
| EScopeCall (sc_name, fields) ->
Expr.escopecall sc_name
(ScopeVarMap.fold
(fun v e fields' ->
let v' =
match ScopeVarMap.find v ctx.scope_var_mapping with
| WholeVar v' -> v'
| States ((_, v') :: _) ->
(* When there are multiple states, the input is always the first
one *)
v'
| States [] -> assert false
in
ScopeVarMap.add v' (translate_expr ctx e) fields')
fields ScopeVarMap.empty)
m
| ELit
(( LBool _ | LEmptyError | LInt _ | LRat _ | LMoney _ | LUnit | LDate _
| LDuration _ ) as l) ->
Expr.elit l m
| EAbs (binder, typs) ->
let vars, body = Bindlib.unmbind binder in
let new_vars = Array.map (fun var -> Var.make (Bindlib.name_of var)) vars in
let ctx =
List.fold_left2
(fun ctx var new_var ->
{ ctx with var_mapping = Var.Map.add var new_var ctx.var_mapping })
ctx (Array.to_list vars) (Array.to_list new_vars)
in
Expr.eabs (Expr.bind new_vars (translate_expr ctx body)) typs m
| EApp (e1, args) ->
Expr.eapp (translate_expr ctx e1) (List.map (translate_expr ctx) args) m
| EOp op -> Expr.eop op m
| EDefault (excepts, just, cons) ->
Expr.edefault
(List.map (translate_expr ctx) excepts)
(translate_expr ctx just) (translate_expr ctx cons) m
| EIfThenElse (e1, e2, e3) ->
Expr.eifthenelse (translate_expr ctx e1) (translate_expr ctx e2)
(translate_expr ctx e3) m
| EArray args -> Expr.earray (List.map (translate_expr ctx) args) m
| ErrorOnEmpty e1 -> Expr.eerroronempty (translate_expr ctx e1) m
(** {1 Rule tree construction} *)
(** Intermediate representation for the exception tree of rules for a particular
scope definition. *)
type rule_tree =
| Leaf of Ast.rule list
(** Rules defining a base case piecewise. List is non-empty. *)
| Node of rule_tree list * Ast.rule list
(** [Node (exceptions, base_case)] is a list of exceptions to a non-empty
list of rules defining a base case piecewise. *)
(** Transforms a flat list of rules into a tree, taking into account the
priorities declared between rules *)
let def_map_to_tree (def_info : Ast.ScopeDef.t) (def : Ast.rule Ast.RuleMap.t) :
rule_tree list =
let exc_graph = Dependency.build_exceptions_graph def def_info in
Dependency.check_for_exception_cycle exc_graph;
(* we start by the base cases: they are the vertices which have no
successors *)
let base_cases =
Dependency.ExceptionsDependencies.fold_vertex
(fun v base_cases ->
if Dependency.ExceptionsDependencies.out_degree exc_graph v = 0 then
v :: base_cases
else base_cases)
exc_graph []
in
let rec build_tree (base_cases : Ast.RuleSet.t) : rule_tree =
let exceptions =
Dependency.ExceptionsDependencies.pred exc_graph base_cases
in
let base_case_as_rule_list =
List.map
(fun r -> Ast.RuleMap.find r def)
(Ast.RuleSet.elements base_cases)
in
match exceptions with
| [] -> Leaf base_case_as_rule_list
| _ -> Node (List.map build_tree exceptions, base_case_as_rule_list)
in
List.map build_tree base_cases
(** From the {!type: rule_tree}, builds an {!constructor: Dcalc.EDefault}
expression in the scope language. The [~toplevel] parameter is used to know
when to place the toplevel binding in the case of functions. *)
let rec rule_tree_to_expr
~(toplevel : bool)
(ctx : ctx)
(def_pos : Pos.t)
(is_func : Ast.expr Var.t option)
(tree : rule_tree) : untyped Scopelang.Ast.expr boxed =
let emark = Untyped { pos = def_pos } in
let exceptions, base_rules =
match tree with Leaf r -> [], r | Node (exceptions, r) -> exceptions, r
in
(* because each rule has its own variable parameter and we want to convert the
whole rule tree into a function, we need to perform some alpha-renaming of
all the expressions *)
let substitute_parameter (e : Ast.expr boxed) (rule : Ast.rule) :
Ast.expr boxed =
match is_func, rule.Ast.rule_parameter with
| Some new_param, Some (old_param, _) ->
let binder = Bindlib.bind_var old_param (Marked.unmark e) in
Marked.mark (Marked.get_mark e)
@@ Bindlib.box_apply2
(fun binder new_param -> Bindlib.subst binder new_param)
binder
(Bindlib.box_var new_param)
| None, None -> e
| _ -> assert false
(* should not happen *)
in
let ctx =
match is_func with
| None -> ctx
| Some new_param -> (
match Var.Map.find_opt new_param ctx.var_mapping with
| None ->
let new_param_scope = Var.make (Bindlib.name_of new_param) in
{
ctx with
var_mapping = Var.Map.add new_param new_param_scope ctx.var_mapping;
}
| Some _ ->
(* We only create a mapping if none exists because [rule_tree_to_expr]
is called recursively on the exceptions of the tree and we don't want
to create a new Scopelang variable for the parameter at each tree
level. *)
ctx)
in
let base_just_list =
List.map
(fun rule -> substitute_parameter rule.Ast.rule_just rule)
base_rules
in
let base_cons_list =
List.map
(fun rule -> substitute_parameter rule.Ast.rule_cons rule)
base_rules
in
let translate_and_unbox_list (list : Ast.expr boxed list) :
untyped Scopelang.Ast.expr boxed list =
List.map
(fun e ->
(* There are two levels of boxing here, the outermost is introduced by
the [translate_expr] function for which all of the bindings should
have been closed by now, so we can safely unbox. *)
translate_expr ctx (Expr.unbox e))
list
in
let default_containing_base_cases =
Expr.make_default
(List.map2
(fun base_just base_cons ->
Expr.make_default []
(* Here we insert the logging command that records when a decision
is taken for the value of a variable. *)
(tag_with_log_entry base_just PosRecordIfTrueBool [])
base_cons emark)
(translate_and_unbox_list base_just_list)
(translate_and_unbox_list base_cons_list))
(Expr.elit (LBool false) emark)
(Expr.elit LEmptyError emark)
emark
in
let exceptions =
List.map (rule_tree_to_expr ~toplevel:false ctx def_pos is_func) exceptions
in
let default =
Expr.make_default exceptions
(Expr.elit (LBool true) emark)
default_containing_base_cases emark
in
match is_func, (List.hd base_rules).Ast.rule_parameter with
| None, None -> default
| Some new_param, Some (_, typ) ->
if toplevel then
(* When we're creating a function from multiple defaults, we must check
that the result returned by the function is not empty *)
let default = Expr.eerroronempty default emark in
Expr.make_abs
[| Var.Map.find new_param ctx.var_mapping |]
default [typ] def_pos
else default
| _ -> (* should not happen *) assert false
(** {1 AST translation} *)
(** Translates a definition inside a scope, the resulting expression should be
an {!constructor: Dcalc.EDefault} *)
let translate_def
(ctx : ctx)
(def_info : Ast.ScopeDef.t)
(def : Ast.rule Ast.RuleMap.t)
(typ : typ)
(io : Scopelang.Ast.io)
~(is_cond : bool)
~(is_subscope_var : bool) : untyped Scopelang.Ast.expr boxed =
(* Here, we have to transform this list of rules into a default tree. *)
let is_def_func =
match Marked.unmark typ with TArrow (_, _) -> true | _ -> false
in
let is_rule_func _ (r : Ast.rule) : bool =
Option.is_some r.Ast.rule_parameter
in
let all_rules_func = Ast.RuleMap.for_all is_rule_func def in
let all_rules_not_func =
Ast.RuleMap.for_all (fun n r -> not (is_rule_func n r)) def
in
let is_def_func_param_typ : typ option =
if is_def_func && all_rules_func then
match Marked.unmark typ with
| TArrow (t_param, _) -> Some t_param
| _ ->
Errors.raise_spanned_error (Marked.get_mark typ)
"The definitions of %a are function but it doesn't have a function \
type"
Ast.ScopeDef.format_t def_info
else if (not is_def_func) && all_rules_not_func then None
else
let spans =
List.map
(fun (_, r) ->
Some "This definition is a function:", Expr.pos r.Ast.rule_cons)
(Ast.RuleMap.bindings (Ast.RuleMap.filter is_rule_func def))
@ List.map
(fun (_, r) ->
( Some "This definition is not a function:",
Expr.pos r.Ast.rule_cons ))
(Ast.RuleMap.bindings
(Ast.RuleMap.filter (fun n r -> not (is_rule_func n r)) def))
in
Errors.raise_multispanned_error spans
"some definitions of the same variable are functions while others \
aren't"
in
let top_list = def_map_to_tree def_info def in
let is_input =
match Marked.unmark io.Scopelang.Ast.io_input with
| OnlyInput -> true
| _ -> false
in
let top_value =
if is_cond && ((not is_subscope_var) || (is_subscope_var && is_input)) then
(* We add the bottom [false] value for conditions, only for the scope
where the condition is declared. Except when the variable is an input,
where we want the [false] to be added at each caller parent scope. *)
Some
(Ast.always_false_rule
(Ast.ScopeDef.get_position def_info)
is_def_func_param_typ)
else None
in
if
Ast.RuleMap.cardinal def = 0
&& is_subscope_var
(* Here we have a special case for the empty definitions. Indeed, we could
use the code for the regular case below that would create a convoluted
default always returning empty error, and this would be correct. But it
gets more complicated with functions. Indeed, if we create an empty
definition for a subscope argument whose type is a function, we get
something like [fun () -> (fun real_param -> < ... >)] that is passed as
an argument to the subscope. The sub-scope de-thunks but the de-thunking
does not return empty error, signalling there is not reentrant variable,
because functions are values! So the subscope does not see that there is
not reentrant variable and does not pick its internal definition instead.
See [test/test_scope/subscope_function_arg_not_defined.catala_en] for a
test case exercising that subtlety.
To avoid this complication we special case here and put an empty error
for all subscope variables that are not defined. It covers the subtlety
with functions described above but also conditions with the false default
value. *)
&& not (is_cond && is_input)
(* However, this special case suffers from an exception: when a condition is
defined as an OnlyInput to a subscope, since the [false] default value
will not be provided by the calee scope, it has to be placed in the
caller. *)
then
Expr.elit LEmptyError (Untyped { pos = Ast.ScopeDef.get_position def_info })
else
rule_tree_to_expr ~toplevel:true ctx
(Ast.ScopeDef.get_position def_info)
(Option.map (fun _ -> Var.make "param") is_def_func_param_typ)
(match top_list, top_value with
| [], None ->
(* In this case, there are no rules to define the expression and no
default value so we put an empty rule. *)
Leaf [Ast.empty_rule (Marked.get_mark typ) is_def_func_param_typ]
| [], Some top_value ->
(* In this case, there are no rules to define the expression but a
default value so we put it. *)
Leaf [top_value]
| _, Some top_value ->
(* When there are rules + a default value, we put the rules as
exceptions to the default value *)
Node (top_list, [top_value])
| [top_tree], None -> top_tree
| _, None ->
Node
( top_list,
[Ast.empty_rule (Marked.get_mark typ) is_def_func_param_typ] ))
(** Translates a scope *)
let translate_scope (ctx : ctx) (scope : Ast.scope) :
untyped Scopelang.Ast.scope_decl =
let scope_dependencies = Dependency.build_scope_dependencies scope in
Dependency.check_for_cycle scope scope_dependencies;
let scope_ordering =
Dependency.correct_computation_ordering scope_dependencies
in
let scope_decl_rules =
List.flatten
(List.map
(fun vertex ->
match vertex with
| Dependency.Vertex.Var (var, state) -> (
let scope_def =
Ast.ScopeDefMap.find
(Ast.ScopeDef.Var (var, state))
scope.scope_defs
in
let var_def = scope_def.scope_def_rules in
let var_typ = scope_def.scope_def_typ in
let is_cond = scope_def.scope_def_is_condition in
match Marked.unmark scope_def.Ast.scope_def_io.io_input with
| OnlyInput when not (Ast.RuleMap.is_empty var_def) ->
(* If the variable is tagged as input, then it shall not be
redefined. *)
Errors.raise_multispanned_error
(( Some "Incriminated variable:",
Marked.get_mark (ScopeVar.get_info var) )
:: List.map
(fun (rule, _) ->
( Some "Incriminated variable definition:",
Marked.get_mark (Ast.RuleName.get_info rule) ))
(Ast.RuleMap.bindings var_def))
"It is impossible to give a definition to a scope variable \
tagged as input."
| OnlyInput ->
[]
(* we do not provide any definition for an input-only variable *)
| _ ->
let expr_def =
translate_def ctx
(Ast.ScopeDef.Var (var, state))
var_def var_typ scope_def.Ast.scope_def_io ~is_cond
~is_subscope_var:false
in
let scope_var =
match ScopeVarMap.find var ctx.scope_var_mapping, state with
| WholeVar v, None -> v
| States states, Some state -> List.assoc state states
| _ -> failwith "should not happen"
in
[
Scopelang.Ast.Definition
( ( ScopelangScopeVar
( scope_var,
Marked.get_mark (ScopeVar.get_info scope_var) ),
Marked.get_mark (ScopeVar.get_info scope_var) ),
var_typ,
scope_def.Ast.scope_def_io,
Expr.unbox expr_def );
])
| Dependency.Vertex.SubScope sub_scope_index ->
(* Before calling the sub_scope, we need to include all the
re-definitions of subscope parameters*)
let sub_scope =
SubScopeMap.find sub_scope_index scope.scope_sub_scopes
in
let sub_scope_vars_redefs_candidates =
Ast.ScopeDefMap.filter
(fun def_key scope_def ->
match def_key with
| Ast.ScopeDef.Var _ -> false
| Ast.ScopeDef.SubScopeVar (sub_scope_index', _, _) ->
sub_scope_index = sub_scope_index'
(* We exclude subscope variables that have 0 re-definitions
and are not visible in the input of the subscope *)
&& not
((match
Marked.unmark scope_def.Ast.scope_def_io.io_input
with
| Scopelang.Ast.NoInput -> true
| _ -> false)
&& Ast.RuleMap.is_empty scope_def.scope_def_rules))
scope.scope_defs
in
let sub_scope_vars_redefs =
Ast.ScopeDefMap.mapi
(fun def_key scope_def ->
let def = scope_def.Ast.scope_def_rules in
let def_typ = scope_def.scope_def_typ in
let is_cond = scope_def.scope_def_is_condition in
match def_key with
| Ast.ScopeDef.Var _ -> assert false (* should not happen *)
| Ast.ScopeDef.SubScopeVar (sscope, sub_scope_var, pos) ->
(* This definition redefines a variable of the correct
subscope. But we have to check that this redefinition is
allowed with respect to the io parameters of that
subscope variable. *)
(match
Marked.unmark scope_def.Ast.scope_def_io.io_input
with
| Scopelang.Ast.NoInput ->
Errors.raise_multispanned_error
(( Some "Incriminated subscope:",
Marked.get_mark (SubScopeName.get_info sscope) )
:: ( Some "Incriminated variable:",
Marked.get_mark (ScopeVar.get_info sub_scope_var)
)
:: List.map
(fun (rule, _) ->
( Some
"Incriminated subscope variable definition:",
Marked.get_mark (Ast.RuleName.get_info rule) ))
(Ast.RuleMap.bindings def))
"It is impossible to give a definition to a subscope \
variable not tagged as input or context."
| OnlyInput when Ast.RuleMap.is_empty def && not is_cond ->
(* If the subscope variable is tagged as input, then it
shall be defined. *)
Errors.raise_multispanned_error
[
( Some "Incriminated subscope:",
Marked.get_mark (SubScopeName.get_info sscope) );
Some "Incriminated variable:", pos;
]
"This subscope variable is a mandatory input but no \
definition was provided."
| _ -> ());
(* Now that all is good, we can proceed with translating
this redefinition to a proper Scopelang term. *)
let expr_def =
translate_def ctx def_key def def_typ
scope_def.Ast.scope_def_io ~is_cond
~is_subscope_var:true
in
let subscop_real_name =
SubScopeMap.find sub_scope_index scope.scope_sub_scopes
in
let var_pos = Ast.ScopeDef.get_position def_key in
Scopelang.Ast.Definition
( ( SubScopeVar
( subscop_real_name,
(sub_scope_index, var_pos),
match
ScopeVarMap.find sub_scope_var
ctx.scope_var_mapping
with
| WholeVar v -> v, var_pos
| States states ->
(* When defining a sub-scope variable, we
always define its first state in the
sub-scope. *)
snd (List.hd states), var_pos ),
var_pos ),
def_typ,
scope_def.Ast.scope_def_io,
Expr.unbox expr_def ))
sub_scope_vars_redefs_candidates
in
let sub_scope_vars_redefs =
List.map snd (Ast.ScopeDefMap.bindings sub_scope_vars_redefs)
in
sub_scope_vars_redefs
@ [
Scopelang.Ast.Call
( sub_scope,
sub_scope_index,
Untyped
{
pos =
Marked.get_mark
(SubScopeName.get_info sub_scope_index);
} );
])
scope_ordering)
in
(* Then, after having computed all the scopes variables, we add the
assertions. TODO: the assertions should be interleaved with the
definitions! *)
let scope_decl_rules =
scope_decl_rules
@ List.map
(fun e ->
let scope_e = translate_expr ctx (Expr.unbox e) in
Scopelang.Ast.Assertion (Expr.unbox scope_e))
scope.Ast.scope_assertions
in
let scope_sig =
ScopeVarMap.fold
(fun var (states : Ast.var_or_states) acc ->
match states with
| WholeVar ->
let scope_def =
Ast.ScopeDefMap.find (Ast.ScopeDef.Var (var, None)) scope.scope_defs
in
let typ = scope_def.scope_def_typ in
ScopeVarMap.add
(match ScopeVarMap.find var ctx.scope_var_mapping with
| WholeVar v -> v
| States _ -> failwith "should not happen")
(typ, scope_def.scope_def_io)
acc
| States states ->
(* What happens in the case of variables with multiple states is
interesting. We need to create as many Scopelang.Var entries in the
scope signature as there are states. *)
List.fold_left
(fun acc (state : StateName.t) ->
let scope_def =
Ast.ScopeDefMap.find
(Ast.ScopeDef.Var (var, Some state))
scope.scope_defs
in
ScopeVarMap.add
(match ScopeVarMap.find var ctx.scope_var_mapping with
| WholeVar _ -> failwith "should not happen"
| States states' -> List.assoc state states')
(scope_def.scope_def_typ, scope_def.scope_def_io)
acc)
acc states)
scope.scope_vars ScopeVarMap.empty
in
let pos = Marked.get_mark (ScopeName.get_info scope.scope_uid) in
{
Scopelang.Ast.scope_decl_name = scope.scope_uid;
Scopelang.Ast.scope_decl_rules;
Scopelang.Ast.scope_sig;
Scopelang.Ast.scope_mark = Untyped { pos };
}
(** {1 API} *)
let translate_program (pgrm : Ast.program) : untyped Scopelang.Ast.program =
(* First we give mappings to all the locations between Desugared and
Scopelang. This involves creating a new Scopelang scope variable for every
state of a Desugared variable. *)
let ctx =
ScopeMap.fold
(fun _scope scope_decl ctx ->
ScopeVarMap.fold
(fun scope_var (states : Ast.var_or_states) ctx ->
match states with
| Ast.WholeVar ->
{
ctx with
scope_var_mapping =
ScopeVarMap.add scope_var
(WholeVar (ScopeVar.fresh (ScopeVar.get_info scope_var)))
ctx.scope_var_mapping;
}
| States states ->
{
ctx with
scope_var_mapping =
ScopeVarMap.add scope_var
(States
(List.map
(fun state ->
( state,
ScopeVar.fresh
(let state_name, state_pos =
StateName.get_info state
in
( Marked.unmark (ScopeVar.get_info scope_var)
^ "_"
^ state_name,
state_pos )) ))
states))
ctx.scope_var_mapping;
})
scope_decl.Ast.scope_vars ctx)
pgrm.Ast.program_scopes
{ scope_var_mapping = ScopeVarMap.empty; var_mapping = Var.Map.empty }
in
{
Scopelang.Ast.program_scopes =
ScopeMap.map (translate_scope ctx) pgrm.program_scopes;
Scopelang.Ast.program_ctx = pgrm.program_ctx;
}

View File

@ -0,0 +1,78 @@
(* This file is part of the Catala compiler, a specification language for tax
and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
Louis Gesbert <louis.gesbert@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not
use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *)
open Shared_ast
open Ast
let expr ctx env e =
(* The typer takes care of disambiguating: this consists in: - ensuring
[EAbs.tys] doesn't contain any [TAny] - [EDStructAccess.name_opt] is always
[Some] *)
(* Intermediate unboxings are fine since the last [untype] will rebox in
depth *)
Typing.check_expr ctx ~env (Expr.unbox e)
let rule ctx env rule =
let env =
match rule.rule_parameter with
| None -> env
| Some (v, ty) -> Typing.Env.add_var v ty env
in
(* Note: we could use the known rule type here to direct typing. We choose not
to because it shouldn't be needed for disambiguation, and we prefer to
focus on local type errors first. *)
{
rule with
rule_just = expr ctx env rule.rule_just;
rule_cons = expr ctx env rule.rule_cons;
}
let scope ctx env scope =
let env = Typing.Env.open_scope scope.scope_uid env in
let scope_defs =
ScopeDefMap.map
(fun def ->
let scope_def_rules =
(* Note: ordering in file order might be better for error reporting ?
When we gather errors, the ordering could be done afterwards,
though *)
RuleName.Map.map (rule ctx env) def.scope_def_rules
in
{ def with scope_def_rules })
scope.scope_defs
in
let scope_assertions = List.map (expr ctx env) scope.scope_assertions in
{ scope with scope_defs; scope_assertions }
let program prg =
let env =
ScopeName.Map.fold
(fun scope_name scope env ->
let vars =
ScopeDefMap.fold
(fun var def vars ->
match var with
| Var (v, _states) -> ScopeVar.Map.add v def.scope_def_typ vars
| SubScopeVar _ -> vars)
scope.scope_defs ScopeVar.Map.empty
in
Typing.Env.add_scope scope_name ~vars env)
prg.program_scopes Typing.Env.empty
in
let program_scopes =
ScopeName.Map.map (scope prg.program_ctx env) prg.program_scopes
in
{ prg with program_scopes }

View File

@ -0,0 +1,24 @@
(* This file is part of the Catala compiler, a specification language for tax
and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
Louis Gesbert <louis.gesbert@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not
use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *)
(** This module does local typing in order to fill some missing type information
in the AST:
- it fills the types of arguments in [EAbs] nodes, (untyped ones are
inserted during desugaring, e.g. by `let-in` constructs),
- it resolves the structure names of [EDStructAccess] nodes. *)
val program : Ast.program -> Ast.program

View File

@ -1,7 +1,7 @@
(library (library
(name desugared) (name desugared)
(public_name catala.desugared) (public_name catala.desugared)
(libraries utils dcalc scopelang ocamlgraph)) (libraries ocamlgraph catala_utils shared_ast surface))
(documentation (documentation
(package catala) (package catala)

View File

@ -20,6 +20,6 @@
- Removes syntactic sugars - Removes syntactic sugars
- Separate code from legislation *) - Separate code from legislation *)
val desugar_program : val translate_program :
Name_resolution.context -> Ast.program -> Desugared.Ast.program Name_resolution.context -> Surface.Ast.program -> Ast.program
(** Main function of this module *) (** Main function of this module *)

View File

@ -18,20 +18,18 @@
(** Builds a context that allows for mapping each name to a precise uid, taking (** Builds a context that allows for mapping each name to a precise uid, taking
lexical scopes into account *) lexical scopes into account *)
open Utils open Catala_utils
open Shared_ast open Shared_ast
(** {1 Name resolution context} *) (** {1 Name resolution context} *)
type ident = string
type unique_rulename = type unique_rulename =
| Ambiguous of Pos.t list | Ambiguous of Pos.t list
| Unique of Desugared.Ast.RuleName.t Marked.pos | Unique of RuleName.t Marked.pos
type scope_def_context = { type scope_def_context = {
default_exception_rulename : unique_rulename option; default_exception_rulename : unique_rulename option;
label_idmap : Desugared.Ast.LabelName.t Desugared.Ast.IdentMap.t; label_idmap : LabelName.t IdentName.Map.t;
} }
type scope_var_or_subscope = type scope_var_or_subscope =
@ -39,26 +37,26 @@ type scope_var_or_subscope =
| SubScope of SubScopeName.t * ScopeName.t | SubScope of SubScopeName.t * ScopeName.t
type scope_context = { type scope_context = {
var_idmap : scope_var_or_subscope Desugared.Ast.IdentMap.t; var_idmap : scope_var_or_subscope IdentName.Map.t;
(** All variables, including scope variables and subscopes *) (** All variables, including scope variables and subscopes *)
scope_defs_contexts : scope_def_context Desugared.Ast.ScopeDefMap.t; scope_defs_contexts : scope_def_context Ast.ScopeDefMap.t;
(** What is the default rule to refer to for unnamed exceptions, if any *) (** What is the default rule to refer to for unnamed exceptions, if any *)
sub_scopes : ScopeSet.t; sub_scopes : ScopeName.Set.t;
(** Other scopes referred to by this scope. Used for dependency analysis *) (** Other scopes referred to by this scope. Used for dependency analysis *)
} }
(** Inside a scope, we distinguish between the variables and the subscopes. *) (** Inside a scope, we distinguish between the variables and the subscopes. *)
type struct_context = typ StructFieldMap.t type struct_context = typ StructField.Map.t
(** Types of the fields of a struct *) (** Types of the fields of a struct *)
type enum_context = typ EnumConstructorMap.t type enum_context = typ EnumConstructor.Map.t
(** Types of the payloads of the cases of an enum *) (** Types of the payloads of the cases of an enum *)
type var_sig = { type var_sig = {
var_sig_typ : typ; var_sig_typ : typ;
var_sig_is_condition : bool; var_sig_is_condition : bool;
var_sig_io : Ast.scope_decl_context_io; var_sig_io : Surface.Ast.scope_decl_context_io;
var_sig_states_idmap : StateName.t Desugared.Ast.IdentMap.t; var_sig_states_idmap : StateName.t IdentName.Map.t;
var_sig_states_list : StateName.t list; var_sig_states_list : StateName.t list;
} }
@ -67,25 +65,26 @@ type var_sig = {
type typedef = type typedef =
| TStruct of StructName.t | TStruct of StructName.t
| TEnum of EnumName.t | TEnum of EnumName.t
| TScope of ScopeName.t * StructName.t | TScope of ScopeName.t * scope_out_struct
(** Implicitly defined output struct *) (** Implicitly defined output struct *)
type context = { type context = {
local_var_idmap : Desugared.Ast.expr Var.t Desugared.Ast.IdentMap.t; local_var_idmap : Ast.expr Var.t IdentName.Map.t;
(** Inside a definition, local variables can be introduced by functions (** Inside a definition, local variables can be introduced by functions
arguments or pattern matching *) arguments or pattern matching *)
typedefs : typedef Desugared.Ast.IdentMap.t; typedefs : typedef IdentName.Map.t;
(** Gathers the names of the scopes, structs and enums *) (** Gathers the names of the scopes, structs and enums *)
field_idmap : StructFieldName.t StructMap.t Desugared.Ast.IdentMap.t; field_idmap : StructField.t StructName.Map.t IdentName.Map.t;
(** The names of the struct fields. Names of fields can be shared between (** The names of the struct fields. Names of fields can be shared between
different structs *) different structs *)
constructor_idmap : EnumConstructor.t EnumMap.t Desugared.Ast.IdentMap.t; constructor_idmap : EnumConstructor.t EnumName.Map.t IdentName.Map.t;
(** The names of the enum constructors. Constructor names can be shared (** The names of the enum constructors. Constructor names can be shared
between different enums *) between different enums *)
scopes : scope_context ScopeMap.t; (** For each scope, its context *) scopes : scope_context ScopeName.Map.t; (** For each scope, its context *)
structs : struct_context StructMap.t; (** For each struct, its context *) structs : struct_context StructName.Map.t;
enums : enum_context EnumMap.t; (** For each enum, its context *) (** For each struct, its context *)
var_typs : var_sig ScopeVarMap.t; enums : enum_context EnumName.Map.t; (** For each enum, its context *)
var_typs : var_sig ScopeVar.Map.t;
(** The signatures of each scope variable declared *) (** The signatures of each scope variable declared *)
} }
(** Main context used throughout {!module: Surface.Desugaring} *) (** Main context used throughout {!module: Surface.Desugaring} *)
@ -96,7 +95,7 @@ val raise_unsupported_feature : string -> Pos.t -> 'a
(** Temporary function raising an error message saying that a feature is not (** Temporary function raising an error message saying that a feature is not
supported yet *) supported yet *)
val raise_unknown_identifier : string -> ident Marked.pos -> 'a val raise_unknown_identifier : string -> IdentName.t Marked.pos -> 'a
(** Function to call whenever an identifier used somewhere has not been declared (** Function to call whenever an identifier used somewhere has not been declared
in the program previously *) in the program previously *)
@ -104,53 +103,53 @@ val get_var_typ : context -> ScopeVar.t -> typ
(** Gets the type associated to an uid *) (** Gets the type associated to an uid *)
val is_var_cond : context -> ScopeVar.t -> bool val is_var_cond : context -> ScopeVar.t -> bool
val get_var_io : context -> ScopeVar.t -> Ast.scope_decl_context_io val get_var_io : context -> ScopeVar.t -> Surface.Ast.scope_decl_context_io
val get_var_uid : ScopeName.t -> context -> ident Marked.pos -> ScopeVar.t val get_var_uid : ScopeName.t -> context -> IdentName.t Marked.pos -> ScopeVar.t
(** Get the variable uid inside the scope given in argument *) (** Get the variable uid inside the scope given in argument *)
val get_subscope_uid : val get_subscope_uid :
ScopeName.t -> context -> ident Marked.pos -> SubScopeName.t ScopeName.t -> context -> IdentName.t Marked.pos -> SubScopeName.t
(** Get the subscope uid inside the scope given in argument *) (** Get the subscope uid inside the scope given in argument *)
val is_subscope_uid : ScopeName.t -> context -> ident -> bool val is_subscope_uid : ScopeName.t -> context -> IdentName.t -> bool
(** [is_subscope_uid scope_uid ctxt y] returns true if [y] belongs to the (** [is_subscope_uid scope_uid ctxt y] returns true if [y] belongs to the
subscopes of [scope_uid]. *) subscopes of [scope_uid]. *)
val belongs_to : context -> ScopeVar.t -> ScopeName.t -> bool val belongs_to : context -> ScopeVar.t -> ScopeName.t -> bool
(** Checks if the var_uid belongs to the scope scope_uid *) (** Checks if the var_uid belongs to the scope scope_uid *)
val get_def_typ : context -> Desugared.Ast.ScopeDef.t -> typ val get_def_typ : context -> Ast.ScopeDef.t -> typ
(** Retrieves the type of a scope definition from the context *) (** Retrieves the type of a scope definition from the context *)
val is_def_cond : context -> Desugared.Ast.ScopeDef.t -> bool val is_def_cond : context -> Ast.ScopeDef.t -> bool
val is_type_cond : Ast.typ -> bool val is_type_cond : Surface.Ast.typ -> bool
val add_def_local_var : context -> ident -> context * Desugared.Ast.expr Var.t val add_def_local_var : context -> IdentName.t -> context * Ast.expr Var.t
(** Adds a binding to the context *) (** Adds a binding to the context *)
val get_def_key : val get_def_key :
Ast.qident -> Surface.Ast.scope_var ->
Ast.ident Marked.pos option -> Surface.Ast.lident Marked.pos option ->
ScopeName.t -> ScopeName.t ->
context -> context ->
Pos.t -> Pos.t ->
Desugared.Ast.ScopeDef.t Ast.ScopeDef.t
(** Usage: [get_def_key var_name var_state scope_uid ctxt pos]*) (** Usage: [get_def_key var_name var_state scope_uid ctxt pos]*)
val get_enum : context -> ident Marked.pos -> EnumName.t val get_enum : context -> IdentName.t Marked.pos -> EnumName.t
(** Find an enum definition from the typedefs, failing if there is none or it (** Find an enum definition from the typedefs, failing if there is none or it
has a different kind *) has a different kind *)
val get_struct : context -> ident Marked.pos -> StructName.t val get_struct : context -> IdentName.t Marked.pos -> StructName.t
(** Find a struct definition from the typedefs (possibly an implicit output (** Find a struct definition from the typedefs (possibly an implicit output
struct from a scope), failing if there is none or it has a different kind *) struct from a scope), failing if there is none or it has a different kind *)
val get_scope : context -> ident Marked.pos -> ScopeName.t val get_scope : context -> IdentName.t Marked.pos -> ScopeName.t
(** Find a scope definition from the typedefs, failing if there is none or it (** Find a scope definition from the typedefs, failing if there is none or it
has a different kind *) has a different kind *)
(** {1 API} *) (** {1 API} *)
val form_context : Ast.program -> context val form_context : Surface.Ast.program -> context
(** Derive the context from metadata, in one pass over the declarations *) (** Derive the context from metadata, in one pass over the declarations *)

View File

@ -15,10 +15,7 @@
License for the specific language governing permissions and limitations under License for the specific language governing permissions and limitations under
the License. *) the License. *)
module Cli = Utils.Cli open Catala_utils
module File = Utils.File
module Errors = Utils.Errors
module Pos = Utils.Pos
(** Associates a {!type: Cli.backend_lang} with its string represtation. *) (** Associates a {!type: Cli.backend_lang} with its string represtation. *)
let languages = ["en", Cli.En; "fr", Cli.Fr; "pl", Cli.Pl] let languages = ["en", Cli.En; "fr", Cli.Fr; "pl", Cli.Pl]
@ -76,7 +73,15 @@ let driver source_file (options : Cli.options) : int =
try `Plugin (Plugin.find s) try `Plugin (Plugin.find s)
with Not_found -> with Not_found ->
Errors.raise_error Errors.raise_error
"The selected backend (%s) is not supported by Catala" backend) "The selected backend (%s) is not supported by Catala, nor was a \
plugin by this name found under %a"
backend
(Format.pp_print_list
~pp_sep:(fun ppf () -> Format.fprintf ppf "@ or @ ")
(fun ppf dir ->
Format.pp_print_string ppf
(try Unix.readlink dir with _ -> dir)))
options.plugins_dirs)
in in
let prgm = let prgm =
Surface.Parser_driver.parse_top_level_file source_file language Surface.Parser_driver.parse_top_level_file source_file language
@ -143,7 +148,7 @@ let driver source_file (options : Cli.options) : int =
| ( `Interpret | `Typecheck | `OCaml | `Python | `Scalc | `Lcalc | `Dcalc | ( `Interpret | `Typecheck | `OCaml | `Python | `Scalc | `Lcalc | `Dcalc
| `Scopelang | `Proof | `Plugin _ ) as backend -> ( | `Scopelang | `Proof | `Plugin _ ) as backend -> (
Cli.debug_print "Name resolution..."; Cli.debug_print "Name resolution...";
let ctxt = Surface.Name_resolution.form_context prgm in let ctxt = Desugared.Name_resolution.form_context prgm in
let scope_uid = let scope_uid =
match options.ex_scope, backend with match options.ex_scope, backend with
| None, `Interpret -> | None, `Interpret ->
@ -151,27 +156,29 @@ let driver source_file (options : Cli.options) : int =
| None, _ -> | None, _ ->
let _, scope = let _, scope =
try try
Desugared.Ast.IdentMap.filter_map Shared_ast.IdentName.Map.filter_map
(fun _ -> function (fun _ -> function
| Surface.Name_resolution.TScope (uid, _) -> Some uid | Desugared.Name_resolution.TScope (uid, _) -> Some uid
| _ -> None) | _ -> None)
ctxt.typedefs ctxt.typedefs
|> Desugared.Ast.IdentMap.choose |> Shared_ast.IdentName.Map.choose
with Not_found -> with Not_found ->
Errors.raise_error "There isn't any scope inside the program." Errors.raise_error "There isn't any scope inside the program."
in in
scope scope
| Some name, _ -> ( | Some name, _ -> (
match Desugared.Ast.IdentMap.find_opt name ctxt.typedefs with match Shared_ast.IdentName.Map.find_opt name ctxt.typedefs with
| Some (Surface.Name_resolution.TScope (uid, _)) -> uid | Some (Desugared.Name_resolution.TScope (uid, _)) -> uid
| _ -> | _ ->
Errors.raise_error "There is no scope \"%s\" inside the program." Errors.raise_error "There is no scope \"%s\" inside the program."
name) name)
in in
Cli.debug_print "Desugaring..."; Cli.debug_print "Desugaring...";
let prgm = Surface.Desugaring.desugar_program ctxt prgm in let prgm = Desugared.From_surface.translate_program ctxt prgm in
Cli.debug_print "Disambiguating...";
let prgm = Desugared.Disambiguate.program prgm in
Cli.debug_print "Collecting rules..."; Cli.debug_print "Collecting rules...";
let prgm = Desugared.Desugared_to_scope.translate_program prgm in let prgm = Scopelang.From_desugared.translate_program prgm in
match backend with match backend with
| `Scopelang -> | `Scopelang ->
let _output_file, with_output = get_output_format () in let _output_file, with_output = get_output_format () in
@ -180,7 +187,8 @@ let driver source_file (options : Cli.options) : int =
if Option.is_some options.ex_scope then if Option.is_some options.ex_scope then
Format.fprintf fmt "%a\n" Format.fprintf fmt "%a\n"
(Scopelang.Print.scope prgm.program_ctx ~debug:options.debug) (Scopelang.Print.scope prgm.program_ctx ~debug:options.debug)
(scope_uid, Shared_ast.ScopeMap.find scope_uid prgm.program_scopes) ( scope_uid,
Shared_ast.ScopeName.Map.find scope_uid prgm.program_scopes )
else else
Format.fprintf fmt "%a\n" Format.fprintf fmt "%a\n"
(Scopelang.Print.program ~debug:options.debug) (Scopelang.Print.program ~debug:options.debug)
@ -194,7 +202,7 @@ let driver source_file (options : Cli.options) : int =
in in
let prgm = Scopelang.Ast.type_program prgm in let prgm = Scopelang.Ast.type_program prgm in
Cli.debug_print "Translating to default calculus..."; Cli.debug_print "Translating to default calculus...";
let prgm = Scopelang.Scope_to_dcalc.translate_program prgm in let prgm = Dcalc.From_scopelang.translate_program prgm in
let prgm = let prgm =
if options.optimize then begin if options.optimize then begin
Cli.debug_print "Optimizing default calculus..."; Cli.debug_print "Optimizing default calculus...";
@ -202,8 +210,21 @@ let driver source_file (options : Cli.options) : int =
end end
else prgm else prgm
in in
(* Cli.debug_print (Format.asprintf "Typechecking results :@\n%a"
(Print.typ prgm.decl_ctx) typ); *)
match backend with match backend with
| `Typecheck -> | `Typecheck ->
Cli.debug_print "Typechecking again...";
let _ =
try Shared_ast.Typing.program prgm
with Errors.StructuredError (msg, details) ->
let msg =
"Typing error occured during re-typing on the 'default \
calculus'. This is a bug in the Catala compiler.\n"
^ msg
in
raise (Errors.StructuredError (msg, details))
in
(* That's it! *) (* That's it! *)
Cli.result_print "Typechecking successful!" Cli.result_print "Typechecking successful!"
| `Dcalc -> | `Dcalc ->
@ -229,7 +250,7 @@ let driver source_file (options : Cli.options) : int =
Shared_ast.Expr.unbox (Shared_ast.Program.to_expr prgm scope_uid) Shared_ast.Expr.unbox (Shared_ast.Program.to_expr prgm scope_uid)
in in
Format.fprintf fmt "%a\n" Format.fprintf fmt "%a\n"
(Shared_ast.Expr.format prgm.decl_ctx) (Shared_ast.Expr.format ~debug:options.debug prgm.decl_ctx)
prgrm_dcalc_expr prgrm_dcalc_expr
| (`Interpret | `OCaml | `Python | `Scalc | `Lcalc | `Proof | `Plugin _) | (`Interpret | `OCaml | `Python | `Scalc | `Lcalc | `Proof | `Plugin _)
as backend -> ( as backend -> (
@ -244,8 +265,6 @@ let driver source_file (options : Cli.options) : int =
in in
raise (Errors.StructuredError (msg, details)) raise (Errors.StructuredError (msg, details))
in in
(* Cli.debug_print (Format.asprintf "Typechecking results :@\n%a"
(Print.typ prgm.decl_ctx) typ); *)
match backend with match backend with
| `Proof -> | `Proof ->
let vcs = let vcs =
@ -308,24 +327,14 @@ let driver source_file (options : Cli.options) : int =
if Option.is_some options.ex_scope then if Option.is_some options.ex_scope then
Format.fprintf fmt "%a\n" Format.fprintf fmt "%a\n"
(Shared_ast.Scope.format ~debug:options.debug prgm.decl_ctx) (Shared_ast.Scope.format ~debug:options.debug prgm.decl_ctx)
( scope_uid, (scope_uid, Shared_ast.Program.get_scope_body prgm scope_uid)
Option.get
(Shared_ast.Scope.fold_left ~init:None
~f:(fun acc scope_def _ ->
if
Shared_ast.ScopeName.compare scope_def.scope_name
scope_uid
= 0
then Some scope_def.scope_body
else acc)
prgm.scopes) )
else else
let prgrm_lcalc_expr = let prgrm_lcalc_expr =
Shared_ast.Expr.unbox Shared_ast.Expr.unbox
(Shared_ast.Program.to_expr prgm scope_uid) (Shared_ast.Program.to_expr prgm scope_uid)
in in
Format.fprintf fmt "%a\n" Format.fprintf fmt "%a\n"
(Shared_ast.Expr.format prgm.decl_ctx) (Shared_ast.Expr.format ~debug:options.debug prgm.decl_ctx)
prgrm_lcalc_expr prgrm_lcalc_expr
| (`OCaml | `Python | `Scalc | `Plugin _) as backend -> ( | (`OCaml | `Python | `Scalc | `Plugin _) as backend -> (
match backend with match backend with

View File

@ -15,9 +15,10 @@
License for the specific language governing permissions and limitations under License for the specific language governing permissions and limitations under
the License. *) the License. *)
open Catala_utils
module Plugin = Plugin.PluginAPI module Plugin = Plugin.PluginAPI
val driver : Utils.Pos.input_file -> Utils.Cli.options -> int val driver : Pos.input_file -> Cli.options -> int
(** Entry function for the executable. Returns a negative number in case of (** Entry function for the executable. Returns a negative number in case of
error. *) error. *)

View File

@ -3,7 +3,7 @@
(public_name catala.driver) (public_name catala.driver)
(libraries (libraries
dynlink dynlink
utils catala_utils
surface surface
desugared desugared
literate literate
@ -50,3 +50,7 @@
(documentation (documentation
(package catala) (package catala)
(mld_files index)) (mld_files index))
(alias
(name catala)
(deps catala.exe))

View File

@ -103,7 +103,7 @@ Two more modules contain additional features for the compiler:
{ul {ul
{li {{: literate.html} Literate programming}} {li {{: literate.html} Literate programming}}
{li {{: utils.html} Compiler utilities}} {li {{: catala_utils.html} Compiler utilities}}
} }
The Catala runtimes documentation is available here: The Catala runtimes documentation is available here:

View File

@ -14,7 +14,7 @@
License for the specific language governing permissions and limitations under License for the specific language governing permissions and limitations under
the License. *) the License. *)
open Utils open Catala_utils
include Shared_ast include Shared_ast
type lit = lcalc glit type lit = lcalc glit
@ -28,31 +28,32 @@ let option_enum : EnumName.t = EnumName.fresh ("eoption", Pos.no_pos)
let none_constr : EnumConstructor.t = EnumConstructor.fresh ("ENone", Pos.no_pos) let none_constr : EnumConstructor.t = EnumConstructor.fresh ("ENone", Pos.no_pos)
let some_constr : EnumConstructor.t = EnumConstructor.fresh ("ESome", Pos.no_pos) let some_constr : EnumConstructor.t = EnumConstructor.fresh ("ESome", Pos.no_pos)
let option_enum_config : (EnumConstructor.t * typ) list = let option_enum_config : typ EnumConstructor.Map.t =
[none_constr, (TLit TUnit, Pos.no_pos); some_constr, (TAny, Pos.no_pos)] EnumConstructor.Map.empty
|> EnumConstructor.Map.add none_constr (TLit TUnit, Pos.no_pos)
|> EnumConstructor.Map.add some_constr (TAny, Pos.no_pos)
(* FIXME: proper typing in all the constructors below *) (* FIXME: proper typing in all the constructors below *)
let make_none m = let make_none m =
let tunit = TLit TUnit, Expr.mark_pos m in let tunit = TLit TUnit, Expr.mark_pos m in
Expr.einj Expr.einj (Expr.elit LUnit (Expr.with_ty m tunit)) none_constr option_enum m
(Expr.elit LUnit (Expr.with_ty m tunit))
0 option_enum
[TLit TUnit, Pos.no_pos; TAny, Pos.no_pos]
m
let make_some e = let make_some e =
let m = Marked.get_mark e in let m = Marked.get_mark e in
Expr.einj e 1 option_enum Expr.einj e some_constr option_enum m
[TLit TUnit, Expr.mark_pos m; TAny, Expr.mark_pos m]
m
(** [make_matchopt_with_abs_arms arg e_none e_some] build an expression (** [make_matchopt_with_abs_arms arg e_none e_some] build an expression
[match arg with |None -> e_none | Some -> e_some] and requires e_some and [match arg with |None -> e_none | Some -> e_some] and requires e_some and
e_none to be in the form [EAbs ...].*) e_none to be in the form [EAbs ...].*)
let make_matchopt_with_abs_arms arg e_none e_some = let make_matchopt_with_abs_arms arg e_none e_some =
let m = Marked.get_mark arg in let m = Marked.get_mark arg in
Expr.ematch arg [e_none; e_some] option_enum m let cases =
EnumConstructor.Map.empty
|> EnumConstructor.Map.add none_constr e_none
|> EnumConstructor.Map.add some_constr e_some
in
Expr.ematch arg option_enum cases m
(** [make_matchopt pos v tau arg e_none e_some] builds an expression (** [make_matchopt pos v tau arg e_none e_some] builds an expression
[match arg with | None () -> e_none | Some v -> e_some]. It binds v to [match arg with | None () -> e_none | Some v -> e_some]. It binds v to

View File

@ -14,6 +14,7 @@
License for the specific language governing permissions and limitations under License for the specific language governing permissions and limitations under
the License. *) the License. *)
open Catala_utils
open Shared_ast open Shared_ast
(** Abstract syntax tree for the lambda calculus *) (** Abstract syntax tree for the lambda calculus *)
@ -32,7 +33,7 @@ type 'm program = 'm expr Shared_ast.program
val option_enum : EnumName.t val option_enum : EnumName.t
val none_constr : EnumConstructor.t val none_constr : EnumConstructor.t
val some_constr : EnumConstructor.t val some_constr : EnumConstructor.t
val option_enum_config : (EnumConstructor.t * typ) list val option_enum_config : typ EnumConstructor.Map.t
val make_none : 'm mark -> 'm expr boxed val make_none : 'm mark -> 'm expr boxed
val make_some : 'm expr boxed -> 'm expr boxed val make_some : 'm expr boxed -> 'm expr boxed
@ -40,7 +41,7 @@ val make_matchopt_with_abs_arms :
'm expr boxed -> 'm expr boxed -> 'm expr boxed -> 'm expr boxed 'm expr boxed -> 'm expr boxed -> 'm expr boxed -> 'm expr boxed
val make_matchopt : val make_matchopt :
Utils.Pos.t -> Pos.t ->
'm expr Var.t -> 'm expr Var.t ->
typ -> typ ->
'm expr boxed -> 'm expr boxed ->

View File

@ -14,7 +14,7 @@
License for the specific language governing permissions and limitations under License for the specific language governing permissions and limitations under
the License. *) the License. *)
open Utils open Catala_utils
open Shared_ast open Shared_ast
open Ast open Ast
module D = Dcalc.Ast module D = Dcalc.Ast
@ -31,74 +31,56 @@ let closure_conversion_expr (type m) (ctx : m ctx) (e : m expr) : m expr boxed =
let rec aux e = let rec aux e =
let m = Marked.get_mark e in let m = Marked.get_mark e in
match Marked.unmark e with match Marked.unmark e with
| EStruct _ | EStructAccess _ | ETuple _ | ETupleAccess _ | EInj _
| EArray _ | ELit _ | EAssert _ | EOp _ | EIfThenElse _ | ERaise _
| ECatch _ ->
Expr.map_gather ~acc:Var.Set.empty ~join:Var.Set.union ~f:aux e
| EVar v -> | EVar v ->
( (Bindlib.box_var v, m), ( (if Var.Set.mem v ctx.globally_bound_vars then Var.Set.empty
if Var.Set.mem v ctx.globally_bound_vars then Var.Set.empty else Var.Set.singleton v),
else Var.Set.singleton v ) (Bindlib.box_var v, m) )
| ETuple (args, s) -> | EMatch { e; cases; name } ->
let new_args, free_vars = let free_vars, new_e = aux e in
List.fold_left
(fun (new_args, free_vars) arg ->
let new_arg, new_free_vars = aux arg in
new_arg :: new_args, Var.Set.union new_free_vars free_vars)
([], Var.Set.empty) args
in
Expr.etuple (List.rev new_args) s m, free_vars
| ETupleAccess (e1, n, s, typs) ->
let new_e1, free_vars = aux e1 in
Expr.etupleaccess new_e1 n s typs m, free_vars
| EInj (e1, n, e_name, typs) ->
let new_e1, free_vars = aux e1 in
Expr.einj new_e1 n e_name typs m, free_vars
| EMatch (e1, arms, e_name) ->
let new_e1, free_vars = aux e1 in
(* We do not close the clotures inside the arms of the match expression, (* We do not close the clotures inside the arms of the match expression,
since they get a special treatment at compilation to Scalc. *) since they get a special treatment at compilation to Scalc. *)
let new_arms, free_vars = let free_vars, new_cases =
List.fold_right EnumConstructor.Map.fold
(fun arm (new_arms, free_vars) -> (fun cons e1 (free_vars, new_cases) ->
match Marked.unmark arm with match Marked.unmark e1 with
| EAbs (binder, typs) -> | EAbs { binder; tys } ->
let vars, body = Bindlib.unmbind binder in let vars, body = Bindlib.unmbind binder in
let new_body, new_free_vars = aux body in let new_free_vars, new_body = aux body in
let new_binder = Expr.bind vars new_body in let new_binder = Expr.bind vars new_body in
( Expr.eabs new_binder typs (Marked.get_mark arm) :: new_arms, ( Var.Set.union free_vars new_free_vars,
Var.Set.union free_vars new_free_vars ) EnumConstructor.Map.add cons
(Expr.eabs new_binder tys (Marked.get_mark e1))
new_cases )
| _ -> failwith "should not happen") | _ -> failwith "should not happen")
arms ([], free_vars) cases
(free_vars, EnumConstructor.Map.empty)
in in
Expr.ematch new_e1 new_arms e_name m, free_vars free_vars, Expr.ematch new_e name new_cases m
| EArray args -> | EApp { f = EAbs { binder; tys }, e1_pos; args } ->
let new_args, free_vars =
List.fold_right
(fun arg (new_args, free_vars) ->
let new_arg, new_free_vars = aux arg in
new_arg :: new_args, Var.Set.union free_vars new_free_vars)
args ([], Var.Set.empty)
in
Expr.earray new_args m, free_vars
| ELit l -> Expr.elit l m, Var.Set.empty
| EApp ((EAbs (binder, typs_abs), e1_pos), args) ->
(* let-binding, we should not close these *) (* let-binding, we should not close these *)
let vars, body = Bindlib.unmbind binder in let vars, body = Bindlib.unmbind binder in
let new_body, free_vars = aux body in let free_vars, new_body = aux body in
let new_binder = Expr.bind vars new_body in let new_binder = Expr.bind vars new_body in
let new_args, free_vars = let free_vars, new_args =
List.fold_right List.fold_right
(fun arg (new_args, free_vars) -> (fun arg (free_vars, new_args) ->
let new_arg, new_free_vars = aux arg in let new_free_vars, new_arg = aux arg in
new_arg :: new_args, Var.Set.union free_vars new_free_vars) Var.Set.union free_vars new_free_vars, new_arg :: new_args)
args ([], free_vars) args (free_vars, [])
in in
Expr.eapp (Expr.eabs new_binder typs_abs e1_pos) new_args m, free_vars free_vars, Expr.eapp (Expr.eabs new_binder tys e1_pos) new_args m
| EAbs (binder, typs) -> | EAbs { binder; tys } ->
(* λ x.t *) (* λ x.t *)
let binder_mark = m in let binder_mark = m in
let binder_pos = Expr.mark_pos binder_mark in let binder_pos = Expr.mark_pos binder_mark in
(* Converting the closure. *) (* Converting the closure. *)
let vars, body = Bindlib.unmbind binder in let vars, body = Bindlib.unmbind binder in
(* t *) (* t *)
let new_body, body_vars = aux body in let body_vars, new_body = aux body in
(* [[t]] *) (* [[t]] *)
let extra_vars = let extra_vars =
Var.Set.diff body_vars (Var.Set.of_list (Array.to_list vars)) Var.Set.diff body_vars (Var.Set.of_list (Array.to_list vars))
@ -117,8 +99,8 @@ let closure_conversion_expr (type m) (ctx : m ctx) (e : m expr) : m expr boxed =
(fun i _ -> (fun i _ ->
Expr.etupleaccess Expr.etupleaccess
(Expr.evar inner_c_var binder_mark) (Expr.evar inner_c_var binder_mark)
(i + 1) None (i + 1)
(List.map (fun _ -> any_ty) extra_vars_list) (List.length extra_vars_list)
binder_mark) binder_mark)
extra_vars_list) extra_vars_list)
new_body new_body
@ -128,10 +110,11 @@ let closure_conversion_expr (type m) (ctx : m ctx) (e : m expr) : m expr boxed =
Expr.make_abs Expr.make_abs
(Array.concat [Array.make 1 inner_c_var; vars]) (Array.concat [Array.make 1 inner_c_var; vars])
new_closure_body new_closure_body
((TAny, binder_pos) :: typs) ((TAny, binder_pos) :: tys)
(Expr.pos e) (Expr.pos e)
in in
( Expr.make_let_in code_var ( extra_vars,
Expr.make_let_in code_var
(TAny, Expr.pos e) (TAny, Expr.pos e)
new_closure new_closure
(Expr.etuple (Expr.etuple
@ -139,40 +122,25 @@ let closure_conversion_expr (type m) (ctx : m ctx) (e : m expr) : m expr boxed =
:: List.map :: List.map
(fun extra_var -> Bindlib.box_var extra_var, binder_mark) (fun extra_var -> Bindlib.box_var extra_var, binder_mark)
extra_vars_list) extra_vars_list)
None m) m)
(Expr.pos e), (Expr.pos e) )
extra_vars ) | EApp { f = EOp _, _; _ } ->
| EApp ((EOp op, pos_op), args) ->
(* This corresponds to an operator call, which we don't want to (* This corresponds to an operator call, which we don't want to
transform*) transform*)
let new_args, free_vars = Expr.map_gather ~acc:Var.Set.empty ~join:Var.Set.union ~f:aux e
List.fold_right | EApp { f = EVar v, _; _ } when Var.Set.mem v ctx.globally_bound_vars ->
(fun arg (new_args, free_vars) ->
let new_arg, new_free_vars = aux arg in
new_arg :: new_args, Var.Set.union free_vars new_free_vars)
args ([], Var.Set.empty)
in
Expr.eapp (Expr.eop op pos_op) new_args m, free_vars
| EApp ((EVar v, v_pos), args) when Var.Set.mem v ctx.globally_bound_vars ->
(* This corresponds to a scope call, which we don't want to transform*) (* This corresponds to a scope call, which we don't want to transform*)
let new_args, free_vars = Expr.map_gather ~acc:Var.Set.empty ~join:Var.Set.union ~f:aux e
List.fold_right | EApp { f = e1; args } ->
(fun arg (new_args, free_vars) -> let free_vars, new_e1 = aux e1 in
let new_arg, new_free_vars = aux arg in
new_arg :: new_args, Var.Set.union free_vars new_free_vars)
args ([], Var.Set.empty)
in
Expr.eapp (Bindlib.box_var v, v_pos) new_args m, free_vars
| EApp (e1, args) ->
let new_e1, free_vars = aux e1 in
let env_var = Var.make "env" in let env_var = Var.make "env" in
let code_var = Var.make "code" in let code_var = Var.make "code" in
let new_args, free_vars = let free_vars, new_args =
List.fold_right List.fold_right
(fun arg (new_args, free_vars) -> (fun arg (free_vars, new_args) ->
let new_arg, new_free_vars = aux arg in let new_free_vars, new_arg = aux arg in
new_arg :: new_args, Var.Set.union free_vars new_free_vars) Var.Set.union free_vars new_free_vars, new_arg :: new_args)
args ([], free_vars) args (free_vars, [])
in in
let call_expr = let call_expr =
let m1 = Marked.get_mark e1 in let m1 = Marked.get_mark e1 in
@ -180,7 +148,8 @@ let closure_conversion_expr (type m) (ctx : m ctx) (e : m expr) : m expr boxed =
(TAny, Expr.pos e) (TAny, Expr.pos e)
(Expr.etupleaccess (Expr.etupleaccess
(Bindlib.box_var env_var, m1) (Bindlib.box_var env_var, m1)
0 None [ (*TODO: fill?*) ] 0
(List.length new_args + 1)
m) m)
(Expr.eapp (Expr.eapp
(Bindlib.box_var code_var, m1) (Bindlib.box_var code_var, m1)
@ -188,25 +157,12 @@ let closure_conversion_expr (type m) (ctx : m ctx) (e : m expr) : m expr boxed =
m) m)
(Expr.pos e) (Expr.pos e)
in in
( Expr.make_let_in env_var (TAny, Expr.pos e) new_e1 call_expr (Expr.pos e), ( free_vars,
free_vars ) Expr.make_let_in env_var
| EAssert e1 -> (TAny, Expr.pos e)
let new_e1, free_vars = aux e1 in new_e1 call_expr (Expr.pos e) )
Expr.eassert new_e1 m, free_vars
| EOp op -> Expr.eop op m, Var.Set.empty
| EIfThenElse (e1, e2, e3) ->
let new_e1, free_vars1 = aux e1 in
let new_e2, free_vars2 = aux e2 in
let new_e3, free_vars3 = aux e3 in
( Expr.eifthenelse new_e1 new_e2 new_e3 m,
Var.Set.union (Var.Set.union free_vars1 free_vars2) free_vars3 )
| ERaise except -> Expr.eraise except m, Var.Set.empty
| ECatch (e1, except, e2) ->
let new_e1, free_vars1 = aux e1 in
let new_e2, free_vars2 = aux e2 in
Expr.ecatch new_e1 except new_e2 m, Var.Set.union free_vars1 free_vars2
in in
let e', _vars = aux e in let _vars, e' = aux e in
e' e'
let closure_conversion (p : 'm program) : 'm program Bindlib.box = let closure_conversion (p : 'm program) : 'm program Bindlib.box =

View File

@ -14,7 +14,7 @@
License for the specific language governing permissions and limitations under License for the specific language governing permissions and limitations under
the License. *) the License. *)
open Utils open Catala_utils
open Shared_ast open Shared_ast
module D = Dcalc.Ast module D = Dcalc.Ast
module A = Ast module A = Ast
@ -43,7 +43,7 @@ let rec translate_default
Expr.make_app Expr.make_app
(Expr.make_var (Expr.make_var
(Var.translate A.handle_default) (Var.translate A.handle_default)
(Expr.with_ty mark_default (Utils.Marked.mark pos TAny))) (Expr.with_ty mark_default (Marked.mark pos TAny)))
[ [
Expr.earray exceptions mark_default; Expr.earray exceptions mark_default;
thunk_expr (translate_expr ctx just); thunk_expr (translate_expr ctx just);
@ -54,39 +54,39 @@ let rec translate_default
exceptions exceptions
and translate_expr (ctx : 'm ctx) (e : 'm D.expr) : 'm A.expr boxed = and translate_expr (ctx : 'm ctx) (e : 'm D.expr) : 'm A.expr boxed =
let m = Marked.get_mark e in
match Marked.unmark e with match Marked.unmark e with
| EVar v -> Expr.make_var (Var.Map.find v ctx) (Marked.get_mark e) | EVar v -> Expr.make_var (Var.Map.find v ctx) m
| ETuple (args, s) -> | EStruct { name; fields } ->
Expr.etuple (List.map (translate_expr ctx) args) s (Marked.get_mark e) Expr.estruct name (StructField.Map.map (translate_expr ctx) fields) m
| ETupleAccess (e1, i, s, ts) -> | EStructAccess { name; e; field } ->
Expr.etupleaccess (translate_expr ctx e1) i s ts (Marked.get_mark e) Expr.estructaccess (translate_expr ctx e) field name m
| EInj (e1, i, en, ts) -> | EInj { name; e; cons } -> Expr.einj (translate_expr ctx e) cons name m
Expr.einj (translate_expr ctx e1) i en ts (Marked.get_mark e) | EMatch { name; e; cases } ->
| EMatch (e1, cases, en) -> Expr.ematch (translate_expr ctx e) name
Expr.ematch (translate_expr ctx e1) (EnumConstructor.Map.map (translate_expr ctx) cases)
(List.map (translate_expr ctx) cases) m
en (Marked.get_mark e) | EArray es -> Expr.earray (List.map (translate_expr ctx) es) m
| EArray es ->
Expr.earray (List.map (translate_expr ctx) es) (Marked.get_mark e)
| ELit | ELit
((LBool _ | LInt _ | LRat _ | LMoney _ | LUnit | LDate _ | LDuration _) as ((LBool _ | LInt _ | LRat _ | LMoney _ | LUnit | LDate _ | LDuration _) as
l) -> l) ->
Expr.elit l (Marked.get_mark e) Expr.elit l m
| ELit LEmptyError -> Expr.eraise EmptyError (Marked.get_mark e) | ELit LEmptyError -> Expr.eraise EmptyError m
| EOp op -> Expr.eop op (Marked.get_mark e) | EOp { op; tys } -> Expr.eop (Operator.translate op) tys m
| EIfThenElse (e1, e2, e3) -> | EIfThenElse { cond; etrue; efalse } ->
Expr.eifthenelse (translate_expr ctx e1) (translate_expr ctx e2) Expr.eifthenelse (translate_expr ctx cond) (translate_expr ctx etrue)
(translate_expr ctx e3) (Marked.get_mark e) (translate_expr ctx efalse)
| EAssert e1 -> Expr.eassert (translate_expr ctx e1) (Marked.get_mark e) m
| ErrorOnEmpty arg -> | EAssert e1 -> Expr.eassert (translate_expr ctx e1) m
| EErrorOnEmpty arg ->
Expr.ecatch (translate_expr ctx arg) EmptyError Expr.ecatch (translate_expr ctx arg) EmptyError
(Expr.eraise NoValueProvided (Marked.get_mark e)) (Expr.eraise NoValueProvided m)
(Marked.get_mark e) m
| EApp (e1, args) -> | EApp { f; args } ->
Expr.eapp (translate_expr ctx e1) Expr.eapp (translate_expr ctx f)
(List.map (translate_expr ctx) args) (List.map (translate_expr ctx) args)
(Marked.get_mark e) (Marked.get_mark e)
| EAbs (binder, ts) -> | EAbs { binder; tys } ->
let vars, body = Bindlib.unmbind binder in let vars, body = Bindlib.unmbind binder in
let ctx, lc_vars = let ctx, lc_vars =
Array.fold_right Array.fold_right
@ -98,15 +98,16 @@ and translate_expr (ctx : 'm ctx) (e : 'm D.expr) : 'm A.expr boxed =
let lc_vars = Array.of_list lc_vars in let lc_vars = Array.of_list lc_vars in
let new_body = translate_expr ctx body in let new_body = translate_expr ctx body in
let new_binder = Expr.bind lc_vars new_body in let new_binder = Expr.bind lc_vars new_body in
Expr.eabs new_binder ts (Marked.get_mark e) Expr.eabs new_binder tys (Marked.get_mark e)
| EDefault ([exn], just, cons) when !Cli.optimize_flag -> | EDefault { excepts = [exn]; just; cons } when !Cli.optimize_flag ->
(* FIXME: bad place to rely on a global flag *)
Expr.ecatch (translate_expr ctx exn) EmptyError Expr.ecatch (translate_expr ctx exn) EmptyError
(Expr.eifthenelse (translate_expr ctx just) (translate_expr ctx cons) (Expr.eifthenelse (translate_expr ctx just) (translate_expr ctx cons)
(Expr.eraise EmptyError (Marked.get_mark e)) (Expr.eraise EmptyError (Marked.get_mark e))
(Marked.get_mark e)) (Marked.get_mark e))
(Marked.get_mark e) (Marked.get_mark e)
| EDefault (exceptions, just, cons) -> | EDefault { excepts; just; cons } ->
translate_default ctx exceptions just cons (Marked.get_mark e) translate_default ctx excepts just cons (Marked.get_mark e)
let rec translate_scope_lets let rec translate_scope_lets
(decl_ctx : decl_ctx) (decl_ctx : decl_ctx)

View File

@ -14,7 +14,7 @@
License for the specific language governing permissions and limitations under License for the specific language governing permissions and limitations under
the License. *) the License. *)
open Utils open Catala_utils
module D = Dcalc.Ast module D = Dcalc.Ast
module A = Ast module A = Ast
@ -170,7 +170,7 @@ let rec translate_and_hoist (ctx : 'm ctx) (e : 'm D.expr) :
created a variable %a to replace it" Print.var v Print.var v'; *) created a variable %a to replace it" Print.var v Print.var v'; *)
Expr.make_var v' mark, Var.Map.singleton v' e Expr.make_var v' mark, Var.Map.singleton v' e
else (find ~info:"should never happen" v ctx).expr, Var.Map.empty else (find ~info:"should never happen" v ctx).expr, Var.Map.empty
| EApp ((EVar v, p), [(ELit LUnit, _)]) -> | EApp { f = EVar v, p; args = [(ELit LUnit, _)] } ->
if not (find ~info:"search for a variable" v ctx).is_pure then if not (find ~info:"search for a variable" v ctx).is_pure then
let v' = Var.make (Bindlib.name_of v) in let v' = Var.make (Bindlib.name_of v) in
(* Cli.debug_print @@ Format.asprintf "Found an unpure variable %a, (* Cli.debug_print @@ Format.asprintf "Found an unpure variable %a,
@ -179,7 +179,7 @@ let rec translate_and_hoist (ctx : 'm ctx) (e : 'm D.expr) :
else else
Errors.raise_spanned_error (Expr.pos e) Errors.raise_spanned_error (Expr.pos e)
"Internal error: an pure variable was found in an unpure environment." "Internal error: an pure variable was found in an unpure environment."
| EDefault (_exceptions, _just, _cons) -> | EDefault _ ->
let v' = Var.make "default_term" in let v' = Var.make "default_term" in
Expr.make_var v' mark, Var.Map.singleton v' e Expr.make_var v' mark, Var.Map.singleton v' e
| ELit LEmptyError -> | ELit LEmptyError ->
@ -187,7 +187,7 @@ let rec translate_and_hoist (ctx : 'm ctx) (e : 'm D.expr) :
Expr.make_var v' mark, Var.Map.singleton v' e Expr.make_var v' mark, Var.Map.singleton v' e
(* This one is a very special case. It transform an unpure expression (* This one is a very special case. It transform an unpure expression
environement to a pure expression. *) environement to a pure expression. *)
| ErrorOnEmpty arg -> | EErrorOnEmpty arg ->
(* [ match arg with | None -> raise NoValueProvided | Some v -> {{ v }} ] *) (* [ match arg with | None -> raise NoValueProvided | Some v -> {{ v }} ] *)
let silent_var = Var.make "_" in let silent_var = Var.make "_" in
let x = Var.make "non_empty_argument" in let x = Var.make "non_empty_argument" in
@ -206,22 +206,23 @@ let rec translate_and_hoist (ctx : 'm ctx) (e : 'm D.expr) :
((LBool _ | LInt _ | LRat _ | LMoney _ | LUnit | LDate _ | LDuration _) as ((LBool _ | LInt _ | LRat _ | LMoney _ | LUnit | LDate _ | LDuration _) as
l) -> l) ->
Expr.elit l mark, Var.Map.empty Expr.elit l mark, Var.Map.empty
| EIfThenElse (e1, e2, e3) -> | EIfThenElse { cond; etrue; efalse } ->
let e1', h1 = translate_and_hoist ctx e1 in let cond', h1 = translate_and_hoist ctx cond in
let e2', h2 = translate_and_hoist ctx e2 in let etrue', h2 = translate_and_hoist ctx etrue in
let e3', h3 = translate_and_hoist ctx e3 in let efalse', h3 = translate_and_hoist ctx efalse in
let e' = Expr.eifthenelse e1' e2' e3' mark in let e' = Expr.eifthenelse cond' etrue' efalse' mark in
(*(* equivalent code : *) let e' = let+ e1' = e1' and+ e2' = e2' and+ e3' = (*(* equivalent code : *) let e' = let+ cond' = cond' and+ etrue' = etrue'
e3' in (A.EIfThenElse (e1', e2', e3'), pos) in *) and+ efalse' = efalse' in (A.EIfThenElse (cond', etrue', efalse'), pos)
in *)
e', disjoint_union_maps (Expr.pos e) [h1; h2; h3] e', disjoint_union_maps (Expr.pos e) [h1; h2; h3]
| EAssert e1 -> | EAssert e1 ->
(* same behavior as in the ICFP paper: if e1 is empty, then no error is (* same behavior as in the ICFP paper: if e1 is empty, then no error is
raised. *) raised. *)
let e1', h1 = translate_and_hoist ctx e1 in let e1', h1 = translate_and_hoist ctx e1 in
Expr.eassert e1' mark, h1 Expr.eassert e1' mark, h1
| EAbs (binder, ts) -> | EAbs { binder; tys } ->
let vars, body = Bindlib.unmbind binder in let vars, body = Bindlib.unmbind binder in
let ctx, lc_vars = let ctx, lc_vars =
ArrayLabels.fold_right vars ~init:(ctx, []) ~f:(fun var (ctx, lc_vars) -> ArrayLabels.fold_right vars ~init:(ctx, []) ~f:(fun var (ctx, lc_vars) ->
@ -242,8 +243,8 @@ let rec translate_and_hoist (ctx : 'm ctx) (e : 'm D.expr) :
let new_body, hoists = translate_and_hoist ctx body in let new_body, hoists = translate_and_hoist ctx body in
let new_binder = Expr.bind lc_vars new_body in let new_binder = Expr.bind lc_vars new_body in
Expr.eabs new_binder (List.map translate_typ ts) mark, hoists Expr.eabs new_binder (List.map translate_typ tys) mark, hoists
| EApp (e1, args) -> | EApp { f = e1; args } ->
let e1', h1 = translate_and_hoist ctx e1 in let e1', h1 = translate_and_hoist ctx e1 in
let args', h_args = let args', h_args =
args |> List.map (translate_and_hoist ctx) |> List.split args |> List.map (translate_and_hoist ctx) |> List.split
@ -252,35 +253,43 @@ let rec translate_and_hoist (ctx : 'm ctx) (e : 'm D.expr) :
let hoists = disjoint_union_maps (Expr.pos e) (h1 :: h_args) in let hoists = disjoint_union_maps (Expr.pos e) (h1 :: h_args) in
let e' = Expr.eapp e1' args' mark in let e' = Expr.eapp e1' args' mark in
e', hoists e', hoists
| ETuple (args, s) -> | EStruct { name; fields } ->
let args', h_args = let fields', h_fields =
args |> List.map (translate_and_hoist ctx) |> List.split StructField.Map.fold
(fun field e (fields, hoists) ->
let e, h = translate_and_hoist ctx e in
StructField.Map.add field e fields, h :: hoists)
fields
(StructField.Map.empty, [])
in in
let hoists = disjoint_union_maps (Expr.pos e) h_fields in
let hoists = disjoint_union_maps (Expr.pos e) h_args in Expr.estruct name fields' mark, hoists
Expr.etuple args' s mark, hoists | EStructAccess { name; e = e1; field } ->
| ETupleAccess (e1, i, s, ts) ->
let e1', hoists = translate_and_hoist ctx e1 in let e1', hoists = translate_and_hoist ctx e1 in
let e1' = Expr.etupleaccess e1' i s ts mark in let e1' = Expr.estructaccess e1' field name mark in
e1', hoists e1', hoists
| EInj (e1, i, en, ts) -> | EInj { name; e = e1; cons } ->
let e1', hoists = translate_and_hoist ctx e1 in let e1', hoists = translate_and_hoist ctx e1 in
let e1' = Expr.einj e1' i en ts mark in let e1' = Expr.einj e1' cons name mark in
e1', hoists e1', hoists
| EMatch (e1, cases, en) -> | EMatch { name; e = e1; cases } ->
let e1', h1 = translate_and_hoist ctx e1 in let e1', h1 = translate_and_hoist ctx e1 in
let cases', h_cases = let cases', h_cases =
cases |> List.map (translate_and_hoist ctx) |> List.split EnumConstructor.Map.fold
(fun cons e (cases, hoists) ->
let e', h = translate_and_hoist ctx e in
EnumConstructor.Map.add cons e' cases, h :: hoists)
cases
(EnumConstructor.Map.empty, [])
in in
let hoists = disjoint_union_maps (Expr.pos e) (h1 :: h_cases) in let hoists = disjoint_union_maps (Expr.pos e) (h1 :: h_cases) in
let e' = Expr.ematch e1' cases' en mark in let e' = Expr.ematch e1' name cases' mark in
e', hoists e', hoists
| EArray es -> | EArray es ->
let es', hoists = es |> List.map (translate_and_hoist ctx) |> List.split in let es', hoists = es |> List.map (translate_and_hoist ctx) |> List.split in
Expr.earray es' mark, disjoint_union_maps (Expr.pos e) hoists Expr.earray es' mark, disjoint_union_maps (Expr.pos e) hoists
| EOp op -> Expr.eop op mark, Var.Map.empty | EOp { op; tys } -> Expr.eop (Operator.translate op) tys mark, Var.Map.empty
and translate_expr ?(append_esome = true) (ctx : 'm ctx) (e : 'm D.expr) : and translate_expr ?(append_esome = true) (ctx : 'm ctx) (e : 'm D.expr) :
'm A.expr boxed = 'm A.expr boxed =
@ -302,14 +311,14 @@ and translate_expr ?(append_esome = true) (ctx : 'm ctx) (e : 'm D.expr) :
(* Here we have to handle only the cases appearing in hoists, as defined (* Here we have to handle only the cases appearing in hoists, as defined
the [translate_and_hoist] function. *) the [translate_and_hoist] function. *)
| EVar v -> (find ~info:"should never happen" v ctx).expr | EVar v -> (find ~info:"should never happen" v ctx).expr
| EDefault (excep, just, cons) -> | EDefault { excepts; just; cons } ->
let excep' = List.map (translate_expr ctx) excep in let excepts' = List.map (translate_expr ctx) excepts in
let just' = translate_expr ctx just in let just' = translate_expr ctx just in
let cons' = translate_expr ctx cons in let cons' = translate_expr ctx cons in
(* calls handle_option. *) (* calls handle_option. *)
Expr.make_app Expr.make_app
(Expr.make_var (Var.translate A.handle_default_opt) mark_hoist) (Expr.make_var (Var.translate A.handle_default_opt) mark_hoist)
[Expr.earray excep' mark_hoist; just'; cons'] [Expr.earray excepts' mark_hoist; just'; cons']
pos pos
| ELit LEmptyError -> A.make_none mark_hoist | ELit LEmptyError -> A.make_none mark_hoist
| EAssert arg -> | EAssert arg ->
@ -354,7 +363,7 @@ let rec translate_scope_let (ctx : 'm ctx) (lets : 'm D.expr scope_body_expr) :
{ {
scope_let_kind = SubScopeVarDefinition; scope_let_kind = SubScopeVarDefinition;
scope_let_typ = typ; scope_let_typ = typ;
scope_let_expr = EAbs (binder, _), emark; scope_let_expr = EAbs { binder; _ }, emark;
scope_let_next = next; scope_let_next = next;
scope_let_pos = pos; scope_let_pos = pos;
} -> } ->
@ -385,7 +394,7 @@ let rec translate_scope_let (ctx : 'm ctx) (lets : 'm D.expr scope_body_expr) :
{ {
scope_let_kind = SubScopeVarDefinition; scope_let_kind = SubScopeVarDefinition;
scope_let_typ = typ; scope_let_typ = typ;
scope_let_expr = (ErrorOnEmpty _, emark) as expr; scope_let_expr = (EErrorOnEmpty _, emark) as expr;
scope_let_next = next; scope_let_next = next;
scope_let_pos = pos; scope_let_pos = pos;
} -> } ->
@ -529,7 +538,7 @@ let translate_program (prgm : 'm D.program) : 'm A.program =
prgm.decl_ctx with prgm.decl_ctx with
ctx_enums = ctx_enums =
prgm.decl_ctx.ctx_enums prgm.decl_ctx.ctx_enums
|> EnumMap.add A.option_enum A.option_enum_config; |> EnumName.Map.add A.option_enum A.option_enum_config;
} }
in in
let decl_ctx = let decl_ctx =
@ -537,15 +546,14 @@ let translate_program (prgm : 'm D.program) : 'm A.program =
decl_ctx with decl_ctx with
ctx_structs = ctx_structs =
prgm.decl_ctx.ctx_structs prgm.decl_ctx.ctx_structs
|> StructMap.mapi (fun n l -> |> StructName.Map.mapi (fun n str ->
if List.mem n inputs_structs then if List.mem n inputs_structs then
ListLabels.map l ~f:(fun (n, tau) -> StructField.Map.map translate_typ str
(* Cli.debug_print @@ Format.asprintf "Input type: %a" (* Cli.debug_print @@ Format.asprintf "Input type: %a"
(Print.typ decl_ctx) tau; Cli.debug_print @@ (Print.typ decl_ctx) tau; Cli.debug_print @@ Format.asprintf
Format.asprintf "Output type: %a" (Print.typ decl_ctx) "Output type: %a" (Print.typ decl_ctx) (translate_typ
(translate_typ tau); *) tau); *)
n, translate_typ tau) else str);
else l);
} }
in in

View File

@ -0,0 +1,21 @@
(* This file is part of the Catala compiler, a specification language for tax
and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
Denis Merigoux <denis.merigoux@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not
use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *)
let translate_program_with_exceptions =
Compile_with_exceptions.translate_program
let translate_program_without_exceptions =
Compile_without_exceptions.translate_program

View File

@ -0,0 +1,26 @@
(* This file is part of the Catala compiler, a specification language for tax
and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
Denis Merigoux <denis.merigoux@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not
use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *)
val translate_program_with_exceptions : 'm Dcalc.Ast.program -> 'm Ast.program
(** Translation from the default calculus to the lambda calculus. This
translation uses exceptions to handle empty default terms. *)
val translate_program_without_exceptions :
'm Dcalc.Ast.program -> 'm Ast.program
(** Translation from the default calculus to the lambda calculus. This
translation uses an option monad to handle empty defaults terms. This
transformation is one piece to permit to compile toward legacy languages
that does not contains exceptions. *)

View File

@ -13,50 +13,47 @@
WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under License for the specific language governing permissions and limitations under
the License. *) the License. *)
open Utils open Catala_utils
open Shared_ast open Shared_ast
open Ast open Ast
module D = Dcalc.Ast module D = Dcalc.Ast
let visitor_map (t : 'a -> 'm expr -> 'm expr boxed) (ctx : 'a) (e : 'm expr) : let visitor_map (t : 'm expr -> 'm expr boxed) (e : 'm expr) : 'm expr boxed =
'm expr boxed = Expr.map ~f:t e
Expr.map ctx ~f:t e
let rec iota_expr (_ : unit) (e : 'm expr) : 'm expr boxed = let rec iota_expr (e : 'm expr) : 'm expr boxed =
let m = Marked.get_mark e in let m = Marked.get_mark e in
match Marked.unmark e with match Marked.unmark e with
| EMatch ((EInj (e1, i, n', _ts), _), cases, n) when EnumName.compare n n' = 0 | EMatch { e = EInj { e = e'; cons; name = n' }, _; cases; name = n }
-> when EnumName.equal n n' ->
let e1 = visitor_map iota_expr () e1 in let e1 = visitor_map iota_expr e' in
let case = visitor_map iota_expr () (List.nth cases i) in let case = visitor_map iota_expr (EnumConstructor.Map.find cons cases) in
Expr.eapp case [e1] m Expr.eapp case [e1] m
| EMatch (e', cases, n) | EMatch { e = e'; cases; name = n }
when cases when cases
|> List.mapi (fun i (case, _pos) -> |> EnumConstructor.Map.mapi (fun i case ->
match case with match Marked.unmark case with
| EInj (_ei, i', n', _ts') -> | EInj { cons = i'; name = n'; _ } ->
i = i' && (* n = n' *) EnumName.compare n n' = 0 EnumConstructor.equal i i' && EnumName.equal n n'
| _ -> false) | _ -> false)
|> List.for_all Fun.id -> |> EnumConstructor.Map.for_all (fun _ b -> b) ->
visitor_map iota_expr () e' visitor_map iota_expr e'
| _ -> visitor_map iota_expr () e | _ -> visitor_map iota_expr e
let rec beta_expr (e : 'm expr) : 'm expr boxed = let rec beta_expr (e : 'm expr) : 'm expr boxed =
let m = Marked.get_mark e in let m = Marked.get_mark e in
match Marked.unmark e with match Marked.unmark e with
| EApp (e1, args) -> | EApp { f = e1; args } ->
Expr.Box.app1n (beta_expr e1) (List.map beta_expr args) Expr.Box.app1n (beta_expr e1) (List.map beta_expr args)
(fun e1 args -> (fun e1 args ->
match Marked.unmark e1 with match Marked.unmark e1 with
| EAbs (binder, _) -> Marked.unmark (Expr.subst binder args) | EAbs { binder; _ } -> Marked.unmark (Expr.subst binder args)
| _ -> EApp (e1, args)) | _ -> EApp { f = e1; args })
m m
| _ -> visitor_map (fun () -> beta_expr) () e | _ -> visitor_map beta_expr e
let iota_optimizations (p : 'm program) : 'm program = let iota_optimizations (p : 'm program) : 'm program =
let new_scopes = let new_scopes = Scope.map_exprs ~f:iota_expr ~varf:(fun v -> v) p.scopes in
Scope.map_exprs ~f:(iota_expr ()) ~varf:(fun v -> v) p.scopes
in
{ p with scopes = Bindlib.unbox new_scopes } { p with scopes = Bindlib.unbox new_scopes }
(* TODO: beta optimizations apply inlining of the program. We left the inclusion (* TODO: beta optimizations apply inlining of the program. We left the inclusion
@ -70,30 +67,32 @@ let _beta_optimizations (p : 'm program) : 'm program =
let rec peephole_expr (e : 'm expr) : 'm expr boxed = let rec peephole_expr (e : 'm expr) : 'm expr boxed =
let m = Marked.get_mark e in let m = Marked.get_mark e in
match Marked.unmark e with match Marked.unmark e with
| EIfThenElse (e1, e2, e3) -> | EIfThenElse { cond; etrue; efalse } ->
Expr.Box.app3 (peephole_expr e1) (peephole_expr e2) (peephole_expr e3) Expr.Box.app3 (peephole_expr cond) (peephole_expr etrue)
(fun e1 e2 e3 -> (peephole_expr efalse)
match Marked.unmark e1 with (fun cond etrue efalse ->
match Marked.unmark cond with
| ELit (LBool true) | ELit (LBool true)
| EApp ((EOp (Unop (Log _)), _), [(ELit (LBool true), _)]) -> | EApp { f = EOp { op = Log _; _ }, _; args = [(ELit (LBool true), _)] }
Marked.unmark e2 ->
Marked.unmark etrue
| ELit (LBool false) | ELit (LBool false)
| EApp ((EOp (Unop (Log _)), _), [(ELit (LBool false), _)]) -> | EApp
Marked.unmark e3 { f = EOp { op = Log _; _ }, _; args = [(ELit (LBool false), _)] }
| _ -> EIfThenElse (e1, e2, e3)) ->
Marked.unmark efalse
| _ -> EIfThenElse { cond; etrue; efalse })
m m
| ECatch (e1, except, e2) -> | ECatch { body; exn; handler } ->
Expr.Box.app2 (peephole_expr e1) (peephole_expr e2) Expr.Box.app2 (peephole_expr body) (peephole_expr handler)
(fun e1 e2 -> (fun body handler ->
match Marked.unmark e1, Marked.unmark e2 with match Marked.unmark body, Marked.unmark handler with
| ERaise except', ERaise except'' | ERaise exn', ERaise exn'' when exn' = exn && exn = exn'' -> ERaise exn
when except' = except && except = except'' -> | ERaise exn', _ when exn' = exn -> Marked.unmark handler
ERaise except | _, ERaise exn' when exn' = exn -> Marked.unmark body
| ERaise except', _ when except' = except -> Marked.unmark e2 | _ -> ECatch { body; exn; handler })
| _, ERaise except' when except' = except -> Marked.unmark e1
| _ -> ECatch (e1, except, e2))
m m
| _ -> visitor_map (fun () -> peephole_expr) () e | _ -> visitor_map peephole_expr e
let peephole_optimizations (p : 'm program) : 'm program = let peephole_optimizations (p : 'm program) : 'm program =
let new_scopes = let new_scopes =

View File

@ -14,24 +14,21 @@
License for the specific language governing permissions and limitations under License for the specific language governing permissions and limitations under
the License. *) the License. *)
open Utils open Catala_utils
open Shared_ast open Shared_ast
open Ast open Ast
open String_common
module D = Dcalc.Ast module D = Dcalc.Ast
let find_struct (s : StructName.t) (ctx : decl_ctx) : let find_struct (s : StructName.t) (ctx : decl_ctx) : typ StructField.Map.t =
(StructFieldName.t * typ) list = try StructName.Map.find s ctx.ctx_structs
try StructMap.find s ctx.ctx_structs
with Not_found -> with Not_found ->
let s_name, pos = StructName.get_info s in let s_name, pos = StructName.get_info s in
Errors.raise_spanned_error pos Errors.raise_spanned_error pos
"Internal Error: Structure %s was not found in the current environment." "Internal Error: Structure %s was not found in the current environment."
s_name s_name
let find_enum (en : EnumName.t) (ctx : decl_ctx) : let find_enum (en : EnumName.t) (ctx : decl_ctx) : typ EnumConstructor.Map.t =
(EnumConstructor.t * typ) list = try EnumName.Map.find en ctx.ctx_enums
try EnumMap.find en ctx.ctx_enums
with Not_found -> with Not_found ->
let en_name, pos = EnumName.get_info en in let en_name, pos = EnumName.get_info en in
Errors.raise_spanned_error pos Errors.raise_spanned_error pos
@ -57,43 +54,13 @@ let format_lit (fmt : Format.formatter) (l : lit Marked.pos) : unit =
let years, months, days = Runtime.duration_to_years_months_days d in let years, months, days = Runtime.duration_to_years_months_days d in
Format.fprintf fmt "duration_of_numbers (%d) (%d) (%d)" years months days Format.fprintf fmt "duration_of_numbers (%d) (%d) (%d)" years months days
let format_op_kind (fmt : Format.formatter) (k : op_kind) =
Format.fprintf fmt "%s"
(match k with
| KInt -> "!"
| KRat -> "&"
| KMoney -> "$"
| KDate -> "@"
| KDuration -> "^")
let format_binop (fmt : Format.formatter) (op : binop Marked.pos) : unit =
match Marked.unmark op with
| Add k -> Format.fprintf fmt "+%a" format_op_kind k
| Sub k -> Format.fprintf fmt "-%a" format_op_kind k
| Mult k -> Format.fprintf fmt "*%a" format_op_kind k
| Div k -> Format.fprintf fmt "/%a" format_op_kind k
| And -> Format.fprintf fmt "%s" "&&"
| Or -> Format.fprintf fmt "%s" "||"
| Eq -> Format.fprintf fmt "%s" "="
| Neq | Xor -> Format.fprintf fmt "%s" "<>"
| Lt k -> Format.fprintf fmt "%s%a" "<" format_op_kind k
| Lte k -> Format.fprintf fmt "%s%a" "<=" format_op_kind k
| Gt k -> Format.fprintf fmt "%s%a" ">" format_op_kind k
| Gte k -> Format.fprintf fmt "%s%a" ">=" format_op_kind k
| Concat -> Format.fprintf fmt "@"
| Map -> Format.fprintf fmt "Array.map"
| Filter -> Format.fprintf fmt "array_filter"
let format_ternop (fmt : Format.formatter) (op : ternop Marked.pos) : unit =
match Marked.unmark op with Fold -> Format.fprintf fmt "Array.fold_left"
let format_uid_list (fmt : Format.formatter) (uids : Uid.MarkedString.info list) let format_uid_list (fmt : Format.formatter) (uids : Uid.MarkedString.info list)
: unit = : unit =
Format.fprintf fmt "@[<hov 2>[%a]@]" Format.fprintf fmt "@[<hov 2>[%a]@]"
(Format.pp_print_list (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ";@ ") ~pp_sep:(fun fmt () -> Format.fprintf fmt ";@ ")
(fun fmt info -> (fun fmt info ->
Format.fprintf fmt "\"%a\"" Utils.Uid.MarkedString.format_info info)) Format.fprintf fmt "\"%a\"" Uid.MarkedString.format info))
uids uids
let format_string_list (fmt : Format.formatter) (uids : string list) : unit = let format_string_list (fmt : Format.formatter) (uids : string list) : unit =
@ -106,26 +73,6 @@ let format_string_list (fmt : Format.formatter) (uids : string list) : unit =
(Re.replace sanitize_quotes ~f:(fun _ -> "\\\"") info))) (Re.replace sanitize_quotes ~f:(fun _ -> "\\\"") info)))
uids uids
let format_unop (fmt : Format.formatter) (op : unop Marked.pos) : unit =
match Marked.unmark op with
| Minus k -> Format.fprintf fmt "~-%a" format_op_kind k
| Not -> Format.fprintf fmt "%s" "not"
| Log (_entry, _infos) ->
Errors.raise_spanned_error (Marked.get_mark op)
"Internal error: a log operator has not been caught by the expression \
match"
| Length -> Format.fprintf fmt "%s" "array_length"
| IntToRat -> Format.fprintf fmt "%s" "decimal_of_integer"
| MoneyToRat -> Format.fprintf fmt "%s" "decimal_of_money"
| RatToMoney -> Format.fprintf fmt "%s" "money_of_decimal"
| GetDay -> Format.fprintf fmt "%s" "day_of_month_of_date"
| GetMonth -> Format.fprintf fmt "%s" "month_number_of_date"
| GetYear -> Format.fprintf fmt "%s" "year_of_date"
| FirstDayOfMonth -> Format.fprintf fmt "%s" "first_day_of_month"
| LastDayOfMonth -> Format.fprintf fmt "%s" "last_day_of_month"
| RoundMoney -> Format.fprintf fmt "%s" "money_round"
| RoundDecimal -> Format.fprintf fmt "%s" "decimal_round"
let avoid_keywords (s : string) : string = let avoid_keywords (s : string) : string =
match s with match s with
(* list taken from (* list taken from
@ -137,14 +84,14 @@ let avoid_keywords (s : string) : string =
| "match" | "method" | "mod" | "module" | "mutable" | "new" | "nonrec" | "match" | "method" | "mod" | "module" | "mutable" | "new" | "nonrec"
| "object" | "of" | "open" | "or" | "private" | "rec" | "sig" | "struct" | "object" | "of" | "open" | "or" | "private" | "rec" | "sig" | "struct"
| "then" | "to" | "true" | "try" | "type" | "val" | "virtual" | "when" | "then" | "to" | "true" | "try" | "type" | "val" | "virtual" | "when"
| "while" | "with" -> | "while" | "with" | "Stdlib" | "Runtime" | "Oper" ->
s ^ "_user" s ^ "_user"
| _ -> s | _ -> s
let format_struct_name (fmt : Format.formatter) (v : StructName.t) : unit = let format_struct_name (fmt : Format.formatter) (v : StructName.t) : unit =
Format.asprintf "%a" StructName.format_t v Format.asprintf "%a" StructName.format_t v
|> to_ascii |> String.to_ascii
|> to_snake_case |> String.to_snake_case
|> avoid_keywords |> avoid_keywords
|> Format.fprintf fmt "%s" |> Format.fprintf fmt "%s"
@ -154,8 +101,8 @@ let format_to_module_name
(match name with (match name with
| `Ename v -> Format.asprintf "%a" EnumName.format_t v | `Ename v -> Format.asprintf "%a" EnumName.format_t v
| `Sname v -> Format.asprintf "%a" StructName.format_t v) | `Sname v -> Format.asprintf "%a" StructName.format_t v)
|> to_ascii |> String.to_ascii
|> to_snake_case |> String.to_snake_case
|> avoid_keywords |> avoid_keywords
|> String.split_on_char '_' |> String.split_on_char '_'
|> List.map String.capitalize_ascii |> List.map String.capitalize_ascii
@ -164,24 +111,25 @@ let format_to_module_name
let format_struct_field_name let format_struct_field_name
(fmt : Format.formatter) (fmt : Format.formatter)
((sname_opt, v) : StructName.t option * StructFieldName.t) : unit = ((sname_opt, v) : StructName.t option * StructField.t) : unit =
(match sname_opt with (match sname_opt with
| Some sname -> | Some sname ->
Format.fprintf fmt "%a.%s" format_to_module_name (`Sname sname) Format.fprintf fmt "%a.%s" format_to_module_name (`Sname sname)
| None -> Format.fprintf fmt "%s") | None -> Format.fprintf fmt "%s")
(avoid_keywords (avoid_keywords
(to_ascii (Format.asprintf "%a" StructFieldName.format_t v))) (String.to_ascii (Format.asprintf "%a" StructField.format_t v)))
let format_enum_name (fmt : Format.formatter) (v : EnumName.t) : unit = let format_enum_name (fmt : Format.formatter) (v : EnumName.t) : unit =
Format.fprintf fmt "%s" Format.fprintf fmt "%s"
(avoid_keywords (avoid_keywords
(to_snake_case (to_ascii (Format.asprintf "%a" EnumName.format_t v)))) (String.to_snake_case
(String.to_ascii (Format.asprintf "%a" EnumName.format_t v))))
let format_enum_cons_name (fmt : Format.formatter) (v : EnumConstructor.t) : let format_enum_cons_name (fmt : Format.formatter) (v : EnumConstructor.t) :
unit = unit =
Format.fprintf fmt "%s" Format.fprintf fmt "%s"
(avoid_keywords (avoid_keywords
(to_ascii (Format.asprintf "%a" EnumConstructor.format_t v))) (String.to_ascii (Format.asprintf "%a" EnumConstructor.format_t v)))
let rec typ_embedding_name (fmt : Format.formatter) (ty : typ) : unit = let rec typ_embedding_name (fmt : Format.formatter) (ty : typ) : unit =
match Marked.unmark ty with match Marked.unmark ty with
@ -225,25 +173,27 @@ let rec format_typ (fmt : Format.formatter) (typ : typ) : unit =
| TAny -> Format.fprintf fmt "_" | TAny -> Format.fprintf fmt "_"
let format_var (fmt : Format.formatter) (v : 'm Var.t) : unit = let format_var (fmt : Format.formatter) (v : 'm Var.t) : unit =
let lowercase_name = to_snake_case (to_ascii (Bindlib.name_of v)) in let lowercase_name =
String.to_snake_case (String.to_ascii (Bindlib.name_of v))
in
let lowercase_name = let lowercase_name =
Re.Pcre.substitute ~rex:(Re.Pcre.regexp "\\.") Re.Pcre.substitute ~rex:(Re.Pcre.regexp "\\.")
~subst:(fun _ -> "_dot_") ~subst:(fun _ -> "_dot_")
lowercase_name lowercase_name
in in
let lowercase_name = avoid_keywords (to_ascii lowercase_name) in let lowercase_name = avoid_keywords (String.to_ascii lowercase_name) in
if if
List.mem lowercase_name ["handle_default"; "handle_default_opt"] List.mem lowercase_name ["handle_default"; "handle_default_opt"]
|| begins_with_uppercase (Bindlib.name_of v) || String.begins_with_uppercase (Bindlib.name_of v)
then Format.fprintf fmt "%s" lowercase_name then Format.pp_print_string fmt lowercase_name
else if lowercase_name = "_" then Format.fprintf fmt "%s" lowercase_name else if lowercase_name = "_" then Format.pp_print_string fmt lowercase_name
else ( else (
Cli.debug_print "lowercase_name: %s " lowercase_name; Cli.debug_print "lowercase_name: %s " lowercase_name;
Format.fprintf fmt "%s_" lowercase_name) Format.fprintf fmt "%s_" lowercase_name)
let needs_parens (e : 'm expr) : bool = let needs_parens (e : 'm expr) : bool =
match Marked.unmark e with match Marked.unmark e with
| EApp ((EAbs (_, _), _), _) | EApp { f = EAbs _, _; _ }
| ELit (LBool _ | LUnit) | ELit (LBool _ | LUnit)
| EVar _ | ETuple _ | EOp _ -> | EVar _ | ETuple _ | EOp _ ->
false false
@ -279,56 +229,52 @@ let rec format_expr (ctx : decl_ctx) (fmt : Format.formatter) (e : 'm expr) :
in in
match Marked.unmark e with match Marked.unmark e with
| EVar v -> Format.fprintf fmt "%a" format_var v | EVar v -> Format.fprintf fmt "%a" format_var v
| ETuple (es, None) -> | ETuple es ->
Format.fprintf fmt "@[<hov 2>(%a)@]" Format.fprintf fmt "@[<hov 2>(%a)@]"
(Format.pp_print_list (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ") ~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ")
(fun fmt e -> Format.fprintf fmt "%a" format_with_parens e)) (fun fmt e -> Format.fprintf fmt "%a" format_with_parens e))
es es
| ETuple (es, Some s) -> | EStruct { name = s; fields = es } ->
if List.length es = 0 then Format.fprintf fmt "()" if StructField.Map.is_empty es then Format.fprintf fmt "()"
else else
Format.fprintf fmt "{@[<hov 2>%a@]}" Format.fprintf fmt "{@[<hov 2>%a@]}"
(Format.pp_print_list (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ";@ ") ~pp_sep:(fun fmt () -> Format.fprintf fmt ";@ ")
(fun fmt (e, struct_field) -> (fun fmt (struct_field, e) ->
Format.fprintf fmt "@[<hov 2>%a =@ %a@]" format_struct_field_name Format.fprintf fmt "@[<hov 2>%a =@ %a@]" format_struct_field_name
(Some s, struct_field) format_with_parens e)) (Some s, struct_field) format_with_parens e))
(List.combine es (List.map fst (find_struct s ctx))) (StructField.Map.bindings es)
| EArray es -> | EArray es ->
Format.fprintf fmt "@[<hov 2>[|%a|]@]" Format.fprintf fmt "@[<hov 2>[|%a|]@]"
(Format.pp_print_list (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ";@ ") ~pp_sep:(fun fmt () -> Format.fprintf fmt ";@ ")
(fun fmt e -> Format.fprintf fmt "%a" format_with_parens e)) (fun fmt e -> Format.fprintf fmt "%a" format_with_parens e))
es es
| ETupleAccess (e1, n, s, ts) -> ( | ETupleAccess { e; index; size } ->
match s with Format.fprintf fmt "let@ %a@ = %a@ in@ x"
| None -> (Format.pp_print_list
Format.fprintf fmt "let@ %a@ = %a@ in@ x" ~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ")
(Format.pp_print_list (fun fmt i ->
~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ") Format.pp_print_string fmt (if i = index then "x" else "_")))
(fun fmt i -> Format.fprintf fmt "%s" (if i = n then "x" else "_"))) (List.init size Fun.id) format_with_parens e
(List.mapi (fun i _ -> i) ts) | EStructAccess { e; field; name } ->
format_with_parens e1 Format.fprintf fmt "%a.%a" format_with_parens e format_struct_field_name
| Some s -> (Some name, field)
Format.fprintf fmt "%a.%a" format_with_parens e1 format_struct_field_name | EInj { e; cons; name } ->
(Some s, fst (List.nth (find_struct s ctx) n))) Format.fprintf fmt "@[<hov 2>%a.%a@ %a@]" format_to_module_name
| EInj (e, n, en, _ts) -> (`Ename name) format_enum_cons_name cons format_with_parens e
Format.fprintf fmt "@[<hov 2>%a.%a@ %a@]" format_to_module_name (`Ename en) | EMatch { e; cases; name } ->
format_enum_cons_name
(fst (List.nth (find_enum en ctx) n))
format_with_parens e
| EMatch (e, es, e_name) ->
Format.fprintf fmt "@[<hv>@[<hov 2>match@ %a@]@ with@\n| %a@]" Format.fprintf fmt "@[<hv>@[<hov 2>match@ %a@]@ with@\n| %a@]"
format_with_parens e format_with_parens e
(Format.pp_print_list (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@ | ") ~pp_sep:(fun fmt () -> Format.fprintf fmt "@ | ")
(fun fmt (e, c) -> (fun fmt (c, e) ->
Format.fprintf fmt "@[<hov 2>%a.%a %a@]" format_to_module_name Format.fprintf fmt "@[<hov 2>%a.%a %a@]" format_to_module_name
(`Ename e_name) format_enum_cons_name c (`Ename name) format_enum_cons_name c
(fun fmt e -> (fun fmt e ->
match Marked.unmark e with match Marked.unmark e with
| EAbs (binder, _) -> | EAbs { binder; _ } ->
let xs, body = Bindlib.unmbind binder in let xs, body = Bindlib.unmbind binder in
Format.fprintf fmt "%a ->@ %a" Format.fprintf fmt "%a ->@ %a"
(Format.pp_print_list (Format.pp_print_list
@ -338,11 +284,11 @@ let rec format_expr (ctx : decl_ctx) (fmt : Format.formatter) (e : 'm expr) :
| _ -> assert false | _ -> assert false
(* should not happen *)) (* should not happen *))
e)) e))
(List.combine es (List.map fst (find_enum e_name ctx))) (EnumConstructor.Map.bindings cases)
| ELit l -> Format.fprintf fmt "%a" format_lit (Marked.mark (Expr.pos e) l) | ELit l -> Format.fprintf fmt "%a" format_lit (Marked.mark (Expr.pos e) l)
| EApp ((EAbs (binder, taus), _), args) -> | EApp { f = EAbs { binder; tys }, _; args } ->
let xs, body = Bindlib.unmbind binder in let xs, body = Bindlib.unmbind binder in
let xs_tau = List.map2 (fun x tau -> x, tau) (Array.to_list xs) taus in let xs_tau = List.map2 (fun x tau -> x, tau) (Array.to_list xs) tys in
let xs_tau_arg = List.map2 (fun (x, tau) arg -> x, tau, arg) xs_tau args in let xs_tau_arg = List.map2 (fun (x, tau) arg -> x, tau, arg) xs_tau args in
Format.fprintf fmt "(%a%a)" Format.fprintf fmt "(%a%a)"
(Format.pp_print_list (Format.pp_print_list
@ -351,30 +297,28 @@ let rec format_expr (ctx : decl_ctx) (fmt : Format.formatter) (e : 'm expr) :
Format.fprintf fmt "@[<hov 2>let@ %a@ :@ %a@ =@ %a@]@ in@\n" Format.fprintf fmt "@[<hov 2>let@ %a@ :@ %a@ =@ %a@]@ in@\n"
format_var x format_typ tau format_with_parens arg)) format_var x format_typ tau format_with_parens arg))
xs_tau_arg format_with_parens body xs_tau_arg format_with_parens body
| EAbs (binder, taus) -> | EAbs { binder; tys } ->
let xs, body = Bindlib.unmbind binder in let xs, body = Bindlib.unmbind binder in
let xs_tau = List.map2 (fun x tau -> x, tau) (Array.to_list xs) taus in let xs_tau = List.map2 (fun x tau -> x, tau) (Array.to_list xs) tys in
Format.fprintf fmt "@[<hov 2>fun@ %a ->@ %a@]" Format.fprintf fmt "@[<hov 2>fun@ %a ->@ %a@]"
(Format.pp_print_list (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ") ~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ")
(fun fmt (x, tau) -> (fun fmt (x, tau) ->
Format.fprintf fmt "@[<hov 2>(%a:@ %a)@]" format_var x format_typ tau)) Format.fprintf fmt "@[<hov 2>(%a:@ %a)@]" format_var x format_typ tau))
xs_tau format_expr body xs_tau format_expr body
| EApp ((EOp (Binop ((Map | Filter) as op)), _), [arg1; arg2]) -> | EApp
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@]" format_binop (op, Pos.no_pos) {
format_with_parens arg1 format_with_parens arg2 f = EApp { f = EOp { op = Log (BeginCall, info); _ }, _; args = [f] }, _;
| EApp ((EOp (Binop op), _), [arg1; arg2]) -> args = [arg];
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@]" format_with_parens arg1 }
format_binop (op, Pos.no_pos) format_with_parens arg2
| EApp ((EApp ((EOp (Unop (Log (BeginCall, info))), _), [f]), _), [arg])
when !Cli.trace_flag -> when !Cli.trace_flag ->
Format.fprintf fmt "(log_begin_call@ %a@ %a)@ %a" format_uid_list info Format.fprintf fmt "(log_begin_call@ %a@ %a)@ %a" format_uid_list info
format_with_parens f format_with_parens arg format_with_parens f format_with_parens arg
| EApp ((EOp (Unop (Log (VarDef tau, info))), _), [arg1]) when !Cli.trace_flag | EApp { f = EOp { op = Log (VarDef tau, info); _ }, _; args = [arg1] }
-> when !Cli.trace_flag ->
Format.fprintf fmt "(log_variable_definition@ %a@ (%a)@ %a)" format_uid_list Format.fprintf fmt "(log_variable_definition@ %a@ (%a)@ %a)" format_uid_list
info typ_embedding_name (tau, Pos.no_pos) format_with_parens arg1 info typ_embedding_name (tau, Pos.no_pos) format_with_parens arg1
| EApp ((EOp (Unop (Log (PosRecordIfTrueBool, _))), m), [arg1]) | EApp { f = EOp { op = Log (PosRecordIfTrueBool, _); _ }, m; args = [arg1] }
when !Cli.trace_flag -> when !Cli.trace_flag ->
let pos = Expr.mark_pos m in let pos = Expr.mark_pos m in
Format.fprintf fmt Format.fprintf fmt
@ -383,15 +327,13 @@ let rec format_expr (ctx : decl_ctx) (fmt : Format.formatter) (e : 'm expr) :
(Pos.get_file pos) (Pos.get_start_line pos) (Pos.get_start_column pos) (Pos.get_file pos) (Pos.get_start_line pos) (Pos.get_start_column pos)
(Pos.get_end_line pos) (Pos.get_end_column pos) format_string_list (Pos.get_end_line pos) (Pos.get_end_column pos) format_string_list
(Pos.get_law_info pos) format_with_parens arg1 (Pos.get_law_info pos) format_with_parens arg1
| EApp ((EOp (Unop (Log (EndCall, info))), _), [arg1]) when !Cli.trace_flag -> | EApp { f = EOp { op = Log (EndCall, info); _ }, _; args = [arg1] }
when !Cli.trace_flag ->
Format.fprintf fmt "(log_end_call@ %a@ %a)" format_uid_list info Format.fprintf fmt "(log_end_call@ %a@ %a)" format_uid_list info
format_with_parens arg1 format_with_parens arg1
| EApp ((EOp (Unop (Log _)), _), [arg1]) -> | EApp { f = EOp { op = Log _; _ }, _; args = [arg1] } ->
Format.fprintf fmt "%a" format_with_parens arg1 Format.fprintf fmt "%a" format_with_parens arg1
| EApp ((EOp (Unop op), _), [arg1]) -> | EApp { f = EVar x, pos; args }
Format.fprintf fmt "@[<hov 2>%a@ %a@]" format_unop (op, Pos.no_pos)
format_with_parens arg1
| EApp ((EVar x, pos), args)
when Var.compare x (Var.translate Ast.handle_default) = 0 when Var.compare x (Var.translate Ast.handle_default) = 0
|| Var.compare x (Var.translate Ast.handle_default_opt) = 0 -> || Var.compare x (Var.translate Ast.handle_default_opt) = 0 ->
Format.fprintf fmt Format.fprintf fmt
@ -409,19 +351,17 @@ let rec format_expr (ctx : decl_ctx) (fmt : Format.formatter) (e : 'm expr) :
~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ") ~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ")
format_with_parens) format_with_parens)
args args
| EApp (f, args) -> | EApp { f; args } ->
Format.fprintf fmt "@[<hov 2>%a@ %a@]" format_with_parens f Format.fprintf fmt "@[<hov 2>%a@ %a@]" format_with_parens f
(Format.pp_print_list (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ") ~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ")
format_with_parens) format_with_parens)
args args
| EIfThenElse (e1, e2, e3) -> | EIfThenElse { cond; etrue; efalse } ->
Format.fprintf fmt Format.fprintf fmt
"@[<hov 2> if@ @[<hov 2>%a@]@ then@ @[<hov 2>%a@]@ else@ @[<hov 2>%a@]@]" "@[<hov 2> if@ @[<hov 2>%a@]@ then@ @[<hov 2>%a@]@ else@ @[<hov 2>%a@]@]"
format_with_parens e1 format_with_parens e2 format_with_parens e3 format_with_parens cond format_with_parens etrue format_with_parens efalse
| EOp (Ternop op) -> Format.fprintf fmt "%a" format_ternop (op, Pos.no_pos) | EOp { op; _ } -> Format.pp_print_string fmt (Operator.name op)
| EOp (Binop op) -> Format.fprintf fmt "%a" format_binop (op, Pos.no_pos)
| EOp (Unop op) -> Format.fprintf fmt "%a" format_unop (op, Pos.no_pos)
| EAssert e' -> | EAssert e' ->
Format.fprintf fmt Format.fprintf fmt
"@[<hov 2>if@ %a@ then@ ()@ else@ raise (AssertionFailed @[<hov \ "@[<hov 2>if@ %a@ then@ ()@ else@ raise (AssertionFailed @[<hov \
@ -437,18 +377,17 @@ let rec format_expr (ctx : decl_ctx) (fmt : Format.formatter) (e : 'm expr) :
(Pos.get_law_info (Expr.pos e')) (Pos.get_law_info (Expr.pos e'))
| ERaise exc -> | ERaise exc ->
Format.fprintf fmt "raise@ %a" format_exception (exc, Expr.pos e) Format.fprintf fmt "raise@ %a" format_exception (exc, Expr.pos e)
| ECatch (e1, exc, e2) -> | ECatch { body; exn; handler } ->
Format.fprintf fmt Format.fprintf fmt
"@,@[<hv>@[<hov 2>try@ %a@]@ with@]@ @[<hov 2>%a@ ->@ %a@]" "@,@[<hv>@[<hov 2>try@ %a@]@ with@]@ @[<hov 2>%a@ ->@ %a@]"
format_with_parens e1 format_exception format_with_parens body format_exception
(exc, Expr.pos e) (exn, Expr.pos e)
format_with_parens e2 format_with_parens handler
let format_struct_embedding let format_struct_embedding
(fmt : Format.formatter) (fmt : Format.formatter)
((struct_name, struct_fields) : ((struct_name, struct_fields) : StructName.t * typ StructField.Map.t) =
StructName.t * (StructFieldName.t * typ) list) = if StructField.Map.is_empty struct_fields then
if List.length struct_fields = 0 then
Format.fprintf fmt "let embed_%a (_: %a.t) : runtime_value = Unit@\n@\n" Format.fprintf fmt "let embed_%a (_: %a.t) : runtime_value = Unit@\n@\n"
format_struct_name struct_name format_to_module_name (`Sname struct_name) format_struct_name struct_name format_to_module_name (`Sname struct_name)
else else
@ -461,16 +400,16 @@ let format_struct_embedding
(Format.pp_print_list (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ";@\n") ~pp_sep:(fun fmt () -> Format.fprintf fmt ";@\n")
(fun _fmt (struct_field, struct_field_type) -> (fun _fmt (struct_field, struct_field_type) ->
Format.fprintf fmt "(\"%a\",@ %a@ x.%a)" StructFieldName.format_t Format.fprintf fmt "(\"%a\",@ %a@ x.%a)" StructField.format_t
struct_field typ_embedding_name struct_field_type struct_field typ_embedding_name struct_field_type
format_struct_field_name format_struct_field_name
(Some struct_name, struct_field))) (Some struct_name, struct_field)))
struct_fields (StructField.Map.bindings struct_fields)
let format_enum_embedding let format_enum_embedding
(fmt : Format.formatter) (fmt : Format.formatter)
((enum_name, enum_cases) : EnumName.t * (EnumConstructor.t * typ) list) = ((enum_name, enum_cases) : EnumName.t * typ EnumConstructor.Map.t) =
if List.length enum_cases = 0 then if EnumConstructor.Map.is_empty enum_cases then
Format.fprintf fmt "let embed_%a (_: %a.t) : runtime_value = Unit@\n@\n" Format.fprintf fmt "let embed_%a (_: %a.t) : runtime_value = Unit@\n@\n"
format_to_module_name (`Ename enum_name) format_enum_name enum_name format_to_module_name (`Ename enum_name) format_enum_name enum_name
else else
@ -486,14 +425,14 @@ let format_enum_embedding
Format.fprintf fmt "@[<hov 2>| %a x ->@ (\"%a\", %a x)@]" Format.fprintf fmt "@[<hov 2>| %a x ->@ (\"%a\", %a x)@]"
format_enum_cons_name enum_cons EnumConstructor.format_t enum_cons format_enum_cons_name enum_cons EnumConstructor.format_t enum_cons
typ_embedding_name enum_cons_type)) typ_embedding_name enum_cons_type))
enum_cases (EnumConstructor.Map.bindings enum_cases)
let format_ctx let format_ctx
(type_ordering : Scopelang.Dependency.TVertex.t list) (type_ordering : Scopelang.Dependency.TVertex.t list)
(fmt : Format.formatter) (fmt : Format.formatter)
(ctx : decl_ctx) : unit = (ctx : decl_ctx) : unit =
let format_struct_decl fmt (struct_name, struct_fields) = let format_struct_decl fmt (struct_name, struct_fields) =
if List.length struct_fields = 0 then if StructField.Map.is_empty struct_fields then
Format.fprintf fmt Format.fprintf fmt
"@[<v 2>module %a = struct@\n@[<hov 2>type t = unit@]@]@\nend@\n" "@[<v 2>module %a = struct@\n@[<hov 2>type t = unit@]@]@\nend@\n"
format_to_module_name (`Sname struct_name) format_to_module_name (`Sname struct_name)
@ -508,7 +447,7 @@ let format_ctx
(fun _fmt (struct_field, struct_field_type) -> (fun _fmt (struct_field, struct_field_type) ->
Format.fprintf fmt "@[<hov 2>%a:@ %a@]" format_struct_field_name Format.fprintf fmt "@[<hov 2>%a:@ %a@]" format_struct_field_name
(None, struct_field) format_typ struct_field_type)) (None, struct_field) format_typ struct_field_type))
struct_fields; (StructField.Map.bindings struct_fields);
if !Cli.trace_flag then if !Cli.trace_flag then
format_struct_embedding fmt (struct_name, struct_fields) format_struct_embedding fmt (struct_name, struct_fields)
in in
@ -521,7 +460,7 @@ let format_ctx
(fun _fmt (enum_cons, enum_cons_type) -> (fun _fmt (enum_cons, enum_cons_type) ->
Format.fprintf fmt "@[<hov 2>| %a@ of@ %a@]" format_enum_cons_name Format.fprintf fmt "@[<hov 2>| %a@ of@ %a@]" format_enum_cons_name
enum_cons format_typ enum_cons_type)) enum_cons format_typ enum_cons_type))
enum_cons; (EnumConstructor.Map.bindings enum_cons);
if !Cli.trace_flag then format_enum_embedding fmt (enum_name, enum_cons) if !Cli.trace_flag then format_enum_embedding fmt (enum_name, enum_cons)
in in
let is_in_type_ordering s = let is_in_type_ordering s =
@ -535,8 +474,8 @@ let format_ctx
let scope_structs = let scope_structs =
List.map List.map
(fun (s, _) -> Scopelang.Dependency.TVertex.Struct s) (fun (s, _) -> Scopelang.Dependency.TVertex.Struct s)
(StructMap.bindings (StructName.Map.bindings
(StructMap.filter (StructName.Map.filter
(fun s _ -> not (is_in_type_ordering s)) (fun s _ -> not (is_in_type_ordering s))
ctx.ctx_structs)) ctx.ctx_structs))
in in

View File

@ -14,29 +14,28 @@
License for the specific language governing permissions and limitations under License for the specific language governing permissions and limitations under
the License. *) the License. *)
open Utils
open Shared_ast open Shared_ast
open Ast
(** Formats a lambda calculus program into a valid OCaml program *) (** Formats a lambda calculus program into a valid OCaml program *)
val avoid_keywords : string -> string val avoid_keywords : string -> string
val find_struct : StructName.t -> decl_ctx -> (StructFieldName.t * typ) list val find_struct : StructName.t -> decl_ctx -> typ StructField.Map.t
val find_enum : EnumName.t -> decl_ctx -> (EnumConstructor.t * typ) list val find_enum : EnumName.t -> decl_ctx -> typ EnumConstructor.Map.t
val typ_needs_parens : typ -> bool val typ_needs_parens : typ -> bool
val needs_parens : 'm expr -> bool
(* val needs_parens : 'm expr -> bool *)
val format_enum_name : Format.formatter -> EnumName.t -> unit val format_enum_name : Format.formatter -> EnumName.t -> unit
val format_enum_cons_name : Format.formatter -> EnumConstructor.t -> unit val format_enum_cons_name : Format.formatter -> EnumConstructor.t -> unit
val format_struct_name : Format.formatter -> StructName.t -> unit val format_struct_name : Format.formatter -> StructName.t -> unit
val format_struct_field_name : val format_struct_field_name :
Format.formatter -> StructName.t option * StructFieldName.t -> unit Format.formatter -> StructName.t option * StructField.t -> unit
val format_to_module_name : val format_to_module_name :
Format.formatter -> [< `Ename of EnumName.t | `Sname of StructName.t ] -> unit Format.formatter -> [< `Ename of EnumName.t | `Sname of StructName.t ] -> unit
(* * val format_lit : Format.formatter -> lit Marked.pos -> unit * val
format_uid_list : Format.formatter -> Uid.MarkedString.info list -> unit *)
val format_lit : Format.formatter -> lit Marked.pos -> unit
val format_uid_list : Format.formatter -> Uid.MarkedString.info list -> unit
val format_var : Format.formatter -> 'm Var.t -> unit val format_var : Format.formatter -> 'm Var.t -> unit
val format_program : val format_program :

View File

@ -1,7 +1,7 @@
(library (library
(name literate) (name literate)
(public_name catala.literate) (public_name catala.literate)
(libraries re utils surface ubase)) (libraries re catala_utils surface ubase uutf))
(documentation (documentation
(package catala) (package catala)

View File

@ -18,7 +18,7 @@
(** This modules weaves the source code and the legislative text together into a (** This modules weaves the source code and the legislative text together into a
document that law professionals can understand. *) document that law professionals can understand. *)
open Utils open Catala_utils
open Literate_common open Literate_common
module A = Surface.Ast module A = Surface.Ast
module P = Printf module P = Printf
@ -91,7 +91,7 @@ let wrap_html
</ul>\n" </ul>\n"
css_as_string (literal_title language) css_as_string (literal_title language)
(literal_generated_by language) (literal_generated_by language)
Utils.Cli.version Cli.version
(pre_html (literal_disclaimer_and_link language)) (pre_html (literal_disclaimer_and_link language))
(literal_source_files language) (literal_source_files language)
(String.concat "\n" (String.concat "\n"
@ -133,7 +133,7 @@ let pygmentize_code (c : string Marked.pos) (language : C.backend_lang) : string
"html"; "html";
"-O"; "-O";
"style=colorful,anchorlinenos=True,lineanchors=\"" "style=colorful,anchorlinenos=True,lineanchors=\""
^ String_common.to_ascii (Pos.get_file (Marked.get_mark c)) ^ String.to_ascii (Pos.get_file (Marked.get_mark c))
^ "\",linenos=table,linenostart=" ^ "\",linenos=table,linenostart="
^ string_of_int (Pos.get_start_line (Marked.get_mark c)); ^ string_of_int (Pos.get_start_line (Marked.get_mark c));
"-o"; "-o";
@ -160,7 +160,7 @@ let pygmentize_code (c : string Marked.pos) (language : C.backend_lang) : string
let sanitize_html_href str = let sanitize_html_href str =
str str
|> String_common.to_ascii |> String.to_ascii
|> R.substitute ~rex:(R.regexp "[' '°\"]") ~subst:(function _ -> "%20") |> R.substitute ~rex:(R.regexp "[' '°\"]") ~subst:(function _ -> "%20")
let rec law_structure_to_html let rec law_structure_to_html

View File

@ -17,7 +17,7 @@
(** This modules weaves the source code and the legislative text together into a (** This modules weaves the source code and the legislative text together into a
document that law professionals can understand. *) document that law professionals can understand. *)
open Utils open Catala_utils
(** {1 Helpers} *) (** {1 Helpers} *)

View File

@ -18,7 +18,7 @@
(** This modules weaves the source code and the legislative text together into a (** This modules weaves the source code and the legislative text together into a
document that law professionals can understand. *) document that law professionals can understand. *)
open Utils open Catala_utils
open Literate_common open Literate_common
module A = Surface.Ast module A = Surface.Ast
module R = Re.Pcre module R = Re.Pcre
@ -61,7 +61,7 @@ let wrap_latex
%s %s
\usepackage{minted} \usepackage{minted}
\usepackage{longtable} \usepackage{longtable}
\usepackage{booktabs} \usepackage{booktabs,tabularx}
\usepackage{newunicodechar} \usepackage{newunicodechar}
\usepackage{textcomp} \usepackage{textcomp}
\usepackage[hidelinks]{hyperref} \usepackage[hidelinks]{hyperref}
@ -122,8 +122,8 @@ let wrap_latex
\newunicodechar{}{$\rightarrow$} \newunicodechar{}{$\rightarrow$}
\newunicodechar{}{$\neq$} \newunicodechar{}{$\neq$}
\newcommand*\FancyVerbStartString{```catala} \newcommand*\FancyVerbStartString{\PYG{l+s}{```catala}}
\newcommand*\FancyVerbStopString{```} \newcommand*\FancyVerbStopString{\PYG{l+s}{```}}
\fvset{ \fvset{
numbers=left, numbers=left,
@ -151,14 +151,15 @@ codes={\catcode`\$=3\catcode`\^=7}
\tableofcontents \tableofcontents
\[\star\star\star\] \[\star\star\star\]
\clearpage|latex} \clearpage
|latex}
(match language with Fr -> "french" | En -> "english" | Pl -> "polish") (match language with Fr -> "french" | En -> "english" | Pl -> "polish")
(match language with Fr -> "\\setmainfont{Marianne}" | _ -> "") (match language with Fr -> "\\setmainfont{Marianne}" | _ -> "")
(* for France, we use the official font of the French state design system (* for France, we use the official font of the French state design system
https://gouvfr.atlassian.net/wiki/spaces/DB/pages/223019527/Typographie+-+Typography *) https://gouvfr.atlassian.net/wiki/spaces/DB/pages/223019527/Typographie+-+Typography *)
(literal_title language) (literal_title language)
(literal_generated_by language) (literal_generated_by language)
Utils.Cli.version Cli.version
(pre_latexify (literal_disclaimer_and_link language)) (pre_latexify (literal_disclaimer_and_link language))
(literal_source_files language) (literal_source_files language)
(String.concat (String.concat
@ -243,7 +244,7 @@ let rec law_structure_to_latex
| En -> "Metadata" | En -> "Metadata"
| Pl -> "Metadane" | Pl -> "Metadane"
in in
let start_line = Pos.get_start_line (Marked.get_mark c) - 1 in let start_line = Pos.get_start_line (Marked.get_mark c) + 1 in
let filename = Filename.basename (Pos.get_file (Marked.get_mark c)) in let filename = Filename.basename (Pos.get_file (Marked.get_mark c)) in
let block_content = Marked.unmark c in let block_content = Marked.unmark c in
check_exceeding_lines start_line filename block_content; check_exceeding_lines start_line filename block_content;
@ -252,7 +253,7 @@ let rec law_structure_to_latex
"\\begin{tcolorbox}[colframe=OliveGreen, breakable, \ "\\begin{tcolorbox}[colframe=OliveGreen, breakable, \
title=\\textcolor{black}{\\texttt{%s}},title after \ title=\\textcolor{black}{\\texttt{%s}},title after \
break=\\textcolor{black}{\\texttt{%s}},before skip=1em, after skip=1em]\n\ break=\\textcolor{black}{\\texttt{%s}},before skip=1em, after skip=1em]\n\
\\begin{minted}[numbersep=9mm, firstnumber=%d, breaklines, \ \\begin{minted}[numbersep=9mm, firstnumber=%d, \
label={\\hspace*{\\fill}\\texttt{%s}}]{%s}\n\ label={\\hspace*{\\fill}\\texttt{%s}}]{%s}\n\
```catala\n\ ```catala\n\
%s```\n\ %s```\n\

View File

@ -17,7 +17,7 @@
(** This modules weaves the source code and the legislative text together into a (** This modules weaves the source code and the legislative text together into a
document that law professionals can understand. *) document that law professionals can understand. *)
open Utils open Catala_utils
(** {1 Helpers} *) (** {1 Helpers} *)

View File

@ -14,7 +14,7 @@
License for the specific language governing permissions and limitations under License for the specific language governing permissions and limitations under
the License. *) the License. *)
open Utils open Catala_utils
open Cli open Cli
let literal_title = function let literal_title = function

View File

@ -14,32 +14,30 @@
License for the specific language governing permissions and limitations under License for the specific language governing permissions and limitations under
the License. *) the License. *)
open Utils open Catala_utils
val literal_title : Cli.backend_lang -> string val literal_title : Cli.backend_lang -> string
(** Return the title traduction according the given (** Return the title traduction according the given {!type:Cli.backend_lang}. *)
{!type:Utils.Cli.backend_lang}. *)
val literal_generated_by : Cli.backend_lang -> string val literal_generated_by : Cli.backend_lang -> string
(** Return the 'generated by' traduction according the given (** Return the 'generated by' traduction according the given
{!type:Utils.Cli.backend_lang}. *) {!type:Cli.backend_lang}. *)
val literal_source_files : Cli.backend_lang -> string val literal_source_files : Cli.backend_lang -> string
(** Return the 'source files weaved' traduction according the given (** Return the 'source files weaved' traduction according the given
{!type:Utils.Cli.backend_lang}. *) {!type:Cli.backend_lang}. *)
val literal_disclaimer_and_link : Cli.backend_lang -> string val literal_disclaimer_and_link : Cli.backend_lang -> string
(** Return the traduction of a paragraph giving a basic disclaimer about Catala (** Return the traduction of a paragraph giving a basic disclaimer about Catala
and a link to the website according the given {!type: and a link to the website according the given {!type: Cli.backend_lang}. *)
Utils.Cli.backend_lang}. *)
val literal_last_modification : Cli.backend_lang -> string val literal_last_modification : Cli.backend_lang -> string
(** Return the 'last modification' traduction according the given (** Return the 'last modification' traduction according the given
{!type:Utils.Cli.backend_lang}. *) {!type:Cli.backend_lang}. *)
val get_language_extension : Cli.backend_lang -> string val get_language_extension : Cli.backend_lang -> string
(** Return the file extension corresponding to the given (** Return the file extension corresponding to the given
{!type:Utils.Cli.backend_lang}. *) {!type:Cli.backend_lang}. *)
val run_pandoc : string -> [ `Html | `Latex ] -> string val run_pandoc : string -> [ `Html | `Latex ] -> string
(** Runs the [pandoc] on a string to pretty-print markdown features into the (** Runs the [pandoc] on a string to pretty-print markdown features into the

View File

@ -14,8 +14,10 @@
License for the specific language governing permissions and limitations under License for the specific language governing permissions and limitations under
the License. *) the License. *)
open Catala_utils
type 'ast plugin_apply_fun_typ = type 'ast plugin_apply_fun_typ =
source_file:Utils.Pos.input_file -> source_file:Pos.input_file ->
output_file:string option -> output_file:string option ->
scope:string option -> scope:string option ->
'ast -> 'ast ->
@ -51,17 +53,21 @@ let find name = Hashtbl.find backend_plugins (String.lowercase_ascii name)
let load_file f = let load_file f =
try try
Dynlink.loadfile f; Dynlink.loadfile f;
Utils.Cli.debug_print "Plugin %S loaded" f Cli.debug_print "Plugin %S loaded" f
with e -> with e ->
Utils.Errors.format_warning "Could not load plugin %S: %s" f Errors.format_warning "Could not load plugin %S: %s" f
(Printexc.to_string e) (Printexc.to_string e)
let load_dir d = let rec load_dir d =
let dynlink_exts = let dynlink_exts =
if Dynlink.is_native then [".cmxs"] else [".cmo"; ".cma"] if Dynlink.is_native then [".cmxs"] else [".cmo"; ".cma"]
in in
Array.iter Array.iter
(fun f -> (fun f ->
if List.exists (Filename.check_suffix f) dynlink_exts then if f.[0] = '.' then ()
load_file (Filename.concat d f)) else
let f = Filename.concat d f in
if Sys.is_directory f then load_dir f
else if List.exists (Filename.check_suffix f) dynlink_exts then
load_file f)
(Sys.readdir d) (Sys.readdir d)

View File

@ -16,8 +16,10 @@
(** {2 catala-facing API} *) (** {2 catala-facing API} *)
open Catala_utils
type 'ast plugin_apply_fun_typ = type 'ast plugin_apply_fun_typ =
source_file:Utils.Pos.input_file -> source_file:Pos.input_file ->
output_file:string option -> output_file:string option ->
scope:string option -> scope:string option ->
'ast -> 'ast ->

View File

@ -18,9 +18,8 @@
(** Catala plugin for generating web APIs. It generates OCaml code before the (** Catala plugin for generating web APIs. It generates OCaml code before the
the associated [js_of_ocaml] wrapper. *) the associated [js_of_ocaml] wrapper. *)
open Utils open Catala_utils
open Shared_ast open Shared_ast
open String_common
open Lcalc open Lcalc
open Lcalc.Ast open Lcalc.Ast
open Lcalc.To_ocaml open Lcalc.To_ocaml
@ -40,11 +39,11 @@ module To_jsoo = struct
let format_struct_field_name_camel_case let format_struct_field_name_camel_case
(fmt : Format.formatter) (fmt : Format.formatter)
(v : StructFieldName.t) : unit = (v : StructField.t) : unit =
let s = let s =
Format.asprintf "%a" StructFieldName.format_t v Format.asprintf "%a" StructField.format_t v
|> to_ascii |> String.to_ascii
|> to_snake_case |> String.to_snake_case
|> avoid_keywords |> avoid_keywords
|> to_camel_case |> to_camel_case
in in
@ -118,17 +117,17 @@ module To_jsoo = struct
let format_var_camel_case (fmt : Format.formatter) (v : 'm Var.t) : unit = let format_var_camel_case (fmt : Format.formatter) (v : 'm Var.t) : unit =
let lowercase_name = let lowercase_name =
Bindlib.name_of v Bindlib.name_of v
|> to_ascii |> String.to_ascii
|> to_snake_case |> String.to_snake_case
|> Re.Pcre.substitute ~rex:(Re.Pcre.regexp "\\.") ~subst:(fun _ -> |> Re.Pcre.substitute ~rex:(Re.Pcre.regexp "\\.") ~subst:(fun _ ->
"_dot_") "_dot_")
|> to_ascii |> String.to_ascii
|> avoid_keywords |> avoid_keywords
|> to_camel_case |> to_camel_case
in in
if if
List.mem lowercase_name ["handle_default"; "handle_default_opt"] List.mem lowercase_name ["handle_default"; "handle_default_opt"]
|| begins_with_uppercase (Bindlib.name_of v) || String.begins_with_uppercase (Bindlib.name_of v)
then Format.fprintf fmt "%s" lowercase_name then Format.fprintf fmt "%s" lowercase_name
else if lowercase_name = "_" then Format.fprintf fmt "%s" lowercase_name else if lowercase_name = "_" then Format.fprintf fmt "%s" lowercase_name
else Format.fprintf fmt "%s_" lowercase_name else Format.fprintf fmt "%s_" lowercase_name
@ -166,7 +165,7 @@ module To_jsoo = struct
format_struct_field_name_camel_case struct_field format_struct_field_name_camel_case struct_field
format_typ_to_jsoo struct_field_type fmt_struct_name () format_typ_to_jsoo struct_field_type fmt_struct_name ()
format_struct_field_name (None, struct_field))) format_struct_field_name (None, struct_field)))
struct_fields (StructField.Map.bindings struct_fields)
in in
let fmt_of_jsoo fmt _ = let fmt_of_jsoo fmt _ =
Format.fprintf fmt "%a" Format.fprintf fmt "%a"
@ -186,7 +185,7 @@ module To_jsoo = struct
format_struct_field_name (None, struct_field) format_struct_field_name (None, struct_field)
format_typ_of_jsoo struct_field_type fmt_struct_name () format_typ_of_jsoo struct_field_type fmt_struct_name ()
format_struct_field_name_camel_case struct_field)) format_struct_field_name_camel_case struct_field))
struct_fields (StructField.Map.bindings struct_fields)
in in
let fmt_conv_funs fmt _ = let fmt_conv_funs fmt _ =
Format.fprintf fmt Format.fprintf fmt
@ -203,7 +202,7 @@ module To_jsoo = struct
() fmt_struct_name () fmt_module_struct_name () fmt_of_jsoo () () fmt_struct_name () fmt_module_struct_name () fmt_of_jsoo ()
in in
if List.length struct_fields = 0 then if StructField.Map.is_empty struct_fields then
Format.fprintf fmt Format.fprintf fmt
"class type %a =@ object end@\n\ "class type %a =@ object end@\n\
let %a_to_jsoo (_ : %a.t) : %a Js.t = object%%js end@\n\ let %a_to_jsoo (_ : %a.t) : %a Js.t = object%%js end@\n\
@ -220,11 +219,11 @@ module To_jsoo = struct
Format.fprintf fmt "@[<hov 2>method %a:@ %a %a@]" Format.fprintf fmt "@[<hov 2>method %a:@ %a %a@]"
format_struct_field_name_camel_case struct_field format_typ format_struct_field_name_camel_case struct_field format_typ
struct_field_type format_prop_or_meth struct_field_type)) struct_field_type format_prop_or_meth struct_field_type))
struct_fields fmt_conv_funs () (StructField.Map.bindings struct_fields)
fmt_conv_funs ()
in in
let format_enum_decl let format_enum_decl fmt (enum_name, (enum_cons : typ EnumConstructor.Map.t))
fmt =
(enum_name, (enum_cons : (EnumConstructor.t * typ) list)) =
let fmt_enum_name fmt _ = format_enum_name fmt enum_name in let fmt_enum_name fmt _ = format_enum_name fmt enum_name in
let fmt_module_enum_name fmt _ = let fmt_module_enum_name fmt _ =
To_ocaml.format_to_module_name fmt (`Ename enum_name) To_ocaml.format_to_module_name fmt (`Ename enum_name)
@ -247,7 +246,7 @@ module To_jsoo = struct
end@]" end@]"
format_enum_cons_name cname format_enum_cons_name cname format_enum_cons_name cname format_enum_cons_name cname
format_typ_to_jsoo typ)) format_typ_to_jsoo typ))
enum_cons (EnumConstructor.Map.bindings enum_cons)
in in
let fmt_of_jsoo fmt _ = let fmt_of_jsoo fmt _ =
Format.fprintf fmt Format.fprintf fmt
@ -273,7 +272,8 @@ module To_jsoo = struct
format_enum_cons_name cname fmt_module_enum_name () format_enum_cons_name cname fmt_module_enum_name ()
format_enum_cons_name cname format_typ_of_jsoo typ format_enum_cons_name cname format_typ_of_jsoo typ
fmt_enum_name ())) fmt_enum_name ()))
enum_cons fmt_module_enum_name () (EnumConstructor.Map.bindings enum_cons)
fmt_module_enum_name ()
in in
let fmt_conv_funs fmt _ = let fmt_conv_funs fmt _ =
@ -301,7 +301,8 @@ module To_jsoo = struct
~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n") ~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n")
(fun fmt (enum_cons, _) -> (fun fmt (enum_cons, _) ->
Format.fprintf fmt "- \"%a\"" format_enum_cons_name enum_cons)) Format.fprintf fmt "- \"%a\"" format_enum_cons_name enum_cons))
enum_cons fmt_conv_funs () (EnumConstructor.Map.bindings enum_cons)
fmt_conv_funs ()
in in
let is_in_type_ordering s = let is_in_type_ordering s =
List.exists List.exists
@ -314,8 +315,8 @@ module To_jsoo = struct
let scope_structs = let scope_structs =
List.map List.map
(fun (s, _) -> Scopelang.Dependency.TVertex.Struct s) (fun (s, _) -> Scopelang.Dependency.TVertex.Struct s)
(StructMap.bindings (StructName.Map.bindings
(StructMap.filter (StructName.Map.filter
(fun s _ -> not (is_in_type_ordering s)) (fun s _ -> not (is_in_type_ordering s))
ctx.ctx_structs)) ctx.ctx_structs))
in in

View File

@ -1,18 +1,22 @@
(executable (library
(name python) (name python)
(modes plugin) (public_name catala.plugins.python)
(synopsis
"Demonstration Catala plugin that reproduces the behaviour of the built-in python backend")
(modules python) (modules python)
(libraries catala.driver)) (libraries catala.driver))
(executable (library
(name api_web) (name api_web)
(modes plugin) (public_name catala.plugins.api_web)
(synopsis "Catala plugin for interaction with a web interface")
(modules api_web) (modules api_web)
(libraries catala.driver)) (libraries catala.driver))
(executable (library
(name json_schema) (name json_schema)
(modes plugin) (public_name catala.plugins.json_schema)
(synopsis "Catala plugin generating JSON schemas useful to build web-forms")
(modules json_schema) (modules json_schema)
(libraries catala.driver)) (libraries catala.driver))

View File

@ -20,8 +20,7 @@
let name = "json_schema" let name = "json_schema"
let extension = "_schema.json" let extension = "_schema.json"
open Utils open Catala_utils
open String_common
open Shared_ast open Shared_ast
open Lcalc.Ast open Lcalc.Ast
open Lcalc.To_ocaml open Lcalc.To_ocaml
@ -38,11 +37,11 @@ module To_json = struct
let format_struct_field_name_camel_case let format_struct_field_name_camel_case
(fmt : Format.formatter) (fmt : Format.formatter)
(v : StructFieldName.t) : unit = (v : StructField.t) : unit =
let s = let s =
Format.asprintf "%a" StructFieldName.format_t v Format.asprintf "%a" StructField.format_t v
|> to_ascii |> String.to_ascii
|> to_snake_case |> String.to_snake_case
|> avoid_keywords |> avoid_keywords
|> to_camel_case |> to_camel_case
in in
@ -97,7 +96,7 @@ module To_json = struct
(fun fmt (field_name, field_type) -> (fun fmt (field_name, field_type) ->
Format.fprintf fmt "@[<hov 2>\"%a\": {@\n%a@]@\n}" Format.fprintf fmt "@[<hov 2>\"%a\": {@\n%a@]@\n}"
format_struct_field_name_camel_case field_name fmt_type field_type)) format_struct_field_name_camel_case field_name fmt_type field_type))
(find_struct sname ctx) (StructField.Map.bindings (find_struct sname ctx))
let fmt_definitions let fmt_definitions
(ctx : decl_ctx) (ctx : decl_ctx)
@ -118,11 +117,14 @@ module To_json = struct
(t :: acc) @ collect_required_type_defs_from_scope_input s (t :: acc) @ collect_required_type_defs_from_scope_input s
| TEnum e -> | TEnum e ->
List.fold_left collect (t :: acc) List.fold_left collect (t :: acc)
(List.map snd (EnumMap.find e ctx.ctx_enums)) (List.map snd
(EnumConstructor.Map.bindings
(EnumName.Map.find e ctx.ctx_enums)))
| TArray t -> collect acc t | TArray t -> collect acc t
| _ -> acc | _ -> acc
in in
find_struct input_struct ctx find_struct input_struct ctx
|> StructField.Map.bindings
|> List.fold_left (fun acc (_, field_typ) -> collect acc field_typ) [] |> List.fold_left (fun acc (_, field_typ) -> collect acc field_typ) []
|> List.sort_uniq (fun t t' -> String.compare (get_name t) (get_name t')) |> List.sort_uniq (fun t t' -> String.compare (get_name t) (get_name t'))
in in
@ -146,7 +148,7 @@ module To_json = struct
Format.fprintf fmt Format.fprintf fmt
"@[<hov 2>{@\n\"type\": \"string\",@\n\"enum\": [\"%a\"]@]@\n}" "@[<hov 2>{@\n\"type\": \"string\",@\n\"enum\": [\"%a\"]@]@\n}"
format_enum_cons_name enum_cons)) format_enum_cons_name enum_cons))
enum_def (EnumConstructor.Map.bindings enum_def)
(Format.pp_print_list (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ",@\n") ~pp_sep:(fun fmt () -> Format.fprintf fmt ",@\n")
(fun fmt (enum_cons, payload_type) -> (fun fmt (enum_cons, payload_type) ->
@ -168,7 +170,7 @@ module To_json = struct
}@]@\n\ }@]@\n\
}" }"
format_enum_cons_name enum_cons fmt_type payload_type)) format_enum_cons_name enum_cons fmt_type payload_type))
enum_def (EnumConstructor.Map.bindings enum_def)
in in
Format.fprintf fmt "@\n%a" Format.fprintf fmt "@\n%a"

View File

@ -20,13 +20,15 @@
The code for the Python backend already has first-class support, so there The code for the Python backend already has first-class support, so there
would be no reason to use this plugin instead *) would be no reason to use this plugin instead *)
open Catala_utils
let name = "python-plugin" let name = "python-plugin"
let extension = ".py" let extension = ".py"
let apply ~source_file ~output_file ~scope prgm type_ordering = let apply ~source_file ~output_file ~scope prgm type_ordering =
ignore source_file; ignore source_file;
ignore scope; ignore scope;
Utils.File.with_formatter_of_opt_file output_file File.with_formatter_of_opt_file output_file
@@ fun fmt -> Scalc.To_python.format_program fmt prgm type_ordering @@ fun fmt -> Scalc.To_python.format_program fmt prgm type_ordering
let () = Driver.Plugin.register_scalc ~name ~extension apply let () = Driver.Plugin.register_scalc ~name ~extension apply

View File

@ -14,7 +14,7 @@
License for the specific language governing permissions and limitations under License for the specific language governing permissions and limitations under
the License. *) the License. *)
open Utils open Catala_utils
open Shared_ast open Shared_ast
module D = Dcalc.Ast module D = Dcalc.Ast
module L = Lcalc.Ast module L = Lcalc.Ast
@ -28,15 +28,15 @@ let handle_default_opt = TopLevelName.fresh ("handle_default_opt", Pos.no_pos)
type expr = naked_expr Marked.pos type expr = naked_expr Marked.pos
and naked_expr = and naked_expr =
| EVar of LocalName.t | EVar : LocalName.t -> naked_expr
| EFunc of TopLevelName.t | EFunc : TopLevelName.t -> naked_expr
| EStruct of expr list * StructName.t | EStruct : expr list * StructName.t -> naked_expr
| EStructFieldAccess of expr * StructFieldName.t * StructName.t | EStructFieldAccess : expr * StructField.t * StructName.t -> naked_expr
| EInj of expr * EnumConstructor.t * EnumName.t | EInj : expr * EnumConstructor.t * EnumName.t -> naked_expr
| EArray of expr list | EArray : expr list -> naked_expr
| ELit of L.lit | ELit : L.lit -> naked_expr
| EApp of expr * expr list | EApp : expr * expr list -> naked_expr
| EOp of operator | EOp : (lcalc, _) operator -> naked_expr
type stmt = type stmt =
| SInnerFuncDef of LocalName.t Marked.pos * func | SInnerFuncDef of LocalName.t Marked.pos * func

View File

@ -14,7 +14,7 @@
License for the specific language governing permissions and limitations under License for the specific language governing permissions and limitations under
the License. *) the License. *)
open Utils open Catala_utils
open Shared_ast open Shared_ast
module A = Ast module A = Ast
module L = Lcalc.Ast module L = Lcalc.Ast
@ -35,36 +35,37 @@ let rec translate_expr (ctxt : 'm ctxt) (expr : 'm L.expr) : A.block * A.expr =
| EVar v -> | EVar v ->
let local_var = let local_var =
try A.EVar (Var.Map.find v ctxt.var_dict) try A.EVar (Var.Map.find v ctxt.var_dict)
with Not_found -> A.EFunc (Var.Map.find v ctxt.func_dict) with Not_found -> (
try A.EFunc (Var.Map.find v ctxt.func_dict)
with Not_found ->
Errors.raise_spanned_error (Expr.pos expr)
"Var not found in lambda→scalc: %a@\nknown: @[<hov>%a@]@\n"
Print.var_debug v
(Format.pp_print_list ~pp_sep:Format.pp_print_space
(fun ppf (v, _) -> Print.var_debug ppf v))
(Var.Map.bindings ctxt.var_dict))
in in
[], (local_var, Expr.pos expr) [], (local_var, Expr.pos expr)
| ETuple (args, Some s_name) -> | EStruct { fields; name } ->
let args_stmts, new_args = let args_stmts, new_args =
List.fold_left StructField.Map.fold
(fun (args_stmts, new_args) arg -> (fun _ arg (args_stmts, new_args) ->
let arg_stmts, new_arg = translate_expr ctxt arg in let arg_stmts, new_arg = translate_expr ctxt arg in
arg_stmts @ args_stmts, new_arg :: new_args) arg_stmts @ args_stmts, new_arg :: new_args)
([], []) args fields ([], [])
in in
let new_args = List.rev new_args in let new_args = List.rev new_args in
let args_stmts = List.rev args_stmts in let args_stmts = List.rev args_stmts in
args_stmts, (A.EStruct (new_args, s_name), Expr.pos expr) args_stmts, (A.EStruct (new_args, name), Expr.pos expr)
| ETuple (_, None) -> failwith "Non-struct tuples cannot be compiled to scalc" | ETuple _ -> failwith "Tuples cannot be compiled to scalc"
| ETupleAccess (e1, num_field, Some s_name, _) -> | EStructAccess { e = e1; field; name } ->
let e1_stmts, new_e1 = translate_expr ctxt e1 in let e1_stmts, new_e1 = translate_expr ctxt e1 in
let field_name = e1_stmts, (A.EStructFieldAccess (new_e1, field, name), Expr.pos expr)
fst (List.nth (StructMap.find s_name ctxt.decl_ctx.ctx_structs) num_field) | ETupleAccess _ -> failwith "Non-struct tuples cannot be compiled to scalc"
in | EInj { e = e1; cons; name } ->
e1_stmts, (A.EStructFieldAccess (new_e1, field_name, s_name), Expr.pos expr)
| ETupleAccess (_, _, None, _) ->
failwith "Non-struct tuples cannot be compiled to scalc"
| EInj (e1, num_cons, e_name, _) ->
let e1_stmts, new_e1 = translate_expr ctxt e1 in let e1_stmts, new_e1 = translate_expr ctxt e1 in
let cons_name = e1_stmts, (A.EInj (new_e1, cons, name), Expr.pos expr)
fst (List.nth (EnumMap.find e_name ctxt.decl_ctx.ctx_enums) num_cons) | EApp { f; args } ->
in
e1_stmts, (A.EInj (new_e1, cons_name, e_name), Expr.pos expr)
| EApp (f, args) ->
let f_stmts, new_f = translate_expr ctxt f in let f_stmts, new_f = translate_expr ctxt f in
let args_stmts, new_args = let args_stmts, new_args =
List.fold_left List.fold_left
@ -85,7 +86,7 @@ let rec translate_expr (ctxt : 'm ctxt) (expr : 'm L.expr) : A.block * A.expr =
in in
let new_args = List.rev new_args in let new_args = List.rev new_args in
args_stmts, (A.EArray new_args, Expr.pos expr) args_stmts, (A.EArray new_args, Expr.pos expr)
| EOp op -> [], (A.EOp op, Expr.pos expr) | EOp { op; _ } -> [], (A.EOp op, Expr.pos expr)
| ELit l -> [], (A.ELit l, Expr.pos expr) | ELit l -> [], (A.ELit l, Expr.pos expr)
| _ -> | _ ->
let tmp_var = let tmp_var =
@ -120,11 +121,11 @@ and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.expr) : A.block =
(* Assertions are always encapsulated in a unit-typed let binding *) (* Assertions are always encapsulated in a unit-typed let binding *)
let e_stmts, new_e = translate_expr ctxt e in let e_stmts, new_e = translate_expr ctxt e in
e_stmts @ [A.SAssert (Marked.unmark new_e), Expr.pos block_expr] e_stmts @ [A.SAssert (Marked.unmark new_e), Expr.pos block_expr]
| EApp ((EAbs (binder, taus), binder_mark), args) -> | EApp { f = EAbs { binder; tys }, binder_mark; args } ->
(* This defines multiple local variables at the time *) (* This defines multiple local variables at the time *)
let binder_pos = Expr.mark_pos binder_mark in let binder_pos = Expr.mark_pos binder_mark in
let vars, body = Bindlib.unmbind binder in let vars, body = Bindlib.unmbind binder in
let vars_tau = List.map2 (fun x tau -> x, tau) (Array.to_list vars) taus in let vars_tau = List.map2 (fun x tau -> x, tau) (Array.to_list vars) tys in
let ctxt = let ctxt =
{ {
ctxt with ctxt with
@ -167,10 +168,10 @@ and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.expr) : A.block =
in in
let rest_of_block = translate_statements ctxt body in let rest_of_block = translate_statements ctxt body in
local_decls @ List.flatten def_blocks @ rest_of_block local_decls @ List.flatten def_blocks @ rest_of_block
| EAbs (binder, taus) -> | EAbs { binder; tys } ->
let vars, body = Bindlib.unmbind binder in let vars, body = Bindlib.unmbind binder in
let binder_pos = Expr.pos block_expr in let binder_pos = Expr.pos block_expr in
let vars_tau = List.map2 (fun x tau -> x, tau) (Array.to_list vars) taus in let vars_tau = List.map2 (fun x tau -> x, tau) (Array.to_list vars) tys in
let closure_name = let closure_name =
match ctxt.inside_definition_of with match ctxt.inside_definition_of with
| None -> A.LocalName.fresh (ctxt.context_name, Expr.pos block_expr) | None -> A.LocalName.fresh (ctxt.context_name, Expr.pos block_expr)
@ -203,13 +204,13 @@ and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.expr) : A.block =
} ), } ),
binder_pos ); binder_pos );
] ]
| EMatch (e1, args, e_name) -> | EMatch { e = e1; cases; name } ->
let e1_stmts, new_e1 = translate_expr ctxt e1 in let e1_stmts, new_e1 = translate_expr ctxt e1 in
let new_args = let new_cases =
List.fold_left EnumConstructor.Map.fold
(fun new_args arg -> (fun _ arg new_args ->
match Marked.unmark arg with match Marked.unmark arg with
| EAbs (binder, _) -> | EAbs { binder; _ } ->
let vars, body = Bindlib.unmbind binder in let vars, body = Bindlib.unmbind binder in
assert (Array.length vars = 1); assert (Array.length vars = 1);
let var = vars.(0) in let var = vars.(0) in
@ -223,20 +224,20 @@ and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.expr) : A.block =
(new_arg, scalc_var) :: new_args (new_arg, scalc_var) :: new_args
| _ -> assert false | _ -> assert false
(* should not happen *)) (* should not happen *))
[] args cases []
in in
let new_args = List.rev new_args in let new_args = List.rev new_cases in
e1_stmts @ [A.SSwitch (new_e1, e_name, new_args), Expr.pos block_expr] e1_stmts @ [A.SSwitch (new_e1, name, new_args), Expr.pos block_expr]
| EIfThenElse (cond, e_true, e_false) -> | EIfThenElse { cond; etrue; efalse } ->
let cond_stmts, s_cond = translate_expr ctxt cond in let cond_stmts, s_cond = translate_expr ctxt cond in
let s_e_true = translate_statements ctxt e_true in let s_e_true = translate_statements ctxt etrue in
let s_e_false = translate_statements ctxt e_false in let s_e_false = translate_statements ctxt efalse in
cond_stmts cond_stmts
@ [A.SIfThenElse (s_cond, s_e_true, s_e_false), Expr.pos block_expr] @ [A.SIfThenElse (s_cond, s_e_true, s_e_false), Expr.pos block_expr]
| ECatch (e_try, except, e_catch) -> | ECatch { body; exn; handler } ->
let s_e_try = translate_statements ctxt e_try in let s_e_try = translate_statements ctxt body in
let s_e_catch = translate_statements ctxt e_catch in let s_e_catch = translate_statements ctxt handler in
[A.STryExcept (s_e_try, except, s_e_catch), Expr.pos block_expr] [A.STryExcept (s_e_try, exn, s_e_catch), Expr.pos block_expr]
| ERaise except -> | ERaise except ->
(* Before raising the exception, we still give a dummy definition to the (* Before raising the exception, we still give a dummy definition to the
current variable so that tools like mypy don't complain. *) current variable so that tools like mypy don't complain. *)

View File

@ -14,7 +14,7 @@
License for the specific language governing permissions and limitations under License for the specific language governing permissions and limitations under
the License. *) the License. *)
open Utils open Catala_utils
open Shared_ast open Shared_ast
open Ast open Ast
@ -44,11 +44,12 @@ let rec format_expr
Print.punctuation "{" Print.punctuation "{"
(Format.pp_print_list (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ") ~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ")
(fun fmt (e, struct_field) -> (fun fmt (e, (struct_field, _)) ->
Format.fprintf fmt "%a%a%a%a %a" Print.punctuation "\"" Format.fprintf fmt "%a%a%a%a %a" Print.punctuation "\""
StructFieldName.format_t struct_field Print.punctuation "\"" StructField.format_t struct_field Print.punctuation "\""
Print.punctuation ":" format_expr e)) Print.punctuation ":" format_expr e))
(List.combine es (List.map fst (StructMap.find s decl_ctx.ctx_structs))) (List.combine es
(StructField.Map.bindings (StructName.Map.find s decl_ctx.ctx_structs)))
Print.punctuation "}" Print.punctuation "}"
| EArray es -> | EArray es ->
Format.fprintf fmt "@[<hov 2>%a%a%a@]" Print.punctuation "[" Format.fprintf fmt "@[<hov 2>%a%a%a@]" Print.punctuation "["
@ -56,41 +57,31 @@ let rec format_expr
~pp_sep:(fun fmt () -> Format.fprintf fmt ";@ ") ~pp_sep:(fun fmt () -> Format.fprintf fmt ";@ ")
(fun fmt e -> Format.fprintf fmt "%a" format_expr e)) (fun fmt e -> Format.fprintf fmt "%a" format_expr e))
es Print.punctuation "]" es Print.punctuation "]"
| EStructFieldAccess (e1, field, s) -> | EStructFieldAccess (e1, field, _) ->
Format.fprintf fmt "%a%a%a%a%a" format_expr e1 Print.punctuation "." Format.fprintf fmt "%a%a%a%a%a" format_expr e1 Print.punctuation "."
Print.punctuation "\"" StructFieldName.format_t Print.punctuation "\"" StructField.format_t field Print.punctuation "\""
(fst | EInj (e, cons, _) ->
(List.find Format.fprintf fmt "@[<hov 2>%a@ %a@]" Print.enum_constructor cons
(fun (field', _) -> StructFieldName.compare field' field = 0)
(StructMap.find s decl_ctx.ctx_structs)))
Print.punctuation "\""
| EInj (e, case, enum) ->
Format.fprintf fmt "@[<hov 2>%a@ %a@]" Print.enum_constructor
(fst
(List.find
(fun (case', _) -> EnumConstructor.compare case' case = 0)
(EnumMap.find enum decl_ctx.ctx_enums)))
format_expr e format_expr e
| ELit l -> Print.lit fmt l | ELit l -> Print.lit fmt l
| EApp ((EOp (Binop ((Map | Filter) as op)), _), [arg1; arg2]) -> | EApp ((EOp ((Map | Filter) as op), _), [arg1; arg2]) ->
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@]" Print.binop op format_with_parens Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@]" Print.operator op
arg1 format_with_parens arg2 format_with_parens arg1 format_with_parens arg2
| EApp ((EOp (Binop op), _), [arg1; arg2]) -> | EApp ((EOp op, _), [arg1; arg2]) ->
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@]" format_with_parens arg1 Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@]" format_with_parens arg1
Print.binop op format_with_parens arg2 Print.operator op format_with_parens arg2
| EApp ((EOp (Unop (Log _)), _), [arg1]) when not debug -> | EApp ((EOp (Log _), _), [arg1]) when not debug ->
Format.fprintf fmt "%a" format_with_parens arg1 Format.fprintf fmt "%a" format_with_parens arg1
| EApp ((EOp (Unop op), _), [arg1]) -> | EApp ((EOp op, _), [arg1]) ->
Format.fprintf fmt "@[<hov 2>%a@ %a@]" Print.unop op format_with_parens arg1 Format.fprintf fmt "@[<hov 2>%a@ %a@]" Print.operator op format_with_parens
arg1
| EApp (f, args) -> | EApp (f, args) ->
Format.fprintf fmt "@[<hov 2>%a@ %a@]" format_expr f Format.fprintf fmt "@[<hov 2>%a@ %a@]" format_expr f
(Format.pp_print_list (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ") ~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ")
format_with_parens) format_with_parens)
args args
| EOp (Ternop op) -> Format.fprintf fmt "%a" Print.ternop op | EOp op -> Format.fprintf fmt "%a" Print.operator op
| EOp (Binop op) -> Format.fprintf fmt "%a" Print.binop op
| EOp (Unop op) -> Format.fprintf fmt "%a" Print.unop op
let rec format_statement let rec format_statement
(decl_ctx : decl_ctx) (decl_ctx : decl_ctx)
@ -101,22 +92,22 @@ let rec format_statement
match Marked.unmark stmt with match Marked.unmark stmt with
| SInnerFuncDef (name, func) -> | SInnerFuncDef (name, func) ->
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@ %a@]@\n@[<v 2> %a@]" Print.keyword Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@ %a@]@\n@[<v 2> %a@]" Print.keyword
"let" LocalName.format_t (Marked.unmark name) "let" format_local_name (Marked.unmark name)
(Format.pp_print_list (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ") ~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ")
(fun fmt ((name, _), typ) -> (fun fmt ((name, _), typ) ->
Format.fprintf fmt "%a%a %a@ %a%a" Print.punctuation "(" Format.fprintf fmt "%a%a %a@ %a%a" Print.punctuation "("
LocalName.format_t name Print.punctuation ":" (Print.typ decl_ctx) format_local_name name Print.punctuation ":" (Print.typ decl_ctx)
typ Print.punctuation ")")) typ Print.punctuation ")"))
func.func_params Print.punctuation "=" func.func_params Print.punctuation "="
(format_block decl_ctx ~debug) (format_block decl_ctx ~debug)
func.func_body func.func_body
| SLocalDecl (name, typ) -> | SLocalDecl (name, typ) ->
Format.fprintf fmt "@[<hov 2>%a %a %a@ %a@]" Print.keyword "decl" Format.fprintf fmt "@[<hov 2>%a %a %a@ %a@]" Print.keyword "decl"
LocalName.format_t (Marked.unmark name) Print.punctuation ":" format_local_name (Marked.unmark name) Print.punctuation ":"
(Print.typ decl_ctx) typ (Print.typ decl_ctx) typ
| SLocalDef (name, naked_expr) -> | SLocalDef (name, naked_expr) ->
Format.fprintf fmt "@[<hov 2>%a %a@ %a@]" LocalName.format_t Format.fprintf fmt "@[<hov 2>%a %a@ %a@]" format_local_name
(Marked.unmark name) Print.punctuation "=" (Marked.unmark name) Print.punctuation "="
(format_expr decl_ctx ~debug) (format_expr decl_ctx ~debug)
naked_expr naked_expr
@ -156,10 +147,13 @@ let rec format_statement
(fun fmt ((case, _), (arm_block, payload_name)) -> (fun fmt ((case, _), (arm_block, payload_name)) ->
Format.fprintf fmt "%a %a%a@ %a @[<v 2>%a@ %a@]" Print.punctuation Format.fprintf fmt "%a %a%a@ %a @[<v 2>%a@ %a@]" Print.punctuation
"|" Print.enum_constructor case Print.punctuation ":" "|" Print.enum_constructor case Print.punctuation ":"
LocalName.format_t payload_name Print.punctuation "" format_local_name payload_name Print.punctuation ""
(format_block decl_ctx ~debug) (format_block decl_ctx ~debug)
arm_block)) arm_block))
(List.combine (EnumMap.find enum decl_ctx.ctx_enums) arms) (List.combine
(EnumConstructor.Map.bindings
(EnumName.Map.find enum decl_ctx.ctx_enums))
arms)
and format_block and format_block
(decl_ctx : decl_ctx) (decl_ctx : decl_ctx)
@ -183,8 +177,8 @@ let format_scope
~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ") ~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ")
(fun fmt ((name, _), typ) -> (fun fmt ((name, _), typ) ->
Format.fprintf fmt "%a%a %a@ %a%a" Print.punctuation "(" Format.fprintf fmt "%a%a %a@ %a%a" Print.punctuation "("
LocalName.format_t name Print.punctuation ":" (Print.typ decl_ctx) format_local_name name Print.punctuation ":" (Print.typ decl_ctx) typ
typ Print.punctuation ")")) Print.punctuation ")"))
body.scope_body_func.func_params Print.punctuation "=" body.scope_body_func.func_params Print.punctuation "="
(format_block decl_ctx ~debug) (format_block decl_ctx ~debug)
body.scope_body_func.func_body body.scope_body_func.func_body

View File

@ -15,21 +15,20 @@
the License. *) the License. *)
[@@@warning "-32-27"] [@@@warning "-32-27"]
open Utils open Catala_utils
open Shared_ast open Shared_ast
open Ast open Ast
open String_common
module Runtime = Runtime_ocaml.Runtime module Runtime = Runtime_ocaml.Runtime
module D = Dcalc.Ast module D = Dcalc.Ast
module L = Lcalc.Ast module L = Lcalc.Ast
let format_lit (fmt : Format.formatter) (l : L.lit Marked.pos) : unit = let format_lit (fmt : Format.formatter) (l : L.lit Marked.pos) : unit =
match Marked.unmark l with match Marked.unmark l with
| LBool true -> Format.fprintf fmt "True" | LBool true -> Format.pp_print_string fmt "True"
| LBool false -> Format.fprintf fmt "False" | LBool false -> Format.pp_print_string fmt "False"
| LInt i -> | LInt i ->
Format.fprintf fmt "integer_of_string(\"%s\")" (Runtime.integer_to_string i) Format.fprintf fmt "integer_of_string(\"%s\")" (Runtime.integer_to_string i)
| LUnit -> Format.fprintf fmt "Unit()" | LUnit -> Format.pp_print_string fmt "Unit()"
| LRat i -> Format.fprintf fmt "decimal_of_string(\"%a\")" Print.lit (LRat i) | LRat i -> Format.fprintf fmt "decimal_of_string(\"%a\")" Print.lit (LRat i)
| LMoney e -> | LMoney e ->
Format.fprintf fmt "money_of_cents_string(\"%s\")" Format.fprintf fmt "money_of_cents_string(\"%s\")"
@ -45,31 +44,60 @@ let format_lit (fmt : Format.formatter) (l : L.lit Marked.pos) : unit =
let format_log_entry (fmt : Format.formatter) (entry : log_entry) : unit = let format_log_entry (fmt : Format.formatter) (entry : log_entry) : unit =
match entry with match entry with
| VarDef _ -> Format.fprintf fmt ":=" | VarDef _ -> Format.pp_print_string fmt ":="
| BeginCall -> Format.fprintf fmt "" | BeginCall -> Format.pp_print_string fmt ""
| EndCall -> Format.fprintf fmt "%s" "" | EndCall -> Format.fprintf fmt "%s" ""
| PosRecordIfTrueBool -> Format.fprintf fmt "" | PosRecordIfTrueBool -> Format.pp_print_string fmt ""
let format_binop (fmt : Format.formatter) (op : binop Marked.pos) : unit = let format_op
(type k)
(fmt : Format.formatter)
(op : (lcalc, k) operator Marked.pos) : unit =
match Marked.unmark op with match Marked.unmark op with
| Add _ | Concat -> Format.fprintf fmt "+" | Log (entry, infos) -> assert false
| Sub _ -> Format.fprintf fmt "-" | Minus_int | Minus_rat | Minus_mon | Minus_dur ->
| Mult _ -> Format.fprintf fmt "*" Format.pp_print_string fmt "-"
| Div KInt -> Format.fprintf fmt "//" (* Todo: use the names from [Operator.name] *)
| Div _ -> Format.fprintf fmt "/" | Not -> Format.pp_print_string fmt "not"
| And -> Format.fprintf fmt "and" | Length -> Format.pp_print_string fmt "list_length"
| Or -> Format.fprintf fmt "or" | ToRat_int -> Format.pp_print_string fmt "decimal_of_integer"
| Eq -> Format.fprintf fmt "==" | ToRat_mon -> Format.pp_print_string fmt "decimal_of_money"
| Neq | Xor -> Format.fprintf fmt "!=" | ToMoney_rat -> Format.pp_print_string fmt "money_of_decimal"
| Lt _ -> Format.fprintf fmt "<" | GetDay -> Format.pp_print_string fmt "day_of_month_of_date"
| Lte _ -> Format.fprintf fmt "<=" | GetMonth -> Format.pp_print_string fmt "month_number_of_date"
| Gt _ -> Format.fprintf fmt ">" | GetYear -> Format.pp_print_string fmt "year_of_date"
| Gte _ -> Format.fprintf fmt ">=" | FirstDayOfMonth -> Format.pp_print_string fmt "first_day_of_month"
| Map -> Format.fprintf fmt "list_map" | LastDayOfMonth -> Format.pp_print_string fmt "last_day_of_month"
| Filter -> Format.fprintf fmt "list_filter" | Round_mon -> Format.pp_print_string fmt "money_round"
| Round_rat -> Format.pp_print_string fmt "decimal_round"
let format_ternop (fmt : Format.formatter) (op : ternop Marked.pos) : unit = | Add_int_int | Add_rat_rat | Add_mon_mon | Add_dat_dur | Add_dur_dur | Concat
match Marked.unmark op with Fold -> Format.fprintf fmt "list_fold_left" ->
Format.pp_print_string fmt "+"
| Sub_int_int | Sub_rat_rat | Sub_mon_mon | Sub_dat_dat | Sub_dat_dur
| Sub_dur_dur ->
Format.pp_print_string fmt "-"
| Mult_int_int | Mult_rat_rat | Mult_mon_rat | Mult_dur_int ->
Format.pp_print_string fmt "*"
| Div_int_int -> Format.pp_print_string fmt "//"
| Div_rat_rat | Div_mon_mon | Div_mon_rat -> Format.pp_print_string fmt "/"
| And -> Format.pp_print_string fmt "and"
| Or -> Format.pp_print_string fmt "or"
| Eq -> Format.pp_print_string fmt "=="
| Xor -> Format.pp_print_string fmt "!="
| Lt_int_int | Lt_rat_rat | Lt_mon_mon | Lt_dat_dat | Lt_dur_dur ->
Format.pp_print_string fmt "<"
| Lte_int_int | Lte_rat_rat | Lte_mon_mon | Lte_dat_dat | Lte_dur_dur ->
Format.pp_print_string fmt "<="
| Gt_int_int | Gt_rat_rat | Gt_mon_mon | Gt_dat_dat | Gt_dur_dur ->
Format.pp_print_string fmt ">"
| Gte_int_int | Gte_rat_rat | Gte_mon_mon | Gte_dat_dat | Gte_dur_dur ->
Format.pp_print_string fmt ">="
| Eq_int_int | Eq_rat_rat | Eq_mon_mon | Eq_dat_dat | Eq_dur_dur ->
Format.pp_print_string fmt "=="
| Map -> Format.pp_print_string fmt "list_map"
| Reduce -> Format.pp_print_string fmt "list_reduce"
| Filter -> Format.pp_print_string fmt "list_filter"
| Fold -> Format.pp_print_string fmt "list_fold_left"
let format_uid_list (fmt : Format.formatter) (uids : Uid.MarkedString.info list) let format_uid_list (fmt : Format.formatter) (uids : Uid.MarkedString.info list)
: unit = : unit =
@ -77,7 +105,7 @@ let format_uid_list (fmt : Format.formatter) (uids : Uid.MarkedString.info list)
(Format.pp_print_list (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ") ~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ")
(fun fmt info -> (fun fmt info ->
Format.fprintf fmt "\"%a\"" Utils.Uid.MarkedString.format_info info)) Format.fprintf fmt "\"%a\"" Uid.MarkedString.format info))
uids uids
let format_string_list (fmt : Format.formatter) (uids : string list) : unit = let format_string_list (fmt : Format.formatter) (uids : string list) : unit =
@ -90,23 +118,6 @@ let format_string_list (fmt : Format.formatter) (uids : string list) : unit =
(Re.replace sanitize_quotes ~f:(fun _ -> "\\\"") info))) (Re.replace sanitize_quotes ~f:(fun _ -> "\\\"") info)))
uids uids
let format_unop (fmt : Format.formatter) (op : unop Marked.pos) : unit =
match Marked.unmark op with
| Minus _ -> Format.fprintf fmt "-"
| Not -> Format.fprintf fmt "not"
| Log (entry, infos) -> assert false (* should not happen *)
| Length -> Format.fprintf fmt "%s" "list_length"
| IntToRat -> Format.fprintf fmt "%s" "decimal_of_integer"
| MoneyToRat -> Format.fprintf fmt "%s" "decimal_of_money"
| RatToMoney -> Format.fprintf fmt "%s" "money_of_decimal"
| GetDay -> Format.fprintf fmt "%s" "day_of_month_of_date"
| GetMonth -> Format.fprintf fmt "%s" "month_number_of_date"
| GetYear -> Format.fprintf fmt "%s" "year_of_date"
| FirstDayOfMonth -> Format.fprintf fmt "%s" "first_day_of_month"
| LastDayOfMonth -> Format.fprintf fmt "%s" "last_day_of_month"
| RoundMoney -> Format.fprintf fmt "%s" "money_round"
| RoundDecimal -> Format.fprintf fmt "%s" "decimal_round"
let avoid_keywords (s : string) : string = let avoid_keywords (s : string) : string =
if if
match s with match s with
@ -125,24 +136,26 @@ let avoid_keywords (s : string) : string =
let format_struct_name (fmt : Format.formatter) (v : StructName.t) : unit = let format_struct_name (fmt : Format.formatter) (v : StructName.t) : unit =
Format.fprintf fmt "%s" Format.fprintf fmt "%s"
(avoid_keywords (avoid_keywords
(to_camel_case (to_ascii (Format.asprintf "%a" StructName.format_t v)))) (String.to_camel_case
(String.to_ascii (Format.asprintf "%a" StructName.format_t v))))
let format_struct_field_name (fmt : Format.formatter) (v : StructFieldName.t) : let format_struct_field_name (fmt : Format.formatter) (v : StructField.t) : unit
unit = =
Format.fprintf fmt "%s" Format.fprintf fmt "%s"
(avoid_keywords (avoid_keywords
(to_ascii (Format.asprintf "%a" StructFieldName.format_t v))) (String.to_ascii (Format.asprintf "%a" StructField.format_t v)))
let format_enum_name (fmt : Format.formatter) (v : EnumName.t) : unit = let format_enum_name (fmt : Format.formatter) (v : EnumName.t) : unit =
Format.fprintf fmt "%s" Format.fprintf fmt "%s"
(avoid_keywords (avoid_keywords
(to_camel_case (to_ascii (Format.asprintf "%a" EnumName.format_t v)))) (String.to_camel_case
(String.to_ascii (Format.asprintf "%a" EnumName.format_t v))))
let format_enum_cons_name (fmt : Format.formatter) (v : EnumConstructor.t) : let format_enum_cons_name (fmt : Format.formatter) (v : EnumConstructor.t) :
unit = unit =
Format.fprintf fmt "%s" Format.fprintf fmt "%s"
(avoid_keywords (avoid_keywords
(to_ascii (Format.asprintf "%a" EnumConstructor.format_t v))) (String.to_ascii (Format.asprintf "%a" EnumConstructor.format_t v)))
let typ_needs_parens (e : typ) : bool = let typ_needs_parens (e : typ) : bool =
match Marked.unmark e with TArrow _ | TArray _ -> true | _ -> false match Marked.unmark e with TArrow _ | TArray _ -> true | _ -> false
@ -180,10 +193,10 @@ let rec format_typ (fmt : Format.formatter) (typ : typ) : unit =
let format_name_cleaned (fmt : Format.formatter) (s : string) : unit = let format_name_cleaned (fmt : Format.formatter) (s : string) : unit =
s s
|> to_ascii |> String.to_ascii
|> to_snake_case |> String.to_snake_case
|> Re.Pcre.substitute ~rex:(Re.Pcre.regexp "\\.") ~subst:(fun _ -> "_dot_") |> Re.Pcre.substitute ~rex:(Re.Pcre.regexp "\\.") ~subst:(fun _ -> "_dot_")
|> to_ascii |> String.to_ascii
|> avoid_keywords |> avoid_keywords
|> Format.fprintf fmt "%s" |> Format.fprintf fmt "%s"
@ -268,10 +281,11 @@ let rec format_expression (ctx : decl_ctx) (fmt : Format.formatter) (e : expr) :
Format.fprintf fmt "%a(%a)" format_struct_name s Format.fprintf fmt "%a(%a)" format_struct_name s
(Format.pp_print_list (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ") ~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ")
(fun fmt (e, struct_field) -> (fun fmt (e, (struct_field, _)) ->
Format.fprintf fmt "%a = %a" format_struct_field_name struct_field Format.fprintf fmt "%a = %a" format_struct_field_name struct_field
(format_expression ctx) e)) (format_expression ctx) e))
(List.combine es (List.map fst (StructMap.find s ctx.ctx_structs))) (List.combine es
(StructField.Map.bindings (StructName.Map.find s ctx.ctx_structs)))
| EStructFieldAccess (e1, field, _) -> | EStructFieldAccess (e1, field, _) ->
Format.fprintf fmt "%a.%a" (format_expression ctx) e1 Format.fprintf fmt "%a.%a" (format_expression ctx) e1
format_struct_field_name field format_struct_field_name field
@ -296,21 +310,20 @@ let rec format_expression (ctx : decl_ctx) (fmt : Format.formatter) (e : expr) :
(fun fmt e -> Format.fprintf fmt "%a" (format_expression ctx) e)) (fun fmt e -> Format.fprintf fmt "%a" (format_expression ctx) e))
es es
| ELit l -> Format.fprintf fmt "%a" format_lit (Marked.same_mark_as l e) | ELit l -> Format.fprintf fmt "%a" format_lit (Marked.same_mark_as l e)
| EApp ((EOp (Binop ((Map | Filter) as op)), _), [arg1; arg2]) -> | EApp ((EOp ((Map | Filter) as op), _), [arg1; arg2]) ->
Format.fprintf fmt "%a(%a,@ %a)" format_binop (op, Pos.no_pos) Format.fprintf fmt "%a(%a,@ %a)" format_op (op, Pos.no_pos)
(format_expression ctx) arg1 (format_expression ctx) arg2 (format_expression ctx) arg1 (format_expression ctx) arg2
| EApp ((EOp (Binop op), _), [arg1; arg2]) -> | EApp ((EOp op, _), [arg1; arg2]) ->
Format.fprintf fmt "(%a %a@ %a)" (format_expression ctx) arg1 format_binop Format.fprintf fmt "(%a %a@ %a)" (format_expression ctx) arg1 format_op
(op, Pos.no_pos) (format_expression ctx) arg2 (op, Pos.no_pos) (format_expression ctx) arg2
| EApp ((EApp ((EOp (Unop (Log (BeginCall, info))), _), [f]), _), [arg]) | EApp ((EApp ((EOp (Log (BeginCall, info)), _), [f]), _), [arg])
when !Cli.trace_flag -> when !Cli.trace_flag ->
Format.fprintf fmt "log_begin_call(%a,@ %a,@ %a)" format_uid_list info Format.fprintf fmt "log_begin_call(%a,@ %a,@ %a)" format_uid_list info
(format_expression ctx) f (format_expression ctx) arg (format_expression ctx) f (format_expression ctx) arg
| EApp ((EOp (Unop (Log (VarDef tau, info))), _), [arg1]) when !Cli.trace_flag | EApp ((EOp (Log (VarDef tau, info)), _), [arg1]) when !Cli.trace_flag ->
->
Format.fprintf fmt "log_variable_definition(%a,@ %a)" format_uid_list info Format.fprintf fmt "log_variable_definition(%a,@ %a)" format_uid_list info
(format_expression ctx) arg1 (format_expression ctx) arg1
| EApp ((EOp (Unop (Log (PosRecordIfTrueBool, _))), pos), [arg1]) | EApp ((EOp (Log (PosRecordIfTrueBool, _)), pos), [arg1])
when !Cli.trace_flag -> when !Cli.trace_flag ->
Format.fprintf fmt Format.fprintf fmt
"log_decision_taken(SourcePosition(filename=\"%s\",@ start_line=%d,@ \ "log_decision_taken(SourcePosition(filename=\"%s\",@ start_line=%d,@ \
@ -318,16 +331,21 @@ let rec format_expression (ctx : decl_ctx) (fmt : Format.formatter) (e : expr) :
(Pos.get_file pos) (Pos.get_start_line pos) (Pos.get_start_column pos) (Pos.get_file pos) (Pos.get_start_line pos) (Pos.get_start_column pos)
(Pos.get_end_line pos) (Pos.get_end_column pos) format_string_list (Pos.get_end_line pos) (Pos.get_end_column pos) format_string_list
(Pos.get_law_info pos) (format_expression ctx) arg1 (Pos.get_law_info pos) (format_expression ctx) arg1
| EApp ((EOp (Unop (Log (EndCall, info))), _), [arg1]) when !Cli.trace_flag -> | EApp ((EOp (Log (EndCall, info)), _), [arg1]) when !Cli.trace_flag ->
Format.fprintf fmt "log_end_call(%a,@ %a)" format_uid_list info Format.fprintf fmt "log_end_call(%a,@ %a)" format_uid_list info
(format_expression ctx) arg1 (format_expression ctx) arg1
| EApp ((EOp (Unop (Log _)), _), [arg1]) -> | EApp ((EOp (Log _), _), [arg1]) ->
Format.fprintf fmt "%a" (format_expression ctx) arg1 Format.fprintf fmt "%a" (format_expression ctx) arg1
| EApp ((EOp (Unop ((Minus _ | Not) as op)), _), [arg1]) -> | EApp ((EOp Not, _), [arg1]) ->
Format.fprintf fmt "%a %a" format_unop (op, Pos.no_pos) Format.fprintf fmt "%a %a" format_op (Not, Pos.no_pos)
(format_expression ctx) arg1 (format_expression ctx) arg1
| EApp ((EOp (Unop op), _), [arg1]) -> | EApp
Format.fprintf fmt "%a(%a)" format_unop (op, Pos.no_pos) ((EOp ((Minus_int | Minus_rat | Minus_mon | Minus_dur) as op), _), [arg1])
->
Format.fprintf fmt "%a %a" format_op (op, Pos.no_pos)
(format_expression ctx) arg1
| EApp ((EOp op, _), [arg1]) ->
Format.fprintf fmt "%a(%a)" format_op (op, Pos.no_pos)
(format_expression ctx) arg1 (format_expression ctx) arg1
| EApp ((EFunc x, pos), args) | EApp ((EFunc x, pos), args)
when Ast.TopLevelName.compare x Ast.handle_default = 0 when Ast.TopLevelName.compare x Ast.handle_default = 0
@ -348,9 +366,7 @@ let rec format_expression (ctx : decl_ctx) (fmt : Format.formatter) (e : expr) :
~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ") ~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ")
(format_expression ctx)) (format_expression ctx))
args args
| EOp (Ternop op) -> Format.fprintf fmt "%a" format_ternop (op, Pos.no_pos) | EOp op -> Format.fprintf fmt "%a" format_op (op, Pos.no_pos)
| EOp (Binop op) -> Format.fprintf fmt "%a" format_binop (op, Pos.no_pos)
| EOp (Unop op) -> Format.fprintf fmt "%a" format_unop (op, Pos.no_pos)
let rec format_statement let rec format_statement
(ctx : decl_ctx) (ctx : decl_ctx)
@ -400,7 +416,7 @@ let rec format_statement
List.map2 List.map2
(fun (x, y) (cons, _) -> x, y, cons) (fun (x, y) (cons, _) -> x, y, cons)
cases cases
(EnumMap.find e_name ctx.ctx_enums) (EnumConstructor.Map.bindings (EnumName.Map.find e_name ctx.ctx_enums))
in in
let tmp_var = LocalName.fresh ("match_arg", Pos.no_pos) in let tmp_var = LocalName.fresh ("match_arg", Pos.no_pos) in
Format.fprintf fmt "%a = %a@\n@[<hov 4>if %a@]" format_var tmp_var Format.fprintf fmt "%a = %a@\n@[<hov 4>if %a@]" format_var tmp_var
@ -442,6 +458,7 @@ let format_ctx
(fmt : Format.formatter) (fmt : Format.formatter)
(ctx : decl_ctx) : unit = (ctx : decl_ctx) : unit =
let format_struct_decl fmt (struct_name, struct_fields) = let format_struct_decl fmt (struct_name, struct_fields) =
let fields = StructField.Map.bindings struct_fields in
Format.fprintf fmt Format.fprintf fmt
"class %a:@\n\ "class %a:@\n\
\ def __init__(self, %a) -> None:@\n\ \ def __init__(self, %a) -> None:@\n\
@ -461,40 +478,41 @@ let format_ctx
struct_name struct_name
(Format.pp_print_list (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ", ") ~pp_sep:(fun fmt () -> Format.fprintf fmt ", ")
(fun _fmt (struct_field, struct_field_type) -> (fun fmt (struct_field, struct_field_type) ->
Format.fprintf fmt "%a: %a" format_struct_field_name struct_field Format.fprintf fmt "%a: %a" format_struct_field_name struct_field
format_typ struct_field_type)) format_typ struct_field_type))
struct_fields fields
(if List.length struct_fields = 0 then fun fmt _ -> (if StructField.Map.is_empty struct_fields then fun fmt _ ->
Format.fprintf fmt " pass" Format.fprintf fmt " pass"
else else
Format.pp_print_list Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n") ~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n")
(fun _fmt (struct_field, _) -> (fun fmt (struct_field, _) ->
Format.fprintf fmt " self.%a = %a" format_struct_field_name Format.fprintf fmt " self.%a = %a" format_struct_field_name
struct_field format_struct_field_name struct_field)) struct_field format_struct_field_name struct_field))
struct_fields format_struct_name struct_name fields format_struct_name struct_name
(if List.length struct_fields > 0 then (if not (StructField.Map.is_empty struct_fields) then
Format.pp_print_list Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt " and@ ") ~pp_sep:(fun fmt () -> Format.fprintf fmt " and@ ")
(fun _fmt (struct_field, _) -> (fun fmt (struct_field, _) ->
Format.fprintf fmt "self.%a == other.%a" format_struct_field_name Format.fprintf fmt "self.%a == other.%a" format_struct_field_name
struct_field format_struct_field_name struct_field) struct_field format_struct_field_name struct_field)
else fun fmt _ -> Format.fprintf fmt "True") else fun fmt _ -> Format.fprintf fmt "True")
struct_fields format_struct_name struct_name fields format_struct_name struct_name
(Format.pp_print_list (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ",") ~pp_sep:(fun fmt () -> Format.fprintf fmt ",")
(fun _fmt (struct_field, _) -> (fun fmt (struct_field, _) ->
Format.fprintf fmt "%a={}" format_struct_field_name struct_field)) Format.fprintf fmt "%a={}" format_struct_field_name struct_field))
struct_fields fields
(Format.pp_print_list (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ") ~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ")
(fun _fmt (struct_field, _) -> (fun fmt (struct_field, _) ->
Format.fprintf fmt "self.%a" format_struct_field_name struct_field)) Format.fprintf fmt "self.%a" format_struct_field_name struct_field))
struct_fields fields
in in
let format_enum_decl fmt (enum_name, enum_cons) = let format_enum_decl fmt (enum_name, enum_cons) =
if List.length enum_cons = 0 then failwith "no constructors in the enum" if EnumConstructor.Map.is_empty enum_cons then
failwith "no constructors in the enum"
else else
Format.fprintf fmt Format.fprintf fmt
"@[<hov 4>class %a_Code(Enum):@\n\ "@[<hov 4>class %a_Code(Enum):@\n\
@ -522,9 +540,11 @@ let format_ctx
format_enum_name enum_name format_enum_name enum_name
(Format.pp_print_list (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n") ~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n")
(fun _fmt (i, enum_cons, enum_cons_type) -> (fun fmt (i, enum_cons, enum_cons_type) ->
Format.fprintf fmt "%a = %d" format_enum_cons_name enum_cons i)) Format.fprintf fmt "%a = %d" format_enum_cons_name enum_cons i))
(List.mapi (fun i (x, y) -> i, x, y) enum_cons) (List.mapi
(fun i (x, y) -> i, x, y)
(EnumConstructor.Map.bindings enum_cons))
format_enum_name enum_name format_enum_name enum_name format_enum_name format_enum_name enum_name format_enum_name enum_name format_enum_name
enum_name enum_name
in in
@ -540,8 +560,8 @@ let format_ctx
let scope_structs = let scope_structs =
List.map List.map
(fun (s, _) -> Scopelang.Dependency.TVertex.Struct s) (fun (s, _) -> Scopelang.Dependency.TVertex.Struct s)
(StructMap.bindings (StructName.Map.bindings
(StructMap.filter (StructName.Map.filter
(fun s _ -> not (is_in_type_ordering s)) (fun s _ -> not (is_in_type_ordering s))
ctx.ctx_structs)) ctx.ctx_structs))
in in
@ -550,10 +570,10 @@ let format_ctx
match struct_or_enum with match struct_or_enum with
| Scopelang.Dependency.TVertex.Struct s -> | Scopelang.Dependency.TVertex.Struct s ->
Format.fprintf fmt "%a@\n@\n" format_struct_decl Format.fprintf fmt "%a@\n@\n" format_struct_decl
(s, StructMap.find s ctx.ctx_structs) (s, StructName.Map.find s ctx.ctx_structs)
| Scopelang.Dependency.TVertex.Enum e -> | Scopelang.Dependency.TVertex.Enum e ->
Format.fprintf fmt "%a@\n@\n" format_enum_decl Format.fprintf fmt "%a@\n@\n" format_enum_decl
(e, EnumMap.find e ctx.ctx_enums)) (e, EnumName.Map.find e ctx.ctx_enums))
(type_ordering @ scope_structs) (type_ordering @ scope_structs)
let format_program let format_program

View File

@ -14,7 +14,7 @@
License for the specific language governing permissions and limitations under License for the specific language governing permissions and limitations under
the License. *) the License. *)
open Utils open Catala_utils
open Shared_ast open Shared_ast
type location = scopelang glocation type location = scopelang glocation
@ -31,7 +31,7 @@ type 'm expr = (scopelang, 'm mark) gexpr
let rec locations_used (e : 'm expr) : LocationSet.t = let rec locations_used (e : 'm expr) : LocationSet.t =
match e with match e with
| ELocation l, pos -> LocationSet.singleton (l, Expr.mark_pos pos) | ELocation l, pos -> LocationSet.singleton (l, Expr.mark_pos pos)
| EAbs (binder, _), _ -> | EAbs { binder; _ }, _ ->
let _, body = Bindlib.unmbind binder in let _, body = Bindlib.unmbind binder in
locations_used body locations_used body
| e -> | e ->
@ -39,23 +39,20 @@ let rec locations_used (e : 'm expr) : LocationSet.t =
(fun e -> LocationSet.union (locations_used e)) (fun e -> LocationSet.union (locations_used e))
e LocationSet.empty e LocationSet.empty
type io_input = NoInput | OnlyInput | Reentrant
type io = { io_output : bool Marked.pos; io_input : io_input Marked.pos }
type 'm rule = type 'm rule =
| Definition of location Marked.pos * typ * io * 'm expr | Definition of location Marked.pos * typ * Desugared.Ast.io * 'm expr
| Assertion of 'm expr | Assertion of 'm expr
| Call of ScopeName.t * SubScopeName.t * 'm mark | Call of ScopeName.t * SubScopeName.t * 'm mark
type 'm scope_decl = { type 'm scope_decl = {
scope_decl_name : ScopeName.t; scope_decl_name : ScopeName.t;
scope_sig : (typ * io) ScopeVarMap.t; scope_sig : (typ * Desugared.Ast.io) ScopeVar.Map.t;
scope_decl_rules : 'm rule list; scope_decl_rules : 'm rule list;
scope_mark : 'm mark; scope_mark : 'm mark;
} }
type 'm program = { type 'm program = {
program_scopes : 'm scope_decl ScopeMap.t; program_scopes : 'm scope_decl ScopeName.Map.t;
program_ctx : decl_ctx; program_ctx : decl_ctx;
} }
@ -73,17 +70,17 @@ let type_rule decl_ctx env = function
let type_program (prg : 'm program) : typed program = let type_program (prg : 'm program) : typed program =
let typing_env = let typing_env =
ScopeMap.fold ScopeName.Map.fold
(fun scope_name scope_decl -> (fun scope_name scope_decl ->
let vars = ScopeVarMap.map fst scope_decl.scope_sig in let vars = ScopeVar.Map.map fst scope_decl.scope_sig in
Typing.Env.add_scope scope_name ~vars) Typing.Env.add_scope scope_name ~vars)
prg.program_scopes Typing.Env.empty prg.program_scopes Typing.Env.empty
in in
let program_scopes = let program_scopes =
ScopeMap.map ScopeName.Map.map
(fun scope_decl -> (fun scope_decl ->
let typing_env = let typing_env =
ScopeVarMap.fold ScopeVar.Map.fold
(fun svar (typ, _) env -> Typing.Env.add_scope_var svar typ env) (fun svar (typ, _) env -> Typing.Env.add_scope_var svar typ env)
scope_decl.scope_sig typing_env scope_decl.scope_sig typing_env
in in

View File

@ -16,7 +16,7 @@
(** Abstract syntax tree of the scope language *) (** Abstract syntax tree of the scope language *)
open Utils open Catala_utils
open Shared_ast open Shared_ast
(** {1 Identifiers} *) (** {1 Identifiers} *)
@ -31,41 +31,20 @@ type 'm expr = (scopelang, 'm mark) gexpr
val locations_used : 'm expr -> LocationSet.t val locations_used : 'm expr -> LocationSet.t
(** This type characterizes the three levels of visibility for a given scope
variable with regards to the scope's input and possible redefinitions inside
the scope.. *)
type io_input =
| NoInput
(** For an internal variable defined only in the scope, and does not
appear in the input. *)
| OnlyInput
(** For variables that should not be redefined in the scope, because they
appear in the input. *)
| Reentrant
(** For variables defined in the scope that can also be redefined by the
caller as they appear in the input. *)
type io = {
io_output : bool Marked.pos;
(** [true] is present in the output of the scope. *)
io_input : io_input Marked.pos;
}
(** Characterization of the input/output status of a scope variable. *)
type 'm rule = type 'm rule =
| Definition of location Marked.pos * typ * io * 'm expr | Definition of location Marked.pos * typ * Desugared.Ast.io * 'm expr
| Assertion of 'm expr | Assertion of 'm expr
| Call of ScopeName.t * SubScopeName.t * 'm mark | Call of ScopeName.t * SubScopeName.t * 'm mark
type 'm scope_decl = { type 'm scope_decl = {
scope_decl_name : ScopeName.t; scope_decl_name : ScopeName.t;
scope_sig : (typ * io) ScopeVarMap.t; scope_sig : (typ * Desugared.Ast.io) ScopeVar.Map.t;
scope_decl_rules : 'm rule list; scope_decl_rules : 'm rule list;
scope_mark : 'm mark; scope_mark : 'm mark;
} }
type 'm program = { type 'm program = {
program_scopes : 'm scope_decl ScopeMap.t; program_scopes : 'm scope_decl ScopeName.Map.t;
program_ctx : decl_ctx; program_ctx : decl_ctx;
} }

View File

@ -17,7 +17,7 @@
(** Graph representation of the dependencies between scopes in the Catala (** Graph representation of the dependencies between scopes in the Catala
program. Vertices are functions, x -> y if x is used in the definition of y. *) program. Vertices are functions, x -> y if x is used in the definition of y. *)
open Utils open Catala_utils
open Shared_ast open Shared_ast
module SVertex = ScopeName module SVertex = ScopeName
@ -41,13 +41,13 @@ module SSCC = Graph.Components.Make (SDependencies)
let rec expr_used_scopes e = let rec expr_used_scopes e =
let recurse_subterms e = let recurse_subterms e =
Expr.shallow_fold Expr.shallow_fold
(fun e -> ScopeMap.union (fun _ x _ -> Some x) (expr_used_scopes e)) (fun e -> ScopeName.Map.union (fun _ x _ -> Some x) (expr_used_scopes e))
e ScopeMap.empty e ScopeName.Map.empty
in in
match e with match e with
| (EScopeCall (scope, _), m) as e -> | (EScopeCall { scope; _ }, m) as e ->
ScopeMap.add scope (Expr.mark_pos m) (recurse_subterms e) ScopeName.Map.add scope (Expr.mark_pos m) (recurse_subterms e)
| EAbs (binder, _), _ -> | EAbs { binder; _ }, _ ->
let _, body = Bindlib.unmbind binder in let _, body = Bindlib.unmbind binder in
expr_used_scopes body expr_used_scopes body
| e -> recurse_subterms e | e -> recurse_subterms e
@ -58,28 +58,28 @@ let rule_used_scopes = function
walking through all exprs again *) walking through all exprs again *)
expr_used_scopes e expr_used_scopes e
| Ast.Call (subscope, subindex, _) -> | Ast.Call (subscope, subindex, _) ->
ScopeMap.singleton subscope ScopeName.Map.singleton subscope
(Marked.get_mark (SubScopeName.get_info subindex)) (Marked.get_mark (SubScopeName.get_info subindex))
let build_program_dep_graph (prgm : 'm Ast.program) : SDependencies.t = let build_program_dep_graph (prgm : 'm Ast.program) : SDependencies.t =
let g = SDependencies.empty in let g = SDependencies.empty in
let g = let g =
ScopeMap.fold ScopeName.Map.fold
(fun v _ g -> SDependencies.add_vertex g v) (fun v _ g -> SDependencies.add_vertex g v)
prgm.program_scopes g prgm.program_scopes g
in in
ScopeMap.fold ScopeName.Map.fold
(fun scope_name scope g -> (fun scope_name scope g ->
List.fold_left List.fold_left
(fun g rule -> (fun g rule ->
let used_scopes = rule_used_scopes rule in let used_scopes = rule_used_scopes rule in
if ScopeMap.mem scope_name used_scopes then if ScopeName.Map.mem scope_name used_scopes then
Errors.raise_spanned_error Errors.raise_spanned_error
(Marked.get_mark (ScopeName.get_info scope.Ast.scope_decl_name)) (Marked.get_mark (ScopeName.get_info scope.Ast.scope_decl_name))
"The scope %a is calling into itself as a subscope, which is \ "The scope %a is calling into itself as a subscope, which is \
forbidden since Catala does not provide recursion" forbidden since Catala does not provide recursion"
ScopeName.format_t scope.Ast.scope_decl_name; ScopeName.format_t scope.Ast.scope_decl_name;
ScopeMap.fold ScopeName.Map.fold
(fun used_scope pos g -> (fun used_scope pos g ->
let edge = SDependencies.E.create used_scope pos scope_name in let edge = SDependencies.E.create used_scope pos scope_name in
SDependencies.add_edge_e g edge) SDependencies.add_edge_e g edge)
@ -190,10 +190,10 @@ let build_type_graph (structs : struct_ctx) (enums : enum_ctx) : TDependencies.t
= =
let g = TDependencies.empty in let g = TDependencies.empty in
let g = let g =
StructMap.fold StructName.Map.fold
(fun s fields g -> (fun s fields g ->
List.fold_left StructField.Map.fold
(fun g (_, typ) -> (fun _ typ g ->
let def = TVertex.Struct s in let def = TVertex.Struct s in
let g = TDependencies.add_vertex g def in let g = TDependencies.add_vertex g def in
let used = get_structs_or_enums_in_type typ in let used = get_structs_or_enums_in_type typ in
@ -210,14 +210,14 @@ let build_type_graph (structs : struct_ctx) (enums : enum_ctx) : TDependencies.t
in in
TDependencies.add_edge_e g edge) TDependencies.add_edge_e g edge)
used g) used g)
g fields) fields g)
structs g structs g
in in
let g = let g =
EnumMap.fold EnumName.Map.fold
(fun e cases g -> (fun e cases g ->
List.fold_left EnumConstructor.Map.fold
(fun g (_, typ) -> (fun _ typ g ->
let def = TVertex.Enum e in let def = TVertex.Enum e in
let g = TDependencies.add_vertex g def in let g = TDependencies.add_vertex g def in
let used = get_structs_or_enums_in_type typ in let used = get_structs_or_enums_in_type typ in
@ -234,7 +234,7 @@ let build_type_graph (structs : struct_ctx) (enums : enum_ctx) : TDependencies.t
in in
TDependencies.add_edge_e g edge) TDependencies.add_edge_e g edge)
used g) used g)
g cases) cases g)
enums g enums g
in in
g g

View File

@ -17,7 +17,7 @@
(** Graph representation of the dependencies between scopes in the Catala (** Graph representation of the dependencies between scopes in the Catala
program. Vertices are functions, x -> y if x is used in the definition of y. *) program. Vertices are functions, x -> y if x is used in the definition of y. *)
open Utils open Catala_utils
open Shared_ast open Shared_ast
(** {1 Scope dependencies} *) (** {1 Scope dependencies} *)

View File

@ -1,7 +1,7 @@
(library (library
(name scopelang) (name scopelang)
(public_name catala.scopelang) (public_name catala.scopelang)
(libraries utils dcalc ocamlgraph) (libraries catala_utils ocamlgraph desugared)
(flags (flags
(:standard -short-paths))) (:standard -short-paths)))

View File

@ -0,0 +1,730 @@
(* This file is part of the Catala compiler, a specification language for tax
and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
Denis Merigoux <denis.merigoux@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not
use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *)
(** Translation from {!module: Desugared.Ast} to {!module: Scopelang.Ast} *)
open Catala_utils
open Shared_ast
(** {1 Expression translation}*)
type target_scope_vars =
| WholeVar of ScopeVar.t
| States of (StateName.t * ScopeVar.t) list
type ctx = {
decl_ctx : decl_ctx;
scope_var_mapping : target_scope_vars ScopeVar.Map.t;
var_mapping : (Desugared.Ast.expr, untyped Ast.expr Var.t) Var.Map.t;
}
let tag_with_log_entry
(e : untyped Ast.expr boxed)
(l : log_entry)
(markings : Uid.MarkedString.info list) : untyped Ast.expr boxed =
Expr.eapp
(Expr.eop (Log (l, markings)) [TAny, Expr.pos e] (Marked.get_mark e))
[e] (Marked.get_mark e)
let rec translate_expr (ctx : ctx) (e : Desugared.Ast.expr) :
untyped Ast.expr boxed =
let m = Marked.get_mark e in
match Marked.unmark e with
| ELocation (SubScopeVar (s_name, ss_name, s_var)) ->
(* When referring to a subscope variable in an expression, we are referring
to the output, hence we take the last state. *)
let new_s_var =
match ScopeVar.Map.find (Marked.unmark s_var) ctx.scope_var_mapping with
| WholeVar new_s_var -> Marked.same_mark_as new_s_var s_var
| States states ->
Marked.same_mark_as (snd (List.hd (List.rev states))) s_var
in
Expr.elocation (SubScopeVar (s_name, ss_name, new_s_var)) m
| ELocation (DesugaredScopeVar (s_var, None)) ->
Expr.elocation
(ScopelangScopeVar
(match
ScopeVar.Map.find (Marked.unmark s_var) ctx.scope_var_mapping
with
| WholeVar new_s_var -> Marked.same_mark_as new_s_var s_var
| States _ -> failwith "should not happen"))
m
| ELocation (DesugaredScopeVar (s_var, Some state)) ->
Expr.elocation
(ScopelangScopeVar
(match
ScopeVar.Map.find (Marked.unmark s_var) ctx.scope_var_mapping
with
| WholeVar _ -> failwith "should not happen"
| States states -> Marked.same_mark_as (List.assoc state states) s_var))
m
| EVar v -> Expr.evar (Var.Map.find v ctx.var_mapping) m
| EStruct { name; fields } ->
Expr.estruct name (StructField.Map.map (translate_expr ctx) fields) m
| EDStructAccess { name_opt = None; _ } ->
(* Note: this could only happen if disambiguation was disabled. If we want
to support it, we should still allow this case when the field has only
one possible matching structure *)
Errors.raise_spanned_error (Expr.mark_pos m)
"Ambiguous structure field access"
| EDStructAccess { e; field; name_opt = Some name } ->
let e' = translate_expr ctx e in
let field =
try
StructName.Map.find name
(IdentName.Map.find field ctx.decl_ctx.ctx_struct_fields)
with Not_found ->
(* Should not happen after disambiguation *)
Errors.raise_spanned_error (Expr.mark_pos m)
"Field %s does not belong to structure %a" field StructName.format_t
name
in
Expr.estructaccess e' field name m
| EInj { e; cons; name } -> Expr.einj (translate_expr ctx e) cons name m
| EMatch { e; name; cases } ->
Expr.ematch (translate_expr ctx e) name
(EnumConstructor.Map.map (translate_expr ctx) cases)
m
| EScopeCall { scope; args } ->
Expr.escopecall scope
(ScopeVar.Map.fold
(fun v e args' ->
let v' =
match ScopeVar.Map.find v ctx.scope_var_mapping with
| WholeVar v' -> v'
| States ((_, v') :: _) ->
(* When there are multiple states, the input is always the first
one *)
v'
| States [] -> assert false
in
ScopeVar.Map.add v' (translate_expr ctx e) args')
args ScopeVar.Map.empty)
m
| ELit
(( LBool _ | LEmptyError | LInt _ | LRat _ | LMoney _ | LUnit | LDate _
| LDuration _ ) as l) ->
Expr.elit l m
| EAbs { binder; tys } ->
let vars, body = Bindlib.unmbind binder in
let new_vars = Array.map (fun var -> Var.make (Bindlib.name_of var)) vars in
let ctx =
List.fold_left2
(fun ctx var new_var ->
{ ctx with var_mapping = Var.Map.add var new_var ctx.var_mapping })
ctx (Array.to_list vars) (Array.to_list new_vars)
in
Expr.eabs (Expr.bind new_vars (translate_expr ctx body)) tys m
| EApp { f = EOp { op; tys }, m1; args } ->
let args = List.map (translate_expr ctx) args in
Operator.kind_dispatch op
~monomorphic:(fun op -> Expr.eapp (Expr.eop op tys m1) args m)
~polymorphic:(fun op -> Expr.eapp (Expr.eop op tys m1) args m)
~overloaded:(fun op ->
match
Operator.resolve_overload ctx.decl_ctx
(Marked.mark (Expr.pos e) op)
tys
with
| op, `Straight -> Expr.eapp (Expr.eop op tys m1) args m
| op, `Reversed ->
Expr.eapp (Expr.eop op (List.rev tys) m1) (List.rev args) m)
| EOp _ -> assert false (* Only allowed within [EApp] *)
| EApp { f; args } ->
Expr.eapp (translate_expr ctx f) (List.map (translate_expr ctx) args) m
| EDefault { excepts; just; cons } ->
Expr.edefault
(List.map (translate_expr ctx) excepts)
(translate_expr ctx just) (translate_expr ctx cons) m
| EIfThenElse { cond; etrue; efalse } ->
Expr.eifthenelse (translate_expr ctx cond) (translate_expr ctx etrue)
(translate_expr ctx efalse)
m
| EArray args -> Expr.earray (List.map (translate_expr ctx) args) m
| EErrorOnEmpty e1 -> Expr.eerroronempty (translate_expr ctx e1) m
(** {1 Rule tree construction} *)
(** Intermediate representation for the exception tree of rules for a particular
scope definition. *)
type rule_tree =
| Leaf of Desugared.Ast.rule list
(** Rules defining a base case piecewise. List is non-empty. *)
| Node of rule_tree list * Desugared.Ast.rule list
(** [Node (exceptions, base_case)] is a list of exceptions to a non-empty
list of rules defining a base case piecewise. *)
(** Transforms a flat list of rules into a tree, taking into account the
priorities declared between rules *)
let def_map_to_tree
(def_info : Desugared.Ast.ScopeDef.t)
(def : Desugared.Ast.rule RuleName.Map.t) : rule_tree list =
let exc_graph = Desugared.Dependency.build_exceptions_graph def def_info in
Desugared.Dependency.check_for_exception_cycle exc_graph;
(* we start by the base cases: they are the vertices which have no
successors *)
let base_cases =
Desugared.Dependency.ExceptionsDependencies.fold_vertex
(fun v base_cases ->
if
Desugared.Dependency.ExceptionsDependencies.out_degree exc_graph v = 0
then v :: base_cases
else base_cases)
exc_graph []
in
let rec build_tree (base_cases : RuleName.Set.t) : rule_tree =
let exceptions =
Desugared.Dependency.ExceptionsDependencies.pred exc_graph base_cases
in
let base_case_as_rule_list =
List.map
(fun r -> RuleName.Map.find r def)
(RuleName.Set.elements base_cases)
in
match exceptions with
| [] -> Leaf base_case_as_rule_list
| _ -> Node (List.map build_tree exceptions, base_case_as_rule_list)
in
List.map build_tree base_cases
(** From the {!type: rule_tree}, builds an {!constructor: Dcalc.EDefault}
expression in the scope language. The [~toplevel] parameter is used to know
when to place the toplevel binding in the case of functions. *)
let rec rule_tree_to_expr
~(toplevel : bool)
~(is_reentrant_var : bool)
(ctx : ctx)
(def_pos : Pos.t)
(is_func : Desugared.Ast.expr Var.t option)
(tree : rule_tree) : untyped Ast.expr boxed =
let emark = Untyped { pos = def_pos } in
let exceptions, base_rules =
match tree with Leaf r -> [], r | Node (exceptions, r) -> exceptions, r
in
(* because each rule has its own variable parameter and we want to convert the
whole rule tree into a function, we need to perform some alpha-renaming of
all the expressions *)
let substitute_parameter
(e : Desugared.Ast.expr boxed)
(rule : Desugared.Ast.rule) : Desugared.Ast.expr boxed =
match is_func, rule.Desugared.Ast.rule_parameter with
| Some new_param, Some (old_param, _) ->
let binder = Bindlib.bind_var old_param (Marked.unmark e) in
Marked.mark (Marked.get_mark e)
@@ Bindlib.box_apply2
(fun binder new_param -> Bindlib.subst binder new_param)
binder
(Bindlib.box_var new_param)
| None, None -> e
| _ -> assert false
(* should not happen *)
in
let ctx =
match is_func with
| None -> ctx
| Some new_param -> (
match Var.Map.find_opt new_param ctx.var_mapping with
| None ->
let new_param_scope = Var.make (Bindlib.name_of new_param) in
{
ctx with
var_mapping = Var.Map.add new_param new_param_scope ctx.var_mapping;
}
| Some _ ->
(* We only create a mapping if none exists because [rule_tree_to_expr]
is called recursively on the exceptions of the tree and we don't want
to create a new Scopelang variable for the parameter at each tree
level. *)
ctx)
in
let base_just_list =
List.map
(fun rule -> substitute_parameter rule.Desugared.Ast.rule_just rule)
base_rules
in
let base_cons_list =
List.map
(fun rule -> substitute_parameter rule.Desugared.Ast.rule_cons rule)
base_rules
in
let translate_and_unbox_list (list : Desugared.Ast.expr boxed list) :
untyped Ast.expr boxed list =
List.map
(fun e ->
(* There are two levels of boxing here, the outermost is introduced by
the [translate_expr] function for which all of the bindings should
have been closed by now, so we can safely unbox. *)
translate_expr ctx (Expr.unbox e))
list
in
let default_containing_base_cases =
Expr.make_default
(List.map2
(fun base_just base_cons ->
Expr.make_default []
(* Here we insert the logging command that records when a decision
is taken for the value of a variable. *)
(tag_with_log_entry base_just PosRecordIfTrueBool [])
base_cons emark)
(translate_and_unbox_list base_just_list)
(translate_and_unbox_list base_cons_list))
(Expr.elit (LBool false) emark)
(Expr.elit LEmptyError emark)
emark
in
let exceptions =
List.map
(rule_tree_to_expr ~toplevel:false ~is_reentrant_var ctx def_pos is_func)
exceptions
in
let default =
Expr.make_default exceptions
(Expr.elit (LBool true) emark)
default_containing_base_cases emark
in
match is_func, (List.hd base_rules).Desugared.Ast.rule_parameter with
| None, None -> default
| Some new_param, Some (_, typ) ->
if toplevel then
(* When we're creating a function from multiple defaults, we must check
that the result returned by the function is not empty, unless we're
dealing with a context variable which is reentrant (either in the
caller or callee). In this case the ErrorOnEmpty will be added later in
the scopelang->dcalc translation. *)
let default =
if is_reentrant_var then default else Expr.eerroronempty default emark
in
Expr.make_abs
[| Var.Map.find new_param ctx.var_mapping |]
default [typ] def_pos
else default
| _ -> (* should not happen *) assert false
(** {1 AST translation} *)
(** Translates a definition inside a scope, the resulting expression should be
an {!constructor: Dcalc.EDefault} *)
let translate_def
(ctx : ctx)
(def_info : Desugared.Ast.ScopeDef.t)
(def : Desugared.Ast.rule RuleName.Map.t)
(typ : typ)
(io : Desugared.Ast.io)
~(is_cond : bool)
~(is_subscope_var : bool) : untyped Ast.expr boxed =
(* Here, we have to transform this list of rules into a default tree. *)
let is_def_func =
match Marked.unmark typ with TArrow (_, _) -> true | _ -> false
in
let is_rule_func _ (r : Desugared.Ast.rule) : bool =
Option.is_some r.Desugared.Ast.rule_parameter
in
let all_rules_func = RuleName.Map.for_all is_rule_func def in
let all_rules_not_func =
RuleName.Map.for_all (fun n r -> not (is_rule_func n r)) def
in
let is_def_func_param_typ : typ option =
if is_def_func && all_rules_func then
match Marked.unmark typ with
| TArrow (t_param, _) -> Some t_param
| _ ->
Errors.raise_spanned_error (Marked.get_mark typ)
"The definitions of %a are function but it doesn't have a function \
type"
Desugared.Ast.ScopeDef.format_t def_info
else if (not is_def_func) && all_rules_not_func then None
else
let spans =
List.map
(fun (_, r) ->
( Some "This definition is a function:",
Expr.pos r.Desugared.Ast.rule_cons ))
(RuleName.Map.bindings (RuleName.Map.filter is_rule_func def))
@ List.map
(fun (_, r) ->
( Some "This definition is not a function:",
Expr.pos r.Desugared.Ast.rule_cons ))
(RuleName.Map.bindings
(RuleName.Map.filter (fun n r -> not (is_rule_func n r)) def))
in
Errors.raise_multispanned_error spans
"some definitions of the same variable are functions while others \
aren't"
in
let top_list = def_map_to_tree def_info def in
let is_input =
match Marked.unmark io.Desugared.Ast.io_input with
| OnlyInput -> true
| _ -> false
in
let is_reentrant =
match Marked.unmark io.Desugared.Ast.io_input with
| Reentrant -> true
| _ -> false
in
let top_value =
if is_cond && ((not is_subscope_var) || (is_subscope_var && is_input)) then
(* We add the bottom [false] value for conditions, only for the scope
where the condition is declared. Except when the variable is an input,
where we want the [false] to be added at each caller parent scope. *)
Some
(Desugared.Ast.always_false_rule
(Desugared.Ast.ScopeDef.get_position def_info)
is_def_func_param_typ)
else None
in
if
RuleName.Map.cardinal def = 0
&& is_subscope_var
(* Here we have a special case for the empty definitions. Indeed, we could
use the code for the regular case below that would create a convoluted
default always returning empty error, and this would be correct. But it
gets more complicated with functions. Indeed, if we create an empty
definition for a subscope argument whose type is a function, we get
something like [fun () -> (fun real_param -> < ... >)] that is passed as
an argument to the subscope. The sub-scope de-thunks but the de-thunking
does not return empty error, signalling there is not reentrant variable,
because functions are values! So the subscope does not see that there is
not reentrant variable and does not pick its internal definition instead.
See [test/test_scope/subscope_function_arg_not_defined.catala_en] for a
test case exercising that subtlety.
To avoid this complication we special case here and put an empty error
for all subscope variables that are not defined. It covers the subtlety
with functions described above but also conditions with the false default
value. *)
&& not (is_cond && is_input)
(* However, this special case suffers from an exception: when a condition is
defined as an OnlyInput to a subscope, since the [false] default value
will not be provided by the calee scope, it has to be placed in the
caller. *)
then
let m = Untyped { pos = Desugared.Ast.ScopeDef.get_position def_info } in
let empty_error = Expr.elit LEmptyError m in
match is_def_func_param_typ with
| Some ty ->
Expr.make_abs [| Var.make "_" |] empty_error [ty] (Expr.mark_pos m)
| _ -> empty_error
else
rule_tree_to_expr ~toplevel:true ~is_reentrant_var:is_reentrant ctx
(Desugared.Ast.ScopeDef.get_position def_info)
(Option.map (fun _ -> Var.make "param") is_def_func_param_typ)
(match top_list, top_value with
| [], None ->
(* In this case, there are no rules to define the expression and no
default value so we put an empty rule. *)
Leaf
[Desugared.Ast.empty_rule (Marked.get_mark typ) is_def_func_param_typ]
| [], Some top_value ->
(* In this case, there are no rules to define the expression but a
default value so we put it. *)
Leaf [top_value]
| _, Some top_value ->
(* When there are rules + a default value, we put the rules as
exceptions to the default value *)
Node (top_list, [top_value])
| [top_tree], None -> top_tree
| _, None ->
Node
( top_list,
[
Desugared.Ast.empty_rule (Marked.get_mark typ)
is_def_func_param_typ;
] ))
let translate_rule ctx (scope : Desugared.Ast.scope) = function
| Desugared.Dependency.Vertex.Var (var, state) -> (
let scope_def =
Desugared.Ast.ScopeDefMap.find
(Desugared.Ast.ScopeDef.Var (var, state))
scope.scope_defs
in
let var_def = scope_def.scope_def_rules in
let var_typ = scope_def.scope_def_typ in
let is_cond = scope_def.scope_def_is_condition in
match Marked.unmark scope_def.Desugared.Ast.scope_def_io.io_input with
| OnlyInput when not (RuleName.Map.is_empty var_def) ->
(* If the variable is tagged as input, then it shall not be redefined. *)
Errors.raise_multispanned_error
((Some "Incriminated variable:", Marked.get_mark (ScopeVar.get_info var))
:: List.map
(fun (rule, _) ->
( Some "Incriminated variable definition:",
Marked.get_mark (RuleName.get_info rule) ))
(RuleName.Map.bindings var_def))
"It is impossible to give a definition to a scope variable tagged as \
input."
| OnlyInput -> []
(* we do not provide any definition for an input-only variable *)
| _ ->
let expr_def =
translate_def ctx
(Desugared.Ast.ScopeDef.Var (var, state))
var_def var_typ scope_def.Desugared.Ast.scope_def_io ~is_cond
~is_subscope_var:false
in
let scope_var =
match ScopeVar.Map.find var ctx.scope_var_mapping, state with
| WholeVar v, None -> v
| States states, Some state -> List.assoc state states
| _ -> failwith "should not happen"
in
[
Ast.Definition
( ( ScopelangScopeVar
(scope_var, Marked.get_mark (ScopeVar.get_info scope_var)),
Marked.get_mark (ScopeVar.get_info scope_var) ),
var_typ,
scope_def.Desugared.Ast.scope_def_io,
Expr.unbox expr_def );
])
| Desugared.Dependency.Vertex.SubScope sub_scope_index ->
(* Before calling the sub_scope, we need to include all the re-definitions
of subscope parameters*)
let sub_scope =
SubScopeName.Map.find sub_scope_index scope.scope_sub_scopes
in
let sub_scope_vars_redefs_candidates =
Desugared.Ast.ScopeDefMap.filter
(fun def_key scope_def ->
match def_key with
| Desugared.Ast.ScopeDef.Var _ -> false
| Desugared.Ast.ScopeDef.SubScopeVar (sub_scope_index', _, _) ->
sub_scope_index = sub_scope_index'
(* We exclude subscope variables that have 0 re-definitions and are
not visible in the input of the subscope *)
&& not
((match
Marked.unmark scope_def.Desugared.Ast.scope_def_io.io_input
with
| Desugared.Ast.NoInput -> true
| _ -> false)
&& RuleName.Map.is_empty scope_def.scope_def_rules))
scope.scope_defs
in
let sub_scope_vars_redefs =
Desugared.Ast.ScopeDefMap.mapi
(fun def_key scope_def ->
let def = scope_def.Desugared.Ast.scope_def_rules in
let def_typ = scope_def.scope_def_typ in
let is_cond = scope_def.scope_def_is_condition in
match def_key with
| Desugared.Ast.ScopeDef.Var _ -> assert false (* should not happen *)
| Desugared.Ast.ScopeDef.SubScopeVar (sscope, sub_scope_var, pos) ->
(* This definition redefines a variable of the correct subscope. But
we have to check that this redefinition is allowed with respect
to the io parameters of that subscope variable. *)
(match
Marked.unmark scope_def.Desugared.Ast.scope_def_io.io_input
with
| Desugared.Ast.NoInput ->
Errors.raise_multispanned_error
(( Some "Incriminated subscope:",
Marked.get_mark (SubScopeName.get_info sscope) )
:: ( Some "Incriminated variable:",
Marked.get_mark (ScopeVar.get_info sub_scope_var) )
:: List.map
(fun (rule, _) ->
( Some "Incriminated subscope variable definition:",
Marked.get_mark (RuleName.get_info rule) ))
(RuleName.Map.bindings def))
"It is impossible to give a definition to a subscope variable \
not tagged as input or context."
| OnlyInput when RuleName.Map.is_empty def && not is_cond ->
(* If the subscope variable is tagged as input, then it shall be
defined. *)
Errors.raise_multispanned_error
[
( Some "Incriminated subscope:",
Marked.get_mark (SubScopeName.get_info sscope) );
Some "Incriminated variable:", pos;
]
"This subscope variable is a mandatory input but no definition \
was provided."
| _ -> ());
(* Now that all is good, we can proceed with translating this
redefinition to a proper Scopelang term. *)
let expr_def =
translate_def ctx def_key def def_typ
scope_def.Desugared.Ast.scope_def_io ~is_cond
~is_subscope_var:true
in
let subscop_real_name =
SubScopeName.Map.find sub_scope_index scope.scope_sub_scopes
in
let var_pos = Desugared.Ast.ScopeDef.get_position def_key in
Ast.Definition
( ( SubScopeVar
( subscop_real_name,
(sub_scope_index, var_pos),
match
ScopeVar.Map.find sub_scope_var ctx.scope_var_mapping
with
| WholeVar v -> v, var_pos
| States states ->
(* When defining a sub-scope variable, we always define
its first state in the sub-scope. *)
snd (List.hd states), var_pos ),
var_pos ),
def_typ,
scope_def.Desugared.Ast.scope_def_io,
Expr.unbox expr_def ))
sub_scope_vars_redefs_candidates
in
let sub_scope_vars_redefs =
List.map snd (Desugared.Ast.ScopeDefMap.bindings sub_scope_vars_redefs)
in
sub_scope_vars_redefs
@ [
Ast.Call
( sub_scope,
sub_scope_index,
Untyped
{ pos = Marked.get_mark (SubScopeName.get_info sub_scope_index) }
);
]
(** Translates a scope *)
let translate_scope (ctx : ctx) (scope : Desugared.Ast.scope) :
untyped Ast.scope_decl =
let scope_dependencies =
Desugared.Dependency.build_scope_dependencies scope
in
Desugared.Dependency.check_for_cycle scope scope_dependencies;
let scope_ordering =
Desugared.Dependency.correct_computation_ordering scope_dependencies
in
let scope_decl_rules =
List.flatten (List.map (translate_rule ctx scope) scope_ordering)
in
(* Then, after having computed all the scopes variables, we add the
assertions. TODO: the assertions should be interleaved with the
definitions! *)
let scope_decl_rules =
scope_decl_rules
@ List.map
(fun e ->
let scope_e = translate_expr ctx (Expr.unbox e) in
Ast.Assertion (Expr.unbox scope_e))
scope.Desugared.Ast.scope_assertions
in
let scope_sig =
ScopeVar.Map.fold
(fun var (states : Desugared.Ast.var_or_states) acc ->
match states with
| WholeVar ->
let scope_def =
Desugared.Ast.ScopeDefMap.find
(Desugared.Ast.ScopeDef.Var (var, None))
scope.scope_defs
in
let typ = scope_def.scope_def_typ in
ScopeVar.Map.add
(match ScopeVar.Map.find var ctx.scope_var_mapping with
| WholeVar v -> v
| States _ -> failwith "should not happen")
(typ, scope_def.scope_def_io)
acc
| States states ->
(* What happens in the case of variables with multiple states is
interesting. We need to create as many Var entries in the scope
signature as there are states. *)
List.fold_left
(fun acc (state : StateName.t) ->
let scope_def =
Desugared.Ast.ScopeDefMap.find
(Desugared.Ast.ScopeDef.Var (var, Some state))
scope.scope_defs
in
ScopeVar.Map.add
(match ScopeVar.Map.find var ctx.scope_var_mapping with
| WholeVar _ -> failwith "should not happen"
| States states' -> List.assoc state states')
(scope_def.scope_def_typ, scope_def.scope_def_io)
acc)
acc states)
scope.scope_vars ScopeVar.Map.empty
in
let pos = Marked.get_mark (ScopeName.get_info scope.scope_uid) in
{
Ast.scope_decl_name = scope.scope_uid;
Ast.scope_decl_rules;
Ast.scope_sig;
Ast.scope_mark = Untyped { pos };
}
(** {1 API} *)
let translate_program (pgrm : Desugared.Ast.program) : untyped Ast.program =
(* First we give mappings to all the locations between Desugared and This
involves creating a new Scopelang scope variable for every state of a
Desugared variable. *)
let ctx =
(* Todo: since we rename all scope vars at this point, it would be better to
have different types for Desugared.ScopeVar.t and Scopelang.ScopeVar.t *)
ScopeName.Map.fold
(fun _scope scope_decl ctx ->
ScopeVar.Map.fold
(fun scope_var (states : Desugared.Ast.var_or_states) ctx ->
let var_name, var_pos = ScopeVar.get_info scope_var in
let new_var =
match states with
| Desugared.Ast.WholeVar ->
WholeVar (ScopeVar.fresh (var_name, var_pos))
| States states ->
let var_prefix = var_name ^ "_" in
let state_var state =
ScopeVar.fresh
(Marked.map_under_mark (( ^ ) var_prefix)
(StateName.get_info state))
in
States (List.map (fun state -> state, state_var state) states)
in
{
ctx with
scope_var_mapping =
ScopeVar.Map.add scope_var new_var ctx.scope_var_mapping;
})
scope_decl.Desugared.Ast.scope_vars ctx)
pgrm.Desugared.Ast.program_scopes
{
scope_var_mapping = ScopeVar.Map.empty;
var_mapping = Var.Map.empty;
decl_ctx = pgrm.program_ctx;
}
in
let ctx_scopes =
ScopeName.Map.map
(fun out_str ->
let out_struct_fields =
ScopeVar.Map.fold
(fun var fld out_map ->
let var' =
match ScopeVar.Map.find var ctx.scope_var_mapping with
| WholeVar v -> v
| States l -> snd (List.hd (List.rev l))
in
ScopeVar.Map.add var' fld out_map)
out_str.out_struct_fields ScopeVar.Map.empty
in
{ out_str with out_struct_fields })
pgrm.Desugared.Ast.program_ctx.ctx_scopes
in
{
Ast.program_scopes =
ScopeName.Map.map (translate_scope ctx) pgrm.program_scopes;
program_ctx = { pgrm.program_ctx with ctx_scopes };
}

View File

@ -16,4 +16,4 @@
(** Translation from {!module: Desugared.Ast} to {!module: Scopelang.Ast} *) (** Translation from {!module: Desugared.Ast} to {!module: Scopelang.Ast} *)
val translate_program : Ast.program -> Shared_ast.untyped Scopelang.Ast.program val translate_program : Desugared.Ast.program -> Shared_ast.untyped Ast.program

View File

@ -14,7 +14,7 @@
License for the specific language governing permissions and limitations under License for the specific language governing permissions and limitations under
the License. *) the License. *)
open Utils open Catala_utils
open Shared_ast open Shared_ast
open Ast open Ast
@ -22,21 +22,22 @@ let struc
ctx ctx
(fmt : Format.formatter) (fmt : Format.formatter)
(name : StructName.t) (name : StructName.t)
(fields : (StructFieldName.t * typ) list) : unit = (fields : typ StructField.Map.t) : unit =
Format.fprintf fmt "%a %a %a %a@\n@[<hov 2> %a@]@\n%a" Print.keyword "struct" Format.fprintf fmt "%a %a %a %a@\n@[<hov 2> %a@]@\n%a" Print.keyword "struct"
StructName.format_t name Print.punctuation "=" Print.punctuation "{" StructName.format_t name Print.punctuation "=" Print.punctuation "{"
(Format.pp_print_list (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n") ~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n")
(fun fmt (field_name, typ) -> (fun fmt (field_name, typ) ->
Format.fprintf fmt "%a%a %a" StructFieldName.format_t field_name Format.fprintf fmt "%a%a %a" StructField.format_t field_name
Print.punctuation ":" (Print.typ ctx) typ)) Print.punctuation ":" (Print.typ ctx) typ))
fields Print.punctuation "}" (StructField.Map.bindings fields)
Print.punctuation "}"
let enum let enum
ctx ctx
(fmt : Format.formatter) (fmt : Format.formatter)
(name : EnumName.t) (name : EnumName.t)
(cases : (EnumConstructor.t * typ) list) : unit = (cases : typ EnumConstructor.Map.t) : unit =
Format.fprintf fmt "%a %a %a @\n@[<hov 2> %a@]" Print.keyword "enum" Format.fprintf fmt "%a %a %a @\n@[<hov 2> %a@]" Print.keyword "enum"
EnumName.format_t name Print.punctuation "=" EnumName.format_t name Print.punctuation "="
(Format.pp_print_list (Format.pp_print_list
@ -45,7 +46,7 @@ let enum
Format.fprintf fmt "%a %a%a %a" Print.punctuation "|" Format.fprintf fmt "%a %a%a %a" Print.punctuation "|"
EnumConstructor.format_t field_name Print.punctuation ":" EnumConstructor.format_t field_name Print.punctuation ":"
(Print.typ ctx) typ)) (Print.typ ctx) typ))
cases (EnumConstructor.Map.bindings cases)
let scope ?(debug = false) ctx fmt (name, decl) = let scope ?(debug = false) ctx fmt (name, decl) =
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@ %a@ %a@]@\n@[<v 2> %a@]" Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@ %a@ %a@]@\n@[<v 2> %a@]"
@ -55,16 +56,16 @@ let scope ?(debug = false) ctx fmt (name, decl) =
Format.fprintf fmt "%a%a%a %a%a%a%a%a" Print.punctuation "(" Format.fprintf fmt "%a%a%a %a%a%a%a%a" Print.punctuation "("
ScopeVar.format_t scope_var Print.punctuation ":" (Print.typ ctx) typ ScopeVar.format_t scope_var Print.punctuation ":" (Print.typ ctx) typ
Print.punctuation "|" Print.keyword Print.punctuation "|" Print.keyword
(match Marked.unmark vis.io_input with (match Marked.unmark vis.Desugared.Ast.io_input with
| NoInput -> "internal" | NoInput -> "internal"
| OnlyInput -> "input" | OnlyInput -> "input"
| Reentrant -> "context") | Reentrant -> "context")
(if Marked.unmark vis.io_output then fun fmt () -> (if Marked.unmark vis.Desugared.Ast.io_output then fun fmt () ->
Format.fprintf fmt "%a@,%a" Print.punctuation "|" Print.keyword Format.fprintf fmt "%a@,%a" Print.punctuation "|" Print.keyword
"output" "output"
else fun fmt () -> Format.fprintf fmt "@<0>") else fun fmt () -> Format.fprintf fmt "@<0>")
() Print.punctuation ")")) () Print.punctuation ")"))
(ScopeVarMap.bindings decl.scope_sig) (ScopeVar.Map.bindings decl.scope_sig)
Print.punctuation "=" Print.punctuation "="
(Format.pp_print_list (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "%a@ " Print.punctuation ";") ~pp_sep:(fun fmt () -> Format.fprintf fmt "%a@ " Print.punctuation ";")
@ -80,11 +81,11 @@ let scope ?(debug = false) ctx fmt (name, decl) =
| ScopelangScopeVar v -> ( | ScopelangScopeVar v -> (
match match
Marked.unmark Marked.unmark
(snd (ScopeVarMap.find (Marked.unmark v) decl.scope_sig)) (snd (ScopeVar.Map.find (Marked.unmark v) decl.scope_sig))
.io_input .io_input
with with
| Reentrant -> | Reentrant ->
Format.fprintf fmt "%a@ %a" Print.operator Format.fprintf fmt "%a@ %a" Print.op_style
"reentrant or by default" (Print.expr ~debug ctx) e "reentrant or by default" (Print.expr ~debug ctx) e
| _ -> Format.fprintf fmt "%a" (Print.expr ~debug ctx) e)) | _ -> Format.fprintf fmt "%a" (Print.expr ~debug ctx) e))
e e
@ -105,16 +106,16 @@ let program ?(debug : bool = false) (fmt : Format.formatter) (p : 'm program) :
Format.pp_print_cut fmt () Format.pp_print_cut fmt ()
in in
Format.pp_open_vbox fmt 0; Format.pp_open_vbox fmt 0;
StructMap.iter StructName.Map.iter
(fun n s -> (fun n s ->
struc ctx fmt n s; struc ctx fmt n s;
pp_sep fmt ()) pp_sep fmt ())
ctx.ctx_structs; ctx.ctx_structs;
EnumMap.iter EnumName.Map.iter
(fun n e -> (fun n e ->
enum ctx fmt n e; enum ctx fmt n e;
pp_sep fmt ()) pp_sep fmt ())
ctx.ctx_enums; ctx.ctx_enums;
Format.pp_print_list ~pp_sep (scope ~debug ctx) fmt Format.pp_print_list ~pp_sep (scope ~debug ctx) fmt
(ScopeMap.bindings p.program_scopes); (ScopeName.Map.bindings p.program_scopes);
Format.pp_close_box fmt () Format.pp_close_box fmt ()

View File

@ -20,59 +20,47 @@
(* Doesn't define values, so OK to have without an mli *) (* Doesn't define values, so OK to have without an mli *)
open Utils open Catala_utils
module Runtime = Runtime_ocaml.Runtime module Runtime = Runtime_ocaml.Runtime
module ScopeName = Uid.Gen ()
module StructName = Uid.Gen ()
module StructField = Uid.Gen ()
module EnumName = Uid.Gen ()
module EnumConstructor = Uid.Gen ()
module ScopeName : Uid.Id with type info = Uid.MarkedString.info = (** Only used by surface *)
Uid.Make (Uid.MarkedString) ()
module ScopeSet : Set.S with type elt = ScopeName.t = Set.Make (ScopeName) module RuleName = Uid.Gen ()
module ScopeMap : Map.S with type key = ScopeName.t = Map.Make (ScopeName) module LabelName = Uid.Gen ()
module StructName : Uid.Id with type info = Uid.MarkedString.info = (** Used for unresolved structs/maps in desugared *)
Uid.Make (Uid.MarkedString) ()
module StructFieldName : Uid.Id with type info = Uid.MarkedString.info = module IdentName = String
Uid.Make (Uid.MarkedString) ()
module StructMap : Map.S with type key = StructName.t = Map.Make (StructName)
module EnumName : Uid.Id with type info = Uid.MarkedString.info =
Uid.Make (Uid.MarkedString) ()
module EnumConstructor : Uid.Id with type info = Uid.MarkedString.info =
Uid.Make (Uid.MarkedString) ()
module EnumMap : Map.S with type key = EnumName.t = Map.Make (EnumName)
(** Only used by desugared/scopelang *) (** Only used by desugared/scopelang *)
module ScopeVar : Uid.Id with type info = Uid.MarkedString.info = module ScopeVar = Uid.Gen ()
Uid.Make (Uid.MarkedString) () module SubScopeName = Uid.Gen ()
module StateName = Uid.Gen ()
module ScopeVarSet : Set.S with type elt = ScopeVar.t = Set.Make (ScopeVar)
module ScopeVarMap : Map.S with type key = ScopeVar.t = Map.Make (ScopeVar)
module SubScopeName : Uid.Id with type info = Uid.MarkedString.info =
Uid.Make (Uid.MarkedString) ()
module SubScopeNameSet : Set.S with type elt = SubScopeName.t =
Set.Make (SubScopeName)
module SubScopeMap : Map.S with type key = SubScopeName.t =
Map.Make (SubScopeName)
module StructFieldMap : Map.S with type key = StructFieldName.t =
Map.Make (StructFieldName)
module EnumConstructorMap : Map.S with type key = EnumConstructor.t =
Map.Make (EnumConstructor)
module StateName : Uid.Id with type info = Uid.MarkedString.info =
Uid.Make (Uid.MarkedString) ()
(** {1 Abstract syntax tree} *) (** {1 Abstract syntax tree} *)
(** Define a common base type for the expressions in most passes of the compiler *)
type desugared = [ `Desugared ]
(** {2 Phantom types used to select relevant cases on the generic AST}
we instantiate them with a polymorphic variant to take advantage of
sub-typing. The values aren't actually used. *)
type scopelang = [ `Scopelang ]
type dcalc = [ `Dcalc ]
type lcalc = [ `Lcalc ]
type 'a any = [< desugared | scopelang | dcalc | lcalc ] as 'a
(** ['a any] is 'a, but adds the constraint that it should be restricted to
valid AST kinds *)
(** {2 Types} *) (** {2 Types} *)
type typ_lit = TBool | TUnit | TInt | TRat | TMoney | TDate | TDuration type typ_lit = TBool | TUnit | TInt | TRat | TMoney | TDate | TDuration
@ -94,33 +82,6 @@ and naked_typ =
type date = Runtime.date type date = Runtime.date
type duration = Runtime.duration type duration = Runtime.duration
type op_kind =
| KInt
| KRat
| KMoney
| KDate
| KDuration (** All ops don't have a KDate and KDuration. *)
type ternop = Fold
type binop =
| And
| Or
| Xor
| Add of op_kind
| Sub of op_kind
| Mult of op_kind
| Div of op_kind
| Lt of op_kind
| Lte of op_kind
| Gt of op_kind
| Gte of op_kind
| Eq
| Neq
| Map
| Concat
| Filter
type log_entry = type log_entry =
| VarDef of naked_typ | VarDef of naked_typ
(** During code generation, we need to know the type of the variable being (** During code generation, we need to know the type of the variable being
@ -129,35 +90,140 @@ type log_entry =
| EndCall | EndCall
| PosRecordIfTrueBool | PosRecordIfTrueBool
type unop = module Op = struct
| Not (** Classification of operators on how they should be typed *)
| Minus of op_kind
| Log of log_entry * Uid.MarkedString.info list
| Length
| IntToRat
| MoneyToRat
| RatToMoney
| GetDay
| GetMonth
| GetYear
| FirstDayOfMonth
| LastDayOfMonth
| RoundMoney
| RoundDecimal
type operator = Ternop of ternop | Binop of binop | Unop of unop type monomorphic =
| Monomorphic (** Operands and return types of the operator are fixed *)
type polymorphic =
| Polymorphic
(** The operator is truly polymorphic: it's the same runtime function
that may work on multiple types. We require that resolving the
argument types from right to left trivially resolves all type
variables declared in the operator type. *)
type overloaded =
| Overloaded
(** The operator is ambiguous and requires the types of its arguments to
be known before it can be typed, using a pre-defined table *)
type resolved =
| Resolved (** Explicit monomorphic versions of the overloaded operators *)
(** Classification of operators. This could be inlined in the definition of
[t] but is more concise this way *)
type (_, _) kind =
| Monomorphic : ('a any, monomorphic) kind
| Polymorphic : ('a any, polymorphic) kind
| Overloaded : ([< desugared ], overloaded) kind
| Resolved : ([< scopelang | dcalc | lcalc ], resolved) kind
type (_, _) t =
(* unary *)
(* * monomorphic *)
| Not : ('a any, monomorphic) t
| GetDay : ('a any, monomorphic) t
| GetMonth : ('a any, monomorphic) t
| GetYear : ('a any, monomorphic) t
| FirstDayOfMonth : ('a any, monomorphic) t
| LastDayOfMonth : ('a any, monomorphic) t
(* * polymorphic *)
| Length : ('a any, polymorphic) t
| Log : log_entry * Uid.MarkedString.info list -> ('a any, polymorphic) t
(* * overloaded *)
| Minus : (desugared, overloaded) t
| Minus_int : ([< scopelang | dcalc | lcalc ], resolved) t
| Minus_rat : ([< scopelang | dcalc | lcalc ], resolved) t
| Minus_mon : ([< scopelang | dcalc | lcalc ], resolved) t
| Minus_dur : ([< scopelang | dcalc | lcalc ], resolved) t
| ToRat : (desugared, overloaded) t
| ToRat_int : ([< scopelang | dcalc | lcalc ], resolved) t
| ToRat_mon : ([< scopelang | dcalc | lcalc ], resolved) t
| ToMoney : (desugared, overloaded) t
| ToMoney_rat : ([< scopelang | dcalc | lcalc ], resolved) t
| Round : (desugared, overloaded) t
| Round_rat : ([< scopelang | dcalc | lcalc ], resolved) t
| Round_mon : ([< scopelang | dcalc | lcalc ], resolved) t
(* binary *)
(* * monomorphic *)
| And : ('a any, monomorphic) t
| Or : ('a any, monomorphic) t
| Xor : ('a any, monomorphic) t
(* * polymorphic *)
| Eq : ('a any, polymorphic) t
| Map : ('a any, polymorphic) t
| Concat : ('a any, polymorphic) t
| Filter : ('a any, polymorphic) t
| Reduce : ('a any, polymorphic) t
(* * overloaded *)
| Add : (desugared, overloaded) t
| Add_int_int : ([< scopelang | dcalc | lcalc ], resolved) t
| Add_rat_rat : ([< scopelang | dcalc | lcalc ], resolved) t
| Add_mon_mon : ([< scopelang | dcalc | lcalc ], resolved) t
| Add_dat_dur : ([< scopelang | dcalc | lcalc ], resolved) t
| Add_dur_dur : ([< scopelang | dcalc | lcalc ], resolved) t
| Sub : (desugared, overloaded) t
| Sub_int_int : ([< scopelang | dcalc | lcalc ], resolved) t
| Sub_rat_rat : ([< scopelang | dcalc | lcalc ], resolved) t
| Sub_mon_mon : ([< scopelang | dcalc | lcalc ], resolved) t
| Sub_dat_dat : ([< scopelang | dcalc | lcalc ], resolved) t
| Sub_dat_dur : ([< scopelang | dcalc | lcalc ], resolved) t
| Sub_dur_dur : ([< scopelang | dcalc | lcalc ], resolved) t
| Mult : (desugared, overloaded) t
| Mult_int_int : ([< scopelang | dcalc | lcalc ], resolved) t
| Mult_rat_rat : ([< scopelang | dcalc | lcalc ], resolved) t
| Mult_mon_rat : ([< scopelang | dcalc | lcalc ], resolved) t
| Mult_dur_int : ([< scopelang | dcalc | lcalc ], resolved) t
| Div : (desugared, overloaded) t
| Div_int_int : ([< scopelang | dcalc | lcalc ], resolved) t
| Div_rat_rat : ([< scopelang | dcalc | lcalc ], resolved) t
| Div_mon_rat : ([< scopelang | dcalc | lcalc ], resolved) t
| Div_mon_mon : ([< scopelang | dcalc | lcalc ], resolved) t
| Lt : (desugared, overloaded) t
| Lt_int_int : ([< scopelang | dcalc | lcalc ], resolved) t
| Lt_rat_rat : ([< scopelang | dcalc | lcalc ], resolved) t
| Lt_mon_mon : ([< scopelang | dcalc | lcalc ], resolved) t
| Lt_dat_dat : ([< scopelang | dcalc | lcalc ], resolved) t
| Lt_dur_dur : ([< scopelang | dcalc | lcalc ], resolved) t
| Lte : (desugared, overloaded) t
| Lte_int_int : ([< scopelang | dcalc | lcalc ], resolved) t
| Lte_rat_rat : ([< scopelang | dcalc | lcalc ], resolved) t
| Lte_mon_mon : ([< scopelang | dcalc | lcalc ], resolved) t
| Lte_dat_dat : ([< scopelang | dcalc | lcalc ], resolved) t
| Lte_dur_dur : ([< scopelang | dcalc | lcalc ], resolved) t
| Gt : (desugared, overloaded) t
| Gt_int_int : ([< scopelang | dcalc | lcalc ], resolved) t
| Gt_rat_rat : ([< scopelang | dcalc | lcalc ], resolved) t
| Gt_mon_mon : ([< scopelang | dcalc | lcalc ], resolved) t
| Gt_dat_dat : ([< scopelang | dcalc | lcalc ], resolved) t
| Gt_dur_dur : ([< scopelang | dcalc | lcalc ], resolved) t
| Gte : (desugared, overloaded) t
| Gte_int_int : ([< scopelang | dcalc | lcalc ], resolved) t
| Gte_rat_rat : ([< scopelang | dcalc | lcalc ], resolved) t
| Gte_mon_mon : ([< scopelang | dcalc | lcalc ], resolved) t
| Gte_dat_dat : ([< scopelang | dcalc | lcalc ], resolved) t
| Gte_dur_dur : ([< scopelang | dcalc | lcalc ], resolved) t
(* Todo: Eq is not an overload at the moment, but it should be one. The
trick is that it needs generation of specific code for arrays, every
struct and enum: operators [Eq_structs of StructName.t], etc. *)
| Eq_int_int : ([< scopelang | dcalc | lcalc ], resolved) t
| Eq_rat_rat : ([< scopelang | dcalc | lcalc ], resolved) t
| Eq_mon_mon : ([< scopelang | dcalc | lcalc ], resolved) t
| Eq_dur_dur : ([< scopelang | dcalc | lcalc ], resolved) t
| Eq_dat_dat : ([< scopelang | dcalc | lcalc ], resolved) t
(* ternary *)
(* * polymorphic *)
| Fold : ('a any, polymorphic) t
end
type ('a, 'k) operator = ('a any, 'k) Op.t
type except = ConflictError | EmptyError | NoValueProvided | Crash type except = ConflictError | EmptyError | NoValueProvided | Crash
(** {2 Generic expressions} *) (** {2 Generic expressions} *)
(** Define a common base type for the expressions in most passes of the compiler *) (** Define a common base type for the expressions in most passes of the compiler *)
type desugared = [ `Desugared ]
type scopelang = [ `Scopelang ]
type dcalc = [ `Dcalc ]
type lcalc = [ `Lcalc ]
type 'a any = [< desugared | scopelang | dcalc | lcalc ] as 'a
(** Literals are the same throughout compilation except for the [LEmptyError] (** Literals are the same throughout compilation except for the [LEmptyError]
case which is eliminated midway through. *) case which is eliminated midway through. *)
type 'a glit = type 'a glit =
@ -192,65 +258,98 @@ type ('a, 't) gexpr = (('a, 't) naked_gexpr, 't) Marked.t
- To write a function that handles cases from different ASTs, explicit the - To write a function that handles cases from different ASTs, explicit the
type variables: [fun (type a) (x: a naked_gexpr) -> ...] type variables: [fun (type a) (x: a naked_gexpr) -> ...]
- For recursive functions, you may need to additionally explicit the - For recursive functions, you may need to additionally explicit the
generalisation of the variable: [let rec f: type a . a naked_gexpr -> ...] *) generalisation of the variable: [let rec f: type a . a naked_gexpr -> ...]
- Always think of using the pre-defined map/fold functions in [Expr] rather
than completely defining your recursion manually. *)
and ('a, 't) naked_gexpr = and ('a, 't) naked_gexpr =
(* Constructors common to all ASTs *) (* Constructors common to all ASTs *)
| ELit : 'a glit -> ('a any, 't) naked_gexpr | ELit : 'a glit -> ('a any, 't) naked_gexpr
| EApp : ('a, 't) gexpr * ('a, 't) gexpr list -> ('a any, 't) naked_gexpr | EApp : {
| EOp : operator -> ('a any, 't) naked_gexpr f : ('a, 't) gexpr;
args : ('a, 't) gexpr list;
}
-> ('a any, 't) naked_gexpr
| EOp : { op : ('a, _) operator; tys : typ list } -> ('a any, 't) naked_gexpr
| EArray : ('a, 't) gexpr list -> ('a any, 't) naked_gexpr | EArray : ('a, 't) gexpr list -> ('a any, 't) naked_gexpr
| EVar : ('a, 't) naked_gexpr Bindlib.var -> ('a any, 't) naked_gexpr | EVar : ('a, 't) naked_gexpr Bindlib.var -> ('a any, 't) naked_gexpr
| EAbs : | EAbs : {
(('a, 't) naked_gexpr, ('a, 't) gexpr) Bindlib.mbinder * typ list binder : (('a, 't) naked_gexpr, ('a, 't) gexpr) Bindlib.mbinder;
tys : typ list;
}
-> ('a any, 't) naked_gexpr -> ('a any, 't) naked_gexpr
| EIfThenElse : | EIfThenElse : {
('a, 't) gexpr * ('a, 't) gexpr * ('a, 't) gexpr cond : ('a, 't) gexpr;
etrue : ('a, 't) gexpr;
efalse : ('a, 't) gexpr;
}
-> ('a any, 't) naked_gexpr
| EStruct : {
name : StructName.t;
fields : ('a, 't) gexpr StructField.Map.t;
}
-> ('a any, 't) naked_gexpr
| EInj : {
name : EnumName.t;
e : ('a, 't) gexpr;
cons : EnumConstructor.t;
}
-> ('a any, 't) naked_gexpr
| EMatch : {
name : EnumName.t;
e : ('a, 't) gexpr;
cases : ('a, 't) gexpr EnumConstructor.Map.t;
}
-> ('a any, 't) naked_gexpr -> ('a any, 't) naked_gexpr
(* Early stages *) (* Early stages *)
| ELocation : | ELocation :
'a glocation 'a glocation
-> (([< desugared | scopelang ] as 'a), 't) naked_gexpr -> (([< desugared | scopelang ] as 'a), 't) naked_gexpr
| EStruct : | EScopeCall : {
StructName.t * ('a, 't) gexpr StructFieldMap.t scope : ScopeName.t;
-> (([< desugared | scopelang ] as 'a), 't) naked_gexpr args : ('a, 't) gexpr ScopeVar.Map.t;
| EStructAccess : }
('a, 't) gexpr * StructFieldName.t * StructName.t
-> (([< desugared | scopelang ] as 'a), 't) naked_gexpr
| EEnumInj :
('a, 't) gexpr * EnumConstructor.t * EnumName.t
-> (([< desugared | scopelang ] as 'a), 't) naked_gexpr
| EMatchS :
('a, 't) gexpr * EnumName.t * ('a, 't) gexpr EnumConstructorMap.t
-> (([< desugared | scopelang ] as 'a), 't) naked_gexpr
| EScopeCall :
ScopeName.t * ('a, 't) gexpr ScopeVarMap.t
-> (([< desugared | scopelang ] as 'a), 't) naked_gexpr -> (([< desugared | scopelang ] as 'a), 't) naked_gexpr
| EDStructAccess : {
name_opt : StructName.t option;
e : ('a, 't) gexpr;
field : IdentName.t;
}
-> ((desugared as 'a), 't) naked_gexpr
(** [desugared] has ambiguous struct fields *)
| EStructAccess : {
name : StructName.t;
e : ('a, 't) gexpr;
field : StructField.t;
}
-> (([< scopelang | dcalc | lcalc ] as 'a), 't) naked_gexpr
(** Resolved struct/enums, after [desugared] *)
(* Lambda-like *) (* Lambda-like *)
| ETuple :
('a, 't) gexpr list * StructName.t option
-> (([< dcalc | lcalc ] as 'a), 't) naked_gexpr
| ETupleAccess :
('a, 't) gexpr * int * StructName.t option * typ list
-> (([< dcalc | lcalc ] as 'a), 't) naked_gexpr
| EInj :
('a, 't) gexpr * int * EnumName.t * typ list
-> (([< dcalc | lcalc ] as 'a), 't) naked_gexpr
| EMatch :
('a, 't) gexpr * ('a, 't) gexpr list * EnumName.t
-> (([< dcalc | lcalc ] as 'a), 't) naked_gexpr
| EAssert : ('a, 't) gexpr -> (([< dcalc | lcalc ] as 'a), 't) naked_gexpr | EAssert : ('a, 't) gexpr -> (([< dcalc | lcalc ] as 'a), 't) naked_gexpr
(* Default terms *) (* Default terms *)
| EDefault : | EDefault : {
('a, 't) gexpr list * ('a, 't) gexpr * ('a, 't) gexpr excepts : ('a, 't) gexpr list;
just : ('a, 't) gexpr;
cons : ('a, 't) gexpr;
}
-> (([< desugared | scopelang | dcalc ] as 'a), 't) naked_gexpr -> (([< desugared | scopelang | dcalc ] as 'a), 't) naked_gexpr
| ErrorOnEmpty : | EErrorOnEmpty :
('a, 't) gexpr ('a, 't) gexpr
-> (([< desugared | scopelang | dcalc ] as 'a), 't) naked_gexpr -> (([< desugared | scopelang | dcalc ] as 'a), 't) naked_gexpr
(* Lambda calculus with exceptions *) (* Lambda calculus with exceptions *)
| ETuple : ('a, 't) gexpr list -> ((lcalc as 'a), 't) naked_gexpr
| ETupleAccess : {
e : ('a, 't) gexpr;
index : int;
size : int;
}
-> ((lcalc as 'a), 't) naked_gexpr
| ERaise : except -> ((lcalc as 'a), 't) naked_gexpr | ERaise : except -> ((lcalc as 'a), 't) naked_gexpr
| ECatch : | ECatch : {
('a, 't) gexpr * except * ('a, 't) gexpr body : ('a, 't) gexpr;
exn : except;
handler : ('a, 't) gexpr;
}
-> ((lcalc as 'a), 't) naked_gexpr -> ((lcalc as 'a), 't) naked_gexpr
type ('a, 't) boxed_gexpr = (('a, 't) naked_gexpr Bindlib.box, 't) Marked.t type ('a, 't) boxed_gexpr = (('a, 't) naked_gexpr Bindlib.box, 't) Marked.t
@ -276,9 +375,9 @@ type typed = { pos : Pos.t; ty : typ }
(** The generic type of AST markings. Using a GADT allows functions to be (** The generic type of AST markings. Using a GADT allows functions to be
polymorphic in the marking, but still do transformations on types when polymorphic in the marking, but still do transformations on types when
appropriate. Expected to fill the ['t] parameter of [naked_gexpr] and appropriate. Expected to fill the ['t] parameter of [gexpr] and [gexpr] (a
[gexpr] (a ['t] annotation different from this type is used in the middle of ['t] annotation different from this type is used in the middle of the typing
the typing processing, but all visible ASTs should otherwise use this. *) processing, but all visible ASTs should otherwise use this. *)
type _ mark = Untyped : untyped -> untyped mark | Typed : typed -> typed mark type _ mark = Untyped : untyped -> untyped mark | Typed : typed -> typed mark
(** Useful for errors and printing, for example *) (** Useful for errors and printing, for example *)
@ -287,11 +386,10 @@ type any_expr = AnyExpr : (_, _ mark) gexpr -> any_expr
(** {2 Higher-level program structure} *) (** {2 Higher-level program structure} *)
(** Constructs scopes and programs on top of expressions. The ['e] type (** Constructs scopes and programs on top of expressions. The ['e] type
parameter throughout is expected to match instances of the [naked_gexpr] parameter throughout is expected to match instances of the [gexpr] type
type defined above. Markings are constrained to the [mark] GADT defined defined above. Markings are constrained to the [mark] GADT defined above.
above. Note that this structure is at the moment only relevant for [dcalc] Note that this structure is at the moment only relevant for [dcalc] and
and [lcalc], as [scopelang] has its own scope structure, as the name [lcalc], as [scopelang] has its own scope structure, as the name implies. *)
implies. *)
(** This kind annotation signals that the let-binding respects a structural (** This kind annotation signals that the let-binding respects a structural
invariant. These invariants concern the shape of the expression in the invariant. These invariants concern the shape of the expression in the
@ -350,14 +448,20 @@ and 'e scopes =
| ScopeDef of 'e scope_def | ScopeDef of 'e scope_def
constraint 'e = (_ any, _ mark) gexpr constraint 'e = (_ any, _ mark) gexpr
type struct_ctx = (StructFieldName.t * typ) list StructMap.t type struct_ctx = typ StructField.Map.t StructName.Map.t
type enum_ctx = (EnumConstructor.t * typ) list EnumMap.t type enum_ctx = typ EnumConstructor.Map.t EnumName.Map.t
type scope_out_struct = {
out_struct_name : StructName.t;
out_struct_fields : StructField.t ScopeVar.Map.t;
}
type decl_ctx = { type decl_ctx = {
ctx_enums : enum_ctx; ctx_enums : enum_ctx;
ctx_structs : struct_ctx; ctx_structs : struct_ctx;
ctx_scopes : StructName.t ScopeMap.t; ctx_struct_fields : StructField.t StructName.Map.t IdentName.Map.t;
(** The output structure type of every scope *) (** needed for disambiguation (desugared -> scope) *)
ctx_scopes : scope_out_struct ScopeName.Map.t;
} }
type 'e program = { decl_ctx : decl_ctx; scopes : 'e scopes } type 'e program = { decl_ctx : decl_ctx; scopes : 'e scopes }

View File

@ -3,4 +3,4 @@
(public_name catala.shared_ast) (public_name catala.shared_ast)
(flags (flags
(:standard -short-paths)) (:standard -short-paths))
(libraries bindlib unionFind utils catala.runtime_ocaml)) (libraries bindlib unionFind catala_utils catala.runtime_ocaml))

View File

@ -15,7 +15,7 @@
License for the specific language governing permissions and limitations under License for the specific language governing permissions and limitations under
the License. *) the License. *)
open Utils open Catala_utils
open Definitions open Definitions
(** Functions handling the types of [shared_ast] *) (** Functions handling the types of [shared_ast] *)
@ -57,15 +57,15 @@ module Box = struct
fun em -> fun em ->
B.box_apply (fun e -> Marked.mark (Marked.get_mark em) e) (Marked.unmark em) B.box_apply (fun e -> Marked.mark (Marked.get_mark em) e) (Marked.unmark em)
module LiftStruct = Bindlib.Lift (StructFieldMap) module LiftStruct = Bindlib.Lift (StructField.Map)
let lift_struct = LiftStruct.lift_box let lift_struct = LiftStruct.lift_box
module LiftEnum = Bindlib.Lift (EnumConstructorMap) module LiftEnum = Bindlib.Lift (EnumConstructor.Map)
let lift_enum = LiftEnum.lift_box let lift_enum = LiftEnum.lift_box
module LiftScopeVars = Bindlib.Lift (ScopeVarMap) module LiftScopeVars = Bindlib.Lift (ScopeVar.Map)
let lift_scope_vars = LiftScopeVars.lift_box let lift_scope_vars = LiftScopeVars.lift_box
end end
@ -76,61 +76,64 @@ let subst binder vars =
Bindlib.msubst binder (Array.of_list (List.map Marked.unmark vars)) Bindlib.msubst binder (Array.of_list (List.map Marked.unmark vars))
let evar v mark = Marked.mark mark (Bindlib.box_var v) let evar v mark = Marked.mark mark (Bindlib.box_var v)
let etuple args s = Box.appn args @@ fun args -> ETuple (args, s) let etuple args = Box.appn args @@ fun args -> ETuple args
let etupleaccess e1 i s typs = let etupleaccess e index size =
Box.app1 e1 @@ fun e1 -> ETupleAccess (e1, i, s, typs) assert (index < size);
Box.app1 e @@ fun e -> ETupleAccess { e; index; size }
let einj e1 i e_name typs = Box.app1 e1 @@ fun e1 -> EInj (e1, i, e_name, typs)
let ematch arg arms e_name =
Box.app1n arg arms @@ fun arg arms -> EMatch (arg, arms, e_name)
let earray args = Box.appn args @@ fun args -> EArray args let earray args = Box.appn args @@ fun args -> EArray args
let elit l mark = Marked.mark mark (Bindlib.box (ELit l)) let elit l mark = Marked.mark mark (Bindlib.box (ELit l))
let eabs binder typs mark = let eabs binder tys mark =
Bindlib.box_apply (fun binder -> EAbs (binder, typs)) binder, mark Bindlib.box_apply (fun binder -> EAbs { binder; tys }) binder, mark
let eapp e1 args = Box.app1n e1 args @@ fun e1 args -> EApp (e1, args) let eapp f args = Box.app1n f args @@ fun f args -> EApp { f; args }
let eassert e1 = Box.app1 e1 @@ fun e1 -> EAssert e1 let eassert e1 = Box.app1 e1 @@ fun e1 -> EAssert e1
let eop op = Box.app0 @@ EOp op let eop op tys = Box.app0 @@ EOp { op; tys }
let edefault excepts just cons = let edefault excepts just cons =
Box.app2n just cons excepts Box.app2n just cons excepts
@@ fun just cons excepts -> EDefault (excepts, just, cons) @@ fun just cons excepts -> EDefault { excepts; just; cons }
let eifthenelse e1 e2 e3 = let eifthenelse cond etrue efalse =
Box.app3 e1 e2 e3 @@ fun e1 e2 e3 -> EIfThenElse (e1, e2, e3) Box.app3 cond etrue efalse
@@ fun cond etrue efalse -> EIfThenElse { cond; etrue; efalse }
let eerroronempty e1 = Box.app1 e1 @@ fun e1 -> ErrorOnEmpty e1 let eerroronempty e1 = Box.app1 e1 @@ fun e1 -> EErrorOnEmpty e1
let eraise e1 = Box.app0 @@ ERaise e1 let eraise e1 = Box.app0 @@ ERaise e1
let ecatch e1 exn e2 = Box.app2 e1 e2 @@ fun e1 e2 -> ECatch (e1, exn, e2)
let ecatch body exn handler =
Box.app2 body handler @@ fun body handler -> ECatch { body; exn; handler }
let elocation loc = Box.app0 @@ ELocation loc let elocation loc = Box.app0 @@ ELocation loc
let estruct name (fields : ('a, 't) boxed_gexpr StructFieldMap.t) mark = let estruct name (fields : ('a, 't) boxed_gexpr StructField.Map.t) mark =
Marked.mark mark Marked.mark mark
@@ Bindlib.box_apply @@ Bindlib.box_apply
(fun fields -> EStruct (name, fields)) (fun fields -> EStruct { name; fields })
(Box.lift_struct (StructFieldMap.map Box.lift fields)) (Box.lift_struct (StructField.Map.map Box.lift fields))
let estructaccess e1 field struc = let edstructaccess e field name_opt =
Box.app1 e1 @@ fun e1 -> EStructAccess (e1, field, struc) Box.app1 e @@ fun e -> EDStructAccess { name_opt; e; field }
let eenuminj e1 cons enum = Box.app1 e1 @@ fun e1 -> EEnumInj (e1, cons, enum) let estructaccess e field name =
Box.app1 e @@ fun e -> EStructAccess { name; e; field }
let ematchs e1 enum cases mark = let einj e cons name = Box.app1 e @@ fun e -> EInj { name; e; cons }
let ematch e name cases mark =
Marked.mark mark Marked.mark mark
@@ Bindlib.box_apply2 @@ Bindlib.box_apply2
(fun e1 cases -> EMatchS (e1, enum, cases)) (fun e cases -> EMatch { name; e; cases })
(Box.lift e1) (Box.lift e)
(Box.lift_enum (EnumConstructorMap.map Box.lift cases)) (Box.lift_enum (EnumConstructor.Map.map Box.lift cases))
let escopecall scope_name fields mark = let escopecall scope args mark =
Marked.mark mark Marked.mark mark
@@ Bindlib.box_apply @@ Bindlib.box_apply
(fun fields -> EScopeCall (scope_name, fields)) (fun args -> EScopeCall { scope; args })
(Box.lift_scope_vars (ScopeVarMap.map Box.lift fields)) (Box.lift_scope_vars (ScopeVar.Map.map Box.lift args))
(* - Manipulation of marks - *) (* - Manipulation of marks - *)
@ -203,49 +206,46 @@ let maybe_ty (type m) ?(typ = TAny) (m : m mark) : typ =
(* shallow map *) (* shallow map *)
let map let map
(type a) (type a)
(ctx : 'ctx) ~(f : (a, 'm1) gexpr -> (a, 'm2) boxed_gexpr)
~(f : 'ctx -> (a, 'm1) gexpr -> (a, 'm2) boxed_gexpr)
(e : ((a, 'm1) naked_gexpr, 'm2) Marked.t) : (a, 'm2) boxed_gexpr = (e : ((a, 'm1) naked_gexpr, 'm2) Marked.t) : (a, 'm2) boxed_gexpr =
let m = Marked.get_mark e in let m = Marked.get_mark e in
match Marked.unmark e with match Marked.unmark e with
| ELit l -> elit l m | ELit l -> elit l m
| EApp (e1, args) -> eapp (f ctx e1) (List.map (f ctx) args) m | EApp { f = e1; args } -> eapp (f e1) (List.map f args) m
| EOp op -> eop op m | EOp { op; tys } -> eop op tys m
| EArray args -> earray (List.map (f ctx) args) m | EArray args -> earray (List.map f args) m
| EVar v -> evar (Var.translate v) m | EVar v -> evar (Var.translate v) m
| EAbs (binder, typs) -> | EAbs { binder; tys } ->
let vars, body = Bindlib.unmbind binder in let vars, body = Bindlib.unmbind binder in
let body = f ctx body in let body = f body in
let binder = bind (Array.map Var.translate vars) body in let binder = bind (Array.map Var.translate vars) body in
eabs binder typs m eabs binder tys m
| EIfThenElse (e1, e2, e3) -> | EIfThenElse { cond; etrue; efalse } ->
eifthenelse ((f ctx) e1) ((f ctx) e2) ((f ctx) e3) m eifthenelse (f cond) (f etrue) (f efalse) m
| ETuple (args, s) -> etuple (List.map (f ctx) args) s m | ETuple args -> etuple (List.map f args) m
| ETupleAccess (e1, n, s_name, typs) -> | ETupleAccess { e; index; size } -> etupleaccess (f e) index size m
etupleaccess ((f ctx) e1) n s_name typs m | EInj { e; name; cons } -> einj (f e) cons name m
| EInj (e1, i, e_name, typs) -> einj ((f ctx) e1) i e_name typs m | EAssert e1 -> eassert (f e1) m
| EMatch (arg, arms, e_name) -> | EDefault { excepts; just; cons } ->
ematch ((f ctx) arg) (List.map (f ctx) arms) e_name m edefault (List.map f excepts) (f just) (f cons) m
| EAssert e1 -> eassert ((f ctx) e1) m | EErrorOnEmpty e1 -> eerroronempty (f e1) m
| EDefault (excepts, just, cons) -> | ECatch { body; exn; handler } -> ecatch (f body) exn (f handler) m
edefault (List.map (f ctx) excepts) ((f ctx) just) ((f ctx) cons) m
| ErrorOnEmpty e1 -> eerroronempty ((f ctx) e1) m
| ECatch (e1, exn, e2) -> ecatch (f ctx e1) exn (f ctx e2) m
| ERaise exn -> eraise exn m | ERaise exn -> eraise exn m
| ELocation loc -> elocation loc m | ELocation loc -> elocation loc m
| EStruct (name, fields) -> | EStruct { name; fields } ->
let fields = StructFieldMap.map (f ctx) fields in let fields = StructField.Map.map f fields in
estruct name fields m estruct name fields m
| EStructAccess (e1, field, struc) -> estructaccess (f ctx e1) field struc m | EDStructAccess { e; field; name_opt } ->
| EEnumInj (e1, cons, enum) -> eenuminj (f ctx e1) cons enum m edstructaccess (f e) field name_opt m
| EMatchS (e1, enum, cases) -> | EStructAccess { e; field; name } -> estructaccess (f e) field name m
let cases = EnumConstructorMap.map (f ctx) cases in | EMatch { e; name; cases } ->
ematchs (f ctx e1) enum cases m let cases = EnumConstructor.Map.map f cases in
| EScopeCall (scope_name, fields) -> ematch (f e) name cases m
let fields = ScopeVarMap.map (f ctx) fields in | EScopeCall { scope; args } ->
escopecall scope_name fields m let fields = ScopeVar.Map.map f args in
escopecall scope fields m
let rec map_top_down ~f e = map () ~f:(fun () -> map_top_down ~f) (f e) let rec map_top_down ~f e = map ~f:(map_top_down ~f) (f e)
let map_marks ~f e = let map_marks ~f e =
map_top_down ~f:(fun e -> Marked.(mark (f (get_mark e)) (unmark e))) e map_top_down ~f:(fun e -> Marked.(mark (f (get_mark e)) (unmark e))) e
@ -260,31 +260,130 @@ let shallow_fold
let lfold x acc = List.fold_left (fun acc x -> f x acc) acc x in let lfold x acc = List.fold_left (fun acc x -> f x acc) acc x in
match Marked.unmark e with match Marked.unmark e with
| ELit _ | EOp _ | EVar _ | ERaise _ | ELocation _ -> acc | ELit _ | EOp _ | EVar _ | ERaise _ | ELocation _ -> acc
| EApp (e1, args) -> acc |> f e1 |> lfold args | EApp { f = e; args } -> acc |> f e |> lfold args
| EArray args -> acc |> lfold args | EArray args -> acc |> lfold args
| EAbs _ -> acc | EAbs _ -> acc
| EIfThenElse (e1, e2, e3) -> acc |> f e1 |> f e2 |> f e3 | EIfThenElse { cond; etrue; efalse } -> acc |> f cond |> f etrue |> f efalse
| ETuple (args, _) -> acc |> lfold args | ETuple args -> acc |> lfold args
| ETupleAccess (e1, _, _, _) -> acc |> f e1 | ETupleAccess { e; _ } -> acc |> f e
| EInj (e1, _, _, _) -> acc |> f e1 | EInj { e; _ } -> acc |> f e
| EMatch (arg, arms, _) -> acc |> f arg |> lfold arms | EAssert e -> acc |> f e
| EAssert e1 -> acc |> f e1 | EDefault { excepts; just; cons } -> acc |> lfold excepts |> f just |> f cons
| EDefault (excepts, just, cons) -> acc |> lfold excepts |> f just |> f cons | EErrorOnEmpty e -> acc |> f e
| ErrorOnEmpty e1 -> acc |> f e1 | ECatch { body; handler; _ } -> acc |> f body |> f handler
| ECatch (e1, _, e2) -> acc |> f e1 |> f e2 | EStruct { fields; _ } -> acc |> StructField.Map.fold (fun _ -> f) fields
| EStruct (_, fields) -> acc |> StructFieldMap.fold (fun _ -> f) fields | EDStructAccess { e; _ } -> acc |> f e
| EStructAccess (e1, _, _) -> acc |> f e1 | EStructAccess { e; _ } -> acc |> f e
| EEnumInj (e1, _, _) -> acc |> f e1 | EMatch { e; cases; _ } ->
| EMatchS (e1, _, cases) -> acc |> f e |> EnumConstructor.Map.fold (fun _ -> f) cases
acc |> f e1 |> EnumConstructorMap.fold (fun _ -> f) cases | EScopeCall { args; _ } -> acc |> ScopeVar.Map.fold (fun _ -> f) args
| EScopeCall (_, fields) -> acc |> ScopeVarMap.fold (fun _ -> f) fields
(* Like [map], but also allows to gather a result bottom-up. *)
let map_gather
(type a)
~(acc : 'acc)
~(join : 'acc -> 'acc -> 'acc)
~(f : (a, 'm1) gexpr -> 'acc * (a, 'm2) boxed_gexpr)
(e : ((a, 'm1) naked_gexpr, 'm2) Marked.t) : 'acc * (a, 'm2) boxed_gexpr =
let m = Marked.get_mark e in
let lfoldmap es =
let acc, r_es =
List.fold_left
(fun (acc, es) e ->
let acc1, e = f e in
join acc acc1, e :: es)
(acc, []) es
in
acc, List.rev r_es
in
match Marked.unmark e with
| ELit l -> acc, elit l m
| EApp { f = e1; args } ->
let acc1, f = f e1 in
let acc2, args = lfoldmap args in
join acc1 acc2, eapp f args m
| EOp { op; tys } -> acc, eop op tys m
| EArray args ->
let acc, args = lfoldmap args in
acc, earray args m
| EVar v -> acc, evar (Var.translate v) m
| EAbs { binder; tys } ->
let vars, body = Bindlib.unmbind binder in
let acc, body = f body in
let binder = bind (Array.map Var.translate vars) body in
acc, eabs binder tys m
| EIfThenElse { cond; etrue; efalse } ->
let acc1, cond = f cond in
let acc2, etrue = f etrue in
let acc3, efalse = f efalse in
join (join acc1 acc2) acc3, eifthenelse cond etrue efalse m
| ETuple args ->
let acc, args = lfoldmap args in
acc, etuple args m
| ETupleAccess { e; index; size } ->
let acc, e = f e in
acc, etupleaccess e index size m
| EInj { e; name; cons } ->
let acc, e = f e in
acc, einj e cons name m
| EAssert e ->
let acc, e = f e in
acc, eassert e m
| EDefault { excepts; just; cons } ->
let acc1, excepts = lfoldmap excepts in
let acc2, just = f just in
let acc3, cons = f cons in
join (join acc1 acc2) acc3, edefault excepts just cons m
| EErrorOnEmpty e ->
let acc, e = f e in
acc, eerroronempty e m
| ECatch { body; exn; handler } ->
let acc1, body = f body in
let acc2, handler = f handler in
join acc1 acc2, ecatch body exn handler m
| ERaise exn -> acc, eraise exn m
| ELocation loc -> acc, elocation loc m
| EStruct { name; fields } ->
let acc, fields =
StructField.Map.fold
(fun cons e (acc, fields) ->
let acc1, e = f e in
join acc acc1, StructField.Map.add cons e fields)
fields
(acc, StructField.Map.empty)
in
acc, estruct name fields m
| EDStructAccess { e; field; name_opt } ->
let acc, e = f e in
acc, edstructaccess e field name_opt m
| EStructAccess { e; field; name } ->
let acc, e = f e in
acc, estructaccess e field name m
| EMatch { e; name; cases } ->
let acc, e = f e in
let acc, cases =
EnumConstructor.Map.fold
(fun cons e (acc, cases) ->
let acc1, e = f e in
join acc acc1, EnumConstructor.Map.add cons e cases)
cases
(acc, EnumConstructor.Map.empty)
in
acc, ematch e name cases m
| EScopeCall { scope; args } ->
let acc, args =
ScopeVar.Map.fold
(fun var e (acc, args) ->
let acc1, e = f e in
join acc acc1, ScopeVar.Map.add var e args)
args (acc, ScopeVar.Map.empty)
in
acc, escopecall scope args m
(* - *) (* - *)
(** See [Bindlib.box_term] documentation for why we are doing that. *) (** See [Bindlib.box_term] documentation for why we are doing that. *)
let rebox e = let rec rebox e = map ~f:rebox e
let rec id_t () e = map () ~f:id_t e in
id_t () e
let box e = Marked.same_mark_as (Bindlib.box (Marked.unmark e)) e let box e = Marked.same_mark_as (Bindlib.box (Marked.unmark e)) e
let unbox (e, m) = Bindlib.unbox e, m let unbox (e, m) = Bindlib.unbox e, m
@ -297,99 +396,36 @@ let is_value (type a) (e : (a, _) gexpr) =
| ELit _ | EAbs _ | EOp _ | ERaise _ -> true | ELit _ | EAbs _ | EOp _ | ERaise _ -> true
| _ -> false | _ -> false
let equal_tlit l1 l2 = l1 = l2
let compare_tlit l1 l2 = Stdlib.compare l1 l2
let rec equal_typ ty1 ty2 =
match Marked.unmark ty1, Marked.unmark ty2 with
| TLit l1, TLit l2 -> equal_tlit l1 l2
| TTuple tys1, TTuple tys2 -> equal_typ_list tys1 tys2
| TStruct n1, TStruct n2 -> StructName.equal n1 n2
| TEnum n1, TEnum n2 -> EnumName.equal n1 n2
| TOption t1, TOption t2 -> equal_typ t1 t2
| TArrow (t1, t1'), TArrow (t2, t2') -> equal_typ t1 t2 && equal_typ t1' t2'
| TArray t1, TArray t2 -> equal_typ t1 t2
| TAny, TAny -> true
| ( ( TLit _ | TTuple _ | TStruct _ | TEnum _ | TOption _ | TArrow _
| TArray _ | TAny ),
_ ) ->
false
and equal_typ_list tys1 tys2 =
try List.for_all2 equal_typ tys1 tys2 with Invalid_argument _ -> false
(* Similar to [equal_typ], but allows TAny holes *)
let rec unifiable ty1 ty2 =
match Marked.unmark ty1, Marked.unmark ty2 with
| TAny, _ | _, TAny -> true
| TLit l1, TLit l2 -> equal_tlit l1 l2
| TTuple tys1, TTuple tys2 -> unifiable_list tys1 tys2
| TStruct n1, TStruct n2 -> StructName.equal n1 n2
| TEnum n1, TEnum n2 -> EnumName.equal n1 n2
| TOption t1, TOption t2 -> unifiable t1 t2
| TArrow (t1, t1'), TArrow (t2, t2') -> unifiable t1 t2 && unifiable t1' t2'
| TArray t1, TArray t2 -> unifiable t1 t2
| ( (TLit _ | TTuple _ | TStruct _ | TEnum _ | TOption _ | TArrow _ | TArray _),
_ ) ->
false
and unifiable_list tys1 tys2 =
try List.for_all2 unifiable tys1 tys2 with Invalid_argument _ -> false
let rec compare_typ ty1 ty2 =
match Marked.unmark ty1, Marked.unmark ty2 with
| TLit l1, TLit l2 -> compare_tlit l1 l2
| TTuple tys1, TTuple tys2 -> List.compare compare_typ tys1 tys2
| TStruct n1, TStruct n2 -> StructName.compare n1 n2
| TEnum en1, TEnum en2 -> EnumName.compare en1 en2
| TOption t1, TOption t2 -> compare_typ t1 t2
| TArrow (a1, b1), TArrow (a2, b2) -> (
match compare_typ a1 a2 with 0 -> compare_typ b1 b2 | n -> n)
| TArray t1, TArray t2 -> compare_typ t1 t2
| TAny, TAny -> 0
| TLit _, _ -> -1
| _, TLit _ -> 1
| TTuple _, _ -> -1
| _, TTuple _ -> 1
| TStruct _, _ -> -1
| _, TStruct _ -> 1
| TEnum _, _ -> -1
| _, TEnum _ -> 1
| TOption _, _ -> -1
| _, TOption _ -> 1
| TArrow _, _ -> -1
| _, TArrow _ -> 1
| TArray _, _ -> -1
| _, TArray _ -> 1
let equal_lit (type a) (l1 : a glit) (l2 : a glit) = let equal_lit (type a) (l1 : a glit) (l2 : a glit) =
let open Runtime.Oper in
match l1, l2 with match l1, l2 with
| LBool b1, LBool b2 -> Bool.equal b1 b2 | LBool b1, LBool b2 -> not (o_xor b1 b2)
| LEmptyError, LEmptyError -> true | LEmptyError, LEmptyError -> true
| LInt n1, LInt n2 -> Runtime.( =! ) n1 n2 | LInt n1, LInt n2 -> o_eq_int_int n1 n2
| LRat r1, LRat r2 -> Runtime.( =& ) r1 r2 | LRat r1, LRat r2 -> o_eq_rat_rat r1 r2
| LMoney m1, LMoney m2 -> Runtime.( =$ ) m1 m2 | LMoney m1, LMoney m2 -> o_eq_mon_mon m1 m2
| LUnit, LUnit -> true | LUnit, LUnit -> true
| LDate d1, LDate d2 -> Runtime.( =@ ) d1 d2 | LDate d1, LDate d2 -> o_eq_dat_dat d1 d2
| LDuration d1, LDuration d2 -> Runtime.( =^ ) d1 d2 | LDuration d1, LDuration d2 -> o_eq_dur_dur d1 d2
| ( ( LBool _ | LEmptyError | LInt _ | LRat _ | LMoney _ | LUnit | LDate _ | ( ( LBool _ | LEmptyError | LInt _ | LRat _ | LMoney _ | LUnit | LDate _
| LDuration _ ), | LDuration _ ),
_ ) -> _ ) ->
false false
let compare_lit (type a) (l1 : a glit) (l2 : a glit) = let compare_lit (type a) (l1 : a glit) (l2 : a glit) =
let open Runtime.Oper in
match l1, l2 with match l1, l2 with
| LBool b1, LBool b2 -> Bool.compare b1 b2 | LBool b1, LBool b2 -> Bool.compare b1 b2
| LEmptyError, LEmptyError -> 0 | LEmptyError, LEmptyError -> 0
| LInt n1, LInt n2 -> | LInt n1, LInt n2 ->
if Runtime.( <! ) n1 n2 then -1 else if Runtime.( =! ) n1 n2 then 0 else 1 if o_lt_int_int n1 n2 then -1 else if o_eq_int_int n1 n2 then 0 else 1
| LRat r1, LRat r2 -> | LRat r1, LRat r2 ->
if Runtime.( <& ) r1 r2 then -1 else if Runtime.( =& ) r1 r2 then 0 else 1 if o_lt_rat_rat r1 r2 then -1 else if o_eq_rat_rat r1 r2 then 0 else 1
| LMoney m1, LMoney m2 -> | LMoney m1, LMoney m2 ->
if Runtime.( <$ ) m1 m2 then -1 else if Runtime.( =$ ) m1 m2 then 0 else 1 if o_lt_mon_mon m1 m2 then -1 else if o_eq_mon_mon m1 m2 then 0 else 1
| LUnit, LUnit -> 0 | LUnit, LUnit -> 0
| LDate d1, LDate d2 -> | LDate d1, LDate d2 ->
if Runtime.( <@ ) d1 d2 then -1 else if Runtime.( =@ ) d1 d2 then 0 else 1 if o_lt_dat_dat d1 d2 then -1 else if o_eq_dat_dat d1 d2 then 0 else 1
| LDuration d1, LDuration d2 -> ( | LDuration d1, LDuration d2 -> (
(* Duration comparison in the runtime may fail, so rely on a basic (* Duration comparison in the runtime may fail, so rely on a basic
lexicographic comparison instead *) lexicographic comparison instead *)
@ -441,119 +477,6 @@ let compare_location
| _, SubScopeVar _ -> . | _, SubScopeVar _ -> .
let equal_location a b = compare_location a b = 0 let equal_location a b = compare_location a b = 0
let equal_log_entries l1 l2 =
match l1, l2 with
| VarDef t1, VarDef t2 -> equal_typ (t1, Pos.no_pos) (t2, Pos.no_pos)
| x, y -> x = y
let compare_log_entries l1 l2 =
match l1, l2 with
| VarDef t1, VarDef t2 -> compare_typ (t1, Pos.no_pos) (t2, Pos.no_pos)
| BeginCall, BeginCall
| EndCall, EndCall
| PosRecordIfTrueBool, PosRecordIfTrueBool ->
0
| VarDef _, _ -> -1
| _, VarDef _ -> 1
| BeginCall, _ -> -1
| _, BeginCall -> 1
| EndCall, _ -> -1
| _, EndCall -> 1
| PosRecordIfTrueBool, _ -> .
| _, PosRecordIfTrueBool -> .
(* let equal_op_kind = Stdlib.(=) *)
let compare_op_kind = Stdlib.compare
let equal_unops op1 op2 =
match op1, op2 with
(* Log entries contain a typ which contain position information, we thus need
to descend into them *)
| Log (l1, info1), Log (l2, info2) ->
equal_log_entries l1 l2 && List.equal Uid.MarkedString.equal info1 info2
| Log _, _ | _, Log _ -> false
(* All the other cases can be discharged through equality *)
| ( ( Not | Minus _ | Length | IntToRat | MoneyToRat | RatToMoney | GetDay
| GetMonth | GetYear | FirstDayOfMonth | LastDayOfMonth | RoundMoney
| RoundDecimal ),
_ ) ->
op1 = op2
let compare_unops op1 op2 =
match op1, op2 with
| Not, Not -> 0
| Minus k1, Minus k2 -> compare_op_kind k1 k2
| Log (l1, info1), Log (l2, info2) -> (
match compare_log_entries l1 l2 with
| 0 -> List.compare Uid.MarkedString.compare info1 info2
| n -> n)
| Length, Length
| IntToRat, IntToRat
| MoneyToRat, MoneyToRat
| RatToMoney, RatToMoney
| GetDay, GetDay
| GetMonth, GetMonth
| GetYear, GetYear
| FirstDayOfMonth, FirstDayOfMonth
| LastDayOfMonth, LastDayOfMonth
| RoundMoney, RoundMoney
| RoundDecimal, RoundDecimal ->
0
| Not, _ -> -1
| _, Not -> 1
| Minus _, _ -> -1
| _, Minus _ -> 1
| Log _, _ -> -1
| _, Log _ -> 1
| Length, _ -> -1
| _, Length -> 1
| IntToRat, _ -> -1
| _, IntToRat -> 1
| MoneyToRat, _ -> -1
| _, MoneyToRat -> 1
| RatToMoney, _ -> -1
| _, RatToMoney -> 1
| GetDay, _ -> -1
| _, GetDay -> 1
| GetMonth, _ -> -1
| _, GetMonth -> 1
| GetYear, _ -> -1
| _, GetYear -> 1
| FirstDayOfMonth, _ -> -1
| _, FirstDayOfMonth -> 1
| LastDayOfMonth, _ -> -1
| _, LastDayOfMonth -> 1
| RoundMoney, _ -> -1
| _, RoundMoney -> 1
| RoundDecimal, _ -> .
| _, RoundDecimal -> .
let equal_binop = Stdlib.( = )
let compare_binop = Stdlib.compare
let equal_ternop = Stdlib.( = )
let compare_ternop = Stdlib.compare
let equal_ops op1 op2 =
match op1, op2 with
| Ternop op1, Ternop op2 -> equal_ternop op1 op2
| Binop op1, Binop op2 -> equal_binop op1 op2
| Unop op1, Unop op2 -> equal_unops op1 op2
| _, _ -> false
let compare_op op1 op2 =
match op1, op2 with
| Ternop op1, Ternop op2 -> compare_ternop op1 op2
| Binop op1, Binop op2 -> compare_binop op1 op2
| Unop op1, Unop op2 -> compare_unops op1 op2
| Ternop _, _ -> -1
| _, Ternop _ -> 1
| Binop _, _ -> -1
| _, Binop _ -> 1
| Unop _, _ -> .
| _, Unop _ -> .
let equal_except ex1 ex2 = ex1 = ex2 let equal_except ex1 ex2 = ex1 = ex2
let compare_except ex1 ex2 = Stdlib.compare ex1 ex2 let compare_except ex1 ex2 = Stdlib.compare ex1 ex2
@ -567,50 +490,60 @@ and equal : type a. (a, 't) gexpr -> (a, 't) gexpr -> bool =
fun e1 e2 -> fun e1 e2 ->
match Marked.unmark e1, Marked.unmark e2 with match Marked.unmark e1, Marked.unmark e2 with
| EVar v1, EVar v2 -> Bindlib.eq_vars v1 v2 | EVar v1, EVar v2 -> Bindlib.eq_vars v1 v2
| ETuple (es1, n1), ETuple (es2, n2) -> n1 = n2 && equal_list es1 es2 | ETuple es1, ETuple es2 -> equal_list es1 es2
| ETupleAccess (e1, id1, n1, tys1), ETupleAccess (e2, id2, n2, tys2) -> | ( ETupleAccess { e = e1; index = id1; size = s1 },
equal e1 e2 && id1 = id2 && n1 = n2 && equal_typ_list tys1 tys2 ETupleAccess { e = e2; index = id2; size = s2 } ) ->
| EInj (e1, id1, n1, tys1), EInj (e2, id2, n2, tys2) -> s1 = s2 && equal e1 e2 && id1 = id2
equal e1 e2 && id1 = id2 && n1 = n2 && equal_typ_list tys1 tys2
| EMatch (e1, cases1, n1), EMatch (e2, cases2, n2) ->
n1 = n2 && equal e1 e2 && equal_list cases1 cases2
| EArray es1, EArray es2 -> equal_list es1 es2 | EArray es1, EArray es2 -> equal_list es1 es2
| ELit l1, ELit l2 -> l1 = l2 | ELit l1, ELit l2 -> l1 = l2
| EAbs (b1, tys1), EAbs (b2, tys2) -> | EAbs { binder = b1; tys = tys1 }, EAbs { binder = b2; tys = tys2 } ->
equal_typ_list tys1 tys2 Type.equal_list tys1 tys2
&& &&
let vars1, body1 = Bindlib.unmbind b1 in let vars1, body1 = Bindlib.unmbind b1 in
let body2 = Bindlib.msubst b2 (Array.map (fun x -> EVar x) vars1) in let body2 = Bindlib.msubst b2 (Array.map (fun x -> EVar x) vars1) in
equal body1 body2 equal body1 body2
| EApp (e1, args1), EApp (e2, args2) -> equal e1 e2 && equal_list args1 args2 | EApp { f = e1; args = args1 }, EApp { f = e2; args = args2 } ->
equal e1 e2 && equal_list args1 args2
| EAssert e1, EAssert e2 -> equal e1 e2 | EAssert e1, EAssert e2 -> equal e1 e2
| EOp op1, EOp op2 -> equal_ops op1 op2 | EOp { op = op1; tys = tys1 }, EOp { op = op2; tys = tys2 } ->
| EDefault (exc1, def1, cons1), EDefault (exc2, def2, cons2) -> Operator.equal op1 op2 && Type.equal_list tys1 tys2
| ( EDefault { excepts = exc1; just = def1; cons = cons1 },
EDefault { excepts = exc2; just = def2; cons = cons2 } ) ->
equal def1 def2 && equal cons1 cons2 && equal_list exc1 exc2 equal def1 def2 && equal cons1 cons2 && equal_list exc1 exc2
| EIfThenElse (if1, then1, else1), EIfThenElse (if2, then2, else2) -> | ( EIfThenElse { cond = if1; etrue = then1; efalse = else1 },
EIfThenElse { cond = if2; etrue = then2; efalse = else2 } ) ->
equal if1 if2 && equal then1 then2 && equal else1 else2 equal if1 if2 && equal then1 then2 && equal else1 else2
| ErrorOnEmpty e1, ErrorOnEmpty e2 -> equal e1 e2 | EErrorOnEmpty e1, EErrorOnEmpty e2 -> equal e1 e2
| ERaise ex1, ERaise ex2 -> equal_except ex1 ex2 | ERaise ex1, ERaise ex2 -> equal_except ex1 ex2
| ECatch (etry1, ex1, ewith1), ECatch (etry2, ex2, ewith2) -> | ( ECatch { body = etry1; exn = ex1; handler = ewith1 },
ECatch { body = etry2; exn = ex2; handler = ewith2 } ) ->
equal etry1 etry2 && equal_except ex1 ex2 && equal ewith1 ewith2 equal etry1 etry2 && equal_except ex1 ex2 && equal ewith1 ewith2
| ELocation l1, ELocation l2 -> | ELocation l1, ELocation l2 ->
equal_location (Marked.mark Pos.no_pos l1) (Marked.mark Pos.no_pos l2) equal_location (Marked.mark Pos.no_pos l1) (Marked.mark Pos.no_pos l2)
| EStruct (s1, fields1), EStruct (s2, fields2) -> | ( EStruct { name = s1; fields = fields1 },
StructName.equal s1 s2 && StructFieldMap.equal equal fields1 fields2 EStruct { name = s2; fields = fields2 } ) ->
| EStructAccess (e1, f1, s1), EStructAccess (e2, f2, s2) -> StructName.equal s1 s2 && StructField.Map.equal equal fields1 fields2
StructName.equal s1 s2 && StructFieldName.equal f1 f2 && equal e1 e2 | ( EDStructAccess { e = e1; field = f1; name_opt = s1 },
| EEnumInj (e1, c1, n1), EEnumInj (e2, c2, n2) -> EDStructAccess { e = e2; field = f2; name_opt = s2 } ) ->
Option.equal StructName.equal s1 s2 && IdentName.equal f1 f2 && equal e1 e2
| ( EStructAccess { e = e1; field = f1; name = s1 },
EStructAccess { e = e2; field = f2; name = s2 } ) ->
StructName.equal s1 s2 && StructField.equal f1 f2 && equal e1 e2
| EInj { e = e1; cons = c1; name = n1 }, EInj { e = e2; cons = c2; name = n2 }
->
EnumName.equal n1 n2 && EnumConstructor.equal c1 c2 && equal e1 e2 EnumName.equal n1 n2 && EnumConstructor.equal c1 c2 && equal e1 e2
| EMatchS (e1, n1, cases1), EMatchS (e2, n2, cases2) -> | ( EMatch { e = e1; name = n1; cases = cases1 },
EMatch { e = e2; name = n2; cases = cases2 } ) ->
EnumName.equal n1 n2 EnumName.equal n1 n2
&& equal e1 e2 && equal e1 e2
&& EnumConstructorMap.equal equal cases1 cases2 && EnumConstructor.Map.equal equal cases1 cases2
| EScopeCall (s1, fields1), EScopeCall (s2, fields2) -> | ( EScopeCall { scope = s1; args = fields1 },
ScopeName.equal s1 s2 && ScopeVarMap.equal equal fields1 fields2 EScopeCall { scope = s2; args = fields2 } ) ->
| ( ( EVar _ | ETuple _ | ETupleAccess _ | EInj _ | EMatch _ | EArray _ ScopeName.equal s1 s2 && ScopeVar.Map.equal equal fields1 fields2
| ELit _ | EAbs _ | EApp _ | EAssert _ | EOp _ | EDefault _ | ( ( EVar _ | ETuple _ | ETupleAccess _ | EArray _ | ELit _ | EAbs _ | EApp _
| EIfThenElse _ | ErrorOnEmpty _ | ERaise _ | ECatch _ | ELocation _ | EAssert _ | EOp _ | EDefault _ | EIfThenElse _ | EErrorOnEmpty _
| EStruct _ | EStructAccess _ | EEnumInj _ | EMatchS _ | EScopeCall _ ), | ERaise _ | ECatch _ | ELocation _ | EStruct _ | EDStructAccess _
| EStructAccess _ | EInj _ | EMatch _ | EScopeCall _ ),
_ ) -> _ ) ->
false false
@ -623,72 +556,76 @@ let rec compare : type a. (a, _) gexpr -> (a, _) gexpr -> int =
match[@ocamlformat "disable"] Marked.unmark e1, Marked.unmark e2 with match[@ocamlformat "disable"] Marked.unmark e1, Marked.unmark e2 with
| ELit l1, ELit l2 -> | ELit l1, ELit l2 ->
compare_lit l1 l2 compare_lit l1 l2
| EApp (f1, args1), EApp (f2, args2) -> | EApp {f=f1; args=args1}, EApp {f=f2; args=args2} ->
compare f1 f2 @@< fun () -> compare f1 f2 @@< fun () ->
List.compare compare args1 args2 List.compare compare args1 args2
| EOp op1, EOp op2 -> | EOp {op=op1; tys=tys1}, EOp {op=op2; tys=tys2} ->
compare_op op1 op2 Operator.compare op1 op2 @@< fun () ->
List.compare Type.compare tys1 tys2
| EArray a1, EArray a2 -> | EArray a1, EArray a2 ->
List.compare compare a1 a2 List.compare compare a1 a2
| EVar v1, EVar v2 -> | EVar v1, EVar v2 ->
Bindlib.compare_vars v1 v2 Bindlib.compare_vars v1 v2
| EAbs (binder1, typs1), EAbs (binder2, typs2) -> | EAbs {binder=binder1; tys=typs1},
List.compare compare_typ typs1 typs2 @@< fun () -> EAbs {binder=binder2; tys=typs2} ->
List.compare Type.compare typs1 typs2 @@< fun () ->
let _, e1, e2 = Bindlib.unmbind2 binder1 binder2 in let _, e1, e2 = Bindlib.unmbind2 binder1 binder2 in
compare e1 e2 compare e1 e2
| EIfThenElse (i1, t1, e1), EIfThenElse (i2, t2, e2) -> | EIfThenElse {cond=i1; etrue=t1; efalse=e1},
EIfThenElse {cond=i2; etrue=t2; efalse=e2} ->
compare i1 i2 @@< fun () -> compare i1 i2 @@< fun () ->
compare t1 t2 @@< fun () -> compare t1 t2 @@< fun () ->
compare e1 e2 compare e1 e2
| ELocation l1, ELocation l2 -> | ELocation l1, ELocation l2 ->
compare_location (Marked.mark Pos.no_pos l1) (Marked.mark Pos.no_pos l2) compare_location (Marked.mark Pos.no_pos l1) (Marked.mark Pos.no_pos l2)
| EStruct (name1, field_map1), EStruct (name2, field_map2) -> | EStruct {name=name1; fields=field_map1},
EStruct {name=name2; fields=field_map2} ->
StructName.compare name1 name2 @@< fun () -> StructName.compare name1 name2 @@< fun () ->
StructFieldMap.compare compare field_map1 field_map2 StructField.Map.compare compare field_map1 field_map2
| EStructAccess (e1, field_name1, struct_name1), | EDStructAccess {e=e1; field=field_name1; name_opt=struct_name1},
EStructAccess (e2, field_name2, struct_name2) -> EDStructAccess {e=e2; field=field_name2; name_opt=struct_name2} ->
compare e1 e2 @@< fun () -> compare e1 e2 @@< fun () ->
StructFieldName.compare field_name1 field_name2 @@< fun () -> IdentName.compare field_name1 field_name2 @@< fun () ->
Option.compare StructName.compare struct_name1 struct_name2
| EStructAccess {e=e1; field=field_name1; name=struct_name1},
EStructAccess {e=e2; field=field_name2; name=struct_name2} ->
compare e1 e2 @@< fun () ->
StructField.compare field_name1 field_name2 @@< fun () ->
StructName.compare struct_name1 struct_name2 StructName.compare struct_name1 struct_name2
| EEnumInj (e1, cstr1, name1), EEnumInj (e2, cstr2, name2) -> | EMatch {e=e1; name=name1; cases=emap1},
compare e1 e2 @@< fun () -> EMatch {e=e2; name=name2; cases=emap2} ->
EnumName.compare name1 name2 @@< fun () -> EnumName.compare name1 name2 @@< fun () ->
EnumConstructor.compare cstr1 cstr2
| EMatchS (e1, name1, emap1), EMatchS (e2, name2, emap2) ->
compare e1 e2 @@< fun () -> compare e1 e2 @@< fun () ->
EnumName.compare name1 name2 @@< fun () -> EnumConstructor.Map.compare compare emap1 emap2
EnumConstructorMap.compare compare emap1 emap2 | EScopeCall {scope=name1; args=field_map1},
| EScopeCall (name1, field_map1), EScopeCall (name2, field_map2) -> EScopeCall {scope=name2; args=field_map2} ->
ScopeName.compare name1 name2 @@< fun () -> ScopeName.compare name1 name2 @@< fun () ->
ScopeVarMap.compare compare field_map1 field_map2 ScopeVar.Map.compare compare field_map1 field_map2
| ETuple (es1, s1), ETuple (es2, s2) -> | ETuple es1, ETuple es2 ->
Option.compare StructName.compare s1 s2 @@< fun () ->
List.compare compare es1 es2 List.compare compare es1 es2
| ETupleAccess (e1, n1, s1, tys1), ETupleAccess (e2, n2, s2, tys2) -> | ETupleAccess {e=e1; index=n1; size=s1},
Option.compare StructName.compare s1 s2 @@< fun () -> ETupleAccess {e=e2; index=n2; size=s2} ->
Int.compare s1 s2 @@< fun () ->
Int.compare n1 n2 @@< fun () -> Int.compare n1 n2 @@< fun () ->
List.compare compare_typ tys1 tys2 @@< fun () ->
compare e1 e2 compare e1 e2
| EInj (e1, n1, name1, ts1), EInj (e2, n2, name2, ts2) -> | EInj {e=e1; name=name1; cons=cons1},
EInj {e=e2; name=name2; cons=cons2} ->
EnumName.compare name1 name2 @@< fun () -> EnumName.compare name1 name2 @@< fun () ->
Int.compare n1 n2 @@< fun () -> EnumConstructor.compare cons1 cons2 @@< fun () ->
List.compare compare_typ ts1 ts2 @@< fun () ->
compare e1 e2 compare e1 e2
| EMatch (e1, cases1, n1), EMatch (e2, cases2, n2) ->
EnumName.compare n1 n2 @@< fun () ->
compare e1 e2 @@< fun () ->
List.compare compare cases1 cases2
| EAssert e1, EAssert e2 -> | EAssert e1, EAssert e2 ->
compare e1 e2 compare e1 e2
| EDefault (exs1, just1, cons1), EDefault (exs2, just2, cons2) -> | EDefault {excepts=exs1; just=just1; cons=cons1},
EDefault {excepts=exs2; just=just2; cons=cons2} ->
compare just1 just2 @@< fun () -> compare just1 just2 @@< fun () ->
compare cons1 cons2 @@< fun () -> compare cons1 cons2 @@< fun () ->
List.compare compare exs1 exs2 List.compare compare exs1 exs2
| ErrorOnEmpty e1, ErrorOnEmpty e2 -> | EErrorOnEmpty e1, EErrorOnEmpty e2 ->
compare e1 e2 compare e1 e2
| ERaise ex1, ERaise ex2 -> | ERaise ex1, ERaise ex2 ->
compare_except ex1 ex2 compare_except ex1 ex2
| ECatch (etry1, ex1, ewith1), ECatch (etry2, ex2, ewith2) -> | ECatch {body=etry1; exn=ex1; handler=ewith1},
ECatch {body=etry2; exn=ex2; handler=ewith2} ->
compare_except ex1 ex2 @@< fun () -> compare_except ex1 ex2 @@< fun () ->
compare etry1 etry2 @@< fun () -> compare etry1 etry2 @@< fun () ->
compare ewith1 ewith2 compare ewith1 ewith2
@ -701,34 +638,33 @@ let rec compare : type a. (a, _) gexpr -> (a, _) gexpr -> int =
| EIfThenElse _, _ -> -1 | _, EIfThenElse _ -> 1 | EIfThenElse _, _ -> -1 | _, EIfThenElse _ -> 1
| ELocation _, _ -> -1 | _, ELocation _ -> 1 | ELocation _, _ -> -1 | _, ELocation _ -> 1
| EStruct _, _ -> -1 | _, EStruct _ -> 1 | EStruct _, _ -> -1 | _, EStruct _ -> 1
| EDStructAccess _, _ -> -1 | _, EDStructAccess _ -> 1
| EStructAccess _, _ -> -1 | _, EStructAccess _ -> 1 | EStructAccess _, _ -> -1 | _, EStructAccess _ -> 1
| EEnumInj _, _ -> -1 | _, EEnumInj _ -> 1 | EMatch _, _ -> -1 | _, EMatch _ -> 1
| EMatchS _, _ -> -1 | _, EMatchS _ -> 1
| EScopeCall _, _ -> -1 | _, EScopeCall _ -> 1 | EScopeCall _, _ -> -1 | _, EScopeCall _ -> 1
| ETuple _, _ -> -1 | _, ETuple _ -> 1 | ETuple _, _ -> -1 | _, ETuple _ -> 1
| ETupleAccess _, _ -> -1 | _, ETupleAccess _ -> 1 | ETupleAccess _, _ -> -1 | _, ETupleAccess _ -> 1
| EInj _, _ -> -1 | _, EInj _ -> 1 | EInj _, _ -> -1 | _, EInj _ -> 1
| EMatch _, _ -> -1 | _, EMatch _ -> 1
| EAssert _, _ -> -1 | _, EAssert _ -> 1 | EAssert _, _ -> -1 | _, EAssert _ -> 1
| EDefault _, _ -> -1 | _, EDefault _ -> 1 | EDefault _, _ -> -1 | _, EDefault _ -> 1
| ErrorOnEmpty _, _ -> . | _, ErrorOnEmpty _ -> . | EErrorOnEmpty _, _ -> . | _, EErrorOnEmpty _ -> .
| ERaise _, _ -> -1 | _, ERaise _ -> 1 | ERaise _, _ -> -1 | _, ERaise _ -> 1
| ECatch _, _ -> . | _, ECatch _ -> . | ECatch _, _ -> . | _, ECatch _ -> .
let rec free_vars : type a. (a, 't) gexpr -> (a, 't) gexpr Var.Set.t = function let rec free_vars : type a. (a, 't) gexpr -> (a, 't) gexpr Var.Set.t = function
| EVar v, _ -> Var.Set.singleton v | EVar v, _ -> Var.Set.singleton v
| EAbs (binder, _), _ -> | EAbs { binder; _ }, _ ->
let vs, body = Bindlib.unmbind binder in let vs, body = Bindlib.unmbind binder in
Array.fold_right Var.Set.remove vs (free_vars body) Array.fold_right Var.Set.remove vs (free_vars body)
| e -> shallow_fold (fun e -> Var.Set.union (free_vars e)) e Var.Set.empty | e -> shallow_fold (fun e -> Var.Set.union (free_vars e)) e Var.Set.empty
let remove_logging_calls e = let remove_logging_calls e =
let rec f () e = let rec f e =
match Marked.unmark e with match Marked.unmark e with
| EApp ((EOp (Unop (Log _)), _), [arg]) -> map () ~f arg | EApp { f = EOp { op = Log _; _ }, _; args = [arg] } -> map ~f arg
| _ -> map () ~f e | _ -> map ~f e
in in
f () e f e
let format ?debug decl_ctx ppf e = Print.expr ?debug decl_ctx ppf e let format ?debug decl_ctx ppf e = Print.expr ?debug decl_ctx ppf e
@ -736,36 +672,35 @@ let rec size : type a. (a, 't) gexpr -> int =
fun e -> fun e ->
match Marked.unmark e with match Marked.unmark e with
| EVar _ | ELit _ | EOp _ -> 1 | EVar _ | ELit _ | EOp _ -> 1
| ETuple (args, _) -> List.fold_left (fun acc arg -> acc + size arg) 1 args | ETuple args -> List.fold_left (fun acc arg -> acc + size arg) 1 args
| EArray args -> List.fold_left (fun acc arg -> acc + size arg) 1 args | EArray args -> List.fold_left (fun acc arg -> acc + size arg) 1 args
| ETupleAccess (e1, _, _, _) -> size e1 + 1 | ETupleAccess { e; _ } -> size e + 1
| EInj (e1, _, _, _) -> size e1 + 1 | EInj { e; _ } -> size e + 1
| EAssert e1 -> size e1 + 1 | EAssert e -> size e + 1
| ErrorOnEmpty e1 -> size e1 + 1 | EErrorOnEmpty e -> size e + 1
| EMatch (arg, args, _) -> | EApp { f; args } ->
List.fold_left (fun acc arg -> acc + size arg) (1 + size arg) args List.fold_left (fun acc arg -> acc + size arg) (1 + size f) args
| EApp (arg, args) -> | EAbs { binder; _ } ->
List.fold_left (fun acc arg -> acc + size arg) (1 + size arg) args
| EAbs (binder, _) ->
let _, body = Bindlib.unmbind binder in let _, body = Bindlib.unmbind binder in
1 + size body 1 + size body
| EIfThenElse (e1, e2, e3) -> 1 + size e1 + size e2 + size e3 | EIfThenElse { cond; etrue; efalse } ->
| EDefault (exceptions, just, cons) -> 1 + size cond + size etrue + size efalse
| EDefault { excepts; just; cons } ->
List.fold_left List.fold_left
(fun acc except -> acc + size except) (fun acc except -> acc + size except)
(1 + size just + size cons) (1 + size just + size cons)
exceptions excepts
| ERaise _ -> 1 | ERaise _ -> 1
| ECatch (etry, _, ewith) -> 1 + size etry + size ewith | ECatch { body; handler; _ } -> 1 + size body + size handler
| ELocation _ -> 1 | ELocation _ -> 1
| EStruct (_, fields) -> | EStruct { fields; _ } ->
StructFieldMap.fold (fun _ e acc -> acc + 1 + size e) fields 0 StructField.Map.fold (fun _ e acc -> acc + 1 + size e) fields 0
| EStructAccess (e1, _, _) -> 1 + size e1 | EDStructAccess { e; _ } -> 1 + size e
| EEnumInj (e1, _, _) -> 1 + size e1 | EStructAccess { e; _ } -> 1 + size e
| EMatchS (e1, _, cases) -> | EMatch { e; cases; _ } ->
EnumConstructorMap.fold (fun _ e acc -> acc + 1 + size e) cases (size e1) EnumConstructor.Map.fold (fun _ e acc -> acc + 1 + size e) cases (size e)
| EScopeCall (_, fields) -> | EScopeCall { args; _ } ->
ScopeVarMap.fold (fun _ e acc -> acc + 1 + size e) fields 1 ScopeVar.Map.fold (fun _ e acc -> acc + 1 + size e) args 1
(* - Expression building helpers - *) (* - Expression building helpers - *)
@ -794,7 +729,7 @@ let make_app e u pos =
(fun tf tx -> (fun tf tx ->
match Marked.unmark tf with match Marked.unmark tf with
| TArrow (tx', tr) -> | TArrow (tx', tr) ->
assert (unifiable tx.ty tx'); assert (Type.unifiable tx.ty tx');
(* wrong arg type *) (* wrong arg type *)
tr tr
| TAny -> tf | TAny -> tf
@ -818,50 +753,35 @@ let make_let_in x tau e1 e2 mpos =
let make_multiple_let_in xs taus e1s e2 mpos = let make_multiple_let_in xs taus e1s e2 mpos =
make_app (make_abs xs e2 taus mpos) e1s (pos e2) make_app (make_abs xs e2 taus mpos) e1s (pos e2)
let make_default_unboxed exceptions just cons = let make_default_unboxed excepts just cons =
let rec bool_value = function let rec bool_value = function
| ELit (LBool b), _ -> Some b | ELit (LBool b), _ -> Some b
| EApp ((EOp (Unop (Log (l, _))), _), [e]), _ | EApp { f = EOp { op = Log (l, _); _ }, _; args = [e]; _ }, _
when l <> PosRecordIfTrueBool when l <> PosRecordIfTrueBool
(* we don't remove the log calls corresponding to source code (* we don't remove the log calls corresponding to source code
definitions !*) -> definitions !*) ->
bool_value e bool_value e
| _ -> None | _ -> None
in in
match exceptions, bool_value just, cons with match excepts, bool_value just, cons with
| [], Some true, cons -> Marked.unmark cons | [], Some true, cons -> Marked.unmark cons
| exceptions, Some true, (EDefault ([], just, cons), _) -> | excepts, Some true, (EDefault { excepts = []; just; cons }, _) ->
EDefault (exceptions, just, cons) EDefault { excepts; just; cons }
| [except], Some false, _ -> Marked.unmark except | [except], Some false, _ -> Marked.unmark except
| exceptions, _, cons -> EDefault (exceptions, just, cons) | excepts, _, cons -> EDefault { excepts; just; cons }
let make_default exceptions just cons = let make_default exceptions just cons =
Box.app2n just cons exceptions Box.app2n just cons exceptions
@@ fun just cons exceptions -> make_default_unboxed exceptions just cons @@ fun just cons exceptions -> make_default_unboxed exceptions just cons
let make_tuple el structname m0 = let make_tuple el m0 =
match el with match el with
| [] -> | [] -> etuple [] (with_ty m0 (TTuple [], mark_pos m0))
etuple [] structname
(with_ty m0
(match structname with
| Some n -> TStruct n, mark_pos m0
| None -> TTuple [], mark_pos m0))
| el -> | el ->
let m = let m =
fold_marks fold_marks
(fun posl -> List.hd posl) (fun posl -> List.hd posl)
(fun ml -> (fun ml -> TTuple (List.map (fun t -> t.ty) ml), (List.hd ml).pos)
let pos = (List.hd ml).pos in
match structname with
| Some n -> TStruct n, pos
| None -> TTuple (List.map (fun t -> t.ty) ml), pos)
(List.map (fun e -> Marked.get_mark e) el) (List.map (fun e -> Marked.get_mark e) el)
in in
etuple el structname m etuple el m
let make_struct fieldmap structname m =
let fields =
List.rev (StructFieldMap.fold (fun _ e acc -> e :: acc) fieldmap [])
in
make_tuple fields (Some structname) m

View File

@ -17,7 +17,7 @@
(** Functions handling the expressions of [shared_ast] *) (** Functions handling the expressions of [shared_ast] *)
open Utils open Catala_utils
open Definitions open Definitions
(** {2 Boxed constructors} *) (** {2 Boxed constructors} *)
@ -43,34 +43,10 @@ val subst :
('a, 't) gexpr list -> ('a, 't) gexpr list ->
('a, 't) gexpr ('a, 't) gexpr
val etuple : val etuple : (lcalc, 't) boxed_gexpr list -> 't -> (lcalc, 't) boxed_gexpr
(([< dcalc | lcalc ] as 'a), 't) boxed_gexpr list ->
StructName.t option ->
't ->
('a, 't) boxed_gexpr
val etupleaccess : val etupleaccess :
(([< dcalc | lcalc ] as 'a), 't) boxed_gexpr -> (lcalc, 't) boxed_gexpr -> int -> int -> 't -> (lcalc, 't) boxed_gexpr
int ->
StructName.t option ->
typ list ->
't ->
('a, 't) boxed_gexpr
val einj :
(([< dcalc | lcalc ] as 'a), 't) boxed_gexpr ->
int ->
EnumName.t ->
typ list ->
't ->
('a, 't) boxed_gexpr
val ematch :
(([< dcalc | lcalc ] as 'a), 't) boxed_gexpr ->
('a, 't) boxed_gexpr list ->
EnumName.t ->
't ->
('a, 't) boxed_gexpr
val earray : ('a any, 't) boxed_gexpr list -> 't -> ('a, 't) boxed_gexpr val earray : ('a any, 't) boxed_gexpr list -> 't -> ('a, 't) boxed_gexpr
val elit : 'a any glit -> 't -> ('a, 't) boxed_gexpr val elit : 'a any glit -> 't -> ('a, 't) boxed_gexpr
@ -90,7 +66,7 @@ val eapp :
val eassert : val eassert :
(([< dcalc | lcalc ] as 'a), 't) boxed_gexpr -> 't -> ('a, 't) boxed_gexpr (([< dcalc | lcalc ] as 'a), 't) boxed_gexpr -> 't -> ('a, 't) boxed_gexpr
val eop : operator -> 't -> (_ any, 't) boxed_gexpr val eop : ('a any, 'k) operator -> typ list -> 't -> ('a, 't) boxed_gexpr
val edefault : val edefault :
(([< desugared | scopelang | dcalc ] as 'a), 't) boxed_gexpr list -> (([< desugared | scopelang | dcalc ] as 'a), 't) boxed_gexpr list ->
@ -125,34 +101,41 @@ val elocation :
val estruct : val estruct :
StructName.t -> StructName.t ->
(([< desugared | scopelang ] as 'a), 't) boxed_gexpr StructFieldMap.t -> ('a any, 't) boxed_gexpr StructField.Map.t ->
't -> 't ->
('a, 't) boxed_gexpr ('a, 't) boxed_gexpr
val edstructaccess :
(desugared, 't) boxed_gexpr ->
IdentName.t ->
StructName.t option ->
't ->
(desugared, 't) boxed_gexpr
val estructaccess : val estructaccess :
(([< desugared | scopelang ] as 'a), 't) boxed_gexpr -> (([< scopelang | dcalc | lcalc ] as 'a), 't) boxed_gexpr ->
StructFieldName.t -> StructField.t ->
StructName.t -> StructName.t ->
't -> 't ->
('a, 't) boxed_gexpr ('a, 't) boxed_gexpr
val eenuminj : val einj :
(([< desugared | scopelang ] as 'a), 't) boxed_gexpr -> ('a any, 't) boxed_gexpr ->
EnumConstructor.t -> EnumConstructor.t ->
EnumName.t -> EnumName.t ->
't -> 't ->
('a, 't) boxed_gexpr ('a, 't) boxed_gexpr
val ematchs : val ematch :
(([< desugared | scopelang ] as 'a), 't) boxed_gexpr -> ('a any, 't) boxed_gexpr ->
EnumName.t -> EnumName.t ->
('a, 't) boxed_gexpr EnumConstructorMap.t -> ('a, 't) boxed_gexpr EnumConstructor.Map.t ->
't -> 't ->
('a, 't) boxed_gexpr ('a, 't) boxed_gexpr
val escopecall : val escopecall :
ScopeName.t -> ScopeName.t ->
(([< desugared | scopelang ] as 'a), 't) boxed_gexpr ScopeVarMap.t -> (([< desugared | scopelang ] as 'a), 't) boxed_gexpr ScopeVar.Map.t ->
't -> 't ->
('a, 't) boxed_gexpr ('a, 't) boxed_gexpr
@ -194,28 +177,25 @@ val untype : ('a, 'm mark) gexpr -> ('a, untyped mark) boxed_gexpr
(** {2 Traversal functions} *) (** {2 Traversal functions} *)
val map : val map :
'ctx -> f:(('a, 't1) gexpr -> ('a, 't2) boxed_gexpr) ->
f:('ctx -> ('a, 't1) gexpr -> ('a, 't2) boxed_gexpr) ->
(('a, 't1) naked_gexpr, 't2) Marked.t -> (('a, 't1) naked_gexpr, 't2) Marked.t ->
('a, 't2) boxed_gexpr ('a, 't2) boxed_gexpr
(** Flat (non-recursive) mapping on expressions. (** Shallow mapping on expressions (non recursive): applies the given function
to all sub-terms of the given expression, and rebuilds the node.
If you want to apply a map transform to an expression, you can save up When applying a map transform to an expression, this avoids expliciting all
writing a painful match over all the cases of the AST. For instance, if you cases that remain unchanged. For instance, if you want to remove all errors
want to remove all errors on empty, you can write on empty, you can write
{[ {[
let remove_error_empty = let remove_error_empty =
let rec f () e = let rec f e =
match Marked.unmark e with match Marked.unmark e with
| ErrorOnEmpty e1 -> Expr.map () f e1 | ErrorOnEmpty e1 -> Expr.map f e1
| _ -> Expr.map () f e | _ -> Expr.map f e
in in
f () e f e
]} ]} *)
The first argument of map_expr is an optional context that you can carry
around during your map traversal. *)
val map_top_down : val map_top_down :
f:(('a, 't1) gexpr -> (('a, 't1) naked_gexpr, 't2) Marked.t) -> f:(('a, 't1) gexpr -> (('a, 't1) naked_gexpr, 't2) Marked.t) ->
@ -231,7 +211,42 @@ val shallow_fold :
(('a, 't) gexpr -> 'acc -> 'acc) -> ('a, 't) gexpr -> 'acc -> 'acc (('a, 't) gexpr -> 'acc -> 'acc) -> ('a, 't) gexpr -> 'acc -> 'acc
(** Applies a function on all sub-terms of the given expression. Does not (** Applies a function on all sub-terms of the given expression. Does not
recurse, and doesn't open binders. Useful as helper for recursive calls recurse, and doesn't open binders. Useful as helper for recursive calls
within traversal functions *) within traversal functions. This can be used to compute free variables with
e.g.:
{[
let rec free_vars = function
| EVar v, _ -> Var.Set.singleton v
| EAbs { binder; _ }, _ ->
let vs, body = Bindlib.unmbind binder in
Array.fold_right Var.Set.remove vs (free_vars body)
| e ->
shallow_fold (fun e -> Var.Set.union (free_vars e)) e Var.Set.empty
]} *)
val map_gather :
acc:'acc ->
join:('acc -> 'acc -> 'acc) ->
f:(('a, 't1) gexpr -> 'acc * ('a, 't2) boxed_gexpr) ->
(('a, 't1) naked_gexpr, 't2) Marked.t ->
'acc * ('a, 't2) boxed_gexpr
(** Shallow mapping similar to [map], but additionally allows to gather an
accumulator bottom-up. [acc] is the accumulator value returned on terminal
nodes, and [join] is used to merge accumulators from the different sub-terms
of an expression. [acc] is assumed to be a neutral element for [join].
Typically used with a set of variables used in the rewrite:
{[
let rec rewrite e =
match Marked.unmark e with
| Specific_case ->
Var.Set.singleton x, some_rewrite_fun e
| _ ->
Expr.map_gather ~acc:Var.Set.empty ~join:Var.Set.union ~f:rewrite e
}]
See [Lcalc.closure_conversion] for a real-world example. *)
(** {2 Expression building helpers} *) (** {2 Expression building helpers} *)
@ -289,21 +304,10 @@ val make_default :
- [<ex | false :- _>], when [ex] is a single exception, is rewritten as [ex] *) - [<ex | false :- _>], when [ex] is a single exception, is rewritten as [ex] *)
val make_tuple : val make_tuple :
(([< dcalc | lcalc ] as 'a), 'm mark) boxed_gexpr list -> (lcalc, 'm mark) boxed_gexpr list -> 'm mark -> (lcalc, 'm mark) boxed_gexpr
StructName.t option ->
'm mark ->
('a, 'm mark) boxed_gexpr
(** Builds a tuple; the mark argument is only used as witness and for position (** Builds a tuple; the mark argument is only used as witness and for position
when building 0-uples *) when building 0-uples *)
val make_struct :
(([< dcalc | lcalc ] as 'a), 'm mark) boxed_gexpr StructFieldMap.t ->
StructName.t ->
'm mark ->
('a, 'm mark) boxed_gexpr
(** Builds the tuple of values for the given struct with proper ordering,
assuming the structfieldmap contains the fields defined for structname *)
(** {2 Transformations} *) (** {2 Transformations} *)
val remove_logging_calls : ('a any, 't) gexpr -> ('a, 't) boxed_gexpr val remove_logging_calls : ('a any, 't) gexpr -> ('a, 't) boxed_gexpr
@ -331,8 +335,6 @@ val compare : ('a, 't) gexpr -> ('a, 't) gexpr -> int
(** Standard comparison function, suitable for e.g. [Set.Make]. Ignores position (** Standard comparison function, suitable for e.g. [Set.Make]. Ignores position
information *) information *)
val equal_typ : typ -> typ -> bool
val compare_typ : typ -> typ -> int
val is_value : ('a any, 't) gexpr -> bool val is_value : ('a any, 't) gexpr -> bool
val free_vars : ('a any, 't) gexpr -> ('a, 't) gexpr Var.Set.t val free_vars : ('a any, 't) gexpr -> ('a, 't) gexpr Var.Set.t
@ -363,10 +365,10 @@ module Box : sig
a separate argument. *) a separate argument. *)
val app1 : val app1 :
('a, 't) boxed_gexpr -> ('a, 't1) boxed_gexpr ->
(('a, 't) gexpr -> ('a, 't) naked_gexpr) -> (('a, 't1) gexpr -> ('a, 't2) naked_gexpr) ->
't -> 't2 ->
('a, 't) boxed_gexpr ('a, 't2) boxed_gexpr
val app2 : val app2 :
('a, 't) boxed_gexpr -> ('a, 't) boxed_gexpr ->

View File

@ -0,0 +1,582 @@
(* This file is part of the Catala compiler, a specification language for tax
and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
Louis Gesbert <louis.gesbert@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not
use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *)
open Catala_utils
open Definitions
include Definitions.Op
let name : type a k. (a, k) t -> string = function
| Not -> "o_not"
| Length -> "o_length"
| GetDay -> "o_getDay"
| GetMonth -> "o_getMonth"
| GetYear -> "o_getYear"
| FirstDayOfMonth -> "o_firstDayOfMonth"
| LastDayOfMonth -> "o_lastDayOfMonth"
| Log _ -> "o_log"
| Minus -> "o_minus"
| Minus_int -> "o_minus_int"
| Minus_rat -> "o_minus_rat"
| Minus_mon -> "o_minus_mon"
| Minus_dur -> "o_minus_dur"
| ToRat -> "o_torat"
| ToRat_int -> "o_torat_int"
| ToRat_mon -> "o_torat_mon"
| ToMoney -> "o_tomoney"
| ToMoney_rat -> "o_tomoney_rat"
| Round -> "o_round"
| Round_rat -> "o_round_rat"
| Round_mon -> "o_round_mon"
| And -> "o_and"
| Or -> "o_or"
| Xor -> "o_xor"
| Eq -> "o_eq"
| Map -> "o_map"
| Concat -> "o_concat"
| Filter -> "o_filter"
| Reduce -> "o_reduce"
| Add -> "o_add"
| Add_int_int -> "o_add_int_int"
| Add_rat_rat -> "o_add_rat_rat"
| Add_mon_mon -> "o_add_mon_mon"
| Add_dat_dur -> "o_add_dat_dur"
| Add_dur_dur -> "o_add_dur_dur"
| Sub -> "o_sub"
| Sub_int_int -> "o_sub_int_int"
| Sub_rat_rat -> "o_sub_rat_rat"
| Sub_mon_mon -> "o_sub_mon_mon"
| Sub_dat_dat -> "o_sub_dat_dat"
| Sub_dat_dur -> "o_sub_dat_dur"
| Sub_dur_dur -> "o_sub_dur_dur"
| Mult -> "o_mult"
| Mult_int_int -> "o_mult_int_int"
| Mult_rat_rat -> "o_mult_rat_rat"
| Mult_mon_rat -> "o_mult_mon_rat"
| Mult_dur_int -> "o_mult_dur_int"
| Div -> "o_div"
| Div_int_int -> "o_div_int_int"
| Div_rat_rat -> "o_div_rat_rat"
| Div_mon_mon -> "o_div_mon_mon"
| Div_mon_rat -> "o_div_mon_mon"
| Lt -> "o_lt"
| Lt_int_int -> "o_lt_int_int"
| Lt_rat_rat -> "o_lt_rat_rat"
| Lt_mon_mon -> "o_lt_mon_mon"
| Lt_dur_dur -> "o_lt_dur_dur"
| Lt_dat_dat -> "o_lt_dat_dat"
| Lte -> "o_lte"
| Lte_int_int -> "o_lte_int_int"
| Lte_rat_rat -> "o_lte_rat_rat"
| Lte_mon_mon -> "o_lte_mon_mon"
| Lte_dur_dur -> "o_lte_dur_dur"
| Lte_dat_dat -> "o_lte_dat_dat"
| Gt -> "o_gt"
| Gt_int_int -> "o_gt_int_int"
| Gt_rat_rat -> "o_gt_rat_rat"
| Gt_mon_mon -> "o_gt_mon_mon"
| Gt_dur_dur -> "o_gt_dur_dur"
| Gt_dat_dat -> "o_gt_dat_dat"
| Gte -> "o_gte"
| Gte_int_int -> "o_gte_int_int"
| Gte_rat_rat -> "o_gte_rat_rat"
| Gte_mon_mon -> "o_gte_mon_mon"
| Gte_dur_dur -> "o_gte_dur_dur"
| Gte_dat_dat -> "o_gte_dat_dat"
| Eq_int_int -> "o_eq_int_int"
| Eq_rat_rat -> "o_eq_rat_rat"
| Eq_mon_mon -> "o_eq_mon_mon"
| Eq_dur_dur -> "o_eq_dur_dur"
| Eq_dat_dat -> "o_eq_dat_dat"
| Fold -> "o_fold"
let compare_log_entries l1 l2 =
match l1, l2 with
| VarDef t1, VarDef t2 -> Type.compare (t1, Pos.no_pos) (t2, Pos.no_pos)
| BeginCall, BeginCall
| EndCall, EndCall
| PosRecordIfTrueBool, PosRecordIfTrueBool ->
0
| VarDef _, _ -> -1
| _, VarDef _ -> 1
| BeginCall, _ -> -1
| _, BeginCall -> 1
| EndCall, _ -> -1
| _, EndCall -> 1
| PosRecordIfTrueBool, _ -> .
| _, PosRecordIfTrueBool -> .
let compare (type a k a2 k2) (t1 : (a, k) t) (t2 : (a2, k2) t) =
match[@ocamlformat "disable"] t1, t2 with
| Log (l1, info1), Log (l2, info2) -> (
match compare_log_entries l1 l2 with
| 0 -> List.compare Uid.MarkedString.compare info1 info2
| n -> n)
| Not, Not
| Length, Length
| GetDay, GetDay
| GetMonth, GetMonth
| GetYear, GetYear
| FirstDayOfMonth, FirstDayOfMonth
| LastDayOfMonth, LastDayOfMonth
| Minus, Minus
| Minus_int, Minus_int
| Minus_rat, Minus_rat
| Minus_mon, Minus_mon
| Minus_dur, Minus_dur
| ToRat, ToRat
| ToRat_int, ToRat_int
| ToRat_mon, ToRat_mon
| ToMoney, ToMoney
| ToMoney_rat, ToMoney_rat
| Round, Round
| Round_rat, Round_rat
| Round_mon, Round_mon
| And, And
| Or, Or
| Xor, Xor
| Eq, Eq
| Map, Map
| Concat, Concat
| Filter, Filter
| Reduce, Reduce
| Add, Add
| Add_int_int, Add_int_int
| Add_rat_rat, Add_rat_rat
| Add_mon_mon, Add_mon_mon
| Add_dat_dur, Add_dat_dur
| Add_dur_dur, Add_dur_dur
| Sub, Sub
| Sub_int_int, Sub_int_int
| Sub_rat_rat, Sub_rat_rat
| Sub_mon_mon, Sub_mon_mon
| Sub_dat_dat, Sub_dat_dat
| Sub_dat_dur, Sub_dat_dur
| Sub_dur_dur, Sub_dur_dur
| Mult, Mult
| Mult_int_int, Mult_int_int
| Mult_rat_rat, Mult_rat_rat
| Mult_mon_rat, Mult_mon_rat
| Mult_dur_int, Mult_dur_int
| Div, Div
| Div_int_int, Div_int_int
| Div_rat_rat, Div_rat_rat
| Div_mon_mon, Div_mon_mon
| Div_mon_rat, Div_mon_rat
| Lt, Lt
| Lt_int_int, Lt_int_int
| Lt_rat_rat, Lt_rat_rat
| Lt_mon_mon, Lt_mon_mon
| Lt_dat_dat, Lt_dat_dat
| Lt_dur_dur, Lt_dur_dur
| Lte, Lte
| Lte_int_int, Lte_int_int
| Lte_rat_rat, Lte_rat_rat
| Lte_mon_mon, Lte_mon_mon
| Lte_dat_dat, Lte_dat_dat
| Lte_dur_dur, Lte_dur_dur
| Gt, Gt
| Gt_int_int, Gt_int_int
| Gt_rat_rat, Gt_rat_rat
| Gt_mon_mon, Gt_mon_mon
| Gt_dat_dat, Gt_dat_dat
| Gt_dur_dur, Gt_dur_dur
| Gte, Gte
| Gte_int_int, Gte_int_int
| Gte_rat_rat, Gte_rat_rat
| Gte_mon_mon, Gte_mon_mon
| Gte_dat_dat, Gte_dat_dat
| Gte_dur_dur, Gte_dur_dur
| Eq_int_int, Eq_int_int
| Eq_rat_rat, Eq_rat_rat
| Eq_mon_mon, Eq_mon_mon
| Eq_dat_dat, Eq_dat_dat
| Eq_dur_dur, Eq_dur_dur
| Fold, Fold -> 0
| Not, _ -> -1 | _, Not -> 1
| Length, _ -> -1 | _, Length -> 1
| GetDay, _ -> -1 | _, GetDay -> 1
| GetMonth, _ -> -1 | _, GetMonth -> 1
| GetYear, _ -> -1 | _, GetYear -> 1
| FirstDayOfMonth, _ -> -1 | _, FirstDayOfMonth -> 1
| LastDayOfMonth, _ -> -1 | _, LastDayOfMonth -> 1
| Log _, _ -> -1 | _, Log _ -> 1
| Minus, _ -> -1 | _, Minus -> 1
| Minus_int, _ -> -1 | _, Minus_int -> 1
| Minus_rat, _ -> -1 | _, Minus_rat -> 1
| Minus_mon, _ -> -1 | _, Minus_mon -> 1
| Minus_dur, _ -> -1 | _, Minus_dur -> 1
| ToRat, _ -> -1 | _, ToRat -> 1
| ToRat_int, _ -> -1 | _, ToRat_int -> 1
| ToRat_mon, _ -> -1 | _, ToRat_mon -> 1
| ToMoney, _ -> -1 | _, ToMoney -> 1
| ToMoney_rat, _ -> -1 | _, ToMoney_rat -> 1
| Round, _ -> -1 | _, Round -> 1
| Round_rat, _ -> -1 | _, Round_rat -> 1
| Round_mon, _ -> -1 | _, Round_mon -> 1
| And, _ -> -1 | _, And -> 1
| Or, _ -> -1 | _, Or -> 1
| Xor, _ -> -1 | _, Xor -> 1
| Eq, _ -> -1 | _, Eq -> 1
| Map, _ -> -1 | _, Map -> 1
| Concat, _ -> -1 | _, Concat -> 1
| Filter, _ -> -1 | _, Filter -> 1
| Reduce, _ -> -1 | _, Reduce -> 1
| Add, _ -> -1 | _, Add -> 1
| Add_int_int, _ -> -1 | _, Add_int_int -> 1
| Add_rat_rat, _ -> -1 | _, Add_rat_rat -> 1
| Add_mon_mon, _ -> -1 | _, Add_mon_mon -> 1
| Add_dat_dur, _ -> -1 | _, Add_dat_dur -> 1
| Add_dur_dur, _ -> -1 | _, Add_dur_dur -> 1
| Sub, _ -> -1 | _, Sub -> 1
| Sub_int_int, _ -> -1 | _, Sub_int_int -> 1
| Sub_rat_rat, _ -> -1 | _, Sub_rat_rat -> 1
| Sub_mon_mon, _ -> -1 | _, Sub_mon_mon -> 1
| Sub_dat_dat, _ -> -1 | _, Sub_dat_dat -> 1
| Sub_dat_dur, _ -> -1 | _, Sub_dat_dur -> 1
| Sub_dur_dur, _ -> -1 | _, Sub_dur_dur -> 1
| Mult, _ -> -1 | _, Mult -> 1
| Mult_int_int, _ -> -1 | _, Mult_int_int -> 1
| Mult_rat_rat, _ -> -1 | _, Mult_rat_rat -> 1
| Mult_mon_rat, _ -> -1 | _, Mult_mon_rat -> 1
| Mult_dur_int, _ -> -1 | _, Mult_dur_int -> 1
| Div, _ -> -1 | _, Div -> 1
| Div_int_int, _ -> -1 | _, Div_int_int -> 1
| Div_rat_rat, _ -> -1 | _, Div_rat_rat -> 1
| Div_mon_mon, _ -> -1 | _, Div_mon_mon -> 1
| Div_mon_rat, _ -> -1 | _, Div_mon_rat -> 1
| Lt, _ -> -1 | _, Lt -> 1
| Lt_int_int, _ -> -1 | _, Lt_int_int -> 1
| Lt_rat_rat, _ -> -1 | _, Lt_rat_rat -> 1
| Lt_mon_mon, _ -> -1 | _, Lt_mon_mon -> 1
| Lt_dat_dat, _ -> -1 | _, Lt_dat_dat -> 1
| Lt_dur_dur, _ -> -1 | _, Lt_dur_dur -> 1
| Lte, _ -> -1 | _, Lte -> 1
| Lte_int_int, _ -> -1 | _, Lte_int_int -> 1
| Lte_rat_rat, _ -> -1 | _, Lte_rat_rat -> 1
| Lte_mon_mon, _ -> -1 | _, Lte_mon_mon -> 1
| Lte_dat_dat, _ -> -1 | _, Lte_dat_dat -> 1
| Lte_dur_dur, _ -> -1 | _, Lte_dur_dur -> 1
| Gt, _ -> -1 | _, Gt -> 1
| Gt_int_int, _ -> -1 | _, Gt_int_int -> 1
| Gt_rat_rat, _ -> -1 | _, Gt_rat_rat -> 1
| Gt_mon_mon, _ -> -1 | _, Gt_mon_mon -> 1
| Gt_dat_dat, _ -> -1 | _, Gt_dat_dat -> 1
| Gt_dur_dur, _ -> -1 | _, Gt_dur_dur -> 1
| Gte, _ -> -1 | _, Gte -> 1
| Gte_int_int, _ -> -1 | _, Gte_int_int -> 1
| Gte_rat_rat, _ -> -1 | _, Gte_rat_rat -> 1
| Gte_mon_mon, _ -> -1 | _, Gte_mon_mon -> 1
| Gte_dat_dat, _ -> -1 | _, Gte_dat_dat -> 1
| Gte_dur_dur, _ -> -1 | _, Gte_dur_dur -> 1
| Eq_int_int, _ -> -1 | _, Eq_int_int -> 1
| Eq_rat_rat, _ -> -1 | _, Eq_rat_rat -> 1
| Eq_mon_mon, _ -> -1 | _, Eq_mon_mon -> 1
| Eq_dat_dat, _ -> -1 | _, Eq_dat_dat -> 1
| Eq_dur_dur, _ -> -1 | _, Eq_dur_dur -> 1
| Fold, _ | _, Fold -> .
let equal (type a k a2 k2) (t1 : (a, k) t) (t2 : (a2, k2) t) = compare t1 t2 = 0
(* Classification of operators *)
let kind_dispatch :
type a b k.
polymorphic:((_, polymorphic) t -> b) ->
monomorphic:((_, monomorphic) t -> b) ->
?overloaded:((_, overloaded) t -> b) ->
?resolved:((_, resolved) t -> b) ->
(a, k) t ->
b =
fun ~polymorphic ~monomorphic ?(overloaded = fun _ -> assert false)
?(resolved = fun _ -> assert false) op ->
match op with
| ( Not | GetDay | GetMonth | GetYear | FirstDayOfMonth | LastDayOfMonth | And
| Or | Xor ) as op ->
monomorphic op
| (Log _ | Length | Eq | Map | Concat | Filter | Reduce | Fold) as op ->
polymorphic op
| ( Minus | ToRat | ToMoney | Round | Add | Sub | Mult | Div | Lt | Lte | Gt
| Gte ) as op ->
overloaded op
| ( Minus_int | Minus_rat | Minus_mon | Minus_dur | ToRat_int | ToRat_mon
| ToMoney_rat | Round_rat | Round_mon | Add_int_int | Add_rat_rat
| Add_mon_mon | Add_dat_dur | Add_dur_dur | Sub_int_int | Sub_rat_rat
| Sub_mon_mon | Sub_dat_dat | Sub_dat_dur | Sub_dur_dur | Mult_int_int
| Mult_rat_rat | Mult_mon_rat | Mult_dur_int | Div_int_int | Div_rat_rat
| Div_mon_mon | Div_mon_rat | Lt_int_int | Lt_rat_rat | Lt_mon_mon
| Lt_dat_dat | Lt_dur_dur | Lte_int_int | Lte_rat_rat | Lte_mon_mon
| Lte_dat_dat | Lte_dur_dur | Gt_int_int | Gt_rat_rat | Gt_mon_mon
| Gt_dat_dat | Gt_dur_dur | Gte_int_int | Gte_rat_rat | Gte_mon_mon
| Gte_dat_dat | Gte_dur_dur | Eq_int_int | Eq_rat_rat | Eq_mon_mon
| Eq_dat_dat | Eq_dur_dur ) as op ->
resolved op
(* Glorified identity... allowed operators are the same in scopelang, dcalc,
lcalc *)
let translate :
type k.
([< scopelang | dcalc | lcalc ], k) t ->
([< scopelang | dcalc | lcalc ], k) t =
fun op ->
match op with
| Length -> Length
| Log (i, l) -> Log (i, l)
| Eq -> Eq
| Map -> Map
| Concat -> Concat
| Filter -> Filter
| Reduce -> Reduce
| Fold -> Fold
| Not -> Not
| GetDay -> GetDay
| GetMonth -> GetMonth
| GetYear -> GetYear
| FirstDayOfMonth -> FirstDayOfMonth
| LastDayOfMonth -> LastDayOfMonth
| And -> And
| Or -> Or
| Xor -> Xor
| Minus_int -> Minus_int
| Minus_rat -> Minus_rat
| Minus_mon -> Minus_mon
| Minus_dur -> Minus_dur
| ToRat_int -> ToRat_int
| ToRat_mon -> ToRat_mon
| ToMoney_rat -> ToMoney_rat
| Round_rat -> Round_rat
| Round_mon -> Round_mon
| Add_int_int -> Add_int_int
| Add_rat_rat -> Add_rat_rat
| Add_mon_mon -> Add_mon_mon
| Add_dat_dur -> Add_dat_dur
| Add_dur_dur -> Add_dur_dur
| Sub_int_int -> Sub_int_int
| Sub_rat_rat -> Sub_rat_rat
| Sub_mon_mon -> Sub_mon_mon
| Sub_dat_dat -> Sub_dat_dat
| Sub_dat_dur -> Sub_dat_dur
| Sub_dur_dur -> Sub_dur_dur
| Mult_int_int -> Mult_int_int
| Mult_rat_rat -> Mult_rat_rat
| Mult_mon_rat -> Mult_mon_rat
| Mult_dur_int -> Mult_dur_int
| Div_int_int -> Div_int_int
| Div_rat_rat -> Div_rat_rat
| Div_mon_mon -> Div_mon_mon
| Div_mon_rat -> Div_mon_rat
| Lt_int_int -> Lt_int_int
| Lt_rat_rat -> Lt_rat_rat
| Lt_mon_mon -> Lt_mon_mon
| Lt_dat_dat -> Lt_dat_dat
| Lt_dur_dur -> Lt_dur_dur
| Lte_int_int -> Lte_int_int
| Lte_rat_rat -> Lte_rat_rat
| Lte_mon_mon -> Lte_mon_mon
| Lte_dat_dat -> Lte_dat_dat
| Lte_dur_dur -> Lte_dur_dur
| Gt_int_int -> Gt_int_int
| Gt_rat_rat -> Gt_rat_rat
| Gt_mon_mon -> Gt_mon_mon
| Gt_dat_dat -> Gt_dat_dat
| Gt_dur_dur -> Gt_dur_dur
| Gte_int_int -> Gte_int_int
| Gte_rat_rat -> Gte_rat_rat
| Gte_mon_mon -> Gte_mon_mon
| Gte_dat_dat -> Gte_dat_dat
| Gte_dur_dur -> Gte_dur_dur
| Eq_int_int -> Eq_int_int
| Eq_rat_rat -> Eq_rat_rat
| Eq_mon_mon -> Eq_mon_mon
| Eq_dat_dat -> Eq_dat_dat
| Eq_dur_dur -> Eq_dur_dur
let monomorphic_type (op, pos) =
let ( @- ) a b = TArrow ((TLit a, pos), b), pos in
let ( @-> ) a b = TArrow ((TLit a, pos), (TLit b, pos)), pos in
match op with
| Not -> TBool @-> TBool
| GetDay -> TDate @-> TInt
| GetMonth -> TDate @-> TInt
| GetYear -> TDate @-> TInt
| FirstDayOfMonth -> TDate @-> TDate
| LastDayOfMonth -> TDate @-> TDate
| And -> TBool @- TBool @-> TBool
| Or -> TBool @- TBool @-> TBool
| Xor -> TBool @- TBool @-> TBool
(** Rules for overloads definitions:
- the concrete operator, including its return type, is uniquely determined
by the type of the operands
- no resolved version of an operator should be the redefinition of another
one with an added conversion. For example, [int + rat -> rat] is not
acceptable (that would amount to implicit casts).
These two points can be generalised for binary operators as: when
considering an operator with type ['a -> 'b -> 'c], for any given two among
['a], ['b] and ['c], there should be a unique solution for the third. *)
let resolved_type (op, pos) =
let ( @- ) a b = TArrow ((TLit a, pos), b), pos in
let ( @-> ) a b = TArrow ((TLit a, pos), (TLit b, pos)), pos in
match op with
| Minus_int -> TInt @-> TInt
| Minus_rat -> TRat @-> TRat
| Minus_mon -> TMoney @-> TMoney
| Minus_dur -> TDuration @-> TDuration
| ToRat_int -> TInt @-> TRat
| ToRat_mon -> TMoney @-> TRat
| ToMoney_rat -> TRat @-> TMoney
| Round_rat -> TRat @-> TRat
| Round_mon -> TMoney @-> TMoney
| Add_int_int -> TInt @- TInt @-> TInt
| Add_rat_rat -> TRat @- TRat @-> TRat
| Add_mon_mon -> TMoney @- TMoney @-> TMoney
| Add_dat_dur -> TDate @- TDuration @-> TDate
| Add_dur_dur -> TDuration @- TDuration @-> TDuration
| Sub_int_int -> TInt @- TInt @-> TInt
| Sub_rat_rat -> TRat @- TRat @-> TRat
| Sub_mon_mon -> TMoney @- TMoney @-> TMoney
| Sub_dat_dat -> TDate @- TDate @-> TDuration
| Sub_dat_dur -> TDate @- TDuration @-> TDuration
| Sub_dur_dur -> TDuration @- TDuration @-> TDuration
| Mult_int_int -> TInt @- TInt @-> TInt
| Mult_rat_rat -> TRat @- TRat @-> TRat
| Mult_mon_rat -> TMoney @- TRat @-> TMoney
| Mult_dur_int -> TDuration @- TInt @-> TDuration
| Div_int_int -> TInt @- TInt @-> TRat
| Div_rat_rat -> TRat @- TRat @-> TRat
| Div_mon_mon -> TMoney @- TMoney @-> TRat
| Div_mon_rat -> TMoney @- TRat @-> TMoney
| Lt_int_int -> TInt @- TInt @-> TBool
| Lt_rat_rat -> TRat @- TRat @-> TBool
| Lt_mon_mon -> TMoney @- TMoney @-> TBool
| Lt_dat_dat -> TDate @- TDate @-> TBool
| Lt_dur_dur -> TDuration @- TDuration @-> TBool
| Lte_int_int -> TInt @- TInt @-> TBool
| Lte_rat_rat -> TRat @- TRat @-> TBool
| Lte_mon_mon -> TMoney @- TMoney @-> TBool
| Lte_dat_dat -> TDate @- TDate @-> TBool
| Lte_dur_dur -> TDuration @- TDuration @-> TBool
| Gt_int_int -> TInt @- TInt @-> TBool
| Gt_rat_rat -> TRat @- TRat @-> TBool
| Gt_mon_mon -> TMoney @- TMoney @-> TBool
| Gt_dat_dat -> TDate @- TDate @-> TBool
| Gt_dur_dur -> TDuration @- TDuration @-> TBool
| Gte_int_int -> TInt @- TInt @-> TBool
| Gte_rat_rat -> TRat @- TRat @-> TBool
| Gte_mon_mon -> TMoney @- TMoney @-> TBool
| Gte_dat_dat -> TDate @- TDate @-> TBool
| Gte_dur_dur -> TDuration @- TDuration @-> TBool
| Eq_int_int -> TInt @- TInt @-> TBool
| Eq_rat_rat -> TRat @- TRat @-> TBool
| Eq_mon_mon -> TMoney @- TMoney @-> TBool
| Eq_dat_dat -> TDate @- TDate @-> TBool
| Eq_dur_dur -> TDuration @- TDuration @-> TBool
let resolve_overload_aux (op : ('a, overloaded) t) (operands : typ_lit list) :
('b, resolved) t * [ `Straight | `Reversed ] =
match op, operands with
| Minus, [TInt] -> Minus_int, `Straight
| Minus, [TRat] -> Minus_rat, `Straight
| Minus, [TMoney] -> Minus_mon, `Straight
| Minus, [TDuration] -> Minus_dur, `Straight
| ToRat, [TInt] -> ToRat_int, `Straight
| ToRat, [TMoney] -> ToRat_mon, `Straight
| ToMoney, [TRat] -> ToMoney_rat, `Straight
| Round, [TRat] -> Round_rat, `Straight
| Round, [TMoney] -> Round_mon, `Straight
| Add, [TInt; TInt] -> Add_int_int, `Straight
| Add, [TRat; TRat] -> Add_rat_rat, `Straight
| Add, [TMoney; TMoney] -> Add_mon_mon, `Straight
| Add, [TDuration; TDuration] -> Add_dur_dur, `Straight
| Add, [TDate; TDuration] -> Add_dat_dur, `Straight
| Add, [TDuration; TDate] -> Add_dat_dur, `Reversed
| Sub, [TInt; TInt] -> Sub_int_int, `Straight
| Sub, [TRat; TRat] -> Sub_rat_rat, `Straight
| Sub, [TMoney; TMoney] -> Sub_mon_mon, `Straight
| Sub, [TDuration; TDuration] -> Sub_dur_dur, `Straight
| Sub, [TDate; TDate] -> Sub_dat_dat, `Straight
| Sub, [TDate; TDuration] -> Sub_dat_dur, `Straight
| Mult, [TInt; TInt] -> Mult_int_int, `Straight
| Mult, [TRat; TRat] -> Mult_rat_rat, `Straight
| Mult, [TMoney; TRat] -> Mult_mon_rat, `Straight
| Mult, [TRat; TMoney] -> Mult_mon_rat, `Reversed
| Mult, [TDuration; TInt] -> Mult_dur_int, `Straight
| Mult, [TInt; TDuration] -> Mult_dur_int, `Reversed
| Div, [TInt; TInt] -> Div_int_int, `Straight
| Div, [TRat; TRat] -> Div_rat_rat, `Straight
| Div, [TMoney; TMoney] -> Div_mon_mon, `Straight
| Div, [TMoney; TRat] -> Div_mon_rat, `Straight
| Lt, [TInt; TInt] -> Lt_int_int, `Straight
| Lt, [TRat; TRat] -> Lt_rat_rat, `Straight
| Lt, [TMoney; TMoney] -> Lt_mon_mon, `Straight
| Lt, [TDuration; TDuration] -> Lt_dur_dur, `Straight
| Lt, [TDate; TDate] -> Lt_dat_dat, `Straight
| Lte, [TInt; TInt] -> Lte_int_int, `Straight
| Lte, [TRat; TRat] -> Lte_rat_rat, `Straight
| Lte, [TMoney; TMoney] -> Lte_mon_mon, `Straight
| Lte, [TDuration; TDuration] -> Lte_dur_dur, `Straight
| Lte, [TDate; TDate] -> Lte_dat_dat, `Straight
| Gt, [TInt; TInt] -> Gt_int_int, `Straight
| Gt, [TRat; TRat] -> Gt_rat_rat, `Straight
| Gt, [TMoney; TMoney] -> Gt_mon_mon, `Straight
| Gt, [TDuration; TDuration] -> Gt_dur_dur, `Straight
| Gt, [TDate; TDate] -> Gt_dat_dat, `Straight
| Gte, [TInt; TInt] -> Gte_int_int, `Straight
| Gte, [TRat; TRat] -> Gte_rat_rat, `Straight
| Gte, [TMoney; TMoney] -> Gte_mon_mon, `Straight
| Gte, [TDuration; TDuration] -> Gte_dur_dur, `Straight
| Gte, [TDate; TDate] -> Gte_dat_dat, `Straight
| ( ( Minus | ToRat | ToMoney | Round | Add | Sub | Mult | Div | Lt | Lte | Gt
| Gte ),
_ ) ->
raise Not_found
let resolve_overload
ctx
(op : ('a, overloaded) t Marked.pos)
(operands : typ list) : ('b, resolved) t * [ `Straight | `Reversed ] =
try
let operands =
List.map
(fun t ->
match Marked.unmark t with TLit tl -> tl | _ -> raise Not_found)
operands
in
resolve_overload_aux (Marked.unmark op) operands
with Not_found ->
Errors.raise_multispanned_error
((None, Marked.get_mark op)
:: List.map
(fun ty ->
( Some
(Format.asprintf "Type %a coming from expression:"
(Print.typ ctx) ty),
Marked.get_mark ty ))
operands)
"I don't know how to apply operator %a on types %a" Print.operator
(Marked.unmark op)
(Format.pp_print_list
~pp_sep:(fun ppf () -> Format.fprintf ppf " and@ ")
(Print.typ ctx))
operands
let overload_type ctx (op : ('a, overloaded) t Marked.pos) (operands : typ list)
: typ =
let rop = fst (resolve_overload ctx op operands) in
resolved_type (Marked.same_mark_as rop op)

View File

@ -0,0 +1,85 @@
(* This file is part of the Catala compiler, a specification language for tax
and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
Louis Gesbert <louis.gesbert@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not
use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *)
(** {1 Catala operator utilities} *)
(** Resolving operators from the surface syntax proceeds in three steps:
- During desugaring, the operators may remain untyped (with [TAny]) or, if
they have an explicit type suffix (e.g. the [$] for "money" in [+$]),
their operands types are already explicited in the [EOp] expression node.
- {!modules:Shared_ast.Typing} will then enforce these constraints in
addition to the known built-in type for each operator (e.g.
[Eq: 'a -> 'a -> 'a] isn't encoded in the first-order AST types).
- Finally, during {!modules:Scopelang.From_desugared}, these types are
leveraged to resolve the overloaded operators to their concrete,
monomorphic counterparts
*)
open Catala_utils
open Definitions
include module type of Definitions.Op
val equal : ('a1, 'k1) t -> ('a2, 'k2) t -> bool
val compare : ('a1, 'k1) t -> ('a2, 'k2) t -> int
val name : ('a, 'k) t -> string
(** Returns the operator name as a valid ident starting with a lowercase
character. This is different from Print.operator which returns operator
symbols, e.g. [+$]. *)
val kind_dispatch :
polymorphic:((_ any, polymorphic) t -> 'b) ->
monomorphic:((_ any, monomorphic) t -> 'b) ->
?overloaded:((desugared, overloaded) t -> 'b) ->
?resolved:(([< scopelang | dcalc | lcalc ], resolved) t -> 'b) ->
('a, 'k) t ->
'b
(** Calls one of the supplied functions depending on the kind of the operator *)
val translate :
([< scopelang | dcalc | lcalc ], 'k) t ->
([< scopelang | dcalc | lcalc ], 'k) t
(** An identity function that allows translating an operator between different
passes that don't change operator types *)
(** {2 Getting the types of operators} *)
val monomorphic_type : ('a any, monomorphic) t Marked.pos -> typ
val resolved_type :
([< scopelang | dcalc | lcalc ], resolved) t Marked.pos -> typ
val overload_type :
decl_ctx -> (desugared, overloaded) t Marked.pos -> typ list -> typ
(** The type for typing overloads is different since the types of the operands
are required in advance.
@raise a detailed user error if no matching operator can be found *)
(** Polymorphic operators are typed directly within [Typing], since their types
may contain type variables that can't be expressed outside of it*)
(** {2 Overload handling} *)
val resolve_overload :
decl_ctx ->
(desugared, overloaded) t Marked.pos ->
typ list ->
([< scopelang | dcalc | lcalc ], resolved) t * [ `Straight | `Reversed ]
(** Some overloads are sugar for an operation with reversed operands, e.g.
[TRat * TMoney] is using [mult_mon_rat]. [`Reversed] is returned to signify
this case. *)

View File

@ -14,8 +14,7 @@
License for the specific language governing permissions and limitations under License for the specific language governing permissions and limitations under
the License. *) the License. *)
open Utils open Catala_utils
open String_common
open Definitions open Definitions
let typ_needs_parens (ty : typ) : bool = let typ_needs_parens (ty : typ) : bool =
@ -26,27 +25,28 @@ let uid_list (fmt : Format.formatter) (infos : Uid.MarkedString.info list) :
Format.pp_print_list Format.pp_print_list
~pp_sep:(fun fmt () -> Format.pp_print_char fmt '.') ~pp_sep:(fun fmt () -> Format.pp_print_char fmt '.')
(fun fmt info -> (fun fmt info ->
Utils.Cli.format_with_style Cli.format_with_style
(if begins_with_uppercase (Marked.unmark info) then [ANSITerminal.red] (if String.begins_with_uppercase (Marked.unmark info) then
[ANSITerminal.red]
else []) else [])
fmt fmt
(Utils.Uid.MarkedString.to_string info)) (Uid.MarkedString.to_string info))
fmt infos fmt infos
let keyword (fmt : Format.formatter) (s : string) : unit = let keyword (fmt : Format.formatter) (s : string) : unit =
Utils.Cli.format_with_style [ANSITerminal.red] fmt s Cli.format_with_style [ANSITerminal.red] fmt s
let base_type (fmt : Format.formatter) (s : string) : unit = let base_type (fmt : Format.formatter) (s : string) : unit =
Utils.Cli.format_with_style [ANSITerminal.yellow] fmt s Cli.format_with_style [ANSITerminal.yellow] fmt s
let punctuation (fmt : Format.formatter) (s : string) : unit = let punctuation (fmt : Format.formatter) (s : string) : unit =
Utils.Cli.format_with_style [ANSITerminal.cyan] fmt s Cli.format_with_style [ANSITerminal.cyan] fmt s
let operator (fmt : Format.formatter) (s : string) : unit = let op_style (fmt : Format.formatter) (s : string) : unit =
Utils.Cli.format_with_style [ANSITerminal.green] fmt s Cli.format_with_style [ANSITerminal.green] fmt s
let lit_style (fmt : Format.formatter) (s : string) : unit = let lit_style (fmt : Format.formatter) (s : string) : unit =
Utils.Cli.format_with_style [ANSITerminal.yellow] fmt s Cli.format_with_style [ANSITerminal.yellow] fmt s
let tlit (fmt : Format.formatter) (l : typ_lit) : unit = let tlit (fmt : Format.formatter) (l : typ_lit) : unit =
base_type fmt base_type fmt
@ -68,7 +68,7 @@ let location (type a) (fmt : Format.formatter) (l : a glocation) : unit =
ScopeVar.format_t (Marked.unmark subvar) ScopeVar.format_t (Marked.unmark subvar)
let enum_constructor (fmt : Format.formatter) (c : EnumConstructor.t) : unit = let enum_constructor (fmt : Format.formatter) (c : EnumConstructor.t) : unit =
Utils.Cli.format_with_style [ANSITerminal.magenta] fmt Cli.format_with_style [ANSITerminal.magenta] fmt
(Format.asprintf "%a" EnumConstructor.format_t c) (Format.asprintf "%a" EnumConstructor.format_t c)
let rec typ (ctx : decl_ctx option) (fmt : Format.formatter) (ty : typ) : unit = let rec typ (ctx : decl_ctx option) (fmt : Format.formatter) (ty : typ) : unit =
@ -81,7 +81,7 @@ let rec typ (ctx : decl_ctx option) (fmt : Format.formatter) (ty : typ) : unit =
| TTuple ts -> | TTuple ts ->
Format.fprintf fmt "@[<hov 2>(%a)@]" Format.fprintf fmt "@[<hov 2>(%a)@]"
(Format.pp_print_list (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@ %a@ " operator "*") ~pp_sep:(fun fmt () -> Format.fprintf fmt "@ %a@ " op_style "*")
typ) typ)
ts ts
| TStruct s -> ( | TStruct s -> (
@ -94,9 +94,9 @@ let rec typ (ctx : decl_ctx option) (fmt : Format.formatter) (ty : typ) : unit =
~pp_sep:(fun fmt () -> Format.fprintf fmt "%a@ " punctuation ";") ~pp_sep:(fun fmt () -> Format.fprintf fmt "%a@ " punctuation ";")
(fun fmt (field, mty) -> (fun fmt (field, mty) ->
Format.fprintf fmt "%a%a%a%a@ %a" punctuation "\"" Format.fprintf fmt "%a%a%a%a@ %a" punctuation "\""
StructFieldName.format_t field punctuation "\"" punctuation ":" StructField.format_t field punctuation "\"" punctuation ":" typ
typ mty)) mty))
(StructMap.find s ctx.ctx_structs) (StructField.Map.bindings (StructName.Map.find s ctx.ctx_structs))
punctuation "}") punctuation "}")
| TEnum e -> ( | TEnum e -> (
match ctx with match ctx with
@ -109,11 +109,11 @@ let rec typ (ctx : decl_ctx option) (fmt : Format.formatter) (ty : typ) : unit =
(fun fmt (case, mty) -> (fun fmt (case, mty) ->
Format.fprintf fmt "%a%a@ %a" enum_constructor case punctuation ":" Format.fprintf fmt "%a%a@ %a" enum_constructor case punctuation ":"
typ mty)) typ mty))
(EnumMap.find e ctx.ctx_enums) (EnumConstructor.Map.bindings (EnumName.Map.find e ctx.ctx_enums))
punctuation "]") punctuation "]")
| TOption t -> Format.fprintf fmt "@[<hov 2>%a@ %a@]" base_type "option" typ t | TOption t -> Format.fprintf fmt "@[<hov 2>%a@ %a@]" base_type "option" typ t
| TArrow (t1, t2) -> | TArrow (t1, t2) ->
Format.fprintf fmt "@[<hov 2>%a %a@ %a@]" typ_with_parens t1 operator "" Format.fprintf fmt "@[<hov 2>%a %a@ %a@]" typ_with_parens t1 op_style ""
typ t2 typ t2
| TArray t1 -> | TArray t1 ->
Format.fprintf fmt "@[<hov 2>%a@ %a@]" base_type "collection" typ t1 Format.fprintf fmt "@[<hov 2>%a@ %a@]" base_type "collection" typ t1
@ -127,9 +127,9 @@ let lit (type a) (fmt : Format.formatter) (l : a glit) : unit =
| LUnit -> lit_style fmt "()" | LUnit -> lit_style fmt "()"
| LRat i -> | LRat i ->
lit_style fmt lit_style fmt
(Runtime.decimal_to_string ~max_prec_digits:!Utils.Cli.max_prec_digits i) (Runtime.decimal_to_string ~max_prec_digits:!Cli.max_prec_digits i)
| LMoney e -> ( | LMoney e -> (
match !Utils.Cli.locale_lang with match !Cli.locale_lang with
| En -> lit_style fmt (Format.asprintf "$%s" (Runtime.money_to_string e)) | En -> lit_style fmt (Format.asprintf "$%s" (Runtime.money_to_string e))
| Fr -> lit_style fmt (Format.asprintf "%s €" (Runtime.money_to_string e)) | Fr -> lit_style fmt (Format.asprintf "%s €" (Runtime.money_to_string e))
| Pl -> lit_style fmt (Format.asprintf "%s PLN" (Runtime.money_to_string e)) | Pl -> lit_style fmt (Format.asprintf "%s PLN" (Runtime.money_to_string e))
@ -137,72 +137,112 @@ let lit (type a) (fmt : Format.formatter) (l : a glit) : unit =
| LDate d -> lit_style fmt (Runtime.date_to_string d) | LDate d -> lit_style fmt (Runtime.date_to_string d)
| LDuration d -> lit_style fmt (Runtime.duration_to_string d) | LDuration d -> lit_style fmt (Runtime.duration_to_string d)
let op_kind (fmt : Format.formatter) (k : op_kind) =
Format.fprintf fmt "%s"
(match k with
| KInt -> ""
| KRat -> "."
| KMoney -> "$"
| KDate -> "@"
| KDuration -> "^")
let binop (fmt : Format.formatter) (op : binop) : unit =
operator fmt
(match op with
| Add k -> Format.asprintf "+%a" op_kind k
| Sub k -> Format.asprintf "-%a" op_kind k
| Mult k -> Format.asprintf "*%a" op_kind k
| Div k -> Format.asprintf "/%a" op_kind k
| And -> "&&"
| Or -> "||"
| Xor -> "xor"
| Eq -> "="
| Neq -> "!="
| Lt k -> Format.asprintf "%s%a" "<" op_kind k
| Lte k -> Format.asprintf "%s%a" "<=" op_kind k
| Gt k -> Format.asprintf "%s%a" ">" op_kind k
| Gte k -> Format.asprintf "%s%a" ">=" op_kind k
| Concat -> "++"
| Map -> "map"
| Filter -> "filter")
let ternop (fmt : Format.formatter) (op : ternop) : unit =
match op with Fold -> keyword fmt "fold"
let log_entry (fmt : Format.formatter) (entry : log_entry) : unit = let log_entry (fmt : Format.formatter) (entry : log_entry) : unit =
Format.fprintf fmt "@<2>%a" Format.fprintf fmt "@<2>%a"
(fun fmt -> function (fun fmt -> function
| VarDef _ -> Utils.Cli.format_with_style [ANSITerminal.blue] fmt "" | VarDef _ -> Cli.format_with_style [ANSITerminal.blue] fmt ""
| BeginCall -> Utils.Cli.format_with_style [ANSITerminal.yellow] fmt "" | BeginCall -> Cli.format_with_style [ANSITerminal.yellow] fmt ""
| EndCall -> Utils.Cli.format_with_style [ANSITerminal.yellow] fmt "" | EndCall -> Cli.format_with_style [ANSITerminal.yellow] fmt ""
| PosRecordIfTrueBool -> | PosRecordIfTrueBool ->
Utils.Cli.format_with_style [ANSITerminal.green] fmt "") Cli.format_with_style [ANSITerminal.green] fmt "")
entry entry
let unop (fmt : Format.formatter) (op : unop) : unit = let operator_to_string : type a k. (a, k) Op.t -> string = function
| Not -> "~"
| Length -> "length"
| GetDay -> "get_day"
| GetMonth -> "get_month"
| GetYear -> "get_year"
| FirstDayOfMonth -> "first_day_of_month"
| LastDayOfMonth -> "last_day_of_month"
| ToRat -> "to_rat"
| ToRat_int -> "to_rat_int"
| ToRat_mon -> "to_rat_mon"
| ToMoney -> "to_mon"
| ToMoney_rat -> "to_mon_rat"
| Round -> "round"
| Round_rat -> "round_rat"
| Round_mon -> "round_mon"
| Log _ -> "Log"
| Minus -> "-"
| Minus_int -> "-!"
| Minus_rat -> "-."
| Minus_mon -> "-$"
| Minus_dur -> "-^"
| And -> "&&"
| Or -> "||"
| Xor -> "xor"
| Eq -> "="
| Map -> "map"
| Reduce -> "reduce"
| Concat -> "++"
| Filter -> "filter"
| Add -> "+"
| Add_int_int -> "+!"
| Add_rat_rat -> "+."
| Add_mon_mon -> "+$"
| Add_dat_dur -> "+@"
| Add_dur_dur -> "+^"
| Sub -> "-"
| Sub_int_int -> "-!"
| Sub_rat_rat -> "-."
| Sub_mon_mon -> "-$"
| Sub_dat_dat -> "-@"
| Sub_dat_dur -> "-@^"
| Sub_dur_dur -> "-^"
| Mult -> "*"
| Mult_int_int -> "*!"
| Mult_rat_rat -> "*."
| Mult_mon_rat -> "*$"
| Mult_dur_int -> "*^"
| Div -> "/"
| Div_int_int -> "/!"
| Div_rat_rat -> "/."
| Div_mon_mon -> "/$"
| Div_mon_rat -> "/$."
| Lt -> "<"
| Lt_int_int -> "<!"
| Lt_rat_rat -> "<."
| Lt_mon_mon -> "<$"
| Lt_dur_dur -> "<^"
| Lt_dat_dat -> "<@"
| Lte -> "<="
| Lte_int_int -> "<=!"
| Lte_rat_rat -> "<=."
| Lte_mon_mon -> "<=$"
| Lte_dur_dur -> "<=^"
| Lte_dat_dat -> "<=@"
| Gt -> ">"
| Gt_int_int -> ">!"
| Gt_rat_rat -> ">."
| Gt_mon_mon -> ">$"
| Gt_dur_dur -> ">^"
| Gt_dat_dat -> ">@"
| Gte -> ">="
| Gte_int_int -> ">=!"
| Gte_rat_rat -> ">=."
| Gte_mon_mon -> ">=$"
| Gte_dur_dur -> ">=^"
| Gte_dat_dat -> ">=@"
| Eq_int_int -> "=!"
| Eq_rat_rat -> "=."
| Eq_mon_mon -> "=$"
| Eq_dur_dur -> "=^"
| Eq_dat_dat -> "=@"
| Fold -> "fold"
let operator (type k) (fmt : Format.formatter) (op : ('a, k) operator) : unit =
match op with match op with
| Minus _ -> Format.pp_print_string fmt "-"
| Not -> Format.pp_print_string fmt "~"
| Log (entry, infos) -> | Log (entry, infos) ->
Format.fprintf fmt "log@[<hov 2>[%a|%a]@]" log_entry entry Format.fprintf fmt "%a@[<hov 2>[%a|%a]@]" op_style "log" log_entry entry
(Format.pp_print_list (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ".") ~pp_sep:(fun fmt () -> Format.fprintf fmt ".")
(fun fmt info -> Utils.Uid.MarkedString.format_info fmt info)) (fun fmt info -> Uid.MarkedString.format fmt info))
infos infos
| Length -> Format.pp_print_string fmt "length" | op -> Format.fprintf fmt "%a" op_style (operator_to_string op)
| IntToRat -> Format.pp_print_string fmt "int_to_rat"
| MoneyToRat -> Format.pp_print_string fmt "money_to_rat"
| RatToMoney -> Format.pp_print_string fmt "rat_to_money"
| GetDay -> Format.pp_print_string fmt "get_day"
| GetMonth -> Format.pp_print_string fmt "get_month"
| GetYear -> Format.pp_print_string fmt "get_year"
| FirstDayOfMonth -> Format.pp_print_string fmt "first_day_of_month"
| LastDayOfMonth -> Format.pp_print_string fmt "last_day_of_month"
| RoundMoney -> Format.pp_print_string fmt "round_money"
| RoundDecimal -> Format.pp_print_string fmt "round_decimal"
let except (fmt : Format.formatter) (exn : except) : unit = let except (fmt : Format.formatter) (exn : except) : unit =
operator fmt op_style fmt
(match exn with (match exn with
| EmptyError -> "EmptyError" | EmptyError -> "EmptyError"
| ConflictError -> "ConflictError" | ConflictError -> "ConflictError"
@ -215,7 +255,7 @@ let var_debug fmt v =
let var fmt v = Format.pp_print_string fmt (Bindlib.name_of v) let var fmt v = Format.pp_print_string fmt (Bindlib.name_of v)
let needs_parens (type a) (e : (a, _) gexpr) : bool = let needs_parens (type a) (e : (a, _) gexpr) : bool =
match Marked.unmark e with EAbs _ | ETuple (_, Some _) -> true | _ -> false match Marked.unmark e with EAbs _ | EStruct _ -> true | _ -> false
let rec expr_aux : let rec expr_aux :
type a. type a.
@ -228,6 +268,7 @@ let rec expr_aux :
fun ?(debug = false) ctx bnd_ctx fmt e -> fun ?(debug = false) ctx bnd_ctx fmt e ->
let exprb bnd_ctx e = expr_aux ~debug ctx bnd_ctx e in let exprb bnd_ctx e = expr_aux ~debug ctx bnd_ctx e in
let expr e = exprb bnd_ctx e in let expr e = exprb bnd_ctx e in
let var = if debug then var_debug else var in
let with_parens fmt e = let with_parens fmt e =
if needs_parens e then ( if needs_parens e then (
punctuation fmt "("; punctuation fmt "(";
@ -236,79 +277,28 @@ let rec expr_aux :
else expr fmt e else expr fmt e
in in
match Marked.unmark e with match Marked.unmark e with
| EVar v -> if debug then var_debug fmt v else var fmt v | EVar v -> var fmt v
| ETuple (es, None) -> | ETuple es ->
Format.fprintf fmt "@[<hov 2>%a%a%a@]" punctuation "(" Format.fprintf fmt "@[<hov 2>%a%a%a@]" punctuation "("
(Format.pp_print_list (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ") ~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ")
(fun fmt e -> expr fmt e)) (fun fmt e -> expr fmt e))
es punctuation ")" es punctuation ")"
| ETuple (es, Some s) -> (
match ctx with
| None -> expr fmt (Marked.same_mark_as (ETuple (es, None)) e)
| Some ctx ->
Format.fprintf fmt "@[<hov 2>%a@ @[<hov 2>%a%a%a@]@]" StructName.format_t
s punctuation "{"
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "%a@ " punctuation ";")
(fun fmt (e, struct_field) ->
Format.fprintf fmt "%a%a%a%a@ %a" punctuation "\""
StructFieldName.format_t struct_field punctuation "\""
punctuation "=" expr e))
(List.combine es (List.map fst (StructMap.find s ctx.ctx_structs)))
punctuation "}")
| EArray es -> | EArray es ->
Format.fprintf fmt "@[<hov 2>%a%a%a@]" punctuation "[" Format.fprintf fmt "@[<hov 2>%a%a%a@]" punctuation "["
(Format.pp_print_list (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ";@ ") ~pp_sep:(fun fmt () -> Format.fprintf fmt ";@ ")
(fun fmt e -> expr fmt e)) (fun fmt e -> expr fmt e))
es punctuation "]" es punctuation "]"
| ETupleAccess (e1, n, s, _ts) -> ( | ETupleAccess { e; index; _ } ->
match s, ctx with expr fmt e;
| None, _ | _, None -> punctuation fmt ".";
expr fmt e1; Format.pp_print_int fmt index
punctuation fmt ".";
Format.pp_print_int fmt n
| Some s, Some ctx ->
expr fmt e1;
operator fmt ".";
punctuation fmt "\"";
StructFieldName.format_t fmt
(fst (List.nth (StructMap.find s ctx.ctx_structs) n));
punctuation fmt "\"")
| EInj (e, n, en, _ts) -> (
match ctx with
| None ->
Format.fprintf fmt "@[<hov 2>%a[%d]@ %a@]" EnumName.format_t en n expr e
| Some ctx ->
Format.fprintf fmt "@[<hov 2>%a@ %a@]" enum_constructor
(fst (List.nth (EnumMap.find en ctx.ctx_enums) n))
expr e)
| EMatch (e, es, e_name) -> (
match ctx with
| None ->
Format.fprintf fmt "@[<hov 0>%a@ @[<hov 2>%a@]@ %a@ %a@]" keyword "match"
expr e keyword "with"
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n")
(fun fmt (e, i) ->
Format.fprintf fmt "@[<hov 2>%a %a[%d]%a@ %a@]" punctuation "|"
EnumName.format_t e_name i punctuation ":" expr e))
(List.mapi (fun i e -> e, i) es)
| Some ctx ->
Format.fprintf fmt "@[<hov 0>%a@ @[<hov 2>%a@]@ %a@ %a@]" keyword "match"
expr e keyword "with"
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n")
(fun fmt (e, c) ->
Format.fprintf fmt "@[<hov 2>%a %a%a@ %a@]" punctuation "|"
enum_constructor c punctuation ":" expr e))
(List.combine es (List.map fst (EnumMap.find e_name ctx.ctx_enums))))
| ELit l -> lit fmt l | ELit l -> lit fmt l
| EApp ((EAbs (binder, taus), _), args) -> | EApp { f = EAbs { binder; tys }, _; args } ->
let xs, body, bnd_ctx = Bindlib.unmbind_in bnd_ctx binder in let xs, body, bnd_ctx = Bindlib.unmbind_in bnd_ctx binder in
let expr = exprb bnd_ctx in let expr = exprb bnd_ctx in
let xs_tau = List.mapi (fun i tau -> xs.(i), tau) taus in let xs_tau = List.mapi (fun i tau -> xs.(i), tau) tys in
let xs_tau_arg = List.map2 (fun (x, tau) arg -> x, tau, arg) xs_tau args in let xs_tau_arg = List.map2 (fun (x, tau) arg -> x, tau, arg) xs_tau args in
Format.fprintf fmt "%a%a" Format.fprintf fmt "%a%a"
(Format.pp_print_list (Format.pp_print_list
@ -318,10 +308,10 @@ let rec expr_aux :
"let" var x punctuation ":" (typ ctx) tau punctuation "=" expr arg "let" var x punctuation ":" (typ ctx) tau punctuation "=" expr arg
keyword "in")) keyword "in"))
xs_tau_arg expr body xs_tau_arg expr body
| EAbs (binder, taus) -> | EAbs { binder; tys } ->
let xs, body, bnd_ctx = Bindlib.unmbind_in bnd_ctx binder in let xs, body, bnd_ctx = Bindlib.unmbind_in bnd_ctx binder in
let expr = exprb bnd_ctx in let expr = exprb bnd_ctx in
let xs_tau = List.mapi (fun i tau -> xs.(i), tau) taus in let xs_tau = List.mapi (fun i tau -> xs.(i), tau) tys in
Format.fprintf fmt "@[<hov 2>%a @[<hov 2>%a@] %a@ %a@]" punctuation "λ" Format.fprintf fmt "@[<hov 2>%a @[<hov 2>%a@] %a@ %a@]" punctuation "λ"
(Format.pp_print_list (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ") ~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ")
@ -329,29 +319,28 @@ let rec expr_aux :
Format.fprintf fmt "%a%a%a %a%a" punctuation "(" var x punctuation Format.fprintf fmt "%a%a%a %a%a" punctuation "(" var x punctuation
":" (typ ctx) tau punctuation ")")) ":" (typ ctx) tau punctuation ")"))
xs_tau punctuation "" expr body xs_tau punctuation "" expr body
| EApp ((EOp (Binop ((Map | Filter) as op)), _), [arg1; arg2]) -> | EApp { f = EOp { op = (Map | Filter) as op; _ }, _; args = [arg1; arg2] } ->
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@]" binop op with_parens arg1 Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@]" operator op with_parens arg1
with_parens arg2 with_parens arg2
| EApp ((EOp (Binop op), _), [arg1; arg2]) -> | EApp { f = EOp { op; _ }, _; args = [arg1; arg2] } ->
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@]" with_parens arg1 binop op Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@]" with_parens arg1 operator op
with_parens arg2 with_parens arg2
| EApp ((EOp (Unop (Log _)), _), [arg1]) when not debug -> expr fmt arg1 | EApp { f = EOp { op = Log _; _ }, _; args = [arg1] } when not debug ->
| EApp ((EOp (Unop op), _), [arg1]) -> expr fmt arg1
Format.fprintf fmt "@[<hov 2>%a@ %a@]" unop op with_parens arg1 | EApp { f = EOp { op; _ }, _; args = [arg1] } ->
| EApp (f, args) -> Format.fprintf fmt "@[<hov 2>%a@ %a@]" operator op with_parens arg1
| EApp { f; args } ->
Format.fprintf fmt "@[<hov 2>%a@ %a@]" expr f Format.fprintf fmt "@[<hov 2>%a@ %a@]" expr f
(Format.pp_print_list (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ") ~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ")
with_parens) with_parens)
args args
| EIfThenElse (e1, e2, e3) -> | EIfThenElse { cond; etrue; efalse } ->
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@ %a@ %a@ %a@]" keyword "if" expr e1 Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@ %a@ %a@ %a@]" keyword "if" expr
keyword "then" expr e2 keyword "else" expr e3 cond keyword "then" expr etrue keyword "else" expr efalse
| EOp (Ternop op) -> ternop fmt op | EOp { op; _ } -> operator fmt op
| EOp (Binop op) -> binop fmt op | EDefault { excepts; just; cons } ->
| EOp (Unop op) -> unop fmt op if List.length excepts = 0 then
| EDefault (exceptions, just, cons) ->
if List.length exceptions = 0 then
Format.fprintf fmt "@[<hov 2>%a%a@ %a@ %a%a@]" punctuation "" expr just Format.fprintf fmt "@[<hov 2>%a%a@ %a@ %a%a@]" punctuation "" expr just
punctuation "" expr cons punctuation "" punctuation "" expr cons punctuation ""
else else
@ -359,45 +348,48 @@ let rec expr_aux :
(Format.pp_print_list (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "%a@ " punctuation ",") ~pp_sep:(fun fmt () -> Format.fprintf fmt "%a@ " punctuation ",")
expr) expr)
exceptions punctuation "|" expr just punctuation "" expr cons excepts punctuation "|" expr just punctuation "" expr cons punctuation
punctuation "" ""
| ErrorOnEmpty e' -> | EErrorOnEmpty e' ->
Format.fprintf fmt "%a@ %a" operator "error_empty" with_parens e' Format.fprintf fmt "%a@ %a" op_style "error_empty" with_parens e'
| EAssert e' -> | EAssert e' ->
Format.fprintf fmt "@[<hov 2>%a@ %a%a%a@]" keyword "assert" punctuation "(" Format.fprintf fmt "@[<hov 2>%a@ %a%a%a@]" keyword "assert" punctuation "("
expr e' punctuation ")" expr e' punctuation ")"
| ECatch (e1, exn, e2) -> | ECatch { body; exn; handler } ->
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@ %a ->@ %a@]" keyword "try" Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@ %a ->@ %a@]" keyword "try"
with_parens e1 keyword "with" except exn with_parens e2 with_parens body keyword "with" except exn with_parens handler
| ERaise exn -> | ERaise exn ->
Format.fprintf fmt "@[<hov 2>%a@ %a@]" keyword "raise" except exn Format.fprintf fmt "@[<hov 2>%a@ %a@]" keyword "raise" except exn
| ELocation loc -> location fmt loc | ELocation loc -> location fmt loc
| EStruct (name, fields) -> | EDStructAccess { e; field; _ } ->
Format.fprintf fmt " @[<hov 2>%a@ %a@ %a@ %a@]" StructName.format_t name Format.fprintf fmt "%a%a%a%a%a" expr e punctuation "." punctuation "\""
IdentName.format_t field punctuation "\""
| EStruct { name; fields } ->
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@ %a@]" StructName.format_t name
punctuation "{" punctuation "{"
(Format.pp_print_list (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "%a@ " punctuation ";") ~pp_sep:(fun fmt () -> Format.fprintf fmt "%a@ " punctuation ";")
(fun fmt (field_name, field_expr) -> (fun fmt (field_name, field_expr) ->
Format.fprintf fmt "%a%a%a%a@ %a" punctuation "\"" Format.fprintf fmt "%a%a%a%a@ %a" punctuation "\""
StructFieldName.format_t field_name punctuation "\"" punctuation StructField.format_t field_name punctuation "\"" punctuation "="
"=" expr field_expr)) expr field_expr))
(StructFieldMap.bindings fields) (StructField.Map.bindings fields)
punctuation "}" punctuation "}"
| EStructAccess (e1, field, _) -> | EStructAccess { e; field; _ } ->
Format.fprintf fmt "%a%a%a%a%a" expr e1 punctuation "." punctuation "\"" Format.fprintf fmt "%a%a%a%a%a" expr e punctuation "." punctuation "\""
StructFieldName.format_t field punctuation "\"" StructField.format_t field punctuation "\""
| EEnumInj (e1, cons, _) -> | EInj { e; cons; _ } ->
Format.fprintf fmt "%a@ %a" EnumConstructor.format_t cons expr e1 Format.fprintf fmt "%a@ %a" EnumConstructor.format_t cons expr e
| EMatchS (e1, _, cases) -> | EMatch { e; cases; _ } ->
Format.fprintf fmt "@[<hov 0>%a@ @[<hov 2>%a@]@ %a@ %a@]" keyword "match" Format.fprintf fmt "@[<hov 0>%a@ @[<hov 2>%a@]@ %a@ %a@]" keyword "match"
expr e1 keyword "with" expr e keyword "with"
(Format.pp_print_list (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n") ~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n")
(fun fmt (cons_name, case_expr) -> (fun fmt (cons_name, case_expr) ->
Format.fprintf fmt "@[<hov 2>%a %a@ %a@ %a@]" punctuation "|" Format.fprintf fmt "@[<hov 2>%a %a@ %a@ %a@]" punctuation "|"
enum_constructor cons_name punctuation "" expr case_expr)) enum_constructor cons_name punctuation "" expr case_expr))
(EnumConstructorMap.bindings cases) (EnumConstructor.Map.bindings cases)
| EScopeCall (scope, fields) -> | EScopeCall { scope; args } ->
Format.pp_open_hovbox fmt 2; Format.pp_open_hovbox fmt 2;
ScopeName.format_t fmt scope; ScopeName.format_t fmt scope;
Format.pp_print_space fmt (); Format.pp_print_space fmt ();
@ -411,7 +403,7 @@ let rec expr_aux :
Format.fprintf fmt "%a%a%a%a@ %a" punctuation "\"" ScopeVar.format_t Format.fprintf fmt "%a%a%a%a@ %a" punctuation "\"" ScopeVar.format_t
field_name punctuation "\"" punctuation "=" expr field_expr) field_name punctuation "\"" punctuation "=" expr field_expr)
fmt fmt
(ScopeVarMap.bindings fields); (ScopeVar.Map.bindings args);
Format.pp_close_box fmt (); Format.pp_close_box fmt ();
punctuation fmt "}"; punctuation fmt "}";
Format.pp_close_box fmt () Format.pp_close_box fmt ()

View File

@ -16,7 +16,7 @@
(** Printing functions for the default calculus AST *) (** Printing functions for the default calculus AST *)
open Utils open Catala_utils
open Definitions open Definitions
(** {1 Common syntax highlighting helpers}*) (** {1 Common syntax highlighting helpers}*)
@ -24,7 +24,7 @@ open Definitions
val base_type : Format.formatter -> string -> unit val base_type : Format.formatter -> string -> unit
val keyword : Format.formatter -> string -> unit val keyword : Format.formatter -> string -> unit
val punctuation : Format.formatter -> string -> unit val punctuation : Format.formatter -> string -> unit
val operator : Format.formatter -> string -> unit val op_style : Format.formatter -> string -> unit
val lit_style : Format.formatter -> string -> unit val lit_style : Format.formatter -> string -> unit
(** {1 Formatters} *) (** {1 Formatters} *)
@ -35,13 +35,11 @@ val tlit : Format.formatter -> typ_lit -> unit
val location : Format.formatter -> 'a glocation -> unit val location : Format.formatter -> 'a glocation -> unit
val typ : decl_ctx -> Format.formatter -> typ -> unit val typ : decl_ctx -> Format.formatter -> typ -> unit
val lit : Format.formatter -> 'a glit -> unit val lit : Format.formatter -> 'a glit -> unit
val op_kind : Format.formatter -> op_kind -> unit val operator : Format.formatter -> ('a any, 'k) operator -> unit
val binop : Format.formatter -> binop -> unit
val ternop : Format.formatter -> ternop -> unit
val log_entry : Format.formatter -> log_entry -> unit val log_entry : Format.formatter -> log_entry -> unit
val unop : Format.formatter -> unop -> unit
val except : Format.formatter -> except -> unit val except : Format.formatter -> except -> unit
val var : Format.formatter -> 'e Var.t -> unit val var : Format.formatter -> 'e Var.t -> unit
val var_debug : Format.formatter -> 'e Var.t -> unit
val expr : val expr :
?debug:bool (** [true] for debug printing *) -> ?debug:bool (** [true] for debug printing *) ->

View File

@ -22,6 +22,18 @@ let map_exprs ~f ~varf { scopes; decl_ctx } =
(fun scopes -> { scopes; decl_ctx }) (fun scopes -> { scopes; decl_ctx })
(Scope.map_exprs ~f ~varf scopes) (Scope.map_exprs ~f ~varf scopes)
let get_scope_body { scopes; _ } scope =
match
Scope.fold_left ~init:None
~f:(fun acc scope_def _ ->
if ScopeName.equal scope_def.scope_name scope then
Some scope_def.scope_body
else acc)
scopes
with
| None -> raise Not_found
| Some body -> body
let untype : 'm. ('a, 'm mark) gexpr program -> ('a, untyped mark) gexpr program let untype : 'm. ('a, 'm mark) gexpr program -> ('a, untyped mark) gexpr program
= =
fun prg -> Bindlib.unbox (map_exprs ~f:Expr.untype ~varf:Var.translate prg) fun prg -> Bindlib.unbox (map_exprs ~f:Expr.untype ~varf:Var.translate prg)

View File

@ -25,6 +25,9 @@ val map_exprs :
'expr1 program -> 'expr1 program ->
'expr2 program Bindlib.box 'expr2 program Bindlib.box
val get_scope_body :
(([< dcalc | lcalc ], _) gexpr as 'e) program -> ScopeName.t -> 'e scope_body
val untype : val untype :
(([< dcalc | lcalc ] as 'a), 'm mark) gexpr program -> (([< dcalc | lcalc ] as 'a), 'm mark) gexpr program ->
('a, untyped mark) gexpr program ('a, untyped mark) gexpr program

View File

@ -15,7 +15,7 @@
License for the specific language governing permissions and limitations under License for the specific language governing permissions and limitations under
the License. *) the License. *)
open Utils open Catala_utils
open Definitions open Definitions
let rec fold_left_lets ~f ~init scope_body_expr = let rec fold_left_lets ~f ~init scope_body_expr =
@ -106,7 +106,7 @@ let rec get_body_expr_mark = function
get_body_expr_mark e get_body_expr_mark e
| Result e -> | Result e ->
let m = Marked.get_mark e in let m = Marked.get_mark e in
Expr.with_ty m (Utils.Marked.mark (Expr.mark_pos m) TAny) Expr.with_ty m (Marked.mark (Expr.mark_pos m) TAny)
let get_body_mark scope_body = let get_body_mark scope_body =
let _, e = Bindlib.unbind scope_body.scope_body_expr in let _, e = Bindlib.unbind scope_body.scope_body_expr in

View File

@ -17,7 +17,7 @@
(** Functions handling the scope structures of [shared_ast] *) (** Functions handling the scope structures of [shared_ast] *)
open Utils open Catala_utils
open Definitions open Definitions
(** {2 Traversal functions} *) (** {2 Traversal functions} *)

View File

@ -16,6 +16,8 @@
include Definitions include Definitions
module Var = Var module Var = Var
module Type = Type
module Operator = Operator
module Expr = Expr module Expr = Expr
module Scope = Scope module Scope = Scope
module Program = Program module Program = Program

View File

@ -0,0 +1,87 @@
(* This file is part of the Catala compiler, a specification language for tax
and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
Louis Gesbert <louis.gesbert@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not
use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *)
open Catala_utils
open Definitions
type t = typ
let equal_tlit l1 l2 = l1 = l2
let compare_tlit l1 l2 = Stdlib.compare l1 l2
let rec equal ty1 ty2 =
match Marked.unmark ty1, Marked.unmark ty2 with
| TLit l1, TLit l2 -> equal_tlit l1 l2
| TTuple tys1, TTuple tys2 -> equal_list tys1 tys2
| TStruct n1, TStruct n2 -> StructName.equal n1 n2
| TEnum n1, TEnum n2 -> EnumName.equal n1 n2
| TOption t1, TOption t2 -> equal t1 t2
| TArrow (t1, t1'), TArrow (t2, t2') -> equal t1 t2 && equal t1' t2'
| TArray t1, TArray t2 -> equal t1 t2
| TAny, TAny -> true
| ( ( TLit _ | TTuple _ | TStruct _ | TEnum _ | TOption _ | TArrow _
| TArray _ | TAny ),
_ ) ->
false
and equal_list tys1 tys2 =
try List.for_all2 equal tys1 tys2 with Invalid_argument _ -> false
(* Similar to [equal], but allows TAny holes *)
let rec unifiable ty1 ty2 =
match Marked.unmark ty1, Marked.unmark ty2 with
| TAny, _ | _, TAny -> true
| TLit l1, TLit l2 -> equal_tlit l1 l2
| TTuple tys1, TTuple tys2 -> unifiable_list tys1 tys2
| TStruct n1, TStruct n2 -> StructName.equal n1 n2
| TEnum n1, TEnum n2 -> EnumName.equal n1 n2
| TOption t1, TOption t2 -> unifiable t1 t2
| TArrow (t1, t1'), TArrow (t2, t2') -> unifiable t1 t2 && unifiable t1' t2'
| TArray t1, TArray t2 -> unifiable t1 t2
| ( (TLit _ | TTuple _ | TStruct _ | TEnum _ | TOption _ | TArrow _ | TArray _),
_ ) ->
false
and unifiable_list tys1 tys2 =
try List.for_all2 unifiable tys1 tys2 with Invalid_argument _ -> false
let rec compare ty1 ty2 =
match Marked.unmark ty1, Marked.unmark ty2 with
| TLit l1, TLit l2 -> compare_tlit l1 l2
| TTuple tys1, TTuple tys2 -> List.compare compare tys1 tys2
| TStruct n1, TStruct n2 -> StructName.compare n1 n2
| TEnum en1, TEnum en2 -> EnumName.compare en1 en2
| TOption t1, TOption t2 -> compare t1 t2
| TArrow (a1, b1), TArrow (a2, b2) -> (
match compare a1 a2 with 0 -> compare b1 b2 | n -> n)
| TArray t1, TArray t2 -> compare t1 t2
| TAny, TAny -> 0
| TLit _, _ -> -1
| _, TLit _ -> 1
| TTuple _, _ -> -1
| _, TTuple _ -> 1
| TStruct _, _ -> -1
| _, TStruct _ -> 1
| TEnum _, _ -> -1
| _, TEnum _ -> 1
| TOption _, _ -> -1
| _, TOption _ -> 1
| TArrow _, _ -> -1
| _, TArrow _ -> 1
| TArray _, _ -> -1
| _, TArray _ -> 1
let rec arrow_return = function TArrow (_, b), _ -> arrow_return b | t -> t

Some files were not shown because too many files have changed in this diff Show More