Merge branch 'master' into afromher_334

This commit is contained in:
Denis Merigoux 2023-01-20 14:05:38 -05:00
commit 7cffc53169
No known key found for this signature in database
GPG Key ID: EE99DCFA365C3EE3
331 changed files with 33415 additions and 30203 deletions

View File

@ -1,28 +0,0 @@
{ lib, fetchFromGitHub, buildDunePackage }:
# We need the very last version "bleeding edge" since previous versions don't use dune.
buildDunePackage rec {
pname = "bindlib";
version = "5.0.1a";
minimumOCamlVersion = "4.0.8";
useDune2 = true;
src = fetchFromGitHub {
owner = "rlepigre";
repo = "ocaml-${pname}";
rev = "317f195d22c75f556053039cd94b52bd0c423709";
name = pname;
hash = "sha256-uO/Ko9PmQ+wE0d9jfEngd4G014B4nxGgfQyEvB52Pz8=";
};
meta = with lib; {
homepage = "https://rlepigre.github.io/ocaml-bindlib/";
description =
"Bindlib is a library allowing the manipulation of data structures with bound variables";
license = licenses.lgpl3;
maintainers = [ ];
};
}

View File

@ -5,7 +5,7 @@
, bindlib
, buildDunePackage
, calendar
, cmdliner_1_1_0
, cmdliner
, cppo
, dates_calc
, fetchFromGitHub
@ -42,7 +42,7 @@ buildDunePackage rec {
ansiterminal
benchmark
bindlib
cmdliner_1_1_0
cmdliner
cppo
dates_calc
js_of_ocaml

View File

@ -1,32 +1,13 @@
{ ocamlPackages, fetchurl }:
ocamlPackages.overrideScope' (self: super: {
cmdliner_1_1_0 = super.cmdliner.overrideAttrs (o: rec {
version = "1.1.0";
src = fetchurl {
url = "https://erratique.ch/software/${o.pname}/releases/${o.pname }-${version}.tbz";
sha256 = "sha256-irWd4HTlJSYuz3HMgi1de2GVL2qus0QjeCe1WdsSs8Q=";
};
});
alcotest = (super.alcotest.override {
cmdliner = self.cmdliner_1_1_0;
}).overrideAttrs (_: {
alcotest = (super.alcotest.override {}).overrideAttrs (_: {
doCheck = false;
});
# Use a more recent version of `re` than the one packaged in nixpkgs
re = super.re.overrideAttrs (o: rec {
version = "1.10.4";
src = fetchurl {
url = "https://github.com/ocaml/ocaml-${o.pname}/releases/download/${version}/${o.pname}-${version}.tbz";
sha256 = "sha256-g+s+QwCqmx3HggdJAQ9DYuqDUkdCEwUk14wgzpnKdHw=";
};
});
catala = self.callPackage ./catala.nix { };
bindlib = self.callPackage ./bindlib.nix { };
unionfind = self.callPackage ./unionfind.nix { };
ninja_utils = self.callPackage ./ninja_utils.nix { };
clerk = self.callPackage ./clerk.nix { };
ppx_yojson_conv = self.callPackage ./ppx_yojson_conv.nix { };
ubase = self.callPackage ./ubase.nix { };
dates_calc = self.callPackage ./dates_calc.nix { };
})

View File

@ -1,20 +0,0 @@
{ lib, fetchurl, buildDunePackage, ppxlib, ppx_yojson_conv_lib, ppx_js_style }:
buildDunePackage rec {
pname = "ppx_yojson_conv";
version = "0.14.0";
minimumOCamlVersion = "4.0.8";
useDune2 = true;
propagatedBuildInputs = [
ppxlib ppx_yojson_conv_lib ppx_js_style
];
src = fetchurl
{
url = "https://ocaml.janestreet.com/ocaml-core/v0.14/files/ppx_yojson_conv-v0.14.0.tar.gz";
sha256 = "0ls6vzj7k0wrjliifqczs78anbc8b88as5w7a3wixfcs1gjfsp2w";
};
}

View File

@ -104,19 +104,20 @@ need more, here is how one can be added:
- Choose a name wisely. Be ready to patch any code that already used the name
for scope parameters, variables or structure fields, since it won't compile
anymore.
- Add an element to the `builtin_expression` type in `surface/ast.ml(i)`
- Add an element to the `builtin_expression` type in `surface/ast.ml`
- Add your builtin in the `builtins` list in `surface/lexer.cppo.ml`, and with
proper translations in all of the language-specific modules
`surface/lexer_en.cppo.ml`, `surface/lexer_fr.cppo.ml`, etc. Don't forget the
macro at the beginning of `lexer.cppo.ml`.
- The rest can all be done by following the type errors downstream:
- Add a corresponding element to the lower-level AST in `dcalc/ast.ml(i)`, type `unop`
- Extend the translation accordingly in `surface/desugaring.ml`
- Extend the printer (`dcalc/print.ml`) and the typer with correct type
information (`dcalc/typing.ml`)
- Add a corresponding element to the lower-level AST in `shared_ast/definitions.ml`, type `Op.t`
- Extend the generic operations on operators in `shared_ast/operators.ml` as well as the type information for the operator
- Extend the translation accordingly in `desugared/from_surface.ml`
- Extend the printer (`shared_ast/print.ml`)
- Finally, provide the implementations:
- in `lcalc/to_ocaml.ml`, function `format_unop`
- in `dcalc/interpreter.ml`, function `evaluate_operator`
- in `../runtimes/ocaml/runtime.ml`
- in `../runtimes/python/catala/src/catala/runtime.py`
- Update the syntax guide in `doc/syntax/syntax.tex` with your new builtin
### Internationalization of the Catala syntax

View File

@ -3,7 +3,7 @@
FROM ocamlpro/ocaml:4.14-2022-07-17 AS dev-build-context
# pandoc is not in alpine stable yet, install it manually with an explicit repository
RUN sudo apk add pandoc --repository=http://dl-cdn.alpinelinux.org/alpine/edge/testing/
RUN sudo apk add pandoc --repository=http://dl-cdn.alpinelinux.org/alpine/edge/community/
RUN mkdir catala
WORKDIR catala

View File

@ -22,18 +22,31 @@ Finally, start a shell inside a new container created from the newly built image
The repository provides nix files to build or develop the catala compiler.
Once [nix is installed](https://nixos.org/manual/nix/stable/#ch-installing-binary),
it is possible to enter a development shell:
with flakes enabled it is possible to enter a development shell:
nix-shell
nix develop
or to build the Catala compiler, documentation and runtime library:
nix-build release.nix
nix build
Dependencies not yet in nixpkgs (`bindlib` and `unionFind` at the moment of writing)
are hardcoded inside the `.nix` directory. The `default.nix` should be compatible with
Dependencies not yet in nixpkgs (`ubase` and `unionFind` at the moment of writing)
are hardcoded inside the `.nix` directory. The `.nix/catala.nix` should be compatible with
nixpkgs, if it finds a maintainer.
To develop catala's compiler using vscode using ocaml's [lsp](https://microsoft.github.io/language-server-protocol/), you can use the [ocaml-platform extension](https://marketplace.visualstudio.com/items?itemName=ocamllabs.ocaml-platform) with the following settings (inside the file `.vscode/settings.json`).
```json
{
"ocaml.sandbox": {
"kind": "custom",
"template": "nix develop --command $prog $args"
},
}
```
The nix build is updated weekly by an automatic github action.
### With opam
The Catala compiler is written using OCaml. First, you have to install `opam`,

View File

@ -299,8 +299,9 @@ run_french_law_library_benchmark_python: $(PY_VIRTUALENV) \
CATALA_OPTS?=
CLERK_OPTS?=--makeflags="$(MAKEFLAGS)"
CATALA_BIN=_build/default/compiler/catala.exe
CLERK_BIN=_build/default/build_system/clerk.exe
CATALA_BIN=_build/default/$(COMPILER_DIR)/catala.exe
CLERK_BIN=_build/default/$(BUILD_SYSTEM_DIR)/clerk.exe
CATALA_LEGIFRANCE_BIN=_build/default/$(CATALA_LEGIFRANCE_DIR)/catala_legifrance.exe
CLERK=$(CLERK_BIN) --exe $(CATALA_BIN) \
$(CLERK_OPTS) $(if $(CATALA_OPTS),--catala-opts=$(CATALA_OPTS),)
@ -336,7 +337,7 @@ tests/%: .FORCE
# Website assets
##########################################
WEBSITE_ASSETS = grammar.html catala.html
WEBSITE_ASSETS = grammar.html catala.html clerk.html catala_legifrance.html
$(addprefix _build/default/,$(WEBSITE_ASSETS)):
dune build $@
@ -386,6 +387,11 @@ help_clerk:
help_catala:
$(CATALA_BIN) --help
#> help_catala_legifrance : Display the catala_legifrance man page
help_catala_legifrance:
$(CATALA_LEGIFRANCE_BIN) --help
##########################################
# Special targets
##########################################

View File

@ -16,7 +16,7 @@
the License. *)
open Cmdliner
open Utils
open Catala_utils
open Ninja_utils
module Nj = Ninja_utils
@ -524,7 +524,7 @@ let collect_all_ninja_build
(tested_file : string)
(reset_test_outputs : bool) : (string * ninja) option =
let expected_outputs = search_for_expected_outputs tested_file in
if List.length expected_outputs = 0 then (
if expected_outputs = [] then (
Cli.debug_print "No expected outputs were found for test file %s"
tested_file;
None)
@ -890,10 +890,18 @@ let driver
let files_or_folders = List.sort_uniq String.compare files_or_folders
and catala_exe = Option.fold ~none:"catala" ~some:Fun.id catala_exe
and catala_opts = Option.fold ~none:"" ~some:Fun.id catala_opts
and ninja_output =
Option.fold
~none:(Filename.temp_file "clerk_build_" ".ninja")
~some:Fun.id ninja_output
and with_ninja_output k =
match ninja_output with
| Some f -> k f
| None -> (
let f = Filename.temp_file "clerk_build_" ".ninja" in
match k f with
| exception e ->
if not debug then Sys.remove f;
raise e
| r ->
Sys.remove f;
r)
in
match String.lowercase_ascii command with
| "test" -> (
@ -919,20 +927,22 @@ let driver
if 0 = List.compare_lengths ctx.all_failed_names files_or_folders then
return_ok
else
try
File.with_formatter_of_file ninja_output (fun fmt ->
Cli.debug_print "writing %s..." ninja_output;
with_ninja_output
@@ fun nin ->
match
File.with_formatter_of_file nin (fun fmt ->
Cli.debug_print "writing %s..." nin;
Nj.format fmt
(add_root_test_build ninja ctx.all_file_names
ctx.all_test_builds));
ctx.all_test_builds))
with
| () ->
let ninja_cmd =
"ninja -k 0 -f " ^ ninja_output ^ " " ^ ninja_flags ^ " test"
"ninja -k 0 -f " ^ nin ^ " " ^ ninja_flags ^ " test"
in
Cli.debug_print "executing '%s'..." ninja_cmd;
let return = Sys.command ninja_cmd in
if not debug then Sys.remove ninja_output;
return
with Sys_error e ->
Sys.command ninja_cmd
| exception Sys_error e ->
Cli.error_print "can not write in %s" e;
return_err)
| "run" -> (

View File

@ -9,7 +9,7 @@
(public_name clerk.driver)
(libraries
catala.runtime_ocaml
catala.utils
catala.catala_utils
ninja_utils
cmdliner
re

View File

@ -34,6 +34,7 @@ depends: [
"ppx_yojson_conv" {>= "0.14.0"}
"re" {>= "1.9.0"}
"sedlex" {>= "2.4"}
"uutf" {>= "1.0.3"}
"ubase" {>= "0.05"}
"unionFind" {>= "20200320"}
"visitors" {>= "20200210"}
@ -45,6 +46,7 @@ depends: [
"obelisk" {cataladevmode}
"conf-npm" {cataladevmode}
"conf-python-3-dev" {cataladevmode}
"cpdf" {cataladevmode}
"z3" {catalaz3mode}
]
depopts: ["z3"]

View File

@ -7,12 +7,12 @@ In {{: desugared.html} the desugared representation} or in the
global identifiers. These identifiers use OCaml's type system to statically
distinguish e.g. a scope identifier from a struct identifier.
The {!module: Utils.Uid} module provides a generative functor whose output is
The {!module: Uid} module provides a generative functor whose output is
a fresh sort of global identifiers.
Related modules:
{!modules: Utils.Uid}
{!modules: Uid}
{1 Source code positions}
@ -22,7 +22,7 @@ code. These annotations are critical to produce readable error messages.
Related modules:
{!modules: Utils.Pos}
{!modules: Pos}
{1 Error messages}

View File

@ -172,7 +172,7 @@ let plugins_dirs =
let default =
let ( / ) = Filename.concat in
[
Sys.executable_name
Filename.dirname Sys.executable_name
/ Filename.parent_dir_name
/ "lib"
/ "catala"

View File

@ -1,8 +1,8 @@
(library
(name utils)
(public_name catala.utils)
(name catala_utils)
(public_name catala.catala_utils)
(libraries cmdliner ubase ANSITerminal re bindlib catala.runtime_ocaml))
(documentation
(package catala)
(mld_files utils))
(mld_files catala_utils))

View File

@ -26,7 +26,7 @@ exception StructuredError of (string * (string option * Pos.t) list)
let print_structured_error (msg : string) (pos : (string option * Pos.t) list) :
string =
Printf.sprintf "%s%s%s" msg
(if List.length pos = 0 then "" else "\n\n")
(if pos = [] then "" else "\n\n")
(String.concat "\n\n"
(List.map
(fun (msg, pos) ->

View File

@ -79,11 +79,11 @@ let to_string (pos : t) : string =
let to_string_short (pos : t) : string =
let s, e = pos.code_pos in
if e.Lexing.pos_lnum = s.Lexing.pos_lnum then
Printf.sprintf "%s:%d.%d-%d" s.Lexing.pos_fname s.Lexing.pos_lnum
Printf.sprintf "%s:%d.%d-%d:" s.Lexing.pos_fname s.Lexing.pos_lnum
(s.Lexing.pos_cnum - s.Lexing.pos_bol)
(e.Lexing.pos_cnum - e.Lexing.pos_bol)
else
Printf.sprintf "%s:%d.%d-%d.%d" s.Lexing.pos_fname s.Lexing.pos_lnum
Printf.sprintf "%s:%d.%d-%d.%d:" s.Lexing.pos_fname s.Lexing.pos_lnum
(s.Lexing.pos_cnum - s.Lexing.pos_bol)
e.Lexing.pos_lnum
(e.Lexing.pos_cnum - e.Lexing.pos_bol)
@ -102,6 +102,27 @@ let string_repeat n s =
done;
Bytes.to_string buf
(* Note: this should do, but remains incorrect for combined unicode characters
that display as one (e.g. `e` + postfix `'`). We should switch to Uuseg at
some poing *)
let string_columns s =
let len = String.length s in
let rec aux ncols i =
if i >= len then ncols
else if s.[i] = '\t' then aux (ncols + 8) (i + 1)
else
aux (ncols + 1) (i + Uchar.utf_decode_length (String.get_utf_8_uchar s i))
in
aux 0 0
let utf8_byte_index s ui0 =
let rec aux bi ui =
if ui >= ui0 then bi
else
aux (bi + Uchar.utf_decode_length (String.get_utf_8_uchar s bi)) (ui + 1)
in
aux 0 0
let retrieve_loc_text (pos : t) : string =
try
let filename = get_file pos in
@ -132,34 +153,32 @@ let retrieve_loc_text (pos : t) : string =
let print_matched_line (line : string) (line_no : int) : string =
let line_indent = indent_number line in
let error_indicator_style = [ANSITerminal.red; ANSITerminal.Bold] in
line
^
if line_no >= sline && line_no <= eline then
"\n"
^
if line_no = sline && line_no = eline then
Cli.with_style error_indicator_style "%*s%s"
(get_start_column pos - 1)
""
(string_repeat
(max (get_end_column pos - get_start_column pos) 0)
"")
else if line_no = sline && line_no <> eline then
Cli.with_style error_indicator_style "%*s%s"
(get_start_column pos - 1)
""
(string_repeat
(max (String.length line - get_start_column pos) 0)
"")
else if line_no <> sline && line_no <> eline then
Cli.with_style error_indicator_style "%*s%s" line_indent ""
(string_repeat (max (String.length line - line_indent) 0) "")
else if line_no <> sline && line_no = eline then
Cli.with_style error_indicator_style "%*s%*s" line_indent ""
(get_end_column pos - 1 - line_indent)
(string_repeat (max (get_end_column pos - line_indent) 0) "")
else assert false (* should not happen *)
else ""
let match_start_index =
utf8_byte_index line
(if line_no = sline then get_start_column pos - 1 else line_indent)
in
let match_end_index =
if line_no = eline then utf8_byte_index line (get_end_column pos - 1)
else String.length line
in
let unmatched_prefix = String.sub line 0 match_start_index in
let matched_substring =
String.sub line match_start_index
(max 0 (match_end_index - match_start_index))
in
let match_start_col = string_columns unmatched_prefix in
let match_num_cols = string_columns matched_substring in
String.concat ""
(line
:: "\n"
::
(if line_no >= sline && line_no <= eline then
[
string_repeat match_start_col " ";
Cli.with_style error_indicator_style "%s"
(string_repeat match_num_cols "");
]
else []))
in
let include_extra_count = 0 in
let rec get_lines (n : int) : string list =
@ -193,10 +212,8 @@ let retrieve_loc_text (pos : t) : string =
(Cli.with_style blue_style "└%s┐" (string_repeat spaces ""));
Buffer.add_char buf '\n';
Buffer.add_string buf
(Cli.add_prefix_to_each_line
(String.concat "\n" ("" :: pos_lines))
(fun i ->
let cur_line = sline - include_extra_count + i - 1 in
(Cli.add_prefix_to_each_line (String.concat "\n" pos_lines) (fun i ->
let cur_line = sline - include_extra_count + i in
if
cur_line >= sline
&& cur_line <= sline + (2 * (eline - sline))

View File

@ -14,39 +14,47 @@
License for the specific language governing permissions and limitations under
the License. *)
let to_ascii : string -> string = Ubase.from_utf8
include Stdlib.String
let is_uppercase_ascii (c : char) : bool =
let c = Char.code c in
(* 'A' <= c && c <= 'Z' *)
0x41 <= c && c <= 0x5b
let to_ascii : string -> string = Ubase.from_utf8
let is_uppercase_ascii = function 'A' .. 'Z' -> true | _ -> false
let begins_with_uppercase (s : string) : bool =
if "" = s then false else is_uppercase_ascii (to_ascii s).[0]
"" <> s && is_uppercase_ascii (get (to_ascii s) 0)
let to_snake_case (s : string) : string =
let out = ref "" in
to_ascii s
|> String.iteri (fun i c ->
|> iteri (fun i c ->
out :=
!out
^ (if is_uppercase_ascii c && 0 <> i then "_" else "")
^ String.lowercase_ascii (String.make 1 c));
^ lowercase_ascii (make 1 c));
!out
let to_camel_case (s : string) : string =
let last_was_underscore = ref false in
let out = ref "" in
to_ascii s
|> String.iteri (fun i c ->
|> iteri (fun i c ->
let is_underscore = c = '_' in
let c_string = String.make 1 c in
let c_string = make 1 c in
out :=
!out
^
if is_underscore then ""
else if !last_was_underscore || 0 = i then
String.uppercase_ascii c_string
else if !last_was_underscore || 0 = i then uppercase_ascii c_string
else c_string;
last_was_underscore := is_underscore);
!out
let remove_prefix ~prefix s =
if starts_with ~prefix s then
let plen = length prefix in
sub s plen (length s - plen)
else s
let format_t = Format.pp_print_string
module Set = Set.Make (Stdlib.String)
module Map = Map.Make (Stdlib.String)

View File

@ -14,6 +14,10 @@
License for the specific language governing permissions and limitations under
the License. *)
include module type of Stdlib.String
module Set : Set.S with type elt = string
module Map : Map.S with type key = string
(** Helper functions used for string manipulation. *)
val to_ascii : string -> string
@ -34,3 +38,11 @@ val to_snake_case : string -> string
val to_camel_case : string -> string
(** Converts snake_case into CamlCase after removing Remove all diacritics on
Latin letters. *)
val remove_prefix : prefix:string -> string -> string
(** [remove_prefix ~prefix str] returns
- if [str] starts with [prefix], a string [s] such that [prefix ^ s = str]
- otherwise, [str] unchanged *)
val format_t : Format.formatter -> string -> unit

View File

@ -18,7 +18,7 @@ module type Info = sig
type info
val to_string : info -> string
val format_info : Format.formatter -> info -> unit
val format : Format.formatter -> info -> unit
val equal : info -> info -> bool
val compare : info -> info -> int
end
@ -33,10 +33,21 @@ module type Id = sig
val equal : t -> t -> bool
val format_t : Format.formatter -> t -> unit
val hash : t -> int
module Set : Set.S with type elt = t
module Map : Map.S with type key = t
end
module Make (X : Info) () : Id with type info = X.info = struct
type t = { id : int; info : X.info }
module Ordering = struct
type t = { id : int; info : X.info }
let compare (x : t) (y : t) : int = compare x.id y.id
let equal x y = Int.equal x.id y.id
end
include Ordering
type info = X.info
let counter = ref 0
@ -46,20 +57,20 @@ module Make (X : Info) () : Id with type info = X.info = struct
{ id = !counter; info }
let get_info (uid : t) : X.info = uid.info
let compare (x : t) (y : t) : int = compare x.id y.id
let equal x y = Int.equal x.id y.id
let format_t (fmt : Format.formatter) (x : t) : unit =
X.format_info fmt x.info
let format_t (fmt : Format.formatter) (x : t) : unit = X.format fmt x.info
let hash (x : t) : int = x.id
module Set = Set.Make (Ordering)
module Map = Map.Make (Ordering)
end
module MarkedString = struct
type info = string Marked.pos
let to_string (s, _) = s
let format_info fmt i = Format.pp_print_string fmt (to_string i)
let format fmt i = Format.pp_print_string fmt (to_string i)
let equal i1 i2 = String.equal (Marked.unmark i1) (Marked.unmark i2)
let compare i1 i2 = String.compare (Marked.unmark i1) (Marked.unmark i2)
end
module Gen () = Make (MarkedString) ()

View File

@ -21,7 +21,7 @@ module type Info = sig
type info
val to_string : info -> string
val format_info : Format.formatter -> info -> unit
val format : Format.formatter -> info -> unit
val equal : info -> info -> bool
(** Equality disregards position *)
@ -48,9 +48,15 @@ module type Id = sig
val equal : t -> t -> bool
val format_t : Format.formatter -> t -> unit
val hash : t -> int
module Set : Set.S with type elt = t
module Map : Map.S with type key = t
end
(** This is the generative functor that ensures that two modules resulting from
two different calls to [Make] will be viewed as different types [t] by the
OCaml typechecker. Prevents mixing up different sorts of identifiers. *)
module Make (X : Info) () : Id with type info = X.info
module Gen () : Id with type info = MarkedString.info
(** Shortcut for creating a kind of uids over marked strings *)

View File

@ -1,3 +1,4 @@
open Catala_utils
open Driver
open Js_of_ocaml
@ -12,7 +13,7 @@ let _ =
driver
(Contents (Js.to_string contents))
{
Utils.Cli.debug = false;
Cli.debug = false;
color = Never;
wrap_weaved_output = false;
avoid_exceptions = false;

View File

@ -1,7 +1,15 @@
(library
(name dcalc)
(public_name catala.dcalc)
(libraries bindlib unionFind utils re ubase catala.runtime_ocaml shared_ast)
(libraries
bindlib
unionFind
catala_utils
re
ubase
catala.runtime_ocaml
shared_ast
scopelang)
(preprocess
(pps visitors.ppx)))

View File

@ -16,4 +16,4 @@
(** Scope language to default calculus translator *)
val translate_program : 'm Ast.program -> 'm Dcalc.Ast.program
val translate_program : 'm Scopelang.Ast.program -> 'm Ast.program

View File

@ -16,7 +16,7 @@
(** Reference interpreter for the default calculus *)
open Utils
open Catala_utils
open Shared_ast
module Runtime = Runtime_ocaml.Runtime
@ -29,272 +29,117 @@ let log_indent = ref 0
(** {1 Evaluation} *)
let rec evaluate_operator
(ctx : decl_ctx)
(op : operator)
(pos : Pos.t)
(args : 'm Ast.expr list) : 'm Ast.naked_expr =
(* Try to apply [div] and if a [Division_by_zero] exceptions is catched, use
[op] to raise multispanned errors. *)
let apply_div_or_raise_err (div : unit -> 'm Ast.naked_expr) :
'm Ast.naked_expr =
try div ()
with Division_by_zero ->
let print_log ctx entry infos pos e =
if !Cli.trace_flag then
match entry with
| VarDef _ ->
(* TODO: this usage of Format is broken, Formatting requires that all is
formatted in one pass, without going through intermediate "%s" *)
Cli.log_format "%*s%a %a: %s" (!log_indent * 2) "" Print.log_entry entry
Print.uid_list infos
(match Marked.unmark e with
| EAbs _ -> Cli.with_style [ANSITerminal.green] "<function>"
| _ ->
let expr_str =
Format.asprintf "%a" (Expr.format ctx ~debug:false) e
in
let expr_str =
Re.Pcre.substitute ~rex:(Re.Pcre.regexp "\n\\s*")
~subst:(fun _ -> " ")
expr_str
in
Cli.with_style [ANSITerminal.green] "%s" expr_str)
| PosRecordIfTrueBool -> (
match pos <> Pos.no_pos, Marked.unmark e with
| true, ELit (LBool true) ->
Cli.log_format "%*s%a%s:\n%s" (!log_indent * 2) "" Print.log_entry entry
(Cli.with_style [ANSITerminal.green] "Definition applied")
(Cli.add_prefix_to_each_line (Pos.retrieve_loc_text pos) (fun _ ->
Format.asprintf "%*s" (!log_indent * 2) ""))
| _ -> ())
| BeginCall ->
Cli.log_format "%*s%a %a" (!log_indent * 2) "" Print.log_entry entry
Print.uid_list infos;
log_indent := !log_indent + 1
| EndCall ->
log_indent := !log_indent - 1;
Cli.log_format "%*s%a %a" (!log_indent * 2) "" Print.log_entry entry
Print.uid_list infos
(* Todo: this should be handled early when resolving overloads. Here we have
proper structural equality, but the OCaml backend for example uses the
builtin equality function instead of this. *)
let rec handle_eq ctx pos e1 e2 =
let open Runtime.Oper in
match e1, e2 with
| ELit LUnit, ELit LUnit -> true
| ELit (LBool b1), ELit (LBool b2) -> not (o_xor b1 b2)
| ELit (LInt x1), ELit (LInt x2) -> o_eq_int_int x1 x2
| ELit (LRat x1), ELit (LRat x2) -> o_eq_rat_rat x1 x2
| ELit (LMoney x1), ELit (LMoney x2) -> o_eq_mon_mon x1 x2
| ELit (LDuration x1), ELit (LDuration x2) -> o_eq_dur_dur x1 x2
| ELit (LDate x1), ELit (LDate x2) -> o_eq_dat_dat x1 x2
| EArray es1, EArray es2 -> (
try
List.for_all2
(fun e1 e2 ->
match evaluate_operator ctx Eq pos [e1; e2] with
| ELit (LBool b) -> b
| _ -> assert false
(* should not happen *))
es1 es2
with Invalid_argument _ -> false)
| EStruct { fields = es1; name = s1 }, EStruct { fields = es2; name = s2 } ->
StructName.equal s1 s2
&& StructField.Map.equal
(fun e1 e2 ->
match evaluate_operator ctx Eq pos [e1; e2] with
| ELit (LBool b) -> b
| _ -> assert false
(* should not happen *))
es1 es2
| ( EInj { e = e1; cons = i1; name = en1 },
EInj { e = e2; cons = i2; name = en2 } ) -> (
try
EnumName.equal en1 en2
&& EnumConstructor.equal i1 i2
&&
match evaluate_operator ctx Eq pos [e1; e2] with
| ELit (LBool b) -> b
| _ -> assert false
(* should not happen *)
with Invalid_argument _ -> false)
| _, _ -> false (* comparing anything else return false *)
(* Call-by-value: the arguments are expected to be already evaluated here *)
and evaluate_operator :
type k.
decl_ctx ->
(dcalc, k) operator ->
Pos.t ->
'm Ast.expr list ->
'm Ast.naked_expr =
fun ctx op pos args ->
let protect f x y =
let get_binop_args_pos = function
| (arg0 :: arg1 :: _ : 'm Ast.expr list) ->
[None, Expr.pos arg0; None, Expr.pos arg1]
| _ -> assert false
in
try f x y with
| Division_by_zero ->
Errors.raise_multispanned_error
[
Some "The division operator:", pos;
Some "The null denominator:", Expr.pos (List.nth args 1);
]
"division by zero at runtime"
in
let get_binop_args_pos = function
| (arg0 :: arg1 :: _ : 'm Ast.expr list) ->
[None, Expr.pos arg0; None, Expr.pos arg1]
| _ -> assert false
in
(* Try to apply [cmp] and if a [UncomparableDurations] exceptions is catched,
use [args] to raise multispanned errors. *)
let apply_cmp_or_raise_err
(cmp : unit -> 'm Ast.naked_expr)
(args : 'm Ast.expr list) : 'm Ast.naked_expr =
try cmp ()
with Runtime.UncomparableDurations ->
| Runtime.UncomparableDurations ->
Errors.raise_multispanned_error (get_binop_args_pos args)
"Cannot compare together durations that cannot be converted to a \
precise number of days"
in
match op, List.map Marked.unmark args with
| Ternop Fold, [_f; _init; EArray es] ->
Marked.unmark
(List.fold_left
(fun acc e' ->
evaluate_expr ctx
(Marked.same_mark_as (EApp (List.nth args 0, [acc; e'])) e'))
(List.nth args 1) es)
| Binop And, [ELit (LBool b1); ELit (LBool b2)] -> ELit (LBool (b1 && b2))
| Binop Or, [ELit (LBool b1); ELit (LBool b2)] -> ELit (LBool (b1 || b2))
| Binop Xor, [ELit (LBool b1); ELit (LBool b2)] -> ELit (LBool (b1 <> b2))
| Binop (Add KInt), [ELit (LInt i1); ELit (LInt i2)] ->
ELit (LInt Runtime.(i1 +! i2))
| Binop (Sub KInt), [ELit (LInt i1); ELit (LInt i2)] ->
ELit (LInt Runtime.(i1 -! i2))
| Binop (Mult KInt), [ELit (LInt i1); ELit (LInt i2)] ->
ELit (LInt Runtime.(i1 *! i2))
| Binop (Div KInt), [ELit (LInt i1); ELit (LInt i2)] ->
apply_div_or_raise_err (fun _ -> ELit (LInt Runtime.(i1 /! i2)))
| Binop (Add KRat), [ELit (LRat i1); ELit (LRat i2)] ->
ELit (LRat Runtime.(i1 +& i2))
| Binop (Sub KRat), [ELit (LRat i1); ELit (LRat i2)] ->
ELit (LRat Runtime.(i1 -& i2))
| Binop (Mult KRat), [ELit (LRat i1); ELit (LRat i2)] ->
ELit (LRat Runtime.(i1 *& i2))
| Binop (Div KRat), [ELit (LRat i1); ELit (LRat i2)] ->
apply_div_or_raise_err (fun _ -> ELit (LRat Runtime.(i1 /& i2)))
| Binop (Add KMoney), [ELit (LMoney m1); ELit (LMoney m2)] ->
ELit (LMoney Runtime.(m1 +$ m2))
| Binop (Sub KMoney), [ELit (LMoney m1); ELit (LMoney m2)] ->
ELit (LMoney Runtime.(m1 -$ m2))
| Binop (Mult KMoney), [ELit (LMoney m1); ELit (LRat m2)] ->
ELit (LMoney Runtime.(m1 *$ m2))
| Binop (Div KMoney), [ELit (LMoney m1); ELit (LMoney m2)] ->
apply_div_or_raise_err (fun _ -> ELit (LRat Runtime.(m1 /$ m2)))
| Binop (Add KDuration), [ELit (LDuration d1); ELit (LDuration d2)] ->
ELit (LDuration Runtime.(d1 +^ d2))
| Binop (Sub KDuration), [ELit (LDuration d1); ELit (LDuration d2)] ->
ELit (LDuration Runtime.(d1 -^ d2))
| Binop (Sub KDate), [ELit (LDate d1); ELit (LDate d2)] ->
ELit (LDuration Runtime.(d1 -@ d2))
| Binop (Add KDate), [ELit (LDate d1); ELit (LDuration d2)] ->
ELit (LDate Runtime.(d1 +@ d2))
| Binop (Mult KDuration), [ELit (LDuration d1); ELit (LInt i1)] ->
ELit (LDuration Runtime.(d1 *^ i1))
| Binop (Lt KInt), [ELit (LInt i1); ELit (LInt i2)] ->
ELit (LBool Runtime.(i1 <! i2))
| Binop (Lte KInt), [ELit (LInt i1); ELit (LInt i2)] ->
ELit (LBool Runtime.(i1 <=! i2))
| Binop (Gt KInt), [ELit (LInt i1); ELit (LInt i2)] ->
ELit (LBool Runtime.(i1 >! i2))
| Binop (Gte KInt), [ELit (LInt i1); ELit (LInt i2)] ->
ELit (LBool Runtime.(i1 >=! i2))
| Binop (Lt KRat), [ELit (LRat i1); ELit (LRat i2)] ->
ELit (LBool Runtime.(i1 <& i2))
| Binop (Lte KRat), [ELit (LRat i1); ELit (LRat i2)] ->
ELit (LBool Runtime.(i1 <=& i2))
| Binop (Gt KRat), [ELit (LRat i1); ELit (LRat i2)] ->
ELit (LBool Runtime.(i1 >& i2))
| Binop (Gte KRat), [ELit (LRat i1); ELit (LRat i2)] ->
ELit (LBool Runtime.(i1 >=& i2))
| Binop (Lt KMoney), [ELit (LMoney m1); ELit (LMoney m2)] ->
ELit (LBool Runtime.(m1 <$ m2))
| Binop (Lte KMoney), [ELit (LMoney m1); ELit (LMoney m2)] ->
ELit (LBool Runtime.(m1 <=$ m2))
| Binop (Gt KMoney), [ELit (LMoney m1); ELit (LMoney m2)] ->
ELit (LBool Runtime.(m1 >$ m2))
| Binop (Gte KMoney), [ELit (LMoney m1); ELit (LMoney m2)] ->
ELit (LBool Runtime.(m1 >=$ m2))
| Binop (Lt KDuration), [ELit (LDuration d1); ELit (LDuration d2)] ->
apply_cmp_or_raise_err (fun _ -> ELit (LBool Runtime.(d1 <^ d2))) args
| Binop (Lte KDuration), [ELit (LDuration d1); ELit (LDuration d2)] ->
apply_cmp_or_raise_err (fun _ -> ELit (LBool Runtime.(d1 <=^ d2))) args
| Binop (Gt KDuration), [ELit (LDuration d1); ELit (LDuration d2)] ->
apply_cmp_or_raise_err (fun _ -> ELit (LBool Runtime.(d1 >^ d2))) args
| Binop (Gte KDuration), [ELit (LDuration d1); ELit (LDuration d2)] ->
apply_cmp_or_raise_err (fun _ -> ELit (LBool Runtime.(d1 >=^ d2))) args
| Binop (Lt KDate), [ELit (LDate d1); ELit (LDate d2)] ->
ELit (LBool Runtime.(d1 <@ d2))
| Binop (Lte KDate), [ELit (LDate d1); ELit (LDate d2)] ->
ELit (LBool Runtime.(d1 <=@ d2))
| Binop (Gt KDate), [ELit (LDate d1); ELit (LDate d2)] ->
ELit (LBool Runtime.(d1 >@ d2))
| Binop (Gte KDate), [ELit (LDate d1); ELit (LDate d2)] ->
ELit (LBool Runtime.(d1 >=@ d2))
| Binop Eq, [ELit LUnit; ELit LUnit] -> ELit (LBool true)
| Binop Eq, [ELit (LDuration d1); ELit (LDuration d2)] ->
ELit (LBool Runtime.(d1 =^ d2))
| Binop Eq, [ELit (LDate d1); ELit (LDate d2)] ->
ELit (LBool Runtime.(d1 =@ d2))
| Binop Eq, [ELit (LMoney m1); ELit (LMoney m2)] ->
ELit (LBool Runtime.(m1 =$ m2))
| Binop Eq, [ELit (LRat i1); ELit (LRat i2)] ->
ELit (LBool Runtime.(i1 =& i2))
| Binop Eq, [ELit (LInt i1); ELit (LInt i2)] ->
ELit (LBool Runtime.(i1 =! i2))
| Binop Eq, [ELit (LBool b1); ELit (LBool b2)] -> ELit (LBool (b1 = b2))
| Binop Eq, [EArray es1; EArray es2] ->
ELit
(LBool
(try
List.for_all2
(fun e1 e2 ->
match evaluate_operator ctx op pos [e1; e2] with
| ELit (LBool b) -> b
| _ -> assert false
(* should not happen *))
es1 es2
with Invalid_argument _ -> false))
| Binop Eq, [ETuple (es1, s1); ETuple (es2, s2)] ->
ELit
(LBool
(try
s1 = s2
&& List.for_all2
(fun e1 e2 ->
match evaluate_operator ctx op pos [e1; e2] with
| ELit (LBool b) -> b
| _ -> assert false
(* should not happen *))
es1 es2
with Invalid_argument _ -> false))
| Binop Eq, [EInj (e1, i1, en1, _ts1); EInj (e2, i2, en2, _ts2)] ->
ELit
(LBool
(try
en1 = en2
&& i1 = i2
&&
match evaluate_operator ctx op pos [e1; e2] with
| ELit (LBool b) -> b
| _ -> assert false
(* should not happen *)
with Invalid_argument _ -> false))
| Binop Eq, [_; _] ->
ELit (LBool false) (* comparing anything else return false *)
| Binop Neq, [_; _] -> (
match evaluate_operator ctx (Binop Eq) pos args with
| ELit (LBool b) -> ELit (LBool (not b))
| _ -> assert false (*should not happen *))
| Binop Concat, [EArray es1; EArray es2] -> EArray (es1 @ es2)
| Binop Map, [_; EArray es] ->
EArray
(List.map
(fun e' ->
evaluate_expr ctx
(Marked.same_mark_as (EApp (List.nth args 0, [e'])) e'))
es)
| Binop Filter, [_; EArray es] ->
EArray
(List.filter
(fun e' ->
match
evaluate_expr ctx
(Marked.same_mark_as (EApp (List.nth args 0, [e'])) e')
with
| ELit (LBool b), _ -> b
| _ ->
Errors.raise_spanned_error
(Expr.pos (List.nth args 0))
"This predicate evaluated to something else than a boolean \
(should not happen if the term was well-typed)")
es)
| Binop _, ([ELit LEmptyError; _] | [_; ELit LEmptyError]) -> ELit LEmptyError
| Unop (Minus KInt), [ELit (LInt i)] ->
ELit (LInt Runtime.(integer_of_int 0 -! i))
| Unop (Minus KRat), [ELit (LRat i)] ->
ELit (LRat Runtime.(decimal_of_string "0" -& i))
| Unop (Minus KMoney), [ELit (LMoney i)] ->
ELit (LMoney Runtime.(money_of_units_int 0 -$ i))
| Unop (Minus KDuration), [ELit (LDuration i)] ->
ELit (LDuration Runtime.(~-^i))
| Unop Not, [ELit (LBool b)] -> ELit (LBool (not b))
| Unop Length, [EArray es] ->
ELit (LInt (Runtime.integer_of_int (List.length es)))
| Unop GetDay, [ELit (LDate d)] ->
ELit (LInt Runtime.(day_of_month_of_date d))
| Unop GetMonth, [ELit (LDate d)] ->
ELit (LInt Runtime.(month_number_of_date d))
| Unop GetYear, [ELit (LDate d)] -> ELit (LInt Runtime.(year_of_date d))
| Unop FirstDayOfMonth, [ELit (LDate d)] ->
ELit (LDate Runtime.(first_day_of_month d))
| Unop LastDayOfMonth, [ELit (LDate d)] ->
ELit (LDate Runtime.(first_day_of_month d))
| Unop IntToRat, [ELit (LInt i)] -> ELit (LRat Runtime.(decimal_of_integer i))
| Unop MoneyToRat, [ELit (LMoney i)] ->
ELit (LRat Runtime.(decimal_of_money i))
| Unop RatToMoney, [ELit (LRat i)] ->
ELit (LMoney Runtime.(money_of_decimal i))
| Unop RoundMoney, [ELit (LMoney m)] -> ELit (LMoney Runtime.(money_round m))
| Unop RoundDecimal, [ELit (LRat m)] -> ELit (LRat Runtime.(decimal_round m))
| Unop (Log (entry, infos)), [e'] ->
if !Cli.trace_flag then (
match entry with
| VarDef _ ->
(* TODO: this usage of Format is broken, Formatting requires that all is
formatted in one pass, without going through intermediate "%s" *)
Cli.log_format "%*s%a %a: %s" (!log_indent * 2) "" Print.log_entry entry
Print.uid_list infos
(match e' with
| EAbs _ -> Cli.with_style [ANSITerminal.green] "<function>"
| _ ->
let expr_str =
Format.asprintf "%a" (Expr.format ctx ~debug:false) (List.hd args)
in
let expr_str =
Re.Pcre.substitute ~rex:(Re.Pcre.regexp "\n\\s*")
~subst:(fun _ -> " ")
expr_str
in
Cli.with_style [ANSITerminal.green] "%s" expr_str)
| PosRecordIfTrueBool -> (
match pos <> Pos.no_pos, e' with
| true, ELit (LBool true) ->
Cli.log_format "%*s%a%s:\n%s" (!log_indent * 2) "" Print.log_entry
entry
(Cli.with_style [ANSITerminal.green] "Definition applied")
(Cli.add_prefix_to_each_line (Pos.retrieve_loc_text pos) (fun _ ->
Format.asprintf "%*s" (!log_indent * 2) ""))
| _ -> ())
| BeginCall ->
Cli.log_format "%*s%a %a" (!log_indent * 2) "" Print.log_entry entry
Print.uid_list infos;
log_indent := !log_indent + 1
| EndCall ->
log_indent := !log_indent - 1;
Cli.log_format "%*s%a %a" (!log_indent * 2) "" Print.log_entry entry
Print.uid_list infos)
else ();
e'
| Unop _, [ELit LEmptyError] -> ELit LEmptyError
| _ ->
let err () =
Errors.raise_multispanned_error
([Some "Operator:", pos]
@ List.mapi
@ -307,6 +152,162 @@ let rec evaluate_operator
args)
"Operator applied to the wrong arguments\n\
(should not happen if the term was well-typed)"
in
let open Runtime.Oper in
if List.exists (function ELit LEmptyError, _ -> true | _ -> false) args then
ELit LEmptyError
else
Operator.kind_dispatch op
~polymorphic:(fun op ->
match op, args with
| Length, [(EArray es, _)] ->
ELit (LInt (Runtime.integer_of_int (List.length es)))
| Log (entry, infos), [e'] ->
print_log ctx entry infos pos e';
Marked.unmark e'
| Eq, [(e1, _); (e2, _)] -> ELit (LBool (handle_eq ctx pos e1 e2))
| Map, [f; (EArray es, _)] ->
EArray
(List.map
(fun e' ->
evaluate_expr ctx
(Marked.same_mark_as (EApp { f; args = [e'] }) e'))
es)
| Reduce, [_; default; (EArray [], _)] -> Marked.unmark default
| Reduce, [f; _; (EArray (x0 :: xn), _)] ->
Marked.unmark
(List.fold_left
(fun acc x ->
evaluate_expr ctx
(Marked.same_mark_as (EApp { f; args = [acc; x] }) f))
x0 xn)
| Concat, [(EArray es1, _); (EArray es2, _)] -> EArray (es1 @ es2)
| Filter, [f; (EArray es, _)] ->
EArray
(List.filter
(fun e' ->
match
evaluate_expr ctx
(Marked.same_mark_as (EApp { f; args = [e'] }) e')
with
| ELit (LBool b), _ -> b
| _ ->
Errors.raise_spanned_error
(Expr.pos (List.nth args 0))
"This predicate evaluated to something else than a \
boolean (should not happen if the term was well-typed)")
es)
| Fold, [f; init; (EArray es, _)] ->
Marked.unmark
(List.fold_left
(fun acc e' ->
evaluate_expr ctx
(Marked.same_mark_as (EApp { f; args = [acc; e'] }) e'))
init es)
| (Length | Log _ | Eq | Map | Concat | Filter | Fold | Reduce), _ ->
err ())
~monomorphic:(fun op ->
let rlit =
match op, List.map (function ELit l, _ -> l | _ -> err ()) args with
| Not, [LBool b] -> LBool (o_not b)
| GetDay, [LDate d] -> LInt (o_getDay d)
| GetMonth, [LDate d] -> LInt (o_getMonth d)
| GetYear, [LDate d] -> LInt (o_getYear d)
| FirstDayOfMonth, [LDate d] -> LDate (o_firstDayOfMonth d)
| LastDayOfMonth, [LDate d] -> LDate (o_lastDayOfMonth d)
| And, [LBool b1; LBool b2] -> LBool (o_and b1 b2)
| Or, [LBool b1; LBool b2] -> LBool (o_or b1 b2)
| Xor, [LBool b1; LBool b2] -> LBool (o_xor b1 b2)
| ( ( Not | GetDay | GetMonth | GetYear | FirstDayOfMonth
| LastDayOfMonth | And | Or | Xor ),
_ ) ->
err ()
in
ELit rlit)
~resolved:(fun op ->
let rlit =
match op, List.map (function ELit l, _ -> l | _ -> err ()) args with
| Minus_int, [LInt x] -> LInt (o_minus_int x)
| Minus_rat, [LRat x] -> LRat (o_minus_rat x)
| Minus_mon, [LMoney x] -> LMoney (o_minus_mon x)
| Minus_dur, [LDuration x] -> LDuration (o_minus_dur x)
| ToRat_int, [LInt i] -> LRat (o_torat_int i)
| ToRat_mon, [LMoney i] -> LRat (o_torat_mon i)
| ToMoney_rat, [LRat i] -> LMoney (o_tomoney_rat i)
| Round_mon, [LMoney m] -> LMoney (o_round_mon m)
| Round_rat, [LRat m] -> LRat (o_round_rat m)
| Add_int_int, [LInt x; LInt y] -> LInt (o_add_int_int x y)
| Add_rat_rat, [LRat x; LRat y] -> LRat (o_add_rat_rat x y)
| Add_mon_mon, [LMoney x; LMoney y] -> LMoney (o_add_mon_mon x y)
| Add_dat_dur, [LDate x; LDuration y] -> LDate (o_add_dat_dur x y)
| Add_dur_dur, [LDuration x; LDuration y] ->
LDuration (o_add_dur_dur x y)
| Sub_int_int, [LInt x; LInt y] -> LInt (o_sub_int_int x y)
| Sub_rat_rat, [LRat x; LRat y] -> LRat (o_sub_rat_rat x y)
| Sub_mon_mon, [LMoney x; LMoney y] -> LMoney (o_sub_mon_mon x y)
| Sub_dat_dat, [LDate x; LDate y] -> LDuration (o_sub_dat_dat x y)
| Sub_dat_dur, [LDate x; LDuration y] -> LDate (o_sub_dat_dur x y)
| Sub_dur_dur, [LDuration x; LDuration y] ->
LDuration (o_sub_dur_dur x y)
| Mult_int_int, [LInt x; LInt y] -> LInt (o_mult_int_int x y)
| Mult_rat_rat, [LRat x; LRat y] -> LRat (o_mult_rat_rat x y)
| Mult_mon_rat, [LMoney x; LRat y] -> LMoney (o_mult_mon_rat x y)
| Mult_dur_int, [LDuration x; LInt y] ->
LDuration (o_mult_dur_int x y)
| Div_int_int, [LInt x; LInt y] -> LRat (protect o_div_int_int x y)
| Div_rat_rat, [LRat x; LRat y] -> LRat (protect o_div_rat_rat x y)
| Div_mon_mon, [LMoney x; LMoney y] ->
LRat (protect o_div_mon_mon x y)
| Div_mon_rat, [LMoney x; LRat y] ->
LMoney (protect o_div_mon_rat x y)
| Lt_int_int, [LInt x; LInt y] -> LBool (o_lt_int_int x y)
| Lt_rat_rat, [LRat x; LRat y] -> LBool (o_lt_rat_rat x y)
| Lt_mon_mon, [LMoney x; LMoney y] -> LBool (o_lt_mon_mon x y)
| Lt_dat_dat, [LDate x; LDate y] -> LBool (o_lt_dat_dat x y)
| Lt_dur_dur, [LDuration x; LDuration y] ->
LBool (protect o_lt_dur_dur x y)
| Lte_int_int, [LInt x; LInt y] -> LBool (o_lte_int_int x y)
| Lte_rat_rat, [LRat x; LRat y] -> LBool (o_lte_rat_rat x y)
| Lte_mon_mon, [LMoney x; LMoney y] -> LBool (o_lte_mon_mon x y)
| Lte_dat_dat, [LDate x; LDate y] -> LBool (o_lte_dat_dat x y)
| Lte_dur_dur, [LDuration x; LDuration y] ->
LBool (protect o_lte_dur_dur x y)
| Gt_int_int, [LInt x; LInt y] -> LBool (o_gt_int_int x y)
| Gt_rat_rat, [LRat x; LRat y] -> LBool (o_gt_rat_rat x y)
| Gt_mon_mon, [LMoney x; LMoney y] -> LBool (o_gt_mon_mon x y)
| Gt_dat_dat, [LDate x; LDate y] -> LBool (o_gt_dat_dat x y)
| Gt_dur_dur, [LDuration x; LDuration y] ->
LBool (protect o_gt_dur_dur x y)
| Gte_int_int, [LInt x; LInt y] -> LBool (o_gte_int_int x y)
| Gte_rat_rat, [LRat x; LRat y] -> LBool (o_gte_rat_rat x y)
| Gte_mon_mon, [LMoney x; LMoney y] -> LBool (o_gte_mon_mon x y)
| Gte_dat_dat, [LDate x; LDate y] -> LBool (o_gte_dat_dat x y)
| Gte_dur_dur, [LDuration x; LDuration y] ->
LBool (protect o_gte_dur_dur x y)
| Eq_int_int, [LInt x; LInt y] -> LBool (o_eq_int_int x y)
| Eq_rat_rat, [LRat x; LRat y] -> LBool (o_eq_rat_rat x y)
| Eq_mon_mon, [LMoney x; LMoney y] -> LBool (o_eq_mon_mon x y)
| Eq_dat_dat, [LDate x; LDate y] -> LBool (o_eq_dat_dat x y)
| Eq_dur_dur, [LDuration x; LDuration y] ->
LBool (protect o_eq_dur_dur x y)
| ( ( Minus_int | Minus_rat | Minus_mon | Minus_dur | ToRat_int
| ToRat_mon | ToMoney_rat | Round_rat | Round_mon | Add_int_int
| Add_rat_rat | Add_mon_mon | Add_dat_dur | Add_dur_dur
| Sub_int_int | Sub_rat_rat | Sub_mon_mon | Sub_dat_dat
| Sub_dat_dur | Sub_dur_dur | Mult_int_int | Mult_rat_rat
| Mult_mon_rat | Mult_dur_int | Div_int_int | Div_rat_rat
| Div_mon_mon | Div_mon_rat | Lt_int_int | Lt_rat_rat | Lt_mon_mon
| Lt_dat_dat | Lt_dur_dur | Lte_int_int | Lte_rat_rat
| Lte_mon_mon | Lte_dat_dat | Lte_dur_dur | Gt_int_int
| Gt_rat_rat | Gt_mon_mon | Gt_dat_dat | Gt_dur_dur | Gte_int_int
| Gte_rat_rat | Gte_mon_mon | Gte_dat_dat | Gte_dur_dur
| Eq_int_int | Eq_rat_rat | Eq_mon_mon | Eq_dat_dat | Eq_dur_dur
),
_ ) ->
err ()
in
ELit rlit)
~overloaded:(fun _ -> assert false)
and evaluate_expr (ctx : decl_ctx) (e : 'm Ast.expr) : 'm Ast.expr =
match Marked.unmark e with
@ -314,11 +315,11 @@ and evaluate_expr (ctx : decl_ctx) (e : 'm Ast.expr) : 'm Ast.expr =
Errors.raise_spanned_error (Expr.pos e)
"free variable found at evaluation (should not happen if term was \
well-typed"
| EApp (e1, args) -> (
| EApp { f = e1; args } -> (
let e1 = evaluate_expr ctx e1 in
let args = List.map (evaluate_expr ctx) args in
match Marked.unmark e1 with
| EAbs (binder, _) ->
| EAbs { binder; _ } ->
if Bindlib.mbinder_arity binder = List.length args then
evaluate_expr ctx
(Bindlib.msubst binder (Array.of_list (List.map Marked.unmark args)))
@ -327,7 +328,7 @@ and evaluate_expr (ctx : decl_ctx) (e : 'm Ast.expr) : 'm Ast.expr =
"wrong function call, expected %d arguments, got %d"
(Bindlib.mbinder_arity binder)
(List.length args)
| EOp op ->
| EOp { op; _ } ->
Marked.same_mark_as (evaluate_operator ctx op (Expr.pos e) args) e
| ELit LEmptyError -> Marked.same_mark_as (ELit LEmptyError) e
| _ ->
@ -335,69 +336,66 @@ and evaluate_expr (ctx : decl_ctx) (e : 'm Ast.expr) : 'm Ast.expr =
"function has not been reduced to a lambda at evaluation (should not \
happen if the term was well-typed")
| EAbs _ | ELit _ | EOp _ -> e (* these are values *)
| ETuple (es, s) ->
let new_es = List.map (evaluate_expr ctx) es in
if List.exists is_empty_error new_es then
| EStruct { fields = es; name } ->
let new_es = StructField.Map.map (evaluate_expr ctx) es in
if StructField.Map.exists (fun _ e -> is_empty_error e) new_es then
Marked.same_mark_as (ELit LEmptyError) e
else Marked.same_mark_as (ETuple (new_es, s)) e
| ETupleAccess (e1, n, s, _) -> (
else Marked.same_mark_as (EStruct { fields = new_es; name }) e
| EStructAccess { e = e1; name = s; field } -> (
let e1 = evaluate_expr ctx e1 in
match Marked.unmark e1 with
| ETuple (es, s') -> (
(match s, s' with
| None, None -> ()
| Some s, Some s' when s = s' -> ()
| _ ->
| EStruct { fields = es; name = s' } -> (
if not (StructName.equal s s') then
Errors.raise_multispanned_error
[None, Expr.pos e; None, Expr.pos e1]
"Error during tuple access: not the same structs (should not happen \
if the term was well-typed)");
match List.nth_opt es n with
"Error during struct access: not the same structs (should not happen \
if the term was well-typed)";
match StructField.Map.find_opt field es with
| Some e' -> e'
| None ->
Errors.raise_spanned_error (Expr.pos e1)
"The tuple has %d components but the %i-th element was requested \
(should not happen if the term was well-type)"
(List.length es) n)
"Invalid field access %a in struct %a (should not happen if the term \
was well-typed)"
StructField.format_t field StructName.format_t s)
| ELit LEmptyError -> Marked.same_mark_as (ELit LEmptyError) e
| _ ->
Errors.raise_spanned_error (Expr.pos e1)
"The expression %a should be a tuple with %d components but is not \
(should not happen if the term was well-typed)"
"The expression %a should be a struct %a but is not (should not happen \
if the term was well-typed)"
(Expr.format ctx ~debug:true)
e n)
| EInj (e1, n, en, ts) ->
e StructName.format_t s)
| EInj { e = e1; name; cons } ->
let e1' = evaluate_expr ctx e1 in
if is_empty_error e1' then Marked.same_mark_as (ELit LEmptyError) e
else Marked.same_mark_as (EInj (e1', n, en, ts)) e
| EMatch (e1, es, e_name) -> (
if is_empty_error e then Marked.same_mark_as (ELit LEmptyError) e
else Marked.same_mark_as (EInj { e = e1'; name; cons }) e
| EMatch { e = e1; cases = es; name } -> (
let e1 = evaluate_expr ctx e1 in
match Marked.unmark e1 with
| EInj (e1, n, e_name', _) ->
if e_name <> e_name' then
| EInj { e = e1; cons; name = name' } ->
if not (EnumName.equal name name') then
Errors.raise_multispanned_error
[None, Expr.pos e; None, Expr.pos e1]
"Error during match: two different enums found (should not happen if \
the term was well-typed)";
let es_n =
match List.nth_opt es n with
match EnumConstructor.Map.find_opt cons es with
| Some es_n -> es_n
| None ->
Errors.raise_spanned_error (Expr.pos e)
"sum type index error (should not happen if the term was \
well-typed)"
in
let new_e = Marked.same_mark_as (EApp (es_n, [e1])) e in
let new_e = Marked.same_mark_as (EApp { f = es_n; args = [e1] }) e in
evaluate_expr ctx new_e
| ELit LEmptyError -> Marked.same_mark_as (ELit LEmptyError) e
| _ ->
Errors.raise_spanned_error (Expr.pos e1)
"Expected a term having a sum type as an argument to a match (should \
not happen if the term was well-typed")
| EDefault (exceptions, just, cons) -> (
let exceptions = List.map (evaluate_expr ctx) exceptions in
let empty_count = List.length (List.filter is_empty_error exceptions) in
match List.length exceptions - empty_count with
| EDefault { excepts; just; cons } -> (
let excepts = List.map (evaluate_expr ctx) excepts in
let empty_count = List.length (List.filter is_empty_error excepts) in
match List.length excepts - empty_count with
| 0 -> (
let just = evaluate_expr ctx just in
match Marked.unmark just with
@ -408,19 +406,19 @@ and evaluate_expr (ctx : decl_ctx) (e : 'm Ast.expr) : 'm Ast.expr =
Errors.raise_spanned_error (Expr.pos e)
"Default justification has not been reduced to a boolean at \
evaluation (should not happen if the term was well-typed")
| 1 -> List.find (fun sub -> not (is_empty_error sub)) exceptions
| 1 -> List.find (fun sub -> not (is_empty_error sub)) excepts
| _ ->
Errors.raise_multispanned_error
(List.map
(fun except ->
Some "This consequence has a valid justification:", Expr.pos except)
(List.filter (fun sub -> not (is_empty_error sub)) exceptions))
(List.filter (fun sub -> not (is_empty_error sub)) excepts))
"There is a conflict between multiple valid consequences for assigning \
the same variable.")
| EIfThenElse (cond, et, ef) -> (
| EIfThenElse { cond; etrue; efalse } -> (
match Marked.unmark (evaluate_expr ctx cond) with
| ELit (LBool true) -> evaluate_expr ctx et
| ELit (LBool false) -> evaluate_expr ctx ef
| ELit (LBool true) -> evaluate_expr ctx etrue
| ELit (LBool false) -> evaluate_expr ctx efalse
| ELit LEmptyError -> Marked.same_mark_as (ELit LEmptyError) e
| _ ->
Errors.raise_spanned_error (Expr.pos cond)
@ -431,7 +429,7 @@ and evaluate_expr (ctx : decl_ctx) (e : 'm Ast.expr) : 'm Ast.expr =
if List.exists is_empty_error new_es then
Marked.same_mark_as (ELit LEmptyError) e
else Marked.same_mark_as (EArray new_es) e
| ErrorOnEmpty e' ->
| EErrorOnEmpty e' ->
let e' = evaluate_expr ctx e' in
if Marked.unmark e' = ELit LEmptyError then
Errors.raise_spanned_error (Expr.pos e')
@ -443,23 +441,44 @@ and evaluate_expr (ctx : decl_ctx) (e : 'm Ast.expr) : 'm Ast.expr =
| ELit (LBool true) -> Marked.same_mark_as (ELit LUnit) e'
| ELit (LBool false) -> (
match Marked.unmark e' with
| ErrorOnEmpty
| EErrorOnEmpty
( EApp
((EOp (Binop op), _), [((ELit _, _) as e1); ((ELit _, _) as e2)]),
_ )
| EApp
( (EOp (Unop (Log _)), _),
[
( EApp
( (EOp (Binop op), _),
[((ELit _, _) as e1); ((ELit _, _) as e2)] ),
_ );
] )
| EApp ((EOp (Binop op), _), [((ELit _, _) as e1); ((ELit _, _) as e2)])
->
{
f = EOp { op; _ }, _;
args = [((ELit _, _) as e1); ((ELit _, _) as e2)];
},
_ ) ->
Errors.raise_spanned_error (Expr.pos e') "Assertion failed: %a %a %a"
(Expr.format ctx ~debug:false)
e1 Print.binop op
e1 Print.operator op
(Expr.format ctx ~debug:false)
e2
| EApp
{
f = EOp { op = Log _; _ }, _;
args =
[
( EApp
{
f = EOp { op; _ }, _;
args = [((ELit _, _) as e1); ((ELit _, _) as e2)];
},
_ );
];
} ->
Errors.raise_spanned_error (Expr.pos e') "Assertion failed: %a %a %a"
(Expr.format ctx ~debug:false)
e1 Print.operator op
(Expr.format ctx ~debug:false)
e2
| EApp
{
f = EOp { op; _ }, _;
args = [((ELit _, _) as e1); ((ELit _, _) as e2)];
} ->
Errors.raise_spanned_error (Expr.pos e') "Assertion failed: %a %a %a"
(Expr.format ctx ~debug:false)
e1 Print.operator op
(Expr.format ctx ~debug:false)
e2
| _ ->
@ -479,19 +498,22 @@ let interpret_program :
fun (ctx : decl_ctx) (e : 'm Ast.expr) :
(Uid.MarkedString.info * 'm Ast.expr) list ->
match evaluate_expr ctx e with
| (EAbs (_, [((TStruct s_in, _) as _targs)]), mark_e) as e -> begin
| (EAbs { tys = [((TStruct s_in, _) as _targs)]; _ }, mark_e) as e -> begin
(* At this point, the interpreter seeks to execute the scope but does not
have a way to retrieve input values from the command line. [taus] contain
the types of the scope arguments. For [context] arguments, we can provide
an empty thunked term. But for [input] arguments of another type, we
cannot provide anything so we have to fail. *)
let taus = StructMap.find s_in ctx.ctx_structs in
let taus = StructName.Map.find s_in ctx.ctx_structs in
let application_term =
List.map
(fun (_, ty) ->
StructField.Map.map
(fun ty ->
match Marked.unmark ty with
| TArrow ((TLit TUnit, _), ty_in) ->
Expr.empty_thunked_term (Expr.with_ty mark_e ty_in)
| TArrow (ty_in, ty_out) ->
Expr.make_abs
[| Var.make "_" |]
(Bindlib.box (ELit LEmptyError), Expr.with_ty mark_e ty_out)
[ty_in] (Expr.mark_pos mark_e)
| _ ->
Errors.raise_spanned_error (Marked.get_mark ty)
"This scope needs input arguments to be executed. But the Catala \
@ -503,17 +525,14 @@ let interpret_program :
in
let to_interpret =
Expr.make_app (Expr.box e)
[Expr.make_tuple application_term (Some s_in) mark_e]
[Expr.estruct s_in application_term mark_e]
(Expr.pos e)
in
match Marked.unmark (evaluate_expr ctx (Expr.unbox to_interpret)) with
| ETuple (args, Some s_out) ->
let s_out_fields =
List.map
(fun (f, _) -> StructFieldName.get_info f)
(StructMap.find s_out ctx.ctx_structs)
in
List.map2 (fun arg var -> var, arg) args s_out_fields
| EStruct { fields; _ } ->
List.map
(fun (fld, e) -> StructField.get_info fld, e)
(StructField.Map.bindings fields)
| _ ->
Errors.raise_spanned_error (Expr.pos e)
"The interpretation of a program should always yield a struct \

View File

@ -16,7 +16,7 @@
(** Reference interpreter for the default calculus *)
open Utils
open Catala_utils
open Shared_ast
val evaluate_expr : decl_ctx -> 'm Ast.expr -> 'm Ast.expr

View File

@ -14,7 +14,7 @@
WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *)
open Utils
open Catala_utils
open Shared_ast
open Ast
@ -24,179 +24,206 @@ type partial_evaluation_ctx = {
}
let rec partial_evaluation (ctx : partial_evaluation_ctx) (e : 'm expr) :
'm expr Bindlib.box =
(dcalc, 'm mark) boxed_gexpr =
(* We proceed bottom-up, first apply on the subterms *)
let e = Expr.map ~f:(partial_evaluation ctx) e in
let mark = Marked.get_mark e in
let rec_helper = partial_evaluation ctx in
match Marked.unmark e with
| EApp
( (( EOp (Unop Not), _
| EApp ((EOp (Unop (Log _)), _), [(EOp (Unop Not), _)]), _ ) as op),
[e1] ) ->
(* reduction of logical not *)
(Bindlib.box_apply (fun e1 ->
match e1 with
| ELit (LBool false), _ -> ELit (LBool true), mark
| ELit (LBool true), _ -> ELit (LBool false), mark
| _ -> EApp (op, [e1]), mark))
(rec_helper e1)
| EApp
( (( EOp (Binop Or), _
| EApp ((EOp (Unop (Log _)), _), [(EOp (Binop Or), _)]), _ ) as op),
[e1; e2] ) ->
(* reduction of logical or *)
(Bindlib.box_apply2 (fun e1 e2 ->
match e1, e2 with
| (ELit (LBool false), _), new_e | new_e, (ELit (LBool false), _) ->
new_e
| (ELit (LBool true), _), _ | _, (ELit (LBool true), _) ->
ELit (LBool true), mark
| _ -> EApp (op, [e1; e2]), mark))
(rec_helper e1) (rec_helper e2)
| EApp
( (( EOp (Binop And), _
| EApp ((EOp (Unop (Log _)), _), [(EOp (Binop And), _)]), _ ) as op),
[e1; e2] ) ->
(* reduction of logical and *)
(Bindlib.box_apply2 (fun e1 e2 ->
match e1, e2 with
| (ELit (LBool true), _), new_e | new_e, (ELit (LBool true), _) ->
new_e
| (ELit (LBool false), _), _ | _, (ELit (LBool false), _) ->
ELit (LBool false), mark
| _ -> EApp (op, [e1; e2]), mark))
(rec_helper e1) (rec_helper e2)
| EVar x -> Bindlib.box_apply (fun x -> x, mark) (Bindlib.box_var x)
| ETuple (args, s_name) ->
Bindlib.box_apply
(fun args -> ETuple (args, s_name), mark)
(List.map rec_helper args |> Bindlib.box_list)
| ETupleAccess (arg, i, s_name, typs) ->
Bindlib.box_apply
(fun arg -> ETupleAccess (arg, i, s_name, typs), mark)
(rec_helper arg)
| EInj (arg, i, e_name, typs) ->
Bindlib.box_apply
(fun arg -> EInj (arg, i, e_name, typs), mark)
(rec_helper arg)
| EMatch (arg, arms, e_name) ->
Bindlib.box_apply2
(fun arg arms ->
match arg, arms with
| (EInj (e1, i, e_name', _ts), _), _
when EnumName.compare e_name e_name' = 0 ->
(* iota reduction *)
EApp (List.nth arms i, [e1]), mark
| _ -> EMatch (arg, arms, e_name), mark)
(rec_helper arg)
(List.map rec_helper arms |> Bindlib.box_list)
| EArray args ->
Bindlib.box_apply
(fun args -> EArray args, mark)
(List.map rec_helper args |> Bindlib.box_list)
| ELit l -> Bindlib.box (ELit l, mark)
| EAbs (binder, typs) ->
let vars, body = Bindlib.unmbind binder in
let new_body = rec_helper body in
let new_binder = Bindlib.bind_mvar vars new_body in
Bindlib.box_apply (fun binder -> EAbs (binder, typs), mark) new_binder
| EApp (f, args) ->
Bindlib.box_apply2
(fun f args ->
match Marked.unmark f with
| EAbs (binder, _ts) ->
(* beta reduction *)
Bindlib.msubst binder (List.map fst args |> Array.of_list)
| _ -> EApp (f, args), mark)
(rec_helper f)
(List.map rec_helper args |> Bindlib.box_list)
| EAssert e1 -> Bindlib.box_apply (fun e1 -> EAssert e1, mark) (rec_helper e1)
| EOp op -> Bindlib.box (EOp op, mark)
| EDefault (exceptions, just, cons) ->
Bindlib.box_apply3
(fun exceptions just cons ->
(* TODO: mechanically prove each of these optimizations correct :) *)
match
( List.filter
(fun except ->
match Marked.unmark except with
| ELit LEmptyError -> false
| _ -> true)
exceptions
(* we can discard the exceptions that are always empty error *),
just,
cons )
with
| exceptions, just, cons
when List.fold_left
(fun nb except -> if Expr.is_value except then nb + 1 else nb)
0 exceptions
> 1 ->
(* at this point we know a conflict error will be triggered so we just
feed the expression to the interpreter that will print the
beautiful right error message *)
Interpreter.evaluate_expr ctx.decl_ctx
(EDefault (exceptions, just, cons), mark)
| [except], _, _ when Expr.is_value except ->
(* Then reduce the parent node *)
let reduce e =
(* Todo: improve the handling of eapp(log,elit) cases here, it obfuscates
the matches and the log calls are not preserved, which would be a good
property *)
match Marked.unmark e with
| EApp
{
f =
( EOp { op = Not; _ }, _
| ( EApp
{
f = EOp { op = Log _; _ }, _;
args = [(EOp { op = Not; _ }, _)];
},
_ ) ) as op;
args = [e1];
} -> (
(* reduction of logical not *)
match e1 with
| ELit (LBool false), _ -> ELit (LBool true)
| ELit (LBool true), _ -> ELit (LBool false)
| e1 -> EApp { f = op; args = [e1] })
| EApp
{
f =
( EOp { op = Or; _ }, _
| ( EApp
{
f = EOp { op = Log _; _ }, _;
args = [(EOp { op = Or; _ }, _)];
},
_ ) ) as op;
args = [e1; e2];
} -> (
(* reduction of logical or *)
match e1, e2 with
| (ELit (LBool false), _), new_e | new_e, (ELit (LBool false), _) ->
Marked.unmark new_e
| (ELit (LBool true), _), _ | _, (ELit (LBool true), _) ->
ELit (LBool true)
| _ -> EApp { f = op; args = [e1; e2] })
| EApp
{
f =
( EOp { op = And; _ }, _
| ( EApp
{
f = EOp { op = Log _; _ }, _;
args = [(EOp { op = And; _ }, _)];
},
_ ) ) as op;
args = [e1; e2];
} -> (
(* reduction of logical and *)
match e1, e2 with
| (ELit (LBool true), _), new_e | new_e, (ELit (LBool true), _) ->
Marked.unmark new_e
| (ELit (LBool false), _), _ | _, (ELit (LBool false), _) ->
ELit (LBool false)
| _ -> EApp { f = op; args = [e1; e2] })
| EMatch { e = EInj { e; name = name1; cons }, _; cases; name }
when EnumName.equal name name1 ->
(* iota reduction *)
EApp { f = EnumConstructor.Map.find cons cases; args = [e] }
| EApp { f = EAbs { binder; _ }, _; args } ->
(* beta reduction *)
Marked.unmark (Bindlib.msubst binder (List.map fst args |> Array.of_list))
| EDefault { excepts; just; cons } -> (
(* TODO: mechanically prove each of these optimizations correct :) *)
let excepts =
List.filter
(fun except -> Marked.unmark except <> ELit LEmptyError)
excepts
(* we can discard the exceptions that are always empty error *)
in
let value_except_count =
List.fold_left
(fun nb except -> if Expr.is_value except then nb + 1 else nb)
0 excepts
in
if value_except_count > 1 then
(* at this point we know a conflict error will be triggered so we just
feed the expression to the interpreter that will print the beautiful
right error message *)
Marked.unmark (Interpreter.evaluate_expr ctx.decl_ctx e)
else
match excepts, just with
| [except], _ when Expr.is_value except ->
(* if there is only one exception and it is a non-empty value it is
always chosen *)
except
Marked.unmark except
| ( [],
( ( ELit (LBool true)
| EApp ((EOp (Unop (Log _)), _), [(ELit (LBool true), _)]) ),
_ ),
cons ) ->
cons
| EApp
{
f = EOp { op = Log _; _ }, _;
args = [(ELit (LBool true), _)];
} ),
_ ) ) ->
Marked.unmark cons
| ( [],
( ( ELit (LBool false)
| EApp ((EOp (Unop (Log _)), _), [(ELit (LBool false), _)]) ),
_ ),
_ ) ->
ELit LEmptyError, mark
| [], just, cons when not !Cli.avoid_exceptions_flag ->
| EApp
{
f = EOp { op = Log _; _ }, _;
args = [(ELit (LBool false), _)];
} ),
_ ) ) ->
ELit LEmptyError
| [], just when not !Cli.avoid_exceptions_flag ->
(* without exceptions, a default is just an [if then else] raising an
error in the else case. This exception is only valid in the context
of compilation_with_exceptions, so we desactivate with a global
flag to know if we will be compiling using exceptions or the option
monad. *)
EIfThenElse (just, cons, (ELit LEmptyError, mark)), mark
| exceptions, just, cons -> EDefault (exceptions, just, cons), mark)
(List.map rec_helper exceptions |> Bindlib.box_list)
(rec_helper just) (rec_helper cons)
| EIfThenElse (e1, e2, e3) ->
Bindlib.box_apply3
(fun e1 e2 e3 ->
match Marked.unmark e1, Marked.unmark e2, Marked.unmark e3 with
| ELit (LBool true), _, _
| EApp ((EOp (Unop (Log _)), _), [(ELit (LBool true), _)]), _, _ ->
e2
| ELit (LBool false), _, _
| EApp ((EOp (Unop (Log _)), _), [(ELit (LBool false), _)]), _, _ ->
e3
| ( _,
( ELit (LBool true)
| EApp ((EOp (Unop (Log _)), _), [(ELit (LBool true), _)]) ),
( ELit (LBool false)
| EApp ((EOp (Unop (Log _)), _), [(ELit (LBool false), _)]) ) ) ->
e1
| _ when Expr.equal e2 e3 -> e2
| _ -> EIfThenElse (e1, e2, e3), mark)
(rec_helper e1) (rec_helper e2) (rec_helper e3)
| ErrorOnEmpty e1 ->
Bindlib.box_apply (fun e1 -> ErrorOnEmpty e1, mark) (rec_helper e1)
monad. FIXME: move this optimisation somewhere else to avoid this
check *)
EIfThenElse
{ cond = just; etrue = cons; efalse = ELit LEmptyError, mark }
| excepts, just -> EDefault { excepts; just; cons })
| EIfThenElse
{
cond =
( ELit (LBool true), _
| ( EApp
{
f = EOp { op = Log _; _ }, _;
args = [(ELit (LBool true), _)];
},
_ ) );
etrue;
_;
} ->
Marked.unmark etrue
| EIfThenElse
{
cond =
( ( ELit (LBool false)
| EApp
{
f = EOp { op = Log _; _ }, _;
args = [(ELit (LBool false), _)];
} ),
_ );
efalse;
_;
} ->
Marked.unmark efalse
| EIfThenElse
{
cond;
etrue =
( ( ELit (LBool btrue)
| EApp
{
f = EOp { op = Log _; _ }, _;
args = [(ELit (LBool btrue), _)];
} ),
_ );
efalse =
( ( ELit (LBool bfalse)
| EApp
{
f = EOp { op = Log _; _ }, _;
args = [(ELit (LBool bfalse), _)];
} ),
_ );
} ->
if btrue && not bfalse then Marked.unmark cond
else if (not btrue) && bfalse then
EApp
{
f = EOp { op = Not; tys = [TLit TBool, Expr.mark_pos mark] }, mark;
args = [cond];
}
(* note: this last call eliminates the condition & might skip log calls
as well *)
else (* btrue = bfalse *) ELit (LBool btrue)
| e -> e
in
Expr.Box.app1 e reduce mark
let optimize_expr (decl_ctx : decl_ctx) (e : 'm expr) =
partial_evaluation { var_values = Var.Map.empty; decl_ctx } e
let rec scope_lets_map
(t : 'a -> 'm expr -> 'm expr Bindlib.box)
(t : 'a -> 'm expr -> (dcalc, 'm mark) boxed_gexpr)
(ctx : 'a)
(scope_body_expr : 'm expr scope_body_expr) :
'm expr scope_body_expr Bindlib.box =
match scope_body_expr with
| Result e -> Bindlib.box_apply (fun e' -> Result e') (t ctx e)
| Result e ->
Bindlib.box_apply (fun e' -> Result e') (Expr.Box.lift (t ctx e))
| ScopeLet scope_let ->
let var, next = Bindlib.unbind scope_let.scope_let_next in
let new_scope_let_expr = t ctx scope_let.scope_let_expr in
let new_scope_let_expr = Expr.Box.lift (t ctx scope_let.scope_let_expr) in
let new_next = scope_lets_map t ctx next in
let new_next = Bindlib.bind_var var new_next in
Bindlib.box_apply2
@ -210,7 +237,7 @@ let rec scope_lets_map
new_scope_let_expr new_next
let rec scopes_map
(t : 'a -> 'm expr -> 'm expr Bindlib.box)
(t : 'a -> 'm expr -> (dcalc, 'm mark) boxed_gexpr)
(ctx : 'a)
(scopes : 'm expr scopes) : 'm expr scopes Bindlib.box =
match scopes with
@ -241,7 +268,7 @@ let rec scopes_map
new_scope_body_expr new_scope_next
let program_map
(t : 'a -> 'm expr -> 'm expr Bindlib.box)
(t : 'a -> 'm expr -> (dcalc, 'm mark) boxed_gexpr)
(ctx : 'a)
(p : 'm program) : 'm program Bindlib.box =
Bindlib.box_apply

View File

@ -20,5 +20,5 @@
open Shared_ast
open Ast
val optimize_expr : decl_ctx -> 'm expr -> 'm expr Bindlib.box
val optimize_expr : decl_ctx -> 'm expr -> (dcalc, 'm mark) boxed_gexpr
val optimize_program : 'm program -> 'm program

View File

@ -16,25 +16,11 @@
(** Abstract syntax tree of the desugared representation *)
open Utils
open Catala_utils
open Shared_ast
(** {1 Names, Maps and Keys} *)
module IdentMap : Map.S with type key = String.t = Map.Make (String)
module RuleName : Uid.Id with type info = Uid.MarkedString.info =
Uid.Make (Uid.MarkedString) ()
module RuleMap : Map.S with type key = RuleName.t = Map.Make (RuleName)
module RuleSet : Set.S with type elt = RuleName.t = Set.Make (RuleName)
module LabelName : Uid.Id with type info = Uid.MarkedString.info =
Uid.Make (Uid.MarkedString) ()
module LabelMap : Map.S with type key = LabelName.t = Map.Make (LabelName)
module LabelSet : Set.S with type elt = LabelName.t = Set.Make (LabelName)
(** Inside a scope, a definition can refer either to a scope def, or a subscope
def *)
module ScopeDef = struct
@ -103,6 +89,9 @@ module ExprMap = Map.Make (struct
let compare = Expr.compare
end)
type io_input = NoInput | OnlyInput | Reentrant
type io = { io_output : bool Marked.pos; io_input : io_input Marked.pos }
type exception_situation =
| BaseCase
| ExceptionToLabel of LabelName.t Marked.pos
@ -136,7 +125,7 @@ module Rule = struct
Expr.compare c1 c2
| n -> n)
| Some (v1, t1), Some (v2, t2) -> (
match Shared_ast.Expr.compare_typ t1 t2 with
match Type.compare t1 t2 with
| 0 -> (
let open Bindlib in
let b1 = unbox (bind_var v1 (Expr.Box.lift r1.rule_just)) in
@ -189,29 +178,32 @@ type meta_assertion =
| VariesWith of unit * variation_typ Marked.pos option
type scope_def = {
scope_def_rules : rule RuleMap.t;
scope_def_rules : rule RuleName.Map.t;
scope_def_typ : typ;
scope_def_is_condition : bool;
scope_def_io : Scopelang.Ast.io;
scope_def_io : io;
}
type var_or_states = WholeVar | States of StateName.t list
type scope = {
scope_vars : var_or_states ScopeVarMap.t;
scope_sub_scopes : ScopeName.t SubScopeMap.t;
scope_vars : var_or_states ScopeVar.Map.t;
scope_sub_scopes : ScopeName.t SubScopeName.Map.t;
scope_uid : ScopeName.t;
scope_defs : scope_def ScopeDefMap.t;
scope_assertions : assertion list;
scope_meta_assertions : meta_assertion list;
}
type program = { program_scopes : scope ScopeMap.t; program_ctx : decl_ctx }
type program = {
program_scopes : scope ScopeName.Map.t;
program_ctx : decl_ctx;
}
let rec locations_used e : LocationSet.t =
match e with
| ELocation l, m -> LocationSet.singleton (l, Expr.mark_pos m)
| EAbs (binder, _), _ ->
| EAbs { binder; _ }, _ ->
let _, body = Bindlib.unmbind binder in
locations_used body
| e ->
@ -219,7 +211,7 @@ let rec locations_used e : LocationSet.t =
(fun e -> LocationSet.union (locations_used e))
e LocationSet.empty
let free_variables (def : rule RuleMap.t) : Pos.t ScopeDefMap.t =
let free_variables (def : rule RuleName.Map.t) : Pos.t ScopeDefMap.t =
let add_locs (acc : Pos.t ScopeDefMap.t) (locs : LocationSet.t) :
Pos.t ScopeDefMap.t =
LocationSet.fold
@ -235,7 +227,7 @@ let free_variables (def : rule RuleMap.t) : Pos.t ScopeDefMap.t =
loc_pos acc)
locs acc
in
RuleMap.fold
RuleName.Map.fold
(fun _ rule acc ->
let locs =
LocationSet.union

View File

@ -16,19 +16,9 @@
(** Abstract syntax tree of the desugared representation *)
open Utils
open Catala_utils
open Shared_ast
(** {1 Names, Maps and Keys} *)
module IdentMap : Map.S with type key = String.t
module RuleName : Uid.Id with type info = Uid.MarkedString.info
module RuleMap : Map.S with type key = RuleName.t
module RuleSet : Set.S with type elt = RuleName.t
module LabelName : Uid.Id with type info = Uid.MarkedString.info
module LabelMap : Map.S with type key = LabelName.t
module LabelSet : Set.S with type elt = LabelName.t
(** Inside a scope, a definition can refer either to a scope def, or a subscope
def *)
module ScopeDef : sig
@ -88,27 +78,51 @@ type meta_assertion =
| FixedBy of reference_typ Marked.pos
| VariesWith of unit * variation_typ Marked.pos option
(** This type characterizes the three levels of visibility for a given scope
variable with regards to the scope's input and possible redefinitions inside
the scope.. *)
type io_input =
| NoInput
(** For an internal variable defined only in the scope, and does not
appear in the input. *)
| OnlyInput
(** For variables that should not be redefined in the scope, because they
appear in the input. *)
| Reentrant
(** For variables defined in the scope that can also be redefined by the
caller as they appear in the input. *)
type io = {
io_output : bool Marked.pos;
(** [true] is present in the output of the scope. *)
io_input : io_input Marked.pos;
}
(** Characterization of the input/output status of a scope variable. *)
type scope_def = {
scope_def_rules : rule RuleMap.t;
scope_def_rules : rule RuleName.Map.t;
scope_def_typ : typ;
scope_def_is_condition : bool;
scope_def_io : Scopelang.Ast.io;
scope_def_io : io;
}
type var_or_states = WholeVar | States of StateName.t list
type scope = {
scope_vars : var_or_states ScopeVarMap.t;
scope_sub_scopes : ScopeName.t SubScopeMap.t;
scope_vars : var_or_states ScopeVar.Map.t;
scope_sub_scopes : ScopeName.t SubScopeName.Map.t;
scope_uid : ScopeName.t;
scope_defs : scope_def ScopeDefMap.t;
scope_assertions : assertion list;
scope_meta_assertions : meta_assertion list;
}
type program = { program_scopes : scope ScopeMap.t; program_ctx : decl_ctx }
type program = {
program_scopes : scope ScopeName.Map.t;
program_ctx : decl_ctx;
}
(** {1 Helpers} *)
val locations_used : expr -> LocationSet.t
val free_variables : rule RuleMap.t -> Pos.t ScopeDefMap.t
val free_variables : rule RuleName.Map.t -> Pos.t ScopeDefMap.t

View File

@ -17,7 +17,7 @@
(** Scope dependencies computations using {{:http://ocamlgraph.lri.fr/}
OCamlgraph} *)
open Utils
open Catala_utils
open Shared_ast
(** {1 Scope variables dependency graph} *)
@ -143,7 +143,7 @@ let build_scope_dependencies (scope : Ast.scope) : ScopeDependencies.t =
let g = ScopeDependencies.empty in
(* Add all the vertices to the graph *)
let g =
ScopeVarMap.fold
ScopeVar.Map.fold
(fun (v : ScopeVar.t) var_or_state g ->
match var_or_state with
| Ast.WholeVar -> ScopeDependencies.add_vertex g (Vertex.Var (v, None))
@ -155,7 +155,7 @@ let build_scope_dependencies (scope : Ast.scope) : ScopeDependencies.t =
scope.scope_vars g
in
let g =
SubScopeMap.fold
SubScopeName.Map.fold
(fun (v : SubScopeName.t) _ g ->
ScopeDependencies.add_vertex g (Vertex.SubScope v))
scope.scope_sub_scopes g
@ -229,10 +229,10 @@ let build_scope_dependencies (scope : Ast.scope) : ScopeDependencies.t =
(** {2 Graph declaration} *)
module ExceptionVertex = struct
include Ast.RuleSet
include RuleName.Set
let hash (x : t) : int =
Ast.RuleSet.fold (fun r acc -> Int.logxor (Ast.RuleName.hash r) acc) x 0
RuleName.Set.fold (fun r acc -> Int.logxor (RuleName.hash r) acc) x 0
let equal x y = compare x y = 0
end
@ -257,13 +257,13 @@ module ExceptionsSCC = Graph.Components.Make (ExceptionsDependencies)
(** {2 Graph computations} *)
type exception_edge = {
label_from : Ast.LabelName.t;
label_to : Ast.LabelName.t;
label_from : LabelName.t;
label_to : LabelName.t;
edge_positions : Pos.t list;
}
let build_exceptions_graph
(def : Ast.rule Ast.RuleMap.t)
(def : Ast.rule RuleName.Map.t)
(def_info : Ast.ScopeDef.t) : ExceptionsDependencies.t =
(* First we partition the definitions into groups bearing the same label. To
handle the rules that were not labeled by the user, we create implicit
@ -271,63 +271,59 @@ let build_exceptions_graph
(* All the rules of the form [definition x ...] are base case with no explicit
label, so they should share this implicit label. *)
let base_case_implicit_label =
Ast.LabelName.fresh ("base_case", Pos.no_pos)
in
let base_case_implicit_label = LabelName.fresh ("base_case", Pos.no_pos) in
(* When declaring [exception definition x ...], it means there is a unique
rule [R] to which this can be an exception to. So we give a unique label to
all the rules that are implicitly exceptions to rule [R]. *)
let exception_to_rule_implicit_labels : Ast.LabelName.t Ast.RuleMap.t =
Ast.RuleMap.fold
let exception_to_rule_implicit_labels : LabelName.t RuleName.Map.t =
RuleName.Map.fold
(fun _ rule_from exception_to_rule_implicit_labels ->
match rule_from.Ast.rule_exception with
| Ast.ExceptionToRule (rule_to, _) -> (
match
Ast.RuleMap.find_opt rule_to exception_to_rule_implicit_labels
RuleName.Map.find_opt rule_to exception_to_rule_implicit_labels
with
| Some _ ->
(* we already created the label *) exception_to_rule_implicit_labels
| None ->
Ast.RuleMap.add rule_to
(Ast.LabelName.fresh
( "exception_to_"
^ Marked.unmark (Ast.RuleName.get_info rule_to),
RuleName.Map.add rule_to
(LabelName.fresh
( "exception_to_" ^ Marked.unmark (RuleName.get_info rule_to),
Pos.no_pos ))
exception_to_rule_implicit_labels)
| _ -> exception_to_rule_implicit_labels)
def Ast.RuleMap.empty
def RuleName.Map.empty
in
(* When declaring [exception foo_l definition x ...], the rule is exception to
all the rules sharing label [foo_l]. So we give a unique label to all the
rules that are implicitly exceptions to rule [foo_l]. *)
let exception_to_label_implicit_labels : Ast.LabelName.t Ast.LabelMap.t =
Ast.RuleMap.fold
let exception_to_label_implicit_labels : LabelName.t LabelName.Map.t =
RuleName.Map.fold
(fun _ rule_from
(exception_to_label_implicit_labels : Ast.LabelName.t Ast.LabelMap.t) ->
(exception_to_label_implicit_labels : LabelName.t LabelName.Map.t) ->
match rule_from.Ast.rule_exception with
| Ast.ExceptionToLabel (label_to, _) -> (
match
Ast.LabelMap.find_opt label_to exception_to_label_implicit_labels
LabelName.Map.find_opt label_to exception_to_label_implicit_labels
with
| Some _ ->
(* we already created the label *)
exception_to_label_implicit_labels
| None ->
Ast.LabelMap.add label_to
(Ast.LabelName.fresh
( "exception_to_"
^ Marked.unmark (Ast.LabelName.get_info label_to),
LabelName.Map.add label_to
(LabelName.fresh
( "exception_to_" ^ Marked.unmark (LabelName.get_info label_to),
Pos.no_pos ))
exception_to_label_implicit_labels)
| _ -> exception_to_label_implicit_labels)
def Ast.LabelMap.empty
def LabelName.Map.empty
in
(* Now we have all the labels necessary to partition our rules into sets, each
one corresponding to a label relating to the structure of the exception
DAG. *)
let label_to_rule_sets =
Ast.RuleMap.fold
RuleName.Map.fold
(fun rule_name rule rule_sets ->
let label_of_rule =
match rule.Ast.rule_label with
@ -336,23 +332,23 @@ let build_exceptions_graph
match rule.Ast.rule_exception with
| BaseCase -> base_case_implicit_label
| ExceptionToRule (r, _) ->
Ast.RuleMap.find r exception_to_rule_implicit_labels
RuleName.Map.find r exception_to_rule_implicit_labels
| ExceptionToLabel (l', _) ->
Ast.LabelMap.find l' exception_to_label_implicit_labels)
LabelName.Map.find l' exception_to_label_implicit_labels)
in
Ast.LabelMap.update label_of_rule
LabelName.Map.update label_of_rule
(fun rule_set ->
match rule_set with
| None -> Some (Ast.RuleSet.singleton rule_name)
| Some rule_set -> Some (Ast.RuleSet.add rule_name rule_set))
| None -> Some (RuleName.Set.singleton rule_name)
| Some rule_set -> Some (RuleName.Set.add rule_name rule_set))
rule_sets)
def Ast.LabelMap.empty
def LabelName.Map.empty
in
let find_label_of_rule (r : Ast.RuleName.t) : Ast.LabelName.t =
let find_label_of_rule (r : RuleName.t) : LabelName.t =
fst
(Ast.LabelMap.choose
(Ast.LabelMap.filter
(fun _ rule_set -> Ast.RuleSet.mem r rule_set)
(LabelName.Map.choose
(LabelName.Map.filter
(fun _ rule_set -> RuleName.Set.mem r rule_set)
label_to_rule_sets))
in
(* Next, we collect the exception edges between those groups of rules referred
@ -360,7 +356,7 @@ let build_exceptions_graph
edges as they are declared at each rule but should be the same for all the
rules of the same group. *)
let exception_edges : exception_edge list =
Ast.RuleMap.fold
RuleName.Map.fold
(fun rule_name rule exception_edges ->
let label_from = find_label_of_rule rule_name in
let label_to_and_pos =
@ -374,16 +370,16 @@ let build_exceptions_graph
| Some (label_to, edge_pos) -> (
let other_edges_originating_from_same_label =
List.filter
(fun edge -> Ast.LabelName.compare edge.label_from label_from = 0)
(fun edge -> LabelName.compare edge.label_from label_from = 0)
exception_edges
in
(* We check the consistency*)
if Ast.LabelName.compare label_from label_to = 0 then
if LabelName.compare label_from label_to = 0 then
Errors.raise_spanned_error edge_pos
"Cannot define rule as an exception to itself";
List.iter
(fun edge ->
if Ast.LabelName.compare edge.label_to label_to <> 0 then
if LabelName.compare edge.label_to label_to <> 0 then
Errors.raise_multispanned_error
(( Some
"This declaration contradicts another exception \
@ -401,8 +397,8 @@ let build_exceptions_graph
let existing_edge =
List.find_opt
(fun edge ->
Ast.LabelName.compare edge.label_from label_from = 0
&& Ast.LabelName.compare edge.label_to label_to = 0)
LabelName.compare edge.label_from label_from = 0
&& LabelName.compare edge.label_to label_to = 0)
exception_edges
in
match existing_edge with
@ -420,7 +416,7 @@ let build_exceptions_graph
in
(* We've got the vertices and the edges, let's build the graph! *)
let g =
Ast.LabelMap.fold
LabelName.Map.fold
(fun _label rule_set g -> ExceptionsDependencies.add_vertex g rule_set)
label_to_rule_sets ExceptionsDependencies.empty
in
@ -429,10 +425,10 @@ let build_exceptions_graph
List.fold_left
(fun g edge ->
let rule_group_from =
Ast.LabelMap.find edge.label_from label_to_rule_sets
LabelName.Map.find edge.label_from label_to_rule_sets
in
let rule_group_to =
Ast.LabelMap.find edge.label_to label_to_rule_sets
LabelName.Map.find edge.label_to label_to_rule_sets
in
let edge =
ExceptionsDependencies.E.create rule_group_from edge.edge_positions
@ -453,11 +449,10 @@ let check_for_exception_cycle (g : ExceptionsDependencies.t) : unit =
let spans =
List.flatten
(List.map
(fun (vs : Ast.RuleSet.t) ->
let v = Ast.RuleSet.choose vs in
(fun (vs : RuleName.Set.t) ->
let v = RuleName.Set.choose vs in
let var_str, var_info =
( Format.asprintf "%a" Ast.RuleName.format_t v,
Ast.RuleName.get_info v )
Format.asprintf "%a" RuleName.format_t v, RuleName.get_info v
in
let succs = ExceptionsDependencies.succ_e g vs in
let _, edge_pos, _ =

View File

@ -17,7 +17,8 @@
(** Scope dependencies computations using {{:http://ocamlgraph.lri.fr/}
OCamlgraph} *)
open Utils
open Catala_utils
open Shared_ast
(** {1 Scope variables dependency graph} *)
@ -71,9 +72,9 @@ val build_scope_dependencies : Ast.scope -> ScopeDependencies.t
module EdgeExceptions : Graph.Sig.ORDERED_TYPE_DFT with type t = Pos.t list
module ExceptionsDependencies :
Graph.Sig.P with type V.t = Ast.RuleSet.t and type E.label = EdgeExceptions.t
Graph.Sig.P with type V.t = RuleName.Set.t and type E.label = EdgeExceptions.t
val build_exceptions_graph :
Ast.rule Ast.RuleMap.t -> Ast.ScopeDef.t -> ExceptionsDependencies.t
Ast.rule RuleName.Map.t -> Ast.ScopeDef.t -> ExceptionsDependencies.t
val check_for_exception_cycle : ExceptionsDependencies.t -> unit

View File

@ -1,670 +0,0 @@
(* This file is part of the Catala compiler, a specification language for tax
and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
Denis Merigoux <denis.merigoux@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not
use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *)
(** Translation from {!module: Desugared.Ast} to {!module: Scopelang.Ast} *)
open Utils
open Shared_ast
(** {1 Expression translation}*)
type target_scope_vars =
| WholeVar of ScopeVar.t
| States of (StateName.t * ScopeVar.t) list
type ctx = {
scope_var_mapping : target_scope_vars ScopeVarMap.t;
var_mapping : (Ast.expr, untyped Scopelang.Ast.expr Var.t) Var.Map.t;
}
let tag_with_log_entry
(e : untyped Scopelang.Ast.expr boxed)
(l : log_entry)
(markings : Utils.Uid.MarkedString.info list) :
untyped Scopelang.Ast.expr boxed =
Expr.eapp
(Expr.eop (Unop (Log (l, markings))) (Marked.get_mark e))
[e] (Marked.get_mark e)
let rec translate_expr (ctx : ctx) (e : Ast.expr) :
untyped Scopelang.Ast.expr boxed =
let m = Marked.get_mark e in
match Marked.unmark e with
| ELocation (SubScopeVar (s_name, ss_name, s_var)) ->
(* When referring to a subscope variable in an expression, we are referring
to the output, hence we take the last state. *)
let new_s_var =
match ScopeVarMap.find (Marked.unmark s_var) ctx.scope_var_mapping with
| WholeVar new_s_var -> Marked.same_mark_as new_s_var s_var
| States states ->
Marked.same_mark_as (snd (List.hd (List.rev states))) s_var
in
Expr.elocation (SubScopeVar (s_name, ss_name, new_s_var)) m
| ELocation (DesugaredScopeVar (s_var, None)) ->
Expr.elocation
(ScopelangScopeVar
(match
ScopeVarMap.find (Marked.unmark s_var) ctx.scope_var_mapping
with
| WholeVar new_s_var -> Marked.same_mark_as new_s_var s_var
| States _ -> failwith "should not happen"))
m
| ELocation (DesugaredScopeVar (s_var, Some state)) ->
Expr.elocation
(ScopelangScopeVar
(match
ScopeVarMap.find (Marked.unmark s_var) ctx.scope_var_mapping
with
| WholeVar _ -> failwith "should not happen"
| States states -> Marked.same_mark_as (List.assoc state states) s_var))
m
| EVar v -> Expr.evar (Var.Map.find v ctx.var_mapping) m
| EStruct (s_name, fields) ->
Expr.estruct s_name (StructFieldMap.map (translate_expr ctx) fields) m
| EStructAccess (e1, f_name, s_name) ->
Expr.estructaccess (translate_expr ctx e1) f_name s_name m
| EEnumInj (e1, cons, e_name) ->
Expr.eenuminj (translate_expr ctx e1) cons e_name m
| EMatchS (e1, e_name, arms) ->
Expr.ematchs (translate_expr ctx e1) e_name
(EnumConstructorMap.map (translate_expr ctx) arms)
m
| EScopeCall (sc_name, fields) ->
Expr.escopecall sc_name
(ScopeVarMap.fold
(fun v e fields' ->
let v' =
match ScopeVarMap.find v ctx.scope_var_mapping with
| WholeVar v' -> v'
| States ((_, v') :: _) ->
(* When there are multiple states, the input is always the first
one *)
v'
| States [] -> assert false
in
ScopeVarMap.add v' (translate_expr ctx e) fields')
fields ScopeVarMap.empty)
m
| ELit
(( LBool _ | LEmptyError | LInt _ | LRat _ | LMoney _ | LUnit | LDate _
| LDuration _ ) as l) ->
Expr.elit l m
| EAbs (binder, typs) ->
let vars, body = Bindlib.unmbind binder in
let new_vars = Array.map (fun var -> Var.make (Bindlib.name_of var)) vars in
let ctx =
List.fold_left2
(fun ctx var new_var ->
{ ctx with var_mapping = Var.Map.add var new_var ctx.var_mapping })
ctx (Array.to_list vars) (Array.to_list new_vars)
in
Expr.eabs (Expr.bind new_vars (translate_expr ctx body)) typs m
| EApp (e1, args) ->
Expr.eapp (translate_expr ctx e1) (List.map (translate_expr ctx) args) m
| EOp op -> Expr.eop op m
| EDefault (excepts, just, cons) ->
Expr.edefault
(List.map (translate_expr ctx) excepts)
(translate_expr ctx just) (translate_expr ctx cons) m
| EIfThenElse (e1, e2, e3) ->
Expr.eifthenelse (translate_expr ctx e1) (translate_expr ctx e2)
(translate_expr ctx e3) m
| EArray args -> Expr.earray (List.map (translate_expr ctx) args) m
| ErrorOnEmpty e1 -> Expr.eerroronempty (translate_expr ctx e1) m
(** {1 Rule tree construction} *)
(** Intermediate representation for the exception tree of rules for a particular
scope definition. *)
type rule_tree =
| Leaf of Ast.rule list
(** Rules defining a base case piecewise. List is non-empty. *)
| Node of rule_tree list * Ast.rule list
(** [Node (exceptions, base_case)] is a list of exceptions to a non-empty
list of rules defining a base case piecewise. *)
(** Transforms a flat list of rules into a tree, taking into account the
priorities declared between rules *)
let def_map_to_tree (def_info : Ast.ScopeDef.t) (def : Ast.rule Ast.RuleMap.t) :
rule_tree list =
let exc_graph = Dependency.build_exceptions_graph def def_info in
Dependency.check_for_exception_cycle exc_graph;
(* we start by the base cases: they are the vertices which have no
successors *)
let base_cases =
Dependency.ExceptionsDependencies.fold_vertex
(fun v base_cases ->
if Dependency.ExceptionsDependencies.out_degree exc_graph v = 0 then
v :: base_cases
else base_cases)
exc_graph []
in
let rec build_tree (base_cases : Ast.RuleSet.t) : rule_tree =
let exceptions =
Dependency.ExceptionsDependencies.pred exc_graph base_cases
in
let base_case_as_rule_list =
List.map
(fun r -> Ast.RuleMap.find r def)
(Ast.RuleSet.elements base_cases)
in
match exceptions with
| [] -> Leaf base_case_as_rule_list
| _ -> Node (List.map build_tree exceptions, base_case_as_rule_list)
in
List.map build_tree base_cases
(** From the {!type: rule_tree}, builds an {!constructor: Dcalc.EDefault}
expression in the scope language. The [~toplevel] parameter is used to know
when to place the toplevel binding in the case of functions. *)
let rec rule_tree_to_expr
~(toplevel : bool)
(ctx : ctx)
(def_pos : Pos.t)
(is_func : Ast.expr Var.t option)
(tree : rule_tree) : untyped Scopelang.Ast.expr boxed =
let emark = Untyped { pos = def_pos } in
let exceptions, base_rules =
match tree with Leaf r -> [], r | Node (exceptions, r) -> exceptions, r
in
(* because each rule has its own variable parameter and we want to convert the
whole rule tree into a function, we need to perform some alpha-renaming of
all the expressions *)
let substitute_parameter (e : Ast.expr boxed) (rule : Ast.rule) :
Ast.expr boxed =
match is_func, rule.Ast.rule_parameter with
| Some new_param, Some (old_param, _) ->
let binder = Bindlib.bind_var old_param (Marked.unmark e) in
Marked.mark (Marked.get_mark e)
@@ Bindlib.box_apply2
(fun binder new_param -> Bindlib.subst binder new_param)
binder
(Bindlib.box_var new_param)
| None, None -> e
| _ -> assert false
(* should not happen *)
in
let ctx =
match is_func with
| None -> ctx
| Some new_param -> (
match Var.Map.find_opt new_param ctx.var_mapping with
| None ->
let new_param_scope = Var.make (Bindlib.name_of new_param) in
{
ctx with
var_mapping = Var.Map.add new_param new_param_scope ctx.var_mapping;
}
| Some _ ->
(* We only create a mapping if none exists because [rule_tree_to_expr]
is called recursively on the exceptions of the tree and we don't want
to create a new Scopelang variable for the parameter at each tree
level. *)
ctx)
in
let base_just_list =
List.map
(fun rule -> substitute_parameter rule.Ast.rule_just rule)
base_rules
in
let base_cons_list =
List.map
(fun rule -> substitute_parameter rule.Ast.rule_cons rule)
base_rules
in
let translate_and_unbox_list (list : Ast.expr boxed list) :
untyped Scopelang.Ast.expr boxed list =
List.map
(fun e ->
(* There are two levels of boxing here, the outermost is introduced by
the [translate_expr] function for which all of the bindings should
have been closed by now, so we can safely unbox. *)
translate_expr ctx (Expr.unbox e))
list
in
let default_containing_base_cases =
Expr.make_default
(List.map2
(fun base_just base_cons ->
Expr.make_default []
(* Here we insert the logging command that records when a decision
is taken for the value of a variable. *)
(tag_with_log_entry base_just PosRecordIfTrueBool [])
base_cons emark)
(translate_and_unbox_list base_just_list)
(translate_and_unbox_list base_cons_list))
(Expr.elit (LBool false) emark)
(Expr.elit LEmptyError emark)
emark
in
let exceptions =
List.map (rule_tree_to_expr ~toplevel:false ctx def_pos is_func) exceptions
in
let default =
Expr.make_default exceptions
(Expr.elit (LBool true) emark)
default_containing_base_cases emark
in
match is_func, (List.hd base_rules).Ast.rule_parameter with
| None, None -> default
| Some new_param, Some (_, typ) ->
if toplevel then
(* When we're creating a function from multiple defaults, we must check
that the result returned by the function is not empty *)
let default = Expr.eerroronempty default emark in
Expr.make_abs
[| Var.Map.find new_param ctx.var_mapping |]
default [typ] def_pos
else default
| _ -> (* should not happen *) assert false
(** {1 AST translation} *)
(** Translates a definition inside a scope, the resulting expression should be
an {!constructor: Dcalc.EDefault} *)
let translate_def
(ctx : ctx)
(def_info : Ast.ScopeDef.t)
(def : Ast.rule Ast.RuleMap.t)
(typ : typ)
(io : Scopelang.Ast.io)
~(is_cond : bool)
~(is_subscope_var : bool) : untyped Scopelang.Ast.expr boxed =
(* Here, we have to transform this list of rules into a default tree. *)
let is_def_func =
match Marked.unmark typ with TArrow (_, _) -> true | _ -> false
in
let is_rule_func _ (r : Ast.rule) : bool =
Option.is_some r.Ast.rule_parameter
in
let all_rules_func = Ast.RuleMap.for_all is_rule_func def in
let all_rules_not_func =
Ast.RuleMap.for_all (fun n r -> not (is_rule_func n r)) def
in
let is_def_func_param_typ : typ option =
if is_def_func && all_rules_func then
match Marked.unmark typ with
| TArrow (t_param, _) -> Some t_param
| _ ->
Errors.raise_spanned_error (Marked.get_mark typ)
"The definitions of %a are function but it doesn't have a function \
type"
Ast.ScopeDef.format_t def_info
else if (not is_def_func) && all_rules_not_func then None
else
let spans =
List.map
(fun (_, r) ->
Some "This definition is a function:", Expr.pos r.Ast.rule_cons)
(Ast.RuleMap.bindings (Ast.RuleMap.filter is_rule_func def))
@ List.map
(fun (_, r) ->
( Some "This definition is not a function:",
Expr.pos r.Ast.rule_cons ))
(Ast.RuleMap.bindings
(Ast.RuleMap.filter (fun n r -> not (is_rule_func n r)) def))
in
Errors.raise_multispanned_error spans
"some definitions of the same variable are functions while others \
aren't"
in
let top_list = def_map_to_tree def_info def in
let is_input =
match Marked.unmark io.Scopelang.Ast.io_input with
| OnlyInput -> true
| _ -> false
in
let top_value =
if is_cond && ((not is_subscope_var) || (is_subscope_var && is_input)) then
(* We add the bottom [false] value for conditions, only for the scope
where the condition is declared. Except when the variable is an input,
where we want the [false] to be added at each caller parent scope. *)
Some
(Ast.always_false_rule
(Ast.ScopeDef.get_position def_info)
is_def_func_param_typ)
else None
in
if
Ast.RuleMap.cardinal def = 0
&& is_subscope_var
(* Here we have a special case for the empty definitions. Indeed, we could
use the code for the regular case below that would create a convoluted
default always returning empty error, and this would be correct. But it
gets more complicated with functions. Indeed, if we create an empty
definition for a subscope argument whose type is a function, we get
something like [fun () -> (fun real_param -> < ... >)] that is passed as
an argument to the subscope. The sub-scope de-thunks but the de-thunking
does not return empty error, signalling there is not reentrant variable,
because functions are values! So the subscope does not see that there is
not reentrant variable and does not pick its internal definition instead.
See [test/test_scope/subscope_function_arg_not_defined.catala_en] for a
test case exercising that subtlety.
To avoid this complication we special case here and put an empty error
for all subscope variables that are not defined. It covers the subtlety
with functions described above but also conditions with the false default
value. *)
&& not (is_cond && is_input)
(* However, this special case suffers from an exception: when a condition is
defined as an OnlyInput to a subscope, since the [false] default value
will not be provided by the calee scope, it has to be placed in the
caller. *)
then
Expr.elit LEmptyError (Untyped { pos = Ast.ScopeDef.get_position def_info })
else
rule_tree_to_expr ~toplevel:true ctx
(Ast.ScopeDef.get_position def_info)
(Option.map (fun _ -> Var.make "param") is_def_func_param_typ)
(match top_list, top_value with
| [], None ->
(* In this case, there are no rules to define the expression and no
default value so we put an empty rule. *)
Leaf [Ast.empty_rule (Marked.get_mark typ) is_def_func_param_typ]
| [], Some top_value ->
(* In this case, there are no rules to define the expression but a
default value so we put it. *)
Leaf [top_value]
| _, Some top_value ->
(* When there are rules + a default value, we put the rules as
exceptions to the default value *)
Node (top_list, [top_value])
| [top_tree], None -> top_tree
| _, None ->
Node
( top_list,
[Ast.empty_rule (Marked.get_mark typ) is_def_func_param_typ] ))
(** Translates a scope *)
let translate_scope (ctx : ctx) (scope : Ast.scope) :
untyped Scopelang.Ast.scope_decl =
let scope_dependencies = Dependency.build_scope_dependencies scope in
Dependency.check_for_cycle scope scope_dependencies;
let scope_ordering =
Dependency.correct_computation_ordering scope_dependencies
in
let scope_decl_rules =
List.flatten
(List.map
(fun vertex ->
match vertex with
| Dependency.Vertex.Var (var, state) -> (
let scope_def =
Ast.ScopeDefMap.find
(Ast.ScopeDef.Var (var, state))
scope.scope_defs
in
let var_def = scope_def.scope_def_rules in
let var_typ = scope_def.scope_def_typ in
let is_cond = scope_def.scope_def_is_condition in
match Marked.unmark scope_def.Ast.scope_def_io.io_input with
| OnlyInput when not (Ast.RuleMap.is_empty var_def) ->
(* If the variable is tagged as input, then it shall not be
redefined. *)
Errors.raise_multispanned_error
(( Some "Incriminated variable:",
Marked.get_mark (ScopeVar.get_info var) )
:: List.map
(fun (rule, _) ->
( Some "Incriminated variable definition:",
Marked.get_mark (Ast.RuleName.get_info rule) ))
(Ast.RuleMap.bindings var_def))
"It is impossible to give a definition to a scope variable \
tagged as input."
| OnlyInput ->
[]
(* we do not provide any definition for an input-only variable *)
| _ ->
let expr_def =
translate_def ctx
(Ast.ScopeDef.Var (var, state))
var_def var_typ scope_def.Ast.scope_def_io ~is_cond
~is_subscope_var:false
in
let scope_var =
match ScopeVarMap.find var ctx.scope_var_mapping, state with
| WholeVar v, None -> v
| States states, Some state -> List.assoc state states
| _ -> failwith "should not happen"
in
[
Scopelang.Ast.Definition
( ( ScopelangScopeVar
( scope_var,
Marked.get_mark (ScopeVar.get_info scope_var) ),
Marked.get_mark (ScopeVar.get_info scope_var) ),
var_typ,
scope_def.Ast.scope_def_io,
Expr.unbox expr_def );
])
| Dependency.Vertex.SubScope sub_scope_index ->
(* Before calling the sub_scope, we need to include all the
re-definitions of subscope parameters*)
let sub_scope =
SubScopeMap.find sub_scope_index scope.scope_sub_scopes
in
let sub_scope_vars_redefs_candidates =
Ast.ScopeDefMap.filter
(fun def_key scope_def ->
match def_key with
| Ast.ScopeDef.Var _ -> false
| Ast.ScopeDef.SubScopeVar (sub_scope_index', _, _) ->
sub_scope_index = sub_scope_index'
(* We exclude subscope variables that have 0 re-definitions
and are not visible in the input of the subscope *)
&& not
((match
Marked.unmark scope_def.Ast.scope_def_io.io_input
with
| Scopelang.Ast.NoInput -> true
| _ -> false)
&& Ast.RuleMap.is_empty scope_def.scope_def_rules))
scope.scope_defs
in
let sub_scope_vars_redefs =
Ast.ScopeDefMap.mapi
(fun def_key scope_def ->
let def = scope_def.Ast.scope_def_rules in
let def_typ = scope_def.scope_def_typ in
let is_cond = scope_def.scope_def_is_condition in
match def_key with
| Ast.ScopeDef.Var _ -> assert false (* should not happen *)
| Ast.ScopeDef.SubScopeVar (sscope, sub_scope_var, pos) ->
(* This definition redefines a variable of the correct
subscope. But we have to check that this redefinition is
allowed with respect to the io parameters of that
subscope variable. *)
(match
Marked.unmark scope_def.Ast.scope_def_io.io_input
with
| Scopelang.Ast.NoInput ->
Errors.raise_multispanned_error
(( Some "Incriminated subscope:",
Marked.get_mark (SubScopeName.get_info sscope) )
:: ( Some "Incriminated variable:",
Marked.get_mark (ScopeVar.get_info sub_scope_var)
)
:: List.map
(fun (rule, _) ->
( Some
"Incriminated subscope variable definition:",
Marked.get_mark (Ast.RuleName.get_info rule) ))
(Ast.RuleMap.bindings def))
"It is impossible to give a definition to a subscope \
variable not tagged as input or context."
| OnlyInput when Ast.RuleMap.is_empty def && not is_cond ->
(* If the subscope variable is tagged as input, then it
shall be defined. *)
Errors.raise_multispanned_error
[
( Some "Incriminated subscope:",
Marked.get_mark (SubScopeName.get_info sscope) );
Some "Incriminated variable:", pos;
]
"This subscope variable is a mandatory input but no \
definition was provided."
| _ -> ());
(* Now that all is good, we can proceed with translating
this redefinition to a proper Scopelang term. *)
let expr_def =
translate_def ctx def_key def def_typ
scope_def.Ast.scope_def_io ~is_cond
~is_subscope_var:true
in
let subscop_real_name =
SubScopeMap.find sub_scope_index scope.scope_sub_scopes
in
let var_pos = Ast.ScopeDef.get_position def_key in
Scopelang.Ast.Definition
( ( SubScopeVar
( subscop_real_name,
(sub_scope_index, var_pos),
match
ScopeVarMap.find sub_scope_var
ctx.scope_var_mapping
with
| WholeVar v -> v, var_pos
| States states ->
(* When defining a sub-scope variable, we
always define its first state in the
sub-scope. *)
snd (List.hd states), var_pos ),
var_pos ),
def_typ,
scope_def.Ast.scope_def_io,
Expr.unbox expr_def ))
sub_scope_vars_redefs_candidates
in
let sub_scope_vars_redefs =
List.map snd (Ast.ScopeDefMap.bindings sub_scope_vars_redefs)
in
sub_scope_vars_redefs
@ [
Scopelang.Ast.Call
( sub_scope,
sub_scope_index,
Untyped
{
pos =
Marked.get_mark
(SubScopeName.get_info sub_scope_index);
} );
])
scope_ordering)
in
(* Then, after having computed all the scopes variables, we add the
assertions. TODO: the assertions should be interleaved with the
definitions! *)
let scope_decl_rules =
scope_decl_rules
@ List.map
(fun e ->
let scope_e = translate_expr ctx (Expr.unbox e) in
Scopelang.Ast.Assertion (Expr.unbox scope_e))
scope.Ast.scope_assertions
in
let scope_sig =
ScopeVarMap.fold
(fun var (states : Ast.var_or_states) acc ->
match states with
| WholeVar ->
let scope_def =
Ast.ScopeDefMap.find (Ast.ScopeDef.Var (var, None)) scope.scope_defs
in
let typ = scope_def.scope_def_typ in
ScopeVarMap.add
(match ScopeVarMap.find var ctx.scope_var_mapping with
| WholeVar v -> v
| States _ -> failwith "should not happen")
(typ, scope_def.scope_def_io)
acc
| States states ->
(* What happens in the case of variables with multiple states is
interesting. We need to create as many Scopelang.Var entries in the
scope signature as there are states. *)
List.fold_left
(fun acc (state : StateName.t) ->
let scope_def =
Ast.ScopeDefMap.find
(Ast.ScopeDef.Var (var, Some state))
scope.scope_defs
in
ScopeVarMap.add
(match ScopeVarMap.find var ctx.scope_var_mapping with
| WholeVar _ -> failwith "should not happen"
| States states' -> List.assoc state states')
(scope_def.scope_def_typ, scope_def.scope_def_io)
acc)
acc states)
scope.scope_vars ScopeVarMap.empty
in
let pos = Marked.get_mark (ScopeName.get_info scope.scope_uid) in
{
Scopelang.Ast.scope_decl_name = scope.scope_uid;
Scopelang.Ast.scope_decl_rules;
Scopelang.Ast.scope_sig;
Scopelang.Ast.scope_mark = Untyped { pos };
}
(** {1 API} *)
let translate_program (pgrm : Ast.program) : untyped Scopelang.Ast.program =
(* First we give mappings to all the locations between Desugared and
Scopelang. This involves creating a new Scopelang scope variable for every
state of a Desugared variable. *)
let ctx =
ScopeMap.fold
(fun _scope scope_decl ctx ->
ScopeVarMap.fold
(fun scope_var (states : Ast.var_or_states) ctx ->
match states with
| Ast.WholeVar ->
{
ctx with
scope_var_mapping =
ScopeVarMap.add scope_var
(WholeVar (ScopeVar.fresh (ScopeVar.get_info scope_var)))
ctx.scope_var_mapping;
}
| States states ->
{
ctx with
scope_var_mapping =
ScopeVarMap.add scope_var
(States
(List.map
(fun state ->
( state,
ScopeVar.fresh
(let state_name, state_pos =
StateName.get_info state
in
( Marked.unmark (ScopeVar.get_info scope_var)
^ "_"
^ state_name,
state_pos )) ))
states))
ctx.scope_var_mapping;
})
scope_decl.Ast.scope_vars ctx)
pgrm.Ast.program_scopes
{ scope_var_mapping = ScopeVarMap.empty; var_mapping = Var.Map.empty }
in
{
Scopelang.Ast.program_scopes =
ScopeMap.map (translate_scope ctx) pgrm.program_scopes;
Scopelang.Ast.program_ctx = pgrm.program_ctx;
}

View File

@ -0,0 +1,78 @@
(* This file is part of the Catala compiler, a specification language for tax
and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
Louis Gesbert <louis.gesbert@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not
use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *)
open Shared_ast
open Ast
let expr ctx env e =
(* The typer takes care of disambiguating: this consists in: - ensuring
[EAbs.tys] doesn't contain any [TAny] - [EDStructAccess.name_opt] is always
[Some] *)
(* Intermediate unboxings are fine since the last [untype] will rebox in
depth *)
Typing.check_expr ctx ~env (Expr.unbox e)
let rule ctx env rule =
let env =
match rule.rule_parameter with
| None -> env
| Some (v, ty) -> Typing.Env.add_var v ty env
in
(* Note: we could use the known rule type here to direct typing. We choose not
to because it shouldn't be needed for disambiguation, and we prefer to
focus on local type errors first. *)
{
rule with
rule_just = expr ctx env rule.rule_just;
rule_cons = expr ctx env rule.rule_cons;
}
let scope ctx env scope =
let env = Typing.Env.open_scope scope.scope_uid env in
let scope_defs =
ScopeDefMap.map
(fun def ->
let scope_def_rules =
(* Note: ordering in file order might be better for error reporting ?
When we gather errors, the ordering could be done afterwards,
though *)
RuleName.Map.map (rule ctx env) def.scope_def_rules
in
{ def with scope_def_rules })
scope.scope_defs
in
let scope_assertions = List.map (expr ctx env) scope.scope_assertions in
{ scope with scope_defs; scope_assertions }
let program prg =
let env =
ScopeName.Map.fold
(fun scope_name scope env ->
let vars =
ScopeDefMap.fold
(fun var def vars ->
match var with
| Var (v, _states) -> ScopeVar.Map.add v def.scope_def_typ vars
| SubScopeVar _ -> vars)
scope.scope_defs ScopeVar.Map.empty
in
Typing.Env.add_scope scope_name ~vars env)
prg.program_scopes Typing.Env.empty
in
let program_scopes =
ScopeName.Map.map (scope prg.program_ctx env) prg.program_scopes
in
{ prg with program_scopes }

View File

@ -0,0 +1,24 @@
(* This file is part of the Catala compiler, a specification language for tax
and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
Louis Gesbert <louis.gesbert@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not
use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *)
(** This module does local typing in order to fill some missing type information
in the AST:
- it fills the types of arguments in [EAbs] nodes, (untyped ones are
inserted during desugaring, e.g. by `let-in` constructs),
- it resolves the structure names of [EDStructAccess] nodes. *)
val program : Ast.program -> Ast.program

View File

@ -1,7 +1,7 @@
(library
(name desugared)
(public_name catala.desugared)
(libraries utils dcalc scopelang ocamlgraph))
(libraries ocamlgraph catala_utils shared_ast surface))
(documentation
(package catala)

View File

@ -20,6 +20,6 @@
- Removes syntactic sugars
- Separate code from legislation *)
val desugar_program :
Name_resolution.context -> Ast.program -> Desugared.Ast.program
val translate_program :
Name_resolution.context -> Surface.Ast.program -> Ast.program
(** Main function of this module *)

View File

@ -18,20 +18,18 @@
(** Builds a context that allows for mapping each name to a precise uid, taking
lexical scopes into account *)
open Utils
open Catala_utils
open Shared_ast
(** {1 Name resolution context} *)
type ident = string
type unique_rulename =
| Ambiguous of Pos.t list
| Unique of Desugared.Ast.RuleName.t Marked.pos
| Unique of RuleName.t Marked.pos
type scope_def_context = {
default_exception_rulename : unique_rulename option;
label_idmap : Desugared.Ast.LabelName.t Desugared.Ast.IdentMap.t;
label_idmap : LabelName.t IdentName.Map.t;
}
type scope_var_or_subscope =
@ -39,26 +37,26 @@ type scope_var_or_subscope =
| SubScope of SubScopeName.t * ScopeName.t
type scope_context = {
var_idmap : scope_var_or_subscope Desugared.Ast.IdentMap.t;
var_idmap : scope_var_or_subscope IdentName.Map.t;
(** All variables, including scope variables and subscopes *)
scope_defs_contexts : scope_def_context Desugared.Ast.ScopeDefMap.t;
scope_defs_contexts : scope_def_context Ast.ScopeDefMap.t;
(** What is the default rule to refer to for unnamed exceptions, if any *)
sub_scopes : ScopeSet.t;
sub_scopes : ScopeName.Set.t;
(** Other scopes referred to by this scope. Used for dependency analysis *)
}
(** Inside a scope, we distinguish between the variables and the subscopes. *)
type struct_context = typ StructFieldMap.t
type struct_context = typ StructField.Map.t
(** Types of the fields of a struct *)
type enum_context = typ EnumConstructorMap.t
type enum_context = typ EnumConstructor.Map.t
(** Types of the payloads of the cases of an enum *)
type var_sig = {
var_sig_typ : typ;
var_sig_is_condition : bool;
var_sig_io : Ast.scope_decl_context_io;
var_sig_states_idmap : StateName.t Desugared.Ast.IdentMap.t;
var_sig_io : Surface.Ast.scope_decl_context_io;
var_sig_states_idmap : StateName.t IdentName.Map.t;
var_sig_states_list : StateName.t list;
}
@ -67,25 +65,26 @@ type var_sig = {
type typedef =
| TStruct of StructName.t
| TEnum of EnumName.t
| TScope of ScopeName.t * StructName.t
| TScope of ScopeName.t * scope_out_struct
(** Implicitly defined output struct *)
type context = {
local_var_idmap : Desugared.Ast.expr Var.t Desugared.Ast.IdentMap.t;
local_var_idmap : Ast.expr Var.t IdentName.Map.t;
(** Inside a definition, local variables can be introduced by functions
arguments or pattern matching *)
typedefs : typedef Desugared.Ast.IdentMap.t;
typedefs : typedef IdentName.Map.t;
(** Gathers the names of the scopes, structs and enums *)
field_idmap : StructFieldName.t StructMap.t Desugared.Ast.IdentMap.t;
field_idmap : StructField.t StructName.Map.t IdentName.Map.t;
(** The names of the struct fields. Names of fields can be shared between
different structs *)
constructor_idmap : EnumConstructor.t EnumMap.t Desugared.Ast.IdentMap.t;
constructor_idmap : EnumConstructor.t EnumName.Map.t IdentName.Map.t;
(** The names of the enum constructors. Constructor names can be shared
between different enums *)
scopes : scope_context ScopeMap.t; (** For each scope, its context *)
structs : struct_context StructMap.t; (** For each struct, its context *)
enums : enum_context EnumMap.t; (** For each enum, its context *)
var_typs : var_sig ScopeVarMap.t;
scopes : scope_context ScopeName.Map.t; (** For each scope, its context *)
structs : struct_context StructName.Map.t;
(** For each struct, its context *)
enums : enum_context EnumName.Map.t; (** For each enum, its context *)
var_typs : var_sig ScopeVar.Map.t;
(** The signatures of each scope variable declared *)
}
(** Main context used throughout {!module: Surface.Desugaring} *)
@ -96,7 +95,7 @@ val raise_unsupported_feature : string -> Pos.t -> 'a
(** Temporary function raising an error message saying that a feature is not
supported yet *)
val raise_unknown_identifier : string -> ident Marked.pos -> 'a
val raise_unknown_identifier : string -> IdentName.t Marked.pos -> 'a
(** Function to call whenever an identifier used somewhere has not been declared
in the program previously *)
@ -104,53 +103,53 @@ val get_var_typ : context -> ScopeVar.t -> typ
(** Gets the type associated to an uid *)
val is_var_cond : context -> ScopeVar.t -> bool
val get_var_io : context -> ScopeVar.t -> Ast.scope_decl_context_io
val get_var_io : context -> ScopeVar.t -> Surface.Ast.scope_decl_context_io
val get_var_uid : ScopeName.t -> context -> ident Marked.pos -> ScopeVar.t
val get_var_uid : ScopeName.t -> context -> IdentName.t Marked.pos -> ScopeVar.t
(** Get the variable uid inside the scope given in argument *)
val get_subscope_uid :
ScopeName.t -> context -> ident Marked.pos -> SubScopeName.t
ScopeName.t -> context -> IdentName.t Marked.pos -> SubScopeName.t
(** Get the subscope uid inside the scope given in argument *)
val is_subscope_uid : ScopeName.t -> context -> ident -> bool
val is_subscope_uid : ScopeName.t -> context -> IdentName.t -> bool
(** [is_subscope_uid scope_uid ctxt y] returns true if [y] belongs to the
subscopes of [scope_uid]. *)
val belongs_to : context -> ScopeVar.t -> ScopeName.t -> bool
(** Checks if the var_uid belongs to the scope scope_uid *)
val get_def_typ : context -> Desugared.Ast.ScopeDef.t -> typ
val get_def_typ : context -> Ast.ScopeDef.t -> typ
(** Retrieves the type of a scope definition from the context *)
val is_def_cond : context -> Desugared.Ast.ScopeDef.t -> bool
val is_type_cond : Ast.typ -> bool
val is_def_cond : context -> Ast.ScopeDef.t -> bool
val is_type_cond : Surface.Ast.typ -> bool
val add_def_local_var : context -> ident -> context * Desugared.Ast.expr Var.t
val add_def_local_var : context -> IdentName.t -> context * Ast.expr Var.t
(** Adds a binding to the context *)
val get_def_key :
Ast.qident ->
Ast.ident Marked.pos option ->
Surface.Ast.scope_var ->
Surface.Ast.lident Marked.pos option ->
ScopeName.t ->
context ->
Pos.t ->
Desugared.Ast.ScopeDef.t
Ast.ScopeDef.t
(** Usage: [get_def_key var_name var_state scope_uid ctxt pos]*)
val get_enum : context -> ident Marked.pos -> EnumName.t
val get_enum : context -> IdentName.t Marked.pos -> EnumName.t
(** Find an enum definition from the typedefs, failing if there is none or it
has a different kind *)
val get_struct : context -> ident Marked.pos -> StructName.t
val get_struct : context -> IdentName.t Marked.pos -> StructName.t
(** Find a struct definition from the typedefs (possibly an implicit output
struct from a scope), failing if there is none or it has a different kind *)
val get_scope : context -> ident Marked.pos -> ScopeName.t
val get_scope : context -> IdentName.t Marked.pos -> ScopeName.t
(** Find a scope definition from the typedefs, failing if there is none or it
has a different kind *)
(** {1 API} *)
val form_context : Ast.program -> context
val form_context : Surface.Ast.program -> context
(** Derive the context from metadata, in one pass over the declarations *)

View File

@ -15,10 +15,7 @@
License for the specific language governing permissions and limitations under
the License. *)
module Cli = Utils.Cli
module File = Utils.File
module Errors = Utils.Errors
module Pos = Utils.Pos
open Catala_utils
(** Associates a {!type: Cli.backend_lang} with its string represtation. *)
let languages = ["en", Cli.En; "fr", Cli.Fr; "pl", Cli.Pl]
@ -76,7 +73,15 @@ let driver source_file (options : Cli.options) : int =
try `Plugin (Plugin.find s)
with Not_found ->
Errors.raise_error
"The selected backend (%s) is not supported by Catala" backend)
"The selected backend (%s) is not supported by Catala, nor was a \
plugin by this name found under %a"
backend
(Format.pp_print_list
~pp_sep:(fun ppf () -> Format.fprintf ppf "@ or @ ")
(fun ppf dir ->
Format.pp_print_string ppf
(try Unix.readlink dir with _ -> dir)))
options.plugins_dirs)
in
let prgm =
Surface.Parser_driver.parse_top_level_file source_file language
@ -143,7 +148,7 @@ let driver source_file (options : Cli.options) : int =
| ( `Interpret | `Typecheck | `OCaml | `Python | `Scalc | `Lcalc | `Dcalc
| `Scopelang | `Proof | `Plugin _ ) as backend -> (
Cli.debug_print "Name resolution...";
let ctxt = Surface.Name_resolution.form_context prgm in
let ctxt = Desugared.Name_resolution.form_context prgm in
let scope_uid =
match options.ex_scope, backend with
| None, `Interpret ->
@ -151,27 +156,29 @@ let driver source_file (options : Cli.options) : int =
| None, _ ->
let _, scope =
try
Desugared.Ast.IdentMap.filter_map
Shared_ast.IdentName.Map.filter_map
(fun _ -> function
| Surface.Name_resolution.TScope (uid, _) -> Some uid
| Desugared.Name_resolution.TScope (uid, _) -> Some uid
| _ -> None)
ctxt.typedefs
|> Desugared.Ast.IdentMap.choose
|> Shared_ast.IdentName.Map.choose
with Not_found ->
Errors.raise_error "There isn't any scope inside the program."
in
scope
| Some name, _ -> (
match Desugared.Ast.IdentMap.find_opt name ctxt.typedefs with
| Some (Surface.Name_resolution.TScope (uid, _)) -> uid
match Shared_ast.IdentName.Map.find_opt name ctxt.typedefs with
| Some (Desugared.Name_resolution.TScope (uid, _)) -> uid
| _ ->
Errors.raise_error "There is no scope \"%s\" inside the program."
name)
in
Cli.debug_print "Desugaring...";
let prgm = Surface.Desugaring.desugar_program ctxt prgm in
let prgm = Desugared.From_surface.translate_program ctxt prgm in
Cli.debug_print "Disambiguating...";
let prgm = Desugared.Disambiguate.program prgm in
Cli.debug_print "Collecting rules...";
let prgm = Desugared.Desugared_to_scope.translate_program prgm in
let prgm = Scopelang.From_desugared.translate_program prgm in
match backend with
| `Scopelang ->
let _output_file, with_output = get_output_format () in
@ -180,7 +187,8 @@ let driver source_file (options : Cli.options) : int =
if Option.is_some options.ex_scope then
Format.fprintf fmt "%a\n"
(Scopelang.Print.scope prgm.program_ctx ~debug:options.debug)
(scope_uid, Shared_ast.ScopeMap.find scope_uid prgm.program_scopes)
( scope_uid,
Shared_ast.ScopeName.Map.find scope_uid prgm.program_scopes )
else
Format.fprintf fmt "%a\n"
(Scopelang.Print.program ~debug:options.debug)
@ -194,7 +202,7 @@ let driver source_file (options : Cli.options) : int =
in
let prgm = Scopelang.Ast.type_program prgm in
Cli.debug_print "Translating to default calculus...";
let prgm = Scopelang.Scope_to_dcalc.translate_program prgm in
let prgm = Dcalc.From_scopelang.translate_program prgm in
let prgm =
if options.optimize then begin
Cli.debug_print "Optimizing default calculus...";
@ -202,8 +210,21 @@ let driver source_file (options : Cli.options) : int =
end
else prgm
in
(* Cli.debug_print (Format.asprintf "Typechecking results :@\n%a"
(Print.typ prgm.decl_ctx) typ); *)
match backend with
| `Typecheck ->
Cli.debug_print "Typechecking again...";
let _ =
try Shared_ast.Typing.program prgm
with Errors.StructuredError (msg, details) ->
let msg =
"Typing error occured during re-typing on the 'default \
calculus'. This is a bug in the Catala compiler.\n"
^ msg
in
raise (Errors.StructuredError (msg, details))
in
(* That's it! *)
Cli.result_print "Typechecking successful!"
| `Dcalc ->
@ -229,7 +250,7 @@ let driver source_file (options : Cli.options) : int =
Shared_ast.Expr.unbox (Shared_ast.Program.to_expr prgm scope_uid)
in
Format.fprintf fmt "%a\n"
(Shared_ast.Expr.format prgm.decl_ctx)
(Shared_ast.Expr.format ~debug:options.debug prgm.decl_ctx)
prgrm_dcalc_expr
| (`Interpret | `OCaml | `Python | `Scalc | `Lcalc | `Proof | `Plugin _)
as backend -> (
@ -244,8 +265,6 @@ let driver source_file (options : Cli.options) : int =
in
raise (Errors.StructuredError (msg, details))
in
(* Cli.debug_print (Format.asprintf "Typechecking results :@\n%a"
(Print.typ prgm.decl_ctx) typ); *)
match backend with
| `Proof ->
let vcs =
@ -308,24 +327,14 @@ let driver source_file (options : Cli.options) : int =
if Option.is_some options.ex_scope then
Format.fprintf fmt "%a\n"
(Shared_ast.Scope.format ~debug:options.debug prgm.decl_ctx)
( scope_uid,
Option.get
(Shared_ast.Scope.fold_left ~init:None
~f:(fun acc scope_def _ ->
if
Shared_ast.ScopeName.compare scope_def.scope_name
scope_uid
= 0
then Some scope_def.scope_body
else acc)
prgm.scopes) )
(scope_uid, Shared_ast.Program.get_scope_body prgm scope_uid)
else
let prgrm_lcalc_expr =
Shared_ast.Expr.unbox
(Shared_ast.Program.to_expr prgm scope_uid)
in
Format.fprintf fmt "%a\n"
(Shared_ast.Expr.format prgm.decl_ctx)
(Shared_ast.Expr.format ~debug:options.debug prgm.decl_ctx)
prgrm_lcalc_expr
| (`OCaml | `Python | `Scalc | `Plugin _) as backend -> (
match backend with

View File

@ -15,9 +15,10 @@
License for the specific language governing permissions and limitations under
the License. *)
open Catala_utils
module Plugin = Plugin.PluginAPI
val driver : Utils.Pos.input_file -> Utils.Cli.options -> int
val driver : Pos.input_file -> Cli.options -> int
(** Entry function for the executable. Returns a negative number in case of
error. *)

View File

@ -3,7 +3,7 @@
(public_name catala.driver)
(libraries
dynlink
utils
catala_utils
surface
desugared
literate
@ -50,3 +50,7 @@
(documentation
(package catala)
(mld_files index))
(alias
(name catala)
(deps catala.exe))

View File

@ -103,7 +103,7 @@ Two more modules contain additional features for the compiler:
{ul
{li {{: literate.html} Literate programming}}
{li {{: utils.html} Compiler utilities}}
{li {{: catala_utils.html} Compiler utilities}}
}
The Catala runtimes documentation is available here:

View File

@ -14,7 +14,7 @@
License for the specific language governing permissions and limitations under
the License. *)
open Utils
open Catala_utils
include Shared_ast
type lit = lcalc glit
@ -28,31 +28,32 @@ let option_enum : EnumName.t = EnumName.fresh ("eoption", Pos.no_pos)
let none_constr : EnumConstructor.t = EnumConstructor.fresh ("ENone", Pos.no_pos)
let some_constr : EnumConstructor.t = EnumConstructor.fresh ("ESome", Pos.no_pos)
let option_enum_config : (EnumConstructor.t * typ) list =
[none_constr, (TLit TUnit, Pos.no_pos); some_constr, (TAny, Pos.no_pos)]
let option_enum_config : typ EnumConstructor.Map.t =
EnumConstructor.Map.empty
|> EnumConstructor.Map.add none_constr (TLit TUnit, Pos.no_pos)
|> EnumConstructor.Map.add some_constr (TAny, Pos.no_pos)
(* FIXME: proper typing in all the constructors below *)
let make_none m =
let tunit = TLit TUnit, Expr.mark_pos m in
Expr.einj
(Expr.elit LUnit (Expr.with_ty m tunit))
0 option_enum
[TLit TUnit, Pos.no_pos; TAny, Pos.no_pos]
m
Expr.einj (Expr.elit LUnit (Expr.with_ty m tunit)) none_constr option_enum m
let make_some e =
let m = Marked.get_mark e in
Expr.einj e 1 option_enum
[TLit TUnit, Expr.mark_pos m; TAny, Expr.mark_pos m]
m
Expr.einj e some_constr option_enum m
(** [make_matchopt_with_abs_arms arg e_none e_some] build an expression
[match arg with |None -> e_none | Some -> e_some] and requires e_some and
e_none to be in the form [EAbs ...].*)
let make_matchopt_with_abs_arms arg e_none e_some =
let m = Marked.get_mark arg in
Expr.ematch arg [e_none; e_some] option_enum m
let cases =
EnumConstructor.Map.empty
|> EnumConstructor.Map.add none_constr e_none
|> EnumConstructor.Map.add some_constr e_some
in
Expr.ematch arg option_enum cases m
(** [make_matchopt pos v tau arg e_none e_some] builds an expression
[match arg with | None () -> e_none | Some v -> e_some]. It binds v to

View File

@ -14,6 +14,7 @@
License for the specific language governing permissions and limitations under
the License. *)
open Catala_utils
open Shared_ast
(** Abstract syntax tree for the lambda calculus *)
@ -32,7 +33,7 @@ type 'm program = 'm expr Shared_ast.program
val option_enum : EnumName.t
val none_constr : EnumConstructor.t
val some_constr : EnumConstructor.t
val option_enum_config : (EnumConstructor.t * typ) list
val option_enum_config : typ EnumConstructor.Map.t
val make_none : 'm mark -> 'm expr boxed
val make_some : 'm expr boxed -> 'm expr boxed
@ -40,7 +41,7 @@ val make_matchopt_with_abs_arms :
'm expr boxed -> 'm expr boxed -> 'm expr boxed -> 'm expr boxed
val make_matchopt :
Utils.Pos.t ->
Pos.t ->
'm expr Var.t ->
typ ->
'm expr boxed ->

View File

@ -14,7 +14,7 @@
License for the specific language governing permissions and limitations under
the License. *)
open Utils
open Catala_utils
open Shared_ast
open Ast
module D = Dcalc.Ast
@ -31,74 +31,56 @@ let closure_conversion_expr (type m) (ctx : m ctx) (e : m expr) : m expr boxed =
let rec aux e =
let m = Marked.get_mark e in
match Marked.unmark e with
| EStruct _ | EStructAccess _ | ETuple _ | ETupleAccess _ | EInj _
| EArray _ | ELit _ | EAssert _ | EOp _ | EIfThenElse _ | ERaise _
| ECatch _ ->
Expr.map_gather ~acc:Var.Set.empty ~join:Var.Set.union ~f:aux e
| EVar v ->
( (Bindlib.box_var v, m),
if Var.Set.mem v ctx.globally_bound_vars then Var.Set.empty
else Var.Set.singleton v )
| ETuple (args, s) ->
let new_args, free_vars =
List.fold_left
(fun (new_args, free_vars) arg ->
let new_arg, new_free_vars = aux arg in
new_arg :: new_args, Var.Set.union new_free_vars free_vars)
([], Var.Set.empty) args
in
Expr.etuple (List.rev new_args) s m, free_vars
| ETupleAccess (e1, n, s, typs) ->
let new_e1, free_vars = aux e1 in
Expr.etupleaccess new_e1 n s typs m, free_vars
| EInj (e1, n, e_name, typs) ->
let new_e1, free_vars = aux e1 in
Expr.einj new_e1 n e_name typs m, free_vars
| EMatch (e1, arms, e_name) ->
let new_e1, free_vars = aux e1 in
( (if Var.Set.mem v ctx.globally_bound_vars then Var.Set.empty
else Var.Set.singleton v),
(Bindlib.box_var v, m) )
| EMatch { e; cases; name } ->
let free_vars, new_e = aux e in
(* We do not close the clotures inside the arms of the match expression,
since they get a special treatment at compilation to Scalc. *)
let new_arms, free_vars =
List.fold_right
(fun arm (new_arms, free_vars) ->
match Marked.unmark arm with
| EAbs (binder, typs) ->
let free_vars, new_cases =
EnumConstructor.Map.fold
(fun cons e1 (free_vars, new_cases) ->
match Marked.unmark e1 with
| EAbs { binder; tys } ->
let vars, body = Bindlib.unmbind binder in
let new_body, new_free_vars = aux body in
let new_free_vars, new_body = aux body in
let new_binder = Expr.bind vars new_body in
( Expr.eabs new_binder typs (Marked.get_mark arm) :: new_arms,
Var.Set.union free_vars new_free_vars )
( Var.Set.union free_vars new_free_vars,
EnumConstructor.Map.add cons
(Expr.eabs new_binder tys (Marked.get_mark e1))
new_cases )
| _ -> failwith "should not happen")
arms ([], free_vars)
cases
(free_vars, EnumConstructor.Map.empty)
in
Expr.ematch new_e1 new_arms e_name m, free_vars
| EArray args ->
let new_args, free_vars =
List.fold_right
(fun arg (new_args, free_vars) ->
let new_arg, new_free_vars = aux arg in
new_arg :: new_args, Var.Set.union free_vars new_free_vars)
args ([], Var.Set.empty)
in
Expr.earray new_args m, free_vars
| ELit l -> Expr.elit l m, Var.Set.empty
| EApp ((EAbs (binder, typs_abs), e1_pos), args) ->
free_vars, Expr.ematch new_e name new_cases m
| EApp { f = EAbs { binder; tys }, e1_pos; args } ->
(* let-binding, we should not close these *)
let vars, body = Bindlib.unmbind binder in
let new_body, free_vars = aux body in
let free_vars, new_body = aux body in
let new_binder = Expr.bind vars new_body in
let new_args, free_vars =
let free_vars, new_args =
List.fold_right
(fun arg (new_args, free_vars) ->
let new_arg, new_free_vars = aux arg in
new_arg :: new_args, Var.Set.union free_vars new_free_vars)
args ([], free_vars)
(fun arg (free_vars, new_args) ->
let new_free_vars, new_arg = aux arg in
Var.Set.union free_vars new_free_vars, new_arg :: new_args)
args (free_vars, [])
in
Expr.eapp (Expr.eabs new_binder typs_abs e1_pos) new_args m, free_vars
| EAbs (binder, typs) ->
free_vars, Expr.eapp (Expr.eabs new_binder tys e1_pos) new_args m
| EAbs { binder; tys } ->
(* λ x.t *)
let binder_mark = m in
let binder_pos = Expr.mark_pos binder_mark in
(* Converting the closure. *)
let vars, body = Bindlib.unmbind binder in
(* t *)
let new_body, body_vars = aux body in
let body_vars, new_body = aux body in
(* [[t]] *)
let extra_vars =
Var.Set.diff body_vars (Var.Set.of_list (Array.to_list vars))
@ -117,8 +99,8 @@ let closure_conversion_expr (type m) (ctx : m ctx) (e : m expr) : m expr boxed =
(fun i _ ->
Expr.etupleaccess
(Expr.evar inner_c_var binder_mark)
(i + 1) None
(List.map (fun _ -> any_ty) extra_vars_list)
(i + 1)
(List.length extra_vars_list)
binder_mark)
extra_vars_list)
new_body
@ -128,10 +110,11 @@ let closure_conversion_expr (type m) (ctx : m ctx) (e : m expr) : m expr boxed =
Expr.make_abs
(Array.concat [Array.make 1 inner_c_var; vars])
new_closure_body
((TAny, binder_pos) :: typs)
((TAny, binder_pos) :: tys)
(Expr.pos e)
in
( Expr.make_let_in code_var
( extra_vars,
Expr.make_let_in code_var
(TAny, Expr.pos e)
new_closure
(Expr.etuple
@ -139,40 +122,25 @@ let closure_conversion_expr (type m) (ctx : m ctx) (e : m expr) : m expr boxed =
:: List.map
(fun extra_var -> Bindlib.box_var extra_var, binder_mark)
extra_vars_list)
None m)
(Expr.pos e),
extra_vars )
| EApp ((EOp op, pos_op), args) ->
m)
(Expr.pos e) )
| EApp { f = EOp _, _; _ } ->
(* This corresponds to an operator call, which we don't want to
transform*)
let new_args, free_vars =
List.fold_right
(fun arg (new_args, free_vars) ->
let new_arg, new_free_vars = aux arg in
new_arg :: new_args, Var.Set.union free_vars new_free_vars)
args ([], Var.Set.empty)
in
Expr.eapp (Expr.eop op pos_op) new_args m, free_vars
| EApp ((EVar v, v_pos), args) when Var.Set.mem v ctx.globally_bound_vars ->
Expr.map_gather ~acc:Var.Set.empty ~join:Var.Set.union ~f:aux e
| EApp { f = EVar v, _; _ } when Var.Set.mem v ctx.globally_bound_vars ->
(* This corresponds to a scope call, which we don't want to transform*)
let new_args, free_vars =
List.fold_right
(fun arg (new_args, free_vars) ->
let new_arg, new_free_vars = aux arg in
new_arg :: new_args, Var.Set.union free_vars new_free_vars)
args ([], Var.Set.empty)
in
Expr.eapp (Bindlib.box_var v, v_pos) new_args m, free_vars
| EApp (e1, args) ->
let new_e1, free_vars = aux e1 in
Expr.map_gather ~acc:Var.Set.empty ~join:Var.Set.union ~f:aux e
| EApp { f = e1; args } ->
let free_vars, new_e1 = aux e1 in
let env_var = Var.make "env" in
let code_var = Var.make "code" in
let new_args, free_vars =
let free_vars, new_args =
List.fold_right
(fun arg (new_args, free_vars) ->
let new_arg, new_free_vars = aux arg in
new_arg :: new_args, Var.Set.union free_vars new_free_vars)
args ([], free_vars)
(fun arg (free_vars, new_args) ->
let new_free_vars, new_arg = aux arg in
Var.Set.union free_vars new_free_vars, new_arg :: new_args)
args (free_vars, [])
in
let call_expr =
let m1 = Marked.get_mark e1 in
@ -180,7 +148,8 @@ let closure_conversion_expr (type m) (ctx : m ctx) (e : m expr) : m expr boxed =
(TAny, Expr.pos e)
(Expr.etupleaccess
(Bindlib.box_var env_var, m1)
0 None [ (*TODO: fill?*) ]
0
(List.length new_args + 1)
m)
(Expr.eapp
(Bindlib.box_var code_var, m1)
@ -188,25 +157,12 @@ let closure_conversion_expr (type m) (ctx : m ctx) (e : m expr) : m expr boxed =
m)
(Expr.pos e)
in
( Expr.make_let_in env_var (TAny, Expr.pos e) new_e1 call_expr (Expr.pos e),
free_vars )
| EAssert e1 ->
let new_e1, free_vars = aux e1 in
Expr.eassert new_e1 m, free_vars
| EOp op -> Expr.eop op m, Var.Set.empty
| EIfThenElse (e1, e2, e3) ->
let new_e1, free_vars1 = aux e1 in
let new_e2, free_vars2 = aux e2 in
let new_e3, free_vars3 = aux e3 in
( Expr.eifthenelse new_e1 new_e2 new_e3 m,
Var.Set.union (Var.Set.union free_vars1 free_vars2) free_vars3 )
| ERaise except -> Expr.eraise except m, Var.Set.empty
| ECatch (e1, except, e2) ->
let new_e1, free_vars1 = aux e1 in
let new_e2, free_vars2 = aux e2 in
Expr.ecatch new_e1 except new_e2 m, Var.Set.union free_vars1 free_vars2
( free_vars,
Expr.make_let_in env_var
(TAny, Expr.pos e)
new_e1 call_expr (Expr.pos e) )
in
let e', _vars = aux e in
let _vars, e' = aux e in
e'
let closure_conversion (p : 'm program) : 'm program Bindlib.box =

View File

@ -14,7 +14,7 @@
License for the specific language governing permissions and limitations under
the License. *)
open Utils
open Catala_utils
open Shared_ast
module D = Dcalc.Ast
module A = Ast
@ -43,7 +43,7 @@ let rec translate_default
Expr.make_app
(Expr.make_var
(Var.translate A.handle_default)
(Expr.with_ty mark_default (Utils.Marked.mark pos TAny)))
(Expr.with_ty mark_default (Marked.mark pos TAny)))
[
Expr.earray exceptions mark_default;
thunk_expr (translate_expr ctx just);
@ -54,39 +54,39 @@ let rec translate_default
exceptions
and translate_expr (ctx : 'm ctx) (e : 'm D.expr) : 'm A.expr boxed =
let m = Marked.get_mark e in
match Marked.unmark e with
| EVar v -> Expr.make_var (Var.Map.find v ctx) (Marked.get_mark e)
| ETuple (args, s) ->
Expr.etuple (List.map (translate_expr ctx) args) s (Marked.get_mark e)
| ETupleAccess (e1, i, s, ts) ->
Expr.etupleaccess (translate_expr ctx e1) i s ts (Marked.get_mark e)
| EInj (e1, i, en, ts) ->
Expr.einj (translate_expr ctx e1) i en ts (Marked.get_mark e)
| EMatch (e1, cases, en) ->
Expr.ematch (translate_expr ctx e1)
(List.map (translate_expr ctx) cases)
en (Marked.get_mark e)
| EArray es ->
Expr.earray (List.map (translate_expr ctx) es) (Marked.get_mark e)
| EVar v -> Expr.make_var (Var.Map.find v ctx) m
| EStruct { name; fields } ->
Expr.estruct name (StructField.Map.map (translate_expr ctx) fields) m
| EStructAccess { name; e; field } ->
Expr.estructaccess (translate_expr ctx e) field name m
| EInj { name; e; cons } -> Expr.einj (translate_expr ctx e) cons name m
| EMatch { name; e; cases } ->
Expr.ematch (translate_expr ctx e) name
(EnumConstructor.Map.map (translate_expr ctx) cases)
m
| EArray es -> Expr.earray (List.map (translate_expr ctx) es) m
| ELit
((LBool _ | LInt _ | LRat _ | LMoney _ | LUnit | LDate _ | LDuration _) as
l) ->
Expr.elit l (Marked.get_mark e)
| ELit LEmptyError -> Expr.eraise EmptyError (Marked.get_mark e)
| EOp op -> Expr.eop op (Marked.get_mark e)
| EIfThenElse (e1, e2, e3) ->
Expr.eifthenelse (translate_expr ctx e1) (translate_expr ctx e2)
(translate_expr ctx e3) (Marked.get_mark e)
| EAssert e1 -> Expr.eassert (translate_expr ctx e1) (Marked.get_mark e)
| ErrorOnEmpty arg ->
Expr.elit l m
| ELit LEmptyError -> Expr.eraise EmptyError m
| EOp { op; tys } -> Expr.eop (Operator.translate op) tys m
| EIfThenElse { cond; etrue; efalse } ->
Expr.eifthenelse (translate_expr ctx cond) (translate_expr ctx etrue)
(translate_expr ctx efalse)
m
| EAssert e1 -> Expr.eassert (translate_expr ctx e1) m
| EErrorOnEmpty arg ->
Expr.ecatch (translate_expr ctx arg) EmptyError
(Expr.eraise NoValueProvided (Marked.get_mark e))
(Marked.get_mark e)
| EApp (e1, args) ->
Expr.eapp (translate_expr ctx e1)
(Expr.eraise NoValueProvided m)
m
| EApp { f; args } ->
Expr.eapp (translate_expr ctx f)
(List.map (translate_expr ctx) args)
(Marked.get_mark e)
| EAbs (binder, ts) ->
| EAbs { binder; tys } ->
let vars, body = Bindlib.unmbind binder in
let ctx, lc_vars =
Array.fold_right
@ -98,15 +98,16 @@ and translate_expr (ctx : 'm ctx) (e : 'm D.expr) : 'm A.expr boxed =
let lc_vars = Array.of_list lc_vars in
let new_body = translate_expr ctx body in
let new_binder = Expr.bind lc_vars new_body in
Expr.eabs new_binder ts (Marked.get_mark e)
| EDefault ([exn], just, cons) when !Cli.optimize_flag ->
Expr.eabs new_binder tys (Marked.get_mark e)
| EDefault { excepts = [exn]; just; cons } when !Cli.optimize_flag ->
(* FIXME: bad place to rely on a global flag *)
Expr.ecatch (translate_expr ctx exn) EmptyError
(Expr.eifthenelse (translate_expr ctx just) (translate_expr ctx cons)
(Expr.eraise EmptyError (Marked.get_mark e))
(Marked.get_mark e))
(Marked.get_mark e)
| EDefault (exceptions, just, cons) ->
translate_default ctx exceptions just cons (Marked.get_mark e)
| EDefault { excepts; just; cons } ->
translate_default ctx excepts just cons (Marked.get_mark e)
let rec translate_scope_lets
(decl_ctx : decl_ctx)

View File

@ -14,7 +14,7 @@
License for the specific language governing permissions and limitations under
the License. *)
open Utils
open Catala_utils
module D = Dcalc.Ast
module A = Ast
@ -170,7 +170,7 @@ let rec translate_and_hoist (ctx : 'm ctx) (e : 'm D.expr) :
created a variable %a to replace it" Print.var v Print.var v'; *)
Expr.make_var v' mark, Var.Map.singleton v' e
else (find ~info:"should never happen" v ctx).expr, Var.Map.empty
| EApp ((EVar v, p), [(ELit LUnit, _)]) ->
| EApp { f = EVar v, p; args = [(ELit LUnit, _)] } ->
if not (find ~info:"search for a variable" v ctx).is_pure then
let v' = Var.make (Bindlib.name_of v) in
(* Cli.debug_print @@ Format.asprintf "Found an unpure variable %a,
@ -179,7 +179,7 @@ let rec translate_and_hoist (ctx : 'm ctx) (e : 'm D.expr) :
else
Errors.raise_spanned_error (Expr.pos e)
"Internal error: an pure variable was found in an unpure environment."
| EDefault (_exceptions, _just, _cons) ->
| EDefault _ ->
let v' = Var.make "default_term" in
Expr.make_var v' mark, Var.Map.singleton v' e
| ELit LEmptyError ->
@ -187,7 +187,7 @@ let rec translate_and_hoist (ctx : 'm ctx) (e : 'm D.expr) :
Expr.make_var v' mark, Var.Map.singleton v' e
(* This one is a very special case. It transform an unpure expression
environement to a pure expression. *)
| ErrorOnEmpty arg ->
| EErrorOnEmpty arg ->
(* [ match arg with | None -> raise NoValueProvided | Some v -> {{ v }} ] *)
let silent_var = Var.make "_" in
let x = Var.make "non_empty_argument" in
@ -206,22 +206,23 @@ let rec translate_and_hoist (ctx : 'm ctx) (e : 'm D.expr) :
((LBool _ | LInt _ | LRat _ | LMoney _ | LUnit | LDate _ | LDuration _) as
l) ->
Expr.elit l mark, Var.Map.empty
| EIfThenElse (e1, e2, e3) ->
let e1', h1 = translate_and_hoist ctx e1 in
let e2', h2 = translate_and_hoist ctx e2 in
let e3', h3 = translate_and_hoist ctx e3 in
| EIfThenElse { cond; etrue; efalse } ->
let cond', h1 = translate_and_hoist ctx cond in
let etrue', h2 = translate_and_hoist ctx etrue in
let efalse', h3 = translate_and_hoist ctx efalse in
let e' = Expr.eifthenelse e1' e2' e3' mark in
let e' = Expr.eifthenelse cond' etrue' efalse' mark in
(*(* equivalent code : *) let e' = let+ e1' = e1' and+ e2' = e2' and+ e3' =
e3' in (A.EIfThenElse (e1', e2', e3'), pos) in *)
(*(* equivalent code : *) let e' = let+ cond' = cond' and+ etrue' = etrue'
and+ efalse' = efalse' in (A.EIfThenElse (cond', etrue', efalse'), pos)
in *)
e', disjoint_union_maps (Expr.pos e) [h1; h2; h3]
| EAssert e1 ->
(* same behavior as in the ICFP paper: if e1 is empty, then no error is
raised. *)
let e1', h1 = translate_and_hoist ctx e1 in
Expr.eassert e1' mark, h1
| EAbs (binder, ts) ->
| EAbs { binder; tys } ->
let vars, body = Bindlib.unmbind binder in
let ctx, lc_vars =
ArrayLabels.fold_right vars ~init:(ctx, []) ~f:(fun var (ctx, lc_vars) ->
@ -242,8 +243,8 @@ let rec translate_and_hoist (ctx : 'm ctx) (e : 'm D.expr) :
let new_body, hoists = translate_and_hoist ctx body in
let new_binder = Expr.bind lc_vars new_body in
Expr.eabs new_binder (List.map translate_typ ts) mark, hoists
| EApp (e1, args) ->
Expr.eabs new_binder (List.map translate_typ tys) mark, hoists
| EApp { f = e1; args } ->
let e1', h1 = translate_and_hoist ctx e1 in
let args', h_args =
args |> List.map (translate_and_hoist ctx) |> List.split
@ -252,35 +253,43 @@ let rec translate_and_hoist (ctx : 'm ctx) (e : 'm D.expr) :
let hoists = disjoint_union_maps (Expr.pos e) (h1 :: h_args) in
let e' = Expr.eapp e1' args' mark in
e', hoists
| ETuple (args, s) ->
let args', h_args =
args |> List.map (translate_and_hoist ctx) |> List.split
| EStruct { name; fields } ->
let fields', h_fields =
StructField.Map.fold
(fun field e (fields, hoists) ->
let e, h = translate_and_hoist ctx e in
StructField.Map.add field e fields, h :: hoists)
fields
(StructField.Map.empty, [])
in
let hoists = disjoint_union_maps (Expr.pos e) h_args in
Expr.etuple args' s mark, hoists
| ETupleAccess (e1, i, s, ts) ->
let hoists = disjoint_union_maps (Expr.pos e) h_fields in
Expr.estruct name fields' mark, hoists
| EStructAccess { name; e = e1; field } ->
let e1', hoists = translate_and_hoist ctx e1 in
let e1' = Expr.etupleaccess e1' i s ts mark in
let e1' = Expr.estructaccess e1' field name mark in
e1', hoists
| EInj (e1, i, en, ts) ->
| EInj { name; e = e1; cons } ->
let e1', hoists = translate_and_hoist ctx e1 in
let e1' = Expr.einj e1' i en ts mark in
let e1' = Expr.einj e1' cons name mark in
e1', hoists
| EMatch (e1, cases, en) ->
| EMatch { name; e = e1; cases } ->
let e1', h1 = translate_and_hoist ctx e1 in
let cases', h_cases =
cases |> List.map (translate_and_hoist ctx) |> List.split
EnumConstructor.Map.fold
(fun cons e (cases, hoists) ->
let e', h = translate_and_hoist ctx e in
EnumConstructor.Map.add cons e' cases, h :: hoists)
cases
(EnumConstructor.Map.empty, [])
in
let hoists = disjoint_union_maps (Expr.pos e) (h1 :: h_cases) in
let e' = Expr.ematch e1' cases' en mark in
let e' = Expr.ematch e1' name cases' mark in
e', hoists
| EArray es ->
let es', hoists = es |> List.map (translate_and_hoist ctx) |> List.split in
Expr.earray es' mark, disjoint_union_maps (Expr.pos e) hoists
| EOp op -> Expr.eop op mark, Var.Map.empty
| EOp { op; tys } -> Expr.eop (Operator.translate op) tys mark, Var.Map.empty
and translate_expr ?(append_esome = true) (ctx : 'm ctx) (e : 'm D.expr) :
'm A.expr boxed =
@ -302,14 +311,14 @@ and translate_expr ?(append_esome = true) (ctx : 'm ctx) (e : 'm D.expr) :
(* Here we have to handle only the cases appearing in hoists, as defined
the [translate_and_hoist] function. *)
| EVar v -> (find ~info:"should never happen" v ctx).expr
| EDefault (excep, just, cons) ->
let excep' = List.map (translate_expr ctx) excep in
| EDefault { excepts; just; cons } ->
let excepts' = List.map (translate_expr ctx) excepts in
let just' = translate_expr ctx just in
let cons' = translate_expr ctx cons in
(* calls handle_option. *)
Expr.make_app
(Expr.make_var (Var.translate A.handle_default_opt) mark_hoist)
[Expr.earray excep' mark_hoist; just'; cons']
[Expr.earray excepts' mark_hoist; just'; cons']
pos
| ELit LEmptyError -> A.make_none mark_hoist
| EAssert arg ->
@ -354,7 +363,7 @@ let rec translate_scope_let (ctx : 'm ctx) (lets : 'm D.expr scope_body_expr) :
{
scope_let_kind = SubScopeVarDefinition;
scope_let_typ = typ;
scope_let_expr = EAbs (binder, _), emark;
scope_let_expr = EAbs { binder; _ }, emark;
scope_let_next = next;
scope_let_pos = pos;
} ->
@ -385,7 +394,7 @@ let rec translate_scope_let (ctx : 'm ctx) (lets : 'm D.expr scope_body_expr) :
{
scope_let_kind = SubScopeVarDefinition;
scope_let_typ = typ;
scope_let_expr = (ErrorOnEmpty _, emark) as expr;
scope_let_expr = (EErrorOnEmpty _, emark) as expr;
scope_let_next = next;
scope_let_pos = pos;
} ->
@ -529,7 +538,7 @@ let translate_program (prgm : 'm D.program) : 'm A.program =
prgm.decl_ctx with
ctx_enums =
prgm.decl_ctx.ctx_enums
|> EnumMap.add A.option_enum A.option_enum_config;
|> EnumName.Map.add A.option_enum A.option_enum_config;
}
in
let decl_ctx =
@ -537,15 +546,14 @@ let translate_program (prgm : 'm D.program) : 'm A.program =
decl_ctx with
ctx_structs =
prgm.decl_ctx.ctx_structs
|> StructMap.mapi (fun n l ->
|> StructName.Map.mapi (fun n str ->
if List.mem n inputs_structs then
ListLabels.map l ~f:(fun (n, tau) ->
(* Cli.debug_print @@ Format.asprintf "Input type: %a"
(Print.typ decl_ctx) tau; Cli.debug_print @@
Format.asprintf "Output type: %a" (Print.typ decl_ctx)
(translate_typ tau); *)
n, translate_typ tau)
else l);
StructField.Map.map translate_typ str
(* Cli.debug_print @@ Format.asprintf "Input type: %a"
(Print.typ decl_ctx) tau; Cli.debug_print @@ Format.asprintf
"Output type: %a" (Print.typ decl_ctx) (translate_typ
tau); *)
else str);
}
in

View File

@ -0,0 +1,21 @@
(* This file is part of the Catala compiler, a specification language for tax
and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
Denis Merigoux <denis.merigoux@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not
use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *)
let translate_program_with_exceptions =
Compile_with_exceptions.translate_program
let translate_program_without_exceptions =
Compile_without_exceptions.translate_program

View File

@ -0,0 +1,26 @@
(* This file is part of the Catala compiler, a specification language for tax
and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
Denis Merigoux <denis.merigoux@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not
use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *)
val translate_program_with_exceptions : 'm Dcalc.Ast.program -> 'm Ast.program
(** Translation from the default calculus to the lambda calculus. This
translation uses exceptions to handle empty default terms. *)
val translate_program_without_exceptions :
'm Dcalc.Ast.program -> 'm Ast.program
(** Translation from the default calculus to the lambda calculus. This
translation uses an option monad to handle empty defaults terms. This
transformation is one piece to permit to compile toward legacy languages
that does not contains exceptions. *)

View File

@ -13,50 +13,47 @@
WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *)
open Utils
open Catala_utils
open Shared_ast
open Ast
module D = Dcalc.Ast
let visitor_map (t : 'a -> 'm expr -> 'm expr boxed) (ctx : 'a) (e : 'm expr) :
'm expr boxed =
Expr.map ctx ~f:t e
let visitor_map (t : 'm expr -> 'm expr boxed) (e : 'm expr) : 'm expr boxed =
Expr.map ~f:t e
let rec iota_expr (_ : unit) (e : 'm expr) : 'm expr boxed =
let rec iota_expr (e : 'm expr) : 'm expr boxed =
let m = Marked.get_mark e in
match Marked.unmark e with
| EMatch ((EInj (e1, i, n', _ts), _), cases, n) when EnumName.compare n n' = 0
->
let e1 = visitor_map iota_expr () e1 in
let case = visitor_map iota_expr () (List.nth cases i) in
| EMatch { e = EInj { e = e'; cons; name = n' }, _; cases; name = n }
when EnumName.equal n n' ->
let e1 = visitor_map iota_expr e' in
let case = visitor_map iota_expr (EnumConstructor.Map.find cons cases) in
Expr.eapp case [e1] m
| EMatch (e', cases, n)
| EMatch { e = e'; cases; name = n }
when cases
|> List.mapi (fun i (case, _pos) ->
match case with
| EInj (_ei, i', n', _ts') ->
i = i' && (* n = n' *) EnumName.compare n n' = 0
|> EnumConstructor.Map.mapi (fun i case ->
match Marked.unmark case with
| EInj { cons = i'; name = n'; _ } ->
EnumConstructor.equal i i' && EnumName.equal n n'
| _ -> false)
|> List.for_all Fun.id ->
visitor_map iota_expr () e'
| _ -> visitor_map iota_expr () e
|> EnumConstructor.Map.for_all (fun _ b -> b) ->
visitor_map iota_expr e'
| _ -> visitor_map iota_expr e
let rec beta_expr (e : 'm expr) : 'm expr boxed =
let m = Marked.get_mark e in
match Marked.unmark e with
| EApp (e1, args) ->
| EApp { f = e1; args } ->
Expr.Box.app1n (beta_expr e1) (List.map beta_expr args)
(fun e1 args ->
match Marked.unmark e1 with
| EAbs (binder, _) -> Marked.unmark (Expr.subst binder args)
| _ -> EApp (e1, args))
| EAbs { binder; _ } -> Marked.unmark (Expr.subst binder args)
| _ -> EApp { f = e1; args })
m
| _ -> visitor_map (fun () -> beta_expr) () e
| _ -> visitor_map beta_expr e
let iota_optimizations (p : 'm program) : 'm program =
let new_scopes =
Scope.map_exprs ~f:(iota_expr ()) ~varf:(fun v -> v) p.scopes
in
let new_scopes = Scope.map_exprs ~f:iota_expr ~varf:(fun v -> v) p.scopes in
{ p with scopes = Bindlib.unbox new_scopes }
(* TODO: beta optimizations apply inlining of the program. We left the inclusion
@ -70,30 +67,32 @@ let _beta_optimizations (p : 'm program) : 'm program =
let rec peephole_expr (e : 'm expr) : 'm expr boxed =
let m = Marked.get_mark e in
match Marked.unmark e with
| EIfThenElse (e1, e2, e3) ->
Expr.Box.app3 (peephole_expr e1) (peephole_expr e2) (peephole_expr e3)
(fun e1 e2 e3 ->
match Marked.unmark e1 with
| EIfThenElse { cond; etrue; efalse } ->
Expr.Box.app3 (peephole_expr cond) (peephole_expr etrue)
(peephole_expr efalse)
(fun cond etrue efalse ->
match Marked.unmark cond with
| ELit (LBool true)
| EApp ((EOp (Unop (Log _)), _), [(ELit (LBool true), _)]) ->
Marked.unmark e2
| EApp { f = EOp { op = Log _; _ }, _; args = [(ELit (LBool true), _)] }
->
Marked.unmark etrue
| ELit (LBool false)
| EApp ((EOp (Unop (Log _)), _), [(ELit (LBool false), _)]) ->
Marked.unmark e3
| _ -> EIfThenElse (e1, e2, e3))
| EApp
{ f = EOp { op = Log _; _ }, _; args = [(ELit (LBool false), _)] }
->
Marked.unmark efalse
| _ -> EIfThenElse { cond; etrue; efalse })
m
| ECatch (e1, except, e2) ->
Expr.Box.app2 (peephole_expr e1) (peephole_expr e2)
(fun e1 e2 ->
match Marked.unmark e1, Marked.unmark e2 with
| ERaise except', ERaise except''
when except' = except && except = except'' ->
ERaise except
| ERaise except', _ when except' = except -> Marked.unmark e2
| _, ERaise except' when except' = except -> Marked.unmark e1
| _ -> ECatch (e1, except, e2))
| ECatch { body; exn; handler } ->
Expr.Box.app2 (peephole_expr body) (peephole_expr handler)
(fun body handler ->
match Marked.unmark body, Marked.unmark handler with
| ERaise exn', ERaise exn'' when exn' = exn && exn = exn'' -> ERaise exn
| ERaise exn', _ when exn' = exn -> Marked.unmark handler
| _, ERaise exn' when exn' = exn -> Marked.unmark body
| _ -> ECatch { body; exn; handler })
m
| _ -> visitor_map (fun () -> peephole_expr) () e
| _ -> visitor_map peephole_expr e
let peephole_optimizations (p : 'm program) : 'm program =
let new_scopes =

View File

@ -14,24 +14,21 @@
License for the specific language governing permissions and limitations under
the License. *)
open Utils
open Catala_utils
open Shared_ast
open Ast
open String_common
module D = Dcalc.Ast
let find_struct (s : StructName.t) (ctx : decl_ctx) :
(StructFieldName.t * typ) list =
try StructMap.find s ctx.ctx_structs
let find_struct (s : StructName.t) (ctx : decl_ctx) : typ StructField.Map.t =
try StructName.Map.find s ctx.ctx_structs
with Not_found ->
let s_name, pos = StructName.get_info s in
Errors.raise_spanned_error pos
"Internal Error: Structure %s was not found in the current environment."
s_name
let find_enum (en : EnumName.t) (ctx : decl_ctx) :
(EnumConstructor.t * typ) list =
try EnumMap.find en ctx.ctx_enums
let find_enum (en : EnumName.t) (ctx : decl_ctx) : typ EnumConstructor.Map.t =
try EnumName.Map.find en ctx.ctx_enums
with Not_found ->
let en_name, pos = EnumName.get_info en in
Errors.raise_spanned_error pos
@ -57,43 +54,13 @@ let format_lit (fmt : Format.formatter) (l : lit Marked.pos) : unit =
let years, months, days = Runtime.duration_to_years_months_days d in
Format.fprintf fmt "duration_of_numbers (%d) (%d) (%d)" years months days
let format_op_kind (fmt : Format.formatter) (k : op_kind) =
Format.fprintf fmt "%s"
(match k with
| KInt -> "!"
| KRat -> "&"
| KMoney -> "$"
| KDate -> "@"
| KDuration -> "^")
let format_binop (fmt : Format.formatter) (op : binop Marked.pos) : unit =
match Marked.unmark op with
| Add k -> Format.fprintf fmt "+%a" format_op_kind k
| Sub k -> Format.fprintf fmt "-%a" format_op_kind k
| Mult k -> Format.fprintf fmt "*%a" format_op_kind k
| Div k -> Format.fprintf fmt "/%a" format_op_kind k
| And -> Format.fprintf fmt "%s" "&&"
| Or -> Format.fprintf fmt "%s" "||"
| Eq -> Format.fprintf fmt "%s" "="
| Neq | Xor -> Format.fprintf fmt "%s" "<>"
| Lt k -> Format.fprintf fmt "%s%a" "<" format_op_kind k
| Lte k -> Format.fprintf fmt "%s%a" "<=" format_op_kind k
| Gt k -> Format.fprintf fmt "%s%a" ">" format_op_kind k
| Gte k -> Format.fprintf fmt "%s%a" ">=" format_op_kind k
| Concat -> Format.fprintf fmt "@"
| Map -> Format.fprintf fmt "Array.map"
| Filter -> Format.fprintf fmt "array_filter"
let format_ternop (fmt : Format.formatter) (op : ternop Marked.pos) : unit =
match Marked.unmark op with Fold -> Format.fprintf fmt "Array.fold_left"
let format_uid_list (fmt : Format.formatter) (uids : Uid.MarkedString.info list)
: unit =
Format.fprintf fmt "@[<hov 2>[%a]@]"
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ";@ ")
(fun fmt info ->
Format.fprintf fmt "\"%a\"" Utils.Uid.MarkedString.format_info info))
Format.fprintf fmt "\"%a\"" Uid.MarkedString.format info))
uids
let format_string_list (fmt : Format.formatter) (uids : string list) : unit =
@ -106,26 +73,6 @@ let format_string_list (fmt : Format.formatter) (uids : string list) : unit =
(Re.replace sanitize_quotes ~f:(fun _ -> "\\\"") info)))
uids
let format_unop (fmt : Format.formatter) (op : unop Marked.pos) : unit =
match Marked.unmark op with
| Minus k -> Format.fprintf fmt "~-%a" format_op_kind k
| Not -> Format.fprintf fmt "%s" "not"
| Log (_entry, _infos) ->
Errors.raise_spanned_error (Marked.get_mark op)
"Internal error: a log operator has not been caught by the expression \
match"
| Length -> Format.fprintf fmt "%s" "array_length"
| IntToRat -> Format.fprintf fmt "%s" "decimal_of_integer"
| MoneyToRat -> Format.fprintf fmt "%s" "decimal_of_money"
| RatToMoney -> Format.fprintf fmt "%s" "money_of_decimal"
| GetDay -> Format.fprintf fmt "%s" "day_of_month_of_date"
| GetMonth -> Format.fprintf fmt "%s" "month_number_of_date"
| GetYear -> Format.fprintf fmt "%s" "year_of_date"
| FirstDayOfMonth -> Format.fprintf fmt "%s" "first_day_of_month"
| LastDayOfMonth -> Format.fprintf fmt "%s" "last_day_of_month"
| RoundMoney -> Format.fprintf fmt "%s" "money_round"
| RoundDecimal -> Format.fprintf fmt "%s" "decimal_round"
let avoid_keywords (s : string) : string =
match s with
(* list taken from
@ -137,14 +84,14 @@ let avoid_keywords (s : string) : string =
| "match" | "method" | "mod" | "module" | "mutable" | "new" | "nonrec"
| "object" | "of" | "open" | "or" | "private" | "rec" | "sig" | "struct"
| "then" | "to" | "true" | "try" | "type" | "val" | "virtual" | "when"
| "while" | "with" ->
| "while" | "with" | "Stdlib" | "Runtime" | "Oper" ->
s ^ "_user"
| _ -> s
let format_struct_name (fmt : Format.formatter) (v : StructName.t) : unit =
Format.asprintf "%a" StructName.format_t v
|> to_ascii
|> to_snake_case
|> String.to_ascii
|> String.to_snake_case
|> avoid_keywords
|> Format.fprintf fmt "%s"
@ -154,8 +101,8 @@ let format_to_module_name
(match name with
| `Ename v -> Format.asprintf "%a" EnumName.format_t v
| `Sname v -> Format.asprintf "%a" StructName.format_t v)
|> to_ascii
|> to_snake_case
|> String.to_ascii
|> String.to_snake_case
|> avoid_keywords
|> String.split_on_char '_'
|> List.map String.capitalize_ascii
@ -164,24 +111,25 @@ let format_to_module_name
let format_struct_field_name
(fmt : Format.formatter)
((sname_opt, v) : StructName.t option * StructFieldName.t) : unit =
((sname_opt, v) : StructName.t option * StructField.t) : unit =
(match sname_opt with
| Some sname ->
Format.fprintf fmt "%a.%s" format_to_module_name (`Sname sname)
| None -> Format.fprintf fmt "%s")
(avoid_keywords
(to_ascii (Format.asprintf "%a" StructFieldName.format_t v)))
(String.to_ascii (Format.asprintf "%a" StructField.format_t v)))
let format_enum_name (fmt : Format.formatter) (v : EnumName.t) : unit =
Format.fprintf fmt "%s"
(avoid_keywords
(to_snake_case (to_ascii (Format.asprintf "%a" EnumName.format_t v))))
(String.to_snake_case
(String.to_ascii (Format.asprintf "%a" EnumName.format_t v))))
let format_enum_cons_name (fmt : Format.formatter) (v : EnumConstructor.t) :
unit =
Format.fprintf fmt "%s"
(avoid_keywords
(to_ascii (Format.asprintf "%a" EnumConstructor.format_t v)))
(String.to_ascii (Format.asprintf "%a" EnumConstructor.format_t v)))
let rec typ_embedding_name (fmt : Format.formatter) (ty : typ) : unit =
match Marked.unmark ty with
@ -225,25 +173,27 @@ let rec format_typ (fmt : Format.formatter) (typ : typ) : unit =
| TAny -> Format.fprintf fmt "_"
let format_var (fmt : Format.formatter) (v : 'm Var.t) : unit =
let lowercase_name = to_snake_case (to_ascii (Bindlib.name_of v)) in
let lowercase_name =
String.to_snake_case (String.to_ascii (Bindlib.name_of v))
in
let lowercase_name =
Re.Pcre.substitute ~rex:(Re.Pcre.regexp "\\.")
~subst:(fun _ -> "_dot_")
lowercase_name
in
let lowercase_name = avoid_keywords (to_ascii lowercase_name) in
let lowercase_name = avoid_keywords (String.to_ascii lowercase_name) in
if
List.mem lowercase_name ["handle_default"; "handle_default_opt"]
|| begins_with_uppercase (Bindlib.name_of v)
then Format.fprintf fmt "%s" lowercase_name
else if lowercase_name = "_" then Format.fprintf fmt "%s" lowercase_name
|| String.begins_with_uppercase (Bindlib.name_of v)
then Format.pp_print_string fmt lowercase_name
else if lowercase_name = "_" then Format.pp_print_string fmt lowercase_name
else (
Cli.debug_print "lowercase_name: %s " lowercase_name;
Format.fprintf fmt "%s_" lowercase_name)
let needs_parens (e : 'm expr) : bool =
match Marked.unmark e with
| EApp ((EAbs (_, _), _), _)
| EApp { f = EAbs _, _; _ }
| ELit (LBool _ | LUnit)
| EVar _ | ETuple _ | EOp _ ->
false
@ -279,56 +229,52 @@ let rec format_expr (ctx : decl_ctx) (fmt : Format.formatter) (e : 'm expr) :
in
match Marked.unmark e with
| EVar v -> Format.fprintf fmt "%a" format_var v
| ETuple (es, None) ->
| ETuple es ->
Format.fprintf fmt "@[<hov 2>(%a)@]"
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ")
(fun fmt e -> Format.fprintf fmt "%a" format_with_parens e))
es
| ETuple (es, Some s) ->
if List.length es = 0 then Format.fprintf fmt "()"
| EStruct { name = s; fields = es } ->
if StructField.Map.is_empty es then Format.fprintf fmt "()"
else
Format.fprintf fmt "{@[<hov 2>%a@]}"
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ";@ ")
(fun fmt (e, struct_field) ->
(fun fmt (struct_field, e) ->
Format.fprintf fmt "@[<hov 2>%a =@ %a@]" format_struct_field_name
(Some s, struct_field) format_with_parens e))
(List.combine es (List.map fst (find_struct s ctx)))
(StructField.Map.bindings es)
| EArray es ->
Format.fprintf fmt "@[<hov 2>[|%a|]@]"
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ";@ ")
(fun fmt e -> Format.fprintf fmt "%a" format_with_parens e))
es
| ETupleAccess (e1, n, s, ts) -> (
match s with
| None ->
Format.fprintf fmt "let@ %a@ = %a@ in@ x"
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ")
(fun fmt i -> Format.fprintf fmt "%s" (if i = n then "x" else "_")))
(List.mapi (fun i _ -> i) ts)
format_with_parens e1
| Some s ->
Format.fprintf fmt "%a.%a" format_with_parens e1 format_struct_field_name
(Some s, fst (List.nth (find_struct s ctx) n)))
| EInj (e, n, en, _ts) ->
Format.fprintf fmt "@[<hov 2>%a.%a@ %a@]" format_to_module_name (`Ename en)
format_enum_cons_name
(fst (List.nth (find_enum en ctx) n))
format_with_parens e
| EMatch (e, es, e_name) ->
| ETupleAccess { e; index; size } ->
Format.fprintf fmt "let@ %a@ = %a@ in@ x"
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ")
(fun fmt i ->
Format.pp_print_string fmt (if i = index then "x" else "_")))
(List.init size Fun.id) format_with_parens e
| EStructAccess { e; field; name } ->
Format.fprintf fmt "%a.%a" format_with_parens e format_struct_field_name
(Some name, field)
| EInj { e; cons; name } ->
Format.fprintf fmt "@[<hov 2>%a.%a@ %a@]" format_to_module_name
(`Ename name) format_enum_cons_name cons format_with_parens e
| EMatch { e; cases; name } ->
Format.fprintf fmt "@[<hv>@[<hov 2>match@ %a@]@ with@\n| %a@]"
format_with_parens e
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@ | ")
(fun fmt (e, c) ->
(fun fmt (c, e) ->
Format.fprintf fmt "@[<hov 2>%a.%a %a@]" format_to_module_name
(`Ename e_name) format_enum_cons_name c
(`Ename name) format_enum_cons_name c
(fun fmt e ->
match Marked.unmark e with
| EAbs (binder, _) ->
| EAbs { binder; _ } ->
let xs, body = Bindlib.unmbind binder in
Format.fprintf fmt "%a ->@ %a"
(Format.pp_print_list
@ -338,11 +284,11 @@ let rec format_expr (ctx : decl_ctx) (fmt : Format.formatter) (e : 'm expr) :
| _ -> assert false
(* should not happen *))
e))
(List.combine es (List.map fst (find_enum e_name ctx)))
(EnumConstructor.Map.bindings cases)
| ELit l -> Format.fprintf fmt "%a" format_lit (Marked.mark (Expr.pos e) l)
| EApp ((EAbs (binder, taus), _), args) ->
| EApp { f = EAbs { binder; tys }, _; args } ->
let xs, body = Bindlib.unmbind binder in
let xs_tau = List.map2 (fun x tau -> x, tau) (Array.to_list xs) taus in
let xs_tau = List.map2 (fun x tau -> x, tau) (Array.to_list xs) tys in
let xs_tau_arg = List.map2 (fun (x, tau) arg -> x, tau, arg) xs_tau args in
Format.fprintf fmt "(%a%a)"
(Format.pp_print_list
@ -351,30 +297,28 @@ let rec format_expr (ctx : decl_ctx) (fmt : Format.formatter) (e : 'm expr) :
Format.fprintf fmt "@[<hov 2>let@ %a@ :@ %a@ =@ %a@]@ in@\n"
format_var x format_typ tau format_with_parens arg))
xs_tau_arg format_with_parens body
| EAbs (binder, taus) ->
| EAbs { binder; tys } ->
let xs, body = Bindlib.unmbind binder in
let xs_tau = List.map2 (fun x tau -> x, tau) (Array.to_list xs) taus in
let xs_tau = List.map2 (fun x tau -> x, tau) (Array.to_list xs) tys in
Format.fprintf fmt "@[<hov 2>fun@ %a ->@ %a@]"
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ")
(fun fmt (x, tau) ->
Format.fprintf fmt "@[<hov 2>(%a:@ %a)@]" format_var x format_typ tau))
xs_tau format_expr body
| EApp ((EOp (Binop ((Map | Filter) as op)), _), [arg1; arg2]) ->
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@]" format_binop (op, Pos.no_pos)
format_with_parens arg1 format_with_parens arg2
| EApp ((EOp (Binop op), _), [arg1; arg2]) ->
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@]" format_with_parens arg1
format_binop (op, Pos.no_pos) format_with_parens arg2
| EApp ((EApp ((EOp (Unop (Log (BeginCall, info))), _), [f]), _), [arg])
| EApp
{
f = EApp { f = EOp { op = Log (BeginCall, info); _ }, _; args = [f] }, _;
args = [arg];
}
when !Cli.trace_flag ->
Format.fprintf fmt "(log_begin_call@ %a@ %a)@ %a" format_uid_list info
format_with_parens f format_with_parens arg
| EApp ((EOp (Unop (Log (VarDef tau, info))), _), [arg1]) when !Cli.trace_flag
->
| EApp { f = EOp { op = Log (VarDef tau, info); _ }, _; args = [arg1] }
when !Cli.trace_flag ->
Format.fprintf fmt "(log_variable_definition@ %a@ (%a)@ %a)" format_uid_list
info typ_embedding_name (tau, Pos.no_pos) format_with_parens arg1
| EApp ((EOp (Unop (Log (PosRecordIfTrueBool, _))), m), [arg1])
| EApp { f = EOp { op = Log (PosRecordIfTrueBool, _); _ }, m; args = [arg1] }
when !Cli.trace_flag ->
let pos = Expr.mark_pos m in
Format.fprintf fmt
@ -383,15 +327,13 @@ let rec format_expr (ctx : decl_ctx) (fmt : Format.formatter) (e : 'm expr) :
(Pos.get_file pos) (Pos.get_start_line pos) (Pos.get_start_column pos)
(Pos.get_end_line pos) (Pos.get_end_column pos) format_string_list
(Pos.get_law_info pos) format_with_parens arg1
| EApp ((EOp (Unop (Log (EndCall, info))), _), [arg1]) when !Cli.trace_flag ->
| EApp { f = EOp { op = Log (EndCall, info); _ }, _; args = [arg1] }
when !Cli.trace_flag ->
Format.fprintf fmt "(log_end_call@ %a@ %a)" format_uid_list info
format_with_parens arg1
| EApp ((EOp (Unop (Log _)), _), [arg1]) ->
| EApp { f = EOp { op = Log _; _ }, _; args = [arg1] } ->
Format.fprintf fmt "%a" format_with_parens arg1
| EApp ((EOp (Unop op), _), [arg1]) ->
Format.fprintf fmt "@[<hov 2>%a@ %a@]" format_unop (op, Pos.no_pos)
format_with_parens arg1
| EApp ((EVar x, pos), args)
| EApp { f = EVar x, pos; args }
when Var.compare x (Var.translate Ast.handle_default) = 0
|| Var.compare x (Var.translate Ast.handle_default_opt) = 0 ->
Format.fprintf fmt
@ -409,19 +351,17 @@ let rec format_expr (ctx : decl_ctx) (fmt : Format.formatter) (e : 'm expr) :
~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ")
format_with_parens)
args
| EApp (f, args) ->
| EApp { f; args } ->
Format.fprintf fmt "@[<hov 2>%a@ %a@]" format_with_parens f
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ")
format_with_parens)
args
| EIfThenElse (e1, e2, e3) ->
| EIfThenElse { cond; etrue; efalse } ->
Format.fprintf fmt
"@[<hov 2> if@ @[<hov 2>%a@]@ then@ @[<hov 2>%a@]@ else@ @[<hov 2>%a@]@]"
format_with_parens e1 format_with_parens e2 format_with_parens e3
| EOp (Ternop op) -> Format.fprintf fmt "%a" format_ternop (op, Pos.no_pos)
| EOp (Binop op) -> Format.fprintf fmt "%a" format_binop (op, Pos.no_pos)
| EOp (Unop op) -> Format.fprintf fmt "%a" format_unop (op, Pos.no_pos)
format_with_parens cond format_with_parens etrue format_with_parens efalse
| EOp { op; _ } -> Format.pp_print_string fmt (Operator.name op)
| EAssert e' ->
Format.fprintf fmt
"@[<hov 2>if@ %a@ then@ ()@ else@ raise (AssertionFailed @[<hov \
@ -437,18 +377,17 @@ let rec format_expr (ctx : decl_ctx) (fmt : Format.formatter) (e : 'm expr) :
(Pos.get_law_info (Expr.pos e'))
| ERaise exc ->
Format.fprintf fmt "raise@ %a" format_exception (exc, Expr.pos e)
| ECatch (e1, exc, e2) ->
| ECatch { body; exn; handler } ->
Format.fprintf fmt
"@,@[<hv>@[<hov 2>try@ %a@]@ with@]@ @[<hov 2>%a@ ->@ %a@]"
format_with_parens e1 format_exception
(exc, Expr.pos e)
format_with_parens e2
format_with_parens body format_exception
(exn, Expr.pos e)
format_with_parens handler
let format_struct_embedding
(fmt : Format.formatter)
((struct_name, struct_fields) :
StructName.t * (StructFieldName.t * typ) list) =
if List.length struct_fields = 0 then
((struct_name, struct_fields) : StructName.t * typ StructField.Map.t) =
if StructField.Map.is_empty struct_fields then
Format.fprintf fmt "let embed_%a (_: %a.t) : runtime_value = Unit@\n@\n"
format_struct_name struct_name format_to_module_name (`Sname struct_name)
else
@ -461,16 +400,16 @@ let format_struct_embedding
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ";@\n")
(fun _fmt (struct_field, struct_field_type) ->
Format.fprintf fmt "(\"%a\",@ %a@ x.%a)" StructFieldName.format_t
Format.fprintf fmt "(\"%a\",@ %a@ x.%a)" StructField.format_t
struct_field typ_embedding_name struct_field_type
format_struct_field_name
(Some struct_name, struct_field)))
struct_fields
(StructField.Map.bindings struct_fields)
let format_enum_embedding
(fmt : Format.formatter)
((enum_name, enum_cases) : EnumName.t * (EnumConstructor.t * typ) list) =
if List.length enum_cases = 0 then
((enum_name, enum_cases) : EnumName.t * typ EnumConstructor.Map.t) =
if EnumConstructor.Map.is_empty enum_cases then
Format.fprintf fmt "let embed_%a (_: %a.t) : runtime_value = Unit@\n@\n"
format_to_module_name (`Ename enum_name) format_enum_name enum_name
else
@ -486,14 +425,14 @@ let format_enum_embedding
Format.fprintf fmt "@[<hov 2>| %a x ->@ (\"%a\", %a x)@]"
format_enum_cons_name enum_cons EnumConstructor.format_t enum_cons
typ_embedding_name enum_cons_type))
enum_cases
(EnumConstructor.Map.bindings enum_cases)
let format_ctx
(type_ordering : Scopelang.Dependency.TVertex.t list)
(fmt : Format.formatter)
(ctx : decl_ctx) : unit =
let format_struct_decl fmt (struct_name, struct_fields) =
if List.length struct_fields = 0 then
if StructField.Map.is_empty struct_fields then
Format.fprintf fmt
"@[<v 2>module %a = struct@\n@[<hov 2>type t = unit@]@]@\nend@\n"
format_to_module_name (`Sname struct_name)
@ -508,7 +447,7 @@ let format_ctx
(fun _fmt (struct_field, struct_field_type) ->
Format.fprintf fmt "@[<hov 2>%a:@ %a@]" format_struct_field_name
(None, struct_field) format_typ struct_field_type))
struct_fields;
(StructField.Map.bindings struct_fields);
if !Cli.trace_flag then
format_struct_embedding fmt (struct_name, struct_fields)
in
@ -521,7 +460,7 @@ let format_ctx
(fun _fmt (enum_cons, enum_cons_type) ->
Format.fprintf fmt "@[<hov 2>| %a@ of@ %a@]" format_enum_cons_name
enum_cons format_typ enum_cons_type))
enum_cons;
(EnumConstructor.Map.bindings enum_cons);
if !Cli.trace_flag then format_enum_embedding fmt (enum_name, enum_cons)
in
let is_in_type_ordering s =
@ -535,8 +474,8 @@ let format_ctx
let scope_structs =
List.map
(fun (s, _) -> Scopelang.Dependency.TVertex.Struct s)
(StructMap.bindings
(StructMap.filter
(StructName.Map.bindings
(StructName.Map.filter
(fun s _ -> not (is_in_type_ordering s))
ctx.ctx_structs))
in

View File

@ -14,29 +14,28 @@
License for the specific language governing permissions and limitations under
the License. *)
open Utils
open Shared_ast
open Ast
(** Formats a lambda calculus program into a valid OCaml program *)
val avoid_keywords : string -> string
val find_struct : StructName.t -> decl_ctx -> (StructFieldName.t * typ) list
val find_enum : EnumName.t -> decl_ctx -> (EnumConstructor.t * typ) list
val find_struct : StructName.t -> decl_ctx -> typ StructField.Map.t
val find_enum : EnumName.t -> decl_ctx -> typ EnumConstructor.Map.t
val typ_needs_parens : typ -> bool
val needs_parens : 'm expr -> bool
(* val needs_parens : 'm expr -> bool *)
val format_enum_name : Format.formatter -> EnumName.t -> unit
val format_enum_cons_name : Format.formatter -> EnumConstructor.t -> unit
val format_struct_name : Format.formatter -> StructName.t -> unit
val format_struct_field_name :
Format.formatter -> StructName.t option * StructFieldName.t -> unit
Format.formatter -> StructName.t option * StructField.t -> unit
val format_to_module_name :
Format.formatter -> [< `Ename of EnumName.t | `Sname of StructName.t ] -> unit
(* * val format_lit : Format.formatter -> lit Marked.pos -> unit * val
format_uid_list : Format.formatter -> Uid.MarkedString.info list -> unit *)
val format_lit : Format.formatter -> lit Marked.pos -> unit
val format_uid_list : Format.formatter -> Uid.MarkedString.info list -> unit
val format_var : Format.formatter -> 'm Var.t -> unit
val format_program :

View File

@ -1,7 +1,7 @@
(library
(name literate)
(public_name catala.literate)
(libraries re utils surface ubase))
(libraries re catala_utils surface ubase uutf))
(documentation
(package catala)

View File

@ -18,7 +18,7 @@
(** This modules weaves the source code and the legislative text together into a
document that law professionals can understand. *)
open Utils
open Catala_utils
open Literate_common
module A = Surface.Ast
module P = Printf
@ -91,7 +91,7 @@ let wrap_html
</ul>\n"
css_as_string (literal_title language)
(literal_generated_by language)
Utils.Cli.version
Cli.version
(pre_html (literal_disclaimer_and_link language))
(literal_source_files language)
(String.concat "\n"
@ -133,7 +133,7 @@ let pygmentize_code (c : string Marked.pos) (language : C.backend_lang) : string
"html";
"-O";
"style=colorful,anchorlinenos=True,lineanchors=\""
^ String_common.to_ascii (Pos.get_file (Marked.get_mark c))
^ String.to_ascii (Pos.get_file (Marked.get_mark c))
^ "\",linenos=table,linenostart="
^ string_of_int (Pos.get_start_line (Marked.get_mark c));
"-o";
@ -160,7 +160,7 @@ let pygmentize_code (c : string Marked.pos) (language : C.backend_lang) : string
let sanitize_html_href str =
str
|> String_common.to_ascii
|> String.to_ascii
|> R.substitute ~rex:(R.regexp "[' '°\"]") ~subst:(function _ -> "%20")
let rec law_structure_to_html

View File

@ -17,7 +17,7 @@
(** This modules weaves the source code and the legislative text together into a
document that law professionals can understand. *)
open Utils
open Catala_utils
(** {1 Helpers} *)

View File

@ -18,7 +18,7 @@
(** This modules weaves the source code and the legislative text together into a
document that law professionals can understand. *)
open Utils
open Catala_utils
open Literate_common
module A = Surface.Ast
module R = Re.Pcre
@ -61,7 +61,7 @@ let wrap_latex
%s
\usepackage{minted}
\usepackage{longtable}
\usepackage{booktabs}
\usepackage{booktabs,tabularx}
\usepackage{newunicodechar}
\usepackage{textcomp}
\usepackage[hidelinks]{hyperref}
@ -122,8 +122,8 @@ let wrap_latex
\newunicodechar{}{$\rightarrow$}
\newunicodechar{}{$\neq$}
\newcommand*\FancyVerbStartString{```catala}
\newcommand*\FancyVerbStopString{```}
\newcommand*\FancyVerbStartString{\PYG{l+s}{```catala}}
\newcommand*\FancyVerbStopString{\PYG{l+s}{```}}
\fvset{
numbers=left,
@ -151,14 +151,15 @@ codes={\catcode`\$=3\catcode`\^=7}
\tableofcontents
\[\star\star\star\]
\clearpage|latex}
\clearpage
|latex}
(match language with Fr -> "french" | En -> "english" | Pl -> "polish")
(match language with Fr -> "\\setmainfont{Marianne}" | _ -> "")
(* for France, we use the official font of the French state design system
https://gouvfr.atlassian.net/wiki/spaces/DB/pages/223019527/Typographie+-+Typography *)
(literal_title language)
(literal_generated_by language)
Utils.Cli.version
Cli.version
(pre_latexify (literal_disclaimer_and_link language))
(literal_source_files language)
(String.concat
@ -243,7 +244,7 @@ let rec law_structure_to_latex
| En -> "Metadata"
| Pl -> "Metadane"
in
let start_line = Pos.get_start_line (Marked.get_mark c) - 1 in
let start_line = Pos.get_start_line (Marked.get_mark c) + 1 in
let filename = Filename.basename (Pos.get_file (Marked.get_mark c)) in
let block_content = Marked.unmark c in
check_exceeding_lines start_line filename block_content;
@ -252,7 +253,7 @@ let rec law_structure_to_latex
"\\begin{tcolorbox}[colframe=OliveGreen, breakable, \
title=\\textcolor{black}{\\texttt{%s}},title after \
break=\\textcolor{black}{\\texttt{%s}},before skip=1em, after skip=1em]\n\
\\begin{minted}[numbersep=9mm, firstnumber=%d, breaklines, \
\\begin{minted}[numbersep=9mm, firstnumber=%d, \
label={\\hspace*{\\fill}\\texttt{%s}}]{%s}\n\
```catala\n\
%s```\n\

View File

@ -17,7 +17,7 @@
(** This modules weaves the source code and the legislative text together into a
document that law professionals can understand. *)
open Utils
open Catala_utils
(** {1 Helpers} *)

View File

@ -14,7 +14,7 @@
License for the specific language governing permissions and limitations under
the License. *)
open Utils
open Catala_utils
open Cli
let literal_title = function

View File

@ -14,32 +14,30 @@
License for the specific language governing permissions and limitations under
the License. *)
open Utils
open Catala_utils
val literal_title : Cli.backend_lang -> string
(** Return the title traduction according the given
{!type:Utils.Cli.backend_lang}. *)
(** Return the title traduction according the given {!type:Cli.backend_lang}. *)
val literal_generated_by : Cli.backend_lang -> string
(** Return the 'generated by' traduction according the given
{!type:Utils.Cli.backend_lang}. *)
{!type:Cli.backend_lang}. *)
val literal_source_files : Cli.backend_lang -> string
(** Return the 'source files weaved' traduction according the given
{!type:Utils.Cli.backend_lang}. *)
{!type:Cli.backend_lang}. *)
val literal_disclaimer_and_link : Cli.backend_lang -> string
(** Return the traduction of a paragraph giving a basic disclaimer about Catala
and a link to the website according the given {!type:
Utils.Cli.backend_lang}. *)
and a link to the website according the given {!type: Cli.backend_lang}. *)
val literal_last_modification : Cli.backend_lang -> string
(** Return the 'last modification' traduction according the given
{!type:Utils.Cli.backend_lang}. *)
{!type:Cli.backend_lang}. *)
val get_language_extension : Cli.backend_lang -> string
(** Return the file extension corresponding to the given
{!type:Utils.Cli.backend_lang}. *)
{!type:Cli.backend_lang}. *)
val run_pandoc : string -> [ `Html | `Latex ] -> string
(** Runs the [pandoc] on a string to pretty-print markdown features into the

View File

@ -14,8 +14,10 @@
License for the specific language governing permissions and limitations under
the License. *)
open Catala_utils
type 'ast plugin_apply_fun_typ =
source_file:Utils.Pos.input_file ->
source_file:Pos.input_file ->
output_file:string option ->
scope:string option ->
'ast ->
@ -51,17 +53,21 @@ let find name = Hashtbl.find backend_plugins (String.lowercase_ascii name)
let load_file f =
try
Dynlink.loadfile f;
Utils.Cli.debug_print "Plugin %S loaded" f
Cli.debug_print "Plugin %S loaded" f
with e ->
Utils.Errors.format_warning "Could not load plugin %S: %s" f
Errors.format_warning "Could not load plugin %S: %s" f
(Printexc.to_string e)
let load_dir d =
let rec load_dir d =
let dynlink_exts =
if Dynlink.is_native then [".cmxs"] else [".cmo"; ".cma"]
in
Array.iter
(fun f ->
if List.exists (Filename.check_suffix f) dynlink_exts then
load_file (Filename.concat d f))
if f.[0] = '.' then ()
else
let f = Filename.concat d f in
if Sys.is_directory f then load_dir f
else if List.exists (Filename.check_suffix f) dynlink_exts then
load_file f)
(Sys.readdir d)

View File

@ -16,8 +16,10 @@
(** {2 catala-facing API} *)
open Catala_utils
type 'ast plugin_apply_fun_typ =
source_file:Utils.Pos.input_file ->
source_file:Pos.input_file ->
output_file:string option ->
scope:string option ->
'ast ->

View File

@ -18,9 +18,8 @@
(** Catala plugin for generating web APIs. It generates OCaml code before the
the associated [js_of_ocaml] wrapper. *)
open Utils
open Catala_utils
open Shared_ast
open String_common
open Lcalc
open Lcalc.Ast
open Lcalc.To_ocaml
@ -40,11 +39,11 @@ module To_jsoo = struct
let format_struct_field_name_camel_case
(fmt : Format.formatter)
(v : StructFieldName.t) : unit =
(v : StructField.t) : unit =
let s =
Format.asprintf "%a" StructFieldName.format_t v
|> to_ascii
|> to_snake_case
Format.asprintf "%a" StructField.format_t v
|> String.to_ascii
|> String.to_snake_case
|> avoid_keywords
|> to_camel_case
in
@ -118,17 +117,17 @@ module To_jsoo = struct
let format_var_camel_case (fmt : Format.formatter) (v : 'm Var.t) : unit =
let lowercase_name =
Bindlib.name_of v
|> to_ascii
|> to_snake_case
|> String.to_ascii
|> String.to_snake_case
|> Re.Pcre.substitute ~rex:(Re.Pcre.regexp "\\.") ~subst:(fun _ ->
"_dot_")
|> to_ascii
|> String.to_ascii
|> avoid_keywords
|> to_camel_case
in
if
List.mem lowercase_name ["handle_default"; "handle_default_opt"]
|| begins_with_uppercase (Bindlib.name_of v)
|| String.begins_with_uppercase (Bindlib.name_of v)
then Format.fprintf fmt "%s" lowercase_name
else if lowercase_name = "_" then Format.fprintf fmt "%s" lowercase_name
else Format.fprintf fmt "%s_" lowercase_name
@ -166,7 +165,7 @@ module To_jsoo = struct
format_struct_field_name_camel_case struct_field
format_typ_to_jsoo struct_field_type fmt_struct_name ()
format_struct_field_name (None, struct_field)))
struct_fields
(StructField.Map.bindings struct_fields)
in
let fmt_of_jsoo fmt _ =
Format.fprintf fmt "%a"
@ -186,7 +185,7 @@ module To_jsoo = struct
format_struct_field_name (None, struct_field)
format_typ_of_jsoo struct_field_type fmt_struct_name ()
format_struct_field_name_camel_case struct_field))
struct_fields
(StructField.Map.bindings struct_fields)
in
let fmt_conv_funs fmt _ =
Format.fprintf fmt
@ -203,7 +202,7 @@ module To_jsoo = struct
() fmt_struct_name () fmt_module_struct_name () fmt_of_jsoo ()
in
if List.length struct_fields = 0 then
if StructField.Map.is_empty struct_fields then
Format.fprintf fmt
"class type %a =@ object end@\n\
let %a_to_jsoo (_ : %a.t) : %a Js.t = object%%js end@\n\
@ -220,11 +219,11 @@ module To_jsoo = struct
Format.fprintf fmt "@[<hov 2>method %a:@ %a %a@]"
format_struct_field_name_camel_case struct_field format_typ
struct_field_type format_prop_or_meth struct_field_type))
struct_fields fmt_conv_funs ()
(StructField.Map.bindings struct_fields)
fmt_conv_funs ()
in
let format_enum_decl
fmt
(enum_name, (enum_cons : (EnumConstructor.t * typ) list)) =
let format_enum_decl fmt (enum_name, (enum_cons : typ EnumConstructor.Map.t))
=
let fmt_enum_name fmt _ = format_enum_name fmt enum_name in
let fmt_module_enum_name fmt _ =
To_ocaml.format_to_module_name fmt (`Ename enum_name)
@ -247,7 +246,7 @@ module To_jsoo = struct
end@]"
format_enum_cons_name cname format_enum_cons_name cname
format_typ_to_jsoo typ))
enum_cons
(EnumConstructor.Map.bindings enum_cons)
in
let fmt_of_jsoo fmt _ =
Format.fprintf fmt
@ -273,7 +272,8 @@ module To_jsoo = struct
format_enum_cons_name cname fmt_module_enum_name ()
format_enum_cons_name cname format_typ_of_jsoo typ
fmt_enum_name ()))
enum_cons fmt_module_enum_name ()
(EnumConstructor.Map.bindings enum_cons)
fmt_module_enum_name ()
in
let fmt_conv_funs fmt _ =
@ -301,7 +301,8 @@ module To_jsoo = struct
~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n")
(fun fmt (enum_cons, _) ->
Format.fprintf fmt "- \"%a\"" format_enum_cons_name enum_cons))
enum_cons fmt_conv_funs ()
(EnumConstructor.Map.bindings enum_cons)
fmt_conv_funs ()
in
let is_in_type_ordering s =
List.exists
@ -314,8 +315,8 @@ module To_jsoo = struct
let scope_structs =
List.map
(fun (s, _) -> Scopelang.Dependency.TVertex.Struct s)
(StructMap.bindings
(StructMap.filter
(StructName.Map.bindings
(StructName.Map.filter
(fun s _ -> not (is_in_type_ordering s))
ctx.ctx_structs))
in

View File

@ -1,18 +1,22 @@
(executable
(library
(name python)
(modes plugin)
(public_name catala.plugins.python)
(synopsis
"Demonstration Catala plugin that reproduces the behaviour of the built-in python backend")
(modules python)
(libraries catala.driver))
(executable
(library
(name api_web)
(modes plugin)
(public_name catala.plugins.api_web)
(synopsis "Catala plugin for interaction with a web interface")
(modules api_web)
(libraries catala.driver))
(executable
(library
(name json_schema)
(modes plugin)
(public_name catala.plugins.json_schema)
(synopsis "Catala plugin generating JSON schemas useful to build web-forms")
(modules json_schema)
(libraries catala.driver))

View File

@ -20,8 +20,7 @@
let name = "json_schema"
let extension = "_schema.json"
open Utils
open String_common
open Catala_utils
open Shared_ast
open Lcalc.Ast
open Lcalc.To_ocaml
@ -38,11 +37,11 @@ module To_json = struct
let format_struct_field_name_camel_case
(fmt : Format.formatter)
(v : StructFieldName.t) : unit =
(v : StructField.t) : unit =
let s =
Format.asprintf "%a" StructFieldName.format_t v
|> to_ascii
|> to_snake_case
Format.asprintf "%a" StructField.format_t v
|> String.to_ascii
|> String.to_snake_case
|> avoid_keywords
|> to_camel_case
in
@ -97,7 +96,7 @@ module To_json = struct
(fun fmt (field_name, field_type) ->
Format.fprintf fmt "@[<hov 2>\"%a\": {@\n%a@]@\n}"
format_struct_field_name_camel_case field_name fmt_type field_type))
(find_struct sname ctx)
(StructField.Map.bindings (find_struct sname ctx))
let fmt_definitions
(ctx : decl_ctx)
@ -118,11 +117,14 @@ module To_json = struct
(t :: acc) @ collect_required_type_defs_from_scope_input s
| TEnum e ->
List.fold_left collect (t :: acc)
(List.map snd (EnumMap.find e ctx.ctx_enums))
(List.map snd
(EnumConstructor.Map.bindings
(EnumName.Map.find e ctx.ctx_enums)))
| TArray t -> collect acc t
| _ -> acc
in
find_struct input_struct ctx
|> StructField.Map.bindings
|> List.fold_left (fun acc (_, field_typ) -> collect acc field_typ) []
|> List.sort_uniq (fun t t' -> String.compare (get_name t) (get_name t'))
in
@ -146,7 +148,7 @@ module To_json = struct
Format.fprintf fmt
"@[<hov 2>{@\n\"type\": \"string\",@\n\"enum\": [\"%a\"]@]@\n}"
format_enum_cons_name enum_cons))
enum_def
(EnumConstructor.Map.bindings enum_def)
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ",@\n")
(fun fmt (enum_cons, payload_type) ->
@ -168,7 +170,7 @@ module To_json = struct
}@]@\n\
}"
format_enum_cons_name enum_cons fmt_type payload_type))
enum_def
(EnumConstructor.Map.bindings enum_def)
in
Format.fprintf fmt "@\n%a"

View File

@ -20,13 +20,15 @@
The code for the Python backend already has first-class support, so there
would be no reason to use this plugin instead *)
open Catala_utils
let name = "python-plugin"
let extension = ".py"
let apply ~source_file ~output_file ~scope prgm type_ordering =
ignore source_file;
ignore scope;
Utils.File.with_formatter_of_opt_file output_file
File.with_formatter_of_opt_file output_file
@@ fun fmt -> Scalc.To_python.format_program fmt prgm type_ordering
let () = Driver.Plugin.register_scalc ~name ~extension apply

View File

@ -14,7 +14,7 @@
License for the specific language governing permissions and limitations under
the License. *)
open Utils
open Catala_utils
open Shared_ast
module D = Dcalc.Ast
module L = Lcalc.Ast
@ -28,15 +28,15 @@ let handle_default_opt = TopLevelName.fresh ("handle_default_opt", Pos.no_pos)
type expr = naked_expr Marked.pos
and naked_expr =
| EVar of LocalName.t
| EFunc of TopLevelName.t
| EStruct of expr list * StructName.t
| EStructFieldAccess of expr * StructFieldName.t * StructName.t
| EInj of expr * EnumConstructor.t * EnumName.t
| EArray of expr list
| ELit of L.lit
| EApp of expr * expr list
| EOp of operator
| EVar : LocalName.t -> naked_expr
| EFunc : TopLevelName.t -> naked_expr
| EStruct : expr list * StructName.t -> naked_expr
| EStructFieldAccess : expr * StructField.t * StructName.t -> naked_expr
| EInj : expr * EnumConstructor.t * EnumName.t -> naked_expr
| EArray : expr list -> naked_expr
| ELit : L.lit -> naked_expr
| EApp : expr * expr list -> naked_expr
| EOp : (lcalc, _) operator -> naked_expr
type stmt =
| SInnerFuncDef of LocalName.t Marked.pos * func

View File

@ -14,7 +14,7 @@
License for the specific language governing permissions and limitations under
the License. *)
open Utils
open Catala_utils
open Shared_ast
module A = Ast
module L = Lcalc.Ast
@ -35,36 +35,37 @@ let rec translate_expr (ctxt : 'm ctxt) (expr : 'm L.expr) : A.block * A.expr =
| EVar v ->
let local_var =
try A.EVar (Var.Map.find v ctxt.var_dict)
with Not_found -> A.EFunc (Var.Map.find v ctxt.func_dict)
with Not_found -> (
try A.EFunc (Var.Map.find v ctxt.func_dict)
with Not_found ->
Errors.raise_spanned_error (Expr.pos expr)
"Var not found in lambda→scalc: %a@\nknown: @[<hov>%a@]@\n"
Print.var_debug v
(Format.pp_print_list ~pp_sep:Format.pp_print_space
(fun ppf (v, _) -> Print.var_debug ppf v))
(Var.Map.bindings ctxt.var_dict))
in
[], (local_var, Expr.pos expr)
| ETuple (args, Some s_name) ->
| EStruct { fields; name } ->
let args_stmts, new_args =
List.fold_left
(fun (args_stmts, new_args) arg ->
StructField.Map.fold
(fun _ arg (args_stmts, new_args) ->
let arg_stmts, new_arg = translate_expr ctxt arg in
arg_stmts @ args_stmts, new_arg :: new_args)
([], []) args
fields ([], [])
in
let new_args = List.rev new_args in
let args_stmts = List.rev args_stmts in
args_stmts, (A.EStruct (new_args, s_name), Expr.pos expr)
| ETuple (_, None) -> failwith "Non-struct tuples cannot be compiled to scalc"
| ETupleAccess (e1, num_field, Some s_name, _) ->
args_stmts, (A.EStruct (new_args, name), Expr.pos expr)
| ETuple _ -> failwith "Tuples cannot be compiled to scalc"
| EStructAccess { e = e1; field; name } ->
let e1_stmts, new_e1 = translate_expr ctxt e1 in
let field_name =
fst (List.nth (StructMap.find s_name ctxt.decl_ctx.ctx_structs) num_field)
in
e1_stmts, (A.EStructFieldAccess (new_e1, field_name, s_name), Expr.pos expr)
| ETupleAccess (_, _, None, _) ->
failwith "Non-struct tuples cannot be compiled to scalc"
| EInj (e1, num_cons, e_name, _) ->
e1_stmts, (A.EStructFieldAccess (new_e1, field, name), Expr.pos expr)
| ETupleAccess _ -> failwith "Non-struct tuples cannot be compiled to scalc"
| EInj { e = e1; cons; name } ->
let e1_stmts, new_e1 = translate_expr ctxt e1 in
let cons_name =
fst (List.nth (EnumMap.find e_name ctxt.decl_ctx.ctx_enums) num_cons)
in
e1_stmts, (A.EInj (new_e1, cons_name, e_name), Expr.pos expr)
| EApp (f, args) ->
e1_stmts, (A.EInj (new_e1, cons, name), Expr.pos expr)
| EApp { f; args } ->
let f_stmts, new_f = translate_expr ctxt f in
let args_stmts, new_args =
List.fold_left
@ -85,7 +86,7 @@ let rec translate_expr (ctxt : 'm ctxt) (expr : 'm L.expr) : A.block * A.expr =
in
let new_args = List.rev new_args in
args_stmts, (A.EArray new_args, Expr.pos expr)
| EOp op -> [], (A.EOp op, Expr.pos expr)
| EOp { op; _ } -> [], (A.EOp op, Expr.pos expr)
| ELit l -> [], (A.ELit l, Expr.pos expr)
| _ ->
let tmp_var =
@ -120,11 +121,11 @@ and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.expr) : A.block =
(* Assertions are always encapsulated in a unit-typed let binding *)
let e_stmts, new_e = translate_expr ctxt e in
e_stmts @ [A.SAssert (Marked.unmark new_e), Expr.pos block_expr]
| EApp ((EAbs (binder, taus), binder_mark), args) ->
| EApp { f = EAbs { binder; tys }, binder_mark; args } ->
(* This defines multiple local variables at the time *)
let binder_pos = Expr.mark_pos binder_mark in
let vars, body = Bindlib.unmbind binder in
let vars_tau = List.map2 (fun x tau -> x, tau) (Array.to_list vars) taus in
let vars_tau = List.map2 (fun x tau -> x, tau) (Array.to_list vars) tys in
let ctxt =
{
ctxt with
@ -167,10 +168,10 @@ and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.expr) : A.block =
in
let rest_of_block = translate_statements ctxt body in
local_decls @ List.flatten def_blocks @ rest_of_block
| EAbs (binder, taus) ->
| EAbs { binder; tys } ->
let vars, body = Bindlib.unmbind binder in
let binder_pos = Expr.pos block_expr in
let vars_tau = List.map2 (fun x tau -> x, tau) (Array.to_list vars) taus in
let vars_tau = List.map2 (fun x tau -> x, tau) (Array.to_list vars) tys in
let closure_name =
match ctxt.inside_definition_of with
| None -> A.LocalName.fresh (ctxt.context_name, Expr.pos block_expr)
@ -203,13 +204,13 @@ and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.expr) : A.block =
} ),
binder_pos );
]
| EMatch (e1, args, e_name) ->
| EMatch { e = e1; cases; name } ->
let e1_stmts, new_e1 = translate_expr ctxt e1 in
let new_args =
List.fold_left
(fun new_args arg ->
let new_cases =
EnumConstructor.Map.fold
(fun _ arg new_args ->
match Marked.unmark arg with
| EAbs (binder, _) ->
| EAbs { binder; _ } ->
let vars, body = Bindlib.unmbind binder in
assert (Array.length vars = 1);
let var = vars.(0) in
@ -223,20 +224,20 @@ and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.expr) : A.block =
(new_arg, scalc_var) :: new_args
| _ -> assert false
(* should not happen *))
[] args
cases []
in
let new_args = List.rev new_args in
e1_stmts @ [A.SSwitch (new_e1, e_name, new_args), Expr.pos block_expr]
| EIfThenElse (cond, e_true, e_false) ->
let new_args = List.rev new_cases in
e1_stmts @ [A.SSwitch (new_e1, name, new_args), Expr.pos block_expr]
| EIfThenElse { cond; etrue; efalse } ->
let cond_stmts, s_cond = translate_expr ctxt cond in
let s_e_true = translate_statements ctxt e_true in
let s_e_false = translate_statements ctxt e_false in
let s_e_true = translate_statements ctxt etrue in
let s_e_false = translate_statements ctxt efalse in
cond_stmts
@ [A.SIfThenElse (s_cond, s_e_true, s_e_false), Expr.pos block_expr]
| ECatch (e_try, except, e_catch) ->
let s_e_try = translate_statements ctxt e_try in
let s_e_catch = translate_statements ctxt e_catch in
[A.STryExcept (s_e_try, except, s_e_catch), Expr.pos block_expr]
| ECatch { body; exn; handler } ->
let s_e_try = translate_statements ctxt body in
let s_e_catch = translate_statements ctxt handler in
[A.STryExcept (s_e_try, exn, s_e_catch), Expr.pos block_expr]
| ERaise except ->
(* Before raising the exception, we still give a dummy definition to the
current variable so that tools like mypy don't complain. *)

View File

@ -14,7 +14,7 @@
License for the specific language governing permissions and limitations under
the License. *)
open Utils
open Catala_utils
open Shared_ast
open Ast
@ -44,11 +44,12 @@ let rec format_expr
Print.punctuation "{"
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ")
(fun fmt (e, struct_field) ->
(fun fmt (e, (struct_field, _)) ->
Format.fprintf fmt "%a%a%a%a %a" Print.punctuation "\""
StructFieldName.format_t struct_field Print.punctuation "\""
StructField.format_t struct_field Print.punctuation "\""
Print.punctuation ":" format_expr e))
(List.combine es (List.map fst (StructMap.find s decl_ctx.ctx_structs)))
(List.combine es
(StructField.Map.bindings (StructName.Map.find s decl_ctx.ctx_structs)))
Print.punctuation "}"
| EArray es ->
Format.fprintf fmt "@[<hov 2>%a%a%a@]" Print.punctuation "["
@ -56,41 +57,31 @@ let rec format_expr
~pp_sep:(fun fmt () -> Format.fprintf fmt ";@ ")
(fun fmt e -> Format.fprintf fmt "%a" format_expr e))
es Print.punctuation "]"
| EStructFieldAccess (e1, field, s) ->
| EStructFieldAccess (e1, field, _) ->
Format.fprintf fmt "%a%a%a%a%a" format_expr e1 Print.punctuation "."
Print.punctuation "\"" StructFieldName.format_t
(fst
(List.find
(fun (field', _) -> StructFieldName.compare field' field = 0)
(StructMap.find s decl_ctx.ctx_structs)))
Print.punctuation "\""
| EInj (e, case, enum) ->
Format.fprintf fmt "@[<hov 2>%a@ %a@]" Print.enum_constructor
(fst
(List.find
(fun (case', _) -> EnumConstructor.compare case' case = 0)
(EnumMap.find enum decl_ctx.ctx_enums)))
Print.punctuation "\"" StructField.format_t field Print.punctuation "\""
| EInj (e, cons, _) ->
Format.fprintf fmt "@[<hov 2>%a@ %a@]" Print.enum_constructor cons
format_expr e
| ELit l -> Print.lit fmt l
| EApp ((EOp (Binop ((Map | Filter) as op)), _), [arg1; arg2]) ->
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@]" Print.binop op format_with_parens
arg1 format_with_parens arg2
| EApp ((EOp (Binop op), _), [arg1; arg2]) ->
| EApp ((EOp ((Map | Filter) as op), _), [arg1; arg2]) ->
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@]" Print.operator op
format_with_parens arg1 format_with_parens arg2
| EApp ((EOp op, _), [arg1; arg2]) ->
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@]" format_with_parens arg1
Print.binop op format_with_parens arg2
| EApp ((EOp (Unop (Log _)), _), [arg1]) when not debug ->
Print.operator op format_with_parens arg2
| EApp ((EOp (Log _), _), [arg1]) when not debug ->
Format.fprintf fmt "%a" format_with_parens arg1
| EApp ((EOp (Unop op), _), [arg1]) ->
Format.fprintf fmt "@[<hov 2>%a@ %a@]" Print.unop op format_with_parens arg1
| EApp ((EOp op, _), [arg1]) ->
Format.fprintf fmt "@[<hov 2>%a@ %a@]" Print.operator op format_with_parens
arg1
| EApp (f, args) ->
Format.fprintf fmt "@[<hov 2>%a@ %a@]" format_expr f
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ")
format_with_parens)
args
| EOp (Ternop op) -> Format.fprintf fmt "%a" Print.ternop op
| EOp (Binop op) -> Format.fprintf fmt "%a" Print.binop op
| EOp (Unop op) -> Format.fprintf fmt "%a" Print.unop op
| EOp op -> Format.fprintf fmt "%a" Print.operator op
let rec format_statement
(decl_ctx : decl_ctx)
@ -101,22 +92,22 @@ let rec format_statement
match Marked.unmark stmt with
| SInnerFuncDef (name, func) ->
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@ %a@]@\n@[<v 2> %a@]" Print.keyword
"let" LocalName.format_t (Marked.unmark name)
"let" format_local_name (Marked.unmark name)
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ")
(fun fmt ((name, _), typ) ->
Format.fprintf fmt "%a%a %a@ %a%a" Print.punctuation "("
LocalName.format_t name Print.punctuation ":" (Print.typ decl_ctx)
format_local_name name Print.punctuation ":" (Print.typ decl_ctx)
typ Print.punctuation ")"))
func.func_params Print.punctuation "="
(format_block decl_ctx ~debug)
func.func_body
| SLocalDecl (name, typ) ->
Format.fprintf fmt "@[<hov 2>%a %a %a@ %a@]" Print.keyword "decl"
LocalName.format_t (Marked.unmark name) Print.punctuation ":"
format_local_name (Marked.unmark name) Print.punctuation ":"
(Print.typ decl_ctx) typ
| SLocalDef (name, naked_expr) ->
Format.fprintf fmt "@[<hov 2>%a %a@ %a@]" LocalName.format_t
Format.fprintf fmt "@[<hov 2>%a %a@ %a@]" format_local_name
(Marked.unmark name) Print.punctuation "="
(format_expr decl_ctx ~debug)
naked_expr
@ -156,10 +147,13 @@ let rec format_statement
(fun fmt ((case, _), (arm_block, payload_name)) ->
Format.fprintf fmt "%a %a%a@ %a @[<v 2>%a@ %a@]" Print.punctuation
"|" Print.enum_constructor case Print.punctuation ":"
LocalName.format_t payload_name Print.punctuation ""
format_local_name payload_name Print.punctuation ""
(format_block decl_ctx ~debug)
arm_block))
(List.combine (EnumMap.find enum decl_ctx.ctx_enums) arms)
(List.combine
(EnumConstructor.Map.bindings
(EnumName.Map.find enum decl_ctx.ctx_enums))
arms)
and format_block
(decl_ctx : decl_ctx)
@ -183,8 +177,8 @@ let format_scope
~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ")
(fun fmt ((name, _), typ) ->
Format.fprintf fmt "%a%a %a@ %a%a" Print.punctuation "("
LocalName.format_t name Print.punctuation ":" (Print.typ decl_ctx)
typ Print.punctuation ")"))
format_local_name name Print.punctuation ":" (Print.typ decl_ctx) typ
Print.punctuation ")"))
body.scope_body_func.func_params Print.punctuation "="
(format_block decl_ctx ~debug)
body.scope_body_func.func_body

View File

@ -15,21 +15,20 @@
the License. *)
[@@@warning "-32-27"]
open Utils
open Catala_utils
open Shared_ast
open Ast
open String_common
module Runtime = Runtime_ocaml.Runtime
module D = Dcalc.Ast
module L = Lcalc.Ast
let format_lit (fmt : Format.formatter) (l : L.lit Marked.pos) : unit =
match Marked.unmark l with
| LBool true -> Format.fprintf fmt "True"
| LBool false -> Format.fprintf fmt "False"
| LBool true -> Format.pp_print_string fmt "True"
| LBool false -> Format.pp_print_string fmt "False"
| LInt i ->
Format.fprintf fmt "integer_of_string(\"%s\")" (Runtime.integer_to_string i)
| LUnit -> Format.fprintf fmt "Unit()"
| LUnit -> Format.pp_print_string fmt "Unit()"
| LRat i -> Format.fprintf fmt "decimal_of_string(\"%a\")" Print.lit (LRat i)
| LMoney e ->
Format.fprintf fmt "money_of_cents_string(\"%s\")"
@ -45,31 +44,60 @@ let format_lit (fmt : Format.formatter) (l : L.lit Marked.pos) : unit =
let format_log_entry (fmt : Format.formatter) (entry : log_entry) : unit =
match entry with
| VarDef _ -> Format.fprintf fmt ":="
| BeginCall -> Format.fprintf fmt ""
| VarDef _ -> Format.pp_print_string fmt ":="
| BeginCall -> Format.pp_print_string fmt ""
| EndCall -> Format.fprintf fmt "%s" ""
| PosRecordIfTrueBool -> Format.fprintf fmt ""
| PosRecordIfTrueBool -> Format.pp_print_string fmt ""
let format_binop (fmt : Format.formatter) (op : binop Marked.pos) : unit =
let format_op
(type k)
(fmt : Format.formatter)
(op : (lcalc, k) operator Marked.pos) : unit =
match Marked.unmark op with
| Add _ | Concat -> Format.fprintf fmt "+"
| Sub _ -> Format.fprintf fmt "-"
| Mult _ -> Format.fprintf fmt "*"
| Div KInt -> Format.fprintf fmt "//"
| Div _ -> Format.fprintf fmt "/"
| And -> Format.fprintf fmt "and"
| Or -> Format.fprintf fmt "or"
| Eq -> Format.fprintf fmt "=="
| Neq | Xor -> Format.fprintf fmt "!="
| Lt _ -> Format.fprintf fmt "<"
| Lte _ -> Format.fprintf fmt "<="
| Gt _ -> Format.fprintf fmt ">"
| Gte _ -> Format.fprintf fmt ">="
| Map -> Format.fprintf fmt "list_map"
| Filter -> Format.fprintf fmt "list_filter"
let format_ternop (fmt : Format.formatter) (op : ternop Marked.pos) : unit =
match Marked.unmark op with Fold -> Format.fprintf fmt "list_fold_left"
| Log (entry, infos) -> assert false
| Minus_int | Minus_rat | Minus_mon | Minus_dur ->
Format.pp_print_string fmt "-"
(* Todo: use the names from [Operator.name] *)
| Not -> Format.pp_print_string fmt "not"
| Length -> Format.pp_print_string fmt "list_length"
| ToRat_int -> Format.pp_print_string fmt "decimal_of_integer"
| ToRat_mon -> Format.pp_print_string fmt "decimal_of_money"
| ToMoney_rat -> Format.pp_print_string fmt "money_of_decimal"
| GetDay -> Format.pp_print_string fmt "day_of_month_of_date"
| GetMonth -> Format.pp_print_string fmt "month_number_of_date"
| GetYear -> Format.pp_print_string fmt "year_of_date"
| FirstDayOfMonth -> Format.pp_print_string fmt "first_day_of_month"
| LastDayOfMonth -> Format.pp_print_string fmt "last_day_of_month"
| Round_mon -> Format.pp_print_string fmt "money_round"
| Round_rat -> Format.pp_print_string fmt "decimal_round"
| Add_int_int | Add_rat_rat | Add_mon_mon | Add_dat_dur | Add_dur_dur | Concat
->
Format.pp_print_string fmt "+"
| Sub_int_int | Sub_rat_rat | Sub_mon_mon | Sub_dat_dat | Sub_dat_dur
| Sub_dur_dur ->
Format.pp_print_string fmt "-"
| Mult_int_int | Mult_rat_rat | Mult_mon_rat | Mult_dur_int ->
Format.pp_print_string fmt "*"
| Div_int_int -> Format.pp_print_string fmt "//"
| Div_rat_rat | Div_mon_mon | Div_mon_rat -> Format.pp_print_string fmt "/"
| And -> Format.pp_print_string fmt "and"
| Or -> Format.pp_print_string fmt "or"
| Eq -> Format.pp_print_string fmt "=="
| Xor -> Format.pp_print_string fmt "!="
| Lt_int_int | Lt_rat_rat | Lt_mon_mon | Lt_dat_dat | Lt_dur_dur ->
Format.pp_print_string fmt "<"
| Lte_int_int | Lte_rat_rat | Lte_mon_mon | Lte_dat_dat | Lte_dur_dur ->
Format.pp_print_string fmt "<="
| Gt_int_int | Gt_rat_rat | Gt_mon_mon | Gt_dat_dat | Gt_dur_dur ->
Format.pp_print_string fmt ">"
| Gte_int_int | Gte_rat_rat | Gte_mon_mon | Gte_dat_dat | Gte_dur_dur ->
Format.pp_print_string fmt ">="
| Eq_int_int | Eq_rat_rat | Eq_mon_mon | Eq_dat_dat | Eq_dur_dur ->
Format.pp_print_string fmt "=="
| Map -> Format.pp_print_string fmt "list_map"
| Reduce -> Format.pp_print_string fmt "list_reduce"
| Filter -> Format.pp_print_string fmt "list_filter"
| Fold -> Format.pp_print_string fmt "list_fold_left"
let format_uid_list (fmt : Format.formatter) (uids : Uid.MarkedString.info list)
: unit =
@ -77,7 +105,7 @@ let format_uid_list (fmt : Format.formatter) (uids : Uid.MarkedString.info list)
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ")
(fun fmt info ->
Format.fprintf fmt "\"%a\"" Utils.Uid.MarkedString.format_info info))
Format.fprintf fmt "\"%a\"" Uid.MarkedString.format info))
uids
let format_string_list (fmt : Format.formatter) (uids : string list) : unit =
@ -90,23 +118,6 @@ let format_string_list (fmt : Format.formatter) (uids : string list) : unit =
(Re.replace sanitize_quotes ~f:(fun _ -> "\\\"") info)))
uids
let format_unop (fmt : Format.formatter) (op : unop Marked.pos) : unit =
match Marked.unmark op with
| Minus _ -> Format.fprintf fmt "-"
| Not -> Format.fprintf fmt "not"
| Log (entry, infos) -> assert false (* should not happen *)
| Length -> Format.fprintf fmt "%s" "list_length"
| IntToRat -> Format.fprintf fmt "%s" "decimal_of_integer"
| MoneyToRat -> Format.fprintf fmt "%s" "decimal_of_money"
| RatToMoney -> Format.fprintf fmt "%s" "money_of_decimal"
| GetDay -> Format.fprintf fmt "%s" "day_of_month_of_date"
| GetMonth -> Format.fprintf fmt "%s" "month_number_of_date"
| GetYear -> Format.fprintf fmt "%s" "year_of_date"
| FirstDayOfMonth -> Format.fprintf fmt "%s" "first_day_of_month"
| LastDayOfMonth -> Format.fprintf fmt "%s" "last_day_of_month"
| RoundMoney -> Format.fprintf fmt "%s" "money_round"
| RoundDecimal -> Format.fprintf fmt "%s" "decimal_round"
let avoid_keywords (s : string) : string =
if
match s with
@ -125,24 +136,26 @@ let avoid_keywords (s : string) : string =
let format_struct_name (fmt : Format.formatter) (v : StructName.t) : unit =
Format.fprintf fmt "%s"
(avoid_keywords
(to_camel_case (to_ascii (Format.asprintf "%a" StructName.format_t v))))
(String.to_camel_case
(String.to_ascii (Format.asprintf "%a" StructName.format_t v))))
let format_struct_field_name (fmt : Format.formatter) (v : StructFieldName.t) :
unit =
let format_struct_field_name (fmt : Format.formatter) (v : StructField.t) : unit
=
Format.fprintf fmt "%s"
(avoid_keywords
(to_ascii (Format.asprintf "%a" StructFieldName.format_t v)))
(String.to_ascii (Format.asprintf "%a" StructField.format_t v)))
let format_enum_name (fmt : Format.formatter) (v : EnumName.t) : unit =
Format.fprintf fmt "%s"
(avoid_keywords
(to_camel_case (to_ascii (Format.asprintf "%a" EnumName.format_t v))))
(String.to_camel_case
(String.to_ascii (Format.asprintf "%a" EnumName.format_t v))))
let format_enum_cons_name (fmt : Format.formatter) (v : EnumConstructor.t) :
unit =
Format.fprintf fmt "%s"
(avoid_keywords
(to_ascii (Format.asprintf "%a" EnumConstructor.format_t v)))
(String.to_ascii (Format.asprintf "%a" EnumConstructor.format_t v)))
let typ_needs_parens (e : typ) : bool =
match Marked.unmark e with TArrow _ | TArray _ -> true | _ -> false
@ -180,10 +193,10 @@ let rec format_typ (fmt : Format.formatter) (typ : typ) : unit =
let format_name_cleaned (fmt : Format.formatter) (s : string) : unit =
s
|> to_ascii
|> to_snake_case
|> String.to_ascii
|> String.to_snake_case
|> Re.Pcre.substitute ~rex:(Re.Pcre.regexp "\\.") ~subst:(fun _ -> "_dot_")
|> to_ascii
|> String.to_ascii
|> avoid_keywords
|> Format.fprintf fmt "%s"
@ -268,10 +281,11 @@ let rec format_expression (ctx : decl_ctx) (fmt : Format.formatter) (e : expr) :
Format.fprintf fmt "%a(%a)" format_struct_name s
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ")
(fun fmt (e, struct_field) ->
(fun fmt (e, (struct_field, _)) ->
Format.fprintf fmt "%a = %a" format_struct_field_name struct_field
(format_expression ctx) e))
(List.combine es (List.map fst (StructMap.find s ctx.ctx_structs)))
(List.combine es
(StructField.Map.bindings (StructName.Map.find s ctx.ctx_structs)))
| EStructFieldAccess (e1, field, _) ->
Format.fprintf fmt "%a.%a" (format_expression ctx) e1
format_struct_field_name field
@ -296,21 +310,20 @@ let rec format_expression (ctx : decl_ctx) (fmt : Format.formatter) (e : expr) :
(fun fmt e -> Format.fprintf fmt "%a" (format_expression ctx) e))
es
| ELit l -> Format.fprintf fmt "%a" format_lit (Marked.same_mark_as l e)
| EApp ((EOp (Binop ((Map | Filter) as op)), _), [arg1; arg2]) ->
Format.fprintf fmt "%a(%a,@ %a)" format_binop (op, Pos.no_pos)
| EApp ((EOp ((Map | Filter) as op), _), [arg1; arg2]) ->
Format.fprintf fmt "%a(%a,@ %a)" format_op (op, Pos.no_pos)
(format_expression ctx) arg1 (format_expression ctx) arg2
| EApp ((EOp (Binop op), _), [arg1; arg2]) ->
Format.fprintf fmt "(%a %a@ %a)" (format_expression ctx) arg1 format_binop
| EApp ((EOp op, _), [arg1; arg2]) ->
Format.fprintf fmt "(%a %a@ %a)" (format_expression ctx) arg1 format_op
(op, Pos.no_pos) (format_expression ctx) arg2
| EApp ((EApp ((EOp (Unop (Log (BeginCall, info))), _), [f]), _), [arg])
| EApp ((EApp ((EOp (Log (BeginCall, info)), _), [f]), _), [arg])
when !Cli.trace_flag ->
Format.fprintf fmt "log_begin_call(%a,@ %a,@ %a)" format_uid_list info
(format_expression ctx) f (format_expression ctx) arg
| EApp ((EOp (Unop (Log (VarDef tau, info))), _), [arg1]) when !Cli.trace_flag
->
| EApp ((EOp (Log (VarDef tau, info)), _), [arg1]) when !Cli.trace_flag ->
Format.fprintf fmt "log_variable_definition(%a,@ %a)" format_uid_list info
(format_expression ctx) arg1
| EApp ((EOp (Unop (Log (PosRecordIfTrueBool, _))), pos), [arg1])
| EApp ((EOp (Log (PosRecordIfTrueBool, _)), pos), [arg1])
when !Cli.trace_flag ->
Format.fprintf fmt
"log_decision_taken(SourcePosition(filename=\"%s\",@ start_line=%d,@ \
@ -318,16 +331,21 @@ let rec format_expression (ctx : decl_ctx) (fmt : Format.formatter) (e : expr) :
(Pos.get_file pos) (Pos.get_start_line pos) (Pos.get_start_column pos)
(Pos.get_end_line pos) (Pos.get_end_column pos) format_string_list
(Pos.get_law_info pos) (format_expression ctx) arg1
| EApp ((EOp (Unop (Log (EndCall, info))), _), [arg1]) when !Cli.trace_flag ->
| EApp ((EOp (Log (EndCall, info)), _), [arg1]) when !Cli.trace_flag ->
Format.fprintf fmt "log_end_call(%a,@ %a)" format_uid_list info
(format_expression ctx) arg1
| EApp ((EOp (Unop (Log _)), _), [arg1]) ->
| EApp ((EOp (Log _), _), [arg1]) ->
Format.fprintf fmt "%a" (format_expression ctx) arg1
| EApp ((EOp (Unop ((Minus _ | Not) as op)), _), [arg1]) ->
Format.fprintf fmt "%a %a" format_unop (op, Pos.no_pos)
| EApp ((EOp Not, _), [arg1]) ->
Format.fprintf fmt "%a %a" format_op (Not, Pos.no_pos)
(format_expression ctx) arg1
| EApp ((EOp (Unop op), _), [arg1]) ->
Format.fprintf fmt "%a(%a)" format_unop (op, Pos.no_pos)
| EApp
((EOp ((Minus_int | Minus_rat | Minus_mon | Minus_dur) as op), _), [arg1])
->
Format.fprintf fmt "%a %a" format_op (op, Pos.no_pos)
(format_expression ctx) arg1
| EApp ((EOp op, _), [arg1]) ->
Format.fprintf fmt "%a(%a)" format_op (op, Pos.no_pos)
(format_expression ctx) arg1
| EApp ((EFunc x, pos), args)
when Ast.TopLevelName.compare x Ast.handle_default = 0
@ -348,9 +366,7 @@ let rec format_expression (ctx : decl_ctx) (fmt : Format.formatter) (e : expr) :
~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ")
(format_expression ctx))
args
| EOp (Ternop op) -> Format.fprintf fmt "%a" format_ternop (op, Pos.no_pos)
| EOp (Binop op) -> Format.fprintf fmt "%a" format_binop (op, Pos.no_pos)
| EOp (Unop op) -> Format.fprintf fmt "%a" format_unop (op, Pos.no_pos)
| EOp op -> Format.fprintf fmt "%a" format_op (op, Pos.no_pos)
let rec format_statement
(ctx : decl_ctx)
@ -400,7 +416,7 @@ let rec format_statement
List.map2
(fun (x, y) (cons, _) -> x, y, cons)
cases
(EnumMap.find e_name ctx.ctx_enums)
(EnumConstructor.Map.bindings (EnumName.Map.find e_name ctx.ctx_enums))
in
let tmp_var = LocalName.fresh ("match_arg", Pos.no_pos) in
Format.fprintf fmt "%a = %a@\n@[<hov 4>if %a@]" format_var tmp_var
@ -442,6 +458,7 @@ let format_ctx
(fmt : Format.formatter)
(ctx : decl_ctx) : unit =
let format_struct_decl fmt (struct_name, struct_fields) =
let fields = StructField.Map.bindings struct_fields in
Format.fprintf fmt
"class %a:@\n\
\ def __init__(self, %a) -> None:@\n\
@ -461,40 +478,41 @@ let format_ctx
struct_name
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ", ")
(fun _fmt (struct_field, struct_field_type) ->
(fun fmt (struct_field, struct_field_type) ->
Format.fprintf fmt "%a: %a" format_struct_field_name struct_field
format_typ struct_field_type))
struct_fields
(if List.length struct_fields = 0 then fun fmt _ ->
fields
(if StructField.Map.is_empty struct_fields then fun fmt _ ->
Format.fprintf fmt " pass"
else
Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n")
(fun _fmt (struct_field, _) ->
(fun fmt (struct_field, _) ->
Format.fprintf fmt " self.%a = %a" format_struct_field_name
struct_field format_struct_field_name struct_field))
struct_fields format_struct_name struct_name
(if List.length struct_fields > 0 then
fields format_struct_name struct_name
(if not (StructField.Map.is_empty struct_fields) then
Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt " and@ ")
(fun _fmt (struct_field, _) ->
(fun fmt (struct_field, _) ->
Format.fprintf fmt "self.%a == other.%a" format_struct_field_name
struct_field format_struct_field_name struct_field)
else fun fmt _ -> Format.fprintf fmt "True")
struct_fields format_struct_name struct_name
fields format_struct_name struct_name
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ",")
(fun _fmt (struct_field, _) ->
(fun fmt (struct_field, _) ->
Format.fprintf fmt "%a={}" format_struct_field_name struct_field))
struct_fields
fields
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ")
(fun _fmt (struct_field, _) ->
(fun fmt (struct_field, _) ->
Format.fprintf fmt "self.%a" format_struct_field_name struct_field))
struct_fields
fields
in
let format_enum_decl fmt (enum_name, enum_cons) =
if List.length enum_cons = 0 then failwith "no constructors in the enum"
if EnumConstructor.Map.is_empty enum_cons then
failwith "no constructors in the enum"
else
Format.fprintf fmt
"@[<hov 4>class %a_Code(Enum):@\n\
@ -522,9 +540,11 @@ let format_ctx
format_enum_name enum_name
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n")
(fun _fmt (i, enum_cons, enum_cons_type) ->
(fun fmt (i, enum_cons, enum_cons_type) ->
Format.fprintf fmt "%a = %d" format_enum_cons_name enum_cons i))
(List.mapi (fun i (x, y) -> i, x, y) enum_cons)
(List.mapi
(fun i (x, y) -> i, x, y)
(EnumConstructor.Map.bindings enum_cons))
format_enum_name enum_name format_enum_name enum_name format_enum_name
enum_name
in
@ -540,8 +560,8 @@ let format_ctx
let scope_structs =
List.map
(fun (s, _) -> Scopelang.Dependency.TVertex.Struct s)
(StructMap.bindings
(StructMap.filter
(StructName.Map.bindings
(StructName.Map.filter
(fun s _ -> not (is_in_type_ordering s))
ctx.ctx_structs))
in
@ -550,10 +570,10 @@ let format_ctx
match struct_or_enum with
| Scopelang.Dependency.TVertex.Struct s ->
Format.fprintf fmt "%a@\n@\n" format_struct_decl
(s, StructMap.find s ctx.ctx_structs)
(s, StructName.Map.find s ctx.ctx_structs)
| Scopelang.Dependency.TVertex.Enum e ->
Format.fprintf fmt "%a@\n@\n" format_enum_decl
(e, EnumMap.find e ctx.ctx_enums))
(e, EnumName.Map.find e ctx.ctx_enums))
(type_ordering @ scope_structs)
let format_program

View File

@ -14,7 +14,7 @@
License for the specific language governing permissions and limitations under
the License. *)
open Utils
open Catala_utils
open Shared_ast
type location = scopelang glocation
@ -31,7 +31,7 @@ type 'm expr = (scopelang, 'm mark) gexpr
let rec locations_used (e : 'm expr) : LocationSet.t =
match e with
| ELocation l, pos -> LocationSet.singleton (l, Expr.mark_pos pos)
| EAbs (binder, _), _ ->
| EAbs { binder; _ }, _ ->
let _, body = Bindlib.unmbind binder in
locations_used body
| e ->
@ -39,23 +39,20 @@ let rec locations_used (e : 'm expr) : LocationSet.t =
(fun e -> LocationSet.union (locations_used e))
e LocationSet.empty
type io_input = NoInput | OnlyInput | Reentrant
type io = { io_output : bool Marked.pos; io_input : io_input Marked.pos }
type 'm rule =
| Definition of location Marked.pos * typ * io * 'm expr
| Definition of location Marked.pos * typ * Desugared.Ast.io * 'm expr
| Assertion of 'm expr
| Call of ScopeName.t * SubScopeName.t * 'm mark
type 'm scope_decl = {
scope_decl_name : ScopeName.t;
scope_sig : (typ * io) ScopeVarMap.t;
scope_sig : (typ * Desugared.Ast.io) ScopeVar.Map.t;
scope_decl_rules : 'm rule list;
scope_mark : 'm mark;
}
type 'm program = {
program_scopes : 'm scope_decl ScopeMap.t;
program_scopes : 'm scope_decl ScopeName.Map.t;
program_ctx : decl_ctx;
}
@ -73,17 +70,17 @@ let type_rule decl_ctx env = function
let type_program (prg : 'm program) : typed program =
let typing_env =
ScopeMap.fold
ScopeName.Map.fold
(fun scope_name scope_decl ->
let vars = ScopeVarMap.map fst scope_decl.scope_sig in
let vars = ScopeVar.Map.map fst scope_decl.scope_sig in
Typing.Env.add_scope scope_name ~vars)
prg.program_scopes Typing.Env.empty
in
let program_scopes =
ScopeMap.map
ScopeName.Map.map
(fun scope_decl ->
let typing_env =
ScopeVarMap.fold
ScopeVar.Map.fold
(fun svar (typ, _) env -> Typing.Env.add_scope_var svar typ env)
scope_decl.scope_sig typing_env
in

View File

@ -16,7 +16,7 @@
(** Abstract syntax tree of the scope language *)
open Utils
open Catala_utils
open Shared_ast
(** {1 Identifiers} *)
@ -31,41 +31,20 @@ type 'm expr = (scopelang, 'm mark) gexpr
val locations_used : 'm expr -> LocationSet.t
(** This type characterizes the three levels of visibility for a given scope
variable with regards to the scope's input and possible redefinitions inside
the scope.. *)
type io_input =
| NoInput
(** For an internal variable defined only in the scope, and does not
appear in the input. *)
| OnlyInput
(** For variables that should not be redefined in the scope, because they
appear in the input. *)
| Reentrant
(** For variables defined in the scope that can also be redefined by the
caller as they appear in the input. *)
type io = {
io_output : bool Marked.pos;
(** [true] is present in the output of the scope. *)
io_input : io_input Marked.pos;
}
(** Characterization of the input/output status of a scope variable. *)
type 'm rule =
| Definition of location Marked.pos * typ * io * 'm expr
| Definition of location Marked.pos * typ * Desugared.Ast.io * 'm expr
| Assertion of 'm expr
| Call of ScopeName.t * SubScopeName.t * 'm mark
type 'm scope_decl = {
scope_decl_name : ScopeName.t;
scope_sig : (typ * io) ScopeVarMap.t;
scope_sig : (typ * Desugared.Ast.io) ScopeVar.Map.t;
scope_decl_rules : 'm rule list;
scope_mark : 'm mark;
}
type 'm program = {
program_scopes : 'm scope_decl ScopeMap.t;
program_scopes : 'm scope_decl ScopeName.Map.t;
program_ctx : decl_ctx;
}

View File

@ -17,7 +17,7 @@
(** Graph representation of the dependencies between scopes in the Catala
program. Vertices are functions, x -> y if x is used in the definition of y. *)
open Utils
open Catala_utils
open Shared_ast
module SVertex = ScopeName
@ -41,13 +41,13 @@ module SSCC = Graph.Components.Make (SDependencies)
let rec expr_used_scopes e =
let recurse_subterms e =
Expr.shallow_fold
(fun e -> ScopeMap.union (fun _ x _ -> Some x) (expr_used_scopes e))
e ScopeMap.empty
(fun e -> ScopeName.Map.union (fun _ x _ -> Some x) (expr_used_scopes e))
e ScopeName.Map.empty
in
match e with
| (EScopeCall (scope, _), m) as e ->
ScopeMap.add scope (Expr.mark_pos m) (recurse_subterms e)
| EAbs (binder, _), _ ->
| (EScopeCall { scope; _ }, m) as e ->
ScopeName.Map.add scope (Expr.mark_pos m) (recurse_subterms e)
| EAbs { binder; _ }, _ ->
let _, body = Bindlib.unmbind binder in
expr_used_scopes body
| e -> recurse_subterms e
@ -58,28 +58,28 @@ let rule_used_scopes = function
walking through all exprs again *)
expr_used_scopes e
| Ast.Call (subscope, subindex, _) ->
ScopeMap.singleton subscope
ScopeName.Map.singleton subscope
(Marked.get_mark (SubScopeName.get_info subindex))
let build_program_dep_graph (prgm : 'm Ast.program) : SDependencies.t =
let g = SDependencies.empty in
let g =
ScopeMap.fold
ScopeName.Map.fold
(fun v _ g -> SDependencies.add_vertex g v)
prgm.program_scopes g
in
ScopeMap.fold
ScopeName.Map.fold
(fun scope_name scope g ->
List.fold_left
(fun g rule ->
let used_scopes = rule_used_scopes rule in
if ScopeMap.mem scope_name used_scopes then
if ScopeName.Map.mem scope_name used_scopes then
Errors.raise_spanned_error
(Marked.get_mark (ScopeName.get_info scope.Ast.scope_decl_name))
"The scope %a is calling into itself as a subscope, which is \
forbidden since Catala does not provide recursion"
ScopeName.format_t scope.Ast.scope_decl_name;
ScopeMap.fold
ScopeName.Map.fold
(fun used_scope pos g ->
let edge = SDependencies.E.create used_scope pos scope_name in
SDependencies.add_edge_e g edge)
@ -190,10 +190,10 @@ let build_type_graph (structs : struct_ctx) (enums : enum_ctx) : TDependencies.t
=
let g = TDependencies.empty in
let g =
StructMap.fold
StructName.Map.fold
(fun s fields g ->
List.fold_left
(fun g (_, typ) ->
StructField.Map.fold
(fun _ typ g ->
let def = TVertex.Struct s in
let g = TDependencies.add_vertex g def in
let used = get_structs_or_enums_in_type typ in
@ -210,14 +210,14 @@ let build_type_graph (structs : struct_ctx) (enums : enum_ctx) : TDependencies.t
in
TDependencies.add_edge_e g edge)
used g)
g fields)
fields g)
structs g
in
let g =
EnumMap.fold
EnumName.Map.fold
(fun e cases g ->
List.fold_left
(fun g (_, typ) ->
EnumConstructor.Map.fold
(fun _ typ g ->
let def = TVertex.Enum e in
let g = TDependencies.add_vertex g def in
let used = get_structs_or_enums_in_type typ in
@ -234,7 +234,7 @@ let build_type_graph (structs : struct_ctx) (enums : enum_ctx) : TDependencies.t
in
TDependencies.add_edge_e g edge)
used g)
g cases)
cases g)
enums g
in
g

View File

@ -17,7 +17,7 @@
(** Graph representation of the dependencies between scopes in the Catala
program. Vertices are functions, x -> y if x is used in the definition of y. *)
open Utils
open Catala_utils
open Shared_ast
(** {1 Scope dependencies} *)

View File

@ -1,7 +1,7 @@
(library
(name scopelang)
(public_name catala.scopelang)
(libraries utils dcalc ocamlgraph)
(libraries catala_utils ocamlgraph desugared)
(flags
(:standard -short-paths)))

View File

@ -0,0 +1,730 @@
(* This file is part of the Catala compiler, a specification language for tax
and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
Denis Merigoux <denis.merigoux@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not
use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *)
(** Translation from {!module: Desugared.Ast} to {!module: Scopelang.Ast} *)
open Catala_utils
open Shared_ast
(** {1 Expression translation}*)
type target_scope_vars =
| WholeVar of ScopeVar.t
| States of (StateName.t * ScopeVar.t) list
type ctx = {
decl_ctx : decl_ctx;
scope_var_mapping : target_scope_vars ScopeVar.Map.t;
var_mapping : (Desugared.Ast.expr, untyped Ast.expr Var.t) Var.Map.t;
}
let tag_with_log_entry
(e : untyped Ast.expr boxed)
(l : log_entry)
(markings : Uid.MarkedString.info list) : untyped Ast.expr boxed =
Expr.eapp
(Expr.eop (Log (l, markings)) [TAny, Expr.pos e] (Marked.get_mark e))
[e] (Marked.get_mark e)
let rec translate_expr (ctx : ctx) (e : Desugared.Ast.expr) :
untyped Ast.expr boxed =
let m = Marked.get_mark e in
match Marked.unmark e with
| ELocation (SubScopeVar (s_name, ss_name, s_var)) ->
(* When referring to a subscope variable in an expression, we are referring
to the output, hence we take the last state. *)
let new_s_var =
match ScopeVar.Map.find (Marked.unmark s_var) ctx.scope_var_mapping with
| WholeVar new_s_var -> Marked.same_mark_as new_s_var s_var
| States states ->
Marked.same_mark_as (snd (List.hd (List.rev states))) s_var
in
Expr.elocation (SubScopeVar (s_name, ss_name, new_s_var)) m
| ELocation (DesugaredScopeVar (s_var, None)) ->
Expr.elocation
(ScopelangScopeVar
(match
ScopeVar.Map.find (Marked.unmark s_var) ctx.scope_var_mapping
with
| WholeVar new_s_var -> Marked.same_mark_as new_s_var s_var
| States _ -> failwith "should not happen"))
m
| ELocation (DesugaredScopeVar (s_var, Some state)) ->
Expr.elocation
(ScopelangScopeVar
(match
ScopeVar.Map.find (Marked.unmark s_var) ctx.scope_var_mapping
with
| WholeVar _ -> failwith "should not happen"
| States states -> Marked.same_mark_as (List.assoc state states) s_var))
m
| EVar v -> Expr.evar (Var.Map.find v ctx.var_mapping) m
| EStruct { name; fields } ->
Expr.estruct name (StructField.Map.map (translate_expr ctx) fields) m
| EDStructAccess { name_opt = None; _ } ->
(* Note: this could only happen if disambiguation was disabled. If we want
to support it, we should still allow this case when the field has only
one possible matching structure *)
Errors.raise_spanned_error (Expr.mark_pos m)
"Ambiguous structure field access"
| EDStructAccess { e; field; name_opt = Some name } ->
let e' = translate_expr ctx e in
let field =
try
StructName.Map.find name
(IdentName.Map.find field ctx.decl_ctx.ctx_struct_fields)
with Not_found ->
(* Should not happen after disambiguation *)
Errors.raise_spanned_error (Expr.mark_pos m)
"Field %s does not belong to structure %a" field StructName.format_t
name
in
Expr.estructaccess e' field name m
| EInj { e; cons; name } -> Expr.einj (translate_expr ctx e) cons name m
| EMatch { e; name; cases } ->
Expr.ematch (translate_expr ctx e) name
(EnumConstructor.Map.map (translate_expr ctx) cases)
m
| EScopeCall { scope; args } ->
Expr.escopecall scope
(ScopeVar.Map.fold
(fun v e args' ->
let v' =
match ScopeVar.Map.find v ctx.scope_var_mapping with
| WholeVar v' -> v'
| States ((_, v') :: _) ->
(* When there are multiple states, the input is always the first
one *)
v'
| States [] -> assert false
in
ScopeVar.Map.add v' (translate_expr ctx e) args')
args ScopeVar.Map.empty)
m
| ELit
(( LBool _ | LEmptyError | LInt _ | LRat _ | LMoney _ | LUnit | LDate _
| LDuration _ ) as l) ->
Expr.elit l m
| EAbs { binder; tys } ->
let vars, body = Bindlib.unmbind binder in
let new_vars = Array.map (fun var -> Var.make (Bindlib.name_of var)) vars in
let ctx =
List.fold_left2
(fun ctx var new_var ->
{ ctx with var_mapping = Var.Map.add var new_var ctx.var_mapping })
ctx (Array.to_list vars) (Array.to_list new_vars)
in
Expr.eabs (Expr.bind new_vars (translate_expr ctx body)) tys m
| EApp { f = EOp { op; tys }, m1; args } ->
let args = List.map (translate_expr ctx) args in
Operator.kind_dispatch op
~monomorphic:(fun op -> Expr.eapp (Expr.eop op tys m1) args m)
~polymorphic:(fun op -> Expr.eapp (Expr.eop op tys m1) args m)
~overloaded:(fun op ->
match
Operator.resolve_overload ctx.decl_ctx
(Marked.mark (Expr.pos e) op)
tys
with
| op, `Straight -> Expr.eapp (Expr.eop op tys m1) args m
| op, `Reversed ->
Expr.eapp (Expr.eop op (List.rev tys) m1) (List.rev args) m)
| EOp _ -> assert false (* Only allowed within [EApp] *)
| EApp { f; args } ->
Expr.eapp (translate_expr ctx f) (List.map (translate_expr ctx) args) m
| EDefault { excepts; just; cons } ->
Expr.edefault
(List.map (translate_expr ctx) excepts)
(translate_expr ctx just) (translate_expr ctx cons) m
| EIfThenElse { cond; etrue; efalse } ->
Expr.eifthenelse (translate_expr ctx cond) (translate_expr ctx etrue)
(translate_expr ctx efalse)
m
| EArray args -> Expr.earray (List.map (translate_expr ctx) args) m
| EErrorOnEmpty e1 -> Expr.eerroronempty (translate_expr ctx e1) m
(** {1 Rule tree construction} *)
(** Intermediate representation for the exception tree of rules for a particular
scope definition. *)
type rule_tree =
| Leaf of Desugared.Ast.rule list
(** Rules defining a base case piecewise. List is non-empty. *)
| Node of rule_tree list * Desugared.Ast.rule list
(** [Node (exceptions, base_case)] is a list of exceptions to a non-empty
list of rules defining a base case piecewise. *)
(** Transforms a flat list of rules into a tree, taking into account the
priorities declared between rules *)
let def_map_to_tree
(def_info : Desugared.Ast.ScopeDef.t)
(def : Desugared.Ast.rule RuleName.Map.t) : rule_tree list =
let exc_graph = Desugared.Dependency.build_exceptions_graph def def_info in
Desugared.Dependency.check_for_exception_cycle exc_graph;
(* we start by the base cases: they are the vertices which have no
successors *)
let base_cases =
Desugared.Dependency.ExceptionsDependencies.fold_vertex
(fun v base_cases ->
if
Desugared.Dependency.ExceptionsDependencies.out_degree exc_graph v = 0
then v :: base_cases
else base_cases)
exc_graph []
in
let rec build_tree (base_cases : RuleName.Set.t) : rule_tree =
let exceptions =
Desugared.Dependency.ExceptionsDependencies.pred exc_graph base_cases
in
let base_case_as_rule_list =
List.map
(fun r -> RuleName.Map.find r def)
(RuleName.Set.elements base_cases)
in
match exceptions with
| [] -> Leaf base_case_as_rule_list
| _ -> Node (List.map build_tree exceptions, base_case_as_rule_list)
in
List.map build_tree base_cases
(** From the {!type: rule_tree}, builds an {!constructor: Dcalc.EDefault}
expression in the scope language. The [~toplevel] parameter is used to know
when to place the toplevel binding in the case of functions. *)
let rec rule_tree_to_expr
~(toplevel : bool)
~(is_reentrant_var : bool)
(ctx : ctx)
(def_pos : Pos.t)
(is_func : Desugared.Ast.expr Var.t option)
(tree : rule_tree) : untyped Ast.expr boxed =
let emark = Untyped { pos = def_pos } in
let exceptions, base_rules =
match tree with Leaf r -> [], r | Node (exceptions, r) -> exceptions, r
in
(* because each rule has its own variable parameter and we want to convert the
whole rule tree into a function, we need to perform some alpha-renaming of
all the expressions *)
let substitute_parameter
(e : Desugared.Ast.expr boxed)
(rule : Desugared.Ast.rule) : Desugared.Ast.expr boxed =
match is_func, rule.Desugared.Ast.rule_parameter with
| Some new_param, Some (old_param, _) ->
let binder = Bindlib.bind_var old_param (Marked.unmark e) in
Marked.mark (Marked.get_mark e)
@@ Bindlib.box_apply2
(fun binder new_param -> Bindlib.subst binder new_param)
binder
(Bindlib.box_var new_param)
| None, None -> e
| _ -> assert false
(* should not happen *)
in
let ctx =
match is_func with
| None -> ctx
| Some new_param -> (
match Var.Map.find_opt new_param ctx.var_mapping with
| None ->
let new_param_scope = Var.make (Bindlib.name_of new_param) in
{
ctx with
var_mapping = Var.Map.add new_param new_param_scope ctx.var_mapping;
}
| Some _ ->
(* We only create a mapping if none exists because [rule_tree_to_expr]
is called recursively on the exceptions of the tree and we don't want
to create a new Scopelang variable for the parameter at each tree
level. *)
ctx)
in
let base_just_list =
List.map
(fun rule -> substitute_parameter rule.Desugared.Ast.rule_just rule)
base_rules
in
let base_cons_list =
List.map
(fun rule -> substitute_parameter rule.Desugared.Ast.rule_cons rule)
base_rules
in
let translate_and_unbox_list (list : Desugared.Ast.expr boxed list) :
untyped Ast.expr boxed list =
List.map
(fun e ->
(* There are two levels of boxing here, the outermost is introduced by
the [translate_expr] function for which all of the bindings should
have been closed by now, so we can safely unbox. *)
translate_expr ctx (Expr.unbox e))
list
in
let default_containing_base_cases =
Expr.make_default
(List.map2
(fun base_just base_cons ->
Expr.make_default []
(* Here we insert the logging command that records when a decision
is taken for the value of a variable. *)
(tag_with_log_entry base_just PosRecordIfTrueBool [])
base_cons emark)
(translate_and_unbox_list base_just_list)
(translate_and_unbox_list base_cons_list))
(Expr.elit (LBool false) emark)
(Expr.elit LEmptyError emark)
emark
in
let exceptions =
List.map
(rule_tree_to_expr ~toplevel:false ~is_reentrant_var ctx def_pos is_func)
exceptions
in
let default =
Expr.make_default exceptions
(Expr.elit (LBool true) emark)
default_containing_base_cases emark
in
match is_func, (List.hd base_rules).Desugared.Ast.rule_parameter with
| None, None -> default
| Some new_param, Some (_, typ) ->
if toplevel then
(* When we're creating a function from multiple defaults, we must check
that the result returned by the function is not empty, unless we're
dealing with a context variable which is reentrant (either in the
caller or callee). In this case the ErrorOnEmpty will be added later in
the scopelang->dcalc translation. *)
let default =
if is_reentrant_var then default else Expr.eerroronempty default emark
in
Expr.make_abs
[| Var.Map.find new_param ctx.var_mapping |]
default [typ] def_pos
else default
| _ -> (* should not happen *) assert false
(** {1 AST translation} *)
(** Translates a definition inside a scope, the resulting expression should be
an {!constructor: Dcalc.EDefault} *)
let translate_def
(ctx : ctx)
(def_info : Desugared.Ast.ScopeDef.t)
(def : Desugared.Ast.rule RuleName.Map.t)
(typ : typ)
(io : Desugared.Ast.io)
~(is_cond : bool)
~(is_subscope_var : bool) : untyped Ast.expr boxed =
(* Here, we have to transform this list of rules into a default tree. *)
let is_def_func =
match Marked.unmark typ with TArrow (_, _) -> true | _ -> false
in
let is_rule_func _ (r : Desugared.Ast.rule) : bool =
Option.is_some r.Desugared.Ast.rule_parameter
in
let all_rules_func = RuleName.Map.for_all is_rule_func def in
let all_rules_not_func =
RuleName.Map.for_all (fun n r -> not (is_rule_func n r)) def
in
let is_def_func_param_typ : typ option =
if is_def_func && all_rules_func then
match Marked.unmark typ with
| TArrow (t_param, _) -> Some t_param
| _ ->
Errors.raise_spanned_error (Marked.get_mark typ)
"The definitions of %a are function but it doesn't have a function \
type"
Desugared.Ast.ScopeDef.format_t def_info
else if (not is_def_func) && all_rules_not_func then None
else
let spans =
List.map
(fun (_, r) ->
( Some "This definition is a function:",
Expr.pos r.Desugared.Ast.rule_cons ))
(RuleName.Map.bindings (RuleName.Map.filter is_rule_func def))
@ List.map
(fun (_, r) ->
( Some "This definition is not a function:",
Expr.pos r.Desugared.Ast.rule_cons ))
(RuleName.Map.bindings
(RuleName.Map.filter (fun n r -> not (is_rule_func n r)) def))
in
Errors.raise_multispanned_error spans
"some definitions of the same variable are functions while others \
aren't"
in
let top_list = def_map_to_tree def_info def in
let is_input =
match Marked.unmark io.Desugared.Ast.io_input with
| OnlyInput -> true
| _ -> false
in
let is_reentrant =
match Marked.unmark io.Desugared.Ast.io_input with
| Reentrant -> true
| _ -> false
in
let top_value =
if is_cond && ((not is_subscope_var) || (is_subscope_var && is_input)) then
(* We add the bottom [false] value for conditions, only for the scope
where the condition is declared. Except when the variable is an input,
where we want the [false] to be added at each caller parent scope. *)
Some
(Desugared.Ast.always_false_rule
(Desugared.Ast.ScopeDef.get_position def_info)
is_def_func_param_typ)
else None
in
if
RuleName.Map.cardinal def = 0
&& is_subscope_var
(* Here we have a special case for the empty definitions. Indeed, we could
use the code for the regular case below that would create a convoluted
default always returning empty error, and this would be correct. But it
gets more complicated with functions. Indeed, if we create an empty
definition for a subscope argument whose type is a function, we get
something like [fun () -> (fun real_param -> < ... >)] that is passed as
an argument to the subscope. The sub-scope de-thunks but the de-thunking
does not return empty error, signalling there is not reentrant variable,
because functions are values! So the subscope does not see that there is
not reentrant variable and does not pick its internal definition instead.
See [test/test_scope/subscope_function_arg_not_defined.catala_en] for a
test case exercising that subtlety.
To avoid this complication we special case here and put an empty error
for all subscope variables that are not defined. It covers the subtlety
with functions described above but also conditions with the false default
value. *)
&& not (is_cond && is_input)
(* However, this special case suffers from an exception: when a condition is
defined as an OnlyInput to a subscope, since the [false] default value
will not be provided by the calee scope, it has to be placed in the
caller. *)
then
let m = Untyped { pos = Desugared.Ast.ScopeDef.get_position def_info } in
let empty_error = Expr.elit LEmptyError m in
match is_def_func_param_typ with
| Some ty ->
Expr.make_abs [| Var.make "_" |] empty_error [ty] (Expr.mark_pos m)
| _ -> empty_error
else
rule_tree_to_expr ~toplevel:true ~is_reentrant_var:is_reentrant ctx
(Desugared.Ast.ScopeDef.get_position def_info)
(Option.map (fun _ -> Var.make "param") is_def_func_param_typ)
(match top_list, top_value with
| [], None ->
(* In this case, there are no rules to define the expression and no
default value so we put an empty rule. *)
Leaf
[Desugared.Ast.empty_rule (Marked.get_mark typ) is_def_func_param_typ]
| [], Some top_value ->
(* In this case, there are no rules to define the expression but a
default value so we put it. *)
Leaf [top_value]
| _, Some top_value ->
(* When there are rules + a default value, we put the rules as
exceptions to the default value *)
Node (top_list, [top_value])
| [top_tree], None -> top_tree
| _, None ->
Node
( top_list,
[
Desugared.Ast.empty_rule (Marked.get_mark typ)
is_def_func_param_typ;
] ))
let translate_rule ctx (scope : Desugared.Ast.scope) = function
| Desugared.Dependency.Vertex.Var (var, state) -> (
let scope_def =
Desugared.Ast.ScopeDefMap.find
(Desugared.Ast.ScopeDef.Var (var, state))
scope.scope_defs
in
let var_def = scope_def.scope_def_rules in
let var_typ = scope_def.scope_def_typ in
let is_cond = scope_def.scope_def_is_condition in
match Marked.unmark scope_def.Desugared.Ast.scope_def_io.io_input with
| OnlyInput when not (RuleName.Map.is_empty var_def) ->
(* If the variable is tagged as input, then it shall not be redefined. *)
Errors.raise_multispanned_error
((Some "Incriminated variable:", Marked.get_mark (ScopeVar.get_info var))
:: List.map
(fun (rule, _) ->
( Some "Incriminated variable definition:",
Marked.get_mark (RuleName.get_info rule) ))
(RuleName.Map.bindings var_def))
"It is impossible to give a definition to a scope variable tagged as \
input."
| OnlyInput -> []
(* we do not provide any definition for an input-only variable *)
| _ ->
let expr_def =
translate_def ctx
(Desugared.Ast.ScopeDef.Var (var, state))
var_def var_typ scope_def.Desugared.Ast.scope_def_io ~is_cond
~is_subscope_var:false
in
let scope_var =
match ScopeVar.Map.find var ctx.scope_var_mapping, state with
| WholeVar v, None -> v
| States states, Some state -> List.assoc state states
| _ -> failwith "should not happen"
in
[
Ast.Definition
( ( ScopelangScopeVar
(scope_var, Marked.get_mark (ScopeVar.get_info scope_var)),
Marked.get_mark (ScopeVar.get_info scope_var) ),
var_typ,
scope_def.Desugared.Ast.scope_def_io,
Expr.unbox expr_def );
])
| Desugared.Dependency.Vertex.SubScope sub_scope_index ->
(* Before calling the sub_scope, we need to include all the re-definitions
of subscope parameters*)
let sub_scope =
SubScopeName.Map.find sub_scope_index scope.scope_sub_scopes
in
let sub_scope_vars_redefs_candidates =
Desugared.Ast.ScopeDefMap.filter
(fun def_key scope_def ->
match def_key with
| Desugared.Ast.ScopeDef.Var _ -> false
| Desugared.Ast.ScopeDef.SubScopeVar (sub_scope_index', _, _) ->
sub_scope_index = sub_scope_index'
(* We exclude subscope variables that have 0 re-definitions and are
not visible in the input of the subscope *)
&& not
((match
Marked.unmark scope_def.Desugared.Ast.scope_def_io.io_input
with
| Desugared.Ast.NoInput -> true
| _ -> false)
&& RuleName.Map.is_empty scope_def.scope_def_rules))
scope.scope_defs
in
let sub_scope_vars_redefs =
Desugared.Ast.ScopeDefMap.mapi
(fun def_key scope_def ->
let def = scope_def.Desugared.Ast.scope_def_rules in
let def_typ = scope_def.scope_def_typ in
let is_cond = scope_def.scope_def_is_condition in
match def_key with
| Desugared.Ast.ScopeDef.Var _ -> assert false (* should not happen *)
| Desugared.Ast.ScopeDef.SubScopeVar (sscope, sub_scope_var, pos) ->
(* This definition redefines a variable of the correct subscope. But
we have to check that this redefinition is allowed with respect
to the io parameters of that subscope variable. *)
(match
Marked.unmark scope_def.Desugared.Ast.scope_def_io.io_input
with
| Desugared.Ast.NoInput ->
Errors.raise_multispanned_error
(( Some "Incriminated subscope:",
Marked.get_mark (SubScopeName.get_info sscope) )
:: ( Some "Incriminated variable:",
Marked.get_mark (ScopeVar.get_info sub_scope_var) )
:: List.map
(fun (rule, _) ->
( Some "Incriminated subscope variable definition:",
Marked.get_mark (RuleName.get_info rule) ))
(RuleName.Map.bindings def))
"It is impossible to give a definition to a subscope variable \
not tagged as input or context."
| OnlyInput when RuleName.Map.is_empty def && not is_cond ->
(* If the subscope variable is tagged as input, then it shall be
defined. *)
Errors.raise_multispanned_error
[
( Some "Incriminated subscope:",
Marked.get_mark (SubScopeName.get_info sscope) );
Some "Incriminated variable:", pos;
]
"This subscope variable is a mandatory input but no definition \
was provided."
| _ -> ());
(* Now that all is good, we can proceed with translating this
redefinition to a proper Scopelang term. *)
let expr_def =
translate_def ctx def_key def def_typ
scope_def.Desugared.Ast.scope_def_io ~is_cond
~is_subscope_var:true
in
let subscop_real_name =
SubScopeName.Map.find sub_scope_index scope.scope_sub_scopes
in
let var_pos = Desugared.Ast.ScopeDef.get_position def_key in
Ast.Definition
( ( SubScopeVar
( subscop_real_name,
(sub_scope_index, var_pos),
match
ScopeVar.Map.find sub_scope_var ctx.scope_var_mapping
with
| WholeVar v -> v, var_pos
| States states ->
(* When defining a sub-scope variable, we always define
its first state in the sub-scope. *)
snd (List.hd states), var_pos ),
var_pos ),
def_typ,
scope_def.Desugared.Ast.scope_def_io,
Expr.unbox expr_def ))
sub_scope_vars_redefs_candidates
in
let sub_scope_vars_redefs =
List.map snd (Desugared.Ast.ScopeDefMap.bindings sub_scope_vars_redefs)
in
sub_scope_vars_redefs
@ [
Ast.Call
( sub_scope,
sub_scope_index,
Untyped
{ pos = Marked.get_mark (SubScopeName.get_info sub_scope_index) }
);
]
(** Translates a scope *)
let translate_scope (ctx : ctx) (scope : Desugared.Ast.scope) :
untyped Ast.scope_decl =
let scope_dependencies =
Desugared.Dependency.build_scope_dependencies scope
in
Desugared.Dependency.check_for_cycle scope scope_dependencies;
let scope_ordering =
Desugared.Dependency.correct_computation_ordering scope_dependencies
in
let scope_decl_rules =
List.flatten (List.map (translate_rule ctx scope) scope_ordering)
in
(* Then, after having computed all the scopes variables, we add the
assertions. TODO: the assertions should be interleaved with the
definitions! *)
let scope_decl_rules =
scope_decl_rules
@ List.map
(fun e ->
let scope_e = translate_expr ctx (Expr.unbox e) in
Ast.Assertion (Expr.unbox scope_e))
scope.Desugared.Ast.scope_assertions
in
let scope_sig =
ScopeVar.Map.fold
(fun var (states : Desugared.Ast.var_or_states) acc ->
match states with
| WholeVar ->
let scope_def =
Desugared.Ast.ScopeDefMap.find
(Desugared.Ast.ScopeDef.Var (var, None))
scope.scope_defs
in
let typ = scope_def.scope_def_typ in
ScopeVar.Map.add
(match ScopeVar.Map.find var ctx.scope_var_mapping with
| WholeVar v -> v
| States _ -> failwith "should not happen")
(typ, scope_def.scope_def_io)
acc
| States states ->
(* What happens in the case of variables with multiple states is
interesting. We need to create as many Var entries in the scope
signature as there are states. *)
List.fold_left
(fun acc (state : StateName.t) ->
let scope_def =
Desugared.Ast.ScopeDefMap.find
(Desugared.Ast.ScopeDef.Var (var, Some state))
scope.scope_defs
in
ScopeVar.Map.add
(match ScopeVar.Map.find var ctx.scope_var_mapping with
| WholeVar _ -> failwith "should not happen"
| States states' -> List.assoc state states')
(scope_def.scope_def_typ, scope_def.scope_def_io)
acc)
acc states)
scope.scope_vars ScopeVar.Map.empty
in
let pos = Marked.get_mark (ScopeName.get_info scope.scope_uid) in
{
Ast.scope_decl_name = scope.scope_uid;
Ast.scope_decl_rules;
Ast.scope_sig;
Ast.scope_mark = Untyped { pos };
}
(** {1 API} *)
let translate_program (pgrm : Desugared.Ast.program) : untyped Ast.program =
(* First we give mappings to all the locations between Desugared and This
involves creating a new Scopelang scope variable for every state of a
Desugared variable. *)
let ctx =
(* Todo: since we rename all scope vars at this point, it would be better to
have different types for Desugared.ScopeVar.t and Scopelang.ScopeVar.t *)
ScopeName.Map.fold
(fun _scope scope_decl ctx ->
ScopeVar.Map.fold
(fun scope_var (states : Desugared.Ast.var_or_states) ctx ->
let var_name, var_pos = ScopeVar.get_info scope_var in
let new_var =
match states with
| Desugared.Ast.WholeVar ->
WholeVar (ScopeVar.fresh (var_name, var_pos))
| States states ->
let var_prefix = var_name ^ "_" in
let state_var state =
ScopeVar.fresh
(Marked.map_under_mark (( ^ ) var_prefix)
(StateName.get_info state))
in
States (List.map (fun state -> state, state_var state) states)
in
{
ctx with
scope_var_mapping =
ScopeVar.Map.add scope_var new_var ctx.scope_var_mapping;
})
scope_decl.Desugared.Ast.scope_vars ctx)
pgrm.Desugared.Ast.program_scopes
{
scope_var_mapping = ScopeVar.Map.empty;
var_mapping = Var.Map.empty;
decl_ctx = pgrm.program_ctx;
}
in
let ctx_scopes =
ScopeName.Map.map
(fun out_str ->
let out_struct_fields =
ScopeVar.Map.fold
(fun var fld out_map ->
let var' =
match ScopeVar.Map.find var ctx.scope_var_mapping with
| WholeVar v -> v
| States l -> snd (List.hd (List.rev l))
in
ScopeVar.Map.add var' fld out_map)
out_str.out_struct_fields ScopeVar.Map.empty
in
{ out_str with out_struct_fields })
pgrm.Desugared.Ast.program_ctx.ctx_scopes
in
{
Ast.program_scopes =
ScopeName.Map.map (translate_scope ctx) pgrm.program_scopes;
program_ctx = { pgrm.program_ctx with ctx_scopes };
}

View File

@ -16,4 +16,4 @@
(** Translation from {!module: Desugared.Ast} to {!module: Scopelang.Ast} *)
val translate_program : Ast.program -> Shared_ast.untyped Scopelang.Ast.program
val translate_program : Desugared.Ast.program -> Shared_ast.untyped Ast.program

View File

@ -14,7 +14,7 @@
License for the specific language governing permissions and limitations under
the License. *)
open Utils
open Catala_utils
open Shared_ast
open Ast
@ -22,21 +22,22 @@ let struc
ctx
(fmt : Format.formatter)
(name : StructName.t)
(fields : (StructFieldName.t * typ) list) : unit =
(fields : typ StructField.Map.t) : unit =
Format.fprintf fmt "%a %a %a %a@\n@[<hov 2> %a@]@\n%a" Print.keyword "struct"
StructName.format_t name Print.punctuation "=" Print.punctuation "{"
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n")
(fun fmt (field_name, typ) ->
Format.fprintf fmt "%a%a %a" StructFieldName.format_t field_name
Format.fprintf fmt "%a%a %a" StructField.format_t field_name
Print.punctuation ":" (Print.typ ctx) typ))
fields Print.punctuation "}"
(StructField.Map.bindings fields)
Print.punctuation "}"
let enum
ctx
(fmt : Format.formatter)
(name : EnumName.t)
(cases : (EnumConstructor.t * typ) list) : unit =
(cases : typ EnumConstructor.Map.t) : unit =
Format.fprintf fmt "%a %a %a @\n@[<hov 2> %a@]" Print.keyword "enum"
EnumName.format_t name Print.punctuation "="
(Format.pp_print_list
@ -45,7 +46,7 @@ let enum
Format.fprintf fmt "%a %a%a %a" Print.punctuation "|"
EnumConstructor.format_t field_name Print.punctuation ":"
(Print.typ ctx) typ))
cases
(EnumConstructor.Map.bindings cases)
let scope ?(debug = false) ctx fmt (name, decl) =
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@ %a@ %a@]@\n@[<v 2> %a@]"
@ -55,16 +56,16 @@ let scope ?(debug = false) ctx fmt (name, decl) =
Format.fprintf fmt "%a%a%a %a%a%a%a%a" Print.punctuation "("
ScopeVar.format_t scope_var Print.punctuation ":" (Print.typ ctx) typ
Print.punctuation "|" Print.keyword
(match Marked.unmark vis.io_input with
(match Marked.unmark vis.Desugared.Ast.io_input with
| NoInput -> "internal"
| OnlyInput -> "input"
| Reentrant -> "context")
(if Marked.unmark vis.io_output then fun fmt () ->
(if Marked.unmark vis.Desugared.Ast.io_output then fun fmt () ->
Format.fprintf fmt "%a@,%a" Print.punctuation "|" Print.keyword
"output"
else fun fmt () -> Format.fprintf fmt "@<0>")
() Print.punctuation ")"))
(ScopeVarMap.bindings decl.scope_sig)
(ScopeVar.Map.bindings decl.scope_sig)
Print.punctuation "="
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "%a@ " Print.punctuation ";")
@ -80,11 +81,11 @@ let scope ?(debug = false) ctx fmt (name, decl) =
| ScopelangScopeVar v -> (
match
Marked.unmark
(snd (ScopeVarMap.find (Marked.unmark v) decl.scope_sig))
(snd (ScopeVar.Map.find (Marked.unmark v) decl.scope_sig))
.io_input
with
| Reentrant ->
Format.fprintf fmt "%a@ %a" Print.operator
Format.fprintf fmt "%a@ %a" Print.op_style
"reentrant or by default" (Print.expr ~debug ctx) e
| _ -> Format.fprintf fmt "%a" (Print.expr ~debug ctx) e))
e
@ -105,16 +106,16 @@ let program ?(debug : bool = false) (fmt : Format.formatter) (p : 'm program) :
Format.pp_print_cut fmt ()
in
Format.pp_open_vbox fmt 0;
StructMap.iter
StructName.Map.iter
(fun n s ->
struc ctx fmt n s;
pp_sep fmt ())
ctx.ctx_structs;
EnumMap.iter
EnumName.Map.iter
(fun n e ->
enum ctx fmt n e;
pp_sep fmt ())
ctx.ctx_enums;
Format.pp_print_list ~pp_sep (scope ~debug ctx) fmt
(ScopeMap.bindings p.program_scopes);
(ScopeName.Map.bindings p.program_scopes);
Format.pp_close_box fmt ()

View File

@ -20,59 +20,47 @@
(* Doesn't define values, so OK to have without an mli *)
open Utils
open Catala_utils
module Runtime = Runtime_ocaml.Runtime
module ScopeName = Uid.Gen ()
module StructName = Uid.Gen ()
module StructField = Uid.Gen ()
module EnumName = Uid.Gen ()
module EnumConstructor = Uid.Gen ()
module ScopeName : Uid.Id with type info = Uid.MarkedString.info =
Uid.Make (Uid.MarkedString) ()
(** Only used by surface *)
module ScopeSet : Set.S with type elt = ScopeName.t = Set.Make (ScopeName)
module ScopeMap : Map.S with type key = ScopeName.t = Map.Make (ScopeName)
module RuleName = Uid.Gen ()
module LabelName = Uid.Gen ()
module StructName : Uid.Id with type info = Uid.MarkedString.info =
Uid.Make (Uid.MarkedString) ()
(** Used for unresolved structs/maps in desugared *)
module StructFieldName : Uid.Id with type info = Uid.MarkedString.info =
Uid.Make (Uid.MarkedString) ()
module StructMap : Map.S with type key = StructName.t = Map.Make (StructName)
module EnumName : Uid.Id with type info = Uid.MarkedString.info =
Uid.Make (Uid.MarkedString) ()
module EnumConstructor : Uid.Id with type info = Uid.MarkedString.info =
Uid.Make (Uid.MarkedString) ()
module EnumMap : Map.S with type key = EnumName.t = Map.Make (EnumName)
module IdentName = String
(** Only used by desugared/scopelang *)
module ScopeVar : Uid.Id with type info = Uid.MarkedString.info =
Uid.Make (Uid.MarkedString) ()
module ScopeVarSet : Set.S with type elt = ScopeVar.t = Set.Make (ScopeVar)
module ScopeVarMap : Map.S with type key = ScopeVar.t = Map.Make (ScopeVar)
module SubScopeName : Uid.Id with type info = Uid.MarkedString.info =
Uid.Make (Uid.MarkedString) ()
module SubScopeNameSet : Set.S with type elt = SubScopeName.t =
Set.Make (SubScopeName)
module SubScopeMap : Map.S with type key = SubScopeName.t =
Map.Make (SubScopeName)
module StructFieldMap : Map.S with type key = StructFieldName.t =
Map.Make (StructFieldName)
module EnumConstructorMap : Map.S with type key = EnumConstructor.t =
Map.Make (EnumConstructor)
module StateName : Uid.Id with type info = Uid.MarkedString.info =
Uid.Make (Uid.MarkedString) ()
module ScopeVar = Uid.Gen ()
module SubScopeName = Uid.Gen ()
module StateName = Uid.Gen ()
(** {1 Abstract syntax tree} *)
(** Define a common base type for the expressions in most passes of the compiler *)
type desugared = [ `Desugared ]
(** {2 Phantom types used to select relevant cases on the generic AST}
we instantiate them with a polymorphic variant to take advantage of
sub-typing. The values aren't actually used. *)
type scopelang = [ `Scopelang ]
type dcalc = [ `Dcalc ]
type lcalc = [ `Lcalc ]
type 'a any = [< desugared | scopelang | dcalc | lcalc ] as 'a
(** ['a any] is 'a, but adds the constraint that it should be restricted to
valid AST kinds *)
(** {2 Types} *)
type typ_lit = TBool | TUnit | TInt | TRat | TMoney | TDate | TDuration
@ -94,33 +82,6 @@ and naked_typ =
type date = Runtime.date
type duration = Runtime.duration
type op_kind =
| KInt
| KRat
| KMoney
| KDate
| KDuration (** All ops don't have a KDate and KDuration. *)
type ternop = Fold
type binop =
| And
| Or
| Xor
| Add of op_kind
| Sub of op_kind
| Mult of op_kind
| Div of op_kind
| Lt of op_kind
| Lte of op_kind
| Gt of op_kind
| Gte of op_kind
| Eq
| Neq
| Map
| Concat
| Filter
type log_entry =
| VarDef of naked_typ
(** During code generation, we need to know the type of the variable being
@ -129,35 +90,140 @@ type log_entry =
| EndCall
| PosRecordIfTrueBool
type unop =
| Not
| Minus of op_kind
| Log of log_entry * Uid.MarkedString.info list
| Length
| IntToRat
| MoneyToRat
| RatToMoney
| GetDay
| GetMonth
| GetYear
| FirstDayOfMonth
| LastDayOfMonth
| RoundMoney
| RoundDecimal
module Op = struct
(** Classification of operators on how they should be typed *)
type operator = Ternop of ternop | Binop of binop | Unop of unop
type monomorphic =
| Monomorphic (** Operands and return types of the operator are fixed *)
type polymorphic =
| Polymorphic
(** The operator is truly polymorphic: it's the same runtime function
that may work on multiple types. We require that resolving the
argument types from right to left trivially resolves all type
variables declared in the operator type. *)
type overloaded =
| Overloaded
(** The operator is ambiguous and requires the types of its arguments to
be known before it can be typed, using a pre-defined table *)
type resolved =
| Resolved (** Explicit monomorphic versions of the overloaded operators *)
(** Classification of operators. This could be inlined in the definition of
[t] but is more concise this way *)
type (_, _) kind =
| Monomorphic : ('a any, monomorphic) kind
| Polymorphic : ('a any, polymorphic) kind
| Overloaded : ([< desugared ], overloaded) kind
| Resolved : ([< scopelang | dcalc | lcalc ], resolved) kind
type (_, _) t =
(* unary *)
(* * monomorphic *)
| Not : ('a any, monomorphic) t
| GetDay : ('a any, monomorphic) t
| GetMonth : ('a any, monomorphic) t
| GetYear : ('a any, monomorphic) t
| FirstDayOfMonth : ('a any, monomorphic) t
| LastDayOfMonth : ('a any, monomorphic) t
(* * polymorphic *)
| Length : ('a any, polymorphic) t
| Log : log_entry * Uid.MarkedString.info list -> ('a any, polymorphic) t
(* * overloaded *)
| Minus : (desugared, overloaded) t
| Minus_int : ([< scopelang | dcalc | lcalc ], resolved) t
| Minus_rat : ([< scopelang | dcalc | lcalc ], resolved) t
| Minus_mon : ([< scopelang | dcalc | lcalc ], resolved) t
| Minus_dur : ([< scopelang | dcalc | lcalc ], resolved) t
| ToRat : (desugared, overloaded) t
| ToRat_int : ([< scopelang | dcalc | lcalc ], resolved) t
| ToRat_mon : ([< scopelang | dcalc | lcalc ], resolved) t
| ToMoney : (desugared, overloaded) t
| ToMoney_rat : ([< scopelang | dcalc | lcalc ], resolved) t
| Round : (desugared, overloaded) t
| Round_rat : ([< scopelang | dcalc | lcalc ], resolved) t
| Round_mon : ([< scopelang | dcalc | lcalc ], resolved) t
(* binary *)
(* * monomorphic *)
| And : ('a any, monomorphic) t
| Or : ('a any, monomorphic) t
| Xor : ('a any, monomorphic) t
(* * polymorphic *)
| Eq : ('a any, polymorphic) t
| Map : ('a any, polymorphic) t
| Concat : ('a any, polymorphic) t
| Filter : ('a any, polymorphic) t
| Reduce : ('a any, polymorphic) t
(* * overloaded *)
| Add : (desugared, overloaded) t
| Add_int_int : ([< scopelang | dcalc | lcalc ], resolved) t
| Add_rat_rat : ([< scopelang | dcalc | lcalc ], resolved) t
| Add_mon_mon : ([< scopelang | dcalc | lcalc ], resolved) t
| Add_dat_dur : ([< scopelang | dcalc | lcalc ], resolved) t
| Add_dur_dur : ([< scopelang | dcalc | lcalc ], resolved) t
| Sub : (desugared, overloaded) t
| Sub_int_int : ([< scopelang | dcalc | lcalc ], resolved) t
| Sub_rat_rat : ([< scopelang | dcalc | lcalc ], resolved) t
| Sub_mon_mon : ([< scopelang | dcalc | lcalc ], resolved) t
| Sub_dat_dat : ([< scopelang | dcalc | lcalc ], resolved) t
| Sub_dat_dur : ([< scopelang | dcalc | lcalc ], resolved) t
| Sub_dur_dur : ([< scopelang | dcalc | lcalc ], resolved) t
| Mult : (desugared, overloaded) t
| Mult_int_int : ([< scopelang | dcalc | lcalc ], resolved) t
| Mult_rat_rat : ([< scopelang | dcalc | lcalc ], resolved) t
| Mult_mon_rat : ([< scopelang | dcalc | lcalc ], resolved) t
| Mult_dur_int : ([< scopelang | dcalc | lcalc ], resolved) t
| Div : (desugared, overloaded) t
| Div_int_int : ([< scopelang | dcalc | lcalc ], resolved) t
| Div_rat_rat : ([< scopelang | dcalc | lcalc ], resolved) t
| Div_mon_rat : ([< scopelang | dcalc | lcalc ], resolved) t
| Div_mon_mon : ([< scopelang | dcalc | lcalc ], resolved) t
| Lt : (desugared, overloaded) t
| Lt_int_int : ([< scopelang | dcalc | lcalc ], resolved) t
| Lt_rat_rat : ([< scopelang | dcalc | lcalc ], resolved) t
| Lt_mon_mon : ([< scopelang | dcalc | lcalc ], resolved) t
| Lt_dat_dat : ([< scopelang | dcalc | lcalc ], resolved) t
| Lt_dur_dur : ([< scopelang | dcalc | lcalc ], resolved) t
| Lte : (desugared, overloaded) t
| Lte_int_int : ([< scopelang | dcalc | lcalc ], resolved) t
| Lte_rat_rat : ([< scopelang | dcalc | lcalc ], resolved) t
| Lte_mon_mon : ([< scopelang | dcalc | lcalc ], resolved) t
| Lte_dat_dat : ([< scopelang | dcalc | lcalc ], resolved) t
| Lte_dur_dur : ([< scopelang | dcalc | lcalc ], resolved) t
| Gt : (desugared, overloaded) t
| Gt_int_int : ([< scopelang | dcalc | lcalc ], resolved) t
| Gt_rat_rat : ([< scopelang | dcalc | lcalc ], resolved) t
| Gt_mon_mon : ([< scopelang | dcalc | lcalc ], resolved) t
| Gt_dat_dat : ([< scopelang | dcalc | lcalc ], resolved) t
| Gt_dur_dur : ([< scopelang | dcalc | lcalc ], resolved) t
| Gte : (desugared, overloaded) t
| Gte_int_int : ([< scopelang | dcalc | lcalc ], resolved) t
| Gte_rat_rat : ([< scopelang | dcalc | lcalc ], resolved) t
| Gte_mon_mon : ([< scopelang | dcalc | lcalc ], resolved) t
| Gte_dat_dat : ([< scopelang | dcalc | lcalc ], resolved) t
| Gte_dur_dur : ([< scopelang | dcalc | lcalc ], resolved) t
(* Todo: Eq is not an overload at the moment, but it should be one. The
trick is that it needs generation of specific code for arrays, every
struct and enum: operators [Eq_structs of StructName.t], etc. *)
| Eq_int_int : ([< scopelang | dcalc | lcalc ], resolved) t
| Eq_rat_rat : ([< scopelang | dcalc | lcalc ], resolved) t
| Eq_mon_mon : ([< scopelang | dcalc | lcalc ], resolved) t
| Eq_dur_dur : ([< scopelang | dcalc | lcalc ], resolved) t
| Eq_dat_dat : ([< scopelang | dcalc | lcalc ], resolved) t
(* ternary *)
(* * polymorphic *)
| Fold : ('a any, polymorphic) t
end
type ('a, 'k) operator = ('a any, 'k) Op.t
type except = ConflictError | EmptyError | NoValueProvided | Crash
(** {2 Generic expressions} *)
(** Define a common base type for the expressions in most passes of the compiler *)
type desugared = [ `Desugared ]
type scopelang = [ `Scopelang ]
type dcalc = [ `Dcalc ]
type lcalc = [ `Lcalc ]
type 'a any = [< desugared | scopelang | dcalc | lcalc ] as 'a
(** Literals are the same throughout compilation except for the [LEmptyError]
case which is eliminated midway through. *)
type 'a glit =
@ -192,65 +258,98 @@ type ('a, 't) gexpr = (('a, 't) naked_gexpr, 't) Marked.t
- To write a function that handles cases from different ASTs, explicit the
type variables: [fun (type a) (x: a naked_gexpr) -> ...]
- For recursive functions, you may need to additionally explicit the
generalisation of the variable: [let rec f: type a . a naked_gexpr -> ...] *)
generalisation of the variable: [let rec f: type a . a naked_gexpr -> ...]
- Always think of using the pre-defined map/fold functions in [Expr] rather
than completely defining your recursion manually. *)
and ('a, 't) naked_gexpr =
(* Constructors common to all ASTs *)
| ELit : 'a glit -> ('a any, 't) naked_gexpr
| EApp : ('a, 't) gexpr * ('a, 't) gexpr list -> ('a any, 't) naked_gexpr
| EOp : operator -> ('a any, 't) naked_gexpr
| EApp : {
f : ('a, 't) gexpr;
args : ('a, 't) gexpr list;
}
-> ('a any, 't) naked_gexpr
| EOp : { op : ('a, _) operator; tys : typ list } -> ('a any, 't) naked_gexpr
| EArray : ('a, 't) gexpr list -> ('a any, 't) naked_gexpr
| EVar : ('a, 't) naked_gexpr Bindlib.var -> ('a any, 't) naked_gexpr
| EAbs :
(('a, 't) naked_gexpr, ('a, 't) gexpr) Bindlib.mbinder * typ list
| EAbs : {
binder : (('a, 't) naked_gexpr, ('a, 't) gexpr) Bindlib.mbinder;
tys : typ list;
}
-> ('a any, 't) naked_gexpr
| EIfThenElse :
('a, 't) gexpr * ('a, 't) gexpr * ('a, 't) gexpr
| EIfThenElse : {
cond : ('a, 't) gexpr;
etrue : ('a, 't) gexpr;
efalse : ('a, 't) gexpr;
}
-> ('a any, 't) naked_gexpr
| EStruct : {
name : StructName.t;
fields : ('a, 't) gexpr StructField.Map.t;
}
-> ('a any, 't) naked_gexpr
| EInj : {
name : EnumName.t;
e : ('a, 't) gexpr;
cons : EnumConstructor.t;
}
-> ('a any, 't) naked_gexpr
| EMatch : {
name : EnumName.t;
e : ('a, 't) gexpr;
cases : ('a, 't) gexpr EnumConstructor.Map.t;
}
-> ('a any, 't) naked_gexpr
(* Early stages *)
| ELocation :
'a glocation
-> (([< desugared | scopelang ] as 'a), 't) naked_gexpr
| EStruct :
StructName.t * ('a, 't) gexpr StructFieldMap.t
-> (([< desugared | scopelang ] as 'a), 't) naked_gexpr
| EStructAccess :
('a, 't) gexpr * StructFieldName.t * StructName.t
-> (([< desugared | scopelang ] as 'a), 't) naked_gexpr
| EEnumInj :
('a, 't) gexpr * EnumConstructor.t * EnumName.t
-> (([< desugared | scopelang ] as 'a), 't) naked_gexpr
| EMatchS :
('a, 't) gexpr * EnumName.t * ('a, 't) gexpr EnumConstructorMap.t
-> (([< desugared | scopelang ] as 'a), 't) naked_gexpr
| EScopeCall :
ScopeName.t * ('a, 't) gexpr ScopeVarMap.t
| EScopeCall : {
scope : ScopeName.t;
args : ('a, 't) gexpr ScopeVar.Map.t;
}
-> (([< desugared | scopelang ] as 'a), 't) naked_gexpr
| EDStructAccess : {
name_opt : StructName.t option;
e : ('a, 't) gexpr;
field : IdentName.t;
}
-> ((desugared as 'a), 't) naked_gexpr
(** [desugared] has ambiguous struct fields *)
| EStructAccess : {
name : StructName.t;
e : ('a, 't) gexpr;
field : StructField.t;
}
-> (([< scopelang | dcalc | lcalc ] as 'a), 't) naked_gexpr
(** Resolved struct/enums, after [desugared] *)
(* Lambda-like *)
| ETuple :
('a, 't) gexpr list * StructName.t option
-> (([< dcalc | lcalc ] as 'a), 't) naked_gexpr
| ETupleAccess :
('a, 't) gexpr * int * StructName.t option * typ list
-> (([< dcalc | lcalc ] as 'a), 't) naked_gexpr
| EInj :
('a, 't) gexpr * int * EnumName.t * typ list
-> (([< dcalc | lcalc ] as 'a), 't) naked_gexpr
| EMatch :
('a, 't) gexpr * ('a, 't) gexpr list * EnumName.t
-> (([< dcalc | lcalc ] as 'a), 't) naked_gexpr
| EAssert : ('a, 't) gexpr -> (([< dcalc | lcalc ] as 'a), 't) naked_gexpr
(* Default terms *)
| EDefault :
('a, 't) gexpr list * ('a, 't) gexpr * ('a, 't) gexpr
| EDefault : {
excepts : ('a, 't) gexpr list;
just : ('a, 't) gexpr;
cons : ('a, 't) gexpr;
}
-> (([< desugared | scopelang | dcalc ] as 'a), 't) naked_gexpr
| ErrorOnEmpty :
| EErrorOnEmpty :
('a, 't) gexpr
-> (([< desugared | scopelang | dcalc ] as 'a), 't) naked_gexpr
(* Lambda calculus with exceptions *)
| ETuple : ('a, 't) gexpr list -> ((lcalc as 'a), 't) naked_gexpr
| ETupleAccess : {
e : ('a, 't) gexpr;
index : int;
size : int;
}
-> ((lcalc as 'a), 't) naked_gexpr
| ERaise : except -> ((lcalc as 'a), 't) naked_gexpr
| ECatch :
('a, 't) gexpr * except * ('a, 't) gexpr
| ECatch : {
body : ('a, 't) gexpr;
exn : except;
handler : ('a, 't) gexpr;
}
-> ((lcalc as 'a), 't) naked_gexpr
type ('a, 't) boxed_gexpr = (('a, 't) naked_gexpr Bindlib.box, 't) Marked.t
@ -276,9 +375,9 @@ type typed = { pos : Pos.t; ty : typ }
(** The generic type of AST markings. Using a GADT allows functions to be
polymorphic in the marking, but still do transformations on types when
appropriate. Expected to fill the ['t] parameter of [naked_gexpr] and
[gexpr] (a ['t] annotation different from this type is used in the middle of
the typing processing, but all visible ASTs should otherwise use this. *)
appropriate. Expected to fill the ['t] parameter of [gexpr] and [gexpr] (a
['t] annotation different from this type is used in the middle of the typing
processing, but all visible ASTs should otherwise use this. *)
type _ mark = Untyped : untyped -> untyped mark | Typed : typed -> typed mark
(** Useful for errors and printing, for example *)
@ -287,11 +386,10 @@ type any_expr = AnyExpr : (_, _ mark) gexpr -> any_expr
(** {2 Higher-level program structure} *)
(** Constructs scopes and programs on top of expressions. The ['e] type
parameter throughout is expected to match instances of the [naked_gexpr]
type defined above. Markings are constrained to the [mark] GADT defined
above. Note that this structure is at the moment only relevant for [dcalc]
and [lcalc], as [scopelang] has its own scope structure, as the name
implies. *)
parameter throughout is expected to match instances of the [gexpr] type
defined above. Markings are constrained to the [mark] GADT defined above.
Note that this structure is at the moment only relevant for [dcalc] and
[lcalc], as [scopelang] has its own scope structure, as the name implies. *)
(** This kind annotation signals that the let-binding respects a structural
invariant. These invariants concern the shape of the expression in the
@ -350,14 +448,20 @@ and 'e scopes =
| ScopeDef of 'e scope_def
constraint 'e = (_ any, _ mark) gexpr
type struct_ctx = (StructFieldName.t * typ) list StructMap.t
type enum_ctx = (EnumConstructor.t * typ) list EnumMap.t
type struct_ctx = typ StructField.Map.t StructName.Map.t
type enum_ctx = typ EnumConstructor.Map.t EnumName.Map.t
type scope_out_struct = {
out_struct_name : StructName.t;
out_struct_fields : StructField.t ScopeVar.Map.t;
}
type decl_ctx = {
ctx_enums : enum_ctx;
ctx_structs : struct_ctx;
ctx_scopes : StructName.t ScopeMap.t;
(** The output structure type of every scope *)
ctx_struct_fields : StructField.t StructName.Map.t IdentName.Map.t;
(** needed for disambiguation (desugared -> scope) *)
ctx_scopes : scope_out_struct ScopeName.Map.t;
}
type 'e program = { decl_ctx : decl_ctx; scopes : 'e scopes }

View File

@ -3,4 +3,4 @@
(public_name catala.shared_ast)
(flags
(:standard -short-paths))
(libraries bindlib unionFind utils catala.runtime_ocaml))
(libraries bindlib unionFind catala_utils catala.runtime_ocaml))

View File

@ -15,7 +15,7 @@
License for the specific language governing permissions and limitations under
the License. *)
open Utils
open Catala_utils
open Definitions
(** Functions handling the types of [shared_ast] *)
@ -57,15 +57,15 @@ module Box = struct
fun em ->
B.box_apply (fun e -> Marked.mark (Marked.get_mark em) e) (Marked.unmark em)
module LiftStruct = Bindlib.Lift (StructFieldMap)
module LiftStruct = Bindlib.Lift (StructField.Map)
let lift_struct = LiftStruct.lift_box
module LiftEnum = Bindlib.Lift (EnumConstructorMap)
module LiftEnum = Bindlib.Lift (EnumConstructor.Map)
let lift_enum = LiftEnum.lift_box
module LiftScopeVars = Bindlib.Lift (ScopeVarMap)
module LiftScopeVars = Bindlib.Lift (ScopeVar.Map)
let lift_scope_vars = LiftScopeVars.lift_box
end
@ -76,61 +76,64 @@ let subst binder vars =
Bindlib.msubst binder (Array.of_list (List.map Marked.unmark vars))
let evar v mark = Marked.mark mark (Bindlib.box_var v)
let etuple args s = Box.appn args @@ fun args -> ETuple (args, s)
let etuple args = Box.appn args @@ fun args -> ETuple args
let etupleaccess e1 i s typs =
Box.app1 e1 @@ fun e1 -> ETupleAccess (e1, i, s, typs)
let einj e1 i e_name typs = Box.app1 e1 @@ fun e1 -> EInj (e1, i, e_name, typs)
let ematch arg arms e_name =
Box.app1n arg arms @@ fun arg arms -> EMatch (arg, arms, e_name)
let etupleaccess e index size =
assert (index < size);
Box.app1 e @@ fun e -> ETupleAccess { e; index; size }
let earray args = Box.appn args @@ fun args -> EArray args
let elit l mark = Marked.mark mark (Bindlib.box (ELit l))
let eabs binder typs mark =
Bindlib.box_apply (fun binder -> EAbs (binder, typs)) binder, mark
let eabs binder tys mark =
Bindlib.box_apply (fun binder -> EAbs { binder; tys }) binder, mark
let eapp e1 args = Box.app1n e1 args @@ fun e1 args -> EApp (e1, args)
let eapp f args = Box.app1n f args @@ fun f args -> EApp { f; args }
let eassert e1 = Box.app1 e1 @@ fun e1 -> EAssert e1
let eop op = Box.app0 @@ EOp op
let eop op tys = Box.app0 @@ EOp { op; tys }
let edefault excepts just cons =
Box.app2n just cons excepts
@@ fun just cons excepts -> EDefault (excepts, just, cons)
@@ fun just cons excepts -> EDefault { excepts; just; cons }
let eifthenelse e1 e2 e3 =
Box.app3 e1 e2 e3 @@ fun e1 e2 e3 -> EIfThenElse (e1, e2, e3)
let eifthenelse cond etrue efalse =
Box.app3 cond etrue efalse
@@ fun cond etrue efalse -> EIfThenElse { cond; etrue; efalse }
let eerroronempty e1 = Box.app1 e1 @@ fun e1 -> ErrorOnEmpty e1
let eerroronempty e1 = Box.app1 e1 @@ fun e1 -> EErrorOnEmpty e1
let eraise e1 = Box.app0 @@ ERaise e1
let ecatch e1 exn e2 = Box.app2 e1 e2 @@ fun e1 e2 -> ECatch (e1, exn, e2)
let ecatch body exn handler =
Box.app2 body handler @@ fun body handler -> ECatch { body; exn; handler }
let elocation loc = Box.app0 @@ ELocation loc
let estruct name (fields : ('a, 't) boxed_gexpr StructFieldMap.t) mark =
let estruct name (fields : ('a, 't) boxed_gexpr StructField.Map.t) mark =
Marked.mark mark
@@ Bindlib.box_apply
(fun fields -> EStruct (name, fields))
(Box.lift_struct (StructFieldMap.map Box.lift fields))
(fun fields -> EStruct { name; fields })
(Box.lift_struct (StructField.Map.map Box.lift fields))
let estructaccess e1 field struc =
Box.app1 e1 @@ fun e1 -> EStructAccess (e1, field, struc)
let edstructaccess e field name_opt =
Box.app1 e @@ fun e -> EDStructAccess { name_opt; e; field }
let eenuminj e1 cons enum = Box.app1 e1 @@ fun e1 -> EEnumInj (e1, cons, enum)
let estructaccess e field name =
Box.app1 e @@ fun e -> EStructAccess { name; e; field }
let ematchs e1 enum cases mark =
let einj e cons name = Box.app1 e @@ fun e -> EInj { name; e; cons }
let ematch e name cases mark =
Marked.mark mark
@@ Bindlib.box_apply2
(fun e1 cases -> EMatchS (e1, enum, cases))
(Box.lift e1)
(Box.lift_enum (EnumConstructorMap.map Box.lift cases))
(fun e cases -> EMatch { name; e; cases })
(Box.lift e)
(Box.lift_enum (EnumConstructor.Map.map Box.lift cases))
let escopecall scope_name fields mark =
let escopecall scope args mark =
Marked.mark mark
@@ Bindlib.box_apply
(fun fields -> EScopeCall (scope_name, fields))
(Box.lift_scope_vars (ScopeVarMap.map Box.lift fields))
(fun args -> EScopeCall { scope; args })
(Box.lift_scope_vars (ScopeVar.Map.map Box.lift args))
(* - Manipulation of marks - *)
@ -203,49 +206,46 @@ let maybe_ty (type m) ?(typ = TAny) (m : m mark) : typ =
(* shallow map *)
let map
(type a)
(ctx : 'ctx)
~(f : 'ctx -> (a, 'm1) gexpr -> (a, 'm2) boxed_gexpr)
~(f : (a, 'm1) gexpr -> (a, 'm2) boxed_gexpr)
(e : ((a, 'm1) naked_gexpr, 'm2) Marked.t) : (a, 'm2) boxed_gexpr =
let m = Marked.get_mark e in
match Marked.unmark e with
| ELit l -> elit l m
| EApp (e1, args) -> eapp (f ctx e1) (List.map (f ctx) args) m
| EOp op -> eop op m
| EArray args -> earray (List.map (f ctx) args) m
| EApp { f = e1; args } -> eapp (f e1) (List.map f args) m
| EOp { op; tys } -> eop op tys m
| EArray args -> earray (List.map f args) m
| EVar v -> evar (Var.translate v) m
| EAbs (binder, typs) ->
| EAbs { binder; tys } ->
let vars, body = Bindlib.unmbind binder in
let body = f ctx body in
let body = f body in
let binder = bind (Array.map Var.translate vars) body in
eabs binder typs m
| EIfThenElse (e1, e2, e3) ->
eifthenelse ((f ctx) e1) ((f ctx) e2) ((f ctx) e3) m
| ETuple (args, s) -> etuple (List.map (f ctx) args) s m
| ETupleAccess (e1, n, s_name, typs) ->
etupleaccess ((f ctx) e1) n s_name typs m
| EInj (e1, i, e_name, typs) -> einj ((f ctx) e1) i e_name typs m
| EMatch (arg, arms, e_name) ->
ematch ((f ctx) arg) (List.map (f ctx) arms) e_name m
| EAssert e1 -> eassert ((f ctx) e1) m
| EDefault (excepts, just, cons) ->
edefault (List.map (f ctx) excepts) ((f ctx) just) ((f ctx) cons) m
| ErrorOnEmpty e1 -> eerroronempty ((f ctx) e1) m
| ECatch (e1, exn, e2) -> ecatch (f ctx e1) exn (f ctx e2) m
eabs binder tys m
| EIfThenElse { cond; etrue; efalse } ->
eifthenelse (f cond) (f etrue) (f efalse) m
| ETuple args -> etuple (List.map f args) m
| ETupleAccess { e; index; size } -> etupleaccess (f e) index size m
| EInj { e; name; cons } -> einj (f e) cons name m
| EAssert e1 -> eassert (f e1) m
| EDefault { excepts; just; cons } ->
edefault (List.map f excepts) (f just) (f cons) m
| EErrorOnEmpty e1 -> eerroronempty (f e1) m
| ECatch { body; exn; handler } -> ecatch (f body) exn (f handler) m
| ERaise exn -> eraise exn m
| ELocation loc -> elocation loc m
| EStruct (name, fields) ->
let fields = StructFieldMap.map (f ctx) fields in
| EStruct { name; fields } ->
let fields = StructField.Map.map f fields in
estruct name fields m
| EStructAccess (e1, field, struc) -> estructaccess (f ctx e1) field struc m
| EEnumInj (e1, cons, enum) -> eenuminj (f ctx e1) cons enum m
| EMatchS (e1, enum, cases) ->
let cases = EnumConstructorMap.map (f ctx) cases in
ematchs (f ctx e1) enum cases m
| EScopeCall (scope_name, fields) ->
let fields = ScopeVarMap.map (f ctx) fields in
escopecall scope_name fields m
| EDStructAccess { e; field; name_opt } ->
edstructaccess (f e) field name_opt m
| EStructAccess { e; field; name } -> estructaccess (f e) field name m
| EMatch { e; name; cases } ->
let cases = EnumConstructor.Map.map f cases in
ematch (f e) name cases m
| EScopeCall { scope; args } ->
let fields = ScopeVar.Map.map f args in
escopecall scope fields m
let rec map_top_down ~f e = map () ~f:(fun () -> map_top_down ~f) (f e)
let rec map_top_down ~f e = map ~f:(map_top_down ~f) (f e)
let map_marks ~f e =
map_top_down ~f:(fun e -> Marked.(mark (f (get_mark e)) (unmark e))) e
@ -260,31 +260,130 @@ let shallow_fold
let lfold x acc = List.fold_left (fun acc x -> f x acc) acc x in
match Marked.unmark e with
| ELit _ | EOp _ | EVar _ | ERaise _ | ELocation _ -> acc
| EApp (e1, args) -> acc |> f e1 |> lfold args
| EApp { f = e; args } -> acc |> f e |> lfold args
| EArray args -> acc |> lfold args
| EAbs _ -> acc
| EIfThenElse (e1, e2, e3) -> acc |> f e1 |> f e2 |> f e3
| ETuple (args, _) -> acc |> lfold args
| ETupleAccess (e1, _, _, _) -> acc |> f e1
| EInj (e1, _, _, _) -> acc |> f e1
| EMatch (arg, arms, _) -> acc |> f arg |> lfold arms
| EAssert e1 -> acc |> f e1
| EDefault (excepts, just, cons) -> acc |> lfold excepts |> f just |> f cons
| ErrorOnEmpty e1 -> acc |> f e1
| ECatch (e1, _, e2) -> acc |> f e1 |> f e2
| EStruct (_, fields) -> acc |> StructFieldMap.fold (fun _ -> f) fields
| EStructAccess (e1, _, _) -> acc |> f e1
| EEnumInj (e1, _, _) -> acc |> f e1
| EMatchS (e1, _, cases) ->
acc |> f e1 |> EnumConstructorMap.fold (fun _ -> f) cases
| EScopeCall (_, fields) -> acc |> ScopeVarMap.fold (fun _ -> f) fields
| EIfThenElse { cond; etrue; efalse } -> acc |> f cond |> f etrue |> f efalse
| ETuple args -> acc |> lfold args
| ETupleAccess { e; _ } -> acc |> f e
| EInj { e; _ } -> acc |> f e
| EAssert e -> acc |> f e
| EDefault { excepts; just; cons } -> acc |> lfold excepts |> f just |> f cons
| EErrorOnEmpty e -> acc |> f e
| ECatch { body; handler; _ } -> acc |> f body |> f handler
| EStruct { fields; _ } -> acc |> StructField.Map.fold (fun _ -> f) fields
| EDStructAccess { e; _ } -> acc |> f e
| EStructAccess { e; _ } -> acc |> f e
| EMatch { e; cases; _ } ->
acc |> f e |> EnumConstructor.Map.fold (fun _ -> f) cases
| EScopeCall { args; _ } -> acc |> ScopeVar.Map.fold (fun _ -> f) args
(* Like [map], but also allows to gather a result bottom-up. *)
let map_gather
(type a)
~(acc : 'acc)
~(join : 'acc -> 'acc -> 'acc)
~(f : (a, 'm1) gexpr -> 'acc * (a, 'm2) boxed_gexpr)
(e : ((a, 'm1) naked_gexpr, 'm2) Marked.t) : 'acc * (a, 'm2) boxed_gexpr =
let m = Marked.get_mark e in
let lfoldmap es =
let acc, r_es =
List.fold_left
(fun (acc, es) e ->
let acc1, e = f e in
join acc acc1, e :: es)
(acc, []) es
in
acc, List.rev r_es
in
match Marked.unmark e with
| ELit l -> acc, elit l m
| EApp { f = e1; args } ->
let acc1, f = f e1 in
let acc2, args = lfoldmap args in
join acc1 acc2, eapp f args m
| EOp { op; tys } -> acc, eop op tys m
| EArray args ->
let acc, args = lfoldmap args in
acc, earray args m
| EVar v -> acc, evar (Var.translate v) m
| EAbs { binder; tys } ->
let vars, body = Bindlib.unmbind binder in
let acc, body = f body in
let binder = bind (Array.map Var.translate vars) body in
acc, eabs binder tys m
| EIfThenElse { cond; etrue; efalse } ->
let acc1, cond = f cond in
let acc2, etrue = f etrue in
let acc3, efalse = f efalse in
join (join acc1 acc2) acc3, eifthenelse cond etrue efalse m
| ETuple args ->
let acc, args = lfoldmap args in
acc, etuple args m
| ETupleAccess { e; index; size } ->
let acc, e = f e in
acc, etupleaccess e index size m
| EInj { e; name; cons } ->
let acc, e = f e in
acc, einj e cons name m
| EAssert e ->
let acc, e = f e in
acc, eassert e m
| EDefault { excepts; just; cons } ->
let acc1, excepts = lfoldmap excepts in
let acc2, just = f just in
let acc3, cons = f cons in
join (join acc1 acc2) acc3, edefault excepts just cons m
| EErrorOnEmpty e ->
let acc, e = f e in
acc, eerroronempty e m
| ECatch { body; exn; handler } ->
let acc1, body = f body in
let acc2, handler = f handler in
join acc1 acc2, ecatch body exn handler m
| ERaise exn -> acc, eraise exn m
| ELocation loc -> acc, elocation loc m
| EStruct { name; fields } ->
let acc, fields =
StructField.Map.fold
(fun cons e (acc, fields) ->
let acc1, e = f e in
join acc acc1, StructField.Map.add cons e fields)
fields
(acc, StructField.Map.empty)
in
acc, estruct name fields m
| EDStructAccess { e; field; name_opt } ->
let acc, e = f e in
acc, edstructaccess e field name_opt m
| EStructAccess { e; field; name } ->
let acc, e = f e in
acc, estructaccess e field name m
| EMatch { e; name; cases } ->
let acc, e = f e in
let acc, cases =
EnumConstructor.Map.fold
(fun cons e (acc, cases) ->
let acc1, e = f e in
join acc acc1, EnumConstructor.Map.add cons e cases)
cases
(acc, EnumConstructor.Map.empty)
in
acc, ematch e name cases m
| EScopeCall { scope; args } ->
let acc, args =
ScopeVar.Map.fold
(fun var e (acc, args) ->
let acc1, e = f e in
join acc acc1, ScopeVar.Map.add var e args)
args (acc, ScopeVar.Map.empty)
in
acc, escopecall scope args m
(* - *)
(** See [Bindlib.box_term] documentation for why we are doing that. *)
let rebox e =
let rec id_t () e = map () ~f:id_t e in
id_t () e
let rec rebox e = map ~f:rebox e
let box e = Marked.same_mark_as (Bindlib.box (Marked.unmark e)) e
let unbox (e, m) = Bindlib.unbox e, m
@ -297,99 +396,36 @@ let is_value (type a) (e : (a, _) gexpr) =
| ELit _ | EAbs _ | EOp _ | ERaise _ -> true
| _ -> false
let equal_tlit l1 l2 = l1 = l2
let compare_tlit l1 l2 = Stdlib.compare l1 l2
let rec equal_typ ty1 ty2 =
match Marked.unmark ty1, Marked.unmark ty2 with
| TLit l1, TLit l2 -> equal_tlit l1 l2
| TTuple tys1, TTuple tys2 -> equal_typ_list tys1 tys2
| TStruct n1, TStruct n2 -> StructName.equal n1 n2
| TEnum n1, TEnum n2 -> EnumName.equal n1 n2
| TOption t1, TOption t2 -> equal_typ t1 t2
| TArrow (t1, t1'), TArrow (t2, t2') -> equal_typ t1 t2 && equal_typ t1' t2'
| TArray t1, TArray t2 -> equal_typ t1 t2
| TAny, TAny -> true
| ( ( TLit _ | TTuple _ | TStruct _ | TEnum _ | TOption _ | TArrow _
| TArray _ | TAny ),
_ ) ->
false
and equal_typ_list tys1 tys2 =
try List.for_all2 equal_typ tys1 tys2 with Invalid_argument _ -> false
(* Similar to [equal_typ], but allows TAny holes *)
let rec unifiable ty1 ty2 =
match Marked.unmark ty1, Marked.unmark ty2 with
| TAny, _ | _, TAny -> true
| TLit l1, TLit l2 -> equal_tlit l1 l2
| TTuple tys1, TTuple tys2 -> unifiable_list tys1 tys2
| TStruct n1, TStruct n2 -> StructName.equal n1 n2
| TEnum n1, TEnum n2 -> EnumName.equal n1 n2
| TOption t1, TOption t2 -> unifiable t1 t2
| TArrow (t1, t1'), TArrow (t2, t2') -> unifiable t1 t2 && unifiable t1' t2'
| TArray t1, TArray t2 -> unifiable t1 t2
| ( (TLit _ | TTuple _ | TStruct _ | TEnum _ | TOption _ | TArrow _ | TArray _),
_ ) ->
false
and unifiable_list tys1 tys2 =
try List.for_all2 unifiable tys1 tys2 with Invalid_argument _ -> false
let rec compare_typ ty1 ty2 =
match Marked.unmark ty1, Marked.unmark ty2 with
| TLit l1, TLit l2 -> compare_tlit l1 l2
| TTuple tys1, TTuple tys2 -> List.compare compare_typ tys1 tys2
| TStruct n1, TStruct n2 -> StructName.compare n1 n2
| TEnum en1, TEnum en2 -> EnumName.compare en1 en2
| TOption t1, TOption t2 -> compare_typ t1 t2
| TArrow (a1, b1), TArrow (a2, b2) -> (
match compare_typ a1 a2 with 0 -> compare_typ b1 b2 | n -> n)
| TArray t1, TArray t2 -> compare_typ t1 t2
| TAny, TAny -> 0
| TLit _, _ -> -1
| _, TLit _ -> 1
| TTuple _, _ -> -1
| _, TTuple _ -> 1
| TStruct _, _ -> -1
| _, TStruct _ -> 1
| TEnum _, _ -> -1
| _, TEnum _ -> 1
| TOption _, _ -> -1
| _, TOption _ -> 1
| TArrow _, _ -> -1
| _, TArrow _ -> 1
| TArray _, _ -> -1
| _, TArray _ -> 1
let equal_lit (type a) (l1 : a glit) (l2 : a glit) =
let open Runtime.Oper in
match l1, l2 with
| LBool b1, LBool b2 -> Bool.equal b1 b2
| LBool b1, LBool b2 -> not (o_xor b1 b2)
| LEmptyError, LEmptyError -> true
| LInt n1, LInt n2 -> Runtime.( =! ) n1 n2
| LRat r1, LRat r2 -> Runtime.( =& ) r1 r2
| LMoney m1, LMoney m2 -> Runtime.( =$ ) m1 m2
| LInt n1, LInt n2 -> o_eq_int_int n1 n2
| LRat r1, LRat r2 -> o_eq_rat_rat r1 r2
| LMoney m1, LMoney m2 -> o_eq_mon_mon m1 m2
| LUnit, LUnit -> true
| LDate d1, LDate d2 -> Runtime.( =@ ) d1 d2
| LDuration d1, LDuration d2 -> Runtime.( =^ ) d1 d2
| LDate d1, LDate d2 -> o_eq_dat_dat d1 d2
| LDuration d1, LDuration d2 -> o_eq_dur_dur d1 d2
| ( ( LBool _ | LEmptyError | LInt _ | LRat _ | LMoney _ | LUnit | LDate _
| LDuration _ ),
_ ) ->
false
let compare_lit (type a) (l1 : a glit) (l2 : a glit) =
let open Runtime.Oper in
match l1, l2 with
| LBool b1, LBool b2 -> Bool.compare b1 b2
| LEmptyError, LEmptyError -> 0
| LInt n1, LInt n2 ->
if Runtime.( <! ) n1 n2 then -1 else if Runtime.( =! ) n1 n2 then 0 else 1
if o_lt_int_int n1 n2 then -1 else if o_eq_int_int n1 n2 then 0 else 1
| LRat r1, LRat r2 ->
if Runtime.( <& ) r1 r2 then -1 else if Runtime.( =& ) r1 r2 then 0 else 1
if o_lt_rat_rat r1 r2 then -1 else if o_eq_rat_rat r1 r2 then 0 else 1
| LMoney m1, LMoney m2 ->
if Runtime.( <$ ) m1 m2 then -1 else if Runtime.( =$ ) m1 m2 then 0 else 1
if o_lt_mon_mon m1 m2 then -1 else if o_eq_mon_mon m1 m2 then 0 else 1
| LUnit, LUnit -> 0
| LDate d1, LDate d2 ->
if Runtime.( <@ ) d1 d2 then -1 else if Runtime.( =@ ) d1 d2 then 0 else 1
if o_lt_dat_dat d1 d2 then -1 else if o_eq_dat_dat d1 d2 then 0 else 1
| LDuration d1, LDuration d2 -> (
(* Duration comparison in the runtime may fail, so rely on a basic
lexicographic comparison instead *)
@ -441,119 +477,6 @@ let compare_location
| _, SubScopeVar _ -> .
let equal_location a b = compare_location a b = 0
let equal_log_entries l1 l2 =
match l1, l2 with
| VarDef t1, VarDef t2 -> equal_typ (t1, Pos.no_pos) (t2, Pos.no_pos)
| x, y -> x = y
let compare_log_entries l1 l2 =
match l1, l2 with
| VarDef t1, VarDef t2 -> compare_typ (t1, Pos.no_pos) (t2, Pos.no_pos)
| BeginCall, BeginCall
| EndCall, EndCall
| PosRecordIfTrueBool, PosRecordIfTrueBool ->
0
| VarDef _, _ -> -1
| _, VarDef _ -> 1
| BeginCall, _ -> -1
| _, BeginCall -> 1
| EndCall, _ -> -1
| _, EndCall -> 1
| PosRecordIfTrueBool, _ -> .
| _, PosRecordIfTrueBool -> .
(* let equal_op_kind = Stdlib.(=) *)
let compare_op_kind = Stdlib.compare
let equal_unops op1 op2 =
match op1, op2 with
(* Log entries contain a typ which contain position information, we thus need
to descend into them *)
| Log (l1, info1), Log (l2, info2) ->
equal_log_entries l1 l2 && List.equal Uid.MarkedString.equal info1 info2
| Log _, _ | _, Log _ -> false
(* All the other cases can be discharged through equality *)
| ( ( Not | Minus _ | Length | IntToRat | MoneyToRat | RatToMoney | GetDay
| GetMonth | GetYear | FirstDayOfMonth | LastDayOfMonth | RoundMoney
| RoundDecimal ),
_ ) ->
op1 = op2
let compare_unops op1 op2 =
match op1, op2 with
| Not, Not -> 0
| Minus k1, Minus k2 -> compare_op_kind k1 k2
| Log (l1, info1), Log (l2, info2) -> (
match compare_log_entries l1 l2 with
| 0 -> List.compare Uid.MarkedString.compare info1 info2
| n -> n)
| Length, Length
| IntToRat, IntToRat
| MoneyToRat, MoneyToRat
| RatToMoney, RatToMoney
| GetDay, GetDay
| GetMonth, GetMonth
| GetYear, GetYear
| FirstDayOfMonth, FirstDayOfMonth
| LastDayOfMonth, LastDayOfMonth
| RoundMoney, RoundMoney
| RoundDecimal, RoundDecimal ->
0
| Not, _ -> -1
| _, Not -> 1
| Minus _, _ -> -1
| _, Minus _ -> 1
| Log _, _ -> -1
| _, Log _ -> 1
| Length, _ -> -1
| _, Length -> 1
| IntToRat, _ -> -1
| _, IntToRat -> 1
| MoneyToRat, _ -> -1
| _, MoneyToRat -> 1
| RatToMoney, _ -> -1
| _, RatToMoney -> 1
| GetDay, _ -> -1
| _, GetDay -> 1
| GetMonth, _ -> -1
| _, GetMonth -> 1
| GetYear, _ -> -1
| _, GetYear -> 1
| FirstDayOfMonth, _ -> -1
| _, FirstDayOfMonth -> 1
| LastDayOfMonth, _ -> -1
| _, LastDayOfMonth -> 1
| RoundMoney, _ -> -1
| _, RoundMoney -> 1
| RoundDecimal, _ -> .
| _, RoundDecimal -> .
let equal_binop = Stdlib.( = )
let compare_binop = Stdlib.compare
let equal_ternop = Stdlib.( = )
let compare_ternop = Stdlib.compare
let equal_ops op1 op2 =
match op1, op2 with
| Ternop op1, Ternop op2 -> equal_ternop op1 op2
| Binop op1, Binop op2 -> equal_binop op1 op2
| Unop op1, Unop op2 -> equal_unops op1 op2
| _, _ -> false
let compare_op op1 op2 =
match op1, op2 with
| Ternop op1, Ternop op2 -> compare_ternop op1 op2
| Binop op1, Binop op2 -> compare_binop op1 op2
| Unop op1, Unop op2 -> compare_unops op1 op2
| Ternop _, _ -> -1
| _, Ternop _ -> 1
| Binop _, _ -> -1
| _, Binop _ -> 1
| Unop _, _ -> .
| _, Unop _ -> .
let equal_except ex1 ex2 = ex1 = ex2
let compare_except ex1 ex2 = Stdlib.compare ex1 ex2
@ -567,50 +490,60 @@ and equal : type a. (a, 't) gexpr -> (a, 't) gexpr -> bool =
fun e1 e2 ->
match Marked.unmark e1, Marked.unmark e2 with
| EVar v1, EVar v2 -> Bindlib.eq_vars v1 v2
| ETuple (es1, n1), ETuple (es2, n2) -> n1 = n2 && equal_list es1 es2
| ETupleAccess (e1, id1, n1, tys1), ETupleAccess (e2, id2, n2, tys2) ->
equal e1 e2 && id1 = id2 && n1 = n2 && equal_typ_list tys1 tys2
| EInj (e1, id1, n1, tys1), EInj (e2, id2, n2, tys2) ->
equal e1 e2 && id1 = id2 && n1 = n2 && equal_typ_list tys1 tys2
| EMatch (e1, cases1, n1), EMatch (e2, cases2, n2) ->
n1 = n2 && equal e1 e2 && equal_list cases1 cases2
| ETuple es1, ETuple es2 -> equal_list es1 es2
| ( ETupleAccess { e = e1; index = id1; size = s1 },
ETupleAccess { e = e2; index = id2; size = s2 } ) ->
s1 = s2 && equal e1 e2 && id1 = id2
| EArray es1, EArray es2 -> equal_list es1 es2
| ELit l1, ELit l2 -> l1 = l2
| EAbs (b1, tys1), EAbs (b2, tys2) ->
equal_typ_list tys1 tys2
| EAbs { binder = b1; tys = tys1 }, EAbs { binder = b2; tys = tys2 } ->
Type.equal_list tys1 tys2
&&
let vars1, body1 = Bindlib.unmbind b1 in
let body2 = Bindlib.msubst b2 (Array.map (fun x -> EVar x) vars1) in
equal body1 body2
| EApp (e1, args1), EApp (e2, args2) -> equal e1 e2 && equal_list args1 args2
| EApp { f = e1; args = args1 }, EApp { f = e2; args = args2 } ->
equal e1 e2 && equal_list args1 args2
| EAssert e1, EAssert e2 -> equal e1 e2
| EOp op1, EOp op2 -> equal_ops op1 op2
| EDefault (exc1, def1, cons1), EDefault (exc2, def2, cons2) ->
| EOp { op = op1; tys = tys1 }, EOp { op = op2; tys = tys2 } ->
Operator.equal op1 op2 && Type.equal_list tys1 tys2
| ( EDefault { excepts = exc1; just = def1; cons = cons1 },
EDefault { excepts = exc2; just = def2; cons = cons2 } ) ->
equal def1 def2 && equal cons1 cons2 && equal_list exc1 exc2
| EIfThenElse (if1, then1, else1), EIfThenElse (if2, then2, else2) ->
| ( EIfThenElse { cond = if1; etrue = then1; efalse = else1 },
EIfThenElse { cond = if2; etrue = then2; efalse = else2 } ) ->
equal if1 if2 && equal then1 then2 && equal else1 else2
| ErrorOnEmpty e1, ErrorOnEmpty e2 -> equal e1 e2
| EErrorOnEmpty e1, EErrorOnEmpty e2 -> equal e1 e2
| ERaise ex1, ERaise ex2 -> equal_except ex1 ex2
| ECatch (etry1, ex1, ewith1), ECatch (etry2, ex2, ewith2) ->
| ( ECatch { body = etry1; exn = ex1; handler = ewith1 },
ECatch { body = etry2; exn = ex2; handler = ewith2 } ) ->
equal etry1 etry2 && equal_except ex1 ex2 && equal ewith1 ewith2
| ELocation l1, ELocation l2 ->
equal_location (Marked.mark Pos.no_pos l1) (Marked.mark Pos.no_pos l2)
| EStruct (s1, fields1), EStruct (s2, fields2) ->
StructName.equal s1 s2 && StructFieldMap.equal equal fields1 fields2
| EStructAccess (e1, f1, s1), EStructAccess (e2, f2, s2) ->
StructName.equal s1 s2 && StructFieldName.equal f1 f2 && equal e1 e2
| EEnumInj (e1, c1, n1), EEnumInj (e2, c2, n2) ->
| ( EStruct { name = s1; fields = fields1 },
EStruct { name = s2; fields = fields2 } ) ->
StructName.equal s1 s2 && StructField.Map.equal equal fields1 fields2
| ( EDStructAccess { e = e1; field = f1; name_opt = s1 },
EDStructAccess { e = e2; field = f2; name_opt = s2 } ) ->
Option.equal StructName.equal s1 s2 && IdentName.equal f1 f2 && equal e1 e2
| ( EStructAccess { e = e1; field = f1; name = s1 },
EStructAccess { e = e2; field = f2; name = s2 } ) ->
StructName.equal s1 s2 && StructField.equal f1 f2 && equal e1 e2
| EInj { e = e1; cons = c1; name = n1 }, EInj { e = e2; cons = c2; name = n2 }
->
EnumName.equal n1 n2 && EnumConstructor.equal c1 c2 && equal e1 e2
| EMatchS (e1, n1, cases1), EMatchS (e2, n2, cases2) ->
| ( EMatch { e = e1; name = n1; cases = cases1 },
EMatch { e = e2; name = n2; cases = cases2 } ) ->
EnumName.equal n1 n2
&& equal e1 e2
&& EnumConstructorMap.equal equal cases1 cases2
| EScopeCall (s1, fields1), EScopeCall (s2, fields2) ->
ScopeName.equal s1 s2 && ScopeVarMap.equal equal fields1 fields2
| ( ( EVar _ | ETuple _ | ETupleAccess _ | EInj _ | EMatch _ | EArray _
| ELit _ | EAbs _ | EApp _ | EAssert _ | EOp _ | EDefault _
| EIfThenElse _ | ErrorOnEmpty _ | ERaise _ | ECatch _ | ELocation _
| EStruct _ | EStructAccess _ | EEnumInj _ | EMatchS _ | EScopeCall _ ),
&& EnumConstructor.Map.equal equal cases1 cases2
| ( EScopeCall { scope = s1; args = fields1 },
EScopeCall { scope = s2; args = fields2 } ) ->
ScopeName.equal s1 s2 && ScopeVar.Map.equal equal fields1 fields2
| ( ( EVar _ | ETuple _ | ETupleAccess _ | EArray _ | ELit _ | EAbs _ | EApp _
| EAssert _ | EOp _ | EDefault _ | EIfThenElse _ | EErrorOnEmpty _
| ERaise _ | ECatch _ | ELocation _ | EStruct _ | EDStructAccess _
| EStructAccess _ | EInj _ | EMatch _ | EScopeCall _ ),
_ ) ->
false
@ -623,72 +556,76 @@ let rec compare : type a. (a, _) gexpr -> (a, _) gexpr -> int =
match[@ocamlformat "disable"] Marked.unmark e1, Marked.unmark e2 with
| ELit l1, ELit l2 ->
compare_lit l1 l2
| EApp (f1, args1), EApp (f2, args2) ->
| EApp {f=f1; args=args1}, EApp {f=f2; args=args2} ->
compare f1 f2 @@< fun () ->
List.compare compare args1 args2
| EOp op1, EOp op2 ->
compare_op op1 op2
| EOp {op=op1; tys=tys1}, EOp {op=op2; tys=tys2} ->
Operator.compare op1 op2 @@< fun () ->
List.compare Type.compare tys1 tys2
| EArray a1, EArray a2 ->
List.compare compare a1 a2
| EVar v1, EVar v2 ->
Bindlib.compare_vars v1 v2
| EAbs (binder1, typs1), EAbs (binder2, typs2) ->
List.compare compare_typ typs1 typs2 @@< fun () ->
| EAbs {binder=binder1; tys=typs1},
EAbs {binder=binder2; tys=typs2} ->
List.compare Type.compare typs1 typs2 @@< fun () ->
let _, e1, e2 = Bindlib.unmbind2 binder1 binder2 in
compare e1 e2
| EIfThenElse (i1, t1, e1), EIfThenElse (i2, t2, e2) ->
| EIfThenElse {cond=i1; etrue=t1; efalse=e1},
EIfThenElse {cond=i2; etrue=t2; efalse=e2} ->
compare i1 i2 @@< fun () ->
compare t1 t2 @@< fun () ->
compare e1 e2
| ELocation l1, ELocation l2 ->
compare_location (Marked.mark Pos.no_pos l1) (Marked.mark Pos.no_pos l2)
| EStruct (name1, field_map1), EStruct (name2, field_map2) ->
| EStruct {name=name1; fields=field_map1},
EStruct {name=name2; fields=field_map2} ->
StructName.compare name1 name2 @@< fun () ->
StructFieldMap.compare compare field_map1 field_map2
| EStructAccess (e1, field_name1, struct_name1),
EStructAccess (e2, field_name2, struct_name2) ->
StructField.Map.compare compare field_map1 field_map2
| EDStructAccess {e=e1; field=field_name1; name_opt=struct_name1},
EDStructAccess {e=e2; field=field_name2; name_opt=struct_name2} ->
compare e1 e2 @@< fun () ->
StructFieldName.compare field_name1 field_name2 @@< fun () ->
IdentName.compare field_name1 field_name2 @@< fun () ->
Option.compare StructName.compare struct_name1 struct_name2
| EStructAccess {e=e1; field=field_name1; name=struct_name1},
EStructAccess {e=e2; field=field_name2; name=struct_name2} ->
compare e1 e2 @@< fun () ->
StructField.compare field_name1 field_name2 @@< fun () ->
StructName.compare struct_name1 struct_name2
| EEnumInj (e1, cstr1, name1), EEnumInj (e2, cstr2, name2) ->
compare e1 e2 @@< fun () ->
| EMatch {e=e1; name=name1; cases=emap1},
EMatch {e=e2; name=name2; cases=emap2} ->
EnumName.compare name1 name2 @@< fun () ->
EnumConstructor.compare cstr1 cstr2
| EMatchS (e1, name1, emap1), EMatchS (e2, name2, emap2) ->
compare e1 e2 @@< fun () ->
EnumName.compare name1 name2 @@< fun () ->
EnumConstructorMap.compare compare emap1 emap2
| EScopeCall (name1, field_map1), EScopeCall (name2, field_map2) ->
EnumConstructor.Map.compare compare emap1 emap2
| EScopeCall {scope=name1; args=field_map1},
EScopeCall {scope=name2; args=field_map2} ->
ScopeName.compare name1 name2 @@< fun () ->
ScopeVarMap.compare compare field_map1 field_map2
| ETuple (es1, s1), ETuple (es2, s2) ->
Option.compare StructName.compare s1 s2 @@< fun () ->
ScopeVar.Map.compare compare field_map1 field_map2
| ETuple es1, ETuple es2 ->
List.compare compare es1 es2
| ETupleAccess (e1, n1, s1, tys1), ETupleAccess (e2, n2, s2, tys2) ->
Option.compare StructName.compare s1 s2 @@< fun () ->
| ETupleAccess {e=e1; index=n1; size=s1},
ETupleAccess {e=e2; index=n2; size=s2} ->
Int.compare s1 s2 @@< fun () ->
Int.compare n1 n2 @@< fun () ->
List.compare compare_typ tys1 tys2 @@< fun () ->
compare e1 e2
| EInj (e1, n1, name1, ts1), EInj (e2, n2, name2, ts2) ->
| EInj {e=e1; name=name1; cons=cons1},
EInj {e=e2; name=name2; cons=cons2} ->
EnumName.compare name1 name2 @@< fun () ->
Int.compare n1 n2 @@< fun () ->
List.compare compare_typ ts1 ts2 @@< fun () ->
EnumConstructor.compare cons1 cons2 @@< fun () ->
compare e1 e2
| EMatch (e1, cases1, n1), EMatch (e2, cases2, n2) ->
EnumName.compare n1 n2 @@< fun () ->
compare e1 e2 @@< fun () ->
List.compare compare cases1 cases2
| EAssert e1, EAssert e2 ->
compare e1 e2
| EDefault (exs1, just1, cons1), EDefault (exs2, just2, cons2) ->
| EDefault {excepts=exs1; just=just1; cons=cons1},
EDefault {excepts=exs2; just=just2; cons=cons2} ->
compare just1 just2 @@< fun () ->
compare cons1 cons2 @@< fun () ->
List.compare compare exs1 exs2
| ErrorOnEmpty e1, ErrorOnEmpty e2 ->
| EErrorOnEmpty e1, EErrorOnEmpty e2 ->
compare e1 e2
| ERaise ex1, ERaise ex2 ->
compare_except ex1 ex2
| ECatch (etry1, ex1, ewith1), ECatch (etry2, ex2, ewith2) ->
| ECatch {body=etry1; exn=ex1; handler=ewith1},
ECatch {body=etry2; exn=ex2; handler=ewith2} ->
compare_except ex1 ex2 @@< fun () ->
compare etry1 etry2 @@< fun () ->
compare ewith1 ewith2
@ -701,34 +638,33 @@ let rec compare : type a. (a, _) gexpr -> (a, _) gexpr -> int =
| EIfThenElse _, _ -> -1 | _, EIfThenElse _ -> 1
| ELocation _, _ -> -1 | _, ELocation _ -> 1
| EStruct _, _ -> -1 | _, EStruct _ -> 1
| EDStructAccess _, _ -> -1 | _, EDStructAccess _ -> 1
| EStructAccess _, _ -> -1 | _, EStructAccess _ -> 1
| EEnumInj _, _ -> -1 | _, EEnumInj _ -> 1
| EMatchS _, _ -> -1 | _, EMatchS _ -> 1
| EMatch _, _ -> -1 | _, EMatch _ -> 1
| EScopeCall _, _ -> -1 | _, EScopeCall _ -> 1
| ETuple _, _ -> -1 | _, ETuple _ -> 1
| ETupleAccess _, _ -> -1 | _, ETupleAccess _ -> 1
| EInj _, _ -> -1 | _, EInj _ -> 1
| EMatch _, _ -> -1 | _, EMatch _ -> 1
| EAssert _, _ -> -1 | _, EAssert _ -> 1
| EDefault _, _ -> -1 | _, EDefault _ -> 1
| ErrorOnEmpty _, _ -> . | _, ErrorOnEmpty _ -> .
| EErrorOnEmpty _, _ -> . | _, EErrorOnEmpty _ -> .
| ERaise _, _ -> -1 | _, ERaise _ -> 1
| ECatch _, _ -> . | _, ECatch _ -> .
let rec free_vars : type a. (a, 't) gexpr -> (a, 't) gexpr Var.Set.t = function
| EVar v, _ -> Var.Set.singleton v
| EAbs (binder, _), _ ->
| EAbs { binder; _ }, _ ->
let vs, body = Bindlib.unmbind binder in
Array.fold_right Var.Set.remove vs (free_vars body)
| e -> shallow_fold (fun e -> Var.Set.union (free_vars e)) e Var.Set.empty
let remove_logging_calls e =
let rec f () e =
let rec f e =
match Marked.unmark e with
| EApp ((EOp (Unop (Log _)), _), [arg]) -> map () ~f arg
| _ -> map () ~f e
| EApp { f = EOp { op = Log _; _ }, _; args = [arg] } -> map ~f arg
| _ -> map ~f e
in
f () e
f e
let format ?debug decl_ctx ppf e = Print.expr ?debug decl_ctx ppf e
@ -736,36 +672,35 @@ let rec size : type a. (a, 't) gexpr -> int =
fun e ->
match Marked.unmark e with
| EVar _ | ELit _ | EOp _ -> 1
| ETuple (args, _) -> List.fold_left (fun acc arg -> acc + size arg) 1 args
| ETuple args -> List.fold_left (fun acc arg -> acc + size arg) 1 args
| EArray args -> List.fold_left (fun acc arg -> acc + size arg) 1 args
| ETupleAccess (e1, _, _, _) -> size e1 + 1
| EInj (e1, _, _, _) -> size e1 + 1
| EAssert e1 -> size e1 + 1
| ErrorOnEmpty e1 -> size e1 + 1
| EMatch (arg, args, _) ->
List.fold_left (fun acc arg -> acc + size arg) (1 + size arg) args
| EApp (arg, args) ->
List.fold_left (fun acc arg -> acc + size arg) (1 + size arg) args
| EAbs (binder, _) ->
| ETupleAccess { e; _ } -> size e + 1
| EInj { e; _ } -> size e + 1
| EAssert e -> size e + 1
| EErrorOnEmpty e -> size e + 1
| EApp { f; args } ->
List.fold_left (fun acc arg -> acc + size arg) (1 + size f) args
| EAbs { binder; _ } ->
let _, body = Bindlib.unmbind binder in
1 + size body
| EIfThenElse (e1, e2, e3) -> 1 + size e1 + size e2 + size e3
| EDefault (exceptions, just, cons) ->
| EIfThenElse { cond; etrue; efalse } ->
1 + size cond + size etrue + size efalse
| EDefault { excepts; just; cons } ->
List.fold_left
(fun acc except -> acc + size except)
(1 + size just + size cons)
exceptions
excepts
| ERaise _ -> 1
| ECatch (etry, _, ewith) -> 1 + size etry + size ewith
| ECatch { body; handler; _ } -> 1 + size body + size handler
| ELocation _ -> 1
| EStruct (_, fields) ->
StructFieldMap.fold (fun _ e acc -> acc + 1 + size e) fields 0
| EStructAccess (e1, _, _) -> 1 + size e1
| EEnumInj (e1, _, _) -> 1 + size e1
| EMatchS (e1, _, cases) ->
EnumConstructorMap.fold (fun _ e acc -> acc + 1 + size e) cases (size e1)
| EScopeCall (_, fields) ->
ScopeVarMap.fold (fun _ e acc -> acc + 1 + size e) fields 1
| EStruct { fields; _ } ->
StructField.Map.fold (fun _ e acc -> acc + 1 + size e) fields 0
| EDStructAccess { e; _ } -> 1 + size e
| EStructAccess { e; _ } -> 1 + size e
| EMatch { e; cases; _ } ->
EnumConstructor.Map.fold (fun _ e acc -> acc + 1 + size e) cases (size e)
| EScopeCall { args; _ } ->
ScopeVar.Map.fold (fun _ e acc -> acc + 1 + size e) args 1
(* - Expression building helpers - *)
@ -794,7 +729,7 @@ let make_app e u pos =
(fun tf tx ->
match Marked.unmark tf with
| TArrow (tx', tr) ->
assert (unifiable tx.ty tx');
assert (Type.unifiable tx.ty tx');
(* wrong arg type *)
tr
| TAny -> tf
@ -818,50 +753,35 @@ let make_let_in x tau e1 e2 mpos =
let make_multiple_let_in xs taus e1s e2 mpos =
make_app (make_abs xs e2 taus mpos) e1s (pos e2)
let make_default_unboxed exceptions just cons =
let make_default_unboxed excepts just cons =
let rec bool_value = function
| ELit (LBool b), _ -> Some b
| EApp ((EOp (Unop (Log (l, _))), _), [e]), _
| EApp { f = EOp { op = Log (l, _); _ }, _; args = [e]; _ }, _
when l <> PosRecordIfTrueBool
(* we don't remove the log calls corresponding to source code
definitions !*) ->
bool_value e
| _ -> None
in
match exceptions, bool_value just, cons with
match excepts, bool_value just, cons with
| [], Some true, cons -> Marked.unmark cons
| exceptions, Some true, (EDefault ([], just, cons), _) ->
EDefault (exceptions, just, cons)
| excepts, Some true, (EDefault { excepts = []; just; cons }, _) ->
EDefault { excepts; just; cons }
| [except], Some false, _ -> Marked.unmark except
| exceptions, _, cons -> EDefault (exceptions, just, cons)
| excepts, _, cons -> EDefault { excepts; just; cons }
let make_default exceptions just cons =
Box.app2n just cons exceptions
@@ fun just cons exceptions -> make_default_unboxed exceptions just cons
let make_tuple el structname m0 =
let make_tuple el m0 =
match el with
| [] ->
etuple [] structname
(with_ty m0
(match structname with
| Some n -> TStruct n, mark_pos m0
| None -> TTuple [], mark_pos m0))
| [] -> etuple [] (with_ty m0 (TTuple [], mark_pos m0))
| el ->
let m =
fold_marks
(fun posl -> List.hd posl)
(fun ml ->
let pos = (List.hd ml).pos in
match structname with
| Some n -> TStruct n, pos
| None -> TTuple (List.map (fun t -> t.ty) ml), pos)
(fun ml -> TTuple (List.map (fun t -> t.ty) ml), (List.hd ml).pos)
(List.map (fun e -> Marked.get_mark e) el)
in
etuple el structname m
let make_struct fieldmap structname m =
let fields =
List.rev (StructFieldMap.fold (fun _ e acc -> e :: acc) fieldmap [])
in
make_tuple fields (Some structname) m
etuple el m

View File

@ -17,7 +17,7 @@
(** Functions handling the expressions of [shared_ast] *)
open Utils
open Catala_utils
open Definitions
(** {2 Boxed constructors} *)
@ -43,34 +43,10 @@ val subst :
('a, 't) gexpr list ->
('a, 't) gexpr
val etuple :
(([< dcalc | lcalc ] as 'a), 't) boxed_gexpr list ->
StructName.t option ->
't ->
('a, 't) boxed_gexpr
val etuple : (lcalc, 't) boxed_gexpr list -> 't -> (lcalc, 't) boxed_gexpr
val etupleaccess :
(([< dcalc | lcalc ] as 'a), 't) boxed_gexpr ->
int ->
StructName.t option ->
typ list ->
't ->
('a, 't) boxed_gexpr
val einj :
(([< dcalc | lcalc ] as 'a), 't) boxed_gexpr ->
int ->
EnumName.t ->
typ list ->
't ->
('a, 't) boxed_gexpr
val ematch :
(([< dcalc | lcalc ] as 'a), 't) boxed_gexpr ->
('a, 't) boxed_gexpr list ->
EnumName.t ->
't ->
('a, 't) boxed_gexpr
(lcalc, 't) boxed_gexpr -> int -> int -> 't -> (lcalc, 't) boxed_gexpr
val earray : ('a any, 't) boxed_gexpr list -> 't -> ('a, 't) boxed_gexpr
val elit : 'a any glit -> 't -> ('a, 't) boxed_gexpr
@ -90,7 +66,7 @@ val eapp :
val eassert :
(([< dcalc | lcalc ] as 'a), 't) boxed_gexpr -> 't -> ('a, 't) boxed_gexpr
val eop : operator -> 't -> (_ any, 't) boxed_gexpr
val eop : ('a any, 'k) operator -> typ list -> 't -> ('a, 't) boxed_gexpr
val edefault :
(([< desugared | scopelang | dcalc ] as 'a), 't) boxed_gexpr list ->
@ -125,34 +101,41 @@ val elocation :
val estruct :
StructName.t ->
(([< desugared | scopelang ] as 'a), 't) boxed_gexpr StructFieldMap.t ->
('a any, 't) boxed_gexpr StructField.Map.t ->
't ->
('a, 't) boxed_gexpr
val edstructaccess :
(desugared, 't) boxed_gexpr ->
IdentName.t ->
StructName.t option ->
't ->
(desugared, 't) boxed_gexpr
val estructaccess :
(([< desugared | scopelang ] as 'a), 't) boxed_gexpr ->
StructFieldName.t ->
(([< scopelang | dcalc | lcalc ] as 'a), 't) boxed_gexpr ->
StructField.t ->
StructName.t ->
't ->
('a, 't) boxed_gexpr
val eenuminj :
(([< desugared | scopelang ] as 'a), 't) boxed_gexpr ->
val einj :
('a any, 't) boxed_gexpr ->
EnumConstructor.t ->
EnumName.t ->
't ->
('a, 't) boxed_gexpr
val ematchs :
(([< desugared | scopelang ] as 'a), 't) boxed_gexpr ->
val ematch :
('a any, 't) boxed_gexpr ->
EnumName.t ->
('a, 't) boxed_gexpr EnumConstructorMap.t ->
('a, 't) boxed_gexpr EnumConstructor.Map.t ->
't ->
('a, 't) boxed_gexpr
val escopecall :
ScopeName.t ->
(([< desugared | scopelang ] as 'a), 't) boxed_gexpr ScopeVarMap.t ->
(([< desugared | scopelang ] as 'a), 't) boxed_gexpr ScopeVar.Map.t ->
't ->
('a, 't) boxed_gexpr
@ -194,28 +177,25 @@ val untype : ('a, 'm mark) gexpr -> ('a, untyped mark) boxed_gexpr
(** {2 Traversal functions} *)
val map :
'ctx ->
f:('ctx -> ('a, 't1) gexpr -> ('a, 't2) boxed_gexpr) ->
f:(('a, 't1) gexpr -> ('a, 't2) boxed_gexpr) ->
(('a, 't1) naked_gexpr, 't2) Marked.t ->
('a, 't2) boxed_gexpr
(** Flat (non-recursive) mapping on expressions.
(** Shallow mapping on expressions (non recursive): applies the given function
to all sub-terms of the given expression, and rebuilds the node.
If you want to apply a map transform to an expression, you can save up
writing a painful match over all the cases of the AST. For instance, if you
want to remove all errors on empty, you can write
When applying a map transform to an expression, this avoids expliciting all
cases that remain unchanged. For instance, if you want to remove all errors
on empty, you can write
{[
let remove_error_empty =
let rec f () e =
let rec f e =
match Marked.unmark e with
| ErrorOnEmpty e1 -> Expr.map () f e1
| _ -> Expr.map () f e
| ErrorOnEmpty e1 -> Expr.map f e1
| _ -> Expr.map f e
in
f () e
]}
The first argument of map_expr is an optional context that you can carry
around during your map traversal. *)
f e
]} *)
val map_top_down :
f:(('a, 't1) gexpr -> (('a, 't1) naked_gexpr, 't2) Marked.t) ->
@ -231,7 +211,42 @@ val shallow_fold :
(('a, 't) gexpr -> 'acc -> 'acc) -> ('a, 't) gexpr -> 'acc -> 'acc
(** Applies a function on all sub-terms of the given expression. Does not
recurse, and doesn't open binders. Useful as helper for recursive calls
within traversal functions *)
within traversal functions. This can be used to compute free variables with
e.g.:
{[
let rec free_vars = function
| EVar v, _ -> Var.Set.singleton v
| EAbs { binder; _ }, _ ->
let vs, body = Bindlib.unmbind binder in
Array.fold_right Var.Set.remove vs (free_vars body)
| e ->
shallow_fold (fun e -> Var.Set.union (free_vars e)) e Var.Set.empty
]} *)
val map_gather :
acc:'acc ->
join:('acc -> 'acc -> 'acc) ->
f:(('a, 't1) gexpr -> 'acc * ('a, 't2) boxed_gexpr) ->
(('a, 't1) naked_gexpr, 't2) Marked.t ->
'acc * ('a, 't2) boxed_gexpr
(** Shallow mapping similar to [map], but additionally allows to gather an
accumulator bottom-up. [acc] is the accumulator value returned on terminal
nodes, and [join] is used to merge accumulators from the different sub-terms
of an expression. [acc] is assumed to be a neutral element for [join].
Typically used with a set of variables used in the rewrite:
{[
let rec rewrite e =
match Marked.unmark e with
| Specific_case ->
Var.Set.singleton x, some_rewrite_fun e
| _ ->
Expr.map_gather ~acc:Var.Set.empty ~join:Var.Set.union ~f:rewrite e
}]
See [Lcalc.closure_conversion] for a real-world example. *)
(** {2 Expression building helpers} *)
@ -289,21 +304,10 @@ val make_default :
- [<ex | false :- _>], when [ex] is a single exception, is rewritten as [ex] *)
val make_tuple :
(([< dcalc | lcalc ] as 'a), 'm mark) boxed_gexpr list ->
StructName.t option ->
'm mark ->
('a, 'm mark) boxed_gexpr
(lcalc, 'm mark) boxed_gexpr list -> 'm mark -> (lcalc, 'm mark) boxed_gexpr
(** Builds a tuple; the mark argument is only used as witness and for position
when building 0-uples *)
val make_struct :
(([< dcalc | lcalc ] as 'a), 'm mark) boxed_gexpr StructFieldMap.t ->
StructName.t ->
'm mark ->
('a, 'm mark) boxed_gexpr
(** Builds the tuple of values for the given struct with proper ordering,
assuming the structfieldmap contains the fields defined for structname *)
(** {2 Transformations} *)
val remove_logging_calls : ('a any, 't) gexpr -> ('a, 't) boxed_gexpr
@ -331,8 +335,6 @@ val compare : ('a, 't) gexpr -> ('a, 't) gexpr -> int
(** Standard comparison function, suitable for e.g. [Set.Make]. Ignores position
information *)
val equal_typ : typ -> typ -> bool
val compare_typ : typ -> typ -> int
val is_value : ('a any, 't) gexpr -> bool
val free_vars : ('a any, 't) gexpr -> ('a, 't) gexpr Var.Set.t
@ -363,10 +365,10 @@ module Box : sig
a separate argument. *)
val app1 :
('a, 't) boxed_gexpr ->
(('a, 't) gexpr -> ('a, 't) naked_gexpr) ->
't ->
('a, 't) boxed_gexpr
('a, 't1) boxed_gexpr ->
(('a, 't1) gexpr -> ('a, 't2) naked_gexpr) ->
't2 ->
('a, 't2) boxed_gexpr
val app2 :
('a, 't) boxed_gexpr ->

View File

@ -0,0 +1,582 @@
(* This file is part of the Catala compiler, a specification language for tax
and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
Louis Gesbert <louis.gesbert@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not
use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *)
open Catala_utils
open Definitions
include Definitions.Op
let name : type a k. (a, k) t -> string = function
| Not -> "o_not"
| Length -> "o_length"
| GetDay -> "o_getDay"
| GetMonth -> "o_getMonth"
| GetYear -> "o_getYear"
| FirstDayOfMonth -> "o_firstDayOfMonth"
| LastDayOfMonth -> "o_lastDayOfMonth"
| Log _ -> "o_log"
| Minus -> "o_minus"
| Minus_int -> "o_minus_int"
| Minus_rat -> "o_minus_rat"
| Minus_mon -> "o_minus_mon"
| Minus_dur -> "o_minus_dur"
| ToRat -> "o_torat"
| ToRat_int -> "o_torat_int"
| ToRat_mon -> "o_torat_mon"
| ToMoney -> "o_tomoney"
| ToMoney_rat -> "o_tomoney_rat"
| Round -> "o_round"
| Round_rat -> "o_round_rat"
| Round_mon -> "o_round_mon"
| And -> "o_and"
| Or -> "o_or"
| Xor -> "o_xor"
| Eq -> "o_eq"
| Map -> "o_map"
| Concat -> "o_concat"
| Filter -> "o_filter"
| Reduce -> "o_reduce"
| Add -> "o_add"
| Add_int_int -> "o_add_int_int"
| Add_rat_rat -> "o_add_rat_rat"
| Add_mon_mon -> "o_add_mon_mon"
| Add_dat_dur -> "o_add_dat_dur"
| Add_dur_dur -> "o_add_dur_dur"
| Sub -> "o_sub"
| Sub_int_int -> "o_sub_int_int"
| Sub_rat_rat -> "o_sub_rat_rat"
| Sub_mon_mon -> "o_sub_mon_mon"
| Sub_dat_dat -> "o_sub_dat_dat"
| Sub_dat_dur -> "o_sub_dat_dur"
| Sub_dur_dur -> "o_sub_dur_dur"
| Mult -> "o_mult"
| Mult_int_int -> "o_mult_int_int"
| Mult_rat_rat -> "o_mult_rat_rat"
| Mult_mon_rat -> "o_mult_mon_rat"
| Mult_dur_int -> "o_mult_dur_int"
| Div -> "o_div"
| Div_int_int -> "o_div_int_int"
| Div_rat_rat -> "o_div_rat_rat"
| Div_mon_mon -> "o_div_mon_mon"
| Div_mon_rat -> "o_div_mon_mon"
| Lt -> "o_lt"
| Lt_int_int -> "o_lt_int_int"
| Lt_rat_rat -> "o_lt_rat_rat"
| Lt_mon_mon -> "o_lt_mon_mon"
| Lt_dur_dur -> "o_lt_dur_dur"
| Lt_dat_dat -> "o_lt_dat_dat"
| Lte -> "o_lte"
| Lte_int_int -> "o_lte_int_int"
| Lte_rat_rat -> "o_lte_rat_rat"
| Lte_mon_mon -> "o_lte_mon_mon"
| Lte_dur_dur -> "o_lte_dur_dur"
| Lte_dat_dat -> "o_lte_dat_dat"
| Gt -> "o_gt"
| Gt_int_int -> "o_gt_int_int"
| Gt_rat_rat -> "o_gt_rat_rat"
| Gt_mon_mon -> "o_gt_mon_mon"
| Gt_dur_dur -> "o_gt_dur_dur"
| Gt_dat_dat -> "o_gt_dat_dat"
| Gte -> "o_gte"
| Gte_int_int -> "o_gte_int_int"
| Gte_rat_rat -> "o_gte_rat_rat"
| Gte_mon_mon -> "o_gte_mon_mon"
| Gte_dur_dur -> "o_gte_dur_dur"
| Gte_dat_dat -> "o_gte_dat_dat"
| Eq_int_int -> "o_eq_int_int"
| Eq_rat_rat -> "o_eq_rat_rat"
| Eq_mon_mon -> "o_eq_mon_mon"
| Eq_dur_dur -> "o_eq_dur_dur"
| Eq_dat_dat -> "o_eq_dat_dat"
| Fold -> "o_fold"
let compare_log_entries l1 l2 =
match l1, l2 with
| VarDef t1, VarDef t2 -> Type.compare (t1, Pos.no_pos) (t2, Pos.no_pos)
| BeginCall, BeginCall
| EndCall, EndCall
| PosRecordIfTrueBool, PosRecordIfTrueBool ->
0
| VarDef _, _ -> -1
| _, VarDef _ -> 1
| BeginCall, _ -> -1
| _, BeginCall -> 1
| EndCall, _ -> -1
| _, EndCall -> 1
| PosRecordIfTrueBool, _ -> .
| _, PosRecordIfTrueBool -> .
let compare (type a k a2 k2) (t1 : (a, k) t) (t2 : (a2, k2) t) =
match[@ocamlformat "disable"] t1, t2 with
| Log (l1, info1), Log (l2, info2) -> (
match compare_log_entries l1 l2 with
| 0 -> List.compare Uid.MarkedString.compare info1 info2
| n -> n)
| Not, Not
| Length, Length
| GetDay, GetDay
| GetMonth, GetMonth
| GetYear, GetYear
| FirstDayOfMonth, FirstDayOfMonth
| LastDayOfMonth, LastDayOfMonth
| Minus, Minus
| Minus_int, Minus_int
| Minus_rat, Minus_rat
| Minus_mon, Minus_mon
| Minus_dur, Minus_dur
| ToRat, ToRat
| ToRat_int, ToRat_int
| ToRat_mon, ToRat_mon
| ToMoney, ToMoney
| ToMoney_rat, ToMoney_rat
| Round, Round
| Round_rat, Round_rat
| Round_mon, Round_mon
| And, And
| Or, Or
| Xor, Xor
| Eq, Eq
| Map, Map
| Concat, Concat
| Filter, Filter
| Reduce, Reduce
| Add, Add
| Add_int_int, Add_int_int
| Add_rat_rat, Add_rat_rat
| Add_mon_mon, Add_mon_mon
| Add_dat_dur, Add_dat_dur
| Add_dur_dur, Add_dur_dur
| Sub, Sub
| Sub_int_int, Sub_int_int
| Sub_rat_rat, Sub_rat_rat
| Sub_mon_mon, Sub_mon_mon
| Sub_dat_dat, Sub_dat_dat
| Sub_dat_dur, Sub_dat_dur
| Sub_dur_dur, Sub_dur_dur
| Mult, Mult
| Mult_int_int, Mult_int_int
| Mult_rat_rat, Mult_rat_rat
| Mult_mon_rat, Mult_mon_rat
| Mult_dur_int, Mult_dur_int
| Div, Div
| Div_int_int, Div_int_int
| Div_rat_rat, Div_rat_rat
| Div_mon_mon, Div_mon_mon
| Div_mon_rat, Div_mon_rat
| Lt, Lt
| Lt_int_int, Lt_int_int
| Lt_rat_rat, Lt_rat_rat
| Lt_mon_mon, Lt_mon_mon
| Lt_dat_dat, Lt_dat_dat
| Lt_dur_dur, Lt_dur_dur
| Lte, Lte
| Lte_int_int, Lte_int_int
| Lte_rat_rat, Lte_rat_rat
| Lte_mon_mon, Lte_mon_mon
| Lte_dat_dat, Lte_dat_dat
| Lte_dur_dur, Lte_dur_dur
| Gt, Gt
| Gt_int_int, Gt_int_int
| Gt_rat_rat, Gt_rat_rat
| Gt_mon_mon, Gt_mon_mon
| Gt_dat_dat, Gt_dat_dat
| Gt_dur_dur, Gt_dur_dur
| Gte, Gte
| Gte_int_int, Gte_int_int
| Gte_rat_rat, Gte_rat_rat
| Gte_mon_mon, Gte_mon_mon
| Gte_dat_dat, Gte_dat_dat
| Gte_dur_dur, Gte_dur_dur
| Eq_int_int, Eq_int_int
| Eq_rat_rat, Eq_rat_rat
| Eq_mon_mon, Eq_mon_mon
| Eq_dat_dat, Eq_dat_dat
| Eq_dur_dur, Eq_dur_dur
| Fold, Fold -> 0
| Not, _ -> -1 | _, Not -> 1
| Length, _ -> -1 | _, Length -> 1
| GetDay, _ -> -1 | _, GetDay -> 1
| GetMonth, _ -> -1 | _, GetMonth -> 1
| GetYear, _ -> -1 | _, GetYear -> 1
| FirstDayOfMonth, _ -> -1 | _, FirstDayOfMonth -> 1
| LastDayOfMonth, _ -> -1 | _, LastDayOfMonth -> 1
| Log _, _ -> -1 | _, Log _ -> 1
| Minus, _ -> -1 | _, Minus -> 1
| Minus_int, _ -> -1 | _, Minus_int -> 1
| Minus_rat, _ -> -1 | _, Minus_rat -> 1
| Minus_mon, _ -> -1 | _, Minus_mon -> 1
| Minus_dur, _ -> -1 | _, Minus_dur -> 1
| ToRat, _ -> -1 | _, ToRat -> 1
| ToRat_int, _ -> -1 | _, ToRat_int -> 1
| ToRat_mon, _ -> -1 | _, ToRat_mon -> 1
| ToMoney, _ -> -1 | _, ToMoney -> 1
| ToMoney_rat, _ -> -1 | _, ToMoney_rat -> 1
| Round, _ -> -1 | _, Round -> 1
| Round_rat, _ -> -1 | _, Round_rat -> 1
| Round_mon, _ -> -1 | _, Round_mon -> 1
| And, _ -> -1 | _, And -> 1
| Or, _ -> -1 | _, Or -> 1
| Xor, _ -> -1 | _, Xor -> 1
| Eq, _ -> -1 | _, Eq -> 1
| Map, _ -> -1 | _, Map -> 1
| Concat, _ -> -1 | _, Concat -> 1
| Filter, _ -> -1 | _, Filter -> 1
| Reduce, _ -> -1 | _, Reduce -> 1
| Add, _ -> -1 | _, Add -> 1
| Add_int_int, _ -> -1 | _, Add_int_int -> 1
| Add_rat_rat, _ -> -1 | _, Add_rat_rat -> 1
| Add_mon_mon, _ -> -1 | _, Add_mon_mon -> 1
| Add_dat_dur, _ -> -1 | _, Add_dat_dur -> 1
| Add_dur_dur, _ -> -1 | _, Add_dur_dur -> 1
| Sub, _ -> -1 | _, Sub -> 1
| Sub_int_int, _ -> -1 | _, Sub_int_int -> 1
| Sub_rat_rat, _ -> -1 | _, Sub_rat_rat -> 1
| Sub_mon_mon, _ -> -1 | _, Sub_mon_mon -> 1
| Sub_dat_dat, _ -> -1 | _, Sub_dat_dat -> 1
| Sub_dat_dur, _ -> -1 | _, Sub_dat_dur -> 1
| Sub_dur_dur, _ -> -1 | _, Sub_dur_dur -> 1
| Mult, _ -> -1 | _, Mult -> 1
| Mult_int_int, _ -> -1 | _, Mult_int_int -> 1
| Mult_rat_rat, _ -> -1 | _, Mult_rat_rat -> 1
| Mult_mon_rat, _ -> -1 | _, Mult_mon_rat -> 1
| Mult_dur_int, _ -> -1 | _, Mult_dur_int -> 1
| Div, _ -> -1 | _, Div -> 1
| Div_int_int, _ -> -1 | _, Div_int_int -> 1
| Div_rat_rat, _ -> -1 | _, Div_rat_rat -> 1
| Div_mon_mon, _ -> -1 | _, Div_mon_mon -> 1
| Div_mon_rat, _ -> -1 | _, Div_mon_rat -> 1
| Lt, _ -> -1 | _, Lt -> 1
| Lt_int_int, _ -> -1 | _, Lt_int_int -> 1
| Lt_rat_rat, _ -> -1 | _, Lt_rat_rat -> 1
| Lt_mon_mon, _ -> -1 | _, Lt_mon_mon -> 1
| Lt_dat_dat, _ -> -1 | _, Lt_dat_dat -> 1
| Lt_dur_dur, _ -> -1 | _, Lt_dur_dur -> 1
| Lte, _ -> -1 | _, Lte -> 1
| Lte_int_int, _ -> -1 | _, Lte_int_int -> 1
| Lte_rat_rat, _ -> -1 | _, Lte_rat_rat -> 1
| Lte_mon_mon, _ -> -1 | _, Lte_mon_mon -> 1
| Lte_dat_dat, _ -> -1 | _, Lte_dat_dat -> 1
| Lte_dur_dur, _ -> -1 | _, Lte_dur_dur -> 1
| Gt, _ -> -1 | _, Gt -> 1
| Gt_int_int, _ -> -1 | _, Gt_int_int -> 1
| Gt_rat_rat, _ -> -1 | _, Gt_rat_rat -> 1
| Gt_mon_mon, _ -> -1 | _, Gt_mon_mon -> 1
| Gt_dat_dat, _ -> -1 | _, Gt_dat_dat -> 1
| Gt_dur_dur, _ -> -1 | _, Gt_dur_dur -> 1
| Gte, _ -> -1 | _, Gte -> 1
| Gte_int_int, _ -> -1 | _, Gte_int_int -> 1
| Gte_rat_rat, _ -> -1 | _, Gte_rat_rat -> 1
| Gte_mon_mon, _ -> -1 | _, Gte_mon_mon -> 1
| Gte_dat_dat, _ -> -1 | _, Gte_dat_dat -> 1
| Gte_dur_dur, _ -> -1 | _, Gte_dur_dur -> 1
| Eq_int_int, _ -> -1 | _, Eq_int_int -> 1
| Eq_rat_rat, _ -> -1 | _, Eq_rat_rat -> 1
| Eq_mon_mon, _ -> -1 | _, Eq_mon_mon -> 1
| Eq_dat_dat, _ -> -1 | _, Eq_dat_dat -> 1
| Eq_dur_dur, _ -> -1 | _, Eq_dur_dur -> 1
| Fold, _ | _, Fold -> .
let equal (type a k a2 k2) (t1 : (a, k) t) (t2 : (a2, k2) t) = compare t1 t2 = 0
(* Classification of operators *)
let kind_dispatch :
type a b k.
polymorphic:((_, polymorphic) t -> b) ->
monomorphic:((_, monomorphic) t -> b) ->
?overloaded:((_, overloaded) t -> b) ->
?resolved:((_, resolved) t -> b) ->
(a, k) t ->
b =
fun ~polymorphic ~monomorphic ?(overloaded = fun _ -> assert false)
?(resolved = fun _ -> assert false) op ->
match op with
| ( Not | GetDay | GetMonth | GetYear | FirstDayOfMonth | LastDayOfMonth | And
| Or | Xor ) as op ->
monomorphic op
| (Log _ | Length | Eq | Map | Concat | Filter | Reduce | Fold) as op ->
polymorphic op
| ( Minus | ToRat | ToMoney | Round | Add | Sub | Mult | Div | Lt | Lte | Gt
| Gte ) as op ->
overloaded op
| ( Minus_int | Minus_rat | Minus_mon | Minus_dur | ToRat_int | ToRat_mon
| ToMoney_rat | Round_rat | Round_mon | Add_int_int | Add_rat_rat
| Add_mon_mon | Add_dat_dur | Add_dur_dur | Sub_int_int | Sub_rat_rat
| Sub_mon_mon | Sub_dat_dat | Sub_dat_dur | Sub_dur_dur | Mult_int_int
| Mult_rat_rat | Mult_mon_rat | Mult_dur_int | Div_int_int | Div_rat_rat
| Div_mon_mon | Div_mon_rat | Lt_int_int | Lt_rat_rat | Lt_mon_mon
| Lt_dat_dat | Lt_dur_dur | Lte_int_int | Lte_rat_rat | Lte_mon_mon
| Lte_dat_dat | Lte_dur_dur | Gt_int_int | Gt_rat_rat | Gt_mon_mon
| Gt_dat_dat | Gt_dur_dur | Gte_int_int | Gte_rat_rat | Gte_mon_mon
| Gte_dat_dat | Gte_dur_dur | Eq_int_int | Eq_rat_rat | Eq_mon_mon
| Eq_dat_dat | Eq_dur_dur ) as op ->
resolved op
(* Glorified identity... allowed operators are the same in scopelang, dcalc,
lcalc *)
let translate :
type k.
([< scopelang | dcalc | lcalc ], k) t ->
([< scopelang | dcalc | lcalc ], k) t =
fun op ->
match op with
| Length -> Length
| Log (i, l) -> Log (i, l)
| Eq -> Eq
| Map -> Map
| Concat -> Concat
| Filter -> Filter
| Reduce -> Reduce
| Fold -> Fold
| Not -> Not
| GetDay -> GetDay
| GetMonth -> GetMonth
| GetYear -> GetYear
| FirstDayOfMonth -> FirstDayOfMonth
| LastDayOfMonth -> LastDayOfMonth
| And -> And
| Or -> Or
| Xor -> Xor
| Minus_int -> Minus_int
| Minus_rat -> Minus_rat
| Minus_mon -> Minus_mon
| Minus_dur -> Minus_dur
| ToRat_int -> ToRat_int
| ToRat_mon -> ToRat_mon
| ToMoney_rat -> ToMoney_rat
| Round_rat -> Round_rat
| Round_mon -> Round_mon
| Add_int_int -> Add_int_int
| Add_rat_rat -> Add_rat_rat
| Add_mon_mon -> Add_mon_mon
| Add_dat_dur -> Add_dat_dur
| Add_dur_dur -> Add_dur_dur
| Sub_int_int -> Sub_int_int
| Sub_rat_rat -> Sub_rat_rat
| Sub_mon_mon -> Sub_mon_mon
| Sub_dat_dat -> Sub_dat_dat
| Sub_dat_dur -> Sub_dat_dur
| Sub_dur_dur -> Sub_dur_dur
| Mult_int_int -> Mult_int_int
| Mult_rat_rat -> Mult_rat_rat
| Mult_mon_rat -> Mult_mon_rat
| Mult_dur_int -> Mult_dur_int
| Div_int_int -> Div_int_int
| Div_rat_rat -> Div_rat_rat
| Div_mon_mon -> Div_mon_mon
| Div_mon_rat -> Div_mon_rat
| Lt_int_int -> Lt_int_int
| Lt_rat_rat -> Lt_rat_rat
| Lt_mon_mon -> Lt_mon_mon
| Lt_dat_dat -> Lt_dat_dat
| Lt_dur_dur -> Lt_dur_dur
| Lte_int_int -> Lte_int_int
| Lte_rat_rat -> Lte_rat_rat
| Lte_mon_mon -> Lte_mon_mon
| Lte_dat_dat -> Lte_dat_dat
| Lte_dur_dur -> Lte_dur_dur
| Gt_int_int -> Gt_int_int
| Gt_rat_rat -> Gt_rat_rat
| Gt_mon_mon -> Gt_mon_mon
| Gt_dat_dat -> Gt_dat_dat
| Gt_dur_dur -> Gt_dur_dur
| Gte_int_int -> Gte_int_int
| Gte_rat_rat -> Gte_rat_rat
| Gte_mon_mon -> Gte_mon_mon
| Gte_dat_dat -> Gte_dat_dat
| Gte_dur_dur -> Gte_dur_dur
| Eq_int_int -> Eq_int_int
| Eq_rat_rat -> Eq_rat_rat
| Eq_mon_mon -> Eq_mon_mon
| Eq_dat_dat -> Eq_dat_dat
| Eq_dur_dur -> Eq_dur_dur
let monomorphic_type (op, pos) =
let ( @- ) a b = TArrow ((TLit a, pos), b), pos in
let ( @-> ) a b = TArrow ((TLit a, pos), (TLit b, pos)), pos in
match op with
| Not -> TBool @-> TBool
| GetDay -> TDate @-> TInt
| GetMonth -> TDate @-> TInt
| GetYear -> TDate @-> TInt
| FirstDayOfMonth -> TDate @-> TDate
| LastDayOfMonth -> TDate @-> TDate
| And -> TBool @- TBool @-> TBool
| Or -> TBool @- TBool @-> TBool
| Xor -> TBool @- TBool @-> TBool
(** Rules for overloads definitions:
- the concrete operator, including its return type, is uniquely determined
by the type of the operands
- no resolved version of an operator should be the redefinition of another
one with an added conversion. For example, [int + rat -> rat] is not
acceptable (that would amount to implicit casts).
These two points can be generalised for binary operators as: when
considering an operator with type ['a -> 'b -> 'c], for any given two among
['a], ['b] and ['c], there should be a unique solution for the third. *)
let resolved_type (op, pos) =
let ( @- ) a b = TArrow ((TLit a, pos), b), pos in
let ( @-> ) a b = TArrow ((TLit a, pos), (TLit b, pos)), pos in
match op with
| Minus_int -> TInt @-> TInt
| Minus_rat -> TRat @-> TRat
| Minus_mon -> TMoney @-> TMoney
| Minus_dur -> TDuration @-> TDuration
| ToRat_int -> TInt @-> TRat
| ToRat_mon -> TMoney @-> TRat
| ToMoney_rat -> TRat @-> TMoney
| Round_rat -> TRat @-> TRat
| Round_mon -> TMoney @-> TMoney
| Add_int_int -> TInt @- TInt @-> TInt
| Add_rat_rat -> TRat @- TRat @-> TRat
| Add_mon_mon -> TMoney @- TMoney @-> TMoney
| Add_dat_dur -> TDate @- TDuration @-> TDate
| Add_dur_dur -> TDuration @- TDuration @-> TDuration
| Sub_int_int -> TInt @- TInt @-> TInt
| Sub_rat_rat -> TRat @- TRat @-> TRat
| Sub_mon_mon -> TMoney @- TMoney @-> TMoney
| Sub_dat_dat -> TDate @- TDate @-> TDuration
| Sub_dat_dur -> TDate @- TDuration @-> TDuration
| Sub_dur_dur -> TDuration @- TDuration @-> TDuration
| Mult_int_int -> TInt @- TInt @-> TInt
| Mult_rat_rat -> TRat @- TRat @-> TRat
| Mult_mon_rat -> TMoney @- TRat @-> TMoney
| Mult_dur_int -> TDuration @- TInt @-> TDuration
| Div_int_int -> TInt @- TInt @-> TRat
| Div_rat_rat -> TRat @- TRat @-> TRat
| Div_mon_mon -> TMoney @- TMoney @-> TRat
| Div_mon_rat -> TMoney @- TRat @-> TMoney
| Lt_int_int -> TInt @- TInt @-> TBool
| Lt_rat_rat -> TRat @- TRat @-> TBool
| Lt_mon_mon -> TMoney @- TMoney @-> TBool
| Lt_dat_dat -> TDate @- TDate @-> TBool
| Lt_dur_dur -> TDuration @- TDuration @-> TBool
| Lte_int_int -> TInt @- TInt @-> TBool
| Lte_rat_rat -> TRat @- TRat @-> TBool
| Lte_mon_mon -> TMoney @- TMoney @-> TBool
| Lte_dat_dat -> TDate @- TDate @-> TBool
| Lte_dur_dur -> TDuration @- TDuration @-> TBool
| Gt_int_int -> TInt @- TInt @-> TBool
| Gt_rat_rat -> TRat @- TRat @-> TBool
| Gt_mon_mon -> TMoney @- TMoney @-> TBool
| Gt_dat_dat -> TDate @- TDate @-> TBool
| Gt_dur_dur -> TDuration @- TDuration @-> TBool
| Gte_int_int -> TInt @- TInt @-> TBool
| Gte_rat_rat -> TRat @- TRat @-> TBool
| Gte_mon_mon -> TMoney @- TMoney @-> TBool
| Gte_dat_dat -> TDate @- TDate @-> TBool
| Gte_dur_dur -> TDuration @- TDuration @-> TBool
| Eq_int_int -> TInt @- TInt @-> TBool
| Eq_rat_rat -> TRat @- TRat @-> TBool
| Eq_mon_mon -> TMoney @- TMoney @-> TBool
| Eq_dat_dat -> TDate @- TDate @-> TBool
| Eq_dur_dur -> TDuration @- TDuration @-> TBool
let resolve_overload_aux (op : ('a, overloaded) t) (operands : typ_lit list) :
('b, resolved) t * [ `Straight | `Reversed ] =
match op, operands with
| Minus, [TInt] -> Minus_int, `Straight
| Minus, [TRat] -> Minus_rat, `Straight
| Minus, [TMoney] -> Minus_mon, `Straight
| Minus, [TDuration] -> Minus_dur, `Straight
| ToRat, [TInt] -> ToRat_int, `Straight
| ToRat, [TMoney] -> ToRat_mon, `Straight
| ToMoney, [TRat] -> ToMoney_rat, `Straight
| Round, [TRat] -> Round_rat, `Straight
| Round, [TMoney] -> Round_mon, `Straight
| Add, [TInt; TInt] -> Add_int_int, `Straight
| Add, [TRat; TRat] -> Add_rat_rat, `Straight
| Add, [TMoney; TMoney] -> Add_mon_mon, `Straight
| Add, [TDuration; TDuration] -> Add_dur_dur, `Straight
| Add, [TDate; TDuration] -> Add_dat_dur, `Straight
| Add, [TDuration; TDate] -> Add_dat_dur, `Reversed
| Sub, [TInt; TInt] -> Sub_int_int, `Straight
| Sub, [TRat; TRat] -> Sub_rat_rat, `Straight
| Sub, [TMoney; TMoney] -> Sub_mon_mon, `Straight
| Sub, [TDuration; TDuration] -> Sub_dur_dur, `Straight
| Sub, [TDate; TDate] -> Sub_dat_dat, `Straight
| Sub, [TDate; TDuration] -> Sub_dat_dur, `Straight
| Mult, [TInt; TInt] -> Mult_int_int, `Straight
| Mult, [TRat; TRat] -> Mult_rat_rat, `Straight
| Mult, [TMoney; TRat] -> Mult_mon_rat, `Straight
| Mult, [TRat; TMoney] -> Mult_mon_rat, `Reversed
| Mult, [TDuration; TInt] -> Mult_dur_int, `Straight
| Mult, [TInt; TDuration] -> Mult_dur_int, `Reversed
| Div, [TInt; TInt] -> Div_int_int, `Straight
| Div, [TRat; TRat] -> Div_rat_rat, `Straight
| Div, [TMoney; TMoney] -> Div_mon_mon, `Straight
| Div, [TMoney; TRat] -> Div_mon_rat, `Straight
| Lt, [TInt; TInt] -> Lt_int_int, `Straight
| Lt, [TRat; TRat] -> Lt_rat_rat, `Straight
| Lt, [TMoney; TMoney] -> Lt_mon_mon, `Straight
| Lt, [TDuration; TDuration] -> Lt_dur_dur, `Straight
| Lt, [TDate; TDate] -> Lt_dat_dat, `Straight
| Lte, [TInt; TInt] -> Lte_int_int, `Straight
| Lte, [TRat; TRat] -> Lte_rat_rat, `Straight
| Lte, [TMoney; TMoney] -> Lte_mon_mon, `Straight
| Lte, [TDuration; TDuration] -> Lte_dur_dur, `Straight
| Lte, [TDate; TDate] -> Lte_dat_dat, `Straight
| Gt, [TInt; TInt] -> Gt_int_int, `Straight
| Gt, [TRat; TRat] -> Gt_rat_rat, `Straight
| Gt, [TMoney; TMoney] -> Gt_mon_mon, `Straight
| Gt, [TDuration; TDuration] -> Gt_dur_dur, `Straight
| Gt, [TDate; TDate] -> Gt_dat_dat, `Straight
| Gte, [TInt; TInt] -> Gte_int_int, `Straight
| Gte, [TRat; TRat] -> Gte_rat_rat, `Straight
| Gte, [TMoney; TMoney] -> Gte_mon_mon, `Straight
| Gte, [TDuration; TDuration] -> Gte_dur_dur, `Straight
| Gte, [TDate; TDate] -> Gte_dat_dat, `Straight
| ( ( Minus | ToRat | ToMoney | Round | Add | Sub | Mult | Div | Lt | Lte | Gt
| Gte ),
_ ) ->
raise Not_found
let resolve_overload
ctx
(op : ('a, overloaded) t Marked.pos)
(operands : typ list) : ('b, resolved) t * [ `Straight | `Reversed ] =
try
let operands =
List.map
(fun t ->
match Marked.unmark t with TLit tl -> tl | _ -> raise Not_found)
operands
in
resolve_overload_aux (Marked.unmark op) operands
with Not_found ->
Errors.raise_multispanned_error
((None, Marked.get_mark op)
:: List.map
(fun ty ->
( Some
(Format.asprintf "Type %a coming from expression:"
(Print.typ ctx) ty),
Marked.get_mark ty ))
operands)
"I don't know how to apply operator %a on types %a" Print.operator
(Marked.unmark op)
(Format.pp_print_list
~pp_sep:(fun ppf () -> Format.fprintf ppf " and@ ")
(Print.typ ctx))
operands
let overload_type ctx (op : ('a, overloaded) t Marked.pos) (operands : typ list)
: typ =
let rop = fst (resolve_overload ctx op operands) in
resolved_type (Marked.same_mark_as rop op)

View File

@ -0,0 +1,85 @@
(* This file is part of the Catala compiler, a specification language for tax
and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
Louis Gesbert <louis.gesbert@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not
use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *)
(** {1 Catala operator utilities} *)
(** Resolving operators from the surface syntax proceeds in three steps:
- During desugaring, the operators may remain untyped (with [TAny]) or, if
they have an explicit type suffix (e.g. the [$] for "money" in [+$]),
their operands types are already explicited in the [EOp] expression node.
- {!modules:Shared_ast.Typing} will then enforce these constraints in
addition to the known built-in type for each operator (e.g.
[Eq: 'a -> 'a -> 'a] isn't encoded in the first-order AST types).
- Finally, during {!modules:Scopelang.From_desugared}, these types are
leveraged to resolve the overloaded operators to their concrete,
monomorphic counterparts
*)
open Catala_utils
open Definitions
include module type of Definitions.Op
val equal : ('a1, 'k1) t -> ('a2, 'k2) t -> bool
val compare : ('a1, 'k1) t -> ('a2, 'k2) t -> int
val name : ('a, 'k) t -> string
(** Returns the operator name as a valid ident starting with a lowercase
character. This is different from Print.operator which returns operator
symbols, e.g. [+$]. *)
val kind_dispatch :
polymorphic:((_ any, polymorphic) t -> 'b) ->
monomorphic:((_ any, monomorphic) t -> 'b) ->
?overloaded:((desugared, overloaded) t -> 'b) ->
?resolved:(([< scopelang | dcalc | lcalc ], resolved) t -> 'b) ->
('a, 'k) t ->
'b
(** Calls one of the supplied functions depending on the kind of the operator *)
val translate :
([< scopelang | dcalc | lcalc ], 'k) t ->
([< scopelang | dcalc | lcalc ], 'k) t
(** An identity function that allows translating an operator between different
passes that don't change operator types *)
(** {2 Getting the types of operators} *)
val monomorphic_type : ('a any, monomorphic) t Marked.pos -> typ
val resolved_type :
([< scopelang | dcalc | lcalc ], resolved) t Marked.pos -> typ
val overload_type :
decl_ctx -> (desugared, overloaded) t Marked.pos -> typ list -> typ
(** The type for typing overloads is different since the types of the operands
are required in advance.
@raise a detailed user error if no matching operator can be found *)
(** Polymorphic operators are typed directly within [Typing], since their types
may contain type variables that can't be expressed outside of it*)
(** {2 Overload handling} *)
val resolve_overload :
decl_ctx ->
(desugared, overloaded) t Marked.pos ->
typ list ->
([< scopelang | dcalc | lcalc ], resolved) t * [ `Straight | `Reversed ]
(** Some overloads are sugar for an operation with reversed operands, e.g.
[TRat * TMoney] is using [mult_mon_rat]. [`Reversed] is returned to signify
this case. *)

View File

@ -14,8 +14,7 @@
License for the specific language governing permissions and limitations under
the License. *)
open Utils
open String_common
open Catala_utils
open Definitions
let typ_needs_parens (ty : typ) : bool =
@ -26,27 +25,28 @@ let uid_list (fmt : Format.formatter) (infos : Uid.MarkedString.info list) :
Format.pp_print_list
~pp_sep:(fun fmt () -> Format.pp_print_char fmt '.')
(fun fmt info ->
Utils.Cli.format_with_style
(if begins_with_uppercase (Marked.unmark info) then [ANSITerminal.red]
Cli.format_with_style
(if String.begins_with_uppercase (Marked.unmark info) then
[ANSITerminal.red]
else [])
fmt
(Utils.Uid.MarkedString.to_string info))
(Uid.MarkedString.to_string info))
fmt infos
let keyword (fmt : Format.formatter) (s : string) : unit =
Utils.Cli.format_with_style [ANSITerminal.red] fmt s
Cli.format_with_style [ANSITerminal.red] fmt s
let base_type (fmt : Format.formatter) (s : string) : unit =
Utils.Cli.format_with_style [ANSITerminal.yellow] fmt s
Cli.format_with_style [ANSITerminal.yellow] fmt s
let punctuation (fmt : Format.formatter) (s : string) : unit =
Utils.Cli.format_with_style [ANSITerminal.cyan] fmt s
Cli.format_with_style [ANSITerminal.cyan] fmt s
let operator (fmt : Format.formatter) (s : string) : unit =
Utils.Cli.format_with_style [ANSITerminal.green] fmt s
let op_style (fmt : Format.formatter) (s : string) : unit =
Cli.format_with_style [ANSITerminal.green] fmt s
let lit_style (fmt : Format.formatter) (s : string) : unit =
Utils.Cli.format_with_style [ANSITerminal.yellow] fmt s
Cli.format_with_style [ANSITerminal.yellow] fmt s
let tlit (fmt : Format.formatter) (l : typ_lit) : unit =
base_type fmt
@ -68,7 +68,7 @@ let location (type a) (fmt : Format.formatter) (l : a glocation) : unit =
ScopeVar.format_t (Marked.unmark subvar)
let enum_constructor (fmt : Format.formatter) (c : EnumConstructor.t) : unit =
Utils.Cli.format_with_style [ANSITerminal.magenta] fmt
Cli.format_with_style [ANSITerminal.magenta] fmt
(Format.asprintf "%a" EnumConstructor.format_t c)
let rec typ (ctx : decl_ctx option) (fmt : Format.formatter) (ty : typ) : unit =
@ -81,7 +81,7 @@ let rec typ (ctx : decl_ctx option) (fmt : Format.formatter) (ty : typ) : unit =
| TTuple ts ->
Format.fprintf fmt "@[<hov 2>(%a)@]"
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@ %a@ " operator "*")
~pp_sep:(fun fmt () -> Format.fprintf fmt "@ %a@ " op_style "*")
typ)
ts
| TStruct s -> (
@ -94,9 +94,9 @@ let rec typ (ctx : decl_ctx option) (fmt : Format.formatter) (ty : typ) : unit =
~pp_sep:(fun fmt () -> Format.fprintf fmt "%a@ " punctuation ";")
(fun fmt (field, mty) ->
Format.fprintf fmt "%a%a%a%a@ %a" punctuation "\""
StructFieldName.format_t field punctuation "\"" punctuation ":"
typ mty))
(StructMap.find s ctx.ctx_structs)
StructField.format_t field punctuation "\"" punctuation ":" typ
mty))
(StructField.Map.bindings (StructName.Map.find s ctx.ctx_structs))
punctuation "}")
| TEnum e -> (
match ctx with
@ -109,11 +109,11 @@ let rec typ (ctx : decl_ctx option) (fmt : Format.formatter) (ty : typ) : unit =
(fun fmt (case, mty) ->
Format.fprintf fmt "%a%a@ %a" enum_constructor case punctuation ":"
typ mty))
(EnumMap.find e ctx.ctx_enums)
(EnumConstructor.Map.bindings (EnumName.Map.find e ctx.ctx_enums))
punctuation "]")
| TOption t -> Format.fprintf fmt "@[<hov 2>%a@ %a@]" base_type "option" typ t
| TArrow (t1, t2) ->
Format.fprintf fmt "@[<hov 2>%a %a@ %a@]" typ_with_parens t1 operator ""
Format.fprintf fmt "@[<hov 2>%a %a@ %a@]" typ_with_parens t1 op_style ""
typ t2
| TArray t1 ->
Format.fprintf fmt "@[<hov 2>%a@ %a@]" base_type "collection" typ t1
@ -127,9 +127,9 @@ let lit (type a) (fmt : Format.formatter) (l : a glit) : unit =
| LUnit -> lit_style fmt "()"
| LRat i ->
lit_style fmt
(Runtime.decimal_to_string ~max_prec_digits:!Utils.Cli.max_prec_digits i)
(Runtime.decimal_to_string ~max_prec_digits:!Cli.max_prec_digits i)
| LMoney e -> (
match !Utils.Cli.locale_lang with
match !Cli.locale_lang with
| En -> lit_style fmt (Format.asprintf "$%s" (Runtime.money_to_string e))
| Fr -> lit_style fmt (Format.asprintf "%s €" (Runtime.money_to_string e))
| Pl -> lit_style fmt (Format.asprintf "%s PLN" (Runtime.money_to_string e))
@ -137,72 +137,112 @@ let lit (type a) (fmt : Format.formatter) (l : a glit) : unit =
| LDate d -> lit_style fmt (Runtime.date_to_string d)
| LDuration d -> lit_style fmt (Runtime.duration_to_string d)
let op_kind (fmt : Format.formatter) (k : op_kind) =
Format.fprintf fmt "%s"
(match k with
| KInt -> ""
| KRat -> "."
| KMoney -> "$"
| KDate -> "@"
| KDuration -> "^")
let binop (fmt : Format.formatter) (op : binop) : unit =
operator fmt
(match op with
| Add k -> Format.asprintf "+%a" op_kind k
| Sub k -> Format.asprintf "-%a" op_kind k
| Mult k -> Format.asprintf "*%a" op_kind k
| Div k -> Format.asprintf "/%a" op_kind k
| And -> "&&"
| Or -> "||"
| Xor -> "xor"
| Eq -> "="
| Neq -> "!="
| Lt k -> Format.asprintf "%s%a" "<" op_kind k
| Lte k -> Format.asprintf "%s%a" "<=" op_kind k
| Gt k -> Format.asprintf "%s%a" ">" op_kind k
| Gte k -> Format.asprintf "%s%a" ">=" op_kind k
| Concat -> "++"
| Map -> "map"
| Filter -> "filter")
let ternop (fmt : Format.formatter) (op : ternop) : unit =
match op with Fold -> keyword fmt "fold"
let log_entry (fmt : Format.formatter) (entry : log_entry) : unit =
Format.fprintf fmt "@<2>%a"
(fun fmt -> function
| VarDef _ -> Utils.Cli.format_with_style [ANSITerminal.blue] fmt ""
| BeginCall -> Utils.Cli.format_with_style [ANSITerminal.yellow] fmt ""
| EndCall -> Utils.Cli.format_with_style [ANSITerminal.yellow] fmt ""
| VarDef _ -> Cli.format_with_style [ANSITerminal.blue] fmt ""
| BeginCall -> Cli.format_with_style [ANSITerminal.yellow] fmt ""
| EndCall -> Cli.format_with_style [ANSITerminal.yellow] fmt ""
| PosRecordIfTrueBool ->
Utils.Cli.format_with_style [ANSITerminal.green] fmt "")
Cli.format_with_style [ANSITerminal.green] fmt "")
entry
let unop (fmt : Format.formatter) (op : unop) : unit =
let operator_to_string : type a k. (a, k) Op.t -> string = function
| Not -> "~"
| Length -> "length"
| GetDay -> "get_day"
| GetMonth -> "get_month"
| GetYear -> "get_year"
| FirstDayOfMonth -> "first_day_of_month"
| LastDayOfMonth -> "last_day_of_month"
| ToRat -> "to_rat"
| ToRat_int -> "to_rat_int"
| ToRat_mon -> "to_rat_mon"
| ToMoney -> "to_mon"
| ToMoney_rat -> "to_mon_rat"
| Round -> "round"
| Round_rat -> "round_rat"
| Round_mon -> "round_mon"
| Log _ -> "Log"
| Minus -> "-"
| Minus_int -> "-!"
| Minus_rat -> "-."
| Minus_mon -> "-$"
| Minus_dur -> "-^"
| And -> "&&"
| Or -> "||"
| Xor -> "xor"
| Eq -> "="
| Map -> "map"
| Reduce -> "reduce"
| Concat -> "++"
| Filter -> "filter"
| Add -> "+"
| Add_int_int -> "+!"
| Add_rat_rat -> "+."
| Add_mon_mon -> "+$"
| Add_dat_dur -> "+@"
| Add_dur_dur -> "+^"
| Sub -> "-"
| Sub_int_int -> "-!"
| Sub_rat_rat -> "-."
| Sub_mon_mon -> "-$"
| Sub_dat_dat -> "-@"
| Sub_dat_dur -> "-@^"
| Sub_dur_dur -> "-^"
| Mult -> "*"
| Mult_int_int -> "*!"
| Mult_rat_rat -> "*."
| Mult_mon_rat -> "*$"
| Mult_dur_int -> "*^"
| Div -> "/"
| Div_int_int -> "/!"
| Div_rat_rat -> "/."
| Div_mon_mon -> "/$"
| Div_mon_rat -> "/$."
| Lt -> "<"
| Lt_int_int -> "<!"
| Lt_rat_rat -> "<."
| Lt_mon_mon -> "<$"
| Lt_dur_dur -> "<^"
| Lt_dat_dat -> "<@"
| Lte -> "<="
| Lte_int_int -> "<=!"
| Lte_rat_rat -> "<=."
| Lte_mon_mon -> "<=$"
| Lte_dur_dur -> "<=^"
| Lte_dat_dat -> "<=@"
| Gt -> ">"
| Gt_int_int -> ">!"
| Gt_rat_rat -> ">."
| Gt_mon_mon -> ">$"
| Gt_dur_dur -> ">^"
| Gt_dat_dat -> ">@"
| Gte -> ">="
| Gte_int_int -> ">=!"
| Gte_rat_rat -> ">=."
| Gte_mon_mon -> ">=$"
| Gte_dur_dur -> ">=^"
| Gte_dat_dat -> ">=@"
| Eq_int_int -> "=!"
| Eq_rat_rat -> "=."
| Eq_mon_mon -> "=$"
| Eq_dur_dur -> "=^"
| Eq_dat_dat -> "=@"
| Fold -> "fold"
let operator (type k) (fmt : Format.formatter) (op : ('a, k) operator) : unit =
match op with
| Minus _ -> Format.pp_print_string fmt "-"
| Not -> Format.pp_print_string fmt "~"
| Log (entry, infos) ->
Format.fprintf fmt "log@[<hov 2>[%a|%a]@]" log_entry entry
Format.fprintf fmt "%a@[<hov 2>[%a|%a]@]" op_style "log" log_entry entry
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ".")
(fun fmt info -> Utils.Uid.MarkedString.format_info fmt info))
(fun fmt info -> Uid.MarkedString.format fmt info))
infos
| Length -> Format.pp_print_string fmt "length"
| IntToRat -> Format.pp_print_string fmt "int_to_rat"
| MoneyToRat -> Format.pp_print_string fmt "money_to_rat"
| RatToMoney -> Format.pp_print_string fmt "rat_to_money"
| GetDay -> Format.pp_print_string fmt "get_day"
| GetMonth -> Format.pp_print_string fmt "get_month"
| GetYear -> Format.pp_print_string fmt "get_year"
| FirstDayOfMonth -> Format.pp_print_string fmt "first_day_of_month"
| LastDayOfMonth -> Format.pp_print_string fmt "last_day_of_month"
| RoundMoney -> Format.pp_print_string fmt "round_money"
| RoundDecimal -> Format.pp_print_string fmt "round_decimal"
| op -> Format.fprintf fmt "%a" op_style (operator_to_string op)
let except (fmt : Format.formatter) (exn : except) : unit =
operator fmt
op_style fmt
(match exn with
| EmptyError -> "EmptyError"
| ConflictError -> "ConflictError"
@ -215,7 +255,7 @@ let var_debug fmt v =
let var fmt v = Format.pp_print_string fmt (Bindlib.name_of v)
let needs_parens (type a) (e : (a, _) gexpr) : bool =
match Marked.unmark e with EAbs _ | ETuple (_, Some _) -> true | _ -> false
match Marked.unmark e with EAbs _ | EStruct _ -> true | _ -> false
let rec expr_aux :
type a.
@ -228,6 +268,7 @@ let rec expr_aux :
fun ?(debug = false) ctx bnd_ctx fmt e ->
let exprb bnd_ctx e = expr_aux ~debug ctx bnd_ctx e in
let expr e = exprb bnd_ctx e in
let var = if debug then var_debug else var in
let with_parens fmt e =
if needs_parens e then (
punctuation fmt "(";
@ -236,79 +277,28 @@ let rec expr_aux :
else expr fmt e
in
match Marked.unmark e with
| EVar v -> if debug then var_debug fmt v else var fmt v
| ETuple (es, None) ->
| EVar v -> var fmt v
| ETuple es ->
Format.fprintf fmt "@[<hov 2>%a%a%a@]" punctuation "("
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ")
(fun fmt e -> expr fmt e))
es punctuation ")"
| ETuple (es, Some s) -> (
match ctx with
| None -> expr fmt (Marked.same_mark_as (ETuple (es, None)) e)
| Some ctx ->
Format.fprintf fmt "@[<hov 2>%a@ @[<hov 2>%a%a%a@]@]" StructName.format_t
s punctuation "{"
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "%a@ " punctuation ";")
(fun fmt (e, struct_field) ->
Format.fprintf fmt "%a%a%a%a@ %a" punctuation "\""
StructFieldName.format_t struct_field punctuation "\""
punctuation "=" expr e))
(List.combine es (List.map fst (StructMap.find s ctx.ctx_structs)))
punctuation "}")
| EArray es ->
Format.fprintf fmt "@[<hov 2>%a%a%a@]" punctuation "["
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ";@ ")
(fun fmt e -> expr fmt e))
es punctuation "]"
| ETupleAccess (e1, n, s, _ts) -> (
match s, ctx with
| None, _ | _, None ->
expr fmt e1;
punctuation fmt ".";
Format.pp_print_int fmt n
| Some s, Some ctx ->
expr fmt e1;
operator fmt ".";
punctuation fmt "\"";
StructFieldName.format_t fmt
(fst (List.nth (StructMap.find s ctx.ctx_structs) n));
punctuation fmt "\"")
| EInj (e, n, en, _ts) -> (
match ctx with
| None ->
Format.fprintf fmt "@[<hov 2>%a[%d]@ %a@]" EnumName.format_t en n expr e
| Some ctx ->
Format.fprintf fmt "@[<hov 2>%a@ %a@]" enum_constructor
(fst (List.nth (EnumMap.find en ctx.ctx_enums) n))
expr e)
| EMatch (e, es, e_name) -> (
match ctx with
| None ->
Format.fprintf fmt "@[<hov 0>%a@ @[<hov 2>%a@]@ %a@ %a@]" keyword "match"
expr e keyword "with"
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n")
(fun fmt (e, i) ->
Format.fprintf fmt "@[<hov 2>%a %a[%d]%a@ %a@]" punctuation "|"
EnumName.format_t e_name i punctuation ":" expr e))
(List.mapi (fun i e -> e, i) es)
| Some ctx ->
Format.fprintf fmt "@[<hov 0>%a@ @[<hov 2>%a@]@ %a@ %a@]" keyword "match"
expr e keyword "with"
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n")
(fun fmt (e, c) ->
Format.fprintf fmt "@[<hov 2>%a %a%a@ %a@]" punctuation "|"
enum_constructor c punctuation ":" expr e))
(List.combine es (List.map fst (EnumMap.find e_name ctx.ctx_enums))))
| ETupleAccess { e; index; _ } ->
expr fmt e;
punctuation fmt ".";
Format.pp_print_int fmt index
| ELit l -> lit fmt l
| EApp ((EAbs (binder, taus), _), args) ->
| EApp { f = EAbs { binder; tys }, _; args } ->
let xs, body, bnd_ctx = Bindlib.unmbind_in bnd_ctx binder in
let expr = exprb bnd_ctx in
let xs_tau = List.mapi (fun i tau -> xs.(i), tau) taus in
let xs_tau = List.mapi (fun i tau -> xs.(i), tau) tys in
let xs_tau_arg = List.map2 (fun (x, tau) arg -> x, tau, arg) xs_tau args in
Format.fprintf fmt "%a%a"
(Format.pp_print_list
@ -318,10 +308,10 @@ let rec expr_aux :
"let" var x punctuation ":" (typ ctx) tau punctuation "=" expr arg
keyword "in"))
xs_tau_arg expr body
| EAbs (binder, taus) ->
| EAbs { binder; tys } ->
let xs, body, bnd_ctx = Bindlib.unmbind_in bnd_ctx binder in
let expr = exprb bnd_ctx in
let xs_tau = List.mapi (fun i tau -> xs.(i), tau) taus in
let xs_tau = List.mapi (fun i tau -> xs.(i), tau) tys in
Format.fprintf fmt "@[<hov 2>%a @[<hov 2>%a@] %a@ %a@]" punctuation "λ"
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ")
@ -329,29 +319,28 @@ let rec expr_aux :
Format.fprintf fmt "%a%a%a %a%a" punctuation "(" var x punctuation
":" (typ ctx) tau punctuation ")"))
xs_tau punctuation "" expr body
| EApp ((EOp (Binop ((Map | Filter) as op)), _), [arg1; arg2]) ->
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@]" binop op with_parens arg1
| EApp { f = EOp { op = (Map | Filter) as op; _ }, _; args = [arg1; arg2] } ->
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@]" operator op with_parens arg1
with_parens arg2
| EApp ((EOp (Binop op), _), [arg1; arg2]) ->
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@]" with_parens arg1 binop op
| EApp { f = EOp { op; _ }, _; args = [arg1; arg2] } ->
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@]" with_parens arg1 operator op
with_parens arg2
| EApp ((EOp (Unop (Log _)), _), [arg1]) when not debug -> expr fmt arg1
| EApp ((EOp (Unop op), _), [arg1]) ->
Format.fprintf fmt "@[<hov 2>%a@ %a@]" unop op with_parens arg1
| EApp (f, args) ->
| EApp { f = EOp { op = Log _; _ }, _; args = [arg1] } when not debug ->
expr fmt arg1
| EApp { f = EOp { op; _ }, _; args = [arg1] } ->
Format.fprintf fmt "@[<hov 2>%a@ %a@]" operator op with_parens arg1
| EApp { f; args } ->
Format.fprintf fmt "@[<hov 2>%a@ %a@]" expr f
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ")
with_parens)
args
| EIfThenElse (e1, e2, e3) ->
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@ %a@ %a@ %a@]" keyword "if" expr e1
keyword "then" expr e2 keyword "else" expr e3
| EOp (Ternop op) -> ternop fmt op
| EOp (Binop op) -> binop fmt op
| EOp (Unop op) -> unop fmt op
| EDefault (exceptions, just, cons) ->
if List.length exceptions = 0 then
| EIfThenElse { cond; etrue; efalse } ->
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@ %a@ %a@ %a@]" keyword "if" expr
cond keyword "then" expr etrue keyword "else" expr efalse
| EOp { op; _ } -> operator fmt op
| EDefault { excepts; just; cons } ->
if List.length excepts = 0 then
Format.fprintf fmt "@[<hov 2>%a%a@ %a@ %a%a@]" punctuation "" expr just
punctuation "" expr cons punctuation ""
else
@ -359,45 +348,48 @@ let rec expr_aux :
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "%a@ " punctuation ",")
expr)
exceptions punctuation "|" expr just punctuation "" expr cons
punctuation ""
| ErrorOnEmpty e' ->
Format.fprintf fmt "%a@ %a" operator "error_empty" with_parens e'
excepts punctuation "|" expr just punctuation "" expr cons punctuation
""
| EErrorOnEmpty e' ->
Format.fprintf fmt "%a@ %a" op_style "error_empty" with_parens e'
| EAssert e' ->
Format.fprintf fmt "@[<hov 2>%a@ %a%a%a@]" keyword "assert" punctuation "("
expr e' punctuation ")"
| ECatch (e1, exn, e2) ->
| ECatch { body; exn; handler } ->
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@ %a ->@ %a@]" keyword "try"
with_parens e1 keyword "with" except exn with_parens e2
with_parens body keyword "with" except exn with_parens handler
| ERaise exn ->
Format.fprintf fmt "@[<hov 2>%a@ %a@]" keyword "raise" except exn
| ELocation loc -> location fmt loc
| EStruct (name, fields) ->
Format.fprintf fmt " @[<hov 2>%a@ %a@ %a@ %a@]" StructName.format_t name
| EDStructAccess { e; field; _ } ->
Format.fprintf fmt "%a%a%a%a%a" expr e punctuation "." punctuation "\""
IdentName.format_t field punctuation "\""
| EStruct { name; fields } ->
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@ %a@]" StructName.format_t name
punctuation "{"
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "%a@ " punctuation ";")
(fun fmt (field_name, field_expr) ->
Format.fprintf fmt "%a%a%a%a@ %a" punctuation "\""
StructFieldName.format_t field_name punctuation "\"" punctuation
"=" expr field_expr))
(StructFieldMap.bindings fields)
StructField.format_t field_name punctuation "\"" punctuation "="
expr field_expr))
(StructField.Map.bindings fields)
punctuation "}"
| EStructAccess (e1, field, _) ->
Format.fprintf fmt "%a%a%a%a%a" expr e1 punctuation "." punctuation "\""
StructFieldName.format_t field punctuation "\""
| EEnumInj (e1, cons, _) ->
Format.fprintf fmt "%a@ %a" EnumConstructor.format_t cons expr e1
| EMatchS (e1, _, cases) ->
| EStructAccess { e; field; _ } ->
Format.fprintf fmt "%a%a%a%a%a" expr e punctuation "." punctuation "\""
StructField.format_t field punctuation "\""
| EInj { e; cons; _ } ->
Format.fprintf fmt "%a@ %a" EnumConstructor.format_t cons expr e
| EMatch { e; cases; _ } ->
Format.fprintf fmt "@[<hov 0>%a@ @[<hov 2>%a@]@ %a@ %a@]" keyword "match"
expr e1 keyword "with"
expr e keyword "with"
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n")
(fun fmt (cons_name, case_expr) ->
Format.fprintf fmt "@[<hov 2>%a %a@ %a@ %a@]" punctuation "|"
enum_constructor cons_name punctuation "" expr case_expr))
(EnumConstructorMap.bindings cases)
| EScopeCall (scope, fields) ->
(EnumConstructor.Map.bindings cases)
| EScopeCall { scope; args } ->
Format.pp_open_hovbox fmt 2;
ScopeName.format_t fmt scope;
Format.pp_print_space fmt ();
@ -411,7 +403,7 @@ let rec expr_aux :
Format.fprintf fmt "%a%a%a%a@ %a" punctuation "\"" ScopeVar.format_t
field_name punctuation "\"" punctuation "=" expr field_expr)
fmt
(ScopeVarMap.bindings fields);
(ScopeVar.Map.bindings args);
Format.pp_close_box fmt ();
punctuation fmt "}";
Format.pp_close_box fmt ()

View File

@ -16,7 +16,7 @@
(** Printing functions for the default calculus AST *)
open Utils
open Catala_utils
open Definitions
(** {1 Common syntax highlighting helpers}*)
@ -24,7 +24,7 @@ open Definitions
val base_type : Format.formatter -> string -> unit
val keyword : Format.formatter -> string -> unit
val punctuation : Format.formatter -> string -> unit
val operator : Format.formatter -> string -> unit
val op_style : Format.formatter -> string -> unit
val lit_style : Format.formatter -> string -> unit
(** {1 Formatters} *)
@ -35,13 +35,11 @@ val tlit : Format.formatter -> typ_lit -> unit
val location : Format.formatter -> 'a glocation -> unit
val typ : decl_ctx -> Format.formatter -> typ -> unit
val lit : Format.formatter -> 'a glit -> unit
val op_kind : Format.formatter -> op_kind -> unit
val binop : Format.formatter -> binop -> unit
val ternop : Format.formatter -> ternop -> unit
val operator : Format.formatter -> ('a any, 'k) operator -> unit
val log_entry : Format.formatter -> log_entry -> unit
val unop : Format.formatter -> unop -> unit
val except : Format.formatter -> except -> unit
val var : Format.formatter -> 'e Var.t -> unit
val var_debug : Format.formatter -> 'e Var.t -> unit
val expr :
?debug:bool (** [true] for debug printing *) ->

View File

@ -22,6 +22,18 @@ let map_exprs ~f ~varf { scopes; decl_ctx } =
(fun scopes -> { scopes; decl_ctx })
(Scope.map_exprs ~f ~varf scopes)
let get_scope_body { scopes; _ } scope =
match
Scope.fold_left ~init:None
~f:(fun acc scope_def _ ->
if ScopeName.equal scope_def.scope_name scope then
Some scope_def.scope_body
else acc)
scopes
with
| None -> raise Not_found
| Some body -> body
let untype : 'm. ('a, 'm mark) gexpr program -> ('a, untyped mark) gexpr program
=
fun prg -> Bindlib.unbox (map_exprs ~f:Expr.untype ~varf:Var.translate prg)

View File

@ -25,6 +25,9 @@ val map_exprs :
'expr1 program ->
'expr2 program Bindlib.box
val get_scope_body :
(([< dcalc | lcalc ], _) gexpr as 'e) program -> ScopeName.t -> 'e scope_body
val untype :
(([< dcalc | lcalc ] as 'a), 'm mark) gexpr program ->
('a, untyped mark) gexpr program

View File

@ -15,7 +15,7 @@
License for the specific language governing permissions and limitations under
the License. *)
open Utils
open Catala_utils
open Definitions
let rec fold_left_lets ~f ~init scope_body_expr =
@ -106,7 +106,7 @@ let rec get_body_expr_mark = function
get_body_expr_mark e
| Result e ->
let m = Marked.get_mark e in
Expr.with_ty m (Utils.Marked.mark (Expr.mark_pos m) TAny)
Expr.with_ty m (Marked.mark (Expr.mark_pos m) TAny)
let get_body_mark scope_body =
let _, e = Bindlib.unbind scope_body.scope_body_expr in

View File

@ -17,7 +17,7 @@
(** Functions handling the scope structures of [shared_ast] *)
open Utils
open Catala_utils
open Definitions
(** {2 Traversal functions} *)

View File

@ -16,6 +16,8 @@
include Definitions
module Var = Var
module Type = Type
module Operator = Operator
module Expr = Expr
module Scope = Scope
module Program = Program

View File

@ -0,0 +1,87 @@
(* This file is part of the Catala compiler, a specification language for tax
and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
Louis Gesbert <louis.gesbert@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not
use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *)
open Catala_utils
open Definitions
type t = typ
let equal_tlit l1 l2 = l1 = l2
let compare_tlit l1 l2 = Stdlib.compare l1 l2
let rec equal ty1 ty2 =
match Marked.unmark ty1, Marked.unmark ty2 with
| TLit l1, TLit l2 -> equal_tlit l1 l2
| TTuple tys1, TTuple tys2 -> equal_list tys1 tys2
| TStruct n1, TStruct n2 -> StructName.equal n1 n2
| TEnum n1, TEnum n2 -> EnumName.equal n1 n2
| TOption t1, TOption t2 -> equal t1 t2
| TArrow (t1, t1'), TArrow (t2, t2') -> equal t1 t2 && equal t1' t2'
| TArray t1, TArray t2 -> equal t1 t2
| TAny, TAny -> true
| ( ( TLit _ | TTuple _ | TStruct _ | TEnum _ | TOption _ | TArrow _
| TArray _ | TAny ),
_ ) ->
false
and equal_list tys1 tys2 =
try List.for_all2 equal tys1 tys2 with Invalid_argument _ -> false
(* Similar to [equal], but allows TAny holes *)
let rec unifiable ty1 ty2 =
match Marked.unmark ty1, Marked.unmark ty2 with
| TAny, _ | _, TAny -> true
| TLit l1, TLit l2 -> equal_tlit l1 l2
| TTuple tys1, TTuple tys2 -> unifiable_list tys1 tys2
| TStruct n1, TStruct n2 -> StructName.equal n1 n2
| TEnum n1, TEnum n2 -> EnumName.equal n1 n2
| TOption t1, TOption t2 -> unifiable t1 t2
| TArrow (t1, t1'), TArrow (t2, t2') -> unifiable t1 t2 && unifiable t1' t2'
| TArray t1, TArray t2 -> unifiable t1 t2
| ( (TLit _ | TTuple _ | TStruct _ | TEnum _ | TOption _ | TArrow _ | TArray _),
_ ) ->
false
and unifiable_list tys1 tys2 =
try List.for_all2 unifiable tys1 tys2 with Invalid_argument _ -> false
let rec compare ty1 ty2 =
match Marked.unmark ty1, Marked.unmark ty2 with
| TLit l1, TLit l2 -> compare_tlit l1 l2
| TTuple tys1, TTuple tys2 -> List.compare compare tys1 tys2
| TStruct n1, TStruct n2 -> StructName.compare n1 n2
| TEnum en1, TEnum en2 -> EnumName.compare en1 en2
| TOption t1, TOption t2 -> compare t1 t2
| TArrow (a1, b1), TArrow (a2, b2) -> (
match compare a1 a2 with 0 -> compare b1 b2 | n -> n)
| TArray t1, TArray t2 -> compare t1 t2
| TAny, TAny -> 0
| TLit _, _ -> -1
| _, TLit _ -> 1
| TTuple _, _ -> -1
| _, TTuple _ -> 1
| TStruct _, _ -> -1
| _, TStruct _ -> 1
| TEnum _, _ -> -1
| _, TEnum _ -> 1
| TOption _, _ -> -1
| _, TOption _ -> 1
| TArrow _, _ -> -1
| _, TArrow _ -> 1
| TArray _, _ -> -1
| _, TArray _ -> 1
let rec arrow_return = function TArrow (_, b), _ -> arrow_return b | t -> t

Some files were not shown because too many files have changed in this diff Show More