mirror of
https://github.com/CatalaLang/catala.git
synced 2024-11-08 07:51:43 +03:00
Merge branch 'master' into afromher_334
This commit is contained in:
commit
7cffc53169
@ -1,28 +0,0 @@
|
||||
{ lib, fetchFromGitHub, buildDunePackage }:
|
||||
|
||||
# We need the very last version "bleeding edge" since previous versions don't use dune.
|
||||
|
||||
buildDunePackage rec {
|
||||
pname = "bindlib";
|
||||
version = "5.0.1a";
|
||||
|
||||
minimumOCamlVersion = "4.0.8";
|
||||
|
||||
useDune2 = true;
|
||||
|
||||
src = fetchFromGitHub {
|
||||
owner = "rlepigre";
|
||||
repo = "ocaml-${pname}";
|
||||
rev = "317f195d22c75f556053039cd94b52bd0c423709";
|
||||
name = pname;
|
||||
hash = "sha256-uO/Ko9PmQ+wE0d9jfEngd4G014B4nxGgfQyEvB52Pz8=";
|
||||
};
|
||||
|
||||
meta = with lib; {
|
||||
homepage = "https://rlepigre.github.io/ocaml-bindlib/";
|
||||
description =
|
||||
"Bindlib is a library allowing the manipulation of data structures with bound variables";
|
||||
license = licenses.lgpl3;
|
||||
maintainers = [ ];
|
||||
};
|
||||
}
|
@ -5,7 +5,7 @@
|
||||
, bindlib
|
||||
, buildDunePackage
|
||||
, calendar
|
||||
, cmdliner_1_1_0
|
||||
, cmdliner
|
||||
, cppo
|
||||
, dates_calc
|
||||
, fetchFromGitHub
|
||||
@ -42,7 +42,7 @@ buildDunePackage rec {
|
||||
ansiterminal
|
||||
benchmark
|
||||
bindlib
|
||||
cmdliner_1_1_0
|
||||
cmdliner
|
||||
cppo
|
||||
dates_calc
|
||||
js_of_ocaml
|
||||
|
@ -1,32 +1,13 @@
|
||||
{ ocamlPackages, fetchurl }:
|
||||
|
||||
ocamlPackages.overrideScope' (self: super: {
|
||||
cmdliner_1_1_0 = super.cmdliner.overrideAttrs (o: rec {
|
||||
version = "1.1.0";
|
||||
src = fetchurl {
|
||||
url = "https://erratique.ch/software/${o.pname}/releases/${o.pname }-${version}.tbz";
|
||||
sha256 = "sha256-irWd4HTlJSYuz3HMgi1de2GVL2qus0QjeCe1WdsSs8Q=";
|
||||
};
|
||||
});
|
||||
alcotest = (super.alcotest.override {
|
||||
cmdliner = self.cmdliner_1_1_0;
|
||||
}).overrideAttrs (_: {
|
||||
alcotest = (super.alcotest.override {}).overrideAttrs (_: {
|
||||
doCheck = false;
|
||||
});
|
||||
# Use a more recent version of `re` than the one packaged in nixpkgs
|
||||
re = super.re.overrideAttrs (o: rec {
|
||||
version = "1.10.4";
|
||||
src = fetchurl {
|
||||
url = "https://github.com/ocaml/ocaml-${o.pname}/releases/download/${version}/${o.pname}-${version}.tbz";
|
||||
sha256 = "sha256-g+s+QwCqmx3HggdJAQ9DYuqDUkdCEwUk14wgzpnKdHw=";
|
||||
};
|
||||
});
|
||||
catala = self.callPackage ./catala.nix { };
|
||||
bindlib = self.callPackage ./bindlib.nix { };
|
||||
unionfind = self.callPackage ./unionfind.nix { };
|
||||
ninja_utils = self.callPackage ./ninja_utils.nix { };
|
||||
clerk = self.callPackage ./clerk.nix { };
|
||||
ppx_yojson_conv = self.callPackage ./ppx_yojson_conv.nix { };
|
||||
ubase = self.callPackage ./ubase.nix { };
|
||||
dates_calc = self.callPackage ./dates_calc.nix { };
|
||||
})
|
||||
|
@ -1,20 +0,0 @@
|
||||
{ lib, fetchurl, buildDunePackage, ppxlib, ppx_yojson_conv_lib, ppx_js_style }:
|
||||
|
||||
buildDunePackage rec {
|
||||
pname = "ppx_yojson_conv";
|
||||
version = "0.14.0";
|
||||
|
||||
minimumOCamlVersion = "4.0.8";
|
||||
|
||||
useDune2 = true;
|
||||
|
||||
propagatedBuildInputs = [
|
||||
ppxlib ppx_yojson_conv_lib ppx_js_style
|
||||
];
|
||||
|
||||
src = fetchurl
|
||||
{
|
||||
url = "https://ocaml.janestreet.com/ocaml-core/v0.14/files/ppx_yojson_conv-v0.14.0.tar.gz";
|
||||
sha256 = "0ls6vzj7k0wrjliifqczs78anbc8b88as5w7a3wixfcs1gjfsp2w";
|
||||
};
|
||||
}
|
@ -104,19 +104,20 @@ need more, here is how one can be added:
|
||||
- Choose a name wisely. Be ready to patch any code that already used the name
|
||||
for scope parameters, variables or structure fields, since it won't compile
|
||||
anymore.
|
||||
- Add an element to the `builtin_expression` type in `surface/ast.ml(i)`
|
||||
- Add an element to the `builtin_expression` type in `surface/ast.ml`
|
||||
- Add your builtin in the `builtins` list in `surface/lexer.cppo.ml`, and with
|
||||
proper translations in all of the language-specific modules
|
||||
`surface/lexer_en.cppo.ml`, `surface/lexer_fr.cppo.ml`, etc. Don't forget the
|
||||
macro at the beginning of `lexer.cppo.ml`.
|
||||
- The rest can all be done by following the type errors downstream:
|
||||
- Add a corresponding element to the lower-level AST in `dcalc/ast.ml(i)`, type `unop`
|
||||
- Extend the translation accordingly in `surface/desugaring.ml`
|
||||
- Extend the printer (`dcalc/print.ml`) and the typer with correct type
|
||||
information (`dcalc/typing.ml`)
|
||||
- Add a corresponding element to the lower-level AST in `shared_ast/definitions.ml`, type `Op.t`
|
||||
- Extend the generic operations on operators in `shared_ast/operators.ml` as well as the type information for the operator
|
||||
- Extend the translation accordingly in `desugared/from_surface.ml`
|
||||
- Extend the printer (`shared_ast/print.ml`)
|
||||
- Finally, provide the implementations:
|
||||
- in `lcalc/to_ocaml.ml`, function `format_unop`
|
||||
- in `dcalc/interpreter.ml`, function `evaluate_operator`
|
||||
- in `../runtimes/ocaml/runtime.ml`
|
||||
- in `../runtimes/python/catala/src/catala/runtime.py`
|
||||
- Update the syntax guide in `doc/syntax/syntax.tex` with your new builtin
|
||||
|
||||
### Internationalization of the Catala syntax
|
||||
|
@ -3,7 +3,7 @@
|
||||
FROM ocamlpro/ocaml:4.14-2022-07-17 AS dev-build-context
|
||||
|
||||
# pandoc is not in alpine stable yet, install it manually with an explicit repository
|
||||
RUN sudo apk add pandoc --repository=http://dl-cdn.alpinelinux.org/alpine/edge/testing/
|
||||
RUN sudo apk add pandoc --repository=http://dl-cdn.alpinelinux.org/alpine/edge/community/
|
||||
|
||||
RUN mkdir catala
|
||||
WORKDIR catala
|
||||
|
23
INSTALL.md
23
INSTALL.md
@ -22,18 +22,31 @@ Finally, start a shell inside a new container created from the newly built image
|
||||
The repository provides nix files to build or develop the catala compiler.
|
||||
|
||||
Once [nix is installed](https://nixos.org/manual/nix/stable/#ch-installing-binary),
|
||||
it is possible to enter a development shell:
|
||||
with flakes enabled it is possible to enter a development shell:
|
||||
|
||||
nix-shell
|
||||
nix develop
|
||||
|
||||
or to build the Catala compiler, documentation and runtime library:
|
||||
|
||||
nix-build release.nix
|
||||
nix build
|
||||
|
||||
Dependencies not yet in nixpkgs (`bindlib` and `unionFind` at the moment of writing)
|
||||
are hardcoded inside the `.nix` directory. The `default.nix` should be compatible with
|
||||
Dependencies not yet in nixpkgs (`ubase` and `unionFind` at the moment of writing)
|
||||
are hardcoded inside the `.nix` directory. The `.nix/catala.nix` should be compatible with
|
||||
nixpkgs, if it finds a maintainer.
|
||||
|
||||
To develop catala's compiler using vscode using ocaml's [lsp](https://microsoft.github.io/language-server-protocol/), you can use the [ocaml-platform extension](https://marketplace.visualstudio.com/items?itemName=ocamllabs.ocaml-platform) with the following settings (inside the file `.vscode/settings.json`).
|
||||
|
||||
```json
|
||||
{
|
||||
"ocaml.sandbox": {
|
||||
"kind": "custom",
|
||||
"template": "nix develop --command $prog $args"
|
||||
},
|
||||
}
|
||||
```
|
||||
|
||||
The nix build is updated weekly by an automatic github action.
|
||||
|
||||
### With opam
|
||||
|
||||
The Catala compiler is written using OCaml. First, you have to install `opam`,
|
||||
|
12
Makefile
12
Makefile
@ -299,8 +299,9 @@ run_french_law_library_benchmark_python: $(PY_VIRTUALENV) \
|
||||
CATALA_OPTS?=
|
||||
CLERK_OPTS?=--makeflags="$(MAKEFLAGS)"
|
||||
|
||||
CATALA_BIN=_build/default/compiler/catala.exe
|
||||
CLERK_BIN=_build/default/build_system/clerk.exe
|
||||
CATALA_BIN=_build/default/$(COMPILER_DIR)/catala.exe
|
||||
CLERK_BIN=_build/default/$(BUILD_SYSTEM_DIR)/clerk.exe
|
||||
CATALA_LEGIFRANCE_BIN=_build/default/$(CATALA_LEGIFRANCE_DIR)/catala_legifrance.exe
|
||||
|
||||
CLERK=$(CLERK_BIN) --exe $(CATALA_BIN) \
|
||||
$(CLERK_OPTS) $(if $(CATALA_OPTS),--catala-opts=$(CATALA_OPTS),)
|
||||
@ -336,7 +337,7 @@ tests/%: .FORCE
|
||||
# Website assets
|
||||
##########################################
|
||||
|
||||
WEBSITE_ASSETS = grammar.html catala.html
|
||||
WEBSITE_ASSETS = grammar.html catala.html clerk.html catala_legifrance.html
|
||||
|
||||
$(addprefix _build/default/,$(WEBSITE_ASSETS)):
|
||||
dune build $@
|
||||
@ -386,6 +387,11 @@ help_clerk:
|
||||
help_catala:
|
||||
$(CATALA_BIN) --help
|
||||
|
||||
#> help_catala_legifrance : Display the catala_legifrance man page
|
||||
help_catala_legifrance:
|
||||
$(CATALA_LEGIFRANCE_BIN) --help
|
||||
|
||||
|
||||
##########################################
|
||||
# Special targets
|
||||
##########################################
|
||||
|
@ -16,7 +16,7 @@
|
||||
the License. *)
|
||||
|
||||
open Cmdliner
|
||||
open Utils
|
||||
open Catala_utils
|
||||
open Ninja_utils
|
||||
module Nj = Ninja_utils
|
||||
|
||||
@ -524,7 +524,7 @@ let collect_all_ninja_build
|
||||
(tested_file : string)
|
||||
(reset_test_outputs : bool) : (string * ninja) option =
|
||||
let expected_outputs = search_for_expected_outputs tested_file in
|
||||
if List.length expected_outputs = 0 then (
|
||||
if expected_outputs = [] then (
|
||||
Cli.debug_print "No expected outputs were found for test file %s"
|
||||
tested_file;
|
||||
None)
|
||||
@ -890,10 +890,18 @@ let driver
|
||||
let files_or_folders = List.sort_uniq String.compare files_or_folders
|
||||
and catala_exe = Option.fold ~none:"catala" ~some:Fun.id catala_exe
|
||||
and catala_opts = Option.fold ~none:"" ~some:Fun.id catala_opts
|
||||
and ninja_output =
|
||||
Option.fold
|
||||
~none:(Filename.temp_file "clerk_build_" ".ninja")
|
||||
~some:Fun.id ninja_output
|
||||
and with_ninja_output k =
|
||||
match ninja_output with
|
||||
| Some f -> k f
|
||||
| None -> (
|
||||
let f = Filename.temp_file "clerk_build_" ".ninja" in
|
||||
match k f with
|
||||
| exception e ->
|
||||
if not debug then Sys.remove f;
|
||||
raise e
|
||||
| r ->
|
||||
Sys.remove f;
|
||||
r)
|
||||
in
|
||||
match String.lowercase_ascii command with
|
||||
| "test" -> (
|
||||
@ -919,20 +927,22 @@ let driver
|
||||
if 0 = List.compare_lengths ctx.all_failed_names files_or_folders then
|
||||
return_ok
|
||||
else
|
||||
try
|
||||
File.with_formatter_of_file ninja_output (fun fmt ->
|
||||
Cli.debug_print "writing %s..." ninja_output;
|
||||
with_ninja_output
|
||||
@@ fun nin ->
|
||||
match
|
||||
File.with_formatter_of_file nin (fun fmt ->
|
||||
Cli.debug_print "writing %s..." nin;
|
||||
Nj.format fmt
|
||||
(add_root_test_build ninja ctx.all_file_names
|
||||
ctx.all_test_builds));
|
||||
ctx.all_test_builds))
|
||||
with
|
||||
| () ->
|
||||
let ninja_cmd =
|
||||
"ninja -k 0 -f " ^ ninja_output ^ " " ^ ninja_flags ^ " test"
|
||||
"ninja -k 0 -f " ^ nin ^ " " ^ ninja_flags ^ " test"
|
||||
in
|
||||
Cli.debug_print "executing '%s'..." ninja_cmd;
|
||||
let return = Sys.command ninja_cmd in
|
||||
if not debug then Sys.remove ninja_output;
|
||||
return
|
||||
with Sys_error e ->
|
||||
Sys.command ninja_cmd
|
||||
| exception Sys_error e ->
|
||||
Cli.error_print "can not write in %s" e;
|
||||
return_err)
|
||||
| "run" -> (
|
||||
|
@ -9,7 +9,7 @@
|
||||
(public_name clerk.driver)
|
||||
(libraries
|
||||
catala.runtime_ocaml
|
||||
catala.utils
|
||||
catala.catala_utils
|
||||
ninja_utils
|
||||
cmdliner
|
||||
re
|
||||
|
@ -34,6 +34,7 @@ depends: [
|
||||
"ppx_yojson_conv" {>= "0.14.0"}
|
||||
"re" {>= "1.9.0"}
|
||||
"sedlex" {>= "2.4"}
|
||||
"uutf" {>= "1.0.3"}
|
||||
"ubase" {>= "0.05"}
|
||||
"unionFind" {>= "20200320"}
|
||||
"visitors" {>= "20200210"}
|
||||
@ -45,6 +46,7 @@ depends: [
|
||||
"obelisk" {cataladevmode}
|
||||
"conf-npm" {cataladevmode}
|
||||
"conf-python-3-dev" {cataladevmode}
|
||||
"cpdf" {cataladevmode}
|
||||
"z3" {catalaz3mode}
|
||||
]
|
||||
depopts: ["z3"]
|
||||
|
@ -7,12 +7,12 @@ In {{: desugared.html} the desugared representation} or in the
|
||||
global identifiers. These identifiers use OCaml's type system to statically
|
||||
distinguish e.g. a scope identifier from a struct identifier.
|
||||
|
||||
The {!module: Utils.Uid} module provides a generative functor whose output is
|
||||
The {!module: Uid} module provides a generative functor whose output is
|
||||
a fresh sort of global identifiers.
|
||||
|
||||
Related modules:
|
||||
|
||||
{!modules: Utils.Uid}
|
||||
{!modules: Uid}
|
||||
|
||||
{1 Source code positions}
|
||||
|
||||
@ -22,7 +22,7 @@ code. These annotations are critical to produce readable error messages.
|
||||
|
||||
Related modules:
|
||||
|
||||
{!modules: Utils.Pos}
|
||||
{!modules: Pos}
|
||||
|
||||
{1 Error messages}
|
||||
|
@ -172,7 +172,7 @@ let plugins_dirs =
|
||||
let default =
|
||||
let ( / ) = Filename.concat in
|
||||
[
|
||||
Sys.executable_name
|
||||
Filename.dirname Sys.executable_name
|
||||
/ Filename.parent_dir_name
|
||||
/ "lib"
|
||||
/ "catala"
|
@ -1,8 +1,8 @@
|
||||
(library
|
||||
(name utils)
|
||||
(public_name catala.utils)
|
||||
(name catala_utils)
|
||||
(public_name catala.catala_utils)
|
||||
(libraries cmdliner ubase ANSITerminal re bindlib catala.runtime_ocaml))
|
||||
|
||||
(documentation
|
||||
(package catala)
|
||||
(mld_files utils))
|
||||
(mld_files catala_utils))
|
@ -26,7 +26,7 @@ exception StructuredError of (string * (string option * Pos.t) list)
|
||||
let print_structured_error (msg : string) (pos : (string option * Pos.t) list) :
|
||||
string =
|
||||
Printf.sprintf "%s%s%s" msg
|
||||
(if List.length pos = 0 then "" else "\n\n")
|
||||
(if pos = [] then "" else "\n\n")
|
||||
(String.concat "\n\n"
|
||||
(List.map
|
||||
(fun (msg, pos) ->
|
@ -79,11 +79,11 @@ let to_string (pos : t) : string =
|
||||
let to_string_short (pos : t) : string =
|
||||
let s, e = pos.code_pos in
|
||||
if e.Lexing.pos_lnum = s.Lexing.pos_lnum then
|
||||
Printf.sprintf "%s:%d.%d-%d" s.Lexing.pos_fname s.Lexing.pos_lnum
|
||||
Printf.sprintf "%s:%d.%d-%d:" s.Lexing.pos_fname s.Lexing.pos_lnum
|
||||
(s.Lexing.pos_cnum - s.Lexing.pos_bol)
|
||||
(e.Lexing.pos_cnum - e.Lexing.pos_bol)
|
||||
else
|
||||
Printf.sprintf "%s:%d.%d-%d.%d" s.Lexing.pos_fname s.Lexing.pos_lnum
|
||||
Printf.sprintf "%s:%d.%d-%d.%d:" s.Lexing.pos_fname s.Lexing.pos_lnum
|
||||
(s.Lexing.pos_cnum - s.Lexing.pos_bol)
|
||||
e.Lexing.pos_lnum
|
||||
(e.Lexing.pos_cnum - e.Lexing.pos_bol)
|
||||
@ -102,6 +102,27 @@ let string_repeat n s =
|
||||
done;
|
||||
Bytes.to_string buf
|
||||
|
||||
(* Note: this should do, but remains incorrect for combined unicode characters
|
||||
that display as one (e.g. `e` + postfix `'`). We should switch to Uuseg at
|
||||
some poing *)
|
||||
let string_columns s =
|
||||
let len = String.length s in
|
||||
let rec aux ncols i =
|
||||
if i >= len then ncols
|
||||
else if s.[i] = '\t' then aux (ncols + 8) (i + 1)
|
||||
else
|
||||
aux (ncols + 1) (i + Uchar.utf_decode_length (String.get_utf_8_uchar s i))
|
||||
in
|
||||
aux 0 0
|
||||
|
||||
let utf8_byte_index s ui0 =
|
||||
let rec aux bi ui =
|
||||
if ui >= ui0 then bi
|
||||
else
|
||||
aux (bi + Uchar.utf_decode_length (String.get_utf_8_uchar s bi)) (ui + 1)
|
||||
in
|
||||
aux 0 0
|
||||
|
||||
let retrieve_loc_text (pos : t) : string =
|
||||
try
|
||||
let filename = get_file pos in
|
||||
@ -132,34 +153,32 @@ let retrieve_loc_text (pos : t) : string =
|
||||
let print_matched_line (line : string) (line_no : int) : string =
|
||||
let line_indent = indent_number line in
|
||||
let error_indicator_style = [ANSITerminal.red; ANSITerminal.Bold] in
|
||||
line
|
||||
^
|
||||
if line_no >= sline && line_no <= eline then
|
||||
"\n"
|
||||
^
|
||||
if line_no = sline && line_no = eline then
|
||||
Cli.with_style error_indicator_style "%*s%s"
|
||||
(get_start_column pos - 1)
|
||||
""
|
||||
(string_repeat
|
||||
(max (get_end_column pos - get_start_column pos) 0)
|
||||
"‾")
|
||||
else if line_no = sline && line_no <> eline then
|
||||
Cli.with_style error_indicator_style "%*s%s"
|
||||
(get_start_column pos - 1)
|
||||
""
|
||||
(string_repeat
|
||||
(max (String.length line - get_start_column pos) 0)
|
||||
"‾")
|
||||
else if line_no <> sline && line_no <> eline then
|
||||
Cli.with_style error_indicator_style "%*s%s" line_indent ""
|
||||
(string_repeat (max (String.length line - line_indent) 0) "‾")
|
||||
else if line_no <> sline && line_no = eline then
|
||||
Cli.with_style error_indicator_style "%*s%*s" line_indent ""
|
||||
(get_end_column pos - 1 - line_indent)
|
||||
(string_repeat (max (get_end_column pos - line_indent) 0) "‾")
|
||||
else assert false (* should not happen *)
|
||||
else ""
|
||||
let match_start_index =
|
||||
utf8_byte_index line
|
||||
(if line_no = sline then get_start_column pos - 1 else line_indent)
|
||||
in
|
||||
let match_end_index =
|
||||
if line_no = eline then utf8_byte_index line (get_end_column pos - 1)
|
||||
else String.length line
|
||||
in
|
||||
let unmatched_prefix = String.sub line 0 match_start_index in
|
||||
let matched_substring =
|
||||
String.sub line match_start_index
|
||||
(max 0 (match_end_index - match_start_index))
|
||||
in
|
||||
let match_start_col = string_columns unmatched_prefix in
|
||||
let match_num_cols = string_columns matched_substring in
|
||||
String.concat ""
|
||||
(line
|
||||
:: "\n"
|
||||
::
|
||||
(if line_no >= sline && line_no <= eline then
|
||||
[
|
||||
string_repeat match_start_col " ";
|
||||
Cli.with_style error_indicator_style "%s"
|
||||
(string_repeat match_num_cols "‾");
|
||||
]
|
||||
else []))
|
||||
in
|
||||
let include_extra_count = 0 in
|
||||
let rec get_lines (n : int) : string list =
|
||||
@ -193,10 +212,8 @@ let retrieve_loc_text (pos : t) : string =
|
||||
(Cli.with_style blue_style "└%s┐" (string_repeat spaces "─"));
|
||||
Buffer.add_char buf '\n';
|
||||
Buffer.add_string buf
|
||||
(Cli.add_prefix_to_each_line
|
||||
(String.concat "\n" ("" :: pos_lines))
|
||||
(fun i ->
|
||||
let cur_line = sline - include_extra_count + i - 1 in
|
||||
(Cli.add_prefix_to_each_line (String.concat "\n" pos_lines) (fun i ->
|
||||
let cur_line = sline - include_extra_count + i in
|
||||
if
|
||||
cur_line >= sline
|
||||
&& cur_line <= sline + (2 * (eline - sline))
|
@ -14,39 +14,47 @@
|
||||
License for the specific language governing permissions and limitations under
|
||||
the License. *)
|
||||
|
||||
let to_ascii : string -> string = Ubase.from_utf8
|
||||
include Stdlib.String
|
||||
|
||||
let is_uppercase_ascii (c : char) : bool =
|
||||
let c = Char.code c in
|
||||
(* 'A' <= c && c <= 'Z' *)
|
||||
0x41 <= c && c <= 0x5b
|
||||
let to_ascii : string -> string = Ubase.from_utf8
|
||||
let is_uppercase_ascii = function 'A' .. 'Z' -> true | _ -> false
|
||||
|
||||
let begins_with_uppercase (s : string) : bool =
|
||||
if "" = s then false else is_uppercase_ascii (to_ascii s).[0]
|
||||
"" <> s && is_uppercase_ascii (get (to_ascii s) 0)
|
||||
|
||||
let to_snake_case (s : string) : string =
|
||||
let out = ref "" in
|
||||
to_ascii s
|
||||
|> String.iteri (fun i c ->
|
||||
|> iteri (fun i c ->
|
||||
out :=
|
||||
!out
|
||||
^ (if is_uppercase_ascii c && 0 <> i then "_" else "")
|
||||
^ String.lowercase_ascii (String.make 1 c));
|
||||
^ lowercase_ascii (make 1 c));
|
||||
!out
|
||||
|
||||
let to_camel_case (s : string) : string =
|
||||
let last_was_underscore = ref false in
|
||||
let out = ref "" in
|
||||
to_ascii s
|
||||
|> String.iteri (fun i c ->
|
||||
|> iteri (fun i c ->
|
||||
let is_underscore = c = '_' in
|
||||
let c_string = String.make 1 c in
|
||||
let c_string = make 1 c in
|
||||
out :=
|
||||
!out
|
||||
^
|
||||
if is_underscore then ""
|
||||
else if !last_was_underscore || 0 = i then
|
||||
String.uppercase_ascii c_string
|
||||
else if !last_was_underscore || 0 = i then uppercase_ascii c_string
|
||||
else c_string;
|
||||
last_was_underscore := is_underscore);
|
||||
!out
|
||||
|
||||
let remove_prefix ~prefix s =
|
||||
if starts_with ~prefix s then
|
||||
let plen = length prefix in
|
||||
sub s plen (length s - plen)
|
||||
else s
|
||||
|
||||
let format_t = Format.pp_print_string
|
||||
|
||||
module Set = Set.Make (Stdlib.String)
|
||||
module Map = Map.Make (Stdlib.String)
|
@ -14,6 +14,10 @@
|
||||
License for the specific language governing permissions and limitations under
|
||||
the License. *)
|
||||
|
||||
include module type of Stdlib.String
|
||||
module Set : Set.S with type elt = string
|
||||
module Map : Map.S with type key = string
|
||||
|
||||
(** Helper functions used for string manipulation. *)
|
||||
|
||||
val to_ascii : string -> string
|
||||
@ -34,3 +38,11 @@ val to_snake_case : string -> string
|
||||
val to_camel_case : string -> string
|
||||
(** Converts snake_case into CamlCase after removing Remove all diacritics on
|
||||
Latin letters. *)
|
||||
|
||||
val remove_prefix : prefix:string -> string -> string
|
||||
(** [remove_prefix ~prefix str] returns
|
||||
|
||||
- if [str] starts with [prefix], a string [s] such that [prefix ^ s = str]
|
||||
- otherwise, [str] unchanged *)
|
||||
|
||||
val format_t : Format.formatter -> string -> unit
|
@ -18,7 +18,7 @@ module type Info = sig
|
||||
type info
|
||||
|
||||
val to_string : info -> string
|
||||
val format_info : Format.formatter -> info -> unit
|
||||
val format : Format.formatter -> info -> unit
|
||||
val equal : info -> info -> bool
|
||||
val compare : info -> info -> int
|
||||
end
|
||||
@ -33,10 +33,21 @@ module type Id = sig
|
||||
val equal : t -> t -> bool
|
||||
val format_t : Format.formatter -> t -> unit
|
||||
val hash : t -> int
|
||||
|
||||
module Set : Set.S with type elt = t
|
||||
module Map : Map.S with type key = t
|
||||
end
|
||||
|
||||
module Make (X : Info) () : Id with type info = X.info = struct
|
||||
type t = { id : int; info : X.info }
|
||||
module Ordering = struct
|
||||
type t = { id : int; info : X.info }
|
||||
|
||||
let compare (x : t) (y : t) : int = compare x.id y.id
|
||||
let equal x y = Int.equal x.id y.id
|
||||
end
|
||||
|
||||
include Ordering
|
||||
|
||||
type info = X.info
|
||||
|
||||
let counter = ref 0
|
||||
@ -46,20 +57,20 @@ module Make (X : Info) () : Id with type info = X.info = struct
|
||||
{ id = !counter; info }
|
||||
|
||||
let get_info (uid : t) : X.info = uid.info
|
||||
let compare (x : t) (y : t) : int = compare x.id y.id
|
||||
let equal x y = Int.equal x.id y.id
|
||||
|
||||
let format_t (fmt : Format.formatter) (x : t) : unit =
|
||||
X.format_info fmt x.info
|
||||
|
||||
let format_t (fmt : Format.formatter) (x : t) : unit = X.format fmt x.info
|
||||
let hash (x : t) : int = x.id
|
||||
|
||||
module Set = Set.Make (Ordering)
|
||||
module Map = Map.Make (Ordering)
|
||||
end
|
||||
|
||||
module MarkedString = struct
|
||||
type info = string Marked.pos
|
||||
|
||||
let to_string (s, _) = s
|
||||
let format_info fmt i = Format.pp_print_string fmt (to_string i)
|
||||
let format fmt i = Format.pp_print_string fmt (to_string i)
|
||||
let equal i1 i2 = String.equal (Marked.unmark i1) (Marked.unmark i2)
|
||||
let compare i1 i2 = String.compare (Marked.unmark i1) (Marked.unmark i2)
|
||||
end
|
||||
|
||||
module Gen () = Make (MarkedString) ()
|
@ -21,7 +21,7 @@ module type Info = sig
|
||||
type info
|
||||
|
||||
val to_string : info -> string
|
||||
val format_info : Format.formatter -> info -> unit
|
||||
val format : Format.formatter -> info -> unit
|
||||
|
||||
val equal : info -> info -> bool
|
||||
(** Equality disregards position *)
|
||||
@ -48,9 +48,15 @@ module type Id = sig
|
||||
val equal : t -> t -> bool
|
||||
val format_t : Format.formatter -> t -> unit
|
||||
val hash : t -> int
|
||||
|
||||
module Set : Set.S with type elt = t
|
||||
module Map : Map.S with type key = t
|
||||
end
|
||||
|
||||
(** This is the generative functor that ensures that two modules resulting from
|
||||
two different calls to [Make] will be viewed as different types [t] by the
|
||||
OCaml typechecker. Prevents mixing up different sorts of identifiers. *)
|
||||
module Make (X : Info) () : Id with type info = X.info
|
||||
|
||||
module Gen () : Id with type info = MarkedString.info
|
||||
(** Shortcut for creating a kind of uids over marked strings *)
|
@ -1,3 +1,4 @@
|
||||
open Catala_utils
|
||||
open Driver
|
||||
open Js_of_ocaml
|
||||
|
||||
@ -12,7 +13,7 @@ let _ =
|
||||
driver
|
||||
(Contents (Js.to_string contents))
|
||||
{
|
||||
Utils.Cli.debug = false;
|
||||
Cli.debug = false;
|
||||
color = Never;
|
||||
wrap_weaved_output = false;
|
||||
avoid_exceptions = false;
|
||||
|
@ -1,7 +1,15 @@
|
||||
(library
|
||||
(name dcalc)
|
||||
(public_name catala.dcalc)
|
||||
(libraries bindlib unionFind utils re ubase catala.runtime_ocaml shared_ast)
|
||||
(libraries
|
||||
bindlib
|
||||
unionFind
|
||||
catala_utils
|
||||
re
|
||||
ubase
|
||||
catala.runtime_ocaml
|
||||
shared_ast
|
||||
scopelang)
|
||||
(preprocess
|
||||
(pps visitors.ppx)))
|
||||
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -16,4 +16,4 @@
|
||||
|
||||
(** Scope language to default calculus translator *)
|
||||
|
||||
val translate_program : 'm Ast.program -> 'm Dcalc.Ast.program
|
||||
val translate_program : 'm Scopelang.Ast.program -> 'm Ast.program
|
@ -16,7 +16,7 @@
|
||||
|
||||
(** Reference interpreter for the default calculus *)
|
||||
|
||||
open Utils
|
||||
open Catala_utils
|
||||
open Shared_ast
|
||||
module Runtime = Runtime_ocaml.Runtime
|
||||
|
||||
@ -29,272 +29,117 @@ let log_indent = ref 0
|
||||
|
||||
(** {1 Evaluation} *)
|
||||
|
||||
let rec evaluate_operator
|
||||
(ctx : decl_ctx)
|
||||
(op : operator)
|
||||
(pos : Pos.t)
|
||||
(args : 'm Ast.expr list) : 'm Ast.naked_expr =
|
||||
(* Try to apply [div] and if a [Division_by_zero] exceptions is catched, use
|
||||
[op] to raise multispanned errors. *)
|
||||
let apply_div_or_raise_err (div : unit -> 'm Ast.naked_expr) :
|
||||
'm Ast.naked_expr =
|
||||
try div ()
|
||||
with Division_by_zero ->
|
||||
let print_log ctx entry infos pos e =
|
||||
if !Cli.trace_flag then
|
||||
match entry with
|
||||
| VarDef _ ->
|
||||
(* TODO: this usage of Format is broken, Formatting requires that all is
|
||||
formatted in one pass, without going through intermediate "%s" *)
|
||||
Cli.log_format "%*s%a %a: %s" (!log_indent * 2) "" Print.log_entry entry
|
||||
Print.uid_list infos
|
||||
(match Marked.unmark e with
|
||||
| EAbs _ -> Cli.with_style [ANSITerminal.green] "<function>"
|
||||
| _ ->
|
||||
let expr_str =
|
||||
Format.asprintf "%a" (Expr.format ctx ~debug:false) e
|
||||
in
|
||||
let expr_str =
|
||||
Re.Pcre.substitute ~rex:(Re.Pcre.regexp "\n\\s*")
|
||||
~subst:(fun _ -> " ")
|
||||
expr_str
|
||||
in
|
||||
Cli.with_style [ANSITerminal.green] "%s" expr_str)
|
||||
| PosRecordIfTrueBool -> (
|
||||
match pos <> Pos.no_pos, Marked.unmark e with
|
||||
| true, ELit (LBool true) ->
|
||||
Cli.log_format "%*s%a%s:\n%s" (!log_indent * 2) "" Print.log_entry entry
|
||||
(Cli.with_style [ANSITerminal.green] "Definition applied")
|
||||
(Cli.add_prefix_to_each_line (Pos.retrieve_loc_text pos) (fun _ ->
|
||||
Format.asprintf "%*s" (!log_indent * 2) ""))
|
||||
| _ -> ())
|
||||
| BeginCall ->
|
||||
Cli.log_format "%*s%a %a" (!log_indent * 2) "" Print.log_entry entry
|
||||
Print.uid_list infos;
|
||||
log_indent := !log_indent + 1
|
||||
| EndCall ->
|
||||
log_indent := !log_indent - 1;
|
||||
Cli.log_format "%*s%a %a" (!log_indent * 2) "" Print.log_entry entry
|
||||
Print.uid_list infos
|
||||
|
||||
(* Todo: this should be handled early when resolving overloads. Here we have
|
||||
proper structural equality, but the OCaml backend for example uses the
|
||||
builtin equality function instead of this. *)
|
||||
let rec handle_eq ctx pos e1 e2 =
|
||||
let open Runtime.Oper in
|
||||
match e1, e2 with
|
||||
| ELit LUnit, ELit LUnit -> true
|
||||
| ELit (LBool b1), ELit (LBool b2) -> not (o_xor b1 b2)
|
||||
| ELit (LInt x1), ELit (LInt x2) -> o_eq_int_int x1 x2
|
||||
| ELit (LRat x1), ELit (LRat x2) -> o_eq_rat_rat x1 x2
|
||||
| ELit (LMoney x1), ELit (LMoney x2) -> o_eq_mon_mon x1 x2
|
||||
| ELit (LDuration x1), ELit (LDuration x2) -> o_eq_dur_dur x1 x2
|
||||
| ELit (LDate x1), ELit (LDate x2) -> o_eq_dat_dat x1 x2
|
||||
| EArray es1, EArray es2 -> (
|
||||
try
|
||||
List.for_all2
|
||||
(fun e1 e2 ->
|
||||
match evaluate_operator ctx Eq pos [e1; e2] with
|
||||
| ELit (LBool b) -> b
|
||||
| _ -> assert false
|
||||
(* should not happen *))
|
||||
es1 es2
|
||||
with Invalid_argument _ -> false)
|
||||
| EStruct { fields = es1; name = s1 }, EStruct { fields = es2; name = s2 } ->
|
||||
StructName.equal s1 s2
|
||||
&& StructField.Map.equal
|
||||
(fun e1 e2 ->
|
||||
match evaluate_operator ctx Eq pos [e1; e2] with
|
||||
| ELit (LBool b) -> b
|
||||
| _ -> assert false
|
||||
(* should not happen *))
|
||||
es1 es2
|
||||
| ( EInj { e = e1; cons = i1; name = en1 },
|
||||
EInj { e = e2; cons = i2; name = en2 } ) -> (
|
||||
try
|
||||
EnumName.equal en1 en2
|
||||
&& EnumConstructor.equal i1 i2
|
||||
&&
|
||||
match evaluate_operator ctx Eq pos [e1; e2] with
|
||||
| ELit (LBool b) -> b
|
||||
| _ -> assert false
|
||||
(* should not happen *)
|
||||
with Invalid_argument _ -> false)
|
||||
| _, _ -> false (* comparing anything else return false *)
|
||||
|
||||
(* Call-by-value: the arguments are expected to be already evaluated here *)
|
||||
and evaluate_operator :
|
||||
type k.
|
||||
decl_ctx ->
|
||||
(dcalc, k) operator ->
|
||||
Pos.t ->
|
||||
'm Ast.expr list ->
|
||||
'm Ast.naked_expr =
|
||||
fun ctx op pos args ->
|
||||
let protect f x y =
|
||||
let get_binop_args_pos = function
|
||||
| (arg0 :: arg1 :: _ : 'm Ast.expr list) ->
|
||||
[None, Expr.pos arg0; None, Expr.pos arg1]
|
||||
| _ -> assert false
|
||||
in
|
||||
try f x y with
|
||||
| Division_by_zero ->
|
||||
Errors.raise_multispanned_error
|
||||
[
|
||||
Some "The division operator:", pos;
|
||||
Some "The null denominator:", Expr.pos (List.nth args 1);
|
||||
]
|
||||
"division by zero at runtime"
|
||||
in
|
||||
let get_binop_args_pos = function
|
||||
| (arg0 :: arg1 :: _ : 'm Ast.expr list) ->
|
||||
[None, Expr.pos arg0; None, Expr.pos arg1]
|
||||
| _ -> assert false
|
||||
in
|
||||
(* Try to apply [cmp] and if a [UncomparableDurations] exceptions is catched,
|
||||
use [args] to raise multispanned errors. *)
|
||||
let apply_cmp_or_raise_err
|
||||
(cmp : unit -> 'm Ast.naked_expr)
|
||||
(args : 'm Ast.expr list) : 'm Ast.naked_expr =
|
||||
try cmp ()
|
||||
with Runtime.UncomparableDurations ->
|
||||
| Runtime.UncomparableDurations ->
|
||||
Errors.raise_multispanned_error (get_binop_args_pos args)
|
||||
"Cannot compare together durations that cannot be converted to a \
|
||||
precise number of days"
|
||||
in
|
||||
match op, List.map Marked.unmark args with
|
||||
| Ternop Fold, [_f; _init; EArray es] ->
|
||||
Marked.unmark
|
||||
(List.fold_left
|
||||
(fun acc e' ->
|
||||
evaluate_expr ctx
|
||||
(Marked.same_mark_as (EApp (List.nth args 0, [acc; e'])) e'))
|
||||
(List.nth args 1) es)
|
||||
| Binop And, [ELit (LBool b1); ELit (LBool b2)] -> ELit (LBool (b1 && b2))
|
||||
| Binop Or, [ELit (LBool b1); ELit (LBool b2)] -> ELit (LBool (b1 || b2))
|
||||
| Binop Xor, [ELit (LBool b1); ELit (LBool b2)] -> ELit (LBool (b1 <> b2))
|
||||
| Binop (Add KInt), [ELit (LInt i1); ELit (LInt i2)] ->
|
||||
ELit (LInt Runtime.(i1 +! i2))
|
||||
| Binop (Sub KInt), [ELit (LInt i1); ELit (LInt i2)] ->
|
||||
ELit (LInt Runtime.(i1 -! i2))
|
||||
| Binop (Mult KInt), [ELit (LInt i1); ELit (LInt i2)] ->
|
||||
ELit (LInt Runtime.(i1 *! i2))
|
||||
| Binop (Div KInt), [ELit (LInt i1); ELit (LInt i2)] ->
|
||||
apply_div_or_raise_err (fun _ -> ELit (LInt Runtime.(i1 /! i2)))
|
||||
| Binop (Add KRat), [ELit (LRat i1); ELit (LRat i2)] ->
|
||||
ELit (LRat Runtime.(i1 +& i2))
|
||||
| Binop (Sub KRat), [ELit (LRat i1); ELit (LRat i2)] ->
|
||||
ELit (LRat Runtime.(i1 -& i2))
|
||||
| Binop (Mult KRat), [ELit (LRat i1); ELit (LRat i2)] ->
|
||||
ELit (LRat Runtime.(i1 *& i2))
|
||||
| Binop (Div KRat), [ELit (LRat i1); ELit (LRat i2)] ->
|
||||
apply_div_or_raise_err (fun _ -> ELit (LRat Runtime.(i1 /& i2)))
|
||||
| Binop (Add KMoney), [ELit (LMoney m1); ELit (LMoney m2)] ->
|
||||
ELit (LMoney Runtime.(m1 +$ m2))
|
||||
| Binop (Sub KMoney), [ELit (LMoney m1); ELit (LMoney m2)] ->
|
||||
ELit (LMoney Runtime.(m1 -$ m2))
|
||||
| Binop (Mult KMoney), [ELit (LMoney m1); ELit (LRat m2)] ->
|
||||
ELit (LMoney Runtime.(m1 *$ m2))
|
||||
| Binop (Div KMoney), [ELit (LMoney m1); ELit (LMoney m2)] ->
|
||||
apply_div_or_raise_err (fun _ -> ELit (LRat Runtime.(m1 /$ m2)))
|
||||
| Binop (Add KDuration), [ELit (LDuration d1); ELit (LDuration d2)] ->
|
||||
ELit (LDuration Runtime.(d1 +^ d2))
|
||||
| Binop (Sub KDuration), [ELit (LDuration d1); ELit (LDuration d2)] ->
|
||||
ELit (LDuration Runtime.(d1 -^ d2))
|
||||
| Binop (Sub KDate), [ELit (LDate d1); ELit (LDate d2)] ->
|
||||
ELit (LDuration Runtime.(d1 -@ d2))
|
||||
| Binop (Add KDate), [ELit (LDate d1); ELit (LDuration d2)] ->
|
||||
ELit (LDate Runtime.(d1 +@ d2))
|
||||
| Binop (Mult KDuration), [ELit (LDuration d1); ELit (LInt i1)] ->
|
||||
ELit (LDuration Runtime.(d1 *^ i1))
|
||||
| Binop (Lt KInt), [ELit (LInt i1); ELit (LInt i2)] ->
|
||||
ELit (LBool Runtime.(i1 <! i2))
|
||||
| Binop (Lte KInt), [ELit (LInt i1); ELit (LInt i2)] ->
|
||||
ELit (LBool Runtime.(i1 <=! i2))
|
||||
| Binop (Gt KInt), [ELit (LInt i1); ELit (LInt i2)] ->
|
||||
ELit (LBool Runtime.(i1 >! i2))
|
||||
| Binop (Gte KInt), [ELit (LInt i1); ELit (LInt i2)] ->
|
||||
ELit (LBool Runtime.(i1 >=! i2))
|
||||
| Binop (Lt KRat), [ELit (LRat i1); ELit (LRat i2)] ->
|
||||
ELit (LBool Runtime.(i1 <& i2))
|
||||
| Binop (Lte KRat), [ELit (LRat i1); ELit (LRat i2)] ->
|
||||
ELit (LBool Runtime.(i1 <=& i2))
|
||||
| Binop (Gt KRat), [ELit (LRat i1); ELit (LRat i2)] ->
|
||||
ELit (LBool Runtime.(i1 >& i2))
|
||||
| Binop (Gte KRat), [ELit (LRat i1); ELit (LRat i2)] ->
|
||||
ELit (LBool Runtime.(i1 >=& i2))
|
||||
| Binop (Lt KMoney), [ELit (LMoney m1); ELit (LMoney m2)] ->
|
||||
ELit (LBool Runtime.(m1 <$ m2))
|
||||
| Binop (Lte KMoney), [ELit (LMoney m1); ELit (LMoney m2)] ->
|
||||
ELit (LBool Runtime.(m1 <=$ m2))
|
||||
| Binop (Gt KMoney), [ELit (LMoney m1); ELit (LMoney m2)] ->
|
||||
ELit (LBool Runtime.(m1 >$ m2))
|
||||
| Binop (Gte KMoney), [ELit (LMoney m1); ELit (LMoney m2)] ->
|
||||
ELit (LBool Runtime.(m1 >=$ m2))
|
||||
| Binop (Lt KDuration), [ELit (LDuration d1); ELit (LDuration d2)] ->
|
||||
apply_cmp_or_raise_err (fun _ -> ELit (LBool Runtime.(d1 <^ d2))) args
|
||||
| Binop (Lte KDuration), [ELit (LDuration d1); ELit (LDuration d2)] ->
|
||||
apply_cmp_or_raise_err (fun _ -> ELit (LBool Runtime.(d1 <=^ d2))) args
|
||||
| Binop (Gt KDuration), [ELit (LDuration d1); ELit (LDuration d2)] ->
|
||||
apply_cmp_or_raise_err (fun _ -> ELit (LBool Runtime.(d1 >^ d2))) args
|
||||
| Binop (Gte KDuration), [ELit (LDuration d1); ELit (LDuration d2)] ->
|
||||
apply_cmp_or_raise_err (fun _ -> ELit (LBool Runtime.(d1 >=^ d2))) args
|
||||
| Binop (Lt KDate), [ELit (LDate d1); ELit (LDate d2)] ->
|
||||
ELit (LBool Runtime.(d1 <@ d2))
|
||||
| Binop (Lte KDate), [ELit (LDate d1); ELit (LDate d2)] ->
|
||||
ELit (LBool Runtime.(d1 <=@ d2))
|
||||
| Binop (Gt KDate), [ELit (LDate d1); ELit (LDate d2)] ->
|
||||
ELit (LBool Runtime.(d1 >@ d2))
|
||||
| Binop (Gte KDate), [ELit (LDate d1); ELit (LDate d2)] ->
|
||||
ELit (LBool Runtime.(d1 >=@ d2))
|
||||
| Binop Eq, [ELit LUnit; ELit LUnit] -> ELit (LBool true)
|
||||
| Binop Eq, [ELit (LDuration d1); ELit (LDuration d2)] ->
|
||||
ELit (LBool Runtime.(d1 =^ d2))
|
||||
| Binop Eq, [ELit (LDate d1); ELit (LDate d2)] ->
|
||||
ELit (LBool Runtime.(d1 =@ d2))
|
||||
| Binop Eq, [ELit (LMoney m1); ELit (LMoney m2)] ->
|
||||
ELit (LBool Runtime.(m1 =$ m2))
|
||||
| Binop Eq, [ELit (LRat i1); ELit (LRat i2)] ->
|
||||
ELit (LBool Runtime.(i1 =& i2))
|
||||
| Binop Eq, [ELit (LInt i1); ELit (LInt i2)] ->
|
||||
ELit (LBool Runtime.(i1 =! i2))
|
||||
| Binop Eq, [ELit (LBool b1); ELit (LBool b2)] -> ELit (LBool (b1 = b2))
|
||||
| Binop Eq, [EArray es1; EArray es2] ->
|
||||
ELit
|
||||
(LBool
|
||||
(try
|
||||
List.for_all2
|
||||
(fun e1 e2 ->
|
||||
match evaluate_operator ctx op pos [e1; e2] with
|
||||
| ELit (LBool b) -> b
|
||||
| _ -> assert false
|
||||
(* should not happen *))
|
||||
es1 es2
|
||||
with Invalid_argument _ -> false))
|
||||
| Binop Eq, [ETuple (es1, s1); ETuple (es2, s2)] ->
|
||||
ELit
|
||||
(LBool
|
||||
(try
|
||||
s1 = s2
|
||||
&& List.for_all2
|
||||
(fun e1 e2 ->
|
||||
match evaluate_operator ctx op pos [e1; e2] with
|
||||
| ELit (LBool b) -> b
|
||||
| _ -> assert false
|
||||
(* should not happen *))
|
||||
es1 es2
|
||||
with Invalid_argument _ -> false))
|
||||
| Binop Eq, [EInj (e1, i1, en1, _ts1); EInj (e2, i2, en2, _ts2)] ->
|
||||
ELit
|
||||
(LBool
|
||||
(try
|
||||
en1 = en2
|
||||
&& i1 = i2
|
||||
&&
|
||||
match evaluate_operator ctx op pos [e1; e2] with
|
||||
| ELit (LBool b) -> b
|
||||
| _ -> assert false
|
||||
(* should not happen *)
|
||||
with Invalid_argument _ -> false))
|
||||
| Binop Eq, [_; _] ->
|
||||
ELit (LBool false) (* comparing anything else return false *)
|
||||
| Binop Neq, [_; _] -> (
|
||||
match evaluate_operator ctx (Binop Eq) pos args with
|
||||
| ELit (LBool b) -> ELit (LBool (not b))
|
||||
| _ -> assert false (*should not happen *))
|
||||
| Binop Concat, [EArray es1; EArray es2] -> EArray (es1 @ es2)
|
||||
| Binop Map, [_; EArray es] ->
|
||||
EArray
|
||||
(List.map
|
||||
(fun e' ->
|
||||
evaluate_expr ctx
|
||||
(Marked.same_mark_as (EApp (List.nth args 0, [e'])) e'))
|
||||
es)
|
||||
| Binop Filter, [_; EArray es] ->
|
||||
EArray
|
||||
(List.filter
|
||||
(fun e' ->
|
||||
match
|
||||
evaluate_expr ctx
|
||||
(Marked.same_mark_as (EApp (List.nth args 0, [e'])) e')
|
||||
with
|
||||
| ELit (LBool b), _ -> b
|
||||
| _ ->
|
||||
Errors.raise_spanned_error
|
||||
(Expr.pos (List.nth args 0))
|
||||
"This predicate evaluated to something else than a boolean \
|
||||
(should not happen if the term was well-typed)")
|
||||
es)
|
||||
| Binop _, ([ELit LEmptyError; _] | [_; ELit LEmptyError]) -> ELit LEmptyError
|
||||
| Unop (Minus KInt), [ELit (LInt i)] ->
|
||||
ELit (LInt Runtime.(integer_of_int 0 -! i))
|
||||
| Unop (Minus KRat), [ELit (LRat i)] ->
|
||||
ELit (LRat Runtime.(decimal_of_string "0" -& i))
|
||||
| Unop (Minus KMoney), [ELit (LMoney i)] ->
|
||||
ELit (LMoney Runtime.(money_of_units_int 0 -$ i))
|
||||
| Unop (Minus KDuration), [ELit (LDuration i)] ->
|
||||
ELit (LDuration Runtime.(~-^i))
|
||||
| Unop Not, [ELit (LBool b)] -> ELit (LBool (not b))
|
||||
| Unop Length, [EArray es] ->
|
||||
ELit (LInt (Runtime.integer_of_int (List.length es)))
|
||||
| Unop GetDay, [ELit (LDate d)] ->
|
||||
ELit (LInt Runtime.(day_of_month_of_date d))
|
||||
| Unop GetMonth, [ELit (LDate d)] ->
|
||||
ELit (LInt Runtime.(month_number_of_date d))
|
||||
| Unop GetYear, [ELit (LDate d)] -> ELit (LInt Runtime.(year_of_date d))
|
||||
| Unop FirstDayOfMonth, [ELit (LDate d)] ->
|
||||
ELit (LDate Runtime.(first_day_of_month d))
|
||||
| Unop LastDayOfMonth, [ELit (LDate d)] ->
|
||||
ELit (LDate Runtime.(first_day_of_month d))
|
||||
| Unop IntToRat, [ELit (LInt i)] -> ELit (LRat Runtime.(decimal_of_integer i))
|
||||
| Unop MoneyToRat, [ELit (LMoney i)] ->
|
||||
ELit (LRat Runtime.(decimal_of_money i))
|
||||
| Unop RatToMoney, [ELit (LRat i)] ->
|
||||
ELit (LMoney Runtime.(money_of_decimal i))
|
||||
| Unop RoundMoney, [ELit (LMoney m)] -> ELit (LMoney Runtime.(money_round m))
|
||||
| Unop RoundDecimal, [ELit (LRat m)] -> ELit (LRat Runtime.(decimal_round m))
|
||||
| Unop (Log (entry, infos)), [e'] ->
|
||||
if !Cli.trace_flag then (
|
||||
match entry with
|
||||
| VarDef _ ->
|
||||
(* TODO: this usage of Format is broken, Formatting requires that all is
|
||||
formatted in one pass, without going through intermediate "%s" *)
|
||||
Cli.log_format "%*s%a %a: %s" (!log_indent * 2) "" Print.log_entry entry
|
||||
Print.uid_list infos
|
||||
(match e' with
|
||||
| EAbs _ -> Cli.with_style [ANSITerminal.green] "<function>"
|
||||
| _ ->
|
||||
let expr_str =
|
||||
Format.asprintf "%a" (Expr.format ctx ~debug:false) (List.hd args)
|
||||
in
|
||||
let expr_str =
|
||||
Re.Pcre.substitute ~rex:(Re.Pcre.regexp "\n\\s*")
|
||||
~subst:(fun _ -> " ")
|
||||
expr_str
|
||||
in
|
||||
Cli.with_style [ANSITerminal.green] "%s" expr_str)
|
||||
| PosRecordIfTrueBool -> (
|
||||
match pos <> Pos.no_pos, e' with
|
||||
| true, ELit (LBool true) ->
|
||||
Cli.log_format "%*s%a%s:\n%s" (!log_indent * 2) "" Print.log_entry
|
||||
entry
|
||||
(Cli.with_style [ANSITerminal.green] "Definition applied")
|
||||
(Cli.add_prefix_to_each_line (Pos.retrieve_loc_text pos) (fun _ ->
|
||||
Format.asprintf "%*s" (!log_indent * 2) ""))
|
||||
| _ -> ())
|
||||
| BeginCall ->
|
||||
Cli.log_format "%*s%a %a" (!log_indent * 2) "" Print.log_entry entry
|
||||
Print.uid_list infos;
|
||||
log_indent := !log_indent + 1
|
||||
| EndCall ->
|
||||
log_indent := !log_indent - 1;
|
||||
Cli.log_format "%*s%a %a" (!log_indent * 2) "" Print.log_entry entry
|
||||
Print.uid_list infos)
|
||||
else ();
|
||||
e'
|
||||
| Unop _, [ELit LEmptyError] -> ELit LEmptyError
|
||||
| _ ->
|
||||
let err () =
|
||||
Errors.raise_multispanned_error
|
||||
([Some "Operator:", pos]
|
||||
@ List.mapi
|
||||
@ -307,6 +152,162 @@ let rec evaluate_operator
|
||||
args)
|
||||
"Operator applied to the wrong arguments\n\
|
||||
(should not happen if the term was well-typed)"
|
||||
in
|
||||
let open Runtime.Oper in
|
||||
if List.exists (function ELit LEmptyError, _ -> true | _ -> false) args then
|
||||
ELit LEmptyError
|
||||
else
|
||||
Operator.kind_dispatch op
|
||||
~polymorphic:(fun op ->
|
||||
match op, args with
|
||||
| Length, [(EArray es, _)] ->
|
||||
ELit (LInt (Runtime.integer_of_int (List.length es)))
|
||||
| Log (entry, infos), [e'] ->
|
||||
print_log ctx entry infos pos e';
|
||||
Marked.unmark e'
|
||||
| Eq, [(e1, _); (e2, _)] -> ELit (LBool (handle_eq ctx pos e1 e2))
|
||||
| Map, [f; (EArray es, _)] ->
|
||||
EArray
|
||||
(List.map
|
||||
(fun e' ->
|
||||
evaluate_expr ctx
|
||||
(Marked.same_mark_as (EApp { f; args = [e'] }) e'))
|
||||
es)
|
||||
| Reduce, [_; default; (EArray [], _)] -> Marked.unmark default
|
||||
| Reduce, [f; _; (EArray (x0 :: xn), _)] ->
|
||||
Marked.unmark
|
||||
(List.fold_left
|
||||
(fun acc x ->
|
||||
evaluate_expr ctx
|
||||
(Marked.same_mark_as (EApp { f; args = [acc; x] }) f))
|
||||
x0 xn)
|
||||
| Concat, [(EArray es1, _); (EArray es2, _)] -> EArray (es1 @ es2)
|
||||
| Filter, [f; (EArray es, _)] ->
|
||||
EArray
|
||||
(List.filter
|
||||
(fun e' ->
|
||||
match
|
||||
evaluate_expr ctx
|
||||
(Marked.same_mark_as (EApp { f; args = [e'] }) e')
|
||||
with
|
||||
| ELit (LBool b), _ -> b
|
||||
| _ ->
|
||||
Errors.raise_spanned_error
|
||||
(Expr.pos (List.nth args 0))
|
||||
"This predicate evaluated to something else than a \
|
||||
boolean (should not happen if the term was well-typed)")
|
||||
es)
|
||||
| Fold, [f; init; (EArray es, _)] ->
|
||||
Marked.unmark
|
||||
(List.fold_left
|
||||
(fun acc e' ->
|
||||
evaluate_expr ctx
|
||||
(Marked.same_mark_as (EApp { f; args = [acc; e'] }) e'))
|
||||
init es)
|
||||
| (Length | Log _ | Eq | Map | Concat | Filter | Fold | Reduce), _ ->
|
||||
err ())
|
||||
~monomorphic:(fun op ->
|
||||
let rlit =
|
||||
match op, List.map (function ELit l, _ -> l | _ -> err ()) args with
|
||||
| Not, [LBool b] -> LBool (o_not b)
|
||||
| GetDay, [LDate d] -> LInt (o_getDay d)
|
||||
| GetMonth, [LDate d] -> LInt (o_getMonth d)
|
||||
| GetYear, [LDate d] -> LInt (o_getYear d)
|
||||
| FirstDayOfMonth, [LDate d] -> LDate (o_firstDayOfMonth d)
|
||||
| LastDayOfMonth, [LDate d] -> LDate (o_lastDayOfMonth d)
|
||||
| And, [LBool b1; LBool b2] -> LBool (o_and b1 b2)
|
||||
| Or, [LBool b1; LBool b2] -> LBool (o_or b1 b2)
|
||||
| Xor, [LBool b1; LBool b2] -> LBool (o_xor b1 b2)
|
||||
| ( ( Not | GetDay | GetMonth | GetYear | FirstDayOfMonth
|
||||
| LastDayOfMonth | And | Or | Xor ),
|
||||
_ ) ->
|
||||
err ()
|
||||
in
|
||||
ELit rlit)
|
||||
~resolved:(fun op ->
|
||||
let rlit =
|
||||
match op, List.map (function ELit l, _ -> l | _ -> err ()) args with
|
||||
| Minus_int, [LInt x] -> LInt (o_minus_int x)
|
||||
| Minus_rat, [LRat x] -> LRat (o_minus_rat x)
|
||||
| Minus_mon, [LMoney x] -> LMoney (o_minus_mon x)
|
||||
| Minus_dur, [LDuration x] -> LDuration (o_minus_dur x)
|
||||
| ToRat_int, [LInt i] -> LRat (o_torat_int i)
|
||||
| ToRat_mon, [LMoney i] -> LRat (o_torat_mon i)
|
||||
| ToMoney_rat, [LRat i] -> LMoney (o_tomoney_rat i)
|
||||
| Round_mon, [LMoney m] -> LMoney (o_round_mon m)
|
||||
| Round_rat, [LRat m] -> LRat (o_round_rat m)
|
||||
| Add_int_int, [LInt x; LInt y] -> LInt (o_add_int_int x y)
|
||||
| Add_rat_rat, [LRat x; LRat y] -> LRat (o_add_rat_rat x y)
|
||||
| Add_mon_mon, [LMoney x; LMoney y] -> LMoney (o_add_mon_mon x y)
|
||||
| Add_dat_dur, [LDate x; LDuration y] -> LDate (o_add_dat_dur x y)
|
||||
| Add_dur_dur, [LDuration x; LDuration y] ->
|
||||
LDuration (o_add_dur_dur x y)
|
||||
| Sub_int_int, [LInt x; LInt y] -> LInt (o_sub_int_int x y)
|
||||
| Sub_rat_rat, [LRat x; LRat y] -> LRat (o_sub_rat_rat x y)
|
||||
| Sub_mon_mon, [LMoney x; LMoney y] -> LMoney (o_sub_mon_mon x y)
|
||||
| Sub_dat_dat, [LDate x; LDate y] -> LDuration (o_sub_dat_dat x y)
|
||||
| Sub_dat_dur, [LDate x; LDuration y] -> LDate (o_sub_dat_dur x y)
|
||||
| Sub_dur_dur, [LDuration x; LDuration y] ->
|
||||
LDuration (o_sub_dur_dur x y)
|
||||
| Mult_int_int, [LInt x; LInt y] -> LInt (o_mult_int_int x y)
|
||||
| Mult_rat_rat, [LRat x; LRat y] -> LRat (o_mult_rat_rat x y)
|
||||
| Mult_mon_rat, [LMoney x; LRat y] -> LMoney (o_mult_mon_rat x y)
|
||||
| Mult_dur_int, [LDuration x; LInt y] ->
|
||||
LDuration (o_mult_dur_int x y)
|
||||
| Div_int_int, [LInt x; LInt y] -> LRat (protect o_div_int_int x y)
|
||||
| Div_rat_rat, [LRat x; LRat y] -> LRat (protect o_div_rat_rat x y)
|
||||
| Div_mon_mon, [LMoney x; LMoney y] ->
|
||||
LRat (protect o_div_mon_mon x y)
|
||||
| Div_mon_rat, [LMoney x; LRat y] ->
|
||||
LMoney (protect o_div_mon_rat x y)
|
||||
| Lt_int_int, [LInt x; LInt y] -> LBool (o_lt_int_int x y)
|
||||
| Lt_rat_rat, [LRat x; LRat y] -> LBool (o_lt_rat_rat x y)
|
||||
| Lt_mon_mon, [LMoney x; LMoney y] -> LBool (o_lt_mon_mon x y)
|
||||
| Lt_dat_dat, [LDate x; LDate y] -> LBool (o_lt_dat_dat x y)
|
||||
| Lt_dur_dur, [LDuration x; LDuration y] ->
|
||||
LBool (protect o_lt_dur_dur x y)
|
||||
| Lte_int_int, [LInt x; LInt y] -> LBool (o_lte_int_int x y)
|
||||
| Lte_rat_rat, [LRat x; LRat y] -> LBool (o_lte_rat_rat x y)
|
||||
| Lte_mon_mon, [LMoney x; LMoney y] -> LBool (o_lte_mon_mon x y)
|
||||
| Lte_dat_dat, [LDate x; LDate y] -> LBool (o_lte_dat_dat x y)
|
||||
| Lte_dur_dur, [LDuration x; LDuration y] ->
|
||||
LBool (protect o_lte_dur_dur x y)
|
||||
| Gt_int_int, [LInt x; LInt y] -> LBool (o_gt_int_int x y)
|
||||
| Gt_rat_rat, [LRat x; LRat y] -> LBool (o_gt_rat_rat x y)
|
||||
| Gt_mon_mon, [LMoney x; LMoney y] -> LBool (o_gt_mon_mon x y)
|
||||
| Gt_dat_dat, [LDate x; LDate y] -> LBool (o_gt_dat_dat x y)
|
||||
| Gt_dur_dur, [LDuration x; LDuration y] ->
|
||||
LBool (protect o_gt_dur_dur x y)
|
||||
| Gte_int_int, [LInt x; LInt y] -> LBool (o_gte_int_int x y)
|
||||
| Gte_rat_rat, [LRat x; LRat y] -> LBool (o_gte_rat_rat x y)
|
||||
| Gte_mon_mon, [LMoney x; LMoney y] -> LBool (o_gte_mon_mon x y)
|
||||
| Gte_dat_dat, [LDate x; LDate y] -> LBool (o_gte_dat_dat x y)
|
||||
| Gte_dur_dur, [LDuration x; LDuration y] ->
|
||||
LBool (protect o_gte_dur_dur x y)
|
||||
| Eq_int_int, [LInt x; LInt y] -> LBool (o_eq_int_int x y)
|
||||
| Eq_rat_rat, [LRat x; LRat y] -> LBool (o_eq_rat_rat x y)
|
||||
| Eq_mon_mon, [LMoney x; LMoney y] -> LBool (o_eq_mon_mon x y)
|
||||
| Eq_dat_dat, [LDate x; LDate y] -> LBool (o_eq_dat_dat x y)
|
||||
| Eq_dur_dur, [LDuration x; LDuration y] ->
|
||||
LBool (protect o_eq_dur_dur x y)
|
||||
| ( ( Minus_int | Minus_rat | Minus_mon | Minus_dur | ToRat_int
|
||||
| ToRat_mon | ToMoney_rat | Round_rat | Round_mon | Add_int_int
|
||||
| Add_rat_rat | Add_mon_mon | Add_dat_dur | Add_dur_dur
|
||||
| Sub_int_int | Sub_rat_rat | Sub_mon_mon | Sub_dat_dat
|
||||
| Sub_dat_dur | Sub_dur_dur | Mult_int_int | Mult_rat_rat
|
||||
| Mult_mon_rat | Mult_dur_int | Div_int_int | Div_rat_rat
|
||||
| Div_mon_mon | Div_mon_rat | Lt_int_int | Lt_rat_rat | Lt_mon_mon
|
||||
| Lt_dat_dat | Lt_dur_dur | Lte_int_int | Lte_rat_rat
|
||||
| Lte_mon_mon | Lte_dat_dat | Lte_dur_dur | Gt_int_int
|
||||
| Gt_rat_rat | Gt_mon_mon | Gt_dat_dat | Gt_dur_dur | Gte_int_int
|
||||
| Gte_rat_rat | Gte_mon_mon | Gte_dat_dat | Gte_dur_dur
|
||||
| Eq_int_int | Eq_rat_rat | Eq_mon_mon | Eq_dat_dat | Eq_dur_dur
|
||||
),
|
||||
_ ) ->
|
||||
err ()
|
||||
in
|
||||
ELit rlit)
|
||||
~overloaded:(fun _ -> assert false)
|
||||
|
||||
and evaluate_expr (ctx : decl_ctx) (e : 'm Ast.expr) : 'm Ast.expr =
|
||||
match Marked.unmark e with
|
||||
@ -314,11 +315,11 @@ and evaluate_expr (ctx : decl_ctx) (e : 'm Ast.expr) : 'm Ast.expr =
|
||||
Errors.raise_spanned_error (Expr.pos e)
|
||||
"free variable found at evaluation (should not happen if term was \
|
||||
well-typed"
|
||||
| EApp (e1, args) -> (
|
||||
| EApp { f = e1; args } -> (
|
||||
let e1 = evaluate_expr ctx e1 in
|
||||
let args = List.map (evaluate_expr ctx) args in
|
||||
match Marked.unmark e1 with
|
||||
| EAbs (binder, _) ->
|
||||
| EAbs { binder; _ } ->
|
||||
if Bindlib.mbinder_arity binder = List.length args then
|
||||
evaluate_expr ctx
|
||||
(Bindlib.msubst binder (Array.of_list (List.map Marked.unmark args)))
|
||||
@ -327,7 +328,7 @@ and evaluate_expr (ctx : decl_ctx) (e : 'm Ast.expr) : 'm Ast.expr =
|
||||
"wrong function call, expected %d arguments, got %d"
|
||||
(Bindlib.mbinder_arity binder)
|
||||
(List.length args)
|
||||
| EOp op ->
|
||||
| EOp { op; _ } ->
|
||||
Marked.same_mark_as (evaluate_operator ctx op (Expr.pos e) args) e
|
||||
| ELit LEmptyError -> Marked.same_mark_as (ELit LEmptyError) e
|
||||
| _ ->
|
||||
@ -335,69 +336,66 @@ and evaluate_expr (ctx : decl_ctx) (e : 'm Ast.expr) : 'm Ast.expr =
|
||||
"function has not been reduced to a lambda at evaluation (should not \
|
||||
happen if the term was well-typed")
|
||||
| EAbs _ | ELit _ | EOp _ -> e (* these are values *)
|
||||
| ETuple (es, s) ->
|
||||
let new_es = List.map (evaluate_expr ctx) es in
|
||||
if List.exists is_empty_error new_es then
|
||||
| EStruct { fields = es; name } ->
|
||||
let new_es = StructField.Map.map (evaluate_expr ctx) es in
|
||||
if StructField.Map.exists (fun _ e -> is_empty_error e) new_es then
|
||||
Marked.same_mark_as (ELit LEmptyError) e
|
||||
else Marked.same_mark_as (ETuple (new_es, s)) e
|
||||
| ETupleAccess (e1, n, s, _) -> (
|
||||
else Marked.same_mark_as (EStruct { fields = new_es; name }) e
|
||||
| EStructAccess { e = e1; name = s; field } -> (
|
||||
let e1 = evaluate_expr ctx e1 in
|
||||
match Marked.unmark e1 with
|
||||
| ETuple (es, s') -> (
|
||||
(match s, s' with
|
||||
| None, None -> ()
|
||||
| Some s, Some s' when s = s' -> ()
|
||||
| _ ->
|
||||
| EStruct { fields = es; name = s' } -> (
|
||||
if not (StructName.equal s s') then
|
||||
Errors.raise_multispanned_error
|
||||
[None, Expr.pos e; None, Expr.pos e1]
|
||||
"Error during tuple access: not the same structs (should not happen \
|
||||
if the term was well-typed)");
|
||||
match List.nth_opt es n with
|
||||
"Error during struct access: not the same structs (should not happen \
|
||||
if the term was well-typed)";
|
||||
match StructField.Map.find_opt field es with
|
||||
| Some e' -> e'
|
||||
| None ->
|
||||
Errors.raise_spanned_error (Expr.pos e1)
|
||||
"The tuple has %d components but the %i-th element was requested \
|
||||
(should not happen if the term was well-type)"
|
||||
(List.length es) n)
|
||||
"Invalid field access %a in struct %a (should not happen if the term \
|
||||
was well-typed)"
|
||||
StructField.format_t field StructName.format_t s)
|
||||
| ELit LEmptyError -> Marked.same_mark_as (ELit LEmptyError) e
|
||||
| _ ->
|
||||
Errors.raise_spanned_error (Expr.pos e1)
|
||||
"The expression %a should be a tuple with %d components but is not \
|
||||
(should not happen if the term was well-typed)"
|
||||
"The expression %a should be a struct %a but is not (should not happen \
|
||||
if the term was well-typed)"
|
||||
(Expr.format ctx ~debug:true)
|
||||
e n)
|
||||
| EInj (e1, n, en, ts) ->
|
||||
e StructName.format_t s)
|
||||
| EInj { e = e1; name; cons } ->
|
||||
let e1' = evaluate_expr ctx e1 in
|
||||
if is_empty_error e1' then Marked.same_mark_as (ELit LEmptyError) e
|
||||
else Marked.same_mark_as (EInj (e1', n, en, ts)) e
|
||||
| EMatch (e1, es, e_name) -> (
|
||||
if is_empty_error e then Marked.same_mark_as (ELit LEmptyError) e
|
||||
else Marked.same_mark_as (EInj { e = e1'; name; cons }) e
|
||||
| EMatch { e = e1; cases = es; name } -> (
|
||||
let e1 = evaluate_expr ctx e1 in
|
||||
match Marked.unmark e1 with
|
||||
| EInj (e1, n, e_name', _) ->
|
||||
if e_name <> e_name' then
|
||||
| EInj { e = e1; cons; name = name' } ->
|
||||
if not (EnumName.equal name name') then
|
||||
Errors.raise_multispanned_error
|
||||
[None, Expr.pos e; None, Expr.pos e1]
|
||||
"Error during match: two different enums found (should not happen if \
|
||||
the term was well-typed)";
|
||||
let es_n =
|
||||
match List.nth_opt es n with
|
||||
match EnumConstructor.Map.find_opt cons es with
|
||||
| Some es_n -> es_n
|
||||
| None ->
|
||||
Errors.raise_spanned_error (Expr.pos e)
|
||||
"sum type index error (should not happen if the term was \
|
||||
well-typed)"
|
||||
in
|
||||
let new_e = Marked.same_mark_as (EApp (es_n, [e1])) e in
|
||||
let new_e = Marked.same_mark_as (EApp { f = es_n; args = [e1] }) e in
|
||||
evaluate_expr ctx new_e
|
||||
| ELit LEmptyError -> Marked.same_mark_as (ELit LEmptyError) e
|
||||
| _ ->
|
||||
Errors.raise_spanned_error (Expr.pos e1)
|
||||
"Expected a term having a sum type as an argument to a match (should \
|
||||
not happen if the term was well-typed")
|
||||
| EDefault (exceptions, just, cons) -> (
|
||||
let exceptions = List.map (evaluate_expr ctx) exceptions in
|
||||
let empty_count = List.length (List.filter is_empty_error exceptions) in
|
||||
match List.length exceptions - empty_count with
|
||||
| EDefault { excepts; just; cons } -> (
|
||||
let excepts = List.map (evaluate_expr ctx) excepts in
|
||||
let empty_count = List.length (List.filter is_empty_error excepts) in
|
||||
match List.length excepts - empty_count with
|
||||
| 0 -> (
|
||||
let just = evaluate_expr ctx just in
|
||||
match Marked.unmark just with
|
||||
@ -408,19 +406,19 @@ and evaluate_expr (ctx : decl_ctx) (e : 'm Ast.expr) : 'm Ast.expr =
|
||||
Errors.raise_spanned_error (Expr.pos e)
|
||||
"Default justification has not been reduced to a boolean at \
|
||||
evaluation (should not happen if the term was well-typed")
|
||||
| 1 -> List.find (fun sub -> not (is_empty_error sub)) exceptions
|
||||
| 1 -> List.find (fun sub -> not (is_empty_error sub)) excepts
|
||||
| _ ->
|
||||
Errors.raise_multispanned_error
|
||||
(List.map
|
||||
(fun except ->
|
||||
Some "This consequence has a valid justification:", Expr.pos except)
|
||||
(List.filter (fun sub -> not (is_empty_error sub)) exceptions))
|
||||
(List.filter (fun sub -> not (is_empty_error sub)) excepts))
|
||||
"There is a conflict between multiple valid consequences for assigning \
|
||||
the same variable.")
|
||||
| EIfThenElse (cond, et, ef) -> (
|
||||
| EIfThenElse { cond; etrue; efalse } -> (
|
||||
match Marked.unmark (evaluate_expr ctx cond) with
|
||||
| ELit (LBool true) -> evaluate_expr ctx et
|
||||
| ELit (LBool false) -> evaluate_expr ctx ef
|
||||
| ELit (LBool true) -> evaluate_expr ctx etrue
|
||||
| ELit (LBool false) -> evaluate_expr ctx efalse
|
||||
| ELit LEmptyError -> Marked.same_mark_as (ELit LEmptyError) e
|
||||
| _ ->
|
||||
Errors.raise_spanned_error (Expr.pos cond)
|
||||
@ -431,7 +429,7 @@ and evaluate_expr (ctx : decl_ctx) (e : 'm Ast.expr) : 'm Ast.expr =
|
||||
if List.exists is_empty_error new_es then
|
||||
Marked.same_mark_as (ELit LEmptyError) e
|
||||
else Marked.same_mark_as (EArray new_es) e
|
||||
| ErrorOnEmpty e' ->
|
||||
| EErrorOnEmpty e' ->
|
||||
let e' = evaluate_expr ctx e' in
|
||||
if Marked.unmark e' = ELit LEmptyError then
|
||||
Errors.raise_spanned_error (Expr.pos e')
|
||||
@ -443,23 +441,44 @@ and evaluate_expr (ctx : decl_ctx) (e : 'm Ast.expr) : 'm Ast.expr =
|
||||
| ELit (LBool true) -> Marked.same_mark_as (ELit LUnit) e'
|
||||
| ELit (LBool false) -> (
|
||||
match Marked.unmark e' with
|
||||
| ErrorOnEmpty
|
||||
| EErrorOnEmpty
|
||||
( EApp
|
||||
((EOp (Binop op), _), [((ELit _, _) as e1); ((ELit _, _) as e2)]),
|
||||
_ )
|
||||
| EApp
|
||||
( (EOp (Unop (Log _)), _),
|
||||
[
|
||||
( EApp
|
||||
( (EOp (Binop op), _),
|
||||
[((ELit _, _) as e1); ((ELit _, _) as e2)] ),
|
||||
_ );
|
||||
] )
|
||||
| EApp ((EOp (Binop op), _), [((ELit _, _) as e1); ((ELit _, _) as e2)])
|
||||
->
|
||||
{
|
||||
f = EOp { op; _ }, _;
|
||||
args = [((ELit _, _) as e1); ((ELit _, _) as e2)];
|
||||
},
|
||||
_ ) ->
|
||||
Errors.raise_spanned_error (Expr.pos e') "Assertion failed: %a %a %a"
|
||||
(Expr.format ctx ~debug:false)
|
||||
e1 Print.binop op
|
||||
e1 Print.operator op
|
||||
(Expr.format ctx ~debug:false)
|
||||
e2
|
||||
| EApp
|
||||
{
|
||||
f = EOp { op = Log _; _ }, _;
|
||||
args =
|
||||
[
|
||||
( EApp
|
||||
{
|
||||
f = EOp { op; _ }, _;
|
||||
args = [((ELit _, _) as e1); ((ELit _, _) as e2)];
|
||||
},
|
||||
_ );
|
||||
];
|
||||
} ->
|
||||
Errors.raise_spanned_error (Expr.pos e') "Assertion failed: %a %a %a"
|
||||
(Expr.format ctx ~debug:false)
|
||||
e1 Print.operator op
|
||||
(Expr.format ctx ~debug:false)
|
||||
e2
|
||||
| EApp
|
||||
{
|
||||
f = EOp { op; _ }, _;
|
||||
args = [((ELit _, _) as e1); ((ELit _, _) as e2)];
|
||||
} ->
|
||||
Errors.raise_spanned_error (Expr.pos e') "Assertion failed: %a %a %a"
|
||||
(Expr.format ctx ~debug:false)
|
||||
e1 Print.operator op
|
||||
(Expr.format ctx ~debug:false)
|
||||
e2
|
||||
| _ ->
|
||||
@ -479,19 +498,22 @@ let interpret_program :
|
||||
fun (ctx : decl_ctx) (e : 'm Ast.expr) :
|
||||
(Uid.MarkedString.info * 'm Ast.expr) list ->
|
||||
match evaluate_expr ctx e with
|
||||
| (EAbs (_, [((TStruct s_in, _) as _targs)]), mark_e) as e -> begin
|
||||
| (EAbs { tys = [((TStruct s_in, _) as _targs)]; _ }, mark_e) as e -> begin
|
||||
(* At this point, the interpreter seeks to execute the scope but does not
|
||||
have a way to retrieve input values from the command line. [taus] contain
|
||||
the types of the scope arguments. For [context] arguments, we can provide
|
||||
an empty thunked term. But for [input] arguments of another type, we
|
||||
cannot provide anything so we have to fail. *)
|
||||
let taus = StructMap.find s_in ctx.ctx_structs in
|
||||
let taus = StructName.Map.find s_in ctx.ctx_structs in
|
||||
let application_term =
|
||||
List.map
|
||||
(fun (_, ty) ->
|
||||
StructField.Map.map
|
||||
(fun ty ->
|
||||
match Marked.unmark ty with
|
||||
| TArrow ((TLit TUnit, _), ty_in) ->
|
||||
Expr.empty_thunked_term (Expr.with_ty mark_e ty_in)
|
||||
| TArrow (ty_in, ty_out) ->
|
||||
Expr.make_abs
|
||||
[| Var.make "_" |]
|
||||
(Bindlib.box (ELit LEmptyError), Expr.with_ty mark_e ty_out)
|
||||
[ty_in] (Expr.mark_pos mark_e)
|
||||
| _ ->
|
||||
Errors.raise_spanned_error (Marked.get_mark ty)
|
||||
"This scope needs input arguments to be executed. But the Catala \
|
||||
@ -503,17 +525,14 @@ let interpret_program :
|
||||
in
|
||||
let to_interpret =
|
||||
Expr.make_app (Expr.box e)
|
||||
[Expr.make_tuple application_term (Some s_in) mark_e]
|
||||
[Expr.estruct s_in application_term mark_e]
|
||||
(Expr.pos e)
|
||||
in
|
||||
match Marked.unmark (evaluate_expr ctx (Expr.unbox to_interpret)) with
|
||||
| ETuple (args, Some s_out) ->
|
||||
let s_out_fields =
|
||||
List.map
|
||||
(fun (f, _) -> StructFieldName.get_info f)
|
||||
(StructMap.find s_out ctx.ctx_structs)
|
||||
in
|
||||
List.map2 (fun arg var -> var, arg) args s_out_fields
|
||||
| EStruct { fields; _ } ->
|
||||
List.map
|
||||
(fun (fld, e) -> StructField.get_info fld, e)
|
||||
(StructField.Map.bindings fields)
|
||||
| _ ->
|
||||
Errors.raise_spanned_error (Expr.pos e)
|
||||
"The interpretation of a program should always yield a struct \
|
||||
|
@ -16,7 +16,7 @@
|
||||
|
||||
(** Reference interpreter for the default calculus *)
|
||||
|
||||
open Utils
|
||||
open Catala_utils
|
||||
open Shared_ast
|
||||
|
||||
val evaluate_expr : decl_ctx -> 'm Ast.expr -> 'm Ast.expr
|
||||
|
@ -14,7 +14,7 @@
|
||||
WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
License for the specific language governing permissions and limitations under
|
||||
the License. *)
|
||||
open Utils
|
||||
open Catala_utils
|
||||
open Shared_ast
|
||||
open Ast
|
||||
|
||||
@ -24,179 +24,206 @@ type partial_evaluation_ctx = {
|
||||
}
|
||||
|
||||
let rec partial_evaluation (ctx : partial_evaluation_ctx) (e : 'm expr) :
|
||||
'm expr Bindlib.box =
|
||||
(dcalc, 'm mark) boxed_gexpr =
|
||||
(* We proceed bottom-up, first apply on the subterms *)
|
||||
let e = Expr.map ~f:(partial_evaluation ctx) e in
|
||||
let mark = Marked.get_mark e in
|
||||
let rec_helper = partial_evaluation ctx in
|
||||
match Marked.unmark e with
|
||||
| EApp
|
||||
( (( EOp (Unop Not), _
|
||||
| EApp ((EOp (Unop (Log _)), _), [(EOp (Unop Not), _)]), _ ) as op),
|
||||
[e1] ) ->
|
||||
(* reduction of logical not *)
|
||||
(Bindlib.box_apply (fun e1 ->
|
||||
match e1 with
|
||||
| ELit (LBool false), _ -> ELit (LBool true), mark
|
||||
| ELit (LBool true), _ -> ELit (LBool false), mark
|
||||
| _ -> EApp (op, [e1]), mark))
|
||||
(rec_helper e1)
|
||||
| EApp
|
||||
( (( EOp (Binop Or), _
|
||||
| EApp ((EOp (Unop (Log _)), _), [(EOp (Binop Or), _)]), _ ) as op),
|
||||
[e1; e2] ) ->
|
||||
(* reduction of logical or *)
|
||||
(Bindlib.box_apply2 (fun e1 e2 ->
|
||||
match e1, e2 with
|
||||
| (ELit (LBool false), _), new_e | new_e, (ELit (LBool false), _) ->
|
||||
new_e
|
||||
| (ELit (LBool true), _), _ | _, (ELit (LBool true), _) ->
|
||||
ELit (LBool true), mark
|
||||
| _ -> EApp (op, [e1; e2]), mark))
|
||||
(rec_helper e1) (rec_helper e2)
|
||||
| EApp
|
||||
( (( EOp (Binop And), _
|
||||
| EApp ((EOp (Unop (Log _)), _), [(EOp (Binop And), _)]), _ ) as op),
|
||||
[e1; e2] ) ->
|
||||
(* reduction of logical and *)
|
||||
(Bindlib.box_apply2 (fun e1 e2 ->
|
||||
match e1, e2 with
|
||||
| (ELit (LBool true), _), new_e | new_e, (ELit (LBool true), _) ->
|
||||
new_e
|
||||
| (ELit (LBool false), _), _ | _, (ELit (LBool false), _) ->
|
||||
ELit (LBool false), mark
|
||||
| _ -> EApp (op, [e1; e2]), mark))
|
||||
(rec_helper e1) (rec_helper e2)
|
||||
| EVar x -> Bindlib.box_apply (fun x -> x, mark) (Bindlib.box_var x)
|
||||
| ETuple (args, s_name) ->
|
||||
Bindlib.box_apply
|
||||
(fun args -> ETuple (args, s_name), mark)
|
||||
(List.map rec_helper args |> Bindlib.box_list)
|
||||
| ETupleAccess (arg, i, s_name, typs) ->
|
||||
Bindlib.box_apply
|
||||
(fun arg -> ETupleAccess (arg, i, s_name, typs), mark)
|
||||
(rec_helper arg)
|
||||
| EInj (arg, i, e_name, typs) ->
|
||||
Bindlib.box_apply
|
||||
(fun arg -> EInj (arg, i, e_name, typs), mark)
|
||||
(rec_helper arg)
|
||||
| EMatch (arg, arms, e_name) ->
|
||||
Bindlib.box_apply2
|
||||
(fun arg arms ->
|
||||
match arg, arms with
|
||||
| (EInj (e1, i, e_name', _ts), _), _
|
||||
when EnumName.compare e_name e_name' = 0 ->
|
||||
(* iota reduction *)
|
||||
EApp (List.nth arms i, [e1]), mark
|
||||
| _ -> EMatch (arg, arms, e_name), mark)
|
||||
(rec_helper arg)
|
||||
(List.map rec_helper arms |> Bindlib.box_list)
|
||||
| EArray args ->
|
||||
Bindlib.box_apply
|
||||
(fun args -> EArray args, mark)
|
||||
(List.map rec_helper args |> Bindlib.box_list)
|
||||
| ELit l -> Bindlib.box (ELit l, mark)
|
||||
| EAbs (binder, typs) ->
|
||||
let vars, body = Bindlib.unmbind binder in
|
||||
let new_body = rec_helper body in
|
||||
let new_binder = Bindlib.bind_mvar vars new_body in
|
||||
Bindlib.box_apply (fun binder -> EAbs (binder, typs), mark) new_binder
|
||||
| EApp (f, args) ->
|
||||
Bindlib.box_apply2
|
||||
(fun f args ->
|
||||
match Marked.unmark f with
|
||||
| EAbs (binder, _ts) ->
|
||||
(* beta reduction *)
|
||||
Bindlib.msubst binder (List.map fst args |> Array.of_list)
|
||||
| _ -> EApp (f, args), mark)
|
||||
(rec_helper f)
|
||||
(List.map rec_helper args |> Bindlib.box_list)
|
||||
| EAssert e1 -> Bindlib.box_apply (fun e1 -> EAssert e1, mark) (rec_helper e1)
|
||||
| EOp op -> Bindlib.box (EOp op, mark)
|
||||
| EDefault (exceptions, just, cons) ->
|
||||
Bindlib.box_apply3
|
||||
(fun exceptions just cons ->
|
||||
(* TODO: mechanically prove each of these optimizations correct :) *)
|
||||
match
|
||||
( List.filter
|
||||
(fun except ->
|
||||
match Marked.unmark except with
|
||||
| ELit LEmptyError -> false
|
||||
| _ -> true)
|
||||
exceptions
|
||||
(* we can discard the exceptions that are always empty error *),
|
||||
just,
|
||||
cons )
|
||||
with
|
||||
| exceptions, just, cons
|
||||
when List.fold_left
|
||||
(fun nb except -> if Expr.is_value except then nb + 1 else nb)
|
||||
0 exceptions
|
||||
> 1 ->
|
||||
(* at this point we know a conflict error will be triggered so we just
|
||||
feed the expression to the interpreter that will print the
|
||||
beautiful right error message *)
|
||||
Interpreter.evaluate_expr ctx.decl_ctx
|
||||
(EDefault (exceptions, just, cons), mark)
|
||||
| [except], _, _ when Expr.is_value except ->
|
||||
(* Then reduce the parent node *)
|
||||
let reduce e =
|
||||
(* Todo: improve the handling of eapp(log,elit) cases here, it obfuscates
|
||||
the matches and the log calls are not preserved, which would be a good
|
||||
property *)
|
||||
match Marked.unmark e with
|
||||
| EApp
|
||||
{
|
||||
f =
|
||||
( EOp { op = Not; _ }, _
|
||||
| ( EApp
|
||||
{
|
||||
f = EOp { op = Log _; _ }, _;
|
||||
args = [(EOp { op = Not; _ }, _)];
|
||||
},
|
||||
_ ) ) as op;
|
||||
args = [e1];
|
||||
} -> (
|
||||
(* reduction of logical not *)
|
||||
match e1 with
|
||||
| ELit (LBool false), _ -> ELit (LBool true)
|
||||
| ELit (LBool true), _ -> ELit (LBool false)
|
||||
| e1 -> EApp { f = op; args = [e1] })
|
||||
| EApp
|
||||
{
|
||||
f =
|
||||
( EOp { op = Or; _ }, _
|
||||
| ( EApp
|
||||
{
|
||||
f = EOp { op = Log _; _ }, _;
|
||||
args = [(EOp { op = Or; _ }, _)];
|
||||
},
|
||||
_ ) ) as op;
|
||||
args = [e1; e2];
|
||||
} -> (
|
||||
(* reduction of logical or *)
|
||||
match e1, e2 with
|
||||
| (ELit (LBool false), _), new_e | new_e, (ELit (LBool false), _) ->
|
||||
Marked.unmark new_e
|
||||
| (ELit (LBool true), _), _ | _, (ELit (LBool true), _) ->
|
||||
ELit (LBool true)
|
||||
| _ -> EApp { f = op; args = [e1; e2] })
|
||||
| EApp
|
||||
{
|
||||
f =
|
||||
( EOp { op = And; _ }, _
|
||||
| ( EApp
|
||||
{
|
||||
f = EOp { op = Log _; _ }, _;
|
||||
args = [(EOp { op = And; _ }, _)];
|
||||
},
|
||||
_ ) ) as op;
|
||||
args = [e1; e2];
|
||||
} -> (
|
||||
(* reduction of logical and *)
|
||||
match e1, e2 with
|
||||
| (ELit (LBool true), _), new_e | new_e, (ELit (LBool true), _) ->
|
||||
Marked.unmark new_e
|
||||
| (ELit (LBool false), _), _ | _, (ELit (LBool false), _) ->
|
||||
ELit (LBool false)
|
||||
| _ -> EApp { f = op; args = [e1; e2] })
|
||||
| EMatch { e = EInj { e; name = name1; cons }, _; cases; name }
|
||||
when EnumName.equal name name1 ->
|
||||
(* iota reduction *)
|
||||
EApp { f = EnumConstructor.Map.find cons cases; args = [e] }
|
||||
| EApp { f = EAbs { binder; _ }, _; args } ->
|
||||
(* beta reduction *)
|
||||
Marked.unmark (Bindlib.msubst binder (List.map fst args |> Array.of_list))
|
||||
| EDefault { excepts; just; cons } -> (
|
||||
(* TODO: mechanically prove each of these optimizations correct :) *)
|
||||
let excepts =
|
||||
List.filter
|
||||
(fun except -> Marked.unmark except <> ELit LEmptyError)
|
||||
excepts
|
||||
(* we can discard the exceptions that are always empty error *)
|
||||
in
|
||||
let value_except_count =
|
||||
List.fold_left
|
||||
(fun nb except -> if Expr.is_value except then nb + 1 else nb)
|
||||
0 excepts
|
||||
in
|
||||
if value_except_count > 1 then
|
||||
(* at this point we know a conflict error will be triggered so we just
|
||||
feed the expression to the interpreter that will print the beautiful
|
||||
right error message *)
|
||||
Marked.unmark (Interpreter.evaluate_expr ctx.decl_ctx e)
|
||||
else
|
||||
match excepts, just with
|
||||
| [except], _ when Expr.is_value except ->
|
||||
(* if there is only one exception and it is a non-empty value it is
|
||||
always chosen *)
|
||||
except
|
||||
Marked.unmark except
|
||||
| ( [],
|
||||
( ( ELit (LBool true)
|
||||
| EApp ((EOp (Unop (Log _)), _), [(ELit (LBool true), _)]) ),
|
||||
_ ),
|
||||
cons ) ->
|
||||
cons
|
||||
| EApp
|
||||
{
|
||||
f = EOp { op = Log _; _ }, _;
|
||||
args = [(ELit (LBool true), _)];
|
||||
} ),
|
||||
_ ) ) ->
|
||||
Marked.unmark cons
|
||||
| ( [],
|
||||
( ( ELit (LBool false)
|
||||
| EApp ((EOp (Unop (Log _)), _), [(ELit (LBool false), _)]) ),
|
||||
_ ),
|
||||
_ ) ->
|
||||
ELit LEmptyError, mark
|
||||
| [], just, cons when not !Cli.avoid_exceptions_flag ->
|
||||
| EApp
|
||||
{
|
||||
f = EOp { op = Log _; _ }, _;
|
||||
args = [(ELit (LBool false), _)];
|
||||
} ),
|
||||
_ ) ) ->
|
||||
ELit LEmptyError
|
||||
| [], just when not !Cli.avoid_exceptions_flag ->
|
||||
(* without exceptions, a default is just an [if then else] raising an
|
||||
error in the else case. This exception is only valid in the context
|
||||
of compilation_with_exceptions, so we desactivate with a global
|
||||
flag to know if we will be compiling using exceptions or the option
|
||||
monad. *)
|
||||
EIfThenElse (just, cons, (ELit LEmptyError, mark)), mark
|
||||
| exceptions, just, cons -> EDefault (exceptions, just, cons), mark)
|
||||
(List.map rec_helper exceptions |> Bindlib.box_list)
|
||||
(rec_helper just) (rec_helper cons)
|
||||
| EIfThenElse (e1, e2, e3) ->
|
||||
Bindlib.box_apply3
|
||||
(fun e1 e2 e3 ->
|
||||
match Marked.unmark e1, Marked.unmark e2, Marked.unmark e3 with
|
||||
| ELit (LBool true), _, _
|
||||
| EApp ((EOp (Unop (Log _)), _), [(ELit (LBool true), _)]), _, _ ->
|
||||
e2
|
||||
| ELit (LBool false), _, _
|
||||
| EApp ((EOp (Unop (Log _)), _), [(ELit (LBool false), _)]), _, _ ->
|
||||
e3
|
||||
| ( _,
|
||||
( ELit (LBool true)
|
||||
| EApp ((EOp (Unop (Log _)), _), [(ELit (LBool true), _)]) ),
|
||||
( ELit (LBool false)
|
||||
| EApp ((EOp (Unop (Log _)), _), [(ELit (LBool false), _)]) ) ) ->
|
||||
e1
|
||||
| _ when Expr.equal e2 e3 -> e2
|
||||
| _ -> EIfThenElse (e1, e2, e3), mark)
|
||||
(rec_helper e1) (rec_helper e2) (rec_helper e3)
|
||||
| ErrorOnEmpty e1 ->
|
||||
Bindlib.box_apply (fun e1 -> ErrorOnEmpty e1, mark) (rec_helper e1)
|
||||
monad. FIXME: move this optimisation somewhere else to avoid this
|
||||
check *)
|
||||
EIfThenElse
|
||||
{ cond = just; etrue = cons; efalse = ELit LEmptyError, mark }
|
||||
| excepts, just -> EDefault { excepts; just; cons })
|
||||
| EIfThenElse
|
||||
{
|
||||
cond =
|
||||
( ELit (LBool true), _
|
||||
| ( EApp
|
||||
{
|
||||
f = EOp { op = Log _; _ }, _;
|
||||
args = [(ELit (LBool true), _)];
|
||||
},
|
||||
_ ) );
|
||||
etrue;
|
||||
_;
|
||||
} ->
|
||||
Marked.unmark etrue
|
||||
| EIfThenElse
|
||||
{
|
||||
cond =
|
||||
( ( ELit (LBool false)
|
||||
| EApp
|
||||
{
|
||||
f = EOp { op = Log _; _ }, _;
|
||||
args = [(ELit (LBool false), _)];
|
||||
} ),
|
||||
_ );
|
||||
efalse;
|
||||
_;
|
||||
} ->
|
||||
Marked.unmark efalse
|
||||
| EIfThenElse
|
||||
{
|
||||
cond;
|
||||
etrue =
|
||||
( ( ELit (LBool btrue)
|
||||
| EApp
|
||||
{
|
||||
f = EOp { op = Log _; _ }, _;
|
||||
args = [(ELit (LBool btrue), _)];
|
||||
} ),
|
||||
_ );
|
||||
efalse =
|
||||
( ( ELit (LBool bfalse)
|
||||
| EApp
|
||||
{
|
||||
f = EOp { op = Log _; _ }, _;
|
||||
args = [(ELit (LBool bfalse), _)];
|
||||
} ),
|
||||
_ );
|
||||
} ->
|
||||
if btrue && not bfalse then Marked.unmark cond
|
||||
else if (not btrue) && bfalse then
|
||||
EApp
|
||||
{
|
||||
f = EOp { op = Not; tys = [TLit TBool, Expr.mark_pos mark] }, mark;
|
||||
args = [cond];
|
||||
}
|
||||
(* note: this last call eliminates the condition & might skip log calls
|
||||
as well *)
|
||||
else (* btrue = bfalse *) ELit (LBool btrue)
|
||||
| e -> e
|
||||
in
|
||||
Expr.Box.app1 e reduce mark
|
||||
|
||||
let optimize_expr (decl_ctx : decl_ctx) (e : 'm expr) =
|
||||
partial_evaluation { var_values = Var.Map.empty; decl_ctx } e
|
||||
|
||||
let rec scope_lets_map
|
||||
(t : 'a -> 'm expr -> 'm expr Bindlib.box)
|
||||
(t : 'a -> 'm expr -> (dcalc, 'm mark) boxed_gexpr)
|
||||
(ctx : 'a)
|
||||
(scope_body_expr : 'm expr scope_body_expr) :
|
||||
'm expr scope_body_expr Bindlib.box =
|
||||
match scope_body_expr with
|
||||
| Result e -> Bindlib.box_apply (fun e' -> Result e') (t ctx e)
|
||||
| Result e ->
|
||||
Bindlib.box_apply (fun e' -> Result e') (Expr.Box.lift (t ctx e))
|
||||
| ScopeLet scope_let ->
|
||||
let var, next = Bindlib.unbind scope_let.scope_let_next in
|
||||
let new_scope_let_expr = t ctx scope_let.scope_let_expr in
|
||||
let new_scope_let_expr = Expr.Box.lift (t ctx scope_let.scope_let_expr) in
|
||||
let new_next = scope_lets_map t ctx next in
|
||||
let new_next = Bindlib.bind_var var new_next in
|
||||
Bindlib.box_apply2
|
||||
@ -210,7 +237,7 @@ let rec scope_lets_map
|
||||
new_scope_let_expr new_next
|
||||
|
||||
let rec scopes_map
|
||||
(t : 'a -> 'm expr -> 'm expr Bindlib.box)
|
||||
(t : 'a -> 'm expr -> (dcalc, 'm mark) boxed_gexpr)
|
||||
(ctx : 'a)
|
||||
(scopes : 'm expr scopes) : 'm expr scopes Bindlib.box =
|
||||
match scopes with
|
||||
@ -241,7 +268,7 @@ let rec scopes_map
|
||||
new_scope_body_expr new_scope_next
|
||||
|
||||
let program_map
|
||||
(t : 'a -> 'm expr -> 'm expr Bindlib.box)
|
||||
(t : 'a -> 'm expr -> (dcalc, 'm mark) boxed_gexpr)
|
||||
(ctx : 'a)
|
||||
(p : 'm program) : 'm program Bindlib.box =
|
||||
Bindlib.box_apply
|
||||
|
@ -20,5 +20,5 @@
|
||||
open Shared_ast
|
||||
open Ast
|
||||
|
||||
val optimize_expr : decl_ctx -> 'm expr -> 'm expr Bindlib.box
|
||||
val optimize_expr : decl_ctx -> 'm expr -> (dcalc, 'm mark) boxed_gexpr
|
||||
val optimize_program : 'm program -> 'm program
|
||||
|
@ -16,25 +16,11 @@
|
||||
|
||||
(** Abstract syntax tree of the desugared representation *)
|
||||
|
||||
open Utils
|
||||
open Catala_utils
|
||||
open Shared_ast
|
||||
|
||||
(** {1 Names, Maps and Keys} *)
|
||||
|
||||
module IdentMap : Map.S with type key = String.t = Map.Make (String)
|
||||
|
||||
module RuleName : Uid.Id with type info = Uid.MarkedString.info =
|
||||
Uid.Make (Uid.MarkedString) ()
|
||||
|
||||
module RuleMap : Map.S with type key = RuleName.t = Map.Make (RuleName)
|
||||
module RuleSet : Set.S with type elt = RuleName.t = Set.Make (RuleName)
|
||||
|
||||
module LabelName : Uid.Id with type info = Uid.MarkedString.info =
|
||||
Uid.Make (Uid.MarkedString) ()
|
||||
|
||||
module LabelMap : Map.S with type key = LabelName.t = Map.Make (LabelName)
|
||||
module LabelSet : Set.S with type elt = LabelName.t = Set.Make (LabelName)
|
||||
|
||||
(** Inside a scope, a definition can refer either to a scope def, or a subscope
|
||||
def *)
|
||||
module ScopeDef = struct
|
||||
@ -103,6 +89,9 @@ module ExprMap = Map.Make (struct
|
||||
let compare = Expr.compare
|
||||
end)
|
||||
|
||||
type io_input = NoInput | OnlyInput | Reentrant
|
||||
type io = { io_output : bool Marked.pos; io_input : io_input Marked.pos }
|
||||
|
||||
type exception_situation =
|
||||
| BaseCase
|
||||
| ExceptionToLabel of LabelName.t Marked.pos
|
||||
@ -136,7 +125,7 @@ module Rule = struct
|
||||
Expr.compare c1 c2
|
||||
| n -> n)
|
||||
| Some (v1, t1), Some (v2, t2) -> (
|
||||
match Shared_ast.Expr.compare_typ t1 t2 with
|
||||
match Type.compare t1 t2 with
|
||||
| 0 -> (
|
||||
let open Bindlib in
|
||||
let b1 = unbox (bind_var v1 (Expr.Box.lift r1.rule_just)) in
|
||||
@ -189,29 +178,32 @@ type meta_assertion =
|
||||
| VariesWith of unit * variation_typ Marked.pos option
|
||||
|
||||
type scope_def = {
|
||||
scope_def_rules : rule RuleMap.t;
|
||||
scope_def_rules : rule RuleName.Map.t;
|
||||
scope_def_typ : typ;
|
||||
scope_def_is_condition : bool;
|
||||
scope_def_io : Scopelang.Ast.io;
|
||||
scope_def_io : io;
|
||||
}
|
||||
|
||||
type var_or_states = WholeVar | States of StateName.t list
|
||||
|
||||
type scope = {
|
||||
scope_vars : var_or_states ScopeVarMap.t;
|
||||
scope_sub_scopes : ScopeName.t SubScopeMap.t;
|
||||
scope_vars : var_or_states ScopeVar.Map.t;
|
||||
scope_sub_scopes : ScopeName.t SubScopeName.Map.t;
|
||||
scope_uid : ScopeName.t;
|
||||
scope_defs : scope_def ScopeDefMap.t;
|
||||
scope_assertions : assertion list;
|
||||
scope_meta_assertions : meta_assertion list;
|
||||
}
|
||||
|
||||
type program = { program_scopes : scope ScopeMap.t; program_ctx : decl_ctx }
|
||||
type program = {
|
||||
program_scopes : scope ScopeName.Map.t;
|
||||
program_ctx : decl_ctx;
|
||||
}
|
||||
|
||||
let rec locations_used e : LocationSet.t =
|
||||
match e with
|
||||
| ELocation l, m -> LocationSet.singleton (l, Expr.mark_pos m)
|
||||
| EAbs (binder, _), _ ->
|
||||
| EAbs { binder; _ }, _ ->
|
||||
let _, body = Bindlib.unmbind binder in
|
||||
locations_used body
|
||||
| e ->
|
||||
@ -219,7 +211,7 @@ let rec locations_used e : LocationSet.t =
|
||||
(fun e -> LocationSet.union (locations_used e))
|
||||
e LocationSet.empty
|
||||
|
||||
let free_variables (def : rule RuleMap.t) : Pos.t ScopeDefMap.t =
|
||||
let free_variables (def : rule RuleName.Map.t) : Pos.t ScopeDefMap.t =
|
||||
let add_locs (acc : Pos.t ScopeDefMap.t) (locs : LocationSet.t) :
|
||||
Pos.t ScopeDefMap.t =
|
||||
LocationSet.fold
|
||||
@ -235,7 +227,7 @@ let free_variables (def : rule RuleMap.t) : Pos.t ScopeDefMap.t =
|
||||
loc_pos acc)
|
||||
locs acc
|
||||
in
|
||||
RuleMap.fold
|
||||
RuleName.Map.fold
|
||||
(fun _ rule acc ->
|
||||
let locs =
|
||||
LocationSet.union
|
||||
|
@ -16,19 +16,9 @@
|
||||
|
||||
(** Abstract syntax tree of the desugared representation *)
|
||||
|
||||
open Utils
|
||||
open Catala_utils
|
||||
open Shared_ast
|
||||
|
||||
(** {1 Names, Maps and Keys} *)
|
||||
|
||||
module IdentMap : Map.S with type key = String.t
|
||||
module RuleName : Uid.Id with type info = Uid.MarkedString.info
|
||||
module RuleMap : Map.S with type key = RuleName.t
|
||||
module RuleSet : Set.S with type elt = RuleName.t
|
||||
module LabelName : Uid.Id with type info = Uid.MarkedString.info
|
||||
module LabelMap : Map.S with type key = LabelName.t
|
||||
module LabelSet : Set.S with type elt = LabelName.t
|
||||
|
||||
(** Inside a scope, a definition can refer either to a scope def, or a subscope
|
||||
def *)
|
||||
module ScopeDef : sig
|
||||
@ -88,27 +78,51 @@ type meta_assertion =
|
||||
| FixedBy of reference_typ Marked.pos
|
||||
| VariesWith of unit * variation_typ Marked.pos option
|
||||
|
||||
(** This type characterizes the three levels of visibility for a given scope
|
||||
variable with regards to the scope's input and possible redefinitions inside
|
||||
the scope.. *)
|
||||
type io_input =
|
||||
| NoInput
|
||||
(** For an internal variable defined only in the scope, and does not
|
||||
appear in the input. *)
|
||||
| OnlyInput
|
||||
(** For variables that should not be redefined in the scope, because they
|
||||
appear in the input. *)
|
||||
| Reentrant
|
||||
(** For variables defined in the scope that can also be redefined by the
|
||||
caller as they appear in the input. *)
|
||||
|
||||
type io = {
|
||||
io_output : bool Marked.pos;
|
||||
(** [true] is present in the output of the scope. *)
|
||||
io_input : io_input Marked.pos;
|
||||
}
|
||||
(** Characterization of the input/output status of a scope variable. *)
|
||||
|
||||
type scope_def = {
|
||||
scope_def_rules : rule RuleMap.t;
|
||||
scope_def_rules : rule RuleName.Map.t;
|
||||
scope_def_typ : typ;
|
||||
scope_def_is_condition : bool;
|
||||
scope_def_io : Scopelang.Ast.io;
|
||||
scope_def_io : io;
|
||||
}
|
||||
|
||||
type var_or_states = WholeVar | States of StateName.t list
|
||||
|
||||
type scope = {
|
||||
scope_vars : var_or_states ScopeVarMap.t;
|
||||
scope_sub_scopes : ScopeName.t SubScopeMap.t;
|
||||
scope_vars : var_or_states ScopeVar.Map.t;
|
||||
scope_sub_scopes : ScopeName.t SubScopeName.Map.t;
|
||||
scope_uid : ScopeName.t;
|
||||
scope_defs : scope_def ScopeDefMap.t;
|
||||
scope_assertions : assertion list;
|
||||
scope_meta_assertions : meta_assertion list;
|
||||
}
|
||||
|
||||
type program = { program_scopes : scope ScopeMap.t; program_ctx : decl_ctx }
|
||||
type program = {
|
||||
program_scopes : scope ScopeName.Map.t;
|
||||
program_ctx : decl_ctx;
|
||||
}
|
||||
|
||||
(** {1 Helpers} *)
|
||||
|
||||
val locations_used : expr -> LocationSet.t
|
||||
val free_variables : rule RuleMap.t -> Pos.t ScopeDefMap.t
|
||||
val free_variables : rule RuleName.Map.t -> Pos.t ScopeDefMap.t
|
||||
|
@ -17,7 +17,7 @@
|
||||
(** Scope dependencies computations using {{:http://ocamlgraph.lri.fr/}
|
||||
OCamlgraph} *)
|
||||
|
||||
open Utils
|
||||
open Catala_utils
|
||||
open Shared_ast
|
||||
|
||||
(** {1 Scope variables dependency graph} *)
|
||||
@ -143,7 +143,7 @@ let build_scope_dependencies (scope : Ast.scope) : ScopeDependencies.t =
|
||||
let g = ScopeDependencies.empty in
|
||||
(* Add all the vertices to the graph *)
|
||||
let g =
|
||||
ScopeVarMap.fold
|
||||
ScopeVar.Map.fold
|
||||
(fun (v : ScopeVar.t) var_or_state g ->
|
||||
match var_or_state with
|
||||
| Ast.WholeVar -> ScopeDependencies.add_vertex g (Vertex.Var (v, None))
|
||||
@ -155,7 +155,7 @@ let build_scope_dependencies (scope : Ast.scope) : ScopeDependencies.t =
|
||||
scope.scope_vars g
|
||||
in
|
||||
let g =
|
||||
SubScopeMap.fold
|
||||
SubScopeName.Map.fold
|
||||
(fun (v : SubScopeName.t) _ g ->
|
||||
ScopeDependencies.add_vertex g (Vertex.SubScope v))
|
||||
scope.scope_sub_scopes g
|
||||
@ -229,10 +229,10 @@ let build_scope_dependencies (scope : Ast.scope) : ScopeDependencies.t =
|
||||
(** {2 Graph declaration} *)
|
||||
|
||||
module ExceptionVertex = struct
|
||||
include Ast.RuleSet
|
||||
include RuleName.Set
|
||||
|
||||
let hash (x : t) : int =
|
||||
Ast.RuleSet.fold (fun r acc -> Int.logxor (Ast.RuleName.hash r) acc) x 0
|
||||
RuleName.Set.fold (fun r acc -> Int.logxor (RuleName.hash r) acc) x 0
|
||||
|
||||
let equal x y = compare x y = 0
|
||||
end
|
||||
@ -257,13 +257,13 @@ module ExceptionsSCC = Graph.Components.Make (ExceptionsDependencies)
|
||||
(** {2 Graph computations} *)
|
||||
|
||||
type exception_edge = {
|
||||
label_from : Ast.LabelName.t;
|
||||
label_to : Ast.LabelName.t;
|
||||
label_from : LabelName.t;
|
||||
label_to : LabelName.t;
|
||||
edge_positions : Pos.t list;
|
||||
}
|
||||
|
||||
let build_exceptions_graph
|
||||
(def : Ast.rule Ast.RuleMap.t)
|
||||
(def : Ast.rule RuleName.Map.t)
|
||||
(def_info : Ast.ScopeDef.t) : ExceptionsDependencies.t =
|
||||
(* First we partition the definitions into groups bearing the same label. To
|
||||
handle the rules that were not labeled by the user, we create implicit
|
||||
@ -271,63 +271,59 @@ let build_exceptions_graph
|
||||
|
||||
(* All the rules of the form [definition x ...] are base case with no explicit
|
||||
label, so they should share this implicit label. *)
|
||||
let base_case_implicit_label =
|
||||
Ast.LabelName.fresh ("base_case", Pos.no_pos)
|
||||
in
|
||||
let base_case_implicit_label = LabelName.fresh ("base_case", Pos.no_pos) in
|
||||
(* When declaring [exception definition x ...], it means there is a unique
|
||||
rule [R] to which this can be an exception to. So we give a unique label to
|
||||
all the rules that are implicitly exceptions to rule [R]. *)
|
||||
let exception_to_rule_implicit_labels : Ast.LabelName.t Ast.RuleMap.t =
|
||||
Ast.RuleMap.fold
|
||||
let exception_to_rule_implicit_labels : LabelName.t RuleName.Map.t =
|
||||
RuleName.Map.fold
|
||||
(fun _ rule_from exception_to_rule_implicit_labels ->
|
||||
match rule_from.Ast.rule_exception with
|
||||
| Ast.ExceptionToRule (rule_to, _) -> (
|
||||
match
|
||||
Ast.RuleMap.find_opt rule_to exception_to_rule_implicit_labels
|
||||
RuleName.Map.find_opt rule_to exception_to_rule_implicit_labels
|
||||
with
|
||||
| Some _ ->
|
||||
(* we already created the label *) exception_to_rule_implicit_labels
|
||||
| None ->
|
||||
Ast.RuleMap.add rule_to
|
||||
(Ast.LabelName.fresh
|
||||
( "exception_to_"
|
||||
^ Marked.unmark (Ast.RuleName.get_info rule_to),
|
||||
RuleName.Map.add rule_to
|
||||
(LabelName.fresh
|
||||
( "exception_to_" ^ Marked.unmark (RuleName.get_info rule_to),
|
||||
Pos.no_pos ))
|
||||
exception_to_rule_implicit_labels)
|
||||
| _ -> exception_to_rule_implicit_labels)
|
||||
def Ast.RuleMap.empty
|
||||
def RuleName.Map.empty
|
||||
in
|
||||
(* When declaring [exception foo_l definition x ...], the rule is exception to
|
||||
all the rules sharing label [foo_l]. So we give a unique label to all the
|
||||
rules that are implicitly exceptions to rule [foo_l]. *)
|
||||
let exception_to_label_implicit_labels : Ast.LabelName.t Ast.LabelMap.t =
|
||||
Ast.RuleMap.fold
|
||||
let exception_to_label_implicit_labels : LabelName.t LabelName.Map.t =
|
||||
RuleName.Map.fold
|
||||
(fun _ rule_from
|
||||
(exception_to_label_implicit_labels : Ast.LabelName.t Ast.LabelMap.t) ->
|
||||
(exception_to_label_implicit_labels : LabelName.t LabelName.Map.t) ->
|
||||
match rule_from.Ast.rule_exception with
|
||||
| Ast.ExceptionToLabel (label_to, _) -> (
|
||||
match
|
||||
Ast.LabelMap.find_opt label_to exception_to_label_implicit_labels
|
||||
LabelName.Map.find_opt label_to exception_to_label_implicit_labels
|
||||
with
|
||||
| Some _ ->
|
||||
(* we already created the label *)
|
||||
exception_to_label_implicit_labels
|
||||
| None ->
|
||||
Ast.LabelMap.add label_to
|
||||
(Ast.LabelName.fresh
|
||||
( "exception_to_"
|
||||
^ Marked.unmark (Ast.LabelName.get_info label_to),
|
||||
LabelName.Map.add label_to
|
||||
(LabelName.fresh
|
||||
( "exception_to_" ^ Marked.unmark (LabelName.get_info label_to),
|
||||
Pos.no_pos ))
|
||||
exception_to_label_implicit_labels)
|
||||
| _ -> exception_to_label_implicit_labels)
|
||||
def Ast.LabelMap.empty
|
||||
def LabelName.Map.empty
|
||||
in
|
||||
|
||||
(* Now we have all the labels necessary to partition our rules into sets, each
|
||||
one corresponding to a label relating to the structure of the exception
|
||||
DAG. *)
|
||||
let label_to_rule_sets =
|
||||
Ast.RuleMap.fold
|
||||
RuleName.Map.fold
|
||||
(fun rule_name rule rule_sets ->
|
||||
let label_of_rule =
|
||||
match rule.Ast.rule_label with
|
||||
@ -336,23 +332,23 @@ let build_exceptions_graph
|
||||
match rule.Ast.rule_exception with
|
||||
| BaseCase -> base_case_implicit_label
|
||||
| ExceptionToRule (r, _) ->
|
||||
Ast.RuleMap.find r exception_to_rule_implicit_labels
|
||||
RuleName.Map.find r exception_to_rule_implicit_labels
|
||||
| ExceptionToLabel (l', _) ->
|
||||
Ast.LabelMap.find l' exception_to_label_implicit_labels)
|
||||
LabelName.Map.find l' exception_to_label_implicit_labels)
|
||||
in
|
||||
Ast.LabelMap.update label_of_rule
|
||||
LabelName.Map.update label_of_rule
|
||||
(fun rule_set ->
|
||||
match rule_set with
|
||||
| None -> Some (Ast.RuleSet.singleton rule_name)
|
||||
| Some rule_set -> Some (Ast.RuleSet.add rule_name rule_set))
|
||||
| None -> Some (RuleName.Set.singleton rule_name)
|
||||
| Some rule_set -> Some (RuleName.Set.add rule_name rule_set))
|
||||
rule_sets)
|
||||
def Ast.LabelMap.empty
|
||||
def LabelName.Map.empty
|
||||
in
|
||||
let find_label_of_rule (r : Ast.RuleName.t) : Ast.LabelName.t =
|
||||
let find_label_of_rule (r : RuleName.t) : LabelName.t =
|
||||
fst
|
||||
(Ast.LabelMap.choose
|
||||
(Ast.LabelMap.filter
|
||||
(fun _ rule_set -> Ast.RuleSet.mem r rule_set)
|
||||
(LabelName.Map.choose
|
||||
(LabelName.Map.filter
|
||||
(fun _ rule_set -> RuleName.Set.mem r rule_set)
|
||||
label_to_rule_sets))
|
||||
in
|
||||
(* Next, we collect the exception edges between those groups of rules referred
|
||||
@ -360,7 +356,7 @@ let build_exceptions_graph
|
||||
edges as they are declared at each rule but should be the same for all the
|
||||
rules of the same group. *)
|
||||
let exception_edges : exception_edge list =
|
||||
Ast.RuleMap.fold
|
||||
RuleName.Map.fold
|
||||
(fun rule_name rule exception_edges ->
|
||||
let label_from = find_label_of_rule rule_name in
|
||||
let label_to_and_pos =
|
||||
@ -374,16 +370,16 @@ let build_exceptions_graph
|
||||
| Some (label_to, edge_pos) -> (
|
||||
let other_edges_originating_from_same_label =
|
||||
List.filter
|
||||
(fun edge -> Ast.LabelName.compare edge.label_from label_from = 0)
|
||||
(fun edge -> LabelName.compare edge.label_from label_from = 0)
|
||||
exception_edges
|
||||
in
|
||||
(* We check the consistency*)
|
||||
if Ast.LabelName.compare label_from label_to = 0 then
|
||||
if LabelName.compare label_from label_to = 0 then
|
||||
Errors.raise_spanned_error edge_pos
|
||||
"Cannot define rule as an exception to itself";
|
||||
List.iter
|
||||
(fun edge ->
|
||||
if Ast.LabelName.compare edge.label_to label_to <> 0 then
|
||||
if LabelName.compare edge.label_to label_to <> 0 then
|
||||
Errors.raise_multispanned_error
|
||||
(( Some
|
||||
"This declaration contradicts another exception \
|
||||
@ -401,8 +397,8 @@ let build_exceptions_graph
|
||||
let existing_edge =
|
||||
List.find_opt
|
||||
(fun edge ->
|
||||
Ast.LabelName.compare edge.label_from label_from = 0
|
||||
&& Ast.LabelName.compare edge.label_to label_to = 0)
|
||||
LabelName.compare edge.label_from label_from = 0
|
||||
&& LabelName.compare edge.label_to label_to = 0)
|
||||
exception_edges
|
||||
in
|
||||
match existing_edge with
|
||||
@ -420,7 +416,7 @@ let build_exceptions_graph
|
||||
in
|
||||
(* We've got the vertices and the edges, let's build the graph! *)
|
||||
let g =
|
||||
Ast.LabelMap.fold
|
||||
LabelName.Map.fold
|
||||
(fun _label rule_set g -> ExceptionsDependencies.add_vertex g rule_set)
|
||||
label_to_rule_sets ExceptionsDependencies.empty
|
||||
in
|
||||
@ -429,10 +425,10 @@ let build_exceptions_graph
|
||||
List.fold_left
|
||||
(fun g edge ->
|
||||
let rule_group_from =
|
||||
Ast.LabelMap.find edge.label_from label_to_rule_sets
|
||||
LabelName.Map.find edge.label_from label_to_rule_sets
|
||||
in
|
||||
let rule_group_to =
|
||||
Ast.LabelMap.find edge.label_to label_to_rule_sets
|
||||
LabelName.Map.find edge.label_to label_to_rule_sets
|
||||
in
|
||||
let edge =
|
||||
ExceptionsDependencies.E.create rule_group_from edge.edge_positions
|
||||
@ -453,11 +449,10 @@ let check_for_exception_cycle (g : ExceptionsDependencies.t) : unit =
|
||||
let spans =
|
||||
List.flatten
|
||||
(List.map
|
||||
(fun (vs : Ast.RuleSet.t) ->
|
||||
let v = Ast.RuleSet.choose vs in
|
||||
(fun (vs : RuleName.Set.t) ->
|
||||
let v = RuleName.Set.choose vs in
|
||||
let var_str, var_info =
|
||||
( Format.asprintf "%a" Ast.RuleName.format_t v,
|
||||
Ast.RuleName.get_info v )
|
||||
Format.asprintf "%a" RuleName.format_t v, RuleName.get_info v
|
||||
in
|
||||
let succs = ExceptionsDependencies.succ_e g vs in
|
||||
let _, edge_pos, _ =
|
||||
|
@ -17,7 +17,8 @@
|
||||
(** Scope dependencies computations using {{:http://ocamlgraph.lri.fr/}
|
||||
OCamlgraph} *)
|
||||
|
||||
open Utils
|
||||
open Catala_utils
|
||||
open Shared_ast
|
||||
|
||||
(** {1 Scope variables dependency graph} *)
|
||||
|
||||
@ -71,9 +72,9 @@ val build_scope_dependencies : Ast.scope -> ScopeDependencies.t
|
||||
module EdgeExceptions : Graph.Sig.ORDERED_TYPE_DFT with type t = Pos.t list
|
||||
|
||||
module ExceptionsDependencies :
|
||||
Graph.Sig.P with type V.t = Ast.RuleSet.t and type E.label = EdgeExceptions.t
|
||||
Graph.Sig.P with type V.t = RuleName.Set.t and type E.label = EdgeExceptions.t
|
||||
|
||||
val build_exceptions_graph :
|
||||
Ast.rule Ast.RuleMap.t -> Ast.ScopeDef.t -> ExceptionsDependencies.t
|
||||
Ast.rule RuleName.Map.t -> Ast.ScopeDef.t -> ExceptionsDependencies.t
|
||||
|
||||
val check_for_exception_cycle : ExceptionsDependencies.t -> unit
|
||||
|
@ -1,670 +0,0 @@
|
||||
(* This file is part of the Catala compiler, a specification language for tax
|
||||
and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
|
||||
Denis Merigoux <denis.merigoux@inria.fr>
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not
|
||||
use this file except in compliance with the License. You may obtain a copy of
|
||||
the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
License for the specific language governing permissions and limitations under
|
||||
the License. *)
|
||||
|
||||
(** Translation from {!module: Desugared.Ast} to {!module: Scopelang.Ast} *)
|
||||
|
||||
open Utils
|
||||
open Shared_ast
|
||||
|
||||
(** {1 Expression translation}*)
|
||||
|
||||
type target_scope_vars =
|
||||
| WholeVar of ScopeVar.t
|
||||
| States of (StateName.t * ScopeVar.t) list
|
||||
|
||||
type ctx = {
|
||||
scope_var_mapping : target_scope_vars ScopeVarMap.t;
|
||||
var_mapping : (Ast.expr, untyped Scopelang.Ast.expr Var.t) Var.Map.t;
|
||||
}
|
||||
|
||||
let tag_with_log_entry
|
||||
(e : untyped Scopelang.Ast.expr boxed)
|
||||
(l : log_entry)
|
||||
(markings : Utils.Uid.MarkedString.info list) :
|
||||
untyped Scopelang.Ast.expr boxed =
|
||||
Expr.eapp
|
||||
(Expr.eop (Unop (Log (l, markings))) (Marked.get_mark e))
|
||||
[e] (Marked.get_mark e)
|
||||
|
||||
let rec translate_expr (ctx : ctx) (e : Ast.expr) :
|
||||
untyped Scopelang.Ast.expr boxed =
|
||||
let m = Marked.get_mark e in
|
||||
match Marked.unmark e with
|
||||
| ELocation (SubScopeVar (s_name, ss_name, s_var)) ->
|
||||
(* When referring to a subscope variable in an expression, we are referring
|
||||
to the output, hence we take the last state. *)
|
||||
let new_s_var =
|
||||
match ScopeVarMap.find (Marked.unmark s_var) ctx.scope_var_mapping with
|
||||
| WholeVar new_s_var -> Marked.same_mark_as new_s_var s_var
|
||||
| States states ->
|
||||
Marked.same_mark_as (snd (List.hd (List.rev states))) s_var
|
||||
in
|
||||
Expr.elocation (SubScopeVar (s_name, ss_name, new_s_var)) m
|
||||
| ELocation (DesugaredScopeVar (s_var, None)) ->
|
||||
Expr.elocation
|
||||
(ScopelangScopeVar
|
||||
(match
|
||||
ScopeVarMap.find (Marked.unmark s_var) ctx.scope_var_mapping
|
||||
with
|
||||
| WholeVar new_s_var -> Marked.same_mark_as new_s_var s_var
|
||||
| States _ -> failwith "should not happen"))
|
||||
m
|
||||
| ELocation (DesugaredScopeVar (s_var, Some state)) ->
|
||||
Expr.elocation
|
||||
(ScopelangScopeVar
|
||||
(match
|
||||
ScopeVarMap.find (Marked.unmark s_var) ctx.scope_var_mapping
|
||||
with
|
||||
| WholeVar _ -> failwith "should not happen"
|
||||
| States states -> Marked.same_mark_as (List.assoc state states) s_var))
|
||||
m
|
||||
| EVar v -> Expr.evar (Var.Map.find v ctx.var_mapping) m
|
||||
| EStruct (s_name, fields) ->
|
||||
Expr.estruct s_name (StructFieldMap.map (translate_expr ctx) fields) m
|
||||
| EStructAccess (e1, f_name, s_name) ->
|
||||
Expr.estructaccess (translate_expr ctx e1) f_name s_name m
|
||||
| EEnumInj (e1, cons, e_name) ->
|
||||
Expr.eenuminj (translate_expr ctx e1) cons e_name m
|
||||
| EMatchS (e1, e_name, arms) ->
|
||||
Expr.ematchs (translate_expr ctx e1) e_name
|
||||
(EnumConstructorMap.map (translate_expr ctx) arms)
|
||||
m
|
||||
| EScopeCall (sc_name, fields) ->
|
||||
Expr.escopecall sc_name
|
||||
(ScopeVarMap.fold
|
||||
(fun v e fields' ->
|
||||
let v' =
|
||||
match ScopeVarMap.find v ctx.scope_var_mapping with
|
||||
| WholeVar v' -> v'
|
||||
| States ((_, v') :: _) ->
|
||||
(* When there are multiple states, the input is always the first
|
||||
one *)
|
||||
v'
|
||||
| States [] -> assert false
|
||||
in
|
||||
ScopeVarMap.add v' (translate_expr ctx e) fields')
|
||||
fields ScopeVarMap.empty)
|
||||
m
|
||||
| ELit
|
||||
(( LBool _ | LEmptyError | LInt _ | LRat _ | LMoney _ | LUnit | LDate _
|
||||
| LDuration _ ) as l) ->
|
||||
Expr.elit l m
|
||||
| EAbs (binder, typs) ->
|
||||
let vars, body = Bindlib.unmbind binder in
|
||||
let new_vars = Array.map (fun var -> Var.make (Bindlib.name_of var)) vars in
|
||||
let ctx =
|
||||
List.fold_left2
|
||||
(fun ctx var new_var ->
|
||||
{ ctx with var_mapping = Var.Map.add var new_var ctx.var_mapping })
|
||||
ctx (Array.to_list vars) (Array.to_list new_vars)
|
||||
in
|
||||
Expr.eabs (Expr.bind new_vars (translate_expr ctx body)) typs m
|
||||
| EApp (e1, args) ->
|
||||
Expr.eapp (translate_expr ctx e1) (List.map (translate_expr ctx) args) m
|
||||
| EOp op -> Expr.eop op m
|
||||
| EDefault (excepts, just, cons) ->
|
||||
Expr.edefault
|
||||
(List.map (translate_expr ctx) excepts)
|
||||
(translate_expr ctx just) (translate_expr ctx cons) m
|
||||
| EIfThenElse (e1, e2, e3) ->
|
||||
Expr.eifthenelse (translate_expr ctx e1) (translate_expr ctx e2)
|
||||
(translate_expr ctx e3) m
|
||||
| EArray args -> Expr.earray (List.map (translate_expr ctx) args) m
|
||||
| ErrorOnEmpty e1 -> Expr.eerroronempty (translate_expr ctx e1) m
|
||||
|
||||
(** {1 Rule tree construction} *)
|
||||
|
||||
(** Intermediate representation for the exception tree of rules for a particular
|
||||
scope definition. *)
|
||||
type rule_tree =
|
||||
| Leaf of Ast.rule list
|
||||
(** Rules defining a base case piecewise. List is non-empty. *)
|
||||
| Node of rule_tree list * Ast.rule list
|
||||
(** [Node (exceptions, base_case)] is a list of exceptions to a non-empty
|
||||
list of rules defining a base case piecewise. *)
|
||||
|
||||
(** Transforms a flat list of rules into a tree, taking into account the
|
||||
priorities declared between rules *)
|
||||
let def_map_to_tree (def_info : Ast.ScopeDef.t) (def : Ast.rule Ast.RuleMap.t) :
|
||||
rule_tree list =
|
||||
let exc_graph = Dependency.build_exceptions_graph def def_info in
|
||||
Dependency.check_for_exception_cycle exc_graph;
|
||||
(* we start by the base cases: they are the vertices which have no
|
||||
successors *)
|
||||
let base_cases =
|
||||
Dependency.ExceptionsDependencies.fold_vertex
|
||||
(fun v base_cases ->
|
||||
if Dependency.ExceptionsDependencies.out_degree exc_graph v = 0 then
|
||||
v :: base_cases
|
||||
else base_cases)
|
||||
exc_graph []
|
||||
in
|
||||
let rec build_tree (base_cases : Ast.RuleSet.t) : rule_tree =
|
||||
let exceptions =
|
||||
Dependency.ExceptionsDependencies.pred exc_graph base_cases
|
||||
in
|
||||
let base_case_as_rule_list =
|
||||
List.map
|
||||
(fun r -> Ast.RuleMap.find r def)
|
||||
(Ast.RuleSet.elements base_cases)
|
||||
in
|
||||
match exceptions with
|
||||
| [] -> Leaf base_case_as_rule_list
|
||||
| _ -> Node (List.map build_tree exceptions, base_case_as_rule_list)
|
||||
in
|
||||
List.map build_tree base_cases
|
||||
|
||||
(** From the {!type: rule_tree}, builds an {!constructor: Dcalc.EDefault}
|
||||
expression in the scope language. The [~toplevel] parameter is used to know
|
||||
when to place the toplevel binding in the case of functions. *)
|
||||
let rec rule_tree_to_expr
|
||||
~(toplevel : bool)
|
||||
(ctx : ctx)
|
||||
(def_pos : Pos.t)
|
||||
(is_func : Ast.expr Var.t option)
|
||||
(tree : rule_tree) : untyped Scopelang.Ast.expr boxed =
|
||||
let emark = Untyped { pos = def_pos } in
|
||||
let exceptions, base_rules =
|
||||
match tree with Leaf r -> [], r | Node (exceptions, r) -> exceptions, r
|
||||
in
|
||||
(* because each rule has its own variable parameter and we want to convert the
|
||||
whole rule tree into a function, we need to perform some alpha-renaming of
|
||||
all the expressions *)
|
||||
let substitute_parameter (e : Ast.expr boxed) (rule : Ast.rule) :
|
||||
Ast.expr boxed =
|
||||
match is_func, rule.Ast.rule_parameter with
|
||||
| Some new_param, Some (old_param, _) ->
|
||||
let binder = Bindlib.bind_var old_param (Marked.unmark e) in
|
||||
Marked.mark (Marked.get_mark e)
|
||||
@@ Bindlib.box_apply2
|
||||
(fun binder new_param -> Bindlib.subst binder new_param)
|
||||
binder
|
||||
(Bindlib.box_var new_param)
|
||||
| None, None -> e
|
||||
| _ -> assert false
|
||||
(* should not happen *)
|
||||
in
|
||||
let ctx =
|
||||
match is_func with
|
||||
| None -> ctx
|
||||
| Some new_param -> (
|
||||
match Var.Map.find_opt new_param ctx.var_mapping with
|
||||
| None ->
|
||||
let new_param_scope = Var.make (Bindlib.name_of new_param) in
|
||||
{
|
||||
ctx with
|
||||
var_mapping = Var.Map.add new_param new_param_scope ctx.var_mapping;
|
||||
}
|
||||
| Some _ ->
|
||||
(* We only create a mapping if none exists because [rule_tree_to_expr]
|
||||
is called recursively on the exceptions of the tree and we don't want
|
||||
to create a new Scopelang variable for the parameter at each tree
|
||||
level. *)
|
||||
ctx)
|
||||
in
|
||||
let base_just_list =
|
||||
List.map
|
||||
(fun rule -> substitute_parameter rule.Ast.rule_just rule)
|
||||
base_rules
|
||||
in
|
||||
let base_cons_list =
|
||||
List.map
|
||||
(fun rule -> substitute_parameter rule.Ast.rule_cons rule)
|
||||
base_rules
|
||||
in
|
||||
let translate_and_unbox_list (list : Ast.expr boxed list) :
|
||||
untyped Scopelang.Ast.expr boxed list =
|
||||
List.map
|
||||
(fun e ->
|
||||
(* There are two levels of boxing here, the outermost is introduced by
|
||||
the [translate_expr] function for which all of the bindings should
|
||||
have been closed by now, so we can safely unbox. *)
|
||||
translate_expr ctx (Expr.unbox e))
|
||||
list
|
||||
in
|
||||
let default_containing_base_cases =
|
||||
Expr.make_default
|
||||
(List.map2
|
||||
(fun base_just base_cons ->
|
||||
Expr.make_default []
|
||||
(* Here we insert the logging command that records when a decision
|
||||
is taken for the value of a variable. *)
|
||||
(tag_with_log_entry base_just PosRecordIfTrueBool [])
|
||||
base_cons emark)
|
||||
(translate_and_unbox_list base_just_list)
|
||||
(translate_and_unbox_list base_cons_list))
|
||||
(Expr.elit (LBool false) emark)
|
||||
(Expr.elit LEmptyError emark)
|
||||
emark
|
||||
in
|
||||
let exceptions =
|
||||
List.map (rule_tree_to_expr ~toplevel:false ctx def_pos is_func) exceptions
|
||||
in
|
||||
let default =
|
||||
Expr.make_default exceptions
|
||||
(Expr.elit (LBool true) emark)
|
||||
default_containing_base_cases emark
|
||||
in
|
||||
match is_func, (List.hd base_rules).Ast.rule_parameter with
|
||||
| None, None -> default
|
||||
| Some new_param, Some (_, typ) ->
|
||||
if toplevel then
|
||||
(* When we're creating a function from multiple defaults, we must check
|
||||
that the result returned by the function is not empty *)
|
||||
let default = Expr.eerroronempty default emark in
|
||||
Expr.make_abs
|
||||
[| Var.Map.find new_param ctx.var_mapping |]
|
||||
default [typ] def_pos
|
||||
else default
|
||||
| _ -> (* should not happen *) assert false
|
||||
|
||||
(** {1 AST translation} *)
|
||||
|
||||
(** Translates a definition inside a scope, the resulting expression should be
|
||||
an {!constructor: Dcalc.EDefault} *)
|
||||
let translate_def
|
||||
(ctx : ctx)
|
||||
(def_info : Ast.ScopeDef.t)
|
||||
(def : Ast.rule Ast.RuleMap.t)
|
||||
(typ : typ)
|
||||
(io : Scopelang.Ast.io)
|
||||
~(is_cond : bool)
|
||||
~(is_subscope_var : bool) : untyped Scopelang.Ast.expr boxed =
|
||||
(* Here, we have to transform this list of rules into a default tree. *)
|
||||
let is_def_func =
|
||||
match Marked.unmark typ with TArrow (_, _) -> true | _ -> false
|
||||
in
|
||||
let is_rule_func _ (r : Ast.rule) : bool =
|
||||
Option.is_some r.Ast.rule_parameter
|
||||
in
|
||||
let all_rules_func = Ast.RuleMap.for_all is_rule_func def in
|
||||
let all_rules_not_func =
|
||||
Ast.RuleMap.for_all (fun n r -> not (is_rule_func n r)) def
|
||||
in
|
||||
let is_def_func_param_typ : typ option =
|
||||
if is_def_func && all_rules_func then
|
||||
match Marked.unmark typ with
|
||||
| TArrow (t_param, _) -> Some t_param
|
||||
| _ ->
|
||||
Errors.raise_spanned_error (Marked.get_mark typ)
|
||||
"The definitions of %a are function but it doesn't have a function \
|
||||
type"
|
||||
Ast.ScopeDef.format_t def_info
|
||||
else if (not is_def_func) && all_rules_not_func then None
|
||||
else
|
||||
let spans =
|
||||
List.map
|
||||
(fun (_, r) ->
|
||||
Some "This definition is a function:", Expr.pos r.Ast.rule_cons)
|
||||
(Ast.RuleMap.bindings (Ast.RuleMap.filter is_rule_func def))
|
||||
@ List.map
|
||||
(fun (_, r) ->
|
||||
( Some "This definition is not a function:",
|
||||
Expr.pos r.Ast.rule_cons ))
|
||||
(Ast.RuleMap.bindings
|
||||
(Ast.RuleMap.filter (fun n r -> not (is_rule_func n r)) def))
|
||||
in
|
||||
Errors.raise_multispanned_error spans
|
||||
"some definitions of the same variable are functions while others \
|
||||
aren't"
|
||||
in
|
||||
let top_list = def_map_to_tree def_info def in
|
||||
let is_input =
|
||||
match Marked.unmark io.Scopelang.Ast.io_input with
|
||||
| OnlyInput -> true
|
||||
| _ -> false
|
||||
in
|
||||
let top_value =
|
||||
if is_cond && ((not is_subscope_var) || (is_subscope_var && is_input)) then
|
||||
(* We add the bottom [false] value for conditions, only for the scope
|
||||
where the condition is declared. Except when the variable is an input,
|
||||
where we want the [false] to be added at each caller parent scope. *)
|
||||
Some
|
||||
(Ast.always_false_rule
|
||||
(Ast.ScopeDef.get_position def_info)
|
||||
is_def_func_param_typ)
|
||||
else None
|
||||
in
|
||||
if
|
||||
Ast.RuleMap.cardinal def = 0
|
||||
&& is_subscope_var
|
||||
(* Here we have a special case for the empty definitions. Indeed, we could
|
||||
use the code for the regular case below that would create a convoluted
|
||||
default always returning empty error, and this would be correct. But it
|
||||
gets more complicated with functions. Indeed, if we create an empty
|
||||
definition for a subscope argument whose type is a function, we get
|
||||
something like [fun () -> (fun real_param -> < ... >)] that is passed as
|
||||
an argument to the subscope. The sub-scope de-thunks but the de-thunking
|
||||
does not return empty error, signalling there is not reentrant variable,
|
||||
because functions are values! So the subscope does not see that there is
|
||||
not reentrant variable and does not pick its internal definition instead.
|
||||
See [test/test_scope/subscope_function_arg_not_defined.catala_en] for a
|
||||
test case exercising that subtlety.
|
||||
|
||||
To avoid this complication we special case here and put an empty error
|
||||
for all subscope variables that are not defined. It covers the subtlety
|
||||
with functions described above but also conditions with the false default
|
||||
value. *)
|
||||
&& not (is_cond && is_input)
|
||||
(* However, this special case suffers from an exception: when a condition is
|
||||
defined as an OnlyInput to a subscope, since the [false] default value
|
||||
will not be provided by the calee scope, it has to be placed in the
|
||||
caller. *)
|
||||
then
|
||||
Expr.elit LEmptyError (Untyped { pos = Ast.ScopeDef.get_position def_info })
|
||||
else
|
||||
rule_tree_to_expr ~toplevel:true ctx
|
||||
(Ast.ScopeDef.get_position def_info)
|
||||
(Option.map (fun _ -> Var.make "param") is_def_func_param_typ)
|
||||
(match top_list, top_value with
|
||||
| [], None ->
|
||||
(* In this case, there are no rules to define the expression and no
|
||||
default value so we put an empty rule. *)
|
||||
Leaf [Ast.empty_rule (Marked.get_mark typ) is_def_func_param_typ]
|
||||
| [], Some top_value ->
|
||||
(* In this case, there are no rules to define the expression but a
|
||||
default value so we put it. *)
|
||||
Leaf [top_value]
|
||||
| _, Some top_value ->
|
||||
(* When there are rules + a default value, we put the rules as
|
||||
exceptions to the default value *)
|
||||
Node (top_list, [top_value])
|
||||
| [top_tree], None -> top_tree
|
||||
| _, None ->
|
||||
Node
|
||||
( top_list,
|
||||
[Ast.empty_rule (Marked.get_mark typ) is_def_func_param_typ] ))
|
||||
|
||||
(** Translates a scope *)
|
||||
let translate_scope (ctx : ctx) (scope : Ast.scope) :
|
||||
untyped Scopelang.Ast.scope_decl =
|
||||
let scope_dependencies = Dependency.build_scope_dependencies scope in
|
||||
Dependency.check_for_cycle scope scope_dependencies;
|
||||
let scope_ordering =
|
||||
Dependency.correct_computation_ordering scope_dependencies
|
||||
in
|
||||
let scope_decl_rules =
|
||||
List.flatten
|
||||
(List.map
|
||||
(fun vertex ->
|
||||
match vertex with
|
||||
| Dependency.Vertex.Var (var, state) -> (
|
||||
let scope_def =
|
||||
Ast.ScopeDefMap.find
|
||||
(Ast.ScopeDef.Var (var, state))
|
||||
scope.scope_defs
|
||||
in
|
||||
let var_def = scope_def.scope_def_rules in
|
||||
let var_typ = scope_def.scope_def_typ in
|
||||
let is_cond = scope_def.scope_def_is_condition in
|
||||
match Marked.unmark scope_def.Ast.scope_def_io.io_input with
|
||||
| OnlyInput when not (Ast.RuleMap.is_empty var_def) ->
|
||||
(* If the variable is tagged as input, then it shall not be
|
||||
redefined. *)
|
||||
Errors.raise_multispanned_error
|
||||
(( Some "Incriminated variable:",
|
||||
Marked.get_mark (ScopeVar.get_info var) )
|
||||
:: List.map
|
||||
(fun (rule, _) ->
|
||||
( Some "Incriminated variable definition:",
|
||||
Marked.get_mark (Ast.RuleName.get_info rule) ))
|
||||
(Ast.RuleMap.bindings var_def))
|
||||
"It is impossible to give a definition to a scope variable \
|
||||
tagged as input."
|
||||
| OnlyInput ->
|
||||
[]
|
||||
(* we do not provide any definition for an input-only variable *)
|
||||
| _ ->
|
||||
let expr_def =
|
||||
translate_def ctx
|
||||
(Ast.ScopeDef.Var (var, state))
|
||||
var_def var_typ scope_def.Ast.scope_def_io ~is_cond
|
||||
~is_subscope_var:false
|
||||
in
|
||||
let scope_var =
|
||||
match ScopeVarMap.find var ctx.scope_var_mapping, state with
|
||||
| WholeVar v, None -> v
|
||||
| States states, Some state -> List.assoc state states
|
||||
| _ -> failwith "should not happen"
|
||||
in
|
||||
[
|
||||
Scopelang.Ast.Definition
|
||||
( ( ScopelangScopeVar
|
||||
( scope_var,
|
||||
Marked.get_mark (ScopeVar.get_info scope_var) ),
|
||||
Marked.get_mark (ScopeVar.get_info scope_var) ),
|
||||
var_typ,
|
||||
scope_def.Ast.scope_def_io,
|
||||
Expr.unbox expr_def );
|
||||
])
|
||||
| Dependency.Vertex.SubScope sub_scope_index ->
|
||||
(* Before calling the sub_scope, we need to include all the
|
||||
re-definitions of subscope parameters*)
|
||||
let sub_scope =
|
||||
SubScopeMap.find sub_scope_index scope.scope_sub_scopes
|
||||
in
|
||||
let sub_scope_vars_redefs_candidates =
|
||||
Ast.ScopeDefMap.filter
|
||||
(fun def_key scope_def ->
|
||||
match def_key with
|
||||
| Ast.ScopeDef.Var _ -> false
|
||||
| Ast.ScopeDef.SubScopeVar (sub_scope_index', _, _) ->
|
||||
sub_scope_index = sub_scope_index'
|
||||
(* We exclude subscope variables that have 0 re-definitions
|
||||
and are not visible in the input of the subscope *)
|
||||
&& not
|
||||
((match
|
||||
Marked.unmark scope_def.Ast.scope_def_io.io_input
|
||||
with
|
||||
| Scopelang.Ast.NoInput -> true
|
||||
| _ -> false)
|
||||
&& Ast.RuleMap.is_empty scope_def.scope_def_rules))
|
||||
scope.scope_defs
|
||||
in
|
||||
let sub_scope_vars_redefs =
|
||||
Ast.ScopeDefMap.mapi
|
||||
(fun def_key scope_def ->
|
||||
let def = scope_def.Ast.scope_def_rules in
|
||||
let def_typ = scope_def.scope_def_typ in
|
||||
let is_cond = scope_def.scope_def_is_condition in
|
||||
match def_key with
|
||||
| Ast.ScopeDef.Var _ -> assert false (* should not happen *)
|
||||
| Ast.ScopeDef.SubScopeVar (sscope, sub_scope_var, pos) ->
|
||||
(* This definition redefines a variable of the correct
|
||||
subscope. But we have to check that this redefinition is
|
||||
allowed with respect to the io parameters of that
|
||||
subscope variable. *)
|
||||
(match
|
||||
Marked.unmark scope_def.Ast.scope_def_io.io_input
|
||||
with
|
||||
| Scopelang.Ast.NoInput ->
|
||||
Errors.raise_multispanned_error
|
||||
(( Some "Incriminated subscope:",
|
||||
Marked.get_mark (SubScopeName.get_info sscope) )
|
||||
:: ( Some "Incriminated variable:",
|
||||
Marked.get_mark (ScopeVar.get_info sub_scope_var)
|
||||
)
|
||||
:: List.map
|
||||
(fun (rule, _) ->
|
||||
( Some
|
||||
"Incriminated subscope variable definition:",
|
||||
Marked.get_mark (Ast.RuleName.get_info rule) ))
|
||||
(Ast.RuleMap.bindings def))
|
||||
"It is impossible to give a definition to a subscope \
|
||||
variable not tagged as input or context."
|
||||
| OnlyInput when Ast.RuleMap.is_empty def && not is_cond ->
|
||||
(* If the subscope variable is tagged as input, then it
|
||||
shall be defined. *)
|
||||
Errors.raise_multispanned_error
|
||||
[
|
||||
( Some "Incriminated subscope:",
|
||||
Marked.get_mark (SubScopeName.get_info sscope) );
|
||||
Some "Incriminated variable:", pos;
|
||||
]
|
||||
"This subscope variable is a mandatory input but no \
|
||||
definition was provided."
|
||||
| _ -> ());
|
||||
(* Now that all is good, we can proceed with translating
|
||||
this redefinition to a proper Scopelang term. *)
|
||||
let expr_def =
|
||||
translate_def ctx def_key def def_typ
|
||||
scope_def.Ast.scope_def_io ~is_cond
|
||||
~is_subscope_var:true
|
||||
in
|
||||
let subscop_real_name =
|
||||
SubScopeMap.find sub_scope_index scope.scope_sub_scopes
|
||||
in
|
||||
let var_pos = Ast.ScopeDef.get_position def_key in
|
||||
Scopelang.Ast.Definition
|
||||
( ( SubScopeVar
|
||||
( subscop_real_name,
|
||||
(sub_scope_index, var_pos),
|
||||
match
|
||||
ScopeVarMap.find sub_scope_var
|
||||
ctx.scope_var_mapping
|
||||
with
|
||||
| WholeVar v -> v, var_pos
|
||||
| States states ->
|
||||
(* When defining a sub-scope variable, we
|
||||
always define its first state in the
|
||||
sub-scope. *)
|
||||
snd (List.hd states), var_pos ),
|
||||
var_pos ),
|
||||
def_typ,
|
||||
scope_def.Ast.scope_def_io,
|
||||
Expr.unbox expr_def ))
|
||||
sub_scope_vars_redefs_candidates
|
||||
in
|
||||
let sub_scope_vars_redefs =
|
||||
List.map snd (Ast.ScopeDefMap.bindings sub_scope_vars_redefs)
|
||||
in
|
||||
sub_scope_vars_redefs
|
||||
@ [
|
||||
Scopelang.Ast.Call
|
||||
( sub_scope,
|
||||
sub_scope_index,
|
||||
Untyped
|
||||
{
|
||||
pos =
|
||||
Marked.get_mark
|
||||
(SubScopeName.get_info sub_scope_index);
|
||||
} );
|
||||
])
|
||||
scope_ordering)
|
||||
in
|
||||
(* Then, after having computed all the scopes variables, we add the
|
||||
assertions. TODO: the assertions should be interleaved with the
|
||||
definitions! *)
|
||||
let scope_decl_rules =
|
||||
scope_decl_rules
|
||||
@ List.map
|
||||
(fun e ->
|
||||
let scope_e = translate_expr ctx (Expr.unbox e) in
|
||||
Scopelang.Ast.Assertion (Expr.unbox scope_e))
|
||||
scope.Ast.scope_assertions
|
||||
in
|
||||
let scope_sig =
|
||||
ScopeVarMap.fold
|
||||
(fun var (states : Ast.var_or_states) acc ->
|
||||
match states with
|
||||
| WholeVar ->
|
||||
let scope_def =
|
||||
Ast.ScopeDefMap.find (Ast.ScopeDef.Var (var, None)) scope.scope_defs
|
||||
in
|
||||
let typ = scope_def.scope_def_typ in
|
||||
ScopeVarMap.add
|
||||
(match ScopeVarMap.find var ctx.scope_var_mapping with
|
||||
| WholeVar v -> v
|
||||
| States _ -> failwith "should not happen")
|
||||
(typ, scope_def.scope_def_io)
|
||||
acc
|
||||
| States states ->
|
||||
(* What happens in the case of variables with multiple states is
|
||||
interesting. We need to create as many Scopelang.Var entries in the
|
||||
scope signature as there are states. *)
|
||||
List.fold_left
|
||||
(fun acc (state : StateName.t) ->
|
||||
let scope_def =
|
||||
Ast.ScopeDefMap.find
|
||||
(Ast.ScopeDef.Var (var, Some state))
|
||||
scope.scope_defs
|
||||
in
|
||||
ScopeVarMap.add
|
||||
(match ScopeVarMap.find var ctx.scope_var_mapping with
|
||||
| WholeVar _ -> failwith "should not happen"
|
||||
| States states' -> List.assoc state states')
|
||||
(scope_def.scope_def_typ, scope_def.scope_def_io)
|
||||
acc)
|
||||
acc states)
|
||||
scope.scope_vars ScopeVarMap.empty
|
||||
in
|
||||
let pos = Marked.get_mark (ScopeName.get_info scope.scope_uid) in
|
||||
{
|
||||
Scopelang.Ast.scope_decl_name = scope.scope_uid;
|
||||
Scopelang.Ast.scope_decl_rules;
|
||||
Scopelang.Ast.scope_sig;
|
||||
Scopelang.Ast.scope_mark = Untyped { pos };
|
||||
}
|
||||
|
||||
(** {1 API} *)
|
||||
|
||||
let translate_program (pgrm : Ast.program) : untyped Scopelang.Ast.program =
|
||||
(* First we give mappings to all the locations between Desugared and
|
||||
Scopelang. This involves creating a new Scopelang scope variable for every
|
||||
state of a Desugared variable. *)
|
||||
let ctx =
|
||||
ScopeMap.fold
|
||||
(fun _scope scope_decl ctx ->
|
||||
ScopeVarMap.fold
|
||||
(fun scope_var (states : Ast.var_or_states) ctx ->
|
||||
match states with
|
||||
| Ast.WholeVar ->
|
||||
{
|
||||
ctx with
|
||||
scope_var_mapping =
|
||||
ScopeVarMap.add scope_var
|
||||
(WholeVar (ScopeVar.fresh (ScopeVar.get_info scope_var)))
|
||||
ctx.scope_var_mapping;
|
||||
}
|
||||
| States states ->
|
||||
{
|
||||
ctx with
|
||||
scope_var_mapping =
|
||||
ScopeVarMap.add scope_var
|
||||
(States
|
||||
(List.map
|
||||
(fun state ->
|
||||
( state,
|
||||
ScopeVar.fresh
|
||||
(let state_name, state_pos =
|
||||
StateName.get_info state
|
||||
in
|
||||
( Marked.unmark (ScopeVar.get_info scope_var)
|
||||
^ "_"
|
||||
^ state_name,
|
||||
state_pos )) ))
|
||||
states))
|
||||
ctx.scope_var_mapping;
|
||||
})
|
||||
scope_decl.Ast.scope_vars ctx)
|
||||
pgrm.Ast.program_scopes
|
||||
{ scope_var_mapping = ScopeVarMap.empty; var_mapping = Var.Map.empty }
|
||||
in
|
||||
{
|
||||
Scopelang.Ast.program_scopes =
|
||||
ScopeMap.map (translate_scope ctx) pgrm.program_scopes;
|
||||
Scopelang.Ast.program_ctx = pgrm.program_ctx;
|
||||
}
|
78
compiler/desugared/disambiguate.ml
Normal file
78
compiler/desugared/disambiguate.ml
Normal file
@ -0,0 +1,78 @@
|
||||
(* This file is part of the Catala compiler, a specification language for tax
|
||||
and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
|
||||
Louis Gesbert <louis.gesbert@inria.fr>
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not
|
||||
use this file except in compliance with the License. You may obtain a copy of
|
||||
the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
License for the specific language governing permissions and limitations under
|
||||
the License. *)
|
||||
|
||||
open Shared_ast
|
||||
open Ast
|
||||
|
||||
let expr ctx env e =
|
||||
(* The typer takes care of disambiguating: this consists in: - ensuring
|
||||
[EAbs.tys] doesn't contain any [TAny] - [EDStructAccess.name_opt] is always
|
||||
[Some] *)
|
||||
(* Intermediate unboxings are fine since the last [untype] will rebox in
|
||||
depth *)
|
||||
Typing.check_expr ctx ~env (Expr.unbox e)
|
||||
|
||||
let rule ctx env rule =
|
||||
let env =
|
||||
match rule.rule_parameter with
|
||||
| None -> env
|
||||
| Some (v, ty) -> Typing.Env.add_var v ty env
|
||||
in
|
||||
(* Note: we could use the known rule type here to direct typing. We choose not
|
||||
to because it shouldn't be needed for disambiguation, and we prefer to
|
||||
focus on local type errors first. *)
|
||||
{
|
||||
rule with
|
||||
rule_just = expr ctx env rule.rule_just;
|
||||
rule_cons = expr ctx env rule.rule_cons;
|
||||
}
|
||||
|
||||
let scope ctx env scope =
|
||||
let env = Typing.Env.open_scope scope.scope_uid env in
|
||||
let scope_defs =
|
||||
ScopeDefMap.map
|
||||
(fun def ->
|
||||
let scope_def_rules =
|
||||
(* Note: ordering in file order might be better for error reporting ?
|
||||
When we gather errors, the ordering could be done afterwards,
|
||||
though *)
|
||||
RuleName.Map.map (rule ctx env) def.scope_def_rules
|
||||
in
|
||||
{ def with scope_def_rules })
|
||||
scope.scope_defs
|
||||
in
|
||||
let scope_assertions = List.map (expr ctx env) scope.scope_assertions in
|
||||
{ scope with scope_defs; scope_assertions }
|
||||
|
||||
let program prg =
|
||||
let env =
|
||||
ScopeName.Map.fold
|
||||
(fun scope_name scope env ->
|
||||
let vars =
|
||||
ScopeDefMap.fold
|
||||
(fun var def vars ->
|
||||
match var with
|
||||
| Var (v, _states) -> ScopeVar.Map.add v def.scope_def_typ vars
|
||||
| SubScopeVar _ -> vars)
|
||||
scope.scope_defs ScopeVar.Map.empty
|
||||
in
|
||||
Typing.Env.add_scope scope_name ~vars env)
|
||||
prg.program_scopes Typing.Env.empty
|
||||
in
|
||||
let program_scopes =
|
||||
ScopeName.Map.map (scope prg.program_ctx env) prg.program_scopes
|
||||
in
|
||||
{ prg with program_scopes }
|
24
compiler/desugared/disambiguate.mli
Normal file
24
compiler/desugared/disambiguate.mli
Normal file
@ -0,0 +1,24 @@
|
||||
(* This file is part of the Catala compiler, a specification language for tax
|
||||
and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
|
||||
Louis Gesbert <louis.gesbert@inria.fr>
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not
|
||||
use this file except in compliance with the License. You may obtain a copy of
|
||||
the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
License for the specific language governing permissions and limitations under
|
||||
the License. *)
|
||||
|
||||
(** This module does local typing in order to fill some missing type information
|
||||
in the AST:
|
||||
|
||||
- it fills the types of arguments in [EAbs] nodes, (untyped ones are
|
||||
inserted during desugaring, e.g. by `let-in` constructs),
|
||||
- it resolves the structure names of [EDStructAccess] nodes. *)
|
||||
|
||||
val program : Ast.program -> Ast.program
|
@ -1,7 +1,7 @@
|
||||
(library
|
||||
(name desugared)
|
||||
(public_name catala.desugared)
|
||||
(libraries utils dcalc scopelang ocamlgraph))
|
||||
(libraries ocamlgraph catala_utils shared_ast surface))
|
||||
|
||||
(documentation
|
||||
(package catala)
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -20,6 +20,6 @@
|
||||
- Removes syntactic sugars
|
||||
- Separate code from legislation *)
|
||||
|
||||
val desugar_program :
|
||||
Name_resolution.context -> Ast.program -> Desugared.Ast.program
|
||||
val translate_program :
|
||||
Name_resolution.context -> Surface.Ast.program -> Ast.program
|
||||
(** Main function of this module *)
|
File diff suppressed because it is too large
Load Diff
@ -18,20 +18,18 @@
|
||||
(** Builds a context that allows for mapping each name to a precise uid, taking
|
||||
lexical scopes into account *)
|
||||
|
||||
open Utils
|
||||
open Catala_utils
|
||||
open Shared_ast
|
||||
|
||||
(** {1 Name resolution context} *)
|
||||
|
||||
type ident = string
|
||||
|
||||
type unique_rulename =
|
||||
| Ambiguous of Pos.t list
|
||||
| Unique of Desugared.Ast.RuleName.t Marked.pos
|
||||
| Unique of RuleName.t Marked.pos
|
||||
|
||||
type scope_def_context = {
|
||||
default_exception_rulename : unique_rulename option;
|
||||
label_idmap : Desugared.Ast.LabelName.t Desugared.Ast.IdentMap.t;
|
||||
label_idmap : LabelName.t IdentName.Map.t;
|
||||
}
|
||||
|
||||
type scope_var_or_subscope =
|
||||
@ -39,26 +37,26 @@ type scope_var_or_subscope =
|
||||
| SubScope of SubScopeName.t * ScopeName.t
|
||||
|
||||
type scope_context = {
|
||||
var_idmap : scope_var_or_subscope Desugared.Ast.IdentMap.t;
|
||||
var_idmap : scope_var_or_subscope IdentName.Map.t;
|
||||
(** All variables, including scope variables and subscopes *)
|
||||
scope_defs_contexts : scope_def_context Desugared.Ast.ScopeDefMap.t;
|
||||
scope_defs_contexts : scope_def_context Ast.ScopeDefMap.t;
|
||||
(** What is the default rule to refer to for unnamed exceptions, if any *)
|
||||
sub_scopes : ScopeSet.t;
|
||||
sub_scopes : ScopeName.Set.t;
|
||||
(** Other scopes referred to by this scope. Used for dependency analysis *)
|
||||
}
|
||||
(** Inside a scope, we distinguish between the variables and the subscopes. *)
|
||||
|
||||
type struct_context = typ StructFieldMap.t
|
||||
type struct_context = typ StructField.Map.t
|
||||
(** Types of the fields of a struct *)
|
||||
|
||||
type enum_context = typ EnumConstructorMap.t
|
||||
type enum_context = typ EnumConstructor.Map.t
|
||||
(** Types of the payloads of the cases of an enum *)
|
||||
|
||||
type var_sig = {
|
||||
var_sig_typ : typ;
|
||||
var_sig_is_condition : bool;
|
||||
var_sig_io : Ast.scope_decl_context_io;
|
||||
var_sig_states_idmap : StateName.t Desugared.Ast.IdentMap.t;
|
||||
var_sig_io : Surface.Ast.scope_decl_context_io;
|
||||
var_sig_states_idmap : StateName.t IdentName.Map.t;
|
||||
var_sig_states_list : StateName.t list;
|
||||
}
|
||||
|
||||
@ -67,25 +65,26 @@ type var_sig = {
|
||||
type typedef =
|
||||
| TStruct of StructName.t
|
||||
| TEnum of EnumName.t
|
||||
| TScope of ScopeName.t * StructName.t
|
||||
| TScope of ScopeName.t * scope_out_struct
|
||||
(** Implicitly defined output struct *)
|
||||
|
||||
type context = {
|
||||
local_var_idmap : Desugared.Ast.expr Var.t Desugared.Ast.IdentMap.t;
|
||||
local_var_idmap : Ast.expr Var.t IdentName.Map.t;
|
||||
(** Inside a definition, local variables can be introduced by functions
|
||||
arguments or pattern matching *)
|
||||
typedefs : typedef Desugared.Ast.IdentMap.t;
|
||||
typedefs : typedef IdentName.Map.t;
|
||||
(** Gathers the names of the scopes, structs and enums *)
|
||||
field_idmap : StructFieldName.t StructMap.t Desugared.Ast.IdentMap.t;
|
||||
field_idmap : StructField.t StructName.Map.t IdentName.Map.t;
|
||||
(** The names of the struct fields. Names of fields can be shared between
|
||||
different structs *)
|
||||
constructor_idmap : EnumConstructor.t EnumMap.t Desugared.Ast.IdentMap.t;
|
||||
constructor_idmap : EnumConstructor.t EnumName.Map.t IdentName.Map.t;
|
||||
(** The names of the enum constructors. Constructor names can be shared
|
||||
between different enums *)
|
||||
scopes : scope_context ScopeMap.t; (** For each scope, its context *)
|
||||
structs : struct_context StructMap.t; (** For each struct, its context *)
|
||||
enums : enum_context EnumMap.t; (** For each enum, its context *)
|
||||
var_typs : var_sig ScopeVarMap.t;
|
||||
scopes : scope_context ScopeName.Map.t; (** For each scope, its context *)
|
||||
structs : struct_context StructName.Map.t;
|
||||
(** For each struct, its context *)
|
||||
enums : enum_context EnumName.Map.t; (** For each enum, its context *)
|
||||
var_typs : var_sig ScopeVar.Map.t;
|
||||
(** The signatures of each scope variable declared *)
|
||||
}
|
||||
(** Main context used throughout {!module: Surface.Desugaring} *)
|
||||
@ -96,7 +95,7 @@ val raise_unsupported_feature : string -> Pos.t -> 'a
|
||||
(** Temporary function raising an error message saying that a feature is not
|
||||
supported yet *)
|
||||
|
||||
val raise_unknown_identifier : string -> ident Marked.pos -> 'a
|
||||
val raise_unknown_identifier : string -> IdentName.t Marked.pos -> 'a
|
||||
(** Function to call whenever an identifier used somewhere has not been declared
|
||||
in the program previously *)
|
||||
|
||||
@ -104,53 +103,53 @@ val get_var_typ : context -> ScopeVar.t -> typ
|
||||
(** Gets the type associated to an uid *)
|
||||
|
||||
val is_var_cond : context -> ScopeVar.t -> bool
|
||||
val get_var_io : context -> ScopeVar.t -> Ast.scope_decl_context_io
|
||||
val get_var_io : context -> ScopeVar.t -> Surface.Ast.scope_decl_context_io
|
||||
|
||||
val get_var_uid : ScopeName.t -> context -> ident Marked.pos -> ScopeVar.t
|
||||
val get_var_uid : ScopeName.t -> context -> IdentName.t Marked.pos -> ScopeVar.t
|
||||
(** Get the variable uid inside the scope given in argument *)
|
||||
|
||||
val get_subscope_uid :
|
||||
ScopeName.t -> context -> ident Marked.pos -> SubScopeName.t
|
||||
ScopeName.t -> context -> IdentName.t Marked.pos -> SubScopeName.t
|
||||
(** Get the subscope uid inside the scope given in argument *)
|
||||
|
||||
val is_subscope_uid : ScopeName.t -> context -> ident -> bool
|
||||
val is_subscope_uid : ScopeName.t -> context -> IdentName.t -> bool
|
||||
(** [is_subscope_uid scope_uid ctxt y] returns true if [y] belongs to the
|
||||
subscopes of [scope_uid]. *)
|
||||
|
||||
val belongs_to : context -> ScopeVar.t -> ScopeName.t -> bool
|
||||
(** Checks if the var_uid belongs to the scope scope_uid *)
|
||||
|
||||
val get_def_typ : context -> Desugared.Ast.ScopeDef.t -> typ
|
||||
val get_def_typ : context -> Ast.ScopeDef.t -> typ
|
||||
(** Retrieves the type of a scope definition from the context *)
|
||||
|
||||
val is_def_cond : context -> Desugared.Ast.ScopeDef.t -> bool
|
||||
val is_type_cond : Ast.typ -> bool
|
||||
val is_def_cond : context -> Ast.ScopeDef.t -> bool
|
||||
val is_type_cond : Surface.Ast.typ -> bool
|
||||
|
||||
val add_def_local_var : context -> ident -> context * Desugared.Ast.expr Var.t
|
||||
val add_def_local_var : context -> IdentName.t -> context * Ast.expr Var.t
|
||||
(** Adds a binding to the context *)
|
||||
|
||||
val get_def_key :
|
||||
Ast.qident ->
|
||||
Ast.ident Marked.pos option ->
|
||||
Surface.Ast.scope_var ->
|
||||
Surface.Ast.lident Marked.pos option ->
|
||||
ScopeName.t ->
|
||||
context ->
|
||||
Pos.t ->
|
||||
Desugared.Ast.ScopeDef.t
|
||||
Ast.ScopeDef.t
|
||||
(** Usage: [get_def_key var_name var_state scope_uid ctxt pos]*)
|
||||
|
||||
val get_enum : context -> ident Marked.pos -> EnumName.t
|
||||
val get_enum : context -> IdentName.t Marked.pos -> EnumName.t
|
||||
(** Find an enum definition from the typedefs, failing if there is none or it
|
||||
has a different kind *)
|
||||
|
||||
val get_struct : context -> ident Marked.pos -> StructName.t
|
||||
val get_struct : context -> IdentName.t Marked.pos -> StructName.t
|
||||
(** Find a struct definition from the typedefs (possibly an implicit output
|
||||
struct from a scope), failing if there is none or it has a different kind *)
|
||||
|
||||
val get_scope : context -> ident Marked.pos -> ScopeName.t
|
||||
val get_scope : context -> IdentName.t Marked.pos -> ScopeName.t
|
||||
(** Find a scope definition from the typedefs, failing if there is none or it
|
||||
has a different kind *)
|
||||
|
||||
(** {1 API} *)
|
||||
|
||||
val form_context : Ast.program -> context
|
||||
val form_context : Surface.Ast.program -> context
|
||||
(** Derive the context from metadata, in one pass over the declarations *)
|
@ -15,10 +15,7 @@
|
||||
License for the specific language governing permissions and limitations under
|
||||
the License. *)
|
||||
|
||||
module Cli = Utils.Cli
|
||||
module File = Utils.File
|
||||
module Errors = Utils.Errors
|
||||
module Pos = Utils.Pos
|
||||
open Catala_utils
|
||||
|
||||
(** Associates a {!type: Cli.backend_lang} with its string represtation. *)
|
||||
let languages = ["en", Cli.En; "fr", Cli.Fr; "pl", Cli.Pl]
|
||||
@ -76,7 +73,15 @@ let driver source_file (options : Cli.options) : int =
|
||||
try `Plugin (Plugin.find s)
|
||||
with Not_found ->
|
||||
Errors.raise_error
|
||||
"The selected backend (%s) is not supported by Catala" backend)
|
||||
"The selected backend (%s) is not supported by Catala, nor was a \
|
||||
plugin by this name found under %a"
|
||||
backend
|
||||
(Format.pp_print_list
|
||||
~pp_sep:(fun ppf () -> Format.fprintf ppf "@ or @ ")
|
||||
(fun ppf dir ->
|
||||
Format.pp_print_string ppf
|
||||
(try Unix.readlink dir with _ -> dir)))
|
||||
options.plugins_dirs)
|
||||
in
|
||||
let prgm =
|
||||
Surface.Parser_driver.parse_top_level_file source_file language
|
||||
@ -143,7 +148,7 @@ let driver source_file (options : Cli.options) : int =
|
||||
| ( `Interpret | `Typecheck | `OCaml | `Python | `Scalc | `Lcalc | `Dcalc
|
||||
| `Scopelang | `Proof | `Plugin _ ) as backend -> (
|
||||
Cli.debug_print "Name resolution...";
|
||||
let ctxt = Surface.Name_resolution.form_context prgm in
|
||||
let ctxt = Desugared.Name_resolution.form_context prgm in
|
||||
let scope_uid =
|
||||
match options.ex_scope, backend with
|
||||
| None, `Interpret ->
|
||||
@ -151,27 +156,29 @@ let driver source_file (options : Cli.options) : int =
|
||||
| None, _ ->
|
||||
let _, scope =
|
||||
try
|
||||
Desugared.Ast.IdentMap.filter_map
|
||||
Shared_ast.IdentName.Map.filter_map
|
||||
(fun _ -> function
|
||||
| Surface.Name_resolution.TScope (uid, _) -> Some uid
|
||||
| Desugared.Name_resolution.TScope (uid, _) -> Some uid
|
||||
| _ -> None)
|
||||
ctxt.typedefs
|
||||
|> Desugared.Ast.IdentMap.choose
|
||||
|> Shared_ast.IdentName.Map.choose
|
||||
with Not_found ->
|
||||
Errors.raise_error "There isn't any scope inside the program."
|
||||
in
|
||||
scope
|
||||
| Some name, _ -> (
|
||||
match Desugared.Ast.IdentMap.find_opt name ctxt.typedefs with
|
||||
| Some (Surface.Name_resolution.TScope (uid, _)) -> uid
|
||||
match Shared_ast.IdentName.Map.find_opt name ctxt.typedefs with
|
||||
| Some (Desugared.Name_resolution.TScope (uid, _)) -> uid
|
||||
| _ ->
|
||||
Errors.raise_error "There is no scope \"%s\" inside the program."
|
||||
name)
|
||||
in
|
||||
Cli.debug_print "Desugaring...";
|
||||
let prgm = Surface.Desugaring.desugar_program ctxt prgm in
|
||||
let prgm = Desugared.From_surface.translate_program ctxt prgm in
|
||||
Cli.debug_print "Disambiguating...";
|
||||
let prgm = Desugared.Disambiguate.program prgm in
|
||||
Cli.debug_print "Collecting rules...";
|
||||
let prgm = Desugared.Desugared_to_scope.translate_program prgm in
|
||||
let prgm = Scopelang.From_desugared.translate_program prgm in
|
||||
match backend with
|
||||
| `Scopelang ->
|
||||
let _output_file, with_output = get_output_format () in
|
||||
@ -180,7 +187,8 @@ let driver source_file (options : Cli.options) : int =
|
||||
if Option.is_some options.ex_scope then
|
||||
Format.fprintf fmt "%a\n"
|
||||
(Scopelang.Print.scope prgm.program_ctx ~debug:options.debug)
|
||||
(scope_uid, Shared_ast.ScopeMap.find scope_uid prgm.program_scopes)
|
||||
( scope_uid,
|
||||
Shared_ast.ScopeName.Map.find scope_uid prgm.program_scopes )
|
||||
else
|
||||
Format.fprintf fmt "%a\n"
|
||||
(Scopelang.Print.program ~debug:options.debug)
|
||||
@ -194,7 +202,7 @@ let driver source_file (options : Cli.options) : int =
|
||||
in
|
||||
let prgm = Scopelang.Ast.type_program prgm in
|
||||
Cli.debug_print "Translating to default calculus...";
|
||||
let prgm = Scopelang.Scope_to_dcalc.translate_program prgm in
|
||||
let prgm = Dcalc.From_scopelang.translate_program prgm in
|
||||
let prgm =
|
||||
if options.optimize then begin
|
||||
Cli.debug_print "Optimizing default calculus...";
|
||||
@ -202,8 +210,21 @@ let driver source_file (options : Cli.options) : int =
|
||||
end
|
||||
else prgm
|
||||
in
|
||||
(* Cli.debug_print (Format.asprintf "Typechecking results :@\n%a"
|
||||
(Print.typ prgm.decl_ctx) typ); *)
|
||||
match backend with
|
||||
| `Typecheck ->
|
||||
Cli.debug_print "Typechecking again...";
|
||||
let _ =
|
||||
try Shared_ast.Typing.program prgm
|
||||
with Errors.StructuredError (msg, details) ->
|
||||
let msg =
|
||||
"Typing error occured during re-typing on the 'default \
|
||||
calculus'. This is a bug in the Catala compiler.\n"
|
||||
^ msg
|
||||
in
|
||||
raise (Errors.StructuredError (msg, details))
|
||||
in
|
||||
(* That's it! *)
|
||||
Cli.result_print "Typechecking successful!"
|
||||
| `Dcalc ->
|
||||
@ -229,7 +250,7 @@ let driver source_file (options : Cli.options) : int =
|
||||
Shared_ast.Expr.unbox (Shared_ast.Program.to_expr prgm scope_uid)
|
||||
in
|
||||
Format.fprintf fmt "%a\n"
|
||||
(Shared_ast.Expr.format prgm.decl_ctx)
|
||||
(Shared_ast.Expr.format ~debug:options.debug prgm.decl_ctx)
|
||||
prgrm_dcalc_expr
|
||||
| (`Interpret | `OCaml | `Python | `Scalc | `Lcalc | `Proof | `Plugin _)
|
||||
as backend -> (
|
||||
@ -244,8 +265,6 @@ let driver source_file (options : Cli.options) : int =
|
||||
in
|
||||
raise (Errors.StructuredError (msg, details))
|
||||
in
|
||||
(* Cli.debug_print (Format.asprintf "Typechecking results :@\n%a"
|
||||
(Print.typ prgm.decl_ctx) typ); *)
|
||||
match backend with
|
||||
| `Proof ->
|
||||
let vcs =
|
||||
@ -308,24 +327,14 @@ let driver source_file (options : Cli.options) : int =
|
||||
if Option.is_some options.ex_scope then
|
||||
Format.fprintf fmt "%a\n"
|
||||
(Shared_ast.Scope.format ~debug:options.debug prgm.decl_ctx)
|
||||
( scope_uid,
|
||||
Option.get
|
||||
(Shared_ast.Scope.fold_left ~init:None
|
||||
~f:(fun acc scope_def _ ->
|
||||
if
|
||||
Shared_ast.ScopeName.compare scope_def.scope_name
|
||||
scope_uid
|
||||
= 0
|
||||
then Some scope_def.scope_body
|
||||
else acc)
|
||||
prgm.scopes) )
|
||||
(scope_uid, Shared_ast.Program.get_scope_body prgm scope_uid)
|
||||
else
|
||||
let prgrm_lcalc_expr =
|
||||
Shared_ast.Expr.unbox
|
||||
(Shared_ast.Program.to_expr prgm scope_uid)
|
||||
in
|
||||
Format.fprintf fmt "%a\n"
|
||||
(Shared_ast.Expr.format prgm.decl_ctx)
|
||||
(Shared_ast.Expr.format ~debug:options.debug prgm.decl_ctx)
|
||||
prgrm_lcalc_expr
|
||||
| (`OCaml | `Python | `Scalc | `Plugin _) as backend -> (
|
||||
match backend with
|
||||
|
@ -15,9 +15,10 @@
|
||||
License for the specific language governing permissions and limitations under
|
||||
the License. *)
|
||||
|
||||
open Catala_utils
|
||||
module Plugin = Plugin.PluginAPI
|
||||
|
||||
val driver : Utils.Pos.input_file -> Utils.Cli.options -> int
|
||||
val driver : Pos.input_file -> Cli.options -> int
|
||||
(** Entry function for the executable. Returns a negative number in case of
|
||||
error. *)
|
||||
|
||||
|
@ -3,7 +3,7 @@
|
||||
(public_name catala.driver)
|
||||
(libraries
|
||||
dynlink
|
||||
utils
|
||||
catala_utils
|
||||
surface
|
||||
desugared
|
||||
literate
|
||||
@ -50,3 +50,7 @@
|
||||
(documentation
|
||||
(package catala)
|
||||
(mld_files index))
|
||||
|
||||
(alias
|
||||
(name catala)
|
||||
(deps catala.exe))
|
||||
|
@ -103,7 +103,7 @@ Two more modules contain additional features for the compiler:
|
||||
|
||||
{ul
|
||||
{li {{: literate.html} Literate programming}}
|
||||
{li {{: utils.html} Compiler utilities}}
|
||||
{li {{: catala_utils.html} Compiler utilities}}
|
||||
}
|
||||
|
||||
The Catala runtimes documentation is available here:
|
||||
|
@ -14,7 +14,7 @@
|
||||
License for the specific language governing permissions and limitations under
|
||||
the License. *)
|
||||
|
||||
open Utils
|
||||
open Catala_utils
|
||||
include Shared_ast
|
||||
|
||||
type lit = lcalc glit
|
||||
@ -28,31 +28,32 @@ let option_enum : EnumName.t = EnumName.fresh ("eoption", Pos.no_pos)
|
||||
let none_constr : EnumConstructor.t = EnumConstructor.fresh ("ENone", Pos.no_pos)
|
||||
let some_constr : EnumConstructor.t = EnumConstructor.fresh ("ESome", Pos.no_pos)
|
||||
|
||||
let option_enum_config : (EnumConstructor.t * typ) list =
|
||||
[none_constr, (TLit TUnit, Pos.no_pos); some_constr, (TAny, Pos.no_pos)]
|
||||
let option_enum_config : typ EnumConstructor.Map.t =
|
||||
EnumConstructor.Map.empty
|
||||
|> EnumConstructor.Map.add none_constr (TLit TUnit, Pos.no_pos)
|
||||
|> EnumConstructor.Map.add some_constr (TAny, Pos.no_pos)
|
||||
|
||||
(* FIXME: proper typing in all the constructors below *)
|
||||
|
||||
let make_none m =
|
||||
let tunit = TLit TUnit, Expr.mark_pos m in
|
||||
Expr.einj
|
||||
(Expr.elit LUnit (Expr.with_ty m tunit))
|
||||
0 option_enum
|
||||
[TLit TUnit, Pos.no_pos; TAny, Pos.no_pos]
|
||||
m
|
||||
Expr.einj (Expr.elit LUnit (Expr.with_ty m tunit)) none_constr option_enum m
|
||||
|
||||
let make_some e =
|
||||
let m = Marked.get_mark e in
|
||||
Expr.einj e 1 option_enum
|
||||
[TLit TUnit, Expr.mark_pos m; TAny, Expr.mark_pos m]
|
||||
m
|
||||
Expr.einj e some_constr option_enum m
|
||||
|
||||
(** [make_matchopt_with_abs_arms arg e_none e_some] build an expression
|
||||
[match arg with |None -> e_none | Some -> e_some] and requires e_some and
|
||||
e_none to be in the form [EAbs ...].*)
|
||||
let make_matchopt_with_abs_arms arg e_none e_some =
|
||||
let m = Marked.get_mark arg in
|
||||
Expr.ematch arg [e_none; e_some] option_enum m
|
||||
let cases =
|
||||
EnumConstructor.Map.empty
|
||||
|> EnumConstructor.Map.add none_constr e_none
|
||||
|> EnumConstructor.Map.add some_constr e_some
|
||||
in
|
||||
Expr.ematch arg option_enum cases m
|
||||
|
||||
(** [make_matchopt pos v tau arg e_none e_some] builds an expression
|
||||
[match arg with | None () -> e_none | Some v -> e_some]. It binds v to
|
||||
|
@ -14,6 +14,7 @@
|
||||
License for the specific language governing permissions and limitations under
|
||||
the License. *)
|
||||
|
||||
open Catala_utils
|
||||
open Shared_ast
|
||||
|
||||
(** Abstract syntax tree for the lambda calculus *)
|
||||
@ -32,7 +33,7 @@ type 'm program = 'm expr Shared_ast.program
|
||||
val option_enum : EnumName.t
|
||||
val none_constr : EnumConstructor.t
|
||||
val some_constr : EnumConstructor.t
|
||||
val option_enum_config : (EnumConstructor.t * typ) list
|
||||
val option_enum_config : typ EnumConstructor.Map.t
|
||||
val make_none : 'm mark -> 'm expr boxed
|
||||
val make_some : 'm expr boxed -> 'm expr boxed
|
||||
|
||||
@ -40,7 +41,7 @@ val make_matchopt_with_abs_arms :
|
||||
'm expr boxed -> 'm expr boxed -> 'm expr boxed -> 'm expr boxed
|
||||
|
||||
val make_matchopt :
|
||||
Utils.Pos.t ->
|
||||
Pos.t ->
|
||||
'm expr Var.t ->
|
||||
typ ->
|
||||
'm expr boxed ->
|
||||
|
@ -14,7 +14,7 @@
|
||||
License for the specific language governing permissions and limitations under
|
||||
the License. *)
|
||||
|
||||
open Utils
|
||||
open Catala_utils
|
||||
open Shared_ast
|
||||
open Ast
|
||||
module D = Dcalc.Ast
|
||||
@ -31,74 +31,56 @@ let closure_conversion_expr (type m) (ctx : m ctx) (e : m expr) : m expr boxed =
|
||||
let rec aux e =
|
||||
let m = Marked.get_mark e in
|
||||
match Marked.unmark e with
|
||||
| EStruct _ | EStructAccess _ | ETuple _ | ETupleAccess _ | EInj _
|
||||
| EArray _ | ELit _ | EAssert _ | EOp _ | EIfThenElse _ | ERaise _
|
||||
| ECatch _ ->
|
||||
Expr.map_gather ~acc:Var.Set.empty ~join:Var.Set.union ~f:aux e
|
||||
| EVar v ->
|
||||
( (Bindlib.box_var v, m),
|
||||
if Var.Set.mem v ctx.globally_bound_vars then Var.Set.empty
|
||||
else Var.Set.singleton v )
|
||||
| ETuple (args, s) ->
|
||||
let new_args, free_vars =
|
||||
List.fold_left
|
||||
(fun (new_args, free_vars) arg ->
|
||||
let new_arg, new_free_vars = aux arg in
|
||||
new_arg :: new_args, Var.Set.union new_free_vars free_vars)
|
||||
([], Var.Set.empty) args
|
||||
in
|
||||
Expr.etuple (List.rev new_args) s m, free_vars
|
||||
| ETupleAccess (e1, n, s, typs) ->
|
||||
let new_e1, free_vars = aux e1 in
|
||||
Expr.etupleaccess new_e1 n s typs m, free_vars
|
||||
| EInj (e1, n, e_name, typs) ->
|
||||
let new_e1, free_vars = aux e1 in
|
||||
Expr.einj new_e1 n e_name typs m, free_vars
|
||||
| EMatch (e1, arms, e_name) ->
|
||||
let new_e1, free_vars = aux e1 in
|
||||
( (if Var.Set.mem v ctx.globally_bound_vars then Var.Set.empty
|
||||
else Var.Set.singleton v),
|
||||
(Bindlib.box_var v, m) )
|
||||
| EMatch { e; cases; name } ->
|
||||
let free_vars, new_e = aux e in
|
||||
(* We do not close the clotures inside the arms of the match expression,
|
||||
since they get a special treatment at compilation to Scalc. *)
|
||||
let new_arms, free_vars =
|
||||
List.fold_right
|
||||
(fun arm (new_arms, free_vars) ->
|
||||
match Marked.unmark arm with
|
||||
| EAbs (binder, typs) ->
|
||||
let free_vars, new_cases =
|
||||
EnumConstructor.Map.fold
|
||||
(fun cons e1 (free_vars, new_cases) ->
|
||||
match Marked.unmark e1 with
|
||||
| EAbs { binder; tys } ->
|
||||
let vars, body = Bindlib.unmbind binder in
|
||||
let new_body, new_free_vars = aux body in
|
||||
let new_free_vars, new_body = aux body in
|
||||
let new_binder = Expr.bind vars new_body in
|
||||
( Expr.eabs new_binder typs (Marked.get_mark arm) :: new_arms,
|
||||
Var.Set.union free_vars new_free_vars )
|
||||
( Var.Set.union free_vars new_free_vars,
|
||||
EnumConstructor.Map.add cons
|
||||
(Expr.eabs new_binder tys (Marked.get_mark e1))
|
||||
new_cases )
|
||||
| _ -> failwith "should not happen")
|
||||
arms ([], free_vars)
|
||||
cases
|
||||
(free_vars, EnumConstructor.Map.empty)
|
||||
in
|
||||
Expr.ematch new_e1 new_arms e_name m, free_vars
|
||||
| EArray args ->
|
||||
let new_args, free_vars =
|
||||
List.fold_right
|
||||
(fun arg (new_args, free_vars) ->
|
||||
let new_arg, new_free_vars = aux arg in
|
||||
new_arg :: new_args, Var.Set.union free_vars new_free_vars)
|
||||
args ([], Var.Set.empty)
|
||||
in
|
||||
Expr.earray new_args m, free_vars
|
||||
| ELit l -> Expr.elit l m, Var.Set.empty
|
||||
| EApp ((EAbs (binder, typs_abs), e1_pos), args) ->
|
||||
free_vars, Expr.ematch new_e name new_cases m
|
||||
| EApp { f = EAbs { binder; tys }, e1_pos; args } ->
|
||||
(* let-binding, we should not close these *)
|
||||
let vars, body = Bindlib.unmbind binder in
|
||||
let new_body, free_vars = aux body in
|
||||
let free_vars, new_body = aux body in
|
||||
let new_binder = Expr.bind vars new_body in
|
||||
let new_args, free_vars =
|
||||
let free_vars, new_args =
|
||||
List.fold_right
|
||||
(fun arg (new_args, free_vars) ->
|
||||
let new_arg, new_free_vars = aux arg in
|
||||
new_arg :: new_args, Var.Set.union free_vars new_free_vars)
|
||||
args ([], free_vars)
|
||||
(fun arg (free_vars, new_args) ->
|
||||
let new_free_vars, new_arg = aux arg in
|
||||
Var.Set.union free_vars new_free_vars, new_arg :: new_args)
|
||||
args (free_vars, [])
|
||||
in
|
||||
Expr.eapp (Expr.eabs new_binder typs_abs e1_pos) new_args m, free_vars
|
||||
| EAbs (binder, typs) ->
|
||||
free_vars, Expr.eapp (Expr.eabs new_binder tys e1_pos) new_args m
|
||||
| EAbs { binder; tys } ->
|
||||
(* λ x.t *)
|
||||
let binder_mark = m in
|
||||
let binder_pos = Expr.mark_pos binder_mark in
|
||||
(* Converting the closure. *)
|
||||
let vars, body = Bindlib.unmbind binder in
|
||||
(* t *)
|
||||
let new_body, body_vars = aux body in
|
||||
let body_vars, new_body = aux body in
|
||||
(* [[t]] *)
|
||||
let extra_vars =
|
||||
Var.Set.diff body_vars (Var.Set.of_list (Array.to_list vars))
|
||||
@ -117,8 +99,8 @@ let closure_conversion_expr (type m) (ctx : m ctx) (e : m expr) : m expr boxed =
|
||||
(fun i _ ->
|
||||
Expr.etupleaccess
|
||||
(Expr.evar inner_c_var binder_mark)
|
||||
(i + 1) None
|
||||
(List.map (fun _ -> any_ty) extra_vars_list)
|
||||
(i + 1)
|
||||
(List.length extra_vars_list)
|
||||
binder_mark)
|
||||
extra_vars_list)
|
||||
new_body
|
||||
@ -128,10 +110,11 @@ let closure_conversion_expr (type m) (ctx : m ctx) (e : m expr) : m expr boxed =
|
||||
Expr.make_abs
|
||||
(Array.concat [Array.make 1 inner_c_var; vars])
|
||||
new_closure_body
|
||||
((TAny, binder_pos) :: typs)
|
||||
((TAny, binder_pos) :: tys)
|
||||
(Expr.pos e)
|
||||
in
|
||||
( Expr.make_let_in code_var
|
||||
( extra_vars,
|
||||
Expr.make_let_in code_var
|
||||
(TAny, Expr.pos e)
|
||||
new_closure
|
||||
(Expr.etuple
|
||||
@ -139,40 +122,25 @@ let closure_conversion_expr (type m) (ctx : m ctx) (e : m expr) : m expr boxed =
|
||||
:: List.map
|
||||
(fun extra_var -> Bindlib.box_var extra_var, binder_mark)
|
||||
extra_vars_list)
|
||||
None m)
|
||||
(Expr.pos e),
|
||||
extra_vars )
|
||||
| EApp ((EOp op, pos_op), args) ->
|
||||
m)
|
||||
(Expr.pos e) )
|
||||
| EApp { f = EOp _, _; _ } ->
|
||||
(* This corresponds to an operator call, which we don't want to
|
||||
transform*)
|
||||
let new_args, free_vars =
|
||||
List.fold_right
|
||||
(fun arg (new_args, free_vars) ->
|
||||
let new_arg, new_free_vars = aux arg in
|
||||
new_arg :: new_args, Var.Set.union free_vars new_free_vars)
|
||||
args ([], Var.Set.empty)
|
||||
in
|
||||
Expr.eapp (Expr.eop op pos_op) new_args m, free_vars
|
||||
| EApp ((EVar v, v_pos), args) when Var.Set.mem v ctx.globally_bound_vars ->
|
||||
Expr.map_gather ~acc:Var.Set.empty ~join:Var.Set.union ~f:aux e
|
||||
| EApp { f = EVar v, _; _ } when Var.Set.mem v ctx.globally_bound_vars ->
|
||||
(* This corresponds to a scope call, which we don't want to transform*)
|
||||
let new_args, free_vars =
|
||||
List.fold_right
|
||||
(fun arg (new_args, free_vars) ->
|
||||
let new_arg, new_free_vars = aux arg in
|
||||
new_arg :: new_args, Var.Set.union free_vars new_free_vars)
|
||||
args ([], Var.Set.empty)
|
||||
in
|
||||
Expr.eapp (Bindlib.box_var v, v_pos) new_args m, free_vars
|
||||
| EApp (e1, args) ->
|
||||
let new_e1, free_vars = aux e1 in
|
||||
Expr.map_gather ~acc:Var.Set.empty ~join:Var.Set.union ~f:aux e
|
||||
| EApp { f = e1; args } ->
|
||||
let free_vars, new_e1 = aux e1 in
|
||||
let env_var = Var.make "env" in
|
||||
let code_var = Var.make "code" in
|
||||
let new_args, free_vars =
|
||||
let free_vars, new_args =
|
||||
List.fold_right
|
||||
(fun arg (new_args, free_vars) ->
|
||||
let new_arg, new_free_vars = aux arg in
|
||||
new_arg :: new_args, Var.Set.union free_vars new_free_vars)
|
||||
args ([], free_vars)
|
||||
(fun arg (free_vars, new_args) ->
|
||||
let new_free_vars, new_arg = aux arg in
|
||||
Var.Set.union free_vars new_free_vars, new_arg :: new_args)
|
||||
args (free_vars, [])
|
||||
in
|
||||
let call_expr =
|
||||
let m1 = Marked.get_mark e1 in
|
||||
@ -180,7 +148,8 @@ let closure_conversion_expr (type m) (ctx : m ctx) (e : m expr) : m expr boxed =
|
||||
(TAny, Expr.pos e)
|
||||
(Expr.etupleaccess
|
||||
(Bindlib.box_var env_var, m1)
|
||||
0 None [ (*TODO: fill?*) ]
|
||||
0
|
||||
(List.length new_args + 1)
|
||||
m)
|
||||
(Expr.eapp
|
||||
(Bindlib.box_var code_var, m1)
|
||||
@ -188,25 +157,12 @@ let closure_conversion_expr (type m) (ctx : m ctx) (e : m expr) : m expr boxed =
|
||||
m)
|
||||
(Expr.pos e)
|
||||
in
|
||||
( Expr.make_let_in env_var (TAny, Expr.pos e) new_e1 call_expr (Expr.pos e),
|
||||
free_vars )
|
||||
| EAssert e1 ->
|
||||
let new_e1, free_vars = aux e1 in
|
||||
Expr.eassert new_e1 m, free_vars
|
||||
| EOp op -> Expr.eop op m, Var.Set.empty
|
||||
| EIfThenElse (e1, e2, e3) ->
|
||||
let new_e1, free_vars1 = aux e1 in
|
||||
let new_e2, free_vars2 = aux e2 in
|
||||
let new_e3, free_vars3 = aux e3 in
|
||||
( Expr.eifthenelse new_e1 new_e2 new_e3 m,
|
||||
Var.Set.union (Var.Set.union free_vars1 free_vars2) free_vars3 )
|
||||
| ERaise except -> Expr.eraise except m, Var.Set.empty
|
||||
| ECatch (e1, except, e2) ->
|
||||
let new_e1, free_vars1 = aux e1 in
|
||||
let new_e2, free_vars2 = aux e2 in
|
||||
Expr.ecatch new_e1 except new_e2 m, Var.Set.union free_vars1 free_vars2
|
||||
( free_vars,
|
||||
Expr.make_let_in env_var
|
||||
(TAny, Expr.pos e)
|
||||
new_e1 call_expr (Expr.pos e) )
|
||||
in
|
||||
let e', _vars = aux e in
|
||||
let _vars, e' = aux e in
|
||||
e'
|
||||
|
||||
let closure_conversion (p : 'm program) : 'm program Bindlib.box =
|
||||
|
@ -14,7 +14,7 @@
|
||||
License for the specific language governing permissions and limitations under
|
||||
the License. *)
|
||||
|
||||
open Utils
|
||||
open Catala_utils
|
||||
open Shared_ast
|
||||
module D = Dcalc.Ast
|
||||
module A = Ast
|
||||
@ -43,7 +43,7 @@ let rec translate_default
|
||||
Expr.make_app
|
||||
(Expr.make_var
|
||||
(Var.translate A.handle_default)
|
||||
(Expr.with_ty mark_default (Utils.Marked.mark pos TAny)))
|
||||
(Expr.with_ty mark_default (Marked.mark pos TAny)))
|
||||
[
|
||||
Expr.earray exceptions mark_default;
|
||||
thunk_expr (translate_expr ctx just);
|
||||
@ -54,39 +54,39 @@ let rec translate_default
|
||||
exceptions
|
||||
|
||||
and translate_expr (ctx : 'm ctx) (e : 'm D.expr) : 'm A.expr boxed =
|
||||
let m = Marked.get_mark e in
|
||||
match Marked.unmark e with
|
||||
| EVar v -> Expr.make_var (Var.Map.find v ctx) (Marked.get_mark e)
|
||||
| ETuple (args, s) ->
|
||||
Expr.etuple (List.map (translate_expr ctx) args) s (Marked.get_mark e)
|
||||
| ETupleAccess (e1, i, s, ts) ->
|
||||
Expr.etupleaccess (translate_expr ctx e1) i s ts (Marked.get_mark e)
|
||||
| EInj (e1, i, en, ts) ->
|
||||
Expr.einj (translate_expr ctx e1) i en ts (Marked.get_mark e)
|
||||
| EMatch (e1, cases, en) ->
|
||||
Expr.ematch (translate_expr ctx e1)
|
||||
(List.map (translate_expr ctx) cases)
|
||||
en (Marked.get_mark e)
|
||||
| EArray es ->
|
||||
Expr.earray (List.map (translate_expr ctx) es) (Marked.get_mark e)
|
||||
| EVar v -> Expr.make_var (Var.Map.find v ctx) m
|
||||
| EStruct { name; fields } ->
|
||||
Expr.estruct name (StructField.Map.map (translate_expr ctx) fields) m
|
||||
| EStructAccess { name; e; field } ->
|
||||
Expr.estructaccess (translate_expr ctx e) field name m
|
||||
| EInj { name; e; cons } -> Expr.einj (translate_expr ctx e) cons name m
|
||||
| EMatch { name; e; cases } ->
|
||||
Expr.ematch (translate_expr ctx e) name
|
||||
(EnumConstructor.Map.map (translate_expr ctx) cases)
|
||||
m
|
||||
| EArray es -> Expr.earray (List.map (translate_expr ctx) es) m
|
||||
| ELit
|
||||
((LBool _ | LInt _ | LRat _ | LMoney _ | LUnit | LDate _ | LDuration _) as
|
||||
l) ->
|
||||
Expr.elit l (Marked.get_mark e)
|
||||
| ELit LEmptyError -> Expr.eraise EmptyError (Marked.get_mark e)
|
||||
| EOp op -> Expr.eop op (Marked.get_mark e)
|
||||
| EIfThenElse (e1, e2, e3) ->
|
||||
Expr.eifthenelse (translate_expr ctx e1) (translate_expr ctx e2)
|
||||
(translate_expr ctx e3) (Marked.get_mark e)
|
||||
| EAssert e1 -> Expr.eassert (translate_expr ctx e1) (Marked.get_mark e)
|
||||
| ErrorOnEmpty arg ->
|
||||
Expr.elit l m
|
||||
| ELit LEmptyError -> Expr.eraise EmptyError m
|
||||
| EOp { op; tys } -> Expr.eop (Operator.translate op) tys m
|
||||
| EIfThenElse { cond; etrue; efalse } ->
|
||||
Expr.eifthenelse (translate_expr ctx cond) (translate_expr ctx etrue)
|
||||
(translate_expr ctx efalse)
|
||||
m
|
||||
| EAssert e1 -> Expr.eassert (translate_expr ctx e1) m
|
||||
| EErrorOnEmpty arg ->
|
||||
Expr.ecatch (translate_expr ctx arg) EmptyError
|
||||
(Expr.eraise NoValueProvided (Marked.get_mark e))
|
||||
(Marked.get_mark e)
|
||||
| EApp (e1, args) ->
|
||||
Expr.eapp (translate_expr ctx e1)
|
||||
(Expr.eraise NoValueProvided m)
|
||||
m
|
||||
| EApp { f; args } ->
|
||||
Expr.eapp (translate_expr ctx f)
|
||||
(List.map (translate_expr ctx) args)
|
||||
(Marked.get_mark e)
|
||||
| EAbs (binder, ts) ->
|
||||
| EAbs { binder; tys } ->
|
||||
let vars, body = Bindlib.unmbind binder in
|
||||
let ctx, lc_vars =
|
||||
Array.fold_right
|
||||
@ -98,15 +98,16 @@ and translate_expr (ctx : 'm ctx) (e : 'm D.expr) : 'm A.expr boxed =
|
||||
let lc_vars = Array.of_list lc_vars in
|
||||
let new_body = translate_expr ctx body in
|
||||
let new_binder = Expr.bind lc_vars new_body in
|
||||
Expr.eabs new_binder ts (Marked.get_mark e)
|
||||
| EDefault ([exn], just, cons) when !Cli.optimize_flag ->
|
||||
Expr.eabs new_binder tys (Marked.get_mark e)
|
||||
| EDefault { excepts = [exn]; just; cons } when !Cli.optimize_flag ->
|
||||
(* FIXME: bad place to rely on a global flag *)
|
||||
Expr.ecatch (translate_expr ctx exn) EmptyError
|
||||
(Expr.eifthenelse (translate_expr ctx just) (translate_expr ctx cons)
|
||||
(Expr.eraise EmptyError (Marked.get_mark e))
|
||||
(Marked.get_mark e))
|
||||
(Marked.get_mark e)
|
||||
| EDefault (exceptions, just, cons) ->
|
||||
translate_default ctx exceptions just cons (Marked.get_mark e)
|
||||
| EDefault { excepts; just; cons } ->
|
||||
translate_default ctx excepts just cons (Marked.get_mark e)
|
||||
|
||||
let rec translate_scope_lets
|
||||
(decl_ctx : decl_ctx)
|
||||
|
@ -14,7 +14,7 @@
|
||||
License for the specific language governing permissions and limitations under
|
||||
the License. *)
|
||||
|
||||
open Utils
|
||||
open Catala_utils
|
||||
module D = Dcalc.Ast
|
||||
module A = Ast
|
||||
|
||||
@ -170,7 +170,7 @@ let rec translate_and_hoist (ctx : 'm ctx) (e : 'm D.expr) :
|
||||
created a variable %a to replace it" Print.var v Print.var v'; *)
|
||||
Expr.make_var v' mark, Var.Map.singleton v' e
|
||||
else (find ~info:"should never happen" v ctx).expr, Var.Map.empty
|
||||
| EApp ((EVar v, p), [(ELit LUnit, _)]) ->
|
||||
| EApp { f = EVar v, p; args = [(ELit LUnit, _)] } ->
|
||||
if not (find ~info:"search for a variable" v ctx).is_pure then
|
||||
let v' = Var.make (Bindlib.name_of v) in
|
||||
(* Cli.debug_print @@ Format.asprintf "Found an unpure variable %a,
|
||||
@ -179,7 +179,7 @@ let rec translate_and_hoist (ctx : 'm ctx) (e : 'm D.expr) :
|
||||
else
|
||||
Errors.raise_spanned_error (Expr.pos e)
|
||||
"Internal error: an pure variable was found in an unpure environment."
|
||||
| EDefault (_exceptions, _just, _cons) ->
|
||||
| EDefault _ ->
|
||||
let v' = Var.make "default_term" in
|
||||
Expr.make_var v' mark, Var.Map.singleton v' e
|
||||
| ELit LEmptyError ->
|
||||
@ -187,7 +187,7 @@ let rec translate_and_hoist (ctx : 'm ctx) (e : 'm D.expr) :
|
||||
Expr.make_var v' mark, Var.Map.singleton v' e
|
||||
(* This one is a very special case. It transform an unpure expression
|
||||
environement to a pure expression. *)
|
||||
| ErrorOnEmpty arg ->
|
||||
| EErrorOnEmpty arg ->
|
||||
(* [ match arg with | None -> raise NoValueProvided | Some v -> {{ v }} ] *)
|
||||
let silent_var = Var.make "_" in
|
||||
let x = Var.make "non_empty_argument" in
|
||||
@ -206,22 +206,23 @@ let rec translate_and_hoist (ctx : 'm ctx) (e : 'm D.expr) :
|
||||
((LBool _ | LInt _ | LRat _ | LMoney _ | LUnit | LDate _ | LDuration _) as
|
||||
l) ->
|
||||
Expr.elit l mark, Var.Map.empty
|
||||
| EIfThenElse (e1, e2, e3) ->
|
||||
let e1', h1 = translate_and_hoist ctx e1 in
|
||||
let e2', h2 = translate_and_hoist ctx e2 in
|
||||
let e3', h3 = translate_and_hoist ctx e3 in
|
||||
| EIfThenElse { cond; etrue; efalse } ->
|
||||
let cond', h1 = translate_and_hoist ctx cond in
|
||||
let etrue', h2 = translate_and_hoist ctx etrue in
|
||||
let efalse', h3 = translate_and_hoist ctx efalse in
|
||||
|
||||
let e' = Expr.eifthenelse e1' e2' e3' mark in
|
||||
let e' = Expr.eifthenelse cond' etrue' efalse' mark in
|
||||
|
||||
(*(* equivalent code : *) let e' = let+ e1' = e1' and+ e2' = e2' and+ e3' =
|
||||
e3' in (A.EIfThenElse (e1', e2', e3'), pos) in *)
|
||||
(*(* equivalent code : *) let e' = let+ cond' = cond' and+ etrue' = etrue'
|
||||
and+ efalse' = efalse' in (A.EIfThenElse (cond', etrue', efalse'), pos)
|
||||
in *)
|
||||
e', disjoint_union_maps (Expr.pos e) [h1; h2; h3]
|
||||
| EAssert e1 ->
|
||||
(* same behavior as in the ICFP paper: if e1 is empty, then no error is
|
||||
raised. *)
|
||||
let e1', h1 = translate_and_hoist ctx e1 in
|
||||
Expr.eassert e1' mark, h1
|
||||
| EAbs (binder, ts) ->
|
||||
| EAbs { binder; tys } ->
|
||||
let vars, body = Bindlib.unmbind binder in
|
||||
let ctx, lc_vars =
|
||||
ArrayLabels.fold_right vars ~init:(ctx, []) ~f:(fun var (ctx, lc_vars) ->
|
||||
@ -242,8 +243,8 @@ let rec translate_and_hoist (ctx : 'm ctx) (e : 'm D.expr) :
|
||||
let new_body, hoists = translate_and_hoist ctx body in
|
||||
let new_binder = Expr.bind lc_vars new_body in
|
||||
|
||||
Expr.eabs new_binder (List.map translate_typ ts) mark, hoists
|
||||
| EApp (e1, args) ->
|
||||
Expr.eabs new_binder (List.map translate_typ tys) mark, hoists
|
||||
| EApp { f = e1; args } ->
|
||||
let e1', h1 = translate_and_hoist ctx e1 in
|
||||
let args', h_args =
|
||||
args |> List.map (translate_and_hoist ctx) |> List.split
|
||||
@ -252,35 +253,43 @@ let rec translate_and_hoist (ctx : 'm ctx) (e : 'm D.expr) :
|
||||
let hoists = disjoint_union_maps (Expr.pos e) (h1 :: h_args) in
|
||||
let e' = Expr.eapp e1' args' mark in
|
||||
e', hoists
|
||||
| ETuple (args, s) ->
|
||||
let args', h_args =
|
||||
args |> List.map (translate_and_hoist ctx) |> List.split
|
||||
| EStruct { name; fields } ->
|
||||
let fields', h_fields =
|
||||
StructField.Map.fold
|
||||
(fun field e (fields, hoists) ->
|
||||
let e, h = translate_and_hoist ctx e in
|
||||
StructField.Map.add field e fields, h :: hoists)
|
||||
fields
|
||||
(StructField.Map.empty, [])
|
||||
in
|
||||
|
||||
let hoists = disjoint_union_maps (Expr.pos e) h_args in
|
||||
Expr.etuple args' s mark, hoists
|
||||
| ETupleAccess (e1, i, s, ts) ->
|
||||
let hoists = disjoint_union_maps (Expr.pos e) h_fields in
|
||||
Expr.estruct name fields' mark, hoists
|
||||
| EStructAccess { name; e = e1; field } ->
|
||||
let e1', hoists = translate_and_hoist ctx e1 in
|
||||
let e1' = Expr.etupleaccess e1' i s ts mark in
|
||||
let e1' = Expr.estructaccess e1' field name mark in
|
||||
e1', hoists
|
||||
| EInj (e1, i, en, ts) ->
|
||||
| EInj { name; e = e1; cons } ->
|
||||
let e1', hoists = translate_and_hoist ctx e1 in
|
||||
let e1' = Expr.einj e1' i en ts mark in
|
||||
let e1' = Expr.einj e1' cons name mark in
|
||||
e1', hoists
|
||||
| EMatch (e1, cases, en) ->
|
||||
| EMatch { name; e = e1; cases } ->
|
||||
let e1', h1 = translate_and_hoist ctx e1 in
|
||||
let cases', h_cases =
|
||||
cases |> List.map (translate_and_hoist ctx) |> List.split
|
||||
EnumConstructor.Map.fold
|
||||
(fun cons e (cases, hoists) ->
|
||||
let e', h = translate_and_hoist ctx e in
|
||||
EnumConstructor.Map.add cons e' cases, h :: hoists)
|
||||
cases
|
||||
(EnumConstructor.Map.empty, [])
|
||||
in
|
||||
|
||||
let hoists = disjoint_union_maps (Expr.pos e) (h1 :: h_cases) in
|
||||
let e' = Expr.ematch e1' cases' en mark in
|
||||
let e' = Expr.ematch e1' name cases' mark in
|
||||
e', hoists
|
||||
| EArray es ->
|
||||
let es', hoists = es |> List.map (translate_and_hoist ctx) |> List.split in
|
||||
|
||||
Expr.earray es' mark, disjoint_union_maps (Expr.pos e) hoists
|
||||
| EOp op -> Expr.eop op mark, Var.Map.empty
|
||||
| EOp { op; tys } -> Expr.eop (Operator.translate op) tys mark, Var.Map.empty
|
||||
|
||||
and translate_expr ?(append_esome = true) (ctx : 'm ctx) (e : 'm D.expr) :
|
||||
'm A.expr boxed =
|
||||
@ -302,14 +311,14 @@ and translate_expr ?(append_esome = true) (ctx : 'm ctx) (e : 'm D.expr) :
|
||||
(* Here we have to handle only the cases appearing in hoists, as defined
|
||||
the [translate_and_hoist] function. *)
|
||||
| EVar v -> (find ~info:"should never happen" v ctx).expr
|
||||
| EDefault (excep, just, cons) ->
|
||||
let excep' = List.map (translate_expr ctx) excep in
|
||||
| EDefault { excepts; just; cons } ->
|
||||
let excepts' = List.map (translate_expr ctx) excepts in
|
||||
let just' = translate_expr ctx just in
|
||||
let cons' = translate_expr ctx cons in
|
||||
(* calls handle_option. *)
|
||||
Expr.make_app
|
||||
(Expr.make_var (Var.translate A.handle_default_opt) mark_hoist)
|
||||
[Expr.earray excep' mark_hoist; just'; cons']
|
||||
[Expr.earray excepts' mark_hoist; just'; cons']
|
||||
pos
|
||||
| ELit LEmptyError -> A.make_none mark_hoist
|
||||
| EAssert arg ->
|
||||
@ -354,7 +363,7 @@ let rec translate_scope_let (ctx : 'm ctx) (lets : 'm D.expr scope_body_expr) :
|
||||
{
|
||||
scope_let_kind = SubScopeVarDefinition;
|
||||
scope_let_typ = typ;
|
||||
scope_let_expr = EAbs (binder, _), emark;
|
||||
scope_let_expr = EAbs { binder; _ }, emark;
|
||||
scope_let_next = next;
|
||||
scope_let_pos = pos;
|
||||
} ->
|
||||
@ -385,7 +394,7 @@ let rec translate_scope_let (ctx : 'm ctx) (lets : 'm D.expr scope_body_expr) :
|
||||
{
|
||||
scope_let_kind = SubScopeVarDefinition;
|
||||
scope_let_typ = typ;
|
||||
scope_let_expr = (ErrorOnEmpty _, emark) as expr;
|
||||
scope_let_expr = (EErrorOnEmpty _, emark) as expr;
|
||||
scope_let_next = next;
|
||||
scope_let_pos = pos;
|
||||
} ->
|
||||
@ -529,7 +538,7 @@ let translate_program (prgm : 'm D.program) : 'm A.program =
|
||||
prgm.decl_ctx with
|
||||
ctx_enums =
|
||||
prgm.decl_ctx.ctx_enums
|
||||
|> EnumMap.add A.option_enum A.option_enum_config;
|
||||
|> EnumName.Map.add A.option_enum A.option_enum_config;
|
||||
}
|
||||
in
|
||||
let decl_ctx =
|
||||
@ -537,15 +546,14 @@ let translate_program (prgm : 'm D.program) : 'm A.program =
|
||||
decl_ctx with
|
||||
ctx_structs =
|
||||
prgm.decl_ctx.ctx_structs
|
||||
|> StructMap.mapi (fun n l ->
|
||||
|> StructName.Map.mapi (fun n str ->
|
||||
if List.mem n inputs_structs then
|
||||
ListLabels.map l ~f:(fun (n, tau) ->
|
||||
(* Cli.debug_print @@ Format.asprintf "Input type: %a"
|
||||
(Print.typ decl_ctx) tau; Cli.debug_print @@
|
||||
Format.asprintf "Output type: %a" (Print.typ decl_ctx)
|
||||
(translate_typ tau); *)
|
||||
n, translate_typ tau)
|
||||
else l);
|
||||
StructField.Map.map translate_typ str
|
||||
(* Cli.debug_print @@ Format.asprintf "Input type: %a"
|
||||
(Print.typ decl_ctx) tau; Cli.debug_print @@ Format.asprintf
|
||||
"Output type: %a" (Print.typ decl_ctx) (translate_typ
|
||||
tau); *)
|
||||
else str);
|
||||
}
|
||||
in
|
||||
|
||||
|
21
compiler/lcalc/from_dcalc.ml
Normal file
21
compiler/lcalc/from_dcalc.ml
Normal file
@ -0,0 +1,21 @@
|
||||
(* This file is part of the Catala compiler, a specification language for tax
|
||||
and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
|
||||
Denis Merigoux <denis.merigoux@inria.fr>
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not
|
||||
use this file except in compliance with the License. You may obtain a copy of
|
||||
the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
License for the specific language governing permissions and limitations under
|
||||
the License. *)
|
||||
|
||||
let translate_program_with_exceptions =
|
||||
Compile_with_exceptions.translate_program
|
||||
|
||||
let translate_program_without_exceptions =
|
||||
Compile_without_exceptions.translate_program
|
26
compiler/lcalc/from_dcalc.mli
Normal file
26
compiler/lcalc/from_dcalc.mli
Normal file
@ -0,0 +1,26 @@
|
||||
(* This file is part of the Catala compiler, a specification language for tax
|
||||
and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
|
||||
Denis Merigoux <denis.merigoux@inria.fr>
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not
|
||||
use this file except in compliance with the License. You may obtain a copy of
|
||||
the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
License for the specific language governing permissions and limitations under
|
||||
the License. *)
|
||||
|
||||
val translate_program_with_exceptions : 'm Dcalc.Ast.program -> 'm Ast.program
|
||||
(** Translation from the default calculus to the lambda calculus. This
|
||||
translation uses exceptions to handle empty default terms. *)
|
||||
|
||||
val translate_program_without_exceptions :
|
||||
'm Dcalc.Ast.program -> 'm Ast.program
|
||||
(** Translation from the default calculus to the lambda calculus. This
|
||||
translation uses an option monad to handle empty defaults terms. This
|
||||
transformation is one piece to permit to compile toward legacy languages
|
||||
that does not contains exceptions. *)
|
@ -13,50 +13,47 @@
|
||||
WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
License for the specific language governing permissions and limitations under
|
||||
the License. *)
|
||||
open Utils
|
||||
open Catala_utils
|
||||
open Shared_ast
|
||||
open Ast
|
||||
module D = Dcalc.Ast
|
||||
|
||||
let visitor_map (t : 'a -> 'm expr -> 'm expr boxed) (ctx : 'a) (e : 'm expr) :
|
||||
'm expr boxed =
|
||||
Expr.map ctx ~f:t e
|
||||
let visitor_map (t : 'm expr -> 'm expr boxed) (e : 'm expr) : 'm expr boxed =
|
||||
Expr.map ~f:t e
|
||||
|
||||
let rec iota_expr (_ : unit) (e : 'm expr) : 'm expr boxed =
|
||||
let rec iota_expr (e : 'm expr) : 'm expr boxed =
|
||||
let m = Marked.get_mark e in
|
||||
match Marked.unmark e with
|
||||
| EMatch ((EInj (e1, i, n', _ts), _), cases, n) when EnumName.compare n n' = 0
|
||||
->
|
||||
let e1 = visitor_map iota_expr () e1 in
|
||||
let case = visitor_map iota_expr () (List.nth cases i) in
|
||||
| EMatch { e = EInj { e = e'; cons; name = n' }, _; cases; name = n }
|
||||
when EnumName.equal n n' ->
|
||||
let e1 = visitor_map iota_expr e' in
|
||||
let case = visitor_map iota_expr (EnumConstructor.Map.find cons cases) in
|
||||
Expr.eapp case [e1] m
|
||||
| EMatch (e', cases, n)
|
||||
| EMatch { e = e'; cases; name = n }
|
||||
when cases
|
||||
|> List.mapi (fun i (case, _pos) ->
|
||||
match case with
|
||||
| EInj (_ei, i', n', _ts') ->
|
||||
i = i' && (* n = n' *) EnumName.compare n n' = 0
|
||||
|> EnumConstructor.Map.mapi (fun i case ->
|
||||
match Marked.unmark case with
|
||||
| EInj { cons = i'; name = n'; _ } ->
|
||||
EnumConstructor.equal i i' && EnumName.equal n n'
|
||||
| _ -> false)
|
||||
|> List.for_all Fun.id ->
|
||||
visitor_map iota_expr () e'
|
||||
| _ -> visitor_map iota_expr () e
|
||||
|> EnumConstructor.Map.for_all (fun _ b -> b) ->
|
||||
visitor_map iota_expr e'
|
||||
| _ -> visitor_map iota_expr e
|
||||
|
||||
let rec beta_expr (e : 'm expr) : 'm expr boxed =
|
||||
let m = Marked.get_mark e in
|
||||
match Marked.unmark e with
|
||||
| EApp (e1, args) ->
|
||||
| EApp { f = e1; args } ->
|
||||
Expr.Box.app1n (beta_expr e1) (List.map beta_expr args)
|
||||
(fun e1 args ->
|
||||
match Marked.unmark e1 with
|
||||
| EAbs (binder, _) -> Marked.unmark (Expr.subst binder args)
|
||||
| _ -> EApp (e1, args))
|
||||
| EAbs { binder; _ } -> Marked.unmark (Expr.subst binder args)
|
||||
| _ -> EApp { f = e1; args })
|
||||
m
|
||||
| _ -> visitor_map (fun () -> beta_expr) () e
|
||||
| _ -> visitor_map beta_expr e
|
||||
|
||||
let iota_optimizations (p : 'm program) : 'm program =
|
||||
let new_scopes =
|
||||
Scope.map_exprs ~f:(iota_expr ()) ~varf:(fun v -> v) p.scopes
|
||||
in
|
||||
let new_scopes = Scope.map_exprs ~f:iota_expr ~varf:(fun v -> v) p.scopes in
|
||||
{ p with scopes = Bindlib.unbox new_scopes }
|
||||
|
||||
(* TODO: beta optimizations apply inlining of the program. We left the inclusion
|
||||
@ -70,30 +67,32 @@ let _beta_optimizations (p : 'm program) : 'm program =
|
||||
let rec peephole_expr (e : 'm expr) : 'm expr boxed =
|
||||
let m = Marked.get_mark e in
|
||||
match Marked.unmark e with
|
||||
| EIfThenElse (e1, e2, e3) ->
|
||||
Expr.Box.app3 (peephole_expr e1) (peephole_expr e2) (peephole_expr e3)
|
||||
(fun e1 e2 e3 ->
|
||||
match Marked.unmark e1 with
|
||||
| EIfThenElse { cond; etrue; efalse } ->
|
||||
Expr.Box.app3 (peephole_expr cond) (peephole_expr etrue)
|
||||
(peephole_expr efalse)
|
||||
(fun cond etrue efalse ->
|
||||
match Marked.unmark cond with
|
||||
| ELit (LBool true)
|
||||
| EApp ((EOp (Unop (Log _)), _), [(ELit (LBool true), _)]) ->
|
||||
Marked.unmark e2
|
||||
| EApp { f = EOp { op = Log _; _ }, _; args = [(ELit (LBool true), _)] }
|
||||
->
|
||||
Marked.unmark etrue
|
||||
| ELit (LBool false)
|
||||
| EApp ((EOp (Unop (Log _)), _), [(ELit (LBool false), _)]) ->
|
||||
Marked.unmark e3
|
||||
| _ -> EIfThenElse (e1, e2, e3))
|
||||
| EApp
|
||||
{ f = EOp { op = Log _; _ }, _; args = [(ELit (LBool false), _)] }
|
||||
->
|
||||
Marked.unmark efalse
|
||||
| _ -> EIfThenElse { cond; etrue; efalse })
|
||||
m
|
||||
| ECatch (e1, except, e2) ->
|
||||
Expr.Box.app2 (peephole_expr e1) (peephole_expr e2)
|
||||
(fun e1 e2 ->
|
||||
match Marked.unmark e1, Marked.unmark e2 with
|
||||
| ERaise except', ERaise except''
|
||||
when except' = except && except = except'' ->
|
||||
ERaise except
|
||||
| ERaise except', _ when except' = except -> Marked.unmark e2
|
||||
| _, ERaise except' when except' = except -> Marked.unmark e1
|
||||
| _ -> ECatch (e1, except, e2))
|
||||
| ECatch { body; exn; handler } ->
|
||||
Expr.Box.app2 (peephole_expr body) (peephole_expr handler)
|
||||
(fun body handler ->
|
||||
match Marked.unmark body, Marked.unmark handler with
|
||||
| ERaise exn', ERaise exn'' when exn' = exn && exn = exn'' -> ERaise exn
|
||||
| ERaise exn', _ when exn' = exn -> Marked.unmark handler
|
||||
| _, ERaise exn' when exn' = exn -> Marked.unmark body
|
||||
| _ -> ECatch { body; exn; handler })
|
||||
m
|
||||
| _ -> visitor_map (fun () -> peephole_expr) () e
|
||||
| _ -> visitor_map peephole_expr e
|
||||
|
||||
let peephole_optimizations (p : 'm program) : 'm program =
|
||||
let new_scopes =
|
||||
|
@ -14,24 +14,21 @@
|
||||
License for the specific language governing permissions and limitations under
|
||||
the License. *)
|
||||
|
||||
open Utils
|
||||
open Catala_utils
|
||||
open Shared_ast
|
||||
open Ast
|
||||
open String_common
|
||||
module D = Dcalc.Ast
|
||||
|
||||
let find_struct (s : StructName.t) (ctx : decl_ctx) :
|
||||
(StructFieldName.t * typ) list =
|
||||
try StructMap.find s ctx.ctx_structs
|
||||
let find_struct (s : StructName.t) (ctx : decl_ctx) : typ StructField.Map.t =
|
||||
try StructName.Map.find s ctx.ctx_structs
|
||||
with Not_found ->
|
||||
let s_name, pos = StructName.get_info s in
|
||||
Errors.raise_spanned_error pos
|
||||
"Internal Error: Structure %s was not found in the current environment."
|
||||
s_name
|
||||
|
||||
let find_enum (en : EnumName.t) (ctx : decl_ctx) :
|
||||
(EnumConstructor.t * typ) list =
|
||||
try EnumMap.find en ctx.ctx_enums
|
||||
let find_enum (en : EnumName.t) (ctx : decl_ctx) : typ EnumConstructor.Map.t =
|
||||
try EnumName.Map.find en ctx.ctx_enums
|
||||
with Not_found ->
|
||||
let en_name, pos = EnumName.get_info en in
|
||||
Errors.raise_spanned_error pos
|
||||
@ -57,43 +54,13 @@ let format_lit (fmt : Format.formatter) (l : lit Marked.pos) : unit =
|
||||
let years, months, days = Runtime.duration_to_years_months_days d in
|
||||
Format.fprintf fmt "duration_of_numbers (%d) (%d) (%d)" years months days
|
||||
|
||||
let format_op_kind (fmt : Format.formatter) (k : op_kind) =
|
||||
Format.fprintf fmt "%s"
|
||||
(match k with
|
||||
| KInt -> "!"
|
||||
| KRat -> "&"
|
||||
| KMoney -> "$"
|
||||
| KDate -> "@"
|
||||
| KDuration -> "^")
|
||||
|
||||
let format_binop (fmt : Format.formatter) (op : binop Marked.pos) : unit =
|
||||
match Marked.unmark op with
|
||||
| Add k -> Format.fprintf fmt "+%a" format_op_kind k
|
||||
| Sub k -> Format.fprintf fmt "-%a" format_op_kind k
|
||||
| Mult k -> Format.fprintf fmt "*%a" format_op_kind k
|
||||
| Div k -> Format.fprintf fmt "/%a" format_op_kind k
|
||||
| And -> Format.fprintf fmt "%s" "&&"
|
||||
| Or -> Format.fprintf fmt "%s" "||"
|
||||
| Eq -> Format.fprintf fmt "%s" "="
|
||||
| Neq | Xor -> Format.fprintf fmt "%s" "<>"
|
||||
| Lt k -> Format.fprintf fmt "%s%a" "<" format_op_kind k
|
||||
| Lte k -> Format.fprintf fmt "%s%a" "<=" format_op_kind k
|
||||
| Gt k -> Format.fprintf fmt "%s%a" ">" format_op_kind k
|
||||
| Gte k -> Format.fprintf fmt "%s%a" ">=" format_op_kind k
|
||||
| Concat -> Format.fprintf fmt "@"
|
||||
| Map -> Format.fprintf fmt "Array.map"
|
||||
| Filter -> Format.fprintf fmt "array_filter"
|
||||
|
||||
let format_ternop (fmt : Format.formatter) (op : ternop Marked.pos) : unit =
|
||||
match Marked.unmark op with Fold -> Format.fprintf fmt "Array.fold_left"
|
||||
|
||||
let format_uid_list (fmt : Format.formatter) (uids : Uid.MarkedString.info list)
|
||||
: unit =
|
||||
Format.fprintf fmt "@[<hov 2>[%a]@]"
|
||||
(Format.pp_print_list
|
||||
~pp_sep:(fun fmt () -> Format.fprintf fmt ";@ ")
|
||||
(fun fmt info ->
|
||||
Format.fprintf fmt "\"%a\"" Utils.Uid.MarkedString.format_info info))
|
||||
Format.fprintf fmt "\"%a\"" Uid.MarkedString.format info))
|
||||
uids
|
||||
|
||||
let format_string_list (fmt : Format.formatter) (uids : string list) : unit =
|
||||
@ -106,26 +73,6 @@ let format_string_list (fmt : Format.formatter) (uids : string list) : unit =
|
||||
(Re.replace sanitize_quotes ~f:(fun _ -> "\\\"") info)))
|
||||
uids
|
||||
|
||||
let format_unop (fmt : Format.formatter) (op : unop Marked.pos) : unit =
|
||||
match Marked.unmark op with
|
||||
| Minus k -> Format.fprintf fmt "~-%a" format_op_kind k
|
||||
| Not -> Format.fprintf fmt "%s" "not"
|
||||
| Log (_entry, _infos) ->
|
||||
Errors.raise_spanned_error (Marked.get_mark op)
|
||||
"Internal error: a log operator has not been caught by the expression \
|
||||
match"
|
||||
| Length -> Format.fprintf fmt "%s" "array_length"
|
||||
| IntToRat -> Format.fprintf fmt "%s" "decimal_of_integer"
|
||||
| MoneyToRat -> Format.fprintf fmt "%s" "decimal_of_money"
|
||||
| RatToMoney -> Format.fprintf fmt "%s" "money_of_decimal"
|
||||
| GetDay -> Format.fprintf fmt "%s" "day_of_month_of_date"
|
||||
| GetMonth -> Format.fprintf fmt "%s" "month_number_of_date"
|
||||
| GetYear -> Format.fprintf fmt "%s" "year_of_date"
|
||||
| FirstDayOfMonth -> Format.fprintf fmt "%s" "first_day_of_month"
|
||||
| LastDayOfMonth -> Format.fprintf fmt "%s" "last_day_of_month"
|
||||
| RoundMoney -> Format.fprintf fmt "%s" "money_round"
|
||||
| RoundDecimal -> Format.fprintf fmt "%s" "decimal_round"
|
||||
|
||||
let avoid_keywords (s : string) : string =
|
||||
match s with
|
||||
(* list taken from
|
||||
@ -137,14 +84,14 @@ let avoid_keywords (s : string) : string =
|
||||
| "match" | "method" | "mod" | "module" | "mutable" | "new" | "nonrec"
|
||||
| "object" | "of" | "open" | "or" | "private" | "rec" | "sig" | "struct"
|
||||
| "then" | "to" | "true" | "try" | "type" | "val" | "virtual" | "when"
|
||||
| "while" | "with" ->
|
||||
| "while" | "with" | "Stdlib" | "Runtime" | "Oper" ->
|
||||
s ^ "_user"
|
||||
| _ -> s
|
||||
|
||||
let format_struct_name (fmt : Format.formatter) (v : StructName.t) : unit =
|
||||
Format.asprintf "%a" StructName.format_t v
|
||||
|> to_ascii
|
||||
|> to_snake_case
|
||||
|> String.to_ascii
|
||||
|> String.to_snake_case
|
||||
|> avoid_keywords
|
||||
|> Format.fprintf fmt "%s"
|
||||
|
||||
@ -154,8 +101,8 @@ let format_to_module_name
|
||||
(match name with
|
||||
| `Ename v -> Format.asprintf "%a" EnumName.format_t v
|
||||
| `Sname v -> Format.asprintf "%a" StructName.format_t v)
|
||||
|> to_ascii
|
||||
|> to_snake_case
|
||||
|> String.to_ascii
|
||||
|> String.to_snake_case
|
||||
|> avoid_keywords
|
||||
|> String.split_on_char '_'
|
||||
|> List.map String.capitalize_ascii
|
||||
@ -164,24 +111,25 @@ let format_to_module_name
|
||||
|
||||
let format_struct_field_name
|
||||
(fmt : Format.formatter)
|
||||
((sname_opt, v) : StructName.t option * StructFieldName.t) : unit =
|
||||
((sname_opt, v) : StructName.t option * StructField.t) : unit =
|
||||
(match sname_opt with
|
||||
| Some sname ->
|
||||
Format.fprintf fmt "%a.%s" format_to_module_name (`Sname sname)
|
||||
| None -> Format.fprintf fmt "%s")
|
||||
(avoid_keywords
|
||||
(to_ascii (Format.asprintf "%a" StructFieldName.format_t v)))
|
||||
(String.to_ascii (Format.asprintf "%a" StructField.format_t v)))
|
||||
|
||||
let format_enum_name (fmt : Format.formatter) (v : EnumName.t) : unit =
|
||||
Format.fprintf fmt "%s"
|
||||
(avoid_keywords
|
||||
(to_snake_case (to_ascii (Format.asprintf "%a" EnumName.format_t v))))
|
||||
(String.to_snake_case
|
||||
(String.to_ascii (Format.asprintf "%a" EnumName.format_t v))))
|
||||
|
||||
let format_enum_cons_name (fmt : Format.formatter) (v : EnumConstructor.t) :
|
||||
unit =
|
||||
Format.fprintf fmt "%s"
|
||||
(avoid_keywords
|
||||
(to_ascii (Format.asprintf "%a" EnumConstructor.format_t v)))
|
||||
(String.to_ascii (Format.asprintf "%a" EnumConstructor.format_t v)))
|
||||
|
||||
let rec typ_embedding_name (fmt : Format.formatter) (ty : typ) : unit =
|
||||
match Marked.unmark ty with
|
||||
@ -225,25 +173,27 @@ let rec format_typ (fmt : Format.formatter) (typ : typ) : unit =
|
||||
| TAny -> Format.fprintf fmt "_"
|
||||
|
||||
let format_var (fmt : Format.formatter) (v : 'm Var.t) : unit =
|
||||
let lowercase_name = to_snake_case (to_ascii (Bindlib.name_of v)) in
|
||||
let lowercase_name =
|
||||
String.to_snake_case (String.to_ascii (Bindlib.name_of v))
|
||||
in
|
||||
let lowercase_name =
|
||||
Re.Pcre.substitute ~rex:(Re.Pcre.regexp "\\.")
|
||||
~subst:(fun _ -> "_dot_")
|
||||
lowercase_name
|
||||
in
|
||||
let lowercase_name = avoid_keywords (to_ascii lowercase_name) in
|
||||
let lowercase_name = avoid_keywords (String.to_ascii lowercase_name) in
|
||||
if
|
||||
List.mem lowercase_name ["handle_default"; "handle_default_opt"]
|
||||
|| begins_with_uppercase (Bindlib.name_of v)
|
||||
then Format.fprintf fmt "%s" lowercase_name
|
||||
else if lowercase_name = "_" then Format.fprintf fmt "%s" lowercase_name
|
||||
|| String.begins_with_uppercase (Bindlib.name_of v)
|
||||
then Format.pp_print_string fmt lowercase_name
|
||||
else if lowercase_name = "_" then Format.pp_print_string fmt lowercase_name
|
||||
else (
|
||||
Cli.debug_print "lowercase_name: %s " lowercase_name;
|
||||
Format.fprintf fmt "%s_" lowercase_name)
|
||||
|
||||
let needs_parens (e : 'm expr) : bool =
|
||||
match Marked.unmark e with
|
||||
| EApp ((EAbs (_, _), _), _)
|
||||
| EApp { f = EAbs _, _; _ }
|
||||
| ELit (LBool _ | LUnit)
|
||||
| EVar _ | ETuple _ | EOp _ ->
|
||||
false
|
||||
@ -279,56 +229,52 @@ let rec format_expr (ctx : decl_ctx) (fmt : Format.formatter) (e : 'm expr) :
|
||||
in
|
||||
match Marked.unmark e with
|
||||
| EVar v -> Format.fprintf fmt "%a" format_var v
|
||||
| ETuple (es, None) ->
|
||||
| ETuple es ->
|
||||
Format.fprintf fmt "@[<hov 2>(%a)@]"
|
||||
(Format.pp_print_list
|
||||
~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ")
|
||||
(fun fmt e -> Format.fprintf fmt "%a" format_with_parens e))
|
||||
es
|
||||
| ETuple (es, Some s) ->
|
||||
if List.length es = 0 then Format.fprintf fmt "()"
|
||||
| EStruct { name = s; fields = es } ->
|
||||
if StructField.Map.is_empty es then Format.fprintf fmt "()"
|
||||
else
|
||||
Format.fprintf fmt "{@[<hov 2>%a@]}"
|
||||
(Format.pp_print_list
|
||||
~pp_sep:(fun fmt () -> Format.fprintf fmt ";@ ")
|
||||
(fun fmt (e, struct_field) ->
|
||||
(fun fmt (struct_field, e) ->
|
||||
Format.fprintf fmt "@[<hov 2>%a =@ %a@]" format_struct_field_name
|
||||
(Some s, struct_field) format_with_parens e))
|
||||
(List.combine es (List.map fst (find_struct s ctx)))
|
||||
(StructField.Map.bindings es)
|
||||
| EArray es ->
|
||||
Format.fprintf fmt "@[<hov 2>[|%a|]@]"
|
||||
(Format.pp_print_list
|
||||
~pp_sep:(fun fmt () -> Format.fprintf fmt ";@ ")
|
||||
(fun fmt e -> Format.fprintf fmt "%a" format_with_parens e))
|
||||
es
|
||||
| ETupleAccess (e1, n, s, ts) -> (
|
||||
match s with
|
||||
| None ->
|
||||
Format.fprintf fmt "let@ %a@ = %a@ in@ x"
|
||||
(Format.pp_print_list
|
||||
~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ")
|
||||
(fun fmt i -> Format.fprintf fmt "%s" (if i = n then "x" else "_")))
|
||||
(List.mapi (fun i _ -> i) ts)
|
||||
format_with_parens e1
|
||||
| Some s ->
|
||||
Format.fprintf fmt "%a.%a" format_with_parens e1 format_struct_field_name
|
||||
(Some s, fst (List.nth (find_struct s ctx) n)))
|
||||
| EInj (e, n, en, _ts) ->
|
||||
Format.fprintf fmt "@[<hov 2>%a.%a@ %a@]" format_to_module_name (`Ename en)
|
||||
format_enum_cons_name
|
||||
(fst (List.nth (find_enum en ctx) n))
|
||||
format_with_parens e
|
||||
| EMatch (e, es, e_name) ->
|
||||
| ETupleAccess { e; index; size } ->
|
||||
Format.fprintf fmt "let@ %a@ = %a@ in@ x"
|
||||
(Format.pp_print_list
|
||||
~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ")
|
||||
(fun fmt i ->
|
||||
Format.pp_print_string fmt (if i = index then "x" else "_")))
|
||||
(List.init size Fun.id) format_with_parens e
|
||||
| EStructAccess { e; field; name } ->
|
||||
Format.fprintf fmt "%a.%a" format_with_parens e format_struct_field_name
|
||||
(Some name, field)
|
||||
| EInj { e; cons; name } ->
|
||||
Format.fprintf fmt "@[<hov 2>%a.%a@ %a@]" format_to_module_name
|
||||
(`Ename name) format_enum_cons_name cons format_with_parens e
|
||||
| EMatch { e; cases; name } ->
|
||||
Format.fprintf fmt "@[<hv>@[<hov 2>match@ %a@]@ with@\n| %a@]"
|
||||
format_with_parens e
|
||||
(Format.pp_print_list
|
||||
~pp_sep:(fun fmt () -> Format.fprintf fmt "@ | ")
|
||||
(fun fmt (e, c) ->
|
||||
(fun fmt (c, e) ->
|
||||
Format.fprintf fmt "@[<hov 2>%a.%a %a@]" format_to_module_name
|
||||
(`Ename e_name) format_enum_cons_name c
|
||||
(`Ename name) format_enum_cons_name c
|
||||
(fun fmt e ->
|
||||
match Marked.unmark e with
|
||||
| EAbs (binder, _) ->
|
||||
| EAbs { binder; _ } ->
|
||||
let xs, body = Bindlib.unmbind binder in
|
||||
Format.fprintf fmt "%a ->@ %a"
|
||||
(Format.pp_print_list
|
||||
@ -338,11 +284,11 @@ let rec format_expr (ctx : decl_ctx) (fmt : Format.formatter) (e : 'm expr) :
|
||||
| _ -> assert false
|
||||
(* should not happen *))
|
||||
e))
|
||||
(List.combine es (List.map fst (find_enum e_name ctx)))
|
||||
(EnumConstructor.Map.bindings cases)
|
||||
| ELit l -> Format.fprintf fmt "%a" format_lit (Marked.mark (Expr.pos e) l)
|
||||
| EApp ((EAbs (binder, taus), _), args) ->
|
||||
| EApp { f = EAbs { binder; tys }, _; args } ->
|
||||
let xs, body = Bindlib.unmbind binder in
|
||||
let xs_tau = List.map2 (fun x tau -> x, tau) (Array.to_list xs) taus in
|
||||
let xs_tau = List.map2 (fun x tau -> x, tau) (Array.to_list xs) tys in
|
||||
let xs_tau_arg = List.map2 (fun (x, tau) arg -> x, tau, arg) xs_tau args in
|
||||
Format.fprintf fmt "(%a%a)"
|
||||
(Format.pp_print_list
|
||||
@ -351,30 +297,28 @@ let rec format_expr (ctx : decl_ctx) (fmt : Format.formatter) (e : 'm expr) :
|
||||
Format.fprintf fmt "@[<hov 2>let@ %a@ :@ %a@ =@ %a@]@ in@\n"
|
||||
format_var x format_typ tau format_with_parens arg))
|
||||
xs_tau_arg format_with_parens body
|
||||
| EAbs (binder, taus) ->
|
||||
| EAbs { binder; tys } ->
|
||||
let xs, body = Bindlib.unmbind binder in
|
||||
let xs_tau = List.map2 (fun x tau -> x, tau) (Array.to_list xs) taus in
|
||||
let xs_tau = List.map2 (fun x tau -> x, tau) (Array.to_list xs) tys in
|
||||
Format.fprintf fmt "@[<hov 2>fun@ %a ->@ %a@]"
|
||||
(Format.pp_print_list
|
||||
~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ")
|
||||
(fun fmt (x, tau) ->
|
||||
Format.fprintf fmt "@[<hov 2>(%a:@ %a)@]" format_var x format_typ tau))
|
||||
xs_tau format_expr body
|
||||
| EApp ((EOp (Binop ((Map | Filter) as op)), _), [arg1; arg2]) ->
|
||||
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@]" format_binop (op, Pos.no_pos)
|
||||
format_with_parens arg1 format_with_parens arg2
|
||||
| EApp ((EOp (Binop op), _), [arg1; arg2]) ->
|
||||
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@]" format_with_parens arg1
|
||||
format_binop (op, Pos.no_pos) format_with_parens arg2
|
||||
| EApp ((EApp ((EOp (Unop (Log (BeginCall, info))), _), [f]), _), [arg])
|
||||
| EApp
|
||||
{
|
||||
f = EApp { f = EOp { op = Log (BeginCall, info); _ }, _; args = [f] }, _;
|
||||
args = [arg];
|
||||
}
|
||||
when !Cli.trace_flag ->
|
||||
Format.fprintf fmt "(log_begin_call@ %a@ %a)@ %a" format_uid_list info
|
||||
format_with_parens f format_with_parens arg
|
||||
| EApp ((EOp (Unop (Log (VarDef tau, info))), _), [arg1]) when !Cli.trace_flag
|
||||
->
|
||||
| EApp { f = EOp { op = Log (VarDef tau, info); _ }, _; args = [arg1] }
|
||||
when !Cli.trace_flag ->
|
||||
Format.fprintf fmt "(log_variable_definition@ %a@ (%a)@ %a)" format_uid_list
|
||||
info typ_embedding_name (tau, Pos.no_pos) format_with_parens arg1
|
||||
| EApp ((EOp (Unop (Log (PosRecordIfTrueBool, _))), m), [arg1])
|
||||
| EApp { f = EOp { op = Log (PosRecordIfTrueBool, _); _ }, m; args = [arg1] }
|
||||
when !Cli.trace_flag ->
|
||||
let pos = Expr.mark_pos m in
|
||||
Format.fprintf fmt
|
||||
@ -383,15 +327,13 @@ let rec format_expr (ctx : decl_ctx) (fmt : Format.formatter) (e : 'm expr) :
|
||||
(Pos.get_file pos) (Pos.get_start_line pos) (Pos.get_start_column pos)
|
||||
(Pos.get_end_line pos) (Pos.get_end_column pos) format_string_list
|
||||
(Pos.get_law_info pos) format_with_parens arg1
|
||||
| EApp ((EOp (Unop (Log (EndCall, info))), _), [arg1]) when !Cli.trace_flag ->
|
||||
| EApp { f = EOp { op = Log (EndCall, info); _ }, _; args = [arg1] }
|
||||
when !Cli.trace_flag ->
|
||||
Format.fprintf fmt "(log_end_call@ %a@ %a)" format_uid_list info
|
||||
format_with_parens arg1
|
||||
| EApp ((EOp (Unop (Log _)), _), [arg1]) ->
|
||||
| EApp { f = EOp { op = Log _; _ }, _; args = [arg1] } ->
|
||||
Format.fprintf fmt "%a" format_with_parens arg1
|
||||
| EApp ((EOp (Unop op), _), [arg1]) ->
|
||||
Format.fprintf fmt "@[<hov 2>%a@ %a@]" format_unop (op, Pos.no_pos)
|
||||
format_with_parens arg1
|
||||
| EApp ((EVar x, pos), args)
|
||||
| EApp { f = EVar x, pos; args }
|
||||
when Var.compare x (Var.translate Ast.handle_default) = 0
|
||||
|| Var.compare x (Var.translate Ast.handle_default_opt) = 0 ->
|
||||
Format.fprintf fmt
|
||||
@ -409,19 +351,17 @@ let rec format_expr (ctx : decl_ctx) (fmt : Format.formatter) (e : 'm expr) :
|
||||
~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ")
|
||||
format_with_parens)
|
||||
args
|
||||
| EApp (f, args) ->
|
||||
| EApp { f; args } ->
|
||||
Format.fprintf fmt "@[<hov 2>%a@ %a@]" format_with_parens f
|
||||
(Format.pp_print_list
|
||||
~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ")
|
||||
format_with_parens)
|
||||
args
|
||||
| EIfThenElse (e1, e2, e3) ->
|
||||
| EIfThenElse { cond; etrue; efalse } ->
|
||||
Format.fprintf fmt
|
||||
"@[<hov 2> if@ @[<hov 2>%a@]@ then@ @[<hov 2>%a@]@ else@ @[<hov 2>%a@]@]"
|
||||
format_with_parens e1 format_with_parens e2 format_with_parens e3
|
||||
| EOp (Ternop op) -> Format.fprintf fmt "%a" format_ternop (op, Pos.no_pos)
|
||||
| EOp (Binop op) -> Format.fprintf fmt "%a" format_binop (op, Pos.no_pos)
|
||||
| EOp (Unop op) -> Format.fprintf fmt "%a" format_unop (op, Pos.no_pos)
|
||||
format_with_parens cond format_with_parens etrue format_with_parens efalse
|
||||
| EOp { op; _ } -> Format.pp_print_string fmt (Operator.name op)
|
||||
| EAssert e' ->
|
||||
Format.fprintf fmt
|
||||
"@[<hov 2>if@ %a@ then@ ()@ else@ raise (AssertionFailed @[<hov \
|
||||
@ -437,18 +377,17 @@ let rec format_expr (ctx : decl_ctx) (fmt : Format.formatter) (e : 'm expr) :
|
||||
(Pos.get_law_info (Expr.pos e'))
|
||||
| ERaise exc ->
|
||||
Format.fprintf fmt "raise@ %a" format_exception (exc, Expr.pos e)
|
||||
| ECatch (e1, exc, e2) ->
|
||||
| ECatch { body; exn; handler } ->
|
||||
Format.fprintf fmt
|
||||
"@,@[<hv>@[<hov 2>try@ %a@]@ with@]@ @[<hov 2>%a@ ->@ %a@]"
|
||||
format_with_parens e1 format_exception
|
||||
(exc, Expr.pos e)
|
||||
format_with_parens e2
|
||||
format_with_parens body format_exception
|
||||
(exn, Expr.pos e)
|
||||
format_with_parens handler
|
||||
|
||||
let format_struct_embedding
|
||||
(fmt : Format.formatter)
|
||||
((struct_name, struct_fields) :
|
||||
StructName.t * (StructFieldName.t * typ) list) =
|
||||
if List.length struct_fields = 0 then
|
||||
((struct_name, struct_fields) : StructName.t * typ StructField.Map.t) =
|
||||
if StructField.Map.is_empty struct_fields then
|
||||
Format.fprintf fmt "let embed_%a (_: %a.t) : runtime_value = Unit@\n@\n"
|
||||
format_struct_name struct_name format_to_module_name (`Sname struct_name)
|
||||
else
|
||||
@ -461,16 +400,16 @@ let format_struct_embedding
|
||||
(Format.pp_print_list
|
||||
~pp_sep:(fun fmt () -> Format.fprintf fmt ";@\n")
|
||||
(fun _fmt (struct_field, struct_field_type) ->
|
||||
Format.fprintf fmt "(\"%a\",@ %a@ x.%a)" StructFieldName.format_t
|
||||
Format.fprintf fmt "(\"%a\",@ %a@ x.%a)" StructField.format_t
|
||||
struct_field typ_embedding_name struct_field_type
|
||||
format_struct_field_name
|
||||
(Some struct_name, struct_field)))
|
||||
struct_fields
|
||||
(StructField.Map.bindings struct_fields)
|
||||
|
||||
let format_enum_embedding
|
||||
(fmt : Format.formatter)
|
||||
((enum_name, enum_cases) : EnumName.t * (EnumConstructor.t * typ) list) =
|
||||
if List.length enum_cases = 0 then
|
||||
((enum_name, enum_cases) : EnumName.t * typ EnumConstructor.Map.t) =
|
||||
if EnumConstructor.Map.is_empty enum_cases then
|
||||
Format.fprintf fmt "let embed_%a (_: %a.t) : runtime_value = Unit@\n@\n"
|
||||
format_to_module_name (`Ename enum_name) format_enum_name enum_name
|
||||
else
|
||||
@ -486,14 +425,14 @@ let format_enum_embedding
|
||||
Format.fprintf fmt "@[<hov 2>| %a x ->@ (\"%a\", %a x)@]"
|
||||
format_enum_cons_name enum_cons EnumConstructor.format_t enum_cons
|
||||
typ_embedding_name enum_cons_type))
|
||||
enum_cases
|
||||
(EnumConstructor.Map.bindings enum_cases)
|
||||
|
||||
let format_ctx
|
||||
(type_ordering : Scopelang.Dependency.TVertex.t list)
|
||||
(fmt : Format.formatter)
|
||||
(ctx : decl_ctx) : unit =
|
||||
let format_struct_decl fmt (struct_name, struct_fields) =
|
||||
if List.length struct_fields = 0 then
|
||||
if StructField.Map.is_empty struct_fields then
|
||||
Format.fprintf fmt
|
||||
"@[<v 2>module %a = struct@\n@[<hov 2>type t = unit@]@]@\nend@\n"
|
||||
format_to_module_name (`Sname struct_name)
|
||||
@ -508,7 +447,7 @@ let format_ctx
|
||||
(fun _fmt (struct_field, struct_field_type) ->
|
||||
Format.fprintf fmt "@[<hov 2>%a:@ %a@]" format_struct_field_name
|
||||
(None, struct_field) format_typ struct_field_type))
|
||||
struct_fields;
|
||||
(StructField.Map.bindings struct_fields);
|
||||
if !Cli.trace_flag then
|
||||
format_struct_embedding fmt (struct_name, struct_fields)
|
||||
in
|
||||
@ -521,7 +460,7 @@ let format_ctx
|
||||
(fun _fmt (enum_cons, enum_cons_type) ->
|
||||
Format.fprintf fmt "@[<hov 2>| %a@ of@ %a@]" format_enum_cons_name
|
||||
enum_cons format_typ enum_cons_type))
|
||||
enum_cons;
|
||||
(EnumConstructor.Map.bindings enum_cons);
|
||||
if !Cli.trace_flag then format_enum_embedding fmt (enum_name, enum_cons)
|
||||
in
|
||||
let is_in_type_ordering s =
|
||||
@ -535,8 +474,8 @@ let format_ctx
|
||||
let scope_structs =
|
||||
List.map
|
||||
(fun (s, _) -> Scopelang.Dependency.TVertex.Struct s)
|
||||
(StructMap.bindings
|
||||
(StructMap.filter
|
||||
(StructName.Map.bindings
|
||||
(StructName.Map.filter
|
||||
(fun s _ -> not (is_in_type_ordering s))
|
||||
ctx.ctx_structs))
|
||||
in
|
||||
|
@ -14,29 +14,28 @@
|
||||
License for the specific language governing permissions and limitations under
|
||||
the License. *)
|
||||
|
||||
open Utils
|
||||
open Shared_ast
|
||||
open Ast
|
||||
|
||||
(** Formats a lambda calculus program into a valid OCaml program *)
|
||||
|
||||
val avoid_keywords : string -> string
|
||||
val find_struct : StructName.t -> decl_ctx -> (StructFieldName.t * typ) list
|
||||
val find_enum : EnumName.t -> decl_ctx -> (EnumConstructor.t * typ) list
|
||||
val find_struct : StructName.t -> decl_ctx -> typ StructField.Map.t
|
||||
val find_enum : EnumName.t -> decl_ctx -> typ EnumConstructor.Map.t
|
||||
val typ_needs_parens : typ -> bool
|
||||
val needs_parens : 'm expr -> bool
|
||||
|
||||
(* val needs_parens : 'm expr -> bool *)
|
||||
val format_enum_name : Format.formatter -> EnumName.t -> unit
|
||||
val format_enum_cons_name : Format.formatter -> EnumConstructor.t -> unit
|
||||
val format_struct_name : Format.formatter -> StructName.t -> unit
|
||||
|
||||
val format_struct_field_name :
|
||||
Format.formatter -> StructName.t option * StructFieldName.t -> unit
|
||||
Format.formatter -> StructName.t option * StructField.t -> unit
|
||||
|
||||
val format_to_module_name :
|
||||
Format.formatter -> [< `Ename of EnumName.t | `Sname of StructName.t ] -> unit
|
||||
(* * val format_lit : Format.formatter -> lit Marked.pos -> unit * val
|
||||
format_uid_list : Format.formatter -> Uid.MarkedString.info list -> unit *)
|
||||
|
||||
val format_lit : Format.formatter -> lit Marked.pos -> unit
|
||||
val format_uid_list : Format.formatter -> Uid.MarkedString.info list -> unit
|
||||
val format_var : Format.formatter -> 'm Var.t -> unit
|
||||
|
||||
val format_program :
|
||||
|
@ -1,7 +1,7 @@
|
||||
(library
|
||||
(name literate)
|
||||
(public_name catala.literate)
|
||||
(libraries re utils surface ubase))
|
||||
(libraries re catala_utils surface ubase uutf))
|
||||
|
||||
(documentation
|
||||
(package catala)
|
||||
|
@ -18,7 +18,7 @@
|
||||
(** This modules weaves the source code and the legislative text together into a
|
||||
document that law professionals can understand. *)
|
||||
|
||||
open Utils
|
||||
open Catala_utils
|
||||
open Literate_common
|
||||
module A = Surface.Ast
|
||||
module P = Printf
|
||||
@ -91,7 +91,7 @@ let wrap_html
|
||||
</ul>\n"
|
||||
css_as_string (literal_title language)
|
||||
(literal_generated_by language)
|
||||
Utils.Cli.version
|
||||
Cli.version
|
||||
(pre_html (literal_disclaimer_and_link language))
|
||||
(literal_source_files language)
|
||||
(String.concat "\n"
|
||||
@ -133,7 +133,7 @@ let pygmentize_code (c : string Marked.pos) (language : C.backend_lang) : string
|
||||
"html";
|
||||
"-O";
|
||||
"style=colorful,anchorlinenos=True,lineanchors=\""
|
||||
^ String_common.to_ascii (Pos.get_file (Marked.get_mark c))
|
||||
^ String.to_ascii (Pos.get_file (Marked.get_mark c))
|
||||
^ "\",linenos=table,linenostart="
|
||||
^ string_of_int (Pos.get_start_line (Marked.get_mark c));
|
||||
"-o";
|
||||
@ -160,7 +160,7 @@ let pygmentize_code (c : string Marked.pos) (language : C.backend_lang) : string
|
||||
|
||||
let sanitize_html_href str =
|
||||
str
|
||||
|> String_common.to_ascii
|
||||
|> String.to_ascii
|
||||
|> R.substitute ~rex:(R.regexp "[' '°\"]") ~subst:(function _ -> "%20")
|
||||
|
||||
let rec law_structure_to_html
|
||||
|
@ -17,7 +17,7 @@
|
||||
(** This modules weaves the source code and the legislative text together into a
|
||||
document that law professionals can understand. *)
|
||||
|
||||
open Utils
|
||||
open Catala_utils
|
||||
|
||||
(** {1 Helpers} *)
|
||||
|
||||
|
@ -18,7 +18,7 @@
|
||||
(** This modules weaves the source code and the legislative text together into a
|
||||
document that law professionals can understand. *)
|
||||
|
||||
open Utils
|
||||
open Catala_utils
|
||||
open Literate_common
|
||||
module A = Surface.Ast
|
||||
module R = Re.Pcre
|
||||
@ -61,7 +61,7 @@ let wrap_latex
|
||||
%s
|
||||
\usepackage{minted}
|
||||
\usepackage{longtable}
|
||||
\usepackage{booktabs}
|
||||
\usepackage{booktabs,tabularx}
|
||||
\usepackage{newunicodechar}
|
||||
\usepackage{textcomp}
|
||||
\usepackage[hidelinks]{hyperref}
|
||||
@ -122,8 +122,8 @@ let wrap_latex
|
||||
\newunicodechar{→}{$\rightarrow$}
|
||||
\newunicodechar{≠}{$\neq$}
|
||||
|
||||
\newcommand*\FancyVerbStartString{```catala}
|
||||
\newcommand*\FancyVerbStopString{```}
|
||||
\newcommand*\FancyVerbStartString{\PYG{l+s}{```catala}}
|
||||
\newcommand*\FancyVerbStopString{\PYG{l+s}{```}}
|
||||
|
||||
\fvset{
|
||||
numbers=left,
|
||||
@ -151,14 +151,15 @@ codes={\catcode`\$=3\catcode`\^=7}
|
||||
\tableofcontents
|
||||
|
||||
\[\star\star\star\]
|
||||
\clearpage|latex}
|
||||
\clearpage
|
||||
|latex}
|
||||
(match language with Fr -> "french" | En -> "english" | Pl -> "polish")
|
||||
(match language with Fr -> "\\setmainfont{Marianne}" | _ -> "")
|
||||
(* for France, we use the official font of the French state design system
|
||||
https://gouvfr.atlassian.net/wiki/spaces/DB/pages/223019527/Typographie+-+Typography *)
|
||||
(literal_title language)
|
||||
(literal_generated_by language)
|
||||
Utils.Cli.version
|
||||
Cli.version
|
||||
(pre_latexify (literal_disclaimer_and_link language))
|
||||
(literal_source_files language)
|
||||
(String.concat
|
||||
@ -243,7 +244,7 @@ let rec law_structure_to_latex
|
||||
| En -> "Metadata"
|
||||
| Pl -> "Metadane"
|
||||
in
|
||||
let start_line = Pos.get_start_line (Marked.get_mark c) - 1 in
|
||||
let start_line = Pos.get_start_line (Marked.get_mark c) + 1 in
|
||||
let filename = Filename.basename (Pos.get_file (Marked.get_mark c)) in
|
||||
let block_content = Marked.unmark c in
|
||||
check_exceeding_lines start_line filename block_content;
|
||||
@ -252,7 +253,7 @@ let rec law_structure_to_latex
|
||||
"\\begin{tcolorbox}[colframe=OliveGreen, breakable, \
|
||||
title=\\textcolor{black}{\\texttt{%s}},title after \
|
||||
break=\\textcolor{black}{\\texttt{%s}},before skip=1em, after skip=1em]\n\
|
||||
\\begin{minted}[numbersep=9mm, firstnumber=%d, breaklines, \
|
||||
\\begin{minted}[numbersep=9mm, firstnumber=%d, \
|
||||
label={\\hspace*{\\fill}\\texttt{%s}}]{%s}\n\
|
||||
```catala\n\
|
||||
%s```\n\
|
||||
|
@ -17,7 +17,7 @@
|
||||
(** This modules weaves the source code and the legislative text together into a
|
||||
document that law professionals can understand. *)
|
||||
|
||||
open Utils
|
||||
open Catala_utils
|
||||
|
||||
(** {1 Helpers} *)
|
||||
|
||||
|
@ -14,7 +14,7 @@
|
||||
License for the specific language governing permissions and limitations under
|
||||
the License. *)
|
||||
|
||||
open Utils
|
||||
open Catala_utils
|
||||
open Cli
|
||||
|
||||
let literal_title = function
|
||||
|
@ -14,32 +14,30 @@
|
||||
License for the specific language governing permissions and limitations under
|
||||
the License. *)
|
||||
|
||||
open Utils
|
||||
open Catala_utils
|
||||
|
||||
val literal_title : Cli.backend_lang -> string
|
||||
(** Return the title traduction according the given
|
||||
{!type:Utils.Cli.backend_lang}. *)
|
||||
(** Return the title traduction according the given {!type:Cli.backend_lang}. *)
|
||||
|
||||
val literal_generated_by : Cli.backend_lang -> string
|
||||
(** Return the 'generated by' traduction according the given
|
||||
{!type:Utils.Cli.backend_lang}. *)
|
||||
{!type:Cli.backend_lang}. *)
|
||||
|
||||
val literal_source_files : Cli.backend_lang -> string
|
||||
(** Return the 'source files weaved' traduction according the given
|
||||
{!type:Utils.Cli.backend_lang}. *)
|
||||
{!type:Cli.backend_lang}. *)
|
||||
|
||||
val literal_disclaimer_and_link : Cli.backend_lang -> string
|
||||
(** Return the traduction of a paragraph giving a basic disclaimer about Catala
|
||||
and a link to the website according the given {!type:
|
||||
Utils.Cli.backend_lang}. *)
|
||||
and a link to the website according the given {!type: Cli.backend_lang}. *)
|
||||
|
||||
val literal_last_modification : Cli.backend_lang -> string
|
||||
(** Return the 'last modification' traduction according the given
|
||||
{!type:Utils.Cli.backend_lang}. *)
|
||||
{!type:Cli.backend_lang}. *)
|
||||
|
||||
val get_language_extension : Cli.backend_lang -> string
|
||||
(** Return the file extension corresponding to the given
|
||||
{!type:Utils.Cli.backend_lang}. *)
|
||||
{!type:Cli.backend_lang}. *)
|
||||
|
||||
val run_pandoc : string -> [ `Html | `Latex ] -> string
|
||||
(** Runs the [pandoc] on a string to pretty-print markdown features into the
|
||||
|
@ -14,8 +14,10 @@
|
||||
License for the specific language governing permissions and limitations under
|
||||
the License. *)
|
||||
|
||||
open Catala_utils
|
||||
|
||||
type 'ast plugin_apply_fun_typ =
|
||||
source_file:Utils.Pos.input_file ->
|
||||
source_file:Pos.input_file ->
|
||||
output_file:string option ->
|
||||
scope:string option ->
|
||||
'ast ->
|
||||
@ -51,17 +53,21 @@ let find name = Hashtbl.find backend_plugins (String.lowercase_ascii name)
|
||||
let load_file f =
|
||||
try
|
||||
Dynlink.loadfile f;
|
||||
Utils.Cli.debug_print "Plugin %S loaded" f
|
||||
Cli.debug_print "Plugin %S loaded" f
|
||||
with e ->
|
||||
Utils.Errors.format_warning "Could not load plugin %S: %s" f
|
||||
Errors.format_warning "Could not load plugin %S: %s" f
|
||||
(Printexc.to_string e)
|
||||
|
||||
let load_dir d =
|
||||
let rec load_dir d =
|
||||
let dynlink_exts =
|
||||
if Dynlink.is_native then [".cmxs"] else [".cmo"; ".cma"]
|
||||
in
|
||||
Array.iter
|
||||
(fun f ->
|
||||
if List.exists (Filename.check_suffix f) dynlink_exts then
|
||||
load_file (Filename.concat d f))
|
||||
if f.[0] = '.' then ()
|
||||
else
|
||||
let f = Filename.concat d f in
|
||||
if Sys.is_directory f then load_dir f
|
||||
else if List.exists (Filename.check_suffix f) dynlink_exts then
|
||||
load_file f)
|
||||
(Sys.readdir d)
|
||||
|
@ -16,8 +16,10 @@
|
||||
|
||||
(** {2 catala-facing API} *)
|
||||
|
||||
open Catala_utils
|
||||
|
||||
type 'ast plugin_apply_fun_typ =
|
||||
source_file:Utils.Pos.input_file ->
|
||||
source_file:Pos.input_file ->
|
||||
output_file:string option ->
|
||||
scope:string option ->
|
||||
'ast ->
|
||||
|
@ -18,9 +18,8 @@
|
||||
(** Catala plugin for generating web APIs. It generates OCaml code before the
|
||||
the associated [js_of_ocaml] wrapper. *)
|
||||
|
||||
open Utils
|
||||
open Catala_utils
|
||||
open Shared_ast
|
||||
open String_common
|
||||
open Lcalc
|
||||
open Lcalc.Ast
|
||||
open Lcalc.To_ocaml
|
||||
@ -40,11 +39,11 @@ module To_jsoo = struct
|
||||
|
||||
let format_struct_field_name_camel_case
|
||||
(fmt : Format.formatter)
|
||||
(v : StructFieldName.t) : unit =
|
||||
(v : StructField.t) : unit =
|
||||
let s =
|
||||
Format.asprintf "%a" StructFieldName.format_t v
|
||||
|> to_ascii
|
||||
|> to_snake_case
|
||||
Format.asprintf "%a" StructField.format_t v
|
||||
|> String.to_ascii
|
||||
|> String.to_snake_case
|
||||
|> avoid_keywords
|
||||
|> to_camel_case
|
||||
in
|
||||
@ -118,17 +117,17 @@ module To_jsoo = struct
|
||||
let format_var_camel_case (fmt : Format.formatter) (v : 'm Var.t) : unit =
|
||||
let lowercase_name =
|
||||
Bindlib.name_of v
|
||||
|> to_ascii
|
||||
|> to_snake_case
|
||||
|> String.to_ascii
|
||||
|> String.to_snake_case
|
||||
|> Re.Pcre.substitute ~rex:(Re.Pcre.regexp "\\.") ~subst:(fun _ ->
|
||||
"_dot_")
|
||||
|> to_ascii
|
||||
|> String.to_ascii
|
||||
|> avoid_keywords
|
||||
|> to_camel_case
|
||||
in
|
||||
if
|
||||
List.mem lowercase_name ["handle_default"; "handle_default_opt"]
|
||||
|| begins_with_uppercase (Bindlib.name_of v)
|
||||
|| String.begins_with_uppercase (Bindlib.name_of v)
|
||||
then Format.fprintf fmt "%s" lowercase_name
|
||||
else if lowercase_name = "_" then Format.fprintf fmt "%s" lowercase_name
|
||||
else Format.fprintf fmt "%s_" lowercase_name
|
||||
@ -166,7 +165,7 @@ module To_jsoo = struct
|
||||
format_struct_field_name_camel_case struct_field
|
||||
format_typ_to_jsoo struct_field_type fmt_struct_name ()
|
||||
format_struct_field_name (None, struct_field)))
|
||||
struct_fields
|
||||
(StructField.Map.bindings struct_fields)
|
||||
in
|
||||
let fmt_of_jsoo fmt _ =
|
||||
Format.fprintf fmt "%a"
|
||||
@ -186,7 +185,7 @@ module To_jsoo = struct
|
||||
format_struct_field_name (None, struct_field)
|
||||
format_typ_of_jsoo struct_field_type fmt_struct_name ()
|
||||
format_struct_field_name_camel_case struct_field))
|
||||
struct_fields
|
||||
(StructField.Map.bindings struct_fields)
|
||||
in
|
||||
let fmt_conv_funs fmt _ =
|
||||
Format.fprintf fmt
|
||||
@ -203,7 +202,7 @@ module To_jsoo = struct
|
||||
() fmt_struct_name () fmt_module_struct_name () fmt_of_jsoo ()
|
||||
in
|
||||
|
||||
if List.length struct_fields = 0 then
|
||||
if StructField.Map.is_empty struct_fields then
|
||||
Format.fprintf fmt
|
||||
"class type %a =@ object end@\n\
|
||||
let %a_to_jsoo (_ : %a.t) : %a Js.t = object%%js end@\n\
|
||||
@ -220,11 +219,11 @@ module To_jsoo = struct
|
||||
Format.fprintf fmt "@[<hov 2>method %a:@ %a %a@]"
|
||||
format_struct_field_name_camel_case struct_field format_typ
|
||||
struct_field_type format_prop_or_meth struct_field_type))
|
||||
struct_fields fmt_conv_funs ()
|
||||
(StructField.Map.bindings struct_fields)
|
||||
fmt_conv_funs ()
|
||||
in
|
||||
let format_enum_decl
|
||||
fmt
|
||||
(enum_name, (enum_cons : (EnumConstructor.t * typ) list)) =
|
||||
let format_enum_decl fmt (enum_name, (enum_cons : typ EnumConstructor.Map.t))
|
||||
=
|
||||
let fmt_enum_name fmt _ = format_enum_name fmt enum_name in
|
||||
let fmt_module_enum_name fmt _ =
|
||||
To_ocaml.format_to_module_name fmt (`Ename enum_name)
|
||||
@ -247,7 +246,7 @@ module To_jsoo = struct
|
||||
end@]"
|
||||
format_enum_cons_name cname format_enum_cons_name cname
|
||||
format_typ_to_jsoo typ))
|
||||
enum_cons
|
||||
(EnumConstructor.Map.bindings enum_cons)
|
||||
in
|
||||
let fmt_of_jsoo fmt _ =
|
||||
Format.fprintf fmt
|
||||
@ -273,7 +272,8 @@ module To_jsoo = struct
|
||||
format_enum_cons_name cname fmt_module_enum_name ()
|
||||
format_enum_cons_name cname format_typ_of_jsoo typ
|
||||
fmt_enum_name ()))
|
||||
enum_cons fmt_module_enum_name ()
|
||||
(EnumConstructor.Map.bindings enum_cons)
|
||||
fmt_module_enum_name ()
|
||||
in
|
||||
|
||||
let fmt_conv_funs fmt _ =
|
||||
@ -301,7 +301,8 @@ module To_jsoo = struct
|
||||
~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n")
|
||||
(fun fmt (enum_cons, _) ->
|
||||
Format.fprintf fmt "- \"%a\"" format_enum_cons_name enum_cons))
|
||||
enum_cons fmt_conv_funs ()
|
||||
(EnumConstructor.Map.bindings enum_cons)
|
||||
fmt_conv_funs ()
|
||||
in
|
||||
let is_in_type_ordering s =
|
||||
List.exists
|
||||
@ -314,8 +315,8 @@ module To_jsoo = struct
|
||||
let scope_structs =
|
||||
List.map
|
||||
(fun (s, _) -> Scopelang.Dependency.TVertex.Struct s)
|
||||
(StructMap.bindings
|
||||
(StructMap.filter
|
||||
(StructName.Map.bindings
|
||||
(StructName.Map.filter
|
||||
(fun s _ -> not (is_in_type_ordering s))
|
||||
ctx.ctx_structs))
|
||||
in
|
||||
|
@ -1,18 +1,22 @@
|
||||
(executable
|
||||
(library
|
||||
(name python)
|
||||
(modes plugin)
|
||||
(public_name catala.plugins.python)
|
||||
(synopsis
|
||||
"Demonstration Catala plugin that reproduces the behaviour of the built-in python backend")
|
||||
(modules python)
|
||||
(libraries catala.driver))
|
||||
|
||||
(executable
|
||||
(library
|
||||
(name api_web)
|
||||
(modes plugin)
|
||||
(public_name catala.plugins.api_web)
|
||||
(synopsis "Catala plugin for interaction with a web interface")
|
||||
(modules api_web)
|
||||
(libraries catala.driver))
|
||||
|
||||
(executable
|
||||
(library
|
||||
(name json_schema)
|
||||
(modes plugin)
|
||||
(public_name catala.plugins.json_schema)
|
||||
(synopsis "Catala plugin generating JSON schemas useful to build web-forms")
|
||||
(modules json_schema)
|
||||
(libraries catala.driver))
|
||||
|
||||
|
@ -20,8 +20,7 @@
|
||||
let name = "json_schema"
|
||||
let extension = "_schema.json"
|
||||
|
||||
open Utils
|
||||
open String_common
|
||||
open Catala_utils
|
||||
open Shared_ast
|
||||
open Lcalc.Ast
|
||||
open Lcalc.To_ocaml
|
||||
@ -38,11 +37,11 @@ module To_json = struct
|
||||
|
||||
let format_struct_field_name_camel_case
|
||||
(fmt : Format.formatter)
|
||||
(v : StructFieldName.t) : unit =
|
||||
(v : StructField.t) : unit =
|
||||
let s =
|
||||
Format.asprintf "%a" StructFieldName.format_t v
|
||||
|> to_ascii
|
||||
|> to_snake_case
|
||||
Format.asprintf "%a" StructField.format_t v
|
||||
|> String.to_ascii
|
||||
|> String.to_snake_case
|
||||
|> avoid_keywords
|
||||
|> to_camel_case
|
||||
in
|
||||
@ -97,7 +96,7 @@ module To_json = struct
|
||||
(fun fmt (field_name, field_type) ->
|
||||
Format.fprintf fmt "@[<hov 2>\"%a\": {@\n%a@]@\n}"
|
||||
format_struct_field_name_camel_case field_name fmt_type field_type))
|
||||
(find_struct sname ctx)
|
||||
(StructField.Map.bindings (find_struct sname ctx))
|
||||
|
||||
let fmt_definitions
|
||||
(ctx : decl_ctx)
|
||||
@ -118,11 +117,14 @@ module To_json = struct
|
||||
(t :: acc) @ collect_required_type_defs_from_scope_input s
|
||||
| TEnum e ->
|
||||
List.fold_left collect (t :: acc)
|
||||
(List.map snd (EnumMap.find e ctx.ctx_enums))
|
||||
(List.map snd
|
||||
(EnumConstructor.Map.bindings
|
||||
(EnumName.Map.find e ctx.ctx_enums)))
|
||||
| TArray t -> collect acc t
|
||||
| _ -> acc
|
||||
in
|
||||
find_struct input_struct ctx
|
||||
|> StructField.Map.bindings
|
||||
|> List.fold_left (fun acc (_, field_typ) -> collect acc field_typ) []
|
||||
|> List.sort_uniq (fun t t' -> String.compare (get_name t) (get_name t'))
|
||||
in
|
||||
@ -146,7 +148,7 @@ module To_json = struct
|
||||
Format.fprintf fmt
|
||||
"@[<hov 2>{@\n\"type\": \"string\",@\n\"enum\": [\"%a\"]@]@\n}"
|
||||
format_enum_cons_name enum_cons))
|
||||
enum_def
|
||||
(EnumConstructor.Map.bindings enum_def)
|
||||
(Format.pp_print_list
|
||||
~pp_sep:(fun fmt () -> Format.fprintf fmt ",@\n")
|
||||
(fun fmt (enum_cons, payload_type) ->
|
||||
@ -168,7 +170,7 @@ module To_json = struct
|
||||
}@]@\n\
|
||||
}"
|
||||
format_enum_cons_name enum_cons fmt_type payload_type))
|
||||
enum_def
|
||||
(EnumConstructor.Map.bindings enum_def)
|
||||
in
|
||||
|
||||
Format.fprintf fmt "@\n%a"
|
||||
|
@ -20,13 +20,15 @@
|
||||
The code for the Python backend already has first-class support, so there
|
||||
would be no reason to use this plugin instead *)
|
||||
|
||||
open Catala_utils
|
||||
|
||||
let name = "python-plugin"
|
||||
let extension = ".py"
|
||||
|
||||
let apply ~source_file ~output_file ~scope prgm type_ordering =
|
||||
ignore source_file;
|
||||
ignore scope;
|
||||
Utils.File.with_formatter_of_opt_file output_file
|
||||
File.with_formatter_of_opt_file output_file
|
||||
@@ fun fmt -> Scalc.To_python.format_program fmt prgm type_ordering
|
||||
|
||||
let () = Driver.Plugin.register_scalc ~name ~extension apply
|
||||
|
@ -14,7 +14,7 @@
|
||||
License for the specific language governing permissions and limitations under
|
||||
the License. *)
|
||||
|
||||
open Utils
|
||||
open Catala_utils
|
||||
open Shared_ast
|
||||
module D = Dcalc.Ast
|
||||
module L = Lcalc.Ast
|
||||
@ -28,15 +28,15 @@ let handle_default_opt = TopLevelName.fresh ("handle_default_opt", Pos.no_pos)
|
||||
type expr = naked_expr Marked.pos
|
||||
|
||||
and naked_expr =
|
||||
| EVar of LocalName.t
|
||||
| EFunc of TopLevelName.t
|
||||
| EStruct of expr list * StructName.t
|
||||
| EStructFieldAccess of expr * StructFieldName.t * StructName.t
|
||||
| EInj of expr * EnumConstructor.t * EnumName.t
|
||||
| EArray of expr list
|
||||
| ELit of L.lit
|
||||
| EApp of expr * expr list
|
||||
| EOp of operator
|
||||
| EVar : LocalName.t -> naked_expr
|
||||
| EFunc : TopLevelName.t -> naked_expr
|
||||
| EStruct : expr list * StructName.t -> naked_expr
|
||||
| EStructFieldAccess : expr * StructField.t * StructName.t -> naked_expr
|
||||
| EInj : expr * EnumConstructor.t * EnumName.t -> naked_expr
|
||||
| EArray : expr list -> naked_expr
|
||||
| ELit : L.lit -> naked_expr
|
||||
| EApp : expr * expr list -> naked_expr
|
||||
| EOp : (lcalc, _) operator -> naked_expr
|
||||
|
||||
type stmt =
|
||||
| SInnerFuncDef of LocalName.t Marked.pos * func
|
||||
|
@ -14,7 +14,7 @@
|
||||
License for the specific language governing permissions and limitations under
|
||||
the License. *)
|
||||
|
||||
open Utils
|
||||
open Catala_utils
|
||||
open Shared_ast
|
||||
module A = Ast
|
||||
module L = Lcalc.Ast
|
||||
@ -35,36 +35,37 @@ let rec translate_expr (ctxt : 'm ctxt) (expr : 'm L.expr) : A.block * A.expr =
|
||||
| EVar v ->
|
||||
let local_var =
|
||||
try A.EVar (Var.Map.find v ctxt.var_dict)
|
||||
with Not_found -> A.EFunc (Var.Map.find v ctxt.func_dict)
|
||||
with Not_found -> (
|
||||
try A.EFunc (Var.Map.find v ctxt.func_dict)
|
||||
with Not_found ->
|
||||
Errors.raise_spanned_error (Expr.pos expr)
|
||||
"Var not found in lambda→scalc: %a@\nknown: @[<hov>%a@]@\n"
|
||||
Print.var_debug v
|
||||
(Format.pp_print_list ~pp_sep:Format.pp_print_space
|
||||
(fun ppf (v, _) -> Print.var_debug ppf v))
|
||||
(Var.Map.bindings ctxt.var_dict))
|
||||
in
|
||||
[], (local_var, Expr.pos expr)
|
||||
| ETuple (args, Some s_name) ->
|
||||
| EStruct { fields; name } ->
|
||||
let args_stmts, new_args =
|
||||
List.fold_left
|
||||
(fun (args_stmts, new_args) arg ->
|
||||
StructField.Map.fold
|
||||
(fun _ arg (args_stmts, new_args) ->
|
||||
let arg_stmts, new_arg = translate_expr ctxt arg in
|
||||
arg_stmts @ args_stmts, new_arg :: new_args)
|
||||
([], []) args
|
||||
fields ([], [])
|
||||
in
|
||||
let new_args = List.rev new_args in
|
||||
let args_stmts = List.rev args_stmts in
|
||||
args_stmts, (A.EStruct (new_args, s_name), Expr.pos expr)
|
||||
| ETuple (_, None) -> failwith "Non-struct tuples cannot be compiled to scalc"
|
||||
| ETupleAccess (e1, num_field, Some s_name, _) ->
|
||||
args_stmts, (A.EStruct (new_args, name), Expr.pos expr)
|
||||
| ETuple _ -> failwith "Tuples cannot be compiled to scalc"
|
||||
| EStructAccess { e = e1; field; name } ->
|
||||
let e1_stmts, new_e1 = translate_expr ctxt e1 in
|
||||
let field_name =
|
||||
fst (List.nth (StructMap.find s_name ctxt.decl_ctx.ctx_structs) num_field)
|
||||
in
|
||||
e1_stmts, (A.EStructFieldAccess (new_e1, field_name, s_name), Expr.pos expr)
|
||||
| ETupleAccess (_, _, None, _) ->
|
||||
failwith "Non-struct tuples cannot be compiled to scalc"
|
||||
| EInj (e1, num_cons, e_name, _) ->
|
||||
e1_stmts, (A.EStructFieldAccess (new_e1, field, name), Expr.pos expr)
|
||||
| ETupleAccess _ -> failwith "Non-struct tuples cannot be compiled to scalc"
|
||||
| EInj { e = e1; cons; name } ->
|
||||
let e1_stmts, new_e1 = translate_expr ctxt e1 in
|
||||
let cons_name =
|
||||
fst (List.nth (EnumMap.find e_name ctxt.decl_ctx.ctx_enums) num_cons)
|
||||
in
|
||||
e1_stmts, (A.EInj (new_e1, cons_name, e_name), Expr.pos expr)
|
||||
| EApp (f, args) ->
|
||||
e1_stmts, (A.EInj (new_e1, cons, name), Expr.pos expr)
|
||||
| EApp { f; args } ->
|
||||
let f_stmts, new_f = translate_expr ctxt f in
|
||||
let args_stmts, new_args =
|
||||
List.fold_left
|
||||
@ -85,7 +86,7 @@ let rec translate_expr (ctxt : 'm ctxt) (expr : 'm L.expr) : A.block * A.expr =
|
||||
in
|
||||
let new_args = List.rev new_args in
|
||||
args_stmts, (A.EArray new_args, Expr.pos expr)
|
||||
| EOp op -> [], (A.EOp op, Expr.pos expr)
|
||||
| EOp { op; _ } -> [], (A.EOp op, Expr.pos expr)
|
||||
| ELit l -> [], (A.ELit l, Expr.pos expr)
|
||||
| _ ->
|
||||
let tmp_var =
|
||||
@ -120,11 +121,11 @@ and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.expr) : A.block =
|
||||
(* Assertions are always encapsulated in a unit-typed let binding *)
|
||||
let e_stmts, new_e = translate_expr ctxt e in
|
||||
e_stmts @ [A.SAssert (Marked.unmark new_e), Expr.pos block_expr]
|
||||
| EApp ((EAbs (binder, taus), binder_mark), args) ->
|
||||
| EApp { f = EAbs { binder; tys }, binder_mark; args } ->
|
||||
(* This defines multiple local variables at the time *)
|
||||
let binder_pos = Expr.mark_pos binder_mark in
|
||||
let vars, body = Bindlib.unmbind binder in
|
||||
let vars_tau = List.map2 (fun x tau -> x, tau) (Array.to_list vars) taus in
|
||||
let vars_tau = List.map2 (fun x tau -> x, tau) (Array.to_list vars) tys in
|
||||
let ctxt =
|
||||
{
|
||||
ctxt with
|
||||
@ -167,10 +168,10 @@ and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.expr) : A.block =
|
||||
in
|
||||
let rest_of_block = translate_statements ctxt body in
|
||||
local_decls @ List.flatten def_blocks @ rest_of_block
|
||||
| EAbs (binder, taus) ->
|
||||
| EAbs { binder; tys } ->
|
||||
let vars, body = Bindlib.unmbind binder in
|
||||
let binder_pos = Expr.pos block_expr in
|
||||
let vars_tau = List.map2 (fun x tau -> x, tau) (Array.to_list vars) taus in
|
||||
let vars_tau = List.map2 (fun x tau -> x, tau) (Array.to_list vars) tys in
|
||||
let closure_name =
|
||||
match ctxt.inside_definition_of with
|
||||
| None -> A.LocalName.fresh (ctxt.context_name, Expr.pos block_expr)
|
||||
@ -203,13 +204,13 @@ and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.expr) : A.block =
|
||||
} ),
|
||||
binder_pos );
|
||||
]
|
||||
| EMatch (e1, args, e_name) ->
|
||||
| EMatch { e = e1; cases; name } ->
|
||||
let e1_stmts, new_e1 = translate_expr ctxt e1 in
|
||||
let new_args =
|
||||
List.fold_left
|
||||
(fun new_args arg ->
|
||||
let new_cases =
|
||||
EnumConstructor.Map.fold
|
||||
(fun _ arg new_args ->
|
||||
match Marked.unmark arg with
|
||||
| EAbs (binder, _) ->
|
||||
| EAbs { binder; _ } ->
|
||||
let vars, body = Bindlib.unmbind binder in
|
||||
assert (Array.length vars = 1);
|
||||
let var = vars.(0) in
|
||||
@ -223,20 +224,20 @@ and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.expr) : A.block =
|
||||
(new_arg, scalc_var) :: new_args
|
||||
| _ -> assert false
|
||||
(* should not happen *))
|
||||
[] args
|
||||
cases []
|
||||
in
|
||||
let new_args = List.rev new_args in
|
||||
e1_stmts @ [A.SSwitch (new_e1, e_name, new_args), Expr.pos block_expr]
|
||||
| EIfThenElse (cond, e_true, e_false) ->
|
||||
let new_args = List.rev new_cases in
|
||||
e1_stmts @ [A.SSwitch (new_e1, name, new_args), Expr.pos block_expr]
|
||||
| EIfThenElse { cond; etrue; efalse } ->
|
||||
let cond_stmts, s_cond = translate_expr ctxt cond in
|
||||
let s_e_true = translate_statements ctxt e_true in
|
||||
let s_e_false = translate_statements ctxt e_false in
|
||||
let s_e_true = translate_statements ctxt etrue in
|
||||
let s_e_false = translate_statements ctxt efalse in
|
||||
cond_stmts
|
||||
@ [A.SIfThenElse (s_cond, s_e_true, s_e_false), Expr.pos block_expr]
|
||||
| ECatch (e_try, except, e_catch) ->
|
||||
let s_e_try = translate_statements ctxt e_try in
|
||||
let s_e_catch = translate_statements ctxt e_catch in
|
||||
[A.STryExcept (s_e_try, except, s_e_catch), Expr.pos block_expr]
|
||||
| ECatch { body; exn; handler } ->
|
||||
let s_e_try = translate_statements ctxt body in
|
||||
let s_e_catch = translate_statements ctxt handler in
|
||||
[A.STryExcept (s_e_try, exn, s_e_catch), Expr.pos block_expr]
|
||||
| ERaise except ->
|
||||
(* Before raising the exception, we still give a dummy definition to the
|
||||
current variable so that tools like mypy don't complain. *)
|
||||
|
@ -14,7 +14,7 @@
|
||||
License for the specific language governing permissions and limitations under
|
||||
the License. *)
|
||||
|
||||
open Utils
|
||||
open Catala_utils
|
||||
open Shared_ast
|
||||
open Ast
|
||||
|
||||
@ -44,11 +44,12 @@ let rec format_expr
|
||||
Print.punctuation "{"
|
||||
(Format.pp_print_list
|
||||
~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ")
|
||||
(fun fmt (e, struct_field) ->
|
||||
(fun fmt (e, (struct_field, _)) ->
|
||||
Format.fprintf fmt "%a%a%a%a %a" Print.punctuation "\""
|
||||
StructFieldName.format_t struct_field Print.punctuation "\""
|
||||
StructField.format_t struct_field Print.punctuation "\""
|
||||
Print.punctuation ":" format_expr e))
|
||||
(List.combine es (List.map fst (StructMap.find s decl_ctx.ctx_structs)))
|
||||
(List.combine es
|
||||
(StructField.Map.bindings (StructName.Map.find s decl_ctx.ctx_structs)))
|
||||
Print.punctuation "}"
|
||||
| EArray es ->
|
||||
Format.fprintf fmt "@[<hov 2>%a%a%a@]" Print.punctuation "["
|
||||
@ -56,41 +57,31 @@ let rec format_expr
|
||||
~pp_sep:(fun fmt () -> Format.fprintf fmt ";@ ")
|
||||
(fun fmt e -> Format.fprintf fmt "%a" format_expr e))
|
||||
es Print.punctuation "]"
|
||||
| EStructFieldAccess (e1, field, s) ->
|
||||
| EStructFieldAccess (e1, field, _) ->
|
||||
Format.fprintf fmt "%a%a%a%a%a" format_expr e1 Print.punctuation "."
|
||||
Print.punctuation "\"" StructFieldName.format_t
|
||||
(fst
|
||||
(List.find
|
||||
(fun (field', _) -> StructFieldName.compare field' field = 0)
|
||||
(StructMap.find s decl_ctx.ctx_structs)))
|
||||
Print.punctuation "\""
|
||||
| EInj (e, case, enum) ->
|
||||
Format.fprintf fmt "@[<hov 2>%a@ %a@]" Print.enum_constructor
|
||||
(fst
|
||||
(List.find
|
||||
(fun (case', _) -> EnumConstructor.compare case' case = 0)
|
||||
(EnumMap.find enum decl_ctx.ctx_enums)))
|
||||
Print.punctuation "\"" StructField.format_t field Print.punctuation "\""
|
||||
| EInj (e, cons, _) ->
|
||||
Format.fprintf fmt "@[<hov 2>%a@ %a@]" Print.enum_constructor cons
|
||||
format_expr e
|
||||
| ELit l -> Print.lit fmt l
|
||||
| EApp ((EOp (Binop ((Map | Filter) as op)), _), [arg1; arg2]) ->
|
||||
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@]" Print.binop op format_with_parens
|
||||
arg1 format_with_parens arg2
|
||||
| EApp ((EOp (Binop op), _), [arg1; arg2]) ->
|
||||
| EApp ((EOp ((Map | Filter) as op), _), [arg1; arg2]) ->
|
||||
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@]" Print.operator op
|
||||
format_with_parens arg1 format_with_parens arg2
|
||||
| EApp ((EOp op, _), [arg1; arg2]) ->
|
||||
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@]" format_with_parens arg1
|
||||
Print.binop op format_with_parens arg2
|
||||
| EApp ((EOp (Unop (Log _)), _), [arg1]) when not debug ->
|
||||
Print.operator op format_with_parens arg2
|
||||
| EApp ((EOp (Log _), _), [arg1]) when not debug ->
|
||||
Format.fprintf fmt "%a" format_with_parens arg1
|
||||
| EApp ((EOp (Unop op), _), [arg1]) ->
|
||||
Format.fprintf fmt "@[<hov 2>%a@ %a@]" Print.unop op format_with_parens arg1
|
||||
| EApp ((EOp op, _), [arg1]) ->
|
||||
Format.fprintf fmt "@[<hov 2>%a@ %a@]" Print.operator op format_with_parens
|
||||
arg1
|
||||
| EApp (f, args) ->
|
||||
Format.fprintf fmt "@[<hov 2>%a@ %a@]" format_expr f
|
||||
(Format.pp_print_list
|
||||
~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ")
|
||||
format_with_parens)
|
||||
args
|
||||
| EOp (Ternop op) -> Format.fprintf fmt "%a" Print.ternop op
|
||||
| EOp (Binop op) -> Format.fprintf fmt "%a" Print.binop op
|
||||
| EOp (Unop op) -> Format.fprintf fmt "%a" Print.unop op
|
||||
| EOp op -> Format.fprintf fmt "%a" Print.operator op
|
||||
|
||||
let rec format_statement
|
||||
(decl_ctx : decl_ctx)
|
||||
@ -101,22 +92,22 @@ let rec format_statement
|
||||
match Marked.unmark stmt with
|
||||
| SInnerFuncDef (name, func) ->
|
||||
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@ %a@]@\n@[<v 2> %a@]" Print.keyword
|
||||
"let" LocalName.format_t (Marked.unmark name)
|
||||
"let" format_local_name (Marked.unmark name)
|
||||
(Format.pp_print_list
|
||||
~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ")
|
||||
(fun fmt ((name, _), typ) ->
|
||||
Format.fprintf fmt "%a%a %a@ %a%a" Print.punctuation "("
|
||||
LocalName.format_t name Print.punctuation ":" (Print.typ decl_ctx)
|
||||
format_local_name name Print.punctuation ":" (Print.typ decl_ctx)
|
||||
typ Print.punctuation ")"))
|
||||
func.func_params Print.punctuation "="
|
||||
(format_block decl_ctx ~debug)
|
||||
func.func_body
|
||||
| SLocalDecl (name, typ) ->
|
||||
Format.fprintf fmt "@[<hov 2>%a %a %a@ %a@]" Print.keyword "decl"
|
||||
LocalName.format_t (Marked.unmark name) Print.punctuation ":"
|
||||
format_local_name (Marked.unmark name) Print.punctuation ":"
|
||||
(Print.typ decl_ctx) typ
|
||||
| SLocalDef (name, naked_expr) ->
|
||||
Format.fprintf fmt "@[<hov 2>%a %a@ %a@]" LocalName.format_t
|
||||
Format.fprintf fmt "@[<hov 2>%a %a@ %a@]" format_local_name
|
||||
(Marked.unmark name) Print.punctuation "="
|
||||
(format_expr decl_ctx ~debug)
|
||||
naked_expr
|
||||
@ -156,10 +147,13 @@ let rec format_statement
|
||||
(fun fmt ((case, _), (arm_block, payload_name)) ->
|
||||
Format.fprintf fmt "%a %a%a@ %a @[<v 2>%a@ %a@]" Print.punctuation
|
||||
"|" Print.enum_constructor case Print.punctuation ":"
|
||||
LocalName.format_t payload_name Print.punctuation "→"
|
||||
format_local_name payload_name Print.punctuation "→"
|
||||
(format_block decl_ctx ~debug)
|
||||
arm_block))
|
||||
(List.combine (EnumMap.find enum decl_ctx.ctx_enums) arms)
|
||||
(List.combine
|
||||
(EnumConstructor.Map.bindings
|
||||
(EnumName.Map.find enum decl_ctx.ctx_enums))
|
||||
arms)
|
||||
|
||||
and format_block
|
||||
(decl_ctx : decl_ctx)
|
||||
@ -183,8 +177,8 @@ let format_scope
|
||||
~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ")
|
||||
(fun fmt ((name, _), typ) ->
|
||||
Format.fprintf fmt "%a%a %a@ %a%a" Print.punctuation "("
|
||||
LocalName.format_t name Print.punctuation ":" (Print.typ decl_ctx)
|
||||
typ Print.punctuation ")"))
|
||||
format_local_name name Print.punctuation ":" (Print.typ decl_ctx) typ
|
||||
Print.punctuation ")"))
|
||||
body.scope_body_func.func_params Print.punctuation "="
|
||||
(format_block decl_ctx ~debug)
|
||||
body.scope_body_func.func_body
|
||||
|
@ -15,21 +15,20 @@
|
||||
the License. *)
|
||||
[@@@warning "-32-27"]
|
||||
|
||||
open Utils
|
||||
open Catala_utils
|
||||
open Shared_ast
|
||||
open Ast
|
||||
open String_common
|
||||
module Runtime = Runtime_ocaml.Runtime
|
||||
module D = Dcalc.Ast
|
||||
module L = Lcalc.Ast
|
||||
|
||||
let format_lit (fmt : Format.formatter) (l : L.lit Marked.pos) : unit =
|
||||
match Marked.unmark l with
|
||||
| LBool true -> Format.fprintf fmt "True"
|
||||
| LBool false -> Format.fprintf fmt "False"
|
||||
| LBool true -> Format.pp_print_string fmt "True"
|
||||
| LBool false -> Format.pp_print_string fmt "False"
|
||||
| LInt i ->
|
||||
Format.fprintf fmt "integer_of_string(\"%s\")" (Runtime.integer_to_string i)
|
||||
| LUnit -> Format.fprintf fmt "Unit()"
|
||||
| LUnit -> Format.pp_print_string fmt "Unit()"
|
||||
| LRat i -> Format.fprintf fmt "decimal_of_string(\"%a\")" Print.lit (LRat i)
|
||||
| LMoney e ->
|
||||
Format.fprintf fmt "money_of_cents_string(\"%s\")"
|
||||
@ -45,31 +44,60 @@ let format_lit (fmt : Format.formatter) (l : L.lit Marked.pos) : unit =
|
||||
|
||||
let format_log_entry (fmt : Format.formatter) (entry : log_entry) : unit =
|
||||
match entry with
|
||||
| VarDef _ -> Format.fprintf fmt ":="
|
||||
| BeginCall -> Format.fprintf fmt "→ "
|
||||
| VarDef _ -> Format.pp_print_string fmt ":="
|
||||
| BeginCall -> Format.pp_print_string fmt "→ "
|
||||
| EndCall -> Format.fprintf fmt "%s" "← "
|
||||
| PosRecordIfTrueBool -> Format.fprintf fmt "☛ "
|
||||
| PosRecordIfTrueBool -> Format.pp_print_string fmt "☛ "
|
||||
|
||||
let format_binop (fmt : Format.formatter) (op : binop Marked.pos) : unit =
|
||||
let format_op
|
||||
(type k)
|
||||
(fmt : Format.formatter)
|
||||
(op : (lcalc, k) operator Marked.pos) : unit =
|
||||
match Marked.unmark op with
|
||||
| Add _ | Concat -> Format.fprintf fmt "+"
|
||||
| Sub _ -> Format.fprintf fmt "-"
|
||||
| Mult _ -> Format.fprintf fmt "*"
|
||||
| Div KInt -> Format.fprintf fmt "//"
|
||||
| Div _ -> Format.fprintf fmt "/"
|
||||
| And -> Format.fprintf fmt "and"
|
||||
| Or -> Format.fprintf fmt "or"
|
||||
| Eq -> Format.fprintf fmt "=="
|
||||
| Neq | Xor -> Format.fprintf fmt "!="
|
||||
| Lt _ -> Format.fprintf fmt "<"
|
||||
| Lte _ -> Format.fprintf fmt "<="
|
||||
| Gt _ -> Format.fprintf fmt ">"
|
||||
| Gte _ -> Format.fprintf fmt ">="
|
||||
| Map -> Format.fprintf fmt "list_map"
|
||||
| Filter -> Format.fprintf fmt "list_filter"
|
||||
|
||||
let format_ternop (fmt : Format.formatter) (op : ternop Marked.pos) : unit =
|
||||
match Marked.unmark op with Fold -> Format.fprintf fmt "list_fold_left"
|
||||
| Log (entry, infos) -> assert false
|
||||
| Minus_int | Minus_rat | Minus_mon | Minus_dur ->
|
||||
Format.pp_print_string fmt "-"
|
||||
(* Todo: use the names from [Operator.name] *)
|
||||
| Not -> Format.pp_print_string fmt "not"
|
||||
| Length -> Format.pp_print_string fmt "list_length"
|
||||
| ToRat_int -> Format.pp_print_string fmt "decimal_of_integer"
|
||||
| ToRat_mon -> Format.pp_print_string fmt "decimal_of_money"
|
||||
| ToMoney_rat -> Format.pp_print_string fmt "money_of_decimal"
|
||||
| GetDay -> Format.pp_print_string fmt "day_of_month_of_date"
|
||||
| GetMonth -> Format.pp_print_string fmt "month_number_of_date"
|
||||
| GetYear -> Format.pp_print_string fmt "year_of_date"
|
||||
| FirstDayOfMonth -> Format.pp_print_string fmt "first_day_of_month"
|
||||
| LastDayOfMonth -> Format.pp_print_string fmt "last_day_of_month"
|
||||
| Round_mon -> Format.pp_print_string fmt "money_round"
|
||||
| Round_rat -> Format.pp_print_string fmt "decimal_round"
|
||||
| Add_int_int | Add_rat_rat | Add_mon_mon | Add_dat_dur | Add_dur_dur | Concat
|
||||
->
|
||||
Format.pp_print_string fmt "+"
|
||||
| Sub_int_int | Sub_rat_rat | Sub_mon_mon | Sub_dat_dat | Sub_dat_dur
|
||||
| Sub_dur_dur ->
|
||||
Format.pp_print_string fmt "-"
|
||||
| Mult_int_int | Mult_rat_rat | Mult_mon_rat | Mult_dur_int ->
|
||||
Format.pp_print_string fmt "*"
|
||||
| Div_int_int -> Format.pp_print_string fmt "//"
|
||||
| Div_rat_rat | Div_mon_mon | Div_mon_rat -> Format.pp_print_string fmt "/"
|
||||
| And -> Format.pp_print_string fmt "and"
|
||||
| Or -> Format.pp_print_string fmt "or"
|
||||
| Eq -> Format.pp_print_string fmt "=="
|
||||
| Xor -> Format.pp_print_string fmt "!="
|
||||
| Lt_int_int | Lt_rat_rat | Lt_mon_mon | Lt_dat_dat | Lt_dur_dur ->
|
||||
Format.pp_print_string fmt "<"
|
||||
| Lte_int_int | Lte_rat_rat | Lte_mon_mon | Lte_dat_dat | Lte_dur_dur ->
|
||||
Format.pp_print_string fmt "<="
|
||||
| Gt_int_int | Gt_rat_rat | Gt_mon_mon | Gt_dat_dat | Gt_dur_dur ->
|
||||
Format.pp_print_string fmt ">"
|
||||
| Gte_int_int | Gte_rat_rat | Gte_mon_mon | Gte_dat_dat | Gte_dur_dur ->
|
||||
Format.pp_print_string fmt ">="
|
||||
| Eq_int_int | Eq_rat_rat | Eq_mon_mon | Eq_dat_dat | Eq_dur_dur ->
|
||||
Format.pp_print_string fmt "=="
|
||||
| Map -> Format.pp_print_string fmt "list_map"
|
||||
| Reduce -> Format.pp_print_string fmt "list_reduce"
|
||||
| Filter -> Format.pp_print_string fmt "list_filter"
|
||||
| Fold -> Format.pp_print_string fmt "list_fold_left"
|
||||
|
||||
let format_uid_list (fmt : Format.formatter) (uids : Uid.MarkedString.info list)
|
||||
: unit =
|
||||
@ -77,7 +105,7 @@ let format_uid_list (fmt : Format.formatter) (uids : Uid.MarkedString.info list)
|
||||
(Format.pp_print_list
|
||||
~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ")
|
||||
(fun fmt info ->
|
||||
Format.fprintf fmt "\"%a\"" Utils.Uid.MarkedString.format_info info))
|
||||
Format.fprintf fmt "\"%a\"" Uid.MarkedString.format info))
|
||||
uids
|
||||
|
||||
let format_string_list (fmt : Format.formatter) (uids : string list) : unit =
|
||||
@ -90,23 +118,6 @@ let format_string_list (fmt : Format.formatter) (uids : string list) : unit =
|
||||
(Re.replace sanitize_quotes ~f:(fun _ -> "\\\"") info)))
|
||||
uids
|
||||
|
||||
let format_unop (fmt : Format.formatter) (op : unop Marked.pos) : unit =
|
||||
match Marked.unmark op with
|
||||
| Minus _ -> Format.fprintf fmt "-"
|
||||
| Not -> Format.fprintf fmt "not"
|
||||
| Log (entry, infos) -> assert false (* should not happen *)
|
||||
| Length -> Format.fprintf fmt "%s" "list_length"
|
||||
| IntToRat -> Format.fprintf fmt "%s" "decimal_of_integer"
|
||||
| MoneyToRat -> Format.fprintf fmt "%s" "decimal_of_money"
|
||||
| RatToMoney -> Format.fprintf fmt "%s" "money_of_decimal"
|
||||
| GetDay -> Format.fprintf fmt "%s" "day_of_month_of_date"
|
||||
| GetMonth -> Format.fprintf fmt "%s" "month_number_of_date"
|
||||
| GetYear -> Format.fprintf fmt "%s" "year_of_date"
|
||||
| FirstDayOfMonth -> Format.fprintf fmt "%s" "first_day_of_month"
|
||||
| LastDayOfMonth -> Format.fprintf fmt "%s" "last_day_of_month"
|
||||
| RoundMoney -> Format.fprintf fmt "%s" "money_round"
|
||||
| RoundDecimal -> Format.fprintf fmt "%s" "decimal_round"
|
||||
|
||||
let avoid_keywords (s : string) : string =
|
||||
if
|
||||
match s with
|
||||
@ -125,24 +136,26 @@ let avoid_keywords (s : string) : string =
|
||||
let format_struct_name (fmt : Format.formatter) (v : StructName.t) : unit =
|
||||
Format.fprintf fmt "%s"
|
||||
(avoid_keywords
|
||||
(to_camel_case (to_ascii (Format.asprintf "%a" StructName.format_t v))))
|
||||
(String.to_camel_case
|
||||
(String.to_ascii (Format.asprintf "%a" StructName.format_t v))))
|
||||
|
||||
let format_struct_field_name (fmt : Format.formatter) (v : StructFieldName.t) :
|
||||
unit =
|
||||
let format_struct_field_name (fmt : Format.formatter) (v : StructField.t) : unit
|
||||
=
|
||||
Format.fprintf fmt "%s"
|
||||
(avoid_keywords
|
||||
(to_ascii (Format.asprintf "%a" StructFieldName.format_t v)))
|
||||
(String.to_ascii (Format.asprintf "%a" StructField.format_t v)))
|
||||
|
||||
let format_enum_name (fmt : Format.formatter) (v : EnumName.t) : unit =
|
||||
Format.fprintf fmt "%s"
|
||||
(avoid_keywords
|
||||
(to_camel_case (to_ascii (Format.asprintf "%a" EnumName.format_t v))))
|
||||
(String.to_camel_case
|
||||
(String.to_ascii (Format.asprintf "%a" EnumName.format_t v))))
|
||||
|
||||
let format_enum_cons_name (fmt : Format.formatter) (v : EnumConstructor.t) :
|
||||
unit =
|
||||
Format.fprintf fmt "%s"
|
||||
(avoid_keywords
|
||||
(to_ascii (Format.asprintf "%a" EnumConstructor.format_t v)))
|
||||
(String.to_ascii (Format.asprintf "%a" EnumConstructor.format_t v)))
|
||||
|
||||
let typ_needs_parens (e : typ) : bool =
|
||||
match Marked.unmark e with TArrow _ | TArray _ -> true | _ -> false
|
||||
@ -180,10 +193,10 @@ let rec format_typ (fmt : Format.formatter) (typ : typ) : unit =
|
||||
|
||||
let format_name_cleaned (fmt : Format.formatter) (s : string) : unit =
|
||||
s
|
||||
|> to_ascii
|
||||
|> to_snake_case
|
||||
|> String.to_ascii
|
||||
|> String.to_snake_case
|
||||
|> Re.Pcre.substitute ~rex:(Re.Pcre.regexp "\\.") ~subst:(fun _ -> "_dot_")
|
||||
|> to_ascii
|
||||
|> String.to_ascii
|
||||
|> avoid_keywords
|
||||
|> Format.fprintf fmt "%s"
|
||||
|
||||
@ -268,10 +281,11 @@ let rec format_expression (ctx : decl_ctx) (fmt : Format.formatter) (e : expr) :
|
||||
Format.fprintf fmt "%a(%a)" format_struct_name s
|
||||
(Format.pp_print_list
|
||||
~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ")
|
||||
(fun fmt (e, struct_field) ->
|
||||
(fun fmt (e, (struct_field, _)) ->
|
||||
Format.fprintf fmt "%a = %a" format_struct_field_name struct_field
|
||||
(format_expression ctx) e))
|
||||
(List.combine es (List.map fst (StructMap.find s ctx.ctx_structs)))
|
||||
(List.combine es
|
||||
(StructField.Map.bindings (StructName.Map.find s ctx.ctx_structs)))
|
||||
| EStructFieldAccess (e1, field, _) ->
|
||||
Format.fprintf fmt "%a.%a" (format_expression ctx) e1
|
||||
format_struct_field_name field
|
||||
@ -296,21 +310,20 @@ let rec format_expression (ctx : decl_ctx) (fmt : Format.formatter) (e : expr) :
|
||||
(fun fmt e -> Format.fprintf fmt "%a" (format_expression ctx) e))
|
||||
es
|
||||
| ELit l -> Format.fprintf fmt "%a" format_lit (Marked.same_mark_as l e)
|
||||
| EApp ((EOp (Binop ((Map | Filter) as op)), _), [arg1; arg2]) ->
|
||||
Format.fprintf fmt "%a(%a,@ %a)" format_binop (op, Pos.no_pos)
|
||||
| EApp ((EOp ((Map | Filter) as op), _), [arg1; arg2]) ->
|
||||
Format.fprintf fmt "%a(%a,@ %a)" format_op (op, Pos.no_pos)
|
||||
(format_expression ctx) arg1 (format_expression ctx) arg2
|
||||
| EApp ((EOp (Binop op), _), [arg1; arg2]) ->
|
||||
Format.fprintf fmt "(%a %a@ %a)" (format_expression ctx) arg1 format_binop
|
||||
| EApp ((EOp op, _), [arg1; arg2]) ->
|
||||
Format.fprintf fmt "(%a %a@ %a)" (format_expression ctx) arg1 format_op
|
||||
(op, Pos.no_pos) (format_expression ctx) arg2
|
||||
| EApp ((EApp ((EOp (Unop (Log (BeginCall, info))), _), [f]), _), [arg])
|
||||
| EApp ((EApp ((EOp (Log (BeginCall, info)), _), [f]), _), [arg])
|
||||
when !Cli.trace_flag ->
|
||||
Format.fprintf fmt "log_begin_call(%a,@ %a,@ %a)" format_uid_list info
|
||||
(format_expression ctx) f (format_expression ctx) arg
|
||||
| EApp ((EOp (Unop (Log (VarDef tau, info))), _), [arg1]) when !Cli.trace_flag
|
||||
->
|
||||
| EApp ((EOp (Log (VarDef tau, info)), _), [arg1]) when !Cli.trace_flag ->
|
||||
Format.fprintf fmt "log_variable_definition(%a,@ %a)" format_uid_list info
|
||||
(format_expression ctx) arg1
|
||||
| EApp ((EOp (Unop (Log (PosRecordIfTrueBool, _))), pos), [arg1])
|
||||
| EApp ((EOp (Log (PosRecordIfTrueBool, _)), pos), [arg1])
|
||||
when !Cli.trace_flag ->
|
||||
Format.fprintf fmt
|
||||
"log_decision_taken(SourcePosition(filename=\"%s\",@ start_line=%d,@ \
|
||||
@ -318,16 +331,21 @@ let rec format_expression (ctx : decl_ctx) (fmt : Format.formatter) (e : expr) :
|
||||
(Pos.get_file pos) (Pos.get_start_line pos) (Pos.get_start_column pos)
|
||||
(Pos.get_end_line pos) (Pos.get_end_column pos) format_string_list
|
||||
(Pos.get_law_info pos) (format_expression ctx) arg1
|
||||
| EApp ((EOp (Unop (Log (EndCall, info))), _), [arg1]) when !Cli.trace_flag ->
|
||||
| EApp ((EOp (Log (EndCall, info)), _), [arg1]) when !Cli.trace_flag ->
|
||||
Format.fprintf fmt "log_end_call(%a,@ %a)" format_uid_list info
|
||||
(format_expression ctx) arg1
|
||||
| EApp ((EOp (Unop (Log _)), _), [arg1]) ->
|
||||
| EApp ((EOp (Log _), _), [arg1]) ->
|
||||
Format.fprintf fmt "%a" (format_expression ctx) arg1
|
||||
| EApp ((EOp (Unop ((Minus _ | Not) as op)), _), [arg1]) ->
|
||||
Format.fprintf fmt "%a %a" format_unop (op, Pos.no_pos)
|
||||
| EApp ((EOp Not, _), [arg1]) ->
|
||||
Format.fprintf fmt "%a %a" format_op (Not, Pos.no_pos)
|
||||
(format_expression ctx) arg1
|
||||
| EApp ((EOp (Unop op), _), [arg1]) ->
|
||||
Format.fprintf fmt "%a(%a)" format_unop (op, Pos.no_pos)
|
||||
| EApp
|
||||
((EOp ((Minus_int | Minus_rat | Minus_mon | Minus_dur) as op), _), [arg1])
|
||||
->
|
||||
Format.fprintf fmt "%a %a" format_op (op, Pos.no_pos)
|
||||
(format_expression ctx) arg1
|
||||
| EApp ((EOp op, _), [arg1]) ->
|
||||
Format.fprintf fmt "%a(%a)" format_op (op, Pos.no_pos)
|
||||
(format_expression ctx) arg1
|
||||
| EApp ((EFunc x, pos), args)
|
||||
when Ast.TopLevelName.compare x Ast.handle_default = 0
|
||||
@ -348,9 +366,7 @@ let rec format_expression (ctx : decl_ctx) (fmt : Format.formatter) (e : expr) :
|
||||
~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ")
|
||||
(format_expression ctx))
|
||||
args
|
||||
| EOp (Ternop op) -> Format.fprintf fmt "%a" format_ternop (op, Pos.no_pos)
|
||||
| EOp (Binop op) -> Format.fprintf fmt "%a" format_binop (op, Pos.no_pos)
|
||||
| EOp (Unop op) -> Format.fprintf fmt "%a" format_unop (op, Pos.no_pos)
|
||||
| EOp op -> Format.fprintf fmt "%a" format_op (op, Pos.no_pos)
|
||||
|
||||
let rec format_statement
|
||||
(ctx : decl_ctx)
|
||||
@ -400,7 +416,7 @@ let rec format_statement
|
||||
List.map2
|
||||
(fun (x, y) (cons, _) -> x, y, cons)
|
||||
cases
|
||||
(EnumMap.find e_name ctx.ctx_enums)
|
||||
(EnumConstructor.Map.bindings (EnumName.Map.find e_name ctx.ctx_enums))
|
||||
in
|
||||
let tmp_var = LocalName.fresh ("match_arg", Pos.no_pos) in
|
||||
Format.fprintf fmt "%a = %a@\n@[<hov 4>if %a@]" format_var tmp_var
|
||||
@ -442,6 +458,7 @@ let format_ctx
|
||||
(fmt : Format.formatter)
|
||||
(ctx : decl_ctx) : unit =
|
||||
let format_struct_decl fmt (struct_name, struct_fields) =
|
||||
let fields = StructField.Map.bindings struct_fields in
|
||||
Format.fprintf fmt
|
||||
"class %a:@\n\
|
||||
\ def __init__(self, %a) -> None:@\n\
|
||||
@ -461,40 +478,41 @@ let format_ctx
|
||||
struct_name
|
||||
(Format.pp_print_list
|
||||
~pp_sep:(fun fmt () -> Format.fprintf fmt ", ")
|
||||
(fun _fmt (struct_field, struct_field_type) ->
|
||||
(fun fmt (struct_field, struct_field_type) ->
|
||||
Format.fprintf fmt "%a: %a" format_struct_field_name struct_field
|
||||
format_typ struct_field_type))
|
||||
struct_fields
|
||||
(if List.length struct_fields = 0 then fun fmt _ ->
|
||||
fields
|
||||
(if StructField.Map.is_empty struct_fields then fun fmt _ ->
|
||||
Format.fprintf fmt " pass"
|
||||
else
|
||||
Format.pp_print_list
|
||||
~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n")
|
||||
(fun _fmt (struct_field, _) ->
|
||||
(fun fmt (struct_field, _) ->
|
||||
Format.fprintf fmt " self.%a = %a" format_struct_field_name
|
||||
struct_field format_struct_field_name struct_field))
|
||||
struct_fields format_struct_name struct_name
|
||||
(if List.length struct_fields > 0 then
|
||||
fields format_struct_name struct_name
|
||||
(if not (StructField.Map.is_empty struct_fields) then
|
||||
Format.pp_print_list
|
||||
~pp_sep:(fun fmt () -> Format.fprintf fmt " and@ ")
|
||||
(fun _fmt (struct_field, _) ->
|
||||
(fun fmt (struct_field, _) ->
|
||||
Format.fprintf fmt "self.%a == other.%a" format_struct_field_name
|
||||
struct_field format_struct_field_name struct_field)
|
||||
else fun fmt _ -> Format.fprintf fmt "True")
|
||||
struct_fields format_struct_name struct_name
|
||||
fields format_struct_name struct_name
|
||||
(Format.pp_print_list
|
||||
~pp_sep:(fun fmt () -> Format.fprintf fmt ",")
|
||||
(fun _fmt (struct_field, _) ->
|
||||
(fun fmt (struct_field, _) ->
|
||||
Format.fprintf fmt "%a={}" format_struct_field_name struct_field))
|
||||
struct_fields
|
||||
fields
|
||||
(Format.pp_print_list
|
||||
~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ")
|
||||
(fun _fmt (struct_field, _) ->
|
||||
(fun fmt (struct_field, _) ->
|
||||
Format.fprintf fmt "self.%a" format_struct_field_name struct_field))
|
||||
struct_fields
|
||||
fields
|
||||
in
|
||||
let format_enum_decl fmt (enum_name, enum_cons) =
|
||||
if List.length enum_cons = 0 then failwith "no constructors in the enum"
|
||||
if EnumConstructor.Map.is_empty enum_cons then
|
||||
failwith "no constructors in the enum"
|
||||
else
|
||||
Format.fprintf fmt
|
||||
"@[<hov 4>class %a_Code(Enum):@\n\
|
||||
@ -522,9 +540,11 @@ let format_ctx
|
||||
format_enum_name enum_name
|
||||
(Format.pp_print_list
|
||||
~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n")
|
||||
(fun _fmt (i, enum_cons, enum_cons_type) ->
|
||||
(fun fmt (i, enum_cons, enum_cons_type) ->
|
||||
Format.fprintf fmt "%a = %d" format_enum_cons_name enum_cons i))
|
||||
(List.mapi (fun i (x, y) -> i, x, y) enum_cons)
|
||||
(List.mapi
|
||||
(fun i (x, y) -> i, x, y)
|
||||
(EnumConstructor.Map.bindings enum_cons))
|
||||
format_enum_name enum_name format_enum_name enum_name format_enum_name
|
||||
enum_name
|
||||
in
|
||||
@ -540,8 +560,8 @@ let format_ctx
|
||||
let scope_structs =
|
||||
List.map
|
||||
(fun (s, _) -> Scopelang.Dependency.TVertex.Struct s)
|
||||
(StructMap.bindings
|
||||
(StructMap.filter
|
||||
(StructName.Map.bindings
|
||||
(StructName.Map.filter
|
||||
(fun s _ -> not (is_in_type_ordering s))
|
||||
ctx.ctx_structs))
|
||||
in
|
||||
@ -550,10 +570,10 @@ let format_ctx
|
||||
match struct_or_enum with
|
||||
| Scopelang.Dependency.TVertex.Struct s ->
|
||||
Format.fprintf fmt "%a@\n@\n" format_struct_decl
|
||||
(s, StructMap.find s ctx.ctx_structs)
|
||||
(s, StructName.Map.find s ctx.ctx_structs)
|
||||
| Scopelang.Dependency.TVertex.Enum e ->
|
||||
Format.fprintf fmt "%a@\n@\n" format_enum_decl
|
||||
(e, EnumMap.find e ctx.ctx_enums))
|
||||
(e, EnumName.Map.find e ctx.ctx_enums))
|
||||
(type_ordering @ scope_structs)
|
||||
|
||||
let format_program
|
||||
|
@ -14,7 +14,7 @@
|
||||
License for the specific language governing permissions and limitations under
|
||||
the License. *)
|
||||
|
||||
open Utils
|
||||
open Catala_utils
|
||||
open Shared_ast
|
||||
|
||||
type location = scopelang glocation
|
||||
@ -31,7 +31,7 @@ type 'm expr = (scopelang, 'm mark) gexpr
|
||||
let rec locations_used (e : 'm expr) : LocationSet.t =
|
||||
match e with
|
||||
| ELocation l, pos -> LocationSet.singleton (l, Expr.mark_pos pos)
|
||||
| EAbs (binder, _), _ ->
|
||||
| EAbs { binder; _ }, _ ->
|
||||
let _, body = Bindlib.unmbind binder in
|
||||
locations_used body
|
||||
| e ->
|
||||
@ -39,23 +39,20 @@ let rec locations_used (e : 'm expr) : LocationSet.t =
|
||||
(fun e -> LocationSet.union (locations_used e))
|
||||
e LocationSet.empty
|
||||
|
||||
type io_input = NoInput | OnlyInput | Reentrant
|
||||
type io = { io_output : bool Marked.pos; io_input : io_input Marked.pos }
|
||||
|
||||
type 'm rule =
|
||||
| Definition of location Marked.pos * typ * io * 'm expr
|
||||
| Definition of location Marked.pos * typ * Desugared.Ast.io * 'm expr
|
||||
| Assertion of 'm expr
|
||||
| Call of ScopeName.t * SubScopeName.t * 'm mark
|
||||
|
||||
type 'm scope_decl = {
|
||||
scope_decl_name : ScopeName.t;
|
||||
scope_sig : (typ * io) ScopeVarMap.t;
|
||||
scope_sig : (typ * Desugared.Ast.io) ScopeVar.Map.t;
|
||||
scope_decl_rules : 'm rule list;
|
||||
scope_mark : 'm mark;
|
||||
}
|
||||
|
||||
type 'm program = {
|
||||
program_scopes : 'm scope_decl ScopeMap.t;
|
||||
program_scopes : 'm scope_decl ScopeName.Map.t;
|
||||
program_ctx : decl_ctx;
|
||||
}
|
||||
|
||||
@ -73,17 +70,17 @@ let type_rule decl_ctx env = function
|
||||
|
||||
let type_program (prg : 'm program) : typed program =
|
||||
let typing_env =
|
||||
ScopeMap.fold
|
||||
ScopeName.Map.fold
|
||||
(fun scope_name scope_decl ->
|
||||
let vars = ScopeVarMap.map fst scope_decl.scope_sig in
|
||||
let vars = ScopeVar.Map.map fst scope_decl.scope_sig in
|
||||
Typing.Env.add_scope scope_name ~vars)
|
||||
prg.program_scopes Typing.Env.empty
|
||||
in
|
||||
let program_scopes =
|
||||
ScopeMap.map
|
||||
ScopeName.Map.map
|
||||
(fun scope_decl ->
|
||||
let typing_env =
|
||||
ScopeVarMap.fold
|
||||
ScopeVar.Map.fold
|
||||
(fun svar (typ, _) env -> Typing.Env.add_scope_var svar typ env)
|
||||
scope_decl.scope_sig typing_env
|
||||
in
|
||||
|
@ -16,7 +16,7 @@
|
||||
|
||||
(** Abstract syntax tree of the scope language *)
|
||||
|
||||
open Utils
|
||||
open Catala_utils
|
||||
open Shared_ast
|
||||
|
||||
(** {1 Identifiers} *)
|
||||
@ -31,41 +31,20 @@ type 'm expr = (scopelang, 'm mark) gexpr
|
||||
|
||||
val locations_used : 'm expr -> LocationSet.t
|
||||
|
||||
(** This type characterizes the three levels of visibility for a given scope
|
||||
variable with regards to the scope's input and possible redefinitions inside
|
||||
the scope.. *)
|
||||
type io_input =
|
||||
| NoInput
|
||||
(** For an internal variable defined only in the scope, and does not
|
||||
appear in the input. *)
|
||||
| OnlyInput
|
||||
(** For variables that should not be redefined in the scope, because they
|
||||
appear in the input. *)
|
||||
| Reentrant
|
||||
(** For variables defined in the scope that can also be redefined by the
|
||||
caller as they appear in the input. *)
|
||||
|
||||
type io = {
|
||||
io_output : bool Marked.pos;
|
||||
(** [true] is present in the output of the scope. *)
|
||||
io_input : io_input Marked.pos;
|
||||
}
|
||||
(** Characterization of the input/output status of a scope variable. *)
|
||||
|
||||
type 'm rule =
|
||||
| Definition of location Marked.pos * typ * io * 'm expr
|
||||
| Definition of location Marked.pos * typ * Desugared.Ast.io * 'm expr
|
||||
| Assertion of 'm expr
|
||||
| Call of ScopeName.t * SubScopeName.t * 'm mark
|
||||
|
||||
type 'm scope_decl = {
|
||||
scope_decl_name : ScopeName.t;
|
||||
scope_sig : (typ * io) ScopeVarMap.t;
|
||||
scope_sig : (typ * Desugared.Ast.io) ScopeVar.Map.t;
|
||||
scope_decl_rules : 'm rule list;
|
||||
scope_mark : 'm mark;
|
||||
}
|
||||
|
||||
type 'm program = {
|
||||
program_scopes : 'm scope_decl ScopeMap.t;
|
||||
program_scopes : 'm scope_decl ScopeName.Map.t;
|
||||
program_ctx : decl_ctx;
|
||||
}
|
||||
|
||||
|
@ -17,7 +17,7 @@
|
||||
(** Graph representation of the dependencies between scopes in the Catala
|
||||
program. Vertices are functions, x -> y if x is used in the definition of y. *)
|
||||
|
||||
open Utils
|
||||
open Catala_utils
|
||||
open Shared_ast
|
||||
module SVertex = ScopeName
|
||||
|
||||
@ -41,13 +41,13 @@ module SSCC = Graph.Components.Make (SDependencies)
|
||||
let rec expr_used_scopes e =
|
||||
let recurse_subterms e =
|
||||
Expr.shallow_fold
|
||||
(fun e -> ScopeMap.union (fun _ x _ -> Some x) (expr_used_scopes e))
|
||||
e ScopeMap.empty
|
||||
(fun e -> ScopeName.Map.union (fun _ x _ -> Some x) (expr_used_scopes e))
|
||||
e ScopeName.Map.empty
|
||||
in
|
||||
match e with
|
||||
| (EScopeCall (scope, _), m) as e ->
|
||||
ScopeMap.add scope (Expr.mark_pos m) (recurse_subterms e)
|
||||
| EAbs (binder, _), _ ->
|
||||
| (EScopeCall { scope; _ }, m) as e ->
|
||||
ScopeName.Map.add scope (Expr.mark_pos m) (recurse_subterms e)
|
||||
| EAbs { binder; _ }, _ ->
|
||||
let _, body = Bindlib.unmbind binder in
|
||||
expr_used_scopes body
|
||||
| e -> recurse_subterms e
|
||||
@ -58,28 +58,28 @@ let rule_used_scopes = function
|
||||
walking through all exprs again *)
|
||||
expr_used_scopes e
|
||||
| Ast.Call (subscope, subindex, _) ->
|
||||
ScopeMap.singleton subscope
|
||||
ScopeName.Map.singleton subscope
|
||||
(Marked.get_mark (SubScopeName.get_info subindex))
|
||||
|
||||
let build_program_dep_graph (prgm : 'm Ast.program) : SDependencies.t =
|
||||
let g = SDependencies.empty in
|
||||
let g =
|
||||
ScopeMap.fold
|
||||
ScopeName.Map.fold
|
||||
(fun v _ g -> SDependencies.add_vertex g v)
|
||||
prgm.program_scopes g
|
||||
in
|
||||
ScopeMap.fold
|
||||
ScopeName.Map.fold
|
||||
(fun scope_name scope g ->
|
||||
List.fold_left
|
||||
(fun g rule ->
|
||||
let used_scopes = rule_used_scopes rule in
|
||||
if ScopeMap.mem scope_name used_scopes then
|
||||
if ScopeName.Map.mem scope_name used_scopes then
|
||||
Errors.raise_spanned_error
|
||||
(Marked.get_mark (ScopeName.get_info scope.Ast.scope_decl_name))
|
||||
"The scope %a is calling into itself as a subscope, which is \
|
||||
forbidden since Catala does not provide recursion"
|
||||
ScopeName.format_t scope.Ast.scope_decl_name;
|
||||
ScopeMap.fold
|
||||
ScopeName.Map.fold
|
||||
(fun used_scope pos g ->
|
||||
let edge = SDependencies.E.create used_scope pos scope_name in
|
||||
SDependencies.add_edge_e g edge)
|
||||
@ -190,10 +190,10 @@ let build_type_graph (structs : struct_ctx) (enums : enum_ctx) : TDependencies.t
|
||||
=
|
||||
let g = TDependencies.empty in
|
||||
let g =
|
||||
StructMap.fold
|
||||
StructName.Map.fold
|
||||
(fun s fields g ->
|
||||
List.fold_left
|
||||
(fun g (_, typ) ->
|
||||
StructField.Map.fold
|
||||
(fun _ typ g ->
|
||||
let def = TVertex.Struct s in
|
||||
let g = TDependencies.add_vertex g def in
|
||||
let used = get_structs_or_enums_in_type typ in
|
||||
@ -210,14 +210,14 @@ let build_type_graph (structs : struct_ctx) (enums : enum_ctx) : TDependencies.t
|
||||
in
|
||||
TDependencies.add_edge_e g edge)
|
||||
used g)
|
||||
g fields)
|
||||
fields g)
|
||||
structs g
|
||||
in
|
||||
let g =
|
||||
EnumMap.fold
|
||||
EnumName.Map.fold
|
||||
(fun e cases g ->
|
||||
List.fold_left
|
||||
(fun g (_, typ) ->
|
||||
EnumConstructor.Map.fold
|
||||
(fun _ typ g ->
|
||||
let def = TVertex.Enum e in
|
||||
let g = TDependencies.add_vertex g def in
|
||||
let used = get_structs_or_enums_in_type typ in
|
||||
@ -234,7 +234,7 @@ let build_type_graph (structs : struct_ctx) (enums : enum_ctx) : TDependencies.t
|
||||
in
|
||||
TDependencies.add_edge_e g edge)
|
||||
used g)
|
||||
g cases)
|
||||
cases g)
|
||||
enums g
|
||||
in
|
||||
g
|
||||
|
@ -17,7 +17,7 @@
|
||||
(** Graph representation of the dependencies between scopes in the Catala
|
||||
program. Vertices are functions, x -> y if x is used in the definition of y. *)
|
||||
|
||||
open Utils
|
||||
open Catala_utils
|
||||
open Shared_ast
|
||||
|
||||
(** {1 Scope dependencies} *)
|
||||
|
@ -1,7 +1,7 @@
|
||||
(library
|
||||
(name scopelang)
|
||||
(public_name catala.scopelang)
|
||||
(libraries utils dcalc ocamlgraph)
|
||||
(libraries catala_utils ocamlgraph desugared)
|
||||
(flags
|
||||
(:standard -short-paths)))
|
||||
|
||||
|
730
compiler/scopelang/from_desugared.ml
Normal file
730
compiler/scopelang/from_desugared.ml
Normal file
@ -0,0 +1,730 @@
|
||||
(* This file is part of the Catala compiler, a specification language for tax
|
||||
and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
|
||||
Denis Merigoux <denis.merigoux@inria.fr>
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not
|
||||
use this file except in compliance with the License. You may obtain a copy of
|
||||
the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
License for the specific language governing permissions and limitations under
|
||||
the License. *)
|
||||
|
||||
(** Translation from {!module: Desugared.Ast} to {!module: Scopelang.Ast} *)
|
||||
|
||||
open Catala_utils
|
||||
open Shared_ast
|
||||
|
||||
(** {1 Expression translation}*)
|
||||
|
||||
type target_scope_vars =
|
||||
| WholeVar of ScopeVar.t
|
||||
| States of (StateName.t * ScopeVar.t) list
|
||||
|
||||
type ctx = {
|
||||
decl_ctx : decl_ctx;
|
||||
scope_var_mapping : target_scope_vars ScopeVar.Map.t;
|
||||
var_mapping : (Desugared.Ast.expr, untyped Ast.expr Var.t) Var.Map.t;
|
||||
}
|
||||
|
||||
let tag_with_log_entry
|
||||
(e : untyped Ast.expr boxed)
|
||||
(l : log_entry)
|
||||
(markings : Uid.MarkedString.info list) : untyped Ast.expr boxed =
|
||||
Expr.eapp
|
||||
(Expr.eop (Log (l, markings)) [TAny, Expr.pos e] (Marked.get_mark e))
|
||||
[e] (Marked.get_mark e)
|
||||
|
||||
let rec translate_expr (ctx : ctx) (e : Desugared.Ast.expr) :
|
||||
untyped Ast.expr boxed =
|
||||
let m = Marked.get_mark e in
|
||||
match Marked.unmark e with
|
||||
| ELocation (SubScopeVar (s_name, ss_name, s_var)) ->
|
||||
(* When referring to a subscope variable in an expression, we are referring
|
||||
to the output, hence we take the last state. *)
|
||||
let new_s_var =
|
||||
match ScopeVar.Map.find (Marked.unmark s_var) ctx.scope_var_mapping with
|
||||
| WholeVar new_s_var -> Marked.same_mark_as new_s_var s_var
|
||||
| States states ->
|
||||
Marked.same_mark_as (snd (List.hd (List.rev states))) s_var
|
||||
in
|
||||
Expr.elocation (SubScopeVar (s_name, ss_name, new_s_var)) m
|
||||
| ELocation (DesugaredScopeVar (s_var, None)) ->
|
||||
Expr.elocation
|
||||
(ScopelangScopeVar
|
||||
(match
|
||||
ScopeVar.Map.find (Marked.unmark s_var) ctx.scope_var_mapping
|
||||
with
|
||||
| WholeVar new_s_var -> Marked.same_mark_as new_s_var s_var
|
||||
| States _ -> failwith "should not happen"))
|
||||
m
|
||||
| ELocation (DesugaredScopeVar (s_var, Some state)) ->
|
||||
Expr.elocation
|
||||
(ScopelangScopeVar
|
||||
(match
|
||||
ScopeVar.Map.find (Marked.unmark s_var) ctx.scope_var_mapping
|
||||
with
|
||||
| WholeVar _ -> failwith "should not happen"
|
||||
| States states -> Marked.same_mark_as (List.assoc state states) s_var))
|
||||
m
|
||||
| EVar v -> Expr.evar (Var.Map.find v ctx.var_mapping) m
|
||||
| EStruct { name; fields } ->
|
||||
Expr.estruct name (StructField.Map.map (translate_expr ctx) fields) m
|
||||
| EDStructAccess { name_opt = None; _ } ->
|
||||
(* Note: this could only happen if disambiguation was disabled. If we want
|
||||
to support it, we should still allow this case when the field has only
|
||||
one possible matching structure *)
|
||||
Errors.raise_spanned_error (Expr.mark_pos m)
|
||||
"Ambiguous structure field access"
|
||||
| EDStructAccess { e; field; name_opt = Some name } ->
|
||||
let e' = translate_expr ctx e in
|
||||
let field =
|
||||
try
|
||||
StructName.Map.find name
|
||||
(IdentName.Map.find field ctx.decl_ctx.ctx_struct_fields)
|
||||
with Not_found ->
|
||||
(* Should not happen after disambiguation *)
|
||||
Errors.raise_spanned_error (Expr.mark_pos m)
|
||||
"Field %s does not belong to structure %a" field StructName.format_t
|
||||
name
|
||||
in
|
||||
Expr.estructaccess e' field name m
|
||||
| EInj { e; cons; name } -> Expr.einj (translate_expr ctx e) cons name m
|
||||
| EMatch { e; name; cases } ->
|
||||
Expr.ematch (translate_expr ctx e) name
|
||||
(EnumConstructor.Map.map (translate_expr ctx) cases)
|
||||
m
|
||||
| EScopeCall { scope; args } ->
|
||||
Expr.escopecall scope
|
||||
(ScopeVar.Map.fold
|
||||
(fun v e args' ->
|
||||
let v' =
|
||||
match ScopeVar.Map.find v ctx.scope_var_mapping with
|
||||
| WholeVar v' -> v'
|
||||
| States ((_, v') :: _) ->
|
||||
(* When there are multiple states, the input is always the first
|
||||
one *)
|
||||
v'
|
||||
| States [] -> assert false
|
||||
in
|
||||
ScopeVar.Map.add v' (translate_expr ctx e) args')
|
||||
args ScopeVar.Map.empty)
|
||||
m
|
||||
| ELit
|
||||
(( LBool _ | LEmptyError | LInt _ | LRat _ | LMoney _ | LUnit | LDate _
|
||||
| LDuration _ ) as l) ->
|
||||
Expr.elit l m
|
||||
| EAbs { binder; tys } ->
|
||||
let vars, body = Bindlib.unmbind binder in
|
||||
let new_vars = Array.map (fun var -> Var.make (Bindlib.name_of var)) vars in
|
||||
let ctx =
|
||||
List.fold_left2
|
||||
(fun ctx var new_var ->
|
||||
{ ctx with var_mapping = Var.Map.add var new_var ctx.var_mapping })
|
||||
ctx (Array.to_list vars) (Array.to_list new_vars)
|
||||
in
|
||||
Expr.eabs (Expr.bind new_vars (translate_expr ctx body)) tys m
|
||||
| EApp { f = EOp { op; tys }, m1; args } ->
|
||||
let args = List.map (translate_expr ctx) args in
|
||||
Operator.kind_dispatch op
|
||||
~monomorphic:(fun op -> Expr.eapp (Expr.eop op tys m1) args m)
|
||||
~polymorphic:(fun op -> Expr.eapp (Expr.eop op tys m1) args m)
|
||||
~overloaded:(fun op ->
|
||||
match
|
||||
Operator.resolve_overload ctx.decl_ctx
|
||||
(Marked.mark (Expr.pos e) op)
|
||||
tys
|
||||
with
|
||||
| op, `Straight -> Expr.eapp (Expr.eop op tys m1) args m
|
||||
| op, `Reversed ->
|
||||
Expr.eapp (Expr.eop op (List.rev tys) m1) (List.rev args) m)
|
||||
| EOp _ -> assert false (* Only allowed within [EApp] *)
|
||||
| EApp { f; args } ->
|
||||
Expr.eapp (translate_expr ctx f) (List.map (translate_expr ctx) args) m
|
||||
| EDefault { excepts; just; cons } ->
|
||||
Expr.edefault
|
||||
(List.map (translate_expr ctx) excepts)
|
||||
(translate_expr ctx just) (translate_expr ctx cons) m
|
||||
| EIfThenElse { cond; etrue; efalse } ->
|
||||
Expr.eifthenelse (translate_expr ctx cond) (translate_expr ctx etrue)
|
||||
(translate_expr ctx efalse)
|
||||
m
|
||||
| EArray args -> Expr.earray (List.map (translate_expr ctx) args) m
|
||||
| EErrorOnEmpty e1 -> Expr.eerroronempty (translate_expr ctx e1) m
|
||||
|
||||
(** {1 Rule tree construction} *)
|
||||
|
||||
(** Intermediate representation for the exception tree of rules for a particular
|
||||
scope definition. *)
|
||||
type rule_tree =
|
||||
| Leaf of Desugared.Ast.rule list
|
||||
(** Rules defining a base case piecewise. List is non-empty. *)
|
||||
| Node of rule_tree list * Desugared.Ast.rule list
|
||||
(** [Node (exceptions, base_case)] is a list of exceptions to a non-empty
|
||||
list of rules defining a base case piecewise. *)
|
||||
|
||||
(** Transforms a flat list of rules into a tree, taking into account the
|
||||
priorities declared between rules *)
|
||||
let def_map_to_tree
|
||||
(def_info : Desugared.Ast.ScopeDef.t)
|
||||
(def : Desugared.Ast.rule RuleName.Map.t) : rule_tree list =
|
||||
let exc_graph = Desugared.Dependency.build_exceptions_graph def def_info in
|
||||
Desugared.Dependency.check_for_exception_cycle exc_graph;
|
||||
(* we start by the base cases: they are the vertices which have no
|
||||
successors *)
|
||||
let base_cases =
|
||||
Desugared.Dependency.ExceptionsDependencies.fold_vertex
|
||||
(fun v base_cases ->
|
||||
if
|
||||
Desugared.Dependency.ExceptionsDependencies.out_degree exc_graph v = 0
|
||||
then v :: base_cases
|
||||
else base_cases)
|
||||
exc_graph []
|
||||
in
|
||||
let rec build_tree (base_cases : RuleName.Set.t) : rule_tree =
|
||||
let exceptions =
|
||||
Desugared.Dependency.ExceptionsDependencies.pred exc_graph base_cases
|
||||
in
|
||||
let base_case_as_rule_list =
|
||||
List.map
|
||||
(fun r -> RuleName.Map.find r def)
|
||||
(RuleName.Set.elements base_cases)
|
||||
in
|
||||
match exceptions with
|
||||
| [] -> Leaf base_case_as_rule_list
|
||||
| _ -> Node (List.map build_tree exceptions, base_case_as_rule_list)
|
||||
in
|
||||
List.map build_tree base_cases
|
||||
|
||||
(** From the {!type: rule_tree}, builds an {!constructor: Dcalc.EDefault}
|
||||
expression in the scope language. The [~toplevel] parameter is used to know
|
||||
when to place the toplevel binding in the case of functions. *)
|
||||
let rec rule_tree_to_expr
|
||||
~(toplevel : bool)
|
||||
~(is_reentrant_var : bool)
|
||||
(ctx : ctx)
|
||||
(def_pos : Pos.t)
|
||||
(is_func : Desugared.Ast.expr Var.t option)
|
||||
(tree : rule_tree) : untyped Ast.expr boxed =
|
||||
let emark = Untyped { pos = def_pos } in
|
||||
let exceptions, base_rules =
|
||||
match tree with Leaf r -> [], r | Node (exceptions, r) -> exceptions, r
|
||||
in
|
||||
(* because each rule has its own variable parameter and we want to convert the
|
||||
whole rule tree into a function, we need to perform some alpha-renaming of
|
||||
all the expressions *)
|
||||
let substitute_parameter
|
||||
(e : Desugared.Ast.expr boxed)
|
||||
(rule : Desugared.Ast.rule) : Desugared.Ast.expr boxed =
|
||||
match is_func, rule.Desugared.Ast.rule_parameter with
|
||||
| Some new_param, Some (old_param, _) ->
|
||||
let binder = Bindlib.bind_var old_param (Marked.unmark e) in
|
||||
Marked.mark (Marked.get_mark e)
|
||||
@@ Bindlib.box_apply2
|
||||
(fun binder new_param -> Bindlib.subst binder new_param)
|
||||
binder
|
||||
(Bindlib.box_var new_param)
|
||||
| None, None -> e
|
||||
| _ -> assert false
|
||||
(* should not happen *)
|
||||
in
|
||||
let ctx =
|
||||
match is_func with
|
||||
| None -> ctx
|
||||
| Some new_param -> (
|
||||
match Var.Map.find_opt new_param ctx.var_mapping with
|
||||
| None ->
|
||||
let new_param_scope = Var.make (Bindlib.name_of new_param) in
|
||||
{
|
||||
ctx with
|
||||
var_mapping = Var.Map.add new_param new_param_scope ctx.var_mapping;
|
||||
}
|
||||
| Some _ ->
|
||||
(* We only create a mapping if none exists because [rule_tree_to_expr]
|
||||
is called recursively on the exceptions of the tree and we don't want
|
||||
to create a new Scopelang variable for the parameter at each tree
|
||||
level. *)
|
||||
ctx)
|
||||
in
|
||||
let base_just_list =
|
||||
List.map
|
||||
(fun rule -> substitute_parameter rule.Desugared.Ast.rule_just rule)
|
||||
base_rules
|
||||
in
|
||||
let base_cons_list =
|
||||
List.map
|
||||
(fun rule -> substitute_parameter rule.Desugared.Ast.rule_cons rule)
|
||||
base_rules
|
||||
in
|
||||
let translate_and_unbox_list (list : Desugared.Ast.expr boxed list) :
|
||||
untyped Ast.expr boxed list =
|
||||
List.map
|
||||
(fun e ->
|
||||
(* There are two levels of boxing here, the outermost is introduced by
|
||||
the [translate_expr] function for which all of the bindings should
|
||||
have been closed by now, so we can safely unbox. *)
|
||||
translate_expr ctx (Expr.unbox e))
|
||||
list
|
||||
in
|
||||
let default_containing_base_cases =
|
||||
Expr.make_default
|
||||
(List.map2
|
||||
(fun base_just base_cons ->
|
||||
Expr.make_default []
|
||||
(* Here we insert the logging command that records when a decision
|
||||
is taken for the value of a variable. *)
|
||||
(tag_with_log_entry base_just PosRecordIfTrueBool [])
|
||||
base_cons emark)
|
||||
(translate_and_unbox_list base_just_list)
|
||||
(translate_and_unbox_list base_cons_list))
|
||||
(Expr.elit (LBool false) emark)
|
||||
(Expr.elit LEmptyError emark)
|
||||
emark
|
||||
in
|
||||
let exceptions =
|
||||
List.map
|
||||
(rule_tree_to_expr ~toplevel:false ~is_reentrant_var ctx def_pos is_func)
|
||||
exceptions
|
||||
in
|
||||
let default =
|
||||
Expr.make_default exceptions
|
||||
(Expr.elit (LBool true) emark)
|
||||
default_containing_base_cases emark
|
||||
in
|
||||
match is_func, (List.hd base_rules).Desugared.Ast.rule_parameter with
|
||||
| None, None -> default
|
||||
| Some new_param, Some (_, typ) ->
|
||||
if toplevel then
|
||||
(* When we're creating a function from multiple defaults, we must check
|
||||
that the result returned by the function is not empty, unless we're
|
||||
dealing with a context variable which is reentrant (either in the
|
||||
caller or callee). In this case the ErrorOnEmpty will be added later in
|
||||
the scopelang->dcalc translation. *)
|
||||
let default =
|
||||
if is_reentrant_var then default else Expr.eerroronempty default emark
|
||||
in
|
||||
Expr.make_abs
|
||||
[| Var.Map.find new_param ctx.var_mapping |]
|
||||
default [typ] def_pos
|
||||
else default
|
||||
| _ -> (* should not happen *) assert false
|
||||
|
||||
(** {1 AST translation} *)
|
||||
|
||||
(** Translates a definition inside a scope, the resulting expression should be
|
||||
an {!constructor: Dcalc.EDefault} *)
|
||||
let translate_def
|
||||
(ctx : ctx)
|
||||
(def_info : Desugared.Ast.ScopeDef.t)
|
||||
(def : Desugared.Ast.rule RuleName.Map.t)
|
||||
(typ : typ)
|
||||
(io : Desugared.Ast.io)
|
||||
~(is_cond : bool)
|
||||
~(is_subscope_var : bool) : untyped Ast.expr boxed =
|
||||
(* Here, we have to transform this list of rules into a default tree. *)
|
||||
let is_def_func =
|
||||
match Marked.unmark typ with TArrow (_, _) -> true | _ -> false
|
||||
in
|
||||
let is_rule_func _ (r : Desugared.Ast.rule) : bool =
|
||||
Option.is_some r.Desugared.Ast.rule_parameter
|
||||
in
|
||||
let all_rules_func = RuleName.Map.for_all is_rule_func def in
|
||||
let all_rules_not_func =
|
||||
RuleName.Map.for_all (fun n r -> not (is_rule_func n r)) def
|
||||
in
|
||||
let is_def_func_param_typ : typ option =
|
||||
if is_def_func && all_rules_func then
|
||||
match Marked.unmark typ with
|
||||
| TArrow (t_param, _) -> Some t_param
|
||||
| _ ->
|
||||
Errors.raise_spanned_error (Marked.get_mark typ)
|
||||
"The definitions of %a are function but it doesn't have a function \
|
||||
type"
|
||||
Desugared.Ast.ScopeDef.format_t def_info
|
||||
else if (not is_def_func) && all_rules_not_func then None
|
||||
else
|
||||
let spans =
|
||||
List.map
|
||||
(fun (_, r) ->
|
||||
( Some "This definition is a function:",
|
||||
Expr.pos r.Desugared.Ast.rule_cons ))
|
||||
(RuleName.Map.bindings (RuleName.Map.filter is_rule_func def))
|
||||
@ List.map
|
||||
(fun (_, r) ->
|
||||
( Some "This definition is not a function:",
|
||||
Expr.pos r.Desugared.Ast.rule_cons ))
|
||||
(RuleName.Map.bindings
|
||||
(RuleName.Map.filter (fun n r -> not (is_rule_func n r)) def))
|
||||
in
|
||||
Errors.raise_multispanned_error spans
|
||||
"some definitions of the same variable are functions while others \
|
||||
aren't"
|
||||
in
|
||||
let top_list = def_map_to_tree def_info def in
|
||||
let is_input =
|
||||
match Marked.unmark io.Desugared.Ast.io_input with
|
||||
| OnlyInput -> true
|
||||
| _ -> false
|
||||
in
|
||||
let is_reentrant =
|
||||
match Marked.unmark io.Desugared.Ast.io_input with
|
||||
| Reentrant -> true
|
||||
| _ -> false
|
||||
in
|
||||
let top_value =
|
||||
if is_cond && ((not is_subscope_var) || (is_subscope_var && is_input)) then
|
||||
(* We add the bottom [false] value for conditions, only for the scope
|
||||
where the condition is declared. Except when the variable is an input,
|
||||
where we want the [false] to be added at each caller parent scope. *)
|
||||
Some
|
||||
(Desugared.Ast.always_false_rule
|
||||
(Desugared.Ast.ScopeDef.get_position def_info)
|
||||
is_def_func_param_typ)
|
||||
else None
|
||||
in
|
||||
if
|
||||
RuleName.Map.cardinal def = 0
|
||||
&& is_subscope_var
|
||||
(* Here we have a special case for the empty definitions. Indeed, we could
|
||||
use the code for the regular case below that would create a convoluted
|
||||
default always returning empty error, and this would be correct. But it
|
||||
gets more complicated with functions. Indeed, if we create an empty
|
||||
definition for a subscope argument whose type is a function, we get
|
||||
something like [fun () -> (fun real_param -> < ... >)] that is passed as
|
||||
an argument to the subscope. The sub-scope de-thunks but the de-thunking
|
||||
does not return empty error, signalling there is not reentrant variable,
|
||||
because functions are values! So the subscope does not see that there is
|
||||
not reentrant variable and does not pick its internal definition instead.
|
||||
See [test/test_scope/subscope_function_arg_not_defined.catala_en] for a
|
||||
test case exercising that subtlety.
|
||||
|
||||
To avoid this complication we special case here and put an empty error
|
||||
for all subscope variables that are not defined. It covers the subtlety
|
||||
with functions described above but also conditions with the false default
|
||||
value. *)
|
||||
&& not (is_cond && is_input)
|
||||
(* However, this special case suffers from an exception: when a condition is
|
||||
defined as an OnlyInput to a subscope, since the [false] default value
|
||||
will not be provided by the calee scope, it has to be placed in the
|
||||
caller. *)
|
||||
then
|
||||
let m = Untyped { pos = Desugared.Ast.ScopeDef.get_position def_info } in
|
||||
let empty_error = Expr.elit LEmptyError m in
|
||||
match is_def_func_param_typ with
|
||||
| Some ty ->
|
||||
Expr.make_abs [| Var.make "_" |] empty_error [ty] (Expr.mark_pos m)
|
||||
| _ -> empty_error
|
||||
else
|
||||
rule_tree_to_expr ~toplevel:true ~is_reentrant_var:is_reentrant ctx
|
||||
(Desugared.Ast.ScopeDef.get_position def_info)
|
||||
(Option.map (fun _ -> Var.make "param") is_def_func_param_typ)
|
||||
(match top_list, top_value with
|
||||
| [], None ->
|
||||
(* In this case, there are no rules to define the expression and no
|
||||
default value so we put an empty rule. *)
|
||||
Leaf
|
||||
[Desugared.Ast.empty_rule (Marked.get_mark typ) is_def_func_param_typ]
|
||||
| [], Some top_value ->
|
||||
(* In this case, there are no rules to define the expression but a
|
||||
default value so we put it. *)
|
||||
Leaf [top_value]
|
||||
| _, Some top_value ->
|
||||
(* When there are rules + a default value, we put the rules as
|
||||
exceptions to the default value *)
|
||||
Node (top_list, [top_value])
|
||||
| [top_tree], None -> top_tree
|
||||
| _, None ->
|
||||
Node
|
||||
( top_list,
|
||||
[
|
||||
Desugared.Ast.empty_rule (Marked.get_mark typ)
|
||||
is_def_func_param_typ;
|
||||
] ))
|
||||
|
||||
let translate_rule ctx (scope : Desugared.Ast.scope) = function
|
||||
| Desugared.Dependency.Vertex.Var (var, state) -> (
|
||||
let scope_def =
|
||||
Desugared.Ast.ScopeDefMap.find
|
||||
(Desugared.Ast.ScopeDef.Var (var, state))
|
||||
scope.scope_defs
|
||||
in
|
||||
let var_def = scope_def.scope_def_rules in
|
||||
let var_typ = scope_def.scope_def_typ in
|
||||
let is_cond = scope_def.scope_def_is_condition in
|
||||
match Marked.unmark scope_def.Desugared.Ast.scope_def_io.io_input with
|
||||
| OnlyInput when not (RuleName.Map.is_empty var_def) ->
|
||||
(* If the variable is tagged as input, then it shall not be redefined. *)
|
||||
Errors.raise_multispanned_error
|
||||
((Some "Incriminated variable:", Marked.get_mark (ScopeVar.get_info var))
|
||||
:: List.map
|
||||
(fun (rule, _) ->
|
||||
( Some "Incriminated variable definition:",
|
||||
Marked.get_mark (RuleName.get_info rule) ))
|
||||
(RuleName.Map.bindings var_def))
|
||||
"It is impossible to give a definition to a scope variable tagged as \
|
||||
input."
|
||||
| OnlyInput -> []
|
||||
(* we do not provide any definition for an input-only variable *)
|
||||
| _ ->
|
||||
let expr_def =
|
||||
translate_def ctx
|
||||
(Desugared.Ast.ScopeDef.Var (var, state))
|
||||
var_def var_typ scope_def.Desugared.Ast.scope_def_io ~is_cond
|
||||
~is_subscope_var:false
|
||||
in
|
||||
let scope_var =
|
||||
match ScopeVar.Map.find var ctx.scope_var_mapping, state with
|
||||
| WholeVar v, None -> v
|
||||
| States states, Some state -> List.assoc state states
|
||||
| _ -> failwith "should not happen"
|
||||
in
|
||||
[
|
||||
Ast.Definition
|
||||
( ( ScopelangScopeVar
|
||||
(scope_var, Marked.get_mark (ScopeVar.get_info scope_var)),
|
||||
Marked.get_mark (ScopeVar.get_info scope_var) ),
|
||||
var_typ,
|
||||
scope_def.Desugared.Ast.scope_def_io,
|
||||
Expr.unbox expr_def );
|
||||
])
|
||||
| Desugared.Dependency.Vertex.SubScope sub_scope_index ->
|
||||
(* Before calling the sub_scope, we need to include all the re-definitions
|
||||
of subscope parameters*)
|
||||
let sub_scope =
|
||||
SubScopeName.Map.find sub_scope_index scope.scope_sub_scopes
|
||||
in
|
||||
let sub_scope_vars_redefs_candidates =
|
||||
Desugared.Ast.ScopeDefMap.filter
|
||||
(fun def_key scope_def ->
|
||||
match def_key with
|
||||
| Desugared.Ast.ScopeDef.Var _ -> false
|
||||
| Desugared.Ast.ScopeDef.SubScopeVar (sub_scope_index', _, _) ->
|
||||
sub_scope_index = sub_scope_index'
|
||||
(* We exclude subscope variables that have 0 re-definitions and are
|
||||
not visible in the input of the subscope *)
|
||||
&& not
|
||||
((match
|
||||
Marked.unmark scope_def.Desugared.Ast.scope_def_io.io_input
|
||||
with
|
||||
| Desugared.Ast.NoInput -> true
|
||||
| _ -> false)
|
||||
&& RuleName.Map.is_empty scope_def.scope_def_rules))
|
||||
scope.scope_defs
|
||||
in
|
||||
let sub_scope_vars_redefs =
|
||||
Desugared.Ast.ScopeDefMap.mapi
|
||||
(fun def_key scope_def ->
|
||||
let def = scope_def.Desugared.Ast.scope_def_rules in
|
||||
let def_typ = scope_def.scope_def_typ in
|
||||
let is_cond = scope_def.scope_def_is_condition in
|
||||
match def_key with
|
||||
| Desugared.Ast.ScopeDef.Var _ -> assert false (* should not happen *)
|
||||
| Desugared.Ast.ScopeDef.SubScopeVar (sscope, sub_scope_var, pos) ->
|
||||
(* This definition redefines a variable of the correct subscope. But
|
||||
we have to check that this redefinition is allowed with respect
|
||||
to the io parameters of that subscope variable. *)
|
||||
(match
|
||||
Marked.unmark scope_def.Desugared.Ast.scope_def_io.io_input
|
||||
with
|
||||
| Desugared.Ast.NoInput ->
|
||||
Errors.raise_multispanned_error
|
||||
(( Some "Incriminated subscope:",
|
||||
Marked.get_mark (SubScopeName.get_info sscope) )
|
||||
:: ( Some "Incriminated variable:",
|
||||
Marked.get_mark (ScopeVar.get_info sub_scope_var) )
|
||||
:: List.map
|
||||
(fun (rule, _) ->
|
||||
( Some "Incriminated subscope variable definition:",
|
||||
Marked.get_mark (RuleName.get_info rule) ))
|
||||
(RuleName.Map.bindings def))
|
||||
"It is impossible to give a definition to a subscope variable \
|
||||
not tagged as input or context."
|
||||
| OnlyInput when RuleName.Map.is_empty def && not is_cond ->
|
||||
(* If the subscope variable is tagged as input, then it shall be
|
||||
defined. *)
|
||||
Errors.raise_multispanned_error
|
||||
[
|
||||
( Some "Incriminated subscope:",
|
||||
Marked.get_mark (SubScopeName.get_info sscope) );
|
||||
Some "Incriminated variable:", pos;
|
||||
]
|
||||
"This subscope variable is a mandatory input but no definition \
|
||||
was provided."
|
||||
| _ -> ());
|
||||
(* Now that all is good, we can proceed with translating this
|
||||
redefinition to a proper Scopelang term. *)
|
||||
let expr_def =
|
||||
translate_def ctx def_key def def_typ
|
||||
scope_def.Desugared.Ast.scope_def_io ~is_cond
|
||||
~is_subscope_var:true
|
||||
in
|
||||
let subscop_real_name =
|
||||
SubScopeName.Map.find sub_scope_index scope.scope_sub_scopes
|
||||
in
|
||||
let var_pos = Desugared.Ast.ScopeDef.get_position def_key in
|
||||
Ast.Definition
|
||||
( ( SubScopeVar
|
||||
( subscop_real_name,
|
||||
(sub_scope_index, var_pos),
|
||||
match
|
||||
ScopeVar.Map.find sub_scope_var ctx.scope_var_mapping
|
||||
with
|
||||
| WholeVar v -> v, var_pos
|
||||
| States states ->
|
||||
(* When defining a sub-scope variable, we always define
|
||||
its first state in the sub-scope. *)
|
||||
snd (List.hd states), var_pos ),
|
||||
var_pos ),
|
||||
def_typ,
|
||||
scope_def.Desugared.Ast.scope_def_io,
|
||||
Expr.unbox expr_def ))
|
||||
sub_scope_vars_redefs_candidates
|
||||
in
|
||||
let sub_scope_vars_redefs =
|
||||
List.map snd (Desugared.Ast.ScopeDefMap.bindings sub_scope_vars_redefs)
|
||||
in
|
||||
sub_scope_vars_redefs
|
||||
@ [
|
||||
Ast.Call
|
||||
( sub_scope,
|
||||
sub_scope_index,
|
||||
Untyped
|
||||
{ pos = Marked.get_mark (SubScopeName.get_info sub_scope_index) }
|
||||
);
|
||||
]
|
||||
|
||||
(** Translates a scope *)
|
||||
let translate_scope (ctx : ctx) (scope : Desugared.Ast.scope) :
|
||||
untyped Ast.scope_decl =
|
||||
let scope_dependencies =
|
||||
Desugared.Dependency.build_scope_dependencies scope
|
||||
in
|
||||
Desugared.Dependency.check_for_cycle scope scope_dependencies;
|
||||
let scope_ordering =
|
||||
Desugared.Dependency.correct_computation_ordering scope_dependencies
|
||||
in
|
||||
let scope_decl_rules =
|
||||
List.flatten (List.map (translate_rule ctx scope) scope_ordering)
|
||||
in
|
||||
(* Then, after having computed all the scopes variables, we add the
|
||||
assertions. TODO: the assertions should be interleaved with the
|
||||
definitions! *)
|
||||
let scope_decl_rules =
|
||||
scope_decl_rules
|
||||
@ List.map
|
||||
(fun e ->
|
||||
let scope_e = translate_expr ctx (Expr.unbox e) in
|
||||
Ast.Assertion (Expr.unbox scope_e))
|
||||
scope.Desugared.Ast.scope_assertions
|
||||
in
|
||||
let scope_sig =
|
||||
ScopeVar.Map.fold
|
||||
(fun var (states : Desugared.Ast.var_or_states) acc ->
|
||||
match states with
|
||||
| WholeVar ->
|
||||
let scope_def =
|
||||
Desugared.Ast.ScopeDefMap.find
|
||||
(Desugared.Ast.ScopeDef.Var (var, None))
|
||||
scope.scope_defs
|
||||
in
|
||||
let typ = scope_def.scope_def_typ in
|
||||
ScopeVar.Map.add
|
||||
(match ScopeVar.Map.find var ctx.scope_var_mapping with
|
||||
| WholeVar v -> v
|
||||
| States _ -> failwith "should not happen")
|
||||
(typ, scope_def.scope_def_io)
|
||||
acc
|
||||
| States states ->
|
||||
(* What happens in the case of variables with multiple states is
|
||||
interesting. We need to create as many Var entries in the scope
|
||||
signature as there are states. *)
|
||||
List.fold_left
|
||||
(fun acc (state : StateName.t) ->
|
||||
let scope_def =
|
||||
Desugared.Ast.ScopeDefMap.find
|
||||
(Desugared.Ast.ScopeDef.Var (var, Some state))
|
||||
scope.scope_defs
|
||||
in
|
||||
ScopeVar.Map.add
|
||||
(match ScopeVar.Map.find var ctx.scope_var_mapping with
|
||||
| WholeVar _ -> failwith "should not happen"
|
||||
| States states' -> List.assoc state states')
|
||||
(scope_def.scope_def_typ, scope_def.scope_def_io)
|
||||
acc)
|
||||
acc states)
|
||||
scope.scope_vars ScopeVar.Map.empty
|
||||
in
|
||||
let pos = Marked.get_mark (ScopeName.get_info scope.scope_uid) in
|
||||
{
|
||||
Ast.scope_decl_name = scope.scope_uid;
|
||||
Ast.scope_decl_rules;
|
||||
Ast.scope_sig;
|
||||
Ast.scope_mark = Untyped { pos };
|
||||
}
|
||||
|
||||
(** {1 API} *)
|
||||
|
||||
let translate_program (pgrm : Desugared.Ast.program) : untyped Ast.program =
|
||||
(* First we give mappings to all the locations between Desugared and This
|
||||
involves creating a new Scopelang scope variable for every state of a
|
||||
Desugared variable. *)
|
||||
let ctx =
|
||||
(* Todo: since we rename all scope vars at this point, it would be better to
|
||||
have different types for Desugared.ScopeVar.t and Scopelang.ScopeVar.t *)
|
||||
ScopeName.Map.fold
|
||||
(fun _scope scope_decl ctx ->
|
||||
ScopeVar.Map.fold
|
||||
(fun scope_var (states : Desugared.Ast.var_or_states) ctx ->
|
||||
let var_name, var_pos = ScopeVar.get_info scope_var in
|
||||
let new_var =
|
||||
match states with
|
||||
| Desugared.Ast.WholeVar ->
|
||||
WholeVar (ScopeVar.fresh (var_name, var_pos))
|
||||
| States states ->
|
||||
let var_prefix = var_name ^ "_" in
|
||||
let state_var state =
|
||||
ScopeVar.fresh
|
||||
(Marked.map_under_mark (( ^ ) var_prefix)
|
||||
(StateName.get_info state))
|
||||
in
|
||||
States (List.map (fun state -> state, state_var state) states)
|
||||
in
|
||||
{
|
||||
ctx with
|
||||
scope_var_mapping =
|
||||
ScopeVar.Map.add scope_var new_var ctx.scope_var_mapping;
|
||||
})
|
||||
scope_decl.Desugared.Ast.scope_vars ctx)
|
||||
pgrm.Desugared.Ast.program_scopes
|
||||
{
|
||||
scope_var_mapping = ScopeVar.Map.empty;
|
||||
var_mapping = Var.Map.empty;
|
||||
decl_ctx = pgrm.program_ctx;
|
||||
}
|
||||
in
|
||||
let ctx_scopes =
|
||||
ScopeName.Map.map
|
||||
(fun out_str ->
|
||||
let out_struct_fields =
|
||||
ScopeVar.Map.fold
|
||||
(fun var fld out_map ->
|
||||
let var' =
|
||||
match ScopeVar.Map.find var ctx.scope_var_mapping with
|
||||
| WholeVar v -> v
|
||||
| States l -> snd (List.hd (List.rev l))
|
||||
in
|
||||
ScopeVar.Map.add var' fld out_map)
|
||||
out_str.out_struct_fields ScopeVar.Map.empty
|
||||
in
|
||||
{ out_str with out_struct_fields })
|
||||
pgrm.Desugared.Ast.program_ctx.ctx_scopes
|
||||
in
|
||||
{
|
||||
Ast.program_scopes =
|
||||
ScopeName.Map.map (translate_scope ctx) pgrm.program_scopes;
|
||||
program_ctx = { pgrm.program_ctx with ctx_scopes };
|
||||
}
|
@ -16,4 +16,4 @@
|
||||
|
||||
(** Translation from {!module: Desugared.Ast} to {!module: Scopelang.Ast} *)
|
||||
|
||||
val translate_program : Ast.program -> Shared_ast.untyped Scopelang.Ast.program
|
||||
val translate_program : Desugared.Ast.program -> Shared_ast.untyped Ast.program
|
@ -14,7 +14,7 @@
|
||||
License for the specific language governing permissions and limitations under
|
||||
the License. *)
|
||||
|
||||
open Utils
|
||||
open Catala_utils
|
||||
open Shared_ast
|
||||
open Ast
|
||||
|
||||
@ -22,21 +22,22 @@ let struc
|
||||
ctx
|
||||
(fmt : Format.formatter)
|
||||
(name : StructName.t)
|
||||
(fields : (StructFieldName.t * typ) list) : unit =
|
||||
(fields : typ StructField.Map.t) : unit =
|
||||
Format.fprintf fmt "%a %a %a %a@\n@[<hov 2> %a@]@\n%a" Print.keyword "struct"
|
||||
StructName.format_t name Print.punctuation "=" Print.punctuation "{"
|
||||
(Format.pp_print_list
|
||||
~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n")
|
||||
(fun fmt (field_name, typ) ->
|
||||
Format.fprintf fmt "%a%a %a" StructFieldName.format_t field_name
|
||||
Format.fprintf fmt "%a%a %a" StructField.format_t field_name
|
||||
Print.punctuation ":" (Print.typ ctx) typ))
|
||||
fields Print.punctuation "}"
|
||||
(StructField.Map.bindings fields)
|
||||
Print.punctuation "}"
|
||||
|
||||
let enum
|
||||
ctx
|
||||
(fmt : Format.formatter)
|
||||
(name : EnumName.t)
|
||||
(cases : (EnumConstructor.t * typ) list) : unit =
|
||||
(cases : typ EnumConstructor.Map.t) : unit =
|
||||
Format.fprintf fmt "%a %a %a @\n@[<hov 2> %a@]" Print.keyword "enum"
|
||||
EnumName.format_t name Print.punctuation "="
|
||||
(Format.pp_print_list
|
||||
@ -45,7 +46,7 @@ let enum
|
||||
Format.fprintf fmt "%a %a%a %a" Print.punctuation "|"
|
||||
EnumConstructor.format_t field_name Print.punctuation ":"
|
||||
(Print.typ ctx) typ))
|
||||
cases
|
||||
(EnumConstructor.Map.bindings cases)
|
||||
|
||||
let scope ?(debug = false) ctx fmt (name, decl) =
|
||||
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@ %a@ %a@]@\n@[<v 2> %a@]"
|
||||
@ -55,16 +56,16 @@ let scope ?(debug = false) ctx fmt (name, decl) =
|
||||
Format.fprintf fmt "%a%a%a %a%a%a%a%a" Print.punctuation "("
|
||||
ScopeVar.format_t scope_var Print.punctuation ":" (Print.typ ctx) typ
|
||||
Print.punctuation "|" Print.keyword
|
||||
(match Marked.unmark vis.io_input with
|
||||
(match Marked.unmark vis.Desugared.Ast.io_input with
|
||||
| NoInput -> "internal"
|
||||
| OnlyInput -> "input"
|
||||
| Reentrant -> "context")
|
||||
(if Marked.unmark vis.io_output then fun fmt () ->
|
||||
(if Marked.unmark vis.Desugared.Ast.io_output then fun fmt () ->
|
||||
Format.fprintf fmt "%a@,%a" Print.punctuation "|" Print.keyword
|
||||
"output"
|
||||
else fun fmt () -> Format.fprintf fmt "@<0>")
|
||||
() Print.punctuation ")"))
|
||||
(ScopeVarMap.bindings decl.scope_sig)
|
||||
(ScopeVar.Map.bindings decl.scope_sig)
|
||||
Print.punctuation "="
|
||||
(Format.pp_print_list
|
||||
~pp_sep:(fun fmt () -> Format.fprintf fmt "%a@ " Print.punctuation ";")
|
||||
@ -80,11 +81,11 @@ let scope ?(debug = false) ctx fmt (name, decl) =
|
||||
| ScopelangScopeVar v -> (
|
||||
match
|
||||
Marked.unmark
|
||||
(snd (ScopeVarMap.find (Marked.unmark v) decl.scope_sig))
|
||||
(snd (ScopeVar.Map.find (Marked.unmark v) decl.scope_sig))
|
||||
.io_input
|
||||
with
|
||||
| Reentrant ->
|
||||
Format.fprintf fmt "%a@ %a" Print.operator
|
||||
Format.fprintf fmt "%a@ %a" Print.op_style
|
||||
"reentrant or by default" (Print.expr ~debug ctx) e
|
||||
| _ -> Format.fprintf fmt "%a" (Print.expr ~debug ctx) e))
|
||||
e
|
||||
@ -105,16 +106,16 @@ let program ?(debug : bool = false) (fmt : Format.formatter) (p : 'm program) :
|
||||
Format.pp_print_cut fmt ()
|
||||
in
|
||||
Format.pp_open_vbox fmt 0;
|
||||
StructMap.iter
|
||||
StructName.Map.iter
|
||||
(fun n s ->
|
||||
struc ctx fmt n s;
|
||||
pp_sep fmt ())
|
||||
ctx.ctx_structs;
|
||||
EnumMap.iter
|
||||
EnumName.Map.iter
|
||||
(fun n e ->
|
||||
enum ctx fmt n e;
|
||||
pp_sep fmt ())
|
||||
ctx.ctx_enums;
|
||||
Format.pp_print_list ~pp_sep (scope ~debug ctx) fmt
|
||||
(ScopeMap.bindings p.program_scopes);
|
||||
(ScopeName.Map.bindings p.program_scopes);
|
||||
Format.pp_close_box fmt ()
|
||||
|
@ -20,59 +20,47 @@
|
||||
|
||||
(* Doesn't define values, so OK to have without an mli *)
|
||||
|
||||
open Utils
|
||||
open Catala_utils
|
||||
module Runtime = Runtime_ocaml.Runtime
|
||||
module ScopeName = Uid.Gen ()
|
||||
module StructName = Uid.Gen ()
|
||||
module StructField = Uid.Gen ()
|
||||
module EnumName = Uid.Gen ()
|
||||
module EnumConstructor = Uid.Gen ()
|
||||
|
||||
module ScopeName : Uid.Id with type info = Uid.MarkedString.info =
|
||||
Uid.Make (Uid.MarkedString) ()
|
||||
(** Only used by surface *)
|
||||
|
||||
module ScopeSet : Set.S with type elt = ScopeName.t = Set.Make (ScopeName)
|
||||
module ScopeMap : Map.S with type key = ScopeName.t = Map.Make (ScopeName)
|
||||
module RuleName = Uid.Gen ()
|
||||
module LabelName = Uid.Gen ()
|
||||
|
||||
module StructName : Uid.Id with type info = Uid.MarkedString.info =
|
||||
Uid.Make (Uid.MarkedString) ()
|
||||
(** Used for unresolved structs/maps in desugared *)
|
||||
|
||||
module StructFieldName : Uid.Id with type info = Uid.MarkedString.info =
|
||||
Uid.Make (Uid.MarkedString) ()
|
||||
|
||||
module StructMap : Map.S with type key = StructName.t = Map.Make (StructName)
|
||||
|
||||
module EnumName : Uid.Id with type info = Uid.MarkedString.info =
|
||||
Uid.Make (Uid.MarkedString) ()
|
||||
|
||||
module EnumConstructor : Uid.Id with type info = Uid.MarkedString.info =
|
||||
Uid.Make (Uid.MarkedString) ()
|
||||
|
||||
module EnumMap : Map.S with type key = EnumName.t = Map.Make (EnumName)
|
||||
module IdentName = String
|
||||
|
||||
(** Only used by desugared/scopelang *)
|
||||
|
||||
module ScopeVar : Uid.Id with type info = Uid.MarkedString.info =
|
||||
Uid.Make (Uid.MarkedString) ()
|
||||
|
||||
module ScopeVarSet : Set.S with type elt = ScopeVar.t = Set.Make (ScopeVar)
|
||||
module ScopeVarMap : Map.S with type key = ScopeVar.t = Map.Make (ScopeVar)
|
||||
|
||||
module SubScopeName : Uid.Id with type info = Uid.MarkedString.info =
|
||||
Uid.Make (Uid.MarkedString) ()
|
||||
|
||||
module SubScopeNameSet : Set.S with type elt = SubScopeName.t =
|
||||
Set.Make (SubScopeName)
|
||||
|
||||
module SubScopeMap : Map.S with type key = SubScopeName.t =
|
||||
Map.Make (SubScopeName)
|
||||
|
||||
module StructFieldMap : Map.S with type key = StructFieldName.t =
|
||||
Map.Make (StructFieldName)
|
||||
|
||||
module EnumConstructorMap : Map.S with type key = EnumConstructor.t =
|
||||
Map.Make (EnumConstructor)
|
||||
|
||||
module StateName : Uid.Id with type info = Uid.MarkedString.info =
|
||||
Uid.Make (Uid.MarkedString) ()
|
||||
module ScopeVar = Uid.Gen ()
|
||||
module SubScopeName = Uid.Gen ()
|
||||
module StateName = Uid.Gen ()
|
||||
|
||||
(** {1 Abstract syntax tree} *)
|
||||
|
||||
(** Define a common base type for the expressions in most passes of the compiler *)
|
||||
|
||||
type desugared = [ `Desugared ]
|
||||
(** {2 Phantom types used to select relevant cases on the generic AST}
|
||||
|
||||
we instantiate them with a polymorphic variant to take advantage of
|
||||
sub-typing. The values aren't actually used. *)
|
||||
|
||||
type scopelang = [ `Scopelang ]
|
||||
type dcalc = [ `Dcalc ]
|
||||
type lcalc = [ `Lcalc ]
|
||||
|
||||
type 'a any = [< desugared | scopelang | dcalc | lcalc ] as 'a
|
||||
(** ['a any] is 'a, but adds the constraint that it should be restricted to
|
||||
valid AST kinds *)
|
||||
|
||||
(** {2 Types} *)
|
||||
|
||||
type typ_lit = TBool | TUnit | TInt | TRat | TMoney | TDate | TDuration
|
||||
@ -94,33 +82,6 @@ and naked_typ =
|
||||
type date = Runtime.date
|
||||
type duration = Runtime.duration
|
||||
|
||||
type op_kind =
|
||||
| KInt
|
||||
| KRat
|
||||
| KMoney
|
||||
| KDate
|
||||
| KDuration (** All ops don't have a KDate and KDuration. *)
|
||||
|
||||
type ternop = Fold
|
||||
|
||||
type binop =
|
||||
| And
|
||||
| Or
|
||||
| Xor
|
||||
| Add of op_kind
|
||||
| Sub of op_kind
|
||||
| Mult of op_kind
|
||||
| Div of op_kind
|
||||
| Lt of op_kind
|
||||
| Lte of op_kind
|
||||
| Gt of op_kind
|
||||
| Gte of op_kind
|
||||
| Eq
|
||||
| Neq
|
||||
| Map
|
||||
| Concat
|
||||
| Filter
|
||||
|
||||
type log_entry =
|
||||
| VarDef of naked_typ
|
||||
(** During code generation, we need to know the type of the variable being
|
||||
@ -129,35 +90,140 @@ type log_entry =
|
||||
| EndCall
|
||||
| PosRecordIfTrueBool
|
||||
|
||||
type unop =
|
||||
| Not
|
||||
| Minus of op_kind
|
||||
| Log of log_entry * Uid.MarkedString.info list
|
||||
| Length
|
||||
| IntToRat
|
||||
| MoneyToRat
|
||||
| RatToMoney
|
||||
| GetDay
|
||||
| GetMonth
|
||||
| GetYear
|
||||
| FirstDayOfMonth
|
||||
| LastDayOfMonth
|
||||
| RoundMoney
|
||||
| RoundDecimal
|
||||
module Op = struct
|
||||
(** Classification of operators on how they should be typed *)
|
||||
|
||||
type operator = Ternop of ternop | Binop of binop | Unop of unop
|
||||
type monomorphic =
|
||||
| Monomorphic (** Operands and return types of the operator are fixed *)
|
||||
|
||||
type polymorphic =
|
||||
| Polymorphic
|
||||
(** The operator is truly polymorphic: it's the same runtime function
|
||||
that may work on multiple types. We require that resolving the
|
||||
argument types from right to left trivially resolves all type
|
||||
variables declared in the operator type. *)
|
||||
|
||||
type overloaded =
|
||||
| Overloaded
|
||||
(** The operator is ambiguous and requires the types of its arguments to
|
||||
be known before it can be typed, using a pre-defined table *)
|
||||
|
||||
type resolved =
|
||||
| Resolved (** Explicit monomorphic versions of the overloaded operators *)
|
||||
|
||||
(** Classification of operators. This could be inlined in the definition of
|
||||
[t] but is more concise this way *)
|
||||
type (_, _) kind =
|
||||
| Monomorphic : ('a any, monomorphic) kind
|
||||
| Polymorphic : ('a any, polymorphic) kind
|
||||
| Overloaded : ([< desugared ], overloaded) kind
|
||||
| Resolved : ([< scopelang | dcalc | lcalc ], resolved) kind
|
||||
|
||||
type (_, _) t =
|
||||
(* unary *)
|
||||
(* * monomorphic *)
|
||||
| Not : ('a any, monomorphic) t
|
||||
| GetDay : ('a any, monomorphic) t
|
||||
| GetMonth : ('a any, monomorphic) t
|
||||
| GetYear : ('a any, monomorphic) t
|
||||
| FirstDayOfMonth : ('a any, monomorphic) t
|
||||
| LastDayOfMonth : ('a any, monomorphic) t
|
||||
(* * polymorphic *)
|
||||
| Length : ('a any, polymorphic) t
|
||||
| Log : log_entry * Uid.MarkedString.info list -> ('a any, polymorphic) t
|
||||
(* * overloaded *)
|
||||
| Minus : (desugared, overloaded) t
|
||||
| Minus_int : ([< scopelang | dcalc | lcalc ], resolved) t
|
||||
| Minus_rat : ([< scopelang | dcalc | lcalc ], resolved) t
|
||||
| Minus_mon : ([< scopelang | dcalc | lcalc ], resolved) t
|
||||
| Minus_dur : ([< scopelang | dcalc | lcalc ], resolved) t
|
||||
| ToRat : (desugared, overloaded) t
|
||||
| ToRat_int : ([< scopelang | dcalc | lcalc ], resolved) t
|
||||
| ToRat_mon : ([< scopelang | dcalc | lcalc ], resolved) t
|
||||
| ToMoney : (desugared, overloaded) t
|
||||
| ToMoney_rat : ([< scopelang | dcalc | lcalc ], resolved) t
|
||||
| Round : (desugared, overloaded) t
|
||||
| Round_rat : ([< scopelang | dcalc | lcalc ], resolved) t
|
||||
| Round_mon : ([< scopelang | dcalc | lcalc ], resolved) t
|
||||
(* binary *)
|
||||
(* * monomorphic *)
|
||||
| And : ('a any, monomorphic) t
|
||||
| Or : ('a any, monomorphic) t
|
||||
| Xor : ('a any, monomorphic) t
|
||||
(* * polymorphic *)
|
||||
| Eq : ('a any, polymorphic) t
|
||||
| Map : ('a any, polymorphic) t
|
||||
| Concat : ('a any, polymorphic) t
|
||||
| Filter : ('a any, polymorphic) t
|
||||
| Reduce : ('a any, polymorphic) t
|
||||
(* * overloaded *)
|
||||
| Add : (desugared, overloaded) t
|
||||
| Add_int_int : ([< scopelang | dcalc | lcalc ], resolved) t
|
||||
| Add_rat_rat : ([< scopelang | dcalc | lcalc ], resolved) t
|
||||
| Add_mon_mon : ([< scopelang | dcalc | lcalc ], resolved) t
|
||||
| Add_dat_dur : ([< scopelang | dcalc | lcalc ], resolved) t
|
||||
| Add_dur_dur : ([< scopelang | dcalc | lcalc ], resolved) t
|
||||
| Sub : (desugared, overloaded) t
|
||||
| Sub_int_int : ([< scopelang | dcalc | lcalc ], resolved) t
|
||||
| Sub_rat_rat : ([< scopelang | dcalc | lcalc ], resolved) t
|
||||
| Sub_mon_mon : ([< scopelang | dcalc | lcalc ], resolved) t
|
||||
| Sub_dat_dat : ([< scopelang | dcalc | lcalc ], resolved) t
|
||||
| Sub_dat_dur : ([< scopelang | dcalc | lcalc ], resolved) t
|
||||
| Sub_dur_dur : ([< scopelang | dcalc | lcalc ], resolved) t
|
||||
| Mult : (desugared, overloaded) t
|
||||
| Mult_int_int : ([< scopelang | dcalc | lcalc ], resolved) t
|
||||
| Mult_rat_rat : ([< scopelang | dcalc | lcalc ], resolved) t
|
||||
| Mult_mon_rat : ([< scopelang | dcalc | lcalc ], resolved) t
|
||||
| Mult_dur_int : ([< scopelang | dcalc | lcalc ], resolved) t
|
||||
| Div : (desugared, overloaded) t
|
||||
| Div_int_int : ([< scopelang | dcalc | lcalc ], resolved) t
|
||||
| Div_rat_rat : ([< scopelang | dcalc | lcalc ], resolved) t
|
||||
| Div_mon_rat : ([< scopelang | dcalc | lcalc ], resolved) t
|
||||
| Div_mon_mon : ([< scopelang | dcalc | lcalc ], resolved) t
|
||||
| Lt : (desugared, overloaded) t
|
||||
| Lt_int_int : ([< scopelang | dcalc | lcalc ], resolved) t
|
||||
| Lt_rat_rat : ([< scopelang | dcalc | lcalc ], resolved) t
|
||||
| Lt_mon_mon : ([< scopelang | dcalc | lcalc ], resolved) t
|
||||
| Lt_dat_dat : ([< scopelang | dcalc | lcalc ], resolved) t
|
||||
| Lt_dur_dur : ([< scopelang | dcalc | lcalc ], resolved) t
|
||||
| Lte : (desugared, overloaded) t
|
||||
| Lte_int_int : ([< scopelang | dcalc | lcalc ], resolved) t
|
||||
| Lte_rat_rat : ([< scopelang | dcalc | lcalc ], resolved) t
|
||||
| Lte_mon_mon : ([< scopelang | dcalc | lcalc ], resolved) t
|
||||
| Lte_dat_dat : ([< scopelang | dcalc | lcalc ], resolved) t
|
||||
| Lte_dur_dur : ([< scopelang | dcalc | lcalc ], resolved) t
|
||||
| Gt : (desugared, overloaded) t
|
||||
| Gt_int_int : ([< scopelang | dcalc | lcalc ], resolved) t
|
||||
| Gt_rat_rat : ([< scopelang | dcalc | lcalc ], resolved) t
|
||||
| Gt_mon_mon : ([< scopelang | dcalc | lcalc ], resolved) t
|
||||
| Gt_dat_dat : ([< scopelang | dcalc | lcalc ], resolved) t
|
||||
| Gt_dur_dur : ([< scopelang | dcalc | lcalc ], resolved) t
|
||||
| Gte : (desugared, overloaded) t
|
||||
| Gte_int_int : ([< scopelang | dcalc | lcalc ], resolved) t
|
||||
| Gte_rat_rat : ([< scopelang | dcalc | lcalc ], resolved) t
|
||||
| Gte_mon_mon : ([< scopelang | dcalc | lcalc ], resolved) t
|
||||
| Gte_dat_dat : ([< scopelang | dcalc | lcalc ], resolved) t
|
||||
| Gte_dur_dur : ([< scopelang | dcalc | lcalc ], resolved) t
|
||||
(* Todo: Eq is not an overload at the moment, but it should be one. The
|
||||
trick is that it needs generation of specific code for arrays, every
|
||||
struct and enum: operators [Eq_structs of StructName.t], etc. *)
|
||||
| Eq_int_int : ([< scopelang | dcalc | lcalc ], resolved) t
|
||||
| Eq_rat_rat : ([< scopelang | dcalc | lcalc ], resolved) t
|
||||
| Eq_mon_mon : ([< scopelang | dcalc | lcalc ], resolved) t
|
||||
| Eq_dur_dur : ([< scopelang | dcalc | lcalc ], resolved) t
|
||||
| Eq_dat_dat : ([< scopelang | dcalc | lcalc ], resolved) t
|
||||
(* ternary *)
|
||||
(* * polymorphic *)
|
||||
| Fold : ('a any, polymorphic) t
|
||||
end
|
||||
|
||||
type ('a, 'k) operator = ('a any, 'k) Op.t
|
||||
type except = ConflictError | EmptyError | NoValueProvided | Crash
|
||||
|
||||
(** {2 Generic expressions} *)
|
||||
|
||||
(** Define a common base type for the expressions in most passes of the compiler *)
|
||||
|
||||
type desugared = [ `Desugared ]
|
||||
type scopelang = [ `Scopelang ]
|
||||
type dcalc = [ `Dcalc ]
|
||||
type lcalc = [ `Lcalc ]
|
||||
type 'a any = [< desugared | scopelang | dcalc | lcalc ] as 'a
|
||||
|
||||
(** Literals are the same throughout compilation except for the [LEmptyError]
|
||||
case which is eliminated midway through. *)
|
||||
type 'a glit =
|
||||
@ -192,65 +258,98 @@ type ('a, 't) gexpr = (('a, 't) naked_gexpr, 't) Marked.t
|
||||
- To write a function that handles cases from different ASTs, explicit the
|
||||
type variables: [fun (type a) (x: a naked_gexpr) -> ...]
|
||||
- For recursive functions, you may need to additionally explicit the
|
||||
generalisation of the variable: [let rec f: type a . a naked_gexpr -> ...] *)
|
||||
generalisation of the variable: [let rec f: type a . a naked_gexpr -> ...]
|
||||
- Always think of using the pre-defined map/fold functions in [Expr] rather
|
||||
than completely defining your recursion manually. *)
|
||||
|
||||
and ('a, 't) naked_gexpr =
|
||||
(* Constructors common to all ASTs *)
|
||||
| ELit : 'a glit -> ('a any, 't) naked_gexpr
|
||||
| EApp : ('a, 't) gexpr * ('a, 't) gexpr list -> ('a any, 't) naked_gexpr
|
||||
| EOp : operator -> ('a any, 't) naked_gexpr
|
||||
| EApp : {
|
||||
f : ('a, 't) gexpr;
|
||||
args : ('a, 't) gexpr list;
|
||||
}
|
||||
-> ('a any, 't) naked_gexpr
|
||||
| EOp : { op : ('a, _) operator; tys : typ list } -> ('a any, 't) naked_gexpr
|
||||
| EArray : ('a, 't) gexpr list -> ('a any, 't) naked_gexpr
|
||||
| EVar : ('a, 't) naked_gexpr Bindlib.var -> ('a any, 't) naked_gexpr
|
||||
| EAbs :
|
||||
(('a, 't) naked_gexpr, ('a, 't) gexpr) Bindlib.mbinder * typ list
|
||||
| EAbs : {
|
||||
binder : (('a, 't) naked_gexpr, ('a, 't) gexpr) Bindlib.mbinder;
|
||||
tys : typ list;
|
||||
}
|
||||
-> ('a any, 't) naked_gexpr
|
||||
| EIfThenElse :
|
||||
('a, 't) gexpr * ('a, 't) gexpr * ('a, 't) gexpr
|
||||
| EIfThenElse : {
|
||||
cond : ('a, 't) gexpr;
|
||||
etrue : ('a, 't) gexpr;
|
||||
efalse : ('a, 't) gexpr;
|
||||
}
|
||||
-> ('a any, 't) naked_gexpr
|
||||
| EStruct : {
|
||||
name : StructName.t;
|
||||
fields : ('a, 't) gexpr StructField.Map.t;
|
||||
}
|
||||
-> ('a any, 't) naked_gexpr
|
||||
| EInj : {
|
||||
name : EnumName.t;
|
||||
e : ('a, 't) gexpr;
|
||||
cons : EnumConstructor.t;
|
||||
}
|
||||
-> ('a any, 't) naked_gexpr
|
||||
| EMatch : {
|
||||
name : EnumName.t;
|
||||
e : ('a, 't) gexpr;
|
||||
cases : ('a, 't) gexpr EnumConstructor.Map.t;
|
||||
}
|
||||
-> ('a any, 't) naked_gexpr
|
||||
(* Early stages *)
|
||||
| ELocation :
|
||||
'a glocation
|
||||
-> (([< desugared | scopelang ] as 'a), 't) naked_gexpr
|
||||
| EStruct :
|
||||
StructName.t * ('a, 't) gexpr StructFieldMap.t
|
||||
-> (([< desugared | scopelang ] as 'a), 't) naked_gexpr
|
||||
| EStructAccess :
|
||||
('a, 't) gexpr * StructFieldName.t * StructName.t
|
||||
-> (([< desugared | scopelang ] as 'a), 't) naked_gexpr
|
||||
| EEnumInj :
|
||||
('a, 't) gexpr * EnumConstructor.t * EnumName.t
|
||||
-> (([< desugared | scopelang ] as 'a), 't) naked_gexpr
|
||||
| EMatchS :
|
||||
('a, 't) gexpr * EnumName.t * ('a, 't) gexpr EnumConstructorMap.t
|
||||
-> (([< desugared | scopelang ] as 'a), 't) naked_gexpr
|
||||
| EScopeCall :
|
||||
ScopeName.t * ('a, 't) gexpr ScopeVarMap.t
|
||||
| EScopeCall : {
|
||||
scope : ScopeName.t;
|
||||
args : ('a, 't) gexpr ScopeVar.Map.t;
|
||||
}
|
||||
-> (([< desugared | scopelang ] as 'a), 't) naked_gexpr
|
||||
| EDStructAccess : {
|
||||
name_opt : StructName.t option;
|
||||
e : ('a, 't) gexpr;
|
||||
field : IdentName.t;
|
||||
}
|
||||
-> ((desugared as 'a), 't) naked_gexpr
|
||||
(** [desugared] has ambiguous struct fields *)
|
||||
| EStructAccess : {
|
||||
name : StructName.t;
|
||||
e : ('a, 't) gexpr;
|
||||
field : StructField.t;
|
||||
}
|
||||
-> (([< scopelang | dcalc | lcalc ] as 'a), 't) naked_gexpr
|
||||
(** Resolved struct/enums, after [desugared] *)
|
||||
(* Lambda-like *)
|
||||
| ETuple :
|
||||
('a, 't) gexpr list * StructName.t option
|
||||
-> (([< dcalc | lcalc ] as 'a), 't) naked_gexpr
|
||||
| ETupleAccess :
|
||||
('a, 't) gexpr * int * StructName.t option * typ list
|
||||
-> (([< dcalc | lcalc ] as 'a), 't) naked_gexpr
|
||||
| EInj :
|
||||
('a, 't) gexpr * int * EnumName.t * typ list
|
||||
-> (([< dcalc | lcalc ] as 'a), 't) naked_gexpr
|
||||
| EMatch :
|
||||
('a, 't) gexpr * ('a, 't) gexpr list * EnumName.t
|
||||
-> (([< dcalc | lcalc ] as 'a), 't) naked_gexpr
|
||||
| EAssert : ('a, 't) gexpr -> (([< dcalc | lcalc ] as 'a), 't) naked_gexpr
|
||||
(* Default terms *)
|
||||
| EDefault :
|
||||
('a, 't) gexpr list * ('a, 't) gexpr * ('a, 't) gexpr
|
||||
| EDefault : {
|
||||
excepts : ('a, 't) gexpr list;
|
||||
just : ('a, 't) gexpr;
|
||||
cons : ('a, 't) gexpr;
|
||||
}
|
||||
-> (([< desugared | scopelang | dcalc ] as 'a), 't) naked_gexpr
|
||||
| ErrorOnEmpty :
|
||||
| EErrorOnEmpty :
|
||||
('a, 't) gexpr
|
||||
-> (([< desugared | scopelang | dcalc ] as 'a), 't) naked_gexpr
|
||||
(* Lambda calculus with exceptions *)
|
||||
| ETuple : ('a, 't) gexpr list -> ((lcalc as 'a), 't) naked_gexpr
|
||||
| ETupleAccess : {
|
||||
e : ('a, 't) gexpr;
|
||||
index : int;
|
||||
size : int;
|
||||
}
|
||||
-> ((lcalc as 'a), 't) naked_gexpr
|
||||
| ERaise : except -> ((lcalc as 'a), 't) naked_gexpr
|
||||
| ECatch :
|
||||
('a, 't) gexpr * except * ('a, 't) gexpr
|
||||
| ECatch : {
|
||||
body : ('a, 't) gexpr;
|
||||
exn : except;
|
||||
handler : ('a, 't) gexpr;
|
||||
}
|
||||
-> ((lcalc as 'a), 't) naked_gexpr
|
||||
|
||||
type ('a, 't) boxed_gexpr = (('a, 't) naked_gexpr Bindlib.box, 't) Marked.t
|
||||
@ -276,9 +375,9 @@ type typed = { pos : Pos.t; ty : typ }
|
||||
|
||||
(** The generic type of AST markings. Using a GADT allows functions to be
|
||||
polymorphic in the marking, but still do transformations on types when
|
||||
appropriate. Expected to fill the ['t] parameter of [naked_gexpr] and
|
||||
[gexpr] (a ['t] annotation different from this type is used in the middle of
|
||||
the typing processing, but all visible ASTs should otherwise use this. *)
|
||||
appropriate. Expected to fill the ['t] parameter of [gexpr] and [gexpr] (a
|
||||
['t] annotation different from this type is used in the middle of the typing
|
||||
processing, but all visible ASTs should otherwise use this. *)
|
||||
type _ mark = Untyped : untyped -> untyped mark | Typed : typed -> typed mark
|
||||
|
||||
(** Useful for errors and printing, for example *)
|
||||
@ -287,11 +386,10 @@ type any_expr = AnyExpr : (_, _ mark) gexpr -> any_expr
|
||||
(** {2 Higher-level program structure} *)
|
||||
|
||||
(** Constructs scopes and programs on top of expressions. The ['e] type
|
||||
parameter throughout is expected to match instances of the [naked_gexpr]
|
||||
type defined above. Markings are constrained to the [mark] GADT defined
|
||||
above. Note that this structure is at the moment only relevant for [dcalc]
|
||||
and [lcalc], as [scopelang] has its own scope structure, as the name
|
||||
implies. *)
|
||||
parameter throughout is expected to match instances of the [gexpr] type
|
||||
defined above. Markings are constrained to the [mark] GADT defined above.
|
||||
Note that this structure is at the moment only relevant for [dcalc] and
|
||||
[lcalc], as [scopelang] has its own scope structure, as the name implies. *)
|
||||
|
||||
(** This kind annotation signals that the let-binding respects a structural
|
||||
invariant. These invariants concern the shape of the expression in the
|
||||
@ -350,14 +448,20 @@ and 'e scopes =
|
||||
| ScopeDef of 'e scope_def
|
||||
constraint 'e = (_ any, _ mark) gexpr
|
||||
|
||||
type struct_ctx = (StructFieldName.t * typ) list StructMap.t
|
||||
type enum_ctx = (EnumConstructor.t * typ) list EnumMap.t
|
||||
type struct_ctx = typ StructField.Map.t StructName.Map.t
|
||||
type enum_ctx = typ EnumConstructor.Map.t EnumName.Map.t
|
||||
|
||||
type scope_out_struct = {
|
||||
out_struct_name : StructName.t;
|
||||
out_struct_fields : StructField.t ScopeVar.Map.t;
|
||||
}
|
||||
|
||||
type decl_ctx = {
|
||||
ctx_enums : enum_ctx;
|
||||
ctx_structs : struct_ctx;
|
||||
ctx_scopes : StructName.t ScopeMap.t;
|
||||
(** The output structure type of every scope *)
|
||||
ctx_struct_fields : StructField.t StructName.Map.t IdentName.Map.t;
|
||||
(** needed for disambiguation (desugared -> scope) *)
|
||||
ctx_scopes : scope_out_struct ScopeName.Map.t;
|
||||
}
|
||||
|
||||
type 'e program = { decl_ctx : decl_ctx; scopes : 'e scopes }
|
||||
|
@ -3,4 +3,4 @@
|
||||
(public_name catala.shared_ast)
|
||||
(flags
|
||||
(:standard -short-paths))
|
||||
(libraries bindlib unionFind utils catala.runtime_ocaml))
|
||||
(libraries bindlib unionFind catala_utils catala.runtime_ocaml))
|
||||
|
@ -15,7 +15,7 @@
|
||||
License for the specific language governing permissions and limitations under
|
||||
the License. *)
|
||||
|
||||
open Utils
|
||||
open Catala_utils
|
||||
open Definitions
|
||||
|
||||
(** Functions handling the types of [shared_ast] *)
|
||||
@ -57,15 +57,15 @@ module Box = struct
|
||||
fun em ->
|
||||
B.box_apply (fun e -> Marked.mark (Marked.get_mark em) e) (Marked.unmark em)
|
||||
|
||||
module LiftStruct = Bindlib.Lift (StructFieldMap)
|
||||
module LiftStruct = Bindlib.Lift (StructField.Map)
|
||||
|
||||
let lift_struct = LiftStruct.lift_box
|
||||
|
||||
module LiftEnum = Bindlib.Lift (EnumConstructorMap)
|
||||
module LiftEnum = Bindlib.Lift (EnumConstructor.Map)
|
||||
|
||||
let lift_enum = LiftEnum.lift_box
|
||||
|
||||
module LiftScopeVars = Bindlib.Lift (ScopeVarMap)
|
||||
module LiftScopeVars = Bindlib.Lift (ScopeVar.Map)
|
||||
|
||||
let lift_scope_vars = LiftScopeVars.lift_box
|
||||
end
|
||||
@ -76,61 +76,64 @@ let subst binder vars =
|
||||
Bindlib.msubst binder (Array.of_list (List.map Marked.unmark vars))
|
||||
|
||||
let evar v mark = Marked.mark mark (Bindlib.box_var v)
|
||||
let etuple args s = Box.appn args @@ fun args -> ETuple (args, s)
|
||||
let etuple args = Box.appn args @@ fun args -> ETuple args
|
||||
|
||||
let etupleaccess e1 i s typs =
|
||||
Box.app1 e1 @@ fun e1 -> ETupleAccess (e1, i, s, typs)
|
||||
|
||||
let einj e1 i e_name typs = Box.app1 e1 @@ fun e1 -> EInj (e1, i, e_name, typs)
|
||||
|
||||
let ematch arg arms e_name =
|
||||
Box.app1n arg arms @@ fun arg arms -> EMatch (arg, arms, e_name)
|
||||
let etupleaccess e index size =
|
||||
assert (index < size);
|
||||
Box.app1 e @@ fun e -> ETupleAccess { e; index; size }
|
||||
|
||||
let earray args = Box.appn args @@ fun args -> EArray args
|
||||
let elit l mark = Marked.mark mark (Bindlib.box (ELit l))
|
||||
|
||||
let eabs binder typs mark =
|
||||
Bindlib.box_apply (fun binder -> EAbs (binder, typs)) binder, mark
|
||||
let eabs binder tys mark =
|
||||
Bindlib.box_apply (fun binder -> EAbs { binder; tys }) binder, mark
|
||||
|
||||
let eapp e1 args = Box.app1n e1 args @@ fun e1 args -> EApp (e1, args)
|
||||
let eapp f args = Box.app1n f args @@ fun f args -> EApp { f; args }
|
||||
let eassert e1 = Box.app1 e1 @@ fun e1 -> EAssert e1
|
||||
let eop op = Box.app0 @@ EOp op
|
||||
let eop op tys = Box.app0 @@ EOp { op; tys }
|
||||
|
||||
let edefault excepts just cons =
|
||||
Box.app2n just cons excepts
|
||||
@@ fun just cons excepts -> EDefault (excepts, just, cons)
|
||||
@@ fun just cons excepts -> EDefault { excepts; just; cons }
|
||||
|
||||
let eifthenelse e1 e2 e3 =
|
||||
Box.app3 e1 e2 e3 @@ fun e1 e2 e3 -> EIfThenElse (e1, e2, e3)
|
||||
let eifthenelse cond etrue efalse =
|
||||
Box.app3 cond etrue efalse
|
||||
@@ fun cond etrue efalse -> EIfThenElse { cond; etrue; efalse }
|
||||
|
||||
let eerroronempty e1 = Box.app1 e1 @@ fun e1 -> ErrorOnEmpty e1
|
||||
let eerroronempty e1 = Box.app1 e1 @@ fun e1 -> EErrorOnEmpty e1
|
||||
let eraise e1 = Box.app0 @@ ERaise e1
|
||||
let ecatch e1 exn e2 = Box.app2 e1 e2 @@ fun e1 e2 -> ECatch (e1, exn, e2)
|
||||
|
||||
let ecatch body exn handler =
|
||||
Box.app2 body handler @@ fun body handler -> ECatch { body; exn; handler }
|
||||
|
||||
let elocation loc = Box.app0 @@ ELocation loc
|
||||
|
||||
let estruct name (fields : ('a, 't) boxed_gexpr StructFieldMap.t) mark =
|
||||
let estruct name (fields : ('a, 't) boxed_gexpr StructField.Map.t) mark =
|
||||
Marked.mark mark
|
||||
@@ Bindlib.box_apply
|
||||
(fun fields -> EStruct (name, fields))
|
||||
(Box.lift_struct (StructFieldMap.map Box.lift fields))
|
||||
(fun fields -> EStruct { name; fields })
|
||||
(Box.lift_struct (StructField.Map.map Box.lift fields))
|
||||
|
||||
let estructaccess e1 field struc =
|
||||
Box.app1 e1 @@ fun e1 -> EStructAccess (e1, field, struc)
|
||||
let edstructaccess e field name_opt =
|
||||
Box.app1 e @@ fun e -> EDStructAccess { name_opt; e; field }
|
||||
|
||||
let eenuminj e1 cons enum = Box.app1 e1 @@ fun e1 -> EEnumInj (e1, cons, enum)
|
||||
let estructaccess e field name =
|
||||
Box.app1 e @@ fun e -> EStructAccess { name; e; field }
|
||||
|
||||
let ematchs e1 enum cases mark =
|
||||
let einj e cons name = Box.app1 e @@ fun e -> EInj { name; e; cons }
|
||||
|
||||
let ematch e name cases mark =
|
||||
Marked.mark mark
|
||||
@@ Bindlib.box_apply2
|
||||
(fun e1 cases -> EMatchS (e1, enum, cases))
|
||||
(Box.lift e1)
|
||||
(Box.lift_enum (EnumConstructorMap.map Box.lift cases))
|
||||
(fun e cases -> EMatch { name; e; cases })
|
||||
(Box.lift e)
|
||||
(Box.lift_enum (EnumConstructor.Map.map Box.lift cases))
|
||||
|
||||
let escopecall scope_name fields mark =
|
||||
let escopecall scope args mark =
|
||||
Marked.mark mark
|
||||
@@ Bindlib.box_apply
|
||||
(fun fields -> EScopeCall (scope_name, fields))
|
||||
(Box.lift_scope_vars (ScopeVarMap.map Box.lift fields))
|
||||
(fun args -> EScopeCall { scope; args })
|
||||
(Box.lift_scope_vars (ScopeVar.Map.map Box.lift args))
|
||||
|
||||
(* - Manipulation of marks - *)
|
||||
|
||||
@ -203,49 +206,46 @@ let maybe_ty (type m) ?(typ = TAny) (m : m mark) : typ =
|
||||
(* shallow map *)
|
||||
let map
|
||||
(type a)
|
||||
(ctx : 'ctx)
|
||||
~(f : 'ctx -> (a, 'm1) gexpr -> (a, 'm2) boxed_gexpr)
|
||||
~(f : (a, 'm1) gexpr -> (a, 'm2) boxed_gexpr)
|
||||
(e : ((a, 'm1) naked_gexpr, 'm2) Marked.t) : (a, 'm2) boxed_gexpr =
|
||||
let m = Marked.get_mark e in
|
||||
match Marked.unmark e with
|
||||
| ELit l -> elit l m
|
||||
| EApp (e1, args) -> eapp (f ctx e1) (List.map (f ctx) args) m
|
||||
| EOp op -> eop op m
|
||||
| EArray args -> earray (List.map (f ctx) args) m
|
||||
| EApp { f = e1; args } -> eapp (f e1) (List.map f args) m
|
||||
| EOp { op; tys } -> eop op tys m
|
||||
| EArray args -> earray (List.map f args) m
|
||||
| EVar v -> evar (Var.translate v) m
|
||||
| EAbs (binder, typs) ->
|
||||
| EAbs { binder; tys } ->
|
||||
let vars, body = Bindlib.unmbind binder in
|
||||
let body = f ctx body in
|
||||
let body = f body in
|
||||
let binder = bind (Array.map Var.translate vars) body in
|
||||
eabs binder typs m
|
||||
| EIfThenElse (e1, e2, e3) ->
|
||||
eifthenelse ((f ctx) e1) ((f ctx) e2) ((f ctx) e3) m
|
||||
| ETuple (args, s) -> etuple (List.map (f ctx) args) s m
|
||||
| ETupleAccess (e1, n, s_name, typs) ->
|
||||
etupleaccess ((f ctx) e1) n s_name typs m
|
||||
| EInj (e1, i, e_name, typs) -> einj ((f ctx) e1) i e_name typs m
|
||||
| EMatch (arg, arms, e_name) ->
|
||||
ematch ((f ctx) arg) (List.map (f ctx) arms) e_name m
|
||||
| EAssert e1 -> eassert ((f ctx) e1) m
|
||||
| EDefault (excepts, just, cons) ->
|
||||
edefault (List.map (f ctx) excepts) ((f ctx) just) ((f ctx) cons) m
|
||||
| ErrorOnEmpty e1 -> eerroronempty ((f ctx) e1) m
|
||||
| ECatch (e1, exn, e2) -> ecatch (f ctx e1) exn (f ctx e2) m
|
||||
eabs binder tys m
|
||||
| EIfThenElse { cond; etrue; efalse } ->
|
||||
eifthenelse (f cond) (f etrue) (f efalse) m
|
||||
| ETuple args -> etuple (List.map f args) m
|
||||
| ETupleAccess { e; index; size } -> etupleaccess (f e) index size m
|
||||
| EInj { e; name; cons } -> einj (f e) cons name m
|
||||
| EAssert e1 -> eassert (f e1) m
|
||||
| EDefault { excepts; just; cons } ->
|
||||
edefault (List.map f excepts) (f just) (f cons) m
|
||||
| EErrorOnEmpty e1 -> eerroronempty (f e1) m
|
||||
| ECatch { body; exn; handler } -> ecatch (f body) exn (f handler) m
|
||||
| ERaise exn -> eraise exn m
|
||||
| ELocation loc -> elocation loc m
|
||||
| EStruct (name, fields) ->
|
||||
let fields = StructFieldMap.map (f ctx) fields in
|
||||
| EStruct { name; fields } ->
|
||||
let fields = StructField.Map.map f fields in
|
||||
estruct name fields m
|
||||
| EStructAccess (e1, field, struc) -> estructaccess (f ctx e1) field struc m
|
||||
| EEnumInj (e1, cons, enum) -> eenuminj (f ctx e1) cons enum m
|
||||
| EMatchS (e1, enum, cases) ->
|
||||
let cases = EnumConstructorMap.map (f ctx) cases in
|
||||
ematchs (f ctx e1) enum cases m
|
||||
| EScopeCall (scope_name, fields) ->
|
||||
let fields = ScopeVarMap.map (f ctx) fields in
|
||||
escopecall scope_name fields m
|
||||
| EDStructAccess { e; field; name_opt } ->
|
||||
edstructaccess (f e) field name_opt m
|
||||
| EStructAccess { e; field; name } -> estructaccess (f e) field name m
|
||||
| EMatch { e; name; cases } ->
|
||||
let cases = EnumConstructor.Map.map f cases in
|
||||
ematch (f e) name cases m
|
||||
| EScopeCall { scope; args } ->
|
||||
let fields = ScopeVar.Map.map f args in
|
||||
escopecall scope fields m
|
||||
|
||||
let rec map_top_down ~f e = map () ~f:(fun () -> map_top_down ~f) (f e)
|
||||
let rec map_top_down ~f e = map ~f:(map_top_down ~f) (f e)
|
||||
|
||||
let map_marks ~f e =
|
||||
map_top_down ~f:(fun e -> Marked.(mark (f (get_mark e)) (unmark e))) e
|
||||
@ -260,31 +260,130 @@ let shallow_fold
|
||||
let lfold x acc = List.fold_left (fun acc x -> f x acc) acc x in
|
||||
match Marked.unmark e with
|
||||
| ELit _ | EOp _ | EVar _ | ERaise _ | ELocation _ -> acc
|
||||
| EApp (e1, args) -> acc |> f e1 |> lfold args
|
||||
| EApp { f = e; args } -> acc |> f e |> lfold args
|
||||
| EArray args -> acc |> lfold args
|
||||
| EAbs _ -> acc
|
||||
| EIfThenElse (e1, e2, e3) -> acc |> f e1 |> f e2 |> f e3
|
||||
| ETuple (args, _) -> acc |> lfold args
|
||||
| ETupleAccess (e1, _, _, _) -> acc |> f e1
|
||||
| EInj (e1, _, _, _) -> acc |> f e1
|
||||
| EMatch (arg, arms, _) -> acc |> f arg |> lfold arms
|
||||
| EAssert e1 -> acc |> f e1
|
||||
| EDefault (excepts, just, cons) -> acc |> lfold excepts |> f just |> f cons
|
||||
| ErrorOnEmpty e1 -> acc |> f e1
|
||||
| ECatch (e1, _, e2) -> acc |> f e1 |> f e2
|
||||
| EStruct (_, fields) -> acc |> StructFieldMap.fold (fun _ -> f) fields
|
||||
| EStructAccess (e1, _, _) -> acc |> f e1
|
||||
| EEnumInj (e1, _, _) -> acc |> f e1
|
||||
| EMatchS (e1, _, cases) ->
|
||||
acc |> f e1 |> EnumConstructorMap.fold (fun _ -> f) cases
|
||||
| EScopeCall (_, fields) -> acc |> ScopeVarMap.fold (fun _ -> f) fields
|
||||
| EIfThenElse { cond; etrue; efalse } -> acc |> f cond |> f etrue |> f efalse
|
||||
| ETuple args -> acc |> lfold args
|
||||
| ETupleAccess { e; _ } -> acc |> f e
|
||||
| EInj { e; _ } -> acc |> f e
|
||||
| EAssert e -> acc |> f e
|
||||
| EDefault { excepts; just; cons } -> acc |> lfold excepts |> f just |> f cons
|
||||
| EErrorOnEmpty e -> acc |> f e
|
||||
| ECatch { body; handler; _ } -> acc |> f body |> f handler
|
||||
| EStruct { fields; _ } -> acc |> StructField.Map.fold (fun _ -> f) fields
|
||||
| EDStructAccess { e; _ } -> acc |> f e
|
||||
| EStructAccess { e; _ } -> acc |> f e
|
||||
| EMatch { e; cases; _ } ->
|
||||
acc |> f e |> EnumConstructor.Map.fold (fun _ -> f) cases
|
||||
| EScopeCall { args; _ } -> acc |> ScopeVar.Map.fold (fun _ -> f) args
|
||||
|
||||
(* Like [map], but also allows to gather a result bottom-up. *)
|
||||
let map_gather
|
||||
(type a)
|
||||
~(acc : 'acc)
|
||||
~(join : 'acc -> 'acc -> 'acc)
|
||||
~(f : (a, 'm1) gexpr -> 'acc * (a, 'm2) boxed_gexpr)
|
||||
(e : ((a, 'm1) naked_gexpr, 'm2) Marked.t) : 'acc * (a, 'm2) boxed_gexpr =
|
||||
let m = Marked.get_mark e in
|
||||
let lfoldmap es =
|
||||
let acc, r_es =
|
||||
List.fold_left
|
||||
(fun (acc, es) e ->
|
||||
let acc1, e = f e in
|
||||
join acc acc1, e :: es)
|
||||
(acc, []) es
|
||||
in
|
||||
acc, List.rev r_es
|
||||
in
|
||||
match Marked.unmark e with
|
||||
| ELit l -> acc, elit l m
|
||||
| EApp { f = e1; args } ->
|
||||
let acc1, f = f e1 in
|
||||
let acc2, args = lfoldmap args in
|
||||
join acc1 acc2, eapp f args m
|
||||
| EOp { op; tys } -> acc, eop op tys m
|
||||
| EArray args ->
|
||||
let acc, args = lfoldmap args in
|
||||
acc, earray args m
|
||||
| EVar v -> acc, evar (Var.translate v) m
|
||||
| EAbs { binder; tys } ->
|
||||
let vars, body = Bindlib.unmbind binder in
|
||||
let acc, body = f body in
|
||||
let binder = bind (Array.map Var.translate vars) body in
|
||||
acc, eabs binder tys m
|
||||
| EIfThenElse { cond; etrue; efalse } ->
|
||||
let acc1, cond = f cond in
|
||||
let acc2, etrue = f etrue in
|
||||
let acc3, efalse = f efalse in
|
||||
join (join acc1 acc2) acc3, eifthenelse cond etrue efalse m
|
||||
| ETuple args ->
|
||||
let acc, args = lfoldmap args in
|
||||
acc, etuple args m
|
||||
| ETupleAccess { e; index; size } ->
|
||||
let acc, e = f e in
|
||||
acc, etupleaccess e index size m
|
||||
| EInj { e; name; cons } ->
|
||||
let acc, e = f e in
|
||||
acc, einj e cons name m
|
||||
| EAssert e ->
|
||||
let acc, e = f e in
|
||||
acc, eassert e m
|
||||
| EDefault { excepts; just; cons } ->
|
||||
let acc1, excepts = lfoldmap excepts in
|
||||
let acc2, just = f just in
|
||||
let acc3, cons = f cons in
|
||||
join (join acc1 acc2) acc3, edefault excepts just cons m
|
||||
| EErrorOnEmpty e ->
|
||||
let acc, e = f e in
|
||||
acc, eerroronempty e m
|
||||
| ECatch { body; exn; handler } ->
|
||||
let acc1, body = f body in
|
||||
let acc2, handler = f handler in
|
||||
join acc1 acc2, ecatch body exn handler m
|
||||
| ERaise exn -> acc, eraise exn m
|
||||
| ELocation loc -> acc, elocation loc m
|
||||
| EStruct { name; fields } ->
|
||||
let acc, fields =
|
||||
StructField.Map.fold
|
||||
(fun cons e (acc, fields) ->
|
||||
let acc1, e = f e in
|
||||
join acc acc1, StructField.Map.add cons e fields)
|
||||
fields
|
||||
(acc, StructField.Map.empty)
|
||||
in
|
||||
acc, estruct name fields m
|
||||
| EDStructAccess { e; field; name_opt } ->
|
||||
let acc, e = f e in
|
||||
acc, edstructaccess e field name_opt m
|
||||
| EStructAccess { e; field; name } ->
|
||||
let acc, e = f e in
|
||||
acc, estructaccess e field name m
|
||||
| EMatch { e; name; cases } ->
|
||||
let acc, e = f e in
|
||||
let acc, cases =
|
||||
EnumConstructor.Map.fold
|
||||
(fun cons e (acc, cases) ->
|
||||
let acc1, e = f e in
|
||||
join acc acc1, EnumConstructor.Map.add cons e cases)
|
||||
cases
|
||||
(acc, EnumConstructor.Map.empty)
|
||||
in
|
||||
acc, ematch e name cases m
|
||||
| EScopeCall { scope; args } ->
|
||||
let acc, args =
|
||||
ScopeVar.Map.fold
|
||||
(fun var e (acc, args) ->
|
||||
let acc1, e = f e in
|
||||
join acc acc1, ScopeVar.Map.add var e args)
|
||||
args (acc, ScopeVar.Map.empty)
|
||||
in
|
||||
acc, escopecall scope args m
|
||||
|
||||
(* - *)
|
||||
|
||||
(** See [Bindlib.box_term] documentation for why we are doing that. *)
|
||||
let rebox e =
|
||||
let rec id_t () e = map () ~f:id_t e in
|
||||
id_t () e
|
||||
let rec rebox e = map ~f:rebox e
|
||||
|
||||
let box e = Marked.same_mark_as (Bindlib.box (Marked.unmark e)) e
|
||||
let unbox (e, m) = Bindlib.unbox e, m
|
||||
@ -297,99 +396,36 @@ let is_value (type a) (e : (a, _) gexpr) =
|
||||
| ELit _ | EAbs _ | EOp _ | ERaise _ -> true
|
||||
| _ -> false
|
||||
|
||||
let equal_tlit l1 l2 = l1 = l2
|
||||
let compare_tlit l1 l2 = Stdlib.compare l1 l2
|
||||
|
||||
let rec equal_typ ty1 ty2 =
|
||||
match Marked.unmark ty1, Marked.unmark ty2 with
|
||||
| TLit l1, TLit l2 -> equal_tlit l1 l2
|
||||
| TTuple tys1, TTuple tys2 -> equal_typ_list tys1 tys2
|
||||
| TStruct n1, TStruct n2 -> StructName.equal n1 n2
|
||||
| TEnum n1, TEnum n2 -> EnumName.equal n1 n2
|
||||
| TOption t1, TOption t2 -> equal_typ t1 t2
|
||||
| TArrow (t1, t1'), TArrow (t2, t2') -> equal_typ t1 t2 && equal_typ t1' t2'
|
||||
| TArray t1, TArray t2 -> equal_typ t1 t2
|
||||
| TAny, TAny -> true
|
||||
| ( ( TLit _ | TTuple _ | TStruct _ | TEnum _ | TOption _ | TArrow _
|
||||
| TArray _ | TAny ),
|
||||
_ ) ->
|
||||
false
|
||||
|
||||
and equal_typ_list tys1 tys2 =
|
||||
try List.for_all2 equal_typ tys1 tys2 with Invalid_argument _ -> false
|
||||
|
||||
(* Similar to [equal_typ], but allows TAny holes *)
|
||||
let rec unifiable ty1 ty2 =
|
||||
match Marked.unmark ty1, Marked.unmark ty2 with
|
||||
| TAny, _ | _, TAny -> true
|
||||
| TLit l1, TLit l2 -> equal_tlit l1 l2
|
||||
| TTuple tys1, TTuple tys2 -> unifiable_list tys1 tys2
|
||||
| TStruct n1, TStruct n2 -> StructName.equal n1 n2
|
||||
| TEnum n1, TEnum n2 -> EnumName.equal n1 n2
|
||||
| TOption t1, TOption t2 -> unifiable t1 t2
|
||||
| TArrow (t1, t1'), TArrow (t2, t2') -> unifiable t1 t2 && unifiable t1' t2'
|
||||
| TArray t1, TArray t2 -> unifiable t1 t2
|
||||
| ( (TLit _ | TTuple _ | TStruct _ | TEnum _ | TOption _ | TArrow _ | TArray _),
|
||||
_ ) ->
|
||||
false
|
||||
|
||||
and unifiable_list tys1 tys2 =
|
||||
try List.for_all2 unifiable tys1 tys2 with Invalid_argument _ -> false
|
||||
|
||||
let rec compare_typ ty1 ty2 =
|
||||
match Marked.unmark ty1, Marked.unmark ty2 with
|
||||
| TLit l1, TLit l2 -> compare_tlit l1 l2
|
||||
| TTuple tys1, TTuple tys2 -> List.compare compare_typ tys1 tys2
|
||||
| TStruct n1, TStruct n2 -> StructName.compare n1 n2
|
||||
| TEnum en1, TEnum en2 -> EnumName.compare en1 en2
|
||||
| TOption t1, TOption t2 -> compare_typ t1 t2
|
||||
| TArrow (a1, b1), TArrow (a2, b2) -> (
|
||||
match compare_typ a1 a2 with 0 -> compare_typ b1 b2 | n -> n)
|
||||
| TArray t1, TArray t2 -> compare_typ t1 t2
|
||||
| TAny, TAny -> 0
|
||||
| TLit _, _ -> -1
|
||||
| _, TLit _ -> 1
|
||||
| TTuple _, _ -> -1
|
||||
| _, TTuple _ -> 1
|
||||
| TStruct _, _ -> -1
|
||||
| _, TStruct _ -> 1
|
||||
| TEnum _, _ -> -1
|
||||
| _, TEnum _ -> 1
|
||||
| TOption _, _ -> -1
|
||||
| _, TOption _ -> 1
|
||||
| TArrow _, _ -> -1
|
||||
| _, TArrow _ -> 1
|
||||
| TArray _, _ -> -1
|
||||
| _, TArray _ -> 1
|
||||
|
||||
let equal_lit (type a) (l1 : a glit) (l2 : a glit) =
|
||||
let open Runtime.Oper in
|
||||
match l1, l2 with
|
||||
| LBool b1, LBool b2 -> Bool.equal b1 b2
|
||||
| LBool b1, LBool b2 -> not (o_xor b1 b2)
|
||||
| LEmptyError, LEmptyError -> true
|
||||
| LInt n1, LInt n2 -> Runtime.( =! ) n1 n2
|
||||
| LRat r1, LRat r2 -> Runtime.( =& ) r1 r2
|
||||
| LMoney m1, LMoney m2 -> Runtime.( =$ ) m1 m2
|
||||
| LInt n1, LInt n2 -> o_eq_int_int n1 n2
|
||||
| LRat r1, LRat r2 -> o_eq_rat_rat r1 r2
|
||||
| LMoney m1, LMoney m2 -> o_eq_mon_mon m1 m2
|
||||
| LUnit, LUnit -> true
|
||||
| LDate d1, LDate d2 -> Runtime.( =@ ) d1 d2
|
||||
| LDuration d1, LDuration d2 -> Runtime.( =^ ) d1 d2
|
||||
| LDate d1, LDate d2 -> o_eq_dat_dat d1 d2
|
||||
| LDuration d1, LDuration d2 -> o_eq_dur_dur d1 d2
|
||||
| ( ( LBool _ | LEmptyError | LInt _ | LRat _ | LMoney _ | LUnit | LDate _
|
||||
| LDuration _ ),
|
||||
_ ) ->
|
||||
false
|
||||
|
||||
let compare_lit (type a) (l1 : a glit) (l2 : a glit) =
|
||||
let open Runtime.Oper in
|
||||
match l1, l2 with
|
||||
| LBool b1, LBool b2 -> Bool.compare b1 b2
|
||||
| LEmptyError, LEmptyError -> 0
|
||||
| LInt n1, LInt n2 ->
|
||||
if Runtime.( <! ) n1 n2 then -1 else if Runtime.( =! ) n1 n2 then 0 else 1
|
||||
if o_lt_int_int n1 n2 then -1 else if o_eq_int_int n1 n2 then 0 else 1
|
||||
| LRat r1, LRat r2 ->
|
||||
if Runtime.( <& ) r1 r2 then -1 else if Runtime.( =& ) r1 r2 then 0 else 1
|
||||
if o_lt_rat_rat r1 r2 then -1 else if o_eq_rat_rat r1 r2 then 0 else 1
|
||||
| LMoney m1, LMoney m2 ->
|
||||
if Runtime.( <$ ) m1 m2 then -1 else if Runtime.( =$ ) m1 m2 then 0 else 1
|
||||
if o_lt_mon_mon m1 m2 then -1 else if o_eq_mon_mon m1 m2 then 0 else 1
|
||||
| LUnit, LUnit -> 0
|
||||
| LDate d1, LDate d2 ->
|
||||
if Runtime.( <@ ) d1 d2 then -1 else if Runtime.( =@ ) d1 d2 then 0 else 1
|
||||
if o_lt_dat_dat d1 d2 then -1 else if o_eq_dat_dat d1 d2 then 0 else 1
|
||||
| LDuration d1, LDuration d2 -> (
|
||||
(* Duration comparison in the runtime may fail, so rely on a basic
|
||||
lexicographic comparison instead *)
|
||||
@ -441,119 +477,6 @@ let compare_location
|
||||
| _, SubScopeVar _ -> .
|
||||
|
||||
let equal_location a b = compare_location a b = 0
|
||||
|
||||
let equal_log_entries l1 l2 =
|
||||
match l1, l2 with
|
||||
| VarDef t1, VarDef t2 -> equal_typ (t1, Pos.no_pos) (t2, Pos.no_pos)
|
||||
| x, y -> x = y
|
||||
|
||||
let compare_log_entries l1 l2 =
|
||||
match l1, l2 with
|
||||
| VarDef t1, VarDef t2 -> compare_typ (t1, Pos.no_pos) (t2, Pos.no_pos)
|
||||
| BeginCall, BeginCall
|
||||
| EndCall, EndCall
|
||||
| PosRecordIfTrueBool, PosRecordIfTrueBool ->
|
||||
0
|
||||
| VarDef _, _ -> -1
|
||||
| _, VarDef _ -> 1
|
||||
| BeginCall, _ -> -1
|
||||
| _, BeginCall -> 1
|
||||
| EndCall, _ -> -1
|
||||
| _, EndCall -> 1
|
||||
| PosRecordIfTrueBool, _ -> .
|
||||
| _, PosRecordIfTrueBool -> .
|
||||
|
||||
(* let equal_op_kind = Stdlib.(=) *)
|
||||
|
||||
let compare_op_kind = Stdlib.compare
|
||||
|
||||
let equal_unops op1 op2 =
|
||||
match op1, op2 with
|
||||
(* Log entries contain a typ which contain position information, we thus need
|
||||
to descend into them *)
|
||||
| Log (l1, info1), Log (l2, info2) ->
|
||||
equal_log_entries l1 l2 && List.equal Uid.MarkedString.equal info1 info2
|
||||
| Log _, _ | _, Log _ -> false
|
||||
(* All the other cases can be discharged through equality *)
|
||||
| ( ( Not | Minus _ | Length | IntToRat | MoneyToRat | RatToMoney | GetDay
|
||||
| GetMonth | GetYear | FirstDayOfMonth | LastDayOfMonth | RoundMoney
|
||||
| RoundDecimal ),
|
||||
_ ) ->
|
||||
op1 = op2
|
||||
|
||||
let compare_unops op1 op2 =
|
||||
match op1, op2 with
|
||||
| Not, Not -> 0
|
||||
| Minus k1, Minus k2 -> compare_op_kind k1 k2
|
||||
| Log (l1, info1), Log (l2, info2) -> (
|
||||
match compare_log_entries l1 l2 with
|
||||
| 0 -> List.compare Uid.MarkedString.compare info1 info2
|
||||
| n -> n)
|
||||
| Length, Length
|
||||
| IntToRat, IntToRat
|
||||
| MoneyToRat, MoneyToRat
|
||||
| RatToMoney, RatToMoney
|
||||
| GetDay, GetDay
|
||||
| GetMonth, GetMonth
|
||||
| GetYear, GetYear
|
||||
| FirstDayOfMonth, FirstDayOfMonth
|
||||
| LastDayOfMonth, LastDayOfMonth
|
||||
| RoundMoney, RoundMoney
|
||||
| RoundDecimal, RoundDecimal ->
|
||||
0
|
||||
| Not, _ -> -1
|
||||
| _, Not -> 1
|
||||
| Minus _, _ -> -1
|
||||
| _, Minus _ -> 1
|
||||
| Log _, _ -> -1
|
||||
| _, Log _ -> 1
|
||||
| Length, _ -> -1
|
||||
| _, Length -> 1
|
||||
| IntToRat, _ -> -1
|
||||
| _, IntToRat -> 1
|
||||
| MoneyToRat, _ -> -1
|
||||
| _, MoneyToRat -> 1
|
||||
| RatToMoney, _ -> -1
|
||||
| _, RatToMoney -> 1
|
||||
| GetDay, _ -> -1
|
||||
| _, GetDay -> 1
|
||||
| GetMonth, _ -> -1
|
||||
| _, GetMonth -> 1
|
||||
| GetYear, _ -> -1
|
||||
| _, GetYear -> 1
|
||||
| FirstDayOfMonth, _ -> -1
|
||||
| _, FirstDayOfMonth -> 1
|
||||
| LastDayOfMonth, _ -> -1
|
||||
| _, LastDayOfMonth -> 1
|
||||
| RoundMoney, _ -> -1
|
||||
| _, RoundMoney -> 1
|
||||
| RoundDecimal, _ -> .
|
||||
| _, RoundDecimal -> .
|
||||
|
||||
let equal_binop = Stdlib.( = )
|
||||
let compare_binop = Stdlib.compare
|
||||
let equal_ternop = Stdlib.( = )
|
||||
let compare_ternop = Stdlib.compare
|
||||
|
||||
let equal_ops op1 op2 =
|
||||
match op1, op2 with
|
||||
| Ternop op1, Ternop op2 -> equal_ternop op1 op2
|
||||
| Binop op1, Binop op2 -> equal_binop op1 op2
|
||||
| Unop op1, Unop op2 -> equal_unops op1 op2
|
||||
| _, _ -> false
|
||||
|
||||
let compare_op op1 op2 =
|
||||
match op1, op2 with
|
||||
| Ternop op1, Ternop op2 -> compare_ternop op1 op2
|
||||
| Binop op1, Binop op2 -> compare_binop op1 op2
|
||||
| Unop op1, Unop op2 -> compare_unops op1 op2
|
||||
| Ternop _, _ -> -1
|
||||
| _, Ternop _ -> 1
|
||||
| Binop _, _ -> -1
|
||||
| _, Binop _ -> 1
|
||||
| Unop _, _ -> .
|
||||
| _, Unop _ -> .
|
||||
|
||||
let equal_except ex1 ex2 = ex1 = ex2
|
||||
let compare_except ex1 ex2 = Stdlib.compare ex1 ex2
|
||||
|
||||
@ -567,50 +490,60 @@ and equal : type a. (a, 't) gexpr -> (a, 't) gexpr -> bool =
|
||||
fun e1 e2 ->
|
||||
match Marked.unmark e1, Marked.unmark e2 with
|
||||
| EVar v1, EVar v2 -> Bindlib.eq_vars v1 v2
|
||||
| ETuple (es1, n1), ETuple (es2, n2) -> n1 = n2 && equal_list es1 es2
|
||||
| ETupleAccess (e1, id1, n1, tys1), ETupleAccess (e2, id2, n2, tys2) ->
|
||||
equal e1 e2 && id1 = id2 && n1 = n2 && equal_typ_list tys1 tys2
|
||||
| EInj (e1, id1, n1, tys1), EInj (e2, id2, n2, tys2) ->
|
||||
equal e1 e2 && id1 = id2 && n1 = n2 && equal_typ_list tys1 tys2
|
||||
| EMatch (e1, cases1, n1), EMatch (e2, cases2, n2) ->
|
||||
n1 = n2 && equal e1 e2 && equal_list cases1 cases2
|
||||
| ETuple es1, ETuple es2 -> equal_list es1 es2
|
||||
| ( ETupleAccess { e = e1; index = id1; size = s1 },
|
||||
ETupleAccess { e = e2; index = id2; size = s2 } ) ->
|
||||
s1 = s2 && equal e1 e2 && id1 = id2
|
||||
| EArray es1, EArray es2 -> equal_list es1 es2
|
||||
| ELit l1, ELit l2 -> l1 = l2
|
||||
| EAbs (b1, tys1), EAbs (b2, tys2) ->
|
||||
equal_typ_list tys1 tys2
|
||||
| EAbs { binder = b1; tys = tys1 }, EAbs { binder = b2; tys = tys2 } ->
|
||||
Type.equal_list tys1 tys2
|
||||
&&
|
||||
let vars1, body1 = Bindlib.unmbind b1 in
|
||||
let body2 = Bindlib.msubst b2 (Array.map (fun x -> EVar x) vars1) in
|
||||
equal body1 body2
|
||||
| EApp (e1, args1), EApp (e2, args2) -> equal e1 e2 && equal_list args1 args2
|
||||
| EApp { f = e1; args = args1 }, EApp { f = e2; args = args2 } ->
|
||||
equal e1 e2 && equal_list args1 args2
|
||||
| EAssert e1, EAssert e2 -> equal e1 e2
|
||||
| EOp op1, EOp op2 -> equal_ops op1 op2
|
||||
| EDefault (exc1, def1, cons1), EDefault (exc2, def2, cons2) ->
|
||||
| EOp { op = op1; tys = tys1 }, EOp { op = op2; tys = tys2 } ->
|
||||
Operator.equal op1 op2 && Type.equal_list tys1 tys2
|
||||
| ( EDefault { excepts = exc1; just = def1; cons = cons1 },
|
||||
EDefault { excepts = exc2; just = def2; cons = cons2 } ) ->
|
||||
equal def1 def2 && equal cons1 cons2 && equal_list exc1 exc2
|
||||
| EIfThenElse (if1, then1, else1), EIfThenElse (if2, then2, else2) ->
|
||||
| ( EIfThenElse { cond = if1; etrue = then1; efalse = else1 },
|
||||
EIfThenElse { cond = if2; etrue = then2; efalse = else2 } ) ->
|
||||
equal if1 if2 && equal then1 then2 && equal else1 else2
|
||||
| ErrorOnEmpty e1, ErrorOnEmpty e2 -> equal e1 e2
|
||||
| EErrorOnEmpty e1, EErrorOnEmpty e2 -> equal e1 e2
|
||||
| ERaise ex1, ERaise ex2 -> equal_except ex1 ex2
|
||||
| ECatch (etry1, ex1, ewith1), ECatch (etry2, ex2, ewith2) ->
|
||||
| ( ECatch { body = etry1; exn = ex1; handler = ewith1 },
|
||||
ECatch { body = etry2; exn = ex2; handler = ewith2 } ) ->
|
||||
equal etry1 etry2 && equal_except ex1 ex2 && equal ewith1 ewith2
|
||||
| ELocation l1, ELocation l2 ->
|
||||
equal_location (Marked.mark Pos.no_pos l1) (Marked.mark Pos.no_pos l2)
|
||||
| EStruct (s1, fields1), EStruct (s2, fields2) ->
|
||||
StructName.equal s1 s2 && StructFieldMap.equal equal fields1 fields2
|
||||
| EStructAccess (e1, f1, s1), EStructAccess (e2, f2, s2) ->
|
||||
StructName.equal s1 s2 && StructFieldName.equal f1 f2 && equal e1 e2
|
||||
| EEnumInj (e1, c1, n1), EEnumInj (e2, c2, n2) ->
|
||||
| ( EStruct { name = s1; fields = fields1 },
|
||||
EStruct { name = s2; fields = fields2 } ) ->
|
||||
StructName.equal s1 s2 && StructField.Map.equal equal fields1 fields2
|
||||
| ( EDStructAccess { e = e1; field = f1; name_opt = s1 },
|
||||
EDStructAccess { e = e2; field = f2; name_opt = s2 } ) ->
|
||||
Option.equal StructName.equal s1 s2 && IdentName.equal f1 f2 && equal e1 e2
|
||||
| ( EStructAccess { e = e1; field = f1; name = s1 },
|
||||
EStructAccess { e = e2; field = f2; name = s2 } ) ->
|
||||
StructName.equal s1 s2 && StructField.equal f1 f2 && equal e1 e2
|
||||
| EInj { e = e1; cons = c1; name = n1 }, EInj { e = e2; cons = c2; name = n2 }
|
||||
->
|
||||
EnumName.equal n1 n2 && EnumConstructor.equal c1 c2 && equal e1 e2
|
||||
| EMatchS (e1, n1, cases1), EMatchS (e2, n2, cases2) ->
|
||||
| ( EMatch { e = e1; name = n1; cases = cases1 },
|
||||
EMatch { e = e2; name = n2; cases = cases2 } ) ->
|
||||
EnumName.equal n1 n2
|
||||
&& equal e1 e2
|
||||
&& EnumConstructorMap.equal equal cases1 cases2
|
||||
| EScopeCall (s1, fields1), EScopeCall (s2, fields2) ->
|
||||
ScopeName.equal s1 s2 && ScopeVarMap.equal equal fields1 fields2
|
||||
| ( ( EVar _ | ETuple _ | ETupleAccess _ | EInj _ | EMatch _ | EArray _
|
||||
| ELit _ | EAbs _ | EApp _ | EAssert _ | EOp _ | EDefault _
|
||||
| EIfThenElse _ | ErrorOnEmpty _ | ERaise _ | ECatch _ | ELocation _
|
||||
| EStruct _ | EStructAccess _ | EEnumInj _ | EMatchS _ | EScopeCall _ ),
|
||||
&& EnumConstructor.Map.equal equal cases1 cases2
|
||||
| ( EScopeCall { scope = s1; args = fields1 },
|
||||
EScopeCall { scope = s2; args = fields2 } ) ->
|
||||
ScopeName.equal s1 s2 && ScopeVar.Map.equal equal fields1 fields2
|
||||
| ( ( EVar _ | ETuple _ | ETupleAccess _ | EArray _ | ELit _ | EAbs _ | EApp _
|
||||
| EAssert _ | EOp _ | EDefault _ | EIfThenElse _ | EErrorOnEmpty _
|
||||
| ERaise _ | ECatch _ | ELocation _ | EStruct _ | EDStructAccess _
|
||||
| EStructAccess _ | EInj _ | EMatch _ | EScopeCall _ ),
|
||||
_ ) ->
|
||||
false
|
||||
|
||||
@ -623,72 +556,76 @@ let rec compare : type a. (a, _) gexpr -> (a, _) gexpr -> int =
|
||||
match[@ocamlformat "disable"] Marked.unmark e1, Marked.unmark e2 with
|
||||
| ELit l1, ELit l2 ->
|
||||
compare_lit l1 l2
|
||||
| EApp (f1, args1), EApp (f2, args2) ->
|
||||
| EApp {f=f1; args=args1}, EApp {f=f2; args=args2} ->
|
||||
compare f1 f2 @@< fun () ->
|
||||
List.compare compare args1 args2
|
||||
| EOp op1, EOp op2 ->
|
||||
compare_op op1 op2
|
||||
| EOp {op=op1; tys=tys1}, EOp {op=op2; tys=tys2} ->
|
||||
Operator.compare op1 op2 @@< fun () ->
|
||||
List.compare Type.compare tys1 tys2
|
||||
| EArray a1, EArray a2 ->
|
||||
List.compare compare a1 a2
|
||||
| EVar v1, EVar v2 ->
|
||||
Bindlib.compare_vars v1 v2
|
||||
| EAbs (binder1, typs1), EAbs (binder2, typs2) ->
|
||||
List.compare compare_typ typs1 typs2 @@< fun () ->
|
||||
| EAbs {binder=binder1; tys=typs1},
|
||||
EAbs {binder=binder2; tys=typs2} ->
|
||||
List.compare Type.compare typs1 typs2 @@< fun () ->
|
||||
let _, e1, e2 = Bindlib.unmbind2 binder1 binder2 in
|
||||
compare e1 e2
|
||||
| EIfThenElse (i1, t1, e1), EIfThenElse (i2, t2, e2) ->
|
||||
| EIfThenElse {cond=i1; etrue=t1; efalse=e1},
|
||||
EIfThenElse {cond=i2; etrue=t2; efalse=e2} ->
|
||||
compare i1 i2 @@< fun () ->
|
||||
compare t1 t2 @@< fun () ->
|
||||
compare e1 e2
|
||||
| ELocation l1, ELocation l2 ->
|
||||
compare_location (Marked.mark Pos.no_pos l1) (Marked.mark Pos.no_pos l2)
|
||||
| EStruct (name1, field_map1), EStruct (name2, field_map2) ->
|
||||
| EStruct {name=name1; fields=field_map1},
|
||||
EStruct {name=name2; fields=field_map2} ->
|
||||
StructName.compare name1 name2 @@< fun () ->
|
||||
StructFieldMap.compare compare field_map1 field_map2
|
||||
| EStructAccess (e1, field_name1, struct_name1),
|
||||
EStructAccess (e2, field_name2, struct_name2) ->
|
||||
StructField.Map.compare compare field_map1 field_map2
|
||||
| EDStructAccess {e=e1; field=field_name1; name_opt=struct_name1},
|
||||
EDStructAccess {e=e2; field=field_name2; name_opt=struct_name2} ->
|
||||
compare e1 e2 @@< fun () ->
|
||||
StructFieldName.compare field_name1 field_name2 @@< fun () ->
|
||||
IdentName.compare field_name1 field_name2 @@< fun () ->
|
||||
Option.compare StructName.compare struct_name1 struct_name2
|
||||
| EStructAccess {e=e1; field=field_name1; name=struct_name1},
|
||||
EStructAccess {e=e2; field=field_name2; name=struct_name2} ->
|
||||
compare e1 e2 @@< fun () ->
|
||||
StructField.compare field_name1 field_name2 @@< fun () ->
|
||||
StructName.compare struct_name1 struct_name2
|
||||
| EEnumInj (e1, cstr1, name1), EEnumInj (e2, cstr2, name2) ->
|
||||
compare e1 e2 @@< fun () ->
|
||||
| EMatch {e=e1; name=name1; cases=emap1},
|
||||
EMatch {e=e2; name=name2; cases=emap2} ->
|
||||
EnumName.compare name1 name2 @@< fun () ->
|
||||
EnumConstructor.compare cstr1 cstr2
|
||||
| EMatchS (e1, name1, emap1), EMatchS (e2, name2, emap2) ->
|
||||
compare e1 e2 @@< fun () ->
|
||||
EnumName.compare name1 name2 @@< fun () ->
|
||||
EnumConstructorMap.compare compare emap1 emap2
|
||||
| EScopeCall (name1, field_map1), EScopeCall (name2, field_map2) ->
|
||||
EnumConstructor.Map.compare compare emap1 emap2
|
||||
| EScopeCall {scope=name1; args=field_map1},
|
||||
EScopeCall {scope=name2; args=field_map2} ->
|
||||
ScopeName.compare name1 name2 @@< fun () ->
|
||||
ScopeVarMap.compare compare field_map1 field_map2
|
||||
| ETuple (es1, s1), ETuple (es2, s2) ->
|
||||
Option.compare StructName.compare s1 s2 @@< fun () ->
|
||||
ScopeVar.Map.compare compare field_map1 field_map2
|
||||
| ETuple es1, ETuple es2 ->
|
||||
List.compare compare es1 es2
|
||||
| ETupleAccess (e1, n1, s1, tys1), ETupleAccess (e2, n2, s2, tys2) ->
|
||||
Option.compare StructName.compare s1 s2 @@< fun () ->
|
||||
| ETupleAccess {e=e1; index=n1; size=s1},
|
||||
ETupleAccess {e=e2; index=n2; size=s2} ->
|
||||
Int.compare s1 s2 @@< fun () ->
|
||||
Int.compare n1 n2 @@< fun () ->
|
||||
List.compare compare_typ tys1 tys2 @@< fun () ->
|
||||
compare e1 e2
|
||||
| EInj (e1, n1, name1, ts1), EInj (e2, n2, name2, ts2) ->
|
||||
| EInj {e=e1; name=name1; cons=cons1},
|
||||
EInj {e=e2; name=name2; cons=cons2} ->
|
||||
EnumName.compare name1 name2 @@< fun () ->
|
||||
Int.compare n1 n2 @@< fun () ->
|
||||
List.compare compare_typ ts1 ts2 @@< fun () ->
|
||||
EnumConstructor.compare cons1 cons2 @@< fun () ->
|
||||
compare e1 e2
|
||||
| EMatch (e1, cases1, n1), EMatch (e2, cases2, n2) ->
|
||||
EnumName.compare n1 n2 @@< fun () ->
|
||||
compare e1 e2 @@< fun () ->
|
||||
List.compare compare cases1 cases2
|
||||
| EAssert e1, EAssert e2 ->
|
||||
compare e1 e2
|
||||
| EDefault (exs1, just1, cons1), EDefault (exs2, just2, cons2) ->
|
||||
| EDefault {excepts=exs1; just=just1; cons=cons1},
|
||||
EDefault {excepts=exs2; just=just2; cons=cons2} ->
|
||||
compare just1 just2 @@< fun () ->
|
||||
compare cons1 cons2 @@< fun () ->
|
||||
List.compare compare exs1 exs2
|
||||
| ErrorOnEmpty e1, ErrorOnEmpty e2 ->
|
||||
| EErrorOnEmpty e1, EErrorOnEmpty e2 ->
|
||||
compare e1 e2
|
||||
| ERaise ex1, ERaise ex2 ->
|
||||
compare_except ex1 ex2
|
||||
| ECatch (etry1, ex1, ewith1), ECatch (etry2, ex2, ewith2) ->
|
||||
| ECatch {body=etry1; exn=ex1; handler=ewith1},
|
||||
ECatch {body=etry2; exn=ex2; handler=ewith2} ->
|
||||
compare_except ex1 ex2 @@< fun () ->
|
||||
compare etry1 etry2 @@< fun () ->
|
||||
compare ewith1 ewith2
|
||||
@ -701,34 +638,33 @@ let rec compare : type a. (a, _) gexpr -> (a, _) gexpr -> int =
|
||||
| EIfThenElse _, _ -> -1 | _, EIfThenElse _ -> 1
|
||||
| ELocation _, _ -> -1 | _, ELocation _ -> 1
|
||||
| EStruct _, _ -> -1 | _, EStruct _ -> 1
|
||||
| EDStructAccess _, _ -> -1 | _, EDStructAccess _ -> 1
|
||||
| EStructAccess _, _ -> -1 | _, EStructAccess _ -> 1
|
||||
| EEnumInj _, _ -> -1 | _, EEnumInj _ -> 1
|
||||
| EMatchS _, _ -> -1 | _, EMatchS _ -> 1
|
||||
| EMatch _, _ -> -1 | _, EMatch _ -> 1
|
||||
| EScopeCall _, _ -> -1 | _, EScopeCall _ -> 1
|
||||
| ETuple _, _ -> -1 | _, ETuple _ -> 1
|
||||
| ETupleAccess _, _ -> -1 | _, ETupleAccess _ -> 1
|
||||
| EInj _, _ -> -1 | _, EInj _ -> 1
|
||||
| EMatch _, _ -> -1 | _, EMatch _ -> 1
|
||||
| EAssert _, _ -> -1 | _, EAssert _ -> 1
|
||||
| EDefault _, _ -> -1 | _, EDefault _ -> 1
|
||||
| ErrorOnEmpty _, _ -> . | _, ErrorOnEmpty _ -> .
|
||||
| EErrorOnEmpty _, _ -> . | _, EErrorOnEmpty _ -> .
|
||||
| ERaise _, _ -> -1 | _, ERaise _ -> 1
|
||||
| ECatch _, _ -> . | _, ECatch _ -> .
|
||||
|
||||
let rec free_vars : type a. (a, 't) gexpr -> (a, 't) gexpr Var.Set.t = function
|
||||
| EVar v, _ -> Var.Set.singleton v
|
||||
| EAbs (binder, _), _ ->
|
||||
| EAbs { binder; _ }, _ ->
|
||||
let vs, body = Bindlib.unmbind binder in
|
||||
Array.fold_right Var.Set.remove vs (free_vars body)
|
||||
| e -> shallow_fold (fun e -> Var.Set.union (free_vars e)) e Var.Set.empty
|
||||
|
||||
let remove_logging_calls e =
|
||||
let rec f () e =
|
||||
let rec f e =
|
||||
match Marked.unmark e with
|
||||
| EApp ((EOp (Unop (Log _)), _), [arg]) -> map () ~f arg
|
||||
| _ -> map () ~f e
|
||||
| EApp { f = EOp { op = Log _; _ }, _; args = [arg] } -> map ~f arg
|
||||
| _ -> map ~f e
|
||||
in
|
||||
f () e
|
||||
f e
|
||||
|
||||
let format ?debug decl_ctx ppf e = Print.expr ?debug decl_ctx ppf e
|
||||
|
||||
@ -736,36 +672,35 @@ let rec size : type a. (a, 't) gexpr -> int =
|
||||
fun e ->
|
||||
match Marked.unmark e with
|
||||
| EVar _ | ELit _ | EOp _ -> 1
|
||||
| ETuple (args, _) -> List.fold_left (fun acc arg -> acc + size arg) 1 args
|
||||
| ETuple args -> List.fold_left (fun acc arg -> acc + size arg) 1 args
|
||||
| EArray args -> List.fold_left (fun acc arg -> acc + size arg) 1 args
|
||||
| ETupleAccess (e1, _, _, _) -> size e1 + 1
|
||||
| EInj (e1, _, _, _) -> size e1 + 1
|
||||
| EAssert e1 -> size e1 + 1
|
||||
| ErrorOnEmpty e1 -> size e1 + 1
|
||||
| EMatch (arg, args, _) ->
|
||||
List.fold_left (fun acc arg -> acc + size arg) (1 + size arg) args
|
||||
| EApp (arg, args) ->
|
||||
List.fold_left (fun acc arg -> acc + size arg) (1 + size arg) args
|
||||
| EAbs (binder, _) ->
|
||||
| ETupleAccess { e; _ } -> size e + 1
|
||||
| EInj { e; _ } -> size e + 1
|
||||
| EAssert e -> size e + 1
|
||||
| EErrorOnEmpty e -> size e + 1
|
||||
| EApp { f; args } ->
|
||||
List.fold_left (fun acc arg -> acc + size arg) (1 + size f) args
|
||||
| EAbs { binder; _ } ->
|
||||
let _, body = Bindlib.unmbind binder in
|
||||
1 + size body
|
||||
| EIfThenElse (e1, e2, e3) -> 1 + size e1 + size e2 + size e3
|
||||
| EDefault (exceptions, just, cons) ->
|
||||
| EIfThenElse { cond; etrue; efalse } ->
|
||||
1 + size cond + size etrue + size efalse
|
||||
| EDefault { excepts; just; cons } ->
|
||||
List.fold_left
|
||||
(fun acc except -> acc + size except)
|
||||
(1 + size just + size cons)
|
||||
exceptions
|
||||
excepts
|
||||
| ERaise _ -> 1
|
||||
| ECatch (etry, _, ewith) -> 1 + size etry + size ewith
|
||||
| ECatch { body; handler; _ } -> 1 + size body + size handler
|
||||
| ELocation _ -> 1
|
||||
| EStruct (_, fields) ->
|
||||
StructFieldMap.fold (fun _ e acc -> acc + 1 + size e) fields 0
|
||||
| EStructAccess (e1, _, _) -> 1 + size e1
|
||||
| EEnumInj (e1, _, _) -> 1 + size e1
|
||||
| EMatchS (e1, _, cases) ->
|
||||
EnumConstructorMap.fold (fun _ e acc -> acc + 1 + size e) cases (size e1)
|
||||
| EScopeCall (_, fields) ->
|
||||
ScopeVarMap.fold (fun _ e acc -> acc + 1 + size e) fields 1
|
||||
| EStruct { fields; _ } ->
|
||||
StructField.Map.fold (fun _ e acc -> acc + 1 + size e) fields 0
|
||||
| EDStructAccess { e; _ } -> 1 + size e
|
||||
| EStructAccess { e; _ } -> 1 + size e
|
||||
| EMatch { e; cases; _ } ->
|
||||
EnumConstructor.Map.fold (fun _ e acc -> acc + 1 + size e) cases (size e)
|
||||
| EScopeCall { args; _ } ->
|
||||
ScopeVar.Map.fold (fun _ e acc -> acc + 1 + size e) args 1
|
||||
|
||||
(* - Expression building helpers - *)
|
||||
|
||||
@ -794,7 +729,7 @@ let make_app e u pos =
|
||||
(fun tf tx ->
|
||||
match Marked.unmark tf with
|
||||
| TArrow (tx', tr) ->
|
||||
assert (unifiable tx.ty tx');
|
||||
assert (Type.unifiable tx.ty tx');
|
||||
(* wrong arg type *)
|
||||
tr
|
||||
| TAny -> tf
|
||||
@ -818,50 +753,35 @@ let make_let_in x tau e1 e2 mpos =
|
||||
let make_multiple_let_in xs taus e1s e2 mpos =
|
||||
make_app (make_abs xs e2 taus mpos) e1s (pos e2)
|
||||
|
||||
let make_default_unboxed exceptions just cons =
|
||||
let make_default_unboxed excepts just cons =
|
||||
let rec bool_value = function
|
||||
| ELit (LBool b), _ -> Some b
|
||||
| EApp ((EOp (Unop (Log (l, _))), _), [e]), _
|
||||
| EApp { f = EOp { op = Log (l, _); _ }, _; args = [e]; _ }, _
|
||||
when l <> PosRecordIfTrueBool
|
||||
(* we don't remove the log calls corresponding to source code
|
||||
definitions !*) ->
|
||||
bool_value e
|
||||
| _ -> None
|
||||
in
|
||||
match exceptions, bool_value just, cons with
|
||||
match excepts, bool_value just, cons with
|
||||
| [], Some true, cons -> Marked.unmark cons
|
||||
| exceptions, Some true, (EDefault ([], just, cons), _) ->
|
||||
EDefault (exceptions, just, cons)
|
||||
| excepts, Some true, (EDefault { excepts = []; just; cons }, _) ->
|
||||
EDefault { excepts; just; cons }
|
||||
| [except], Some false, _ -> Marked.unmark except
|
||||
| exceptions, _, cons -> EDefault (exceptions, just, cons)
|
||||
| excepts, _, cons -> EDefault { excepts; just; cons }
|
||||
|
||||
let make_default exceptions just cons =
|
||||
Box.app2n just cons exceptions
|
||||
@@ fun just cons exceptions -> make_default_unboxed exceptions just cons
|
||||
|
||||
let make_tuple el structname m0 =
|
||||
let make_tuple el m0 =
|
||||
match el with
|
||||
| [] ->
|
||||
etuple [] structname
|
||||
(with_ty m0
|
||||
(match structname with
|
||||
| Some n -> TStruct n, mark_pos m0
|
||||
| None -> TTuple [], mark_pos m0))
|
||||
| [] -> etuple [] (with_ty m0 (TTuple [], mark_pos m0))
|
||||
| el ->
|
||||
let m =
|
||||
fold_marks
|
||||
(fun posl -> List.hd posl)
|
||||
(fun ml ->
|
||||
let pos = (List.hd ml).pos in
|
||||
match structname with
|
||||
| Some n -> TStruct n, pos
|
||||
| None -> TTuple (List.map (fun t -> t.ty) ml), pos)
|
||||
(fun ml -> TTuple (List.map (fun t -> t.ty) ml), (List.hd ml).pos)
|
||||
(List.map (fun e -> Marked.get_mark e) el)
|
||||
in
|
||||
etuple el structname m
|
||||
|
||||
let make_struct fieldmap structname m =
|
||||
let fields =
|
||||
List.rev (StructFieldMap.fold (fun _ e acc -> e :: acc) fieldmap [])
|
||||
in
|
||||
make_tuple fields (Some structname) m
|
||||
etuple el m
|
||||
|
@ -17,7 +17,7 @@
|
||||
|
||||
(** Functions handling the expressions of [shared_ast] *)
|
||||
|
||||
open Utils
|
||||
open Catala_utils
|
||||
open Definitions
|
||||
|
||||
(** {2 Boxed constructors} *)
|
||||
@ -43,34 +43,10 @@ val subst :
|
||||
('a, 't) gexpr list ->
|
||||
('a, 't) gexpr
|
||||
|
||||
val etuple :
|
||||
(([< dcalc | lcalc ] as 'a), 't) boxed_gexpr list ->
|
||||
StructName.t option ->
|
||||
't ->
|
||||
('a, 't) boxed_gexpr
|
||||
val etuple : (lcalc, 't) boxed_gexpr list -> 't -> (lcalc, 't) boxed_gexpr
|
||||
|
||||
val etupleaccess :
|
||||
(([< dcalc | lcalc ] as 'a), 't) boxed_gexpr ->
|
||||
int ->
|
||||
StructName.t option ->
|
||||
typ list ->
|
||||
't ->
|
||||
('a, 't) boxed_gexpr
|
||||
|
||||
val einj :
|
||||
(([< dcalc | lcalc ] as 'a), 't) boxed_gexpr ->
|
||||
int ->
|
||||
EnumName.t ->
|
||||
typ list ->
|
||||
't ->
|
||||
('a, 't) boxed_gexpr
|
||||
|
||||
val ematch :
|
||||
(([< dcalc | lcalc ] as 'a), 't) boxed_gexpr ->
|
||||
('a, 't) boxed_gexpr list ->
|
||||
EnumName.t ->
|
||||
't ->
|
||||
('a, 't) boxed_gexpr
|
||||
(lcalc, 't) boxed_gexpr -> int -> int -> 't -> (lcalc, 't) boxed_gexpr
|
||||
|
||||
val earray : ('a any, 't) boxed_gexpr list -> 't -> ('a, 't) boxed_gexpr
|
||||
val elit : 'a any glit -> 't -> ('a, 't) boxed_gexpr
|
||||
@ -90,7 +66,7 @@ val eapp :
|
||||
val eassert :
|
||||
(([< dcalc | lcalc ] as 'a), 't) boxed_gexpr -> 't -> ('a, 't) boxed_gexpr
|
||||
|
||||
val eop : operator -> 't -> (_ any, 't) boxed_gexpr
|
||||
val eop : ('a any, 'k) operator -> typ list -> 't -> ('a, 't) boxed_gexpr
|
||||
|
||||
val edefault :
|
||||
(([< desugared | scopelang | dcalc ] as 'a), 't) boxed_gexpr list ->
|
||||
@ -125,34 +101,41 @@ val elocation :
|
||||
|
||||
val estruct :
|
||||
StructName.t ->
|
||||
(([< desugared | scopelang ] as 'a), 't) boxed_gexpr StructFieldMap.t ->
|
||||
('a any, 't) boxed_gexpr StructField.Map.t ->
|
||||
't ->
|
||||
('a, 't) boxed_gexpr
|
||||
|
||||
val edstructaccess :
|
||||
(desugared, 't) boxed_gexpr ->
|
||||
IdentName.t ->
|
||||
StructName.t option ->
|
||||
't ->
|
||||
(desugared, 't) boxed_gexpr
|
||||
|
||||
val estructaccess :
|
||||
(([< desugared | scopelang ] as 'a), 't) boxed_gexpr ->
|
||||
StructFieldName.t ->
|
||||
(([< scopelang | dcalc | lcalc ] as 'a), 't) boxed_gexpr ->
|
||||
StructField.t ->
|
||||
StructName.t ->
|
||||
't ->
|
||||
('a, 't) boxed_gexpr
|
||||
|
||||
val eenuminj :
|
||||
(([< desugared | scopelang ] as 'a), 't) boxed_gexpr ->
|
||||
val einj :
|
||||
('a any, 't) boxed_gexpr ->
|
||||
EnumConstructor.t ->
|
||||
EnumName.t ->
|
||||
't ->
|
||||
('a, 't) boxed_gexpr
|
||||
|
||||
val ematchs :
|
||||
(([< desugared | scopelang ] as 'a), 't) boxed_gexpr ->
|
||||
val ematch :
|
||||
('a any, 't) boxed_gexpr ->
|
||||
EnumName.t ->
|
||||
('a, 't) boxed_gexpr EnumConstructorMap.t ->
|
||||
('a, 't) boxed_gexpr EnumConstructor.Map.t ->
|
||||
't ->
|
||||
('a, 't) boxed_gexpr
|
||||
|
||||
val escopecall :
|
||||
ScopeName.t ->
|
||||
(([< desugared | scopelang ] as 'a), 't) boxed_gexpr ScopeVarMap.t ->
|
||||
(([< desugared | scopelang ] as 'a), 't) boxed_gexpr ScopeVar.Map.t ->
|
||||
't ->
|
||||
('a, 't) boxed_gexpr
|
||||
|
||||
@ -194,28 +177,25 @@ val untype : ('a, 'm mark) gexpr -> ('a, untyped mark) boxed_gexpr
|
||||
(** {2 Traversal functions} *)
|
||||
|
||||
val map :
|
||||
'ctx ->
|
||||
f:('ctx -> ('a, 't1) gexpr -> ('a, 't2) boxed_gexpr) ->
|
||||
f:(('a, 't1) gexpr -> ('a, 't2) boxed_gexpr) ->
|
||||
(('a, 't1) naked_gexpr, 't2) Marked.t ->
|
||||
('a, 't2) boxed_gexpr
|
||||
(** Flat (non-recursive) mapping on expressions.
|
||||
(** Shallow mapping on expressions (non recursive): applies the given function
|
||||
to all sub-terms of the given expression, and rebuilds the node.
|
||||
|
||||
If you want to apply a map transform to an expression, you can save up
|
||||
writing a painful match over all the cases of the AST. For instance, if you
|
||||
want to remove all errors on empty, you can write
|
||||
When applying a map transform to an expression, this avoids expliciting all
|
||||
cases that remain unchanged. For instance, if you want to remove all errors
|
||||
on empty, you can write
|
||||
|
||||
{[
|
||||
let remove_error_empty =
|
||||
let rec f () e =
|
||||
let rec f e =
|
||||
match Marked.unmark e with
|
||||
| ErrorOnEmpty e1 -> Expr.map () f e1
|
||||
| _ -> Expr.map () f e
|
||||
| ErrorOnEmpty e1 -> Expr.map f e1
|
||||
| _ -> Expr.map f e
|
||||
in
|
||||
f () e
|
||||
]}
|
||||
|
||||
The first argument of map_expr is an optional context that you can carry
|
||||
around during your map traversal. *)
|
||||
f e
|
||||
]} *)
|
||||
|
||||
val map_top_down :
|
||||
f:(('a, 't1) gexpr -> (('a, 't1) naked_gexpr, 't2) Marked.t) ->
|
||||
@ -231,7 +211,42 @@ val shallow_fold :
|
||||
(('a, 't) gexpr -> 'acc -> 'acc) -> ('a, 't) gexpr -> 'acc -> 'acc
|
||||
(** Applies a function on all sub-terms of the given expression. Does not
|
||||
recurse, and doesn't open binders. Useful as helper for recursive calls
|
||||
within traversal functions *)
|
||||
within traversal functions. This can be used to compute free variables with
|
||||
e.g.:
|
||||
|
||||
{[
|
||||
let rec free_vars = function
|
||||
| EVar v, _ -> Var.Set.singleton v
|
||||
| EAbs { binder; _ }, _ ->
|
||||
let vs, body = Bindlib.unmbind binder in
|
||||
Array.fold_right Var.Set.remove vs (free_vars body)
|
||||
| e ->
|
||||
shallow_fold (fun e -> Var.Set.union (free_vars e)) e Var.Set.empty
|
||||
]} *)
|
||||
|
||||
val map_gather :
|
||||
acc:'acc ->
|
||||
join:('acc -> 'acc -> 'acc) ->
|
||||
f:(('a, 't1) gexpr -> 'acc * ('a, 't2) boxed_gexpr) ->
|
||||
(('a, 't1) naked_gexpr, 't2) Marked.t ->
|
||||
'acc * ('a, 't2) boxed_gexpr
|
||||
(** Shallow mapping similar to [map], but additionally allows to gather an
|
||||
accumulator bottom-up. [acc] is the accumulator value returned on terminal
|
||||
nodes, and [join] is used to merge accumulators from the different sub-terms
|
||||
of an expression. [acc] is assumed to be a neutral element for [join].
|
||||
Typically used with a set of variables used in the rewrite:
|
||||
|
||||
{[
|
||||
let rec rewrite e =
|
||||
match Marked.unmark e with
|
||||
| Specific_case ->
|
||||
Var.Set.singleton x, some_rewrite_fun e
|
||||
| _ ->
|
||||
Expr.map_gather ~acc:Var.Set.empty ~join:Var.Set.union ~f:rewrite e
|
||||
}]
|
||||
|
||||
|
||||
See [Lcalc.closure_conversion] for a real-world example. *)
|
||||
|
||||
(** {2 Expression building helpers} *)
|
||||
|
||||
@ -289,21 +304,10 @@ val make_default :
|
||||
- [<ex | false :- _>], when [ex] is a single exception, is rewritten as [ex] *)
|
||||
|
||||
val make_tuple :
|
||||
(([< dcalc | lcalc ] as 'a), 'm mark) boxed_gexpr list ->
|
||||
StructName.t option ->
|
||||
'm mark ->
|
||||
('a, 'm mark) boxed_gexpr
|
||||
(lcalc, 'm mark) boxed_gexpr list -> 'm mark -> (lcalc, 'm mark) boxed_gexpr
|
||||
(** Builds a tuple; the mark argument is only used as witness and for position
|
||||
when building 0-uples *)
|
||||
|
||||
val make_struct :
|
||||
(([< dcalc | lcalc ] as 'a), 'm mark) boxed_gexpr StructFieldMap.t ->
|
||||
StructName.t ->
|
||||
'm mark ->
|
||||
('a, 'm mark) boxed_gexpr
|
||||
(** Builds the tuple of values for the given struct with proper ordering,
|
||||
assuming the structfieldmap contains the fields defined for structname *)
|
||||
|
||||
(** {2 Transformations} *)
|
||||
|
||||
val remove_logging_calls : ('a any, 't) gexpr -> ('a, 't) boxed_gexpr
|
||||
@ -331,8 +335,6 @@ val compare : ('a, 't) gexpr -> ('a, 't) gexpr -> int
|
||||
(** Standard comparison function, suitable for e.g. [Set.Make]. Ignores position
|
||||
information *)
|
||||
|
||||
val equal_typ : typ -> typ -> bool
|
||||
val compare_typ : typ -> typ -> int
|
||||
val is_value : ('a any, 't) gexpr -> bool
|
||||
val free_vars : ('a any, 't) gexpr -> ('a, 't) gexpr Var.Set.t
|
||||
|
||||
@ -363,10 +365,10 @@ module Box : sig
|
||||
a separate argument. *)
|
||||
|
||||
val app1 :
|
||||
('a, 't) boxed_gexpr ->
|
||||
(('a, 't) gexpr -> ('a, 't) naked_gexpr) ->
|
||||
't ->
|
||||
('a, 't) boxed_gexpr
|
||||
('a, 't1) boxed_gexpr ->
|
||||
(('a, 't1) gexpr -> ('a, 't2) naked_gexpr) ->
|
||||
't2 ->
|
||||
('a, 't2) boxed_gexpr
|
||||
|
||||
val app2 :
|
||||
('a, 't) boxed_gexpr ->
|
||||
|
582
compiler/shared_ast/operator.ml
Normal file
582
compiler/shared_ast/operator.ml
Normal file
@ -0,0 +1,582 @@
|
||||
(* This file is part of the Catala compiler, a specification language for tax
|
||||
and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
|
||||
Louis Gesbert <louis.gesbert@inria.fr>
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not
|
||||
use this file except in compliance with the License. You may obtain a copy of
|
||||
the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
License for the specific language governing permissions and limitations under
|
||||
the License. *)
|
||||
|
||||
open Catala_utils
|
||||
open Definitions
|
||||
include Definitions.Op
|
||||
|
||||
let name : type a k. (a, k) t -> string = function
|
||||
| Not -> "o_not"
|
||||
| Length -> "o_length"
|
||||
| GetDay -> "o_getDay"
|
||||
| GetMonth -> "o_getMonth"
|
||||
| GetYear -> "o_getYear"
|
||||
| FirstDayOfMonth -> "o_firstDayOfMonth"
|
||||
| LastDayOfMonth -> "o_lastDayOfMonth"
|
||||
| Log _ -> "o_log"
|
||||
| Minus -> "o_minus"
|
||||
| Minus_int -> "o_minus_int"
|
||||
| Minus_rat -> "o_minus_rat"
|
||||
| Minus_mon -> "o_minus_mon"
|
||||
| Minus_dur -> "o_minus_dur"
|
||||
| ToRat -> "o_torat"
|
||||
| ToRat_int -> "o_torat_int"
|
||||
| ToRat_mon -> "o_torat_mon"
|
||||
| ToMoney -> "o_tomoney"
|
||||
| ToMoney_rat -> "o_tomoney_rat"
|
||||
| Round -> "o_round"
|
||||
| Round_rat -> "o_round_rat"
|
||||
| Round_mon -> "o_round_mon"
|
||||
| And -> "o_and"
|
||||
| Or -> "o_or"
|
||||
| Xor -> "o_xor"
|
||||
| Eq -> "o_eq"
|
||||
| Map -> "o_map"
|
||||
| Concat -> "o_concat"
|
||||
| Filter -> "o_filter"
|
||||
| Reduce -> "o_reduce"
|
||||
| Add -> "o_add"
|
||||
| Add_int_int -> "o_add_int_int"
|
||||
| Add_rat_rat -> "o_add_rat_rat"
|
||||
| Add_mon_mon -> "o_add_mon_mon"
|
||||
| Add_dat_dur -> "o_add_dat_dur"
|
||||
| Add_dur_dur -> "o_add_dur_dur"
|
||||
| Sub -> "o_sub"
|
||||
| Sub_int_int -> "o_sub_int_int"
|
||||
| Sub_rat_rat -> "o_sub_rat_rat"
|
||||
| Sub_mon_mon -> "o_sub_mon_mon"
|
||||
| Sub_dat_dat -> "o_sub_dat_dat"
|
||||
| Sub_dat_dur -> "o_sub_dat_dur"
|
||||
| Sub_dur_dur -> "o_sub_dur_dur"
|
||||
| Mult -> "o_mult"
|
||||
| Mult_int_int -> "o_mult_int_int"
|
||||
| Mult_rat_rat -> "o_mult_rat_rat"
|
||||
| Mult_mon_rat -> "o_mult_mon_rat"
|
||||
| Mult_dur_int -> "o_mult_dur_int"
|
||||
| Div -> "o_div"
|
||||
| Div_int_int -> "o_div_int_int"
|
||||
| Div_rat_rat -> "o_div_rat_rat"
|
||||
| Div_mon_mon -> "o_div_mon_mon"
|
||||
| Div_mon_rat -> "o_div_mon_mon"
|
||||
| Lt -> "o_lt"
|
||||
| Lt_int_int -> "o_lt_int_int"
|
||||
| Lt_rat_rat -> "o_lt_rat_rat"
|
||||
| Lt_mon_mon -> "o_lt_mon_mon"
|
||||
| Lt_dur_dur -> "o_lt_dur_dur"
|
||||
| Lt_dat_dat -> "o_lt_dat_dat"
|
||||
| Lte -> "o_lte"
|
||||
| Lte_int_int -> "o_lte_int_int"
|
||||
| Lte_rat_rat -> "o_lte_rat_rat"
|
||||
| Lte_mon_mon -> "o_lte_mon_mon"
|
||||
| Lte_dur_dur -> "o_lte_dur_dur"
|
||||
| Lte_dat_dat -> "o_lte_dat_dat"
|
||||
| Gt -> "o_gt"
|
||||
| Gt_int_int -> "o_gt_int_int"
|
||||
| Gt_rat_rat -> "o_gt_rat_rat"
|
||||
| Gt_mon_mon -> "o_gt_mon_mon"
|
||||
| Gt_dur_dur -> "o_gt_dur_dur"
|
||||
| Gt_dat_dat -> "o_gt_dat_dat"
|
||||
| Gte -> "o_gte"
|
||||
| Gte_int_int -> "o_gte_int_int"
|
||||
| Gte_rat_rat -> "o_gte_rat_rat"
|
||||
| Gte_mon_mon -> "o_gte_mon_mon"
|
||||
| Gte_dur_dur -> "o_gte_dur_dur"
|
||||
| Gte_dat_dat -> "o_gte_dat_dat"
|
||||
| Eq_int_int -> "o_eq_int_int"
|
||||
| Eq_rat_rat -> "o_eq_rat_rat"
|
||||
| Eq_mon_mon -> "o_eq_mon_mon"
|
||||
| Eq_dur_dur -> "o_eq_dur_dur"
|
||||
| Eq_dat_dat -> "o_eq_dat_dat"
|
||||
| Fold -> "o_fold"
|
||||
|
||||
let compare_log_entries l1 l2 =
|
||||
match l1, l2 with
|
||||
| VarDef t1, VarDef t2 -> Type.compare (t1, Pos.no_pos) (t2, Pos.no_pos)
|
||||
| BeginCall, BeginCall
|
||||
| EndCall, EndCall
|
||||
| PosRecordIfTrueBool, PosRecordIfTrueBool ->
|
||||
0
|
||||
| VarDef _, _ -> -1
|
||||
| _, VarDef _ -> 1
|
||||
| BeginCall, _ -> -1
|
||||
| _, BeginCall -> 1
|
||||
| EndCall, _ -> -1
|
||||
| _, EndCall -> 1
|
||||
| PosRecordIfTrueBool, _ -> .
|
||||
| _, PosRecordIfTrueBool -> .
|
||||
|
||||
let compare (type a k a2 k2) (t1 : (a, k) t) (t2 : (a2, k2) t) =
|
||||
match[@ocamlformat "disable"] t1, t2 with
|
||||
| Log (l1, info1), Log (l2, info2) -> (
|
||||
match compare_log_entries l1 l2 with
|
||||
| 0 -> List.compare Uid.MarkedString.compare info1 info2
|
||||
| n -> n)
|
||||
| Not, Not
|
||||
| Length, Length
|
||||
| GetDay, GetDay
|
||||
| GetMonth, GetMonth
|
||||
| GetYear, GetYear
|
||||
| FirstDayOfMonth, FirstDayOfMonth
|
||||
| LastDayOfMonth, LastDayOfMonth
|
||||
| Minus, Minus
|
||||
| Minus_int, Minus_int
|
||||
| Minus_rat, Minus_rat
|
||||
| Minus_mon, Minus_mon
|
||||
| Minus_dur, Minus_dur
|
||||
| ToRat, ToRat
|
||||
| ToRat_int, ToRat_int
|
||||
| ToRat_mon, ToRat_mon
|
||||
| ToMoney, ToMoney
|
||||
| ToMoney_rat, ToMoney_rat
|
||||
| Round, Round
|
||||
| Round_rat, Round_rat
|
||||
| Round_mon, Round_mon
|
||||
| And, And
|
||||
| Or, Or
|
||||
| Xor, Xor
|
||||
| Eq, Eq
|
||||
| Map, Map
|
||||
| Concat, Concat
|
||||
| Filter, Filter
|
||||
| Reduce, Reduce
|
||||
| Add, Add
|
||||
| Add_int_int, Add_int_int
|
||||
| Add_rat_rat, Add_rat_rat
|
||||
| Add_mon_mon, Add_mon_mon
|
||||
| Add_dat_dur, Add_dat_dur
|
||||
| Add_dur_dur, Add_dur_dur
|
||||
| Sub, Sub
|
||||
| Sub_int_int, Sub_int_int
|
||||
| Sub_rat_rat, Sub_rat_rat
|
||||
| Sub_mon_mon, Sub_mon_mon
|
||||
| Sub_dat_dat, Sub_dat_dat
|
||||
| Sub_dat_dur, Sub_dat_dur
|
||||
| Sub_dur_dur, Sub_dur_dur
|
||||
| Mult, Mult
|
||||
| Mult_int_int, Mult_int_int
|
||||
| Mult_rat_rat, Mult_rat_rat
|
||||
| Mult_mon_rat, Mult_mon_rat
|
||||
| Mult_dur_int, Mult_dur_int
|
||||
| Div, Div
|
||||
| Div_int_int, Div_int_int
|
||||
| Div_rat_rat, Div_rat_rat
|
||||
| Div_mon_mon, Div_mon_mon
|
||||
| Div_mon_rat, Div_mon_rat
|
||||
| Lt, Lt
|
||||
| Lt_int_int, Lt_int_int
|
||||
| Lt_rat_rat, Lt_rat_rat
|
||||
| Lt_mon_mon, Lt_mon_mon
|
||||
| Lt_dat_dat, Lt_dat_dat
|
||||
| Lt_dur_dur, Lt_dur_dur
|
||||
| Lte, Lte
|
||||
| Lte_int_int, Lte_int_int
|
||||
| Lte_rat_rat, Lte_rat_rat
|
||||
| Lte_mon_mon, Lte_mon_mon
|
||||
| Lte_dat_dat, Lte_dat_dat
|
||||
| Lte_dur_dur, Lte_dur_dur
|
||||
| Gt, Gt
|
||||
| Gt_int_int, Gt_int_int
|
||||
| Gt_rat_rat, Gt_rat_rat
|
||||
| Gt_mon_mon, Gt_mon_mon
|
||||
| Gt_dat_dat, Gt_dat_dat
|
||||
| Gt_dur_dur, Gt_dur_dur
|
||||
| Gte, Gte
|
||||
| Gte_int_int, Gte_int_int
|
||||
| Gte_rat_rat, Gte_rat_rat
|
||||
| Gte_mon_mon, Gte_mon_mon
|
||||
| Gte_dat_dat, Gte_dat_dat
|
||||
| Gte_dur_dur, Gte_dur_dur
|
||||
| Eq_int_int, Eq_int_int
|
||||
| Eq_rat_rat, Eq_rat_rat
|
||||
| Eq_mon_mon, Eq_mon_mon
|
||||
| Eq_dat_dat, Eq_dat_dat
|
||||
| Eq_dur_dur, Eq_dur_dur
|
||||
| Fold, Fold -> 0
|
||||
| Not, _ -> -1 | _, Not -> 1
|
||||
| Length, _ -> -1 | _, Length -> 1
|
||||
| GetDay, _ -> -1 | _, GetDay -> 1
|
||||
| GetMonth, _ -> -1 | _, GetMonth -> 1
|
||||
| GetYear, _ -> -1 | _, GetYear -> 1
|
||||
| FirstDayOfMonth, _ -> -1 | _, FirstDayOfMonth -> 1
|
||||
| LastDayOfMonth, _ -> -1 | _, LastDayOfMonth -> 1
|
||||
| Log _, _ -> -1 | _, Log _ -> 1
|
||||
| Minus, _ -> -1 | _, Minus -> 1
|
||||
| Minus_int, _ -> -1 | _, Minus_int -> 1
|
||||
| Minus_rat, _ -> -1 | _, Minus_rat -> 1
|
||||
| Minus_mon, _ -> -1 | _, Minus_mon -> 1
|
||||
| Minus_dur, _ -> -1 | _, Minus_dur -> 1
|
||||
| ToRat, _ -> -1 | _, ToRat -> 1
|
||||
| ToRat_int, _ -> -1 | _, ToRat_int -> 1
|
||||
| ToRat_mon, _ -> -1 | _, ToRat_mon -> 1
|
||||
| ToMoney, _ -> -1 | _, ToMoney -> 1
|
||||
| ToMoney_rat, _ -> -1 | _, ToMoney_rat -> 1
|
||||
| Round, _ -> -1 | _, Round -> 1
|
||||
| Round_rat, _ -> -1 | _, Round_rat -> 1
|
||||
| Round_mon, _ -> -1 | _, Round_mon -> 1
|
||||
| And, _ -> -1 | _, And -> 1
|
||||
| Or, _ -> -1 | _, Or -> 1
|
||||
| Xor, _ -> -1 | _, Xor -> 1
|
||||
| Eq, _ -> -1 | _, Eq -> 1
|
||||
| Map, _ -> -1 | _, Map -> 1
|
||||
| Concat, _ -> -1 | _, Concat -> 1
|
||||
| Filter, _ -> -1 | _, Filter -> 1
|
||||
| Reduce, _ -> -1 | _, Reduce -> 1
|
||||
| Add, _ -> -1 | _, Add -> 1
|
||||
| Add_int_int, _ -> -1 | _, Add_int_int -> 1
|
||||
| Add_rat_rat, _ -> -1 | _, Add_rat_rat -> 1
|
||||
| Add_mon_mon, _ -> -1 | _, Add_mon_mon -> 1
|
||||
| Add_dat_dur, _ -> -1 | _, Add_dat_dur -> 1
|
||||
| Add_dur_dur, _ -> -1 | _, Add_dur_dur -> 1
|
||||
| Sub, _ -> -1 | _, Sub -> 1
|
||||
| Sub_int_int, _ -> -1 | _, Sub_int_int -> 1
|
||||
| Sub_rat_rat, _ -> -1 | _, Sub_rat_rat -> 1
|
||||
| Sub_mon_mon, _ -> -1 | _, Sub_mon_mon -> 1
|
||||
| Sub_dat_dat, _ -> -1 | _, Sub_dat_dat -> 1
|
||||
| Sub_dat_dur, _ -> -1 | _, Sub_dat_dur -> 1
|
||||
| Sub_dur_dur, _ -> -1 | _, Sub_dur_dur -> 1
|
||||
| Mult, _ -> -1 | _, Mult -> 1
|
||||
| Mult_int_int, _ -> -1 | _, Mult_int_int -> 1
|
||||
| Mult_rat_rat, _ -> -1 | _, Mult_rat_rat -> 1
|
||||
| Mult_mon_rat, _ -> -1 | _, Mult_mon_rat -> 1
|
||||
| Mult_dur_int, _ -> -1 | _, Mult_dur_int -> 1
|
||||
| Div, _ -> -1 | _, Div -> 1
|
||||
| Div_int_int, _ -> -1 | _, Div_int_int -> 1
|
||||
| Div_rat_rat, _ -> -1 | _, Div_rat_rat -> 1
|
||||
| Div_mon_mon, _ -> -1 | _, Div_mon_mon -> 1
|
||||
| Div_mon_rat, _ -> -1 | _, Div_mon_rat -> 1
|
||||
| Lt, _ -> -1 | _, Lt -> 1
|
||||
| Lt_int_int, _ -> -1 | _, Lt_int_int -> 1
|
||||
| Lt_rat_rat, _ -> -1 | _, Lt_rat_rat -> 1
|
||||
| Lt_mon_mon, _ -> -1 | _, Lt_mon_mon -> 1
|
||||
| Lt_dat_dat, _ -> -1 | _, Lt_dat_dat -> 1
|
||||
| Lt_dur_dur, _ -> -1 | _, Lt_dur_dur -> 1
|
||||
| Lte, _ -> -1 | _, Lte -> 1
|
||||
| Lte_int_int, _ -> -1 | _, Lte_int_int -> 1
|
||||
| Lte_rat_rat, _ -> -1 | _, Lte_rat_rat -> 1
|
||||
| Lte_mon_mon, _ -> -1 | _, Lte_mon_mon -> 1
|
||||
| Lte_dat_dat, _ -> -1 | _, Lte_dat_dat -> 1
|
||||
| Lte_dur_dur, _ -> -1 | _, Lte_dur_dur -> 1
|
||||
| Gt, _ -> -1 | _, Gt -> 1
|
||||
| Gt_int_int, _ -> -1 | _, Gt_int_int -> 1
|
||||
| Gt_rat_rat, _ -> -1 | _, Gt_rat_rat -> 1
|
||||
| Gt_mon_mon, _ -> -1 | _, Gt_mon_mon -> 1
|
||||
| Gt_dat_dat, _ -> -1 | _, Gt_dat_dat -> 1
|
||||
| Gt_dur_dur, _ -> -1 | _, Gt_dur_dur -> 1
|
||||
| Gte, _ -> -1 | _, Gte -> 1
|
||||
| Gte_int_int, _ -> -1 | _, Gte_int_int -> 1
|
||||
| Gte_rat_rat, _ -> -1 | _, Gte_rat_rat -> 1
|
||||
| Gte_mon_mon, _ -> -1 | _, Gte_mon_mon -> 1
|
||||
| Gte_dat_dat, _ -> -1 | _, Gte_dat_dat -> 1
|
||||
| Gte_dur_dur, _ -> -1 | _, Gte_dur_dur -> 1
|
||||
| Eq_int_int, _ -> -1 | _, Eq_int_int -> 1
|
||||
| Eq_rat_rat, _ -> -1 | _, Eq_rat_rat -> 1
|
||||
| Eq_mon_mon, _ -> -1 | _, Eq_mon_mon -> 1
|
||||
| Eq_dat_dat, _ -> -1 | _, Eq_dat_dat -> 1
|
||||
| Eq_dur_dur, _ -> -1 | _, Eq_dur_dur -> 1
|
||||
| Fold, _ | _, Fold -> .
|
||||
|
||||
let equal (type a k a2 k2) (t1 : (a, k) t) (t2 : (a2, k2) t) = compare t1 t2 = 0
|
||||
|
||||
(* Classification of operators *)
|
||||
|
||||
let kind_dispatch :
|
||||
type a b k.
|
||||
polymorphic:((_, polymorphic) t -> b) ->
|
||||
monomorphic:((_, monomorphic) t -> b) ->
|
||||
?overloaded:((_, overloaded) t -> b) ->
|
||||
?resolved:((_, resolved) t -> b) ->
|
||||
(a, k) t ->
|
||||
b =
|
||||
fun ~polymorphic ~monomorphic ?(overloaded = fun _ -> assert false)
|
||||
?(resolved = fun _ -> assert false) op ->
|
||||
match op with
|
||||
| ( Not | GetDay | GetMonth | GetYear | FirstDayOfMonth | LastDayOfMonth | And
|
||||
| Or | Xor ) as op ->
|
||||
monomorphic op
|
||||
| (Log _ | Length | Eq | Map | Concat | Filter | Reduce | Fold) as op ->
|
||||
polymorphic op
|
||||
| ( Minus | ToRat | ToMoney | Round | Add | Sub | Mult | Div | Lt | Lte | Gt
|
||||
| Gte ) as op ->
|
||||
overloaded op
|
||||
| ( Minus_int | Minus_rat | Minus_mon | Minus_dur | ToRat_int | ToRat_mon
|
||||
| ToMoney_rat | Round_rat | Round_mon | Add_int_int | Add_rat_rat
|
||||
| Add_mon_mon | Add_dat_dur | Add_dur_dur | Sub_int_int | Sub_rat_rat
|
||||
| Sub_mon_mon | Sub_dat_dat | Sub_dat_dur | Sub_dur_dur | Mult_int_int
|
||||
| Mult_rat_rat | Mult_mon_rat | Mult_dur_int | Div_int_int | Div_rat_rat
|
||||
| Div_mon_mon | Div_mon_rat | Lt_int_int | Lt_rat_rat | Lt_mon_mon
|
||||
| Lt_dat_dat | Lt_dur_dur | Lte_int_int | Lte_rat_rat | Lte_mon_mon
|
||||
| Lte_dat_dat | Lte_dur_dur | Gt_int_int | Gt_rat_rat | Gt_mon_mon
|
||||
| Gt_dat_dat | Gt_dur_dur | Gte_int_int | Gte_rat_rat | Gte_mon_mon
|
||||
| Gte_dat_dat | Gte_dur_dur | Eq_int_int | Eq_rat_rat | Eq_mon_mon
|
||||
| Eq_dat_dat | Eq_dur_dur ) as op ->
|
||||
resolved op
|
||||
|
||||
(* Glorified identity... allowed operators are the same in scopelang, dcalc,
|
||||
lcalc *)
|
||||
let translate :
|
||||
type k.
|
||||
([< scopelang | dcalc | lcalc ], k) t ->
|
||||
([< scopelang | dcalc | lcalc ], k) t =
|
||||
fun op ->
|
||||
match op with
|
||||
| Length -> Length
|
||||
| Log (i, l) -> Log (i, l)
|
||||
| Eq -> Eq
|
||||
| Map -> Map
|
||||
| Concat -> Concat
|
||||
| Filter -> Filter
|
||||
| Reduce -> Reduce
|
||||
| Fold -> Fold
|
||||
| Not -> Not
|
||||
| GetDay -> GetDay
|
||||
| GetMonth -> GetMonth
|
||||
| GetYear -> GetYear
|
||||
| FirstDayOfMonth -> FirstDayOfMonth
|
||||
| LastDayOfMonth -> LastDayOfMonth
|
||||
| And -> And
|
||||
| Or -> Or
|
||||
| Xor -> Xor
|
||||
| Minus_int -> Minus_int
|
||||
| Minus_rat -> Minus_rat
|
||||
| Minus_mon -> Minus_mon
|
||||
| Minus_dur -> Minus_dur
|
||||
| ToRat_int -> ToRat_int
|
||||
| ToRat_mon -> ToRat_mon
|
||||
| ToMoney_rat -> ToMoney_rat
|
||||
| Round_rat -> Round_rat
|
||||
| Round_mon -> Round_mon
|
||||
| Add_int_int -> Add_int_int
|
||||
| Add_rat_rat -> Add_rat_rat
|
||||
| Add_mon_mon -> Add_mon_mon
|
||||
| Add_dat_dur -> Add_dat_dur
|
||||
| Add_dur_dur -> Add_dur_dur
|
||||
| Sub_int_int -> Sub_int_int
|
||||
| Sub_rat_rat -> Sub_rat_rat
|
||||
| Sub_mon_mon -> Sub_mon_mon
|
||||
| Sub_dat_dat -> Sub_dat_dat
|
||||
| Sub_dat_dur -> Sub_dat_dur
|
||||
| Sub_dur_dur -> Sub_dur_dur
|
||||
| Mult_int_int -> Mult_int_int
|
||||
| Mult_rat_rat -> Mult_rat_rat
|
||||
| Mult_mon_rat -> Mult_mon_rat
|
||||
| Mult_dur_int -> Mult_dur_int
|
||||
| Div_int_int -> Div_int_int
|
||||
| Div_rat_rat -> Div_rat_rat
|
||||
| Div_mon_mon -> Div_mon_mon
|
||||
| Div_mon_rat -> Div_mon_rat
|
||||
| Lt_int_int -> Lt_int_int
|
||||
| Lt_rat_rat -> Lt_rat_rat
|
||||
| Lt_mon_mon -> Lt_mon_mon
|
||||
| Lt_dat_dat -> Lt_dat_dat
|
||||
| Lt_dur_dur -> Lt_dur_dur
|
||||
| Lte_int_int -> Lte_int_int
|
||||
| Lte_rat_rat -> Lte_rat_rat
|
||||
| Lte_mon_mon -> Lte_mon_mon
|
||||
| Lte_dat_dat -> Lte_dat_dat
|
||||
| Lte_dur_dur -> Lte_dur_dur
|
||||
| Gt_int_int -> Gt_int_int
|
||||
| Gt_rat_rat -> Gt_rat_rat
|
||||
| Gt_mon_mon -> Gt_mon_mon
|
||||
| Gt_dat_dat -> Gt_dat_dat
|
||||
| Gt_dur_dur -> Gt_dur_dur
|
||||
| Gte_int_int -> Gte_int_int
|
||||
| Gte_rat_rat -> Gte_rat_rat
|
||||
| Gte_mon_mon -> Gte_mon_mon
|
||||
| Gte_dat_dat -> Gte_dat_dat
|
||||
| Gte_dur_dur -> Gte_dur_dur
|
||||
| Eq_int_int -> Eq_int_int
|
||||
| Eq_rat_rat -> Eq_rat_rat
|
||||
| Eq_mon_mon -> Eq_mon_mon
|
||||
| Eq_dat_dat -> Eq_dat_dat
|
||||
| Eq_dur_dur -> Eq_dur_dur
|
||||
|
||||
let monomorphic_type (op, pos) =
|
||||
let ( @- ) a b = TArrow ((TLit a, pos), b), pos in
|
||||
let ( @-> ) a b = TArrow ((TLit a, pos), (TLit b, pos)), pos in
|
||||
match op with
|
||||
| Not -> TBool @-> TBool
|
||||
| GetDay -> TDate @-> TInt
|
||||
| GetMonth -> TDate @-> TInt
|
||||
| GetYear -> TDate @-> TInt
|
||||
| FirstDayOfMonth -> TDate @-> TDate
|
||||
| LastDayOfMonth -> TDate @-> TDate
|
||||
| And -> TBool @- TBool @-> TBool
|
||||
| Or -> TBool @- TBool @-> TBool
|
||||
| Xor -> TBool @- TBool @-> TBool
|
||||
|
||||
(** Rules for overloads definitions:
|
||||
|
||||
- the concrete operator, including its return type, is uniquely determined
|
||||
by the type of the operands
|
||||
|
||||
- no resolved version of an operator should be the redefinition of another
|
||||
one with an added conversion. For example, [int + rat -> rat] is not
|
||||
acceptable (that would amount to implicit casts).
|
||||
|
||||
These two points can be generalised for binary operators as: when
|
||||
considering an operator with type ['a -> 'b -> 'c], for any given two among
|
||||
['a], ['b] and ['c], there should be a unique solution for the third. *)
|
||||
|
||||
let resolved_type (op, pos) =
|
||||
let ( @- ) a b = TArrow ((TLit a, pos), b), pos in
|
||||
let ( @-> ) a b = TArrow ((TLit a, pos), (TLit b, pos)), pos in
|
||||
match op with
|
||||
| Minus_int -> TInt @-> TInt
|
||||
| Minus_rat -> TRat @-> TRat
|
||||
| Minus_mon -> TMoney @-> TMoney
|
||||
| Minus_dur -> TDuration @-> TDuration
|
||||
| ToRat_int -> TInt @-> TRat
|
||||
| ToRat_mon -> TMoney @-> TRat
|
||||
| ToMoney_rat -> TRat @-> TMoney
|
||||
| Round_rat -> TRat @-> TRat
|
||||
| Round_mon -> TMoney @-> TMoney
|
||||
| Add_int_int -> TInt @- TInt @-> TInt
|
||||
| Add_rat_rat -> TRat @- TRat @-> TRat
|
||||
| Add_mon_mon -> TMoney @- TMoney @-> TMoney
|
||||
| Add_dat_dur -> TDate @- TDuration @-> TDate
|
||||
| Add_dur_dur -> TDuration @- TDuration @-> TDuration
|
||||
| Sub_int_int -> TInt @- TInt @-> TInt
|
||||
| Sub_rat_rat -> TRat @- TRat @-> TRat
|
||||
| Sub_mon_mon -> TMoney @- TMoney @-> TMoney
|
||||
| Sub_dat_dat -> TDate @- TDate @-> TDuration
|
||||
| Sub_dat_dur -> TDate @- TDuration @-> TDuration
|
||||
| Sub_dur_dur -> TDuration @- TDuration @-> TDuration
|
||||
| Mult_int_int -> TInt @- TInt @-> TInt
|
||||
| Mult_rat_rat -> TRat @- TRat @-> TRat
|
||||
| Mult_mon_rat -> TMoney @- TRat @-> TMoney
|
||||
| Mult_dur_int -> TDuration @- TInt @-> TDuration
|
||||
| Div_int_int -> TInt @- TInt @-> TRat
|
||||
| Div_rat_rat -> TRat @- TRat @-> TRat
|
||||
| Div_mon_mon -> TMoney @- TMoney @-> TRat
|
||||
| Div_mon_rat -> TMoney @- TRat @-> TMoney
|
||||
| Lt_int_int -> TInt @- TInt @-> TBool
|
||||
| Lt_rat_rat -> TRat @- TRat @-> TBool
|
||||
| Lt_mon_mon -> TMoney @- TMoney @-> TBool
|
||||
| Lt_dat_dat -> TDate @- TDate @-> TBool
|
||||
| Lt_dur_dur -> TDuration @- TDuration @-> TBool
|
||||
| Lte_int_int -> TInt @- TInt @-> TBool
|
||||
| Lte_rat_rat -> TRat @- TRat @-> TBool
|
||||
| Lte_mon_mon -> TMoney @- TMoney @-> TBool
|
||||
| Lte_dat_dat -> TDate @- TDate @-> TBool
|
||||
| Lte_dur_dur -> TDuration @- TDuration @-> TBool
|
||||
| Gt_int_int -> TInt @- TInt @-> TBool
|
||||
| Gt_rat_rat -> TRat @- TRat @-> TBool
|
||||
| Gt_mon_mon -> TMoney @- TMoney @-> TBool
|
||||
| Gt_dat_dat -> TDate @- TDate @-> TBool
|
||||
| Gt_dur_dur -> TDuration @- TDuration @-> TBool
|
||||
| Gte_int_int -> TInt @- TInt @-> TBool
|
||||
| Gte_rat_rat -> TRat @- TRat @-> TBool
|
||||
| Gte_mon_mon -> TMoney @- TMoney @-> TBool
|
||||
| Gte_dat_dat -> TDate @- TDate @-> TBool
|
||||
| Gte_dur_dur -> TDuration @- TDuration @-> TBool
|
||||
| Eq_int_int -> TInt @- TInt @-> TBool
|
||||
| Eq_rat_rat -> TRat @- TRat @-> TBool
|
||||
| Eq_mon_mon -> TMoney @- TMoney @-> TBool
|
||||
| Eq_dat_dat -> TDate @- TDate @-> TBool
|
||||
| Eq_dur_dur -> TDuration @- TDuration @-> TBool
|
||||
|
||||
let resolve_overload_aux (op : ('a, overloaded) t) (operands : typ_lit list) :
|
||||
('b, resolved) t * [ `Straight | `Reversed ] =
|
||||
match op, operands with
|
||||
| Minus, [TInt] -> Minus_int, `Straight
|
||||
| Minus, [TRat] -> Minus_rat, `Straight
|
||||
| Minus, [TMoney] -> Minus_mon, `Straight
|
||||
| Minus, [TDuration] -> Minus_dur, `Straight
|
||||
| ToRat, [TInt] -> ToRat_int, `Straight
|
||||
| ToRat, [TMoney] -> ToRat_mon, `Straight
|
||||
| ToMoney, [TRat] -> ToMoney_rat, `Straight
|
||||
| Round, [TRat] -> Round_rat, `Straight
|
||||
| Round, [TMoney] -> Round_mon, `Straight
|
||||
| Add, [TInt; TInt] -> Add_int_int, `Straight
|
||||
| Add, [TRat; TRat] -> Add_rat_rat, `Straight
|
||||
| Add, [TMoney; TMoney] -> Add_mon_mon, `Straight
|
||||
| Add, [TDuration; TDuration] -> Add_dur_dur, `Straight
|
||||
| Add, [TDate; TDuration] -> Add_dat_dur, `Straight
|
||||
| Add, [TDuration; TDate] -> Add_dat_dur, `Reversed
|
||||
| Sub, [TInt; TInt] -> Sub_int_int, `Straight
|
||||
| Sub, [TRat; TRat] -> Sub_rat_rat, `Straight
|
||||
| Sub, [TMoney; TMoney] -> Sub_mon_mon, `Straight
|
||||
| Sub, [TDuration; TDuration] -> Sub_dur_dur, `Straight
|
||||
| Sub, [TDate; TDate] -> Sub_dat_dat, `Straight
|
||||
| Sub, [TDate; TDuration] -> Sub_dat_dur, `Straight
|
||||
| Mult, [TInt; TInt] -> Mult_int_int, `Straight
|
||||
| Mult, [TRat; TRat] -> Mult_rat_rat, `Straight
|
||||
| Mult, [TMoney; TRat] -> Mult_mon_rat, `Straight
|
||||
| Mult, [TRat; TMoney] -> Mult_mon_rat, `Reversed
|
||||
| Mult, [TDuration; TInt] -> Mult_dur_int, `Straight
|
||||
| Mult, [TInt; TDuration] -> Mult_dur_int, `Reversed
|
||||
| Div, [TInt; TInt] -> Div_int_int, `Straight
|
||||
| Div, [TRat; TRat] -> Div_rat_rat, `Straight
|
||||
| Div, [TMoney; TMoney] -> Div_mon_mon, `Straight
|
||||
| Div, [TMoney; TRat] -> Div_mon_rat, `Straight
|
||||
| Lt, [TInt; TInt] -> Lt_int_int, `Straight
|
||||
| Lt, [TRat; TRat] -> Lt_rat_rat, `Straight
|
||||
| Lt, [TMoney; TMoney] -> Lt_mon_mon, `Straight
|
||||
| Lt, [TDuration; TDuration] -> Lt_dur_dur, `Straight
|
||||
| Lt, [TDate; TDate] -> Lt_dat_dat, `Straight
|
||||
| Lte, [TInt; TInt] -> Lte_int_int, `Straight
|
||||
| Lte, [TRat; TRat] -> Lte_rat_rat, `Straight
|
||||
| Lte, [TMoney; TMoney] -> Lte_mon_mon, `Straight
|
||||
| Lte, [TDuration; TDuration] -> Lte_dur_dur, `Straight
|
||||
| Lte, [TDate; TDate] -> Lte_dat_dat, `Straight
|
||||
| Gt, [TInt; TInt] -> Gt_int_int, `Straight
|
||||
| Gt, [TRat; TRat] -> Gt_rat_rat, `Straight
|
||||
| Gt, [TMoney; TMoney] -> Gt_mon_mon, `Straight
|
||||
| Gt, [TDuration; TDuration] -> Gt_dur_dur, `Straight
|
||||
| Gt, [TDate; TDate] -> Gt_dat_dat, `Straight
|
||||
| Gte, [TInt; TInt] -> Gte_int_int, `Straight
|
||||
| Gte, [TRat; TRat] -> Gte_rat_rat, `Straight
|
||||
| Gte, [TMoney; TMoney] -> Gte_mon_mon, `Straight
|
||||
| Gte, [TDuration; TDuration] -> Gte_dur_dur, `Straight
|
||||
| Gte, [TDate; TDate] -> Gte_dat_dat, `Straight
|
||||
| ( ( Minus | ToRat | ToMoney | Round | Add | Sub | Mult | Div | Lt | Lte | Gt
|
||||
| Gte ),
|
||||
_ ) ->
|
||||
raise Not_found
|
||||
|
||||
let resolve_overload
|
||||
ctx
|
||||
(op : ('a, overloaded) t Marked.pos)
|
||||
(operands : typ list) : ('b, resolved) t * [ `Straight | `Reversed ] =
|
||||
try
|
||||
let operands =
|
||||
List.map
|
||||
(fun t ->
|
||||
match Marked.unmark t with TLit tl -> tl | _ -> raise Not_found)
|
||||
operands
|
||||
in
|
||||
resolve_overload_aux (Marked.unmark op) operands
|
||||
with Not_found ->
|
||||
Errors.raise_multispanned_error
|
||||
((None, Marked.get_mark op)
|
||||
:: List.map
|
||||
(fun ty ->
|
||||
( Some
|
||||
(Format.asprintf "Type %a coming from expression:"
|
||||
(Print.typ ctx) ty),
|
||||
Marked.get_mark ty ))
|
||||
operands)
|
||||
"I don't know how to apply operator %a on types %a" Print.operator
|
||||
(Marked.unmark op)
|
||||
(Format.pp_print_list
|
||||
~pp_sep:(fun ppf () -> Format.fprintf ppf " and@ ")
|
||||
(Print.typ ctx))
|
||||
operands
|
||||
|
||||
let overload_type ctx (op : ('a, overloaded) t Marked.pos) (operands : typ list)
|
||||
: typ =
|
||||
let rop = fst (resolve_overload ctx op operands) in
|
||||
resolved_type (Marked.same_mark_as rop op)
|
85
compiler/shared_ast/operator.mli
Normal file
85
compiler/shared_ast/operator.mli
Normal file
@ -0,0 +1,85 @@
|
||||
(* This file is part of the Catala compiler, a specification language for tax
|
||||
and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
|
||||
Louis Gesbert <louis.gesbert@inria.fr>
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not
|
||||
use this file except in compliance with the License. You may obtain a copy of
|
||||
the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
License for the specific language governing permissions and limitations under
|
||||
the License. *)
|
||||
|
||||
(** {1 Catala operator utilities} *)
|
||||
|
||||
(** Resolving operators from the surface syntax proceeds in three steps:
|
||||
|
||||
- During desugaring, the operators may remain untyped (with [TAny]) or, if
|
||||
they have an explicit type suffix (e.g. the [$] for "money" in [+$]),
|
||||
their operands types are already explicited in the [EOp] expression node.
|
||||
- {!modules:Shared_ast.Typing} will then enforce these constraints in
|
||||
addition to the known built-in type for each operator (e.g.
|
||||
[Eq: 'a -> 'a -> 'a] isn't encoded in the first-order AST types).
|
||||
- Finally, during {!modules:Scopelang.From_desugared}, these types are
|
||||
leveraged to resolve the overloaded operators to their concrete,
|
||||
monomorphic counterparts
|
||||
*)
|
||||
|
||||
open Catala_utils
|
||||
open Definitions
|
||||
include module type of Definitions.Op
|
||||
|
||||
val equal : ('a1, 'k1) t -> ('a2, 'k2) t -> bool
|
||||
val compare : ('a1, 'k1) t -> ('a2, 'k2) t -> int
|
||||
|
||||
val name : ('a, 'k) t -> string
|
||||
(** Returns the operator name as a valid ident starting with a lowercase
|
||||
character. This is different from Print.operator which returns operator
|
||||
symbols, e.g. [+$]. *)
|
||||
|
||||
val kind_dispatch :
|
||||
polymorphic:((_ any, polymorphic) t -> 'b) ->
|
||||
monomorphic:((_ any, monomorphic) t -> 'b) ->
|
||||
?overloaded:((desugared, overloaded) t -> 'b) ->
|
||||
?resolved:(([< scopelang | dcalc | lcalc ], resolved) t -> 'b) ->
|
||||
('a, 'k) t ->
|
||||
'b
|
||||
(** Calls one of the supplied functions depending on the kind of the operator *)
|
||||
|
||||
val translate :
|
||||
([< scopelang | dcalc | lcalc ], 'k) t ->
|
||||
([< scopelang | dcalc | lcalc ], 'k) t
|
||||
(** An identity function that allows translating an operator between different
|
||||
passes that don't change operator types *)
|
||||
|
||||
(** {2 Getting the types of operators} *)
|
||||
|
||||
val monomorphic_type : ('a any, monomorphic) t Marked.pos -> typ
|
||||
|
||||
val resolved_type :
|
||||
([< scopelang | dcalc | lcalc ], resolved) t Marked.pos -> typ
|
||||
|
||||
val overload_type :
|
||||
decl_ctx -> (desugared, overloaded) t Marked.pos -> typ list -> typ
|
||||
(** The type for typing overloads is different since the types of the operands
|
||||
are required in advance.
|
||||
|
||||
@raise a detailed user error if no matching operator can be found *)
|
||||
|
||||
(** Polymorphic operators are typed directly within [Typing], since their types
|
||||
may contain type variables that can't be expressed outside of it*)
|
||||
|
||||
(** {2 Overload handling} *)
|
||||
|
||||
val resolve_overload :
|
||||
decl_ctx ->
|
||||
(desugared, overloaded) t Marked.pos ->
|
||||
typ list ->
|
||||
([< scopelang | dcalc | lcalc ], resolved) t * [ `Straight | `Reversed ]
|
||||
(** Some overloads are sugar for an operation with reversed operands, e.g.
|
||||
[TRat * TMoney] is using [mult_mon_rat]. [`Reversed] is returned to signify
|
||||
this case. *)
|
@ -14,8 +14,7 @@
|
||||
License for the specific language governing permissions and limitations under
|
||||
the License. *)
|
||||
|
||||
open Utils
|
||||
open String_common
|
||||
open Catala_utils
|
||||
open Definitions
|
||||
|
||||
let typ_needs_parens (ty : typ) : bool =
|
||||
@ -26,27 +25,28 @@ let uid_list (fmt : Format.formatter) (infos : Uid.MarkedString.info list) :
|
||||
Format.pp_print_list
|
||||
~pp_sep:(fun fmt () -> Format.pp_print_char fmt '.')
|
||||
(fun fmt info ->
|
||||
Utils.Cli.format_with_style
|
||||
(if begins_with_uppercase (Marked.unmark info) then [ANSITerminal.red]
|
||||
Cli.format_with_style
|
||||
(if String.begins_with_uppercase (Marked.unmark info) then
|
||||
[ANSITerminal.red]
|
||||
else [])
|
||||
fmt
|
||||
(Utils.Uid.MarkedString.to_string info))
|
||||
(Uid.MarkedString.to_string info))
|
||||
fmt infos
|
||||
|
||||
let keyword (fmt : Format.formatter) (s : string) : unit =
|
||||
Utils.Cli.format_with_style [ANSITerminal.red] fmt s
|
||||
Cli.format_with_style [ANSITerminal.red] fmt s
|
||||
|
||||
let base_type (fmt : Format.formatter) (s : string) : unit =
|
||||
Utils.Cli.format_with_style [ANSITerminal.yellow] fmt s
|
||||
Cli.format_with_style [ANSITerminal.yellow] fmt s
|
||||
|
||||
let punctuation (fmt : Format.formatter) (s : string) : unit =
|
||||
Utils.Cli.format_with_style [ANSITerminal.cyan] fmt s
|
||||
Cli.format_with_style [ANSITerminal.cyan] fmt s
|
||||
|
||||
let operator (fmt : Format.formatter) (s : string) : unit =
|
||||
Utils.Cli.format_with_style [ANSITerminal.green] fmt s
|
||||
let op_style (fmt : Format.formatter) (s : string) : unit =
|
||||
Cli.format_with_style [ANSITerminal.green] fmt s
|
||||
|
||||
let lit_style (fmt : Format.formatter) (s : string) : unit =
|
||||
Utils.Cli.format_with_style [ANSITerminal.yellow] fmt s
|
||||
Cli.format_with_style [ANSITerminal.yellow] fmt s
|
||||
|
||||
let tlit (fmt : Format.formatter) (l : typ_lit) : unit =
|
||||
base_type fmt
|
||||
@ -68,7 +68,7 @@ let location (type a) (fmt : Format.formatter) (l : a glocation) : unit =
|
||||
ScopeVar.format_t (Marked.unmark subvar)
|
||||
|
||||
let enum_constructor (fmt : Format.formatter) (c : EnumConstructor.t) : unit =
|
||||
Utils.Cli.format_with_style [ANSITerminal.magenta] fmt
|
||||
Cli.format_with_style [ANSITerminal.magenta] fmt
|
||||
(Format.asprintf "%a" EnumConstructor.format_t c)
|
||||
|
||||
let rec typ (ctx : decl_ctx option) (fmt : Format.formatter) (ty : typ) : unit =
|
||||
@ -81,7 +81,7 @@ let rec typ (ctx : decl_ctx option) (fmt : Format.formatter) (ty : typ) : unit =
|
||||
| TTuple ts ->
|
||||
Format.fprintf fmt "@[<hov 2>(%a)@]"
|
||||
(Format.pp_print_list
|
||||
~pp_sep:(fun fmt () -> Format.fprintf fmt "@ %a@ " operator "*")
|
||||
~pp_sep:(fun fmt () -> Format.fprintf fmt "@ %a@ " op_style "*")
|
||||
typ)
|
||||
ts
|
||||
| TStruct s -> (
|
||||
@ -94,9 +94,9 @@ let rec typ (ctx : decl_ctx option) (fmt : Format.formatter) (ty : typ) : unit =
|
||||
~pp_sep:(fun fmt () -> Format.fprintf fmt "%a@ " punctuation ";")
|
||||
(fun fmt (field, mty) ->
|
||||
Format.fprintf fmt "%a%a%a%a@ %a" punctuation "\""
|
||||
StructFieldName.format_t field punctuation "\"" punctuation ":"
|
||||
typ mty))
|
||||
(StructMap.find s ctx.ctx_structs)
|
||||
StructField.format_t field punctuation "\"" punctuation ":" typ
|
||||
mty))
|
||||
(StructField.Map.bindings (StructName.Map.find s ctx.ctx_structs))
|
||||
punctuation "}")
|
||||
| TEnum e -> (
|
||||
match ctx with
|
||||
@ -109,11 +109,11 @@ let rec typ (ctx : decl_ctx option) (fmt : Format.formatter) (ty : typ) : unit =
|
||||
(fun fmt (case, mty) ->
|
||||
Format.fprintf fmt "%a%a@ %a" enum_constructor case punctuation ":"
|
||||
typ mty))
|
||||
(EnumMap.find e ctx.ctx_enums)
|
||||
(EnumConstructor.Map.bindings (EnumName.Map.find e ctx.ctx_enums))
|
||||
punctuation "]")
|
||||
| TOption t -> Format.fprintf fmt "@[<hov 2>%a@ %a@]" base_type "option" typ t
|
||||
| TArrow (t1, t2) ->
|
||||
Format.fprintf fmt "@[<hov 2>%a %a@ %a@]" typ_with_parens t1 operator "→"
|
||||
Format.fprintf fmt "@[<hov 2>%a %a@ %a@]" typ_with_parens t1 op_style "→"
|
||||
typ t2
|
||||
| TArray t1 ->
|
||||
Format.fprintf fmt "@[<hov 2>%a@ %a@]" base_type "collection" typ t1
|
||||
@ -127,9 +127,9 @@ let lit (type a) (fmt : Format.formatter) (l : a glit) : unit =
|
||||
| LUnit -> lit_style fmt "()"
|
||||
| LRat i ->
|
||||
lit_style fmt
|
||||
(Runtime.decimal_to_string ~max_prec_digits:!Utils.Cli.max_prec_digits i)
|
||||
(Runtime.decimal_to_string ~max_prec_digits:!Cli.max_prec_digits i)
|
||||
| LMoney e -> (
|
||||
match !Utils.Cli.locale_lang with
|
||||
match !Cli.locale_lang with
|
||||
| En -> lit_style fmt (Format.asprintf "$%s" (Runtime.money_to_string e))
|
||||
| Fr -> lit_style fmt (Format.asprintf "%s €" (Runtime.money_to_string e))
|
||||
| Pl -> lit_style fmt (Format.asprintf "%s PLN" (Runtime.money_to_string e))
|
||||
@ -137,72 +137,112 @@ let lit (type a) (fmt : Format.formatter) (l : a glit) : unit =
|
||||
| LDate d -> lit_style fmt (Runtime.date_to_string d)
|
||||
| LDuration d -> lit_style fmt (Runtime.duration_to_string d)
|
||||
|
||||
let op_kind (fmt : Format.formatter) (k : op_kind) =
|
||||
Format.fprintf fmt "%s"
|
||||
(match k with
|
||||
| KInt -> ""
|
||||
| KRat -> "."
|
||||
| KMoney -> "$"
|
||||
| KDate -> "@"
|
||||
| KDuration -> "^")
|
||||
|
||||
let binop (fmt : Format.formatter) (op : binop) : unit =
|
||||
operator fmt
|
||||
(match op with
|
||||
| Add k -> Format.asprintf "+%a" op_kind k
|
||||
| Sub k -> Format.asprintf "-%a" op_kind k
|
||||
| Mult k -> Format.asprintf "*%a" op_kind k
|
||||
| Div k -> Format.asprintf "/%a" op_kind k
|
||||
| And -> "&&"
|
||||
| Or -> "||"
|
||||
| Xor -> "xor"
|
||||
| Eq -> "="
|
||||
| Neq -> "!="
|
||||
| Lt k -> Format.asprintf "%s%a" "<" op_kind k
|
||||
| Lte k -> Format.asprintf "%s%a" "<=" op_kind k
|
||||
| Gt k -> Format.asprintf "%s%a" ">" op_kind k
|
||||
| Gte k -> Format.asprintf "%s%a" ">=" op_kind k
|
||||
| Concat -> "++"
|
||||
| Map -> "map"
|
||||
| Filter -> "filter")
|
||||
|
||||
let ternop (fmt : Format.formatter) (op : ternop) : unit =
|
||||
match op with Fold -> keyword fmt "fold"
|
||||
|
||||
let log_entry (fmt : Format.formatter) (entry : log_entry) : unit =
|
||||
Format.fprintf fmt "@<2>%a"
|
||||
(fun fmt -> function
|
||||
| VarDef _ -> Utils.Cli.format_with_style [ANSITerminal.blue] fmt "≔ "
|
||||
| BeginCall -> Utils.Cli.format_with_style [ANSITerminal.yellow] fmt "→ "
|
||||
| EndCall -> Utils.Cli.format_with_style [ANSITerminal.yellow] fmt "← "
|
||||
| VarDef _ -> Cli.format_with_style [ANSITerminal.blue] fmt "≔ "
|
||||
| BeginCall -> Cli.format_with_style [ANSITerminal.yellow] fmt "→ "
|
||||
| EndCall -> Cli.format_with_style [ANSITerminal.yellow] fmt "← "
|
||||
| PosRecordIfTrueBool ->
|
||||
Utils.Cli.format_with_style [ANSITerminal.green] fmt "☛ ")
|
||||
Cli.format_with_style [ANSITerminal.green] fmt "☛ ")
|
||||
entry
|
||||
|
||||
let unop (fmt : Format.formatter) (op : unop) : unit =
|
||||
let operator_to_string : type a k. (a, k) Op.t -> string = function
|
||||
| Not -> "~"
|
||||
| Length -> "length"
|
||||
| GetDay -> "get_day"
|
||||
| GetMonth -> "get_month"
|
||||
| GetYear -> "get_year"
|
||||
| FirstDayOfMonth -> "first_day_of_month"
|
||||
| LastDayOfMonth -> "last_day_of_month"
|
||||
| ToRat -> "to_rat"
|
||||
| ToRat_int -> "to_rat_int"
|
||||
| ToRat_mon -> "to_rat_mon"
|
||||
| ToMoney -> "to_mon"
|
||||
| ToMoney_rat -> "to_mon_rat"
|
||||
| Round -> "round"
|
||||
| Round_rat -> "round_rat"
|
||||
| Round_mon -> "round_mon"
|
||||
| Log _ -> "Log"
|
||||
| Minus -> "-"
|
||||
| Minus_int -> "-!"
|
||||
| Minus_rat -> "-."
|
||||
| Minus_mon -> "-$"
|
||||
| Minus_dur -> "-^"
|
||||
| And -> "&&"
|
||||
| Or -> "||"
|
||||
| Xor -> "xor"
|
||||
| Eq -> "="
|
||||
| Map -> "map"
|
||||
| Reduce -> "reduce"
|
||||
| Concat -> "++"
|
||||
| Filter -> "filter"
|
||||
| Add -> "+"
|
||||
| Add_int_int -> "+!"
|
||||
| Add_rat_rat -> "+."
|
||||
| Add_mon_mon -> "+$"
|
||||
| Add_dat_dur -> "+@"
|
||||
| Add_dur_dur -> "+^"
|
||||
| Sub -> "-"
|
||||
| Sub_int_int -> "-!"
|
||||
| Sub_rat_rat -> "-."
|
||||
| Sub_mon_mon -> "-$"
|
||||
| Sub_dat_dat -> "-@"
|
||||
| Sub_dat_dur -> "-@^"
|
||||
| Sub_dur_dur -> "-^"
|
||||
| Mult -> "*"
|
||||
| Mult_int_int -> "*!"
|
||||
| Mult_rat_rat -> "*."
|
||||
| Mult_mon_rat -> "*$"
|
||||
| Mult_dur_int -> "*^"
|
||||
| Div -> "/"
|
||||
| Div_int_int -> "/!"
|
||||
| Div_rat_rat -> "/."
|
||||
| Div_mon_mon -> "/$"
|
||||
| Div_mon_rat -> "/$."
|
||||
| Lt -> "<"
|
||||
| Lt_int_int -> "<!"
|
||||
| Lt_rat_rat -> "<."
|
||||
| Lt_mon_mon -> "<$"
|
||||
| Lt_dur_dur -> "<^"
|
||||
| Lt_dat_dat -> "<@"
|
||||
| Lte -> "<="
|
||||
| Lte_int_int -> "<=!"
|
||||
| Lte_rat_rat -> "<=."
|
||||
| Lte_mon_mon -> "<=$"
|
||||
| Lte_dur_dur -> "<=^"
|
||||
| Lte_dat_dat -> "<=@"
|
||||
| Gt -> ">"
|
||||
| Gt_int_int -> ">!"
|
||||
| Gt_rat_rat -> ">."
|
||||
| Gt_mon_mon -> ">$"
|
||||
| Gt_dur_dur -> ">^"
|
||||
| Gt_dat_dat -> ">@"
|
||||
| Gte -> ">="
|
||||
| Gte_int_int -> ">=!"
|
||||
| Gte_rat_rat -> ">=."
|
||||
| Gte_mon_mon -> ">=$"
|
||||
| Gte_dur_dur -> ">=^"
|
||||
| Gte_dat_dat -> ">=@"
|
||||
| Eq_int_int -> "=!"
|
||||
| Eq_rat_rat -> "=."
|
||||
| Eq_mon_mon -> "=$"
|
||||
| Eq_dur_dur -> "=^"
|
||||
| Eq_dat_dat -> "=@"
|
||||
| Fold -> "fold"
|
||||
|
||||
let operator (type k) (fmt : Format.formatter) (op : ('a, k) operator) : unit =
|
||||
match op with
|
||||
| Minus _ -> Format.pp_print_string fmt "-"
|
||||
| Not -> Format.pp_print_string fmt "~"
|
||||
| Log (entry, infos) ->
|
||||
Format.fprintf fmt "log@[<hov 2>[%a|%a]@]" log_entry entry
|
||||
Format.fprintf fmt "%a@[<hov 2>[%a|%a]@]" op_style "log" log_entry entry
|
||||
(Format.pp_print_list
|
||||
~pp_sep:(fun fmt () -> Format.fprintf fmt ".")
|
||||
(fun fmt info -> Utils.Uid.MarkedString.format_info fmt info))
|
||||
(fun fmt info -> Uid.MarkedString.format fmt info))
|
||||
infos
|
||||
| Length -> Format.pp_print_string fmt "length"
|
||||
| IntToRat -> Format.pp_print_string fmt "int_to_rat"
|
||||
| MoneyToRat -> Format.pp_print_string fmt "money_to_rat"
|
||||
| RatToMoney -> Format.pp_print_string fmt "rat_to_money"
|
||||
| GetDay -> Format.pp_print_string fmt "get_day"
|
||||
| GetMonth -> Format.pp_print_string fmt "get_month"
|
||||
| GetYear -> Format.pp_print_string fmt "get_year"
|
||||
| FirstDayOfMonth -> Format.pp_print_string fmt "first_day_of_month"
|
||||
| LastDayOfMonth -> Format.pp_print_string fmt "last_day_of_month"
|
||||
| RoundMoney -> Format.pp_print_string fmt "round_money"
|
||||
| RoundDecimal -> Format.pp_print_string fmt "round_decimal"
|
||||
| op -> Format.fprintf fmt "%a" op_style (operator_to_string op)
|
||||
|
||||
let except (fmt : Format.formatter) (exn : except) : unit =
|
||||
operator fmt
|
||||
op_style fmt
|
||||
(match exn with
|
||||
| EmptyError -> "EmptyError"
|
||||
| ConflictError -> "ConflictError"
|
||||
@ -215,7 +255,7 @@ let var_debug fmt v =
|
||||
let var fmt v = Format.pp_print_string fmt (Bindlib.name_of v)
|
||||
|
||||
let needs_parens (type a) (e : (a, _) gexpr) : bool =
|
||||
match Marked.unmark e with EAbs _ | ETuple (_, Some _) -> true | _ -> false
|
||||
match Marked.unmark e with EAbs _ | EStruct _ -> true | _ -> false
|
||||
|
||||
let rec expr_aux :
|
||||
type a.
|
||||
@ -228,6 +268,7 @@ let rec expr_aux :
|
||||
fun ?(debug = false) ctx bnd_ctx fmt e ->
|
||||
let exprb bnd_ctx e = expr_aux ~debug ctx bnd_ctx e in
|
||||
let expr e = exprb bnd_ctx e in
|
||||
let var = if debug then var_debug else var in
|
||||
let with_parens fmt e =
|
||||
if needs_parens e then (
|
||||
punctuation fmt "(";
|
||||
@ -236,79 +277,28 @@ let rec expr_aux :
|
||||
else expr fmt e
|
||||
in
|
||||
match Marked.unmark e with
|
||||
| EVar v -> if debug then var_debug fmt v else var fmt v
|
||||
| ETuple (es, None) ->
|
||||
| EVar v -> var fmt v
|
||||
| ETuple es ->
|
||||
Format.fprintf fmt "@[<hov 2>%a%a%a@]" punctuation "("
|
||||
(Format.pp_print_list
|
||||
~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ")
|
||||
(fun fmt e -> expr fmt e))
|
||||
es punctuation ")"
|
||||
| ETuple (es, Some s) -> (
|
||||
match ctx with
|
||||
| None -> expr fmt (Marked.same_mark_as (ETuple (es, None)) e)
|
||||
| Some ctx ->
|
||||
Format.fprintf fmt "@[<hov 2>%a@ @[<hov 2>%a%a%a@]@]" StructName.format_t
|
||||
s punctuation "{"
|
||||
(Format.pp_print_list
|
||||
~pp_sep:(fun fmt () -> Format.fprintf fmt "%a@ " punctuation ";")
|
||||
(fun fmt (e, struct_field) ->
|
||||
Format.fprintf fmt "%a%a%a%a@ %a" punctuation "\""
|
||||
StructFieldName.format_t struct_field punctuation "\""
|
||||
punctuation "=" expr e))
|
||||
(List.combine es (List.map fst (StructMap.find s ctx.ctx_structs)))
|
||||
punctuation "}")
|
||||
| EArray es ->
|
||||
Format.fprintf fmt "@[<hov 2>%a%a%a@]" punctuation "["
|
||||
(Format.pp_print_list
|
||||
~pp_sep:(fun fmt () -> Format.fprintf fmt ";@ ")
|
||||
(fun fmt e -> expr fmt e))
|
||||
es punctuation "]"
|
||||
| ETupleAccess (e1, n, s, _ts) -> (
|
||||
match s, ctx with
|
||||
| None, _ | _, None ->
|
||||
expr fmt e1;
|
||||
punctuation fmt ".";
|
||||
Format.pp_print_int fmt n
|
||||
| Some s, Some ctx ->
|
||||
expr fmt e1;
|
||||
operator fmt ".";
|
||||
punctuation fmt "\"";
|
||||
StructFieldName.format_t fmt
|
||||
(fst (List.nth (StructMap.find s ctx.ctx_structs) n));
|
||||
punctuation fmt "\"")
|
||||
| EInj (e, n, en, _ts) -> (
|
||||
match ctx with
|
||||
| None ->
|
||||
Format.fprintf fmt "@[<hov 2>%a[%d]@ %a@]" EnumName.format_t en n expr e
|
||||
| Some ctx ->
|
||||
Format.fprintf fmt "@[<hov 2>%a@ %a@]" enum_constructor
|
||||
(fst (List.nth (EnumMap.find en ctx.ctx_enums) n))
|
||||
expr e)
|
||||
| EMatch (e, es, e_name) -> (
|
||||
match ctx with
|
||||
| None ->
|
||||
Format.fprintf fmt "@[<hov 0>%a@ @[<hov 2>%a@]@ %a@ %a@]" keyword "match"
|
||||
expr e keyword "with"
|
||||
(Format.pp_print_list
|
||||
~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n")
|
||||
(fun fmt (e, i) ->
|
||||
Format.fprintf fmt "@[<hov 2>%a %a[%d]%a@ %a@]" punctuation "|"
|
||||
EnumName.format_t e_name i punctuation ":" expr e))
|
||||
(List.mapi (fun i e -> e, i) es)
|
||||
| Some ctx ->
|
||||
Format.fprintf fmt "@[<hov 0>%a@ @[<hov 2>%a@]@ %a@ %a@]" keyword "match"
|
||||
expr e keyword "with"
|
||||
(Format.pp_print_list
|
||||
~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n")
|
||||
(fun fmt (e, c) ->
|
||||
Format.fprintf fmt "@[<hov 2>%a %a%a@ %a@]" punctuation "|"
|
||||
enum_constructor c punctuation ":" expr e))
|
||||
(List.combine es (List.map fst (EnumMap.find e_name ctx.ctx_enums))))
|
||||
| ETupleAccess { e; index; _ } ->
|
||||
expr fmt e;
|
||||
punctuation fmt ".";
|
||||
Format.pp_print_int fmt index
|
||||
| ELit l -> lit fmt l
|
||||
| EApp ((EAbs (binder, taus), _), args) ->
|
||||
| EApp { f = EAbs { binder; tys }, _; args } ->
|
||||
let xs, body, bnd_ctx = Bindlib.unmbind_in bnd_ctx binder in
|
||||
let expr = exprb bnd_ctx in
|
||||
let xs_tau = List.mapi (fun i tau -> xs.(i), tau) taus in
|
||||
let xs_tau = List.mapi (fun i tau -> xs.(i), tau) tys in
|
||||
let xs_tau_arg = List.map2 (fun (x, tau) arg -> x, tau, arg) xs_tau args in
|
||||
Format.fprintf fmt "%a%a"
|
||||
(Format.pp_print_list
|
||||
@ -318,10 +308,10 @@ let rec expr_aux :
|
||||
"let" var x punctuation ":" (typ ctx) tau punctuation "=" expr arg
|
||||
keyword "in"))
|
||||
xs_tau_arg expr body
|
||||
| EAbs (binder, taus) ->
|
||||
| EAbs { binder; tys } ->
|
||||
let xs, body, bnd_ctx = Bindlib.unmbind_in bnd_ctx binder in
|
||||
let expr = exprb bnd_ctx in
|
||||
let xs_tau = List.mapi (fun i tau -> xs.(i), tau) taus in
|
||||
let xs_tau = List.mapi (fun i tau -> xs.(i), tau) tys in
|
||||
Format.fprintf fmt "@[<hov 2>%a @[<hov 2>%a@] %a@ %a@]" punctuation "λ"
|
||||
(Format.pp_print_list
|
||||
~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ")
|
||||
@ -329,29 +319,28 @@ let rec expr_aux :
|
||||
Format.fprintf fmt "%a%a%a %a%a" punctuation "(" var x punctuation
|
||||
":" (typ ctx) tau punctuation ")"))
|
||||
xs_tau punctuation "→" expr body
|
||||
| EApp ((EOp (Binop ((Map | Filter) as op)), _), [arg1; arg2]) ->
|
||||
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@]" binop op with_parens arg1
|
||||
| EApp { f = EOp { op = (Map | Filter) as op; _ }, _; args = [arg1; arg2] } ->
|
||||
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@]" operator op with_parens arg1
|
||||
with_parens arg2
|
||||
| EApp ((EOp (Binop op), _), [arg1; arg2]) ->
|
||||
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@]" with_parens arg1 binop op
|
||||
| EApp { f = EOp { op; _ }, _; args = [arg1; arg2] } ->
|
||||
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@]" with_parens arg1 operator op
|
||||
with_parens arg2
|
||||
| EApp ((EOp (Unop (Log _)), _), [arg1]) when not debug -> expr fmt arg1
|
||||
| EApp ((EOp (Unop op), _), [arg1]) ->
|
||||
Format.fprintf fmt "@[<hov 2>%a@ %a@]" unop op with_parens arg1
|
||||
| EApp (f, args) ->
|
||||
| EApp { f = EOp { op = Log _; _ }, _; args = [arg1] } when not debug ->
|
||||
expr fmt arg1
|
||||
| EApp { f = EOp { op; _ }, _; args = [arg1] } ->
|
||||
Format.fprintf fmt "@[<hov 2>%a@ %a@]" operator op with_parens arg1
|
||||
| EApp { f; args } ->
|
||||
Format.fprintf fmt "@[<hov 2>%a@ %a@]" expr f
|
||||
(Format.pp_print_list
|
||||
~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ")
|
||||
with_parens)
|
||||
args
|
||||
| EIfThenElse (e1, e2, e3) ->
|
||||
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@ %a@ %a@ %a@]" keyword "if" expr e1
|
||||
keyword "then" expr e2 keyword "else" expr e3
|
||||
| EOp (Ternop op) -> ternop fmt op
|
||||
| EOp (Binop op) -> binop fmt op
|
||||
| EOp (Unop op) -> unop fmt op
|
||||
| EDefault (exceptions, just, cons) ->
|
||||
if List.length exceptions = 0 then
|
||||
| EIfThenElse { cond; etrue; efalse } ->
|
||||
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@ %a@ %a@ %a@]" keyword "if" expr
|
||||
cond keyword "then" expr etrue keyword "else" expr efalse
|
||||
| EOp { op; _ } -> operator fmt op
|
||||
| EDefault { excepts; just; cons } ->
|
||||
if List.length excepts = 0 then
|
||||
Format.fprintf fmt "@[<hov 2>%a%a@ %a@ %a%a@]" punctuation "⟨" expr just
|
||||
punctuation "⊢" expr cons punctuation "⟩"
|
||||
else
|
||||
@ -359,45 +348,48 @@ let rec expr_aux :
|
||||
(Format.pp_print_list
|
||||
~pp_sep:(fun fmt () -> Format.fprintf fmt "%a@ " punctuation ",")
|
||||
expr)
|
||||
exceptions punctuation "|" expr just punctuation "⊢" expr cons
|
||||
punctuation "⟩"
|
||||
| ErrorOnEmpty e' ->
|
||||
Format.fprintf fmt "%a@ %a" operator "error_empty" with_parens e'
|
||||
excepts punctuation "|" expr just punctuation "⊢" expr cons punctuation
|
||||
"⟩"
|
||||
| EErrorOnEmpty e' ->
|
||||
Format.fprintf fmt "%a@ %a" op_style "error_empty" with_parens e'
|
||||
| EAssert e' ->
|
||||
Format.fprintf fmt "@[<hov 2>%a@ %a%a%a@]" keyword "assert" punctuation "("
|
||||
expr e' punctuation ")"
|
||||
| ECatch (e1, exn, e2) ->
|
||||
| ECatch { body; exn; handler } ->
|
||||
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@ %a ->@ %a@]" keyword "try"
|
||||
with_parens e1 keyword "with" except exn with_parens e2
|
||||
with_parens body keyword "with" except exn with_parens handler
|
||||
| ERaise exn ->
|
||||
Format.fprintf fmt "@[<hov 2>%a@ %a@]" keyword "raise" except exn
|
||||
| ELocation loc -> location fmt loc
|
||||
| EStruct (name, fields) ->
|
||||
Format.fprintf fmt " @[<hov 2>%a@ %a@ %a@ %a@]" StructName.format_t name
|
||||
| EDStructAccess { e; field; _ } ->
|
||||
Format.fprintf fmt "%a%a%a%a%a" expr e punctuation "." punctuation "\""
|
||||
IdentName.format_t field punctuation "\""
|
||||
| EStruct { name; fields } ->
|
||||
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@ %a@]" StructName.format_t name
|
||||
punctuation "{"
|
||||
(Format.pp_print_list
|
||||
~pp_sep:(fun fmt () -> Format.fprintf fmt "%a@ " punctuation ";")
|
||||
(fun fmt (field_name, field_expr) ->
|
||||
Format.fprintf fmt "%a%a%a%a@ %a" punctuation "\""
|
||||
StructFieldName.format_t field_name punctuation "\"" punctuation
|
||||
"=" expr field_expr))
|
||||
(StructFieldMap.bindings fields)
|
||||
StructField.format_t field_name punctuation "\"" punctuation "="
|
||||
expr field_expr))
|
||||
(StructField.Map.bindings fields)
|
||||
punctuation "}"
|
||||
| EStructAccess (e1, field, _) ->
|
||||
Format.fprintf fmt "%a%a%a%a%a" expr e1 punctuation "." punctuation "\""
|
||||
StructFieldName.format_t field punctuation "\""
|
||||
| EEnumInj (e1, cons, _) ->
|
||||
Format.fprintf fmt "%a@ %a" EnumConstructor.format_t cons expr e1
|
||||
| EMatchS (e1, _, cases) ->
|
||||
| EStructAccess { e; field; _ } ->
|
||||
Format.fprintf fmt "%a%a%a%a%a" expr e punctuation "." punctuation "\""
|
||||
StructField.format_t field punctuation "\""
|
||||
| EInj { e; cons; _ } ->
|
||||
Format.fprintf fmt "%a@ %a" EnumConstructor.format_t cons expr e
|
||||
| EMatch { e; cases; _ } ->
|
||||
Format.fprintf fmt "@[<hov 0>%a@ @[<hov 2>%a@]@ %a@ %a@]" keyword "match"
|
||||
expr e1 keyword "with"
|
||||
expr e keyword "with"
|
||||
(Format.pp_print_list
|
||||
~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n")
|
||||
(fun fmt (cons_name, case_expr) ->
|
||||
Format.fprintf fmt "@[<hov 2>%a %a@ %a@ %a@]" punctuation "|"
|
||||
enum_constructor cons_name punctuation "→" expr case_expr))
|
||||
(EnumConstructorMap.bindings cases)
|
||||
| EScopeCall (scope, fields) ->
|
||||
(EnumConstructor.Map.bindings cases)
|
||||
| EScopeCall { scope; args } ->
|
||||
Format.pp_open_hovbox fmt 2;
|
||||
ScopeName.format_t fmt scope;
|
||||
Format.pp_print_space fmt ();
|
||||
@ -411,7 +403,7 @@ let rec expr_aux :
|
||||
Format.fprintf fmt "%a%a%a%a@ %a" punctuation "\"" ScopeVar.format_t
|
||||
field_name punctuation "\"" punctuation "=" expr field_expr)
|
||||
fmt
|
||||
(ScopeVarMap.bindings fields);
|
||||
(ScopeVar.Map.bindings args);
|
||||
Format.pp_close_box fmt ();
|
||||
punctuation fmt "}";
|
||||
Format.pp_close_box fmt ()
|
||||
|
@ -16,7 +16,7 @@
|
||||
|
||||
(** Printing functions for the default calculus AST *)
|
||||
|
||||
open Utils
|
||||
open Catala_utils
|
||||
open Definitions
|
||||
|
||||
(** {1 Common syntax highlighting helpers}*)
|
||||
@ -24,7 +24,7 @@ open Definitions
|
||||
val base_type : Format.formatter -> string -> unit
|
||||
val keyword : Format.formatter -> string -> unit
|
||||
val punctuation : Format.formatter -> string -> unit
|
||||
val operator : Format.formatter -> string -> unit
|
||||
val op_style : Format.formatter -> string -> unit
|
||||
val lit_style : Format.formatter -> string -> unit
|
||||
|
||||
(** {1 Formatters} *)
|
||||
@ -35,13 +35,11 @@ val tlit : Format.formatter -> typ_lit -> unit
|
||||
val location : Format.formatter -> 'a glocation -> unit
|
||||
val typ : decl_ctx -> Format.formatter -> typ -> unit
|
||||
val lit : Format.formatter -> 'a glit -> unit
|
||||
val op_kind : Format.formatter -> op_kind -> unit
|
||||
val binop : Format.formatter -> binop -> unit
|
||||
val ternop : Format.formatter -> ternop -> unit
|
||||
val operator : Format.formatter -> ('a any, 'k) operator -> unit
|
||||
val log_entry : Format.formatter -> log_entry -> unit
|
||||
val unop : Format.formatter -> unop -> unit
|
||||
val except : Format.formatter -> except -> unit
|
||||
val var : Format.formatter -> 'e Var.t -> unit
|
||||
val var_debug : Format.formatter -> 'e Var.t -> unit
|
||||
|
||||
val expr :
|
||||
?debug:bool (** [true] for debug printing *) ->
|
||||
|
@ -22,6 +22,18 @@ let map_exprs ~f ~varf { scopes; decl_ctx } =
|
||||
(fun scopes -> { scopes; decl_ctx })
|
||||
(Scope.map_exprs ~f ~varf scopes)
|
||||
|
||||
let get_scope_body { scopes; _ } scope =
|
||||
match
|
||||
Scope.fold_left ~init:None
|
||||
~f:(fun acc scope_def _ ->
|
||||
if ScopeName.equal scope_def.scope_name scope then
|
||||
Some scope_def.scope_body
|
||||
else acc)
|
||||
scopes
|
||||
with
|
||||
| None -> raise Not_found
|
||||
| Some body -> body
|
||||
|
||||
let untype : 'm. ('a, 'm mark) gexpr program -> ('a, untyped mark) gexpr program
|
||||
=
|
||||
fun prg -> Bindlib.unbox (map_exprs ~f:Expr.untype ~varf:Var.translate prg)
|
||||
|
@ -25,6 +25,9 @@ val map_exprs :
|
||||
'expr1 program ->
|
||||
'expr2 program Bindlib.box
|
||||
|
||||
val get_scope_body :
|
||||
(([< dcalc | lcalc ], _) gexpr as 'e) program -> ScopeName.t -> 'e scope_body
|
||||
|
||||
val untype :
|
||||
(([< dcalc | lcalc ] as 'a), 'm mark) gexpr program ->
|
||||
('a, untyped mark) gexpr program
|
||||
|
@ -15,7 +15,7 @@
|
||||
License for the specific language governing permissions and limitations under
|
||||
the License. *)
|
||||
|
||||
open Utils
|
||||
open Catala_utils
|
||||
open Definitions
|
||||
|
||||
let rec fold_left_lets ~f ~init scope_body_expr =
|
||||
@ -106,7 +106,7 @@ let rec get_body_expr_mark = function
|
||||
get_body_expr_mark e
|
||||
| Result e ->
|
||||
let m = Marked.get_mark e in
|
||||
Expr.with_ty m (Utils.Marked.mark (Expr.mark_pos m) TAny)
|
||||
Expr.with_ty m (Marked.mark (Expr.mark_pos m) TAny)
|
||||
|
||||
let get_body_mark scope_body =
|
||||
let _, e = Bindlib.unbind scope_body.scope_body_expr in
|
||||
|
@ -17,7 +17,7 @@
|
||||
|
||||
(** Functions handling the scope structures of [shared_ast] *)
|
||||
|
||||
open Utils
|
||||
open Catala_utils
|
||||
open Definitions
|
||||
|
||||
(** {2 Traversal functions} *)
|
||||
|
@ -16,6 +16,8 @@
|
||||
|
||||
include Definitions
|
||||
module Var = Var
|
||||
module Type = Type
|
||||
module Operator = Operator
|
||||
module Expr = Expr
|
||||
module Scope = Scope
|
||||
module Program = Program
|
||||
|
87
compiler/shared_ast/type.ml
Normal file
87
compiler/shared_ast/type.ml
Normal file
@ -0,0 +1,87 @@
|
||||
(* This file is part of the Catala compiler, a specification language for tax
|
||||
and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
|
||||
Louis Gesbert <louis.gesbert@inria.fr>
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not
|
||||
use this file except in compliance with the License. You may obtain a copy of
|
||||
the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
License for the specific language governing permissions and limitations under
|
||||
the License. *)
|
||||
|
||||
open Catala_utils
|
||||
open Definitions
|
||||
|
||||
type t = typ
|
||||
|
||||
let equal_tlit l1 l2 = l1 = l2
|
||||
let compare_tlit l1 l2 = Stdlib.compare l1 l2
|
||||
|
||||
let rec equal ty1 ty2 =
|
||||
match Marked.unmark ty1, Marked.unmark ty2 with
|
||||
| TLit l1, TLit l2 -> equal_tlit l1 l2
|
||||
| TTuple tys1, TTuple tys2 -> equal_list tys1 tys2
|
||||
| TStruct n1, TStruct n2 -> StructName.equal n1 n2
|
||||
| TEnum n1, TEnum n2 -> EnumName.equal n1 n2
|
||||
| TOption t1, TOption t2 -> equal t1 t2
|
||||
| TArrow (t1, t1'), TArrow (t2, t2') -> equal t1 t2 && equal t1' t2'
|
||||
| TArray t1, TArray t2 -> equal t1 t2
|
||||
| TAny, TAny -> true
|
||||
| ( ( TLit _ | TTuple _ | TStruct _ | TEnum _ | TOption _ | TArrow _
|
||||
| TArray _ | TAny ),
|
||||
_ ) ->
|
||||
false
|
||||
|
||||
and equal_list tys1 tys2 =
|
||||
try List.for_all2 equal tys1 tys2 with Invalid_argument _ -> false
|
||||
|
||||
(* Similar to [equal], but allows TAny holes *)
|
||||
let rec unifiable ty1 ty2 =
|
||||
match Marked.unmark ty1, Marked.unmark ty2 with
|
||||
| TAny, _ | _, TAny -> true
|
||||
| TLit l1, TLit l2 -> equal_tlit l1 l2
|
||||
| TTuple tys1, TTuple tys2 -> unifiable_list tys1 tys2
|
||||
| TStruct n1, TStruct n2 -> StructName.equal n1 n2
|
||||
| TEnum n1, TEnum n2 -> EnumName.equal n1 n2
|
||||
| TOption t1, TOption t2 -> unifiable t1 t2
|
||||
| TArrow (t1, t1'), TArrow (t2, t2') -> unifiable t1 t2 && unifiable t1' t2'
|
||||
| TArray t1, TArray t2 -> unifiable t1 t2
|
||||
| ( (TLit _ | TTuple _ | TStruct _ | TEnum _ | TOption _ | TArrow _ | TArray _),
|
||||
_ ) ->
|
||||
false
|
||||
|
||||
and unifiable_list tys1 tys2 =
|
||||
try List.for_all2 unifiable tys1 tys2 with Invalid_argument _ -> false
|
||||
|
||||
let rec compare ty1 ty2 =
|
||||
match Marked.unmark ty1, Marked.unmark ty2 with
|
||||
| TLit l1, TLit l2 -> compare_tlit l1 l2
|
||||
| TTuple tys1, TTuple tys2 -> List.compare compare tys1 tys2
|
||||
| TStruct n1, TStruct n2 -> StructName.compare n1 n2
|
||||
| TEnum en1, TEnum en2 -> EnumName.compare en1 en2
|
||||
| TOption t1, TOption t2 -> compare t1 t2
|
||||
| TArrow (a1, b1), TArrow (a2, b2) -> (
|
||||
match compare a1 a2 with 0 -> compare b1 b2 | n -> n)
|
||||
| TArray t1, TArray t2 -> compare t1 t2
|
||||
| TAny, TAny -> 0
|
||||
| TLit _, _ -> -1
|
||||
| _, TLit _ -> 1
|
||||
| TTuple _, _ -> -1
|
||||
| _, TTuple _ -> 1
|
||||
| TStruct _, _ -> -1
|
||||
| _, TStruct _ -> 1
|
||||
| TEnum _, _ -> -1
|
||||
| _, TEnum _ -> 1
|
||||
| TOption _, _ -> -1
|
||||
| _, TOption _ -> 1
|
||||
| TArrow _, _ -> -1
|
||||
| _, TArrow _ -> 1
|
||||
| TArray _, _ -> -1
|
||||
| _, TArray _ -> 1
|
||||
|
||||
let rec arrow_return = function TArrow (_, b), _ -> arrow_return b | t -> t
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user