Merge branch 'master' into input_context_subscopes

This commit is contained in:
Denis Merigoux 2024-07-19 19:02:18 +02:00
commit 48498ef466
No known key found for this signature in database
GPG Key ID: EE99DCFA365C3EE3
149 changed files with 4671 additions and 4178 deletions

View File

@ -85,15 +85,30 @@ jobs:
runs-on: self-hosted
container:
image: ${{ needs.build.outputs.image }}
options: --user ocaml
options: --user root
steps:
- name: Fix home
# We need to run as root as a Workaround Github actions issues
# (https://github.com/actions/checkout/issues/1014) but need ~ocaml as
# home.
run: sudo sh -c "echo HOME=/home/ocaml >> ${GITHUB_ENV}"
- name: Check promoted files
run: |
cd /home/ocaml/catala && opam exec -- make check-promoted
git diff --exit-code
- name: Run tests
if: ${{ always() }}
run: cd /home/ocaml/catala && opam exec -- make tests
run: cd /home/ocaml/catala && opam exec -- make testsuite
- name: Generate test report
if: ${{ always() }}
run: |
cd /home/ocaml/catala
opam exec -- clerk report --xml _build/*@test _build/test-*/*@test >report.junit.xml
- name: Test Summary
uses: test-summary/action@v2
with:
paths: /home/ocaml/catala/report.junit.xml
if: ${{ always() }}
examples:
name: Build examples and generate artifacts
@ -101,7 +116,7 @@ jobs:
runs-on: self-hosted
container:
image: ${{ needs.build.outputs.image }}
options: --user ocaml
options: --user root
env:
DUNE_PROFILE: release
steps:
@ -112,7 +127,9 @@ jobs:
- name: Install LaTeX deps
# This is done late because caching would not benefit compared to
# installation through apk (1,5G upload is slow)
run: sudo apk add texlive-xetex texmf-dist-latexextra texmf-dist-pictures font-dejavu groff
run: sudo apk add texlive-xetex texmf-dist-latexextra texmf-dist-binextra texmf-dist-pictures texmf-dist-fontsrecommended font-dejavu groff texmf-dist-lang
# Fewer texmf deps should be required once
# https://gitlab.alpinelinux.org/alpine/aports/-/issues/16190 is fixed
- name: Build Catala extra docs
run: |
cd ~/catala
@ -128,6 +145,16 @@ jobs:
run: |
cd ~/catala-examples
opam --cli=2.1 exec -- make all testsuite install
- name: Generate examples test report
if: ${{ always() }}
run: |
cd ~/catala-examples
opam exec -- clerk report --xml _build/clerk_tests/*@test _build/clerk_tests/test-*/*@test >report.junit.xml
- name: Test Summary
uses: test-summary/action@v2
with:
paths: "/home/ocaml/catala-examples/report.junit.xml"
if: ${{ always() }}
- name: Checkout french-law repo
run: |
git clone https://github.com/CatalaLang/french-law --depth 1 ~/french-law -b "${{ github.head_ref || github.ref_name }}" ||
@ -164,10 +191,6 @@ jobs:
cp catala-examples/_build/french_law_python.tar.gz artifacts/
mv catala/website-assets.tar.gz artifacts/
- name: Upload artifacts
continue-on-error: true
# Uploading artifacts works but then return failure with:
# EACCES: permission denied, open '/__w/_temp/_runner_file_commands/set_output_xxx'
# a chmod doesn't work around it so we resort to just ignoring the error...
uses: actions/upload-artifact@v4
with:
name: Catala examples

View File

@ -187,3 +187,25 @@ reformat your branch patch by patch before rebasing.
Requirements of catala that are not inside [nixpkgs](https://github.com/nixos/nixpkgs) are available inside the `.nix` directory of the repo. The main part is inside the `.nix/packages.nix`, where all the packages are either added (because absent from nixpkgs) using `ocamlPackage.callPackage`; or modified from nixpkgs, for instance cmdliner is currently pinned at version 1.1.0.
### Pull Requests Policies
Pull requests must be approved by, at least, one knowledgable
contributor before merging.
Unless there exists legitimate reasons, every commit of the pull
request must compile, and, the final commit must successfully pass the
CI check.
All requested changes should ideally be included in the PR. However,
if the PR is merged while there are still open discussions or if there
are late remarks, it should be addressed as soon as possible in a
follow-up PR.
As much as possible, offline interactions between the author and
reviewer(s) leading to a discussion resolution should result in a
quick summary that documents the decision.
Whenever major changes are requested, both PR's author and reviewer(s)
may reach an agreement to delay the resolution (e.g., in a future PR)
in which case it must be documented as an issue in order to properly
track it.

View File

@ -2,7 +2,7 @@
# STAGE 1: setup an opam switch with all dependencies installed
#
# (only depends on the opam files)
FROM ocamlpro/ocaml:4.14-2024-01-14 AS dev-build-context
FROM ocamlpro/ocaml:4.14-2024-05-26 AS dev-build-context
# Image from https://hub.docker.com/r/ocamlpro/ocaml
RUN mkdir catala
@ -25,10 +25,6 @@ RUN opam --cli=2.1 switch create catala ocaml-system && \
# Note: just `opam switch create . --deps-only --with-test --with-doc && opam clean`
# should be enough once opam 2.2 is released (see opam#5185)
# This is temporary, to avoid pulling in a dependency to Str, until it's merged
# and release into dates_calc
RUN opam --cli=2.1 pin dates_calc.0.0.5 git+https://github.com/AltGr/dates-calc#nostr
#
# STAGE 2: get the whole repo and build
#

View File

@ -99,14 +99,13 @@ prepare-install:
dune build @install --promote-install-files
install: prepare-install
if [ x$$($(OPAM) --version) = "x2.1.5" ]; then \
$(OPAM) install . --working-dir; \
else \
$(OPAM) install . --working-dir --assume-built; \
fi
case x$$($(OPAM) --version) in \
x2.1.5|x2.1.6) $(OPAM) install . --working-dir;; \
*) $(OPAM) install . --working-dir --assume-built;; \
esac
# `dune install` would work, but does a dirty install to the opam prefix without
# registering with opam.
# --assume-built is broken in 2.1.5
# --assume-built is broken in 2.1.5 and 2.1.6
inst: prepare-install
@opam custom-install \
@ -218,7 +217,7 @@ tests: test
TEST_FLAGS_LIST = ""\
-O \
--lcalc \
--lcalc,--avoid-exceptions,-O
--lcalc,--closure-conversion,-O
# Does not include running dune (to avoid duplication when run among bigger rules)
testsuite-base: .FORCE
@ -230,6 +229,7 @@ testsuite-base: .FORCE
#> testsuite : Run interpreter tests over a selection of configurations
testsuite: unit-tests
$(CLERK_TEST) doc
$(MAKE) testsuite-base
#> reset-tests : Update the expected test results from current run
@ -299,7 +299,8 @@ BRANCH = $(shell git branch --show-current 2>/dev/null || echo master)
# its usage.
local_tmp_clone = { \
rm -rf $1.tmp && \
trap "rm -rf $1.tmp" EXIT && \
CLEANUP_TMP_GIT_CLONES="$${CLEANUP_TMP_GIT_CLONES}rm -rf $1.tmp; " && \
trap "$$CLEANUP_TMP_GIT_CLONES" EXIT && \
git clone https://github.com/CatalaLang/$1 \
--depth 1 --reference-if-able ../$1 \
$1.tmp -b $(BRANCH) || \
@ -336,8 +337,12 @@ alltest: dependencies-python
bench_ocaml \
bench_js \
bench_python && \
printf "\n# \e[42;30m[ ALL TESTS PASSED ]\e[m \e[32m☺\e[m\n" || \
{ printf "\n# \e[41;30m[ TESTS FAILED ]\e[m \e[31m☹\e[m\n" ; exit 1; }
printf "\n# Full Catala testsuite:\t\t\e[42;30m ALL TESTS PASSED \e[m\t\t\e[32m☺\e[m\n" || \
{ printf "\n# Full Catala testsuite:\t\t\e[41;30m TESTS FAILED \e[m\t\t\e[31m☹\e[m\n" ; exit 1; }
#> alltest- : Like 'alltest', but skips doc building and is much faster
alltest-:
@$(MAKE) alltest NODOC=1
#> clean : Clean build artifacts
clean:

View File

@ -0,0 +1,112 @@
(* This file is part of the Catala build system, a specification language for
tax and social benefits computation rules. Copyright (C) 2024 Inria,
contributors: Louis Gesbert <louis.gesbert@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not
use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *)
open Catala_utils
open Otoml
type t = {
catala_opts : string list;
build_dir : File.t;
include_dirs : File.t list;
}
let default = { catala_opts = []; build_dir = "_build"; include_dirs = [] }
let toml_to_config toml =
{
catala_opts = Helpers.find_strings_exn toml ["build"; "catala_opts"];
build_dir = Helpers.find_string_exn toml ["build"; "build_dir"];
include_dirs = Helpers.find_strings_exn toml ["project"; "include_dirs"];
}
let config_to_toml t =
table
[
( "build",
table
[
"catala_opts", array (List.map string t.catala_opts);
"build_dir", string t.build_dir;
] );
"project", table ["include_dirs", array (List.map string t.include_dirs)];
]
let default_toml = config_to_toml default
(* joins default and supplied conf, ensuring types match. The filename is for
error reporting *)
let rec join ?(rpath = []) fname t1 t2 =
match t1, t2 with
| TomlString _, TomlString _
| TomlInteger _, TomlInteger _
| TomlFloat _, TomlFloat _
| TomlBoolean _, TomlBoolean _
| TomlOffsetDateTime _, TomlOffsetDateTime _
| TomlLocalDateTime _, TomlLocalDateTime _
| TomlLocalDate _, TomlLocalDate _
| TomlLocalTime _, TomlLocalTime _
| TomlArray _, TomlArray _
| TomlTableArray _, TomlTableArray _ ->
t2
| TomlTable tt1, TomlTable tt2 | TomlInlineTable tt1, TomlInlineTable tt2 ->
let m1 = String.Map.of_list tt1 in
let m2 = String.Map.of_list tt2 in
TomlTable
(String.Map.merge
(fun key t1 t2 ->
match t1, t2 with
| None, Some _ ->
Message.error
"While parsing %a: invalid key @{<red>%S@} at @{<bold>%s@}"
File.format fname key
(if rpath = [] then "file root"
else String.concat "." (List.rev rpath))
| Some t1, Some t2 -> Some (join ~rpath:(key :: rpath) fname t1 t2)
| Some t1, None -> Some t1
| None, None -> assert false)
m1 m2
|> String.Map.bindings)
| _ ->
Message.error
"While parsing %a: Wrong type for config value @{<bold>%s@}, was \
expecting @{<bold>%s@}"
File.format fname
(String.concat "." (List.rev rpath))
(match t1 with
| TomlString _ -> "a string"
| TomlInteger _ -> "an integer"
| TomlFloat _ -> "a float"
| TomlBoolean _ -> "a boolean"
| TomlOffsetDateTime _ -> "an offsetdatetime"
| TomlLocalDateTime _ -> "a localdatetime"
| TomlLocalDate _ -> "a localdate"
| TomlLocalTime _ -> "a localtime"
| TomlArray _ | TomlTableArray _ -> "an array"
| TomlTable _ | TomlInlineTable _ -> "a table")
let read f =
let toml =
try Parser.from_file f
with Parse_error (Some (li, col), msg) ->
Message.error
~pos:(Pos.from_info f li col li (col + 1))
"Error in Clerk configuration:@ %a" Format.pp_print_text msg
in
toml_to_config (join f default_toml toml)
let write f t =
let toml = config_to_toml t in
File.with_out_channel f @@ fun oc -> Printer.to_channel oc toml

View File

@ -1,6 +1,6 @@
(* This file is part of the Catala compiler, a specification language for tax
and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
Denis Merigoux <denis.merigoux@inria.fr>
(* This file is part of the Catala build system, a specification language for
tax and social benefits computation rules. Copyright (C) 2024 Inria,
contributors: Louis Gesbert <louis.gesbert@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not
use this file except in compliance with the License. You may obtain a copy of
@ -14,7 +14,14 @@
License for the specific language governing permissions and limitations under
the License. *)
(** Translation from the default calculus to the lambda calculus. This
translation uses exceptions to handle empty default terms. *)
open Catala_utils
val translate_program : 'm Dcalc.Ast.program -> 'm Ast.program
type t = {
catala_opts : string list;
build_dir : File.t;
include_dirs : File.t list;
}
val default : t
val read : File.t -> t
val write : File.t -> t -> unit

View File

@ -1,7 +1,7 @@
(* This file is part of the Catala build system, a specification language for
tax and social benefits computation rules. Copyright (C) 2020 Inria,
contributors: Denis Merigoux <denis.merigoux@inria.fr>, Emile Rolley
<emile.rolley@tuta.io>
<emile.rolley@tuta.io>, Louis Gesbert <louis.gesbert@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not
use this file except in compliance with the License. You may obtain a copy of
@ -75,14 +75,33 @@ module Cli = struct
tests. Comma-separated list. A subset may also be applied to the \
compilation of modules, as needed.\n\
WARNING: flag shortcuts are not allowed here (i.e. don't use \
non-ambiguous prefixes such as $(b,--avoid-ex) for \
$(b,--avoid-exceptions))\n\
non-ambiguous prefixes such as $(b,--closure) for \
$(b,--closure-conversion))\n\
NOTE: if this is set, all inline tests that are $(i,not) \
$(b,catala test-scope) are skipped to avoid redundant testing.")
let runtest_report =
Arg.(
value
& opt (some string) None
& info ["report"] ~docv:"FILE"
~doc:
"If set, $(i,clerk runtest) will output a tests result summary in \
binary format to the given $(b,FILE)")
let runtest_out =
Arg.(
value
& pos 1 (some string) None
& info [] ~docv:"OUTFILE"
~doc:"Write the test outcome to file $(b,OUTFILE) instead of stdout.")
module Global : sig
val color : Catala_utils.Global.when_enum Term.t
val debug : bool Term.t
val term :
(chdir:File.t option ->
(config_file:File.t option ->
catala_exe:File.t option ->
catala_opts:string list ->
build_dir:File.t option ->
@ -93,12 +112,14 @@ module Cli = struct
'a) ->
'a Term.t
end = struct
let chdir =
let config_file =
Arg.(
value
& opt (some string) None
& info ["C"] ~docv:"DIR"
~doc:"Change to the given directory before processing")
& opt (some file) None
& info ["config"] ~docv:"FILE"
~doc:
"Clerk configuration file to use, instead of looking up \
\"clerk.toml\" in parent directories.")
let color =
Arg.(
@ -129,7 +150,7 @@ module Cli = struct
Term.(
const
(fun
chdir
config_file
catala_exe
catala_opts
build_dir
@ -138,9 +159,9 @@ module Cli = struct
debug
ninja_output
->
f ~chdir ~catala_exe ~catala_opts ~build_dir ~include_dirs ~color
~debug ~ninja_output)
$ chdir
f ~config_file ~catala_exe ~catala_opts ~build_dir ~include_dirs
~color ~debug ~ninja_output)
$ config_file
$ catala_exe
$ catala_opts
$ build_dir
@ -156,6 +177,12 @@ module Cli = struct
& pos_all string []
& info [] ~docv:"FILE(S)" ~doc:"File(s) or folder(s) to process")
let files =
Arg.(
value
& pos_all file []
& info [] ~docv:"FILE(S)" ~doc:"File(s) to process")
let single_file =
Arg.(
required
@ -189,6 +216,45 @@ module Cli = struct
"Flags or targets to forward to Ninja directly (use $(b,-- \
ninja_flags) to separate Ninja flags from Clerk flags)")
let report_verbosity =
Arg.(
value
& vflag `Failures
[
( `Summary,
info ["summary"] ~doc:"Only display a summary of the test results"
);
( `Short,
info ["short"] ~doc:"Don't display detailed test failures diff" );
( `Failures,
info ["failures"]
~doc:"Show details of files with failed tests only" );
( `Verbose,
info ["verbose"; "v"]
~doc:"Display the full list of tests that have been run" );
])
let report_xml =
Arg.(
value
& flag
& info ["xml"]
~env:(Cmd.Env.info "CATALA_XML_REPORT")
~doc:"Output the test report in JUnit-compatible XML format")
let diff_command =
Arg.(
value
& opt ~vopt:(Some None) (some (some string)) None
& info ["diff"]
~env:(Cmd.Env.info "CATALA_DIFF_COMMAND")
~doc:
"Use a standard $(i,diff) command instead of the default \
side-by-side view. If no argument is supplied, the command will \
be $(b,patdiff) if available or $(b,diff) otherwise. A supplied \
argument will be used as diff command with arguments pointing to \
the reference file and the output file")
let ninja_flags =
let env =
Cmd.Env.info
@ -253,36 +319,18 @@ end
(** Some functions that poll the surrounding systems (think [./configure]) *)
module Poll = struct
let project_root_absrel : (File.t option * File.t) Lazy.t =
lazy
(let open File in
let home = try Sys.getenv "HOME" with Not_found -> "" in
let rec lookup dir rel =
if
Sys.file_exists (dir / "catala.opam")
|| Sys.file_exists (dir / ".git")
|| Sys.file_exists (dir / "clerk.toml")
then Some dir, rel
else if dir = home then None, Filename.current_dir_name
else
let parent = Filename.dirname dir in
if parent = dir then None, Filename.current_dir_name
else lookup parent (rel / Filename.parent_dir_name)
in
lookup (Sys.getcwd ()) Filename.current_dir_name)
let project_root = lazy (fst (Lazy.force project_root_absrel))
let project_root_relative = lazy (snd (Lazy.force project_root_absrel))
(** This module is sensitive to the CWD at first use. Therefore it's expected
that [chdir] has been run beforehand to the project root. *)
let root = lazy (Sys.getcwd ())
(** Scans for a parent directory being the root of the Catala source repo *)
let catala_project_root : File.t option Lazy.t =
lazy
(match Lazy.force project_root with
| Some root
when Sys.file_exists File.(root / "catala.opam")
&& Sys.file_exists File.(root / "dune-project") ->
Some root
| _ -> None)
root
|> Lazy.map
@@ fun root ->
if File.(exists (root / "catala.opam") && exists (root / "dune-project"))
then Some root
else None
let exec_dir : File.t = Catala_utils.Cli.exec_dir
let clerk_exe : File.t Lazy.t = lazy (Unix.realpath Sys.executable_name)
@ -292,14 +340,14 @@ module Poll = struct
(let f = File.(exec_dir / "catala") in
if Sys.file_exists f then Unix.realpath f
else
match Lazy.force project_root with
| Some root when Sys.file_exists File.(root / "catala.opam") ->
match catala_project_root with
| (lazy (Some root)) ->
Unix.realpath
File.(root / "_build" / "default" / "compiler" / "catala.exe")
| _ -> File.check_exec "catala")
let build_dir : ?dir:File.t -> unit -> File.t =
fun ?(dir = "_build") () ->
let build_dir : dir:File.t -> unit -> File.t =
fun ~dir () ->
let d = File.clean_path dir in
File.ensure_dir d;
d
@ -385,36 +433,8 @@ module Poll = struct
let ocaml_link_flags : string list Lazy.t =
lazy (snd (Lazy.force ocaml_include_and_lib_flags))
let has_command cmd =
let check_cmd = Printf.sprintf "type %s >/dev/null 2>&1" cmd in
Sys.command check_cmd = 0
let diff_command =
lazy
(if has_command "patdiff" then
["patdiff"; "-alt-old"; "reference"; "-alt-new"; "current-output"]
else
[
"diff";
"-u";
"-b";
"--color";
"--label";
"reference";
"--label";
"current-output";
])
end
(* Adjusts paths specified from the command-line relative to the user cwd to be
instead relative to the project root *)
let fix_path =
let from_dir = Sys.getcwd () in
fun d ->
let to_dir = Lazy.force Poll.project_root_relative in
Catala_utils.File.reverse_path ~from_dir ~to_dir d
(**{1 Building rules}*)
(** Ninja variable names *)
@ -435,8 +455,6 @@ module Var = struct
let ocamlopt_exe = make "OCAMLOPT_EXE"
let ocaml_flags = make "OCAML_FLAGS"
let runtime_ocaml_libs = make "RUNTIME_OCAML_LIBS"
let diff = make "DIFF"
let post_test = make "POST_TEST"
(** Rule vars, Used in specific rules *)
@ -447,7 +465,6 @@ module Var = struct
let orig_src = make "orig-src"
let scope = make "scope"
let test_id = make "test-id"
let test_command = make "test-command"
let ( ! ) = Var.v
end
@ -464,15 +481,13 @@ let base_bindings catala_exe catala_flags build_dir include_dirs test_flags =
let catala_flags_ocaml =
List.filter
(function
| "--avoid-exceptions" | "-O" | "--optimize" -> true | _ -> false)
| "-O" | "--optimize" | "--closure-conversion" -> true | _ -> false)
test_flags
in
let catala_flags_python =
List.filter
(function
| "--avoid-exceptions" | "-O" | "--optimize" | "--closure-conversion" ->
true
| _ -> false)
| "-O" | "--optimize" | "--closure-conversion" -> true | _ -> false)
test_flags
in
let ocaml_flags = Lazy.force Poll.ocaml_include_flags in
@ -500,13 +515,10 @@ let base_bindings catala_exe catala_flags build_dir include_dirs test_flags =
Nj.binding Var.ocamlopt_exe ["ocamlopt"];
Nj.binding Var.ocaml_flags (ocaml_flags @ includes);
Nj.binding Var.runtime_ocaml_libs (Lazy.force Poll.ocaml_link_flags);
Nj.binding Var.diff (Lazy.force Poll.diff_command);
Nj.binding Var.post_test [Var.(!diff)];
]
let[@ocamlformat "disable"] static_base_rules =
let open Var in
let color = Message.has_color stdout in
let shellout l = Format.sprintf "$$(%s)" (String.concat " " l) in
[
Nj.rule "copy"
@ -545,27 +557,11 @@ let[@ocamlformat "disable"] static_base_rules =
!input; "-o"; !output]
~description:["<catala>"; "python"; ""; !output];
Nj.rule "out-test"
~command: [
!catala_exe; !test_command; "--plugin-dir="; "-o -"; !catala_flags;
!input; ">"; !output; "2>&1";
"||"; "true";
]
~description:
["<catala>"; "test"; !test_id; ""; !input; "(" ^ !test_command ^ ")"];
Nj.rule "inline-tests"
Nj.rule "tests"
~command:
[!clerk_exe; "runtest"; !clerk_flags; !input; ">"; !output; "2>&1";
"||"; "true"]
~description:["<catala>"; "inline-tests"; ""; !input];
Nj.rule "post-test"
~command:[
!post_test; !input; ";";
"echo"; "-n"; "$$?"; ">"; !output;
]
~description:["<test>"; !test_id];
[!clerk_exe; "runtest"; !clerk_flags; !input;
"--report"; !output;]
~description:["<catala>"; "tests"; ""; !input];
Nj.rule "interpret"
~command:
@ -576,37 +572,6 @@ let[@ocamlformat "disable"] static_base_rules =
Nj.rule "dir-tests"
~command:["cat"; !input; ">"; !output; ";"]
~description:["<test>"; !test_id];
Nj.rule "test-results"
~command:[
"out=" ^ !output; ";";
"success=$$("; "tr"; "-cd"; "0"; "<"; !input; "|"; "wc"; "-c"; ")"; ";";
"total=$$("; "wc"; "-c"; "<"; !input; ")"; ";";
"pass=$$("; ")"; ";";
"if"; "test"; "\"$$success\""; "-eq"; "\"$$total\""; ";"; "then";
"printf";
(if color then "\"\\n[\\033[32mPASS\\033[m] \\033[1m%s\\033[m: \
\\033[32m%3d\\033[m/\\033[32m%d\\033[m\\n\""
else "\"\\n[PASS] %s: %3d/%d\\n\"");
"$${out%@test}"; "$$success"; "$$total"; ";";
"else";
"printf";
(if color then "\"\\n[\\033[31mFAIL\\033[m] \\033[1m%s\\033[m: \
\\033[31m%3d\\033[m/\\033[32m%d\\033[m\\n\""
else "\"\\n[FAIL] %s: %3d/%d\\n\"");
"$${out%@test}"; "$$success"; "$$total"; ";";
"return"; "1"; ";";
"fi";
]
~description:["<test>"; !output];
(* Note: this last rule looks horrible, but the processing is pretty simple:
in the rules above, we output the returning code of diffing individual
tests to a [<testfile>@test] file, then the rules for directories just
concat these files. What this last rule does is then just count the number
of `0` and the total number of characters in the file, and print a readable
message. Instead of this disgusting shell code embedded in the ninja file,
this could be a specialised subcommand of clerk, e.g. `clerk
test-diagnostic <results-file@test>` *)
]
let gen_build_statements
@ -664,7 +629,7 @@ let gen_build_statements
let obj =
Nj.build "ocaml-object" ~inputs:[ml_file]
~implicit_in:(!Var.catala_exe :: List.map modd modules)
~outputs:(List.map target_file ["mli"; "cmi"; "cmo"; "cmx"; "cmt"; "o"])
~outputs:(List.map target_file ["mli"; "cmi"; "cmo"; "cmx"; "o"])
~vars:
[
( Var.ocaml_flags,
@ -719,75 +684,29 @@ let gen_build_statements
(src /../ "output" / Filename.basename src) -.- test.Scan.id
in
let tests =
let legacy_tests =
List.fold_left
(fun acc test ->
let vars =
[Var.test_id, [test.Scan.id]; Var.test_command, test.Scan.cmd]
in
let reference = legacy_test_reference test in
let test_out =
(!Var.builddir / src /../ "output" / Filename.basename src)
-.- test.id
in
Nj.build "out-test"
~inputs:[inc srcv]
~implicit_in:interp_deps ~outputs:[test_out] ~vars
:: (* The test reference is an input because of the cases when we run
diff; it should actually be an output for the cases when we
reset but that shouldn't cause trouble. *)
Nj.build "post-test" ~inputs:[reference; test_out]
~implicit_in:["always"]
~outputs:[(!Var.builddir / reference) ^ "@post"]
~vars:[Var.test_id, [reference]]
:: acc)
[] item.legacy_tests
let out_tests_references =
List.map (fun test -> legacy_test_reference test) item.legacy_tests
in
let inline_tests =
if not item.has_inline_tests then []
else
[
Nj.build "inline-tests"
~inputs:[inc srcv]
~implicit_in:(!Var.clerk_exe :: interp_deps)
~outputs:[(!Var.builddir / srcv) ^ "@out"];
]
let out_tests_prepare =
List.map
(fun f -> Nj.build "copy" ~inputs:[f] ~outputs:[inc f])
out_tests_references
in
let tests =
let results =
Nj.build "test-results"
~outputs:[srcv ^ "@test"]
~inputs:[inc (srcv ^ "@test")]
in
let inline_test label =
Nj.build "post-test"
~outputs:[inc (srcv ^ label)]
~inputs:[srcv; inc (srcv ^ "@out")]
~implicit_in:["always"]
~vars:[Var.test_id, [srcv]]
in
match item.legacy_tests with
| [] ->
if item.has_inline_tests then [inline_test "@test"; results] else []
| legacy ->
let inline =
if item.has_inline_tests then [inline_test "@inline"] else []
in
inline
@ [
Nj.build "dir-tests"
~outputs:[inc (srcv ^ "@test")]
~inputs:
((if item.has_inline_tests then [inc (srcv ^ "@inline")] else [])
@ List.map
(fun test ->
(!Var.builddir / legacy_test_reference test) ^ "@post")
legacy)
~vars:[Var.test_id, [srcv]];
results;
]
if (not item.has_inline_tests) && item.legacy_tests = [] then []
else
[
Nj.build "tests"
~inputs:[inc srcv]
~implicit_in:
((!Var.clerk_exe :: interp_deps)
@ List.map inc out_tests_references)
~outputs:[inc srcv ^ "@test"; inc srcv ^ "@out"]
~implicit_out:
(List.map (fun o -> inc o ^ "@out") out_tests_references);
]
in
legacy_tests @ inline_tests @ tests
out_tests_prepare @ tests
in
Seq.concat
@@ List.to_seq
@ -836,9 +755,6 @@ let dir_test_rules dir subdirs items =
~outputs:[(Var.(!builddir) / dir) ^ "@test"]
~inputs
~vars:[Var.test_id, [dir]];
Nj.build "test-results"
~outputs:[dir ^ "@test"]
~inputs:[(Var.(!builddir) / dir) ^ "@test"];
]
let build_statements include_dirs dir =
@ -851,11 +767,6 @@ let build_statements include_dirs dir =
let gen_ninja_file catala_exe catala_flags build_dir include_dirs test_flags dir
=
let build_dir =
match test_flags with
| [] -> build_dir
| flags -> File.((build_dir / "test") ^ String.concat "" flags)
in
let ( @+ ) = Seq.append in
Seq.return
(Nj.Comment (Printf.sprintf "File generated by Clerk v.%s\n" version))
@ -871,8 +782,10 @@ let gen_ninja_file catala_exe catala_flags build_dir include_dirs test_flags dir
(** {1 Driver} *)
(* Last argument is a continuation taking as arguments [build_dir], the
[fix_path] function, and the ninja file name *)
let ninja_init
~chdir
~config_file
~catala_exe
~catala_opts
~build_dir
@ -880,13 +793,58 @@ let ninja_init
~color
~debug
~ninja_output :
extra:def Seq.t -> test_flags:string list -> (File.t -> 'a) -> 'a =
extra:def Seq.t ->
test_flags:string list ->
(File.t -> (File.t -> File.t) -> File.t -> 'a) ->
'a =
let _options = Catala_utils.Global.enforce_options ~debug ~color () in
let chdir =
match chdir with None -> Lazy.force Poll.project_root | some -> some
let default_config_file = "clerk.toml" in
let set_root_dir dir =
Message.debug "Entering directory %a" File.format dir;
Sys.chdir dir
in
Option.iter Sys.chdir chdir;
let build_dir = Poll.build_dir ?dir:build_dir () in
(* fix_path adjusts paths specified from the command-line relative to the user
cwd to be instead relative to the project root *)
let fix_path, config =
let from_dir = Sys.getcwd () in
match config_file with
| None -> (
match
File.(find_in_parents (fun dir -> exists (dir / default_config_file)))
with
| Some (root, rel) ->
set_root_dir root;
( Catala_utils.File.reverse_path ~from_dir ~to_dir:rel,
Clerk_config.read default_config_file )
| None -> (
match
File.(
find_in_parents (function dir ->
exists (dir / "catala.opam") || exists (dir / ".git")))
with
| Some (root, rel) ->
set_root_dir root;
( Catala_utils.File.reverse_path ~from_dir ~to_dir:rel,
Clerk_config.default )
| None -> Fun.id, Clerk_config.default))
| Some f ->
let root = Filename.dirname f in
let config = Clerk_config.read f in
set_root_dir root;
( (fun d ->
let r = Catala_utils.File.reverse_path ~from_dir ~to_dir:root d in
Message.debug "%a => %a" File.format d File.format r;
r),
config )
in
let build_dir =
let dir =
match build_dir with None -> config.build_dir | Some dir -> dir
in
Poll.build_dir ~dir ()
in
let catala_opts = config.catala_opts @ catala_opts in
let include_dirs = config.include_dirs @ include_dirs in
let with_ninja_output k =
match ninja_output with
| Some f -> k f
@ -895,6 +853,11 @@ let ninja_init
in
fun ~extra ~test_flags k ->
Message.debug "building ninja rules...";
let build_dir =
match test_flags with
| [] -> build_dir
| flags -> File.((build_dir / "test") ^ String.concat "" flags)
in
with_ninja_output
@@ fun nin_file ->
File.with_formatter_of_file nin_file (fun nin_ppf ->
@ -909,7 +872,7 @@ let ninja_init
]
in
Nj.format nin_ppf ninja_contents);
k nin_file
k build_dir fix_path nin_file
let cleaned_up_env () =
let passthrough_vars =
@ -944,14 +907,14 @@ let run_ninja ~clean_up_env cmdline =
| _, Unix.WEXITED n -> n
| _, (Unix.WSIGNALED n | Unix.WSTOPPED n) -> 128 - n
in
raise (Catala_utils.Cli.Exit_with return_code)
return_code
open Cmdliner
let build_cmd =
let run ninja_init (targets : string list) (ninja_flags : string list) =
ninja_init ~extra:Seq.empty ~test_flags:[]
@@ fun nin_file ->
@@ fun _build_dir fix_path nin_file ->
let targets =
List.map
(fun f ->
@ -962,7 +925,7 @@ let build_cmd =
in
let ninja_cmd = ninja_cmdline ninja_flags nin_file targets in
Message.debug "executing '%s'..." (String.concat " " ninja_cmd);
run_ninja ~clean_up_env:false ninja_cmd
raise (Catala_utils.Cli.Exit_with (run_ninja ~clean_up_env:false ninja_cmd))
in
let doc =
"Low-level build command: can be used to forward build targets or options \
@ -972,37 +935,83 @@ let build_cmd =
Term.(
const run $ Cli.Global.term ninja_init $ Cli.targets $ Cli.ninja_flags)
let set_report_verbosity = function
| `Summary -> Clerk_report.set_display_flags ~files:`None ~tests:`None ()
| `Short ->
Clerk_report.set_display_flags ~files:`Failed ~tests:`Failed ~diffs:false ()
| `Failures ->
if Global.options.debug then Clerk_report.set_display_flags ~files:`All ()
| `Verbose -> Clerk_report.set_display_flags ~files:`All ~tests:`All ()
let test_cmd =
let run
ninja_init
(files_or_folders : string list)
(reset_test_outputs : bool)
(test_flags : string list)
verbosity
xml
(diff_command : string option option)
(ninja_flags : string list) =
set_report_verbosity verbosity;
Clerk_report.set_display_flags ~diff_command ();
ninja_init ~extra:Seq.empty ~test_flags
@@ fun build_dir fix_path nin_file ->
let targets =
let fs = if files_or_folders = [] then ["."] else files_or_folders in
List.map (fun f -> fix_path f ^ "@test") fs
List.map File.(fun f -> (build_dir / fix_path f) ^ "@test") fs
in
let extra =
List.to_seq
((if reset_test_outputs then
[
Nj.binding Var.post_test
[
"test_reset() { if ! diff -q $$1 $$2 >/dev/null; then cp -f \
$$2 $$1; fi; }";
";";
"test_reset";
];
]
else [])
@ [Nj.default targets])
in
ninja_init ~extra ~test_flags
@@ fun nin_file ->
let ninja_cmd = ninja_cmdline ninja_flags nin_file targets in
Message.debug "executing '%s'..." (String.concat " " ninja_cmd);
run_ninja ~clean_up_env:true ninja_cmd
match run_ninja ~clean_up_env:true ninja_cmd with
| 0 ->
Message.debug "gathering test results...";
let open Clerk_report in
let reports = List.flatten (List.map read_many targets) in
if reset_test_outputs then
let () =
if xml then
Message.error
"Options @{<bold>--xml@} and @{<bold>--reset@} are incompatible";
let ppf = Message.formatter_of_out_channel stdout () in
match List.filter (fun f -> f.successful < f.total) reports with
| [] ->
Format.fprintf ppf
"[@{<green>DONE@}] All tests passed, nothing to reset@."
| need_reset ->
List.iter
(fun f ->
let files =
List.fold_left
(fun files t ->
if t.success then files
else
File.Map.add (fst t.result).Lexing.pos_fname
(String.remove_prefix
~prefix:File.(build_dir / "")
(fst t.expected).Lexing.pos_fname)
files)
File.Map.empty f.tests
in
File.Map.iter
(fun result expected ->
Format.kasprintf Sys.command "cp -f %a %a@." File.format
result File.format expected
|> ignore)
files)
need_reset;
Format.fprintf ppf
"[@{<green>DONE@}] @{<yellow;bold>%d@} test files were \
@{<yellow>RESET@}@."
(List.length need_reset)
in
raise (Catala_utils.Cli.Exit_with 0)
else if (if xml then print_xml else summary) ~build_dir reports then
raise (Catala_utils.Cli.Exit_with 0)
else raise (Catala_utils.Cli.Exit_with 1)
| 1 -> raise (Catala_utils.Cli.Exit_with 10) (* Ninja build failed *)
| err -> raise (Catala_utils.Cli.Exit_with err)
(* Other Ninja error ? *)
in
let doc =
"Scan the given files or directories for catala tests, build their \
@ -1017,6 +1026,9 @@ let test_cmd =
$ Cli.files_or_folders
$ Cli.reset_test_outputs
$ Cli.test_flags
$ Cli.report_verbosity
$ Cli.report_xml
$ Cli.diff_command
$ Cli.ninja_flags)
let run_cmd =
@ -1033,10 +1045,10 @@ let run_cmd =
(List.map (fun file -> file ^ "@interpret") files_or_folders)))
in
ninja_init ~extra ~test_flags:[]
@@ fun nin_file ->
@@ fun _build_dir _fix_path nin_file ->
let ninja_cmd = ninja_cmdline ninja_flags nin_file [] in
Message.debug "executing '%s'..." (String.concat " " ninja_cmd);
run_ninja ~clean_up_env:false ninja_cmd
raise (Catala_utils.Cli.Exit_with (run_ninja ~clean_up_env:false ninja_cmd))
in
let doc =
"Runs the Catala interpreter on the given files, after building their \
@ -1052,15 +1064,15 @@ let run_cmd =
$ Cli.ninja_flags)
let runtest_cmd =
let run catala_exe catala_opts include_dirs test_flags file =
let run catala_exe catala_opts include_dirs test_flags report out file =
let catala_opts =
List.fold_left
(fun opts dir -> "-I" :: dir :: opts)
catala_opts include_dirs
in
Clerk_runtest.run_inline_tests
(Option.value ~default:"catala" catala_exe)
catala_opts test_flags file;
Clerk_runtest.run_tests
~catala_exe:(Option.value ~default:"catala" catala_exe)
~catala_opts ~test_flags ~report ~out file;
0
in
let doc =
@ -1074,9 +1086,38 @@ let runtest_cmd =
$ Cli.catala_opts
$ Cli.include_dirs
$ Cli.test_flags
$ Cli.runtest_report
$ Cli.runtest_out
$ Cli.single_file)
let main_cmd = Cmd.group Cli.info [build_cmd; test_cmd; run_cmd; runtest_cmd]
let report_cmd =
let run color debug verbosity xml diff_command build_dir files =
let _options = Catala_utils.Global.enforce_options ~debug ~color () in
let build_dir = Option.value ~default:"_build" build_dir in
set_report_verbosity verbosity;
Clerk_report.set_display_flags ~diff_command ();
let open Clerk_report in
let tests = List.flatten (List.map read_many files) in
let success = (if xml then print_xml else summary) ~build_dir tests in
exit (if success then 0 else 1)
in
let doc =
"Mainly for internal purposes. Reads a test report file and displays a \
summary of the results, returning 0 on success and 1 if any test failed."
in
Cmd.v (Cmd.info ~doc "report")
Term.(
const run
$ Cli.Global.color
$ Cli.Global.debug
$ Cli.report_verbosity
$ Cli.report_xml
$ Cli.diff_command
$ Cli.build_dir
$ Cli.files)
let main_cmd =
Cmd.group Cli.info [build_cmd; test_cmd; run_cmd; runtest_cmd; report_cmd]
let main () =
try exit (Cmdliner.Cmd.eval' ~catch:false main_cmd) with

View File

@ -0,0 +1,440 @@
(* This file is part of the Catala build system, a specification language for
tax and social benefits computation rules. Copyright (C) 2024 Inria,
contributors: Louis Gesbert <louis.gesbert@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not
use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *)
(** This module defines and manipulates Clerk test reports, which can be written
by `clerk runtest` and read to provide test result summaries. This only
concerns inline tests (```catala-test-inline blocks). *)
open Catala_utils
type test = {
success : bool;
command_line : string list;
expected : Lexing.position * Lexing.position;
result : Lexing.position * Lexing.position;
}
type file = { name : File.t; successful : int; total : int; tests : test list }
type disp_flags = {
mutable files : [ `All | `Failed | `None ];
mutable tests : [ `All | `FailedFile | `Failed | `None ];
mutable diffs : bool;
mutable diff_command : string option option;
}
let disp_flags =
{ files = `Failed; tests = `FailedFile; diffs = true; diff_command = None }
let set_display_flags
?(files = disp_flags.files)
?(tests = disp_flags.tests)
?(diffs = disp_flags.diffs)
?(diff_command = disp_flags.diff_command)
() =
disp_flags.files <- files;
disp_flags.tests <- tests;
disp_flags.diffs <- diffs;
disp_flags.diff_command <- diff_command
let write_to f file =
File.with_out_channel f (fun oc -> Marshal.to_channel oc (file : file) [])
let read_from f = File.with_in_channel f Marshal.from_channel
let read_many f =
File.with_in_channel f
@@ fun ic ->
let rec results () =
match Marshal.from_channel ic with
| file -> file :: results ()
| exception End_of_file -> []
in
results ()
let has_command cmd =
let check_cmd = Printf.sprintf "type %s >/dev/null 2>&1" cmd in
Sys.command check_cmd = 0
type 'a diff = Eq of 'a | Subs of 'a * 'a | Del of 'a | Add of 'a
let colordiff_str s1 s2 =
let split_re =
Re.(compile (alt [set "=()[]{};-,"; rep1 space; rep1 digit]))
in
let split s =
Re.Seq.split_full split_re s
|> Seq.map (function `Text t -> t | `Delim g -> Re.Group.get g 0)
in
let a1 = Array.of_seq (split s1) in
let n1 = Array.length a1 in
let a2 = Array.of_seq (split s2) in
let n2 = Array.length a2 in
let d = Array.make_matrix n1 n2 (0, []) in
let get i1 i2 =
if i1 < 0 then
( i2 + 1,
Array.fold_left (fun acc c -> Add c :: acc) [] (Array.sub a2 0 (i2 + 1))
)
else if i2 < 0 then
( i1 + 1,
Array.fold_left (fun acc c -> Del c :: acc) [] (Array.sub a1 0 (i1 + 1))
)
else d.(i1).(i2)
in
for i1 = 0 to n1 - 1 do
for i2 = 0 to n2 - 1 do
if a1.(i1) = a2.(i2) then
let eq, eqops = get (i1 - 1) (i2 - 1) in
d.(i1).(i2) <- eq, Eq a1.(i1) :: eqops
else
let del, delops = get (i1 - 1) i2 in
let add, addops = get i1 (i2 - 1) in
let subs, subsops = get (i1 - 1) (i2 - 1) in
if subs <= del && subs <= add then
d.(i1).(i2) <- subs + 1, Subs (a1.(i1), a2.(i2)) :: subsops
else if del <= add then d.(i1).(i2) <- del + 1, Del a1.(i1) :: delops
else d.(i1).(i2) <- add + 1, Add a2.(i2) :: addops
done
done;
let _, rops = get (n1 - 1) (n2 - 1) in
let ops = List.rev rops in
let pr_left ppf () =
Format.pp_print_list
~pp_sep:(fun _ () -> ())
(fun ppf -> function
| Eq w -> Format.fprintf ppf "%s" w
| Subs (w, _) | Del w -> Format.fprintf ppf "@{<green>%s@}" w
| Add _ -> ())
ppf ops
in
let pr_right ppf () =
Format.pp_print_list
~pp_sep:(fun _ () -> ())
(fun ppf -> function
| Eq w -> Format.fprintf ppf "%s" w
| Subs (_, w) | Add w -> Format.fprintf ppf "@{<red>%s@}" w
| Del _ -> ())
ppf ops
in
pr_left, pr_right
let diff_command =
let has_gnu_diff () =
File.process_out ~check_exit:ignore "diff" ["--version"]
|> Re.(execp (compile (str "GNU")))
in
lazy
begin
match disp_flags.diff_command with
| None when Message.has_color stdout && has_gnu_diff () ->
let width = Message.terminal_columns () - 5 in
( [
"diff";
"-y";
"-t";
"-W";
string_of_int (Message.terminal_columns () - 5);
],
fun ppf s ->
let mid = (width - 1) / 2 in
Format.fprintf ppf "@{<blue;ul>%*sReference%*s│%*sResult%*s@}@,"
((mid - 9) / 2)
""
(mid - 9 - ((mid - 9) / 2))
""
((width - mid - 7) / 2)
""
(width - mid - 7 - ((width - mid - 7) / 2))
"";
s
|> String.trim_end
|> String.split_on_char '\n'
|> Format.pp_print_list
(fun ppf li ->
let rec find_cut col index =
if index >= String.length li then None
else if col = mid then Some index
else
let c = String.get_utf_8_uchar li index in
find_cut (col + 1) (index + Uchar.utf_decode_length c)
in
match find_cut 0 0 with
| None ->
if li = "" then Format.fprintf ppf "%*s@{<blue>│@}" mid ""
else Format.pp_print_string ppf li
| Some i -> (
let l, c, r =
( String.sub li 0 i,
li.[i],
String.sub li (i + 1) (String.length li - i - 1) )
in
match c with
| ' ' -> Format.fprintf ppf "%s@{<blue>│@}%s" l r
| '>' ->
if String.for_all (( = ) ' ') l then
Format.fprintf ppf
"%*s@{<red>-@}@{<blue>│@}@{<red>%s@}" (mid - 1) "" r
else Format.fprintf ppf "%s@{<blue>│@}@{<red>%s@}" l r
| '<' -> Format.fprintf ppf "%s@{<blue>│@}@{<red>-@}" l
| '|' ->
let ppleft, ppright = colordiff_str l r in
Format.fprintf ppf "%a@{<blue>│@}%a" ppleft () ppright ()
| _ -> Format.pp_print_string ppf li))
ppf )
| Some cmd_opt | (None as cmd_opt) ->
let command =
match cmd_opt with
| Some str -> String.split_on_char ' ' str
| None ->
if Message.has_color stdout && has_command "patdiff" then
["patdiff"; "-alt-old"; "Reference"; "-alt-new"; "Result"]
else ["diff"; "-u"; "-L"; "Reference"; "-L"; "Result"]
in
( command,
fun ppf s ->
s
|> String.trim_end
|> String.split_on_char '\n'
|> Format.pp_print_list Format.pp_print_string ppf )
end
let print_diff ppf p1 p2 =
let get_str (pstart, pend) =
assert (pstart.Lexing.pos_fname = pend.Lexing.pos_fname);
File.with_in_channel pstart.Lexing.pos_fname
@@ fun ic ->
seek_in ic pstart.Lexing.pos_cnum;
really_input_string ic (pend.Lexing.pos_cnum - pstart.Lexing.pos_cnum)
in
File.with_temp_file "clerk-diff" "a" ~contents:(get_str p1)
@@ fun f1 ->
File.with_temp_file "clerk_diff" "b" ~contents:(get_str p2)
@@ fun f2 ->
match Lazy.force diff_command with
| [], _ -> assert false
| cmd :: args, printer ->
File.process_out ~check_exit:(fun _ -> ()) cmd (args @ [f1; f2])
|> printer ppf
let catala_commands_with_output_flag =
["makefile"; "html"; "latex"; "ocaml"; "python"; "r"; "c"]
let pfile ~build_dir f =
f
|> String.remove_prefix ~prefix:(build_dir ^ Filename.dir_sep)
|> String.remove_prefix ~prefix:(Sys.getcwd () ^ Filename.dir_sep)
let clean_command_line ~build_dir file cl =
cl
|> List.filter_map (fun s ->
if s = "--directory=" ^ build_dir then None
else Some (pfile ~build_dir s))
|> (function
| catala :: cmd :: args ->
catala :: cmd :: "-I" :: Filename.dirname file :: args
| cl -> cl)
|> function
| catala :: cmd :: args
when List.mem (String.lowercase_ascii cmd) catala_commands_with_output_flag
->
(catala :: cmd :: args) @ ["-o -"]
| cl -> cl
let display ~build_dir file ppf t =
let pp_pos ppf (start, stop) =
assert (start.Lexing.pos_fname = stop.Lexing.pos_fname);
Format.fprintf ppf "@{<cyan>%s:%d-%d@}"
(pfile ~build_dir start.Lexing.pos_fname)
start.Lexing.pos_lnum stop.Lexing.pos_lnum
in
let print_command () =
Format.fprintf ppf "@,@[<h>$ @{<yellow>%a@}@]"
(Format.pp_print_list ~pp_sep:Format.pp_print_space Format.pp_print_string)
(clean_command_line ~build_dir file t.command_line)
in
Format.pp_open_vbox ppf 2;
if t.success then (
Format.fprintf ppf "@{<green>■@} %a passed" pp_pos t.expected;
if Global.options.debug then print_command ())
else (
Format.fprintf ppf "@{<red>■@} %a failed" pp_pos t.expected;
print_command ();
if disp_flags.diffs then (
Format.pp_print_cut ppf ();
print_diff ppf t.expected t.result));
Format.pp_close_box ppf ()
let display_file ~build_dir ppf t =
let pfile f = String.remove_prefix ~prefix:(build_dir ^ Filename.dir_sep) f in
let print_tests tests =
let tests =
match disp_flags.tests with
| `All | `FailedFile -> tests
| `Failed -> List.filter (fun t -> not t.success) tests
| `None -> assert false
in
Format.pp_print_break ppf 0 3;
Format.pp_open_vbox ppf 0;
Format.pp_print_list (display ~build_dir t.name) ppf tests;
Format.pp_close_box ppf ()
in
if t.successful = t.total then (
if disp_flags.files = `All then (
Format.fprintf ppf
"@{<green;reverse;ul> @} @{<cyan>%s@}: @{<green;bold>%d@} / %d tests \
passed"
(pfile t.name) t.successful t.total;
if disp_flags.tests = `All then print_tests t.tests;
Format.pp_print_cut ppf ()))
else
let () =
match t.successful with
| 0 -> Format.fprintf ppf "@{<red;reverse;ul> @}"
| _ -> Format.fprintf ppf "@{<yellow;reverse;ul> @}"
in
Format.fprintf ppf " @{<cyan>%s@}: " (pfile t.name);
(function
| 0 -> Format.fprintf ppf "@{<red;bold>0@}"
| n -> Format.fprintf ppf "@{<yellow;bold>%d@}" n)
t.successful;
Format.fprintf ppf " / %d tests passed" t.total;
if disp_flags.tests <> `None then print_tests t.tests;
Format.pp_print_cut ppf ()
type box = { print_line : 'a. ('a, Format.formatter, unit) format -> 'a }
[@@ocaml.unboxed]
let print_box tcolor ppf title (pcontents : box -> unit) =
let columns = Message.terminal_columns () in
let tpad = columns - String.width title - 6 in
Format.fprintf ppf "@,%t┏%t @{<bold;reverse> %s @} %t┓@}@," tcolor
(Message.pad (tpad / 2) "")
title
(Message.pad (tpad - (tpad / 2)) "");
Format.pp_open_tbox ppf ();
Format.fprintf ppf "%t@<1>%s@}%*s" tcolor "" (columns - 2) "";
Format.pp_set_tab ppf ();
Format.fprintf ppf "%t┃@}@," tcolor;
let box =
{
print_line =
(fun fmt ->
Format.kfprintf
(fun ppf ->
Format.pp_print_tab ppf ();
Format.fprintf ppf "%t┃@}@," tcolor)
ppf ("%t@<1>%s@} " ^^ fmt) tcolor "");
}
in
pcontents box;
box.print_line "";
Format.pp_close_tbox ppf ();
Format.fprintf ppf "%t┗%t┛@}@," tcolor (Message.pad (columns - 2) "")
let summary ~build_dir tests =
let ppf = Message.formatter_of_out_channel stdout () in
Format.pp_open_vbox ppf 0;
let tests = List.filter (fun f -> f.total > 0) tests in
let files, success_files, success, total =
List.fold_left
(fun (files, success_files, success, total) file ->
( files + 1,
(if file.successful < file.total then success_files
else success_files + 1),
success + file.successful,
total + file.total ))
(0, 0, 0, 0) tests
in
if disp_flags.files <> `None then
List.iter (fun f -> display_file ~build_dir ppf f) tests;
let result_box =
if success < total then
print_box (fun ppf -> Format.fprintf ppf "@{<red>") ppf "TESTS FAILED"
else
print_box
(fun ppf -> Format.fprintf ppf "@{<green>")
ppf "ALL TESTS PASSED"
in
result_box (fun box ->
box.print_line "@{<ul>%-5s %10s %10s %10s@}" "" "FAILED" "PASSED" "TOTAL";
if files > 1 then
box.print_line "%-5s @{<red;bold>%a@} @{<green;bold>%a@} @{<bold>%10d@}"
"files"
(fun ppf -> function
| 0 -> Format.fprintf ppf "@{<green>%10d@}" 0
| n -> Format.fprintf ppf "%10d" n)
(files - success_files)
(fun ppf -> function
| 0 -> Format.fprintf ppf "@{<red>%10d@}" 0
| n -> Format.fprintf ppf "%10d" n)
success_files files;
box.print_line "%-5s @{<red;bold>%a@} @{<green;bold>%a@} @{<bold>%10d@}"
"tests"
(fun ppf -> function
| 0 -> Format.fprintf ppf "@{<green>%10d@}" 0
| n -> Format.fprintf ppf "%10d" n)
(total - success)
(fun ppf -> function
| 0 -> Format.fprintf ppf "@{<red>%10d@}" 0
| n -> Format.fprintf ppf "%10d" n)
success total);
Format.pp_close_box ppf ();
Format.pp_print_flush ppf ();
success = total
let print_xml ~build_dir tests =
let ffile ppf f = Format.pp_print_string ppf (pfile ~build_dir f) in
let ppf = Message.formatter_of_out_channel stdout () in
let tests = List.filter (fun f -> f.total > 0) tests in
let success, total =
List.fold_left
(fun (success, total) file ->
success + file.successful, total + file.total)
(0, 0) tests
in
Format.fprintf ppf "@[<v><?xml version=\"1.0\" encoding=\"UTF-8\"?>@,";
Format.fprintf ppf "@[<v 2><testsuites tests=\"%d\" failures=\"%d\">@,"
success (total - success);
Format.pp_print_list
(fun ppf f ->
Format.fprintf ppf
"@[<v 2>@[<hov 1><testsuite@ name=\"%a\"@ tests=\"%d\"@ \
failures=\"%d\">@]@,"
ffile f.name f.total (f.total - f.successful);
Format.pp_print_list
(fun ppf t ->
Format.fprintf ppf "@[<v 2><testcase line=\"%d\">"
(fst t.expected).Lexing.pos_lnum;
Format.fprintf ppf
"@,\
@[<hv 2><property name=\"description\">@,\
@[<hov 2>%a@]@;\
<0 -2></property>@]"
(Format.pp_print_list ~pp_sep:Format.pp_print_space
Format.pp_print_string)
(clean_command_line ~build_dir f.name t.command_line);
if not t.success then (
Format.fprintf ppf
"@,@[<v 2><failure message=\"Output differs from reference\">@,";
print_diff ppf t.expected t.result;
Format.fprintf ppf "@]@,</failure>");
Format.fprintf ppf "@]@,</testcase>")
ppf f.tests;
Format.fprintf ppf "@]@,</testsuite>")
ppf tests;
Format.fprintf ppf "@]@,</testsuites>@,@]@.";
success = total

View File

@ -0,0 +1,52 @@
(* This file is part of the Catala build system, a specification language for
tax and social benefits computation rules. Copyright (C) 2024 Inria,
contributors: Louis Gesbert <louis.gesbert@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not
use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *)
(** This module defines and manipulates Clerk test reports, which can be written
by `clerk runtest` and read to provide test result summaries. This only
concerns inline tests (```catala-test-inline blocks). *)
open Catala_utils
type test = {
success : bool;
command_line : string list;
expected : Lexing.position * Lexing.position;
(** The precise offsets of the expected result in the source file *)
result : Lexing.position * Lexing.position;
(** Same for the actual result in the destination file *)
}
type file = { name : File.t; successful : int; total : int; tests : test list }
val write_to : File.t -> file -> unit
val read_from : File.t -> file
val read_many : File.t -> file list
val display : build_dir:File.t -> File.t -> Format.formatter -> test -> unit
val summary : build_dir:File.t -> file list -> bool
(** Displays a summary to stdout; returns true if all tests succeeded *)
val print_xml : build_dir:File.t -> file list -> bool
(** Displays a summary in JUnit XML comptible format to stdout; returns true if
all tests succeeded *)
val set_display_flags :
?files:[ `All | `Failed | `None ] ->
?tests:[ `All | `FailedFile | `Failed | `None ] ->
?diffs:bool ->
?diff_command:string option option ->
unit ->
unit

View File

@ -16,10 +16,53 @@
open Catala_utils
let run_catala_test test_flags catala_exe catala_opts file program args oc =
let cmd_in_rd, cmd_in_wr = Unix.pipe () in
Unix.set_close_on_exec cmd_in_wr;
let command_oc = Unix.out_channel_of_descr cmd_in_wr in
type output_buf = { oc : out_channel; mutable pos : Lexing.position }
let pos0 pos_fname =
{ Lexing.pos_fname; pos_cnum = 0; pos_lnum = 1; pos_bol = 0 }
let with_output file_opt f =
match file_opt with
| Some file ->
File.with_out_channel file @@ fun oc -> f { oc; pos = pos0 file }
| None -> f { oc = stdout; pos = pos0 "<stdout>" }
let out_line output_buf str =
let len = String.length str in
let has_nl = str <> "" && str.[len - 1] = '\n' in
output_string output_buf.oc str;
if not has_nl then output_char output_buf.oc '\n';
let pos_cnum = output_buf.pos.pos_cnum + len + if has_nl then 0 else 1 in
output_buf.pos <-
{
output_buf.pos with
Lexing.pos_cnum;
pos_lnum = output_buf.pos.pos_lnum + 1;
pos_bol = pos_cnum;
}
let sanitize =
let re_endtest = Re.(compile @@ seq [bol; str "```"]) in
let re_modhash =
Re.(
compile
@@ seq
[
str "\"CM0|";
repn xdigit 8 (Some 8);
char '|';
repn xdigit 8 (Some 8);
char '|';
repn xdigit 8 (Some 8);
char '"';
])
in
fun str ->
str
|> Re.replace_string re_endtest ~by:"\\```"
|> Re.replace_string re_modhash ~by:"\"CMX|XXXXXXXX|XXXXXXXX|XXXXXXXX\""
let catala_test_command test_flags catala_exe catala_opts args out =
let catala_exe =
(* If the exe name contains directories, make it absolute. Otherwise don't
modify it so that it can be looked up in PATH. *)
@ -27,15 +70,15 @@ let run_catala_test test_flags catala_exe catala_opts file program args oc =
Unix.realpath catala_exe
else catala_exe
in
let cmd =
match args with
| cmd0 :: flags ->
match args with
| cmd0 :: flags -> (
try
let cmd0, flags =
match String.lowercase_ascii cmd0, flags, test_flags with
| "test-scope", scope_name :: flags, test_flags ->
"interpret", (("--scope=" ^ scope_name) :: flags) @ test_flags
"interpret", flags @ test_flags @ ["--scope=" ^ scope_name]
| "test-scope", [], _ ->
output_string oc
out_line out
"[INVALID TEST] Invalid test command syntax, the 'test-scope' \
pseudo-command takes a scope name as first argument\n";
"interpret", test_flags
@ -43,38 +86,69 @@ let run_catala_test test_flags catala_exe catala_opts file program args oc =
| _, _, _ :: _ ->
raise Exit (* Skip other tests when test-flags is specified *)
in
Array.of_list
((catala_exe :: cmd0 :: catala_opts) @ flags @ ["--name=" ^ file; "-"])
| [] -> Array.of_list ((catala_exe :: catala_opts) @ [file])
Some (Array.of_list ((catala_exe :: cmd0 :: catala_opts) @ flags))
with Exit -> None)
| [] -> Some (Array.of_list (catala_exe :: catala_opts))
let catala_test_env () =
Unix.environment ()
|> Array.to_seq
|> Seq.filter (fun s ->
not
(String.starts_with ~prefix:"OCAMLRUNPARAM=" s
|| String.starts_with ~prefix:"CATALA_" s))
|> Seq.cons "CATALA_OUT=-"
|> Seq.cons "CATALA_COLOR=never"
|> Seq.cons "CATALA_PLUGINS="
|> Array.of_seq
let run_catala_test filename cmd program expected out_line =
let cmd_in_rd, cmd_in_wr = Unix.pipe ~cloexec:true () in
let cmd_out_rd, cmd_out_wr = Unix.pipe ~cloexec:true () in
let command_oc = Unix.out_channel_of_descr cmd_in_wr in
let command_ic = Unix.in_channel_of_descr cmd_out_rd in
let env = catala_test_env () in
let cmd = Array.append cmd [| "--name=" ^ filename; "-" |] in
let pid =
Unix.create_process_env cmd.(0) cmd env cmd_in_rd cmd_out_wr cmd_out_wr
in
let env =
Unix.environment ()
|> Array.to_seq
|> Seq.filter (fun s ->
not
(String.starts_with ~prefix:"OCAMLRUNPARAM=" s
|| String.starts_with ~prefix:"CATALA_" s))
|> Seq.cons "CATALA_OUT=-"
(* |> Seq.cons "CATALA_COLOR=never" *)
|> Seq.cons "CATALA_PLUGINS="
|> Array.of_seq
in
flush oc;
let ocfd = Unix.descr_of_out_channel oc in
let pid = Unix.create_process_env catala_exe cmd env cmd_in_rd ocfd ocfd in
Unix.close cmd_in_rd;
Unix.close cmd_out_wr;
Seq.iter (output_string command_oc) program;
close_out command_oc;
let out_lines =
Seq.of_dispenser (fun () -> In_channel.input_line command_ic)
in
let success, expected =
Seq.fold_left
(fun (success, expected) result_line ->
let result_line = sanitize result_line ^ "\n" in
out_line result_line;
match Seq.uncons expected with
| Some (l, expected) -> success && String.equal result_line l, expected
| None -> false, Seq.empty)
(true, expected) out_lines
in
let return_code =
match Unix.waitpid [] pid with
| _, Unix.WEXITED n -> n
| _, (Unix.WSIGNALED n | Unix.WSTOPPED n) -> 128 - n
in
if return_code <> 0 then Printf.fprintf oc "#return code %d#\n" return_code
let success, expected =
if return_code = 0 then success, expected
else
let line = Printf.sprintf "#return code %d#\n" return_code in
out_line line;
match Seq.uncons expected with
| Some (l, expected) when String.equal l line -> success, expected
| Some (_, expected) -> false, expected
| None -> false, Seq.empty
in
success && Seq.is_empty expected
(** Directly runs the test (not using ninja, this will be called by ninja rules
through the "clerk runtest" command) *)
let run_inline_tests catala_exe catala_opts test_flags filename =
let run_tests ~catala_exe ~catala_opts ~test_flags ~report ~out filename =
let module L = Surface.Lexer_common in
let lang =
match Clerk_scan.get_lang filename with
@ -84,30 +158,83 @@ let run_inline_tests catala_exe catala_opts test_flags filename =
File.format filename
in
let lines = Surface.Parser_driver.lines filename lang in
let oc = stdout in
with_output out
@@ fun out ->
let lines_until_now = Queue.create () in
let push str =
output_string oc str;
let push_line str =
out_line out str;
Queue.add str lines_until_now
in
let rec run_test lines =
let rtests : Clerk_report.test list ref = ref [] in
let rec skip_block lines =
match Seq.uncons lines with
| Some ((l, tok, _), lines) ->
push_line l;
if tok = L.LINE_BLOCK_END then lines else skip_block lines
| None -> Seq.empty
in
let rec get_block acc lines =
let return lines acc =
let endpos =
match acc with
| (_, _, (_, epos)) :: _ -> epos
| [] -> { Lexing.dummy_pos with pos_fname = filename }
in
let block = List.rev acc in
let startpos =
match block with
| (_, _, (spos, _)) :: _ -> spos
| [] -> { Lexing.dummy_pos with pos_fname = filename }
in
lines, block, (startpos, endpos)
in
match Seq.uncons lines with
| None -> return Seq.empty acc
| Some ((_, L.LINE_BLOCK_END, _), lines) -> return lines acc
| Some (li, lines) -> get_block (li :: acc) lines
in
let broken_test msg =
let opos_start = out.pos in
push_line msg;
{
Clerk_report.success = false;
command_line = [];
expected =
( { Lexing.dummy_pos with pos_fname = filename },
{ Lexing.dummy_pos with pos_fname = filename } );
result = opos_start, out.pos;
}
in
let get_test_command lines =
match Seq.uncons lines with
| None ->
output_string oc
"[INVALID TEST] Missing test command, use '$ catala <args>'\n"
| Some ((str, L.LINE_BLOCK_END), lines) ->
output_string oc
"[INVALID TEST] Missing test command, use '$ catala <args>'\n";
push str;
process lines
| Some ((str, _), lines) -> (
push str;
let t =
broken_test
"[INVALID TEST] Missing test command, use '$ catala <args>'\n"
in
rtests := t :: !rtests;
None, Seq.empty
| Some ((str, L.LINE_BLOCK_END, _), lines) ->
let t =
broken_test
"[INVALID TEST] Missing test command, use '$ catala <args>'\n"
in
rtests := t :: !rtests;
push_line str;
None, lines
| Some ((str, _, _), lines) -> (
push_line str;
match Clerk_scan.test_command_args str with
| None ->
output_string oc
"[INVALID TEST] Invalid test command syntax, must match '$ catala \
<args>'\n";
skip_block lines
let t =
broken_test
"[INVALID TEST] Invalid test command syntax, must match '$ catala \
<args>'\n"
in
let lines, _, ipos = get_block [] lines in
push_line "```\n";
rtests := { t with Clerk_report.expected = ipos } :: !rtests;
None, lines
| Some args -> (
let args = String.split_on_char ' ' args in
let program =
@ -121,29 +248,107 @@ let run_inline_tests catala_exe catala_opts test_flags filename =
in
Queue.to_seq lines_until_now |> drop_last |> drop_last
in
let opos_start = out.pos in
match
run_catala_test test_flags catala_exe catala_opts filename program
args oc
catala_test_command test_flags catala_exe catala_opts args out
with
| () -> skip_block lines
| exception Exit -> process lines))
and skip_block lines =
match Seq.uncons lines with
| None -> ()
| Some ((str, L.LINE_BLOCK_END), lines) ->
push str;
| Some cmd -> Some (cmd, program, opos_start), lines
| None -> None, skip_block lines))
in
let rec run_inline_test lines =
match get_test_command lines with
| None, lines -> process lines
| Some (cmd, program, opos_start), lines ->
let lines, expected, ipos = get_block [] lines in
let expected = Seq.map (fun (s, _, _) -> s) (List.to_seq expected) in
let success = run_catala_test filename cmd program expected push_line in
let opos_end = out.pos in
push_line "```\n";
rtests :=
{
Clerk_report.success;
command_line = Array.to_list cmd @ [filename];
result = opos_start, opos_end;
expected = ipos;
}
:: !rtests;
process lines
and run_output_test id lines =
match get_test_command lines with
| None, lines -> process lines
| Some (cmd, program, _), lines ->
let lines = skip_block lines in
let ref_file =
File.((filename /../ "output" / Filename.basename filename) -.- id)
in
if not (Sys.file_exists ref_file) then
(* Create the file if it doesn't exist *)
File.with_out_channel ref_file ignore;
let output = ref_file ^ "@out" in
let ipos_start = pos0 ref_file in
let ipos_end = ref ipos_start in
let report =
File.with_in_channel ref_file
@@ fun ic ->
let expected =
Seq.of_dispenser (fun () ->
match In_channel.input_line ic with
| None -> None
| Some s ->
let s = s ^ "\n" in
let pos_cnum = !ipos_end.pos_cnum + String.length s in
ipos_end :=
{
!ipos_end with
Lexing.pos_cnum;
pos_lnum = !ipos_end.pos_lnum + 1;
pos_bol = pos_cnum;
};
Some s)
in
with_output (Some output)
@@ fun test_out ->
let opos_start = test_out.pos in
let success =
run_catala_test filename cmd program expected (out_line test_out)
in
Seq.iter ignore expected;
{
Clerk_report.success;
command_line = Array.to_list cmd @ [filename];
result = opos_start, test_out.pos;
expected = ipos_start, !ipos_end;
}
in
rtests := report :: !rtests;
process lines
| Some ((str, _), lines) ->
Queue.add str lines_until_now;
skip_block lines
and process lines =
match Seq.uncons lines with
| Some ((str, L.LINE_INLINE_TEST), lines) ->
push str;
run_test lines
| Some ((str, _), lines) ->
push str;
| Some ((str, L.LINE_INLINE_TEST, _), lines) ->
push_line str;
run_inline_test lines
| Some ((str, L.LINE_TEST id, _), lines) ->
push_line str;
run_output_test id lines
| Some ((str, _, _), lines) ->
push_line str;
process lines
| None -> ()
in
process lines
process lines;
let tests_report =
List.fold_left
Clerk_report.(
fun tests t ->
{
tests with
total = tests.total + 1;
successful = (tests.successful + if t.success then 1 else 0);
tests = t :: tests.tests;
})
{ Clerk_report.name = filename; successful = 0; total = 0; tests = [] }
!rtests
in
match report with
| Some file -> Clerk_report.write_to file tests_report
| None -> ()

View File

@ -22,7 +22,14 @@
open Catala_utils
val run_inline_tests : string -> string list -> string list -> File.t -> unit
(** [run_inline_tests catala_exe catala_opts test_flags file] runs the tests in
Catala [file] using the given path to the Catala executable and the provided
options. Output is printed to [stdout]. *)
val run_tests :
catala_exe:string ->
catala_opts:string list ->
test_flags:string list ->
report:File.t option ->
out:File.t option ->
File.t ->
unit
(** [run_tests ~catala_exe ~catala_opts ~test_flags ~report ~out file] runs the
tests in Catala [file] using the given path to the Catala executable and the
provided options. Output is printed to [stdout] if [out] is [None]. *)

View File

@ -60,10 +60,10 @@ let catala_file (file : File.t) (lang : Catala_utils.Global.backend_lang) : item
let rec parse lines n acc =
match Seq.uncons lines with
| None -> acc
| Some ((_, L.LINE_TEST id), lines) ->
| Some ((_, L.LINE_TEST id, _), lines) ->
let test, lines, n = parse_test id lines (n + 1) in
parse lines n { acc with legacy_tests = test :: acc.legacy_tests }
| Some ((_, line), lines) -> (
| Some ((_, line, _), lines) -> (
parse lines (n + 1)
@@
match line with
@ -88,7 +88,7 @@ let catala_file (file : File.t) (lang : Catala_utils.Global.backend_lang) : item
[Format.asprintf "'invalid test syntax at %a:%d'" File.format file n]
in
match Seq.uncons lines with
| Some ((str, L.LINE_ANY), lines) -> (
| Some ((str, L.LINE_ANY, _), lines) -> (
match test_command_args str with
| Some cmd ->
let cmd, lines, n = parse_block lines (n + 1) [cmd] in
@ -103,8 +103,8 @@ let catala_file (file : File.t) (lang : Catala_utils.Global.backend_lang) : item
| None -> { test with cmd = err n }, lines, n
and parse_block lines n acc =
match Seq.uncons lines with
| Some ((_, L.LINE_BLOCK_END), lines) -> List.rev acc, lines, n + 1
| Some ((str, _), lines) -> String.trim str :: acc, lines, n + 1
| Some ((_, L.LINE_BLOCK_END, _), lines) -> List.rev acc, lines, n + 1
| Some ((str, _, _), lines) -> String.trim str :: acc, lines, n + 1
| None -> List.rev acc, lines, n
in
parse

View File

@ -8,8 +8,9 @@
ninja_utils
cmdliner
re
ocolor)
(modules clerk_scan clerk_runtest clerk_driver))
ocolor
otoml)
(modules clerk_scan clerk_report clerk_runtest clerk_config clerk_driver))
(rule
(target custom_linking.sexp)

View File

@ -22,7 +22,7 @@ depends: [
"bindlib" {>= "6.0"}
"cmdliner" {>= "1.1.0"}
"cppo" {>= "1"}
"dates_calc" {>= "0.0.4"}
"dates_calc" {>= "0.0.6"}
"dune" {>= "3.11"}
"js_of_ocaml-ppx" {= "4.1.0"}
"menhir" {>= "20200211"}
@ -31,7 +31,7 @@ depends: [
"ocamlfind" {!= "1.9.5"}
"ocamlgraph" {>= "1.8.8"}
"re" {>= "1.10"}
"sedlex" {>= "2.4"}
"sedlex" {>= "3.1"}
"uutf" {>= "1.0.3"}
"ubase" {>= "0.05"}
"unionFind" {>= "20220109"}
@ -47,10 +47,10 @@ depends: [
"conf-npm" {cataladevmode}
"conf-python-3-dev" {cataladevmode}
"cpdf" {cataladevmode}
"conf-diffutils" {cataladevmode}
"conf-pandoc" {cataladevmode}
"z3" {catalaz3mode}
"conf-ninja"
"otoml" {>= "1.0"}
]
depopts: ["z3"]
conflicts: [

View File

@ -199,6 +199,12 @@ module Flags = struct
"Behave as if run from the given directory for file and error \
reporting. Does not affect resolution of files in arguments."
let stop_on_error =
value
& flag
& info ["x"; "stop-on-error"]
~doc:"Stops the compilation as soon as an error is encountered."
let flags =
let make
language
@ -209,7 +215,8 @@ module Flags = struct
plugins_dirs
disable_warnings
max_prec_digits
directory : options =
directory
stop_on_error : options =
if debug then Printexc.record_backtrace true;
let path_rewrite =
match directory with
@ -223,7 +230,8 @@ module Flags = struct
(* This sets some global refs for convenience, but most importantly
returns the options record. *)
Global.enforce_options ~language ~debug ~color ~message_format ~trace
~plugins_dirs ~disable_warnings ~max_prec_digits ~path_rewrite ()
~plugins_dirs ~disable_warnings ~max_prec_digits ~path_rewrite
~stop_on_error ()
in
Term.(
const make
@ -235,7 +243,8 @@ module Flags = struct
$ plugins_dirs
$ disable_warnings
$ max_prec_digits
$ directory)
$ directory
$ stop_on_error)
let options =
let make input_src name directory options : options =
@ -325,13 +334,6 @@ module Flags = struct
~env:(Cmd.Env.info "CATALA_OPTIMIZE")
~doc:"Run compiler optimizations."
let avoid_exceptions =
value
& flag
& info ["avoid-exceptions"]
~env:(Cmd.Env.info "CATALA_AVOID_EXCEPTIONS")
~doc:"Compiles the default calculus without exceptions."
let keep_special_ops =
value
& flag
@ -372,9 +374,7 @@ module Flags = struct
value
& flag
& info ["closure-conversion"]
~doc:
"Performs closure conversion on the lambda calculus. Implies \
$(b,--avoid-exceptions)."
~doc:"Performs closure conversion on the lambda calculus."
let disable_counterexamples =
value

View File

@ -55,7 +55,6 @@ module Flags : sig
val ex_variable : string Term.t
val output : raw_file option Term.t
val optimize : bool Term.t
val avoid_exceptions : bool Term.t
val closure_conversion : bool Term.t
val keep_special_ops : bool Term.t
val monomorphize_types : bool Term.t

View File

@ -66,6 +66,8 @@ let clean_path p =
in
if p = "" then "." else p
let exists = Sys.file_exists
let rec ensure_dir dir =
match Sys.is_directory dir with
| true -> ()
@ -104,6 +106,20 @@ let reverse_path ?(from_dir = Sys.getcwd ()) ~to_dir f =
String.concat Filename.dir_sep
(aux (path_to_list f) rbase (path_to_list to_dir))
let find_in_parents predicate =
let home = try Sys.getenv "HOME" with Not_found -> "" in
let rec lookup dir rel =
if predicate dir then Some dir, rel
else if dir = home then None, Filename.current_dir_name
else
let parent = Filename.dirname dir in
if parent = dir then None, Filename.current_dir_name
else lookup parent (rel / Filename.parent_dir_name)
in
match lookup (Sys.getcwd ()) Filename.current_dir_name with
| Some dir, rel -> Some (dir, rel)
| None, _ -> None
let with_out_channel filename f =
ensure_dir (Filename.dirname filename);
let oc = open_out filename in
@ -185,21 +201,24 @@ let process_out ?check_exit cmd args =
let () =
let default = 80 in
let get_terminal_cols () =
let from_env () =
try int_of_string (Sys.getenv "COLUMNS") with Not_found | Failure _ -> 0
in
let count =
try
(* terminfo *)
process_out "tput" ["cols"] |> String.trim |> int_of_string
with Failure _ -> (
if not Unix.(isatty stdin) then from_env ()
else
try
(* stty *)
process_out "stty" ["size"]
|> String.trim
|> fun s ->
let i = String.rindex s ' ' + 1 in
String.sub s i (String.length s - i) |> int_of_string
with Failure _ | Not_found | Invalid_argument _ -> (
try int_of_string (Sys.getenv "COLUMNS")
with Not_found | Failure _ -> 0))
(* terminfo *)
process_out "tput" ["cols"] |> String.trim |> int_of_string
with Failure _ -> (
try
(* stty *)
process_out "stty" ["size"]
|> String.trim
|> fun s ->
let i = String.rindex s ' ' + 1 in
String.sub s i (String.length s - i) |> int_of_string
with Failure _ | Not_found | Invalid_argument _ -> from_env ())
in
if count > 0 then count else default
in

View File

@ -89,6 +89,9 @@ val ensure_dir : t -> unit
(** Creates the directory (and parents recursively) if it doesn't exist already.
Errors out if the file exists but is not a directory *)
val exists : t -> bool
(** Alias for Sys.file_exists*)
val check_file : t -> t option
(** Returns its argument if it exists and is a plain file, [None] otherwise.
Does not do resolution like [check_directory]. *)
@ -122,6 +125,12 @@ val reverse_path : ?from_dir:t -> to_dir:t -> t -> t
leading to [f] from [to_dir]. The results attempts to be relative to
[to_dir]. *)
val find_in_parents : (t -> bool) -> (t * t) option
(** Checks for the first directory matching the given predicate from the current
directory upwards. Recursion stops at home. Returns a pair [dir, rel_path],
where [dir] is the ancestor directory matching the predicate, and [rel_path]
is a path pointing to it from the current dir. *)
val ( /../ ) : t -> t -> t
(** Sugar for [parent a / b] *)

View File

@ -37,6 +37,7 @@ type options = {
mutable disable_warnings : bool;
mutable max_prec_digits : int;
mutable path_rewrite : raw_file -> file;
mutable stop_on_error : bool;
}
(* Note: we force that the global options (ie options common to all commands)
@ -56,6 +57,7 @@ let options =
disable_warnings = false;
max_prec_digits = 20;
path_rewrite = (fun _ -> assert false);
stop_on_error = false;
}
let enforce_options
@ -69,6 +71,7 @@ let enforce_options
?disable_warnings
?max_prec_digits
?path_rewrite
?stop_on_error
() =
Option.iter (fun x -> options.input_src <- x) input_src;
Option.iter (fun x -> options.language <- x) language;
@ -80,6 +83,7 @@ let enforce_options
Option.iter (fun x -> options.disable_warnings <- x) disable_warnings;
Option.iter (fun x -> options.max_prec_digits <- x) max_prec_digits;
Option.iter (fun x -> options.path_rewrite <- x) path_rewrite;
Option.iter (fun x -> options.stop_on_error <- x) stop_on_error;
options
let input_src_file = function FileName f | Contents (_, f) | Stdin f -> f

View File

@ -56,6 +56,7 @@ type options = private {
mutable disable_warnings : bool;
mutable max_prec_digits : int;
mutable path_rewrite : raw_file -> file;
mutable stop_on_error : bool;
}
(** Global options, common to all subcommands (note: the fields are internally
mutable only for purposes of the [globals] toplevel value defined below) *)
@ -76,6 +77,7 @@ val enforce_options :
?disable_warnings:bool ->
?max_prec_digits:int ->
?path_rewrite:(raw_file -> file) ->
?stop_on_error:bool ->
unit ->
options
(** Sets up the global options (side-effect); for specific use-cases only, this

View File

@ -0,0 +1,105 @@
(* This file is part of the Catala compiler, a specification language for tax
and social benefits computation rules. Copyright (C) 2024 Inria, contributor:
Louis Gesbert <louis.gesbert@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not
use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *)
type t = int
let mix (h1 : t) (h2 : t) : t = Hashtbl.hash (h1, h2)
let raw = Hashtbl.hash
module Op = struct
let ( % ) = mix
let ( ! ) = raw
end
open Op
let option f = function None -> !`None | Some x -> !`Some % f x
let list hf l = List.fold_left (fun acc x -> acc % hf x) !`ListEmpty l
let map fold_fun kh vh map =
fold_fun (fun k v acc -> acc lxor (kh k % vh v)) map !`HashMapDelim
module Flags : sig
type nonrec t = private t
val pass :
(t -> 'a) -> closure_conversion:bool -> monomorphize_types:bool -> 'a
val of_t : int -> t
end = struct
type nonrec t = t
let pass k ~closure_conversion ~monomorphize_types =
(* Should not affect the call convention or actual interfaces: include,
optimize, check_invariants, typed *)
!(closure_conversion : bool)
% !(monomorphize_types : bool)
% (* The following may not affect the call convention, but we want it set in
an homogeneous way *)
!(Global.options.trace : bool)
% !(Global.options.max_prec_digits : int)
|> k
let of_t t = t
end
type full = { catala_version : t; flags_hash : Flags.t; contents : t }
let finalise t =
Flags.pass (fun flags_hash ->
{ catala_version = !(Version.v : string); flags_hash; contents = t })
let to_string full =
Printf.sprintf "CM0|%08x|%08x|%08x" full.catala_version
(full.flags_hash :> int)
full.contents
(* Putting color inside the hash makes them much easier to differentiate at a
glance *)
let format ppf full =
let open Ocolor_types in
let pcolor col f x =
Format.pp_open_stag ppf Ocolor_format.(Ocolor_style_tag (Fg (C24 col)));
f x;
Format.pp_close_stag ppf ()
in
let tag = pcolor { r24 = 172; g24 = 172; b24 = 172 } in
let auto i =
{
r24 = 128 + (i mod 128);
g24 = 128 + ((i lsr 10) mod 128);
b24 = 128 + ((i lsr 20) mod 128);
}
in
let phash h =
let col = auto h in
pcolor col (Format.fprintf ppf "%08x") h
in
tag (Format.pp_print_string ppf) "CM0|";
phash full.catala_version;
tag (Format.pp_print_string ppf) "|";
phash (full.flags_hash :> int);
tag (Format.pp_print_string ppf) "|";
phash full.contents
let of_string s =
try
Scanf.sscanf s "CM0|%08x|%08x|%08x"
(fun catala_version flags_hash contents ->
{ catala_version; flags_hash = Flags.of_t flags_hash; contents })
with Scanf.Scan_failure _ -> failwith "Hash.of_string"
let external_placeholder = "*external*"

View File

@ -0,0 +1,76 @@
(* This file is part of the Catala compiler, a specification language for tax
and social benefits computation rules. Copyright (C) 2024 Inria, contributor:
Louis Gesbert <louis.gesbert@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not
use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *)
(** Hashes for the identification of modules.
In contrast with OCaml's basic `Hashtbl.hash`, they process the full depth
of terms. Any meaningful interface change in a module should only be in hash
collision with a 1/2^30 probability. *)
type t = private int
(** Native Hasthbl.hash hashes, value is truncated to 30 bits whatever the
architecture (positive 31-bit integers) *)
type full
(** A "full" hash includes the Catala version and compilation flags, alongside
the module interface *)
val raw : 'a -> t
(** [Hashtbl.hash]. Do not use on deep types (it has a bounded depth), use
specific hashing functions. *)
module Op : sig
val ( ! ) : 'a -> t
(** Shortcut to [raw]. Same warning: use with an explicit type annotation
[!(foo: string)] to ensure it's not called on types that are recursive or
include annotations.
Hint: we use [!`Foo] as a fancy way to generate constants for
discriminating constructions *)
val ( % ) : t -> t -> t
(** Safe combination of two hashes (non commutative or associative, etc.) *)
end
val option : ('a -> t) -> 'a option -> t
val list : ('a -> t) -> 'a list -> t
val map :
(('k -> 'v -> t -> t) -> 'map -> t -> t) ->
('k -> t) ->
('v -> t) ->
'map ->
t
(** [map fold_f key_hash_f value_hash_f map] computes the hash of a map. The
first argument is expected to be a [Foo.Map.fold] function. The result is
independent of the ordering of the map. *)
val finalise : t -> closure_conversion:bool -> monomorphize_types:bool -> full
(** Turns a raw interface hash into a full hash, ready for printing *)
val to_string : full -> string
val format : Format.formatter -> full -> unit
val of_string : string -> full
(** @raise Failure *)
val external_placeholder : string
(** It's inconvenient to need hash updates on external modules. This string is
uses as a hash instead for those cases.
NOTE: This is a temporary solution A future approach could be to have Catala
generate a module loader (with the proper hash), relieving the user
implementation from having to do the registration. *)

View File

@ -29,6 +29,7 @@ let fold f (x, _) = f x
let fold2 f (x, _) (y, _) = f x y
let compare cmp a b = fold2 cmp a b
let equal eq a b = fold2 eq a b
let hash f (x, _) = f x
class ['self] marked_map =
object (_self : 'self)

View File

@ -41,6 +41,10 @@ val compare : ('a -> 'a -> int) -> ('a, 'm) ed -> ('a, 'm) ed -> int
val equal : ('a -> 'a -> bool) -> ('a, 'm) ed -> ('a, 'm) ed -> bool
(** Tests equality of two marked values {b ignoring marks} *)
val hash : ('a -> Hash.t) -> ('a, 'm) ed -> Hash.t
(** Computes the hash of the marked values using the given function
{b ignoring mark} *)
(** Visitors *)
class ['self] marked_map : object ('self)

View File

@ -34,14 +34,14 @@ let unstyle_formatter ppf =
[Format.sprintf] etc. functions (ignoring them) *)
let () = ignore (unstyle_formatter Format.str_formatter)
let terminal_columns, set_terminal_width_function =
let get_cols = ref (fun () -> 80) in
(fun () -> !get_cols ()), fun f -> get_cols := f
(* Note: we could do the same for std_formatter, err_formatter... but we'd
rather promote the use of the formatting functions of this module and the
below std_ppf / err_ppf *)
let terminal_columns, set_terminal_width_function =
let get_cols = ref (fun () -> 80) in
(fun () -> !get_cols ()), fun f -> get_cols := f
let has_color_raw ~(tty : bool Lazy.t) =
match Global.options.color with
| Global.Never -> false
@ -90,6 +90,8 @@ let unformat (f : Format.formatter -> unit) : string =
Format.pp_print_flush ppf ();
Buffer.contents buf
let pad n s ppf = Pos.pad_fmt n s ppf
(**{2 Message types and output helpers *)
type level = Error | Warning | Debug | Log | Result
@ -109,10 +111,9 @@ let print_time_marker =
let old_time = !time in
time := new_time;
let delta = (new_time -. old_time) *. 1000. in
if delta > 50. then
Format.fprintf ppf "@{<bold;black>[TIME] %.0fms@}@\n" delta
if delta > 50. then Format.fprintf ppf " @{<bold;black>%.0fms@}" delta
let pp_marker target ppf =
let pp_marker ?extra_label target ppf =
let open Ocolor_types in
let tags, str =
match target with
@ -122,10 +123,15 @@ let pp_marker target ppf =
| Result -> [Bold; Fg (C4 green)], "RESULT"
| Log -> [Bold; Fg (C4 black)], "LOG"
in
if target = Debug then print_time_marker ppf ();
let str =
match extra_label with
| None -> str
| Some lbl -> Printf.sprintf "%s %s" str lbl
in
Format.pp_open_stag ppf (Ocolor_format.Ocolor_styles_tag tags);
Format.pp_print_string ppf str;
Format.pp_close_stag ppf ()
Format.pp_close_stag ppf ();
if target = Debug then print_time_marker ppf ()
(**{2 Printers}*)
@ -165,7 +171,7 @@ module Content = struct
let of_string (s : string) : t =
[MainMessage (fun ppf -> Format.pp_print_text ppf s)]
let basic_msg ppf target content =
let basic_msg ?(pp_marker = pp_marker) ppf target content =
Format.pp_open_vbox ppf 0;
Format.pp_print_list
~pp_sep:(fun ppf () -> Format.fprintf ppf "@,@,")
@ -184,7 +190,7 @@ module Content = struct
Format.pp_close_box ppf ();
Format.pp_print_newline ppf ()
let fancy_msg ppf target content =
let fancy_msg ?(pp_marker = pp_marker) ppf target content =
let ppf_out_fcts = Format.pp_get_formatter_out_functions ppf () in
let restore_ppf () =
Format.pp_print_flush ppf ();
@ -269,13 +275,13 @@ module Content = struct
restore_ppf ();
Format.pp_print_newline ppf ()
let emit (content : t) (target : level) : unit =
let emit ?(pp_marker = pp_marker) (content : t) (target : level) : unit =
match Global.options.message_format with
| Global.Human -> (
let ppf = get_ppf target in
match target with
| Debug | Log -> basic_msg ppf target content
| Result | Warning | Error -> fancy_msg ppf target content)
| Debug | Log -> basic_msg ~pp_marker ppf target content
| Result | Warning | Error -> fancy_msg ~pp_marker ppf target content)
| Global.GNU ->
(* The top message doesn't come with a position, which is not something
the GNU standard allows. So we look the position list and put the top
@ -320,6 +326,21 @@ module Content = struct
| None -> ())
ppf content;
Format.pp_print_newline ppf ()
let emit_n (target : level) = function
| [content] -> emit content target
| contents ->
let ppf = get_ppf target in
let len = List.length contents in
List.iteri
(fun i c ->
if i > 0 then Format.pp_print_newline ppf ();
let extra_label = Printf.sprintf "(%d/%d)" (succ i) len in
let pp_marker ?extra_label:_ = pp_marker ~extra_label in
emit ~pp_marker c target)
contents
let emit (content : t) (target : level) = emit content target
end
open Content
@ -327,6 +348,7 @@ open Content
(** {1 Error exception} *)
exception CompilerError of Content.t
exception CompilerErrors of Content.t list
(** {1 Error printing} *)
@ -404,3 +426,47 @@ let result = make ~level:Result ~cont:emit
let results r = emit (List.flatten (List.map of_result r)) Result
let warning = make ~level:Warning ~cont:emit
let error = make ~level:Error ~cont:(fun m _ -> raise (CompilerError m))
(* Multiple errors handling *)
type global_errors = {
mutable errors : t list option;
mutable stop_on_error : bool;
}
let global_errors = { errors = None; stop_on_error = false }
let delayed_error x =
make ~level:Error ~cont:(fun m _ ->
if global_errors.stop_on_error then raise (CompilerError m);
match global_errors.errors with
| None ->
error ~internal:true
"delayed error called outside scope: encapsulate using \
'with_delayed_errors' first"
| Some l ->
global_errors.errors <- Some (m :: l);
x)
let with_delayed_errors
?(stop_on_error = Global.options.stop_on_error)
(f : unit -> 'a) : 'a =
(match global_errors.errors with
| None -> global_errors.errors <- Some []
| Some _ ->
error ~internal:true
"delayed error called outside scope: encapsulate using \
'with_delayed_errors' first");
global_errors.stop_on_error <- stop_on_error;
let r = f () in
match global_errors.errors with
| None -> error ~internal:true "intertwined delayed error scope"
| Some [] ->
global_errors.errors <- None;
r
| Some [err] ->
global_errors.errors <- None;
raise (CompilerError err)
| Some errs ->
global_errors.errors <- None;
raise (CompilerErrors (List.rev errs))

View File

@ -55,14 +55,16 @@ module Content : sig
(** {2 Content emission}*)
val emit : t -> level -> unit
val emit_n : level -> t list -> unit
end
(** This functions emits the message according to the emission type defined by
[Cli.message_format_flag]. *)
(** {1 Error exception} *)
(** {1 Error exceptions} *)
exception CompilerError of Content.t
exception CompilerErrors of Content.t list
(** {1 Some formatting helpers}*)
@ -72,6 +74,11 @@ val unformat : (Format.formatter -> unit) -> string
val has_color : out_channel -> bool
val set_terminal_width_function : (unit -> int) -> unit
val terminal_columns : unit -> int
val pad : int -> string -> Format.formatter -> unit
(** Prints the given character the given number of times (assuming it is of
width 1) *)
(* {1 More general color-enabled formatting helpers}*)
@ -98,5 +105,18 @@ val log : ('a, unit) emitter
val debug : ('a, unit) emitter
val result : ('a, unit) emitter
val warning : ('a, unit) emitter
val error : ('a, 'b) emitter
val error : ('a, 'exn) emitter
val results : Content.message list -> unit
(** Multiple errors *)
val with_delayed_errors : ?stop_on_error:bool -> (unit -> 'a) -> 'a
(** [with_delayed_errors ?stop_on_error f] calls [f] and registers each error
triggered using [delayed_error]. [stop_on_error] defaults to
[Global.options.stop_on_error].
@raise CompilerErrors when delayed errors were registered.
@raise CompilerError
on the first error encountered when the [stop_on_error] flag is set. *)
val delayed_error : 'b -> ('a, 'b) emitter

View File

@ -105,14 +105,6 @@ let indent_number (s : string) : int =
aux 0
with Invalid_argument _ -> String.length s
let string_repeat n s =
let slen = String.length s in
let buf = Bytes.create (n * slen) in
for i = 0 to n - 1 do
Bytes.blit_string s 0 buf (i * slen) slen
done;
Bytes.to_string buf
let utf8_byte_index s ui0 =
let rec aux bi ui =
if ui >= ui0 then bi
@ -121,6 +113,11 @@ let utf8_byte_index s ui0 =
in
aux 0 0
let rec pad_fmt n s ppf =
if n > 0 then (
Format.pp_print_as ppf 1 s;
pad_fmt (n - 1) s ppf)
let format_loc_text_parts (pos : t) =
let filename = get_file pos in
if filename = "" then
@ -199,14 +196,12 @@ let format_loc_text_parts (pos : t) =
line;
Format.pp_print_cut ppf ();
if line_no >= sline && line_no <= eline then
Format.fprintf ppf "@{<blue>%s │@} %s@{<bold;red>%a@}"
(string_repeat nspaces " ")
(string_repeat match_start_col " ")
(fun ppf -> Format.pp_print_as ppf match_num_cols)
(string_repeat match_num_cols "")
Format.fprintf ppf "@{<blue>%*s │@} %*s@{<bold;red>%t@}" nspaces ""
match_start_col ""
(pad_fmt match_num_cols "")
in
let pr_context ppf =
Format.fprintf ppf "@{<blue> %s│@}@," (string_repeat nspaces " ");
Format.fprintf ppf "@{<blue> %*s│@}@," nspaces "";
Format.pp_print_list print_matched_line ppf pos_lines
in
let legal_pos_lines =

View File

@ -69,3 +69,8 @@ val format_loc_text_parts :
val no_pos : t
(** Placeholder position *)
(**/**)
val pad_fmt : int -> string -> Format.formatter -> unit
(** Exported as [Message.pad] *)

View File

@ -50,6 +50,18 @@ let remove_prefix ~prefix s =
sub s plen (length s - plen)
else s
let trim_end s =
let rec stop n =
if n < 0 then n
else
match get s n with
| ' ' | '\x0c' | '\n' | '\r' | '\t' -> stop (n - 1)
| _ -> n
in
let last = length s - 1 in
let i = stop last in
if i = last then s else sub s 0 (i + 1)
(* Note: this should do, but remains incorrect for combined unicode characters
that display as one (e.g. `e` + postfix `'`). We should switch to Uuseg at
some poing *)
@ -101,6 +113,7 @@ module Arg = struct
end
let compare = Arg.compare
let hash t = Hash.raw t
module Set = Set.Make (Arg)
module Map = Map.Make (Arg)

View File

@ -23,6 +23,8 @@ module Map : Map.S with type key = string
val compare : string -> string -> int
(** String comparison with natural ordering of numbers within strings *)
val hash : string -> Hash.t
val to_ascii : string -> string
(** Removes all non-ASCII diacritics from a string by converting them to their
base letter in the Latin alphabet. *)
@ -48,6 +50,9 @@ val remove_prefix : prefix:string -> string -> string
- if [str] starts with [prefix], a string [s] such that [prefix ^ s = str]
- otherwise, [str] unchanged *)
val trim_end : string -> string
(** Like [Stdlib.String.trim], but only trims at the end of the string *)
val format : Format.formatter -> string -> unit
val width : string -> int

View File

@ -48,59 +48,54 @@ let levenshtein_distance (s : string) (t : string) : int =
d.(m).(n)
(*We create a list composed by strings that satisfy the following rule : they
have the same levenshtein distance, which is the minimum distance between the
reference word "keyword" and all the strings in "candidates" (with the
condition that this minimum is equal to or less than one third of the length
of keyword + 1, in order to get suggestions close to "keyword")*)
let suggestion_minimum_levenshtein_distance_association
(candidates : string list)
(keyword : string) : string list =
let rec strings_minimum_levenshtein_distance
(minimum : int)
(result : string list)
(candidates' : string list) : string list =
(*As we iterate through the "candidates'" list, we create a list "result"
with all strings that have the last minimum levenshtein distance found
("minimum").*)
match candidates' with
(*When a new minimum levenshtein distance is found, the new result list is
our new element "current_string" followed by strings that have the same
minimum distance. It will be the "result" list if there is no levenshtein
distance smaller than this new minimum.*)
| current_string :: tail ->
let current_levenshtein_distance =
levenshtein_distance current_string keyword
in
if current_levenshtein_distance < minimum then
strings_minimum_levenshtein_distance current_levenshtein_distance
[current_string] tail
(*The "result" list is updated (we append "current_string" to "result")
when a new string shares the same minimum levenshtein distance
"minimum"*)
else if current_levenshtein_distance = minimum then
strings_minimum_levenshtein_distance minimum
(result @ [current_string])
tail
(*If a levenshtein distance greater than the minimum is found, "result"
doesn't change*)
else strings_minimum_levenshtein_distance minimum result tail
(*The "result" list is returned at the end of the "candidates'" list.*)
| [] -> result
in
strings_minimum_levenshtein_distance
(1 + (String.length keyword / 3))
(*In order to select suggestions that are not too far away from the
keyword*)
[] candidates
module M = Stdlib.Map.Make (Int)
let format (ppf : Format.formatter) (suggestions_list : string list) =
match suggestions_list with
let compute_candidates (candidates : string list) (word : string) :
string list M.t =
List.fold_left
(fun m candidate ->
let distance = levenshtein_distance word candidate in
M.update distance
(function None -> Some [candidate] | Some l -> Some (candidate :: l))
m)
M.empty candidates
let best_candidates candidates word =
let candidates = compute_candidates candidates word in
M.choose_opt candidates |> function None -> [] | Some (_, l) -> List.rev l
let sorted_candidates ?(max_elements = 5) suggs given =
let rec sub acc n = function
| [] -> List.rev acc
| x :: t when n > 0 -> sub (x :: acc) (pred n) t
| _ -> List.rev acc
in
let candidates =
List.map
(fun (_, l) -> List.rev l)
(M.bindings (compute_candidates suggs given))
in
List.concat candidates |> sub [] max_elements
let format ppf suggs =
let open Format in
let pp_elt elt = fprintf ppf "@{<yellow>\"%s\"@}" elt in
let rec loop = function
| [] -> assert false
| [h] ->
pp_elt h;
pp_print_string ppf "?"
| [h; t] ->
pp_elt h;
fprintf ppf "@ or@ ";
loop [t]
| h :: t ->
pp_elt h;
fprintf ppf ",@ ";
loop t
in
match suggs with
| [] -> ()
| _ :: _ ->
Format.pp_print_string ppf "Maybe you wanted to write : ";
Format.pp_print_list
~pp_sep:(fun ppf () -> Format.fprintf ppf ",@ or ")
(fun ppf string -> Format.fprintf ppf "@{<yellow>\"%s\"@}" string)
ppf suggestions_list;
Format.pp_print_string ppf " ?"
| suggs ->
pp_print_string ppf "Maybe you wanted to write: ";
loop suggs

View File

@ -15,9 +15,20 @@
License for the specific language governing permissions and limitations under
the License. *)
val suggestion_minimum_levenshtein_distance_association :
string list -> string -> string list
(**Returns a list of the closest words into {!name:candidates} to the keyword
{!name:keyword}*)
val levenshtein_distance : string -> string -> int
(** [levenshtein_distance w1 w2] computes the levenshtein distance separating
[w1] from [w2]. *)
val best_candidates : string list -> string -> string list
(** [best_candidates suggestions word] returns the subset of elements in
[suggestions] that minimize the levenshtein distance to [word]. Multiple
candidates that have a same distance is possible. *)
val sorted_candidates :
?max_elements:int -> string list -> string -> string list
(** [sorted_candidates ?max_elements suggestions word] sorts the [suggestions]
list and retain at most [max_elements] (defaults to 5). This list is ordered
by their levenshtein distance to [word], i.e., the first elements are the
most similar. *)
val format : Format.formatter -> string list -> unit

View File

@ -21,6 +21,7 @@ module type Info = sig
val format : Format.formatter -> info -> unit
val equal : info -> info -> bool
val compare : info -> info -> int
val hash : info -> Hash.t
end
module type Id = sig
@ -33,7 +34,8 @@ module type Id = sig
val equal : t -> t -> bool
val format : Format.formatter -> t -> unit
val to_string : t -> string
val hash : t -> int
val id : t -> int
val hash : t -> Hash.t
module Set : Set.S with type elt = t
module Map : Map.S with type key = t
@ -68,8 +70,9 @@ module Make (X : Info) (S : Style) () : Id with type info = X.info = struct
{ id = !counter; info }
let get_info (uid : t) : X.info = uid.info
let hash (x : t) : int = x.id
let id (x : t) : int = x.id
let to_string t = X.to_string t.info
let hash t = X.hash t.info
module Set = Set.Make (Ordering)
module Map = Map.Make (Ordering)
@ -84,6 +87,7 @@ module MarkedString = struct
let format fmt i = String.format fmt (to_string i)
let equal = Mark.equal String.equal
let compare = Mark.compare String.compare
let hash = Mark.hash String.hash
end
module Gen (S : Style) () = Make (MarkedString) (S) ()
@ -109,6 +113,15 @@ module Path = struct
let to_string p = String.concat "." (List.map Module.to_string p)
let equal = List.equal Module.equal
let compare = List.compare Module.compare
let strip prefix p0 =
let rec aux prefix p =
match prefix, p with
| pfx1 :: pfx, p1 :: p -> if Module.equal pfx1 p1 then aux pfx p else p0
| [], p -> p
| _ -> p0
in
aux prefix p0
end
module QualifiedMarkedString = struct
@ -125,12 +138,21 @@ module QualifiedMarkedString = struct
let compare (p1, i1) (p2, i2) =
match Path.compare p1 p2 with 0 -> MarkedString.compare i1 i2 | n -> n
let hash (p, i) =
let open Hash.Op in
Hash.list Module.hash p % MarkedString.hash i
end
module Gen_qualified (S : Style) () = struct
include Make (QualifiedMarkedString) (S) ()
let fresh path t = fresh (path, t)
let hash ~strip t =
let p, i = get_info t in
QualifiedMarkedString.hash (Path.strip strip p, i)
let path t = fst (get_info t)
let get_info t = snd (get_info t)
end

View File

@ -28,6 +28,9 @@ module type Info = sig
val compare : info -> info -> int
(** Comparison disregards position *)
val hash : info -> Hash.t
(** Hashing disregards position *)
end
module MarkedString : Info with type info = string Mark.pos
@ -48,7 +51,15 @@ module type Id = sig
val equal : t -> t -> bool
val format : Format.formatter -> t -> unit
val to_string : t -> string
val hash : t -> int
val id : t -> int
(** Returns the unique ID of the identifier *)
val hash : t -> Hash.t
(** While [id] returns a unique ID valable for a given Uid instance within a
given run of catala, this is a raw hash of the identifier string.
Therefore, it may collide within a given program, but remains meaninful
across separate compilations. *)
module Set : Set.S with type elt = t
module Map : Map.S with type key = t
@ -79,6 +90,10 @@ module Path : sig
val format : Format.formatter -> t -> unit
val equal : t -> t -> bool
val compare : t -> t -> int
val strip : t -> t -> t
(** [strip pfx p] removed [pfx] from the start of [p]. if [p] doesn't start
with [pfx], it is returned unchanged *)
end
(** Same as [Gen] but also registers path information *)
@ -88,4 +103,6 @@ module Gen_qualified (_ : Style) () : sig
val fresh : Path.t -> MarkedString.info -> t
val path : t -> Path.t
val get_info : t -> MarkedString.info
val hash : strip:Path.t -> t -> Hash.t
(* [strip] strips that prefix from the start of the path before hashing *)
end

View File

@ -592,20 +592,21 @@ let translate_rule
match rule with
| S.ScopeVarDefinition { var; typ; e; _ }
| S.SubScopeVarDefinition { var; typ; e; _ } ->
let scope_var = Mark.remove var in
let decl_pos = Mark.get (ScopeVar.get_info scope_var) in
let pos_mark, _ = pos_mark_mk e in
let scope_let_kind, io =
match rule with
| S.ScopeVarDefinition { io; _ } -> ScopeVarDefinition, io
| S.SubScopeVarDefinition _ ->
let pos = Mark.get var in
( SubScopeVarDefinition,
{ io_input = NoInput, pos; io_output = false, pos } )
{ io_input = NoInput, decl_pos; io_output = false, decl_pos } )
| S.Assertion _ -> assert false
in
let a_name = ScopeVar.get_info (Mark.remove var) in
let a_var = Var.make (Mark.remove a_name) in
let new_e = translate_expr ctx e in
let a_expr = Expr.make_var a_var (pos_mark (Mark.get var)) in
let a_expr = Expr.make_var a_var (pos_mark decl_pos) in
let is_func = match Mark.remove typ with TArrow _ -> true | _ -> false in
let merged_expr =
match Mark.remove io.io_input with
@ -632,7 +633,7 @@ let translate_rule
scope_let_typ = typ;
scope_let_expr = merged_expr;
scope_let_kind;
scope_let_pos = Mark.get var;
scope_let_pos = decl_pos;
},
next ))
(Bindlib.bind_var a_var next)

View File

@ -136,22 +136,27 @@ module ScopeDef = struct
ScopeVar.format ppf (Mark.remove v);
format_kind ppf k
let rec hash_kind = function
| ScopeVarKind None -> 0
| ScopeVarKind (Some st) -> StateName.hash st
| SubScopeInputKind (Direct { var_within_sub_scope = v; _ }) ->
ScopeVar.hash v
open Hash.Op
let rec hash_kind ~strip = function
| ScopeVarKind v -> !`Var % Hash.option StateName.hash v
| SubScopeInputKind
(Direct { var_within_sub_scope = v; sub_scope_name = s }) ->
!`SubScopeInputKind % ScopeName.hash ~strip s % ScopeVar.hash v
| SubScopeInputKind
(NestedSubScope
{
nested_sub_scope_var_within_sub_scope = v;
nested_input_var = i;
_;
sub_scope_name = s;
}) ->
Int.logxor (ScopeVar.hash v) (hash_kind (SubScopeInputKind i))
!`SubScopeInputKind
% ScopeName.hash ~strip s
% ScopeVar.hash v
% hash_kind ~strip (SubScopeInputKind i)
let hash { scope_def_var_within_scope = v; scope_def_kind = k } =
Int.logxor (ScopeVar.hash (Mark.remove v)) (hash_kind k)
let hash ~strip { scope_def_var_within_scope = v; scope_def_kind = k } =
Hash.Op.(ScopeVar.hash (Mark.remove v) % hash_kind ~strip k)
end
include Base
@ -306,6 +311,8 @@ type scope_def = {
type var_or_states = WholeVar | States of StateName.t list
(* If fields are added, make sure to consider including them in the hash
computations below *)
type scope = {
scope_vars : var_or_states ScopeVar.Map.t;
scope_sub_scopes : ScopeName.t ScopeVar.Map.t;
@ -314,21 +321,76 @@ type scope = {
scope_assertions : assertion AssertionName.Map.t;
scope_options : catala_option Mark.pos list;
scope_meta_assertions : meta_assertion list;
scope_visibility : visibility;
}
type topdef = {
topdef_expr : expr option;
topdef_type : typ;
topdef_visibility : visibility;
}
type modul = {
module_scopes : scope ScopeName.Map.t;
module_topdefs : (expr option * typ) TopdefName.Map.t;
module_topdefs : topdef TopdefName.Map.t;
}
type program = {
program_module_name : Ident.t Mark.pos option;
program_module_name : (ModuleName.t * module_intf_id) option;
program_ctx : decl_ctx;
program_modules : modul ModuleName.Map.t;
program_root : modul;
program_lang : Global.backend_lang;
}
module Hash = struct
open Hash.Op
let var_or_state = function
| WholeVar -> !`WholeVar
| States s -> !`States % Hash.list StateName.hash s
let io x =
!(Mark.remove x.io_input : Runtime.io_input)
% !(Mark.remove x.io_output : bool)
let scope_decl ~strip d =
(* scope_def_rules is ignored (not part of the interface) *)
Type.hash ~strip d.scope_def_typ
% Hash.option
(fun (lst, _) ->
List.fold_left
(fun acc (name, ty) ->
acc % Uid.MarkedString.hash name % Type.hash ~strip ty)
!`SDparams lst)
d.scope_def_parameters
% !(d.scope_def_is_condition : bool)
% io d.scope_def_io
let scope ~strip s =
Hash.map ScopeVar.Map.fold ScopeVar.hash var_or_state s.scope_vars
% Hash.map ScopeVar.Map.fold ScopeVar.hash (ScopeName.hash ~strip)
s.scope_sub_scopes
% ScopeName.hash ~strip s.scope_uid
% Hash.map ScopeDef.Map.fold (ScopeDef.hash ~strip) (scope_decl ~strip)
s.scope_defs
(* assertions, options, etc. are not expected to be part of interfaces *)
let modul ?(strip = []) m =
Hash.map ScopeName.Map.fold (ScopeName.hash ~strip) (scope ~strip)
(ScopeName.Map.filter
(fun _ s -> s.scope_visibility = Public)
m.module_scopes)
% Hash.map TopdefName.Map.fold (TopdefName.hash ~strip)
(fun td -> Type.hash ~strip td.topdef_type)
(TopdefName.Map.filter
(fun _ td -> td.topdef_visibility = Public)
m.module_topdefs)
let module_binding modname m =
ModuleName.hash modname % modul ~strip:[modname] m
end
let rec locations_used e : LocationSet.t =
match e with
| ELocation l, m -> LocationSet.singleton (l, Expr.mark_pos m)
@ -391,5 +453,5 @@ let fold_exprs ~(f : 'a -> expr -> 'a) ~(init : 'a) (p : program) : 'a =
p.program_root.module_scopes init
in
TopdefName.Map.fold
(fun _ (e, _) acc -> Option.fold ~none:acc ~some:(f acc) e)
(fun _ tdef acc -> Option.fold ~none:acc ~some:(f acc) tdef.topdef_expr)
p.program_root.module_topdefs acc

View File

@ -58,12 +58,12 @@ module ScopeDef : sig
val equal_kind : kind -> kind -> bool
val compare_kind : kind -> kind -> int
val format_kind : Format.formatter -> kind -> unit
val hash_kind : kind -> int
val hash_kind : strip:Uid.Path.t -> kind -> Hash.t
val equal : t -> t -> bool
val compare : t -> t -> int
val get_position : t -> Pos.t
val format : Format.formatter -> t -> unit
val hash : t -> int
val hash : strip:Uid.Path.t -> t -> Hash.t
module Map : Map.S with type key = t
module Set : Set.S with type elt = t
@ -146,16 +146,23 @@ type scope = {
(** empty outside of the root module *)
scope_options : catala_option Mark.pos list;
scope_meta_assertions : meta_assertion list;
scope_visibility : visibility;
}
type topdef = {
topdef_expr : expr option; (** Always [None] outside of the root module *)
topdef_type : typ;
topdef_visibility : visibility;
(** Necessarily [Public] outside of the root module *)
}
type modul = {
module_scopes : scope ScopeName.Map.t;
module_topdefs : (expr option * typ) TopdefName.Map.t;
(** the expr is [None] outside of the root module *)
module_topdefs : topdef TopdefName.Map.t;
}
type program = {
program_module_name : Ident.t Mark.pos option;
program_module_name : (ModuleName.t * module_intf_id) option;
program_ctx : decl_ctx;
program_modules : modul ModuleName.Map.t;
(** Contains all submodules of the program, in a flattened structure *)
@ -163,6 +170,18 @@ type program = {
program_lang : Global.backend_lang;
}
(** {1 Interface hash computations} *)
(** These hashes are computed on interfaces: only signatures are considered. *)
module Hash : sig
(** The [strip] argument below strips as many leading path components before
hashing *)
val scope : strip:Uid.Path.t -> scope -> Hash.t
val modul : ?strip:Uid.Path.t -> modul -> Hash.t
val module_binding : ModuleName.t -> modul -> Hash.t
end
(** {1 Helpers} *)
val locations_used : expr -> LocationSet.t

View File

@ -39,9 +39,9 @@ module Vertex = struct
let hash x =
match x with
| Var (x, None) -> ScopeVar.hash x
| Var (x, Some sx) -> Int.logxor (ScopeVar.hash x) (StateName.hash sx)
| Assertion a -> Ast.AssertionName.hash a
| Var (x, None) -> ScopeVar.id x
| Var (x, Some sx) -> Hashtbl.hash (ScopeVar.id x, StateName.id sx)
| Assertion a -> Hashtbl.hash (`Assert (Ast.AssertionName.id a))
let compare x y =
match x, y with
@ -257,7 +257,7 @@ module ExceptionVertex = struct
let hash (x : t) : int =
RuleName.Map.fold
(fun r _ acc -> Int.logxor (RuleName.hash r) acc)
(fun r _ acc -> Hashtbl.hash (RuleName.id r, acc))
x.rules 0
let equal x y = compare x y = 0

View File

@ -107,13 +107,19 @@ let program prg =
in
let module_topdefs =
TopdefName.Map.map
(function
| Some e, ty ->
Some (Expr.unbox (expr prg.program_ctx env (Expr.box e))), ty
| None, ty -> None, ty)
(fun def ->
{
def with
topdef_expr =
Option.map
(fun e -> Expr.unbox (expr prg.program_ctx env (Expr.box e)))
def.topdef_expr;
})
prg.program_root.module_topdefs
in
let module_scopes =
ScopeName.Map.map (scope prg.program_ctx env) prg.program_root.module_scopes
in
{ prg with program_root = { module_topdefs; module_scopes } }
let program prg = Message.with_delayed_errors (fun () -> program prg)

View File

@ -138,8 +138,7 @@ let raise_error_cons_not_found
(constructor : string Mark.pos) =
let constructors = Ident.Map.keys ctxt.local.constructor_idmap in
let closest_constructors =
Suggestions.suggestion_minimum_levenshtein_distance_association constructors
(Mark.remove constructor)
Suggestions.best_candidates constructors (Mark.remove constructor)
in
Message.error
~pos_msg:(fun ppf -> Format.fprintf ppf "Here is your code :")
@ -313,7 +312,7 @@ let rec translate_expr
in
let e2 = rec_helper ~local_vars e2 in
Expr.make_abs [| binding_var |] e2 [tau] pos_op)
(EnumName.Map.find enum_uid ctxt.enums)
(fst (EnumName.Map.find enum_uid ctxt.enums))
in
Expr.ematch ~e:(rec_helper e1_sub) ~name:enum_uid ~cases emark
| Binop ((((S.And | S.Or | S.Xor), _) as op), e1, e2) ->
@ -709,7 +708,7 @@ let rec translate_expr
StructField.Map.add f_uid f_e s_fields)
StructField.Map.empty fields
in
let expected_s_fields = StructName.Map.find s_uid ctxt.structs in
let expected_s_fields, _ = StructName.Map.find s_uid ctxt.structs in
if
StructField.Map.exists
(fun expected_f _ -> not (StructField.Map.mem expected_f s_fields))
@ -815,7 +814,7 @@ let rec translate_expr
Expr.make_abs [| nop_var |]
(Expr.elit (LBool (EnumConstructor.compare c_uid c_uid' = 0)) emark)
[tau] pos)
(EnumName.Map.find enum_uid ctxt.enums)
(fst (EnumName.Map.find enum_uid ctxt.enums))
in
Expr.ematch ~e:(rec_helper e1) ~name:enum_uid ~cases emark
| ArrayLit es -> Expr.earray (List.map rec_helper es) emark
@ -1066,7 +1065,7 @@ and disambiguate_match_and_build_expression
Expr.eabs e_binder
[
EnumConstructor.Map.find c_uid
(EnumName.Map.find e_uid ctxt.Name_resolution.enums);
(fst (EnumName.Map.find e_uid ctxt.Name_resolution.enums));
]
(Mark.get case_body)
in
@ -1133,6 +1132,7 @@ and disambiguate_match_and_build_expression
if curr_index < nb_cases - 1 then raise_wildcard_not_last_case_err ();
let missing_constructors =
EnumName.Map.find e_uid ctxt.Name_resolution.enums
|> fst
|> EnumConstructor.Map.filter_map (fun c_uid _ ->
match EnumConstructor.Map.find_opt c_uid cases_d with
| Some _ -> None
@ -1557,6 +1557,7 @@ let process_scope_use
let process_topdef
(ctxt : Name_resolution.context)
(prgm : Ast.program)
(is_public : bool)
(def : S.top_def) : Ast.program =
let id =
Ident.Map.find
@ -1599,12 +1600,14 @@ let process_topdef
in
Some (Expr.unbox_closed e)
in
let topdef_visibility = if is_public then Public else Private in
let module_topdefs =
TopdefName.Map.update id
(fun def0 ->
match def0, expr_opt with
| None, eopt -> Some (eopt, typ)
| Some (eopt0, ty0), eopt -> (
| None, eopt ->
Some { Ast.topdef_expr = eopt; topdef_visibility; topdef_type = typ }
| Some def0, eopt -> (
let err msg =
Message.error
~extra_pos:
@ -1614,13 +1617,16 @@ let process_topdef
]
(msg ^^ " for %a") TopdefName.format id
in
if not (Type.equal ty0 typ) then err "Conflicting type definitions"
if not (Type.equal def0.Ast.topdef_type typ) then
err "Conflicting type definitions"
else
match eopt0, eopt with
match def0.Ast.topdef_expr, eopt with
| None, None -> err "Multiple declarations"
| Some _, Some _ -> err "Multiple definitions"
| Some e, None -> Some (Some e, typ)
| None, Some e -> Some (Some e, ty0)))
| (Some _ as topdef_expr), None ->
Some { Ast.topdef_expr; topdef_visibility; topdef_type = typ }
| None, (Some _ as topdef_expr) ->
Some { def0 with Ast.topdef_expr }))
prgm.Ast.program_root.module_topdefs
in
{ prgm with program_root = { prgm.program_root with module_topdefs } }
@ -1851,6 +1857,7 @@ let translate_program (ctxt : Name_resolution.context) (surface : S.program) :
scope_meta_assertions = [];
scope_options = [];
scope_uid = s_uid;
scope_visibility = s_context.Name_resolution.scope_visibility;
}
in
let get_scopes mctx =
@ -1863,22 +1870,32 @@ let translate_program (ctxt : Name_resolution.context) (surface : S.program) :
mctx.Name_resolution.typedefs ScopeName.Map.empty
in
let program_modules =
ModuleName.Map.map
(fun mctx ->
{
Ast.module_scopes = get_scopes mctx;
Ast.module_topdefs =
Ident.Map.fold
(fun _ name acc ->
TopdefName.Map.add name
( None,
TopdefName.Map.find name ctxt.Name_resolution.topdef_types
)
acc)
mctx.topdefs TopdefName.Map.empty;
})
ModuleName.Map.mapi
(fun mname mctx ->
let m =
{
Ast.module_scopes = get_scopes mctx;
Ast.module_topdefs =
Ident.Map.fold
(fun _ name acc ->
let topdef_type, topdef_visibility =
TopdefName.Map.find name ctxt.Name_resolution.topdefs
in
TopdefName.Map.add name
{ Ast.topdef_expr = None; topdef_visibility; topdef_type }
acc)
mctx.topdefs TopdefName.Map.empty;
}
in
m, Ast.Hash.module_binding mname m)
ctxt.modules
in
let program_root =
{
Ast.module_scopes = get_scopes ctxt.Name_resolution.local;
Ast.module_topdefs = TopdefName.Map.empty;
}
in
let program_ctx =
let open Name_resolution in
let ctx_scopes mctx acc =
@ -1891,23 +1908,30 @@ let translate_program (ctxt : Name_resolution.context) (surface : S.program) :
in
let ctx_modules =
let rec aux mctx =
Ident.Map.fold
(fun _ m (M acc) ->
let sub = aux (ModuleName.Map.find m ctxt.modules) in
M (ModuleName.Map.add m sub acc))
mctx.used_modules (M ModuleName.Map.empty)
let subs =
Ident.Map.fold
(fun _ m acc ->
let mctx = ModuleName.Map.find m ctxt.Name_resolution.modules in
let deps = aux mctx in
let hash = snd (ModuleName.Map.find m program_modules) in
ModuleName.Map.add m
{ deps; intf_id = { hash; is_external = mctx.is_external } }
acc)
mctx.used_modules ModuleName.Map.empty
in
subs
in
aux ctxt.local
in
{
ctx_structs = ctxt.structs;
ctx_enums = ctxt.enums;
ctx_structs = StructName.Map.map fst ctxt.structs;
ctx_enums = EnumName.Map.map fst ctxt.enums;
ctx_scopes =
ModuleName.Map.fold
(fun _ -> ctx_scopes)
ctxt.modules
(ctx_scopes ctxt.local ScopeName.Map.empty);
ctx_topdefs = ctxt.topdef_types;
ctx_topdefs = TopdefName.Map.map fst ctxt.topdefs;
ctx_struct_fields = ctxt.local.field_idmap;
ctx_enum_constrs = ctxt.local.constructor_idmap;
ctx_scope_index =
@ -1919,25 +1943,29 @@ let translate_program (ctxt : Name_resolution.context) (surface : S.program) :
ctx_modules;
}
in
let program_module_name =
surface.Surface.Ast.program_module
|> Option.map
@@ fun { Surface.Ast.module_name; module_external } ->
let mname = ModuleName.fresh module_name in
let hash_placeholder = Hash.raw 0 in
mname, { hash = hash_placeholder; is_external = module_external }
in
let desugared =
{
Ast.program_lang = surface.program_lang;
Ast.program_module_name = surface.Surface.Ast.program_module_name;
Ast.program_modules;
Ast.program_module_name;
Ast.program_modules = ModuleName.Map.map fst program_modules;
Ast.program_ctx;
Ast.program_root =
{
Ast.module_scopes = get_scopes ctxt.Name_resolution.local;
Ast.module_topdefs = TopdefName.Map.empty;
};
Ast.program_root;
}
in
let process_code_block ctxt prgm block =
let process_code_block ctxt prgm is_meta block =
List.fold_left
(fun prgm item ->
match Mark.remove item with
| S.ScopeUse use -> process_scope_use ctxt prgm use
| S.Topdef def -> process_topdef ctxt prgm def
| S.Topdef def -> process_topdef ctxt prgm is_meta def
| S.ScopeDecl _ | S.StructDecl _ | S.EnumDecl _ -> prgm)
prgm block
in
@ -1948,7 +1976,22 @@ let translate_program (ctxt : Name_resolution.context) (surface : S.program) :
List.fold_left
(fun prgm child -> process_structure prgm child)
prgm children
| S.CodeBlock (block, _, _) -> process_code_block ctxt prgm block
| S.CodeBlock (block, _, is_meta) ->
process_code_block ctxt prgm is_meta block
| S.ModuleDef _ | S.LawInclude _ | S.LawText _ | S.ModuleUse _ -> prgm
in
List.fold_left process_structure desugared surface.S.program_items
let desugared =
List.fold_left process_structure desugared surface.S.program_items
in
{
desugared with
Ast.program_module_name =
(desugared.Ast.program_module_name
|> Option.map
@@ fun (mname, intf_id) ->
( mname,
{
intf_id with
hash = Ast.Hash.module_binding mname desugared.Ast.program_root;
} ));
}

View File

@ -273,8 +273,8 @@ let detect_dead_code (p : program) : unit =
let emit_unused_warning vx =
Message.warning
~pos:(Mark.get (Dependency.Vertex.info vx))
"Unused varible:@ %a@ does@ not@ contribute@ to@ computing@ any@ of@ \
scope@ %a@ outputs.@ Did you forget something?"
"Unused variable:@ %a@ does@ not@ contribute@ to@ computing@ any@ \
of@ scope@ %a@ outputs.@ Did you forget something?"
Dependency.Vertex.format vx ScopeName.format scope_name
in
Dependency.ScopeDependencies.iter_vertex

View File

@ -39,6 +39,7 @@ type scope_context = {
scope_out_struct : StructName.t;
sub_scopes : ScopeName.Set.t;
(** Other scopes referred to by this scope. Used for dependency analysis *)
scope_visibility : visibility;
}
(** Inside a scope, we distinguish between the variables and the subscopes. *)
@ -77,15 +78,17 @@ type module_context = {
between different enums *)
topdefs : TopdefName.t Ident.Map.t; (** Global definitions *)
used_modules : ModuleName.t Ident.Map.t;
is_external : bool;
}
(** Context for name resolution, valid within a given module *)
type context = {
scopes : scope_context ScopeName.Map.t; (** For each scope, its context *)
topdef_types : typ TopdefName.Map.t;
structs : struct_context StructName.Map.t;
topdefs : (typ * visibility) TopdefName.Map.t;
structs : (struct_context * visibility) StructName.Map.t;
(** For each struct, its context *)
enums : enum_context EnumName.Map.t; (** For each enum, its context *)
enums : (enum_context * visibility) EnumName.Map.t;
(** For each enum, its context *)
var_typs : var_sig ScopeVar.Map.t;
(** The signatures of each scope variable declared *)
modules : module_context ModuleName.Map.t;
@ -450,8 +453,10 @@ let process_data_decl
}
(** Process a struct declaration *)
let process_struct_decl (ctxt : context) (sdecl : Surface.Ast.struct_decl) :
context =
let process_struct_decl
?(visibility = Public)
(ctxt : context)
(sdecl : Surface.Ast.struct_decl) : context =
let s_uid = get_struct ctxt sdecl.struct_decl_name in
if sdecl.struct_decl_fields = [] then
Message.error
@ -478,25 +483,28 @@ let process_struct_decl (ctxt : context) (sdecl : Surface.Ast.struct_decl) :
let ctxt = { ctxt with local } in
let structs =
StructName.Map.update s_uid
(fun fields ->
match fields with
(function
| None ->
Some
(StructField.Map.singleton f_uid
(process_type ctxt fdecl.Surface.Ast.struct_decl_field_typ))
| Some fields ->
( StructField.Map.singleton f_uid
(process_type ctxt fdecl.Surface.Ast.struct_decl_field_typ),
visibility )
| Some (fields, _) ->
Some
(StructField.Map.add f_uid
(process_type ctxt fdecl.Surface.Ast.struct_decl_field_typ)
fields))
( StructField.Map.add f_uid
(process_type ctxt fdecl.Surface.Ast.struct_decl_field_typ)
fields,
visibility ))
ctxt.structs
in
{ ctxt with structs })
ctxt sdecl.struct_decl_fields
(** Process an enum declaration *)
let process_enum_decl (ctxt : context) (edecl : Surface.Ast.enum_decl) : context
=
let process_enum_decl
?(visibility = Public)
(ctxt : context)
(edecl : Surface.Ast.enum_decl) : context =
let e_uid = get_enum ctxt edecl.enum_decl_name in
if List.length edecl.enum_decl_cases = 0 then
Message.error
@ -530,23 +538,24 @@ let process_enum_decl (ctxt : context) (edecl : Surface.Ast.enum_decl) : context
| Some typ -> process_type ctxt typ
in
match cases with
| None -> Some (EnumConstructor.Map.singleton c_uid typ)
| Some fields -> Some (EnumConstructor.Map.add c_uid typ fields))
| None -> Some (EnumConstructor.Map.singleton c_uid typ, visibility)
| Some (fields, _) ->
Some (EnumConstructor.Map.add c_uid typ fields, visibility))
ctxt.enums
in
{ ctxt with enums })
ctxt edecl.enum_decl_cases
let process_topdef ctxt def =
let process_topdef ?(visibility = Public) ctxt def =
let uid =
Ident.Map.find (Mark.remove def.Surface.Ast.topdef_name) ctxt.local.topdefs
in
{
ctxt with
topdef_types =
topdefs =
TopdefName.Map.add uid
(process_type ctxt def.Surface.Ast.topdef_type)
ctxt.topdef_types;
(process_type ctxt def.Surface.Ast.topdef_type, visibility)
ctxt.topdefs;
}
(** Process an item declaration *)
@ -560,8 +569,10 @@ let process_item_decl
process_subscope_decl scope ctxt sub_decl
(** Process a scope declaration *)
let process_scope_decl (ctxt : context) (decl : Surface.Ast.scope_decl) :
context =
let process_scope_decl
?(visibility = Public)
(ctxt : context)
(decl : Surface.Ast.scope_decl) : context =
let scope_uid = get_scope ctxt decl.scope_decl_name in
let ctxt =
List.fold_left
@ -612,11 +623,12 @@ let process_scope_decl (ctxt : context) (decl : Surface.Ast.scope_decl) :
structs =
StructName.Map.add
(get_struct ctxt decl.scope_decl_name)
StructField.Map.empty ctxt.structs;
(StructField.Map.empty, visibility)
ctxt.structs;
}
else
let ctxt =
process_struct_decl ctxt
process_struct_decl ~visibility ctxt
{
struct_decl_name = decl.scope_decl_name;
struct_decl_fields = output_fields;
@ -660,8 +672,10 @@ let typedef_info = function
| TScope (s, _) -> ScopeName.get_info s
(** Process the names of all declaration items *)
let process_name_item (ctxt : context) (item : Surface.Ast.code_item Mark.pos) :
context =
let process_name_item
?(visibility = Public)
(ctxt : context)
(item : Surface.Ast.code_item Mark.pos) : context =
let raise_already_defined_error (use : Uid.MarkedString.info) name pos msg =
Message.error
~fmt_pos:
@ -702,6 +716,7 @@ let process_name_item (ctxt : context) (item : Surface.Ast.code_item Mark.pos) :
scope_in_struct = in_struct_name;
scope_out_struct = out_struct_name;
sub_scopes = ScopeName.Set.empty;
scope_visibility = visibility;
}
ctxt.scopes
in
@ -746,14 +761,16 @@ let process_name_item (ctxt : context) (item : Surface.Ast.code_item Mark.pos) :
{ ctxt with local = { ctxt.local with topdefs } }
(** Process a code item that is a declaration *)
let process_decl_item (ctxt : context) (item : Surface.Ast.code_item Mark.pos) :
context =
let process_decl_item
?visibility
(ctxt : context)
(item : Surface.Ast.code_item Mark.pos) : context =
match Mark.remove item with
| ScopeDecl decl -> process_scope_decl ctxt decl
| StructDecl sdecl -> process_struct_decl ctxt sdecl
| EnumDecl edecl -> process_enum_decl ctxt edecl
| ScopeDecl decl -> process_scope_decl ?visibility ctxt decl
| StructDecl sdecl -> process_struct_decl ?visibility ctxt sdecl
| EnumDecl edecl -> process_enum_decl ?visibility ctxt edecl
| ScopeUse _ -> ctxt
| Topdef def -> process_topdef ctxt def
| Topdef def -> process_topdef ?visibility ctxt def
(** Process a code block *)
let process_code_block
@ -764,7 +781,11 @@ let process_code_block
(** Process a law structure, only considering the code blocks *)
let rec process_law_structure
(process_item : context -> Surface.Ast.code_item Mark.pos -> context)
(process_item :
?visibility:visibility ->
context ->
Surface.Ast.code_item Mark.pos ->
context)
(ctxt : context)
(s : Surface.Ast.law_structure) : context =
match s with
@ -772,10 +793,14 @@ let rec process_law_structure
List.fold_left
(fun ctxt child -> process_law_structure process_item ctxt child)
ctxt children
| Surface.Ast.CodeBlock (block, _, _) ->
process_code_block process_item ctxt block
| Surface.Ast.CodeBlock (block, _, is_meta) ->
process_code_block
(process_item ~visibility:(if is_meta then Public else Private))
ctxt block
| Surface.Ast.ModuleDef (_, is_external) ->
{ ctxt with local = { ctxt.local with is_external } }
| Surface.Ast.LawInclude _ | Surface.Ast.LawText _ -> ctxt
| Surface.Ast.ModuleDef _ | Surface.Ast.ModuleUse _ -> ctxt
| Surface.Ast.ModuleUse _ -> ctxt
(** {1 Scope uses pass} *)
@ -1025,12 +1050,13 @@ let empty_module_ctxt =
constructor_idmap = Ident.Map.empty;
topdefs = Ident.Map.empty;
used_modules = Ident.Map.empty;
is_external = false;
}
let empty_ctxt =
{
scopes = ScopeName.Map.empty;
topdef_types = TopdefName.Map.empty;
topdefs = TopdefName.Map.empty;
var_typs = ScopeVar.Map.empty;
structs = StructName.Map.empty;
enums = EnumName.Map.empty;
@ -1053,7 +1079,13 @@ let form_context (surface, mod_uses) surface_modules : context =
let ctxt =
{
ctxt with
local = { ctxt.local with used_modules = mod_uses; path = [m] };
local =
{
ctxt.local with
used_modules = mod_uses;
path = [m];
is_external = intf.Surface.Ast.intf_modname.module_external;
};
}
in
let ctxt =
@ -1085,7 +1117,7 @@ let form_context (surface, mod_uses) surface_modules : context =
in
let ctxt =
List.fold_left
(process_law_structure process_use_item)
(process_law_structure (fun ?visibility:_ -> process_use_item))
ctxt surface.Surface.Ast.program_items
in
(* Gather struct fields and enum constrs from direct modules: this helps with

View File

@ -39,6 +39,7 @@ type scope_context = {
scope_out_struct : StructName.t;
sub_scopes : ScopeName.Set.t;
(** Other scopes referred to by this scope. Used for dependency analysis *)
scope_visibility : visibility;
}
(** Inside a scope, we distinguish between the variables and the subscopes. *)
@ -82,16 +83,18 @@ type module_context = {
topdefs : TopdefName.t Ident.Map.t; (** Global definitions *)
used_modules : ModuleName.t Ident.Map.t;
(** Module aliases and the modules they point to *)
is_external : bool;
}
(** Context for name resolution, valid within a given module *)
type context = {
scopes : scope_context ScopeName.Map.t; (** For each scope, its context *)
topdef_types : typ TopdefName.Map.t;
topdefs : (typ * visibility) TopdefName.Map.t;
(** Types associated with the global definitions *)
structs : struct_context StructName.Map.t;
structs : (struct_context * visibility) StructName.Map.t;
(** For each struct, its context *)
enums : enum_context EnumName.Map.t; (** For each enum, its context *)
enums : (enum_context * visibility) EnumName.Map.t;
(** For each enum, its context *)
var_typs : var_sig ScopeVar.Map.t;
(** The signatures of each scope variable declared *)
modules : module_context ModuleName.Map.t;

View File

@ -93,7 +93,7 @@ let load_module_interfaces
Surface.Parser_driver.load_interface ?default_module_name
(Global.FileName f)
in
let modname = ModuleName.fresh intf.intf_modname in
let modname = ModuleName.fresh intf.intf_modname.module_name in
let seen = File.Map.add f None seen in
let seen, sub_use_map =
aux
@ -107,9 +107,9 @@ let load_module_interfaces
(seen, Ident.Map.empty) uses
in
let seen =
match program.Surface.Ast.program_module_name with
match program.Surface.Ast.program_module with
| Some m ->
let file = Pos.get_file (Mark.get m) in
let file = Pos.get_file (Mark.get m.module_name) in
File.Map.singleton file None
| None -> File.Map.empty
in
@ -202,15 +202,9 @@ module Passes = struct
in
let (prg : ty Dcalc.Ast.program) =
match typed with
| Typed _ -> (
| Typed _ ->
Message.debug "Typechecking again...";
try Typing.program prg
with Message.CompilerError error_content ->
let bt = Printexc.get_raw_backtrace () in
Printexc.raise_with_backtrace
(Message.CompilerError
(Message.Content.to_internal_error error_content))
bt)
Typing.program ~internal_check:true prg
| Untyped _ -> prg
| Custom _ -> assert false
in
@ -233,7 +227,6 @@ module Passes = struct
~optimize
~check_invariants
~(typed : ty mark)
~avoid_exceptions
~closure_conversion
~monomorphize_types :
typed Lcalc.Ast.program * Scopelang.Dependency.TVertex.t list =
@ -241,23 +234,11 @@ module Passes = struct
dcalc options ~includes ~optimize ~check_invariants ~typed
in
debug_pass_name "lcalc";
let avoid_exceptions = avoid_exceptions || closure_conversion in
(* --closure-conversion implies --avoid-exceptions *)
let prg =
if avoid_exceptions && options.trace then
Message.warning
"It is discouraged to use option @{<yellow>--avoid-exceptions@} if \
you@ also@ need@ @{<yellow>--trace@},@ the@ resulting@ trace@ may@ \
be@ unreliable@ at@ the@ moment.";
match avoid_exceptions, typed with
| true, Untyped _ ->
Lcalc.From_dcalc.translate_program_without_exceptions prg
| true, Typed _ ->
Lcalc.From_dcalc.translate_program_without_exceptions prg
| false, Typed _ -> Lcalc.From_dcalc.translate_program_with_exceptions prg
| false, Untyped _ ->
Lcalc.From_dcalc.translate_program_with_exceptions prg
| _, Custom _ -> invalid_arg "Driver.Passes.lcalc"
match typed with
| Untyped _ -> Lcalc.From_dcalc.translate_program prg
| Typed _ -> Lcalc.From_dcalc.translate_program prg
| Custom _ -> invalid_arg "Driver.Passes.lcalc"
in
let prg =
if optimize then begin
@ -269,7 +250,7 @@ module Passes = struct
let prg =
if not closure_conversion then (
Message.debug "Retyping lambda calculus...";
Typing.program ~fail_on_any:false prg)
Typing.program ~fail_on_any:false ~internal_check:true prg)
else (
Message.debug "Performing closure conversion...";
let prg = Lcalc.Closure_conversion.closure_conversion prg in
@ -280,14 +261,17 @@ module Passes = struct
else prg
in
Message.debug "Retyping lambda calculus...";
Typing.program ~fail_on_any:false prg)
Typing.program ~fail_on_any:false ~internal_check:true prg)
in
let prg, type_ordering =
if monomorphize_types then (
Message.debug "Monomorphizing types...";
let prg, type_ordering = Lcalc.Monomorphize.program prg in
Message.debug "Retyping lambda calculus...";
let prg = Typing.program ~fail_on_any:false ~assume_op_types:true prg in
let prg =
Typing.program ~fail_on_any:false ~assume_op_types:true
~internal_check:true prg
in
prg, type_ordering)
else prg, type_ordering
in
@ -298,7 +282,6 @@ module Passes = struct
~includes
~optimize
~check_invariants
~avoid_exceptions
~closure_conversion
~keep_special_ops
~dead_value_assignment
@ -307,7 +290,7 @@ module Passes = struct
Scalc.Ast.program * Scopelang.Dependency.TVertex.t list =
let prg, type_ordering =
lcalc options ~includes ~optimize ~check_invariants ~typed:Expr.typed
~avoid_exceptions ~closure_conversion ~monomorphize_types
~closure_conversion ~monomorphize_types
in
debug_pass_name "scalc";
( Scalc.From_lcalc.translate_program
@ -715,7 +698,9 @@ module Commands = struct
let prg, _ =
Passes.dcalc options ~includes ~optimize ~check_invariants ~typed
in
Interpreter.load_runtime_modules prg;
Interpreter.load_runtime_modules
~hashf:Hash.(finalise ~closure_conversion:false ~monomorphize_types:false)
prg;
print_interpretation_results options Interpreter.interpret_program_dcalc prg
(get_scopeopt_uid prg.decl_ctx ex_scope_opt)
@ -726,13 +711,12 @@ module Commands = struct
output
optimize
check_invariants
avoid_exceptions
closure_conversion
monomorphize_types
ex_scope_opt =
let prg, _ =
Passes.lcalc options ~includes ~optimize ~check_invariants
~avoid_exceptions ~closure_conversion ~typed ~monomorphize_types
~closure_conversion ~typed ~monomorphize_types
in
let _output_file, with_output = get_output_format options output in
with_output
@ -765,14 +749,12 @@ module Commands = struct
$ Cli.Flags.output
$ Cli.Flags.optimize
$ Cli.Flags.check_invariants
$ Cli.Flags.avoid_exceptions
$ Cli.Flags.closure_conversion
$ Cli.Flags.monomorphize_types
$ Cli.Flags.ex_scope_opt)
let interpret_lcalc
typed
avoid_exceptions
closure_conversion
monomorphize_types
options
@ -782,29 +764,27 @@ module Commands = struct
ex_scope_opt =
let prg, _ =
Passes.lcalc options ~includes ~optimize ~check_invariants
~avoid_exceptions ~closure_conversion ~monomorphize_types ~typed
~closure_conversion ~monomorphize_types ~typed
in
Interpreter.load_runtime_modules prg;
Interpreter.load_runtime_modules
~hashf:(Hash.finalise ~closure_conversion ~monomorphize_types)
prg;
print_interpretation_results options Interpreter.interpret_program_lcalc prg
(get_scopeopt_uid prg.decl_ctx ex_scope_opt)
let interpret_cmd =
let f lcalc avoid_exceptions closure_conversion monomorphize_types no_typing
=
let f lcalc closure_conversion monomorphize_types no_typing =
if not lcalc then
if avoid_exceptions || closure_conversion || monomorphize_types then
if closure_conversion || monomorphize_types then
Message.error
"The flags @{<bold>--avoid-exceptions@}, \
@{<bold>--closure-conversion@} and @{<bold>--monomorphize-types@} \
only make sense with the @{<bold>--lcalc@} option"
"The flags @{<bold>--closure-conversion@} and \
@{<bold>--monomorphize-types@} only make sense with the \
@{<bold>--lcalc@} option"
else if no_typing then interpret_dcalc Expr.untyped
else interpret_dcalc Expr.typed
else if no_typing then
interpret_lcalc Expr.untyped avoid_exceptions closure_conversion
monomorphize_types
else
interpret_lcalc Expr.typed avoid_exceptions closure_conversion
monomorphize_types
interpret_lcalc Expr.untyped closure_conversion monomorphize_types
else interpret_lcalc Expr.typed closure_conversion monomorphize_types
in
Cmd.v
(Cmd.info "interpret"
@ -815,7 +795,6 @@ module Commands = struct
Term.(
const f
$ Cli.Flags.lcalc
$ Cli.Flags.avoid_exceptions
$ Cli.Flags.closure_conversion
$ Cli.Flags.monomorphize_types
$ Cli.Flags.no_typing
@ -831,12 +810,11 @@ module Commands = struct
output
optimize
check_invariants
avoid_exceptions
closure_conversion
ex_scope_opt =
let prg, type_ordering =
Passes.lcalc options ~includes ~optimize ~check_invariants
~avoid_exceptions ~typed:Expr.typed ~closure_conversion:false
~monomorphize_types:false
~typed:Expr.typed ~closure_conversion ~monomorphize_types:false
in
let output_file, with_output =
get_output_format options ~ext:".ml" output
@ -847,7 +825,8 @@ module Commands = struct
Message.debug "Writing to %s..."
(Option.value ~default:"stdout" output_file);
let exec_scope = Option.map (get_scope_uid prg.decl_ctx) ex_scope_opt in
Lcalc.To_ocaml.format_program fmt prg ?exec_scope type_ordering
let hashf = Hash.finalise ~closure_conversion ~monomorphize_types:false in
Lcalc.To_ocaml.format_program fmt prg ?exec_scope ~hashf type_ordering
let ocaml_cmd =
Cmd.v
@ -860,7 +839,7 @@ module Commands = struct
$ Cli.Flags.output
$ Cli.Flags.optimize
$ Cli.Flags.check_invariants
$ Cli.Flags.avoid_exceptions
$ Cli.Flags.closure_conversion
$ Cli.Flags.ex_scope_opt)
let scalc
@ -869,7 +848,6 @@ module Commands = struct
output
optimize
check_invariants
avoid_exceptions
closure_conversion
keep_special_ops
dead_value_assignment
@ -878,8 +856,8 @@ module Commands = struct
ex_scope_opt =
let prg, _ =
Passes.scalc options ~includes ~optimize ~check_invariants
~avoid_exceptions ~closure_conversion ~keep_special_ops
~dead_value_assignment ~no_struct_literals ~monomorphize_types
~closure_conversion ~keep_special_ops ~dead_value_assignment
~no_struct_literals ~monomorphize_types
in
let _output_file, with_output = get_output_format options output in
with_output
@ -911,7 +889,6 @@ module Commands = struct
$ Cli.Flags.output
$ Cli.Flags.optimize
$ Cli.Flags.check_invariants
$ Cli.Flags.avoid_exceptions
$ Cli.Flags.closure_conversion
$ Cli.Flags.keep_special_ops
$ Cli.Flags.dead_value_assignment
@ -925,13 +902,11 @@ module Commands = struct
output
optimize
check_invariants
avoid_exceptions
closure_conversion =
let prg, type_ordering =
Passes.scalc options ~includes ~optimize ~check_invariants
~avoid_exceptions ~closure_conversion ~keep_special_ops:false
~dead_value_assignment:true ~no_struct_literals:false
~monomorphize_types:false
~closure_conversion ~keep_special_ops:false ~dead_value_assignment:true
~no_struct_literals:false ~monomorphize_types:false
in
let output_file, with_output =
@ -954,39 +929,12 @@ module Commands = struct
$ Cli.Flags.output
$ Cli.Flags.optimize
$ Cli.Flags.check_invariants
$ Cli.Flags.avoid_exceptions
$ Cli.Flags.closure_conversion)
let r options includes output optimize check_invariants closure_conversion =
let prg, type_ordering =
Passes.scalc options ~includes ~optimize ~check_invariants
~avoid_exceptions:false ~closure_conversion ~keep_special_ops:false
~dead_value_assignment:false ~no_struct_literals:false
~monomorphize_types:false
in
let output_file, with_output = get_output_format options ~ext:".r" output in
Message.debug "Compiling program into R...";
Message.debug "Writing to %s..."
(Option.value ~default:"stdout" output_file);
with_output @@ fun fmt -> Scalc.To_r.format_program fmt prg type_ordering
let r_cmd =
Cmd.v
(Cmd.info "r" ~doc:"Generates an R translation of the Catala program.")
Term.(
const r
$ Cli.Flags.Global.options
$ Cli.Flags.include_dirs
$ Cli.Flags.output
$ Cli.Flags.optimize
$ Cli.Flags.check_invariants
$ Cli.Flags.closure_conversion)
let c options includes output optimize check_invariants =
let prg, type_ordering =
Passes.scalc options ~includes ~optimize ~check_invariants
~avoid_exceptions:true ~closure_conversion:true ~keep_special_ops:true
~closure_conversion:true ~keep_special_ops:true
~dead_value_assignment:false ~no_struct_literals:true
~monomorphize_types:true
in
@ -1013,7 +961,7 @@ module Commands = struct
let prg =
Surface.Ast.
{
program_module_name = None;
program_module = None;
program_items = [];
program_source_files = [];
program_used_modules =
@ -1041,7 +989,7 @@ module Commands = struct
in
Format.open_hbox ();
Format.pp_print_list ~pp_sep:Format.pp_print_space
(fun ppf m ->
(fun ppf (m, _) ->
let f = Pos.get_file (Mark.get (ModuleName.get_info m)) in
let f =
match prefix with
@ -1101,7 +1049,6 @@ module Commands = struct
proof_cmd;
ocaml_cmd;
python_cmd;
r_cmd;
c_cmd;
latex_cmd;
html_cmd;
@ -1184,6 +1131,12 @@ let main () =
in
let command = catala_t plugins in
let open Cmdliner in
let[@inline] exit_with_error excode fcontent =
let bt = Printexc.get_raw_backtrace () in
Message.Content.emit (fcontent ()) Error;
if Global.options.debug then Printexc.print_raw_backtrace stderr bt;
exit excode
in
match Cmd.eval_value ~catch:false ~argv command with
| Ok _ -> exit Cmd.Exit.ok
| Error e ->
@ -1191,29 +1144,22 @@ let main () =
exit Cmd.Exit.cli_error
| exception Cli.Exit_with n -> exit n
| exception Message.CompilerError content ->
exit_with_error Cmd.Exit.some_error @@ fun () -> content
| exception Message.CompilerErrors contents ->
let bt = Printexc.get_raw_backtrace () in
Message.Content.emit content Error;
Message.Content.emit_n Error contents;
if Global.options.debug then Printexc.print_raw_backtrace stderr bt;
exit Cmd.Exit.some_error
| exception Failure msg ->
let bt = Printexc.get_raw_backtrace () in
Message.Content.emit (Message.Content.of_string msg) Error;
if Printexc.backtrace_status () then Printexc.print_raw_backtrace stderr bt;
exit Cmd.Exit.some_error
exit_with_error Cmd.Exit.some_error
@@ fun () -> Message.Content.of_string msg
| exception Sys_error msg ->
let bt = Printexc.get_raw_backtrace () in
Message.Content.emit
(Message.Content.of_string ("System error: " ^ msg))
Error;
if Printexc.backtrace_status () then Printexc.print_raw_backtrace stderr bt;
exit Cmd.Exit.internal_error
exit_with_error Cmd.Exit.internal_error
@@ fun () -> Message.Content.of_string ("System error: " ^ msg)
| exception e ->
let bt = Printexc.get_raw_backtrace () in
Message.Content.emit
(Message.Content.of_string ("Unexpected error: " ^ Printexc.to_string e))
Error;
if Printexc.backtrace_status () then Printexc.print_raw_backtrace stderr bt;
exit Cmd.Exit.internal_error
exit_with_error Cmd.Exit.internal_error
@@ fun () ->
Message.Content.of_string ("Unexpected error: " ^ Printexc.to_string e)
(* Export module PluginAPI, hide parent module Plugin *)
module Plugin = struct

View File

@ -51,7 +51,6 @@ module Passes : sig
optimize:bool ->
check_invariants:bool ->
typed:'m Shared_ast.mark ->
avoid_exceptions:bool ->
closure_conversion:bool ->
monomorphize_types:bool ->
Shared_ast.typed Lcalc.Ast.program * Scopelang.Dependency.TVertex.t list
@ -61,7 +60,6 @@ module Passes : sig
includes:Global.raw_file list ->
optimize:bool ->
check_invariants:bool ->
avoid_exceptions:bool ->
closure_conversion:bool ->
keep_special_ops:bool ->
dead_value_assignment:bool ->

View File

@ -19,63 +19,186 @@ open Shared_ast
open Ast
module D = Dcalc.Ast
type name_context = { prefix : string; mutable counter : int }
type 'm ctx = {
name_context : string;
decl_ctx : decl_ctx;
name_context : name_context;
globally_bound_vars : ('m expr, typ) Var.Map.t;
}
let tys_as_tanys tys = List.map (fun x -> Mark.map (fun _ -> TAny) x) tys
let new_var ?(pfx = "") name_context =
name_context.counter <- name_context.counter + 1;
Var.make (pfx ^ name_context.prefix ^ string_of_int name_context.counter)
(* TODO: Closures end up as a toplevel names. However for now we assume toplevel
names are unique, this is a temporary workaround to avoid name wrangling in
the backends. We need to have a better system for name disambiguation when
for instance printing to Dcalc/Lcalc/Scalc but also OCaml, Python, etc. *)
let new_context prefix = { prefix; counter = 0 }
(** Function types will be transformed in this way throughout, including in
[decl_ctx] *)
let rec translate_type t =
let pos = Mark.get t in
match Mark.remove t with
| TArrow (t1, t2) ->
( TTuple
[
( TArrow
( (TClosureEnv, Pos.no_pos) :: List.map translate_type t1,
translate_type t2 ),
Pos.no_pos );
TClosureEnv, Pos.no_pos;
],
pos )
| TDefault t' -> TDefault (translate_type t'), pos
| TOption t' -> TOption (translate_type t'), pos
| TAny | TClosureEnv | TLit _ | TEnum _ | TStruct _ -> t
| TArray ts -> TArray (translate_type ts), pos
| TTuple ts -> TTuple (List.map translate_type ts), pos
let translate_mark e = Mark.map_mark (Expr.map_ty translate_type) e
let join_vars : ('a, 'x) Var.Map.t -> ('a, 'x) Var.Map.t -> ('a, 'x) Var.Map.t =
fun m1 m2 -> Var.Map.union (fun _ a _ -> Some a) m1 m2
(** {1 Transforming closures}*)
let build_closure :
type m.
m ctx ->
(m expr Var.t * m mark) list ->
m expr boxed ->
m expr Var.t array ->
typ list ->
m mark ->
m expr boxed =
fun ctx free_vars body args tys m ->
(* λ x.t *)
let pos = Expr.mark_pos m in
let mark_ty ty = Expr.with_ty m ty in
let free_vars_types = List.map (fun (_, m) -> Expr.maybe_ty m) free_vars in
(* x1, ..., xn *)
let code_var = new_var ctx.name_context in
(* code *)
let closure_env_arg_var = Var.make "env" in
let closure_env_var = Var.make "env" in
let env_ty = TTuple free_vars_types, pos in
(* let env = from_closure_env env in let arg0 = env.0 in ... *)
let new_closure_body =
Expr.make_let_in closure_env_var env_ty
(Expr.eappop
~op:(Operator.FromClosureEnv, pos)
~tys:[TClosureEnv, pos]
~args:[Expr.evar closure_env_arg_var (mark_ty (TClosureEnv, pos))]
(mark_ty env_ty))
(Expr.make_multiple_let_in
(Array.of_list (List.map fst free_vars))
free_vars_types
(List.mapi
(fun i _ ->
Expr.make_tupleaccess
(Expr.evar closure_env_var (mark_ty env_ty))
i (List.length free_vars) pos)
free_vars)
body pos)
pos
in
(* fun env arg0 ... -> new_closure_body *)
let new_closure =
Expr.make_abs
(Array.append [| closure_env_arg_var |] args)
new_closure_body
((TClosureEnv, pos) :: tys)
pos
in
let new_closure_ty = Expr.maybe_ty (Mark.get new_closure) in
Expr.make_let_in code_var new_closure_ty new_closure
(Expr.make_tuple
((Bindlib.box_var code_var, mark_ty new_closure_ty)
:: [
Expr.eappop
~op:(Operator.ToClosureEnv, pos)
~tys:[TTuple free_vars_types, pos]
~args:
[
Expr.etuple
(List.map
(fun (extra_var, m) ->
Bindlib.box_var extra_var, Expr.with_pos pos m)
free_vars)
(mark_ty (TTuple free_vars_types, pos));
]
(mark_ty (TClosureEnv, pos));
])
m)
pos
(** Returns the expression with closed closures and the set of free variables
inside this new expression. Implementation guided by
http://gallium.inria.fr/~fpottier/mpri/cours04.pdf#page=10
(environment-passing closure conversion). *)
let rec transform_closures_expr :
type m. m ctx -> m expr -> m expr Var.Set.t * m expr boxed =
type m. m ctx -> m expr -> (m expr, m mark) Var.Map.t * m expr boxed =
fun ctx e ->
let e = translate_mark e in
let m = Mark.get e in
match Mark.remove e with
| EStruct _ | EStructAccess _ | ETuple _ | ETupleAccess _ | EInj _ | EArray _
| ELit _ | EExternal _ | EAssert _ | EFatalError _ | EIfThenElse _
| ERaiseEmpty | ECatchEmpty _ ->
Expr.map_gather ~acc:Var.Set.empty ~join:Var.Set.union
| ELit _ | EAssert _ | EFatalError _ | EIfThenElse _ ->
Expr.map_gather ~acc:Var.Map.empty ~join:join_vars
~f:(transform_closures_expr ctx)
e
| EVar v -> (
match Var.Map.find_opt v ctx.globally_bound_vars with
| None -> Var.Set.singleton v, (Bindlib.box_var v, m)
| Some (TArrow (targs, tret), _) ->
| (EVar _ | EExternal _) as e -> (
let body, (free_vars, fty) =
match e with
| EVar v -> (
( Bindlib.box_var v,
match Var.Map.find_opt v ctx.globally_bound_vars with
| None -> Var.Map.singleton v m, None
| Some ((TArrow (targs, tret), _) as fty) ->
Var.Map.empty, Some (targs, tret, fty)
| Some _ -> Var.Map.empty, None ))
| EExternal { name = External_value td, _ } as e ->
( Bindlib.box e,
( Var.Map.empty,
match TopdefName.Map.find td ctx.decl_ctx.ctx_topdefs with
| (TArrow (targs, tret), _) as fty -> Some (targs, tret, fty)
| _ -> None ) )
| EExternal { name = External_scope s, pos } ->
let fty =
let si = ScopeName.Map.find s ctx.decl_ctx.ctx_scopes in
let t_in = TStruct si.in_struct_name, pos in
let t_out = TStruct si.out_struct_name, pos in
[t_in], t_out, (TArrow ([t_in], t_out), pos)
in
Bindlib.box e, (Var.Map.empty, Some fty)
| _ -> assert false
in
match fty with
| None -> free_vars, (body, m)
| Some (targs, tret, fty) ->
(* Here we eta-expand the argument to make sure function pointers are
correctly casted as closures *)
let args = Array.init (List.length targs) (fun _ -> Var.make "eta_arg") in
let args =
Array.init (List.length targs) (fun i ->
Var.make ("x" ^ string_of_int i))
in
let arg_vars =
List.map2
(fun v ty -> Expr.evar v (Expr.with_ty m ty))
(Array.to_list args) targs
in
let e =
Expr.eabs
(Expr.bind args
(Expr.eapp ~f:(Expr.rebox e) ~args:arg_vars ~tys:targs
(Expr.with_ty m tret)))
targs m
in
let boxed =
let ctx =
(* We hide the type of the toplevel definition so that the function
doesn't loop *)
{
ctx with
globally_bound_vars =
Var.Map.add v (TAny, Pos.no_pos) ctx.globally_bound_vars;
}
let closure =
let body =
Expr.eapp
~f:(body, Expr.with_ty m fty)
~args:arg_vars ~tys:targs (Expr.with_ty m tret)
in
Bindlib.box_apply (transform_closures_expr ctx) (Expr.Box.lift e)
build_closure ctx [] body args targs m
in
Bindlib.unbox boxed
| Some _ -> Var.Set.empty, (Bindlib.box_var v, m))
Var.Map.empty, closure)
| EMatch { e; cases; name } ->
let free_vars, new_e = (transform_closures_expr ctx) e in
(* We do not close the clotures inside the arms of the match expression,
@ -89,17 +212,15 @@ let rec transform_closures_expr :
let new_free_vars, new_body = (transform_closures_expr ctx) body in
let new_free_vars =
Array.fold_left
(fun acc v -> Var.Set.remove v acc)
(fun acc v -> Var.Map.remove v acc)
new_free_vars vars
in
let new_binder = Expr.bind vars new_body in
( Var.Set.union free_vars
(Var.Set.diff new_free_vars
(Var.Set.of_list (Array.to_list vars))),
( join_vars free_vars new_free_vars,
EnumConstructor.Map.add cons
(Expr.eabs new_binder tys (Mark.get e1))
new_cases )
| _ -> failwith "should not happen")
| _ -> assert false)
cases
(free_vars, EnumConstructor.Map.empty)
in
@ -109,97 +230,33 @@ let rec transform_closures_expr :
let vars, body = Bindlib.unmbind binder in
let free_vars, new_body = (transform_closures_expr ctx) body in
let free_vars =
Array.fold_left (fun acc v -> Var.Set.remove v acc) free_vars vars
Array.fold_left (fun acc v -> Var.Map.remove v acc) free_vars vars
in
let new_binder = Expr.bind vars new_body in
let free_vars, new_args =
List.fold_right
(fun arg (free_vars, new_args) ->
let new_free_vars, new_arg = (transform_closures_expr ctx) arg in
Var.Set.union free_vars new_free_vars, new_arg :: new_args)
join_vars free_vars new_free_vars, new_arg :: new_args)
args (free_vars, [])
in
( free_vars,
Expr.eapp
~f:(Expr.eabs new_binder (tys_as_tanys tys) e1_pos)
~f:(Expr.eabs new_binder (List.map translate_type tys) e1_pos)
~args:new_args ~tys m )
| EAbs { binder; tys } ->
(* λ x.t *)
let binder_mark = Expr.with_ty m (TAny, Expr.mark_pos m) in
let binder_pos = Expr.mark_pos binder_mark in
(* Converting the closure. *)
let vars, body = Bindlib.unmbind binder in
(* t *)
let body_vars, new_body = (transform_closures_expr ctx) body in
let free_vars, body = (transform_closures_expr ctx) body in
(* [[t]] *)
let extra_vars =
Var.Set.diff body_vars (Var.Set.of_list (Array.to_list vars))
let free_vars =
Array.fold_left (fun m v -> Var.Map.remove v m) free_vars vars
in
let extra_vars_list = Var.Set.elements extra_vars in
(* x1, ..., xn *)
let code_var = Var.make ctx.name_context in
(* code *)
let closure_env_arg_var = Var.make "env" in
let closure_env_var = Var.make "env" in
let any_ty = TAny, binder_pos in
(* let env = from_closure_env env in let arg0 = env.0 in ... *)
let new_closure_body =
Expr.make_let_in closure_env_var any_ty
(Expr.eappop
~op:(Operator.FromClosureEnv, binder_pos)
~tys:[TClosureEnv, binder_pos]
~args:[Expr.evar closure_env_arg_var binder_mark]
binder_mark)
(Expr.make_multiple_let_in
(Array.of_list extra_vars_list)
(List.map (fun _ -> any_ty) extra_vars_list)
(List.mapi
(fun i _ ->
Expr.make_tupleaccess
(Expr.evar closure_env_var binder_mark)
i
(List.length extra_vars_list)
binder_pos)
extra_vars_list)
new_body binder_pos)
binder_pos
in
(* fun env arg0 ... -> new_closure_body *)
let new_closure =
Expr.make_abs
(Array.concat [Array.make 1 closure_env_arg_var; vars])
new_closure_body
((TClosureEnv, binder_pos) :: tys)
(Expr.pos e)
in
( extra_vars,
Expr.make_let_in code_var
(TAny, Expr.pos e)
new_closure
(Expr.make_tuple
((Bindlib.box_var code_var, binder_mark)
:: [
Expr.eappop
~op:(Operator.ToClosureEnv, binder_pos)
~tys:[TAny, Expr.pos e]
~args:
[
(if extra_vars_list = [] then Expr.elit LUnit binder_mark
else
Expr.etuple
(List.map
(fun extra_var ->
Bindlib.box_var extra_var, binder_mark)
extra_vars_list)
m);
]
(Mark.get e);
])
m)
(Expr.pos e) )
free_vars, build_closure ctx (Var.Map.bindings free_vars) body vars tys m
| EAppOp
{
op = ((HandleDefaultOpt | Fold | Map | Filter | Reduce), _) as op;
op = ((HandleExceptions | Fold | Map | Map2 | Filter | Reduce), _) as op;
tys;
args;
} ->
@ -216,19 +273,22 @@ let rec transform_closures_expr :
| EAbs { binder; tys } ->
let vars, arg = Bindlib.unmbind binder in
let new_free_vars, new_arg = (transform_closures_expr ctx) arg in
let new_free_vars =
Array.fold_left (fun m v -> Var.Map.remove v m) new_free_vars vars
in
let new_arg =
Expr.make_abs vars new_arg tys (Expr.mark_pos m_arg)
in
Var.Set.union free_vars new_free_vars, new_arg :: new_args
join_vars free_vars new_free_vars, new_arg :: new_args
| _ ->
let new_free_vars, new_arg = transform_closures_expr ctx arg in
Var.Set.union free_vars new_free_vars, new_arg :: new_args)
args (Var.Set.empty, [])
join_vars free_vars new_free_vars, new_arg :: new_args)
args (Var.Map.empty, [])
in
free_vars, Expr.eappop ~op ~tys ~args:new_args (Mark.get e)
| EAppOp _ ->
(* This corresponds to an operator call, which we don't want to transform *)
Expr.map_gather ~acc:Var.Set.empty ~join:Var.Set.union
Expr.map_gather ~acc:Var.Map.empty ~join:join_vars
~f:(transform_closures_expr ctx)
e
| EApp { f = EVar v, f_m; args; tys }
@ -239,12 +299,16 @@ let rec transform_closures_expr :
List.fold_right
(fun arg (free_vars, new_args) ->
let new_free_vars, new_arg = (transform_closures_expr ctx) arg in
Var.Set.union free_vars new_free_vars, new_arg :: new_args)
args (Var.Set.empty, [])
join_vars free_vars new_free_vars, new_arg :: new_args)
args (Var.Map.empty, [])
in
free_vars, Expr.eapp ~f:(Expr.evar v f_m) ~args:new_args ~tys m
| EApp { f = e1; args; tys } ->
let free_vars, new_e1 = (transform_closures_expr ctx) e1 in
let tys = List.map translate_type tys in
let pos = Expr.mark_pos m in
let env_arg_ty = TClosureEnv, Expr.pos new_e1 in
let fun_ty = TArrow (env_arg_ty :: tys, Expr.maybe_ty m), pos in
let code_env_var = Var.make "code_and_env" in
let code_env_expr =
let pos = Expr.pos e1 in
@ -252,8 +316,7 @@ let rec transform_closures_expr :
(Expr.with_ty (Mark.get e1)
( TTuple
[
( TArrow ((TClosureEnv, pos) :: tys, (TAny, Expr.pos e)),
Expr.pos e );
TArrow ((TClosureEnv, pos) :: tys, Expr.maybe_ty m), Expr.pos e;
TClosureEnv, pos;
],
pos ))
@ -264,29 +327,25 @@ let rec transform_closures_expr :
List.fold_right
(fun arg (free_vars, new_args) ->
let new_free_vars, new_arg = (transform_closures_expr ctx) arg in
Var.Set.union free_vars new_free_vars, new_arg :: new_args)
join_vars free_vars new_free_vars, new_arg :: new_args)
args (free_vars, [])
in
let call_expr =
let m1 = Mark.get e1 in
let pos = Expr.mark_pos m in
let env_arg_ty = TClosureEnv, Expr.pos e1 in
let fun_ty = TArrow (env_arg_ty :: tys, (TAny, Expr.pos e)), Expr.pos e in
let m1 = Mark.get new_e1 in
Expr.make_multiple_let_in [| code_var; env_var |] [fun_ty; env_arg_ty]
[
Expr.make_tupleaccess code_env_expr 0 2 pos;
Expr.make_tupleaccess code_env_expr 1 2 pos;
]
(Expr.eapp
~f:(Bindlib.box_var code_var, m1)
~args:((Bindlib.box_var env_var, m1) :: new_args)
~tys:(env_arg_ty :: tys) m)
(Expr.pos e)
(Expr.make_app
(Bindlib.box_var code_var, Expr.with_ty m1 fun_ty)
((Bindlib.box_var env_var, Expr.with_ty m1 env_arg_ty) :: new_args)
(env_arg_ty
:: (* List.map (fun (_, m) -> Expr.maybe_ty m) new_args *) tys)
pos)
pos
in
( free_vars,
Expr.make_let_in code_env_var
(TAny, Expr.pos e)
new_e1 call_expr (Expr.pos e) )
free_vars, Expr.make_let_in code_env_var (TAny, pos) new_e1 call_expr pos
| _ -> .
let transform_closures_scope_let ctx scope_body_expr =
@ -294,7 +353,7 @@ let transform_closures_scope_let ctx scope_body_expr =
~f:(fun var_next scope_let ->
let _free_vars, new_scope_let_expr =
(transform_closures_expr
{ ctx with name_context = Bindlib.name_of var_next })
{ ctx with name_context = new_context (Bindlib.name_of var_next) })
scope_let.scope_let_expr
in
( var_next,
@ -325,7 +384,8 @@ let transform_closures_program (p : 'm program) : 'm program Bindlib.box =
in
let ctx =
{
name_context = Mark.remove (ScopeName.get_info name);
decl_ctx = p.decl_ctx;
name_context = new_context (Mark.remove (ScopeName.get_info name));
globally_bound_vars = toplevel_vars;
}
in
@ -352,7 +412,9 @@ let transform_closures_program (p : 'm program) : 'm program Bindlib.box =
let v, expr = Bindlib.unmbind binder in
let ctx =
{
name_context = Mark.remove (TopdefName.get_info name);
decl_ctx = p.decl_ctx;
name_context =
new_context (Mark.remove (TopdefName.get_info name));
globally_bound_vars = toplevel_vars;
}
in
@ -366,7 +428,9 @@ let transform_closures_program (p : 'm program) : 'm program Bindlib.box =
| Topdef (name, ty, expr) ->
let ctx =
{
name_context = Mark.remove (TopdefName.get_info name);
decl_ctx = p.decl_ctx;
name_context =
new_context (Mark.remove (TopdefName.get_info name));
globally_bound_vars = toplevel_vars;
}
in
@ -393,33 +457,15 @@ let transform_closures_program (p : 'm program) : 'm program Bindlib.box =
capture footprint. See
[tests/tests_func/good/scope_call_func_struct_closure.catala_en]. *)
let new_decl_ctx =
let rec replace_fun_typs t =
match Mark.remove t with
| TArrow (t1, t2) ->
( TTuple
[
( TArrow
( (TClosureEnv, Pos.no_pos) :: List.map replace_fun_typs t1,
replace_fun_typs t2 ),
Pos.no_pos );
TClosureEnv, Pos.no_pos;
],
Mark.get t )
| TDefault t' -> TDefault (replace_fun_typs t'), Mark.get t
| TOption t' -> TOption (replace_fun_typs t'), Mark.get t
| TAny | TClosureEnv | TLit _ | TEnum _ | TStruct _ -> t
| TArray ts -> TArray (replace_fun_typs ts), Mark.get t
| TTuple ts -> TTuple (List.map replace_fun_typs ts), Mark.get t
in
{
p.decl_ctx with
ctx_structs =
StructName.Map.map
(StructField.Map.map replace_fun_typs)
(StructField.Map.map translate_type)
p.decl_ctx.ctx_structs;
ctx_enums =
EnumName.Map.map
(EnumConstructor.Map.map replace_fun_typs)
(EnumConstructor.Map.map translate_type)
p.decl_ctx.ctx_enums;
(* Toplevel definitions may not contain scope calls or take functions as
arguments at the moment, which ensures that their interfaces aren't
@ -445,7 +491,7 @@ type 'm hoisted_closure = {
}
let rec hoist_closures_expr :
type m. string -> m expr -> m hoisted_closure list * m expr boxed =
type m. name_context -> m expr -> m hoisted_closure list * m expr boxed =
fun name_context e ->
let m = Mark.get e in
match Mark.remove e with
@ -467,7 +513,7 @@ let rec hoist_closures_expr :
EnumConstructor.Map.add cons
(Expr.eabs new_binder tys (Mark.get e1))
new_cases )
| _ -> failwith "should not happen")
| _ -> assert false)
cases
(collected_closures, EnumConstructor.Map.empty)
in
@ -489,15 +535,8 @@ let rec hoist_closures_expr :
args (collected_closures, [])
in
( collected_closures,
Expr.eapp
~f:(Expr.eabs new_binder (tys_as_tanys tys) e1_pos)
~args:new_args ~tys m )
| EAppOp
{
op = ((HandleDefaultOpt | Fold | Map | Filter | Reduce), _) as op;
tys;
args;
} ->
Expr.eapp ~f:(Expr.eabs new_binder tys e1_pos) ~args:new_args ~tys m )
| EAppOp { op = ((Fold | Map | Filter | Reduce), _) as op; tys; args } ->
(* Special case for some operators: its arguments closures thunks because if
you want to extract it as a function you need these closures to preserve
evaluation order, but backends that don't support closures will simply
@ -524,34 +563,30 @@ let rec hoist_closures_expr :
args ([], [])
in
collected_closures, Expr.eappop ~op ~args:new_args ~tys (Mark.get e)
| EAbs { tys; _ } ->
(* this is the closure we want to hoist*)
let closure_var = Var.make ("closure_" ^ name_context) in
(* TODO: This will end up as a toplevel name. However for now we assume
toplevel names are unique, but this breaks this assertions and can lead
to name wrangling in the backends. We need to have a better system for
name disambiguation when for instance printing to Dcalc/Lcalc/Scalc but
also OCaml, Python, etc. *)
( [
{
name = closure_var;
ty = TArrow (tys, (TAny, Expr.mark_pos m)), Expr.mark_pos m;
closure = Expr.rebox e;
};
],
| EAbs { binder; tys } ->
(* this is the closure we want to hoist *)
let closure_var = new_var ~pfx:"closure_" name_context in
let pos = Expr.mark_pos m in
let ty = Expr.maybe_ty ~typ:(TArrow (tys, (TAny, pos))) m in
let vars, body = Bindlib.unmbind binder in
let collected_closures, new_body =
(hoist_closures_expr name_context) body
in
let closure = Expr.make_abs vars new_body tys pos in
( { name = closure_var; ty; closure } :: collected_closures,
Expr.make_var closure_var m )
| EApp _ | EStruct _ | EStructAccess _ | ETuple _ | ETupleAccess _ | EInj _
| EArray _ | ELit _ | EAssert _ | EFatalError _ | EAppOp _ | EIfThenElse _
| ERaiseEmpty | ECatchEmpty _ | EVar _ ->
| EVar _ ->
Expr.map_gather ~acc:[] ~join:( @ ) ~f:(hoist_closures_expr name_context) e
| EExternal _ -> failwith "unimplemented"
| EExternal { name } -> [], Expr.box (EExternal { name }, m)
| _ -> .
let hoist_closures_scope_let name_context scope_body_expr =
BoundList.fold_right
~f:(fun scope_let var_next (hoisted_closures, next_scope_lets) ->
let new_hoisted_closures, new_scope_let_expr =
(hoist_closures_expr (Bindlib.name_of var_next))
(hoist_closures_expr (new_context (Bindlib.name_of var_next)))
scope_let.scope_let_expr
in
( new_hoisted_closures @ hoisted_closures,
@ -588,7 +623,7 @@ let rec hoist_closures_code_item_list
in
let new_hoisted_closures, new_scope_lets =
hoist_closures_scope_let
(fst (ScopeName.get_info name))
(new_context (fst (ScopeName.get_info name)))
scope_body_expr
in
let new_scope_body_expr =
@ -602,7 +637,9 @@ let rec hoist_closures_code_item_list
| Topdef (name, ty, (EAbs { binder; tys }, m)) ->
let v, expr = Bindlib.unmbind binder in
let new_hoisted_closures, new_expr =
hoist_closures_expr (Mark.remove (TopdefName.get_info name)) expr
hoist_closures_expr
(new_context (Mark.remove (TopdefName.get_info name)))
expr
in
let new_binder = Expr.bind v new_expr in
( new_hoisted_closures,
@ -611,7 +648,9 @@ let rec hoist_closures_code_item_list
(Expr.Box.lift (Expr.eabs new_binder tys m)) )
| Topdef (name, ty, expr) ->
let new_hoisted_closures, new_expr =
hoist_closures_expr (Mark.remove (TopdefName.get_info name)) expr
hoist_closures_expr
(new_context (Mark.remove (TopdefName.get_info name)))
expr
in
( new_hoisted_closures,
Bindlib.box_apply
@ -660,9 +699,7 @@ let hoist_closures_program (p : 'm program) : 'm program Bindlib.box =
(** {1 Closure conversion}*)
let closure_conversion (p : 'm program) : untyped program =
let closure_conversion (p : 'm program) : 'm program =
let new_p = transform_closures_program p in
let new_p = hoist_closures_program (Bindlib.unbox new_p) in
(* FIXME: either fix the types of the marks, or remove the types annotations
during the main processing (rather than requiring a new traversal) *)
Program.untype (Bindlib.unbox new_p)
Bindlib.unbox new_p

View File

@ -21,4 +21,4 @@
After closure conversion, closure hoisting is perform and all closures end
up as toplevel definitions. *)
val closure_conversion : 'm Ast.program -> Shared_ast.untyped Ast.program
val closure_conversion : 'm Ast.program -> 'm Ast.program

View File

@ -1,95 +0,0 @@
(* This file is part of the Catala compiler, a specification language for tax
and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
Denis Merigoux <denis.merigoux@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not
use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *)
open Catala_utils
open Shared_ast
module D = Dcalc.Ast
module A = Ast
let rec translate_typ (tau : typ) : typ =
Mark.map
(function
| TDefault t -> Mark.remove (translate_typ t)
| TLit l -> TLit l
| TTuple ts -> TTuple (List.map translate_typ ts)
| TStruct s -> TStruct s
| TEnum en -> TEnum en
| TOption _ ->
Message.error ~internal:true
"The types option should not appear before the dcalc -> lcalc \
translation step."
| TClosureEnv ->
Message.error ~internal:true
"The types closure_env should not appear before the dcalc -> lcalc \
translation step."
| TAny -> TAny
| TArray ts -> TArray (translate_typ ts)
| TArrow (t1, t2) -> TArrow (List.map translate_typ t1, translate_typ t2))
tau
let translate_mark m = Expr.map_ty translate_typ m
let rec translate_default
(exceptions : 'm D.expr list)
(just : 'm D.expr)
(cons : 'm D.expr)
(mark_default : 'm mark) : 'm A.expr boxed =
let pos = Expr.mark_pos mark_default in
let exceptions =
List.map (fun except -> Expr.thunk_term (translate_expr except)) exceptions
in
Expr.eappop
~op:(Op.HandleDefault, Expr.pos cons)
~tys:
[
TArray (TArrow ([TLit TUnit, pos], (TAny, pos)), pos), pos;
TArrow ([TLit TUnit, pos], (TLit TBool, pos)), pos;
TArrow ([TLit TUnit, pos], (TAny, pos)), pos;
]
~args:
[
Expr.earray exceptions
(Expr.map_ty
(fun ty -> TArray (TArrow ([TLit TUnit, pos], ty), pos), pos)
mark_default);
Expr.thunk_term (translate_expr just);
Expr.thunk_term (translate_expr cons);
]
mark_default
and translate_expr (e : 'm D.expr) : 'm A.expr boxed =
match e with
| EEmpty, m -> Expr.eraiseempty (translate_mark m)
| EErrorOnEmpty arg, m ->
let m = translate_mark m in
Expr.ecatchempty (translate_expr arg) (Expr.efatalerror Runtime.NoValue m) m
| EDefault { excepts; just; cons }, m ->
translate_default excepts just cons (translate_mark m)
| EPureDefault e, _ -> translate_expr e
| EAppOp { op; args; tys }, m ->
Expr.eappop ~op:(Operator.translate op)
~args:(List.map translate_expr args)
~tys:(List.map translate_typ tys)
(translate_mark m)
| ( ( ELit _ | EArray _ | EVar _ | EAbs _ | EApp _ | EExternal _
| EIfThenElse _ | ETuple _ | ETupleAccess _ | EInj _ | EAssert _
| EFatalError _ | EStruct _ | EStructAccess _ | EMatch _ ),
_ ) as e ->
Expr.map ~f:translate_expr ~typ:translate_typ e
| _ -> .
let translate_program (prg : 'm D.program) : 'm A.program =
Program.map_exprs prg ~typ:translate_typ ~varf:Var.translate ~f:translate_expr

View File

@ -1,126 +0,0 @@
(* This file is part of the Catala compiler, a specification language for tax
and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
Alain Delaët-Tixeuil <alain.delaet--tixeuil@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not
use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *)
open Catala_utils
open Shared_ast
module D = Dcalc.Ast
module A = Ast
(** We make use of the strong invriants on the structure of programs:
Defaultable values can only appear in certin positions. This information is
given by the type structure of expressions. In particular this mean we don't
need to use the monadic bind while computing arithmetic opertions or
function calls. The resulting function is not more difficult than what we
had when translating without exceptions.
The typing translation is to simply trnsform default type into option types. *)
let rec translate_typ (tau : typ) : typ =
Mark.copy tau
begin
match Mark.remove tau with
| TDefault t -> TOption (translate_typ t)
| TLit l -> TLit l
| TTuple ts -> TTuple (List.map translate_typ ts)
| TStruct s -> TStruct s
| TEnum en -> TEnum en
| TOption _ ->
Message.error ~internal:true
"The types option should not appear before the dcalc -> lcalc \
translation step."
| TClosureEnv ->
Message.error ~internal:true
"The types closure_env should not appear before the dcalc -> lcalc \
translation step."
| TAny -> TAny
| TArray ts -> TArray (translate_typ ts)
| TArrow (t1, t2) -> TArrow (List.map translate_typ t1, translate_typ t2)
end
let translate_mark m = Expr.map_ty translate_typ m
let rec translate_default
(exceptions : 'm D.expr list)
(just : 'm D.expr)
(cons : 'm D.expr)
(mark_default : 'm mark) : 'm A.expr boxed =
(* Since the program is well typed, all exceptions have as type [option 't] *)
let pos = Expr.mark_pos mark_default in
let exceptions = List.map translate_expr exceptions in
let exceptions_and_cons_ty = Expr.maybe_ty mark_default in
Expr.eappop
~op:(Op.HandleDefaultOpt, Expr.pos cons)
~tys:
[
TArray exceptions_and_cons_ty, pos;
TArrow ([TLit TUnit, pos], (TLit TBool, pos)), pos;
TArrow ([TLit TUnit, pos], exceptions_and_cons_ty), pos;
]
~args:
[
Expr.earray exceptions
(Expr.map_ty (fun ty -> TArray ty, pos) mark_default);
(* In call-by-value programming languages, as lcalc, arguments are
evalulated before calling the function. Since we don't want to
execute the justification and conclusion while before checking every
exceptions, we need to thunk them. *)
Expr.thunk_term (translate_expr just);
Expr.thunk_term (translate_expr cons);
]
mark_default
and translate_expr (e : 'm D.expr) : 'm A.expr boxed =
match e with
| EEmpty, m ->
let m = translate_mark m in
let pos = Expr.mark_pos m in
Expr.einj
~e:(Expr.elit LUnit (Expr.with_ty m (TLit TUnit, pos)))
~cons:Expr.none_constr ~name:Expr.option_enum m
| EErrorOnEmpty arg, m ->
let m = translate_mark m in
let pos = Expr.mark_pos m in
let cases =
EnumConstructor.Map.of_list
[
( Expr.none_constr,
let x = Var.make "_" in
Expr.make_abs [| x |] (Expr.efatalerror NoValue m) [TAny, pos] pos
);
(* | None x -> raise NoValueProvided *)
Expr.some_constr, Expr.fun_id ~var_name:"arg" m (* | Some x -> x *);
]
in
Expr.ematch ~e:(translate_expr arg) ~name:Expr.option_enum ~cases m
| EDefault { excepts; just; cons }, m ->
translate_default excepts just cons (translate_mark m)
| EPureDefault e, m ->
Expr.einj ~e:(translate_expr e) ~cons:Expr.some_constr
~name:Expr.option_enum (translate_mark m)
| EAppOp { op; tys; args }, m ->
Expr.eappop ~op:(Operator.translate op)
~tys:(List.map translate_typ tys)
~args:(List.map translate_expr args)
(translate_mark m)
| ( ( ELit _ | EArray _ | EVar _ | EApp _ | EAbs _ | EExternal _
| EIfThenElse _ | ETuple _ | ETupleAccess _ | EInj _ | EAssert _
| EFatalError _ | EStruct _ | EStructAccess _ | EMatch _ ),
_ ) as e ->
Expr.map ~f:translate_expr ~typ:translate_typ e
| _ -> .
let translate_program (prg : 'm D.program) : 'm A.program =
Program.map_exprs prg ~typ:translate_typ ~varf:Var.translate ~f:translate_expr

View File

@ -1,22 +0,0 @@
(* This file is part of the Catala compiler, a specification language for tax
and social benefits computation rules. Copyright (C) 2020-2022 Inria,
contributor: Alain Delaët-Tixeuil <alain.delaet--tixeuil@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not
use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *)
(** Translation from the default calculus to the lambda calculus. This
translation uses an option monad to handle empty defaults terms. This
transformation is one piece to permit to compile toward legacy languages
that does not contains exceptions. *)
val translate_program : 'm Dcalc.Ast.program -> 'm Ast.program

View File

@ -1,6 +1,6 @@
(* This file is part of the Catala compiler, a specification language for tax
and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
Denis Merigoux <denis.merigoux@inria.fr>
Alain Delaët-Tixeuil <alain.delaet--tixeuil@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not
use this file except in compliance with the License. You may obtain a copy of
@ -14,7 +14,130 @@
License for the specific language governing permissions and limitations under
the License. *)
open Catala_utils
open Shared_ast
module D = Dcalc.Ast
module A = Ast
(** We make use of the strong invriants on the structure of programs:
Defaultable values can only appear in certin positions. This information is
given by the type structure of expressions. In particular this mean we don't
need to use the monadic bind while computing arithmetic opertions or
function calls. The resulting function is not more difficult than what we
had when translating without exceptions.
The typing translation is to simply trnsform default type into option types. *)
let rec translate_typ (tau : typ) : typ =
Mark.copy tau
begin
match Mark.remove tau with
| TDefault t -> TOption (translate_typ t)
| TLit l -> TLit l
| TTuple ts -> TTuple (List.map translate_typ ts)
| TStruct s -> TStruct s
| TEnum en -> TEnum en
| TOption _ ->
Message.error ~internal:true
"The types option should not appear before the dcalc -> lcalc \
translation step."
| TClosureEnv ->
Message.error ~internal:true
"The types closure_env should not appear before the dcalc -> lcalc \
translation step."
| TAny -> TAny
| TArray ts -> TArray (translate_typ ts)
| TArrow (t1, t2) -> TArrow (List.map translate_typ t1, translate_typ t2)
end
let translate_mark m = Expr.map_ty translate_typ m
let rec translate_default
(exceptions : 'm D.expr list)
(just : 'm D.expr)
(cons : 'm D.expr)
(mark_default : 'm mark) : 'm A.expr boxed =
(* Since the program is well typed, all exceptions have as type [option 't] *)
let pos = Expr.mark_pos mark_default in
let exceptions = List.map translate_expr exceptions in
let ty_option = Expr.maybe_ty mark_default in
let ty_array = TArray ty_option, pos in
let ty_alpha =
match ty_option with
| TOption ty, _ -> ty
| (TAny, _) as ty -> ty
| _ -> assert false
in
let mark_alpha = Expr.with_ty mark_default ty_alpha in
Expr.ematch ~name:Expr.option_enum
~e:
(Expr.eappop
~op:(Op.HandleExceptions, Expr.pos cons)
~tys:[ty_array]
~args:[Expr.earray exceptions (Expr.with_ty mark_default ty_array)]
mark_default)
~cases:
(EnumConstructor.Map.of_list
[
(* Some x -> Some x *)
( Expr.some_constr,
let x = Var.make "x" in
Expr.make_abs [| x |]
(Expr.einj ~name:Expr.option_enum ~cons:Expr.some_constr
~e:(Expr.evar x mark_alpha) mark_default)
[ty_alpha] pos );
(* None -> if just then cons else None *)
( Expr.none_constr,
Expr.thunk_term
(Expr.eifthenelse (translate_expr just) (translate_expr cons)
(Expr.einj
~e:
(Expr.elit LUnit
(Expr.with_ty mark_default (TLit TUnit, pos)))
~cons:Expr.none_constr ~name:Expr.option_enum mark_default)
mark_default) );
])
mark_default
and translate_expr (e : 'm D.expr) : 'm A.expr boxed =
match e with
| EEmpty, m ->
let m = translate_mark m in
let pos = Expr.mark_pos m in
Expr.einj
~e:(Expr.elit LUnit (Expr.with_ty m (TLit TUnit, pos)))
~cons:Expr.none_constr ~name:Expr.option_enum m
| EErrorOnEmpty arg, m ->
let m = translate_mark m in
let pos = Expr.mark_pos m in
let cases =
EnumConstructor.Map.of_list
[
( Expr.none_constr,
let x = Var.make "_" in
Expr.make_abs [| x |] (Expr.efatalerror NoValue m) [TAny, pos] pos
);
(* | None x -> raise NoValueProvided *)
Expr.some_constr, Expr.fun_id ~var_name:"arg" m (* | Some x -> x *);
]
in
Expr.ematch ~e:(translate_expr arg) ~name:Expr.option_enum ~cases m
| EDefault { excepts; just; cons }, m ->
translate_default excepts just cons (translate_mark m)
| EPureDefault e, m ->
Expr.einj ~e:(translate_expr e) ~cons:Expr.some_constr
~name:Expr.option_enum (translate_mark m)
| EAppOp { op; tys; args }, m ->
Expr.eappop ~op:(Operator.translate op)
~tys:(List.map translate_typ tys)
~args:(List.map translate_expr args)
(translate_mark m)
| ( ( ELit _ | EArray _ | EVar _ | EApp _ | EAbs _ | EExternal _
| EIfThenElse _ | ETuple _ | ETupleAccess _ | EInj _ | EAssert _
| EFatalError _ | EStruct _ | EStructAccess _ | EMatch _ ),
_ ) as e ->
Expr.map ~f:translate_expr ~typ:translate_typ e
| _ -> .
let add_option_type ctx =
{
@ -26,9 +149,7 @@ let add_option_type ctx =
let add_option_type_program prg =
{ prg with decl_ctx = add_option_type prg.decl_ctx }
let translate_program_with_exceptions =
Compile_with_exceptions.translate_program
let translate_program_without_exceptions prg =
let prg = add_option_type_program prg in
Compile_without_exceptions.translate_program prg
let translate_program (prg : 'm D.program) : 'm A.program =
Program.map_exprs
(add_option_type_program prg)
~typ:translate_typ ~varf:Var.translate ~f:translate_expr

View File

@ -1,6 +1,6 @@
(* This file is part of the Catala compiler, a specification language for tax
and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
Denis Merigoux <denis.merigoux@inria.fr>
and social benefits computation rules. Copyright (C) 2020-2022 Inria,
contributor: Alain Delaët-Tixeuil <alain.delaet--tixeuil@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not
use this file except in compliance with the License. You may obtain a copy of
@ -14,13 +14,9 @@
License for the specific language governing permissions and limitations under
the License. *)
val translate_program_with_exceptions : 'm Dcalc.Ast.program -> 'm Ast.program
(** Translation from the default calculus to the lambda calculus. This
translation uses exceptions to handle empty default terms. *)
val translate_program_without_exceptions :
'm Dcalc.Ast.program -> 'm Ast.program
(** Translation from the default calculus to the lambda calculus. This
translation uses an option monad to handle empty defaults terms. This
transformation is one piece to permit to compile toward legacy languages
that does not contains catchable exceptions. *)
that does not contains exceptions. *)
val translate_program : 'm Dcalc.Ast.program -> 'm Ast.program

View File

@ -78,7 +78,7 @@ let collect_monomorphized_instances (prg : typed program) :
args;
name =
StructName.fresh []
( "tuple_" ^ string_of_int !option_instances_counter,
( "tuple_" ^ string_of_int !tuple_instances_counter,
Pos.no_pos );
})
acc.tuples;
@ -90,7 +90,7 @@ let collect_monomorphized_instances (prg : typed program) :
{
acc with
arrays =
Type.Map.update t
Type.Map.update typ
(fun monomorphized_name ->
match monomorphized_name with
| Some e -> Some e
@ -118,7 +118,7 @@ let collect_monomorphized_instances (prg : typed program) :
{
acc with
options =
Type.Map.update t
Type.Map.update typ
(fun monomorphized_name ->
match monomorphized_name with
| Some e -> Some e
@ -146,15 +146,9 @@ let collect_monomorphized_instances (prg : typed program) :
collect_typ new_acc t
| TStruct _ | TEnum _ | TAny | TClosureEnv | TLit _ -> acc
| TOption _ | TTuple _ ->
raise
(Message.CompilerError
(Message.Content.add_position
(Message.Content.to_internal_error
(Message.Content.of_message (fun fmt ->
Format.fprintf fmt
"Some types in tuples or option have not been resolved \
by the typechecking before monomorphization.")))
(Mark.get typ)))
Message.error ~internal:true ~pos:(Mark.get typ)
"Some types in tuples or option have not been resolved by the \
typechecking before monomorphization."
in
let rec collect_expr e acc =
Expr.shallow_fold collect_expr e (collect_typ acc (Expr.ty e))
@ -179,8 +173,9 @@ let rec monomorphize_typ
(typ : typ) : typ =
match Mark.remove typ with
| TStruct _ | TEnum _ | TAny | TClosureEnv | TLit _ -> typ
| TArray t1 ->
TStruct (Type.Map.find t1 monomorphized_instances.arrays).name, Mark.get typ
| TArray _ ->
( TStruct (Type.Map.find typ monomorphized_instances.arrays).name,
Mark.get typ )
| TDefault t1 ->
TDefault (monomorphize_typ monomorphized_instances t1), Mark.get typ
| TArrow (t1s, t2) ->
@ -191,8 +186,8 @@ let rec monomorphize_typ
| TTuple _ ->
( TStruct (Type.Map.find typ monomorphized_instances.tuples).name,
Mark.get typ )
| TOption t1 ->
TEnum (Type.Map.find t1 monomorphized_instances.options).name, Mark.get typ
| TOption _ ->
TEnum (Type.Map.find typ monomorphized_instances.options).name, Mark.get typ
let is_some c =
EnumConstructor.equal Expr.some_constr c
@ -239,7 +234,12 @@ let rec monomorphize_expr
field = fst (List.nth tuple_instance.fields index);
}
| EMatch { name; e; cases } when EnumName.equal name Expr.option_enum ->
let option_instance = Type.Map.find ty0 monomorphized_instances.options in
let opt_ty =
match e0 with EMatch { e; _ }, _ -> Expr.ty e | _ -> assert false
in
let option_instance =
Type.Map.find opt_ty monomorphized_instances.options
in
EMatch
{
name = option_instance.name;
@ -253,11 +253,7 @@ let rec monomorphize_expr
cases EnumConstructor.Map.empty;
}
| EInj { name; e; cons } when EnumName.equal name Expr.option_enum ->
let option_instance =
Type.Map.find
(match Mark.remove ty0 with TOption t -> t | _ -> assert false)
monomorphized_instances.options
in
let option_instance = Type.Map.find ty0 monomorphized_instances.options in
EInj
{
name = option_instance.name;
@ -270,7 +266,7 @@ let rec monomorphize_expr
let elt_ty =
match Mark.remove ty0 with TArray t -> t | _ -> assert false
in
let array_instance = Type.Map.find elt_ty monomorphized_instances.arrays in
let array_instance = Type.Map.find ty0 monomorphized_instances.arrays in
EStruct
{
name = array_instance.name;

View File

@ -219,6 +219,7 @@ let rec format_typ (fmt : Format.formatter) (typ : typ) : unit =
in
match Mark.remove typ with
| TLit l -> Format.fprintf fmt "%a" Print.tlit l
| TTuple [] -> Format.fprintf fmt "unit"
| TTuple ts ->
Format.fprintf fmt "@[<hov 2>(%a)@]"
(Format.pp_print_list
@ -239,7 +240,7 @@ let rec format_typ (fmt : Format.formatter) (typ : typ) : unit =
(t1 @ [t2])
| TArray t1 -> Format.fprintf fmt "@[%a@ array@]" format_typ_with_parens t1
| TAny -> Format.fprintf fmt "_"
| TClosureEnv -> failwith "unimplemented!"
| TClosureEnv -> Format.fprintf fmt "Obj.t"
let format_var_str (fmt : Format.formatter) (v : string) : unit =
let lowercase_name = String.to_snake_case (String.to_ascii v) in
@ -408,21 +409,6 @@ let rec format_expr (ctx : decl_ctx) (fmt : Format.formatter) (e : 'm expr) :
format_with_parens arg1
| EAppOp { op = Log _, _; args = [arg1]; _ } ->
Format.fprintf fmt "%a" format_with_parens arg1
| EAppOp
{
op = ((HandleDefault | HandleDefaultOpt) as op), _;
args = (EArray excs, _) :: _ as args;
_;
} ->
let pos = List.map Expr.pos excs in
Format.fprintf fmt "@[<hov 2>%s@ [|%a|]@ %a@]"
(Print.operator_to_string op)
(Format.pp_print_list
~pp_sep:(fun ppf () -> Format.fprintf ppf ";@ ")
format_pos)
pos
(Format.pp_print_list ~pp_sep:Format.pp_print_space format_with_parens)
args
| EApp { f; args; _ } ->
Format.fprintf fmt "@[<hov 2>%a@ %a@]" format_with_parens f
(Format.pp_print_list
@ -442,6 +428,12 @@ let rec format_expr (ctx : decl_ctx) (fmt : Format.formatter) (e : 'm expr) :
Format.fprintf ppf "%a@ " format_pos pos
| Div_int_int | Div_rat_rat | Div_mon_mon | Div_mon_rat | Div_dur_dur ->
Format.fprintf ppf "%a@ " format_pos (Expr.pos (List.nth args 1))
| HandleExceptions ->
Format.fprintf ppf "[|@[<hov>%a@]|]@ "
(Format.pp_print_list
~pp_sep:(fun ppf () -> Format.fprintf ppf ";@ ")
format_pos)
(List.map Expr.pos args)
| _ -> ())
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ")
@ -456,10 +448,6 @@ let rec format_expr (ctx : decl_ctx) (fmt : Format.formatter) (e : 'm expr) :
| EFatalError er ->
Format.fprintf fmt "raise@ (Runtime_ocaml.Runtime.Error (%a, [%a]))"
Print.runtime_error er format_pos (Expr.pos e)
| ERaiseEmpty -> Format.fprintf fmt "raise Empty"
| ECatchEmpty { body; handler } ->
Format.fprintf fmt "@[<hv>@[<hov 2>try@ %a@]@ with Empty ->@]@ @[%a@]"
format_with_parens body format_with_parens handler
| _ -> .
let format_struct_embedding
@ -716,9 +704,21 @@ let commands = if commands = [] then entry_scopes else commands
name format_var var name)
scopes_with_no_input
let reexport_used_modules fmt modules =
let check_and_reexport_used_modules fmt ~hashf modules =
List.iter
(fun m ->
(fun (m, intf_id) ->
Format.fprintf fmt
"@[<hv 2>let () =@ @[<hov 2>match Runtime_ocaml.Runtime.check_module \
%S \"%a\"@ with@]@,\
| Ok () -> ()@,\
@[<hv 2>| Error h -> failwith \"Hash mismatch for module %a, it may \
need recompiling\"@]@]@,"
(ModuleName.to_string m)
(fun ppf h ->
if intf_id.is_external then
Format.pp_print_string ppf Hash.external_placeholder
else Hash.format ppf h)
(hashf intf_id.hash) ModuleName.format m;
Format.fprintf fmt "@[<hv 2>module %a@ = %a@]@," ModuleName.format m
ModuleName.format m)
modules
@ -726,7 +726,9 @@ let reexport_used_modules fmt modules =
let format_module_registration
fmt
(bnd : ('m Ast.expr Var.t * _) String.Map.t)
modname =
modname
hash
is_external =
Format.pp_open_vbox fmt 2;
Format.pp_print_string fmt "let () =";
Format.pp_print_space fmt ();
@ -743,11 +745,17 @@ let format_module_registration
(fun fmt (id, (var, _)) ->
Format.fprintf fmt "@[<hov 2>%S,@ Obj.repr %a@]" id format_var var)
fmt (String.Map.to_seq bnd);
(* TODO: pass the visibility info down from desugared, and filter what is
exported here *)
Format.pp_close_box fmt ();
Format.pp_print_char fmt ' ';
Format.pp_print_string fmt "]";
Format.pp_print_space fmt ();
Format.pp_print_string fmt "\"todo-module-hash\"";
Format.fprintf fmt "\"%a\""
(fun ppf h ->
if is_external then Format.pp_print_string ppf Hash.external_placeholder
else Hash.format ppf h)
hash;
Format.pp_close_box fmt ();
Format.pp_close_box fmt ();
Format.pp_print_newline fmt ()
@ -766,17 +774,21 @@ let format_program
(fmt : Format.formatter)
?exec_scope
?(exec_args = true)
~(hashf : Hash.t -> Hash.full)
(p : 'm Ast.program)
(type_ordering : Scopelang.Dependency.TVertex.t list) : unit =
Format.pp_open_vbox fmt 0;
Format.pp_print_string fmt header;
reexport_used_modules fmt (Program.modules_to_list p.decl_ctx.ctx_modules);
check_and_reexport_used_modules fmt ~hashf
(Program.modules_to_list p.decl_ctx.ctx_modules);
format_ctx type_ordering fmt p.decl_ctx;
let bnd = format_code_items p.decl_ctx fmt p.code_items in
Format.pp_print_cut fmt ();
let () =
match p.module_name, exec_scope with
| Some modname, None -> format_module_registration fmt bnd modname
| Some (modname, intf_id), None ->
format_module_registration fmt bnd modname (hashf intf_id.hash)
intf_id.is_external
| None, Some scope_name ->
let scope_body = Program.get_scope_body p scope_name in
format_scope_exec p.decl_ctx fmt bnd scope_name scope_body

View File

@ -14,6 +14,7 @@
License for the specific language governing permissions and limitations under
the License. *)
open Catala_utils
open Shared_ast
(** Formats a lambda calculus program into a valid OCaml program *)
@ -40,6 +41,7 @@ val format_program :
Format.formatter ->
?exec_scope:ScopeName.t ->
?exec_args:bool ->
hashf:(Hash.t -> Hash.full) ->
'm Ast.program ->
Scopelang.Dependency.TVertex.t list ->
unit

View File

@ -317,7 +317,8 @@ let rec law_structure_to_latex
Format.fprintf fmt
"\\begin{tcolorbox}[colframe=OliveGreen, breakable, \
title=\\textcolor{black}{\\texttt{%s}},title after \
break=\\textcolor{black}{\\texttt{%s}},before skip=1em, after skip=1em]\n\
break=\\textcolor{black}{\\texttt{%s}},before skip=1em, after skip=1em, \
left=0.6em, right=0.6em]\n\
%a\n\
\\end{tcolorbox}"
metadata_title metadata_title

View File

@ -471,15 +471,13 @@ let run
output
optimize
check_invariants
avoid_exceptions
closure_conversion
monomorphize_types
_options =
let options = Global.enforce_options ~trace:true () in
let prg, type_ordering =
Driver.Passes.lcalc options ~includes ~optimize ~check_invariants
~avoid_exceptions ~closure_conversion ~typed:Expr.typed
~monomorphize_types
~closure_conversion ~typed:Expr.typed ~monomorphize_types
in
let jsoo_output_file, with_formatter =
Driver.Commands.get_output_format options ~ext:"_api_web.ml" output
@ -489,7 +487,7 @@ let run
(Option.value ~default:"stdout" jsoo_output_file);
let modname =
match prg.module_name with
| Some m -> ModuleName.to_string m
| Some (m, _) -> ModuleName.to_string m
| None ->
String.capitalize_ascii
Filename.(
@ -506,7 +504,6 @@ let term =
$ Cli.Flags.output
$ Cli.Flags.optimize
$ Cli.Flags.check_invariants
$ Cli.Flags.avoid_exceptions
$ Cli.Flags.closure_conversion
$ Cli.Flags.monomorphize_types

View File

@ -1085,8 +1085,7 @@ let expr_to_dot_label0 :
| Reduce -> xlang () ~en:"reduce" ~fr:"réunion"
| Filter -> xlang () ~en:"filter" ~fr:"filtre"
| Fold -> xlang () ~en:"fold" ~fr:"pliage"
| HandleDefault -> ""
| HandleDefaultOpt -> ""
| HandleExceptions -> ""
| ToClosureEnv -> ""
| FromClosureEnv -> ""
in
@ -1381,7 +1380,8 @@ let run includes optimize ex_scope explain_options global_options =
Driver.Passes.dcalc global_options ~includes ~optimize
~check_invariants:false ~typed:Expr.typed
in
Interpreter.load_runtime_modules prg;
Interpreter.load_runtime_modules prg
~hashf:(Hash.finalise ~closure_conversion:false ~monomorphize_types:false);
let scope = Driver.Commands.get_scope_uid prg.decl_ctx ex_scope in
(* let result_expr, env = interpret_program prg scope in *)
let g, base_vars, env = program_to_graph explain_options prg scope in

View File

@ -210,15 +210,13 @@ let run
output
optimize
check_invariants
avoid_exceptions
closure_conversion
monomorphize_types
ex_scope
options =
let prg, _ =
Driver.Passes.lcalc options ~includes ~optimize ~check_invariants
~avoid_exceptions ~closure_conversion ~typed:Expr.typed
~monomorphize_types
~closure_conversion ~typed:Expr.typed ~monomorphize_types
in
let output_file, with_output =
Driver.Commands.get_output_format options ~ext:"_schema.json" output
@ -239,7 +237,6 @@ let term =
$ Cli.Flags.output
$ Cli.Flags.optimize
$ Cli.Flags.check_invariants
$ Cli.Flags.avoid_exceptions
$ Cli.Flags.closure_conversion
$ Cli.Flags.monomorphize_types
$ Cli.Flags.ex_scope

View File

@ -271,7 +271,8 @@ let run includes optimize check_invariants ex_scope options =
Driver.Passes.dcalc options ~includes ~optimize ~check_invariants
~typed:Expr.typed
in
Interpreter.load_runtime_modules prg;
Interpreter.load_runtime_modules prg
~hashf:(Hash.finalise ~closure_conversion:false ~monomorphize_types:false);
let scope = Driver.Commands.get_scope_uid prg.decl_ctx ex_scope in
let result_expr, _env = interpret_program prg scope in
let fmt = Format.std_formatter in

View File

@ -22,20 +22,12 @@
open Catala_utils
let run
includes
output
optimize
check_invariants
avoid_exceptions
closure_conversion
options =
let run includes output optimize check_invariants closure_conversion options =
let open Driver.Commands in
let prg, type_ordering =
Driver.Passes.scalc options ~includes ~optimize ~check_invariants
~avoid_exceptions ~closure_conversion ~keep_special_ops:false
~dead_value_assignment:true ~no_struct_literals:false
~monomorphize_types:false
~closure_conversion ~keep_special_ops:false ~dead_value_assignment:true
~no_struct_literals:false ~monomorphize_types:false
in
let output_file, with_output = get_output_format options ~ext:".py" output in
@ -50,7 +42,6 @@ let term =
$ Cli.Flags.output
$ Cli.Flags.optimize
$ Cli.Flags.check_invariants
$ Cli.Flags.avoid_exceptions
$ Cli.Flags.closure_conversion
let () =

View File

@ -33,10 +33,6 @@ module VarName =
end)
()
let dead_value = VarName.fresh ("dead_value", Pos.no_pos)
let handle_default = FuncName.fresh ("handle_default", Pos.no_pos)
let handle_default_opt = FuncName.fresh ("handle_default_opt", Pos.no_pos)
type operator = Shared_ast.lcalc Shared_ast.operator
type expr = naked_expr Mark.pos
@ -121,5 +117,5 @@ type ctx = { decl_ctx : decl_ctx; modules : VarName.t ModuleName.Map.t }
type program = {
ctx : ctx;
code_items : code_item list;
module_name : ModuleName.t option;
module_name : (ModuleName.t * module_intf_id) option;
}

View File

@ -35,13 +35,6 @@ type 'm ctxt = {
program_ctx : A.ctx;
}
let unthunk e =
match Mark.remove e with
| EAbs { binder; tys = [(TLit TUnit, _)] } ->
let _, e = Bindlib.unmbind binder in
e
| _ -> failwith "should not happen"
(* Expressions can spill out side effect, hence this function also returns a
list of statements to be prepended before the expression is evaluated *)
@ -138,15 +131,6 @@ and translate_expr (ctxt : 'm ctxt) (expr : 'm L.expr) : RevBlock.t * A.expr =
| ETupleAccess { e = e1; index; _ } ->
let e1_stmts, new_e1 = translate_expr ctxt e1 in
e1_stmts, (A.ETupleAccess { e1 = new_e1; index }, Expr.pos expr)
| EAppOp
{
op = Op.HandleDefaultOpt, _;
args = [_exceptions; _just; _cons];
tys = _;
}
when ctxt.config.keep_special_ops ->
(* This should be translated as a statement *)
raise (NotAnExpr { needs_a_local_decl = true })
| EAppOp { op; args; tys = _ } ->
let args_stmts, new_args = translate_expr_list ctxt args in
(* FIXME: what happens if [arg] is not a tuple but reduces to one ? *)
@ -227,8 +211,7 @@ and translate_expr (ctxt : 'm ctxt) (expr : 'm L.expr) : RevBlock.t * A.expr =
Expr.pos expr )
in
RevBlock.empty, (EExternal { modname; name }, Expr.pos expr)
| ECatchEmpty _ | EAbs _ | EIfThenElse _ | EMatch _ | EAssert _
| EFatalError _ | ERaiseEmpty ->
| EAbs _ | EIfThenElse _ | EMatch _ | EAssert _ | EFatalError _ ->
raise (NotAnExpr { needs_a_local_decl = true })
| _ -> .
with NotAnExpr { needs_a_local_decl } ->
@ -274,60 +257,60 @@ and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.expr) : A.block =
~tail:[A.SAssert (Mark.remove new_e), Expr.pos block_expr]
e_stmts
| EFatalError err -> [SFatalError err, Expr.pos block_expr]
| EAppOp
{ op = Op.HandleDefaultOpt, _; tys = _; args = [exceptions; just; cons] }
when ctxt.config.keep_special_ops ->
let exceptions =
match Mark.remove exceptions with
| EStruct { fields; _ } -> (
let _, exceptions =
List.find
(fun (field, _) ->
String.equal (Mark.remove (StructField.get_info field)) "content")
(StructField.Map.bindings fields)
in
match Mark.remove exceptions with
| EArray exceptions -> exceptions
| _ -> failwith "should not happen")
| _ -> failwith "should not happen"
in
let just = unthunk just in
let cons = unthunk cons in
let exceptions_stmts, new_exceptions =
translate_expr_list ctxt exceptions
in
let just_stmts, new_just = translate_expr ctxt just in
let cons_stmts, new_cons = translate_expr ctxt cons in
RevBlock.rebuild exceptions_stmts
~tail:
(RevBlock.rebuild just_stmts
~tail:
[
( A.SSpecialOp
(OHandleDefaultOpt
{
exceptions = new_exceptions;
just = new_just;
cons =
RevBlock.rebuild cons_stmts
~tail:
[
( (match ctxt.inside_definition_of with
| None -> A.SReturn (Mark.remove new_cons)
| Some x ->
A.SLocalDef
{
name = Mark.copy new_cons x;
expr = new_cons;
typ =
Expr.maybe_ty (Mark.get block_expr);
}),
Expr.pos block_expr );
];
return_typ = Expr.maybe_ty (Mark.get block_expr);
}),
Expr.pos block_expr );
])
(* | EAppOp
* { op = Op.HandleDefaultOpt, _; tys = _; args = [exceptions; just; cons] }
* when ctxt.config.keep_special_ops ->
* let exceptions =
* match Mark.remove exceptions with
* | EStruct { fields; _ } -> (
* let _, exceptions =
* List.find
* (fun (field, _) ->
* String.equal (Mark.remove (StructField.get_info field)) "content")
* (StructField.Map.bindings fields)
* in
* match Mark.remove exceptions with
* | EArray exceptions -> exceptions
* | _ -> failwith "should not happen")
* | _ -> failwith "should not happen"
* in
* let just = unthunk just in
* let cons = unthunk cons in
* let exceptions_stmts, new_exceptions =
* translate_expr_list ctxt exceptions
* in
* let just_stmts, new_just = translate_expr ctxt just in
* let cons_stmts, new_cons = translate_expr ctxt cons in
* RevBlock.rebuild exceptions_stmts
* ~tail:
* (RevBlock.rebuild just_stmts
* ~tail:
* [
* ( A.SSpecialOp
* (OHandleDefaultOpt
* {
* exceptions = new_exceptions;
* just = new_just;
* cons =
* RevBlock.rebuild cons_stmts
* ~tail:
* [
* ( (match ctxt.inside_definition_of with
* | None -> A.SReturn (Mark.remove new_cons)
* | Some x ->
* A.SLocalDef
* {
* name = Mark.copy new_cons x;
* expr = new_cons;
* typ =
* Expr.maybe_ty (Mark.get block_expr);
* }),
* Expr.pos block_expr );
* ];
* return_typ = Expr.maybe_ty (Mark.get block_expr);
* }),
* Expr.pos block_expr );
* ]) *)
| EApp { f = EAbs { binder; tys }, binder_mark; args; _ } ->
(* This defines multiple local variables at the time *)
let binder_pos = Expr.mark_pos binder_mark in
@ -483,29 +466,6 @@ and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.expr) : A.block =
},
Expr.pos block_expr );
]
| ECatchEmpty { body; handler } ->
let s_e_try = translate_statements ctxt body in
let s_e_catch = translate_statements ctxt handler in
[
( A.STryWEmpty { try_block = s_e_try; with_block = s_e_catch },
Expr.pos block_expr );
]
| ERaiseEmpty ->
(* Before raising the exception, we still give a dummy definition to the
current variable so that tools like mypy don't complain. *)
(match ctxt.inside_definition_of with
| Some x when ctxt.config.dead_value_assignment ->
[
( A.SLocalDef
{
name = x, Expr.pos block_expr;
expr = Ast.EVar Ast.dead_value, Expr.pos block_expr;
typ = Expr.maybe_ty (Mark.get block_expr);
},
Expr.pos block_expr );
]
| _ -> [])
@ [A.SRaiseEmpty, Expr.pos block_expr]
| EInj { e = e1; cons; name } when ctxt.config.no_struct_literals ->
let e1_stmts, new_e1 = translate_expr ctxt e1 in
let tmp_struct_var_name =
@ -659,7 +619,7 @@ let translate_program ~(config : translation_config) (p : 'm L.program) :
A.program =
let modules =
List.fold_left
(fun acc m ->
(fun acc (m, _) ->
let vname = Mark.map (( ^ ) "Module_") (ModuleName.get_info m) in
(* The "Module_" prefix is a workaround name clashes for same-name
structs and modules, Python in particular mixes everything in one

View File

@ -21,10 +21,10 @@ open Ast
let needs_parens (_e : expr) : bool = false
let format_var_name (fmt : Format.formatter) (v : VarName.t) : unit =
Format.fprintf fmt "%a_%d" VarName.format v (VarName.hash v)
Format.fprintf fmt "%a_%d" VarName.format v (VarName.id v)
let format_func_name (fmt : Format.formatter) (v : FuncName.t) : unit =
Format.fprintf fmt "@{<green>%a_%d@}" FuncName.format v (FuncName.hash v)
Format.fprintf fmt "@{<green>%a_%d@}" FuncName.format v (FuncName.id v)
let rec format_expr
(decl_ctx : decl_ctx)
@ -53,7 +53,7 @@ let rec format_expr
(StructField.Map.bindings es)
Print.punctuation "}"
| ETuple es ->
Format.fprintf fmt "@[<hov 2>%a%a%a@]" Print.punctuation "()"
Format.fprintf fmt "@[<hov 2>%a%a%a@]" Print.punctuation "("
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ")
(fun fmt e -> Format.fprintf fmt "%a" format_expr e))
@ -233,21 +233,11 @@ let format_item decl_ctx ?debug ppf def =
Format.pp_print_cut ppf ()
let format_program ?debug ppf prg =
let decl_ctx =
(* TODO: this is redundant with From_dcalc.add_option_type (which is already
applied in avoid_exceptions mode) *)
{
prg.ctx.decl_ctx with
ctx_enums =
EnumName.Map.add Expr.option_enum Expr.option_enum_config
prg.ctx.decl_ctx.ctx_enums;
}
in
Format.pp_open_vbox ppf 0;
ModuleName.Map.iter
(fun m var ->
Format.fprintf ppf "%a %a = %a@," Print.keyword "module" format_var_name
var ModuleName.format m)
prg.ctx.modules;
Format.pp_print_list (format_item decl_ctx ?debug) ppf prg.code_items;
Format.pp_print_list (format_item prg.ctx.decl_ctx ?debug) ppf prg.code_items;
Format.pp_close_box ppf ()

View File

@ -96,11 +96,11 @@ let string_counter_map : int IntMap.t StringMap.t ref = ref StringMap.empty
let format_var (fmt : Format.formatter) (v : VarName.t) : unit =
let v_str = Mark.remove (VarName.get_info v) in
let hash = VarName.hash v in
let id = VarName.id v in
let local_id =
match StringMap.find_opt v_str !string_counter_map with
| Some ids -> (
match IntMap.find_opt hash ids with
match IntMap.find_opt id ids with
| None ->
let max_id =
snd
@ -111,13 +111,13 @@ let format_var (fmt : Format.formatter) (v : VarName.t) : unit =
in
string_counter_map :=
StringMap.add v_str
(IntMap.add hash (max_id + 1) ids)
(IntMap.add id (max_id + 1) ids)
!string_counter_map;
max_id + 1
| Some local_id -> local_id)
| None ->
string_counter_map :=
StringMap.add v_str (IntMap.singleton hash 0) !string_counter_map;
StringMap.add v_str (IntMap.singleton id 0) !string_counter_map;
0
in
if v_str = "_" then Format.fprintf fmt "dummy_var"
@ -313,9 +313,8 @@ let format_op (fmt : Format.formatter) (op : operator Mark.pos) : unit =
| Reduce -> Format.pp_print_string fmt "catala_list_reduce"
| Filter -> Format.pp_print_string fmt "catala_list_filter"
| Fold -> Format.pp_print_string fmt "catala_list_fold_left"
| HandleDefault -> Format.pp_print_string fmt "catala_handle_default"
| HandleDefaultOpt | FromClosureEnv | ToClosureEnv | Map2 ->
failwith "unimplemented"
| HandleExceptions -> Format.pp_print_string fmt "catala_handle_exceptions"
| FromClosureEnv | ToClosureEnv | Map2 -> failwith "unimplemented"
let _format_string_list (fmt : Format.formatter) (uids : string list) : unit =
let sanitize_quotes = Re.compile (Re.char '"') in
@ -350,6 +349,8 @@ let rec format_expression (ctx : decl_ctx) (fmt : Format.formatter) (e : expr) :
failwith
"should not happen, array initialization is caught at the statement level"
| ELit l -> Format.fprintf fmt "%a" format_lit (Mark.copy e l)
| EAppOp { op = (ToClosureEnv | FromClosureEnv), _; args = [arg] } ->
format_expression ctx fmt arg
| EAppOp { op = ((Map | Filter), _) as op; args = [arg1; arg2] } ->
Format.fprintf fmt "%a(%a,@ %a)" format_op op (format_expression ctx) arg1
(format_expression ctx) arg2
@ -366,8 +367,6 @@ let rec format_expression (ctx : decl_ctx) (fmt : Format.formatter) (e : expr) :
Format.fprintf fmt "%a %a" format_op op (format_expression ctx) arg1
| EAppOp { op; args = [arg1] } ->
Format.fprintf fmt "%a(%a)" format_op op (format_expression ctx) arg1
| EAppOp { op = (HandleDefaultOpt | HandleDefault), _; args = _ } ->
failwith "should not happen because of keep_special_ops"
| EApp { f; args } ->
Format.fprintf fmt "%a(@[<hov 0>%a)@]" (format_expression ctx) f
(Format.pp_print_list
@ -441,19 +440,14 @@ let rec format_statement
| SFatalError err ->
let pos = Mark.get s in
Format.fprintf fmt
"catala_fatal_error_raised.code = catala_%s;@,\
catala_fatal_error_raised.position.filename = \"%s\";@,\
catala_fatal_error_raised.position.start_line = %d;@,\
catala_fatal_error_raised.position.start_column = %d;@,\
catala_fatal_error_raised.position.end_line = %d;@,\
catala_fatal_error_raised.position.end_column = %d;@,\
longjmp(catala_fatal_error_jump_buffer, 0);"
"@[<hov 2>catala_raise_fatal_error (catala_%s,@ \"%s\",@ %d, %d, %d, \
%d);@]"
(String.to_snake_case (Runtime.error_to_string err))
(Pos.get_file pos) (Pos.get_start_line pos) (Pos.get_start_column pos)
(Pos.get_end_line pos) (Pos.get_end_column pos)
| SIfThenElse { if_expr = cond; then_block = b1; else_block = b2 } ->
Format.fprintf fmt
"@[<hov 2>if (%a) {@\n%a@]@\n@[<hov 2>} else {@\n%a@]@\n}"
"@[<hv 2>@[<hov 2>if (%a) {@]@,%a@,@;<1 -2>} else {@,%a@,@;<1 -2>}@]"
(format_expression ctx) cond (format_block ctx) b1 (format_block ctx) b2
| SSwitch { switch_expr = e1; enum_name = e_name; switch_cases = cases; _ } ->
let cases =
@ -463,34 +457,33 @@ let rec format_statement
(EnumConstructor.Map.bindings (EnumName.Map.find e_name ctx.ctx_enums))
in
let tmp_var = VarName.fresh ("match_arg", Pos.no_pos) in
Format.fprintf fmt "@[<hov 2>%a %a = %a;@]@\n@[<hov 2>if %a@]@\n}"
format_enum_name e_name format_var tmp_var (format_expression ctx) e1
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@]@\n@[<hov 2>} else if ")
(fun fmt ({ case_block; payload_var_name; payload_var_typ }, cons_name) ->
Format.fprintf fmt "(%a.code == %a_%a) {@\n%a = %a.payload.%a;@\n%a"
format_var tmp_var format_enum_name e_name format_enum_cons_name
cons_name
(format_typ ctx (fun fmt -> format_var fmt payload_var_name))
payload_var_typ format_var tmp_var format_enum_cons_name cons_name
(format_block ctx) case_block))
cases
Format.fprintf fmt "@[<hov 2>%a %a = %a;@]@," format_enum_name e_name
format_var tmp_var (format_expression ctx) e1;
Format.pp_open_vbox fmt 2;
Format.fprintf fmt "@[<hov 4>switch (%a.code) {@]@," format_var tmp_var;
Format.pp_print_list
(fun fmt ({ case_block; payload_var_name; payload_var_typ }, cons_name) ->
Format.fprintf fmt "@[<hv 2>case %a_%a:@ " format_enum_name e_name
format_enum_cons_name cons_name;
if not (Type.equal payload_var_typ (TLit TUnit, Pos.no_pos)) then
Format.fprintf fmt "%a = %a.payload.%a;@ "
(format_typ ctx (fun fmt -> format_var fmt payload_var_name))
payload_var_typ format_var tmp_var format_enum_cons_name cons_name;
Format.fprintf fmt "%a@ break;@]" (format_block ctx) case_block)
fmt cases;
(* Do we want to add 'default' case with a failure ? *)
Format.fprintf fmt "@;<0 -2>}";
Format.pp_close_box fmt ()
| SReturn e1 ->
Format.fprintf fmt "@[<hov 2>return %a;@]" (format_expression ctx)
(e1, Mark.get s)
| SAssert e1 ->
let pos = Mark.get s in
Format.fprintf fmt
"@[<hov 2>if (!(%a)) {@\n\
catala_fatal_error_raised.code = catala_assertion_failure;@,\
catala_fatal_error_raised.position.filename = \"%s\";@,\
catala_fatal_error_raised.position.start_line = %d;@,\
catala_fatal_error_raised.position.start_column = %d;@,\
catala_fatal_error_raised.position.end_line = %d;@,\
catala_fatal_error_raised.position.end_column = %d;@,\
longjmp(catala_fatal_error_jump_buffer, 0);@,\
}"
(format_expression ctx)
"@[<v 2>@[<hov 2>if (!(%a)) {@]@,\
@[<hov 2>catala_raise_fatal_error (catala_assertion_failed,@ \"%s\",@ \
%d, %d, %d, %d);@]@;\
<1 -2>}@]" (format_expression ctx)
(e1, Mark.get s)
(Pos.get_file pos) (Pos.get_start_line pos) (Pos.get_start_column pos)
(Pos.get_end_line pos) (Pos.get_end_column pos)
@ -548,14 +541,9 @@ let rec format_statement
exceptions;
Format.fprintf fmt
"@[<v 2>if (%a) {@,\
catala_fatal_error_raised.code = catala_conflict;@,\
catala_fatal_error_raised.position.filename = \"%s\";@,\
catala_fatal_error_raised.position.start_line = %d;@,\
catala_fatal_error_raised.position.start_column = %d;@,\
catala_fatal_error_raised.position.end_line = %d;@,\
catala_fatal_error_raised.position.end_column = %d;@,\
longjmp(catala_fatal_error_jump_buffer, 0);@]@,\
}@,"
@[<hov 2>catala_raise_fatal_error(catala_conflict,@ \"%s\",@ %d, %d, \
%d, %d);@]@;\
<1 -2>}@]@,"
format_var exception_conflict (Pos.get_file pos)
(Pos.get_start_line pos) (Pos.get_start_column pos)
(Pos.get_end_line pos) (Pos.get_end_column pos);

View File

@ -88,8 +88,7 @@ let format_op (fmt : Format.formatter) (op : operator Mark.pos) : unit =
| Reduce -> Format.pp_print_string fmt "list_reduce"
| Filter -> Format.pp_print_string fmt "list_filter"
| Fold -> Format.pp_print_string fmt "list_fold_left"
| HandleDefault -> Format.pp_print_string fmt "handle_default"
| HandleDefaultOpt -> Format.pp_print_string fmt "handle_default_opt"
| HandleExceptions -> Format.pp_print_string fmt "handle_exceptions"
| FromClosureEnv | ToClosureEnv -> failwith "unimplemented"
let format_uid_list (fmt : Format.formatter) (uids : Uid.MarkedString.info list)
@ -152,20 +151,20 @@ let string_counter_map : int IntMap.t StringMap.t ref = ref StringMap.empty
let format_var (fmt : Format.formatter) (v : VarName.t) : unit =
let v_str = clean_name (Mark.remove (VarName.get_info v)) in
let hash = VarName.hash v in
let id = VarName.id v in
let local_id =
match StringMap.find_opt v_str !string_counter_map with
| Some ids -> (
match IntMap.find_opt hash ids with
match IntMap.find_opt id ids with
| None ->
let id = 1 + IntMap.fold (fun _ -> Int.max) ids 0 in
let local_id = 1 + IntMap.fold (fun _ -> Int.max) ids 0 in
string_counter_map :=
StringMap.add v_str (IntMap.add hash id ids) !string_counter_map;
id
StringMap.add v_str (IntMap.add id local_id ids) !string_counter_map;
local_id
| Some local_id -> local_id)
| None ->
string_counter_map :=
StringMap.add v_str (IntMap.singleton hash 0) !string_counter_map;
StringMap.add v_str (IntMap.singleton id 0) !string_counter_map;
0
in
if v_str = "_" then Format.fprintf fmt "_"
@ -347,41 +346,27 @@ let rec format_expression ctx (fmt : Format.formatter) (e : expr) : unit =
args = [arg1];
} ->
Format.fprintf fmt "%a %a" format_op op (format_expression ctx) arg1
| EAppOp { op = (HandleExceptions, _) as op; args = [(EArray el, _)] as args }
->
Format.fprintf fmt "@[<hv 4>%a(@,[%a],@ %a@;<0 -4>)@]" format_op op
(Format.pp_print_list
~pp_sep:(fun ppf () -> Format.fprintf ppf ",@ ")
format_position)
(List.map Mark.get el)
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ")
(format_expression ctx))
args
| EAppOp { op; args = [arg1] } ->
Format.fprintf fmt "%a(%a)" format_op op (format_expression ctx) arg1
| EAppOp { op = ((HandleDefault | HandleDefaultOpt), _) as op; args } ->
let pos = Mark.get e in
Format.fprintf fmt
"%a(@[<hov 0>SourcePosition(filename=\"%s\",@ start_line=%d,@ \
start_column=%d,@ end_line=%d, end_column=%d,@ law_headings=%a), %a)@]"
format_op op (Pos.get_file pos) (Pos.get_start_line pos)
(Pos.get_start_column pos) (Pos.get_end_line pos) (Pos.get_end_column pos)
format_string_list (Pos.get_law_info pos)
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ")
(format_expression ctx))
args
| EApp { f = EFunc x, pos; args }
when Ast.FuncName.compare x Ast.handle_default = 0
|| Ast.FuncName.compare x Ast.handle_default_opt = 0 ->
Format.fprintf fmt
"%a(@[<hov 0>SourcePosition(filename=\"%s\",@ start_line=%d,@ \
start_column=%d,@ end_line=%d, end_column=%d,@ law_headings=%a), %a)@]"
format_func_name x (Pos.get_file pos) (Pos.get_start_line pos)
(Pos.get_start_column pos) (Pos.get_end_line pos) (Pos.get_end_column pos)
format_string_list (Pos.get_law_info pos)
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ")
(format_expression ctx))
args
| EApp { f; args } ->
Format.fprintf fmt "%a(@[<hov 0>%a)@]" (format_expression ctx) f
Format.fprintf fmt "%a(@[<hv 0>%a)@]" (format_expression ctx) f
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ")
(format_expression ctx))
args
| EAppOp { op; args } ->
Format.fprintf fmt "%a(@[<hov 0>%a)@]" format_op op
Format.fprintf fmt "%a(@[<hv 0>%a)@]" format_op op
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ")
(format_expression ctx))
@ -402,10 +387,10 @@ let rec format_statement ctx (fmt : Format.formatter) (s : stmt Mark.pos) : unit
=
match Mark.remove s with
| SInnerFuncDef { name; func = { func_params; func_body; _ } } ->
Format.fprintf fmt "@[<hov 4>def %a(%a):@\n%a@]" format_var
Format.fprintf fmt "@[<v 4>def %a(@[<hov>%a@]):@ %a@]" format_var
(Mark.remove name)
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ", ")
~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ")
(fun fmt (var, typ) ->
Format.fprintf fmt "%a:%a" format_var (Mark.remove var)
(format_typ ctx) typ))
@ -414,16 +399,16 @@ let rec format_statement ctx (fmt : Format.formatter) (s : stmt Mark.pos) : unit
assert false (* We don't need to declare variables in Python *)
| SLocalDef { name = v; expr = e; _ } | SLocalInit { name = v; expr = e; _ }
->
Format.fprintf fmt "@[<hov 4>%a = %a@]" format_var (Mark.remove v)
Format.fprintf fmt "@[<hv 4>%a = %a@]" format_var (Mark.remove v)
(format_expression ctx) e
| STryWEmpty { try_block = try_b; with_block = catch_b } ->
Format.fprintf fmt "@[<v 4>try:@,%a@]@\n@[<v 4>except Empty:@,%a@]"
Format.fprintf fmt "@[<v 4>try:@ %a@]@,@[<v 4>except Empty:@ %a@]"
(format_block ctx) try_b (format_block ctx) catch_b
| SRaiseEmpty -> Format.fprintf fmt "raise Empty"
| SFatalError err ->
Format.fprintf fmt "@[<hov 4>raise %a@]" format_error (err, Mark.get s)
| SIfThenElse { if_expr = cond; then_block = b1; else_block = b2 } ->
Format.fprintf fmt "@[<hov 4>if %a:@\n%a@]@\n@[<hov 4>else:@\n%a@]"
Format.fprintf fmt "@[<v 4>if %a:@ %a@]@,@[<v 4>else:@ %a@]"
(format_expression ctx) cond (format_block ctx) b1 (format_block ctx) b2
| SSwitch
{
@ -439,11 +424,11 @@ let rec format_statement ctx (fmt : Format.formatter) (s : stmt Mark.pos) : unit
when EnumName.equal e_name Expr.option_enum ->
(* We translate the option type with an overloading by Python's [None] *)
let tmp_var = VarName.fresh ("perhaps_none_arg", Pos.no_pos) in
Format.fprintf fmt "%a = %a@\n" format_var tmp_var (format_expression ctx)
e1;
Format.fprintf fmt "@[<v 4>if %a is None:@\n%a@]@\n" format_var tmp_var
Format.fprintf fmt "@[<hv 4>%a = %a@]@," format_var tmp_var
(format_expression ctx) e1;
Format.fprintf fmt "@[<v 4>if %a is None:@ %a@]@," format_var tmp_var
(format_block ctx) case_none;
Format.fprintf fmt "@[<v 4>else:@\n%a = %a@\n%a@]" format_var case_some_var
Format.fprintf fmt "@[<v 4>else:@ %a = %a@,%a@]" format_var case_some_var
format_var tmp_var (format_block ctx) case_some
| SSwitch { switch_expr = e1; enum_name = e_name; switch_cases = cases; _ } ->
let cons_map = EnumName.Map.find e_name ctx.decl_ctx.ctx_enums in
@ -470,10 +455,10 @@ let rec format_statement ctx (fmt : Format.formatter) (s : stmt Mark.pos) : unit
| SAssert e1 ->
let pos = Mark.get s in
Format.fprintf fmt
"@[<hov 4>if not (%a):@\n\
raise AssertionFailure(@[<hov 0>SourcePosition(@[<hov \
0>filename=\"%s\",@ start_line=%d,@ start_column=%d,@ end_line=%d,@ \
end_column=%d,@ law_headings=@[<hv>%a@])@])@]@]"
"@[<hv 4>if not (%a):@,\
raise AssertionFailure(@[<hov>SourcePosition(@[<hov 0>filename=\"%s\",@ \
start_line=%d,@ start_column=%d,@ end_line=%d,@ end_column=%d,@ \
law_headings=@[<hv>%a@])@])@]@]"
(format_expression ctx)
(e1, Mark.get s)
(Pos.get_file pos) (Pos.get_start_line pos) (Pos.get_start_column pos)
@ -482,12 +467,14 @@ let rec format_statement ctx (fmt : Format.formatter) (s : stmt Mark.pos) : unit
| SSpecialOp _ -> failwith "should not happen"
and format_block ctx (fmt : Format.formatter) (b : block) : unit =
Format.pp_open_vbox fmt 0;
Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n")
~pp_sep:(fun fmt () -> Format.fprintf fmt "@,")
(format_statement ctx) fmt
(List.filter
(fun s -> match Mark.remove s with SLocalDecl _ -> false | _ -> true)
b)
b);
Format.pp_close_box fmt ()
let format_ctx
(type_ordering : Scopelang.Dependency.TVertex.t list)
@ -496,20 +483,20 @@ let format_ctx
let format_struct_decl fmt (struct_name, struct_fields) =
let fields = StructField.Map.bindings struct_fields in
Format.fprintf fmt
"class %a:@\n\
\ def __init__(self, %a) -> None:@\n\
%a@\n\
@\n\
\ def __eq__(self, other: object) -> bool:@\n\
\ if isinstance(other, %a):@\n\
\ return @[<hov>(%a)@]@\n\
\ else:@\n\
\ return False@\n\
@\n\
\ def __ne__(self, other: object) -> bool:@\n\
\ return not (self == other)@\n\
@\n\
\ def __str__(self) -> str:@\n\
"class %a:@,\
\ def __init__(self, %a) -> None:@,\
%a@,\
@,\
\ def __eq__(self, other: object) -> bool:@,\
\ if isinstance(other, %a):@,\
\ return @[<hov>(%a)@]@,\
\ else:@,\
\ return False@,\
@,\
\ def __ne__(self, other: object) -> bool:@,\
\ return not (self == other)@,\
@,\
\ def __str__(self) -> str:@,\
\ @[<hov 4>return \"%a(%a)\".format(%a)@]" (format_struct_name ctx)
struct_name
(Format.pp_print_list
@ -521,9 +508,7 @@ let format_ctx
(if StructField.Map.is_empty struct_fields then fun fmt _ ->
Format.fprintf fmt " pass"
else
Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n")
(fun fmt (struct_field, _) ->
Format.pp_print_list (fun fmt (struct_field, _) ->
Format.fprintf fmt " self.%a = %a" format_struct_field_name
struct_field format_struct_field_name struct_field))
fields (format_struct_name ctx) struct_name
@ -551,32 +536,30 @@ let format_ctx
failwith "no constructors in the enum"
else
Format.fprintf fmt
"@[<hov 4>class %a_Code(Enum):@\n\
%a@]@\n\
@\n\
class %a:@\n\
\ def __init__(self, code: %a_Code, value: Any) -> None:@\n\
\ self.code = code@\n\
\ self.value = value@\n\
@\n\
@\n\
\ def __eq__(self, other: object) -> bool:@\n\
\ if isinstance(other, %a):@\n\
"@[<v 4>class %a_Code(Enum):@,\
%a@]@,\
@,\
class %a:@,\
\ def __init__(self, code: %a_Code, value: Any) -> None:@,\
\ self.code = code@,\
\ self.value = value@,\
@,\
@,\
\ def __eq__(self, other: object) -> bool:@,\
\ if isinstance(other, %a):@,\
\ return self.code == other.code and self.value == \
other.value@\n\
\ else:@\n\
\ return False@\n\
@\n\
@\n\
\ def __ne__(self, other: object) -> bool:@\n\
\ return not (self == other)@\n\
@\n\
\ def __str__(self) -> str:@\n\
other.value@,\
\ else:@,\
\ return False@,\
@,\
@,\
\ def __ne__(self, other: object) -> bool:@,\
\ return not (self == other)@,\
@,\
\ def __str__(self) -> str:@,\
\ @[<hov 4>return \"{}({})\".format(self.code, self.value)@]"
(format_enum_name ctx) enum_name
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n")
(fun fmt (i, enum_cons, _enum_cons_type) ->
(Format.pp_print_list (fun fmt (i, enum_cons, _enum_cons_type) ->
Format.fprintf fmt "%a = %d" format_enum_cons_name enum_cons i))
(List.mapi
(fun i (x, y) -> i, x, y)
@ -606,11 +589,11 @@ let format_ctx
match struct_or_enum with
| Scopelang.Dependency.TVertex.Struct s ->
if StructName.path s = [] then
Format.fprintf fmt "%a@\n@\n" format_struct_decl
Format.fprintf fmt "%a@,@," format_struct_decl
(s, StructName.Map.find s ctx.decl_ctx.ctx_structs)
| Scopelang.Dependency.TVertex.Enum e ->
if EnumName.path e = [] then
Format.fprintf fmt "%a@\n@\n" format_enum_decl
Format.fprintf fmt "%a@,@," format_enum_decl
(e, EnumName.Map.find e ctx.decl_ctx.ctx_enums))
(type_ordering @ scope_structs)
@ -626,14 +609,15 @@ let reserve_func_name = function
let format_code_item ctx fmt = function
| SVar { var; expr; typ = _ } ->
Format.fprintf fmt "@[<hv 4>%a = (@,%a@,@])@," format_var var
Format.fprintf fmt "@[<hv 4>%a = (@,%a@;<0 -4>)@]@," format_var var
(format_expression ctx) expr
| SFunc { var; func }
| SScope { scope_body_var = var; scope_body_func = func; _ } ->
let { Ast.func_params; Ast.func_body; _ } = func in
Format.fprintf fmt "@[<hv 4>def %a(%a):@\n%a@]@," format_func_name var
Format.fprintf fmt "@[<v 4>@[<hov 2>def %a(@,%a@;<0 -2>):@]@ %a@]@,"
format_func_name var
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ", ")
~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ")
(fun fmt (var, typ) ->
Format.fprintf fmt "%a:%a" format_var (Mark.remove var)
(format_typ ctx) typ))

View File

@ -1,584 +0,0 @@
(* This file is part of the Catala compiler, a specification language for tax
and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
Denis Merigoux <denis.merigoux@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not
use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *)
open Catala_utils
open Shared_ast
open Ast
module Runtime = Runtime_ocaml.Runtime
module D = Dcalc.Ast
module L = Lcalc.Ast
let format_lit (fmt : Format.formatter) (l : lit Mark.pos) : unit =
match Mark.remove l with
| LBool true -> Format.pp_print_string fmt "TRUE"
| LBool false -> Format.pp_print_string fmt "FALSE"
| LInt i ->
if Z.fits_nativeint i then
Format.fprintf fmt "catala_integer_from_numeric(%s)"
(Runtime.integer_to_string i)
else
Format.fprintf fmt "catala_integer_from_string(\"%s\")"
(Runtime.integer_to_string i)
| LUnit -> Format.pp_print_string fmt "new(\"catala_unit\",v=0)"
| LRat i ->
Format.fprintf fmt "catala_decimal_from_fraction(%s,%s)"
(if Z.fits_nativeint (Q.num i) then Z.to_string (Q.num i)
else "\"" ^ Z.to_string (Q.num i) ^ "\"")
(if Z.fits_nativeint (Q.den i) then Z.to_string (Q.den i)
else "\"" ^ Z.to_string (Q.den i) ^ "\"")
| LMoney e ->
if Z.fits_nativeint e then
Format.fprintf fmt "catala_money_from_cents(%s)"
(Runtime.integer_to_string (Runtime.money_to_cents e))
else
Format.fprintf fmt "catala_money_from_cents(\"%s\")"
(Runtime.integer_to_string (Runtime.money_to_cents e))
| LDate d ->
Format.fprintf fmt "catala_date_from_ymd(%d,%d,%d)"
(Runtime.integer_to_int (Runtime.year_of_date d))
(Runtime.integer_to_int (Runtime.month_number_of_date d))
(Runtime.integer_to_int (Runtime.day_of_month_of_date d))
| LDuration d ->
let years, months, days = Runtime.duration_to_years_months_days d in
Format.fprintf fmt "catala_duration_from_ymd(%d,%d,%d)" years months days
let format_op (fmt : Format.formatter) (op : operator Mark.pos) : unit =
match Mark.remove op with
| Log (_entry, _infos) -> assert false
| Minus_int | Minus_rat | Minus_mon | Minus_dur ->
Format.pp_print_string fmt "-"
(* Todo: use the names from [Operator.name] *)
| Not -> Format.pp_print_string fmt "!"
| Length -> Format.pp_print_string fmt "catala_list_length"
| ToRat_int -> Format.pp_print_string fmt "catala_decimal_from_integer"
| ToRat_mon -> Format.pp_print_string fmt "catala_decimal_from_money"
| ToMoney_rat -> Format.pp_print_string fmt "catala_money_from_decimal"
| GetDay -> Format.pp_print_string fmt "catala_day_of_month_of_date"
| GetMonth -> Format.pp_print_string fmt "catala_month_number_of_date"
| GetYear -> Format.pp_print_string fmt "catala_year_of_date"
| FirstDayOfMonth ->
Format.pp_print_string fmt "catala_date_first_day_of_month"
| LastDayOfMonth -> Format.pp_print_string fmt "catala_date_last_day_of_month"
| Round_mon -> Format.pp_print_string fmt "catala_money_round"
| Round_rat -> Format.pp_print_string fmt "catala_decimal_round"
| Add_int_int | Add_rat_rat | Add_mon_mon | Add_dat_dur _ | Add_dur_dur
| Concat ->
Format.pp_print_string fmt "+"
| Sub_int_int | Sub_rat_rat | Sub_mon_mon | Sub_dat_dat | Sub_dat_dur
| Sub_dur_dur ->
Format.pp_print_string fmt "-"
| Mult_int_int | Mult_rat_rat | Mult_mon_rat | Mult_dur_int ->
Format.pp_print_string fmt "*"
| Div_int_int | Div_rat_rat | Div_mon_mon | Div_mon_rat | Div_dur_dur ->
Format.pp_print_string fmt "/"
| And -> Format.pp_print_string fmt "&&"
| Or -> Format.pp_print_string fmt "||"
| Eq -> Format.pp_print_string fmt "=="
| Xor -> Format.pp_print_string fmt "!="
| Lt_int_int | Lt_rat_rat | Lt_mon_mon | Lt_dat_dat | Lt_dur_dur ->
Format.pp_print_string fmt "<"
| Lte_int_int | Lte_rat_rat | Lte_mon_mon | Lte_dat_dat | Lte_dur_dur ->
Format.pp_print_string fmt "<="
| Gt_int_int | Gt_rat_rat | Gt_mon_mon | Gt_dat_dat | Gt_dur_dur ->
Format.pp_print_string fmt ">"
| Gte_int_int | Gte_rat_rat | Gte_mon_mon | Gte_dat_dat | Gte_dur_dur ->
Format.pp_print_string fmt ">="
| Eq_int_int | Eq_rat_rat | Eq_mon_mon | Eq_dat_dat | Eq_dur_dur ->
Format.pp_print_string fmt "=="
| Map -> Format.pp_print_string fmt "catala_list_map"
| Map2 -> Format.pp_print_string fmt "catala_list_map2"
| Reduce -> Format.pp_print_string fmt "catala_list_reduce"
| Filter -> Format.pp_print_string fmt "catala_list_filter"
| Fold -> Format.pp_print_string fmt "catala_list_fold_left"
| HandleDefault -> Format.pp_print_string fmt "catala_handle_default"
| HandleDefaultOpt | FromClosureEnv | ToClosureEnv -> failwith "unimplemented"
let format_string_list (fmt : Format.formatter) (uids : string list) : unit =
let sanitize_quotes = Re.compile (Re.char '"') in
Format.fprintf fmt "c(%a)"
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ")
(fun fmt info ->
Format.fprintf fmt "\"%s\""
(Re.replace sanitize_quotes ~f:(fun _ -> "\\\"") info)))
uids
let avoid_keywords (s : string) : string =
if
match s with
(* list taken from
https://cran.r-project.org/doc/manuals/r-release/R-lang.html#Reserved-words *)
| "if" | "else" | "repeat" | "while" | "function" | "for" | "in" | "next"
| "break" | "TRUE" | "FALSE" | "NULL" | "Inf" | "NaN" | "NA" | "NA_integer_"
| "NA_real_" | "NA_complex_" | "NA_character_"
(* additions of things that are not keywords but that we should not
overwrite*)
| "list" | "c" | "character" | "logical" | "complex" | "setClass" | "new" ->
true
| _ -> false
then s ^ "_"
else s
let format_struct_name (fmt : Format.formatter) (v : StructName.t) : unit =
Format.fprintf fmt "%s"
(avoid_keywords
(String.to_camel_case
(String.to_ascii (Format.asprintf "%a" StructName.format v))))
let format_struct_field_name (fmt : Format.formatter) (v : StructField.t) : unit
=
Format.fprintf fmt "%s"
(avoid_keywords
(String.to_ascii (Format.asprintf "%a" StructField.format v)))
let format_enum_name (fmt : Format.formatter) (v : EnumName.t) : unit =
Format.fprintf fmt "%s"
(avoid_keywords
(String.to_camel_case
(String.to_ascii (Format.asprintf "%a" EnumName.format v))))
let format_enum_cons_name (fmt : Format.formatter) (v : EnumConstructor.t) :
unit =
Format.fprintf fmt "%s"
(avoid_keywords
(String.to_ascii (Format.asprintf "%a" EnumConstructor.format v)))
let rec format_typ ~inside_comment (fmt : Format.formatter) (typ : typ) : unit =
let format_typ = format_typ in
match Mark.remove typ with
| TLit TUnit -> Format.fprintf fmt "\"catala_unit\""
| TLit TMoney -> Format.fprintf fmt "\"catala_money\""
| TLit TInt -> Format.fprintf fmt "\"catala_integer\""
| TLit TRat -> Format.fprintf fmt "\"catala_decimal\""
| TLit TDate -> Format.fprintf fmt "\"catala_date\""
| TLit TDuration -> Format.fprintf fmt "\"catala_duration\""
| TLit TBool -> Format.fprintf fmt "\"logical\""
| TTuple ts ->
Format.fprintf fmt "\"list\"@ # tuple(%a)%t"
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ",@;")
(format_typ ~inside_comment:true))
ts
(fun fmt -> if inside_comment then () else Format.pp_force_newline fmt ())
| TStruct s -> Format.fprintf fmt "\"catala_struct_%a\"" format_struct_name s
| TOption some_typ | TDefault some_typ ->
(* We loose track of optional value as they're crammed into NULL *)
format_typ ~inside_comment:false fmt some_typ
| TEnum e -> Format.fprintf fmt "\"catala_enum_%a\"" format_enum_name e
| TArrow (t1, t2) ->
Format.fprintf fmt "\"function\" # %a -> %a%t"
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ", ")
(format_typ ~inside_comment:true))
t1
(format_typ ~inside_comment:true)
t2
(fun fmt -> if inside_comment then () else Format.pp_force_newline fmt ())
| TArray t1 ->
Format.fprintf fmt "\"list\" # array(%a)%t"
(format_typ ~inside_comment:true) t1 (fun fmt ->
if inside_comment then () else Format.pp_force_newline fmt ())
| TAny -> Format.fprintf fmt "\"ANY\""
| TClosureEnv -> failwith "unimplemented!"
let format_name_cleaned (fmt : Format.formatter) (s : string) : unit =
s
|> String.to_ascii
|> String.to_snake_case
|> Re.Pcre.substitute ~rex:(Re.Pcre.regexp "\\.") ~subst:(fun _ -> "_dot_")
|> String.to_ascii
|> avoid_keywords
|> Format.fprintf fmt "%s"
module StringMap = String.Map
module IntMap = Map.Make (struct
include Int
let format ppf i = Format.pp_print_int ppf i
end)
(** For each `VarName.t` defined by its string and then by its hash, we keep
track of which local integer id we've given it. This is used to keep
variable naming with low indices rather than one global counter for all
variables. TODO: should be removed when
https://github.com/CatalaLang/catala/issues/240 is fixed. *)
let string_counter_map : int IntMap.t StringMap.t ref = ref StringMap.empty
let format_var (fmt : Format.formatter) (v : VarName.t) : unit =
let v_str = Mark.remove (VarName.get_info v) in
let hash = VarName.hash v in
let local_id =
match StringMap.find_opt v_str !string_counter_map with
| Some ids -> (
match IntMap.find_opt hash ids with
| None ->
let max_id =
snd
(List.hd
(List.fast_sort
(fun (_, x) (_, y) -> Int.compare y x)
(IntMap.bindings ids)))
in
string_counter_map :=
StringMap.add v_str
(IntMap.add hash (max_id + 1) ids)
!string_counter_map;
max_id + 1
| Some local_id -> local_id)
| None ->
string_counter_map :=
StringMap.add v_str (IntMap.singleton hash 0) !string_counter_map;
0
in
if v_str = "_" then Format.fprintf fmt "dummy_var"
(* special case for the unit pattern TODO escape dummy_var *)
else if local_id = 0 then format_name_cleaned fmt v_str
else Format.fprintf fmt "%a_%d" format_name_cleaned v_str local_id
let format_func_name (fmt : Format.formatter) (v : FuncName.t) : unit =
let v_str = Mark.remove (FuncName.get_info v) in
format_name_cleaned fmt v_str
let format_position ppf pos =
Format.fprintf ppf
"@[<hov 2>catala_position(@,\
filename=\"%s\",@ start_line=%d, start_column=%d,@ end_line=%d, \
end_column=%d,@ law_headings=%a@;\
<0 -2>)@]" (Pos.get_file pos) (Pos.get_start_line pos)
(Pos.get_start_column pos) (Pos.get_end_line pos) (Pos.get_end_column pos)
format_string_list (Pos.get_law_info pos)
let format_error (ppf : Format.formatter) (err : Runtime.error Mark.pos) : unit
=
let pos = Mark.get err in
let tag = String.to_snake_case (Runtime.error_to_string (Mark.remove err)) in
Format.fprintf ppf "%s(%a)" tag format_position pos
let rec format_expression (ctx : decl_ctx) (fmt : Format.formatter) (e : expr) :
unit =
match Mark.remove e with
| EVar v -> format_var fmt v
| EFunc f -> format_func_name fmt f
| EStruct { fields = es; name = s } ->
Format.fprintf fmt "new(\"catala_struct_%a\",@ %a)" format_struct_name s
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ")
(fun fmt (struct_field, e) ->
Format.fprintf fmt "%a = %a" format_struct_field_name struct_field
(format_expression ctx) e))
(StructField.Map.bindings es)
| EStructFieldAccess { e1; field; _ } ->
Format.fprintf fmt "%a@%a" (format_expression ctx) e1
format_struct_field_name field
| EInj { cons; name = e_name; _ }
when EnumName.equal e_name Expr.option_enum
&& EnumConstructor.equal cons Expr.none_constr ->
(* We translate the option type with an overloading by R's [NULL] *)
Format.fprintf fmt "NULL"
| EInj { e1 = e; cons; name = e_name; _ }
when EnumName.equal e_name Expr.option_enum
&& EnumConstructor.equal cons Expr.some_constr ->
(* We translate the option type with an overloading by R's [NULL] *)
format_expression ctx fmt e
| EInj { e1 = e; cons; name = enum_name; _ } ->
Format.fprintf fmt "new(\"catala_enum_%a\", code = \"%a\",@ value = %a)"
format_enum_name enum_name format_enum_cons_name cons
(format_expression ctx) e
| EArray es ->
Format.fprintf fmt "list(%a)"
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ")
(fun fmt e -> Format.fprintf fmt "%a" (format_expression ctx) e))
es
| ELit l -> Format.fprintf fmt "%a" format_lit (Mark.copy e l)
| EAppOp { op = ((Map | Filter), _) as op; args = [arg1; arg2] } ->
Format.fprintf fmt "%a(%a,@ %a)" format_op op (format_expression ctx) arg1
(format_expression ctx) arg2
| EAppOp { op; args = [arg1; arg2] } ->
Format.fprintf fmt "(%a %a@ %a)" (format_expression ctx) arg1 format_op op
(format_expression ctx) arg2
| EAppOp { op = (Not, _) as op; args = [arg1] } ->
Format.fprintf fmt "%a %a" format_op op (format_expression ctx) arg1
| EAppOp
{
op = ((Minus_int | Minus_rat | Minus_mon | Minus_dur), _) as op;
args = [arg1];
} ->
Format.fprintf fmt "%a %a" format_op op (format_expression ctx) arg1
| EAppOp { op; args = [arg1] } ->
Format.fprintf fmt "%a(%a)" format_op op (format_expression ctx) arg1
| EAppOp { op = HandleDefaultOpt, _; _ } ->
Message.error ~internal:true
"R compilation does not currently support the avoiding of exceptions"
| EAppOp { op = (HandleDefault as op), _; args; _ } ->
let pos = Mark.get e in
Format.fprintf fmt
"%a(@[<hov 0>catala_position(filename=\"%s\",@ start_line=%d,@ \
start_column=%d,@ end_line=%d, end_column=%d,@ law_headings=%a), %a)@]"
format_op (op, pos) (Pos.get_file pos) (Pos.get_start_line pos)
(Pos.get_start_column pos) (Pos.get_end_line pos) (Pos.get_end_column pos)
format_string_list (Pos.get_law_info pos)
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ")
(format_expression ctx))
args
| EApp { f = EFunc x, pos; args }
when Ast.FuncName.compare x Ast.handle_default = 0
|| Ast.FuncName.compare x Ast.handle_default_opt = 0 ->
Format.fprintf fmt
"%a(@[<hov 0>catala_position(filename=\"%s\",@ start_line=%d,@ \
start_column=%d,@ end_line=%d, end_column=%d,@ law_headings=%a), %a)@]"
format_func_name x (Pos.get_file pos) (Pos.get_start_line pos)
(Pos.get_start_column pos) (Pos.get_end_line pos) (Pos.get_end_column pos)
format_string_list (Pos.get_law_info pos)
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ")
(format_expression ctx))
args
| EApp { f; args } ->
Format.fprintf fmt "%a(@[<hov 0>%a)@]" (format_expression ctx) f
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ")
(format_expression ctx))
args
| EAppOp { op; args } ->
Format.fprintf fmt "%a(@[<hov 0>%a)@]" format_op op
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ")
(format_expression ctx))
args
| ETuple args ->
Format.fprintf fmt "list(@[<hov 0>%a)@]"
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ")
(format_expression ctx))
args
| ETupleAccess { e1; index } ->
Format.fprintf fmt "(%a)[%d]" (format_expression ctx) e1 index
| EExternal _ -> failwith "TODO"
let rec format_statement
(ctx : decl_ctx)
(fmt : Format.formatter)
(s : stmt Mark.pos) : unit =
match Mark.remove s with
| SInnerFuncDef { name; func = { func_params; func_body; _ } } ->
Format.fprintf fmt "@[<hov 2>%a <- function(@\n%a) {@\n%a@]@\n}" format_var
(Mark.remove name)
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n,@;")
(fun fmt (var, typ) ->
Format.fprintf fmt "%a# (%a)@\n" format_var (Mark.remove var)
(format_typ ~inside_comment:true)
typ))
func_params (format_block ctx) func_body
| SLocalDecl _ ->
assert false (* We don't need to declare variables in Python *)
| SLocalDef { name = v; expr = e; _ } | SLocalInit { name = v; expr = e; _ }
->
Format.fprintf fmt "@[<hov 2>%a <- %a@]" format_var (Mark.remove v)
(format_expression ctx) e
| STryWEmpty { try_block = try_b; with_block = catch_b } ->
Format.fprintf fmt
(* TODO escape dummy__arg*)
"@[<hov 2>tryCatch(@[<hov 2>{@;\
%a@;\
}@],@;\
catala_empty_error() = function(dummy__arg) @[<hov 2>{@;\
%a@;\
}@])@]"
(format_block ctx) try_b (format_block ctx) catch_b
| SRaiseEmpty -> Format.pp_print_string fmt "stop(catala_empty_error())"
| SFatalError err ->
Format.fprintf fmt "@[<hov 2>stop(%a)@]" format_error (err, Mark.get s)
| SIfThenElse { if_expr = cond; then_block = b1; else_block = b2 } ->
Format.fprintf fmt
"@[<hov 2>if (%a) {@\n%a@]@\n@[<hov 2>} else {@\n%a@]@\n}"
(format_expression ctx) cond (format_block ctx) b1 (format_block ctx) b2
| SSwitch
{
switch_expr = e1;
enum_name = e_name;
switch_cases =
[
{ case_block = case_none; _ };
{ case_block = case_some; payload_var_name = case_some_var; _ };
];
_;
}
when EnumName.equal e_name Expr.option_enum ->
(* We translate the option type with an overloading by Python's [None] *)
let tmp_var = VarName.fresh ("perhaps_none_arg", Pos.no_pos) in
Format.fprintf fmt
"%a <- %a@\n\
@[<hov 2>if (is.null(%a)) {@\n\
%a@]@\n\
@[<hov 2>} else {@\n\
%a = %a@\n\
%a@]@\n\
}"
format_var tmp_var (format_expression ctx) e1 format_var tmp_var
(format_block ctx) case_none format_var case_some_var format_var tmp_var
(format_block ctx) case_some
| SSwitch { switch_expr = e1; enum_name = e_name; switch_cases = cases; _ } ->
let cases =
List.map2
(fun x (cons, _) -> x, cons)
cases
(EnumConstructor.Map.bindings (EnumName.Map.find e_name ctx.ctx_enums))
in
let tmp_var = VarName.fresh ("match_arg", Pos.no_pos) in
Format.fprintf fmt "@[<hov 2>%a <- %a@]@\n@[<hov 2>if %a@]@\n}" format_var
tmp_var (format_expression ctx) e1
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@]@\n@[<hov 2>} else if ")
(fun fmt ({ case_block; payload_var_name; _ }, cons_name) ->
Format.fprintf fmt "(%a@code == \"%a\") {@\n%a <- %a@value@\n%a"
format_var tmp_var format_enum_cons_name cons_name format_var
payload_var_name format_var tmp_var (format_block ctx) case_block))
cases
| SReturn e1 ->
Format.fprintf fmt "@[<hov 2>return(%a)@]" (format_expression ctx)
(e1, Mark.get s)
| SAssert e1 ->
let pos = Mark.get s in
Format.fprintf fmt
"@[<hov 2>if (!(%a)) {@\n\
stop(catala_assertion_failure(@[<hov 0>catala_position(@[<hov \
0>filename=\"%s\",@ start_line=%d,@ start_column=%d,@ end_line=%d,@ \
end_column=%d,@ law_headings=@[<hv>%a@])@])@])@]@\n\
}"
(format_expression ctx)
(e1, Mark.get s)
(Pos.get_file pos) (Pos.get_start_line pos) (Pos.get_start_column pos)
(Pos.get_end_line pos) (Pos.get_end_column pos) format_string_list
(Pos.get_law_info pos)
| SSpecialOp _ -> failwith "should not happen"
and format_block (ctx : decl_ctx) (fmt : Format.formatter) (b : block) : unit =
Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n")
(format_statement ctx) fmt
(List.filter
(fun s -> match Mark.remove s with SLocalDecl _ -> false | _ -> true)
b)
let format_ctx
(type_ordering : Scopelang.Dependency.TVertex.t list)
(fmt : Format.formatter)
(ctx : decl_ctx) : unit =
let format_struct_decl fmt (struct_name, struct_fields) =
let fields = StructField.Map.bindings struct_fields in
Format.fprintf fmt
"@[<hov 2>setClass(@,\
\"catala_struct_%a\",@;\
representation@[<hov 2>(%a)@]@\n\
)@]"
format_struct_name struct_name
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ",@;")
(fun fmt (struct_field, typ) ->
Format.fprintf fmt "%a = %a" format_struct_field_name struct_field
(format_typ ~inside_comment:false)
typ))
fields
in
let format_enum_decl fmt (enum_name, enum_cons) =
if EnumConstructor.Map.is_empty enum_cons then
failwith "no constructors in the enum"
else
Format.fprintf fmt
"# Enum cases: %a@\n\
@[<hov 2>setClass(@,\
\"catala_enum_%a\",@;\
representation@[<hov 2>(code =@;\
\"character\",@;\
value =@;\
\"ANY\")@])@]"
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ", ")
(fun fmt (enum_cons, enum_cons_type) ->
Format.fprintf fmt "\"%a\" (%a)" format_enum_cons_name enum_cons
(format_typ ~inside_comment:false)
enum_cons_type))
(EnumConstructor.Map.bindings enum_cons)
format_enum_name enum_name
in
let is_in_type_ordering s =
List.exists
(fun struct_or_enum ->
match struct_or_enum with
| Scopelang.Dependency.TVertex.Enum _ -> false
| Scopelang.Dependency.TVertex.Struct s' -> s = s')
type_ordering
in
let scope_structs =
List.map
(fun (s, _) -> Scopelang.Dependency.TVertex.Struct s)
(StructName.Map.bindings
(StructName.Map.filter
(fun s _ -> not (is_in_type_ordering s))
ctx.ctx_structs))
in
List.iter
(fun struct_or_enum ->
match struct_or_enum with
| Scopelang.Dependency.TVertex.Struct s ->
Format.fprintf fmt "%a@\n@\n" format_struct_decl
(s, StructName.Map.find s ctx.ctx_structs)
| Scopelang.Dependency.TVertex.Enum e ->
Format.fprintf fmt "%a@\n@\n" format_enum_decl
(e, EnumName.Map.find e ctx.ctx_enums))
(type_ordering @ scope_structs)
let format_program
(fmt : Format.formatter)
(p : Ast.program)
(type_ordering : Scopelang.Dependency.TVertex.t list) : unit =
(* We disable the style flag in order to enjoy formatting from the
pretty-printers of Dcalc and Lcalc but without the color terminal
markers. *)
Format.fprintf fmt
"@[<v># This file has been generated by the Catala compiler, do not edit!@,\
@,\
library(catalaRuntime)@,\
@,\
@[<v>%a@]@,\
@,\
%a@]@?"
(format_ctx type_ordering) p.ctx.decl_ctx
(Format.pp_print_list ~pp_sep:Format.pp_print_newline (fun fmt -> function
| SVar { var; expr; typ = _ } ->
Format.fprintf fmt "@[<hv 2>%a <- (@,%a@,@])@," format_var var
(format_expression p.ctx.decl_ctx)
expr
| SFunc { var; func }
| SScope { scope_body_var = var; scope_body_func = func; _ } ->
let { Ast.func_params; Ast.func_body; _ } = func in
Format.fprintf fmt "@[<hv 2>%a <- function(@\n%a) {@\n%a@]@\n}@,"
format_func_name var
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n,@;")
(fun fmt (var, typ) ->
Format.fprintf fmt "%a# (%a)@\n" format_var (Mark.remove var)
(format_typ ~inside_comment:true)
typ))
func_params
(format_block p.ctx.decl_ctx)
func_body))
p.code_items

View File

@ -1,21 +0,0 @@
(* This file is part of the Catala compiler, a specification language for tax
and social benefits computation rules. Copyright (C) 2021 Inria, contributor:
Denis Merigoux <denis.merigoux@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not
use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *)
(** Formats a lambda calculus program into a valid R program *)
val format_program :
Format.formatter -> Ast.program -> Scopelang.Dependency.TVertex.t list -> unit
(** Usage [format_program fmt p type_dependencies_ordering] *)

View File

@ -40,13 +40,13 @@ let rec locations_used (e : 'm expr) : LocationSet.t =
type 'm rule =
| ScopeVarDefinition of {
var : ScopeVar.t Mark.pos;
var : (ScopeVar.t, Pos.t list) Mark.ed;
typ : typ;
io : Desugared.Ast.io;
e : 'm expr;
}
| SubScopeVarDefinition of {
var : ScopeVar.t Mark.pos;
var : (ScopeVar.t, Pos.t list) Mark.ed;
var_within_origin_scope : ScopeVar.t;
typ : typ;
e : 'm expr;
@ -67,7 +67,7 @@ type 'm scope_decl = {
}
type 'm program = {
program_module_name : ModuleName.t option;
program_module_name : (ModuleName.t * module_intf_id) option;
program_ctx : decl_ctx;
program_modules : nil scope_decl Mark.pos ScopeName.Map.t ModuleName.Map.t;
program_scopes : 'm scope_decl Mark.pos ScopeName.Map.t;
@ -145,3 +145,5 @@ let type_program (type m) (prg : m program) : typed program =
prg.program_scopes
in
{ prg with program_topdefs; program_scopes }
let type_program prg = Message.with_delayed_errors (fun () -> type_program prg)

View File

@ -33,16 +33,14 @@ val locations_used : 'm expr -> LocationSet.t
type 'm rule =
| ScopeVarDefinition of {
var : ScopeVar.t Mark.pos;
var : ScopeVar.t * Pos.t list;
(** Scope variable and its list of definitions' positions *)
typ : typ;
io : Desugared.Ast.io;
e : 'm expr;
}
| SubScopeVarDefinition of {
var : ScopeVar.t Mark.pos; (** Variable within the current scope *)
(* scope: ScopeVar.t Mark.pos; (\** Variable pointing to the *\) *)
(* origin_var: ScopeVar.t Mark.pos;
* reentrant: bool; *)
var : ScopeVar.t * Pos.t list; (** Variable within the current scope *)
var_within_origin_scope : ScopeVar.t;
typ : typ; (* non-thunked at this point for reentrant vars *)
e : 'm expr;
@ -63,12 +61,13 @@ type 'm scope_decl = {
}
type 'm program = {
program_module_name : ModuleName.t option;
program_module_name : (ModuleName.t * module_intf_id) option;
program_ctx : decl_ctx;
program_modules : nil scope_decl Mark.pos ScopeName.Map.t ModuleName.Map.t;
(* Using [nil] here ensure that program interfaces don't contain any
expressions. They won't contain any rules or topdefs, but will still have
the scope signatures needed to respect the call convention *)
expressions. They won't contain any rules or topdef implementations, but
will still have the scope signatures needed to respect the call
convention *)
program_scopes : 'm scope_decl Mark.pos ScopeName.Map.t;
program_topdefs : ('m expr * typ) TopdefName.Map.t;
program_lang : Global.backend_lang;

View File

@ -42,9 +42,7 @@ module SVertex = struct
| Topdef g1, Topdef g2 -> TopdefName.equal g1 g2
| (Scope _ | Topdef _), _ -> false
let hash = function
| Scope s -> ScopeName.hash s
| Topdef g -> TopdefName.hash g
let hash = function Scope s -> ScopeName.id s | Topdef g -> TopdefName.id g
let format ppf = function
| Scope s -> ScopeName.format ppf s
@ -206,7 +204,9 @@ module TVertex = struct
type t = Struct of StructName.t | Enum of EnumName.t
let hash x =
match x with Struct x -> StructName.hash x | Enum x -> EnumName.hash x
match x with
| Struct x -> StructName.id x
| Enum x -> Hashtbl.hash (`Enum (EnumName.id x))
let compare x y =
match x, y with

View File

@ -637,17 +637,20 @@ let translate_rule
(exc_graphs :
Desugared.Dependency.ExceptionsDependencies.t D.ScopeDef.Map.t) = function
| Desugared.Dependency.Vertex.Var (var, state) -> (
let pos = Mark.get (ScopeVar.get_info var) in
(* TODO: this may point to the place where the variable was declared instead
of the binding in the definition being explored. Needs double-checking
and maybe adding more position information *)
let decl_pos = Mark.get (ScopeVar.get_info var) in
let scope_def_key =
{
D.ScopeDef.scope_def_var_within_scope = var, pos;
D.ScopeDef.scope_def_var_within_scope = var, Pos.no_pos;
scope_def_kind = D.ScopeDef.ScopeVarKind state;
}
in
let scope_def = D.ScopeDef.Map.find scope_def_key scope.scope_defs in
let all_def_pos =
List.map
(fun r -> Mark.get (RuleName.get_info r))
(RuleName.Map.keys scope_def.scope_def_rules)
in
let scope_def = D.ScopeDef.Map.find scope_def_key scope.scope_defs in
match ScopeVar.Map.find_opt var scope.scope_sub_scopes with
| None -> (
let var_def = scope_def.D.scope_def_rules in
@ -660,6 +663,12 @@ let translate_rule
| OnlyInput -> []
(* we do not provide any definition for an input-only variable *)
| _ ->
let scope_def_key =
{
D.ScopeDef.scope_def_var_within_scope = var, decl_pos;
scope_def_kind = D.ScopeDef.ScopeVarKind state;
}
in
let expr_def =
translate_def ctx scope_def_key var_def var_params var_typ
scope_def.D.scope_def_io
@ -675,7 +684,7 @@ let translate_rule
[
Ast.ScopeVarDefinition
{
var = scope_var, pos;
var = Mark.add all_def_pos scope_var;
typ = var_typ;
io = scope_def.D.scope_def_io;
e = Expr.unbox expr_def;
@ -767,7 +776,8 @@ let translate_rule
scope.scope_defs ScopeVar.Map.empty
in
let subscope_expr =
Expr.escopecall ~scope:subscope ~args:subscope_params (Untyped { pos })
Expr.escopecall ~scope:subscope ~args:subscope_params
(Untyped { pos = decl_pos })
in
assert (RuleName.Map.is_empty scope_def.D.scope_def_rules);
(* The subscope will be defined by its inputs, it's not supposed to have
@ -781,7 +791,7 @@ let translate_rule
let subscope_def =
Ast.ScopeVarDefinition
{
var = subscope_var_dcalc, pos;
var = Mark.add all_def_pos subscope_var_dcalc;
typ =
( TStruct scope_info.out_struct_name,
Mark.get (ScopeVar.get_info var) );
@ -1046,8 +1056,9 @@ let translate_program
let program_topdefs =
TopdefName.Map.mapi
(fun id -> function
| Some e, ty -> Expr.unbox (translate_expr ctx e), ty
| None, (_, pos) ->
| { D.topdef_expr = Some e; topdef_type = ty; topdef_visibility = _ } ->
Expr.unbox (translate_expr ctx e), ty
| { D.topdef_expr = None; topdef_type = _, pos; _ } ->
Message.error ~pos "No definition found for %a" TopdefName.format id)
desugared.program_root.module_topdefs
in
@ -1057,8 +1068,7 @@ let translate_program
desugared.D.program_root.module_scopes
in
{
Ast.program_module_name =
Option.map ModuleName.fresh desugared.D.program_module_name;
Ast.program_module_name = desugared.D.program_module_name;
Ast.program_topdefs;
Ast.program_scopes;
Ast.program_ctx = ctx.decl_ctx;

View File

@ -20,6 +20,13 @@ type ('e, 'elt, 'last) t = ('e, 'elt, 'last) bound_list =
| Last of 'last
| Cons of 'elt * ('e, ('e, 'elt, 'last) t) binder
let rec to_seq = function
| Last () -> Seq.empty
| Cons (item, next_bind) ->
fun () ->
let v, next = Bindlib.unbind next_bind in
Seq.Cons ((v, item), to_seq next)
let rec last = function
| Last e -> e
| Cons (_, bnd) ->

View File

@ -30,6 +30,7 @@ type ('e, 'elt, 'last) t = ('e, 'elt, 'last) bound_list =
| Last of 'last
| Cons of 'elt * ('e, ('e, 'elt, 'last) t) binder
val to_seq : (((_, _) gexpr as 'e), 'elt, unit) t -> ('e Var.t * 'elt) Seq.t
val last : (_, _, 'a) t -> 'a
val iter : f:('e Var.t -> 'elt -> unit) -> ('e, 'elt, 'last) t -> 'last
val find : f:('elt -> 'a option) -> (_, 'elt, _) t -> 'a

View File

@ -137,7 +137,6 @@ type desugared =
; explicitScopes : yes
; assertions : no
; defaultTerms : yes
; exceptions : no
; custom : no >
(* Technically, desugared before name resolution has [syntacticNames: yes;
resolvedNames: no], and after name resolution has the opposite; but the
@ -158,7 +157,6 @@ type scopelang =
; explicitScopes : yes
; assertions : no
; defaultTerms : yes
; exceptions : no
; custom : no >
type dcalc =
@ -172,7 +170,6 @@ type dcalc =
; explicitScopes : no
; assertions : yes
; defaultTerms : yes
; exceptions : no
; custom : no >
type lcalc =
@ -186,7 +183,6 @@ type lcalc =
; explicitScopes : no
; assertions : yes
; defaultTerms : no
; exceptions : yes
; custom : no >
type 'a any = < .. > as 'a
@ -205,12 +201,11 @@ type dcalc_lcalc_features =
; assertions : yes >
(** Features that are common to Dcalc and Lcalc *)
type ('a, 'b) dcalc_lcalc =
< dcalc_lcalc_features ; defaultTerms : 'a ; exceptions : 'b ; custom : no >
type 'd dcalc_lcalc = < dcalc_lcalc_features ; defaultTerms : 'd ; custom : no >
(** This type regroups Dcalc and Lcalc ASTs. *)
type ('a, 'b, 'c) interpr_kind =
< dcalc_lcalc_features ; defaultTerms : 'a ; exceptions : 'b ; custom : 'c >
type ('d, 'c) interpr_kind =
< dcalc_lcalc_features ; defaultTerms : 'd ; custom : 'c >
(** This type corresponds to the types handled by the interpreter: it regroups
Dcalc and Lcalc ASTs and may have custom terms *)
@ -222,11 +217,11 @@ type typ = naked_typ Mark.pos
and naked_typ =
| TLit of typ_lit
| TArrow of typ list * typ
| TTuple of typ list
| TStruct of StructName.t
| TEnum of EnumName.t
| TOption of typ
| TArrow of typ list * typ
| TArray of typ
| TDefault of typ
| TAny
@ -371,8 +366,7 @@ module Op = struct
(* * polymorphic *)
| Reduce : < polymorphic ; .. > t
| Fold : < polymorphic ; .. > t
| HandleDefault : < polymorphic ; .. > t
| HandleDefaultOpt : < polymorphic ; .. > t
| HandleExceptions : < polymorphic ; .. > t
end
type 'a operator = 'a Op.t
@ -562,13 +556,6 @@ and ('a, 'b, 'm) base_gexpr =
| EErrorOnEmpty :
('a, 'm) gexpr
-> ('a, < defaultTerms : yes ; .. >, 'm) base_gexpr
(* Lambda calculus with exceptions *)
| ERaiseEmpty : ('a, < exceptions : yes ; .. >, 'm) base_gexpr
| ECatchEmpty : {
body : ('a, 'm) gexpr;
handler : ('a, 'm) gexpr;
}
-> ('a, < exceptions : yes ; .. >, 'm) base_gexpr
(* Only used during evaluation *)
| ECustom : {
obj : Obj.t;
@ -673,8 +660,14 @@ type scope_info = {
out_struct_fields : StructField.t ScopeVar.Map.t;
}
type module_intf_id = { hash : Hash.t; is_external : bool }
type module_tree_node = { deps : module_tree; intf_id : module_intf_id }
and module_tree = module_tree_node ModuleName.Map.t
(** In practice, this is a DAG: beware of repeated names *)
type module_tree = M of module_tree ModuleName.Map.t [@@caml.unboxed]
type visibility = Private | Public
type decl_ctx = {
ctx_enums : enum_ctx;
@ -693,5 +686,5 @@ type 'e program = {
decl_ctx : decl_ctx;
code_items : 'e code_item_list;
lang : Global.backend_lang;
module_name : ModuleName.t option;
module_name : (ModuleName.t * module_intf_id) option;
}

View File

@ -159,10 +159,6 @@ let eifthenelse cond etrue efalse =
let eerroronempty e1 = Box.app1 e1 @@ fun e1 -> EErrorOnEmpty e1
let eempty mark = Mark.add mark (Bindlib.box EEmpty)
let eraiseempty mark = Mark.add mark (Bindlib.box ERaiseEmpty)
let ecatchempty body handler =
Box.app2 body handler @@ fun body handler -> ECatchEmpty { body; handler }
let ecustom obj targs tret mark =
Mark.add mark (Bindlib.box (ECustom { obj; targs; tret }))
@ -347,8 +343,6 @@ let map
| EPureDefault e1 -> epuredefault (f e1) m
| EEmpty -> eempty m
| EErrorOnEmpty e1 -> eerroronempty (f e1) m
| ECatchEmpty { body; handler } -> ecatchempty (f body) (f handler) m
| ERaiseEmpty -> eraiseempty m
| ELocation loc -> elocation loc m
| EStruct { name; fields } ->
let fields = StructField.Map.map f fields in
@ -388,9 +382,7 @@ let shallow_fold
(acc : 'acc) : 'acc =
let lfold x acc = List.fold_left (fun acc x -> f x acc) acc x in
match Mark.remove e with
| ELit _ | EVar _ | EFatalError _ | EExternal _ | ERaiseEmpty | ELocation _
| EEmpty ->
acc
| ELit _ | EVar _ | EFatalError _ | EExternal _ | ELocation _ | EEmpty -> acc
| EApp { f = e; args; _ } -> acc |> f e |> lfold args
| EAppOp { args; _ } -> acc |> lfold args
| EArray args -> acc |> lfold args
@ -405,7 +397,6 @@ let shallow_fold
| EDefault { excepts; just; cons } -> acc |> lfold excepts |> f just |> f cons
| EPureDefault e -> acc |> f e
| EErrorOnEmpty e -> acc |> f e
| ECatchEmpty { body; handler } -> acc |> f body |> f handler
| EStruct { fields; _ } -> acc |> StructField.Map.fold (fun _ -> f) fields
| EDStructAmend { e; fields; _ } ->
acc |> f e |> Ident.Map.fold (fun _ -> f) fields
@ -492,11 +483,6 @@ let map_gather
| EErrorOnEmpty e ->
let acc, e = f e in
acc, eerroronempty e m
| ECatchEmpty { body; handler } ->
let acc1, body = f body in
let acc2, handler = f handler in
join acc1 acc2, ecatchempty body handler m
| ERaiseEmpty -> acc, eraiseempty m
| ELocation loc -> acc, elocation loc m
| EStruct { name; fields } ->
let acc, fields =
@ -573,7 +559,7 @@ let untype e = map_marks ~f:(fun m -> Untyped { pos = mark_pos m }) e
let is_value (type a) (e : (a, _) gexpr) =
match Mark.remove e with
| ELit _ | EAbs _ | ERaiseEmpty | ECustom _ | EExternal _ -> true
| ELit _ | EAbs _ | ECustom _ | EExternal _ -> true
| _ -> false
let equal_lit (l1 : lit) (l2 : lit) =
@ -705,10 +691,6 @@ and equal : type a. (a, 't) gexpr -> (a, 't) gexpr -> bool =
equal if1 if2 && equal then1 then2 && equal else1 else2
| EEmpty, EEmpty -> true
| EErrorOnEmpty e1, EErrorOnEmpty e2 -> equal e1 e2
| ERaiseEmpty, ERaiseEmpty -> true
| ( ECatchEmpty { body = etry1; handler = ewith1 },
ECatchEmpty { body = etry2; handler = ewith2 } ) ->
equal etry1 etry2 && equal ewith1 ewith2
| ELocation l1, ELocation l2 ->
equal_location (Mark.add Pos.no_pos l1) (Mark.add Pos.no_pos l2)
| ( EStruct { name = s1; fields = fields1 },
@ -753,10 +735,9 @@ and equal : type a. (a, 't) gexpr -> (a, 't) gexpr -> bool =
Type.equal_list targs1 targs2 && Type.equal tret1 tret2 && obj1 == obj2
| ( ( EVar _ | EExternal _ | ETuple _ | ETupleAccess _ | EArray _ | ELit _
| EAbs _ | EApp _ | EAppOp _ | EAssert _ | EFatalError _ | EDefault _
| EPureDefault _ | EIfThenElse _ | EEmpty | EErrorOnEmpty _ | ERaiseEmpty
| ECatchEmpty _ | ELocation _ | EStruct _ | EDStructAmend _
| EDStructAccess _ | EStructAccess _ | EInj _ | EMatch _ | EScopeCall _
| ECustom _ ),
| EPureDefault _ | EIfThenElse _ | EEmpty | EErrorOnEmpty _ | ELocation _
| EStruct _ | EDStructAmend _ | EDStructAccess _ | EStructAccess _
| EInj _ | EMatch _ | EScopeCall _ | ECustom _ ),
_ ) ->
false
@ -860,11 +841,6 @@ let rec compare : type a. (a, _) gexpr -> (a, _) gexpr -> int =
| EEmpty, EEmpty -> 0
| EErrorOnEmpty e1, EErrorOnEmpty e2 ->
compare e1 e2
| ERaiseEmpty, ERaiseEmpty -> 0
| ECatchEmpty {body=etry1; handler=ewith1},
ECatchEmpty {body=etry2; handler=ewith2} ->
compare etry1 etry2 @@< fun () ->
compare ewith1 ewith2
| ECustom _, _ | _, ECustom _ ->
(* fixme: ideally this would be forbidden by typing *)
invalid_arg "Custom block comparison"
@ -891,9 +867,7 @@ let rec compare : type a. (a, _) gexpr -> (a, _) gexpr -> int =
| EDefault _, _ -> -1 | _, EDefault _ -> 1
| EPureDefault _, _ -> -1 | _, EPureDefault _ -> 1
| EEmpty , _ -> -1 | _, EEmpty -> 1
| EErrorOnEmpty _, _ -> -1 | _, EErrorOnEmpty _ -> 1
| ERaiseEmpty, _ -> -1 | _, ERaiseEmpty -> 1
| ECatchEmpty _, _ -> . | _, ECatchEmpty _ -> .
| EErrorOnEmpty _, _ -> . | _, EErrorOnEmpty _ -> .
let rec free_vars : ('a, 't) gexpr -> ('a, 't) gexpr Var.Set.t = function
| EVar v, _ -> Var.Set.singleton v
@ -1024,8 +998,6 @@ let rec size : type a. (a, 't) gexpr -> int =
(fun acc except -> acc + size except)
(1 + size just + size cons)
excepts
| ERaiseEmpty -> 1
| ECatchEmpty { body; handler } -> 1 + size body + size handler
| ELocation _ -> 1
| EStruct { fields; _ } ->
StructField.Map.fold (fun _ e acc -> acc + 1 + size e) fields 0

View File

@ -117,13 +117,6 @@ val eerroronempty :
'm mark ->
((< defaultTerms : yes ; .. > as 'a), 'm) boxed_gexpr
val ecatchempty :
('a, 'm) boxed_gexpr ->
('a, 'm) boxed_gexpr ->
'm mark ->
((< exceptions : yes ; .. > as 'a), 'm) boxed_gexpr
val eraiseempty : 'm mark -> (< exceptions : yes ; .. >, 'm) boxed_gexpr
val elocation : 'a glocation -> 'm mark -> ((< .. > as 'a), 'm) boxed_gexpr
val estruct :

View File

@ -422,36 +422,7 @@ let rec evaluate_operator
ELit (LBool (o_eq_dat_dat x y))
| Eq_dur_dur, [(ELit (LDuration x), _); (ELit (LDuration y), _)] ->
ELit (LBool (o_eq_dur_dur (rpos ()) x y))
| HandleDefault, [(EArray excepts, _); just; cons] -> (
(* This case is for lcalc with exceptions: we rely OCaml exception handling
here *)
match
List.filter_map
(fun e ->
try Some (evaluate_expr (Expr.unthunk_term_nobox e))
with Runtime.Empty -> None)
excepts
with
| [] -> (
let just = evaluate_expr (Expr.unthunk_term_nobox just) in
match Mark.remove just with
| ELit (LBool true) ->
Mark.remove (evaluate_expr (Expr.unthunk_term_nobox cons))
| ELit (LBool false) -> raise Runtime.Empty
| _ ->
Message.error ~pos
"Default justification has not been reduced to a boolean at@ \
evaluation@ (should not happen if the term was well-typed@\n\
%a@."
Expr.format just)
| [e] -> Mark.remove e
| es ->
raise
Runtime.(
Error
(Conflict, List.map (fun e -> Expr.pos_to_runtime (Expr.pos e)) es))
)
| HandleDefaultOpt, [(EArray exps, _); justification; conclusion] -> (
| HandleExceptions, [(EArray exps, _)] -> (
let valid_exceptions =
ListLabels.filter exps ~f:(function
| EInj { name; cons; _ }, _ when EnumName.equal name Expr.option_enum ->
@ -459,28 +430,9 @@ let rec evaluate_operator
| _ -> err ())
in
match valid_exceptions with
| [] -> (
let e = evaluate_expr (Expr.unthunk_term_nobox justification) in
match Mark.remove e with
| ELit (LBool true) ->
Mark.remove (evaluate_expr (Expr.unthunk_term_nobox conclusion))
| ELit (LBool false) ->
EInj
{
name = Expr.option_enum;
cons = Expr.none_constr;
e = Mark.copy justification (ELit LUnit);
}
| EInj { name; cons; e }
when EnumName.equal name Expr.option_enum
&& EnumConstructor.equal cons Expr.none_constr ->
EInj
{
name = Expr.option_enum;
cons = Expr.none_constr;
e = Mark.copy e (ELit LUnit);
}
| _ -> err ())
| [] ->
EInj
{ name = Expr.option_enum; cons = Expr.none_constr; e = ELit LUnit, m }
| [((EInj { cons; name; _ } as e), _)]
when EnumName.equal name Expr.option_enum
&& EnumConstructor.equal cons Expr.some_constr ->
@ -501,23 +453,22 @@ let rec evaluate_operator
| Lte_mon_mon | Lte_dat_dat | Lte_dur_dur | Gt_int_int | Gt_rat_rat
| Gt_mon_mon | Gt_dat_dat | Gt_dur_dur | Gte_int_int | Gte_rat_rat
| Gte_mon_mon | Gte_dat_dat | Gte_dur_dur | Eq_int_int | Eq_rat_rat
| Eq_mon_mon | Eq_dat_dat | Eq_dur_dur | HandleDefault | HandleDefaultOpt
),
| Eq_mon_mon | Eq_dat_dat | Eq_dur_dur | HandleExceptions ),
_ ) ->
err ()
(* /S\ dark magic here. This relies both on internals of [Lcalc.to_ocaml] *and*
of the OCaml runtime *)
let rec runtime_to_val :
type d e.
type d.
(decl_ctx ->
((d, e, _) interpr_kind, 'm) gexpr ->
((d, e, _) interpr_kind, 'm) gexpr) ->
((d, _) interpr_kind, 'm) gexpr ->
((d, _) interpr_kind, 'm) gexpr) ->
decl_ctx ->
'm mark ->
typ ->
Obj.t ->
(((d, e, yes) interpr_kind as 'a), 'm) gexpr =
(((d, yes) interpr_kind as 'a), 'm) gexpr =
fun eval_expr ctx m ty o ->
let m = Expr.map_ty (fun _ -> ty) m in
match Mark.remove ty with
@ -566,7 +517,11 @@ let rec runtime_to_val :
let e = runtime_to_val eval_expr ctx m ty (Obj.field o 0) in
EInj { name = Expr.option_enum; cons = Expr.some_constr; e }, m
| _ -> assert false)
| TClosureEnv -> assert false
| TClosureEnv ->
(* By construction, a closure environment can only be consumed from the same
scope where it was built (compiled or not) ; for this reason, we can
safely avoid converting in depth here *)
Obj.obj o, m
| TArray ty ->
( EArray
(List.map
@ -574,21 +529,26 @@ let rec runtime_to_val :
(Array.to_list (Obj.obj o))),
m )
| TArrow (targs, tret) -> ECustom { obj = o; targs; tret }, m
| TDefault ty -> runtime_to_val eval_expr ctx m ty o
| TDefault ty -> (
(* This case is only valid for ASTs including default terms; but the typer
isn't aware so we need some additional dark arts. *)
match (Obj.obj o : 'a Runtime.Eoption.t) with
| Runtime.Eoption.ENone () -> Obj.magic EEmpty, m
| Runtime.Eoption.ESome o -> Obj.magic (runtime_to_val eval_expr ctx m ty o)
)
| TAny -> assert false
and val_to_runtime :
type d e.
type d.
(decl_ctx ->
((d, e, _) interpr_kind, 'm) gexpr ->
((d, e, _) interpr_kind, 'm) gexpr) ->
((d, _) interpr_kind, 'm) gexpr ->
((d, _) interpr_kind, 'm) gexpr) ->
decl_ctx ->
typ ->
((d, e, _) interpr_kind, 'm) gexpr ->
((d, _) interpr_kind, 'm) gexpr ->
Obj.t =
fun eval_expr ctx ty v ->
match Mark.remove ty, Mark.remove v with
| _, EEmpty -> raise Runtime.Empty
| TLit TBool, ELit (LBool b) -> Obj.repr b
| TLit TUnit, ELit LUnit -> Obj.repr ()
| TLit TInt, ELit (LInt i) -> Obj.repr i
@ -655,18 +615,27 @@ and val_to_runtime :
curry (runtime_to_val eval_expr ctx m targ x :: acc) targs)
in
curry [] targs
| TDefault ty, _ -> val_to_runtime eval_expr ctx ty v
| TDefault ty, _ -> (
match v with
| EEmpty, _ -> Obj.repr (Runtime.Eoption.ENone ())
| EPureDefault e, _ | e ->
Obj.repr (Runtime.Eoption.ESome (val_to_runtime eval_expr ctx ty e)))
| TClosureEnv, v ->
(* By construction, a closure environment can only be consumed from the same
scope where it was built (compiled or not) ; for this reason, we can
safely avoid converting in depth here *)
Obj.repr v
| _ ->
Message.error ~internal:true
"Could not convert value of type %a@ to@ runtime:@ %a" (Print.typ ctx) ty
Expr.format v
let rec evaluate_expr :
type d e.
type d.
decl_ctx ->
Global.backend_lang ->
((d, e, yes) interpr_kind, 't) gexpr ->
((d, e, yes) interpr_kind, 't) gexpr =
((d, yes) interpr_kind, 't) gexpr ->
((d, yes) interpr_kind, 't) gexpr =
fun ctx lang e ->
let m = Mark.get e in
let pos = Expr.mark_pos m in
@ -866,18 +835,14 @@ let rec evaluate_expr :
in
raise Runtime.(Error (Conflict, poslist)))
| EPureDefault e -> evaluate_expr ctx lang e
| ERaiseEmpty -> raise Runtime.Empty
| ECatchEmpty { body; handler } -> (
try evaluate_expr ctx lang body
with Runtime.Empty -> evaluate_expr ctx lang handler)
| _ -> .
and partially_evaluate_expr_for_assertion_failure_message :
type d e.
type d.
decl_ctx ->
Global.backend_lang ->
((d, e, yes) interpr_kind, 't) gexpr ->
((d, e, yes) interpr_kind, 't) gexpr =
((d, yes) interpr_kind, 't) gexpr ->
((d, yes) interpr_kind, 't) gexpr =
fun ctx lang e ->
(* Here we want to print an expression that explains why an assertion has
failed. Since assertions have type [bool] and are usually constructed with
@ -912,11 +877,11 @@ and partially_evaluate_expr_for_assertion_failure_message :
| _ -> evaluate_expr ctx lang e
let evaluate_expr_trace :
type d e.
type d.
decl_ctx ->
Global.backend_lang ->
((d, e, yes) interpr_kind, 't) gexpr ->
((d, e, yes) interpr_kind, 't) gexpr =
((d, yes) interpr_kind, 't) gexpr ->
((d, yes) interpr_kind, 't) gexpr =
fun ctx lang e ->
Fun.protect
(fun () -> evaluate_expr ctx lang e)
@ -928,11 +893,11 @@ let evaluate_expr_trace :
(Runtime.EventParser.parse_raw_events trace)] fais here, check why *))
let evaluate_expr_safe :
type d e.
type d.
decl_ctx ->
Global.backend_lang ->
((d, e, yes) interpr_kind, 't) gexpr ->
((d, e, yes) interpr_kind, 't) gexpr =
((d, yes) interpr_kind, 't) gexpr ->
((d, yes) interpr_kind, 't) gexpr =
fun ctx lang e ->
try evaluate_expr_trace ctx lang e
with Runtime.Error (err, rpos) ->
@ -944,9 +909,9 @@ let evaluate_expr_safe :
(* Typing shenanigan to add custom terms to the AST type. *)
let addcustom e =
let rec f :
type c d e.
((d, e, c) interpr_kind, 't) gexpr ->
((d, e, yes) interpr_kind, 't) gexpr boxed = function
type c d.
((d, c) interpr_kind, 't) gexpr -> ((d, yes) interpr_kind, 't) gexpr boxed
= function
| (ECustom _, _) as e -> Expr.map ~f e
| EAppOp { op; tys; args }, m ->
Expr.eappop ~tys ~args:(List.map f args) ~op:(Operator.translate op) m
@ -954,8 +919,6 @@ let addcustom e =
| (EPureDefault _, _) as e -> Expr.map ~f e
| (EEmpty, _) as e -> Expr.map ~f e
| (EErrorOnEmpty _, _) as e -> Expr.map ~f e
| (ECatchEmpty _, _) as e -> Expr.map ~f e
| (ERaiseEmpty, _) as e -> Expr.map ~f e
| ( ( EAssert _ | EFatalError _ | ELit _ | EApp _ | EArray _ | EVar _
| EExternal _ | EAbs _ | EIfThenElse _ | ETuple _ | ETupleAccess _
| EInj _ | EStruct _ | EStructAccess _ | EMatch _ ),
@ -965,8 +928,8 @@ let addcustom e =
in
let open struct
external id :
(('d, 'e, 'c) interpr_kind, 't) gexpr ->
(('d, 'e, yes) interpr_kind, 't) gexpr = "%identity"
(('d, 'c) interpr_kind, 't) gexpr -> (('d, yes) interpr_kind, 't) gexpr
= "%identity"
end in
if false then Expr.unbox (f e)
(* We keep the implementation as a typing proof, but bypass the AST
@ -976,9 +939,9 @@ let addcustom e =
let delcustom e =
let rec f :
type c d e.
((d, e, c) interpr_kind, 't) gexpr ->
((d, e, no) interpr_kind, 't) gexpr boxed = function
type c d.
((d, c) interpr_kind, 't) gexpr -> ((d, no) interpr_kind, 't) gexpr boxed
= function
| ECustom _, _ -> invalid_arg "Custom term remaining in evaluated term"
| EAppOp { op; args; tys }, m ->
Expr.eappop ~tys ~args:(List.map f args) ~op:(Operator.translate op) m
@ -986,8 +949,6 @@ let delcustom e =
| (EPureDefault _, _) as e -> Expr.map ~f e
| (EEmpty, _) as e -> Expr.map ~f e
| (EErrorOnEmpty _, _) as e -> Expr.map ~f e
| (ECatchEmpty _, _) as e -> Expr.map ~f e
| (ERaiseEmpty, _) as e -> Expr.map ~f e
| ( ( EAssert _ | EFatalError _ | ELit _ | EApp _ | EArray _ | EVar _
| EExternal _ | EAbs _ | EIfThenElse _ | ETuple _ | ETupleAccess _
| EInj _ | EStruct _ | EStructAccess _ | EMatch _ ),
@ -1018,22 +979,13 @@ let interpret_program_lcalc p s : (Uid.MarkedString.info * ('a, 'm) gexpr) list
(fun ty ->
match Mark.remove ty with
| TArrow (ty_in, (TOption _, _)) ->
(* Context args may return an option if avoid_exceptions is on *)
(* Context args should return an option *)
Expr.make_abs
(Array.of_list @@ List.map (fun _ -> Var.make "_") ty_in)
(Expr.einj ~e:(Expr.elit LUnit mark_e) ~cons:Expr.none_constr
~name:Expr.option_enum mark_e
: (_, _) boxed_gexpr)
ty_in pos
| TArrow (ty_in, ty_out) ->
(* Or a default term (translated into a plain one if it is off) *)
(* Note: this might catch non-context args, but since the
compilation to lcalc strips the default around [ty_out] we can't
tell with just this info. *)
Expr.make_abs
(Array.of_list @@ List.map (fun _ -> Var.make "_") ty_in)
(Expr.eraiseempty (Expr.with_ty mark_e ty_out))
ty_in (Expr.mark_pos mark_e)
| TTuple ((TArrow (ty_in, (TOption _, _)), _) :: _) ->
(* ... or a closure if closure conversion is enabled *)
Expr.make_tuple
@ -1155,29 +1107,57 @@ let interpret_program_dcalc p s : (Uid.MarkedString.info * ('a, 'm) gexpr) list
reflect that. *)
let evaluate_expr ctx lang e = evaluate_expr ctx lang (addcustom e)
let load_runtime_modules prg =
let load m =
let load_runtime_modules ~hashf prg =
let load (mname, intf_id) =
let hash = hashf intf_id.hash in
let expect_hash =
if intf_id.is_external then Hash.external_placeholder
else Hash.to_string hash
in
let obj_file =
Dynlink.adapt_filename
File.(Pos.get_file (Mark.get (ModuleName.get_info m)) -.- "cmo")
File.(Pos.get_file (Mark.get (ModuleName.get_info mname)) -.- "cmo")
in
if not (Sys.file_exists obj_file) then
(if not (Sys.file_exists obj_file) then
Message.error
~pos_msg:(fun ppf -> Format.pp_print_string ppf "Module defined here")
~pos:(Mark.get (ModuleName.get_info mname))
"Compiled OCaml object %a@ not@ found.@ Make sure it has been \
suitably compiled."
File.format obj_file
else
try Dynlink.loadfile obj_file
with Dynlink.Error dl_err ->
Message.error
"While loading compiled module from %a:@;<1 2>@[<hov>%a@]"
File.format obj_file Format.pp_print_text
(Dynlink.error_message dl_err));
match Runtime.check_module (ModuleName.to_string mname) expect_hash with
| Ok () -> ()
| Error bad_hash ->
Message.debug
"Module hash mismatch for %a:@ @[<v>Expected: %a@,Found: %a@]"
ModuleName.format mname Hash.format hash
(fun ppf h ->
try Hash.format ppf (Hash.of_string h)
with Failure _ ->
if h = Hash.external_placeholder then
Format.fprintf ppf "@{<cyan>%s@}" Hash.external_placeholder
else Format.fprintf ppf "@{<red><invalid>@}")
bad_hash;
Message.error
~pos_msg:(fun ppf -> Format.pp_print_string ppf "Module defined here")
~pos:(Mark.get (ModuleName.get_info m))
"Compiled OCaml object %a@ not@ found.@ Make sure it has been suitably \
compiled."
File.format obj_file
else
try Dynlink.loadfile obj_file
with Dynlink.Error dl_err ->
Message.error "Error loading compiled module from %a:@;<1 2>@[<hov>%a@]"
File.format obj_file Format.pp_print_text
(Dynlink.error_message dl_err)
"Module %a@ needs@ recompiling:@ %a@ was@ likely@ compiled@ from@ an@ \
older@ version@ or@ with@ incompatible@ flags."
ModuleName.format mname File.format obj_file
| exception Not_found ->
Message.error
"Module %a@ was loaded from file %a but did not register properly, \
there is something wrong in its code."
ModuleName.format mname File.format obj_file
in
let modules_list_topo = Program.modules_to_list prg.decl_ctx.ctx_modules in
if modules_list_topo <> [] then
Message.debug "Loading shared modules... %a"
(Format.pp_print_list ~pp_sep:Format.pp_print_space ModuleName.format)
modules_list_topo;
(List.map (fun (m, _) -> m) modules_list_topo);
List.iter load modules_list_topo

View File

@ -21,7 +21,7 @@ open Catala_utils
open Definitions
val evaluate_operator :
((((_, _, _) interpr_kind as 'a), 'm) gexpr -> ('a, 'm) gexpr) ->
((((_, _) interpr_kind as 'a), 'm) gexpr -> ('a, 'm) gexpr) ->
'a operator Mark.pos ->
'm mark ->
Global.backend_lang ->
@ -35,14 +35,14 @@ val evaluate_operator :
val evaluate_expr :
decl_ctx ->
Global.backend_lang ->
(('a, 'b, _) interpr_kind, 'm) gexpr ->
(('a, 'b, yes) interpr_kind, 'm) gexpr
(('a, _) interpr_kind, 'm) gexpr ->
(('a, yes) interpr_kind, 'm) gexpr
(** Evaluates an expression according to the semantics of the default calculus. *)
val interpret_program_dcalc :
(dcalc, 'm) gexpr program ->
ScopeName.t ->
(Uid.MarkedString.info * ((yes, no, yes) interpr_kind, 'm) gexpr) list
(Uid.MarkedString.info * ((yes, yes) interpr_kind, 'm) gexpr) list
(** Interprets a program. This function expects an expression typed as a
function whose argument are all thunked. The function is executed by
providing for each argument a thunked empty default. Returns a list of all
@ -51,17 +51,17 @@ val interpret_program_dcalc :
val interpret_program_lcalc :
(lcalc, 'm) gexpr program ->
ScopeName.t ->
(Uid.MarkedString.info * ((no, yes, yes) interpr_kind, 'm) gexpr) list
(Uid.MarkedString.info * ((no, yes) interpr_kind, 'm) gexpr) list
(** Interprets a program. This function expects an expression typed as a
function whose argument are all thunked. The function is executed by
providing for each argument a thunked empty default. Returns a list of all
the computed values for the scope variables of the executed scope. *)
val delcustom :
(('a, 'b, 'c) interpr_kind, 'm) gexpr -> (('a, 'b, no) interpr_kind, 'm) gexpr
(('a, 'b) interpr_kind, 'm) gexpr -> (('a, no) interpr_kind, 'm) gexpr
(** Runtime check that the term contains no custom terms (raises
[Invalid_argument] if that is the case *)
val load_runtime_modules : _ program -> unit
val load_runtime_modules : hashf:(Hash.t -> Hash.full) -> _ program -> unit
(** Dynlink the runtime modules required by the given program, in order to make
them callable by the interpreter. *)

View File

@ -108,8 +108,7 @@ let name : type a. a t -> string = function
| Eq_dur_dur -> "o_eq_dur_dur"
| Eq_dat_dat -> "o_eq_dat_dat"
| Fold -> "o_fold"
| HandleDefault -> "o_handledefault"
| HandleDefaultOpt -> "o_handledefaultopt"
| HandleExceptions -> "handle_exceptions"
| ToClosureEnv -> "o_toclosureenv"
| FromClosureEnv -> "o_fromclosureenv"
@ -232,8 +231,7 @@ let compare (type a1 a2) (t1 : a1 t) (t2 : a2 t) =
| Eq_dat_dat, Eq_dat_dat
| Eq_dur_dur, Eq_dur_dur
| Fold, Fold
| HandleDefault, HandleDefault
| HandleDefaultOpt, HandleDefaultOpt
| HandleExceptions, HandleExceptions
| FromClosureEnv, FromClosureEnv | ToClosureEnv, ToClosureEnv -> 0
| Not, _ -> -1 | _, Not -> 1
| Length, _ -> -1 | _, Length -> 1
@ -318,8 +316,7 @@ let compare (type a1 a2) (t1 : a1 t) (t2 : a2 t) =
| Eq_mon_mon, _ -> -1 | _, Eq_mon_mon -> 1
| Eq_dat_dat, _ -> -1 | _, Eq_dat_dat -> 1
| Eq_dur_dur, _ -> -1 | _, Eq_dur_dur -> 1
| HandleDefault, _ -> -1 | _, HandleDefault -> 1
| HandleDefaultOpt, _ -> -1 | _, HandleDefaultOpt -> 1
| HandleExceptions, _ -> -1 | _, HandleExceptions -> 1
| FromClosureEnv, _ -> -1 | _, FromClosureEnv -> 1
| ToClosureEnv, _ -> -1 | _, ToClosureEnv -> 1
| Fold, _ | _, Fold -> .
@ -344,7 +341,7 @@ let kind_dispatch :
_ ) as op ->
monomorphic op
| ( ( Log _ | Length | Eq | Map | Map2 | Concat | Filter | Reduce | Fold
| HandleDefault | HandleDefaultOpt | FromClosureEnv | ToClosureEnv ),
| HandleExceptions | FromClosureEnv | ToClosureEnv ),
_ ) as op ->
polymorphic op
| ( ( Minus | ToRat | ToMoney | Round | Add | Sub | Mult | Div | Lt | Lte | Gt
@ -377,19 +374,19 @@ type 'a no_overloads =
let translate (t : 'a no_overloads t Mark.pos) : 'b no_overloads t Mark.pos =
match t with
| ( ( Not | GetDay | GetMonth | GetYear | FirstDayOfMonth | LastDayOfMonth
| And | Or | Xor | HandleDefault | HandleDefaultOpt | Log _ | Length | Eq
| Map | Map2 | Concat | Filter | Reduce | Fold | Minus_int | Minus_rat
| Minus_mon | Minus_dur | ToRat_int | ToRat_mon | ToMoney_rat | Round_rat
| Round_mon | Add_int_int | Add_rat_rat | Add_mon_mon | Add_dat_dur _
| Add_dur_dur | Sub_int_int | Sub_rat_rat | Sub_mon_mon | Sub_dat_dat
| Sub_dat_dur | Sub_dur_dur | Mult_int_int | Mult_rat_rat | Mult_mon_rat
| Mult_dur_int | Div_int_int | Div_rat_rat | Div_mon_mon | Div_mon_rat
| Div_dur_dur | Lt_int_int | Lt_rat_rat | Lt_mon_mon | Lt_dat_dat
| Lt_dur_dur | Lte_int_int | Lte_rat_rat | Lte_mon_mon | Lte_dat_dat
| Lte_dur_dur | Gt_int_int | Gt_rat_rat | Gt_mon_mon | Gt_dat_dat
| Gt_dur_dur | Gte_int_int | Gte_rat_rat | Gte_mon_mon | Gte_dat_dat
| Gte_dur_dur | Eq_int_int | Eq_rat_rat | Eq_mon_mon | Eq_dat_dat
| Eq_dur_dur | FromClosureEnv | ToClosureEnv ),
| And | Or | Xor | HandleExceptions | Log _ | Length | Eq | Map | Map2
| Concat | Filter | Reduce | Fold | Minus_int | Minus_rat | Minus_mon
| Minus_dur | ToRat_int | ToRat_mon | ToMoney_rat | Round_rat | Round_mon
| Add_int_int | Add_rat_rat | Add_mon_mon | Add_dat_dur _ | Add_dur_dur
| Sub_int_int | Sub_rat_rat | Sub_mon_mon | Sub_dat_dat | Sub_dat_dur
| Sub_dur_dur | Mult_int_int | Mult_rat_rat | Mult_mon_rat | Mult_dur_int
| Div_int_int | Div_rat_rat | Div_mon_mon | Div_mon_rat | Div_dur_dur
| Lt_int_int | Lt_rat_rat | Lt_mon_mon | Lt_dat_dat | Lt_dur_dur
| Lte_int_int | Lte_rat_rat | Lte_mon_mon | Lte_dat_dat | Lte_dur_dur
| Gt_int_int | Gt_rat_rat | Gt_mon_mon | Gt_dat_dat | Gt_dur_dur
| Gte_int_int | Gte_rat_rat | Gte_mon_mon | Gte_dat_dat | Gte_dur_dur
| Eq_int_int | Eq_rat_rat | Eq_mon_mon | Eq_dat_dat | Eq_dur_dur
| FromClosureEnv | ToClosureEnv ),
_ ) as op ->
op

View File

@ -58,14 +58,14 @@ let all_match_cases_map_to_same_constructor cases n =
let binder_vars_used_at_most_once
(binder :
( (('a, 'b) dcalc_lcalc, ('a, 'b) dcalc_lcalc, 'm) base_gexpr,
(('a, 'b) dcalc_lcalc, 'm) gexpr )
( ('a dcalc_lcalc, 'a dcalc_lcalc, 'm) base_gexpr,
('a dcalc_lcalc, 'm) gexpr )
Bindlib.mbinder) : bool =
(* fast path: variables not used at all *)
(not (Array.exists Fun.id (Bindlib.mbinder_occurs binder)))
||
let vars, body = Bindlib.unmbind binder in
let rec vars_count (e : (('a, 'b) dcalc_lcalc, 'm) gexpr) : int array =
let rec vars_count (e : ('a dcalc_lcalc, 'm) gexpr) : int array =
match e with
| EVar v, _ ->
Array.map
@ -82,8 +82,8 @@ let binder_vars_used_at_most_once
let rec optimize_expr :
type a b.
(a, b, 'm) optimizations_ctx ->
((a, b) dcalc_lcalc, 'm) gexpr ->
((a, b) dcalc_lcalc, 'm) boxed_gexpr =
(a dcalc_lcalc, 'm) gexpr ->
(a dcalc_lcalc, 'm) boxed_gexpr =
fun ctx e ->
(* We proceed bottom-up, first apply on the subterms *)
let e = Expr.map ~f:(optimize_expr ctx) ~op:Fun.id e in
@ -92,7 +92,7 @@ let rec optimize_expr :
able to keep the inner position (see the division_by_zero test) *)
(* Then reduce the parent node (this is applied through Box.apply, therefore
delayed to unbinding time: no need to be concerned about reboxing) *)
let reduce (e : ((a, b) dcalc_lcalc, 'm) gexpr) =
let reduce (e : (a dcalc_lcalc, 'm) gexpr) =
(* Todo: improve the handling of eapp(log,elit) cases here, it obfuscates
the matches and the log calls are not preserved, which would be a good
property *)
@ -365,22 +365,15 @@ let rec optimize_expr :
el) ->
(* identity tuple reconstruction *)
Mark.remove e
| ECatchEmpty { body; handler } -> (
(* peephole exception catching reductions *)
match Mark.remove body, Mark.remove handler with
| ERaiseEmpty, _ -> Mark.remove handler
| _, ERaiseEmpty -> Mark.remove body
| _ -> ECatchEmpty { body; handler })
| e -> e
in
Expr.Box.app1 e reduce mark
let optimize_expr :
'm.
decl_ctx ->
(('a, 'b) dcalc_lcalc, 'm) gexpr ->
(('a, 'b) dcalc_lcalc, 'm) boxed_gexpr =
fun (decl_ctx : decl_ctx) (e : (('a, 'b) dcalc_lcalc, 'm) gexpr) ->
decl_ctx -> ('a dcalc_lcalc, 'm) gexpr -> ('a dcalc_lcalc, 'm) boxed_gexpr
=
fun (decl_ctx : decl_ctx) (e : ('a dcalc_lcalc, 'm) gexpr) ->
optimize_expr { decl_ctx } e
let optimize_program (p : 'm program) : 'm program =

View File

@ -21,13 +21,10 @@
open Definitions
val optimize_expr :
decl_ctx ->
(('a, 'b) dcalc_lcalc, 'm) gexpr ->
(('a, 'b) dcalc_lcalc, 'm) boxed_gexpr
decl_ctx -> ('a dcalc_lcalc, 'm) gexpr -> ('a dcalc_lcalc, 'm) boxed_gexpr
val optimize_program :
(('a, 'b) dcalc_lcalc, 'm) gexpr program ->
(('a, 'b) dcalc_lcalc, 'm) gexpr program
('a dcalc_lcalc, 'm) gexpr program -> ('a dcalc_lcalc, 'm) gexpr program
(** {1 Tests}*)

View File

@ -102,7 +102,7 @@ let rec typ_gen
Format.pp_open_hvbox fmt 2;
pp_color_string (List.hd colors) fmt "(";
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt " %a@ " op_style "*")
~pp_sep:(fun fmt () -> Format.fprintf fmt "%a@ " op_style ",")
(typ ~colors:(List.tl colors)))
fmt ts;
Format.pp_close_box fmt ();
@ -142,7 +142,7 @@ let rec typ_gen
mty))
def punctuation "]")
| TOption t ->
Format.fprintf fmt "@[<hov 2>%a@ %a@]" base_type "eoption" (typ ~colors) t
Format.fprintf fmt "@[<hov 2>%a@ %a@]" base_type "option" (typ ~colors) t
| TArrow ([t1], t2) ->
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@]" (typ_with_parens ~colors) t1
op_style "" (typ ~colors) t2
@ -280,8 +280,7 @@ let operator_to_string : type a. a Op.t -> string =
| Eq_dur_dur -> "=^"
| Eq_dat_dat -> "=@"
| Fold -> "fold"
| HandleDefault -> "handle_default"
| HandleDefaultOpt -> "handle_default_opt"
| HandleExceptions -> "handle_exceptions"
| ToClosureEnv -> "to_closure_env"
| FromClosureEnv -> "from_closure_env"
@ -325,8 +324,7 @@ let operator_to_shorter_string : type a. a Op.t -> string =
| Gte_int_int | Gte_rat_rat | Gte_mon_mon | Gte_dur_dur | Gte_dat_dat | Gte ->
">="
| Fold -> "fold"
| HandleDefault -> "handle_default"
| HandleDefaultOpt -> "handle_default_opt"
| HandleExceptions -> "handle_exceptions"
| ToClosureEnv -> "to_closure_env"
| FromClosureEnv -> "from_closure_env"
@ -402,8 +400,8 @@ module Precedence = struct
| Div | Div_int_int | Div_rat_rat | Div_mon_rat | Div_mon_mon
| Div_dur_dur ->
Op Div
| HandleDefault | HandleDefaultOpt | Map | Map2 | Concat | Filter | Reduce
| Fold | ToClosureEnv | FromClosureEnv ->
| HandleExceptions | Map | Map2 | Concat | Filter | Reduce | Fold
| ToClosureEnv | FromClosureEnv ->
App)
| EApp _ -> App
| EArray _ -> Contained
@ -426,8 +424,6 @@ module Precedence = struct
| EPureDefault _ -> Contained
| EEmpty -> Contained
| EErrorOnEmpty _ -> App
| ERaiseEmpty -> App
| ECatchEmpty _ -> App
| ECustom _ -> Contained
let needs_parens ~context ?(rhs = false) e =
@ -671,12 +667,6 @@ module ExprGen (C : EXPR_PARAM) = struct
| EFatalError err ->
Format.fprintf fmt "@[<hov 2>%a@ @{<red>%s@}@]" keyword "error"
(Runtime.error_to_string err)
| ECatchEmpty { body; handler } ->
Format.fprintf fmt
"@[<hv 0>@[<hov 2>%a@ %a@]@ @[<hov 2>%a@ %a ->@ %a@]@]" keyword "try"
expr body keyword "with" op_style "Empty" (rhs exprc) handler
| ERaiseEmpty ->
Format.fprintf fmt "@[<hov 2>%a@ %a@]" keyword "raise" op_style "Empty"
| ELocation loc -> location fmt loc
| EDStructAccess { e; field; _ } ->
Format.fprintf fmt "@[<hv 2>%a%a@,%a%a%a@]" (lhs exprc) e punctuation
@ -712,7 +702,6 @@ module ExprGen (C : EXPR_PARAM) = struct
Format.fprintf fmt "@[<v 0>@[<hv 2>%a@ %a@;<1 -2>%a@]@ %a@]" keyword
"match" (lhs exprc) e keyword "with"
(EnumConstructor.Map.format_bindings
~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n")
(fun fmt pp_cons_name case_expr ->
match case_expr with
| EAbs { binder; tys; _ }, _ ->
@ -879,13 +868,12 @@ let enum
fmt
(pp_name : Format.formatter -> unit)
(c : typ EnumConstructor.Map.t) =
Format.fprintf fmt "@[<h 0>%a %t %a@ %a@]" keyword "type" pp_name punctuation
"="
(EnumConstructor.Map.format_bindings
~pp_sep:(fun _ _ -> ())
Format.fprintf fmt "@[<h 0>%a %t %a@ %a@]@," keyword "type" pp_name
punctuation "="
(EnumConstructor.Map.format_bindings ~pp_sep:Format.pp_print_space
(fun fmt pp_n ty ->
Format.fprintf fmt "@[<hov2> %a %t %a %a@]@;" punctuation "|" pp_n
keyword "of"
Format.fprintf fmt "@[<hov2>%a %t %a %a@]" punctuation "|" pp_n keyword
"of"
(if debug then typ_debug else typ decl_ctx)
ty))
c
@ -908,14 +896,10 @@ let struct_
let decl_ctx ?(debug = false) decl_ctx (fmt : Format.formatter) (ctx : decl_ctx)
: unit =
let { ctx_enums; ctx_structs; _ } = ctx in
Format.fprintf fmt "%a@.%a@.@."
(EnumName.Map.format_bindings
~pp_sep:(fun fmt () -> Format.fprintf fmt "@.")
(enum ~debug decl_ctx))
Format.fprintf fmt "@[<v>%a@,%a@,@,@]"
(EnumName.Map.format_bindings (enum ~debug decl_ctx))
ctx_enums
(StructName.Map.format_bindings
~pp_sep:(fun fmt () -> Format.fprintf fmt "@.")
(struct_ ~debug decl_ctx))
(StructName.Map.format_bindings (struct_ ~debug decl_ctx))
ctx_structs
let scope
@ -947,11 +931,15 @@ let code_item ?(debug = false) ?name decl_ctx fmt c =
"=" (expr ~debug ()) e
let code_item_list ?(debug = false) decl_ctx fmt c =
BoundList.iter c ~f:(fun x item ->
Format.pp_open_vbox fmt 0;
Format.pp_print_seq
(fun fmt (x, item) ->
code_item ~debug
~name:(Format.asprintf "%a" var_debug x)
decl_ctx fmt item;
Format.pp_print_newline fmt ())
Format.pp_print_cut fmt ())
fmt (BoundList.to_seq c);
Format.pp_close_box fmt ()
let program ?(debug = false) fmt p =
decl_ctx ~debug p.decl_ctx fmt p.decl_ctx;
@ -1124,6 +1112,8 @@ module UserFacing = struct
~pp_sep:(fun ppf () -> Format.fprintf ppf ";@ ")
(value ~fallback lang))
l
| ETuple [(EAbs { tys = (TClosureEnv, _) :: _; _ }, _); _] ->
Format.pp_print_string ppf "<function>"
| ETuple l ->
Format.fprintf ppf "@[<hv 2>(@,@[<hov>%a@]@;<0 -2>)@]"
(Format.pp_print_list
@ -1145,8 +1135,8 @@ module UserFacing = struct
| EExternal _ -> Format.pp_print_string ppf "<external>"
| EApp _ | EAppOp _ | EVar _ | EIfThenElse _ | EMatch _ | ETupleAccess _
| EStructAccess _ | EAssert _ | EFatalError _ | EDefault _ | EPureDefault _
| EErrorOnEmpty _ | ERaiseEmpty | ECatchEmpty _ | ELocation _ | EScopeCall _
| EDStructAmend _ | EDStructAccess _ | ECustom _ ->
| EErrorOnEmpty _ | ELocation _ | EScopeCall _ | EDStructAmend _
| EDStructAccess _ | ECustom _ ->
fallback ppf e
let expr :

View File

@ -58,7 +58,7 @@ let empty_ctx =
ctx_struct_fields = Ident.Map.empty;
ctx_enum_constrs = Ident.Map.empty;
ctx_scope_index = Ident.Map.empty;
ctx_modules = M ModuleName.Map.empty;
ctx_modules = ModuleName.Map.empty;
}
let get_scope_body { code_items; _ } scope =
@ -87,11 +87,11 @@ let to_expr p main_scope =
res
let modules_to_list (mt : module_tree) =
let rec aux acc (M mtree) =
let rec aux acc mtree =
ModuleName.Map.fold
(fun mname sub acc ->
if List.exists (ModuleName.equal mname) acc then acc
else mname :: aux acc sub)
(fun mname mnode acc ->
if List.exists (fun (m, _) -> ModuleName.equal m mname) acc then acc
else (mname, mnode.intf_id) :: aux acc mnode.deps)
mtree acc
in
List.rev (aux [] mt)

View File

@ -53,5 +53,6 @@ val to_expr : ((_ any, _) gexpr as 'e) program -> ScopeName.t -> 'e boxed
val find_scope : ScopeName.t -> 'e code_item_list -> 'e scope_body
val modules_to_list : module_tree -> ModuleName.t list
(** Returns a list of used modules, in topological order *)
val modules_to_list : module_tree -> (ModuleName.t * module_intf_id) list
(** Returns a list of used modules, in topological order ; the boolean indicates
if the module is external *)

View File

@ -93,6 +93,22 @@ let rec compare ty1 ty2 =
| TClosureEnv, _ -> -1
| _, TClosureEnv -> 1
let rec hash ~strip ty =
let open Hash.Op in
match Mark.remove ty with
| TLit l -> !`TLit % !(l : typ_lit)
| TTuple tl -> List.fold_left (fun acc ty -> acc % hash ~strip ty) !`TTuple tl
| TStruct n -> !`TStruct % StructName.hash ~strip n
| TEnum n -> !`TEnum % EnumName.hash ~strip n
| TOption ty -> !`TOption % hash ~strip ty
| TArrow (tl, ty) ->
!`TArrow
% List.fold_left (fun acc ty -> acc % hash ~strip ty) (hash ~strip ty) tl
| TArray ty -> !`TArray % hash ~strip ty
| TDefault ty -> !`TDefault % hash ~strip ty
| TAny -> !`TAny
| TClosureEnv -> !`TClosureEnv
let rec arrow_return = function TArrow (_, b), _ -> arrow_return b | t -> t
let format = Print.typ_debug

View File

@ -14,6 +14,8 @@
License for the specific language governing permissions and limitations under
the License. *)
open Catala_utils
type t = Definitions.typ
val format : Format.formatter -> t -> unit
@ -23,6 +25,11 @@ module Map : Catala_utils.Map.S with type key = t
val equal : t -> t -> bool
val equal_list : t list -> t list -> bool
val compare : t -> t -> int
val hash : strip:Uid.Path.t -> t -> Hash.t
(** The [strip] argument strips the given leading path components in included
identifiers before hashing *)
val unifiable : t -> t -> bool
val unifiable_list : t list -> t list -> bool

View File

@ -31,6 +31,7 @@ module Any =
let format fmt () = Format.fprintf fmt "any"
let equal () () = true
let compare () () = 0
let hash () = Hash.raw `Any
end)
(struct
let style = Ocolor_types.(Fg (C4 hi_magenta))
@ -166,7 +167,7 @@ let rec format_typ
format_typ ~colors fmt t1;
Format.pp_print_as fmt 1 ""
| TAny v ->
if Global.options.debug then Format.fprintf fmt "<a%d>" (Any.hash v)
if Global.options.debug then Format.fprintf fmt "<a%d>" (Any.id v)
else Format.pp_print_string fmt "<any>"
| TClosureEnv -> Format.fprintf fmt "closure_env"
@ -174,9 +175,59 @@ let rec colors =
let open Ocolor_types in
blue :: cyan :: green :: yellow :: red :: magenta :: colors
let dummy_flags = { fail_on_any = false; assume_op_types = false }
let format_typ ctx fmt naked_typ = format_typ ctx ~colors fmt naked_typ
exception Type_error of A.any_expr * unionfind_typ * unionfind_typ
let record_type_error _ctx (A.AnyExpr e) t1 t2 =
(* We convert union-find types to ast ones otherwise error messages would be
hindered as union-find side-effects wrongly unify both types. The delayed
pretty-printing would yield messages such as: 'incompatible types (integer,
integer)' *)
let t1_repr = typ_to_ast ~flags:dummy_flags t1 in
let t2_repr = typ_to_ast ~flags:dummy_flags t2 in
let e_pos = Expr.pos e in
let t1_pos = Mark.get t1_repr in
let t2_pos = Mark.get t2_repr in
let pp_typ = Print.typ_debug in
let fmt_pos =
if e_pos = t1_pos then
[
( (fun ppf ->
Format.fprintf ppf "@[<hv 2>@[<hov>%a@ %a@]:" Format.pp_print_text
"This expression has type" pp_typ t1_repr;
if Global.options.debug then
Format.fprintf ppf "@ %a@]" Expr.format e
else Format.pp_close_box ppf ()),
e_pos );
( (fun ppf ->
Format.fprintf ppf
"@[<hov>Expected@ type@ %a@ coming@ from@ expression:@]" pp_typ
t2_repr),
t2_pos );
]
else
[
( (fun ppf ->
Format.fprintf ppf "@[<hv 2>@[<hov>%a:@]" Format.pp_print_text
"While typechecking the following expression";
if Global.options.debug then
Format.fprintf ppf "@ %a@]" Expr.format e
else Format.pp_close_box ppf ()),
e_pos );
( (fun ppf ->
Format.fprintf ppf "@[<hov>Type@ %a@ is@ coming@ from:@]" pp_typ
t1_repr),
t1_pos );
( (fun ppf ->
Format.fprintf ppf "@[<hov>Type@ %a@ is@ coming@ from:@]" pp_typ
t2_repr),
t2_pos );
]
in
Message.delayed_error () ~fmt_pos
"Error during typechecking, incompatible types:@\n\
@[<v>@{<blue>@<2>%s@} @[<hov>%a@]@,\
@{<blue>@<2>%s@} @[<hov>%a@]@]" "" pp_typ t1_repr "" pp_typ t2_repr
(** Raises an error if unification cannot be performed. The position annotation
of the second [unionfind_typ] argument is propagated (unless it is [TAny]). *)
@ -190,21 +241,21 @@ let rec unify
t2; *)
let t1_repr = UnionFind.get (UnionFind.find t1) in
let t2_repr = UnionFind.get (UnionFind.find t2) in
let raise_type_error () = raise (Type_error (A.AnyExpr e, t1, t2)) in
let record_type_error () = record_type_error ctx (A.AnyExpr e) t1 t2 in
let () =
match Mark.remove t1_repr, Mark.remove t2_repr with
| TLit tl1, TLit tl2 -> if tl1 <> tl2 then raise_type_error ()
| TLit tl1, TLit tl2 -> if tl1 <> tl2 then record_type_error ()
| TArrow (t11, t12), TArrow (t21, t22) -> (
unify e t12 t22;
try List.iter2 (unify e) t11 t21
with Invalid_argument _ -> raise_type_error ())
with Invalid_argument _ -> record_type_error ())
| TTuple ts1, TTuple ts2 -> (
try List.iter2 (unify e) ts1 ts2
with Invalid_argument _ -> raise_type_error ())
with Invalid_argument _ -> record_type_error ())
| TStruct s1, TStruct s2 ->
if not (A.StructName.equal s1 s2) then raise_type_error ()
if not (A.StructName.equal s1 s2) then record_type_error ()
| TEnum e1, TEnum e2 ->
if not (A.EnumName.equal e1 e2) then raise_type_error ()
if not (A.EnumName.equal e1 e2) then record_type_error ()
| TOption t1, TOption t2 -> unify e t1 t2
| TArray t1', TArray t2' -> unify e t1' t2'
| TDefault t1', TDefault t2' -> unify e t1' t2'
@ -213,62 +264,13 @@ let rec unify
| ( ( TLit _ | TArrow _ | TTuple _ | TStruct _ | TEnum _ | TOption _
| TArray _ | TDefault _ | TClosureEnv ),
_ ) ->
raise_type_error ()
record_type_error ()
in
ignore
@@ UnionFind.merge
(fun t1 t2 -> match Mark.remove t2 with TAny _ -> t1 | _ -> t2)
t1 t2
let handle_type_error ctx (A.AnyExpr e) t1 t2 =
(* TODO: if we get weird error messages, then it means that we should use the
persistent version of the union-find data structure. *)
let t1_repr = UnionFind.get (UnionFind.find t1) in
let t2_repr = UnionFind.get (UnionFind.find t2) in
let e_pos = Expr.pos e in
let t1_pos = Mark.get t1_repr in
let t2_pos = Mark.get t2_repr in
let fmt_pos =
if e_pos = t1_pos then
[
( (fun ppf ->
Format.fprintf ppf "@[<hv 2>@[<hov>%a@ %a@]:" Format.pp_print_text
"This expression has type" (format_typ ctx) t1;
if Global.options.debug then
Format.fprintf ppf "@ %a@]" Expr.format e
else Format.pp_close_box ppf ()),
e_pos );
( (fun ppf ->
Format.fprintf ppf
"@[<hov>Expected@ type@ %a@ coming@ from@ expression:@]"
(format_typ ctx) t2),
t2_pos );
]
else
[
( (fun ppf ->
Format.fprintf ppf "@[<hv 2>@[<hov>%a:@]" Format.pp_print_text
"While typechecking the following expression";
if Global.options.debug then
Format.fprintf ppf "@ %a@]" Expr.format e
else Format.pp_close_box ppf ()),
e_pos );
( (fun ppf ->
Format.fprintf ppf "@[<hov>Type@ %a@ is@ coming@ from:@]"
(format_typ ctx) t1),
t1_pos );
( (fun ppf ->
Format.fprintf ppf "@[<hov>Type@ %a@ is@ coming@ from:@]"
(format_typ ctx) t2),
t2_pos );
]
in
Message.error ~fmt_pos
"Error during typechecking, incompatible types:@\n\
@[<v>@{<blue>@<2>%s@} @[<hov>%a@]@,\
@{<blue>@<2>%s@} @[<hov>%a@]@]" "" (format_typ ctx) t1 ""
(format_typ ctx) t2
let lit_type (lit : A.lit) : naked_typ =
match lit with
| LBool _ -> TLit TBool
@ -292,7 +294,6 @@ let polymorphic_op_type (op : Operator.polymorphic A.operator Mark.pos) :
let any2 = lazy (UnionFind.make (TAny (Any.fresh ()), pos)) in
let any3 = lazy (UnionFind.make (TAny (Any.fresh ()), pos)) in
let bt = lazy (UnionFind.make (TLit TBool, pos)) in
let ut = lazy (UnionFind.make (TLit TUnit, pos)) in
let it = lazy (UnionFind.make (TLit TInt, pos)) in
let cet = lazy (UnionFind.make (TClosureEnv, pos)) in
let array a = lazy (UnionFind.make (TArray (Lazy.force a), pos)) in
@ -312,9 +313,7 @@ let polymorphic_op_type (op : Operator.polymorphic A.operator Mark.pos) :
| Log (PosRecordIfTrueBool, _) -> [bt] @-> bt
| Log _ -> [any] @-> any
| Length -> [array any] @-> it
| HandleDefault -> [array ([ut] @-> any); [ut] @-> bt; [ut] @-> any] @-> any
| HandleDefaultOpt ->
[array (option any); [ut] @-> bt; [ut] @-> option any] @-> option any
| HandleExceptions -> [array (option any)] @-> option any
| ToClosureEnv -> [any] @-> cet
| FromClosureEnv -> [cet] @-> any
in
@ -346,7 +345,7 @@ let polymorphic_op_return_type
| Log (PosRecordIfTrueBool, _), _ -> uf (TLit TBool)
| Log _, [tau] -> tau
| Length, _ -> uf (TLit TInt)
| (HandleDefault | HandleDefaultOpt), [_; _; tf] -> return_type tf 1
| HandleExceptions, [_] -> any ()
| ToClosureEnv, _ -> uf TClosureEnv
| FromClosureEnv, _ -> any ()
| _ -> Message.error ~pos "Mismatched operator arguments"
@ -616,7 +615,10 @@ and typecheck_expr_top_down :
"Variable @{<yellow>%s@} is not a declared output of scope %a."
field A.ScopeName.format scope_out
~suggestion:
(List.map A.StructField.to_string (A.StructField.Map.keys str))
(Suggestions.sorted_candidates
(List.map A.StructField.to_string
(A.StructField.Map.keys str))
field)
| None ->
Message.error
~extra_pos:
@ -627,7 +629,10 @@ and typecheck_expr_top_down :
"Field@ @{<yellow>\"%s\"@}@ does@ not@ belong@ to@ structure@ \
@{<yellow>\"%a\"@}."
field A.StructName.format name
~suggestion:(A.Ident.Map.keys ctx.ctx_struct_fields))
~suggestion:
(Suggestions.sorted_candidates
(A.Ident.Map.keys ctx.ctx_struct_fields)
field))
in
try A.StructName.Map.find name candidate_structs
with A.StructName.Map.Not_found _ ->
@ -758,13 +763,7 @@ and typecheck_expr_top_down :
(sub_scope, typecheck_scope_call_args sub_scope sub_args))
args
in
Expr.escopecall ~scope ~args:(typecheck_scope_call_args scope args) mark
| A.ERaiseEmpty -> Expr.eraiseempty context_mark
| A.ECatchEmpty { body; handler } ->
let body' = typecheck_expr_top_down ctx env tau body in
let handler' = typecheck_expr_top_down ctx env tau handler in
Expr.ecatchempty body' handler' context_mark
| A.EVar v ->
let tau' =
match Env.get env v with
@ -895,15 +894,17 @@ and typecheck_expr_top_down :
let args =
Operator.kind_dispatch (Mark.set pos_e op)
~polymorphic:(fun op ->
(* Type the operator first, then right-to-left: polymorphic operators
are required to allow the resolution of all type variables this
way *)
if not env.flags.assume_op_types then
unify ctx e (polymorphic_op_type op) t_func
else unify ctx e (polymorphic_op_return_type ctx e op t_args) tau;
List.rev_map2
(typecheck_expr_top_down ctx env)
(List.rev t_args) (List.rev args))
if env.flags.assume_op_types then (
unify ctx e (polymorphic_op_return_type ctx e op t_args) tau;
List.rev_map (typecheck_expr_bottom_up ctx env) (List.rev args))
else (
(* Type the operator first, then right-to-left: polymorphic
operators are required to allow the resolution of all type
variables this way *)
unify ctx e (polymorphic_op_type op) t_func;
List.rev_map2
(typecheck_expr_top_down ctx env)
(List.rev t_args) (List.rev args)))
~overloaded:(fun op ->
(* Typing the arguments first is required to resolve the operator *)
let args' = List.map2 (typecheck_expr_top_down ctx env) t_args args in
@ -966,18 +967,6 @@ and typecheck_expr_top_down :
in
Expr.ecustom obj targs tret mark
let wrap ctx f e =
try f e
with Type_error (e, ty1, ty2) -> (
let bt = Printexc.get_raw_backtrace () in
try handle_type_error ctx e ty1 ty2
with e -> Printexc.raise_with_backtrace e bt)
let wrap_expr ctx f e =
(* We need to unbox here, because the typing may otherwise be stored in
Bindlib closures and not yet applied, and would escape the `try..with` *)
wrap ctx (fun e -> Expr.unbox (f e)) e
(** {1 API} *)
let get_ty_mark ~flags (A.Custom { A.custom = uf; pos }) =
@ -994,7 +983,7 @@ let expr_raw
| None -> typecheck_expr_bottom_up ctx env
| Some typ -> typecheck_expr_top_down ctx env (ast_to_typ typ)
in
wrap_expr ctx fty e
Expr.unbox (fty e)
let check_expr ctx ?env ?typ e =
Expr.map_marks
@ -1009,14 +998,14 @@ let scope_body_expr ctx env ty_out body_expr =
let _env, ret =
BoundList.fold_map body_expr ~init:env
~last:(fun env e ->
let e' = wrap_expr ctx (typecheck_expr_top_down ctx env ty_out) e in
let e' = Expr.unbox (typecheck_expr_top_down ctx env ty_out e) in
let e' = Expr.map_marks ~f:(get_ty_mark ~flags:env.flags) e' in
env, Expr.Box.lift e')
~f:(fun env var scope ->
let e0 = scope.A.scope_let_expr in
let ty_e = ast_to_typ scope.A.scope_let_typ in
let e = wrap_expr ctx (typecheck_expr_bottom_up ctx env) e0 in
wrap ctx (fun t -> unify ctx e0 (ty e) t) ty_e;
let e = Expr.unbox (typecheck_expr_bottom_up ctx env e0) in
unify ctx e0 (ty e) ty_e;
(* We could use [typecheck_expr_top_down] rather than this manual
unification, but we get better messages with this order of the
[unify] parameters, which keeps location of the type as defined
@ -1114,3 +1103,26 @@ let program ?fail_on_any ?assume_op_types prg =
prg.decl_ctx.ctx_enums;
};
}
let program ?fail_on_any ?assume_op_types ?(internal_check = false) prg =
let wrap =
if internal_check then (fun f ->
try Message.with_delayed_errors f
with (Message.CompilerError _ | Message.CompilerErrors _) as exc ->
let bt = Printexc.get_raw_backtrace () in
let err =
match exc with
| Message.CompilerError err ->
Message.CompilerError (Message.Content.to_internal_error err)
| Message.CompilerErrors errs ->
Message.CompilerErrors
(List.map Message.Content.to_internal_error errs)
| _ -> assert false
in
Message.debug "Faulty intermediate program:@ %a"
(Print.program ~debug:true)
prg;
Printexc.raise_with_backtrace err bt)
else fun f -> Message.with_delayed_errors f
in
wrap @@ fun () -> program ?fail_on_any ?assume_op_types prg

View File

@ -97,11 +97,15 @@ val check_expr :
val program :
?fail_on_any:bool ->
?assume_op_types:bool ->
?internal_check:bool ->
('a, 'm) gexpr program ->
('a, typed) gexpr program
(** Typing on whole programs (as defined in Shared_ast.program, i.e. for the
later dcalc/lcalc stages.
later dcalc/lcalc stages).
Any existing type annotations are checked for unification. Use
[Program.untype] to remove them beforehand if this is not the desired
behaviour. *)
behaviour.
If [internal_check] is set to [true], typing errors will be marked as
internal, and the faulty program will be printed if '--debug' is set. *)

View File

@ -100,6 +100,7 @@ module Map = struct
let empty = empty
let singleton v x = singleton (t v) x
let add v x m = add (t v) x m
let remove v m = remove (t v) m
let update v f m = update (t v) f m
let find v m = find (t v) m
let find_opt v m = find_opt (t v) m

View File

@ -64,6 +64,7 @@ module Map : sig
val empty : ('e, 'x) t
val singleton : 'e var -> 'x -> ('e, 'x) t
val add : 'e var -> 'x -> ('e, 'x) t -> ('e, 'x) t
val remove : 'e var -> ('e, 'x) t -> ('e, 'x) t
val update : 'e var -> ('x option -> 'x option) -> ('e, 'x) t -> ('e, 'x) t
val find : 'e var -> ('e, 'x) t -> 'x
val find_opt : 'e var -> ('e, 'x) t -> 'x option

View File

@ -318,7 +318,7 @@ and law_structure =
| CodeBlock of code_block * source_repr * bool (* Metadata if true *)
and interface = {
intf_modname : uident Mark.pos;
intf_modname : program_module;
intf_code : code_block;
(** Invariant: an interface shall only contain [*Decl] elements, or
[Topdef] elements with [topdef_expr = None] *)
@ -330,8 +330,10 @@ and module_use = {
mod_use_alias : uident Mark.pos;
}
and program_module = { module_name : uident Mark.pos; module_external : bool }
and program = {
program_module_name : uident Mark.pos option;
program_module : program_module option;
program_items : law_structure list;
program_source_files : (string[@opaque]) list;
program_used_modules : module_use list;

View File

@ -778,9 +778,6 @@ let lex_raw (lexbuf : lexbuf) : token =
| _ -> (
(* Nested match for lower priority; `_` matches length 0 so we effectively retry the
sub-match at the same point *)
let lexbuf = lexbuf in
(* workaround sedlex bug, see https://github.com/ocaml-community/sedlex/issues/12
(fixed in 3.1) *)
match%sedlex lexbuf with
| Star (Compl '\n'), ('\n' | eof) -> LAW_TEXT (Utf8.lexeme lexbuf)
| _ -> L.raise_lexer_error (Pos.from_lpos prev_pos) prev_lexeme)
@ -817,9 +814,6 @@ let lex_law (lexbuf : lexbuf) : token =
| _ -> (
(* Nested match for lower priority; `_` matches length 0 so we effectively retry the
sub-match at the same point *)
let lexbuf = lexbuf in
(* workaround sedlex bug, see https://github.com/ocaml-community/sedlex/issues/12
(fixed in 3.1) *)
match%sedlex lexbuf with
| Star (Compl '\n'), ('\n' | eof) -> LAW_TEXT (Utf8.lexeme lexbuf)
| _ -> L.raise_lexer_error (Pos.from_lpos prev_pos) prev_lexeme)

View File

@ -60,29 +60,6 @@ let rec law_struct_list_to_tree (f : Ast.law_structure list) :
let gobbled, rest_out = split_rest_tree rest_tree in
LawHeading (heading, gobbled) :: rest_out))
(** Usage: [raise_parser_error error_loc last_good_loc token msg]
Raises an error message featuring the [error_loc] position where the parser
has failed, the [token] on which the parser has failed, and the error
message [msg]. If available, displays [last_good_loc] the location of the
last token correctly parsed. *)
let raise_parser_error
?(suggestion : string list option)
(error_loc : Pos.t)
(last_good_loc : Pos.t option)
(token : string)
(msg : Format.formatter -> unit) : 'a =
Message.error ?suggestion
~extra_pos:
[
(match last_good_loc with
| None -> "Error token", error_loc
| Some last_good_loc -> "Last good token", last_good_loc);
]
"@[<hov>Syntax error at %a:@ %t@]"
(fun ppf string -> Format.fprintf ppf "@{<yellow>\"%s\"@}" string)
token msg
module ParserAux (LocalisedLexer : Lexer_common.LocalisedLexer) = struct
include Parser.Make (LocalisedLexer)
module I = MenhirInterpreter
@ -93,40 +70,12 @@ module ParserAux (LocalisedLexer : Lexer_common.LocalisedLexer) = struct
| MenhirLib.General.Nil -> 0
| MenhirLib.General.Cons (Element (s, _, _, _), _) -> I.number s
(** Usage: [fail lexbuf env token_list last_input_needed]
Raises an error with meaningful hints about what the parsing error was.
[lexbuf] is the lexing buffer state at the failure point, [env] is the
Menhir environment and [last_input_needed] is the last checkpoint of a
valid Menhir state before the parsing error. [token_list] is provided by
things like {!val: Surface.Lexer_common.token_list_language_agnostic} and
is used to provide suggestions of the tokens acceptable at the failure
point *)
let fail
let register_parsing_error
(lexbuf : lexbuf)
(env : 'semantic_value I.env)
(token_list : (string * Tokens.token) list)
(last_input_needed : 'semantic_value I.env option) : 'a =
let wrong_token = Utf8.lexeme lexbuf in
let acceptable_tokens, last_positions =
match last_input_needed with
| Some last_input_needed ->
( List.filter
(fun (_, t) ->
I.acceptable
(I.input_needed last_input_needed)
t
(fst (lexing_positions lexbuf)))
token_list,
Some (I.positions last_input_needed) )
| None -> token_list, None
in
let similar_acceptable_tokens =
Suggestions.suggestion_minimum_levenshtein_distance_association
(List.map (fun (s, _) -> s) acceptable_tokens)
wrong_token
in
(* The parser has suspended itself because of a syntax error. Stop. *)
(acceptable_tokens : (string * Tokens.token) list)
(similar_candidate_tokens : string list) : 'a =
(* The parser has suspended itself because of a syntax error. *)
let custom_menhir_message ppf =
(match Parser_errors.message (state env) with
| exception Not_found -> Format.fprintf ppf "@{<yellow>unexpected token@}"
@ -141,31 +90,163 @@ module ParserAux (LocalisedLexer : Lexer_common.LocalisedLexer) = struct
(fun ppf string -> Format.fprintf ppf "@{<yellow>\"%s\"@}" string))
(List.map (fun (s, _) -> s) acceptable_tokens)
in
raise_parser_error ~suggestion:similar_acceptable_tokens
(Pos.from_lpos (lexing_positions lexbuf))
(Option.map Pos.from_lpos last_positions)
(Utf8.lexeme lexbuf) custom_menhir_message
let suggestion =
if similar_candidate_tokens = [] then None
else Some similar_candidate_tokens
in
let error_loc = Pos.from_lpos (lexing_positions lexbuf) in
let wrong_token = Utf8.lexeme lexbuf in
let msg = custom_menhir_message in
Message.delayed_error () ?suggestion
~extra_pos:["", error_loc]
"@[<hov>Syntax error at %a:@ %t@]"
(fun ppf string -> Format.fprintf ppf "@{<yellow>\"%s\"@}" string)
wrong_token msg
let sorted_candidate_tokens lexbuf token_list env =
let acceptable_tokens =
List.filter_map
(fun ((_, t) as elt) ->
if I.acceptable (I.input_needed env) t (fst (lexing_positions lexbuf))
then Some elt
else None)
token_list
in
let lexeme = Utf8.lexeme lexbuf in
let similar_acceptable_tokens =
Suggestions.best_candidates (List.map fst acceptable_tokens) lexeme
in
let module S = Set.Make (String) in
let s_toks = S.of_list similar_acceptable_tokens in
let sorted_acceptable_tokens =
List.sort
(fun (s, _) _ -> if S.mem s s_toks then -1 else 1)
acceptable_tokens
in
similar_acceptable_tokens, sorted_acceptable_tokens
type 'a ring_buffer = {
curr_idx : int;
start : int ref;
stop : int ref;
max_size : int;
feed : unit -> 'a;
data : 'a array;
}
let next ({ curr_idx; start; stop; max_size; feed; data } as buff) =
let next_idx = succ curr_idx mod max_size in
if curr_idx = !stop then (
let new_elt = feed () in
data.(curr_idx) <- new_elt;
let size = ((!stop - !start + max_size) mod max_size) + 1 in
stop := succ !stop mod max_size;
let is_full = size = max_size in
if is_full then
(* buffer will get full: start is also moved *)
start := succ !start mod max_size;
{ buff with curr_idx = next_idx }, new_elt)
else
let elt = data.(curr_idx) in
{ buff with curr_idx = next_idx }, elt
let create ?(max_size = 20) feed v =
{
curr_idx = 0;
start = ref 0;
stop = ref 0;
feed;
data = Array.make max_size v;
max_size;
}
let progress ?(max_step = 10) lexer_buffer env checkpoint : int =
let rec loop nth_step lexer_buffer env checkpoint =
if nth_step >= max_step then nth_step
else
match checkpoint with
| I.InputNeeded env ->
let new_lexer_buffer, token = next lexer_buffer in
let checkpoint = I.offer checkpoint token in
loop (succ nth_step) new_lexer_buffer env checkpoint
| I.Shifting _ | I.AboutToReduce _ ->
let checkpoint = I.resume checkpoint in
loop nth_step lexer_buffer env checkpoint
| I.HandlingError (_ : _ I.env) | I.Accepted _ | I.Rejected -> nth_step
in
loop 0 lexer_buffer env checkpoint
let recover_parsing_error lexer_buffer env acceptable_tokens =
let candidates_checkpoints =
let without_token = I.input_needed env in
let make_with_token tok =
let l, r = I.positions env in
let checkpoint = I.input_needed env in
I.offer checkpoint (tok, l, r)
in
without_token :: List.map make_with_token acceptable_tokens
in
let threshold = min 10 lexer_buffer.max_size in
let rec iterate ((curr_max_progress, _) as acc) = function
| [] -> acc
| cp :: t ->
if curr_max_progress >= 10 then acc
else
let cp_progress = progress ~max_step:threshold lexer_buffer env cp in
if cp_progress > curr_max_progress then iterate (cp_progress, cp) t
else iterate acc t
in
let best_progress, best_cp =
let dummy_cp = I.input_needed env in
iterate (-1, dummy_cp) candidates_checkpoints
in
(* We do not consider paths were progress isn't significant *)
if best_progress < 2 then None else Some best_cp
(** Main parsing loop *)
let rec loop
(next_token : unit -> Tokens.token * Lexing.position * Lexing.position)
let loop
(lexer_buffer :
(Tokens.token * Lexing.position * Lexing.position) ring_buffer)
(token_list : (string * Tokens.token) list)
(lexbuf : lexbuf)
(last_input_needed : 'semantic_value I.env option)
(checkpoint : 'semantic_value I.checkpoint) : Ast.source_file =
match checkpoint with
| I.InputNeeded env ->
let token = next_token () in
let checkpoint = I.offer checkpoint token in
loop next_token token_list lexbuf (Some env) checkpoint
| I.Shifting _ | I.AboutToReduce _ ->
let checkpoint = I.resume checkpoint in
loop next_token token_list lexbuf last_input_needed checkpoint
| I.HandlingError env -> fail lexbuf env token_list last_input_needed
| I.Accepted v -> v
| I.Rejected ->
(* Cannot happen as we stop at syntax error immediatly *)
assert false
let rec loop
(lexer_buffer :
(Tokens.token * Lexing.position * Lexing.position) ring_buffer)
(token_list : (string * Tokens.token) list)
(lexbuf : lexbuf)
(last_input_needed : 'semantic_value I.env option)
(checkpoint : 'semantic_value I.checkpoint) : Ast.source_file =
match checkpoint with
| I.InputNeeded env ->
let new_lexer_buffer, token = next lexer_buffer in
let checkpoint = I.offer checkpoint token in
loop new_lexer_buffer token_list lexbuf (Some env) checkpoint
| I.Shifting _ | I.AboutToReduce _ ->
let checkpoint = I.resume checkpoint in
loop lexer_buffer token_list lexbuf last_input_needed checkpoint
| I.HandlingError (env : 'semantic_value I.env) -> (
let similar_candidate_tokens, sorted_acceptable_tokens =
sorted_candidate_tokens lexbuf token_list env
in
register_parsing_error lexbuf env sorted_acceptable_tokens
similar_candidate_tokens;
let best_effort_checkpoint =
recover_parsing_error lexer_buffer env
(List.map snd sorted_acceptable_tokens)
in
match best_effort_checkpoint with
| None ->
(* No reasonable solution, aborting *)
[]
| Some best_effort_checkpoint ->
loop lexer_buffer token_list lexbuf last_input_needed
best_effort_checkpoint)
| I.Accepted v -> v
| I.Rejected -> []
in
loop lexer_buffer token_list lexbuf last_input_needed checkpoint
(** Stub that wraps the parsing main loop and handles the Menhir/Sedlex type
difference for [lexbuf]. *)
@ -174,12 +255,17 @@ module ParserAux (LocalisedLexer : Lexer_common.LocalisedLexer) = struct
(token_list : (string * Tokens.token) list)
(target_rule : Lexing.position -> 'semantic_value I.checkpoint)
(lexbuf : lexbuf) : Ast.source_file =
let lexer : unit -> Tokens.token * Lexing.position * Lexing.position =
with_tokenizer lexer' lexbuf
let lexer_buffer :
(Tokens.token * Lexing.position * Lexing.position) ring_buffer =
let feed = with_tokenizer lexer' lexbuf in
create feed Lexing.(Tokens.EOF, dummy_pos, dummy_pos)
in
try
loop lexer token_list lexbuf None
(target_rule (fst @@ Sedlexing.lexing_positions lexbuf))
let target_rule =
target_rule (fst @@ Sedlexing.lexing_positions lexbuf)
in
Message.with_delayed_errors
@@ fun () -> loop lexer_buffer token_list lexbuf None target_rule
with Sedlexing.MalFormed | Sedlexing.InvalidCodepoint _ ->
Lexer_common.raise_lexer_error
(Pos.from_lpos (lexing_positions lexbuf))
@ -215,12 +301,13 @@ let lines (file : File.t) (language : Global.backend_lang) =
Sedlexing.set_filename lexbuf file;
let rec aux () =
match lex_line lexbuf with
| Some line -> Seq.Cons (line, aux)
| Some (str, tok) ->
Seq.Cons ((str, tok, Sedlexing.lexing_bytes_positions lexbuf), aux)
| None ->
close_in input;
Seq.Nil
in
aux
Seq.once aux
with exc ->
let bt = Printexc.get_raw_backtrace () in
close_in input;
@ -259,18 +346,21 @@ and expand_includes (source_file : string) (commands : Ast.law_structure list) :
List.fold_left
(fun acc command ->
let join_module_names name_opt =
match acc.Ast.program_module_name, name_opt with
match acc.Ast.program_module, name_opt with
| opt, None | None, opt -> opt
| Some id1, Some id2 ->
Message.error
~extra_pos:["", Mark.get id1; "", Mark.get id2]
~extra_pos:
["", Mark.get id1.module_name; "", Mark.get id2.module_name]
"Multiple definitions of the module name"
in
match command with
| Ast.ModuleDef (id, _) ->
| Ast.ModuleDef (id, is_external) ->
{
acc with
Ast.program_module_name = join_module_names (Some id);
Ast.program_module =
join_module_names
(Some { module_name = id; module_external = is_external });
Ast.program_items = command :: acc.Ast.program_items;
}
| Ast.ModuleUse (mod_use_name, alias) ->
@ -288,22 +378,22 @@ and expand_includes (source_file : string) (commands : Ast.law_structure list) :
@@ fun lexbuf ->
let includ_program = parse_source lexbuf in
let () =
includ_program.Ast.program_module_name
includ_program.Ast.program_module
|> Option.iter
@@ fun id ->
Message.error
~extra_pos:
[
"File include", Mark.get inc_file;
"Module declaration", Mark.get id;
"Module declaration", Mark.get id.Ast.module_name;
]
"A file that declares a module cannot be used through the raw \
'@{<yellow>> Include@}'@ directive.@ You should use it as a \
module with@ '@{<yellow>> Use @{<blue>%s@}@}'@ instead."
(Mark.remove id)
(Mark.remove id.Ast.module_name)
in
{
Ast.program_module_name = acc.program_module_name;
Ast.program_module = acc.program_module;
Ast.program_source_files =
List.rev_append includ_program.program_source_files
acc.Ast.program_source_files;
@ -316,7 +406,7 @@ and expand_includes (source_file : string) (commands : Ast.law_structure list) :
}
| Ast.LawHeading (heading, commands') ->
let {
Ast.program_module_name;
Ast.program_module;
Ast.program_items = commands';
Ast.program_source_files = new_sources;
Ast.program_used_modules = new_used_modules;
@ -325,7 +415,7 @@ and expand_includes (source_file : string) (commands : Ast.law_structure list) :
expand_includes source_file commands'
in
{
Ast.program_module_name = join_module_names program_module_name;
Ast.program_module = join_module_names program_module;
Ast.program_source_files =
List.rev_append new_sources acc.Ast.program_source_files;
Ast.program_items =
@ -336,7 +426,7 @@ and expand_includes (source_file : string) (commands : Ast.law_structure list) :
}
| i -> { acc with Ast.program_items = i :: acc.Ast.program_items })
{
Ast.program_module_name = None;
Ast.program_module = None;
Ast.program_source_files = [];
Ast.program_items = [];
Ast.program_used_modules = [];
@ -346,7 +436,7 @@ and expand_includes (source_file : string) (commands : Ast.law_structure list) :
in
{
Ast.program_lang = language;
Ast.program_module_name = rprg.Ast.program_module_name;
Ast.program_module = rprg.Ast.program_module;
Ast.program_source_files = List.rev rprg.Ast.program_source_files;
Ast.program_items = List.rev rprg.Ast.program_items;
Ast.program_used_modules = List.rev rprg.Ast.program_used_modules;
@ -396,8 +486,8 @@ let with_sedlex_source source_file f =
f lexbuf
let check_modname program source_file =
match program.Ast.program_module_name, source_file with
| ( Some (mname, pos),
match program.Ast.program_module, source_file with
| ( Some { module_name = mname, pos; _ },
(Global.FileName file | Global.Contents (_, file) | Global.Stdin file) )
when not File.(equal mname Filename.(remove_extension (basename file))) ->
Message.error ~pos
@ -413,10 +503,14 @@ let load_interface ?default_module_name source_file =
let program = with_sedlex_source source_file parse_source in
check_modname program source_file;
let modname =
match program.Ast.program_module_name, default_module_name with
match program.Ast.program_module, default_module_name with
| Some mname, _ -> mname
| None, Some n ->
n, Pos.from_info (Global.input_src_file source_file) 0 0 0 0
{
module_name =
n, Pos.from_info (Global.input_src_file source_file) 0 0 0 0;
module_external = false;
}
| None, None ->
Message.error
"%a doesn't define a module name. It should contain a '@{<cyan>> \

View File

@ -20,7 +20,9 @@
open Catala_utils
val lines :
File.t -> Global.backend_lang -> (string * Lexer_common.line_token) Seq.t
File.t ->
Global.backend_lang ->
(string * Lexer_common.line_token * (Lexing.position * Lexing.position)) Seq.t
(** Raw file parser that doesn't interpret any includes and returns the flat law
structure as is *)

View File

@ -31,10 +31,10 @@ catala implementation and compile to OCaml (removing the `external` directive):
```
```shell-session
$ clerk build _build/.../Prorata_external.ml
$ clerk build _build/.../prorata_external.ml
```
(beware the `_build/`, and the capitalisation of the module name)
(beware the `_build/`, it is required here)
## Write the OCaml implementation
@ -44,9 +44,11 @@ capitalisation to match). Edit to replace the dummy implementation by your code.
Refer to `runtimes/ocaml/runtime.mli` for what is available (especially the
`Oper` module to manipulate the types).
Keep the `register_module` at the end as is, it's needed for the toplevel to use
the value (you would get `Failure("Could not resolve reference to Xxx")` during
evaluation).
Keep the `register_module` at the end, but replace the hash (which should be of
the form `"CM0|XXXXXXXX|XXXXXXXX|XXXXXXXX"`) by the string `"*external*"`. This
section is needed for the Catala interpreter to find the declared values --- the
error `Failure("Could not resolve reference to Xxx")` during evaluation is a
symptom that it is missing.
## Compile and test

View File

@ -18,11 +18,13 @@
\usepackage[a4paper,landscape,margin=1cm,includehead,headsep=2ex,nofoot]{geometry}
\usepackage{fancyhdr}
\usepackage{array}
\usepackage[none]{hyphenat}
\usepackage[document]{ragged2e}
\usemintedstyle{tango}
\hyphenpenalty=10000
\exhyphenpenalty=10000
\setsansfont{DejaVu Sans}[Scale=0.9]
\setmonofont{DejaVu Sans Mono}[Scale=0.9]

View File

@ -18,11 +18,13 @@
\usepackage[a4paper,landscape,margin=1cm,includehead,headsep=2ex,nofoot]{geometry}
\usepackage{fancyhdr}
\usepackage{array}
\usepackage[none]{hyphenat}
\usepackage[document]{ragged2e}
\usemintedstyle{tango}
\hyphenpenalty=10000
\exhyphenpenalty=10000
\setsansfont{DejaVu Sans}[Scale=0.9]
\setmonofont{DejaVu Sans Mono}[Scale=0.9]

4
dune
View File

@ -1,6 +1,6 @@
(dirs runtimes compiler build_system)
(dirs runtimes compiler build_system tests)
(data_only_dirs tests syntax_highlighting)
(data_only_dirs syntax_highlighting)
(vendored_dirs catala-examples.tmp french-law.tmp)

View File

@ -33,6 +33,22 @@ catala_fatal_error catala_fatal_error_raised;
jmp_buf catala_fatal_error_jump_buffer;
void catala_raise_fatal_error(catala_fatal_error_code code,
char *filename,
unsigned int start_line,
unsigned int start_column,
unsigned int end_line,
unsigned int end_column)
{
catala_fatal_error_raised.code = code;
catala_fatal_error_raised.position.filename = filename;
catala_fatal_error_raised.position.start_line = start_line;
catala_fatal_error_raised.position.start_column = start_column;
catala_fatal_error_raised.position.end_line = end_line;
catala_fatal_error_raised.position.end_column = end_column;
longjmp(catala_fatal_error_jump_buffer, 0);
}
typedef struct pointer_list pointer_list;
struct pointer_list
{

Some files were not shown because too many files have changed in this diff Show More