Big reformatting

ocamlformat 0.19.0 -> 0.20.1
100 -> 80 columns per line
Reestablished @emilerolley's smart fun break
This commit is contained in:
Denis Merigoux 2022-03-08 15:03:14 +01:00
parent 65a5a42c16
commit 5bd66142a6
No known key found for this signature in database
GPG Key ID: EE99DCFA365C3EE3
102 changed files with 7579 additions and 4271 deletions

View File

@ -1,15 +1,18 @@
(* This file is part of the Catala build system, a specification language for tax and social (* This file is part of the Catala build system, a specification language for
benefits computation rules. Copyright (C) 2020 Inria, contributors: Denis Merigoux tax and social benefits computation rules. Copyright (C) 2020 Inria,
<denis.merigoux@inria.fr>, Emile Rolley <emile.rolley@tuta.io> contributors: Denis Merigoux <denis.merigoux@inria.fr>, Emile Rolley
<emile.rolley@tuta.io>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
let () = Clerk_driver.main () let () = Clerk_driver.main ()

View File

@ -1,15 +1,18 @@
(* This file is part of the Catala build system, a specification language for tax and social (* This file is part of the Catala build system, a specification language for
benefits computation rules. Copyright (C) 2020 Inria, contributors: Denis Merigoux tax and social benefits computation rules. Copyright (C) 2020 Inria,
<denis.merigoux@inria.fr>, Emile Rolley <emile.rolley@tuta.io> contributors: Denis Merigoux <denis.merigoux@inria.fr>, Emile Rolley
<emile.rolley@tuta.io>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
open Cmdliner open Cmdliner
@ -21,7 +24,8 @@ module Nj = Ninja_utils
let files_or_folders = let files_or_folders =
Arg.( Arg.(
non_empty & pos_right 0 file [] & info [] ~docv:"FILE(S)" ~doc:"File(s) or folder(s) to process") non_empty & pos_right 0 file []
& info [] ~docv:"FILE(S)" ~doc:"File(s) or folder(s) to process")
let command = let command =
Arg.( Arg.(
@ -29,21 +33,23 @@ let command =
& pos 0 (some string) None & pos 0 (some string) None
& info [] ~docv:"COMMAND" ~doc:"Command selection among: test, run") & info [] ~docv:"COMMAND" ~doc:"Command selection among: test, run")
let debug = Arg.(value & flag & info [ "debug"; "d" ] ~doc:"Prints debug information") let debug =
Arg.(value & flag & info [ "debug"; "d" ] ~doc:"Prints debug information")
let reset_test_outputs = let reset_test_outputs =
Arg.( Arg.(
value & flag value & flag
& info [ "r"; "reset" ] & info [ "r"; "reset" ]
~doc: ~doc:
"Used with the `test` command, resets the test output to whatever is output by the \ "Used with the `test` command, resets the test output to whatever is \
Catala compiler.") output by the Catala compiler.")
let catalac = let catalac =
Arg.( Arg.(
value value
& opt (some string) None & opt (some string) None
& info [ "e"; "exe" ] ~docv:"EXE" ~doc:"Catala compiler executable, defaults to `catala`") & info [ "e"; "exe" ] ~docv:"EXE"
~doc:"Catala compiler executable, defaults to `catala`")
let ninja_output = let ninja_output =
Arg.( Arg.(
@ -51,22 +57,25 @@ let ninja_output =
& opt (some string) None & opt (some string) None
& info [ "o"; "output" ] ~docv:"OUTPUT" & info [ "o"; "output" ] ~docv:"OUTPUT"
~doc: ~doc:
"$(i, OUTPUT) is the file that will contain the build.ninja file output. If not \ "$(i, OUTPUT) is the file that will contain the build.ninja file \
specified, the build.ninja file will be outputed in the temporary directory of the \ output. If not specified, the build.ninja file will be outputed in \
system.") the temporary directory of the system.")
let scope = let scope =
Arg.( Arg.(
value value
& opt (some string) None & opt (some string) None
& info [ "s"; "scope" ] ~docv:"SCOPE" & info [ "s"; "scope" ] ~docv:"SCOPE"
~doc:"Used with the `run` command, selects which scope of a given Catala file to run.") ~doc:
"Used with the `run` command, selects which scope of a given Catala \
file to run.")
let catala_opts = let catala_opts =
Arg.( Arg.(
value value
& opt (some string) None & opt (some string) None
& info [ "c"; "catala-opts" ] ~docv:"LANG" ~doc:"Options to pass to the Catala compiler") & info [ "c"; "catala-opts" ] ~docv:"LANG"
~doc:"Options to pass to the Catala compiler")
let clerk_t f = let clerk_t f =
Term.( Term.(
@ -77,29 +86,34 @@ let version = "0.5.0"
let info = let info =
let doc = let doc =
"Build system for Catala, a specification language for tax and social benefits computation \ "Build system for Catala, a specification language for tax and social \
rules." benefits computation rules."
in in
let man = let man =
[ [
`S Manpage.s_description; `S Manpage.s_description;
`P `P
"$(b,clerk) is a build system for Catala, a specification language for tax and social \ "$(b,clerk) is a build system for Catala, a specification language for \
benefits computation rules"; tax and social benefits computation rules";
`S Manpage.s_commands; `S Manpage.s_commands;
`I `I
( "test", ( "test",
"Tests a Catala source file given expected outputs provided in a directory called \ "Tests a Catala source file given expected outputs provided in a \
`output` at the same level that the tested file. If the tested file is `foo.catala_en`, \ directory called `output` at the same level that the tested file. \
then `output` should contain expected output files like `foo.catala_en.$(i,BACKEND)` \ If the tested file is `foo.catala_en`, then `output` should contain \
where $(i,BACKEND) is chosen among: `Interpret`, `Dcalc`, `Scalc`, `Lcalc`, \ expected output files like `foo.catala_en.$(i,BACKEND)` where \
`Typecheck, `Scopelang`, `html`, `tex`, `py`, `ml` and `d` (for Makefile dependencies). \ $(i,BACKEND) is chosen among: `Interpret`, `Dcalc`, `Scalc`, \
For the `Interpret` backend, the scope to test is selected by naming the expected \ `Lcalc`, `Typecheck, `Scopelang`, `html`, `tex`, `py`, `ml` and `d` \
output file `foo.catala_en.$(i,SCOPE).interpret`. When the argument of $(b,clerk) is a \ (for Makefile dependencies). For the `Interpret` backend, the scope \
folder, it recursively looks for Catala files coupled with `output` directories and \ to test is selected by naming the expected output file \
matching expected output on which to perform tests." ); `foo.catala_en.$(i,SCOPE).interpret`. When the argument of \
$(b,clerk) is a folder, it recursively looks for Catala files \
coupled with `output` directories and matching expected output on \
which to perform tests." );
`I `I
("run", "Runs the Catala interpreter on a given scope of a given file. See the `-s` option."); ( "run",
"Runs the Catala interpreter on a given scope of a given file. See \
the `-s` option." );
`S Manpage.s_authors; `S Manpage.s_authors;
`P "Denis Merigoux <denis.merigoux@inria.fr>"; `P "Denis Merigoux <denis.merigoux@inria.fr>";
`P "Emile Rolley <emile.rolley@tuta.io>"; `P "Emile Rolley <emile.rolley@tuta.io>";
@ -107,7 +121,8 @@ let info =
`P "Typical usage:"; `P "Typical usage:";
`Pre "clerk test file.catala_en"; `Pre "clerk test file.catala_en";
`S Manpage.s_bugs; `S Manpage.s_bugs;
`P "Please file bug reports at https://github.com/CatalaLang/catala/issues"; `P
"Please file bug reports at https://github.com/CatalaLang/catala/issues";
] ]
in in
let exits = Term.default_exits @ [ Term.exit_info ~doc:"on error." 1 ] in let exits = Term.default_exits @ [ Term.exit_info ~doc:"on error." 1 ] in
@ -140,8 +155,8 @@ type expected_output_descr = {
let catala_suffix_regex = Re.Pcre.regexp "\\.catala_(\\w){2}" let catala_suffix_regex = Re.Pcre.regexp "\\.catala_(\\w){2}"
let filename_to_expected_output_descr (output_dir : string) (filename : string) : let filename_to_expected_output_descr (output_dir : string) (filename : string)
expected_output_descr option = : expected_output_descr option =
let complete_filename = filename in let complete_filename = filename in
let first_extension = Filename.extension filename in let first_extension = Filename.extension filename in
let filename = Filename.remove_extension filename in let filename = Filename.remove_extension filename in
@ -166,16 +181,19 @@ let filename_to_expected_output_descr (output_dir : string) (filename : string)
| Some backend -> | Some backend ->
let second_extension = Filename.extension filename in let second_extension = Filename.extension filename in
let base_filename, scope = let base_filename, scope =
if Re.Pcre.pmatch ~rex:catala_suffix_regex second_extension then (filename, None) if Re.Pcre.pmatch ~rex:catala_suffix_regex second_extension then
(filename, None)
else else
let scope_name_regex = Re.Pcre.regexp "\\.(.+)" in let scope_name_regex = Re.Pcre.regexp "\\.(.+)" in
let scope_name = (Re.Pcre.extract ~rex:scope_name_regex second_extension).(1) in let scope_name =
(Re.Pcre.extract ~rex:scope_name_regex second_extension).(1)
in
(Filename.remove_extension filename, Some scope_name) (Filename.remove_extension filename, Some scope_name)
in in
Some { output_dir; complete_filename; base_filename; backend; scope } Some { output_dir; complete_filename; base_filename; backend; scope }
(** [readdir_sort dirname] returns the sorted subdirectories of [dirname] in an array or an empty (** [readdir_sort dirname] returns the sorted subdirectories of [dirname] in an
array if the [dirname] doesn't exist. *) array or an empty array if the [dirname] doesn't exist. *)
let readdir_sort (dirname : string) : string array = let readdir_sort (dirname : string) : string array =
try try
let dirs = Sys.readdir dirname in let dirs = Sys.readdir dirname in
@ -183,8 +201,8 @@ let readdir_sort (dirname : string) : string array =
dirs dirs
with Sys_error _ -> Array.make 0 "" with Sys_error _ -> Array.make 0 ""
(** Given a file, looks in the relative [output] directory if there are files with the same base (** Given a file, looks in the relative [output] directory if there are files
name that contain expected outputs for different *) with the same base name that contain expected outputs for different *)
let search_for_expected_outputs (file : string) : expected_output_descr list = let search_for_expected_outputs (file : string) : expected_output_descr list =
let output_dir = Filename.dirname file ^ Filename.dir_sep ^ "output/" in let output_dir = Filename.dirname file ^ Filename.dir_sep ^ "output/" in
let output_files = readdir_sort output_dir in let output_files = readdir_sort output_dir in
@ -193,13 +211,17 @@ let search_for_expected_outputs (file : string) : expected_output_descr list =
match filename_to_expected_output_descr output_dir output_file with match filename_to_expected_output_descr output_dir output_file with
| None -> None | None -> None
| Some expected_output -> | Some expected_output ->
if expected_output.base_filename = Filename.basename file then Some expected_output if expected_output.base_filename = Filename.basename file then
Some expected_output
else None) else None)
(Array.to_list output_files) (Array.to_list output_files)
let add_reset_rules_aux ~(redirect : string) ~(with_scope_output_rule : string) let add_reset_rules_aux
~(without_scope_output_rule : string) (catala_exe_opts : string) (rules : Rule.t Nj.RuleMap.t) : ~(redirect : string)
Rule.t Nj.RuleMap.t = ~(with_scope_output_rule : string)
~(without_scope_output_rule : string)
(catala_exe_opts : string)
(rules : Rule.t Nj.RuleMap.t) : Rule.t Nj.RuleMap.t =
let reset_common_cmd_exprs = let reset_common_cmd_exprs =
Nj.Expr. Nj.Expr.
[ [
@ -215,7 +237,10 @@ let add_reset_rules_aux ~(redirect : string) ~(with_scope_output_rule : string)
let reset_with_scope_rule = let reset_with_scope_rule =
Nj.Rule.make with_scope_output_rule Nj.Rule.make with_scope_output_rule
~command: ~command:
Nj.Expr.(Seq ([ Lit catala_exe_opts; Lit "-s"; Var "scope" ] @ reset_common_cmd_exprs)) Nj.Expr.(
Seq
([ Lit catala_exe_opts; Lit "-s"; Var "scope" ]
@ reset_common_cmd_exprs))
~description: ~description:
Nj.Expr.( Nj.Expr.(
Seq Seq
@ -249,13 +274,19 @@ let add_reset_rules_aux ~(redirect : string) ~(with_scope_output_rule : string)
|> add reset_with_scope_rule.name reset_with_scope_rule |> add reset_with_scope_rule.name reset_with_scope_rule
|> add reset_without_scope_rule.name reset_without_scope_rule) |> add reset_without_scope_rule.name reset_without_scope_rule)
let add_test_rules_aux ~(test_common_cmd_exprs : Expr.t list) ~(with_scope_output_rule : string) let add_test_rules_aux
~(without_scope_output_rule : string) (catala_exe_opts : string) (rules : Rule.t Nj.RuleMap.t) : ~(test_common_cmd_exprs : Expr.t list)
Rule.t Nj.RuleMap.t = ~(with_scope_output_rule : string)
~(without_scope_output_rule : string)
(catala_exe_opts : string)
(rules : Rule.t Nj.RuleMap.t) : Rule.t Nj.RuleMap.t =
let test_with_scope_rule = let test_with_scope_rule =
Nj.Rule.make with_scope_output_rule Nj.Rule.make with_scope_output_rule
~command: ~command:
Nj.Expr.(Seq ([ Lit catala_exe_opts; Lit "-s"; Var "scope" ] @ test_common_cmd_exprs)) Nj.Expr.(
Seq
([ Lit catala_exe_opts; Lit "-s"; Var "scope" ]
@ test_common_cmd_exprs))
~description: ~description:
Nj.Expr.( Nj.Expr.(
Seq Seq
@ -276,7 +307,11 @@ let add_test_rules_aux ~(test_common_cmd_exprs : Expr.t list) ~(with_scope_outpu
Nj.Expr.( Nj.Expr.(
Seq Seq
[ [
Lit "TEST on file"; Var "tested_file"; Lit "with the"; Var "catala_cmd"; Lit "command"; Lit "TEST on file";
Var "tested_file";
Lit "with the";
Var "catala_cmd";
Lit "command";
]) ])
in in
Nj.RuleMap.( Nj.RuleMap.(
@ -284,15 +319,18 @@ let add_test_rules_aux ~(test_common_cmd_exprs : Expr.t list) ~(with_scope_outpu
|> add test_with_scope_rule.name test_with_scope_rule |> add test_with_scope_rule.name test_with_scope_rule
|> add test_without_scope_rule.name test_without_scope_rule) |> add test_without_scope_rule.name test_without_scope_rule)
(** [add_reset_rules catala_exe_opts rules] adds ninja rules used to reset test files into [rules] (** [add_reset_rules catala_exe_opts rules] adds ninja rules used to reset test
and returns it.*) files into [rules] and returns it.*)
let add_reset_rules (catala_exe_opts : string) (rules : Rule.t Nj.RuleMap.t) : Rule.t Nj.RuleMap.t = let add_reset_rules (catala_exe_opts : string) (rules : Rule.t Nj.RuleMap.t) :
Rule.t Nj.RuleMap.t =
add_reset_rules_aux ~with_scope_output_rule:"reset_with_scope" add_reset_rules_aux ~with_scope_output_rule:"reset_with_scope"
~without_scope_output_rule:"reset_without_scope" ~redirect:">" catala_exe_opts rules ~without_scope_output_rule:"reset_without_scope" ~redirect:">"
catala_exe_opts rules
(** [add_test_rules catala_exe_opts rules] adds ninja rules used to test files into [rules] and (** [add_test_rules catala_exe_opts rules] adds ninja rules used to test files
returns it.*) into [rules] and returns it.*)
let add_test_rules (catala_exe_opts : string) (rules : Rule.t Nj.RuleMap.t) : Rule.t Nj.RuleMap.t = let add_test_rules (catala_exe_opts : string) (rules : Rule.t Nj.RuleMap.t) :
Rule.t Nj.RuleMap.t =
let test_common_cmd_exprs = let test_common_cmd_exprs =
Nj.Expr. Nj.Expr.
[ [
@ -305,19 +343,23 @@ let add_test_rules (catala_exe_opts : string) (rules : Rule.t Nj.RuleMap.t) : Ru
Lit "-"; Lit "-";
] ]
in in
add_test_rules_aux ~test_common_cmd_exprs ~with_scope_output_rule:"test_with_scope" add_test_rules_aux ~test_common_cmd_exprs
~with_scope_output_rule:"test_with_scope"
~without_scope_output_rule:"test_without_scope" catala_exe_opts rules ~without_scope_output_rule:"test_without_scope" catala_exe_opts rules
(** [add_reset_with_ouput_rules catala_exe_opts rules] adds ninja rules used to reset test files (** [add_reset_with_ouput_rules catala_exe_opts rules] adds ninja rules used to
using an output flag into [rules] and returns it.*) reset test files using an output flag into [rules] and returns it.*)
let add_reset_with_output_rules (catala_exe_opts : string) (rules : Rule.t Nj.RuleMap.t) : let add_reset_with_output_rules
(catala_exe_opts : string) (rules : Rule.t Nj.RuleMap.t) :
Rule.t Nj.RuleMap.t = Rule.t Nj.RuleMap.t =
add_reset_rules_aux ~with_scope_output_rule:"reset_with_scope_and_output" add_reset_rules_aux ~with_scope_output_rule:"reset_with_scope_and_output"
~without_scope_output_rule:"reset_without_scope_and_output" ~redirect:"-o" catala_exe_opts rules ~without_scope_output_rule:"reset_without_scope_and_output" ~redirect:"-o"
catala_exe_opts rules
(** [add_test_with_output_rules catala_exe_opts rules] adds ninja rules used to test files using an (** [add_test_with_output_rules catala_exe_opts rules] adds ninja rules used to
output flag into [rules] and returns it.*) test files using an output flag into [rules] and returns it.*)
let add_test_with_output_rules (catala_exe_opts : string) (rules : Rule.t Nj.RuleMap.t) : let add_test_with_output_rules
(catala_exe_opts : string) (rules : Rule.t Nj.RuleMap.t) :
Rule.t Nj.RuleMap.t = Rule.t Nj.RuleMap.t =
let test_common_cmd_exprs = let test_common_cmd_exprs =
Nj.Expr. Nj.Expr.
@ -333,35 +375,44 @@ let add_test_with_output_rules (catala_exe_opts : string) (rules : Rule.t Nj.Rul
Var "tmp_file"; Var "tmp_file";
] ]
in in
add_test_rules_aux ~test_common_cmd_exprs ~with_scope_output_rule:"test_with_scope_and_output" add_test_rules_aux ~test_common_cmd_exprs
~without_scope_output_rule:"test_without_scope_and_output" catala_exe_opts rules ~with_scope_output_rule:"test_with_scope_and_output"
~without_scope_output_rule:"test_without_scope_and_output" catala_exe_opts
rules
(** [ninja_start catala_exe] returns the inital [ninja] data structure with rules needed to reset (** [ninja_start catala_exe] returns the inital [ninja] data structure with
and test files. *) rules needed to reset and test files. *)
let ninja_start (catala_exe : string) (catala_opts : string) : ninja = let ninja_start (catala_exe : string) (catala_opts : string) : ninja =
let catala_exe_opts = catala_exe ^ " " ^ catala_opts in let catala_exe_opts = catala_exe ^ " " ^ catala_opts in
let run_and_display_final_message = let run_and_display_final_message =
Nj.Rule.make "run_and_display_final_message" Nj.Rule.make "run_and_display_final_message"
~command:Nj.Expr.(Seq [ Lit ":" ]) ~command:Nj.Expr.(Seq [ Lit ":" ])
~description:Nj.Expr.(Seq [ Lit "All tests"; Var "test_file_or_folder"; Lit "passed!" ]) ~description:
Nj.Expr.(
Seq [ Lit "All tests"; Var "test_file_or_folder"; Lit "passed!" ])
in in
{ {
rules = rules =
Nj.RuleMap.( Nj.RuleMap.(
empty |> add_reset_rules catala_exe_opts |> add_test_rules catala_exe_opts empty
|> add_reset_rules catala_exe_opts
|> add_test_rules catala_exe_opts
|> add_test_with_output_rules catala_exe_opts |> add_test_with_output_rules catala_exe_opts
|> add_reset_with_output_rules catala_exe_opts |> add_reset_with_output_rules catala_exe_opts
|> add run_and_display_final_message.name run_and_display_final_message); |> add run_and_display_final_message.name run_and_display_final_message);
builds = Nj.BuildMap.empty; builds = Nj.BuildMap.empty;
} }
(** [collect_all_ninja_build ninja tested_file catala_exe catala_opts reset_test_outputs] creates (** [collect_all_ninja_build ninja tested_file catala_exe catala_opts reset_test_outputs]
and returns all ninja build statements needed to test the [tested_file]. *) creates and returns all ninja build statements needed to test the
let collect_all_ninja_build (ninja : ninja) (tested_file : string) (reset_test_outputs : bool) : [tested_file]. *)
let collect_all_ninja_build
(ninja : ninja) (tested_file : string) (reset_test_outputs : bool) :
(string * ninja) option = (string * ninja) option =
let expected_outputs = search_for_expected_outputs tested_file in let expected_outputs = search_for_expected_outputs tested_file in
if List.length expected_outputs = 0 then ( if List.length expected_outputs = 0 then (
Cli.debug_print "No expected outputs were found for test file %s" tested_file; Cli.debug_print "No expected outputs were found for test file %s"
tested_file;
None) None)
else else
let ninja, test_names = let ninja, test_names =
@ -369,58 +420,80 @@ let collect_all_ninja_build (ninja : ninja) (tested_file : string) (reset_test_o
(fun (ninja, test_names) expected_output -> (fun (ninja, test_names) expected_output ->
let vars = let vars =
[ [
("catala_cmd", Nj.Expr.Lit (catala_backend_to_string expected_output.backend)); ( "catala_cmd",
Nj.Expr.Lit (catala_backend_to_string expected_output.backend)
);
("tested_file", Nj.Expr.Lit tested_file); ("tested_file", Nj.Expr.Lit tested_file);
( "expected_output", ( "expected_output",
Nj.Expr.Lit (expected_output.output_dir ^ expected_output.complete_filename) ); Nj.Expr.Lit
(expected_output.output_dir
^ expected_output.complete_filename) );
] ]
in in
let output_build_kind = if reset_test_outputs then "reset" else "test" in let output_build_kind =
let catala_backend = catala_backend_to_string expected_output.backend in if reset_test_outputs then "reset" else "test"
in
let catala_backend =
catala_backend_to_string expected_output.backend
in
let get_rule_infos ?(rule_postfix = "") : let get_rule_infos ?(rule_postfix = "") :
string option -> string * string * (string * Nj.Expr.t) list = function string option -> string * string * (string * Nj.Expr.t) list =
function
| Some scope -> | Some scope ->
( Printf.sprintf "%s_%s_%s_%s" output_build_kind scope catala_backend tested_file ( Printf.sprintf "%s_%s_%s_%s" output_build_kind scope
catala_backend tested_file
|> Nj.Build.unpath, |> Nj.Build.unpath,
output_build_kind ^ "_with_scope" ^ rule_postfix, output_build_kind ^ "_with_scope" ^ rule_postfix,
("scope", Nj.Expr.Lit scope) :: vars ) ("scope", Nj.Expr.Lit scope) :: vars )
| None -> | None ->
( Printf.sprintf "%s_%s_%s" output_build_kind catala_backend tested_file ( Printf.sprintf "%s_%s_%s" output_build_kind catala_backend
tested_file
|> Nj.Build.unpath, |> Nj.Build.unpath,
output_build_kind ^ "_without_scope" ^ rule_postfix, output_build_kind ^ "_without_scope" ^ rule_postfix,
vars ) vars )
in in
let ninja_add_new_rule (rule_output : string) (rule : string) let ninja_add_new_rule
(vars : (string * Nj.Expr.t) list) (ninja : ninja) : ninja = (rule_output : string)
(rule : string)
(vars : (string * Nj.Expr.t) list)
(ninja : ninja) : ninja =
{ {
ninja with ninja with
builds = builds =
Nj.BuildMap.add rule_output Nj.BuildMap.add rule_output
(Nj.Build.make_with_vars ~outputs:[ Nj.Expr.Lit rule_output ] ~rule ~vars) (Nj.Build.make_with_vars
~outputs:[ Nj.Expr.Lit rule_output ]
~rule ~vars)
ninja.builds; ninja.builds;
} }
in in
match expected_output.backend with match expected_output.backend with
| Cli.Interpret | Cli.Proof | Cli.Typecheck | Cli.Dcalc | Cli.Scopelang | Cli.Scalc | Cli.Interpret | Cli.Proof | Cli.Typecheck | Cli.Dcalc
| Cli.Lcalc -> | Cli.Scopelang | Cli.Scalc | Cli.Lcalc ->
let rule_output, rule_name, rule_vars = get_rule_infos expected_output.scope in let rule_output, rule_name, rule_vars =
get_rule_infos expected_output.scope
in
let rule_vars = let rule_vars =
match expected_output.backend with match expected_output.backend with
| Cli.Proof -> | Cli.Proof ->
("extra_flags", Nj.Expr.Lit "--disable_counterexamples") :: rule_vars ("extra_flags", Nj.Expr.Lit "--disable_counterexamples")
(* Counterexamples can be different at each call because of the randomness :: rule_vars
inside SMT solver, so we can't expect their value to remain constant. Hence (* Counterexamples can be different at each call because of
we disable the counterexamples when testing the replication of failed the randomness inside SMT solver, so we can't expect
their value to remain constant. Hence we disable the
counterexamples when testing the replication of failed
proofs. *) proofs. *)
| _ -> rule_vars | _ -> rule_vars
in in
( ninja_add_new_rule rule_output rule_name rule_vars ninja, ( ninja_add_new_rule rule_output rule_name rule_vars ninja,
test_names ^ " $\n " ^ rule_output ) test_names ^ " $\n " ^ rule_output )
| Cli.Python | Cli.OCaml | Cli.Latex | Cli.Html | Cli.Makefile -> | Cli.Python | Cli.OCaml | Cli.Latex | Cli.Html | Cli.Makefile ->
let tmp_file = Filename.temp_file "clerk_" ("_" ^ catala_backend) in let tmp_file =
Filename.temp_file "clerk_" ("_" ^ catala_backend)
in
let rule_output, rule_name, rule_vars = let rule_output, rule_name, rule_vars =
get_rule_infos ~rule_postfix:"_and_output" expected_output.scope get_rule_infos ~rule_postfix:"_and_output" expected_output.scope
in in
@ -441,36 +514,50 @@ let collect_all_ninja_build (ninja : ninja) (tested_file : string) (reset_test_o
ninja with ninja with
builds = builds =
Nj.BuildMap.add test_name Nj.BuildMap.add test_name
(Nj.Build.make_with_inputs ~outputs:[ Nj.Expr.Lit test_name ] ~rule:"phony" (Nj.Build.make_with_inputs ~outputs:[ Nj.Expr.Lit test_name ]
~inputs:[ Nj.Expr.Lit test_names ]) ~rule:"phony" ~inputs:[ Nj.Expr.Lit test_names ])
ninja.builds; ninja.builds;
} ) } )
(** [add_root_test_build ninja all_file_names all_test_builds] add the 'test' ninja build (** [add_root_test_build ninja all_file_names all_test_builds] add the 'test'
declaration calling the rule 'run_and_display_final_message' for [all_test_builds] which ninja build declaration calling the rule 'run_and_display_final_message' for
correspond to [all_file_names]. *) [all_test_builds] which correspond to [all_file_names]. *)
let add_root_test_build (ninja : ninja) (all_file_names : string list) (all_test_builds : string) : let add_root_test_build
(ninja : ninja) (all_file_names : string list) (all_test_builds : string) :
ninja = ninja =
let file_names_str = let file_names_str =
List.hd all_file_names ^ "" List.hd all_file_names ^ ""
^ List.fold_left (fun acc name -> acc ^ "; " ^ name) "" (List.tl all_file_names) ^ List.fold_left
(fun acc name -> acc ^ "; " ^ name)
"" (List.tl all_file_names)
in in
{ {
ninja with ninja with
builds = builds =
Nj.BuildMap.add "test" Nj.BuildMap.add "test"
(Nj.Build.make_with_vars_and_inputs ~outputs:[ Nj.Expr.Lit "test" ] (Nj.Build.make_with_vars_and_inputs ~outputs:[ Nj.Expr.Lit "test" ]
~rule:"run_and_display_final_message" ~inputs:[ Nj.Expr.Lit all_test_builds ] ~rule:"run_and_display_final_message"
~vars:[ ("test_file_or_folder", Nj.Expr.Lit ("in [ " ^ file_names_str ^ " ]")) ]) ~inputs:[ Nj.Expr.Lit all_test_builds ]
~vars:
[
( "test_file_or_folder",
Nj.Expr.Lit ("in [ " ^ file_names_str ^ " ]") );
])
ninja.builds; ninja.builds;
} }
(**{1 Running}*) (**{1 Running}*)
let run_file (file : string) (catala_exe : string) (catala_opts : string) (scope : string) : int = let run_file
(file : string)
(catala_exe : string)
(catala_opts : string)
(scope : string) : int =
let command = let command =
String.concat " " String.concat " "
(List.filter (fun s -> s <> "") [ catala_exe; catala_opts; "-s " ^ scope; "Interpret"; file ]) (List.filter
(fun s -> s <> "")
[ catala_exe; catala_opts; "-s " ^ scope; "Interpret"; file ])
in in
Cli.debug_print "Running: %s" command; Cli.debug_print "Running: %s" command;
Sys.command command Sys.command command
@ -503,11 +590,13 @@ type ninja_building_context = {
all_test_builds : string; all_test_builds : string;
all_failed_names : string list; all_failed_names : string list;
} }
(** Record used to keep tracks of the current context while building the [Ninja_utils.ninja].*) (** Record used to keep tracks of the current context while building the
[Ninja_utils.ninja].*)
(** [ninja_building_context_init ninja_init] returns the empty context corresponding to (** [ninja_building_context_init ninja_init] returns the empty context
[ninja_init]. *) corresponding to [ninja_init]. *)
let ninja_building_context_init (ninja_init : Nj.ninja) : ninja_building_context = let ninja_building_context_init (ninja_init : Nj.ninja) : ninja_building_context
=
{ {
last_valid_ninja = ninja_init; last_valid_ninja = ninja_init;
curr_ninja = Some ninja_init; curr_ninja = Some ninja_init;
@ -516,9 +605,13 @@ let ninja_building_context_init (ninja_init : Nj.ninja) : ninja_building_context
all_failed_names = []; all_failed_names = [];
} }
(** [collect_in_directory ctx file_or_folder ninja_start reset_test_outputs] updates the building (** [collect_in_directory ctx file_or_folder ninja_start reset_test_outputs]
context [ctx] by adding new ninja build statements needed to test files in [folder].*) updates the building context [ctx] by adding new ninja build statements
let collect_in_folder (ctx : ninja_building_context) (folder : string) (ninja_start : Nj.ninja) needed to test files in [folder].*)
let collect_in_folder
(ctx : ninja_building_context)
(folder : string)
(ninja_start : Nj.ninja)
(reset_test_outputs : bool) : ninja_building_context = (reset_test_outputs : bool) : ninja_building_context =
let ninja, test_file_names = let ninja, test_file_names =
List.fold_left List.fold_left
@ -527,11 +620,14 @@ let collect_in_folder (ctx : ninja_building_context) (folder : string) (ninja_st
| None -> | None ->
(* Skips none Catala file. *) (* Skips none Catala file. *)
(ninja, test_file_names) (ninja, test_file_names)
| Some (test_file_name, ninja) -> (ninja, test_file_names ^ " $\n " ^ test_file_name)) | Some (test_file_name, ninja) ->
(ninja, test_file_names ^ " $\n " ^ test_file_name))
(ninja_start, "") (ninja_start, "")
(get_catala_files_in_folder folder) (get_catala_files_in_folder folder)
in in
let test_dir_name = Printf.sprintf "test_dir_%s" (folder |> Nj.Build.unpath) in let test_dir_name =
Printf.sprintf "test_dir_%s" (folder |> Nj.Build.unpath)
in
let curr_ninja = let curr_ninja =
if 0 = String.length test_file_names then None if 0 = String.length test_file_names then None
else else
@ -540,9 +636,15 @@ let collect_in_folder (ctx : ninja_building_context) (folder : string) (ninja_st
ninja with ninja with
builds = builds =
Nj.BuildMap.add test_dir_name Nj.BuildMap.add test_dir_name
(Nj.Build.make_with_vars_and_inputs ~outputs:[ Nj.Expr.Lit test_dir_name ] (Nj.Build.make_with_vars_and_inputs
~rule:"run_and_display_final_message" ~inputs:[ Nj.Expr.Lit test_file_names ] ~outputs:[ Nj.Expr.Lit test_dir_name ]
~vars:[ ("test_file_or_folder", Nj.Expr.Lit ("in folder '" ^ folder ^ "'")) ]) ~rule:"run_and_display_final_message"
~inputs:[ Nj.Expr.Lit test_file_names ]
~vars:
[
( "test_file_or_folder",
Nj.Expr.Lit ("in folder '" ^ folder ^ "'") );
])
ninja.builds; ninja.builds;
} }
in in
@ -562,9 +664,13 @@ let collect_in_folder (ctx : ninja_building_context) (folder : string) (ninja_st
all_failed_names = folder :: ctx.all_failed_names; all_failed_names = folder :: ctx.all_failed_names;
} }
(** [collect_in_file ctx file_or_folder ninja_start reset_test_outputs] updates the building context (** [collect_in_file ctx file_or_folder ninja_start reset_test_outputs] updates
[ctx] by adding new ninja build statements needed to test the [tested_file].*) the building context [ctx] by adding new ninja build statements needed to
let collect_in_file (ctx : ninja_building_context) (tested_file : string) (ninja_start : Nj.ninja) test the [tested_file].*)
let collect_in_file
(ctx : ninja_building_context)
(tested_file : string)
(ninja_start : Nj.ninja)
(reset_test_outputs : bool) : ninja_building_context = (reset_test_outputs : bool) : ninja_building_context =
match collect_all_ninja_build ninja_start tested_file reset_test_outputs with match collect_all_ninja_build ninja_start tested_file reset_test_outputs with
| Some (test_file_name, ninja) -> | Some (test_file_name, ninja) ->
@ -586,35 +692,47 @@ let collect_in_file (ctx : ninja_building_context) (tested_file : string) (ninja
(** {1 Return code values} *) (** {1 Return code values} *)
let return_ok = 0 let return_ok = 0
let return_err = 1 let return_err = 1
(** {1 Driver} *) (** {1 Driver} *)
(** [add_root_test_build ctx files_or_folders reset_test_outputs] updates the [ctx] by adding ninja (** [add_root_test_build ctx files_or_folders reset_test_outputs] updates the
build statements needed to test or [reset_test_outputs] [files_or_folders]. *) [ctx] by adding ninja build statements needed to test or
let add_test_builds (ctx : ninja_building_context) (files_or_folders : string list) [reset_test_outputs] [files_or_folders]. *)
let add_test_builds
(ctx : ninja_building_context)
(files_or_folders : string list)
(reset_test_outputs : bool) : ninja_building_context = (reset_test_outputs : bool) : ninja_building_context =
files_or_folders files_or_folders
|> List.fold_left |> List.fold_left
(fun ctx file_or_folder -> (fun ctx file_or_folder ->
let curr_ninja = let curr_ninja =
match ctx.curr_ninja with Some ninja -> ninja | None -> ctx.last_valid_ninja match ctx.curr_ninja with
| Some ninja -> ninja
| None -> ctx.last_valid_ninja
in in
if Sys.is_directory file_or_folder then if Sys.is_directory file_or_folder then
collect_in_folder ctx file_or_folder curr_ninja reset_test_outputs collect_in_folder ctx file_or_folder curr_ninja reset_test_outputs
else collect_in_file ctx file_or_folder curr_ninja reset_test_outputs) else collect_in_file ctx file_or_folder curr_ninja reset_test_outputs)
ctx ctx
let driver (files_or_folders : string list) (command : string) (catala_exe : string option) let driver
(catala_opts : string option) (debug : bool) (scope : string option) (reset_test_outputs : bool) (files_or_folders : string list)
(command : string)
(catala_exe : string option)
(catala_opts : string option)
(debug : bool)
(scope : string option)
(reset_test_outputs : bool)
(ninja_output : string option) : int = (ninja_output : string option) : int =
if debug then Cli.debug_flag := true; if debug then Cli.debug_flag := true;
let files_or_folders = List.sort_uniq String.compare files_or_folders let files_or_folders = List.sort_uniq String.compare files_or_folders
and catala_exe = Option.fold ~none:"catala" ~some:Fun.id catala_exe and catala_exe = Option.fold ~none:"catala" ~some:Fun.id catala_exe
and catala_opts = Option.fold ~none:"" ~some:Fun.id catala_opts and catala_opts = Option.fold ~none:"" ~some:Fun.id catala_opts
and ninja_output = and ninja_output =
Option.fold ~none:(Filename.temp_file "clerk_build" ".ninja") ~some:Fun.id ninja_output Option.fold
~none:(Filename.temp_file "clerk_build" ".ninja")
~some:Fun.id ninja_output
in in
match String.lowercase_ascii command with match String.lowercase_ascii command with
| "test" -> ( | "test" -> (
@ -625,7 +743,11 @@ let driver (files_or_folders : string list) (command : string) (catala_exe : str
files_or_folders reset_test_outputs files_or_folders reset_test_outputs
in in
let there_is_some_fails = 0 <> List.length ctx.all_failed_names in let there_is_some_fails = 0 <> List.length ctx.all_failed_names in
let ninja = match ctx.curr_ninja with Some ninja -> ninja | None -> ctx.last_valid_ninja in let ninja =
match ctx.curr_ninja with
| Some ninja -> ninja
| None -> ctx.last_valid_ninja
in
if there_is_some_fails then if there_is_some_fails then
List.iter List.iter
(fun f -> (fun f ->
@ -633,7 +755,8 @@ let driver (files_or_folders : string list) (command : string) (catala_exe : str
|> Cli.with_style [ ANSITerminal.magenta ] "%s" |> Cli.with_style [ ANSITerminal.magenta ] "%s"
|> Cli.warning_print "No test case found for %s") |> Cli.warning_print "No test case found for %s")
ctx.all_failed_names; ctx.all_failed_names;
if 0 = List.compare_lengths ctx.all_failed_names files_or_folders then return_ok if 0 = List.compare_lengths ctx.all_failed_names files_or_folders then
return_ok
else else
try try
let out = open_out ninja_output in let out = open_out ninja_output in

View File

@ -1,15 +1,17 @@
(* This file is part of the Catala build system, a specification language for tax and social (* This file is part of the Catala build system, a specification language for
benefits computation rules. Copyright (C) 2020 Inria, contributor: Emile Rolley tax and social benefits computation rules. Copyright (C) 2020 Inria,
<emile.rolley@tuta.io> contributor: Emile Rolley <emile.rolley@tuta.io>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
module Expr = struct module Expr = struct
@ -23,7 +25,8 @@ module Expr = struct
and format_list fmt = function and format_list fmt = function
| hd :: tl -> | hd :: tl ->
Format.fprintf fmt "%a%a" format hd Format.fprintf fmt "%a%a" format hd
(fun fmt tl -> tl |> List.iter (fun s -> Format.fprintf fmt " %a" format s)) (fun fmt tl ->
tl |> List.iter (fun s -> Format.fprintf fmt " %a" format s))
tl tl
| [] -> () | [] -> ()
end end
@ -31,15 +34,16 @@ end
module Rule = struct module Rule = struct
type t = { name : string; command : Expr.t; description : Expr.t option } type t = { name : string; command : Expr.t; description : Expr.t option }
let make name ~command ~description = { name; command; description = Option.some description } let make name ~command ~description =
{ name; command; description = Option.some description }
let format fmt rule = let format fmt rule =
let format_description fmt = function let format_description fmt = function
| Some e -> Format.fprintf fmt " description = %a\n" Expr.format e | Some e -> Format.fprintf fmt " description = %a\n" Expr.format e
| None -> Format.fprintf fmt "\n" | None -> Format.fprintf fmt "\n"
in in
Format.fprintf fmt "rule %s\n command = %a\n%a" rule.name Expr.format rule.command Format.fprintf fmt "rule %s\n command = %a\n%a" rule.name Expr.format
format_description rule.description rule.command format_description rule.description
end end
module Build = struct module Build = struct
@ -52,7 +56,8 @@ module Build = struct
let make ~outputs ~rule = { outputs; rule; inputs = Option.none; vars = [] } let make ~outputs ~rule = { outputs; rule; inputs = Option.none; vars = [] }
let make_with_vars ~outputs ~rule ~vars = { outputs; rule; inputs = Option.none; vars } let make_with_vars ~outputs ~rule ~vars =
{ outputs; rule; inputs = Option.none; vars }
let make_with_inputs ~outputs ~rule ~inputs = let make_with_inputs ~outputs ~rule ~inputs =
{ outputs; rule; inputs = Option.some inputs; vars = [] } { outputs; rule; inputs = Option.some inputs; vars = [] }
@ -62,21 +67,24 @@ module Build = struct
let empty = make ~outputs:[ Expr.Lit "empty" ] ~rule:"phony" let empty = make ~outputs:[ Expr.Lit "empty" ] ~rule:"phony"
let unpath ?(sep = "-") path = Re.Pcre.(substitute ~rex:(regexp "/") ~subst:(fun _ -> sep)) path let unpath ?(sep = "-") path =
Re.Pcre.(substitute ~rex:(regexp "/") ~subst:(fun _ -> sep)) path
let format fmt build = let format fmt build =
let format_inputs fmt = function let format_inputs fmt = function
| Some exs -> Format.fprintf fmt " %a" Expr.format_list exs | Some exs -> Format.fprintf fmt " %a" Expr.format_list exs
| None -> () | None -> ()
and format_vars fmt vars = and format_vars fmt vars =
List.iter (fun (name, exp) -> Format.fprintf fmt " %s = %a\n" name Expr.format exp) vars List.iter
(fun (name, exp) ->
Format.fprintf fmt " %s = %a\n" name Expr.format exp)
vars
in in
Format.fprintf fmt "build %a: %s%a\n%a" Expr.format_list build.outputs build.rule format_inputs Format.fprintf fmt "build %a: %s%a\n%a" Expr.format_list build.outputs
build.inputs format_vars build.vars build.rule format_inputs build.inputs format_vars build.vars
end end
module RuleMap : Map.S with type key = String.t = Map.Make (String) module RuleMap : Map.S with type key = String.t = Map.Make (String)
module BuildMap : Map.S with type key = String.t = Map.Make (String) module BuildMap : Map.S with type key = String.t = Map.Make (String)
type ninja = { rules : Rule.t RuleMap.t; builds : Build.t BuildMap.t } type ninja = { rules : Rule.t RuleMap.t; builds : Build.t BuildMap.t }
@ -84,6 +92,8 @@ type ninja = { rules : Rule.t RuleMap.t; builds : Build.t BuildMap.t }
let empty = { rules = RuleMap.empty; builds = BuildMap.empty } let empty = { rules = RuleMap.empty; builds = BuildMap.empty }
let format fmt ninja = let format fmt ninja =
let format_for_all iter format = iter (fun _name rule -> Format.fprintf fmt "%a\n" format rule) in let format_for_all iter format =
iter (fun _name rule -> Format.fprintf fmt "%a\n" format rule)
in
format_for_all RuleMap.iter Rule.format ninja.rules; format_for_all RuleMap.iter Rule.format ninja.rules;
format_for_all BuildMap.iter Build.format ninja.builds format_for_all BuildMap.iter Build.format ninja.builds

View File

@ -4,15 +4,12 @@ module Nj = Ninja_utils
module To_test = struct module To_test = struct
let ninja_start = D.ninja_start let ninja_start = D.ninja_start
let add_test_builds = D.add_test_builds let add_test_builds = D.add_test_builds
end end
(* cwd: _build/default/build_system/tests/ *) (* cwd: _build/default/build_system/tests/ *)
let test_files_dir = "../../../../build_system/tests/catala_files/" let test_files_dir = "../../../../build_system/tests/catala_files/"
let ninja_start = To_test.ninja_start "catala" "" let ninja_start = To_test.ninja_start "catala" ""
let al_assert msg = Al.(check bool) msg true let al_assert msg = Al.(check bool) msg true
let test_ninja_start () = let test_ninja_start () =
@ -22,12 +19,16 @@ let test_ninja_start () =
"rule reset_with_scope\n command = catala -s $scope $catala_cmd $tested_file $extra_flags --unstyled > $expected_output 2>&1\n description = RESET scope $scope of file $tested_file with the $catala_cmd command\n\nrule reset_with_scope_and_output\n command = catala -s $scope $catala_cmd $tested_file $extra_flags --unstyled -o $expected_output 2>&1\n description = RESET scope $scope of file $tested_file with the $catala_cmd command\n\nrule reset_without_scope\n command = catala $catala_cmd $tested_file $extra_flags --unstyled > $expected_output 2>&1\n description = RESET file $tested_file with the $catala_cmd command\n\nrule reset_without_scope_and_output\n command = catala $catala_cmd $tested_file $extra_flags --unstyled -o $expected_output 2>&1\n description = RESET file $tested_file with the $catala_cmd command\n\nrule run_and_display_final_message\n command = :\n description = All tests $test_file_or_folder passed!\n\nrule test_with_scope\n command = catala -s $scope $catala_cmd $tested_file $extra_flags --unstyled 2>&1 | colordiff -u -b $expected_output -\n description = TEST scope $scope of file $tested_file with the $catala_cmd command\n\nrule test_with_scope_and_output\n command = catala -s $scope $catala_cmd $tested_file $extra_flags --unstyled -o $tmp_file ; colordiff -u -b $expected_output $tmp_file\n description = TEST scope $scope of file $tested_file with the $catala_cmd command\n\nrule test_without_scope\n command = catala $catala_cmd $tested_file $extra_flags --unstyled 2>&1 | colordiff -u -b $expected_output -\n description = TEST on file $tested_file with the $catala_cmd command\n\nrule test_without_scope_and_output\n command = catala $catala_cmd $tested_file $extra_flags --unstyled -o $tmp_file ; colordiff -u -b $expected_output $tmp_file\n description = TEST on file $tested_file with the $catala_cmd command\n\n"[@ocamlformat "disable"] "rule reset_with_scope\n command = catala -s $scope $catala_cmd $tested_file $extra_flags --unstyled > $expected_output 2>&1\n description = RESET scope $scope of file $tested_file with the $catala_cmd command\n\nrule reset_with_scope_and_output\n command = catala -s $scope $catala_cmd $tested_file $extra_flags --unstyled -o $expected_output 2>&1\n description = RESET scope $scope of file $tested_file with the $catala_cmd command\n\nrule reset_without_scope\n command = catala $catala_cmd $tested_file $extra_flags --unstyled > $expected_output 2>&1\n description = RESET file $tested_file with the $catala_cmd command\n\nrule reset_without_scope_and_output\n command = catala $catala_cmd $tested_file $extra_flags --unstyled -o $expected_output 2>&1\n description = RESET file $tested_file with the $catala_cmd command\n\nrule run_and_display_final_message\n command = :\n description = All tests $test_file_or_folder passed!\n\nrule test_with_scope\n command = catala -s $scope $catala_cmd $tested_file $extra_flags --unstyled 2>&1 | colordiff -u -b $expected_output -\n description = TEST scope $scope of file $tested_file with the $catala_cmd command\n\nrule test_with_scope_and_output\n command = catala -s $scope $catala_cmd $tested_file $extra_flags --unstyled -o $tmp_file ; colordiff -u -b $expected_output $tmp_file\n description = TEST scope $scope of file $tested_file with the $catala_cmd command\n\nrule test_without_scope\n command = catala $catala_cmd $tested_file $extra_flags --unstyled 2>&1 | colordiff -u -b $expected_output -\n description = TEST on file $tested_file with the $catala_cmd command\n\nrule test_without_scope_and_output\n command = catala $catala_cmd $tested_file $extra_flags --unstyled -o $tmp_file ; colordiff -u -b $expected_output $tmp_file\n description = TEST on file $tested_file with the $catala_cmd command\n\n"[@ocamlformat "disable"]
in in
let actual_format = Buffer.contents Format.stdbuf in let actual_format = Buffer.contents Format.stdbuf in
Al.(check string) "both formated strings should equal" expected_format actual_format Al.(check string)
"both formated strings should equal" expected_format actual_format
let test_add_test_builds_for_folder () = let test_add_test_builds_for_folder () =
let ctx = D.ninja_building_context_init ninja_start in let ctx = D.ninja_building_context_init ninja_start in
let nj_building_ctx = To_test.add_test_builds ctx [ test_files_dir ^ "folder" ] false in let nj_building_ctx =
al_assert "a test case should be found" (Option.is_some nj_building_ctx.curr_ninja); To_test.add_test_builds ctx [ test_files_dir ^ "folder" ] false
in
al_assert "a test case should be found"
(Option.is_some nj_building_ctx.curr_ninja);
let expected_format = let expected_format =
"build test_A_Interpret_..-..-..-..-build_system-tests-catala_files-folder-file1.catala_en: test_with_scope\n scope = A\n catala_cmd = Interpret\n tested_file = ../../../../build_system/tests/catala_files/folder/file1.catala_en\n expected_output = ../../../../build_system/tests/catala_files/folder/output/file1.catala_en.A.Interpret\nbuild test_B_Interpret_..-..-..-..-build_system-tests-catala_files-folder-file1.catala_en: test_with_scope\n scope = B\n catala_cmd = Interpret\n tested_file = ../../../../build_system/tests/catala_files/folder/file1.catala_en\n expected_output = ../../../../build_system/tests/catala_files/folder/output/file1.catala_en.B.Interpret\nbuild test_Proof_..-..-..-..-build_system-tests-catala_files-folder-file3.catala_en: test_without_scope\n extra_flags = --disable_counterexamples\n catala_cmd = Proof\n tested_file = ../../../../build_system/tests/catala_files/folder/file3.catala_en\n expected_output = ../../../../build_system/tests/catala_files/folder/output/file3.catala_en.Proof\nbuild test_Typecheck_..-..-..-..-build_system-tests-catala_files-folder-file2.catala_en: test_without_scope\n catala_cmd = Typecheck\n tested_file = ../../../../build_system/tests/catala_files/folder/file2.catala_en\n expected_output = ../../../../build_system/tests/catala_files/folder/output/file2.catala_en.Typecheck\nbuild test_dir_..-..-..-..-build_system-tests-catala_files-folder: run_and_display_final_message $\n test_file_..-..-..-..-build_system-tests-catala_files-folder-file3.catala_en $\n test_file_..-..-..-..-build_system-tests-catala_files-folder-file2.catala_en $\n test_file_..-..-..-..-build_system-tests-catala_files-folder-file1.catala_en\n test_file_or_folder = in folder '../../../../build_system/tests/catala_files/folder'\nbuild test_file_..-..-..-..-build_system-tests-catala_files-folder-file1.catala_en: phony $\n test_A_Interpret_..-..-..-..-build_system-tests-catala_files-folder-file1.catala_en $\n test_B_Interpret_..-..-..-..-build_system-tests-catala_files-folder-file1.catala_en\nbuild test_file_..-..-..-..-build_system-tests-catala_files-folder-file2.catala_en: phony $\n test_Typecheck_..-..-..-..-build_system-tests-catala_files-folder-file2.catala_en\nbuild test_file_..-..-..-..-build_system-tests-catala_files-folder-file3.catala_en: phony $\n test_Proof_..-..-..-..-build_system-tests-catala_files-folder-file3.catala_en\n"[@ocamlformat "disable"] "build test_A_Interpret_..-..-..-..-build_system-tests-catala_files-folder-file1.catala_en: test_with_scope\n scope = A\n catala_cmd = Interpret\n tested_file = ../../../../build_system/tests/catala_files/folder/file1.catala_en\n expected_output = ../../../../build_system/tests/catala_files/folder/output/file1.catala_en.A.Interpret\nbuild test_B_Interpret_..-..-..-..-build_system-tests-catala_files-folder-file1.catala_en: test_with_scope\n scope = B\n catala_cmd = Interpret\n tested_file = ../../../../build_system/tests/catala_files/folder/file1.catala_en\n expected_output = ../../../../build_system/tests/catala_files/folder/output/file1.catala_en.B.Interpret\nbuild test_Proof_..-..-..-..-build_system-tests-catala_files-folder-file3.catala_en: test_without_scope\n extra_flags = --disable_counterexamples\n catala_cmd = Proof\n tested_file = ../../../../build_system/tests/catala_files/folder/file3.catala_en\n expected_output = ../../../../build_system/tests/catala_files/folder/output/file3.catala_en.Proof\nbuild test_Typecheck_..-..-..-..-build_system-tests-catala_files-folder-file2.catala_en: test_without_scope\n catala_cmd = Typecheck\n tested_file = ../../../../build_system/tests/catala_files/folder/file2.catala_en\n expected_output = ../../../../build_system/tests/catala_files/folder/output/file2.catala_en.Typecheck\nbuild test_dir_..-..-..-..-build_system-tests-catala_files-folder: run_and_display_final_message $\n test_file_..-..-..-..-build_system-tests-catala_files-folder-file3.catala_en $\n test_file_..-..-..-..-build_system-tests-catala_files-folder-file2.catala_en $\n test_file_..-..-..-..-build_system-tests-catala_files-folder-file1.catala_en\n test_file_or_folder = in folder '../../../../build_system/tests/catala_files/folder'\nbuild test_file_..-..-..-..-build_system-tests-catala_files-folder-file1.catala_en: phony $\n test_A_Interpret_..-..-..-..-build_system-tests-catala_files-folder-file1.catala_en $\n test_B_Interpret_..-..-..-..-build_system-tests-catala_files-folder-file1.catala_en\nbuild test_file_..-..-..-..-build_system-tests-catala_files-folder-file2.catala_en: phony $\n test_Typecheck_..-..-..-..-build_system-tests-catala_files-folder-file2.catala_en\nbuild test_file_..-..-..-..-build_system-tests-catala_files-folder-file3.catala_en: phony $\n test_Proof_..-..-..-..-build_system-tests-catala_files-folder-file3.catala_en\n"[@ocamlformat "disable"]
@ -35,44 +36,63 @@ let test_add_test_builds_for_folder () =
let actual_format = let actual_format =
let ninja = Option.get nj_building_ctx.curr_ninja in let ninja = Option.get nj_building_ctx.curr_ninja in
Buffer.clear Format.stdbuf; Buffer.clear Format.stdbuf;
Nj.BuildMap.iter (fun _ b -> Nj.Build.format Format.str_formatter b) ninja.builds; Nj.BuildMap.iter
(fun _ b -> Nj.Build.format Format.str_formatter b)
ninja.builds;
Buffer.contents Format.stdbuf Buffer.contents Format.stdbuf
in in
Al.(check string) "both formated strings should equal" expected_format actual_format Al.(check string)
"both formated strings should equal" expected_format actual_format
let test_add_test_builds_for_untested_file () = let test_add_test_builds_for_untested_file () =
let untested_file = test_files_dir ^ "untested_file.catala_en" in let untested_file = test_files_dir ^ "untested_file.catala_en" in
let ctx = D.ninja_building_context_init Nj.empty in let ctx = D.ninja_building_context_init Nj.empty in
let nj_building_ctx = To_test.add_test_builds ctx [ untested_file ] false in let nj_building_ctx = To_test.add_test_builds ctx [ untested_file ] false in
al_assert "no test cases should be found" (Option.is_none nj_building_ctx.curr_ninja); al_assert "no test cases should be found"
(Option.is_none nj_building_ctx.curr_ninja);
al_assert "ninja_start should be the last valid ninja" al_assert "ninja_start should be the last valid ninja"
(Nj.empty = nj_building_ctx.last_valid_ninja) (Nj.empty = nj_building_ctx.last_valid_ninja)
(* Test without comparing formated ninja. *) (* Test without comparing formated ninja. *)
let test_add_test_builds_for_simple_interpret_scope_file () = let test_add_test_builds_for_simple_interpret_scope_file () =
let simple_interpret_scope_file = test_files_dir ^ "simple_interpret_scope_file.catala_en" in let simple_interpret_scope_file =
test_files_dir ^ "simple_interpret_scope_file.catala_en"
in
let ctx = D.ninja_building_context_init ninja_start in let ctx = D.ninja_building_context_init ninja_start in
let nj_building_ctx = To_test.add_test_builds ctx [ simple_interpret_scope_file ] false in let nj_building_ctx =
al_assert "a test case should be found" (Option.is_some nj_building_ctx.curr_ninja); To_test.add_test_builds ctx [ simple_interpret_scope_file ] false
in
al_assert "a test case should be found"
(Option.is_some nj_building_ctx.curr_ninja);
let expected_format = let expected_format =
let open Nj in let open Nj in
let test_file_output = "test_file_" ^ Nj.Build.unpath simple_interpret_scope_file in let test_file_output =
let test_A_file_output = "test_A_Interpret_" ^ Nj.Build.unpath simple_interpret_scope_file in "test_file_" ^ Nj.Build.unpath simple_interpret_scope_file
in
let test_A_file_output =
"test_A_Interpret_" ^ Nj.Build.unpath simple_interpret_scope_file
in
let test_A_file = let test_A_file =
Build.make_with_vars ~outputs:[ Expr.Lit test_A_file_output ] ~rule:"test_with_scope" Build.make_with_vars
~outputs:[ Expr.Lit test_A_file_output ]
~rule:"test_with_scope"
~vars: ~vars:
[ [
("scope", Lit "A"); ("scope", Lit "A");
("catala_cmd", Lit "Interpret"); ("catala_cmd", Lit "Interpret");
("tested_file", Lit simple_interpret_scope_file); ("tested_file", Lit simple_interpret_scope_file);
( "expected_output", ( "expected_output",
Lit (test_files_dir ^ "output/simple_interpret_scope_file.catala_en.A.Interpret") ); Lit
(test_files_dir
^ "output/simple_interpret_scope_file.catala_en.A.Interpret") );
] ]
in in
let test_file = let test_file =
Build.make_with_inputs ~outputs:[ Expr.Lit test_file_output ] ~rule:"phony" Build.make_with_inputs
~outputs:[ Expr.Lit test_file_output ]
~rule:"phony"
~inputs:[ Expr.Lit (" $\n " ^ test_A_file_output) ] ~inputs:[ Expr.Lit (" $\n " ^ test_A_file_output) ]
in in
BuildMap.empty BuildMap.empty
@ -85,21 +105,28 @@ let test_add_test_builds_for_simple_interpret_scope_file () =
let actual_format = let actual_format =
let ninja = Option.get nj_building_ctx.curr_ninja in let ninja = Option.get nj_building_ctx.curr_ninja in
Buffer.clear Format.stdbuf; Buffer.clear Format.stdbuf;
Nj.BuildMap.iter (fun _ b -> Nj.Build.format Format.str_formatter b) ninja.builds; Nj.BuildMap.iter
(fun _ b -> Nj.Build.format Format.str_formatter b)
ninja.builds;
Buffer.contents Format.stdbuf Buffer.contents Format.stdbuf
in in
Al.(check string) "both formated strings should equal" expected_format actual_format Al.(check string)
"both formated strings should equal" expected_format actual_format
let () = let () =
Al.run "Clerk_driver" Al.run "Clerk_driver"
Al. Al.
[ [
( "Test ninja_start", ( "Test ninja_start",
[ test_case "initial ninja rules should be present" `Quick test_ninja_start ] ); [
test_case "initial ninja rules should be present" `Quick
test_ninja_start;
] );
( "Test add_test_builds", ( "Test add_test_builds",
[ [
test_case "an untested file" `Quick test_add_test_builds_for_untested_file; test_case "an untested file" `Quick
test_add_test_builds_for_untested_file;
test_case "a simple Interpret scope" `Quick test_case "a simple Interpret scope" `Quick
test_add_test_builds_for_simple_interpret_scope_file; test_add_test_builds_for_simple_interpret_scope_file;
test_case "a simple folder" `Quick test_add_test_builds_for_folder; test_case "a simple folder" `Quick test_add_test_builds_for_folder;

View File

@ -1,15 +1,17 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits (* This file is part of the Catala compiler, a specification language for tax
computation rules. Copyright (C) 2020 Inria, contributor: Denis Merigoux and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
<denis.merigoux@inria.fr> Denis Merigoux <denis.merigoux@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
let _ = Driver.main () let _ = Driver.main ()

View File

@ -4,8 +4,11 @@ open Js_of_ocaml
let _ = let _ =
Js.export_all Js.export_all
(object%js (object%js
method interpret (contents : Js.js_string Js.t) (scope : Js.js_string Js.t) method interpret
(language : Js.js_string Js.t) (trace : bool) = (contents : Js.js_string Js.t)
(scope : Js.js_string Js.t)
(language : Js.js_string Js.t)
(trace : bool) =
driver driver
(Contents (Js.to_string contents)) (Contents (Js.to_string contents))
false false false false "Interpret" false false false false "Interpret"

View File

@ -1,31 +1,36 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits (* This file is part of the Catala compiler, a specification language for tax
computation rules. Copyright (C) 2020 Inria, contributor: Denis Merigoux and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
<denis.merigoux@inria.fr> Denis Merigoux <denis.merigoux@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
[@@@ocaml.warning "-7-34"] [@@@ocaml.warning "-7-34"]
open Utils open Utils
module ScopeName : Uid.Id with type info = Uid.MarkedString.info = Uid.Make (Uid.MarkedString) () module ScopeName : Uid.Id with type info = Uid.MarkedString.info =
Uid.Make (Uid.MarkedString) ()
module StructName : Uid.Id with type info = Uid.MarkedString.info = Uid.Make (Uid.MarkedString) () module StructName : Uid.Id with type info = Uid.MarkedString.info =
Uid.Make (Uid.MarkedString) ()
module StructFieldName : Uid.Id with type info = Uid.MarkedString.info = module StructFieldName : Uid.Id with type info = Uid.MarkedString.info =
Uid.Make (Uid.MarkedString) () Uid.Make (Uid.MarkedString) ()
module StructMap : Map.S with type key = StructName.t = Map.Make (StructName) module StructMap : Map.S with type key = StructName.t = Map.Make (StructName)
module EnumName : Uid.Id with type info = Uid.MarkedString.info = Uid.Make (Uid.MarkedString) () module EnumName : Uid.Id with type info = Uid.MarkedString.info =
Uid.Make (Uid.MarkedString) ()
module EnumConstructor : Uid.Id with type info = Uid.MarkedString.info = module EnumConstructor : Uid.Id with type info = Uid.MarkedString.info =
Uid.Make (Uid.MarkedString) () Uid.Make (Uid.MarkedString) ()
@ -33,9 +38,7 @@ module EnumConstructor : Uid.Id with type info = Uid.MarkedString.info =
module EnumMap : Map.S with type key = EnumName.t = Map.Make (EnumName) module EnumMap : Map.S with type key = EnumName.t = Map.Make (EnumName)
type typ_lit = TBool | TUnit | TInt | TRat | TMoney | TDate | TDuration type typ_lit = TBool | TUnit | TInt | TRat | TMoney | TDate | TDuration
type struct_name = StructName.t type struct_name = StructName.t
type enum_name = EnumName.t type enum_name = EnumName.t
type typ = type typ =
@ -47,13 +50,9 @@ type typ =
| TAny | TAny
type date = Runtime.date type date = Runtime.date
type duration = Runtime.duration type duration = Runtime.duration
type integer = Runtime.integer type integer = Runtime.integer
type decimal = Runtime.decimal type decimal = Runtime.decimal
type money = Runtime.money type money = Runtime.money
type lit = type lit =
@ -67,7 +66,6 @@ type lit =
| LDuration of duration | LDuration of duration
type op_kind = KInt | KRat | KMoney | KDate | KDuration type op_kind = KInt | KRat | KMoney | KDate | KDuration
type ternop = Fold type ternop = Fold
type binop = type binop =
@ -105,12 +103,14 @@ type operator = Ternop of ternop | Binop of binop | Unop of unop
type expr = type expr =
| EVar of expr Bindlib.var Pos.marked | EVar of expr Bindlib.var Pos.marked
| ETuple of expr Pos.marked list * struct_name option | ETuple of expr Pos.marked list * struct_name option
| ETupleAccess of expr Pos.marked * int * struct_name option * typ Pos.marked list | ETupleAccess of
expr Pos.marked * int * struct_name option * typ Pos.marked list
| EInj of expr Pos.marked * int * enum_name * typ Pos.marked list | EInj of expr Pos.marked * int * enum_name * typ Pos.marked list
| EMatch of expr Pos.marked * expr Pos.marked list * enum_name | EMatch of expr Pos.marked * expr Pos.marked list * enum_name
| EArray of expr Pos.marked list | EArray of expr Pos.marked list
| ELit of lit | ELit of lit
| EAbs of (expr, expr Pos.marked) Bindlib.mbinder Pos.marked * typ Pos.marked list | EAbs of
(expr, expr Pos.marked) Bindlib.mbinder Pos.marked * typ Pos.marked list
| EApp of expr Pos.marked * expr Pos.marked list | EApp of expr Pos.marked * expr Pos.marked list
| EAssert of expr Pos.marked | EAssert of expr Pos.marked
| EOp of operator | EOp of operator
@ -119,11 +119,8 @@ type expr =
| ErrorOnEmpty of expr Pos.marked | ErrorOnEmpty of expr Pos.marked
type struct_ctx = (StructFieldName.t * typ Pos.marked) list StructMap.t type struct_ctx = (StructFieldName.t * typ Pos.marked) list StructMap.t
type enum_ctx = (EnumConstructor.t * typ Pos.marked) list EnumMap.t type enum_ctx = (EnumConstructor.t * typ Pos.marked) list EnumMap.t
type decl_ctx = { ctx_enums : enum_ctx; ctx_structs : struct_ctx } type decl_ctx = { ctx_enums : enum_ctx; ctx_structs : struct_ctx }
type binder = (expr, expr Pos.marked) Bindlib.binder type binder = (expr, expr Pos.marked) Bindlib.binder
type scope_let_kind = type scope_let_kind =
@ -143,13 +140,17 @@ type scope_let = {
type scope_body = { type scope_body = {
scope_body_lets : scope_let list; scope_body_lets : scope_let list;
scope_body_result : expr Pos.marked Bindlib.box; (** {x1 = x1; x2 = x2; x3 = x3; ... } *) scope_body_result : expr Pos.marked Bindlib.box;
(** {x1 = x1; x2 = x2; x3 = x3; ... } *)
scope_body_arg : expr Bindlib.var; (** x: input_struct *) scope_body_arg : expr Bindlib.var; (** x: input_struct *)
scope_body_input_struct : StructName.t; scope_body_input_struct : StructName.t;
scope_body_output_struct : StructName.t; scope_body_output_struct : StructName.t;
} }
type program = { decl_ctx : decl_ctx; scopes : (ScopeName.t * expr Bindlib.var * scope_body) list } type program = {
decl_ctx : decl_ctx;
scopes : (ScopeName.t * expr Bindlib.var * scope_body) list;
}
module Var = struct module Var = struct
type t = expr Bindlib.var type t = expr Bindlib.var
@ -164,21 +165,28 @@ end
module VarMap = Map.Make (Var) module VarMap = Map.Make (Var)
let union : unit VarMap.t -> unit VarMap.t -> unit VarMap.t = VarMap.union (fun _ _ _ -> Some ()) let union : unit VarMap.t -> unit VarMap.t -> unit VarMap.t =
VarMap.union (fun _ _ _ -> Some ())
let rec free_vars_set (e : expr Pos.marked) : unit VarMap.t = let rec free_vars_set (e : expr Pos.marked) : unit VarMap.t =
match Pos.unmark e with match Pos.unmark e with
| EVar (v, _) -> VarMap.singleton v () | EVar (v, _) -> VarMap.singleton v ()
| ETuple (es, _) | EArray es -> es |> List.map free_vars_set |> List.fold_left union VarMap.empty | ETuple (es, _) | EArray es ->
| ETupleAccess (e1, _, _, _) | EAssert e1 | ErrorOnEmpty e1 | EInj (e1, _, _, _) -> es |> List.map free_vars_set |> List.fold_left union VarMap.empty
| ETupleAccess (e1, _, _, _)
| EAssert e1
| ErrorOnEmpty e1
| EInj (e1, _, _, _) ->
free_vars_set e1 free_vars_set e1
| EApp (e1, es) | EMatch (e1, es, _) -> | EApp (e1, es) | EMatch (e1, es, _) ->
e1 :: es |> List.map free_vars_set |> List.fold_left union VarMap.empty e1 :: es |> List.map free_vars_set |> List.fold_left union VarMap.empty
| EDefault (es, ejust, econs) -> | EDefault (es, ejust, econs) ->
ejust :: econs :: es |> List.map free_vars_set |> List.fold_left union VarMap.empty ejust :: econs :: es |> List.map free_vars_set
|> List.fold_left union VarMap.empty
| EOp _ | ELit _ -> VarMap.empty | EOp _ | ELit _ -> VarMap.empty
| EIfThenElse (e1, e2, e3) -> | EIfThenElse (e1, e2, e3) ->
[ e1; e2; e3 ] |> List.map free_vars_set |> List.fold_left union VarMap.empty [ e1; e2; e3 ] |> List.map free_vars_set
|> List.fold_left union VarMap.empty
| EAbs ((binder, _), _) -> | EAbs ((binder, _), _) ->
let vs, body = Bindlib.unmbind binder in let vs, body = Bindlib.unmbind binder in
Array.fold_right VarMap.remove vs (free_vars_set body) Array.fold_right VarMap.remove vs (free_vars_set body)
@ -191,16 +199,28 @@ type vars = expr Bindlib.mvar
let make_var ((x, pos) : Var.t Pos.marked) : expr Pos.marked Bindlib.box = let make_var ((x, pos) : Var.t Pos.marked) : expr Pos.marked Bindlib.box =
Bindlib.box_apply (fun x -> (x, pos)) (Bindlib.box_var x) Bindlib.box_apply (fun x -> (x, pos)) (Bindlib.box_var x)
let make_abs (xs : vars) (e : expr Pos.marked Bindlib.box) (pos_binder : Pos.t) let make_abs
(taus : typ Pos.marked list) (pos : Pos.t) : expr Pos.marked Bindlib.box = (xs : vars)
Bindlib.box_apply (fun b -> (EAbs ((b, pos_binder), taus), pos)) (Bindlib.bind_mvar xs e) (e : expr Pos.marked Bindlib.box)
(pos_binder : Pos.t)
(taus : typ Pos.marked list)
(pos : Pos.t) : expr Pos.marked Bindlib.box =
Bindlib.box_apply
(fun b -> (EAbs ((b, pos_binder), taus), pos))
(Bindlib.bind_mvar xs e)
let make_app (e : expr Pos.marked Bindlib.box) (u : expr Pos.marked Bindlib.box list) (pos : Pos.t) let make_app
: expr Pos.marked Bindlib.box = (e : expr Pos.marked Bindlib.box)
(u : expr Pos.marked Bindlib.box list)
(pos : Pos.t) : expr Pos.marked Bindlib.box =
Bindlib.box_apply2 (fun e u -> (EApp (e, u), pos)) e (Bindlib.box_list u) Bindlib.box_apply2 (fun e u -> (EApp (e, u), pos)) e (Bindlib.box_list u)
let make_let_in (x : Var.t) (tau : typ Pos.marked) (e1 : expr Pos.marked Bindlib.box) let make_let_in
(e2 : expr Pos.marked Bindlib.box) (pos : Pos.t) : expr Pos.marked Bindlib.box = (x : Var.t)
(tau : typ Pos.marked)
(e1 : expr Pos.marked Bindlib.box)
(e2 : expr Pos.marked Bindlib.box)
(pos : Pos.t) : expr Pos.marked Bindlib.box =
make_app (make_abs (Array.of_list [ x ]) e2 pos [ tau ] pos) [ e1 ] pos make_app (make_abs (Array.of_list [ x ]) e2 pos [ tau ] pos) [ e1 ] pos
let empty_thunked_term : expr Pos.marked = let empty_thunked_term : expr Pos.marked =
@ -209,12 +229,15 @@ let empty_thunked_term : expr Pos.marked =
(make_abs (make_abs
(Array.of_list [ silent ]) (Array.of_list [ silent ])
(Bindlib.box (ELit LEmptyError, Pos.no_pos)) (Bindlib.box (ELit LEmptyError, Pos.no_pos))
Pos.no_pos [ (TLit TUnit, Pos.no_pos) ] Pos.no_pos) Pos.no_pos
[ (TLit TUnit, Pos.no_pos) ]
Pos.no_pos)
let is_value (e : expr Pos.marked) : bool = let is_value (e : expr Pos.marked) : bool =
match Pos.unmark e with ELit _ | EAbs _ | EOp _ -> true | _ -> false match Pos.unmark e with ELit _ | EAbs _ | EOp _ -> true | _ -> false
let build_whole_scope_expr (ctx : decl_ctx) (body : scope_body) (pos_scope : Pos.t) = let build_whole_scope_expr
(ctx : decl_ctx) (body : scope_body) (pos_scope : Pos.t) =
let body_expr = let body_expr =
List.fold_right List.fold_right
(fun scope_let acc -> (fun scope_let acc ->
@ -229,25 +252,37 @@ let build_whole_scope_expr (ctx : decl_ctx) (body : scope_body) (pos_scope : Pos
body_expr pos_scope body_expr pos_scope
[ [
( TTuple ( TTuple
( List.map snd (StructMap.find body.scope_body_input_struct ctx.ctx_structs), ( List.map snd
(StructMap.find body.scope_body_input_struct ctx.ctx_structs),
Some body.scope_body_input_struct ), Some body.scope_body_input_struct ),
pos_scope ); pos_scope );
] ]
pos_scope pos_scope
let build_scope_typ_from_sig (ctx : decl_ctx) (scope_input_struct_name : StructName.t) let build_scope_typ_from_sig
(scope_return_struct_name : StructName.t) (pos : Pos.t) : typ Pos.marked = (ctx : decl_ctx)
(scope_input_struct_name : StructName.t)
(scope_return_struct_name : StructName.t)
(pos : Pos.t) : typ Pos.marked =
let scope_sig = StructMap.find scope_input_struct_name ctx.ctx_structs in let scope_sig = StructMap.find scope_input_struct_name ctx.ctx_structs in
let scope_return_typ = StructMap.find scope_return_struct_name ctx.ctx_structs in let scope_return_typ =
let result_typ = (TTuple (List.map snd scope_return_typ, Some scope_return_struct_name), pos) in StructMap.find scope_return_struct_name ctx.ctx_structs
let input_typ = (TTuple (List.map snd scope_sig, Some scope_input_struct_name), pos) in in
let result_typ =
(TTuple (List.map snd scope_return_typ, Some scope_return_struct_name), pos)
in
let input_typ =
(TTuple (List.map snd scope_sig, Some scope_input_struct_name), pos)
in
(TArrow (input_typ, result_typ), pos) (TArrow (input_typ, result_typ), pos)
let build_whole_program_expr (p : program) (main_scope : ScopeName.t) = let build_whole_program_expr (p : program) (main_scope : ScopeName.t) =
let end_result = let end_result =
make_var make_var
(let _, x, _ = (let _, x, _ =
List.find (fun (s_name, _, _) -> ScopeName.compare main_scope s_name = 0) p.scopes List.find
(fun (s_name, _, _) -> ScopeName.compare main_scope s_name = 0)
p.scopes
in in
(x, Pos.no_pos)) (x, Pos.no_pos))
in in
@ -264,11 +299,18 @@ let build_whole_program_expr (p : program) (main_scope : ScopeName.t) =
let rec expr_size (e : expr Pos.marked) : int = let rec expr_size (e : expr Pos.marked) : int =
match Pos.unmark e with match Pos.unmark e with
| EVar _ | ELit _ | EOp _ -> 1 | EVar _ | ELit _ | EOp _ -> 1
| ETuple (args, _) | EArray args -> List.fold_left (fun acc arg -> acc + expr_size arg) 1 args | ETuple (args, _) | EArray args ->
| ETupleAccess (e1, _, _, _) | EInj (e1, _, _, _) | EAssert e1 | ErrorOnEmpty e1 -> List.fold_left (fun acc arg -> acc + expr_size arg) 1 args
| ETupleAccess (e1, _, _, _)
| EInj (e1, _, _, _)
| EAssert e1
| ErrorOnEmpty e1 ->
expr_size e1 + 1 expr_size e1 + 1
| EMatch (arg, args, _) | EApp (arg, args) -> | EMatch (arg, args, _) | EApp (arg, args) ->
List.fold_left (fun acc arg -> acc + expr_size arg) (1 + expr_size arg) args List.fold_left
(fun acc arg -> acc + expr_size arg)
(1 + expr_size arg)
args
| EAbs ((binder, _), _) -> | EAbs ((binder, _), _) ->
let _, body = Bindlib.unmbind binder in let _, body = Bindlib.unmbind binder in
1 + expr_size body 1 + expr_size body
@ -284,6 +326,8 @@ let variable_types (p : program) : typ Pos.marked VarMap.t =
(fun acc (_, _, scope) -> (fun acc (_, _, scope) ->
List.fold_left List.fold_left
(fun acc scope_let -> (fun acc scope_let ->
VarMap.add (Pos.unmark scope_let.scope_let_var) scope_let.scope_let_typ acc) VarMap.add
(Pos.unmark scope_let.scope_let_var)
scope_let.scope_let_typ acc)
acc scope.scope_body_lets) acc scope.scope_body_lets)
VarMap.empty p.scopes VarMap.empty p.scopes

View File

@ -1,33 +1,28 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits (* This file is part of the Catala compiler, a specification language for tax
computation rules. Copyright (C) 2020 Inria, contributor: Denis Merigoux and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
<denis.merigoux@inria.fr> Denis Merigoux <denis.merigoux@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
(** Abstract syntax tree of the default calculus intermediate representation *) (** Abstract syntax tree of the default calculus intermediate representation *)
open Utils open Utils
module ScopeName : Uid.Id with type info = Uid.MarkedString.info module ScopeName : Uid.Id with type info = Uid.MarkedString.info
module StructName : Uid.Id with type info = Uid.MarkedString.info module StructName : Uid.Id with type info = Uid.MarkedString.info
module StructFieldName : Uid.Id with type info = Uid.MarkedString.info module StructFieldName : Uid.Id with type info = Uid.MarkedString.info
module StructMap : Map.S with type key = StructName.t module StructMap : Map.S with type key = StructName.t
module EnumName : Uid.Id with type info = Uid.MarkedString.info module EnumName : Uid.Id with type info = Uid.MarkedString.info
module EnumConstructor : Uid.Id with type info = Uid.MarkedString.info module EnumConstructor : Uid.Id with type info = Uid.MarkedString.info
module EnumMap : Map.S with type key = EnumName.t module EnumMap : Map.S with type key = EnumName.t
(** Abstract syntax tree for the default calculus *) (** Abstract syntax tree for the default calculus *)
@ -45,7 +40,6 @@ type typ =
| TAny | TAny
type date = Runtime.date type date = Runtime.date
type duration = Runtime.duration type duration = Runtime.duration
type lit = type lit =
@ -87,8 +81,8 @@ type binop =
type log_entry = type log_entry =
| VarDef of typ | VarDef of typ
(** During code generation, we need to know the type of the variable being logged for (** During code generation, we need to know the type of the variable being
embedding *) logged for embedding *)
| BeginCall | BeginCall
| EndCall | EndCall
| PosRecordIfTrueBool | PosRecordIfTrueBool
@ -105,13 +99,14 @@ type unop =
type operator = Ternop of ternop | Binop of binop | Unop of unop type operator = Ternop of ternop | Binop of binop | Unop of unop
(** The expressions use the {{:https://lepigre.fr/ocaml-bindlib/} Bindlib} library, based on (** The expressions use the {{:https://lepigre.fr/ocaml-bindlib/} Bindlib}
higher-order abstract syntax*) library, based on higher-order abstract syntax*)
type expr = type expr =
| EVar of expr Bindlib.var Pos.marked | EVar of expr Bindlib.var Pos.marked
| ETuple of expr Pos.marked list * StructName.t option | ETuple of expr Pos.marked list * StructName.t option
(** The [MarkedString.info] is the former struct field name*) (** The [MarkedString.info] is the former struct field name*)
| ETupleAccess of expr Pos.marked * int * StructName.t option * typ Pos.marked list | ETupleAccess of
expr Pos.marked * int * StructName.t option * typ Pos.marked list
(** The [MarkedString.info] is the former struct field name *) (** The [MarkedString.info] is the former struct field name *)
| EInj of expr Pos.marked * int * EnumName.t * typ Pos.marked list | EInj of expr Pos.marked * int * EnumName.t * typ Pos.marked list
(** The [MarkedString.info] is the former enum case name *) (** The [MarkedString.info] is the former enum case name *)
@ -119,7 +114,9 @@ type expr =
(** The [MarkedString.info] is the former enum case name *) (** The [MarkedString.info] is the former enum case name *)
| EArray of expr Pos.marked list | EArray of expr Pos.marked list
| ELit of lit | ELit of lit
| EAbs of ((expr, expr Pos.marked) Bindlib.mbinder[@opaque]) Pos.marked * typ Pos.marked list | EAbs of
((expr, expr Pos.marked) Bindlib.mbinder[@opaque]) Pos.marked
* typ Pos.marked list
| EApp of expr Pos.marked * expr Pos.marked list | EApp of expr Pos.marked * expr Pos.marked list
| EAssert of expr Pos.marked | EAssert of expr Pos.marked
| EOp of operator | EOp of operator
@ -128,20 +125,19 @@ type expr =
| ErrorOnEmpty of expr Pos.marked | ErrorOnEmpty of expr Pos.marked
type struct_ctx = (StructFieldName.t * typ Pos.marked) list StructMap.t type struct_ctx = (StructFieldName.t * typ Pos.marked) list StructMap.t
type enum_ctx = (EnumConstructor.t * typ Pos.marked) list EnumMap.t type enum_ctx = (EnumConstructor.t * typ Pos.marked) list EnumMap.t
type decl_ctx = { ctx_enums : enum_ctx; ctx_structs : struct_ctx } type decl_ctx = { ctx_enums : enum_ctx; ctx_structs : struct_ctx }
type binder = (expr, expr Pos.marked) Bindlib.binder type binder = (expr, expr Pos.marked) Bindlib.binder
(** This kind annotation signals that the let-binding respects a structural invariant. These (** This kind annotation signals that the let-binding respects a structural
invariants concern the shape of the expression in the let-binding, and are documented below. *) invariant. These invariants concern the shape of the expression in the
let-binding, and are documented below. *)
type scope_let_kind = type scope_let_kind =
| DestructuringInputStruct (** [let x = input.field]*) | DestructuringInputStruct (** [let x = input.field]*)
| ScopeVarDefinition (** [let x = error_on_empty e]*) | ScopeVarDefinition (** [let x = error_on_empty e]*)
| SubScopeVarDefinition | SubScopeVarDefinition
(** [let s.x = fun _ -> e] or [let s.x = error_on_empty e] for input-only subscope variables. *) (** [let s.x = fun _ -> e] or [let s.x = error_on_empty e] for input-only
subscope variables. *)
| CallingSubScope (** [let result = s ({ x = s.x; y = s.x; ...}) ]*) | CallingSubScope (** [let result = s ({ x = s.x; y = s.x; ...}) ]*)
| DestructuringSubScopeResults (** [let s.x = result.x ]**) | DestructuringSubScopeResults (** [let s.x = result.x ]**)
| Assertion (** [let _ = assert e]*) | Assertion (** [let _ = assert e]*)
@ -152,9 +148,9 @@ type scope_let = {
scope_let_typ : typ Pos.marked; scope_let_typ : typ Pos.marked;
scope_let_expr : expr Pos.marked Bindlib.box; scope_let_expr : expr Pos.marked Bindlib.box;
} }
(** A scope let-binding has all the information necessary to make a proper let-binding expression, (** A scope let-binding has all the information necessary to make a proper
plus an annotation for the kind of the let-binding that comes from the compilation of a let-binding expression, plus an annotation for the kind of the let-binding
{!module: Scopelang.Ast} statement. *) that comes from the compilation of a {!module: Scopelang.Ast} statement. *)
type scope_body = { type scope_body = {
scope_body_lets : scope_let list; scope_body_lets : scope_let list;
@ -163,11 +159,14 @@ type scope_body = {
scope_body_input_struct : StructName.t; scope_body_input_struct : StructName.t;
scope_body_output_struct : StructName.t; scope_body_output_struct : StructName.t;
} }
(** Instead of being a single expression, we give a little more ad-hoc structure to the scope body (** Instead of being a single expression, we give a little more ad-hoc structure
by decomposing it in an ordered list of let-bindings, and a result expression that uses the to the scope body by decomposing it in an ordered list of let-bindings, and
let-binded variables. *) a result expression that uses the let-binded variables. *)
type program = { decl_ctx : decl_ctx; scopes : (ScopeName.t * expr Bindlib.var * scope_body) list } type program = {
decl_ctx : decl_ctx;
scopes : (ScopeName.t * expr Bindlib.var * scope_body) list;
}
(** {1 Helpers} *) (** {1 Helpers} *)
@ -177,14 +176,12 @@ module Var : sig
type t = expr Bindlib.var type t = expr Bindlib.var
val make : string Pos.marked -> t val make : string Pos.marked -> t
val compare : t -> t -> int val compare : t -> t -> int
end end
module VarMap : Map.S with type key = Var.t module VarMap : Map.S with type key = Var.t
val free_vars_set : expr Pos.marked -> unit VarMap.t val free_vars_set : expr Pos.marked -> unit VarMap.t
val free_vars_list : expr Pos.marked -> Var.t list val free_vars_list : expr Pos.marked -> Var.t list
type vars = expr Bindlib.mvar type vars = expr Bindlib.mvar
@ -216,22 +213,26 @@ val make_let_in :
(**{2 Other}*) (**{2 Other}*)
val empty_thunked_term : expr Pos.marked val empty_thunked_term : expr Pos.marked
val is_value : expr Pos.marked -> bool val is_value : expr Pos.marked -> bool
(** {1 AST manipulation helpers}*) (** {1 AST manipulation helpers}*)
val build_whole_scope_expr : decl_ctx -> scope_body -> Pos.t -> expr Pos.marked Bindlib.box val build_whole_scope_expr :
(** Usage: [build_whole_scope_expr ctx body scope_position] where [scope_position] corresponds to decl_ctx -> scope_body -> Pos.t -> expr Pos.marked Bindlib.box
the line of the scope declaration for instance. *) (** Usage: [build_whole_scope_expr ctx body scope_position] where
[scope_position] corresponds to the line of the scope declaration for
instance. *)
val build_whole_program_expr : program -> ScopeName.t -> expr Pos.marked Bindlib.box val build_whole_program_expr :
(** Usage: [build_whole_program_expr program main_scope] builds an expression corresponding to the program -> ScopeName.t -> expr Pos.marked Bindlib.box
main program and returning the main scope as a function. *) (** Usage: [build_whole_program_expr program main_scope] builds an expression
corresponding to the main program and returning the main scope as a
function. *)
val expr_size : expr Pos.marked -> int val expr_size : expr Pos.marked -> int
(** Used by the optimizer to know when to stop *) (** Used by the optimizer to know when to stop *)
val variable_types : program -> typ Pos.marked VarMap.t val variable_types : program -> typ Pos.marked VarMap.t
(** Traverses all the scopes and retrieves all the types for the variables that may appear in scope (** Traverses all the scopes and retrieves all the types for the variables that
or subscope variable definitions, giving them as a big map. *) may appear in scope or subscope variable definitions, giving them as a big
map. *)

View File

@ -1,15 +1,17 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits (* This file is part of the Catala compiler, a specification language for tax
computation rules. Copyright (C) 2020-2022 Inria, contributor: Alain Delaët-Tixeuil and social benefits computation rules. Copyright (C) 2020-2022 Inria,
<alain.delaet--tixeuil@inria.fr> contributor: Alain Delaët-Tixeuil <alain.delaet--tixeuil@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
open Utils open Utils
@ -47,7 +49,8 @@ let rec free_vars_set_scope_lets (scope_lets : scope_lets) : unit D.VarMap.t =
| Result e -> D.free_vars_set e | Result e -> D.free_vars_set e
| ScopeLet { scope_let_expr = e; scope_let_next = next; _ } -> | ScopeLet { scope_let_expr = e; scope_let_next = next; _ } ->
let v, body = Bindlib.unbind next in let v, body = Bindlib.unbind next in
union (D.free_vars_set e) (D.VarMap.remove v (free_vars_set_scope_lets body)) union (D.free_vars_set e)
(D.VarMap.remove v (free_vars_set_scope_lets body))
let free_vars_set_scope_body (scope_body : scope_body) : unit D.VarMap.t = let free_vars_set_scope_body (scope_body : scope_body) : unit D.VarMap.t =
let { scope_body_result = binder; _ } = scope_body in let { scope_body_result = binder; _ } = scope_body in
@ -60,7 +63,9 @@ let rec free_vars_set_scopes (scopes : scopes) : unit D.VarMap.t =
| ScopeDef { scope_body = body; scope_next = next; _ } -> | ScopeDef { scope_body = body; scope_next = next; _ } ->
let v, next = Bindlib.unbind next in let v, next = Bindlib.unbind next in
union (D.VarMap.remove v (free_vars_set_scopes next)) (free_vars_set_scope_body body) union
(D.VarMap.remove v (free_vars_set_scopes next))
(free_vars_set_scope_body body)
let free_vars_list_scope_lets (scope_lets : scope_lets) : D.Var.t list = let free_vars_list_scope_lets (scope_lets : scope_lets) : D.Var.t list =
free_vars_set_scope_lets scope_lets |> D.VarMap.bindings |> List.map fst free_vars_set_scope_lets scope_lets |> D.VarMap.bindings |> List.map fst
@ -76,13 +81,14 @@ let bind_scope_lets (acc : scope_lets Bindlib.box) (scope_let : D.scope_let) :
scope_lets Bindlib.box = scope_lets Bindlib.box =
let pos = snd scope_let.D.scope_let_var in let pos = snd scope_let.D.scope_let_var in
(* Cli.debug_print @@ Format.asprintf "binding let %a. Variable occurs = %b" Print.format_var (fst (* Cli.debug_print @@ Format.asprintf "binding let %a. Variable occurs = %b"
scope_let.D.scope_let_var) (Bindlib.occur (fst scope_let.D.scope_let_var) acc); *) Print.format_var (fst scope_let.D.scope_let_var) (Bindlib.occur (fst
scope_let.D.scope_let_var) acc); *)
let binder = Bindlib.bind_var (fst scope_let.D.scope_let_var) acc in let binder = Bindlib.bind_var (fst scope_let.D.scope_let_var) acc in
Bindlib.box_apply2 Bindlib.box_apply2
(fun expr binder -> (fun expr binder ->
(* Cli.debug_print @@ Format.asprintf "free variables in expression: %a" (Format.pp_print_list (* Cli.debug_print @@ Format.asprintf "free variables in expression: %a"
Print.format_var) (D.free_vars_list expr); *) (Format.pp_print_list Print.format_var) (D.free_vars_list expr); *)
ScopeLet ScopeLet
{ {
scope_let_kind = scope_let.D.scope_let_kind; scope_let_kind = scope_let.D.scope_let_kind;
@ -101,15 +107,16 @@ let bind_scope_body (body : D.scope_body) : scope_body Bindlib.box =
~f:(Fun.flip bind_scope_lets) ~f:(Fun.flip bind_scope_lets)
in in
(* Cli.debug_print @@ Format.asprintf "binding arg %a" Print.format_var body.D.scope_body_arg; *) (* Cli.debug_print @@ Format.asprintf "binding arg %a" Print.format_var
body.D.scope_body_arg; *)
let scope_body_result = Bindlib.bind_var body.D.scope_body_arg body_result in let scope_body_result = Bindlib.bind_var body.D.scope_body_arg body_result in
(* Cli.debug_print @@ Format.asprintf "isfinal term is closed: %b" (Bindlib.is_closed (* Cli.debug_print @@ Format.asprintf "isfinal term is closed: %b"
scope_body_result); *) (Bindlib.is_closed scope_body_result); *)
Bindlib.box_apply Bindlib.box_apply
(fun scope_body_result -> (fun scope_body_result ->
(* Cli.debug_print @@ Format.asprintf "rank of the final term: %i" (Bindlib.binder_rank (* Cli.debug_print @@ Format.asprintf "rank of the final term: %i"
scope_body_result); *) (Bindlib.binder_rank scope_body_result); *)
{ {
scope_body_output_struct = body.D.scope_body_output_struct; scope_body_output_struct = body.D.scope_body_output_struct;
scope_body_input_struct = body.D.scope_body_input_struct; scope_body_input_struct = body.D.scope_body_input_struct;
@ -118,15 +125,22 @@ let bind_scope_body (body : D.scope_body) : scope_body Bindlib.box =
scope_body_result scope_body_result
let bind_scope let bind_scope
((scope_name, scope_var, scope_body) : D.ScopeName.t * D.expr Bindlib.var * D.scope_body) ((scope_name, scope_var, scope_body) :
D.ScopeName.t * D.expr Bindlib.var * D.scope_body)
(acc : scopes Bindlib.box) : scopes Bindlib.box = (acc : scopes Bindlib.box) : scopes Bindlib.box =
Bindlib.box_apply2 Bindlib.box_apply2
(fun scope_body scope_next -> ScopeDef { scope_name; scope_body; scope_next }) (fun scope_body scope_next ->
(bind_scope_body scope_body) (Bindlib.bind_var scope_var acc) ScopeDef { scope_name; scope_body; scope_next })
(bind_scope_body scope_body)
(Bindlib.bind_var scope_var acc)
let bind_scopes (scopes : (D.ScopeName.t * D.expr Bindlib.var * D.scope_body) list) : let bind_scopes
(scopes : (D.ScopeName.t * D.expr Bindlib.var * D.scope_body) list) :
scopes Bindlib.box = scopes Bindlib.box =
let result = ListLabels.fold_right scopes ~init:(Bindlib.box Nil) ~f:bind_scope in let result =
(* Cli.debug_print @@ Format.asprintf "free variable in the program : [%a]" (Format.pp_print_list ListLabels.fold_right scopes ~init:(Bindlib.box Nil) ~f:bind_scope
Print.format_var) (free_vars_list_scopes (Bindlib.unbox result)); *) in
(* Cli.debug_print @@ Format.asprintf "free variable in the program : [%a]"
(Format.pp_print_list Print.format_var) (free_vars_list_scopes
(Bindlib.unbox result)); *)
result result

View File

@ -1,25 +1,28 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits (* This file is part of the Catala compiler, a specification language for tax
computation rules. Copyright (C) 2020-2022 Inria, contributor: Alain Delaët-Tixeuil and social benefits computation rules. Copyright (C) 2020-2022 Inria,
<alain.delaet--tixeuil@inria.fr> contributor: Alain Delaët-Tixeuil <alain.delaet--tixeuil@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
module D = Ast module D = Ast
(** Alternative representation of the Dcalc Ast. It is currently used in the transformation without (** Alternative representation of the Dcalc Ast. It is currently used in the
exceptions. We make heavy use of bindlib, binding each scope-let-variable and each scope transformation without exceptions. We make heavy use of bindlib, binding
explicitly. *) each scope-let-variable and each scope explicitly. *)
(** In [Ast], [Ast.scope_lets] is defined as a list of kind, var, and boxed expression. This (** In [Ast], [Ast.scope_lets] is defined as a list of kind, var, and boxed
representation binds using bindlib the tail of the list with the variable defined in the let. *) expression. This representation binds using bindlib the tail of the list
with the variable defined in the let. *)
type scope_lets = type scope_lets =
| Result of D.expr Utils.Pos.marked | Result of D.expr Utils.Pos.marked
| ScopeLet of { | ScopeLet of {
@ -35,12 +38,12 @@ type scope_body = {
scope_body_output_struct : D.StructName.t; scope_body_output_struct : D.StructName.t;
scope_body_result : (D.expr, scope_lets) Bindlib.binder; scope_body_result : (D.expr, scope_lets) Bindlib.binder;
} }
(** As a consequence, the scope_body contains only a result and input/output signature, as the other (** As a consequence, the scope_body contains only a result and input/output
elements are stored inside the scope_let. The binder present is the argument of type signature, as the other elements are stored inside the scope_let. The binder
[scope_body_input_struct]. *) present is the argument of type [scope_body_input_struct]. *)
(** Finally, we do the same transformation for the whole program for the kinded lets. This permit us (** Finally, we do the same transformation for the whole program for the kinded
to use bindlib variables for scopes names. *) lets. This permit us to use bindlib variables for scopes names. *)
type scopes = type scopes =
| Nil | Nil
| ScopeDef of { | ScopeDef of {
@ -58,6 +61,8 @@ val free_vars_list_scope_body : scope_body -> D.Var.t list
val free_vars_list_scopes : scopes -> D.Var.t list val free_vars_list_scopes : scopes -> D.Var.t list
(** List of variables not binded inside scopes*) (** List of variables not binded inside scopes*)
val bind_scopes : (D.ScopeName.t * D.expr Bindlib.var * D.scope_body) list -> scopes Bindlib.box val bind_scopes :
(** Transform a list of scopes into our representation of scopes. It requires that scopes are (D.ScopeName.t * D.expr Bindlib.var * D.scope_body) list -> scopes Bindlib.box
topologically-well-ordered, and ensure there is no free variables in the returned [scopes] *) (** Transform a list of scopes into our representation of scopes. It requires
that scopes are topologically-well-ordered, and ensure there is no free
variables in the returned [scopes] *)

View File

@ -1,15 +1,17 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits (* This file is part of the Catala compiler, a specification language for tax
computation rules. Copyright (C) 2020 Inria, contributor: Denis Merigoux and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
<denis.merigoux@inria.fr>, Emile Rolley <emile.rolley@tuta.io> Denis Merigoux <denis.merigoux@inria.fr>, Emile Rolley <emile.rolley@tuta.io>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
(** Reference interpreter for the default calculus *) (** Reference interpreter for the default calculus *)
@ -26,11 +28,14 @@ let log_indent = ref 0
(** {1 Evaluation} *) (** {1 Evaluation} *)
let rec evaluate_operator (ctx : Ast.decl_ctx) (op : A.operator Pos.marked) let rec evaluate_operator
(ctx : Ast.decl_ctx)
(op : A.operator Pos.marked)
(args : A.expr Pos.marked list) : A.expr Pos.marked = (args : A.expr Pos.marked list) : A.expr Pos.marked =
(* Try to apply [div] and if a [Division_by_zero] exceptions is catched, use [op] to raise (* Try to apply [div] and if a [Division_by_zero] exceptions is catched, use
multispanned errors. *) [op] to raise multispanned errors. *)
let apply_div_or_raise_err (div : unit -> A.expr) (op : A.operator Pos.marked) : A.expr = let apply_div_or_raise_err (div : unit -> A.expr) (op : A.operator Pos.marked)
: A.expr =
try div () try div ()
with Division_by_zero -> with Division_by_zero ->
Errors.raise_multispanned_error Errors.raise_multispanned_error
@ -40,16 +45,22 @@ let rec evaluate_operator (ctx : Ast.decl_ctx) (op : A.operator Pos.marked)
] ]
"division by zero at runtime" "division by zero at runtime"
in in
let get_binop_args_pos (args : (A.expr * Pos.t) list) : (string option * Pos.t) list = let get_binop_args_pos (args : (A.expr * Pos.t) list) :
[ (None, Pos.get_position (List.nth args 0)); (None, Pos.get_position (List.nth args 1)) ] (string option * Pos.t) list =
[
(None, Pos.get_position (List.nth args 0));
(None, Pos.get_position (List.nth args 1));
]
in in
(* Try to apply [cmp] and if a [UncomparableDurations] exceptions is catched, use [args] to raise (* Try to apply [cmp] and if a [UncomparableDurations] exceptions is catched,
multispanned errors. *) use [args] to raise multispanned errors. *)
let apply_cmp_or_raise_err (cmp : unit -> A.expr) (args : (A.expr * Pos.t) list) : A.expr = let apply_cmp_or_raise_err
(cmp : unit -> A.expr) (args : (A.expr * Pos.t) list) : A.expr =
try cmp () try cmp ()
with Runtime.UncomparableDurations -> with Runtime.UncomparableDurations ->
Errors.raise_multispanned_error (get_binop_args_pos args) Errors.raise_multispanned_error (get_binop_args_pos args)
"Cannot compare together durations that cannot be converted to a precise number of days" "Cannot compare together durations that cannot be converted to a \
precise number of days"
in in
Pos.same_pos_as Pos.same_pos_as
(match (Pos.unmark op, List.map Pos.unmark args) with (match (Pos.unmark op, List.map Pos.unmark args) with
@ -57,19 +68,29 @@ let rec evaluate_operator (ctx : Ast.decl_ctx) (op : A.operator Pos.marked)
Pos.unmark Pos.unmark
(List.fold_left (List.fold_left
(fun acc e' -> (fun acc e' ->
evaluate_expr ctx (Pos.same_pos_as (A.EApp (List.nth args 0, [ acc; e' ])) e')) evaluate_expr ctx
(Pos.same_pos_as (A.EApp (List.nth args 0, [ acc; e' ])) e'))
(List.nth args 1) es) (List.nth args 1) es)
| A.Binop A.And, [ ELit (LBool b1); ELit (LBool b2) ] -> A.ELit (LBool (b1 && b2)) | A.Binop A.And, [ ELit (LBool b1); ELit (LBool b2) ] ->
| A.Binop A.Or, [ ELit (LBool b1); ELit (LBool b2) ] -> A.ELit (LBool (b1 || b2)) A.ELit (LBool (b1 && b2))
| A.Binop A.Xor, [ ELit (LBool b1); ELit (LBool b2) ] -> A.ELit (LBool (b1 <> b2)) | A.Binop A.Or, [ ELit (LBool b1); ELit (LBool b2) ] ->
| A.Binop (A.Add KInt), [ ELit (LInt i1); ELit (LInt i2) ] -> A.ELit (LInt Runtime.(i1 +! i2)) A.ELit (LBool (b1 || b2))
| A.Binop (A.Sub KInt), [ ELit (LInt i1); ELit (LInt i2) ] -> A.ELit (LInt Runtime.(i1 -! i2)) | A.Binop A.Xor, [ ELit (LBool b1); ELit (LBool b2) ] ->
| A.Binop (A.Mult KInt), [ ELit (LInt i1); ELit (LInt i2) ] -> A.ELit (LInt Runtime.(i1 *! i2)) A.ELit (LBool (b1 <> b2))
| A.Binop (A.Add KInt), [ ELit (LInt i1); ELit (LInt i2) ] ->
A.ELit (LInt Runtime.(i1 +! i2))
| A.Binop (A.Sub KInt), [ ELit (LInt i1); ELit (LInt i2) ] ->
A.ELit (LInt Runtime.(i1 -! i2))
| A.Binop (A.Mult KInt), [ ELit (LInt i1); ELit (LInt i2) ] ->
A.ELit (LInt Runtime.(i1 *! i2))
| A.Binop (A.Div KInt), [ ELit (LInt i1); ELit (LInt i2) ] -> | A.Binop (A.Div KInt), [ ELit (LInt i1); ELit (LInt i2) ] ->
apply_div_or_raise_err (fun _ -> A.ELit (LInt Runtime.(i1 /! i2))) op apply_div_or_raise_err (fun _ -> A.ELit (LInt Runtime.(i1 /! i2))) op
| A.Binop (A.Add KRat), [ ELit (LRat i1); ELit (LRat i2) ] -> A.ELit (LRat Runtime.(i1 +& i2)) | A.Binop (A.Add KRat), [ ELit (LRat i1); ELit (LRat i2) ] ->
| A.Binop (A.Sub KRat), [ ELit (LRat i1); ELit (LRat i2) ] -> A.ELit (LRat Runtime.(i1 -& i2)) A.ELit (LRat Runtime.(i1 +& i2))
| A.Binop (A.Mult KRat), [ ELit (LRat i1); ELit (LRat i2) ] -> A.ELit (LRat Runtime.(i1 *& i2)) | A.Binop (A.Sub KRat), [ ELit (LRat i1); ELit (LRat i2) ] ->
A.ELit (LRat Runtime.(i1 -& i2))
| A.Binop (A.Mult KRat), [ ELit (LRat i1); ELit (LRat i2) ] ->
A.ELit (LRat Runtime.(i1 *& i2))
| A.Binop (A.Div KRat), [ ELit (LRat i1); ELit (LRat i2) ] -> | A.Binop (A.Div KRat), [ ELit (LRat i1); ELit (LRat i2) ] ->
apply_div_or_raise_err (fun _ -> A.ELit (LRat Runtime.(i1 /& i2))) op apply_div_or_raise_err (fun _ -> A.ELit (LRat Runtime.(i1 /& i2))) op
| A.Binop (A.Add KMoney), [ ELit (LMoney m1); ELit (LMoney m2) ] -> | A.Binop (A.Add KMoney), [ ELit (LMoney m1); ELit (LMoney m2) ] ->
@ -94,16 +115,25 @@ let rec evaluate_operator (ctx : Ast.decl_ctx) (op : A.operator Pos.marked)
try A.ELit (LRat Runtime.(d1 /^ d2)) try A.ELit (LRat Runtime.(d1 /^ d2))
with Runtime.IndivisableDurations -> with Runtime.IndivisableDurations ->
Errors.raise_multispanned_error (get_binop_args_pos args) Errors.raise_multispanned_error (get_binop_args_pos args)
"Cannot divide durations that cannot be converted to a precise number of days") "Cannot divide durations that cannot be converted to a precise \
number of days")
op op
| A.Binop (A.Lt KInt), [ ELit (LInt i1); ELit (LInt i2) ] -> A.ELit (LBool Runtime.(i1 <! i2)) | A.Binop (A.Lt KInt), [ ELit (LInt i1); ELit (LInt i2) ] ->
| A.Binop (A.Lte KInt), [ ELit (LInt i1); ELit (LInt i2) ] -> A.ELit (LBool Runtime.(i1 <=! i2)) A.ELit (LBool Runtime.(i1 <! i2))
| A.Binop (A.Gt KInt), [ ELit (LInt i1); ELit (LInt i2) ] -> A.ELit (LBool Runtime.(i1 >! i2)) | A.Binop (A.Lte KInt), [ ELit (LInt i1); ELit (LInt i2) ] ->
| A.Binop (A.Gte KInt), [ ELit (LInt i1); ELit (LInt i2) ] -> A.ELit (LBool Runtime.(i1 >=! i2)) A.ELit (LBool Runtime.(i1 <=! i2))
| A.Binop (A.Lt KRat), [ ELit (LRat i1); ELit (LRat i2) ] -> A.ELit (LBool Runtime.(i1 <& i2)) | A.Binop (A.Gt KInt), [ ELit (LInt i1); ELit (LInt i2) ] ->
| A.Binop (A.Lte KRat), [ ELit (LRat i1); ELit (LRat i2) ] -> A.ELit (LBool Runtime.(i1 <=& i2)) A.ELit (LBool Runtime.(i1 >! i2))
| A.Binop (A.Gt KRat), [ ELit (LRat i1); ELit (LRat i2) ] -> A.ELit (LBool Runtime.(i1 >& i2)) | A.Binop (A.Gte KInt), [ ELit (LInt i1); ELit (LInt i2) ] ->
| A.Binop (A.Gte KRat), [ ELit (LRat i1); ELit (LRat i2) ] -> A.ELit (LBool Runtime.(i1 >=& i2)) A.ELit (LBool Runtime.(i1 >=! i2))
| A.Binop (A.Lt KRat), [ ELit (LRat i1); ELit (LRat i2) ] ->
A.ELit (LBool Runtime.(i1 <& i2))
| A.Binop (A.Lte KRat), [ ELit (LRat i1); ELit (LRat i2) ] ->
A.ELit (LBool Runtime.(i1 <=& i2))
| A.Binop (A.Gt KRat), [ ELit (LRat i1); ELit (LRat i2) ] ->
A.ELit (LBool Runtime.(i1 >& i2))
| A.Binop (A.Gte KRat), [ ELit (LRat i1); ELit (LRat i2) ] ->
A.ELit (LBool Runtime.(i1 >=& i2))
| A.Binop (A.Lt KMoney), [ ELit (LMoney m1); ELit (LMoney m2) ] -> | A.Binop (A.Lt KMoney), [ ELit (LMoney m1); ELit (LMoney m2) ] ->
A.ELit (LBool Runtime.(m1 <$ m2)) A.ELit (LBool Runtime.(m1 <$ m2))
| A.Binop (A.Lte KMoney), [ ELit (LMoney m1); ELit (LMoney m2) ] -> | A.Binop (A.Lte KMoney), [ ELit (LMoney m1); ELit (LMoney m2) ] ->
@ -115,11 +145,15 @@ let rec evaluate_operator (ctx : Ast.decl_ctx) (op : A.operator Pos.marked)
| A.Binop (A.Lt KDuration), [ ELit (LDuration d1); ELit (LDuration d2) ] -> | A.Binop (A.Lt KDuration), [ ELit (LDuration d1); ELit (LDuration d2) ] ->
apply_cmp_or_raise_err (fun _ -> A.ELit (LBool Runtime.(d1 <^ d2))) args apply_cmp_or_raise_err (fun _ -> A.ELit (LBool Runtime.(d1 <^ d2))) args
| A.Binop (A.Lte KDuration), [ ELit (LDuration d1); ELit (LDuration d2) ] -> | A.Binop (A.Lte KDuration), [ ELit (LDuration d1); ELit (LDuration d2) ] ->
apply_cmp_or_raise_err (fun _ -> A.ELit (LBool Runtime.(d1 <=^ d2))) args apply_cmp_or_raise_err
(fun _ -> A.ELit (LBool Runtime.(d1 <=^ d2)))
args
| A.Binop (A.Gt KDuration), [ ELit (LDuration d1); ELit (LDuration d2) ] -> | A.Binop (A.Gt KDuration), [ ELit (LDuration d1); ELit (LDuration d2) ] ->
apply_cmp_or_raise_err (fun _ -> A.ELit (LBool Runtime.(d1 >^ d2))) args apply_cmp_or_raise_err (fun _ -> A.ELit (LBool Runtime.(d1 >^ d2))) args
| A.Binop (A.Gte KDuration), [ ELit (LDuration d1); ELit (LDuration d2) ] -> | A.Binop (A.Gte KDuration), [ ELit (LDuration d1); ELit (LDuration d2) ] ->
apply_cmp_or_raise_err (fun _ -> A.ELit (LBool Runtime.(d1 >=^ d2))) args apply_cmp_or_raise_err
(fun _ -> A.ELit (LBool Runtime.(d1 >=^ d2)))
args
| A.Binop (A.Lt KDate), [ ELit (LDate d1); ELit (LDate d2) ] -> | A.Binop (A.Lt KDate), [ ELit (LDate d1); ELit (LDate d2) ] ->
A.ELit (LBool Runtime.(d1 <@ d2)) A.ELit (LBool Runtime.(d1 <@ d2))
| A.Binop (A.Lte KDate), [ ELit (LDate d1); ELit (LDate d2) ] -> | A.Binop (A.Lte KDate), [ ELit (LDate d1); ELit (LDate d2) ] ->
@ -131,11 +165,16 @@ let rec evaluate_operator (ctx : Ast.decl_ctx) (op : A.operator Pos.marked)
| A.Binop A.Eq, [ ELit LUnit; ELit LUnit ] -> A.ELit (LBool true) | A.Binop A.Eq, [ ELit LUnit; ELit LUnit ] -> A.ELit (LBool true)
| A.Binop A.Eq, [ ELit (LDuration d1); ELit (LDuration d2) ] -> | A.Binop A.Eq, [ ELit (LDuration d1); ELit (LDuration d2) ] ->
A.ELit (LBool Runtime.(d1 =^ d2)) A.ELit (LBool Runtime.(d1 =^ d2))
| A.Binop A.Eq, [ ELit (LDate d1); ELit (LDate d2) ] -> A.ELit (LBool Runtime.(d1 =@ d2)) | A.Binop A.Eq, [ ELit (LDate d1); ELit (LDate d2) ] ->
| A.Binop A.Eq, [ ELit (LMoney m1); ELit (LMoney m2) ] -> A.ELit (LBool Runtime.(m1 =$ m2)) A.ELit (LBool Runtime.(d1 =@ d2))
| A.Binop A.Eq, [ ELit (LRat i1); ELit (LRat i2) ] -> A.ELit (LBool Runtime.(i1 =& i2)) | A.Binop A.Eq, [ ELit (LMoney m1); ELit (LMoney m2) ] ->
| A.Binop A.Eq, [ ELit (LInt i1); ELit (LInt i2) ] -> A.ELit (LBool Runtime.(i1 =! i2)) A.ELit (LBool Runtime.(m1 =$ m2))
| A.Binop A.Eq, [ ELit (LBool b1); ELit (LBool b2) ] -> A.ELit (LBool (b1 = b2)) | A.Binop A.Eq, [ ELit (LRat i1); ELit (LRat i2) ] ->
A.ELit (LBool Runtime.(i1 =& i2))
| A.Binop A.Eq, [ ELit (LInt i1); ELit (LInt i2) ] ->
A.ELit (LBool Runtime.(i1 =! i2))
| A.Binop A.Eq, [ ELit (LBool b1); ELit (LBool b2) ] ->
A.ELit (LBool (b1 = b2))
| A.Binop A.Eq, [ EArray es1; EArray es2 ] -> | A.Binop A.Eq, [ EArray es1; EArray es2 ] ->
A.ELit A.ELit
(LBool (LBool
@ -155,7 +194,9 @@ let rec evaluate_operator (ctx : Ast.decl_ctx) (op : A.operator Pos.marked)
s1 = s2 s1 = s2
&& List.for_all2 && List.for_all2
(fun e1 e2 -> (fun e1 e2 ->
match Pos.unmark (evaluate_operator ctx op [ e1; e2 ]) with match
Pos.unmark (evaluate_operator ctx op [ e1; e2 ])
with
| A.ELit (LBool b) -> b | A.ELit (LBool b) -> b
| _ -> assert false | _ -> assert false
(* should not happen *)) (* should not happen *))
@ -172,54 +213,76 @@ let rec evaluate_operator (ctx : Ast.decl_ctx) (op : A.operator Pos.marked)
| _ -> assert false | _ -> assert false
(* should not happen *) (* should not happen *)
with Invalid_argument _ -> false)) with Invalid_argument _ -> false))
| A.Binop A.Eq, [ _; _ ] -> A.ELit (LBool false) (* comparing anything else return false *) | A.Binop A.Eq, [ _; _ ] ->
A.ELit (LBool false) (* comparing anything else return false *)
| A.Binop A.Neq, [ _; _ ] -> ( | A.Binop A.Neq, [ _; _ ] -> (
match Pos.unmark (evaluate_operator ctx (Pos.same_pos_as (A.Binop A.Eq) op) args) with match
Pos.unmark
(evaluate_operator ctx (Pos.same_pos_as (A.Binop A.Eq) op) args)
with
| A.ELit (A.LBool b) -> A.ELit (A.LBool (not b)) | A.ELit (A.LBool b) -> A.ELit (A.LBool (not b))
| _ -> assert false (*should not happen *)) | _ -> assert false (*should not happen *))
| A.Binop A.Concat, [ A.EArray es1; A.EArray es2 ] -> A.EArray (es1 @ es2) | A.Binop A.Concat, [ A.EArray es1; A.EArray es2 ] -> A.EArray (es1 @ es2)
| A.Binop A.Map, [ _; A.EArray es ] -> | A.Binop A.Map, [ _; A.EArray es ] ->
A.EArray A.EArray
(List.map (List.map
(fun e' -> evaluate_expr ctx (Pos.same_pos_as (A.EApp (List.nth args 0, [ e' ])) e')) (fun e' ->
evaluate_expr ctx
(Pos.same_pos_as (A.EApp (List.nth args 0, [ e' ])) e'))
es) es)
| A.Binop A.Filter, [ _; A.EArray es ] -> | A.Binop A.Filter, [ _; A.EArray es ] ->
A.EArray A.EArray
(List.filter (List.filter
(fun e' -> (fun e' ->
match evaluate_expr ctx (Pos.same_pos_as (A.EApp (List.nth args 0, [ e' ])) e') with match
evaluate_expr ctx
(Pos.same_pos_as (A.EApp (List.nth args 0, [ e' ])) e')
with
| A.ELit (A.LBool b), _ -> b | A.ELit (A.LBool b), _ -> b
| _ -> | _ ->
Errors.raise_spanned_error Errors.raise_spanned_error
(Pos.get_position (List.nth args 0)) (Pos.get_position (List.nth args 0))
"This predicate evaluated to something else than a boolean (should not happen \ "This predicate evaluated to something else than a \
if the term was well-typed)") boolean (should not happen if the term was well-typed)")
es) es)
| A.Binop _, ([ ELit LEmptyError; _ ] | [ _; ELit LEmptyError ]) -> A.ELit LEmptyError | A.Binop _, ([ ELit LEmptyError; _ ] | [ _; ELit LEmptyError ]) ->
| A.Unop (A.Minus KInt), [ ELit (LInt i) ] -> A.ELit (LInt Runtime.(integer_of_int 0 -! i)) A.ELit LEmptyError
| A.Unop (A.Minus KRat), [ ELit (LRat i) ] -> A.ELit (LRat Runtime.(decimal_of_string "0" -& i)) | A.Unop (A.Minus KInt), [ ELit (LInt i) ] ->
A.ELit (LInt Runtime.(integer_of_int 0 -! i))
| A.Unop (A.Minus KRat), [ ELit (LRat i) ] ->
A.ELit (LRat Runtime.(decimal_of_string "0" -& i))
| A.Unop (A.Minus KMoney), [ ELit (LMoney i) ] -> | A.Unop (A.Minus KMoney), [ ELit (LMoney i) ] ->
A.ELit (LMoney Runtime.(money_of_units_int 0 -$ i)) A.ELit (LMoney Runtime.(money_of_units_int 0 -$ i))
| A.Unop (A.Minus KDuration), [ ELit (LDuration i) ] -> A.ELit (LDuration Runtime.(~-^i)) | A.Unop (A.Minus KDuration), [ ELit (LDuration i) ] ->
A.ELit (LDuration Runtime.(~-^i))
| A.Unop A.Not, [ ELit (LBool b) ] -> A.ELit (LBool (not b)) | A.Unop A.Not, [ ELit (LBool b) ] -> A.ELit (LBool (not b))
| A.Unop A.Length, [ EArray es ] -> A.ELit (LInt (Runtime.integer_of_int (List.length es))) | A.Unop A.Length, [ EArray es ] ->
| A.Unop A.GetDay, [ ELit (LDate d) ] -> A.ELit (LInt Runtime.(day_of_month_of_date d)) A.ELit (LInt (Runtime.integer_of_int (List.length es)))
| A.Unop A.GetMonth, [ ELit (LDate d) ] -> A.ELit (LInt Runtime.(month_number_of_date d)) | A.Unop A.GetDay, [ ELit (LDate d) ] ->
| A.Unop A.GetYear, [ ELit (LDate d) ] -> A.ELit (LInt Runtime.(year_of_date d)) A.ELit (LInt Runtime.(day_of_month_of_date d))
| A.Unop A.IntToRat, [ ELit (LInt i) ] -> A.ELit (LRat Runtime.(decimal_of_integer i)) | A.Unop A.GetMonth, [ ELit (LDate d) ] ->
A.ELit (LInt Runtime.(month_number_of_date d))
| A.Unop A.GetYear, [ ELit (LDate d) ] ->
A.ELit (LInt Runtime.(year_of_date d))
| A.Unop A.IntToRat, [ ELit (LInt i) ] ->
A.ELit (LRat Runtime.(decimal_of_integer i))
| A.Unop (A.Log (entry, infos)), [ e' ] -> | A.Unop (A.Log (entry, infos)), [ e' ] ->
if !Cli.trace_flag then ( if !Cli.trace_flag then (
match entry with match entry with
| VarDef _ -> | VarDef _ ->
(* TODO: this usage of Format is broken, Formatting requires that all is formatted in (* TODO: this usage of Format is broken, Formatting requires that
one pass, without going through intermediate "%s" *) all is formatted in one pass, without going through
Cli.log_format "%*s%a %a: %s" (!log_indent * 2) "" Print.format_log_entry entry intermediate "%s" *)
Print.format_uid_list infos Cli.log_format "%*s%a %a: %s" (!log_indent * 2) ""
Print.format_log_entry entry Print.format_uid_list infos
(match e' with (match e' with
(* | Ast.EAbs _ -> Cli.with_style [ ANSITerminal.green ] "<function>" *) (* | Ast.EAbs _ -> Cli.with_style [ ANSITerminal.green ]
"<function>" *)
| _ -> | _ ->
let expr_str = let expr_str =
Format.asprintf "%a" (Print.format_expr ctx ~debug:false) (e', Pos.no_pos) Format.asprintf "%a"
(Print.format_expr ctx ~debug:false)
(e', Pos.no_pos)
in in
let expr_str = let expr_str =
Re.Pcre.substitute ~rex:(Re.Pcre.regexp "\n\\s*") Re.Pcre.substitute ~rex:(Re.Pcre.regexp "\n\\s*")
@ -231,19 +294,20 @@ let rec evaluate_operator (ctx : Ast.decl_ctx) (op : A.operator Pos.marked)
let pos = Pos.get_position op in let pos = Pos.get_position op in
match (pos <> Pos.no_pos, e') with match (pos <> Pos.no_pos, e') with
| true, ELit (LBool true) -> | true, ELit (LBool true) ->
Cli.log_format "%*s%a%s:\n%s" (!log_indent * 2) "" Print.format_log_entry entry Cli.log_format "%*s%a%s:\n%s" (!log_indent * 2) ""
Print.format_log_entry entry
(Cli.with_style [ ANSITerminal.green ] "Definition applied") (Cli.with_style [ ANSITerminal.green ] "Definition applied")
(Cli.add_prefix_to_each_line (Pos.retrieve_loc_text pos) (fun _ -> (Cli.add_prefix_to_each_line (Pos.retrieve_loc_text pos)
Format.asprintf "%*s" (!log_indent * 2) "")) (fun _ -> Format.asprintf "%*s" (!log_indent * 2) ""))
| _ -> ()) | _ -> ())
| BeginCall -> | BeginCall ->
Cli.log_format "%*s%a %a" (!log_indent * 2) "" Print.format_log_entry entry Cli.log_format "%*s%a %a" (!log_indent * 2) ""
Print.format_uid_list infos; Print.format_log_entry entry Print.format_uid_list infos;
log_indent := !log_indent + 1 log_indent := !log_indent + 1
| EndCall -> | EndCall ->
log_indent := !log_indent - 1; log_indent := !log_indent - 1;
Cli.log_format "%*s%a %a" (!log_indent * 2) "" Print.format_log_entry entry Cli.log_format "%*s%a %a" (!log_indent * 2) ""
Print.format_uid_list infos) Print.format_log_entry entry Print.format_uid_list infos)
else (); else ();
e' e'
| A.Unop _, [ ELit LEmptyError ] -> A.ELit LEmptyError | A.Unop _, [ ELit LEmptyError ] -> A.ELit LEmptyError
@ -258,36 +322,44 @@ let rec evaluate_operator (ctx : Ast.decl_ctx) (op : A.operator Pos.marked)
arg), arg),
Pos.get_position arg )) Pos.get_position arg ))
args) args)
"Operator applied to the wrong arguments\n(should not happen if the term was well-typed)") "Operator applied to the wrong arguments\n\
(should not happen if the term was well-typed)")
op op
and evaluate_expr (ctx : Ast.decl_ctx) (e : A.expr Pos.marked) : A.expr Pos.marked = and evaluate_expr (ctx : Ast.decl_ctx) (e : A.expr Pos.marked) :
A.expr Pos.marked =
match Pos.unmark e with match Pos.unmark e with
| EVar _ -> | EVar _ ->
Errors.raise_spanned_error (Pos.get_position e) Errors.raise_spanned_error (Pos.get_position e)
"free variable found at evaluation (should not happen if term was well-typed" "free variable found at evaluation (should not happen if term was \
well-typed"
| EApp (e1, args) -> ( | EApp (e1, args) -> (
let e1 = evaluate_expr ctx e1 in let e1 = evaluate_expr ctx e1 in
let args = List.map (evaluate_expr ctx) args in let args = List.map (evaluate_expr ctx) args in
match Pos.unmark e1 with match Pos.unmark e1 with
| EAbs ((binder, _), _) -> | EAbs ((binder, _), _) ->
if Bindlib.mbinder_arity binder = List.length args then if Bindlib.mbinder_arity binder = List.length args then
evaluate_expr ctx (Bindlib.msubst binder (Array.of_list (List.map Pos.unmark args))) evaluate_expr ctx
(Bindlib.msubst binder (Array.of_list (List.map Pos.unmark args)))
else else
Errors.raise_spanned_error (Pos.get_position e) Errors.raise_spanned_error (Pos.get_position e)
"wrong function call, expected %d arguments, got %d" (Bindlib.mbinder_arity binder) "wrong function call, expected %d arguments, got %d"
(Bindlib.mbinder_arity binder)
(List.length args) (List.length args)
| EOp op -> | EOp op ->
Pos.same_pos_as (Pos.unmark (evaluate_operator ctx (Pos.same_pos_as op e1) args)) e Pos.same_pos_as
(Pos.unmark (evaluate_operator ctx (Pos.same_pos_as op e1) args))
e
| ELit LEmptyError -> Pos.same_pos_as (A.ELit LEmptyError) e | ELit LEmptyError -> Pos.same_pos_as (A.ELit LEmptyError) e
| _ -> | _ ->
Errors.raise_spanned_error (Pos.get_position e) Errors.raise_spanned_error (Pos.get_position e)
"function has not been reduced to a lambda at evaluation (should not happen if the \ "function has not been reduced to a lambda at evaluation (should \
term was well-typed") not happen if the term was well-typed")
| EAbs _ | ELit _ | EOp _ -> e (* these are values *) | EAbs _ | ELit _ | EOp _ -> e (* these are values *)
| ETuple (es, s) -> | ETuple (es, s) ->
let new_es = List.map (evaluate_expr ctx) es in let new_es = List.map (evaluate_expr ctx) es in
if List.exists is_empty_error new_es then Pos.same_pos_as (A.ELit LEmptyError) e if List.exists is_empty_error new_es then
Pos.same_pos_as (A.ELit LEmptyError) e
else Pos.same_pos_as (A.ETuple (new_es, s)) e else Pos.same_pos_as (A.ETuple (new_es, s)) e
| ETupleAccess (e1, n, s, _) -> ( | ETupleAccess (e1, n, s, _) -> (
let e1 = evaluate_expr ctx e1 in let e1 = evaluate_expr ctx e1 in
@ -299,20 +371,20 @@ and evaluate_expr (ctx : Ast.decl_ctx) (e : A.expr Pos.marked) : A.expr Pos.mark
| _ -> | _ ->
Errors.raise_multispanned_error Errors.raise_multispanned_error
[ (None, Pos.get_position e); (None, Pos.get_position e1) ] [ (None, Pos.get_position e); (None, Pos.get_position e1) ]
"Error during tuple access: not the same structs (should not happen if the term \ "Error during tuple access: not the same structs (should not \
was well-typed)"); happen if the term was well-typed)");
match List.nth_opt es n with match List.nth_opt es n with
| Some e' -> e' | Some e' -> e'
| None -> | None ->
Errors.raise_spanned_error (Pos.get_position e1) Errors.raise_spanned_error (Pos.get_position e1)
"The tuple has %d components but the %i-th element was requested (should not \ "The tuple has %d components but the %i-th element was \
happen if the term was well-type)" requested (should not happen if the term was well-type)"
(List.length es) n) (List.length es) n)
| ELit LEmptyError -> Pos.same_pos_as (A.ELit LEmptyError) e | ELit LEmptyError -> Pos.same_pos_as (A.ELit LEmptyError) e
| _ -> | _ ->
Errors.raise_spanned_error (Pos.get_position e1) Errors.raise_spanned_error (Pos.get_position e1)
"The expression %a should be a tuple with %d components but is not (should not happen \ "The expression %a should be a tuple with %d components but is not \
if the term was well-typed)" (should not happen if the term was well-typed)"
(Print.format_expr ctx ~debug:true) (Print.format_expr ctx ~debug:true)
e n) e n)
| EInj (e1, n, en, ts) -> | EInj (e1, n, en, ts) ->
@ -326,22 +398,23 @@ and evaluate_expr (ctx : Ast.decl_ctx) (e : A.expr Pos.marked) : A.expr Pos.mark
if e_name <> e_name' then if e_name <> e_name' then
Errors.raise_multispanned_error Errors.raise_multispanned_error
[ (None, Pos.get_position e); (None, Pos.get_position e1) ] [ (None, Pos.get_position e); (None, Pos.get_position e1) ]
"Error during match: two different enums found (should not happend if the term was \ "Error during match: two different enums found (should not \
well-typed)"; happend if the term was well-typed)";
let es_n = let es_n =
match List.nth_opt es n with match List.nth_opt es n with
| Some es_n -> es_n | Some es_n -> es_n
| None -> | None ->
Errors.raise_spanned_error (Pos.get_position e) Errors.raise_spanned_error (Pos.get_position e)
"sum type index error (should not happend if the term was well-typed)" "sum type index error (should not happend if the term was \
well-typed)"
in in
let new_e = Pos.same_pos_as (A.EApp (es_n, [ e1 ])) e in let new_e = Pos.same_pos_as (A.EApp (es_n, [ e1 ])) e in
evaluate_expr ctx new_e evaluate_expr ctx new_e
| A.ELit A.LEmptyError -> Pos.same_pos_as (A.ELit A.LEmptyError) e | A.ELit A.LEmptyError -> Pos.same_pos_as (A.ELit A.LEmptyError) e
| _ -> | _ ->
Errors.raise_spanned_error (Pos.get_position e1) Errors.raise_spanned_error (Pos.get_position e1)
"Expected a term having a sum type as an argument to a match (should not happend if \ "Expected a term having a sum type as an argument to a match \
the term was well-typed") (should not happend if the term was well-typed")
| EDefault (exceptions, just, cons) -> ( | EDefault (exceptions, just, cons) -> (
let exceptions = List.map (evaluate_expr ctx) exceptions in let exceptions = List.map (evaluate_expr ctx) exceptions in
let empty_count = List.length (List.filter is_empty_error exceptions) in let empty_count = List.length (List.filter is_empty_error exceptions) in
@ -354,17 +427,18 @@ and evaluate_expr (ctx : Ast.decl_ctx) (e : A.expr Pos.marked) : A.expr Pos.mark
| ELit (LBool false) -> Pos.same_pos_as (A.ELit LEmptyError) e | ELit (LBool false) -> Pos.same_pos_as (A.ELit LEmptyError) e
| _ -> | _ ->
Errors.raise_spanned_error (Pos.get_position e) Errors.raise_spanned_error (Pos.get_position e)
"Default justification has not been reduced to a boolean at evaluation (should not \ "Default justification has not been reduced to a boolean at \
happen if the term was well-typed") evaluation (should not happen if the term was well-typed")
| 1 -> List.find (fun sub -> not (is_empty_error sub)) exceptions | 1 -> List.find (fun sub -> not (is_empty_error sub)) exceptions
| _ -> | _ ->
Errors.raise_multispanned_error Errors.raise_multispanned_error
(List.map (List.map
(fun except -> (fun except ->
(Some "This consequence has a valid justification:", Pos.get_position except)) ( Some "This consequence has a valid justification:",
Pos.get_position except ))
(List.filter (fun sub -> not (is_empty_error sub)) exceptions)) (List.filter (fun sub -> not (is_empty_error sub)) exceptions))
"There is a conflict between multiple validd consequences for assigning the same \ "There is a conflict between multiple validd consequences for \
variable.") assigning the same variable.")
| EIfThenElse (cond, et, ef) -> ( | EIfThenElse (cond, et, ef) -> (
match Pos.unmark (evaluate_expr ctx cond) with match Pos.unmark (evaluate_expr ctx cond) with
| ELit (LBool true) -> evaluate_expr ctx et | ELit (LBool true) -> evaluate_expr ctx et
@ -372,36 +446,42 @@ and evaluate_expr (ctx : Ast.decl_ctx) (e : A.expr Pos.marked) : A.expr Pos.mark
| ELit LEmptyError -> Pos.same_pos_as (A.ELit LEmptyError) e | ELit LEmptyError -> Pos.same_pos_as (A.ELit LEmptyError) e
| _ -> | _ ->
Errors.raise_spanned_error (Pos.get_position cond) Errors.raise_spanned_error (Pos.get_position cond)
"Expected a boolean literal for the result of this condition (should not happen if the \ "Expected a boolean literal for the result of this condition \
term was well-typed)") (should not happen if the term was well-typed)")
| EArray es -> | EArray es ->
let new_es = List.map (evaluate_expr ctx) es in let new_es = List.map (evaluate_expr ctx) es in
if List.exists is_empty_error new_es then Pos.same_pos_as (A.ELit LEmptyError) e if List.exists is_empty_error new_es then
Pos.same_pos_as (A.ELit LEmptyError) e
else Pos.same_pos_as (A.EArray new_es) e else Pos.same_pos_as (A.EArray new_es) e
| ErrorOnEmpty e' -> | ErrorOnEmpty e' ->
let e' = evaluate_expr ctx e' in let e' = evaluate_expr ctx e' in
if Pos.unmark e' = A.ELit LEmptyError then if Pos.unmark e' = A.ELit LEmptyError then
Errors.raise_spanned_error (Pos.get_position e) Errors.raise_spanned_error (Pos.get_position e)
"This variable evaluated to an empty term (no rule that defined it applied in this \ "This variable evaluated to an empty term (no rule that defined it \
situation)" applied in this situation)"
else e' else e'
| EAssert e' -> ( | EAssert e' -> (
match Pos.unmark (evaluate_expr ctx e') with match Pos.unmark (evaluate_expr ctx e') with
| ELit (LBool true) -> Pos.same_pos_as (Ast.ELit LUnit) e' | ELit (LBool true) -> Pos.same_pos_as (Ast.ELit LUnit) e'
| ELit (LBool false) -> ( | ELit (LBool false) -> (
match Pos.unmark e' with match Pos.unmark e' with
| EApp ((Ast.EOp (Binop op), pos_op), [ ((ELit _, _) as e1); ((ELit _, _) as e2) ]) -> | EApp
Errors.raise_spanned_error (Pos.get_position e') "Assertion failed: %a %a %a" ( (Ast.EOp (Binop op), pos_op),
[ ((ELit _, _) as e1); ((ELit _, _) as e2) ] ) ->
Errors.raise_spanned_error (Pos.get_position e')
"Assertion failed: %a %a %a"
(Print.format_expr ctx ~debug:false) (Print.format_expr ctx ~debug:false)
e1 Print.format_binop (op, pos_op) e1 Print.format_binop (op, pos_op)
(Print.format_expr ctx ~debug:false) (Print.format_expr ctx ~debug:false)
e2 e2
| _ -> Errors.raise_spanned_error (Pos.get_position e') "Assertion failed") | _ ->
Errors.raise_spanned_error (Pos.get_position e')
"Assertion failed")
| ELit LEmptyError -> Pos.same_pos_as (A.ELit LEmptyError) e | ELit LEmptyError -> Pos.same_pos_as (A.ELit LEmptyError) e
| _ -> | _ ->
Errors.raise_spanned_error (Pos.get_position e') Errors.raise_spanned_error (Pos.get_position e')
"Expected a boolean literal for the result of this assertion (should not happen if the \ "Expected a boolean literal for the result of this assertion \
term was well-typed)") (should not happen if the term was well-typed)")
(** {1 API} *) (** {1 API} *)
@ -411,7 +491,9 @@ let interpret_program (ctx : Ast.decl_ctx) (e : Ast.expr Pos.marked) :
| Ast.EAbs (_, [ (Ast.TTuple (taus, Some s_in), _) ]) -> ( | Ast.EAbs (_, [ (Ast.TTuple (taus, Some s_in), _) ]) -> (
let application_term = List.map (fun _ -> Ast.empty_thunked_term) taus in let application_term = List.map (fun _ -> Ast.empty_thunked_term) taus in
let to_interpret = let to_interpret =
(Ast.EApp (e, [ (Ast.ETuple (application_term, Some s_in), Pos.no_pos) ]), Pos.no_pos) ( Ast.EApp
(e, [ (Ast.ETuple (application_term, Some s_in), Pos.no_pos) ]),
Pos.no_pos )
in in
match Pos.unmark (evaluate_expr ctx to_interpret) with match Pos.unmark (evaluate_expr ctx to_interpret) with
| Ast.ETuple (args, Some s_out) -> | Ast.ETuple (args, Some s_out) ->
@ -423,8 +505,9 @@ let interpret_program (ctx : Ast.decl_ctx) (e : Ast.expr Pos.marked) :
List.map2 (fun arg var -> (var, arg)) args s_out_fields List.map2 (fun arg var -> (var, arg)) args s_out_fields
| _ -> | _ ->
Errors.raise_spanned_error (Pos.get_position e) Errors.raise_spanned_error (Pos.get_position e)
"The interpretation of a program should always yield a struct corresponding to the \ "The interpretation of a program should always yield a struct \
scope variables") corresponding to the scope variables")
| _ -> | _ ->
Errors.raise_spanned_error (Pos.get_position e) Errors.raise_spanned_error (Pos.get_position e)
"The interpreter can only interpret terms starting with functions having thunked arguments" "The interpreter can only interpret terms starting with functions \
having thunked arguments"

View File

@ -1,15 +1,17 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits (* This file is part of the Catala compiler, a specification language for tax
computation rules. Copyright (C) 2020 Inria, contributor: Denis Merigoux and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
<denis.merigoux@inria.fr> Denis Merigoux <denis.merigoux@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
(** Reference interpreter for the default calculus *) (** Reference interpreter for the default calculus *)
@ -20,7 +22,10 @@ val evaluate_expr : Ast.decl_ctx -> Ast.expr Pos.marked -> Ast.expr Pos.marked
(** Evaluates an expression according to the semantics of the default calculus. *) (** Evaluates an expression according to the semantics of the default calculus. *)
val interpret_program : val interpret_program :
Ast.decl_ctx -> Ast.expr Pos.marked -> (Uid.MarkedString.info * Ast.expr Pos.marked) list Ast.decl_ctx ->
(** Interprets a program. This function expects an expression typed as a function whose argument are Ast.expr Pos.marked ->
all thunked. The function is executed by providing for each argument a thunked empty default. (Uid.MarkedString.info * Ast.expr Pos.marked) list
Returns a list of all the computed values for the scope variables of the executed scope. *) (** Interprets a program. This function expects an expression typed as a
function whose argument are all thunked. The function is executed by
providing for each argument a thunked empty default. Returns a list of all
the computed values for the scope variables of the executed scope. *)

View File

@ -1,28 +1,35 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits (* This file is part of the Catala compiler, a specification language for tax
computation rules. Copyright (C) 2022 Inria, contributors: Alain Delaët and social benefits computation rules. Copyright (C) 2022 Inria,
<alain.delaet--tixeuil@inria.fr>, Denis Merigoux <denis.merigoux@inria.fr> contributors: Alain Delaët <alain.delaet--tixeuil@inria.fr>, Denis Merigoux
<denis.merigoux@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
open Utils open Utils
open Ast open Ast
type partial_evaluation_ctx = { var_values : expr Pos.marked Ast.VarMap.t; decl_ctx : decl_ctx } type partial_evaluation_ctx = {
var_values : expr Pos.marked Ast.VarMap.t;
decl_ctx : decl_ctx;
}
let rec partial_evaluation (ctx : partial_evaluation_ctx) (e : expr Pos.marked) : let rec partial_evaluation (ctx : partial_evaluation_ctx) (e : expr Pos.marked)
expr Pos.marked Bindlib.box = : expr Pos.marked Bindlib.box =
let pos = Pos.get_position e in let pos = Pos.get_position e in
let rec_helper = partial_evaluation ctx in let rec_helper = partial_evaluation ctx in
match Pos.unmark e with match Pos.unmark e with
| EApp | EApp
( ((EOp (Unop Not), _ | EApp ((EOp (Unop (Log _)), _), [ (EOp (Unop Not), _) ]), _) as op), ( (( EOp (Unop Not), _
| EApp ((EOp (Unop (Log _)), _), [ (EOp (Unop Not), _) ]), _ ) as op),
[ e1 ] ) -> [ e1 ] ) ->
(* reduction of logical not *) (* reduction of logical not *)
(Bindlib.box_apply (fun e1 -> (Bindlib.box_apply (fun e1 ->
@ -32,23 +39,29 @@ let rec partial_evaluation (ctx : partial_evaluation_ctx) (e : expr Pos.marked)
| _ -> (EApp (op, [ e1 ]), pos))) | _ -> (EApp (op, [ e1 ]), pos)))
(rec_helper e1) (rec_helper e1)
| EApp | EApp
( ((EOp (Binop Or), _ | EApp ((EOp (Unop (Log _)), _), [ (EOp (Binop Or), _) ]), _) as op), ( (( EOp (Binop Or), _
| EApp ((EOp (Unop (Log _)), _), [ (EOp (Binop Or), _) ]), _ ) as op),
[ e1; e2 ] ) -> [ e1; e2 ] ) ->
(* reduction of logical or *) (* reduction of logical or *)
(Bindlib.box_apply2 (fun e1 e2 -> (Bindlib.box_apply2 (fun e1 e2 ->
match (e1, e2) with match (e1, e2) with
| (ELit (LBool false), _), new_e | new_e, (ELit (LBool false), _) -> new_e | (ELit (LBool false), _), new_e | new_e, (ELit (LBool false), _) ->
| (ELit (LBool true), _), _ | _, (ELit (LBool true), _) -> (ELit (LBool true), pos) new_e
| (ELit (LBool true), _), _ | _, (ELit (LBool true), _) ->
(ELit (LBool true), pos)
| _ -> (EApp (op, [ e1; e2 ]), pos))) | _ -> (EApp (op, [ e1; e2 ]), pos)))
(rec_helper e1) (rec_helper e2) (rec_helper e1) (rec_helper e2)
| EApp | EApp
( ((EOp (Binop And), _ | EApp ((EOp (Unop (Log _)), _), [ (EOp (Binop And), _) ]), _) as op), ( (( EOp (Binop And), _
| EApp ((EOp (Unop (Log _)), _), [ (EOp (Binop And), _) ]), _ ) as op),
[ e1; e2 ] ) -> [ e1; e2 ] ) ->
(* reduction of logical and *) (* reduction of logical and *)
(Bindlib.box_apply2 (fun e1 e2 -> (Bindlib.box_apply2 (fun e1 e2 ->
match (e1, e2) with match (e1, e2) with
| (ELit (LBool true), _), new_e | new_e, (ELit (LBool true), _) -> new_e | (ELit (LBool true), _), new_e | new_e, (ELit (LBool true), _) ->
| (ELit (LBool false), _), _ | _, (ELit (LBool false), _) -> (ELit (LBool false), pos) new_e
| (ELit (LBool false), _), _ | _, (ELit (LBool false), _) ->
(ELit (LBool false), pos)
| _ -> (EApp (op, [ e1; e2 ]), pos))) | _ -> (EApp (op, [ e1; e2 ]), pos)))
(rec_helper e1) (rec_helper e2) (rec_helper e1) (rec_helper e2)
| EVar (x, _) -> Bindlib.box_apply (fun x -> (x, pos)) (Bindlib.box_var x) | EVar (x, _) -> Bindlib.box_apply (fun x -> (x, pos)) (Bindlib.box_var x)
@ -57,14 +70,19 @@ let rec partial_evaluation (ctx : partial_evaluation_ctx) (e : expr Pos.marked)
(fun args -> (ETuple (args, s_name), pos)) (fun args -> (ETuple (args, s_name), pos))
(List.map rec_helper args |> Bindlib.box_list) (List.map rec_helper args |> Bindlib.box_list)
| ETupleAccess (arg, i, s_name, typs) -> | ETupleAccess (arg, i, s_name, typs) ->
Bindlib.box_apply (fun arg -> (ETupleAccess (arg, i, s_name, typs), pos)) (rec_helper arg) Bindlib.box_apply
(fun arg -> (ETupleAccess (arg, i, s_name, typs), pos))
(rec_helper arg)
| EInj (arg, i, e_name, typs) -> | EInj (arg, i, e_name, typs) ->
Bindlib.box_apply (fun arg -> (EInj (arg, i, e_name, typs), pos)) (rec_helper arg) Bindlib.box_apply
(fun arg -> (EInj (arg, i, e_name, typs), pos))
(rec_helper arg)
| EMatch (arg, arms, e_name) -> | EMatch (arg, arms, e_name) ->
Bindlib.box_apply2 Bindlib.box_apply2
(fun arg arms -> (fun arg arms ->
match (arg, arms) with match (arg, arms) with
| (EInj (e1, i, e_name', _ts), _), _ when Ast.EnumName.compare e_name e_name' = 0 -> | (EInj (e1, i, e_name', _ts), _), _
when Ast.EnumName.compare e_name e_name' = 0 ->
(* iota reduction *) (* iota reduction *)
(EApp (List.nth arms i, [ e1 ]), pos) (EApp (List.nth arms i, [ e1 ]), pos)
| _ -> (EMatch (arg, arms, e_name), pos)) | _ -> (EMatch (arg, arms, e_name), pos))
@ -79,7 +97,9 @@ let rec partial_evaluation (ctx : partial_evaluation_ctx) (e : expr Pos.marked)
let vars, body = Bindlib.unmbind binder in let vars, body = Bindlib.unmbind binder in
let new_body = rec_helper body in let new_body = rec_helper body in
let new_binder = Bindlib.bind_mvar vars new_body in let new_binder = Bindlib.bind_mvar vars new_body in
Bindlib.box_apply (fun binder -> (EAbs ((binder, binder_pos), typs), pos)) new_binder Bindlib.box_apply
(fun binder -> (EAbs ((binder, binder_pos), typs), pos))
new_binder
| EApp (f, args) -> | EApp (f, args) ->
Bindlib.box_apply2 Bindlib.box_apply2
(fun f args -> (fun f args ->
@ -90,7 +110,8 @@ let rec partial_evaluation (ctx : partial_evaluation_ctx) (e : expr Pos.marked)
| _ -> (EApp (f, args), pos)) | _ -> (EApp (f, args), pos))
(rec_helper f) (rec_helper f)
(List.map rec_helper args |> Bindlib.box_list) (List.map rec_helper args |> Bindlib.box_list)
| EAssert e1 -> Bindlib.box_apply (fun e1 -> (EAssert e1, pos)) (rec_helper e1) | EAssert e1 ->
Bindlib.box_apply (fun e1 -> (EAssert e1, pos)) (rec_helper e1)
| EOp op -> Bindlib.box (EOp op, pos) | EOp op -> Bindlib.box (EOp op, pos)
| EDefault (exceptions, just, cons) -> | EDefault (exceptions, just, cons) ->
Bindlib.box_apply3 Bindlib.box_apply3
@ -98,7 +119,10 @@ let rec partial_evaluation (ctx : partial_evaluation_ctx) (e : expr Pos.marked)
(* TODO: mechanically prove each of these optimizations correct :) *) (* TODO: mechanically prove each of these optimizations correct :) *)
match match
( List.filter ( List.filter
(fun except -> match Pos.unmark except with ELit LEmptyError -> false | _ -> true) (fun except ->
match Pos.unmark except with
| ELit LEmptyError -> false
| _ -> true)
exceptions exceptions
(* we can discard the exceptions that are always empty error *), (* we can discard the exceptions that are always empty error *),
just, just,
@ -109,25 +133,33 @@ let rec partial_evaluation (ctx : partial_evaluation_ctx) (e : expr Pos.marked)
(fun nb except -> if is_value except then nb + 1 else nb) (fun nb except -> if is_value except then nb + 1 else nb)
0 exceptions 0 exceptions
> 1 -> > 1 ->
(* at this point we know a conflict error will be triggered so we just feed the (* at this point we know a conflict error will be triggered so we
expression to the interpreter that will print the beautiful right error message *) just feed the expression to the interpreter that will print the
Interpreter.evaluate_expr ctx.decl_ctx (EDefault (exceptions, just, cons), pos) beautiful right error message *)
Interpreter.evaluate_expr ctx.decl_ctx
(EDefault (exceptions, just, cons), pos)
| [ except ], _, _ when is_value except -> | [ except ], _, _ when is_value except ->
(* if there is only one exception and it is a non-empty value it is always chosen *) (* if there is only one exception and it is a non-empty value it
is always chosen *)
except except
| ( [], | ( [],
((ELit (LBool true) | EApp ((EOp (Unop (Log _)), _), [ (ELit (LBool true), _) ])), _), ( ( ELit (LBool true)
| EApp ((EOp (Unop (Log _)), _), [ (ELit (LBool true), _) ]) ),
_ ),
cons ) -> cons ) ->
cons cons
| ( [], | ( [],
((ELit (LBool false) | EApp ((EOp (Unop (Log _)), _), [ (ELit (LBool false), _) ])), _), ( ( ELit (LBool false)
| EApp ((EOp (Unop (Log _)), _), [ (ELit (LBool false), _) ]) ),
_ ),
_ ) -> _ ) ->
(ELit LEmptyError, pos) (ELit LEmptyError, pos)
| [], just, cons when not !Cli.avoid_exceptions_flag -> | [], just, cons when not !Cli.avoid_exceptions_flag ->
(* without exceptions, a default is just an [if then else] raising an error in the (* without exceptions, a default is just an [if then else] raising
else case. This exception is only valid in the context of an error in the else case. This exception is only valid in the
compilation_with_exceptions, so we desactivate with a global flag to know if we context of compilation_with_exceptions, so we desactivate with
will be compiling using exceptions or the option monad. *) a global flag to know if we will be compiling using exceptions
or the option monad. *)
(EIfThenElse (just, cons, (ELit LEmptyError, pos)), pos) (EIfThenElse (just, cons, (ELit LEmptyError, pos)), pos)
| exceptions, just, cons -> (EDefault (exceptions, just, cons), pos)) | exceptions, just, cons -> (EDefault (exceptions, just, cons), pos))
(List.map rec_helper exceptions |> Bindlib.box_list) (List.map rec_helper exceptions |> Bindlib.box_list)
@ -143,19 +175,24 @@ let rec partial_evaluation (ctx : partial_evaluation_ctx) (e : expr Pos.marked)
| EApp ((EOp (Unop (Log _)), _), [ (ELit (LBool false), _) ]), _, _ -> | EApp ((EOp (Unop (Log _)), _), [ (ELit (LBool false), _) ]), _, _ ->
e3 e3
| ( _, | ( _,
(ELit (LBool true) | EApp ((EOp (Unop (Log _)), _), [ (ELit (LBool true), _) ])), ( ELit (LBool true)
(ELit (LBool false) | EApp ((EOp (Unop (Log _)), _), [ (ELit (LBool false), _) ])) ) | EApp ((EOp (Unop (Log _)), _), [ (ELit (LBool true), _) ]) ),
( ELit (LBool false)
| EApp ((EOp (Unop (Log _)), _), [ (ELit (LBool false), _) ]) ) )
-> ->
e1 e1
| _ -> (EIfThenElse (e1, e2, e3), pos)) | _ -> (EIfThenElse (e1, e2, e3), pos))
(rec_helper e1) (rec_helper e2) (rec_helper e3) (rec_helper e1) (rec_helper e2) (rec_helper e3)
| ErrorOnEmpty e1 -> Bindlib.box_apply (fun e1 -> (ErrorOnEmpty e1, pos)) (rec_helper e1) | ErrorOnEmpty e1 ->
Bindlib.box_apply (fun e1 -> (ErrorOnEmpty e1, pos)) (rec_helper e1)
let optimize_expr (decl_ctx : decl_ctx) (e : expr Pos.marked) = let optimize_expr (decl_ctx : decl_ctx) (e : expr Pos.marked) =
partial_evaluation { var_values = VarMap.empty; decl_ctx } e partial_evaluation { var_values = VarMap.empty; decl_ctx } e
let program_map (t : 'a -> expr Pos.marked -> expr Pos.marked Bindlib.box) (ctx : 'a) (p : program) let program_map
: program = (t : 'a -> expr Pos.marked -> expr Pos.marked Bindlib.box)
(ctx : 'a)
(p : program) : program =
{ {
p with p with
scopes = scopes =
@ -170,7 +207,8 @@ let program_map (t : 'a -> expr Pos.marked -> expr Pos.marked Bindlib.box) (ctx
{ {
scope_let with scope_let with
scope_let_expr = scope_let_expr =
Bindlib.unbox (Bindlib.box_apply (t ctx) scope_let.scope_let_expr); Bindlib.unbox
(Bindlib.box_apply (t ctx) scope_let.scope_let_expr);
}) })
s_body.scope_body_lets; s_body.scope_body_lets;
} }
@ -180,7 +218,9 @@ let program_map (t : 'a -> expr Pos.marked -> expr Pos.marked Bindlib.box) (ctx
} }
let optimize_program (p : program) : program = let optimize_program (p : program) : program =
program_map partial_evaluation { var_values = VarMap.empty; decl_ctx = p.decl_ctx } p program_map partial_evaluation
{ var_values = VarMap.empty; decl_ctx = p.decl_ctx }
p
let rec remove_all_logs (e : expr Pos.marked) : expr Pos.marked Bindlib.box = let rec remove_all_logs (e : expr Pos.marked) : expr Pos.marked Bindlib.box =
let pos = Pos.get_position e in let pos = Pos.get_position e in
@ -192,9 +232,13 @@ let rec remove_all_logs (e : expr Pos.marked) : expr Pos.marked Bindlib.box =
(fun args -> (ETuple (args, s_name), pos)) (fun args -> (ETuple (args, s_name), pos))
(List.map rec_helper args |> Bindlib.box_list) (List.map rec_helper args |> Bindlib.box_list)
| ETupleAccess (arg, i, s_name, typs) -> | ETupleAccess (arg, i, s_name, typs) ->
Bindlib.box_apply (fun arg -> (ETupleAccess (arg, i, s_name, typs), pos)) (rec_helper arg) Bindlib.box_apply
(fun arg -> (ETupleAccess (arg, i, s_name, typs), pos))
(rec_helper arg)
| EInj (arg, i, e_name, typs) -> | EInj (arg, i, e_name, typs) ->
Bindlib.box_apply (fun arg -> (EInj (arg, i, e_name, typs), pos)) (rec_helper arg) Bindlib.box_apply
(fun arg -> (EInj (arg, i, e_name, typs), pos))
(rec_helper arg)
| EMatch (arg, arms, e_name) -> | EMatch (arg, arms, e_name) ->
Bindlib.box_apply2 Bindlib.box_apply2
(fun arg arms -> (EMatch (arg, arms, e_name), pos)) (fun arg arms -> (EMatch (arg, arms, e_name), pos))
@ -209,7 +253,9 @@ let rec remove_all_logs (e : expr Pos.marked) : expr Pos.marked Bindlib.box =
let vars, body = Bindlib.unmbind binder in let vars, body = Bindlib.unmbind binder in
let new_body = rec_helper body in let new_body = rec_helper body in
let new_binder = Bindlib.bind_mvar vars new_body in let new_binder = Bindlib.bind_mvar vars new_body in
Bindlib.box_apply (fun binder -> (EAbs ((binder, binder_pos), typs), pos)) new_binder Bindlib.box_apply
(fun binder -> (EAbs ((binder, binder_pos), typs), pos))
new_binder
| EApp (f, args) -> | EApp (f, args) ->
Bindlib.box_apply2 Bindlib.box_apply2
(fun f args -> (fun f args ->
@ -218,7 +264,8 @@ let rec remove_all_logs (e : expr Pos.marked) : expr Pos.marked Bindlib.box =
| _ -> (EApp (f, args), pos)) | _ -> (EApp (f, args), pos))
(rec_helper f) (rec_helper f)
(List.map rec_helper args |> Bindlib.box_list) (List.map rec_helper args |> Bindlib.box_list)
| EAssert e1 -> Bindlib.box_apply (fun e1 -> (EAssert e1, pos)) (rec_helper e1) | EAssert e1 ->
Bindlib.box_apply (fun e1 -> (EAssert e1, pos)) (rec_helper e1)
| EOp op -> Bindlib.box (EOp op, pos) | EOp op -> Bindlib.box (EOp op, pos)
| EDefault (exceptions, just, cons) -> | EDefault (exceptions, just, cons) ->
Bindlib.box_apply3 Bindlib.box_apply3
@ -229,4 +276,5 @@ let rec remove_all_logs (e : expr Pos.marked) : expr Pos.marked Bindlib.box =
Bindlib.box_apply3 Bindlib.box_apply3
(fun e1 e2 e3 -> (EIfThenElse (e1, e2, e3), pos)) (fun e1 e2 e3 -> (EIfThenElse (e1, e2, e3), pos))
(rec_helper e1) (rec_helper e2) (rec_helper e3) (rec_helper e1) (rec_helper e2) (rec_helper e3)
| ErrorOnEmpty e1 -> Bindlib.box_apply (fun e1 -> (ErrorOnEmpty e1, pos)) (rec_helper e1) | ErrorOnEmpty e1 ->
Bindlib.box_apply (fun e1 -> (ErrorOnEmpty e1, pos)) (rec_helper e1)

View File

@ -1,15 +1,18 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits (* This file is part of the Catala compiler, a specification language for tax
computation rules. Copyright (C) 2022 Inria, contributors: Alain Delaët and social benefits computation rules. Copyright (C) 2022 Inria,
<alain.delaet--tixeuil@inria.fr>, Denis Merigoux <denis.merigoux@inria.fr> contributors: Alain Delaët <alain.delaet--tixeuil@inria.fr>, Denis Merigoux
<denis.merigoux@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
(** Optimization passes for default calculus programs and expressions *) (** Optimization passes for default calculus programs and expressions *)
@ -18,7 +21,5 @@ open Utils
open Ast open Ast
val optimize_expr : decl_ctx -> expr Pos.marked -> expr Pos.marked Bindlib.box val optimize_expr : decl_ctx -> expr Pos.marked -> expr Pos.marked Bindlib.box
val optimize_program : program -> program val optimize_program : program -> program
val remove_all_logs : expr Pos.marked -> expr Pos.marked Bindlib.box val remove_all_logs : expr Pos.marked -> expr Pos.marked Bindlib.box

View File

@ -1,15 +1,17 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits (* This file is part of the Catala compiler, a specification language for tax
computation rules. Copyright (C) 2020 Inria, contributor: Denis Merigoux and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
<denis.merigoux@inria.fr> Denis Merigoux <denis.merigoux@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
open Utils open Utils
@ -30,14 +32,17 @@ let begins_with_uppercase (s : string) : bool =
let first_letter = CamomileLibraryDefault.Camomile.UTF8.get s 0 in let first_letter = CamomileLibraryDefault.Camomile.UTF8.get s 0 in
is_uppercase first_letter is_uppercase first_letter
let format_uid_list (fmt : Format.formatter) (infos : Uid.MarkedString.info list) : unit = let format_uid_list
(fmt : Format.formatter) (infos : Uid.MarkedString.info list) : unit =
Format.fprintf fmt "%a" Format.fprintf fmt "%a"
(Format.pp_print_list (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ".") ~pp_sep:(fun fmt () -> Format.fprintf fmt ".")
(fun fmt info -> (fun fmt info ->
Format.fprintf fmt "%a" Format.fprintf fmt "%a"
(Utils.Cli.format_with_style (Utils.Cli.format_with_style
(if begins_with_uppercase (Pos.unmark info) then [ ANSITerminal.red ] else [])) (if begins_with_uppercase (Pos.unmark info) then
[ ANSITerminal.red ]
else []))
(Format.asprintf "%a" Utils.Uid.MarkedString.format_info info))) (Format.asprintf "%a" Utils.Uid.MarkedString.format_info info)))
infos infos
@ -45,7 +50,9 @@ let format_keyword (fmt : Format.formatter) (s : string) : unit =
Format.fprintf fmt "%a" (Utils.Cli.format_with_style [ ANSITerminal.red ]) s Format.fprintf fmt "%a" (Utils.Cli.format_with_style [ ANSITerminal.red ]) s
let format_base_type (fmt : Format.formatter) (s : string) : unit = let format_base_type (fmt : Format.formatter) (s : string) : unit =
Format.fprintf fmt "%a" (Utils.Cli.format_with_style [ ANSITerminal.yellow ]) s Format.fprintf fmt "%a"
(Utils.Cli.format_with_style [ ANSITerminal.yellow ])
s
let format_punctuation (fmt : Format.formatter) (s : string) : unit = let format_punctuation (fmt : Format.formatter) (s : string) : unit =
Format.fprintf fmt "%a" (Utils.Cli.format_with_style [ ANSITerminal.cyan ]) s Format.fprintf fmt "%a" (Utils.Cli.format_with_style [ ANSITerminal.cyan ]) s
@ -54,7 +61,9 @@ let format_operator (fmt : Format.formatter) (s : string) : unit =
Format.fprintf fmt "%a" (Utils.Cli.format_with_style [ ANSITerminal.green ]) s Format.fprintf fmt "%a" (Utils.Cli.format_with_style [ ANSITerminal.green ]) s
let format_lit_style (fmt : Format.formatter) (s : string) : unit = let format_lit_style (fmt : Format.formatter) (s : string) : unit =
Format.fprintf fmt "%a" (Utils.Cli.format_with_style [ ANSITerminal.yellow ]) s Format.fprintf fmt "%a"
(Utils.Cli.format_with_style [ ANSITerminal.yellow ])
s
let format_tlit (fmt : Format.formatter) (l : typ_lit) : unit = let format_tlit (fmt : Format.formatter) (l : typ_lit) : unit =
format_base_type fmt format_base_type fmt
@ -67,12 +76,15 @@ let format_tlit (fmt : Format.formatter) (l : typ_lit) : unit =
| TDuration -> "duration" | TDuration -> "duration"
| TDate -> "date") | TDate -> "date")
let format_enum_constructor (fmt : Format.formatter) (c : EnumConstructor.t) : unit = let format_enum_constructor (fmt : Format.formatter) (c : EnumConstructor.t) :
unit =
Format.fprintf fmt "%a" Format.fprintf fmt "%a"
(Utils.Cli.format_with_style [ ANSITerminal.magenta ]) (Utils.Cli.format_with_style [ ANSITerminal.magenta ])
(Format.asprintf "%a" EnumConstructor.format_t c) (Format.asprintf "%a" EnumConstructor.format_t c)
let rec format_typ (ctx : Ast.decl_ctx) (fmt : Format.formatter) (typ : typ Pos.marked) : unit = let rec format_typ
(ctx : Ast.decl_ctx) (fmt : Format.formatter) (typ : typ Pos.marked) : unit
=
let format_typ = format_typ ctx in let format_typ = format_typ ctx in
let format_typ_with_parens (fmt : Format.formatter) (t : typ Pos.marked) = let format_typ_with_parens (fmt : Format.formatter) (t : typ Pos.marked) =
if typ_needs_parens t then Format.fprintf fmt "(%a)" format_typ t if typ_needs_parens t then Format.fprintf fmt "(%a)" format_typ t
@ -83,30 +95,39 @@ let rec format_typ (ctx : Ast.decl_ctx) (fmt : Format.formatter) (typ : typ Pos.
| TTuple (ts, None) -> | TTuple (ts, None) ->
Format.fprintf fmt "@[<hov 2>(%a)@]" Format.fprintf fmt "@[<hov 2>(%a)@]"
(Format.pp_print_list (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@ %a@ " format_operator "*") ~pp_sep:(fun fmt () ->
Format.fprintf fmt "@ %a@ " format_operator "*")
(fun fmt t -> Format.fprintf fmt "%a" format_typ t)) (fun fmt t -> Format.fprintf fmt "%a" format_typ t))
ts ts
| TTuple (_args, Some s) -> | TTuple (_args, Some s) ->
Format.fprintf fmt "@[<hov 2>%a%a%a%a@]" Ast.StructName.format_t s format_punctuation "{" Format.fprintf fmt "@[<hov 2>%a%a%a%a@]" Ast.StructName.format_t s
format_punctuation "{"
(Format.pp_print_list (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "%a@ " format_punctuation ";") ~pp_sep:(fun fmt () ->
Format.fprintf fmt "%a@ " format_punctuation ";")
(fun fmt (field, typ) -> (fun fmt (field, typ) ->
Format.fprintf fmt "%a%a%a%a@ %a" format_punctuation "\"" StructFieldName.format_t Format.fprintf fmt "%a%a%a%a@ %a" format_punctuation "\""
field format_punctuation "\"" format_punctuation ":" format_typ typ)) StructFieldName.format_t field format_punctuation "\""
format_punctuation ":" format_typ typ))
(StructMap.find s ctx.ctx_structs) (StructMap.find s ctx.ctx_structs)
format_punctuation "}" format_punctuation "}"
| TEnum (_, e) -> | TEnum (_, e) ->
Format.fprintf fmt "@[<hov 2>%a%a%a%a@]" Ast.EnumName.format_t e format_punctuation "[" Format.fprintf fmt "@[<hov 2>%a%a%a%a@]" Ast.EnumName.format_t e
format_punctuation "["
(Format.pp_print_list (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@ %a@ " format_punctuation "|") ~pp_sep:(fun fmt () ->
Format.fprintf fmt "@ %a@ " format_punctuation "|")
(fun fmt (case, typ) -> (fun fmt (case, typ) ->
Format.fprintf fmt "%a%a@ %a" format_enum_constructor case format_punctuation ":" Format.fprintf fmt "%a%a@ %a" format_enum_constructor case
format_typ typ)) format_punctuation ":" format_typ typ))
(EnumMap.find e ctx.ctx_enums) format_punctuation "]" (EnumMap.find e ctx.ctx_enums)
format_punctuation "]"
| TArrow (t1, t2) -> | TArrow (t1, t2) ->
Format.fprintf fmt "@[<hov 2>%a %a@ %a@]" format_typ_with_parens t1 format_operator "" Format.fprintf fmt "@[<hov 2>%a %a@ %a@]" format_typ_with_parens t1
format_typ t2 format_operator "" format_typ t2
| TArray t1 -> Format.fprintf fmt "@[<hov 2>%a@ %a@]" format_base_type "array" format_typ t1 | TArray t1 ->
Format.fprintf fmt "@[<hov 2>%a@ %a@]" format_base_type "array" format_typ
t1
| TAny -> format_base_type fmt "any" | TAny -> format_base_type fmt "any"
(* (EmileRolley) NOTE: seems to be factorizable with Lcalc.Print.format_lit. *) (* (EmileRolley) NOTE: seems to be factorizable with Lcalc.Print.format_lit. *)
@ -117,18 +138,30 @@ let format_lit (fmt : Format.formatter) (l : lit Pos.marked) : unit =
| LEmptyError -> format_lit_style fmt "" | LEmptyError -> format_lit_style fmt ""
| LUnit -> format_lit_style fmt "()" | LUnit -> format_lit_style fmt "()"
| LRat i -> | LRat i ->
format_lit_style fmt (Runtime.decimal_to_string ~max_prec_digits:!Utils.Cli.max_prec_digits i) format_lit_style fmt
(Runtime.decimal_to_string ~max_prec_digits:!Utils.Cli.max_prec_digits i)
| LMoney e -> ( | LMoney e -> (
match !Utils.Cli.locale_lang with match !Utils.Cli.locale_lang with
| En -> format_lit_style fmt (Format.asprintf "$%s" (Runtime.money_to_string e)) | En ->
| Fr -> format_lit_style fmt (Format.asprintf "%s €" (Runtime.money_to_string e)) format_lit_style fmt
| Pl -> format_lit_style fmt (Format.asprintf "%s PLN" (Runtime.money_to_string e))) (Format.asprintf "$%s" (Runtime.money_to_string e))
| Fr ->
format_lit_style fmt
(Format.asprintf "%s €" (Runtime.money_to_string e))
| Pl ->
format_lit_style fmt
(Format.asprintf "%s PLN" (Runtime.money_to_string e)))
| LDate d -> format_lit_style fmt (Runtime.date_to_string d) | LDate d -> format_lit_style fmt (Runtime.date_to_string d)
| LDuration d -> format_lit_style fmt (Runtime.duration_to_string d) | LDuration d -> format_lit_style fmt (Runtime.duration_to_string d)
let format_op_kind (fmt : Format.formatter) (k : op_kind) = let format_op_kind (fmt : Format.formatter) (k : op_kind) =
Format.fprintf fmt "%s" Format.fprintf fmt "%s"
(match k with KInt -> "" | KRat -> "." | KMoney -> "$" | KDate -> "@" | KDuration -> "^") (match k with
| KInt -> ""
| KRat -> "."
| KMoney -> "$"
| KDate -> "@"
| KDuration -> "^")
let format_binop (fmt : Format.formatter) (op : binop Pos.marked) : unit = let format_binop (fmt : Format.formatter) (op : binop Pos.marked) : unit =
format_operator fmt format_operator fmt
@ -184,12 +217,16 @@ let needs_parens (e : expr Pos.marked) : bool =
let format_var (fmt : Format.formatter) (v : Var.t) : unit = let format_var (fmt : Format.formatter) (v : Var.t) : unit =
Format.fprintf fmt "%s_%d" (Bindlib.name_of v) (Bindlib.uid_of v) Format.fprintf fmt "%s_%d" (Bindlib.name_of v) (Bindlib.uid_of v)
let rec format_expr ?(debug : bool = false) (ctx : Ast.decl_ctx) (fmt : Format.formatter) let rec format_expr
?(debug : bool = false)
(ctx : Ast.decl_ctx)
(fmt : Format.formatter)
(e : expr Pos.marked) : unit = (e : expr Pos.marked) : unit =
let format_expr = format_expr ~debug ctx in let format_expr = format_expr ~debug ctx in
let format_with_parens (fmt : Format.formatter) (e : expr Pos.marked) = let format_with_parens (fmt : Format.formatter) (e : expr Pos.marked) =
if needs_parens e then if needs_parens e then
Format.fprintf fmt "%a%a%a" format_punctuation "(" format_expr e format_punctuation ")" Format.fprintf fmt "%a%a%a" format_punctuation "(" format_expr e
format_punctuation ")"
else Format.fprintf fmt "%a" format_expr e else Format.fprintf fmt "%a" format_expr e
in in
match Pos.unmark e with match Pos.unmark e with
@ -201,13 +238,15 @@ let rec format_expr ?(debug : bool = false) (ctx : Ast.decl_ctx) (fmt : Format.f
(fun fmt e -> Format.fprintf fmt "%a" format_expr e)) (fun fmt e -> Format.fprintf fmt "%a" format_expr e))
es format_punctuation ")" es format_punctuation ")"
| ETuple (es, Some s) -> | ETuple (es, Some s) ->
Format.fprintf fmt "@[<hov 2>%a@ @[<hov 2>%a%a%a@]@]" Ast.StructName.format_t s Format.fprintf fmt "@[<hov 2>%a@ @[<hov 2>%a%a%a@]@]"
format_punctuation "{" Ast.StructName.format_t s format_punctuation "{"
(Format.pp_print_list (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "%a@ " format_punctuation ";") ~pp_sep:(fun fmt () ->
Format.fprintf fmt "%a@ " format_punctuation ";")
(fun fmt (e, struct_field) -> (fun fmt (e, struct_field) ->
Format.fprintf fmt "%a%a%a%a@ %a" format_punctuation "\"" Ast.StructFieldName.format_t Format.fprintf fmt "%a%a%a%a@ %a" format_punctuation "\""
struct_field format_punctuation "\"" format_punctuation "=" format_expr e)) Ast.StructFieldName.format_t struct_field format_punctuation "\""
format_punctuation "=" format_expr e))
(List.combine es (List.map fst (Ast.StructMap.find s ctx.ctx_structs))) (List.combine es (List.map fst (Ast.StructMap.find s ctx.ctx_structs)))
format_punctuation "}" format_punctuation "}"
| EArray es -> | EArray es ->
@ -218,10 +257,11 @@ let rec format_expr ?(debug : bool = false) (ctx : Ast.decl_ctx) (fmt : Format.f
es format_punctuation "]" es format_punctuation "]"
| ETupleAccess (e1, n, s, _ts) -> ( | ETupleAccess (e1, n, s, _ts) -> (
match s with match s with
| None -> Format.fprintf fmt "%a%a%d" format_expr e1 format_punctuation "." n | None ->
Format.fprintf fmt "%a%a%d" format_expr e1 format_punctuation "." n
| Some s -> | Some s ->
Format.fprintf fmt "%a%a%a%a%a" format_expr e1 format_operator "." format_punctuation "\"" Format.fprintf fmt "%a%a%a%a%a" format_expr e1 format_operator "."
Ast.StructFieldName.format_t format_punctuation "\"" Ast.StructFieldName.format_t
(fst (List.nth (Ast.StructMap.find s ctx.ctx_structs) n)) (fst (List.nth (Ast.StructMap.find s ctx.ctx_structs) n))
format_punctuation "\"") format_punctuation "\"")
| EInj (e, n, en, _ts) -> | EInj (e, n, en, _ts) ->
@ -229,8 +269,8 @@ let rec format_expr ?(debug : bool = false) (ctx : Ast.decl_ctx) (fmt : Format.f
(fst (List.nth (Ast.EnumMap.find en ctx.ctx_enums) n)) (fst (List.nth (Ast.EnumMap.find en ctx.ctx_enums) n))
format_expr e format_expr e
| EMatch (e, es, e_name) -> | EMatch (e, es, e_name) ->
Format.fprintf fmt "@[<hov 0>%a@ @[<hov 2>%a@]@ %a@ %a@]" format_keyword "match" format_expr e Format.fprintf fmt "@[<hov 0>%a@ @[<hov 2>%a@]@ %a@ %a@]" format_keyword
format_keyword "with" "match" format_expr e format_keyword "with"
(Format.pp_print_list (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n") ~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n")
(fun fmt (e, c) -> (fun fmt (e, c) ->
@ -241,63 +281,82 @@ let rec format_expr ?(debug : bool = false) (ctx : Ast.decl_ctx) (fmt : Format.f
| EApp ((EAbs ((binder, _), taus), _), args) -> | EApp ((EAbs ((binder, _), taus), _), args) ->
let xs, body = Bindlib.unmbind binder in let xs, body = Bindlib.unmbind binder in
let xs_tau = List.map2 (fun x tau -> (x, tau)) (Array.to_list xs) taus in let xs_tau = List.map2 (fun x tau -> (x, tau)) (Array.to_list xs) taus in
let xs_tau_arg = List.map2 (fun (x, tau) arg -> (x, tau, arg)) xs_tau args in let xs_tau_arg =
List.map2 (fun (x, tau) arg -> (x, tau, arg)) xs_tau args
in
Format.fprintf fmt "%a%a" Format.fprintf fmt "%a%a"
(Format.pp_print_list (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "") ~pp_sep:(fun fmt () -> Format.fprintf fmt "")
(fun fmt (x, tau, arg) -> (fun fmt (x, tau, arg) ->
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@ %a@ %a@ %a@ %a@]@\n" format_keyword "let" Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@ %a@ %a@ %a@ %a@]@\n"
format_var x format_punctuation ":" (format_typ ctx) tau format_punctuation "=" format_keyword "let" format_var x format_punctuation ":"
format_expr arg format_keyword "in")) (format_typ ctx) tau format_punctuation "=" format_expr arg
format_keyword "in"))
xs_tau_arg format_expr body xs_tau_arg format_expr body
| EAbs ((binder, _), taus) -> | EAbs ((binder, _), taus) ->
let xs, body = Bindlib.unmbind binder in let xs, body = Bindlib.unmbind binder in
let xs_tau = List.map2 (fun x tau -> (x, tau)) (Array.to_list xs) taus in let xs_tau = List.map2 (fun x tau -> (x, tau)) (Array.to_list xs) taus in
Format.fprintf fmt "@[<hov 2>%a @[<hov 2>%a@] %a@ %a@]" format_punctuation "λ" Format.fprintf fmt "@[<hov 2>%a @[<hov 2>%a@] %a@ %a@]" format_punctuation
"λ"
(Format.pp_print_list (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ") ~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ")
(fun fmt (x, tau) -> (fun fmt (x, tau) ->
Format.fprintf fmt "%a%a%a %a%a" format_punctuation "(" format_var x format_punctuation Format.fprintf fmt "%a%a%a %a%a" format_punctuation "(" format_var
":" (format_typ ctx) tau format_punctuation ")")) x format_punctuation ":" (format_typ ctx) tau format_punctuation
")"))
xs_tau format_punctuation "" format_expr body xs_tau format_punctuation "" format_expr body
| EApp ((EOp (Binop ((Ast.Map | Ast.Filter) as op)), _), [ arg1; arg2 ]) -> | EApp ((EOp (Binop ((Ast.Map | Ast.Filter) as op)), _), [ arg1; arg2 ]) ->
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@]" format_binop (op, Pos.no_pos) format_with_parens Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@]" format_binop (op, Pos.no_pos)
arg1 format_with_parens arg2 format_with_parens arg1 format_with_parens arg2
| EApp ((EOp (Binop op), _), [ arg1; arg2 ]) -> | EApp ((EOp (Binop op), _), [ arg1; arg2 ]) ->
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@]" format_with_parens arg1 format_binop Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@]" format_with_parens arg1
(op, Pos.no_pos) format_with_parens arg2 format_binop (op, Pos.no_pos) format_with_parens arg2
| EApp ((EOp (Unop (Log _)), _), [ arg1 ]) when not debug -> format_expr fmt arg1 | EApp ((EOp (Unop (Log _)), _), [ arg1 ]) when not debug ->
format_expr fmt arg1
| EApp ((EOp (Unop op), _), [ arg1 ]) -> | EApp ((EOp (Unop op), _), [ arg1 ]) ->
Format.fprintf fmt "@[<hov 2>%a@ %a@]" format_unop (op, Pos.no_pos) format_with_parens arg1 Format.fprintf fmt "@[<hov 2>%a@ %a@]" format_unop (op, Pos.no_pos)
format_with_parens arg1
| EApp (f, args) -> | EApp (f, args) ->
Format.fprintf fmt "@[<hov 2>%a@ %a@]" format_expr f Format.fprintf fmt "@[<hov 2>%a@ %a@]" format_expr f
(Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ") format_with_parens) (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ")
format_with_parens)
args args
| EIfThenElse (e1, e2, e3) -> | EIfThenElse (e1, e2, e3) ->
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@ %a@ %a@ %a@]" format_keyword "if" format_expr e1 Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@ %a@ %a@ %a@]" format_keyword "if"
format_keyword "then" format_expr e2 format_keyword "else" format_expr e3 format_expr e1 format_keyword "then" format_expr e2 format_keyword
"else" format_expr e3
| EOp (Ternop op) -> Format.fprintf fmt "%a" format_ternop (op, Pos.no_pos) | EOp (Ternop op) -> Format.fprintf fmt "%a" format_ternop (op, Pos.no_pos)
| EOp (Binop op) -> Format.fprintf fmt "%a" format_binop (op, Pos.no_pos) | EOp (Binop op) -> Format.fprintf fmt "%a" format_binop (op, Pos.no_pos)
| EOp (Unop op) -> Format.fprintf fmt "%a" format_unop (op, Pos.no_pos) | EOp (Unop op) -> Format.fprintf fmt "%a" format_unop (op, Pos.no_pos)
| EDefault (exceptions, just, cons) -> | EDefault (exceptions, just, cons) ->
if List.length exceptions = 0 then if List.length exceptions = 0 then
Format.fprintf fmt "@[<hov 2>%a%a@ %a@ %a%a@]" format_punctuation "" format_expr just Format.fprintf fmt "@[<hov 2>%a%a@ %a@ %a%a@]" format_punctuation ""
format_punctuation "" format_expr cons format_punctuation "" format_expr just format_punctuation "" format_expr cons
else
Format.fprintf fmt "@[<hov 2>%a%a@ %a@ %a@ %a@ %a%a@]" format_punctuation ""
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "%a@ " format_punctuation ",")
format_expr)
exceptions format_punctuation "|" format_expr just format_punctuation "" format_expr cons
format_punctuation "" format_punctuation ""
else
Format.fprintf fmt "@[<hov 2>%a%a@ %a@ %a@ %a@ %a%a@]"
format_punctuation ""
(Format.pp_print_list
~pp_sep:(fun fmt () ->
Format.fprintf fmt "%a@ " format_punctuation ",")
format_expr)
exceptions format_punctuation "|" format_expr just format_punctuation
"" format_expr cons format_punctuation ""
| ErrorOnEmpty e' -> | ErrorOnEmpty e' ->
Format.fprintf fmt "%a@ %a" format_operator "error_empty" format_with_parens e' Format.fprintf fmt "%a@ %a" format_operator "error_empty"
format_with_parens e'
| EAssert e' -> | EAssert e' ->
Format.fprintf fmt "@[<hov 2>%a@ %a%a%a@]" format_keyword "assert" format_punctuation "(" Format.fprintf fmt "@[<hov 2>%a@ %a%a%a@]" format_keyword "assert"
format_expr e' format_punctuation ")" format_punctuation "(" format_expr e' format_punctuation ")"
let format_scope ?(debug : bool = false) (ctx : decl_ctx) (fmt : Format.formatter) let format_scope
?(debug : bool = false)
(ctx : decl_ctx)
(fmt : Format.formatter)
((n, s) : Ast.ScopeName.t * scope_body) = ((n, s) : Ast.ScopeName.t * scope_body) =
Format.fprintf fmt "@[<hov 2>%a %a =@ %a@]" format_keyword "let" Ast.ScopeName.format_t n Format.fprintf fmt "@[<hov 2>%a %a =@ %a@]" format_keyword "let"
(format_expr ctx ~debug) Ast.ScopeName.format_t n (format_expr ctx ~debug)
(Bindlib.unbox (Ast.build_whole_scope_expr ctx s (Pos.get_position (Ast.ScopeName.get_info n)))) (Bindlib.unbox
(Ast.build_whole_scope_expr ctx s
(Pos.get_position (Ast.ScopeName.get_info n))))

View File

@ -1,15 +1,17 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits (* This file is part of the Catala compiler, a specification language for tax
computation rules. Copyright (C) 2020 Inria, contributor: Denis Merigoux and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
<denis.merigoux@inria.fr> Denis Merigoux <denis.merigoux@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
(** Printing functions for the default calculus AST *) (** Printing functions for the default calculus AST *)
@ -19,43 +21,28 @@ open Utils
(** {1 Helpers} *) (** {1 Helpers} *)
val is_uppercase : CamomileLibraryDefault.Camomile.UChar.t -> bool val is_uppercase : CamomileLibraryDefault.Camomile.UChar.t -> bool
val begins_with_uppercase : string -> bool val begins_with_uppercase : string -> bool
(** {1 Common syntax highlighting helpers}*) (** {1 Common syntax highlighting helpers}*)
val format_base_type : Format.formatter -> string -> unit val format_base_type : Format.formatter -> string -> unit
val format_keyword : Format.formatter -> string -> unit val format_keyword : Format.formatter -> string -> unit
val format_punctuation : Format.formatter -> string -> unit val format_punctuation : Format.formatter -> string -> unit
val format_operator : Format.formatter -> string -> unit val format_operator : Format.formatter -> string -> unit
val format_lit_style : Format.formatter -> string -> unit val format_lit_style : Format.formatter -> string -> unit
(** {1 Formatters} *) (** {1 Formatters} *)
val format_uid_list : Format.formatter -> Uid.MarkedString.info list -> unit val format_uid_list : Format.formatter -> Uid.MarkedString.info list -> unit
val format_enum_constructor : Format.formatter -> Ast.EnumConstructor.t -> unit val format_enum_constructor : Format.formatter -> Ast.EnumConstructor.t -> unit
val format_tlit : Format.formatter -> Ast.typ_lit -> unit val format_tlit : Format.formatter -> Ast.typ_lit -> unit
val format_typ : Ast.decl_ctx -> Format.formatter -> Ast.typ Pos.marked -> unit val format_typ : Ast.decl_ctx -> Format.formatter -> Ast.typ Pos.marked -> unit
val format_lit : Format.formatter -> Ast.lit Pos.marked -> unit val format_lit : Format.formatter -> Ast.lit Pos.marked -> unit
val format_op_kind : Format.formatter -> Ast.op_kind -> unit val format_op_kind : Format.formatter -> Ast.op_kind -> unit
val format_binop : Format.formatter -> Ast.binop Pos.marked -> unit val format_binop : Format.formatter -> Ast.binop Pos.marked -> unit
val format_ternop : Format.formatter -> Ast.ternop Pos.marked -> unit val format_ternop : Format.formatter -> Ast.ternop Pos.marked -> unit
val format_log_entry : Format.formatter -> Ast.log_entry -> unit val format_log_entry : Format.formatter -> Ast.log_entry -> unit
val format_unop : Format.formatter -> Ast.unop Pos.marked -> unit val format_unop : Format.formatter -> Ast.unop Pos.marked -> unit
val format_var : Format.formatter -> Ast.Var.t -> unit val format_var : Format.formatter -> Ast.Var.t -> unit
val format_expr : val format_expr :

View File

@ -1,19 +1,21 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits (* This file is part of the Catala compiler, a specification language for tax
computation rules. Copyright (C) 2020 Inria, contributor: Denis Merigoux and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
<denis.merigoux@inria.fr> Denis Merigoux <denis.merigoux@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
(** Typing for the default calculus. Because of the error terms, we perform type inference using the (** Typing for the default calculus. Because of the error terms, we perform type
classical W algorithm with union-find unification. *) inference using the classical W algorithm with union-find unification. *)
open Utils open Utils
module A = Ast module A = Ast
@ -29,8 +31,9 @@ module Any =
end) end)
() ()
(** We do not reuse {!type: Dcalc.Ast.typ} because we have to include a new [TAny] variant. Indeed, (** We do not reuse {!type: Dcalc.Ast.typ} because we have to include a new
error terms can have any type and this has to be captured by the type sytem. *) [TAny] variant. Indeed, error terms can have any type and this has to be
captured by the type sytem. *)
type typ = type typ =
| TLit of A.typ_lit | TLit of A.typ_lit
| TArrow of typ Pos.marked UnionFind.elem * typ Pos.marked UnionFind.elem | TArrow of typ Pos.marked UnionFind.elem * typ Pos.marked UnionFind.elem
@ -43,10 +46,13 @@ let typ_needs_parens (t : typ Pos.marked UnionFind.elem) : bool =
let t = UnionFind.get (UnionFind.find t) in let t = UnionFind.get (UnionFind.find t) in
match Pos.unmark t with TArrow _ | TArray _ -> true | _ -> false match Pos.unmark t with TArrow _ | TArray _ -> true | _ -> false
let rec format_typ (ctx : Ast.decl_ctx) (fmt : Format.formatter) let rec format_typ
(ctx : Ast.decl_ctx)
(fmt : Format.formatter)
(typ : typ Pos.marked UnionFind.elem) : unit = (typ : typ Pos.marked UnionFind.elem) : unit =
let format_typ = format_typ ctx in let format_typ = format_typ ctx in
let format_typ_with_parens (fmt : Format.formatter) (t : typ Pos.marked UnionFind.elem) = let format_typ_with_parens
(fmt : Format.formatter) (t : typ Pos.marked UnionFind.elem) =
if typ_needs_parens t then Format.fprintf fmt "(%a)" format_typ t if typ_needs_parens t then Format.fprintf fmt "(%a)" format_typ t
else Format.fprintf fmt "%a" format_typ t else Format.fprintf fmt "%a" format_typ t
in in
@ -62,21 +68,24 @@ let rec format_typ (ctx : Ast.decl_ctx) (fmt : Format.formatter)
| TTuple (_ts, Some s) -> Format.fprintf fmt "%a" Ast.StructName.format_t s | TTuple (_ts, Some s) -> Format.fprintf fmt "%a" Ast.StructName.format_t s
| TEnum (_ts, e) -> Format.fprintf fmt "%a" Ast.EnumName.format_t e | TEnum (_ts, e) -> Format.fprintf fmt "%a" Ast.EnumName.format_t e
| TArrow (t1, t2) -> | TArrow (t1, t2) ->
Format.fprintf fmt "@[<hov 2>%a →@ %a@]" format_typ_with_parens t1 format_typ t2 Format.fprintf fmt "@[<hov 2>%a →@ %a@]" format_typ_with_parens t1
format_typ t2
| TArray t1 -> Format.fprintf fmt "@[%a@ array@]" format_typ t1 | TArray t1 -> Format.fprintf fmt "@[%a@ array@]" format_typ t1
| TAny d -> Format.fprintf fmt "any[%d]" (Any.hash d) | TAny d -> Format.fprintf fmt "any[%d]" (Any.hash d)
(** Raises an error if unification cannot be performed *) (** Raises an error if unification cannot be performed *)
let rec unify (ctx : Ast.decl_ctx) (t1 : typ Pos.marked UnionFind.elem) let rec unify
(ctx : Ast.decl_ctx)
(t1 : typ Pos.marked UnionFind.elem)
(t2 : typ Pos.marked UnionFind.elem) : unit = (t2 : typ Pos.marked UnionFind.elem) : unit =
let unify = unify ctx in let unify = unify ctx in
(* Cli.debug_print (Format.asprintf "Unifying %a and %a" (format_typ ctx) t1 (format_typ ctx) (* Cli.debug_print (Format.asprintf "Unifying %a and %a" (format_typ ctx) t1
t2); *) (format_typ ctx) t2); *)
let t1_repr = UnionFind.get (UnionFind.find t1) in let t1_repr = UnionFind.get (UnionFind.find t1) in
let t2_repr = UnionFind.get (UnionFind.find t2) in let t2_repr = UnionFind.get (UnionFind.find t2) in
let raise_type_error (t1_pos : Pos.t) (t2_pos : Pos.t) : 'a = let raise_type_error (t1_pos : Pos.t) (t2_pos : Pos.t) : 'a =
(* TODO: if we get weird error messages, then it means that we should use the persistent version (* TODO: if we get weird error messages, then it means that we should use
of the union-find data structure. *) the persistent version of the union-find data structure. *)
let t1_s = let t1_s =
Cli.with_style [ ANSITerminal.yellow ] "%s" Cli.with_style [ ANSITerminal.yellow ] "%s"
(Re.Pcre.substitute ~rex:(Re.Pcre.regexp "\n\\s*") (Re.Pcre.substitute ~rex:(Re.Pcre.regexp "\n\\s*")
@ -129,9 +138,10 @@ let rec unify (ctx : Ast.decl_ctx) (t1 : typ Pos.marked UnionFind.elem)
let t_union = UnionFind.union t1 t2 in let t_union = UnionFind.union t1 t2 in
match repr with None -> () | Some t_repr -> UnionFind.set t_union t_repr match repr with None -> () | Some t_repr -> UnionFind.set t_union t_repr
(** Operators have a single type, instead of being polymorphic with constraints. This allows us to (** Operators have a single type, instead of being polymorphic with constraints.
have a simpler type system, while we argue the syntactic burden of operator annotations helps This allows us to have a simpler type system, while we argue the syntactic
the programmer visualize the type flow in the code. *) burden of operator annotations helps the programmer visualize the type flow
in the code. *)
let op_type (op : A.operator Pos.marked) : typ Pos.marked UnionFind.elem = let op_type (op : A.operator Pos.marked) : typ Pos.marked UnionFind.elem =
let pos = Pos.get_position op in let pos = Pos.get_position op in
let bt = UnionFind.make (TLit TBool, pos) in let bt = UnionFind.make (TLit TBool, pos) in
@ -146,10 +156,13 @@ let op_type (op : A.operator Pos.marked) : typ Pos.marked UnionFind.elem =
let array_any2 = UnionFind.make (TArray any2, pos) in let array_any2 = UnionFind.make (TArray any2, pos) in
let arr x y = UnionFind.make (TArrow (x, y), pos) in let arr x y = UnionFind.make (TArrow (x, y), pos) in
match Pos.unmark op with match Pos.unmark op with
| A.Ternop A.Fold -> arr (arr any2 (arr any any2)) (arr any2 (arr array_any any2)) | A.Ternop A.Fold ->
arr (arr any2 (arr any any2)) (arr any2 (arr array_any any2))
| A.Binop (A.And | A.Or | A.Xor) -> arr bt (arr bt bt) | A.Binop (A.And | A.Or | A.Xor) -> arr bt (arr bt bt)
| A.Binop (A.Add KInt | A.Sub KInt | A.Mult KInt | A.Div KInt) -> arr it (arr it it) | A.Binop (A.Add KInt | A.Sub KInt | A.Mult KInt | A.Div KInt) ->
| A.Binop (A.Add KRat | A.Sub KRat | A.Mult KRat | A.Div KRat) -> arr rt (arr rt rt) arr it (arr it it)
| A.Binop (A.Add KRat | A.Sub KRat | A.Mult KRat | A.Div KRat) ->
arr rt (arr rt rt)
| A.Binop (A.Add KMoney | A.Sub KMoney) -> arr mt (arr mt mt) | A.Binop (A.Add KMoney | A.Sub KMoney) -> arr mt (arr mt mt)
| A.Binop (A.Add KDuration | A.Sub KDuration) -> arr dut (arr dut dut) | A.Binop (A.Add KDuration | A.Sub KDuration) -> arr dut (arr dut dut)
| A.Binop (A.Sub KDate) -> arr dat (arr dat dut) | A.Binop (A.Sub KDate) -> arr dat (arr dat dut)
@ -157,11 +170,16 @@ let op_type (op : A.operator Pos.marked) : typ Pos.marked UnionFind.elem =
| A.Binop (A.Div KDuration) -> arr dut (arr dut rt) | A.Binop (A.Div KDuration) -> arr dut (arr dut rt)
| A.Binop (A.Div KMoney) -> arr mt (arr mt rt) | A.Binop (A.Div KMoney) -> arr mt (arr mt rt)
| A.Binop (A.Mult KMoney) -> arr mt (arr rt mt) | A.Binop (A.Mult KMoney) -> arr mt (arr rt mt)
| A.Binop (A.Lt KInt | A.Lte KInt | A.Gt KInt | A.Gte KInt) -> arr it (arr it bt) | A.Binop (A.Lt KInt | A.Lte KInt | A.Gt KInt | A.Gte KInt) ->
| A.Binop (A.Lt KRat | A.Lte KRat | A.Gt KRat | A.Gte KRat) -> arr rt (arr rt bt) arr it (arr it bt)
| A.Binop (A.Lt KMoney | A.Lte KMoney | A.Gt KMoney | A.Gte KMoney) -> arr mt (arr mt bt) | A.Binop (A.Lt KRat | A.Lte KRat | A.Gt KRat | A.Gte KRat) ->
| A.Binop (A.Lt KDate | A.Lte KDate | A.Gt KDate | A.Gte KDate) -> arr dat (arr dat bt) arr rt (arr rt bt)
| A.Binop (A.Lt KDuration | A.Lte KDuration | A.Gt KDuration | A.Gte KDuration) -> | A.Binop (A.Lt KMoney | A.Lte KMoney | A.Gt KMoney | A.Gte KMoney) ->
arr mt (arr mt bt)
| A.Binop (A.Lt KDate | A.Lte KDate | A.Gt KDate | A.Gte KDate) ->
arr dat (arr dat bt)
| A.Binop (A.Lt KDuration | A.Lte KDuration | A.Gt KDuration | A.Gte KDuration)
->
arr dut (arr dut bt) arr dut (arr dut bt)
| A.Binop (A.Eq | A.Neq) -> arr any (arr any bt) | A.Binop (A.Eq | A.Neq) -> arr any (arr any bt)
| A.Binop A.Map -> arr (arr any any2) (arr array_any array_any2) | A.Binop A.Map -> arr (arr any any2) (arr array_any array_any2)
@ -190,9 +208,13 @@ let rec ast_to_typ (ty : A.typ) : typ =
( UnionFind.make (Pos.map_under_mark ast_to_typ t1), ( UnionFind.make (Pos.map_under_mark ast_to_typ t1),
UnionFind.make (Pos.map_under_mark ast_to_typ t2) ) UnionFind.make (Pos.map_under_mark ast_to_typ t2) )
| A.TTuple (ts, s) -> | A.TTuple (ts, s) ->
TTuple (List.map (fun t -> UnionFind.make (Pos.map_under_mark ast_to_typ t)) ts, s) TTuple
( List.map (fun t -> UnionFind.make (Pos.map_under_mark ast_to_typ t)) ts,
s )
| A.TEnum (ts, e) -> | A.TEnum (ts, e) ->
TEnum (List.map (fun t -> UnionFind.make (Pos.map_under_mark ast_to_typ t)) ts, e) TEnum
( List.map (fun t -> UnionFind.make (Pos.map_under_mark ast_to_typ t)) ts,
e )
| A.TArray t -> TArray (UnionFind.make (Pos.map_under_mark ast_to_typ t)) | A.TArray t -> TArray (UnionFind.make (Pos.map_under_mark ast_to_typ t))
| A.TAny -> TAny (Any.fresh ()) | A.TAny -> TAny (Any.fresh ())
@ -213,9 +235,11 @@ let rec typ_to_ast (ty : typ Pos.marked UnionFind.elem) : A.typ Pos.marked =
type env = typ Pos.marked UnionFind.elem A.VarMap.t type env = typ Pos.marked UnionFind.elem A.VarMap.t
(** Infers the most permissive type from an expression *) (** Infers the most permissive type from an expression *)
let rec typecheck_expr_bottom_up (ctx : Ast.decl_ctx) (env : env) (e : A.expr Pos.marked) : let rec typecheck_expr_bottom_up
(ctx : Ast.decl_ctx) (env : env) (e : A.expr Pos.marked) :
typ Pos.marked UnionFind.elem = typ Pos.marked UnionFind.elem =
(* Cli.debug_print (Format.asprintf "Looking for type of %a" (Print.format_expr ctx) e); *) (* Cli.debug_print (Format.asprintf "Looking for type of %a"
(Print.format_expr ctx) e); *)
try try
let out = let out =
match Pos.unmark e with match Pos.unmark e with
@ -230,46 +254,66 @@ let rec typecheck_expr_bottom_up (ctx : Ast.decl_ctx) (env : env) (e : A.expr Po
| ELit (LRat _) -> UnionFind.make (Pos.same_pos_as (TLit TRat) e) | ELit (LRat _) -> UnionFind.make (Pos.same_pos_as (TLit TRat) e)
| ELit (LMoney _) -> UnionFind.make (Pos.same_pos_as (TLit TMoney) e) | ELit (LMoney _) -> UnionFind.make (Pos.same_pos_as (TLit TMoney) e)
| ELit (LDate _) -> UnionFind.make (Pos.same_pos_as (TLit TDate) e) | ELit (LDate _) -> UnionFind.make (Pos.same_pos_as (TLit TDate) e)
| ELit (LDuration _) -> UnionFind.make (Pos.same_pos_as (TLit TDuration) e) | ELit (LDuration _) ->
UnionFind.make (Pos.same_pos_as (TLit TDuration) e)
| ELit LUnit -> UnionFind.make (Pos.same_pos_as (TLit TUnit) e) | ELit LUnit -> UnionFind.make (Pos.same_pos_as (TLit TUnit) e)
| ELit LEmptyError -> UnionFind.make (Pos.same_pos_as (TAny (Any.fresh ())) e) | ELit LEmptyError ->
UnionFind.make (Pos.same_pos_as (TAny (Any.fresh ())) e)
| ETuple (es, s) -> | ETuple (es, s) ->
let ts = List.map (typecheck_expr_bottom_up ctx env) es in let ts = List.map (typecheck_expr_bottom_up ctx env) es in
UnionFind.make (Pos.same_pos_as (TTuple (ts, s)) e) UnionFind.make (Pos.same_pos_as (TTuple (ts, s)) e)
| ETupleAccess (e1, n, s, typs) -> ( | ETupleAccess (e1, n, s, typs) -> (
let typs = let typs =
List.map (fun typ -> UnionFind.make (Pos.map_under_mark ast_to_typ typ)) typs List.map
(fun typ -> UnionFind.make (Pos.map_under_mark ast_to_typ typ))
typs
in in
typecheck_expr_top_down ctx env e1 (UnionFind.make (TTuple (typs, s), Pos.get_position e)); typecheck_expr_top_down ctx env e1
(UnionFind.make (TTuple (typs, s), Pos.get_position e));
match List.nth_opt typs n with match List.nth_opt typs n with
| Some t' -> t' | Some t' -> t'
| None -> | None ->
Errors.raise_spanned_error (Pos.get_position e1) Errors.raise_spanned_error (Pos.get_position e1)
"Expression should have a tuple type with at least %d elements but only has %d" n "Expression should have a tuple type with at least %d elements \
(List.length typs)) but only has %d"
n (List.length typs))
| EInj (e1, n, e_name, ts) -> | EInj (e1, n, e_name, ts) ->
let ts = List.map (fun t -> UnionFind.make (Pos.map_under_mark ast_to_typ t)) ts in let ts =
List.map
(fun t -> UnionFind.make (Pos.map_under_mark ast_to_typ t))
ts
in
let ts_n = let ts_n =
match List.nth_opt ts n with match List.nth_opt ts n with
| Some ts_n -> ts_n | Some ts_n -> ts_n
| None -> | None ->
Errors.raise_spanned_error (Pos.get_position e) Errors.raise_spanned_error (Pos.get_position e)
"Expression should have a sum type with at least %d cases but only has %d" n "Expression should have a sum type with at least %d cases \
(List.length ts) but only has %d"
n (List.length ts)
in in
typecheck_expr_top_down ctx env e1 ts_n; typecheck_expr_top_down ctx env e1 ts_n;
UnionFind.make (Pos.same_pos_as (TEnum (ts, e_name)) e) UnionFind.make (Pos.same_pos_as (TEnum (ts, e_name)) e)
| EMatch (e1, es, e_name) -> | EMatch (e1, es, e_name) ->
let enum_cases = let enum_cases =
List.map (fun e' -> UnionFind.make (Pos.same_pos_as (TAny (Any.fresh ())) e')) es List.map
(fun e' ->
UnionFind.make (Pos.same_pos_as (TAny (Any.fresh ())) e'))
es
in
let t_e1 =
UnionFind.make (Pos.same_pos_as (TEnum (enum_cases, e_name)) e1)
in in
let t_e1 = UnionFind.make (Pos.same_pos_as (TEnum (enum_cases, e_name)) e1) in
typecheck_expr_top_down ctx env e1 t_e1; typecheck_expr_top_down ctx env e1 t_e1;
let t_ret = UnionFind.make (Pos.same_pos_as (TAny (Any.fresh ())) e) in let t_ret =
UnionFind.make (Pos.same_pos_as (TAny (Any.fresh ())) e)
in
List.iteri List.iteri
(fun i es' -> (fun i es' ->
let enum_t = List.nth enum_cases i in let enum_t = List.nth enum_cases i in
let t_es' = UnionFind.make (Pos.same_pos_as (TArrow (enum_t, t_ret)) es') in let t_es' =
UnionFind.make (Pos.same_pos_as (TArrow (enum_t, t_ret)) es')
in
typecheck_expr_top_down ctx env es' t_es') typecheck_expr_top_down ctx env es' t_es')
es; es;
t_ret t_ret
@ -279,10 +323,16 @@ let rec typecheck_expr_bottom_up (ctx : Ast.decl_ctx) (env : env) (e : A.expr Po
let xstaus = let xstaus =
List.map2 List.map2
(fun x tau -> (fun x tau ->
(x, UnionFind.make (ast_to_typ (Pos.unmark tau), Pos.get_position tau))) ( x,
UnionFind.make
(ast_to_typ (Pos.unmark tau), Pos.get_position tau) ))
(Array.to_list xs) taus (Array.to_list xs) taus
in in
let env = List.fold_left (fun env (x, tau) -> A.VarMap.add x tau env) env xstaus in let env =
List.fold_left
(fun env (x, tau) -> A.VarMap.add x tau env)
env xstaus
in
List.fold_right List.fold_right
(fun (_, t_arg) (acc : typ Pos.marked UnionFind.elem) -> (fun (_, t_arg) (acc : typ Pos.marked UnionFind.elem) ->
UnionFind.make (TArrow (t_arg, acc), pos_binder)) UnionFind.make (TArrow (t_arg, acc), pos_binder))
@ -290,35 +340,45 @@ let rec typecheck_expr_bottom_up (ctx : Ast.decl_ctx) (env : env) (e : A.expr Po
(typecheck_expr_bottom_up ctx env body) (typecheck_expr_bottom_up ctx env body)
else else
Errors.raise_spanned_error pos_binder Errors.raise_spanned_error pos_binder
"function has %d variables but was supplied %d types" (Array.length xs) "function has %d variables but was supplied %d types"
(List.length taus) (Array.length xs) (List.length taus)
| EApp (e1, args) -> | EApp (e1, args) ->
let t_args = List.map (typecheck_expr_bottom_up ctx env) args in let t_args = List.map (typecheck_expr_bottom_up ctx env) args in
let t_ret = UnionFind.make (Pos.same_pos_as (TAny (Any.fresh ())) e) in let t_ret =
UnionFind.make (Pos.same_pos_as (TAny (Any.fresh ())) e)
in
let t_app = let t_app =
List.fold_right List.fold_right
(fun t_arg acc -> UnionFind.make (Pos.same_pos_as (TArrow (t_arg, acc)) e)) (fun t_arg acc ->
UnionFind.make (Pos.same_pos_as (TArrow (t_arg, acc)) e))
t_args t_ret t_args t_ret
in in
typecheck_expr_top_down ctx env e1 t_app; typecheck_expr_top_down ctx env e1 t_app;
t_ret t_ret
| EOp op -> op_type (Pos.same_pos_as op e) | EOp op -> op_type (Pos.same_pos_as op e)
| EDefault (excepts, just, cons) -> | EDefault (excepts, just, cons) ->
typecheck_expr_top_down ctx env just (UnionFind.make (Pos.same_pos_as (TLit TBool) just)); typecheck_expr_top_down ctx env just
(UnionFind.make (Pos.same_pos_as (TLit TBool) just));
let tcons = typecheck_expr_bottom_up ctx env cons in let tcons = typecheck_expr_bottom_up ctx env cons in
List.iter (fun except -> typecheck_expr_top_down ctx env except tcons) excepts; List.iter
(fun except -> typecheck_expr_top_down ctx env except tcons)
excepts;
tcons tcons
| EIfThenElse (cond, et, ef) -> | EIfThenElse (cond, et, ef) ->
typecheck_expr_top_down ctx env cond (UnionFind.make (Pos.same_pos_as (TLit TBool) cond)); typecheck_expr_top_down ctx env cond
(UnionFind.make (Pos.same_pos_as (TLit TBool) cond));
let tt = typecheck_expr_bottom_up ctx env et in let tt = typecheck_expr_bottom_up ctx env et in
typecheck_expr_top_down ctx env ef tt; typecheck_expr_top_down ctx env ef tt;
tt tt
| EAssert e' -> | EAssert e' ->
typecheck_expr_top_down ctx env e' (UnionFind.make (Pos.same_pos_as (TLit TBool) e')); typecheck_expr_top_down ctx env e'
(UnionFind.make (Pos.same_pos_as (TLit TBool) e'));
UnionFind.make (Pos.same_pos_as (TLit TUnit) e') UnionFind.make (Pos.same_pos_as (TLit TUnit) e')
| ErrorOnEmpty e' -> typecheck_expr_bottom_up ctx env e' | ErrorOnEmpty e' -> typecheck_expr_bottom_up ctx env e'
| EArray es -> | EArray es ->
let cell_type = UnionFind.make (Pos.same_pos_as (TAny (Any.fresh ())) e) in let cell_type =
UnionFind.make (Pos.same_pos_as (TAny (Any.fresh ())) e)
in
List.iter List.iter
(fun e' -> (fun e' ->
let t_e' = typecheck_expr_bottom_up ctx env e' in let t_e' = typecheck_expr_bottom_up ctx env e' in
@ -326,21 +386,25 @@ let rec typecheck_expr_bottom_up (ctx : Ast.decl_ctx) (env : env) (e : A.expr Po
es; es;
UnionFind.make (Pos.same_pos_as (TArray cell_type) e) UnionFind.make (Pos.same_pos_as (TArray cell_type) e)
in in
(* Cli.debug_print (Format.asprintf "Found type of %a: %a" (Print.format_expr ctx) e (format_typ (* Cli.debug_print (Format.asprintf "Found type of %a: %a"
ctx) out); *) (Print.format_expr ctx) e (format_typ ctx) out); *)
out out
with Errors.StructuredError (msg, err_pos) when List.length err_pos = 2 -> with Errors.StructuredError (msg, err_pos) when List.length err_pos = 2 ->
raise raise
(Errors.StructuredError (Errors.StructuredError
( msg, ( msg,
(Some "Error coming from typechecking the following expression:", Pos.get_position e) ( Some "Error coming from typechecking the following expression:",
Pos.get_position e )
:: err_pos )) :: err_pos ))
(** Checks whether the expression can be typed with the provided type *) (** Checks whether the expression can be typed with the provided type *)
and typecheck_expr_top_down (ctx : Ast.decl_ctx) (env : env) (e : A.expr Pos.marked) and typecheck_expr_top_down
(ctx : Ast.decl_ctx)
(env : env)
(e : A.expr Pos.marked)
(tau : typ Pos.marked UnionFind.elem) : unit = (tau : typ Pos.marked UnionFind.elem) : unit =
(* Cli.debug_print (Format.asprintf "Typechecking %a : %a" (Print.format_expr ctx) e (format_typ (* Cli.debug_print (Format.asprintf "Typechecking %a : %a" (Print.format_expr
ctx) tau); *) ctx) e (format_typ ctx) tau); *)
try try
match Pos.unmark e with match Pos.unmark e with
| EVar v -> ( | EVar v -> (
@ -349,52 +413,80 @@ and typecheck_expr_top_down (ctx : Ast.decl_ctx) (env : env) (e : A.expr Pos.mar
| None -> | None ->
Errors.raise_spanned_error (Pos.get_position e) Errors.raise_spanned_error (Pos.get_position e)
"Variable not found in the current context") "Variable not found in the current context")
| ELit (LBool _) -> unify ctx tau (UnionFind.make (Pos.same_pos_as (TLit TBool) e)) | ELit (LBool _) ->
| ELit (LInt _) -> unify ctx tau (UnionFind.make (Pos.same_pos_as (TLit TInt) e)) unify ctx tau (UnionFind.make (Pos.same_pos_as (TLit TBool) e))
| ELit (LRat _) -> unify ctx tau (UnionFind.make (Pos.same_pos_as (TLit TRat) e)) | ELit (LInt _) ->
| ELit (LMoney _) -> unify ctx tau (UnionFind.make (Pos.same_pos_as (TLit TMoney) e)) unify ctx tau (UnionFind.make (Pos.same_pos_as (TLit TInt) e))
| ELit (LDate _) -> unify ctx tau (UnionFind.make (Pos.same_pos_as (TLit TDate) e)) | ELit (LRat _) ->
| ELit (LDuration _) -> unify ctx tau (UnionFind.make (Pos.same_pos_as (TLit TDuration) e)) unify ctx tau (UnionFind.make (Pos.same_pos_as (TLit TRat) e))
| ELit LUnit -> unify ctx tau (UnionFind.make (Pos.same_pos_as (TLit TUnit) e)) | ELit (LMoney _) ->
| ELit LEmptyError -> unify ctx tau (UnionFind.make (Pos.same_pos_as (TAny (Any.fresh ())) e)) unify ctx tau (UnionFind.make (Pos.same_pos_as (TLit TMoney) e))
| ELit (LDate _) ->
unify ctx tau (UnionFind.make (Pos.same_pos_as (TLit TDate) e))
| ELit (LDuration _) ->
unify ctx tau (UnionFind.make (Pos.same_pos_as (TLit TDuration) e))
| ELit LUnit ->
unify ctx tau (UnionFind.make (Pos.same_pos_as (TLit TUnit) e))
| ELit LEmptyError ->
unify ctx tau (UnionFind.make (Pos.same_pos_as (TAny (Any.fresh ())) e))
| ETuple (es, s) -> | ETuple (es, s) ->
let t_es = let t_es =
UnionFind.make UnionFind.make
(Pos.same_pos_as (TTuple (List.map (typecheck_expr_bottom_up ctx env) es, s)) e) (Pos.same_pos_as
(TTuple (List.map (typecheck_expr_bottom_up ctx env) es, s))
e)
in in
unify ctx tau t_es unify ctx tau t_es
| ETupleAccess (e1, n, s, typs) -> ( | ETupleAccess (e1, n, s, typs) -> (
let typs = List.map (fun typ -> UnionFind.make (Pos.map_under_mark ast_to_typ typ)) typs in let typs =
typecheck_expr_top_down ctx env e1 (UnionFind.make (TTuple (typs, s), Pos.get_position e)); List.map
(fun typ -> UnionFind.make (Pos.map_under_mark ast_to_typ typ))
typs
in
typecheck_expr_top_down ctx env e1
(UnionFind.make (TTuple (typs, s), Pos.get_position e));
match List.nth_opt typs n with match List.nth_opt typs n with
| Some t1n -> unify ctx t1n tau | Some t1n -> unify ctx t1n tau
| None -> | None ->
Errors.raise_spanned_error (Pos.get_position e1) Errors.raise_spanned_error (Pos.get_position e1)
"Expression should have a tuple type with at least %d elements but only has %d" n "Expression should have a tuple type with at least %d elements \
(List.length typs)) but only has %d"
n (List.length typs))
| EInj (e1, n, e_name, ts) -> | EInj (e1, n, e_name, ts) ->
let ts = List.map (fun t -> UnionFind.make (Pos.map_under_mark ast_to_typ t)) ts in let ts =
List.map
(fun t -> UnionFind.make (Pos.map_under_mark ast_to_typ t))
ts
in
let ts_n = let ts_n =
match List.nth_opt ts n with match List.nth_opt ts n with
| Some ts_n -> ts_n | Some ts_n -> ts_n
| None -> | None ->
Errors.raise_spanned_error (Pos.get_position e) Errors.raise_spanned_error (Pos.get_position e)
"Expression should have a sum type with at least %d cases but only has %d" n "Expression should have a sum type with at least %d cases but \
(List.length ts) only has %d"
n (List.length ts)
in in
typecheck_expr_top_down ctx env e1 ts_n; typecheck_expr_top_down ctx env e1 ts_n;
unify ctx (UnionFind.make (Pos.same_pos_as (TEnum (ts, e_name)) e)) tau unify ctx (UnionFind.make (Pos.same_pos_as (TEnum (ts, e_name)) e)) tau
| EMatch (e1, es, e_name) -> | EMatch (e1, es, e_name) ->
let enum_cases = let enum_cases =
List.map (fun e' -> UnionFind.make (Pos.same_pos_as (TAny (Any.fresh ())) e')) es List.map
(fun e' ->
UnionFind.make (Pos.same_pos_as (TAny (Any.fresh ())) e'))
es
in
let t_e1 =
UnionFind.make (Pos.same_pos_as (TEnum (enum_cases, e_name)) e1)
in in
let t_e1 = UnionFind.make (Pos.same_pos_as (TEnum (enum_cases, e_name)) e1) in
typecheck_expr_top_down ctx env e1 t_e1; typecheck_expr_top_down ctx env e1 t_e1;
let t_ret = UnionFind.make (Pos.same_pos_as (TAny (Any.fresh ())) e) in let t_ret = UnionFind.make (Pos.same_pos_as (TAny (Any.fresh ())) e) in
List.iteri List.iteri
(fun i es' -> (fun i es' ->
let enum_t = List.nth enum_cases i in let enum_t = List.nth enum_cases i in
let t_es' = UnionFind.make (Pos.same_pos_as (TArrow (enum_t, t_ret)) es') in let t_es' =
UnionFind.make (Pos.same_pos_as (TArrow (enum_t, t_ret)) es')
in
typecheck_expr_top_down ctx env es' t_es') typecheck_expr_top_down ctx env es' t_es')
es; es;
unify ctx tau t_ret unify ctx tau t_ret
@ -403,27 +495,34 @@ and typecheck_expr_top_down (ctx : Ast.decl_ctx) (env : env) (e : A.expr Pos.mar
if Array.length xs = List.length t_args then if Array.length xs = List.length t_args then
let xstaus = let xstaus =
List.map2 List.map2
(fun x t_arg -> (x, UnionFind.make (Pos.map_under_mark ast_to_typ t_arg))) (fun x t_arg ->
(x, UnionFind.make (Pos.map_under_mark ast_to_typ t_arg)))
(Array.to_list xs) t_args (Array.to_list xs) t_args
in in
let env = List.fold_left (fun env (x, t_arg) -> A.VarMap.add x t_arg env) env xstaus in let env =
List.fold_left
(fun env (x, t_arg) -> A.VarMap.add x t_arg env)
env xstaus
in
let t_out = typecheck_expr_bottom_up ctx env body in let t_out = typecheck_expr_bottom_up ctx env body in
let t_func = let t_func =
List.fold_right List.fold_right
(fun (_, t_arg) acc -> UnionFind.make (Pos.same_pos_as (TArrow (t_arg, acc)) e)) (fun (_, t_arg) acc ->
UnionFind.make (Pos.same_pos_as (TArrow (t_arg, acc)) e))
xstaus t_out xstaus t_out
in in
unify ctx t_func tau unify ctx t_func tau
else else
Errors.raise_spanned_error pos_binder Errors.raise_spanned_error pos_binder
"function has %d variables but was supplied %d types" (Array.length xs) "function has %d variables but was supplied %d types"
(List.length t_args) (Array.length xs) (List.length t_args)
| EApp (e1, args) -> | EApp (e1, args) ->
let t_args = List.map (typecheck_expr_bottom_up ctx env) args in let t_args = List.map (typecheck_expr_bottom_up ctx env) args in
let te1 = typecheck_expr_bottom_up ctx env e1 in let te1 = typecheck_expr_bottom_up ctx env e1 in
let t_func = let t_func =
List.fold_right List.fold_right
(fun t_arg acc -> UnionFind.make (Pos.same_pos_as (TArrow (t_arg, acc)) e)) (fun t_arg acc ->
UnionFind.make (Pos.same_pos_as (TArrow (t_arg, acc)) e))
t_args tau t_args tau
in in
unify ctx te1 t_func unify ctx te1 t_func
@ -431,19 +530,26 @@ and typecheck_expr_top_down (ctx : Ast.decl_ctx) (env : env) (e : A.expr Pos.mar
let op_typ = op_type (Pos.same_pos_as op e) in let op_typ = op_type (Pos.same_pos_as op e) in
unify ctx op_typ tau unify ctx op_typ tau
| EDefault (excepts, just, cons) -> | EDefault (excepts, just, cons) ->
typecheck_expr_top_down ctx env just (UnionFind.make (Pos.same_pos_as (TLit TBool) just)); typecheck_expr_top_down ctx env just
(UnionFind.make (Pos.same_pos_as (TLit TBool) just));
typecheck_expr_top_down ctx env cons tau; typecheck_expr_top_down ctx env cons tau;
List.iter (fun except -> typecheck_expr_top_down ctx env except tau) excepts List.iter
(fun except -> typecheck_expr_top_down ctx env except tau)
excepts
| EIfThenElse (cond, et, ef) -> | EIfThenElse (cond, et, ef) ->
typecheck_expr_top_down ctx env cond (UnionFind.make (Pos.same_pos_as (TLit TBool) cond)); typecheck_expr_top_down ctx env cond
(UnionFind.make (Pos.same_pos_as (TLit TBool) cond));
typecheck_expr_top_down ctx env et tau; typecheck_expr_top_down ctx env et tau;
typecheck_expr_top_down ctx env ef tau typecheck_expr_top_down ctx env ef tau
| EAssert e' -> | EAssert e' ->
typecheck_expr_top_down ctx env e' (UnionFind.make (Pos.same_pos_as (TLit TBool) e')); typecheck_expr_top_down ctx env e'
(UnionFind.make (Pos.same_pos_as (TLit TBool) e'));
unify ctx tau (UnionFind.make (Pos.same_pos_as (TLit TUnit) e')) unify ctx tau (UnionFind.make (Pos.same_pos_as (TLit TUnit) e'))
| ErrorOnEmpty e' -> typecheck_expr_top_down ctx env e' tau | ErrorOnEmpty e' -> typecheck_expr_top_down ctx env e' tau
| EArray es -> | EArray es ->
let cell_type = UnionFind.make (Pos.same_pos_as (TAny (Any.fresh ())) e) in let cell_type =
UnionFind.make (Pos.same_pos_as (TAny (Any.fresh ())) e)
in
List.iter List.iter
(fun e' -> (fun e' ->
let t_e' = typecheck_expr_bottom_up ctx env e' in let t_e' = typecheck_expr_bottom_up ctx env e' in
@ -454,7 +560,8 @@ and typecheck_expr_top_down (ctx : Ast.decl_ctx) (env : env) (e : A.expr Pos.mar
raise raise
(Errors.StructuredError (Errors.StructuredError
( msg, ( msg,
(Some "Error coming from typechecking the following expression:", Pos.get_position e) ( Some "Error coming from typechecking the following expression:",
Pos.get_position e )
:: err_pos )) :: err_pos ))
(** {1 API} *) (** {1 API} *)
@ -465,5 +572,7 @@ let infer_type (ctx : Ast.decl_ctx) (e : A.expr Pos.marked) : A.typ Pos.marked =
typ_to_ast ty typ_to_ast ty
(** Typechecks an expression given an expected type *) (** Typechecks an expression given an expected type *)
let check_type (ctx : Ast.decl_ctx) (e : A.expr Pos.marked) (tau : A.typ Pos.marked) = let check_type
typecheck_expr_top_down ctx A.VarMap.empty e (UnionFind.make (Pos.map_under_mark ast_to_typ tau)) (ctx : Ast.decl_ctx) (e : A.expr Pos.marked) (tau : A.typ Pos.marked) =
typecheck_expr_top_down ctx A.VarMap.empty e
(UnionFind.make (Pos.map_under_mark ast_to_typ tau))

View File

@ -1,20 +1,24 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits (* This file is part of the Catala compiler, a specification language for tax
computation rules. Copyright (C) 2020 Inria, contributor: Denis Merigoux and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
<denis.merigoux@inria.fr> Denis Merigoux <denis.merigoux@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
(** Typing for the default calculus. Because of the error terms, we perform type inference using the (** Typing for the default calculus. Because of the error terms, we perform type
classical W algorithm with union-find unification. *) inference using the classical W algorithm with union-find unification. *)
val infer_type : Ast.decl_ctx -> Ast.expr Utils.Pos.marked -> Ast.typ Utils.Pos.marked val infer_type :
Ast.decl_ctx -> Ast.expr Utils.Pos.marked -> Ast.typ Utils.Pos.marked
val check_type : Ast.decl_ctx -> Ast.expr Utils.Pos.marked -> Ast.typ Utils.Pos.marked -> unit val check_type :
Ast.decl_ctx -> Ast.expr Utils.Pos.marked -> Ast.typ Utils.Pos.marked -> unit

View File

@ -1,15 +1,17 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits (* This file is part of the Catala compiler, a specification language for tax
computation rules. Copyright (C) 2020 Inria, contributor: Nicolas Chataing and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
<nicolas.chataing@ens.fr> Nicolas Chataing <nicolas.chataing@ens.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
(** Abstract syntax tree of the desugared representation *) (** Abstract syntax tree of the desugared representation *)
@ -20,33 +22,35 @@ open Utils
module IdentMap : Map.S with type key = String.t = Map.Make (String) module IdentMap : Map.S with type key = String.t = Map.Make (String)
module RuleName : Uid.Id with type info = Uid.MarkedString.info = Uid.Make (Uid.MarkedString) () module RuleName : Uid.Id with type info = Uid.MarkedString.info =
Uid.Make (Uid.MarkedString) ()
module RuleMap : Map.S with type key = RuleName.t = Map.Make (RuleName) module RuleMap : Map.S with type key = RuleName.t = Map.Make (RuleName)
module RuleSet : Set.S with type elt = RuleName.t = Set.Make (RuleName) module RuleSet : Set.S with type elt = RuleName.t = Set.Make (RuleName)
module LabelName : Uid.Id with type info = Uid.MarkedString.info = Uid.Make (Uid.MarkedString) () module LabelName : Uid.Id with type info = Uid.MarkedString.info =
Uid.Make (Uid.MarkedString) ()
module LabelMap : Map.S with type key = LabelName.t = Map.Make (LabelName) module LabelMap : Map.S with type key = LabelName.t = Map.Make (LabelName)
module LabelSet : Set.S with type elt = LabelName.t = Set.Make (LabelName) module LabelSet : Set.S with type elt = LabelName.t = Set.Make (LabelName)
module StateName : Uid.Id with type info = Uid.MarkedString.info = Uid.Make (Uid.MarkedString) () module StateName : Uid.Id with type info = Uid.MarkedString.info =
Uid.Make (Uid.MarkedString) ()
module ScopeVar : Uid.Id with type info = Uid.MarkedString.info = Uid.Make (Uid.MarkedString) () module ScopeVar : Uid.Id with type info = Uid.MarkedString.info =
Uid.Make (Uid.MarkedString) ()
module ScopeVarSet : Set.S with type elt = ScopeVar.t = Set.Make (ScopeVar) module ScopeVarSet : Set.S with type elt = ScopeVar.t = Set.Make (ScopeVar)
module ScopeVarMap : Map.S with type key = ScopeVar.t = Map.Make (ScopeVar) module ScopeVarMap : Map.S with type key = ScopeVar.t = Map.Make (ScopeVar)
(** Inside a scope, a definition can refer either to a scope def, or a subscope def *) (** Inside a scope, a definition can refer either to a scope def, or a subscope
def *)
module ScopeDef = struct module ScopeDef = struct
type t = type t =
| Var of ScopeVar.t * StateName.t option | Var of ScopeVar.t * StateName.t option
| SubScopeVar of Scopelang.Ast.SubScopeName.t * ScopeVar.t | SubScopeVar of Scopelang.Ast.SubScopeName.t * ScopeVar.t
(** In this case, the [ScopeVar.t] lives inside the context of the subscope's original (** In this case, the [ScopeVar.t] lives inside the context of the
declaration *) subscope's original declaration *)
let compare x y = let compare x y =
match (x, y) with match (x, y) with
@ -67,24 +71,27 @@ module ScopeDef = struct
match x with match x with
| Var (x, None) -> Pos.get_position (ScopeVar.get_info x) | Var (x, None) -> Pos.get_position (ScopeVar.get_info x)
| Var (_, Some sx) -> Pos.get_position (StateName.get_info sx) | Var (_, Some sx) -> Pos.get_position (StateName.get_info sx)
| SubScopeVar (x, _) -> Pos.get_position (Scopelang.Ast.SubScopeName.get_info x) | SubScopeVar (x, _) ->
Pos.get_position (Scopelang.Ast.SubScopeName.get_info x)
let format_t fmt x = let format_t fmt x =
match x with match x with
| Var (v, None) -> ScopeVar.format_t fmt v | Var (v, None) -> ScopeVar.format_t fmt v
| Var (v, Some sv) -> Format.fprintf fmt "%a.%a" ScopeVar.format_t v StateName.format_t sv | Var (v, Some sv) ->
Format.fprintf fmt "%a.%a" ScopeVar.format_t v StateName.format_t sv
| SubScopeVar (s, v) -> | SubScopeVar (s, v) ->
Format.fprintf fmt "%a.%a" Scopelang.Ast.SubScopeName.format_t s ScopeVar.format_t v Format.fprintf fmt "%a.%a" Scopelang.Ast.SubScopeName.format_t s
ScopeVar.format_t v
let hash x = let hash x =
match x with match x with
| Var (v, None) -> ScopeVar.hash v | Var (v, None) -> ScopeVar.hash v
| Var (v, Some sv) -> Int.logxor (ScopeVar.hash v) (StateName.hash sv) | Var (v, Some sv) -> Int.logxor (ScopeVar.hash v) (StateName.hash sv)
| SubScopeVar (w, v) -> Int.logxor (Scopelang.Ast.SubScopeName.hash w) (ScopeVar.hash v) | SubScopeVar (w, v) ->
Int.logxor (Scopelang.Ast.SubScopeName.hash w) (ScopeVar.hash v)
end end
module ScopeDefMap : Map.S with type key = ScopeDef.t = Map.Make (ScopeDef) module ScopeDefMap : Map.S with type key = ScopeDef.t = Map.Make (ScopeDef)
module ScopeDefSet : Set.S with type elt = ScopeDef.t = Set.Make (ScopeDef) module ScopeDefSet : Set.S with type elt = ScopeDef.t = Set.Make (ScopeDef)
(** {1 AST} *) (** {1 AST} *)
@ -92,9 +99,12 @@ module ScopeDefSet : Set.S with type elt = ScopeDef.t = Set.Make (ScopeDef)
type location = type location =
| ScopeVar of ScopeVar.t Pos.marked * StateName.t option | ScopeVar of ScopeVar.t Pos.marked * StateName.t option
| SubScopeVar of | SubScopeVar of
Scopelang.Ast.ScopeName.t * Scopelang.Ast.SubScopeName.t Pos.marked * ScopeVar.t Pos.marked Scopelang.Ast.ScopeName.t
* Scopelang.Ast.SubScopeName.t Pos.marked
* ScopeVar.t Pos.marked
module LocationSet : Set.S with type elt = location Pos.marked = Set.Make (struct module LocationSet : Set.S with type elt = location Pos.marked =
Set.Make (struct
type t = location Pos.marked type t = location Pos.marked
let compare x y = let compare x y =
@ -106,28 +116,38 @@ module LocationSet : Set.S with type elt = location Pos.marked = Set.Make (struc
| ScopeVar ((x, _), Some sx), ScopeVar ((y, _), Some sy) -> | ScopeVar ((x, _), Some sx), ScopeVar ((y, _), Some sy) ->
let cmp = ScopeVar.compare x y in let cmp = ScopeVar.compare x y in
if cmp = 0 then StateName.compare sx sy else cmp if cmp = 0 then StateName.compare sx sy else cmp
| SubScopeVar (_, (xsubindex, _), (xsubvar, _)), SubScopeVar (_, (ysubindex, _), (ysubvar, _)) | ( SubScopeVar (_, (xsubindex, _), (xsubvar, _)),
-> SubScopeVar (_, (ysubindex, _), (ysubvar, _)) ) ->
let c = Scopelang.Ast.SubScopeName.compare xsubindex ysubindex in let c = Scopelang.Ast.SubScopeName.compare xsubindex ysubindex in
if c = 0 then ScopeVar.compare xsubvar ysubvar else c if c = 0 then ScopeVar.compare xsubvar ysubvar else c
| ScopeVar _, SubScopeVar _ -> -1 | ScopeVar _, SubScopeVar _ -> -1
| SubScopeVar _, ScopeVar _ -> 1 | SubScopeVar _, ScopeVar _ -> 1
end) end)
(** The expressions use the {{:https://lepigre.fr/ocaml-bindlib/} Bindlib} library, based on (** The expressions use the {{:https://lepigre.fr/ocaml-bindlib/} Bindlib}
higher-order abstract syntax*) library, based on higher-order abstract syntax*)
type expr = type expr =
| ELocation of location | ELocation of location
| EVar of expr Bindlib.var Pos.marked | EVar of expr Bindlib.var Pos.marked
| EStruct of Scopelang.Ast.StructName.t * expr Pos.marked Scopelang.Ast.StructFieldMap.t | EStruct of
| EStructAccess of expr Pos.marked * Scopelang.Ast.StructFieldName.t * Scopelang.Ast.StructName.t Scopelang.Ast.StructName.t
| EEnumInj of expr Pos.marked * Scopelang.Ast.EnumConstructor.t * Scopelang.Ast.EnumName.t * expr Pos.marked Scopelang.Ast.StructFieldMap.t
| EStructAccess of
expr Pos.marked
* Scopelang.Ast.StructFieldName.t
* Scopelang.Ast.StructName.t
| EEnumInj of
expr Pos.marked
* Scopelang.Ast.EnumConstructor.t
* Scopelang.Ast.EnumName.t
| EMatch of | EMatch of
expr Pos.marked expr Pos.marked
* Scopelang.Ast.EnumName.t * Scopelang.Ast.EnumName.t
* expr Pos.marked Scopelang.Ast.EnumConstructorMap.t * expr Pos.marked Scopelang.Ast.EnumConstructorMap.t
| ELit of Dcalc.Ast.lit | ELit of Dcalc.Ast.lit
| EAbs of (expr, expr Pos.marked) Bindlib.mbinder Pos.marked * Scopelang.Ast.typ Pos.marked list | EAbs of
(expr, expr Pos.marked) Bindlib.mbinder Pos.marked
* Scopelang.Ast.typ Pos.marked list
| EApp of expr Pos.marked * expr Pos.marked list | EApp of expr Pos.marked * expr Pos.marked list
| EOp of Dcalc.Ast.operator | EOp of Dcalc.Ast.operator
| EDefault of expr Pos.marked list * expr Pos.marked * expr Pos.marked | EDefault of expr Pos.marked list * expr Pos.marked * expr Pos.marked
@ -156,30 +176,36 @@ type rule = {
rule_exception_to_rules : RuleSet.t Pos.marked; rule_exception_to_rules : RuleSet.t Pos.marked;
} }
let empty_rule (pos : Pos.t) (have_parameter : Scopelang.Ast.typ Pos.marked option) : rule = let empty_rule
(pos : Pos.t) (have_parameter : Scopelang.Ast.typ Pos.marked option) : rule
=
{ {
rule_just = Bindlib.box (ELit (Dcalc.Ast.LBool false), pos); rule_just = Bindlib.box (ELit (Dcalc.Ast.LBool false), pos);
rule_cons = Bindlib.box (ELit Dcalc.Ast.LEmptyError, pos); rule_cons = Bindlib.box (ELit Dcalc.Ast.LEmptyError, pos);
rule_parameter = rule_parameter =
(match have_parameter with Some typ -> Some (Var.make ("dummy", pos), typ) | None -> None); (match have_parameter with
| Some typ -> Some (Var.make ("dummy", pos), typ)
| None -> None);
rule_exception_to_rules = (RuleSet.empty, pos); rule_exception_to_rules = (RuleSet.empty, pos);
rule_id = RuleName.fresh ("empty", pos); rule_id = RuleName.fresh ("empty", pos);
} }
let always_false_rule (pos : Pos.t) (have_parameter : Scopelang.Ast.typ Pos.marked option) : rule = let always_false_rule
(pos : Pos.t) (have_parameter : Scopelang.Ast.typ Pos.marked option) : rule
=
{ {
rule_just = Bindlib.box (ELit (Dcalc.Ast.LBool true), pos); rule_just = Bindlib.box (ELit (Dcalc.Ast.LBool true), pos);
rule_cons = Bindlib.box (ELit (Dcalc.Ast.LBool false), pos); rule_cons = Bindlib.box (ELit (Dcalc.Ast.LBool false), pos);
rule_parameter = rule_parameter =
(match have_parameter with Some typ -> Some (Var.make ("dummy", pos), typ) | None -> None); (match have_parameter with
| Some typ -> Some (Var.make ("dummy", pos), typ)
| None -> None);
rule_exception_to_rules = (RuleSet.empty, pos); rule_exception_to_rules = (RuleSet.empty, pos);
rule_id = RuleName.fresh ("always_false", pos); rule_id = RuleName.fresh ("always_false", pos);
} }
type assertion = expr Pos.marked Bindlib.box type assertion = expr Pos.marked Bindlib.box
type variation_typ = Increasing | Decreasing type variation_typ = Increasing | Decreasing
type reference_typ = Decree | Law type reference_typ = Decree | Law
type meta_assertion = type meta_assertion =
@ -241,11 +267,14 @@ let rec locations_used (e : expr Pos.marked) : LocationSet.t =
(LocationSet.union (locations_used just) (locations_used cons)) (LocationSet.union (locations_used just) (locations_used cons))
excepts excepts
| EArray es -> | EArray es ->
List.fold_left (fun acc e' -> LocationSet.union acc (locations_used e')) LocationSet.empty es List.fold_left
(fun acc e' -> LocationSet.union acc (locations_used e'))
LocationSet.empty es
| ErrorOnEmpty e' -> locations_used e' | ErrorOnEmpty e' -> locations_used e'
let free_variables (def : rule RuleMap.t) : Pos.t ScopeDefMap.t = let free_variables (def : rule RuleMap.t) : Pos.t ScopeDefMap.t =
let add_locs (acc : Pos.t ScopeDefMap.t) (locs : LocationSet.t) : Pos.t ScopeDefMap.t = let add_locs (acc : Pos.t ScopeDefMap.t) (locs : LocationSet.t) :
Pos.t ScopeDefMap.t =
LocationSet.fold LocationSet.fold
(fun (loc, loc_pos) acc -> (fun (loc, loc_pos) acc ->
ScopeDefMap.add ScopeDefMap.add
@ -269,15 +298,26 @@ let free_variables (def : rule RuleMap.t) : Pos.t ScopeDefMap.t =
let make_var ((x, pos) : Var.t Pos.marked) : expr Pos.marked Bindlib.box = let make_var ((x, pos) : Var.t Pos.marked) : expr Pos.marked Bindlib.box =
Bindlib.box_apply (fun v -> (v, pos)) (Bindlib.box_var x) Bindlib.box_apply (fun v -> (v, pos)) (Bindlib.box_var x)
let make_abs (xs : vars) (e : expr Pos.marked Bindlib.box) (pos_binder : Pos.t) let make_abs
(taus : Scopelang.Ast.typ Pos.marked list) (pos : Pos.t) : expr Pos.marked Bindlib.box = (xs : vars)
Bindlib.box_apply (fun b -> (EAbs ((b, pos_binder), taus), pos)) (Bindlib.bind_mvar xs e) (e : expr Pos.marked Bindlib.box)
(pos_binder : Pos.t)
(taus : Scopelang.Ast.typ Pos.marked list)
(pos : Pos.t) : expr Pos.marked Bindlib.box =
Bindlib.box_apply
(fun b -> (EAbs ((b, pos_binder), taus), pos))
(Bindlib.bind_mvar xs e)
let make_app (e : expr Pos.marked Bindlib.box) (u : expr Pos.marked Bindlib.box list) (pos : Pos.t) let make_app
: expr Pos.marked Bindlib.box = (e : expr Pos.marked Bindlib.box)
(u : expr Pos.marked Bindlib.box list)
(pos : Pos.t) : expr Pos.marked Bindlib.box =
Bindlib.box_apply2 (fun e u -> (EApp (e, u), pos)) e (Bindlib.box_list u) Bindlib.box_apply2 (fun e u -> (EApp (e, u), pos)) e (Bindlib.box_list u)
let make_let_in (x : Var.t) (tau : Scopelang.Ast.typ Pos.marked) (e1 : expr Pos.marked Bindlib.box) let make_let_in
(x : Var.t)
(tau : Scopelang.Ast.typ Pos.marked)
(e1 : expr Pos.marked Bindlib.box)
(e2 : expr Pos.marked Bindlib.box) : expr Pos.marked Bindlib.box = (e2 : expr Pos.marked Bindlib.box) : expr Pos.marked Bindlib.box =
Bindlib.box_apply2 Bindlib.box_apply2
(fun e u -> (EApp (e, u), Pos.get_position (Bindlib.unbox e2))) (fun e u -> (EApp (e, u), Pos.get_position (Bindlib.unbox e2)))

View File

@ -1,15 +1,17 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits (* This file is part of the Catala compiler, a specification language for tax
computation rules. Copyright (C) 2020 Inria, contributor: Nicolas Chataing and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
<nicolas.chataing@ens.fr> Nicolas Chataing <nicolas.chataing@ens.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
(** Abstract syntax tree of the desugared representation *) (** Abstract syntax tree of the desugared representation *)
@ -19,44 +21,31 @@ open Utils
(** {1 Names, Maps and Keys} *) (** {1 Names, Maps and Keys} *)
module IdentMap : Map.S with type key = String.t module IdentMap : Map.S with type key = String.t
module RuleName : Uid.Id with type info = Uid.MarkedString.info module RuleName : Uid.Id with type info = Uid.MarkedString.info
module RuleMap : Map.S with type key = RuleName.t module RuleMap : Map.S with type key = RuleName.t
module RuleSet : Set.S with type elt = RuleName.t module RuleSet : Set.S with type elt = RuleName.t
module LabelName : Uid.Id with type info = Uid.MarkedString.info module LabelName : Uid.Id with type info = Uid.MarkedString.info
module LabelMap : Map.S with type key = LabelName.t module LabelMap : Map.S with type key = LabelName.t
module LabelSet : Set.S with type elt = LabelName.t module LabelSet : Set.S with type elt = LabelName.t
module StateName : Uid.Id with type info = Uid.MarkedString.info module StateName : Uid.Id with type info = Uid.MarkedString.info
module ScopeVar : Uid.Id with type info = Uid.MarkedString.info module ScopeVar : Uid.Id with type info = Uid.MarkedString.info
module ScopeVarSet : Set.S with type elt = ScopeVar.t module ScopeVarSet : Set.S with type elt = ScopeVar.t
module ScopeVarMap : Map.S with type key = ScopeVar.t module ScopeVarMap : Map.S with type key = ScopeVar.t
(** Inside a scope, a definition can refer either to a scope def, or a subscope def *) (** Inside a scope, a definition can refer either to a scope def, or a subscope
def *)
module ScopeDef : sig module ScopeDef : sig
type t = type t =
| Var of ScopeVar.t * StateName.t option | Var of ScopeVar.t * StateName.t option
| SubScopeVar of Scopelang.Ast.SubScopeName.t * ScopeVar.t | SubScopeVar of Scopelang.Ast.SubScopeName.t * ScopeVar.t
val compare : t -> t -> int val compare : t -> t -> int
val get_position : t -> Pos.t val get_position : t -> Pos.t
val format_t : Format.formatter -> t -> unit val format_t : Format.formatter -> t -> unit
val hash : t -> int val hash : t -> int
end end
module ScopeDefMap : Map.S with type key = ScopeDef.t module ScopeDefMap : Map.S with type key = ScopeDef.t
module ScopeDefSet : Set.S with type elt = ScopeDef.t module ScopeDefSet : Set.S with type elt = ScopeDef.t
(** {1 AST} *) (** {1 AST} *)
@ -65,24 +54,36 @@ module ScopeDefSet : Set.S with type elt = ScopeDef.t
type location = type location =
| ScopeVar of ScopeVar.t Pos.marked * StateName.t option | ScopeVar of ScopeVar.t Pos.marked * StateName.t option
| SubScopeVar of | SubScopeVar of
Scopelang.Ast.ScopeName.t * Scopelang.Ast.SubScopeName.t Pos.marked * ScopeVar.t Pos.marked Scopelang.Ast.ScopeName.t
* Scopelang.Ast.SubScopeName.t Pos.marked
* ScopeVar.t Pos.marked
module LocationSet : Set.S with type elt = location Pos.marked module LocationSet : Set.S with type elt = location Pos.marked
(** The expressions use the {{:https://lepigre.fr/ocaml-bindlib/} Bindlib} library, based on (** The expressions use the {{:https://lepigre.fr/ocaml-bindlib/} Bindlib}
higher-order abstract syntax*) library, based on higher-order abstract syntax*)
type expr = type expr =
| ELocation of location | ELocation of location
| EVar of expr Bindlib.var Pos.marked | EVar of expr Bindlib.var Pos.marked
| EStruct of Scopelang.Ast.StructName.t * expr Pos.marked Scopelang.Ast.StructFieldMap.t | EStruct of
| EStructAccess of expr Pos.marked * Scopelang.Ast.StructFieldName.t * Scopelang.Ast.StructName.t Scopelang.Ast.StructName.t
| EEnumInj of expr Pos.marked * Scopelang.Ast.EnumConstructor.t * Scopelang.Ast.EnumName.t * expr Pos.marked Scopelang.Ast.StructFieldMap.t
| EStructAccess of
expr Pos.marked
* Scopelang.Ast.StructFieldName.t
* Scopelang.Ast.StructName.t
| EEnumInj of
expr Pos.marked
* Scopelang.Ast.EnumConstructor.t
* Scopelang.Ast.EnumName.t
| EMatch of | EMatch of
expr Pos.marked expr Pos.marked
* Scopelang.Ast.EnumName.t * Scopelang.Ast.EnumName.t
* expr Pos.marked Scopelang.Ast.EnumConstructorMap.t * expr Pos.marked Scopelang.Ast.EnumConstructorMap.t
| ELit of Dcalc.Ast.lit | ELit of Dcalc.Ast.lit
| EAbs of (expr, expr Pos.marked) Bindlib.mbinder Pos.marked * Scopelang.Ast.typ Pos.marked list | EAbs of
(expr, expr Pos.marked) Bindlib.mbinder Pos.marked
* Scopelang.Ast.typ Pos.marked list
| EApp of expr Pos.marked * expr Pos.marked list | EApp of expr Pos.marked * expr Pos.marked list
| EOp of Dcalc.Ast.operator | EOp of Dcalc.Ast.operator
| EDefault of expr Pos.marked list * expr Pos.marked * expr Pos.marked | EDefault of expr Pos.marked list * expr Pos.marked * expr Pos.marked
@ -96,7 +97,6 @@ module Var : sig
type t = expr Bindlib.var type t = expr Bindlib.var
val make : string Pos.marked -> t val make : string Pos.marked -> t
val compare : t -> t -> int val compare : t -> t -> int
end end
@ -138,13 +138,10 @@ type rule = {
} }
val empty_rule : Pos.t -> Scopelang.Ast.typ Pos.marked option -> rule val empty_rule : Pos.t -> Scopelang.Ast.typ Pos.marked option -> rule
val always_false_rule : Pos.t -> Scopelang.Ast.typ Pos.marked option -> rule val always_false_rule : Pos.t -> Scopelang.Ast.typ Pos.marked option -> rule
type assertion = expr Pos.marked Bindlib.box type assertion = expr Pos.marked Bindlib.box
type variation_typ = Increasing | Decreasing type variation_typ = Increasing | Decreasing
type reference_typ = Decree | Law type reference_typ = Decree | Law
type meta_assertion = type meta_assertion =
@ -179,5 +176,4 @@ type program = {
(** {1 Helpers} *) (** {1 Helpers} *)
val locations_used : expr Pos.marked -> LocationSet.t val locations_used : expr Pos.marked -> LocationSet.t
val free_variables : rule RuleMap.t -> Pos.t ScopeDefMap.t val free_variables : rule RuleMap.t -> Pos.t ScopeDefMap.t

View File

@ -1,18 +1,21 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits (* This file is part of the Catala compiler, a specification language for tax
computation rules. Copyright (C) 2020 Inria, contributor: Nicolas Chataing and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
<nicolas.chataing@ens.fr> Nicolas Chataing <nicolas.chataing@ens.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
(** Scope dependencies computations using {{:http://ocamlgraph.lri.fr/} OCamlgraph} *) (** Scope dependencies computations using {{:http://ocamlgraph.lri.fr/}
OCamlgraph} *)
open Utils open Utils
@ -36,7 +39,8 @@ module Vertex = struct
let hash x = let hash x =
match x with match x with
| Var (x, None) -> Ast.ScopeVar.hash x | Var (x, None) -> Ast.ScopeVar.hash x
| Var (x, Some sx) -> Int.logxor (Ast.ScopeVar.hash x) (Ast.StateName.hash sx) | Var (x, Some sx) ->
Int.logxor (Ast.ScopeVar.hash x) (Ast.StateName.hash sx)
| SubScope x -> Scopelang.Ast.SubScopeName.hash x | SubScope x -> Scopelang.Ast.SubScopeName.hash x
let compare = compare let compare = compare
@ -53,21 +57,23 @@ module Vertex = struct
match x with match x with
| Var (v, None) -> Ast.ScopeVar.format_t fmt v | Var (v, None) -> Ast.ScopeVar.format_t fmt v
| Var (v, Some sv) -> | Var (v, Some sv) ->
Format.fprintf fmt "%a.%a" Ast.ScopeVar.format_t v Ast.StateName.format_t sv Format.fprintf fmt "%a.%a" Ast.ScopeVar.format_t v
Ast.StateName.format_t sv
| SubScope v -> Scopelang.Ast.SubScopeName.format_t fmt v | SubScope v -> Scopelang.Ast.SubScopeName.format_t fmt v
end end
(** On the edges, the label is the position of the expression responsible for the use of the (** On the edges, the label is the position of the expression responsible for
variable. In the graph, [x -> y] if [x] is used in the definition of [y].*) the use of the variable. In the graph, [x -> y] if [x] is used in the
definition of [y].*)
module Edge = struct module Edge = struct
type t = Pos.t type t = Pos.t
let compare = compare let compare = compare
let default = Pos.no_pos let default = Pos.no_pos
end end
module ScopeDependencies = Graph.Persistent.Digraph.ConcreteBidirectionalLabeled (Vertex) (Edge) module ScopeDependencies =
Graph.Persistent.Digraph.ConcreteBidirectionalLabeled (Vertex) (Edge)
(** Module of the graph, provided by OCamlGraph *) (** Module of the graph, provided by OCamlGraph *)
module TopologicalTraversal = Graph.Topological.Make (ScopeDependencies) module TopologicalTraversal = Graph.Topological.Make (ScopeDependencies)
@ -78,14 +84,15 @@ module SCC = Graph.Components.Make (ScopeDependencies)
(** {2 Graph computations} *) (** {2 Graph computations} *)
(** Returns an ordering of the scope variables and subscope compatible with the dependencies of the (** Returns an ordering of the scope variables and subscope compatible with the
computation *) dependencies of the computation *)
let correct_computation_ordering (g : ScopeDependencies.t) : Vertex.t list = let correct_computation_ordering (g : ScopeDependencies.t) : Vertex.t list =
List.rev (TopologicalTraversal.fold (fun sd acc -> sd :: acc) g []) List.rev (TopologicalTraversal.fold (fun sd acc -> sd :: acc) g [])
(** Outputs an error in case of cycles. *) (** Outputs an error in case of cycles. *)
let check_for_cycle (scope : Ast.scope) (g : ScopeDependencies.t) : unit = let check_for_cycle (scope : Ast.scope) (g : ScopeDependencies.t) : unit =
(* if there is a cycle, there will be an strongly connected component of cardinality > 1 *) (* if there is a cycle, there will be an strongly connected component of
cardinality > 1 *)
let sccs = SCC.scc_list g in let sccs = SCC.scc_list g in
if List.length sccs < ScopeDependencies.nb_vertex g then if List.length sccs < ScopeDependencies.nb_vertex g then
let scc = List.find (fun scc -> List.length scc > 1) sccs in let scc = List.find (fun scc -> List.length scc > 1) sccs in
@ -96,33 +103,43 @@ let check_for_cycle (scope : Ast.scope) (g : ScopeDependencies.t) : unit =
let var_str, var_info = let var_str, var_info =
match v with match v with
| Vertex.Var (v, None) -> | Vertex.Var (v, None) ->
(Format.asprintf "%a" Ast.ScopeVar.format_t v, Ast.ScopeVar.get_info v) ( Format.asprintf "%a" Ast.ScopeVar.format_t v,
Ast.ScopeVar.get_info v )
| Vertex.Var (v, Some sv) -> | Vertex.Var (v, Some sv) ->
( Format.asprintf "%a.%a" Ast.ScopeVar.format_t v Ast.StateName.format_t sv, ( Format.asprintf "%a.%a" Ast.ScopeVar.format_t v
Ast.StateName.format_t sv,
Ast.StateName.get_info sv ) Ast.StateName.get_info sv )
| Vertex.SubScope v -> | Vertex.SubScope v ->
( Format.asprintf "%a" Scopelang.Ast.SubScopeName.format_t v, ( Format.asprintf "%a" Scopelang.Ast.SubScopeName.format_t v,
Scopelang.Ast.SubScopeName.get_info v ) Scopelang.Ast.SubScopeName.get_info v )
in in
let succs = ScopeDependencies.succ_e g v in let succs = ScopeDependencies.succ_e g v in
let _, edge_pos, succ = List.find (fun (_, _, succ) -> List.mem succ scc) succs in let _, edge_pos, succ =
List.find (fun (_, _, succ) -> List.mem succ scc) succs
in
let succ_str = let succ_str =
match succ with match succ with
| Vertex.Var (v, None) -> Format.asprintf "%a" Ast.ScopeVar.format_t v | Vertex.Var (v, None) ->
Format.asprintf "%a" Ast.ScopeVar.format_t v
| Vertex.Var (v, Some sv) -> | Vertex.Var (v, Some sv) ->
Format.asprintf "%a.%a" Ast.ScopeVar.format_t v Ast.StateName.format_t sv Format.asprintf "%a.%a" Ast.ScopeVar.format_t v
| Vertex.SubScope v -> Format.asprintf "%a" Scopelang.Ast.SubScopeName.format_t v Ast.StateName.format_t sv
| Vertex.SubScope v ->
Format.asprintf "%a" Scopelang.Ast.SubScopeName.format_t v
in in
[ [
(Some ("Cycle variable " ^ var_str ^ ", declared:"), Pos.get_position var_info); ( Some ("Cycle variable " ^ var_str ^ ", declared:"),
( Some ("Used here in the definition of another cycle variable " ^ succ_str ^ ":"), Pos.get_position var_info );
( Some
("Used here in the definition of another cycle variable "
^ succ_str ^ ":"),
edge_pos ); edge_pos );
]) ])
scc) scc)
in in
Errors.raise_multispanned_error spans Errors.raise_multispanned_error spans
"Cyclic dependency detected between variables of scope %a!" Scopelang.Ast.ScopeName.format_t "Cyclic dependency detected between variables of scope %a!"
scope.scope_uid Scopelang.Ast.ScopeName.format_t scope.scope_uid
(** Builds the dependency graph of a particular scope *) (** Builds the dependency graph of a particular scope *)
let build_scope_dependencies (scope : Ast.scope) : ScopeDependencies.t = let build_scope_dependencies (scope : Ast.scope) : ScopeDependencies.t =
@ -135,7 +152,8 @@ let build_scope_dependencies (scope : Ast.scope) : ScopeDependencies.t =
| Ast.WholeVar -> ScopeDependencies.add_vertex g (Vertex.Var (v, None)) | Ast.WholeVar -> ScopeDependencies.add_vertex g (Vertex.Var (v, None))
| Ast.States states -> | Ast.States states ->
List.fold_left List.fold_left
(fun g state -> ScopeDependencies.add_vertex g (Vertex.Var (v, Some state))) (fun g state ->
ScopeDependencies.add_vertex g (Vertex.Var (v, Some state)))
g states) g states)
scope.scope_vars g scope.scope_vars g
in in
@ -153,13 +171,14 @@ let build_scope_dependencies (scope : Ast.scope) : ScopeDependencies.t =
Ast.ScopeDefMap.fold Ast.ScopeDefMap.fold
(fun fv_def fv_def_pos g -> (fun fv_def fv_def_pos g ->
match (def_key, fv_def) with match (def_key, fv_def) with
| Ast.ScopeDef.Var (v_defined, s_defined), Ast.ScopeDef.Var (v_used, s_used) -> | ( Ast.ScopeDef.Var (v_defined, s_defined),
Ast.ScopeDef.Var (v_used, s_used) ) ->
(* simple case *) (* simple case *)
if v_used = v_defined && s_used = s_defined then if v_used = v_defined && s_used = s_defined then
(* variable definitions cannot be recursive *) (* variable definitions cannot be recursive *)
Errors.raise_spanned_error fv_def_pos Errors.raise_spanned_error fv_def_pos
"The variable %a is used in one of its definitions, but recursion is forbidden \ "The variable %a is used in one of its definitions, but \
in Catala" recursion is forbidden in Catala"
Ast.ScopeDef.format_t def_key Ast.ScopeDef.format_t def_key
else else
let edge = let edge =
@ -169,21 +188,25 @@ let build_scope_dependencies (scope : Ast.scope) : ScopeDependencies.t =
(Vertex.Var (v_defined, s_defined)) (Vertex.Var (v_defined, s_defined))
in in
ScopeDependencies.add_edge_e g edge ScopeDependencies.add_edge_e g edge
| Ast.ScopeDef.SubScopeVar (defined, _), Ast.ScopeDef.Var (v_used, s_used) -> | ( Ast.ScopeDef.SubScopeVar (defined, _),
(* here we are defining the input of a subscope using a var of the scope *) Ast.ScopeDef.Var (v_used, s_used) ) ->
(* here we are defining the input of a subscope using a var of
the scope *)
let edge = let edge =
ScopeDependencies.E.create ScopeDependencies.E.create
(Vertex.Var (v_used, s_used)) (Vertex.Var (v_used, s_used))
fv_def_pos (Vertex.SubScope defined) fv_def_pos (Vertex.SubScope defined)
in in
ScopeDependencies.add_edge_e g edge ScopeDependencies.add_edge_e g edge
| Ast.ScopeDef.SubScopeVar (defined, _), Ast.ScopeDef.SubScopeVar (used, _) -> | ( Ast.ScopeDef.SubScopeVar (defined, _),
(* here we are defining the input of a scope with the output of another subscope *) Ast.ScopeDef.SubScopeVar (used, _) ) ->
(* here we are defining the input of a scope with the output of
another subscope *)
if used = defined then if used = defined then
(* subscopes are not recursive functions *) (* subscopes are not recursive functions *)
Errors.raise_spanned_error fv_def_pos Errors.raise_spanned_error fv_def_pos
"The subscope %a is used when defining one of its inputs, but recursion is \ "The subscope %a is used when defining one of its inputs, \
forbidden in Catala" but recursion is forbidden in Catala"
Scopelang.Ast.SubScopeName.format_t defined Scopelang.Ast.SubScopeName.format_t defined
else else
let edge = let edge =
@ -191,8 +214,10 @@ let build_scope_dependencies (scope : Ast.scope) : ScopeDependencies.t =
(Vertex.SubScope defined) (Vertex.SubScope defined)
in in
ScopeDependencies.add_edge_e g edge ScopeDependencies.add_edge_e g edge
| Ast.ScopeDef.Var (v_defined, s_defined), Ast.ScopeDef.SubScopeVar (used, _) -> | ( Ast.ScopeDef.Var (v_defined, s_defined),
(* finally we define a scope var with the output of a subscope *) Ast.ScopeDef.SubScopeVar (used, _) ) ->
(* finally we define a scope var with the output of a
subscope *)
let edge = let edge =
ScopeDependencies.E.create (Vertex.SubScope used) fv_def_pos ScopeDependencies.E.create (Vertex.SubScope used) fv_def_pos
(Vertex.Var (v_defined, s_defined)) (Vertex.Var (v_defined, s_defined))
@ -210,33 +235,38 @@ let build_scope_dependencies (scope : Ast.scope) : ScopeDependencies.t =
module ExceptionVertex = struct module ExceptionVertex = struct
include Ast.RuleSet include Ast.RuleSet
let hash (x : t) : int = Ast.RuleSet.fold (fun r acc -> Int.logxor (Ast.RuleName.hash r) acc) x 0 let hash (x : t) : int =
Ast.RuleSet.fold (fun r acc -> Int.logxor (Ast.RuleName.hash r) acc) x 0
let equal x y = compare x y = 0 let equal x y = compare x y = 0
end end
module ExceptionsDependencies = module ExceptionsDependencies =
Graph.Persistent.Digraph.ConcreteBidirectionalLabeled (ExceptionVertex) (Edge) Graph.Persistent.Digraph.ConcreteBidirectionalLabeled (ExceptionVertex) (Edge)
(** Module of the graph, provided by OCamlGraph. [x -> y] if [y] is an exception to [x] *) (** Module of the graph, provided by OCamlGraph. [x -> y] if [y] is an exception
to [x] *)
module ExceptionsSCC = Graph.Components.Make (ExceptionsDependencies) module ExceptionsSCC = Graph.Components.Make (ExceptionsDependencies)
(** Tarjan's stongly connected components algorithm, provided by OCamlGraph *) (** Tarjan's stongly connected components algorithm, provided by OCamlGraph *)
(** {2 Graph computations} *) (** {2 Graph computations} *)
let build_exceptions_graph (def : Ast.rule Ast.RuleMap.t) (def_info : Ast.ScopeDef.t) : let build_exceptions_graph
(def : Ast.rule Ast.RuleMap.t) (def_info : Ast.ScopeDef.t) :
ExceptionsDependencies.t = ExceptionsDependencies.t =
(* first we collect all the rule sets referred by exceptions *) (* first we collect all the rule sets referred by exceptions *)
let all_rule_sets_pointed_to_by_exceptions : Ast.RuleSet.t list = let all_rule_sets_pointed_to_by_exceptions : Ast.RuleSet.t list =
Ast.RuleMap.fold Ast.RuleMap.fold
(fun _rule_name rule acc -> (fun _rule_name rule acc ->
if Ast.RuleSet.is_empty (Pos.unmark rule.Ast.rule_exception_to_rules) then acc if Ast.RuleSet.is_empty (Pos.unmark rule.Ast.rule_exception_to_rules)
then acc
else Pos.unmark rule.Ast.rule_exception_to_rules :: acc) else Pos.unmark rule.Ast.rule_exception_to_rules :: acc)
def [] def []
in in
(* we make sure these sets are either disjoint or equal ; should be a syntactic invariant since (* we make sure these sets are either disjoint or equal ; should be a
you currently can't assign two labels to a single rule but an extra check is valuable since syntactic invariant since you currently can't assign two labels to a single
this is a required invariant for the graph to be sound *) rule but an extra check is valuable since this is a required invariant for
the graph to be sound *)
List.iter List.iter
(fun rule_set1 -> (fun rule_set1 ->
List.iter List.iter
@ -259,12 +289,13 @@ let build_exceptions_graph (def : Ast.rule Ast.RuleMap.t) (def_info : Ast.ScopeD
(Ast.RuleSet.to_seq rule_set2)) (Ast.RuleSet.to_seq rule_set2))
in in
Errors.raise_multispanned_error spans Errors.raise_multispanned_error spans
"Definitions or rules grouped by different labels overlap, whereas these groups \ "Definitions or rules grouped by different labels overlap, \
shoule be disjoint") whereas these groups shoule be disjoint")
all_rule_sets_pointed_to_by_exceptions) all_rule_sets_pointed_to_by_exceptions)
all_rule_sets_pointed_to_by_exceptions; all_rule_sets_pointed_to_by_exceptions;
(* Then we add the exception graph vertices by taking all those sets of rules pointed to by (* Then we add the exception graph vertices by taking all those sets of rules
exceptions, and adding the remaining rules not pointed as separate singleton set vertices *) pointed to by exceptions, and adding the remaining rules not pointed as
separate singleton set vertices *)
let g = let g =
List.fold_left List.fold_left
(fun g rule_set -> ExceptionsDependencies.add_vertex g rule_set) (fun g rule_set -> ExceptionsDependencies.add_vertex g rule_set)
@ -279,30 +310,34 @@ let build_exceptions_graph (def : Ast.rule Ast.RuleMap.t) (def_info : Ast.ScopeD
Ast.RuleSet.mem rule_name rule_set_pointed_to_by_exceptions) Ast.RuleSet.mem rule_name rule_set_pointed_to_by_exceptions)
all_rule_sets_pointed_to_by_exceptions all_rule_sets_pointed_to_by_exceptions
then g then g
else ExceptionsDependencies.add_vertex g (Ast.RuleSet.singleton rule_name)) else
ExceptionsDependencies.add_vertex g (Ast.RuleSet.singleton rule_name))
def g def g
in in
(* then we add the edges *) (* then we add the edges *)
let g = let g =
Ast.RuleMap.fold Ast.RuleMap.fold
(fun rule_name rule g -> (fun rule_name rule g ->
(* Right now, exceptions can only consist of one rule, we may want to relax that constraint (* Right now, exceptions can only consist of one rule, we may want to
later in the development of Catala. *) relax that constraint later in the development of Catala. *)
let exception_to_ruleset, pos = rule.Ast.rule_exception_to_rules in let exception_to_ruleset, pos = rule.Ast.rule_exception_to_rules in
if Ast.RuleSet.is_empty exception_to_ruleset then g (* we don't add an edge*) if Ast.RuleSet.is_empty exception_to_ruleset then g
(* we don't add an edge*)
else if ExceptionsDependencies.mem_vertex g exception_to_ruleset then else if ExceptionsDependencies.mem_vertex g exception_to_ruleset then
if exception_to_ruleset = Ast.RuleSet.singleton rule_name then if exception_to_ruleset = Ast.RuleSet.singleton rule_name then
Errors.raise_spanned_error pos "Cannot define rule as an exception to itself" Errors.raise_spanned_error pos
"Cannot define rule as an exception to itself"
else else
let edge = let edge =
ExceptionsDependencies.E.create (Ast.RuleSet.singleton rule_name) pos ExceptionsDependencies.E.create
exception_to_ruleset (Ast.RuleSet.singleton rule_name)
pos exception_to_ruleset
in in
ExceptionsDependencies.add_edge_e g edge ExceptionsDependencies.add_edge_e g edge
else else
Errors.raise_spanned_error pos Errors.raise_spanned_error pos
"This rule has been declared as an exception to an incorrect label: this label is not \ "This rule has been declared as an exception to an incorrect \
attached to a definition of \"%a\"" label: this label is not attached to a definition of \"%a\""
Ast.ScopeDef.format_t def_info) Ast.ScopeDef.format_t def_info)
def g def g
in in
@ -310,7 +345,8 @@ let build_exceptions_graph (def : Ast.rule Ast.RuleMap.t) (def_info : Ast.ScopeD
(** Outputs an error in case of cycles. *) (** Outputs an error in case of cycles. *)
let check_for_exception_cycle (g : ExceptionsDependencies.t) : unit = let check_for_exception_cycle (g : ExceptionsDependencies.t) : unit =
(* if there is a cycle, there will be an strongly connected component of cardinality > 1 *) (* if there is a cycle, there will be an strongly connected component of
cardinality > 1 *)
let sccs = ExceptionsSCC.scc_list g in let sccs = ExceptionsSCC.scc_list g in
if List.length sccs < ExceptionsDependencies.nb_vertex g then if List.length sccs < ExceptionsDependencies.nb_vertex g then
let scc = List.find (fun scc -> List.length scc > 1) sccs in let scc = List.find (fun scc -> List.length scc > 1) sccs in
@ -320,20 +356,24 @@ let check_for_exception_cycle (g : ExceptionsDependencies.t) : unit =
(fun (vs : Ast.RuleSet.t) -> (fun (vs : Ast.RuleSet.t) ->
let v = Ast.RuleSet.choose vs in let v = Ast.RuleSet.choose vs in
let var_str, var_info = let var_str, var_info =
(Format.asprintf "%a" Ast.RuleName.format_t v, Ast.RuleName.get_info v) ( Format.asprintf "%a" Ast.RuleName.format_t v,
Ast.RuleName.get_info v )
in in
let succs = ExceptionsDependencies.succ_e g vs in let succs = ExceptionsDependencies.succ_e g vs in
let _, edge_pos, _ = List.find (fun (_, _, succ) -> List.mem succ scc) succs in let _, edge_pos, _ =
List.find (fun (_, _, succ) -> List.mem succ scc) succs
in
[ [
( Some ( Some
("Cyclic exception for definition of variable \"" ^ var_str ("Cyclic exception for definition of variable \"" ^ var_str
^ "\", declared here:"), ^ "\", declared here:"),
Pos.get_position var_info ); Pos.get_position var_info );
( Some ( Some
("Used here in the definition of another cyclic exception for defining \"" ("Used here in the definition of another cyclic exception \
^ var_str ^ "\":"), for defining \"" ^ var_str ^ "\":"),
edge_pos ); edge_pos );
]) ])
scc) scc)
in in
Errors.raise_multispanned_error spans "Cyclic dependency detected between exceptions!" Errors.raise_multispanned_error spans
"Cyclic dependency detected between exceptions!"

View File

@ -1,18 +1,21 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits (* This file is part of the Catala compiler, a specification language for tax
computation rules. Copyright (C) 2020 Inria, contributor: Nicolas Chataing and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
<nicolas.chataing@ens.fr> Nicolas Chataing <nicolas.chataing@ens.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
(** Scope dependencies computations using {{:http://ocamlgraph.lri.fr/} OCamlgraph} *) (** Scope dependencies computations using {{:http://ocamlgraph.lri.fr/}
OCamlgraph} *)
open Utils open Utils
@ -40,20 +43,22 @@ module Vertex : sig
end end
module Edge : Graph.Sig.ORDERED_TYPE_DFT with type t = Pos.t module Edge : Graph.Sig.ORDERED_TYPE_DFT with type t = Pos.t
(** On the edges, the label is the position of the expression responsible for the use of the (** On the edges, the label is the position of the expression responsible for
variable. In the graph, [x -> y] if [x] is used in the definition of [y].*) the use of the variable. In the graph, [x -> y] if [x] is used in the
definition of [y].*)
(** Module of the graph, provided by OCamlGraph *) (** Module of the graph, provided by OCamlGraph *)
module ScopeDependencies : Graph.Sig.P with type V.t = Vertex.t and type E.label = Edge.t module ScopeDependencies :
Graph.Sig.P with type V.t = Vertex.t and type E.label = Edge.t
(** {2 Graph computations} *) (** {2 Graph computations} *)
(** Returns an ordering of the scope variables and subscope compatible with the dependencies of the (** Returns an ordering of the scope variables and subscope compatible with the
computation *) dependencies of the computation *)
val correct_computation_ordering : ScopeDependencies.t -> Vertex.t list val correct_computation_ordering : ScopeDependencies.t -> Vertex.t list
(** Returns an ordering of the scope variables and subscope compatible with the dependencies of the (** Returns an ordering of the scope variables and subscope compatible with the
computation *) dependencies of the computation *)
val check_for_cycle : Ast.scope -> ScopeDependencies.t -> unit val check_for_cycle : Ast.scope -> ScopeDependencies.t -> unit
(** Outputs an error in case of cycles. *) (** Outputs an error in case of cycles. *)
@ -63,8 +68,10 @@ val build_scope_dependencies : Ast.scope -> ScopeDependencies.t
(** {1 Exceptions dependency graph} *) (** {1 Exceptions dependency graph} *)
module ExceptionsDependencies : Graph.Sig.P with type V.t = Ast.RuleSet.t and type E.label = Edge.t module ExceptionsDependencies :
Graph.Sig.P with type V.t = Ast.RuleSet.t and type E.label = Edge.t
val build_exceptions_graph : Ast.rule Ast.RuleMap.t -> Ast.ScopeDef.t -> ExceptionsDependencies.t val build_exceptions_graph :
Ast.rule Ast.RuleMap.t -> Ast.ScopeDef.t -> ExceptionsDependencies.t
val check_for_exception_cycle : ExceptionsDependencies.t -> unit val check_for_exception_cycle : ExceptionsDependencies.t -> unit

View File

@ -1,15 +1,17 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits (* This file is part of the Catala compiler, a specification language for tax
computation rules. Copyright (C) 2020 Inria, contributor: Denis Merigoux and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
<denis.merigoux@inria.fr> Denis Merigoux <denis.merigoux@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
(** Translation from {!module: Desugared.Ast} to {!module: Scopelang.Ast} *) (** Translation from {!module: Desugared.Ast} to {!module: Scopelang.Ast} *)
@ -31,20 +33,24 @@ let rec translate_expr (ctx : ctx) (e : Ast.expr Pos.marked) :
Scopelang.Ast.expr Pos.marked Bindlib.box = Scopelang.Ast.expr Pos.marked Bindlib.box =
match Pos.unmark e with match Pos.unmark e with
| Ast.ELocation (SubScopeVar (s_name, ss_name, s_var)) -> | Ast.ELocation (SubScopeVar (s_name, ss_name, s_var)) ->
(* When referring to a subscope variable in an expression, we are referring to the output, (* When referring to a subscope variable in an expression, we are
hence we take the last state. *) referring to the output, hence we take the last state. *)
let new_s_var = let new_s_var =
match Ast.ScopeVarMap.find (Pos.unmark s_var) ctx.scope_var_mapping with match Ast.ScopeVarMap.find (Pos.unmark s_var) ctx.scope_var_mapping with
| WholeVar new_s_var -> Pos.same_pos_as new_s_var s_var | WholeVar new_s_var -> Pos.same_pos_as new_s_var s_var
| States states -> Pos.same_pos_as (snd (List.hd (List.rev states))) s_var | States states ->
Pos.same_pos_as (snd (List.hd (List.rev states))) s_var
in in
Bindlib.box Bindlib.box
(Scopelang.Ast.ELocation (SubScopeVar (s_name, ss_name, new_s_var)), Pos.get_position e) ( Scopelang.Ast.ELocation (SubScopeVar (s_name, ss_name, new_s_var)),
Pos.get_position e )
| Ast.ELocation (ScopeVar (s_var, None)) -> | Ast.ELocation (ScopeVar (s_var, None)) ->
Bindlib.box Bindlib.box
( Scopelang.Ast.ELocation ( Scopelang.Ast.ELocation
(ScopeVar (ScopeVar
(match Ast.ScopeVarMap.find (Pos.unmark s_var) ctx.scope_var_mapping with (match
Ast.ScopeVarMap.find (Pos.unmark s_var) ctx.scope_var_mapping
with
| WholeVar new_s_var -> Pos.same_pos_as new_s_var s_var | WholeVar new_s_var -> Pos.same_pos_as new_s_var s_var
| States _ -> failwith "should not happen")), | States _ -> failwith "should not happen")),
Pos.get_position e ) Pos.get_position e )
@ -52,9 +58,12 @@ let rec translate_expr (ctx : ctx) (e : Ast.expr Pos.marked) :
Bindlib.box Bindlib.box
( Scopelang.Ast.ELocation ( Scopelang.Ast.ELocation
(ScopeVar (ScopeVar
(match Ast.ScopeVarMap.find (Pos.unmark s_var) ctx.scope_var_mapping with (match
Ast.ScopeVarMap.find (Pos.unmark s_var) ctx.scope_var_mapping
with
| WholeVar _ -> failwith "should not happen" | WholeVar _ -> failwith "should not happen"
| States states -> Pos.same_pos_as (List.assoc state states) s_var)), | States states ->
Pos.same_pos_as (List.assoc state states) s_var)),
Pos.get_position e ) Pos.get_position e )
| Ast.EVar v -> | Ast.EVar v ->
Bindlib.box_apply Bindlib.box_apply
@ -62,16 +71,20 @@ let rec translate_expr (ctx : ctx) (e : Ast.expr Pos.marked) :
(Bindlib.box_var (Ast.VarMap.find (Pos.unmark v) ctx.var_mapping)) (Bindlib.box_var (Ast.VarMap.find (Pos.unmark v) ctx.var_mapping))
| EStruct (s_name, fields) -> | EStruct (s_name, fields) ->
Bindlib.box_apply Bindlib.box_apply
(fun new_fields -> (Scopelang.Ast.EStruct (s_name, new_fields), Pos.get_position e)) (fun new_fields ->
(Scopelang.Ast.EStruct (s_name, new_fields), Pos.get_position e))
(Scopelang.Ast.StructFieldMapLift.lift_box (Scopelang.Ast.StructFieldMapLift.lift_box
(Scopelang.Ast.StructFieldMap.map (translate_expr ctx) fields)) (Scopelang.Ast.StructFieldMap.map (translate_expr ctx) fields))
| EStructAccess (e1, s_name, f_name) -> | EStructAccess (e1, s_name, f_name) ->
Bindlib.box_apply Bindlib.box_apply
(fun new_e1 -> (Scopelang.Ast.EStructAccess (new_e1, s_name, f_name), Pos.get_position e)) (fun new_e1 ->
( Scopelang.Ast.EStructAccess (new_e1, s_name, f_name),
Pos.get_position e ))
(translate_expr ctx e1) (translate_expr ctx e1)
| EEnumInj (e1, cons, e_name) -> | EEnumInj (e1, cons, e_name) ->
Bindlib.box_apply Bindlib.box_apply
(fun new_e1 -> (Scopelang.Ast.EEnumInj (new_e1, cons, e_name), Pos.get_position e)) (fun new_e1 ->
(Scopelang.Ast.EEnumInj (new_e1, cons, e_name), Pos.get_position e))
(translate_expr ctx e1) (translate_expr ctx e1)
| EMatch (e1, e_name, arms) -> | EMatch (e1, e_name, arms) ->
Bindlib.box_apply2 Bindlib.box_apply2
@ -84,34 +97,43 @@ let rec translate_expr (ctx : ctx) (e : Ast.expr Pos.marked) :
| EAbs ((binder, binder_pos), typs) -> | EAbs ((binder, binder_pos), typs) ->
let vars, body = Bindlib.unmbind binder in let vars, body = Bindlib.unmbind binder in
let new_vars = let new_vars =
Array.map (fun var -> Scopelang.Ast.Var.make (Bindlib.name_of var, binder_pos)) vars Array.map
(fun var -> Scopelang.Ast.Var.make (Bindlib.name_of var, binder_pos))
vars
in in
let ctx = let ctx =
List.fold_left2 List.fold_left2
(fun ctx var new_var -> (fun ctx var new_var ->
{ ctx with var_mapping = Ast.VarMap.add var new_var ctx.var_mapping }) {
ctx with
var_mapping = Ast.VarMap.add var new_var ctx.var_mapping;
})
ctx (Array.to_list vars) (Array.to_list new_vars) ctx (Array.to_list vars) (Array.to_list new_vars)
in in
Bindlib.box_apply Bindlib.box_apply
(fun new_binder -> (fun new_binder ->
(Scopelang.Ast.EAbs ((new_binder, binder_pos), typs), Pos.get_position e)) ( Scopelang.Ast.EAbs ((new_binder, binder_pos), typs),
Pos.get_position e ))
(Bindlib.bind_mvar new_vars (translate_expr ctx body)) (Bindlib.bind_mvar new_vars (translate_expr ctx body))
| EApp (e1, args) -> | EApp (e1, args) ->
Bindlib.box_apply2 Bindlib.box_apply2
(fun new_e1 new_args -> (Scopelang.Ast.EApp (new_e1, new_args), Pos.get_position e)) (fun new_e1 new_args ->
(Scopelang.Ast.EApp (new_e1, new_args), Pos.get_position e))
(translate_expr ctx e1) (translate_expr ctx e1)
(Bindlib.box_list (List.map (translate_expr ctx) args)) (Bindlib.box_list (List.map (translate_expr ctx) args))
| EOp op -> Bindlib.box (Scopelang.Ast.EOp op, Pos.get_position e) | EOp op -> Bindlib.box (Scopelang.Ast.EOp op, Pos.get_position e)
| EDefault (excepts, just, cons) -> | EDefault (excepts, just, cons) ->
Bindlib.box_apply3 Bindlib.box_apply3
(fun new_excepts new_just new_cons -> (fun new_excepts new_just new_cons ->
(Scopelang.Ast.EDefault (new_excepts, new_just, new_cons), Pos.get_position e)) ( Scopelang.Ast.EDefault (new_excepts, new_just, new_cons),
Pos.get_position e ))
(Bindlib.box_list (List.map (translate_expr ctx) excepts)) (Bindlib.box_list (List.map (translate_expr ctx) excepts))
(translate_expr ctx just) (translate_expr ctx cons) (translate_expr ctx just) (translate_expr ctx cons)
| EIfThenElse (e1, e2, e3) -> | EIfThenElse (e1, e2, e3) ->
Bindlib.box_apply3 Bindlib.box_apply3
(fun new_e1 new_e2 new_e3 -> (fun new_e1 new_e2 new_e3 ->
(Scopelang.Ast.EIfThenElse (new_e1, new_e2, new_e3), Pos.get_position e)) ( Scopelang.Ast.EIfThenElse (new_e1, new_e2, new_e3),
Pos.get_position e ))
(translate_expr ctx e1) (translate_expr ctx e2) (translate_expr ctx e3) (translate_expr ctx e1) (translate_expr ctx e2) (translate_expr ctx e3)
| EArray args -> | EArray args ->
Bindlib.box_apply Bindlib.box_apply
@ -124,29 +146,39 @@ let rec translate_expr (ctx : ctx) (e : Ast.expr Pos.marked) :
(** {1 Rule tree construction} *) (** {1 Rule tree construction} *)
(** Intermediate representation for the exception tree of rules for a particular scope definition. *) (** Intermediate representation for the exception tree of rules for a particular
scope definition. *)
type rule_tree = type rule_tree =
| Leaf of Ast.rule list (** Rules defining a base case piecewise. List is non-empty. *) | Leaf of Ast.rule list
(** Rules defining a base case piecewise. List is non-empty. *)
| Node of rule_tree list * Ast.rule list | Node of rule_tree list * Ast.rule list
(** A list of exceptions to a non-empty list of rules defining a base case piecewise. *) (** A list of exceptions to a non-empty list of rules defining a base case
piecewise. *)
(** Transforms a flat list of rules into a tree, taking into account the priorities declared between (** Transforms a flat list of rules into a tree, taking into account the
rules *) priorities declared between rules *)
let def_map_to_tree (def_info : Ast.ScopeDef.t) (def : Ast.rule Ast.RuleMap.t) : rule_tree list = let def_map_to_tree (def_info : Ast.ScopeDef.t) (def : Ast.rule Ast.RuleMap.t) :
rule_tree list =
let exc_graph = Dependency.build_exceptions_graph def def_info in let exc_graph = Dependency.build_exceptions_graph def def_info in
Dependency.check_for_exception_cycle exc_graph; Dependency.check_for_exception_cycle exc_graph;
(* we start by the base cases: they are the vertices which have no successors *) (* we start by the base cases: they are the vertices which have no
successors *)
let base_cases = let base_cases =
Dependency.ExceptionsDependencies.fold_vertex Dependency.ExceptionsDependencies.fold_vertex
(fun v base_cases -> (fun v base_cases ->
if Dependency.ExceptionsDependencies.out_degree exc_graph v = 0 then v :: base_cases if Dependency.ExceptionsDependencies.out_degree exc_graph v = 0 then
v :: base_cases
else base_cases) else base_cases)
exc_graph [] exc_graph []
in in
let rec build_tree (base_cases : Ast.RuleSet.t) : rule_tree = let rec build_tree (base_cases : Ast.RuleSet.t) : rule_tree =
let exceptions = Dependency.ExceptionsDependencies.pred exc_graph base_cases in let exceptions =
Dependency.ExceptionsDependencies.pred exc_graph base_cases
in
let base_case_as_rule_list = let base_case_as_rule_list =
List.map (fun r -> Ast.RuleMap.find r def) (List.of_seq (Ast.RuleSet.to_seq base_cases)) List.map
(fun r -> Ast.RuleMap.find r def)
(List.of_seq (Ast.RuleSet.to_seq base_cases))
in in
match exceptions with match exceptions with
| [] -> Leaf base_case_as_rule_list | [] -> Leaf base_case_as_rule_list
@ -154,24 +186,31 @@ let def_map_to_tree (def_info : Ast.ScopeDef.t) (def : Ast.rule Ast.RuleMap.t) :
in in
List.map build_tree base_cases List.map build_tree base_cases
(** From the {!type: rule_tree}, builds an {!constructor: Dcalc.Ast.EDefault} expression in the (** From the {!type: rule_tree}, builds an {!constructor: Dcalc.Ast.EDefault}
scope language. The [~toplevel] parameter is used to know when to place the toplevel binding in expression in the scope language. The [~toplevel] parameter is used to know
the case of functions. *) when to place the toplevel binding in the case of functions. *)
let rec rule_tree_to_expr ~(toplevel : bool) (ctx : ctx) (def_pos : Pos.t) let rec rule_tree_to_expr
(is_func : Ast.Var.t option) (tree : rule_tree) : Scopelang.Ast.expr Pos.marked Bindlib.box = ~(toplevel : bool)
(ctx : ctx)
(def_pos : Pos.t)
(is_func : Ast.Var.t option)
(tree : rule_tree) : Scopelang.Ast.expr Pos.marked Bindlib.box =
let exceptions, base_rules = let exceptions, base_rules =
match tree with Leaf r -> ([], r) | Node (exceptions, r) -> (exceptions, r) match tree with Leaf r -> ([], r) | Node (exceptions, r) -> (exceptions, r)
in in
(* because each rule has its own variable parameter and we want to convert the whole rule tree (* because each rule has its own variable parameter and we want to convert the
into a function, we need to perform some alpha-renaming of all the expressions *) whole rule tree into a function, we need to perform some alpha-renaming of
let substitute_parameter (e : Ast.expr Pos.marked Bindlib.box) (rule : Ast.rule) : all the expressions *)
let substitute_parameter
(e : Ast.expr Pos.marked Bindlib.box) (rule : Ast.rule) :
Ast.expr Pos.marked Bindlib.box = Ast.expr Pos.marked Bindlib.box =
match (is_func, rule.Ast.rule_parameter) with match (is_func, rule.Ast.rule_parameter) with
| Some new_param, Some (old_param, _) -> | Some new_param, Some (old_param, _) ->
let binder = Bindlib.bind_var old_param e in let binder = Bindlib.bind_var old_param e in
Bindlib.box_apply2 Bindlib.box_apply2
(fun binder new_param -> Bindlib.subst binder new_param) (fun binder new_param -> Bindlib.subst binder new_param)
binder (Bindlib.box_var new_param) binder
(Bindlib.box_var new_param)
| None, None -> e | None, None -> e
| _ -> assert false | _ -> assert false
(* should not happen *) (* should not happen *)
@ -182,27 +221,38 @@ let rec rule_tree_to_expr ~(toplevel : bool) (ctx : ctx) (def_pos : Pos.t)
| Some new_param -> ( | Some new_param -> (
match Ast.VarMap.find_opt new_param ctx.var_mapping with match Ast.VarMap.find_opt new_param ctx.var_mapping with
| None -> | None ->
let new_param_scope = Scopelang.Ast.Var.make (Bindlib.name_of new_param, def_pos) in let new_param_scope =
{ ctx with var_mapping = Ast.VarMap.add new_param new_param_scope ctx.var_mapping } Scopelang.Ast.Var.make (Bindlib.name_of new_param, def_pos)
in
{
ctx with
var_mapping =
Ast.VarMap.add new_param new_param_scope ctx.var_mapping;
}
| Some _ -> | Some _ ->
(* We only create a mapping if none exists because [rule_tree_to_expr] is called (* We only create a mapping if none exists because
recursively on the exceptions of the tree and we don't want to create a new Scopelang [rule_tree_to_expr] is called recursively on the exceptions of
variable for the parameter at each tree level. *) the tree and we don't want to create a new Scopelang variable for
the parameter at each tree level. *)
ctx) ctx)
in in
let base_just_list = let base_just_list =
List.map (fun rule -> substitute_parameter rule.Ast.rule_just rule) base_rules List.map
(fun rule -> substitute_parameter rule.Ast.rule_just rule)
base_rules
in in
let base_cons_list = let base_cons_list =
List.map (fun rule -> substitute_parameter rule.Ast.rule_cons rule) base_rules List.map
(fun rule -> substitute_parameter rule.Ast.rule_cons rule)
base_rules
in in
let translate_and_unbox_list (list : Ast.expr Pos.marked Bindlib.box list) : let translate_and_unbox_list (list : Ast.expr Pos.marked Bindlib.box list) :
Scopelang.Ast.expr Pos.marked Bindlib.box list = Scopelang.Ast.expr Pos.marked Bindlib.box list =
List.map List.map
(fun e -> (fun e ->
(* There are two levels of boxing here, the outermost is introduced by the [translate_expr] (* There are two levels of boxing here, the outermost is introduced by
function for which all of the bindings should have been closed by now, so we can safely the [translate_expr] function for which all of the bindings should
unbox. *) have been closed by now, so we can safely unbox. *)
Bindlib.unbox (Bindlib.box_apply (translate_expr ctx) e)) Bindlib.unbox (Bindlib.box_apply (translate_expr ctx) e))
list list
in in
@ -212,7 +262,8 @@ let rec rule_tree_to_expr ~(toplevel : bool) (ctx : ctx) (def_pos : Pos.t)
( Scopelang.Ast.EDefault ( Scopelang.Ast.EDefault
( List.map2 ( List.map2
(fun base_just base_cons -> (fun base_just base_cons ->
(Scopelang.Ast.EDefault ([], base_just, base_cons), Pos.get_position base_just)) ( Scopelang.Ast.EDefault ([], base_just, base_cons),
Pos.get_position base_just ))
base_just_list base_cons_list, base_just_list base_cons_list,
(Scopelang.Ast.ELit (Dcalc.Ast.LBool false), def_pos), (Scopelang.Ast.ELit (Dcalc.Ast.LBool false), def_pos),
(Scopelang.Ast.ELit Dcalc.Ast.LEmptyError, def_pos) ), (Scopelang.Ast.ELit Dcalc.Ast.LEmptyError, def_pos) ),
@ -221,7 +272,10 @@ let rec rule_tree_to_expr ~(toplevel : bool) (ctx : ctx) (def_pos : Pos.t)
(Bindlib.box_list (translate_and_unbox_list base_cons_list)) (Bindlib.box_list (translate_and_unbox_list base_cons_list))
in in
let exceptions = let exceptions =
Bindlib.box_list (List.map (rule_tree_to_expr ~toplevel:false ctx def_pos is_func) exceptions) Bindlib.box_list
(List.map
(rule_tree_to_expr ~toplevel:false ctx def_pos is_func)
exceptions)
in in
let default = let default =
Bindlib.box_apply2 Bindlib.box_apply2
@ -237,8 +291,8 @@ let rec rule_tree_to_expr ~(toplevel : bool) (ctx : ctx) (def_pos : Pos.t)
| None, None -> default | None, None -> default
| Some new_param, Some (_, typ) -> | Some new_param, Some (_, typ) ->
if toplevel then if toplevel then
(* When we're creating a function from multiple defaults, we must check that the result (* When we're creating a function from multiple defaults, we must check
returned by the function is not empty *) that the result returned by the function is not empty *)
let default = let default =
Bindlib.box_apply Bindlib.box_apply
(fun (default : Scopelang.Ast.expr * Pos.t) -> (fun (default : Scopelang.Ast.expr * Pos.t) ->
@ -253,74 +307,98 @@ let rec rule_tree_to_expr ~(toplevel : bool) (ctx : ctx) (def_pos : Pos.t)
(** {1 AST translation} *) (** {1 AST translation} *)
(** Translates a definition inside a scope, the resulting expression should be an {!constructor: (** Translates a definition inside a scope, the resulting expression should be
Dcalc.Ast.EDefault} *) an {!constructor: Dcalc.Ast.EDefault} *)
let translate_def (ctx : ctx) (def_info : Ast.ScopeDef.t) (def : Ast.rule Ast.RuleMap.t) let translate_def
(typ : Scopelang.Ast.typ Pos.marked) (io : Scopelang.Ast.io) ~(is_cond : bool) (ctx : ctx)
(def_info : Ast.ScopeDef.t)
(def : Ast.rule Ast.RuleMap.t)
(typ : Scopelang.Ast.typ Pos.marked)
(io : Scopelang.Ast.io)
~(is_cond : bool)
~(is_subscope_var : bool) : Scopelang.Ast.expr Pos.marked = ~(is_subscope_var : bool) : Scopelang.Ast.expr Pos.marked =
(* Here, we have to transform this list of rules into a default tree. *) (* Here, we have to transform this list of rules into a default tree. *)
let is_def_func = match Pos.unmark typ with Scopelang.Ast.TArrow (_, _) -> true | _ -> false in let is_def_func =
let is_rule_func _ (r : Ast.rule) : bool = Option.is_some r.Ast.rule_parameter in match Pos.unmark typ with Scopelang.Ast.TArrow (_, _) -> true | _ -> false
in
let is_rule_func _ (r : Ast.rule) : bool =
Option.is_some r.Ast.rule_parameter
in
let all_rules_func = Ast.RuleMap.for_all is_rule_func def in let all_rules_func = Ast.RuleMap.for_all is_rule_func def in
let all_rules_not_func = Ast.RuleMap.for_all (fun n r -> not (is_rule_func n r)) def in let all_rules_not_func =
Ast.RuleMap.for_all (fun n r -> not (is_rule_func n r)) def
in
let is_def_func_param_typ : Scopelang.Ast.typ Pos.marked option = let is_def_func_param_typ : Scopelang.Ast.typ Pos.marked option =
if is_def_func && all_rules_func then if is_def_func && all_rules_func then
match Pos.unmark typ with match Pos.unmark typ with
| Scopelang.Ast.TArrow (t_param, _) -> Some t_param | Scopelang.Ast.TArrow (t_param, _) -> Some t_param
| _ -> | _ ->
Errors.raise_spanned_error (Pos.get_position typ) Errors.raise_spanned_error (Pos.get_position typ)
"The definitions of %a are function but its type, %a, is not a function type" "The definitions of %a are function but its type, %a, is not a \
function type"
Ast.ScopeDef.format_t def_info Scopelang.Print.format_typ typ Ast.ScopeDef.format_t def_info Scopelang.Print.format_typ typ
else if (not is_def_func) && all_rules_not_func then None else if (not is_def_func) && all_rules_not_func then None
else else
let spans = let spans =
List.map List.map
(fun (_, r) -> (fun (_, r) ->
(Some "This definition is a function:", Pos.get_position (Bindlib.unbox r.Ast.rule_cons))) ( Some "This definition is a function:",
Pos.get_position (Bindlib.unbox r.Ast.rule_cons) ))
(Ast.RuleMap.bindings (Ast.RuleMap.filter is_rule_func def)) (Ast.RuleMap.bindings (Ast.RuleMap.filter is_rule_func def))
@ List.map @ List.map
(fun (_, r) -> (fun (_, r) ->
( Some "This definition is not a function:", ( Some "This definition is not a function:",
Pos.get_position (Bindlib.unbox r.Ast.rule_cons) )) Pos.get_position (Bindlib.unbox r.Ast.rule_cons) ))
(Ast.RuleMap.bindings (Ast.RuleMap.filter (fun n r -> not (is_rule_func n r)) def)) (Ast.RuleMap.bindings
(Ast.RuleMap.filter (fun n r -> not (is_rule_func n r)) def))
in in
Errors.raise_multispanned_error spans Errors.raise_multispanned_error spans
"some definitions of the same variable are functions while others aren't" "some definitions of the same variable are functions while others \
aren't"
in in
let top_list = def_map_to_tree def_info def in let top_list = def_map_to_tree def_info def in
let top_value = let top_value =
(if is_cond then Ast.always_false_rule else Ast.empty_rule) Pos.no_pos is_def_func_param_typ (if is_cond then Ast.always_false_rule else Ast.empty_rule)
Pos.no_pos is_def_func_param_typ
in in
if if
Ast.RuleMap.cardinal def = 0 Ast.RuleMap.cardinal def = 0
&& is_subscope_var && is_subscope_var
(* Here we have a special case for the empty definitions. Indeed, we could use the code for the (* Here we have a special case for the empty definitions. Indeed, we could
regular case below that would create a convoluted default always returning empty error, and use the code for the regular case below that would create a convoluted
this would be correct. But it gets more complicated with functions. Indeed, if we create an default always returning empty error, and this would be correct. But it
empty definition for a subscope argument whose type is a function, we get something like [fun gets more complicated with functions. Indeed, if we create an empty
() -> (fun real_param -> < ... >)] that is passed as an argument to the subscope. The definition for a subscope argument whose type is a function, we get
sub-scope de-thunks but the de-thunking does not return empty error, signalling there is not something like [fun () -> (fun real_param -> < ... >)] that is passed as
reentrant variable, because functions are values! So the subscope does not see that there is an argument to the subscope. The sub-scope de-thunks but the de-thunking
not reentrant variable and does not pick its internal definition instead. See does not return empty error, signalling there is not reentrant variable,
[test/test_scope/subscope_function_arg_not_defined.catala_en] for a test case exercising that because functions are values! So the subscope does not see that there is
subtlety. not reentrant variable and does not pick its internal definition instead.
See [test/test_scope/subscope_function_arg_not_defined.catala_en] for a
test case exercising that subtlety.
To avoid this complication we special case here and put an empty error for all subscope To avoid this complication we special case here and put an empty error
variables that are not defined. It covers the subtlety with functions described above but for all subscope variables that are not defined. It covers the subtlety
also conditions with the false default value. *) with functions described above but also conditions with the false default
value. *)
&& not && not
(is_cond (is_cond
&& match Pos.unmark io.Scopelang.Ast.io_input with OnlyInput -> true | _ -> false) &&
(* However, this special case suffers from an exception: when a condition is defined as an match Pos.unmark io.Scopelang.Ast.io_input with
OnlyInput to a subscope, since the [false] default value will not be provided by the calee | OnlyInput -> true
scope, it has to be placed in the caller. *) | _ -> false)
(* However, this special case suffers from an exception: when a condition is
defined as an OnlyInput to a subscope, since the [false] default value
will not be provided by the calee scope, it has to be placed in the
caller. *)
then (ELit LEmptyError, Pos.no_pos) then (ELit LEmptyError, Pos.no_pos)
else else
Bindlib.unbox Bindlib.unbox
(rule_tree_to_expr ~toplevel:true ctx (rule_tree_to_expr ~toplevel:true ctx
(Ast.ScopeDef.get_position def_info) (Ast.ScopeDef.get_position def_info)
(Option.map (Option.map
(fun _ -> Ast.Var.make ("param", Ast.ScopeDef.get_position def_info)) (fun _ ->
Ast.Var.make ("param", Ast.ScopeDef.get_position def_info))
is_def_func_param_typ) is_def_func_param_typ)
(match top_list with (match top_list with
| [] -> | [] ->
@ -332,7 +410,9 @@ let translate_def (ctx : ctx) (def_info : Ast.ScopeDef.t) (def : Ast.rule Ast.Ru
let translate_scope (ctx : ctx) (scope : Ast.scope) : Scopelang.Ast.scope_decl = let translate_scope (ctx : ctx) (scope : Ast.scope) : Scopelang.Ast.scope_decl =
let scope_dependencies = Dependency.build_scope_dependencies scope in let scope_dependencies = Dependency.build_scope_dependencies scope in
Dependency.check_for_cycle scope scope_dependencies; Dependency.check_for_cycle scope scope_dependencies;
let scope_ordering = Dependency.correct_computation_ordering scope_dependencies in let scope_ordering =
Dependency.correct_computation_ordering scope_dependencies
in
let scope_decl_rules = let scope_decl_rules =
List.flatten List.flatten
(List.map (List.map
@ -340,31 +420,42 @@ let translate_scope (ctx : ctx) (scope : Ast.scope) : Scopelang.Ast.scope_decl =
match vertex with match vertex with
| Dependency.Vertex.Var (var, state) -> ( | Dependency.Vertex.Var (var, state) -> (
let scope_def = let scope_def =
Ast.ScopeDefMap.find (Ast.ScopeDef.Var (var, state)) scope.scope_defs Ast.ScopeDefMap.find
(Ast.ScopeDef.Var (var, state))
scope.scope_defs
in in
let var_def = scope_def.scope_def_rules in let var_def = scope_def.scope_def_rules in
let var_typ = scope_def.scope_def_typ in let var_typ = scope_def.scope_def_typ in
let is_cond = scope_def.scope_def_is_condition in let is_cond = scope_def.scope_def_is_condition in
match Pos.unmark scope_def.Ast.scope_def_io.io_input with match Pos.unmark scope_def.Ast.scope_def_io.io_input with
| OnlyInput when not (Ast.RuleMap.is_empty var_def) -> | OnlyInput when not (Ast.RuleMap.is_empty var_def) ->
(* If the variable is tagged as input, then it shall not be redefined. *) (* If the variable is tagged as input, then it shall not be
redefined. *)
Errors.raise_multispanned_error Errors.raise_multispanned_error
((Some "Incriminated variable:", Pos.get_position (Ast.ScopeVar.get_info var)) (( Some "Incriminated variable:",
Pos.get_position (Ast.ScopeVar.get_info var) )
:: List.map :: List.map
(fun (rule, _) -> (fun (rule, _) ->
( Some "Incriminated variable definition:", ( Some "Incriminated variable definition:",
Pos.get_position (Ast.RuleName.get_info rule) )) Pos.get_position (Ast.RuleName.get_info rule) ))
(Ast.RuleMap.bindings var_def)) (Ast.RuleMap.bindings var_def))
"It is impossible to give a definition to a scope variable tagged as input." "It is impossible to give a definition to a scope \
| OnlyInput -> [] (* we do not provide any definition for an input-only variable *) variable tagged as input."
| OnlyInput ->
[]
(* we do not provide any definition for an input-only
variable *)
| _ -> | _ ->
let expr_def = let expr_def =
translate_def ctx translate_def ctx
(Ast.ScopeDef.Var (var, state)) (Ast.ScopeDef.Var (var, state))
var_def var_typ scope_def.Ast.scope_def_io ~is_cond ~is_subscope_var:false var_def var_typ scope_def.Ast.scope_def_io ~is_cond
~is_subscope_var:false
in in
let scope_var = let scope_var =
match (Ast.ScopeVarMap.find var ctx.scope_var_mapping, state) with match
(Ast.ScopeVarMap.find var ctx.scope_var_mapping, state)
with
| WholeVar v, None -> v | WholeVar v, None -> v
| States states, Some state -> List.assoc state states | States states, Some state -> List.assoc state states
| _ -> failwith "should not happen" | _ -> failwith "should not happen"
@ -373,17 +464,20 @@ let translate_scope (ctx : ctx) (scope : Ast.scope) : Scopelang.Ast.scope_decl =
Scopelang.Ast.Definition Scopelang.Ast.Definition
( ( Scopelang.Ast.ScopeVar ( ( Scopelang.Ast.ScopeVar
( scope_var, ( scope_var,
Pos.get_position (Scopelang.Ast.ScopeVar.get_info scope_var) ), Pos.get_position
Pos.get_position (Scopelang.Ast.ScopeVar.get_info scope_var) ), (Scopelang.Ast.ScopeVar.get_info scope_var) ),
Pos.get_position
(Scopelang.Ast.ScopeVar.get_info scope_var) ),
var_typ, var_typ,
scope_def.Ast.scope_def_io, scope_def.Ast.scope_def_io,
expr_def ); expr_def );
]) ])
| Dependency.Vertex.SubScope sub_scope_index -> | Dependency.Vertex.SubScope sub_scope_index ->
(* Before calling the sub_scope, we need to include all the re-definitions of (* Before calling the sub_scope, we need to include all the
subscope parameters*) re-definitions of subscope parameters*)
let sub_scope = let sub_scope =
Scopelang.Ast.SubScopeMap.find sub_scope_index scope.scope_sub_scopes Scopelang.Ast.SubScopeMap.find sub_scope_index
scope.scope_sub_scopes
in in
let sub_scope_vars_redefs_candidates = let sub_scope_vars_redefs_candidates =
Ast.ScopeDefMap.filter Ast.ScopeDefMap.filter
@ -392,13 +486,17 @@ let translate_scope (ctx : ctx) (scope : Ast.scope) : Scopelang.Ast.scope_decl =
| Ast.ScopeDef.Var _ -> false | Ast.ScopeDef.Var _ -> false
| Ast.ScopeDef.SubScopeVar (sub_scope_index', _) -> | Ast.ScopeDef.SubScopeVar (sub_scope_index', _) ->
sub_scope_index = sub_scope_index' sub_scope_index = sub_scope_index'
(* We exclude subscope variables that have 0 re-definitions and are not (* We exclude subscope variables that have 0
visible in the input of the subscope *) re-definitions and are not visible in the input of
the subscope *)
&& not && not
((match Pos.unmark scope_def.Ast.scope_def_io.io_input with ((match
Pos.unmark scope_def.Ast.scope_def_io.io_input
with
| Scopelang.Ast.NoInput -> true | Scopelang.Ast.NoInput -> true
| _ -> false) | _ -> false)
&& Ast.RuleMap.is_empty scope_def.scope_def_rules)) && Ast.RuleMap.is_empty scope_def.scope_def_rules
))
scope.scope_defs scope.scope_defs
in in
let sub_scope_vars_redefs = let sub_scope_vars_redefs =
@ -408,57 +506,78 @@ let translate_scope (ctx : ctx) (scope : Ast.scope) : Scopelang.Ast.scope_decl =
let def_typ = scope_def.scope_def_typ in let def_typ = scope_def.scope_def_typ in
let is_cond = scope_def.scope_def_is_condition in let is_cond = scope_def.scope_def_is_condition in
match def_key with match def_key with
| Ast.ScopeDef.Var _ -> assert false (* should not happen *) | Ast.ScopeDef.Var _ ->
assert false (* should not happen *)
| Ast.ScopeDef.SubScopeVar (_, sub_scope_var) -> | Ast.ScopeDef.SubScopeVar (_, sub_scope_var) ->
(* This definition redefines a variable of the correct subscope. But we (* This definition redefines a variable of the correct
have to check that this redefinition is allowed with respect to the io subscope. But we have to check that this
redefinition is allowed with respect to the io
parameters of that subscope variable. *) parameters of that subscope variable. *)
(match Pos.unmark scope_def.Ast.scope_def_io.io_input with (match
Pos.unmark scope_def.Ast.scope_def_io.io_input
with
| Scopelang.Ast.NoInput -> | Scopelang.Ast.NoInput ->
Errors.raise_multispanned_error Errors.raise_multispanned_error
((Some "Incriminated subscope:", Ast.ScopeDef.get_position def_key) (( Some "Incriminated subscope:",
Ast.ScopeDef.get_position def_key )
:: ( Some "Incriminated variable:", :: ( Some "Incriminated variable:",
Pos.get_position (Ast.ScopeVar.get_info sub_scope_var) ) Pos.get_position
(Ast.ScopeVar.get_info sub_scope_var) )
:: List.map :: List.map
(fun (rule, _) -> (fun (rule, _) ->
( Some "Incriminated subscope variable definition:", ( Some
Pos.get_position (Ast.RuleName.get_info rule) )) "Incriminated subscope variable \
definition:",
Pos.get_position
(Ast.RuleName.get_info rule) ))
(Ast.RuleMap.bindings def)) (Ast.RuleMap.bindings def))
"It is impossible to give a definition to a subscope variable not \ "It is impossible to give a definition to a \
tagged as input or context." subscope variable not tagged as input or \
| OnlyInput when Ast.RuleMap.is_empty def && not is_cond -> context."
(* If the subscope variable is tagged as input, then it shall be | OnlyInput
defined. *) when Ast.RuleMap.is_empty def && not is_cond ->
(* If the subscope variable is tagged as input,
then it shall be defined. *)
Errors.raise_multispanned_error Errors.raise_multispanned_error
[ [
(Some "Incriminated subscope:", Ast.ScopeDef.get_position def_key); ( Some "Incriminated subscope:",
Ast.ScopeDef.get_position def_key );
( Some "Incriminated variable:", ( Some "Incriminated variable:",
Pos.get_position (Ast.ScopeVar.get_info sub_scope_var) ); Pos.get_position
(Ast.ScopeVar.get_info sub_scope_var) );
] ]
"This subscope variable is a mandatory input but no definition was \ "This subscope variable is a mandatory input \
provided." but no definition was provided."
| _ -> ()); | _ -> ());
(* Now that all is good, we can proceed with translating this redefinition (* Now that all is good, we can proceed with
to a proper Scopelang term. *) translating this redefinition to a proper Scopelang
term. *)
let expr_def = let expr_def =
translate_def ctx def_key def def_typ scope_def.Ast.scope_def_io ~is_cond translate_def ctx def_key def def_typ
scope_def.Ast.scope_def_io ~is_cond
~is_subscope_var:true ~is_subscope_var:true
in in
let subscop_real_name = let subscop_real_name =
Scopelang.Ast.SubScopeMap.find sub_scope_index scope.scope_sub_scopes Scopelang.Ast.SubScopeMap.find sub_scope_index
scope.scope_sub_scopes
in
let var_pos =
Pos.get_position
(Ast.ScopeVar.get_info sub_scope_var)
in in
let var_pos = Pos.get_position (Ast.ScopeVar.get_info sub_scope_var) in
Scopelang.Ast.Definition Scopelang.Ast.Definition
( ( Scopelang.Ast.SubScopeVar ( ( Scopelang.Ast.SubScopeVar
( subscop_real_name, ( subscop_real_name,
(sub_scope_index, var_pos), (sub_scope_index, var_pos),
match match
Ast.ScopeVarMap.find sub_scope_var ctx.scope_var_mapping Ast.ScopeVarMap.find sub_scope_var
ctx.scope_var_mapping
with with
| WholeVar v -> (v, var_pos) | WholeVar v -> (v, var_pos)
| States states -> | States states ->
(* When defining a sub-scope variable, we always define its (* When defining a sub-scope variable, we
first state in the sub-scope. *) always define its first state in the
sub-scope. *)
(snd (List.hd states), var_pos) ), (snd (List.hd states), var_pos) ),
var_pos ), var_pos ),
def_typ, def_typ,
@ -469,17 +588,22 @@ let translate_scope (ctx : ctx) (scope : Ast.scope) : Scopelang.Ast.scope_decl =
let sub_scope_vars_redefs = let sub_scope_vars_redefs =
List.map snd (Ast.ScopeDefMap.bindings sub_scope_vars_redefs) List.map snd (Ast.ScopeDefMap.bindings sub_scope_vars_redefs)
in in
sub_scope_vars_redefs @ [ Scopelang.Ast.Call (sub_scope, sub_scope_index) ]) sub_scope_vars_redefs
@ [ Scopelang.Ast.Call (sub_scope, sub_scope_index) ])
scope_ordering) scope_ordering)
in in
(* Then, after having computed all the scopes variables, we add the assertions. TODO: the (* Then, after having computed all the scopes variables, we add the
assertions should be interleaved with the definitions! *) assertions. TODO: the assertions should be interleaved with the
definitions! *)
let scope_decl_rules = let scope_decl_rules =
scope_decl_rules scope_decl_rules
@ List.map @ List.map
(fun e -> (fun e ->
let scope_e = translate_expr ctx e in let scope_e = translate_expr ctx e in
Bindlib.unbox (Bindlib.box_apply (fun scope_e -> Scopelang.Ast.Assertion scope_e) scope_e)) Bindlib.unbox
(Bindlib.box_apply
(fun scope_e -> Scopelang.Ast.Assertion scope_e)
scope_e))
(Bindlib.unbox (Bindlib.box_list scope.Ast.scope_assertions)) (Bindlib.unbox (Bindlib.box_list scope.Ast.scope_assertions))
in in
let scope_sig = let scope_sig =
@ -487,20 +611,28 @@ let translate_scope (ctx : ctx) (scope : Ast.scope) : Scopelang.Ast.scope_decl =
(fun var (states : Ast.var_or_states) acc -> (fun var (states : Ast.var_or_states) acc ->
match states with match states with
| WholeVar -> | WholeVar ->
let scope_def = Ast.ScopeDefMap.find (Ast.ScopeDef.Var (var, None)) scope.scope_defs in let scope_def =
Ast.ScopeDefMap.find
(Ast.ScopeDef.Var (var, None))
scope.scope_defs
in
let typ = scope_def.scope_def_typ in let typ = scope_def.scope_def_typ in
Scopelang.Ast.ScopeVarMap.add Scopelang.Ast.ScopeVarMap.add
(match Ast.ScopeVarMap.find var ctx.scope_var_mapping with (match Ast.ScopeVarMap.find var ctx.scope_var_mapping with
| WholeVar v -> v | WholeVar v -> v
| States _ -> failwith "should not happen") | States _ -> failwith "should not happen")
(typ, scope_def.scope_def_io) acc (typ, scope_def.scope_def_io)
acc
| States states -> | States states ->
(* What happens in the case of variables with multiple states is interesting. We need to (* What happens in the case of variables with multiple states is
create as many Scopelang.Var entries in the scope signature as there are states. *) interesting. We need to create as many Scopelang.Var entries in
the scope signature as there are states. *)
List.fold_left List.fold_left
(fun acc (state : Ast.StateName.t) -> (fun acc (state : Ast.StateName.t) ->
let scope_def = let scope_def =
Ast.ScopeDefMap.find (Ast.ScopeDef.Var (var, Some state)) scope.scope_defs Ast.ScopeDefMap.find
(Ast.ScopeDef.Var (var, Some state))
scope.scope_defs
in in
Scopelang.Ast.ScopeVarMap.add Scopelang.Ast.ScopeVarMap.add
(match Ast.ScopeVarMap.find var ctx.scope_var_mapping with (match Ast.ScopeVarMap.find var ctx.scope_var_mapping with
@ -520,8 +652,9 @@ let translate_scope (ctx : ctx) (scope : Ast.scope) : Scopelang.Ast.scope_decl =
(** {1 API} *) (** {1 API} *)
let translate_program (pgrm : Ast.program) : Scopelang.Ast.program = let translate_program (pgrm : Ast.program) : Scopelang.Ast.program =
(* First we give mappings to all the locations between Desugared and Scopelang. This involves (* First we give mappings to all the locations between Desugared and
creating a new Scopelang scope variable for every state of a Desugared variable. *) Scopelang. This involves creating a new Scopelang scope variable for every
state of a Desugared variable. *)
let ctx = let ctx =
Scopelang.Ast.ScopeMap.fold Scopelang.Ast.ScopeMap.fold
(fun _scope scope_decl ctx -> (fun _scope scope_decl ctx ->
@ -533,7 +666,9 @@ let translate_program (pgrm : Ast.program) : Scopelang.Ast.program =
ctx with ctx with
scope_var_mapping = scope_var_mapping =
Ast.ScopeVarMap.add scope_var Ast.ScopeVarMap.add scope_var
(WholeVar (Scopelang.Ast.ScopeVar.fresh (Ast.ScopeVar.get_info scope_var))) (WholeVar
(Scopelang.Ast.ScopeVar.fresh
(Ast.ScopeVar.get_info scope_var)))
ctx.scope_var_mapping; ctx.scope_var_mapping;
} }
| States states -> | States states ->
@ -546,15 +681,21 @@ let translate_program (pgrm : Ast.program) : Scopelang.Ast.program =
(fun state -> (fun state ->
( state, ( state,
Scopelang.Ast.ScopeVar.fresh Scopelang.Ast.ScopeVar.fresh
(let state_name, state_pos = Ast.StateName.get_info state in (let state_name, state_pos =
( Pos.unmark (Ast.ScopeVar.get_info scope_var) ^ "_" ^ state_name, Ast.StateName.get_info state
in
( Pos.unmark (Ast.ScopeVar.get_info scope_var)
^ "_" ^ state_name,
state_pos )) )) state_pos )) ))
states)) states))
ctx.scope_var_mapping; ctx.scope_var_mapping;
}) })
scope_decl.Ast.scope_vars ctx) scope_decl.Ast.scope_vars ctx)
pgrm.Ast.program_scopes pgrm.Ast.program_scopes
{ scope_var_mapping = Ast.ScopeVarMap.empty; var_mapping = Ast.VarMap.empty } {
scope_var_mapping = Ast.ScopeVarMap.empty;
var_mapping = Ast.VarMap.empty;
}
in in
{ {
Scopelang.Ast.program_scopes = Scopelang.Ast.program_scopes =

View File

@ -1,15 +1,17 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits (* This file is part of the Catala compiler, a specification language for tax
computation rules. Copyright (C) 2020 Inria, contributor: Denis Merigoux and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
<denis.merigoux@inria.fr> Denis Merigoux <denis.merigoux@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
(** Translation from {!module: Desugared.Ast} to {!module: Scopelang.Ast} *) (** Translation from {!module: Desugared.Ast} to {!module: Scopelang.Ast} *)

View File

@ -1,15 +1,18 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits (* This file is part of the Catala compiler, a specification language for tax
computation rules. Copyright (C) 2020 Inria, contributors: Denis Merigoux and social benefits computation rules. Copyright (C) 2020 Inria,
<denis.merigoux@inria.fr>, Emile Rolley <emile.rolley@tuta.io> contributors: Denis Merigoux <denis.merigoux@inria.fr>, Emile Rolley
<emile.rolley@tuta.io>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
module Cli = Utils.Cli module Cli = Utils.Cli
@ -19,16 +22,27 @@ module Pos = Utils.Pos
(** Associates a {!type: Cli.backend_lang} with its string represtation. *) (** Associates a {!type: Cli.backend_lang} with its string represtation. *)
let languages = [ ("en", Cli.En); ("fr", Cli.Fr); ("pl", Cli.Pl) ] let languages = [ ("en", Cli.En); ("fr", Cli.Fr); ("pl", Cli.Pl) ]
(** Associates a file extension with its corresponding {!type: Cli.backend_lang} string (** Associates a file extension with its corresponding {!type: Cli.backend_lang}
representation. *) string representation. *)
let extensions = [ (".catala_fr", "fr"); (".catala_en", "en"); (".catala_pl", "pl") ] let extensions =
[ (".catala_fr", "fr"); (".catala_en", "en"); (".catala_pl", "pl") ]
(** Entry function for the executable. Returns a negative number in case of error. Usage: (** Entry function for the executable. Returns a negative number in case of
error. Usage:
[driver source_file debug dcalc unstyled wrap_weaved_output backend language max_prec_digits trace optimize scope_to_execute output_file]*) [driver source_file debug dcalc unstyled wrap_weaved_output backend language max_prec_digits trace optimize scope_to_execute output_file]*)
let driver (source_file : Pos.input_file) (debug : bool) (unstyled : bool) let driver
(wrap_weaved_output : bool) (avoid_exceptions : bool) (backend : string) (source_file : Pos.input_file)
(language : string option) (max_prec_digits : int option) (trace : bool) (debug : bool)
(disable_counterexamples : bool) (optimize : bool) (ex_scope : string option) (unstyled : bool)
(wrap_weaved_output : bool)
(avoid_exceptions : bool)
(backend : string)
(language : string option)
(max_prec_digits : int option)
(trace : bool)
(disable_counterexamples : bool)
(optimize : bool)
(ex_scope : string option)
(output_file : string option) : int = (output_file : string option) : int =
try try
Cli.debug_flag := debug; Cli.debug_flag := debug;
@ -39,8 +53,12 @@ let driver (source_file : Pos.input_file) (debug : bool) (unstyled : bool)
Cli.avoid_exceptions_flag := avoid_exceptions; Cli.avoid_exceptions_flag := avoid_exceptions;
Cli.debug_print "Reading files..."; Cli.debug_print "Reading files...";
let filename = ref "" in let filename = ref "" in
(match source_file with FileName f -> filename := f | Contents c -> Cli.contents := c); (match source_file with
(match max_prec_digits with None -> () | Some i -> Cli.max_prec_digits := i); | FileName f -> filename := f
| Contents c -> Cli.contents := c);
(match max_prec_digits with
| None -> ()
| Some i -> Cli.max_prec_digits := i);
let l = let l =
match language with match language with
| Some l -> l | Some l -> l
@ -49,15 +67,16 @@ let driver (source_file : Pos.input_file) (debug : bool) (unstyled : bool)
let ext = Filename.extension !filename in let ext = Filename.extension !filename in
if ext = "" then if ext = "" then
Errors.raise_error Errors.raise_error
"No file extension found for the file '%s'. (Try to add one or to specify the -l \ "No file extension found for the file '%s'. (Try to add one or \
flag)" to specify the -l flag)"
!filename; !filename;
try List.assoc ext extensions with Not_found -> ext) try List.assoc ext extensions with Not_found -> ext)
in in
let language = let language =
try List.assoc l languages try List.assoc l languages
with Not_found -> with Not_found ->
Errors.raise_error "The selected language (%s) is not supported by Catala" l Errors.raise_error
"The selected language (%s) is not supported by Catala" l
in in
Cli.locale_lang := language; Cli.locale_lang := language;
let backend = let backend =
@ -74,9 +93,13 @@ let driver (source_file : Pos.input_file) (debug : bool) (unstyled : bool)
else if backend = "typecheck" then Cli.Typecheck else if backend = "typecheck" then Cli.Typecheck
else if backend = "lcalc" then Cli.Lcalc else if backend = "lcalc" then Cli.Lcalc
else if backend = "scalc" then Cli.Scalc else if backend = "scalc" then Cli.Scalc
else Errors.raise_error "The selected backend (%s) is not supported by Catala" backend else
Errors.raise_error
"The selected backend (%s) is not supported by Catala" backend
in
let prgm =
Surface.Parser_driver.parse_top_level_file source_file language
in in
let prgm = Surface.Parser_driver.parse_top_level_file source_file language in
let prgm = Surface.Fill_positions.fill_pos_with_legislative_info prgm in let prgm = Surface.Fill_positions.fill_pos_with_legislative_info prgm in
match backend with match backend with
| Cli.Makefile -> | Cli.Makefile ->
@ -85,7 +108,8 @@ let driver (source_file : Pos.input_file) (debug : bool) (unstyled : bool)
match source_file with match source_file with
| FileName f -> f | FileName f -> f
| Contents _ -> | Contents _ ->
Errors.raise_error "The Makefile backend does not work if the input is not a file" Errors.raise_error
"The Makefile backend does not work if the input is not a file"
in in
let output_file = let output_file =
match output_file with match output_file with
@ -109,7 +133,8 @@ let driver (source_file : Pos.input_file) (debug : bool) (unstyled : bool)
| FileName f -> f | FileName f -> f
| Contents _ -> | Contents _ ->
Errors.raise_error Errors.raise_error
"The literate programming backends do not work if the input is not a file" "The literate programming backends do not work if the input is \
not a file"
in in
Cli.debug_print "Weaving literate program into %s" Cli.debug_print "Weaving literate program into %s"
(match backend with (match backend with
@ -122,7 +147,10 @@ let driver (source_file : Pos.input_file) (debug : bool) (unstyled : bool)
| None -> ( | None -> (
Filename.remove_extension source_file Filename.remove_extension source_file
^ ^
match backend with Cli.Latex -> ".tex" | Cli.Html -> ".html" | _ -> assert false match backend with
| Cli.Latex -> ".tex"
| Cli.Html -> ".html"
| _ -> assert false
(* should not happen *)) (* should not happen *))
in in
let oc = open_out output_file in let oc = open_out output_file in
@ -138,11 +166,11 @@ let driver (source_file : Pos.input_file) (debug : bool) (unstyled : bool)
if wrap_weaved_output then if wrap_weaved_output then
match backend with match backend with
| Cli.Latex -> | Cli.Latex ->
Literate.Latex.wrap_latex prgm.Surface.Ast.program_source_files language fmt Literate.Latex.wrap_latex prgm.Surface.Ast.program_source_files
(fun fmt -> weave_output fmt prgm) language fmt (fun fmt -> weave_output fmt prgm)
| Cli.Html -> | Cli.Html ->
Literate.Html.wrap_html prgm.Surface.Ast.program_source_files language fmt (fun fmt -> Literate.Html.wrap_html prgm.Surface.Ast.program_source_files
weave_output fmt prgm) language fmt (fun fmt -> weave_output fmt prgm)
| _ -> assert false (* should not happen *) | _ -> assert false (* should not happen *)
else weave_output fmt prgm; else weave_output fmt prgm;
close_out oc; close_out oc;
@ -152,14 +180,19 @@ let driver (source_file : Pos.input_file) (debug : bool) (unstyled : bool)
let ctxt = Surface.Name_resolution.form_context prgm in let ctxt = Surface.Name_resolution.form_context prgm in
let scope_uid = let scope_uid =
match (ex_scope, backend) with match (ex_scope, backend) with
| None, Cli.Interpret -> Errors.raise_error "No scope was provided for execution." | None, Cli.Interpret ->
Errors.raise_error "No scope was provided for execution."
| None, _ -> | None, _ ->
snd snd
(try Desugared.Ast.IdentMap.choose ctxt.scope_idmap (try Desugared.Ast.IdentMap.choose ctxt.scope_idmap
with Not_found -> Errors.raise_error "There isn't any scope inside the program.") with Not_found ->
Errors.raise_error
"There isn't any scope inside the program.")
| Some name, _ -> ( | Some name, _ -> (
match Desugared.Ast.IdentMap.find_opt name ctxt.scope_idmap with match Desugared.Ast.IdentMap.find_opt name ctxt.scope_idmap with
| None -> Errors.raise_error "There is no scope \"%s\" inside the program." name | None ->
Errors.raise_error
"There is no scope \"%s\" inside the program." name
| Some uid -> uid) | Some uid -> uid)
in in
Cli.debug_print "Desugaring..."; Cli.debug_print "Desugaring...";
@ -176,13 +209,16 @@ let driver (source_file : Pos.input_file) (debug : bool) (unstyled : bool)
in in
if Option.is_some ex_scope then if Option.is_some ex_scope then
Format.fprintf fmt "%a\n" Scopelang.Print.format_scope Format.fprintf fmt "%a\n" Scopelang.Print.format_scope
(scope_uid, Scopelang.Ast.ScopeMap.find scope_uid prgm.program_scopes) ( scope_uid,
Scopelang.Ast.ScopeMap.find scope_uid prgm.program_scopes )
else Format.fprintf fmt "%a\n" Scopelang.Print.format_program prgm; else Format.fprintf fmt "%a\n" Scopelang.Print.format_program prgm;
at_end (); at_end ();
exit 0 exit 0
end; end;
Cli.debug_print "Translating to default calculus..."; Cli.debug_print "Translating to default calculus...";
let prgm, type_ordering = Scopelang.Scope_to_dcalc.translate_program prgm in let prgm, type_ordering =
Scopelang.Scope_to_dcalc.translate_program prgm
in
let prgm = let prgm =
if optimize then begin if optimize then begin
Cli.debug_print "Optimizing default calculus..."; Cli.debug_print "Optimizing default calculus...";
@ -190,7 +226,9 @@ let driver (source_file : Pos.input_file) (debug : bool) (unstyled : bool)
end end
else prgm else prgm
in in
let prgrm_dcalc_expr = Bindlib.unbox (Dcalc.Ast.build_whole_program_expr prgm scope_uid) in let prgrm_dcalc_expr =
Bindlib.unbox (Dcalc.Ast.build_whole_program_expr prgm scope_uid)
in
if backend = Cli.Dcalc then begin if backend = Cli.Dcalc then begin
let fmt, at_end = let fmt, at_end =
match output_file with match output_file with
@ -202,38 +240,51 @@ let driver (source_file : Pos.input_file) (debug : bool) (unstyled : bool)
if Option.is_some ex_scope then if Option.is_some ex_scope then
Format.fprintf fmt "%a\n" Format.fprintf fmt "%a\n"
(Dcalc.Print.format_scope ~debug prgm.decl_ctx) (Dcalc.Print.format_scope ~debug prgm.decl_ctx)
(let _, _, s = List.find (fun (name, _, _) -> name = scope_uid) prgm.scopes in (let _, _, s =
List.find (fun (name, _, _) -> name = scope_uid) prgm.scopes
in
(scope_uid, s)) (scope_uid, s))
else Format.fprintf fmt "%a\n" (Dcalc.Print.format_expr prgm.decl_ctx) prgrm_dcalc_expr; else
Format.fprintf fmt "%a\n"
(Dcalc.Print.format_expr prgm.decl_ctx)
prgrm_dcalc_expr;
at_end (); at_end ();
exit 0 exit 0
end; end;
Cli.debug_print "Typechecking..."; Cli.debug_print "Typechecking...";
let _typ = Dcalc.Typing.infer_type prgm.decl_ctx prgrm_dcalc_expr in let _typ = Dcalc.Typing.infer_type prgm.decl_ctx prgrm_dcalc_expr in
(* Cli.debug_print (Format.asprintf "Typechecking results :@\n%a" (Dcalc.Print.format_typ (* Cli.debug_print (Format.asprintf "Typechecking results :@\n%a"
prgm.decl_ctx) typ); *) (Dcalc.Print.format_typ prgm.decl_ctx) typ); *)
match backend with match backend with
| Cli.Typecheck -> | Cli.Typecheck ->
(* That's it! *) (* That's it! *)
Cli.result_print "Typechecking successful!"; Cli.result_print "Typechecking successful!";
0 0
| Cli.Proof -> | Cli.Proof ->
let vcs = Verification.Conditions.generate_verification_conditions prgm in let vcs =
Verification.Conditions.generate_verification_conditions prgm
in
Verification.Solver.solve_vc prgm prgm.decl_ctx vcs; Verification.Solver.solve_vc prgm prgm.decl_ctx vcs;
0 0
| Cli.Interpret -> | Cli.Interpret ->
Cli.debug_print "Starting interpretation..."; Cli.debug_print "Starting interpretation...";
let results = Dcalc.Interpreter.interpret_program prgm.decl_ctx prgrm_dcalc_expr in let results =
Dcalc.Interpreter.interpret_program prgm.decl_ctx prgrm_dcalc_expr
in
let out_regex = Re.Pcre.regexp "\\_out$" in let out_regex = Re.Pcre.regexp "\\_out$" in
let results = let results =
List.map List.map
(fun ((v1, v1_pos), e1) -> (fun ((v1, v1_pos), e1) ->
let v1 = Re.Pcre.substitute ~rex:out_regex ~subst:(fun _ -> "") v1 in let v1 =
Re.Pcre.substitute ~rex:out_regex ~subst:(fun _ -> "") v1
in
((v1, v1_pos), e1)) ((v1, v1_pos), e1))
results results
in in
let results = let results =
List.sort (fun ((v1, _), _) ((v2, _), _) -> String.compare v1 v2) results List.sort
(fun ((v1, _), _) ((v2, _), _) -> String.compare v1 v2)
results
in in
Cli.debug_print "End of interpretation"; Cli.debug_print "End of interpretation";
Cli.result_print "Computation successful!%s" Cli.result_print "Computation successful!%s"
@ -248,7 +299,8 @@ let driver (source_file : Pos.input_file) (debug : bool) (unstyled : bool)
| Cli.OCaml | Cli.Python | Cli.Lcalc | Cli.Scalc -> | Cli.OCaml | Cli.Python | Cli.Lcalc | Cli.Scalc ->
Cli.debug_print "Compiling program into lambda calculus..."; Cli.debug_print "Compiling program into lambda calculus...";
let prgm = let prgm =
if avoid_exceptions then Lcalc.Compile_without_exceptions.translate_program prgm if avoid_exceptions then
Lcalc.Compile_without_exceptions.translate_program prgm
else Lcalc.Compile_with_exceptions.translate_program prgm else Lcalc.Compile_with_exceptions.translate_program prgm
in in
let prgm = let prgm =
@ -270,14 +322,17 @@ let driver (source_file : Pos.input_file) (debug : bool) (unstyled : bool)
Format.fprintf fmt "%a\n" Format.fprintf fmt "%a\n"
(Lcalc.Print.format_scope ~debug prgm.decl_ctx) (Lcalc.Print.format_scope ~debug prgm.decl_ctx)
(let body = (let body =
List.find (fun body -> body.Lcalc.Ast.scope_body_name = scope_uid) prgm.scopes List.find
(fun body -> body.Lcalc.Ast.scope_body_name = scope_uid)
prgm.scopes
in in
body) body)
else else
Format.fprintf fmt "%a\n" Format.fprintf fmt "%a\n"
(Format.pp_print_list (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "\n\n") ~pp_sep:(fun fmt () -> Format.fprintf fmt "\n\n")
(fun fmt scope -> (Lcalc.Print.format_scope prgm.decl_ctx) fmt scope)) (fun fmt scope ->
(Lcalc.Print.format_scope prgm.decl_ctx) fmt scope))
prgm.scopes; prgm.scopes;
at_end (); at_end ();
exit 0 exit 0
@ -286,7 +341,8 @@ let driver (source_file : Pos.input_file) (debug : bool) (unstyled : bool)
match source_file with match source_file with
| FileName f -> f | FileName f -> f
| Contents _ -> | Contents _ ->
Errors.raise_error "This backend does not work if the input is not a file" Errors.raise_error
"This backend does not work if the input is not a file"
in in
let new_output_file (extension : string) : string = let new_output_file (extension : string) : string =
match output_file with match output_file with
@ -309,7 +365,8 @@ let driver (source_file : Pos.input_file) (debug : bool) (unstyled : bool)
match output_file with match output_file with
| Some f -> | Some f ->
let oc = open_out f in let oc = open_out f in
(Format.formatter_of_out_channel oc, fun _ -> close_out oc) ( Format.formatter_of_out_channel oc,
fun _ -> close_out oc )
| None -> (Format.std_formatter, fun _ -> ()) | None -> (Format.std_formatter, fun _ -> ())
in in
if Option.is_some ex_scope then if Option.is_some ex_scope then
@ -317,7 +374,8 @@ let driver (source_file : Pos.input_file) (debug : bool) (unstyled : bool)
(Scalc.Print.format_scope ~debug prgm.decl_ctx) (Scalc.Print.format_scope ~debug prgm.decl_ctx)
(let body = (let body =
List.find List.find
(fun body -> body.Scalc.Ast.scope_body_name = scope_uid) (fun body ->
body.Scalc.Ast.scope_body_name = scope_uid)
prgm.scopes prgm.scopes
in in
body) body)
@ -325,7 +383,8 @@ let driver (source_file : Pos.input_file) (debug : bool) (unstyled : bool)
Format.fprintf fmt "%a\n" Format.fprintf fmt "%a\n"
(Format.pp_print_list (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "\n\n") ~pp_sep:(fun fmt () -> Format.fprintf fmt "\n\n")
(fun fmt scope -> (Scalc.Print.format_scope prgm.decl_ctx) fmt scope)) (fun fmt scope ->
(Scalc.Print.format_scope prgm.decl_ctx) fmt scope))
prgm.scopes; prgm.scopes;
at_end (); at_end ();
exit 0 exit 0
@ -350,7 +409,9 @@ let driver (source_file : Pos.input_file) (debug : bool) (unstyled : bool)
-1 -1
let main () = let main () =
let return_code = Cmdliner.Term.eval (Cli.catala_t (fun f -> driver (FileName f)), Cli.info) in let return_code =
Cmdliner.Term.eval (Cli.catala_t (fun f -> driver (FileName f)), Cli.info)
in
match return_code with match return_code with
| `Ok 0 -> Cmdliner.Term.exit (`Ok 0) | `Ok 0 -> Cmdliner.Term.exit (`Ok 0)
| _ -> Cmdliner.Term.exit (`Error `Term) | _ -> Cmdliner.Term.exit (`Error `Term)

View File

@ -1,15 +1,17 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits (* This file is part of the Catala compiler, a specification language for tax
computation rules. Copyright (C) 2020 Inria, contributor: Denis Merigoux and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
<denis.merigoux@inria.fr> Denis Merigoux <denis.merigoux@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
open Utils open Utils
@ -30,7 +32,8 @@ type expr =
| EVar of expr Bindlib.var Pos.marked | EVar of expr Bindlib.var Pos.marked
| ETuple of expr Pos.marked list * D.StructName.t option | ETuple of expr Pos.marked list * D.StructName.t option
(** The [MarkedString.info] is the former struct field name*) (** The [MarkedString.info] is the former struct field name*)
| ETupleAccess of expr Pos.marked * int * D.StructName.t option * D.typ Pos.marked list | ETupleAccess of
expr Pos.marked * int * D.StructName.t option * D.typ Pos.marked list
(** The [MarkedString.info] is the former struct field name *) (** The [MarkedString.info] is the former struct field name *)
| EInj of expr Pos.marked * int * D.EnumName.t * D.typ Pos.marked list | EInj of expr Pos.marked * int * D.EnumName.t * D.typ Pos.marked list
(** The [MarkedString.info] is the former enum case name *) (** The [MarkedString.info] is the former enum case name *)
@ -38,7 +41,8 @@ type expr =
(** The [MarkedString.info] is the former enum case name *) (** The [MarkedString.info] is the former enum case name *)
| EArray of expr Pos.marked list | EArray of expr Pos.marked list
| ELit of lit | ELit of lit
| EAbs of (expr, expr Pos.marked) Bindlib.mbinder Pos.marked * D.typ Pos.marked list | EAbs of
(expr, expr Pos.marked) Bindlib.mbinder Pos.marked * D.typ Pos.marked list
| EApp of expr Pos.marked * expr Pos.marked list | EApp of expr Pos.marked * expr Pos.marked list
| EAssert of expr Pos.marked | EAssert of expr Pos.marked
| EOp of D.operator | EOp of D.operator
@ -64,37 +68,55 @@ type vars = expr Bindlib.mvar
let make_var ((x, pos) : Var.t Pos.marked) : expr Pos.marked Bindlib.box = let make_var ((x, pos) : Var.t Pos.marked) : expr Pos.marked Bindlib.box =
Bindlib.box_apply (fun x -> (x, pos)) (Bindlib.box_var x) Bindlib.box_apply (fun x -> (x, pos)) (Bindlib.box_var x)
let make_abs (xs : vars) (e : expr Pos.marked Bindlib.box) (pos_binder : Pos.t) let make_abs
(taus : D.typ Pos.marked list) (pos : Pos.t) : expr Pos.marked Bindlib.box = (xs : vars)
Bindlib.box_apply (fun b -> (EAbs ((b, pos_binder), taus), pos)) (Bindlib.bind_mvar xs e) (e : expr Pos.marked Bindlib.box)
(pos_binder : Pos.t)
(taus : D.typ Pos.marked list)
(pos : Pos.t) : expr Pos.marked Bindlib.box =
Bindlib.box_apply
(fun b -> (EAbs ((b, pos_binder), taus), pos))
(Bindlib.bind_mvar xs e)
let make_app (e : expr Pos.marked Bindlib.box) (u : expr Pos.marked Bindlib.box list) (pos : Pos.t) let make_app
: expr Pos.marked Bindlib.box = (e : expr Pos.marked Bindlib.box)
(u : expr Pos.marked Bindlib.box list)
(pos : Pos.t) : expr Pos.marked Bindlib.box =
Bindlib.box_apply2 (fun e u -> (EApp (e, u), pos)) e (Bindlib.box_list u) Bindlib.box_apply2 (fun e u -> (EApp (e, u), pos)) e (Bindlib.box_list u)
let make_let_in (x : Var.t) (tau : D.typ Pos.marked) (e1 : expr Pos.marked Bindlib.box) let make_let_in
(x : Var.t)
(tau : D.typ Pos.marked)
(e1 : expr Pos.marked Bindlib.box)
(e2 : expr Pos.marked Bindlib.box) : expr Pos.marked Bindlib.box = (e2 : expr Pos.marked Bindlib.box) : expr Pos.marked Bindlib.box =
let pos = Pos.get_position (Bindlib.unbox e2) in let pos = Pos.get_position (Bindlib.unbox e2) in
make_app (make_abs (Array.of_list [ x ]) e2 pos [ tau ] pos) [ e1 ] pos make_app (make_abs (Array.of_list [ x ]) e2 pos [ tau ] pos) [ e1 ] pos
let ( let+ ) x f = Bindlib.box_apply f x let ( let+ ) x f = Bindlib.box_apply f x
let ( and+ ) x y = Bindlib.box_pair x y let ( and+ ) x y = Bindlib.box_pair x y
let option_enum : D.EnumName.t = D.EnumName.fresh ("eoption", Pos.no_pos) let option_enum : D.EnumName.t = D.EnumName.fresh ("eoption", Pos.no_pos)
let none_constr : D.EnumConstructor.t = D.EnumConstructor.fresh ("ENone", Pos.no_pos) let none_constr : D.EnumConstructor.t =
D.EnumConstructor.fresh ("ENone", Pos.no_pos)
let some_constr : D.EnumConstructor.t = D.EnumConstructor.fresh ("ESome", Pos.no_pos) let some_constr : D.EnumConstructor.t =
D.EnumConstructor.fresh ("ESome", Pos.no_pos)
let option_enum_config : (D.EnumConstructor.t * D.typ Pos.marked) list = let option_enum_config : (D.EnumConstructor.t * D.typ Pos.marked) list =
[ (none_constr, (D.TLit D.TUnit, Pos.no_pos)); (some_constr, (D.TAny, Pos.no_pos)) ] [
(none_constr, (D.TLit D.TUnit, Pos.no_pos));
(some_constr, (D.TAny, Pos.no_pos));
]
let make_none (pos : Pos.t) : expr Pos.marked Bindlib.box = let make_none (pos : Pos.t) : expr Pos.marked Bindlib.box =
let mark : 'a -> 'a Pos.marked = Pos.mark pos in let mark : 'a -> 'a Pos.marked = Pos.mark pos in
Bindlib.box @@ mark Bindlib.box @@ mark
@@ EInj (mark @@ ELit LUnit, 0, option_enum, [ (D.TLit D.TUnit, pos); (D.TAny, pos) ]) @@ EInj
( mark @@ ELit LUnit,
0,
option_enum,
[ (D.TLit D.TUnit, pos); (D.TAny, pos) ] )
let make_some (e : expr Pos.marked Bindlib.box) : expr Pos.marked Bindlib.box = let make_some (e : expr Pos.marked Bindlib.box) : expr Pos.marked Bindlib.box =
let pos = Pos.get_position @@ Bindlib.unbox e in let pos = Pos.get_position @@ Bindlib.unbox e in
@ -103,11 +125,12 @@ let make_some (e : expr Pos.marked Bindlib.box) : expr Pos.marked Bindlib.box =
mark @@ EInj (e, 1, option_enum, [ (D.TLit D.TUnit, pos); (D.TAny, pos) ]) mark @@ EInj (e, 1, option_enum, [ (D.TLit D.TUnit, pos); (D.TAny, pos) ])
(** [make_matchopt_with_abs_arms arg e_none e_some] build an expression (** [make_matchopt_with_abs_arms arg e_none e_some] build an expression
[match arg with |None -> e_none | Some -> e_some] and requires e_some and e_none to be in the [match arg with |None -> e_none | Some -> e_some] and requires e_some and
form [EAbs ...].*) e_none to be in the form [EAbs ...].*)
let make_matchopt_with_abs_arms (arg : expr Pos.marked Bindlib.box) let make_matchopt_with_abs_arms
(e_none : expr Pos.marked Bindlib.box) (e_some : expr Pos.marked Bindlib.box) : (arg : expr Pos.marked Bindlib.box)
expr Pos.marked Bindlib.box = (e_none : expr Pos.marked Bindlib.box)
(e_some : expr Pos.marked Bindlib.box) : expr Pos.marked Bindlib.box =
let pos = Pos.get_position @@ Bindlib.unbox arg in let pos = Pos.get_position @@ Bindlib.unbox arg in
let mark : 'a -> 'a Pos.marked = Pos.mark pos in let mark : 'a -> 'a Pos.marked = Pos.mark pos in
@ -116,10 +139,15 @@ let make_matchopt_with_abs_arms (arg : expr Pos.marked Bindlib.box)
mark @@ EMatch (arg, [ e_none; e_some ], option_enum) mark @@ EMatch (arg, [ e_none; e_some ], option_enum)
(** [make_matchopt pos v tau arg e_none e_some] builds an expression (** [make_matchopt pos v tau arg e_none e_some] builds an expression
[match arg with | None () -> e_none | Some v -> e_some]. It binds v to e_some, permitting it to [match arg with | None () -> e_none | Some v -> e_some]. It binds v to
be used inside the expression. There is no requirements on the form of both e_some and e_none. *) e_some, permitting it to be used inside the expression. There is no
let make_matchopt (pos : Pos.t) (v : Var.t) (tau : D.typ Pos.marked) requirements on the form of both e_some and e_none. *)
(arg : expr Pos.marked Bindlib.box) (e_none : expr Pos.marked Bindlib.box) let make_matchopt
(pos : Pos.t)
(v : Var.t)
(tau : D.typ Pos.marked)
(arg : expr Pos.marked Bindlib.box)
(e_none : expr Pos.marked Bindlib.box)
(e_some : expr Pos.marked Bindlib.box) : expr Pos.marked Bindlib.box = (e_some : expr Pos.marked Bindlib.box) : expr Pos.marked Bindlib.box =
let x = Var.make ("_", pos) in let x = Var.make ("_", pos) in
@ -128,7 +156,6 @@ let make_matchopt (pos : Pos.t) (v : Var.t) (tau : D.typ Pos.marked)
(make_abs (Array.of_list [ v ]) e_some pos [ tau ] pos) (make_abs (Array.of_list [ v ]) e_some pos [ tau ] pos)
let handle_default = Var.make ("handle_default", Pos.no_pos) let handle_default = Var.make ("handle_default", Pos.no_pos)
let handle_default_opt = Var.make ("handle_default_opt", Pos.no_pos) let handle_default_opt = Var.make ("handle_default_opt", Pos.no_pos)
type binder = (expr, expr Pos.marked) Bindlib.binder type binder = (expr, expr Pos.marked) Bindlib.binder

View File

@ -1,15 +1,17 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits (* This file is part of the Catala compiler, a specification language for tax
computation rules. Copyright (C) 2020 Inria, contributor: Denis Merigoux and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
<denis.merigoux@inria.fr> Denis Merigoux <denis.merigoux@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
open Utils open Utils
@ -18,8 +20,8 @@ open Utils
(** {1 Abstract syntax tree} *) (** {1 Abstract syntax tree} *)
(** The expressions use the {{:https://lepigre.fr/ocaml-bindlib/} Bindlib} library, based on (** The expressions use the {{:https://lepigre.fr/ocaml-bindlib/} Bindlib}
higher-order abstract syntax*) library, based on higher-order abstract syntax*)
type lit = type lit =
| LBool of bool | LBool of bool
@ -37,15 +39,24 @@ type expr =
| ETuple of expr Pos.marked list * Dcalc.Ast.StructName.t option | ETuple of expr Pos.marked list * Dcalc.Ast.StructName.t option
(** The [MarkedString.info] is the former struct field name*) (** The [MarkedString.info] is the former struct field name*)
| ETupleAccess of | ETupleAccess of
expr Pos.marked * int * Dcalc.Ast.StructName.t option * Dcalc.Ast.typ Pos.marked list expr Pos.marked
* int
* Dcalc.Ast.StructName.t option
* Dcalc.Ast.typ Pos.marked list
(** The [MarkedString.info] is the former struct field name *) (** The [MarkedString.info] is the former struct field name *)
| EInj of expr Pos.marked * int * Dcalc.Ast.EnumName.t * Dcalc.Ast.typ Pos.marked list | EInj of
expr Pos.marked
* int
* Dcalc.Ast.EnumName.t
* Dcalc.Ast.typ Pos.marked list
(** The [MarkedString.info] is the former enum case name *) (** The [MarkedString.info] is the former enum case name *)
| EMatch of expr Pos.marked * expr Pos.marked list * Dcalc.Ast.EnumName.t | EMatch of expr Pos.marked * expr Pos.marked list * Dcalc.Ast.EnumName.t
(** The [MarkedString.info] is the former enum case name *) (** The [MarkedString.info] is the former enum case name *)
| EArray of expr Pos.marked list | EArray of expr Pos.marked list
| ELit of lit | ELit of lit
| EAbs of (expr, expr Pos.marked) Bindlib.mbinder Pos.marked * Dcalc.Ast.typ Pos.marked list | EAbs of
(expr, expr Pos.marked) Bindlib.mbinder Pos.marked
* Dcalc.Ast.typ Pos.marked list
| EApp of expr Pos.marked * expr Pos.marked list | EApp of expr Pos.marked * expr Pos.marked list
| EAssert of expr Pos.marked | EAssert of expr Pos.marked
| EOp of Dcalc.Ast.operator | EOp of Dcalc.Ast.operator
@ -59,7 +70,6 @@ module Var : sig
type t = expr Bindlib.var type t = expr Bindlib.var
val make : string Pos.marked -> t val make : string Pos.marked -> t
val compare : t -> t -> int val compare : t -> t -> int
end end
@ -91,15 +101,13 @@ val make_let_in :
expr Pos.marked Bindlib.box expr Pos.marked Bindlib.box
val option_enum : Dcalc.Ast.EnumName.t val option_enum : Dcalc.Ast.EnumName.t
val none_constr : Dcalc.Ast.EnumConstructor.t val none_constr : Dcalc.Ast.EnumConstructor.t
val some_constr : Dcalc.Ast.EnumConstructor.t val some_constr : Dcalc.Ast.EnumConstructor.t
val option_enum_config : (Dcalc.Ast.EnumConstructor.t * Dcalc.Ast.typ Pos.marked) list val option_enum_config :
(Dcalc.Ast.EnumConstructor.t * Dcalc.Ast.typ Pos.marked) list
val make_none : Pos.t -> expr Pos.marked Bindlib.box val make_none : Pos.t -> expr Pos.marked Bindlib.box
val make_some : expr Pos.marked Bindlib.box -> expr Pos.marked Bindlib.box val make_some : expr Pos.marked Bindlib.box -> expr Pos.marked Bindlib.box
val make_matchopt_with_abs_arms : val make_matchopt_with_abs_arms :
@ -116,11 +124,10 @@ val make_matchopt :
expr Pos.marked Bindlib.box -> expr Pos.marked Bindlib.box ->
expr Pos.marked Bindlib.box -> expr Pos.marked Bindlib.box ->
expr Pos.marked Bindlib.box expr Pos.marked Bindlib.box
(** [e' = make_matchopt'' pos v e e_none e_some] Builds the term corresponding to (** [e' = make_matchopt'' pos v e e_none e_some] Builds the term corresponding
[match e with | None -> fun () -> e_none |Some -> fun v -> e_some]. *) to [match e with | None -> fun () -> e_none |Some -> fun v -> e_some]. *)
val handle_default : Var.t val handle_default : Var.t
val handle_default_opt : Var.t val handle_default_opt : Var.t
type binder = (expr, expr Pos.marked) Bindlib.binder type binder = (expr, expr Pos.marked) Bindlib.binder

View File

@ -1,15 +1,17 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits (* This file is part of the Catala compiler, a specification language for tax
computation rules. Copyright (C) 2021 Inria, contributor: Denis Merigoux and social benefits computation rules. Copyright (C) 2021 Inria, contributor:
<denis.merigoux@inria.fr> Denis Merigoux <denis.merigoux@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
let to_ascii (s : string) : string = let to_ascii (s : string) : string =
@ -48,7 +50,8 @@ let to_lowercase (s : string) : string =
out := out :=
!out !out
^ (if is_uppercase && not !is_first then "_" else "") ^ (if is_uppercase && not !is_first then "_" else "")
^ String.lowercase_ascii (String.make 1 (CamomileLibraryDefault.Camomile.UChar.char_of c)); ^ String.lowercase_ascii
(String.make 1 (CamomileLibraryDefault.Camomile.UChar.char_of c));
is_first := false) is_first := false)
s; s;
!out !out
@ -59,13 +62,18 @@ let to_uppercase (s : string) : string =
let out = ref "" in let out = ref "" in
CamomileLibraryDefault.Camomile.UTF8.iter CamomileLibraryDefault.Camomile.UTF8.iter
(fun c -> (fun c ->
let is_underscore = c = CamomileLibraryDefault.Camomile.UChar.of_char '_' in let is_underscore =
let c_string = String.make 1 (CamomileLibraryDefault.Camomile.UChar.char_of c) in c = CamomileLibraryDefault.Camomile.UChar.of_char '_'
in
let c_string =
String.make 1 (CamomileLibraryDefault.Camomile.UChar.char_of c)
in
out := out :=
!out !out
^ ^
if is_underscore then "" if is_underscore then ""
else if !last_was_underscore || !is_first then String.uppercase_ascii c_string else if !last_was_underscore || !is_first then
String.uppercase_ascii c_string
else c_string; else c_string;
last_was_underscore := is_underscore; last_was_underscore := is_underscore;
is_first := false) is_first := false)

View File

@ -1,22 +1,24 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits (* This file is part of the Catala compiler, a specification language for tax
computation rules. Copyright (C) 2021 Inria, contributor: Denis Merigoux and social benefits computation rules. Copyright (C) 2021 Inria, contributor:
<denis.merigoux@inria.fr> Denis Merigoux <denis.merigoux@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
(** Helper functions common to all Catala compiler backends *) (** Helper functions common to all Catala compiler backends *)
val to_ascii : string -> string val to_ascii : string -> string
(** Removes all non-ASCII diacritics from a string by converting them to their base letter in the (** Removes all non-ASCII diacritics from a string by converting them to their
Latin alphabet *) base letter in the Latin alphabet *)
val to_lowercase : string -> string val to_lowercase : string -> string
(** Converts CamlCase into snake_case *) (** Converts CamlCase into snake_case *)

View File

@ -1,15 +1,17 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits (* This file is part of the Catala compiler, a specification language for tax
computation rules. Copyright (C) 2020 Inria, contributor: Denis Merigoux and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
<denis.merigoux@inria.fr> Denis Merigoux <denis.merigoux@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
open Utils open Utils
@ -17,8 +19,8 @@ module D = Dcalc.Ast
module A = Ast module A = Ast
type ctx = A.expr Pos.marked Bindlib.box D.VarMap.t type ctx = A.expr Pos.marked Bindlib.box D.VarMap.t
(** This environment contains a mapping between the variables in Dcalc and their correspondance in (** This environment contains a mapping between the variables in Dcalc and their
Lcalc. *) correspondance in Lcalc. *)
let translate_lit (l : D.lit) : A.expr = let translate_lit (l : D.lit) : A.expr =
match l with match l with
@ -31,15 +33,21 @@ let translate_lit (l : D.lit) : A.expr =
| D.LDuration d -> A.ELit (A.LDuration d) | D.LDuration d -> A.ELit (A.LDuration d)
| D.LEmptyError -> A.ERaise A.EmptyError | D.LEmptyError -> A.ERaise A.EmptyError
let thunk_expr (e : A.expr Pos.marked Bindlib.box) (pos : Pos.t) : A.expr Pos.marked Bindlib.box = let thunk_expr (e : A.expr Pos.marked Bindlib.box) (pos : Pos.t) :
A.expr Pos.marked Bindlib.box =
let dummy_var = A.Var.make ("_", pos) in let dummy_var = A.Var.make ("_", pos) in
A.make_abs [| dummy_var |] e pos [ (D.TAny, pos) ] pos A.make_abs [| dummy_var |] e pos [ (D.TAny, pos) ] pos
let rec translate_default (ctx : ctx) (exceptions : D.expr Pos.marked list) let rec translate_default
(just : D.expr Pos.marked) (cons : D.expr Pos.marked) (pos_default : Pos.t) : (ctx : ctx)
A.expr Pos.marked Bindlib.box = (exceptions : D.expr Pos.marked list)
(just : D.expr Pos.marked)
(cons : D.expr Pos.marked)
(pos_default : Pos.t) : A.expr Pos.marked Bindlib.box =
let exceptions = let exceptions =
List.map (fun except -> thunk_expr (translate_expr ctx except) pos_default) exceptions List.map
(fun except -> thunk_expr (translate_expr ctx except) pos_default)
exceptions
in in
let exceptions = let exceptions =
A.make_app A.make_app
@ -55,7 +63,8 @@ let rec translate_default (ctx : ctx) (exceptions : D.expr Pos.marked list)
in in
exceptions exceptions
and translate_expr (ctx : ctx) (e : D.expr Pos.marked) : A.expr Pos.marked Bindlib.box = and translate_expr (ctx : ctx) (e : D.expr Pos.marked) :
A.expr Pos.marked Bindlib.box =
match Pos.unmark e with match Pos.unmark e with
| D.EVar v -> D.VarMap.find (Pos.unmark v) ctx | D.EVar v -> D.VarMap.find (Pos.unmark v) ctx
| D.ETuple (args, s) -> | D.ETuple (args, s) ->
@ -86,12 +95,17 @@ and translate_expr (ctx : ctx) (e : D.expr Pos.marked) : A.expr Pos.marked Bindl
(fun e1 e2 e3 -> Pos.same_pos_as (A.EIfThenElse (e1, e2, e3)) e) (fun e1 e2 e3 -> Pos.same_pos_as (A.EIfThenElse (e1, e2, e3)) e)
(translate_expr ctx e1) (translate_expr ctx e2) (translate_expr ctx e3) (translate_expr ctx e1) (translate_expr ctx e2) (translate_expr ctx e3)
| D.EAssert e1 -> | D.EAssert e1 ->
Bindlib.box_apply (fun e1 -> Pos.same_pos_as (A.EAssert e1) e) (translate_expr ctx e1) Bindlib.box_apply
(fun e1 -> Pos.same_pos_as (A.EAssert e1) e)
(translate_expr ctx e1)
| D.ErrorOnEmpty arg -> | D.ErrorOnEmpty arg ->
Bindlib.box_apply Bindlib.box_apply
(fun arg -> (fun arg ->
Pos.same_pos_as Pos.same_pos_as
(A.ECatch (arg, A.EmptyError, Pos.same_pos_as (A.ERaise A.NoValueProvided) e)) (A.ECatch
( arg,
A.EmptyError,
Pos.same_pos_as (A.ERaise A.NoValueProvided) e ))
e) e)
(translate_expr ctx arg) (translate_expr ctx arg)
| D.EApp (e1, args) -> | D.EApp (e1, args) ->
@ -113,7 +127,8 @@ and translate_expr (ctx : ctx) (e : D.expr Pos.marked) : A.expr Pos.marked Bindl
let new_body = translate_expr ctx body in let new_body = translate_expr ctx body in
let new_binder = Bindlib.bind_mvar lc_vars new_body in let new_binder = Bindlib.bind_mvar lc_vars new_body in
Bindlib.box_apply Bindlib.box_apply
(fun new_binder -> Pos.same_pos_as (A.EAbs ((new_binder, pos_binder), ts)) e) (fun new_binder ->
Pos.same_pos_as (A.EAbs ((new_binder, pos_binder), ts)) e)
new_binder new_binder
| D.EDefault ([ exn ], just, cons) when !Cli.optimize_flag -> | D.EDefault ([ exn ], just, cons) when !Cli.optimize_flag ->
Bindlib.box_apply3 Bindlib.box_apply3
@ -123,10 +138,12 @@ and translate_expr (ctx : ctx) (e : D.expr Pos.marked) : A.expr Pos.marked Bindl
( exn, ( exn,
A.EmptyError, A.EmptyError,
Pos.same_pos_as Pos.same_pos_as
(A.EIfThenElse (just, cons, Pos.same_pos_as (A.ERaise A.EmptyError) e)) (A.EIfThenElse
(just, cons, Pos.same_pos_as (A.ERaise A.EmptyError) e))
e )) e ))
e) e)
(translate_expr ctx exn) (translate_expr ctx just) (translate_expr ctx cons) (translate_expr ctx exn) (translate_expr ctx just)
(translate_expr ctx cons)
| D.EDefault (exceptions, just, cons) -> | D.EDefault (exceptions, just, cons) ->
translate_default ctx exceptions just cons (Pos.get_position e) translate_default ctx exceptions just cons (Pos.get_position e)
@ -147,7 +164,8 @@ let translate_program (prgm : D.program) : A.program =
(D.VarMap.map (fun v -> A.make_var (v, Pos.no_pos)) ctx) (D.VarMap.map (fun v -> A.make_var (v, Pos.no_pos)) ctx)
(Bindlib.unbox (Bindlib.unbox
(D.build_whole_scope_expr prgm.decl_ctx e (D.build_whole_scope_expr prgm.decl_ctx e
(Pos.get_position (Dcalc.Ast.ScopeName.get_info scope_name))))); (Pos.get_position
(Dcalc.Ast.ScopeName.get_info scope_name)))));
} }
:: acc :: acc
in in

View File

@ -1,18 +1,20 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits (* This file is part of the Catala compiler, a specification language for tax
computation rules. Copyright (C) 2020 Inria, contributor: Denis Merigoux and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
<denis.merigoux@inria.fr> Denis Merigoux <denis.merigoux@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
(** Translation from the default calculus to the lambda calculus. This translation uses exceptions (** Translation from the default calculus to the lambda calculus. This
handle empty default terms. *) translation uses exceptions handle empty default terms. *)
val translate_program : Dcalc.Ast.program -> Ast.program val translate_program : Dcalc.Ast.program -> Ast.program

View File

@ -1,15 +1,17 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits (* This file is part of the Catala compiler, a specification language for tax
computation rules. Copyright (C) 2020-2022 Inria, contributor: Alain Delaët-Tixeuil and social benefits computation rules. Copyright (C) 2020-2022 Inria,
<alain.delaet--tixeuil@inria.fr> contributor: Alain Delaët-Tixeuil <alain.delaet--tixeuil@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
open Utils open Utils
@ -17,39 +19,50 @@ module D = Dcalc.Ast
module A = Ast module A = Ast
open Dcalc.Binded_representation open Dcalc.Binded_representation
(** The main idea around this pass is to compile Dcalc to Lcalc without using [raise EmptyError] nor (** The main idea around this pass is to compile Dcalc to Lcalc without using
[try _ with EmptyError -> _]. To do so, we use the same technique as in rust or erlang to handle [raise EmptyError] nor [try _ with EmptyError -> _]. To do so, we use the
this kind of exceptions. Each [raise EmptyError] will be translated as [None] and each same technique as in rust or erlang to handle this kind of exceptions. Each
[try e1 with EmtpyError -> e2] as [match e1 with | None -> e2 | Some x -> x]. [raise EmptyError] will be translated as [None] and each
[try e1 with EmtpyError -> e2] as
[match e1 with | None -> e2 | Some x -> x].
When doing this naively, this requires to add matches and Some constructor everywhere. We apply When doing this naively, this requires to add matches and Some constructor
here an other technique where we generate what we call `hoists`. Hoists are expression whom everywhere. We apply here an other technique where we generate what we call
could minimally [raise EmptyError]. For instance in `hoists`. Hoists are expression whom could minimally [raise EmptyError]. For
[let x = <e1, e2, ..., en| e_just :- e_cons> * 3 in x + 1], the sub-expression instance in [let x = <e1, e2, ..., en| e_just :- e_cons> * 3 in x + 1], the
[<e1, e2, ..., en| e_just :- e_cons>] can produce an empty error. So we make a hoist with a new sub-expression [<e1, e2, ..., en| e_just :- e_cons>] can produce an empty
variable [y] linked to the Dcalc expression [<e1, e2, ..., en| e_just :- e_cons>], and we return error. So we make a hoist with a new variable [y] linked to the Dcalc
as the translated expression [let x = y * 3 in x + 1]. expression [<e1, e2, ..., en| e_just :- e_cons>], and we return as the
translated expression [let x = y * 3 in x + 1].
The compilation of expressions is found in the functions [translate_and_hoist ctx e] and The compilation of expressions is found in the functions
[translate_expr ctx e]. Every option-generating expression when calling [translate_and_hoist] [translate_and_hoist ctx e] and [translate_expr ctx e]. Every
will be hoisted and later handled by the [translate_expr] function. Every other cases is found option-generating expression when calling [translate_and_hoist] will be
in the translate_and_hoist function. *) hoisted and later handled by the [translate_expr] function. Every other
cases is found in the translate_and_hoist function. *)
type hoists = D.expr Pos.marked A.VarMap.t type hoists = D.expr Pos.marked A.VarMap.t
(** Hoists definition. It represent bindings between [A.Var.t] and [D.expr]. *) (** Hoists definition. It represent bindings between [A.Var.t] and [D.expr]. *)
type info = { expr : A.expr Pos.marked Bindlib.box; var : A.expr Bindlib.var; is_pure : bool } type info = {
(** Information about each encontered Dcalc variable is stored inside a context : what is the expr : A.expr Pos.marked Bindlib.box;
corresponding LCalc variable; an expression corresponding to the variable build correctly using var : A.expr Bindlib.var;
Bindlib, and a boolean `is_pure` indicating whenever the variable can be an EmptyError and hence is_pure : bool;
should be matched (false) or if it never can be EmptyError (true). *) }
(** Information about each encontered Dcalc variable is stored inside a context
: what is the corresponding LCalc variable; an expression corresponding to
the variable build correctly using Bindlib, and a boolean `is_pure`
indicating whenever the variable can be an EmptyError and hence should be
matched (false) or if it never can be EmptyError (true). *)
let pp_info (fmt : Format.formatter) (info : info) = let pp_info (fmt : Format.formatter) (info : info) =
Format.fprintf fmt "{var: %a; is_pure: %b}" Print.format_var info.var info.is_pure Format.fprintf fmt "{var: %a; is_pure: %b}" Print.format_var info.var
info.is_pure
type ctx = { type ctx = {
decl_ctx : D.decl_ctx; decl_ctx : D.decl_ctx;
vars : info D.VarMap.t; (** information context about variables in the current scope *) vars : info D.VarMap.t;
(** information context about variables in the current scope *)
} }
let _pp_ctx (fmt : Format.formatter) (ctx : ctx) = let _pp_ctx (fmt : Format.formatter) (ctx : ctx) =
@ -58,38 +71,48 @@ let _pp_ctx (fmt : Format.formatter) (ctx : ctx) =
in in
let pp_bindings = let pp_bindings =
Format.pp_print_list ~pp_sep:(fun fmt () -> Format.pp_print_string fmt "; ") pp_binding Format.pp_print_list
~pp_sep:(fun fmt () -> Format.pp_print_string fmt "; ")
pp_binding
in in
Format.fprintf fmt "@[<2>[%a]@]" pp_bindings (D.VarMap.bindings ctx.vars) Format.fprintf fmt "@[<2>[%a]@]" pp_bindings (D.VarMap.bindings ctx.vars)
(** [find ~info n ctx] is a warpper to ocaml's Map.find that handle errors in a slightly better way. *) (** [find ~info n ctx] is a warpper to ocaml's Map.find that handle errors in a
slightly better way. *)
let find ?(info : string = "none") (n : D.Var.t) (ctx : ctx) : info = let find ?(info : string = "none") (n : D.Var.t) (ctx : ctx) : info =
(* let _ = Format.asprintf "Searching for variable %a inside context %a" Dcalc.Print.format_var n (* let _ = Format.asprintf "Searching for variable %a inside context %a"
pp_ctx ctx |> Cli.debug_print in *) Dcalc.Print.format_var n pp_ctx ctx |> Cli.debug_print in *)
try D.VarMap.find n ctx.vars try D.VarMap.find n ctx.vars
with Not_found -> with Not_found ->
Errors.raise_spanned_error Pos.no_pos Errors.raise_spanned_error Pos.no_pos
"Internal Error: Variable %a was not found in the current environment. Additional \ "Internal Error: Variable %a was not found in the current environment. \
informations : %s." Additional informations : %s."
Dcalc.Print.format_var n info Dcalc.Print.format_var n info
(** [add_var pos var is_pure ctx] add to the context [ctx] the Dcalc variable var, creating a unique (** [add_var pos var is_pure ctx] add to the context [ctx] the Dcalc variable
corresponding variable in Lcalc, with the corresponding expression, and the boolean is_pure. It var, creating a unique corresponding variable in Lcalc, with the
is usefull for debuging purposes as it printing each of the Dcalc/Lcalc variable pairs. *) corresponding expression, and the boolean is_pure. It is usefull for
debuging purposes as it printing each of the Dcalc/Lcalc variable pairs. *)
let add_var (pos : Pos.t) (var : D.Var.t) (is_pure : bool) (ctx : ctx) : ctx = let add_var (pos : Pos.t) (var : D.Var.t) (is_pure : bool) (ctx : ctx) : ctx =
let new_var = A.Var.make (Bindlib.name_of var, pos) in let new_var = A.Var.make (Bindlib.name_of var, pos) in
let expr = A.make_var (new_var, pos) in let expr = A.make_var (new_var, pos) in
(* Cli.debug_print @@ Format.asprintf "D.%a |-> A.%a" Dcalc.Print.format_var var Print.format_var (* Cli.debug_print @@ Format.asprintf "D.%a |-> A.%a" Dcalc.Print.format_var
new_var; *) var Print.format_var new_var; *)
{ ctx with vars = D.VarMap.update var (fun _ -> Some { expr; var = new_var; is_pure }) ctx.vars } {
ctx with
vars =
D.VarMap.update var
(fun _ -> Some { expr; var = new_var; is_pure })
ctx.vars;
}
(** [tau' = translate_typ tau] translate the a dcalc type into a lcalc type. (** [tau' = translate_typ tau] translate the a dcalc type into a lcalc type.
Since positions where there is thunked expressions is exactly where we will put option Since positions where there is thunked expressions is exactly where we will
expressions. Hence, the transformation simply reduce [unit -> 'a] into ['a option] recursivly. put option expressions. Hence, the transformation simply reduce [unit -> 'a]
There is no polymorphism inside catala. *) into ['a option] recursivly. There is no polymorphism inside catala. *)
let rec translate_typ (tau : D.typ Pos.marked) : D.typ Pos.marked = let rec translate_typ (tau : D.typ Pos.marked) : D.typ Pos.marked =
(Fun.flip Pos.same_pos_as) tau (Fun.flip Pos.same_pos_as) tau
begin begin
@ -101,7 +124,9 @@ let rec translate_typ (tau : D.typ Pos.marked) : D.typ Pos.marked =
| D.TArray ts -> D.TArray (translate_typ ts) | D.TArray ts -> D.TArray (translate_typ ts)
(* catala is not polymorphic *) (* catala is not polymorphic *)
| D.TArrow ((D.TLit D.TUnit, pos_unit), t2) -> | D.TArrow ((D.TLit D.TUnit, pos_unit), t2) ->
D.TEnum ([ (D.TLit D.TUnit, pos_unit); translate_typ t2 ], A.option_enum) (* D.TAny *) D.TEnum
([ (D.TLit D.TUnit, pos_unit); translate_typ t2 ], A.option_enum)
(* D.TAny *)
| D.TArrow (t1, t2) -> D.TArrow (translate_typ t1, translate_typ t2) | D.TArrow (t1, t2) -> D.TArrow (translate_typ t1, translate_typ t2)
end end
@ -116,45 +141,53 @@ let translate_lit (l : D.lit) (pos : Pos.t) : A.lit =
| D.LDuration d -> A.LDuration d | D.LDuration d -> A.LDuration d
| D.LEmptyError -> | D.LEmptyError ->
Errors.raise_spanned_error pos Errors.raise_spanned_error pos
"Internal Error: An empty error was found in a place that shouldn't be possible." "Internal Error: An empty error was found in a place that shouldn't be \
possible."
(** [c = disjoint_union_maps cs] Compute the disjoint union of multiple maps. Raises an internal (** [c = disjoint_union_maps cs] Compute the disjoint union of multiple maps.
error if there is two identicals keys in differnts parts. *) Raises an internal error if there is two identicals keys in differnts parts. *)
let disjoint_union_maps (pos : Pos.t) (cs : 'a A.VarMap.t list) : 'a A.VarMap.t = let disjoint_union_maps (pos : Pos.t) (cs : 'a A.VarMap.t list) : 'a A.VarMap.t
=
let disjoint_union = let disjoint_union =
A.VarMap.union (fun _ _ _ -> A.VarMap.union (fun _ _ _ ->
Errors.raise_spanned_error pos Errors.raise_spanned_error pos
"Internal Error: Two supposed to be disjoints maps have one shared key.") "Internal Error: Two supposed to be disjoints maps have one shared \
key.")
in in
List.fold_left disjoint_union A.VarMap.empty cs List.fold_left disjoint_union A.VarMap.empty cs
(** [e' = translate_and_hoist ctx e ] Translate the Dcalc expression e into an expression in Lcalc, (** [e' = translate_and_hoist ctx e ] Translate the Dcalc expression e into an
given we translate each hoists correctly. It ensures the equivalence between the execution of e expression in Lcalc, given we translate each hoists correctly. It ensures
and the execution of e' are equivalent in an environement where each variable v, where (v, e_v) the equivalence between the execution of e and the execution of e' are
is in hoists, has the non-empty value in e_v. *) equivalent in an environement where each variable v, where (v, e_v) is in
hoists, has the non-empty value in e_v. *)
let rec translate_and_hoist (ctx : ctx) (e : D.expr Pos.marked) : let rec translate_and_hoist (ctx : ctx) (e : D.expr Pos.marked) :
A.expr Pos.marked Bindlib.box * hoists = A.expr Pos.marked Bindlib.box * hoists =
let pos = Pos.get_position e in let pos = Pos.get_position e in
match Pos.unmark e with match Pos.unmark e with
(* empty-producing/using terms. We hoist those. (D.EVar in some cases, EApp(D.EVar _, [ELit (* empty-producing/using terms. We hoist those. (D.EVar in some cases,
LUnit]), EDefault _, ELit LEmptyDefault) I'm unsure about assert. *) EApp(D.EVar _, [ELit LUnit]), EDefault _, ELit LEmptyDefault) I'm unsure
about assert. *)
| D.EVar v -> | D.EVar v ->
(* todo: for now, every unpure (such that [is_pure] is [false] in the current context) is (* todo: for now, every unpure (such that [is_pure] is [false] in the
thunked, hence matched in the next case. This assumption can change in the future, and this current context) is thunked, hence matched in the next case. This
case is here for this reason. *) assumption can change in the future, and this case is here for this
reason. *)
let v, pos_v = v in let v, pos_v = v in
if not (find ~info:"search for a variable" v ctx).is_pure then if not (find ~info:"search for a variable" v ctx).is_pure then
let v' = A.Var.make (Bindlib.name_of v, pos_v) in let v' = A.Var.make (Bindlib.name_of v, pos_v) in
(* Cli.debug_print @@ Format.asprintf "Found an unpure variable %a, created a variable %a to (* Cli.debug_print @@ Format.asprintf "Found an unpure variable %a,
replace it" Dcalc.Print.format_var v Print.format_var v'; *) created a variable %a to replace it" Dcalc.Print.format_var v
Print.format_var v'; *)
(A.make_var (v', pos), A.VarMap.singleton v' e) (A.make_var (v', pos), A.VarMap.singleton v' e)
else ((find ~info:"should never happend" v ctx).expr, A.VarMap.empty) else ((find ~info:"should never happend" v ctx).expr, A.VarMap.empty)
| D.EApp ((D.EVar (v, pos_v), p), [ (D.ELit D.LUnit, _) ]) -> | D.EApp ((D.EVar (v, pos_v), p), [ (D.ELit D.LUnit, _) ]) ->
if not (find ~info:"search for a variable" v ctx).is_pure then if not (find ~info:"search for a variable" v ctx).is_pure then
let v' = A.Var.make (Bindlib.name_of v, pos_v) in let v' = A.Var.make (Bindlib.name_of v, pos_v) in
(* Cli.debug_print @@ Format.asprintf "Found an unpure variable %a, created a variable %a to (* Cli.debug_print @@ Format.asprintf "Found an unpure variable %a,
replace it" Dcalc.Print.format_var v Print.format_var v'; *) created a variable %a to replace it" Dcalc.Print.format_var v
Print.format_var v'; *)
(A.make_var (v', pos), A.VarMap.singleton v' (D.EVar (v, pos_v), p)) (A.make_var (v', pos), A.VarMap.singleton v' (D.EVar (v, pos_v), p))
else else
Errors.raise_spanned_error pos Errors.raise_spanned_error pos
@ -165,10 +198,11 @@ let rec translate_and_hoist (ctx : ctx) (e : D.expr Pos.marked) :
| D.ELit D.LEmptyError -> | D.ELit D.LEmptyError ->
let v' = A.Var.make ("empty_litteral", pos) in let v' = A.Var.make ("empty_litteral", pos) in
(A.make_var (v', pos), A.VarMap.singleton v' e) (A.make_var (v', pos), A.VarMap.singleton v' e)
(* This one is a very special case. It transform an unpure expression environement to a pure (* This one is a very special case. It transform an unpure expression
expression. *) environement to a pure expression. *)
| ErrorOnEmpty arg -> | ErrorOnEmpty arg ->
(* [ match arg with | None -> raise NoValueProvided | Some v -> {{ v }} ] *) (* [ match arg with | None -> raise NoValueProvided | Some v -> {{ v }}
] *)
let silent_var = A.Var.make ("_", pos) in let silent_var = A.Var.make ("_", pos) in
let x = A.Var.make ("non_empty_argument", pos) in let x = A.Var.make ("non_empty_argument", pos) in
@ -177,7 +211,9 @@ let rec translate_and_hoist (ctx : ctx) (e : D.expr Pos.marked) :
( A.make_matchopt_with_abs_arms arg' ( A.make_matchopt_with_abs_arms arg'
(A.make_abs [| silent_var |] (A.make_abs [| silent_var |]
(Bindlib.box (A.ERaise A.NoValueProvided, pos)) (Bindlib.box (A.ERaise A.NoValueProvided, pos))
pos [ (D.TAny, pos) ] pos) pos
[ (D.TAny, pos) ]
pos)
(A.make_abs [| x |] (A.make_var (x, pos)) pos [ (D.TAny, pos) ] pos), (A.make_abs [| x |] (A.make_var (x, pos)) pos [ (D.TAny, pos) ] pos),
A.VarMap.empty ) A.VarMap.empty )
(* pure terms *) (* pure terms *)
@ -188,43 +224,51 @@ let rec translate_and_hoist (ctx : ctx) (e : D.expr Pos.marked) :
let e3', h3 = translate_and_hoist ctx e3 in let e3', h3 = translate_and_hoist ctx e3 in
let e' = let e' =
Bindlib.box_apply3 (fun e1' e2' e3' -> (A.EIfThenElse (e1', e2', e3'), pos)) e1' e2' e3' Bindlib.box_apply3
(fun e1' e2' e3' -> (A.EIfThenElse (e1', e2', e3'), pos))
e1' e2' e3'
in in
(*(* equivalent code : *) let e' = let+ e1' = e1' and+ e2' = e2' and+ e3' = e3' in (*(* equivalent code : *) let e' = let+ e1' = e1' and+ e2' = e2' and+ e3'
(A.EIfThenElse (e1', e2', e3'), pos) in *) = e3' in (A.EIfThenElse (e1', e2', e3'), pos) in *)
(e', disjoint_union_maps pos [ h1; h2; h3 ]) (e', disjoint_union_maps pos [ h1; h2; h3 ])
| D.EAssert e1 -> | D.EAssert e1 ->
(* same behavior as in the ICFP paper: if e1 is empty, then no error is raised. *) (* same behavior as in the ICFP paper: if e1 is empty, then no error is
raised. *)
let e1', h1 = translate_and_hoist ctx e1 in let e1', h1 = translate_and_hoist ctx e1 in
(Bindlib.box_apply (fun e1' -> (A.EAssert e1', pos)) e1', h1) (Bindlib.box_apply (fun e1' -> (A.EAssert e1', pos)) e1', h1)
| D.EAbs ((binder, pos_binder), ts) -> | D.EAbs ((binder, pos_binder), ts) ->
let vars, body = Bindlib.unmbind binder in let vars, body = Bindlib.unmbind binder in
let ctx, lc_vars = let ctx, lc_vars =
ArrayLabels.fold_right vars ~init:(ctx, []) ~f:(fun var (ctx, lc_vars) -> ArrayLabels.fold_right vars ~init:(ctx, [])
(* we suppose the invariant that when applying a function, its arguments cannot be of ~f:(fun var (ctx, lc_vars) ->
the type "option". (* we suppose the invariant that when applying a function, its
arguments cannot be of the type "option".
The code should behave correctly in the without this assumption if we put here an The code should behave correctly in the without this assumption
is_pure=false, but the types are more compilcated. (unimplemented for now) *) if we put here an is_pure=false, but the types are more
compilcated. (unimplemented for now) *)
let ctx = add_var pos var true ctx in let ctx = add_var pos var true ctx in
let lc_var = (find var ctx).var in let lc_var = (find var ctx).var in
(ctx, lc_var :: lc_vars)) (ctx, lc_var :: lc_vars))
in in
let lc_vars = Array.of_list lc_vars in let lc_vars = Array.of_list lc_vars in
(* here we take the guess that if we cannot build the closure because one of the variable is (* here we take the guess that if we cannot build the closure because one
empty, then we cannot build the function. *) of the variable is empty, then we cannot build the function. *)
let new_body, hoists = translate_and_hoist ctx body in let new_body, hoists = translate_and_hoist ctx body in
let new_binder = Bindlib.bind_mvar lc_vars new_body in let new_binder = Bindlib.bind_mvar lc_vars new_body in
( Bindlib.box_apply ( Bindlib.box_apply
(fun new_binder -> (A.EAbs ((new_binder, pos_binder), List.map translate_typ ts), pos)) (fun new_binder ->
(A.EAbs ((new_binder, pos_binder), List.map translate_typ ts), pos))
new_binder, new_binder,
hoists ) hoists )
| EApp (e1, args) -> | EApp (e1, args) ->
let e1', h1 = translate_and_hoist ctx e1 in let e1', h1 = translate_and_hoist ctx e1 in
let args', h_args = args |> List.map (translate_and_hoist ctx) |> List.split in let args', h_args =
args |> List.map (translate_and_hoist ctx) |> List.split
in
let hoists = disjoint_union_maps pos (h1 :: h_args) in let hoists = disjoint_union_maps pos (h1 :: h_args) in
let e' = let e' =
@ -234,21 +278,32 @@ let rec translate_and_hoist (ctx : ctx) (e : D.expr Pos.marked) :
in in
(e', hoists) (e', hoists)
| ETuple (args, s) -> | ETuple (args, s) ->
let args', h_args = args |> List.map (translate_and_hoist ctx) |> List.split in let args', h_args =
args |> List.map (translate_and_hoist ctx) |> List.split
in
let hoists = disjoint_union_maps pos h_args in let hoists = disjoint_union_maps pos h_args in
(Bindlib.box_apply (fun args' -> (A.ETuple (args', s), pos)) (Bindlib.box_list args'), hoists) ( Bindlib.box_apply
(fun args' -> (A.ETuple (args', s), pos))
(Bindlib.box_list args'),
hoists )
| ETupleAccess (e1, i, s, ts) -> | ETupleAccess (e1, i, s, ts) ->
let e1', hoists = translate_and_hoist ctx e1 in let e1', hoists = translate_and_hoist ctx e1 in
let e1' = Bindlib.box_apply (fun e1' -> (A.ETupleAccess (e1', i, s, ts), pos)) e1' in let e1' =
Bindlib.box_apply (fun e1' -> (A.ETupleAccess (e1', i, s, ts), pos)) e1'
in
(e1', hoists) (e1', hoists)
| EInj (e1, i, en, ts) -> | EInj (e1, i, en, ts) ->
let e1', hoists = translate_and_hoist ctx e1 in let e1', hoists = translate_and_hoist ctx e1 in
let e1' = Bindlib.box_apply (fun e1' -> (A.EInj (e1', i, en, ts), pos)) e1' in let e1' =
Bindlib.box_apply (fun e1' -> (A.EInj (e1', i, en, ts), pos)) e1'
in
(e1', hoists) (e1', hoists)
| EMatch (e1, cases, en) -> | EMatch (e1, cases, en) ->
let e1', h1 = translate_and_hoist ctx e1 in let e1', h1 = translate_and_hoist ctx e1 in
let cases', h_cases = cases |> List.map (translate_and_hoist ctx) |> List.split in let cases', h_cases =
cases |> List.map (translate_and_hoist ctx) |> List.split
in
let hoists = disjoint_union_maps pos (h1 :: h_cases) in let hoists = disjoint_union_maps pos (h1 :: h_cases) in
let e' = let e' =
@ -258,7 +313,9 @@ let rec translate_and_hoist (ctx : ctx) (e : D.expr Pos.marked) :
in in
(e', hoists) (e', hoists)
| EArray es -> | EArray es ->
let es', hoists = es |> List.map (translate_and_hoist ctx) |> List.split in let es', hoists =
es |> List.map (translate_and_hoist ctx) |> List.split
in
( Bindlib.box_apply (fun es' -> (A.EArray es', pos)) (Bindlib.box_list es'), ( Bindlib.box_apply (fun es' -> (A.EArray es', pos)) (Bindlib.box_list es'),
disjoint_union_maps pos hoists ) disjoint_union_maps pos hoists )
@ -272,17 +329,19 @@ and translate_expr ?(append_esome = true) (ctx : ctx) (e : D.expr Pos.marked) :
let _pos = Pos.get_position e in let _pos = Pos.get_position e in
(* build the hoists *) (* build the hoists *)
(* Cli.debug_print @@ Format.asprintf "hoist for the expression: [%a]" (Format.pp_print_list (* Cli.debug_print @@ Format.asprintf "hoist for the expression: [%a]"
Print.format_var) (List.map fst hoists); *) (Format.pp_print_list Print.format_var) (List.map fst hoists); *)
ListLabels.fold_left hoists ListLabels.fold_left hoists
~init:(if append_esome then A.make_some e' else e') ~init:(if append_esome then A.make_some e' else e')
~f:(fun acc (v, (hoist, pos_hoist)) -> ~f:(fun acc (v, (hoist, pos_hoist)) ->
(* Cli.debug_print @@ Format.asprintf "hoist using A.%a" Print.format_var v; *) (* Cli.debug_print @@ Format.asprintf "hoist using A.%a" Print.format_var
v; *)
let c' : A.expr Pos.marked Bindlib.box = let c' : A.expr Pos.marked Bindlib.box =
match hoist with match hoist with
(* Here we have to handle only the cases appearing in hoists, as defined the (* Here we have to handle only the cases appearing in hoists, as defined
[translate_and_hoist] function. *) the [translate_and_hoist] function. *)
| D.EVar v -> (find ~info:"should never happend" (Pos.unmark v) ctx).expr | D.EVar v ->
(find ~info:"should never happend" (Pos.unmark v) ctx).expr
| D.EDefault (excep, just, cons) -> | D.EDefault (excep, just, cons) ->
let excep' = List.map (translate_expr ctx) excep in let excep' = List.map (translate_expr ctx) excep in
let just' = translate_expr ctx just in let just' = translate_expr ctx just in
@ -302,27 +361,36 @@ and translate_expr ?(append_esome = true) (ctx : ctx) (e : D.expr Pos.marked) :
| D.EAssert arg -> | D.EAssert arg ->
let arg' = translate_expr ctx arg in let arg' = translate_expr ctx arg in
(* [ match arg with | None -> raise NoValueProvided | Some v -> assert {{ v }} ] *) (* [ match arg with | None -> raise NoValueProvided | Some v ->
assert {{ v }} ] *)
let silent_var = A.Var.make ("_", pos_hoist) in let silent_var = A.Var.make ("_", pos_hoist) in
let x = A.Var.make ("assertion_argument", pos_hoist) in let x = A.Var.make ("assertion_argument", pos_hoist) in
A.make_matchopt_with_abs_arms arg' A.make_matchopt_with_abs_arms arg'
(A.make_abs [| silent_var |] (A.make_abs [| silent_var |]
(Bindlib.box (A.ERaise A.NoValueProvided, pos_hoist)) (Bindlib.box (A.ERaise A.NoValueProvided, pos_hoist))
pos_hoist [ (D.TAny, pos_hoist) ] pos_hoist) pos_hoist
[ (D.TAny, pos_hoist) ]
pos_hoist)
(A.make_abs [| x |] (A.make_abs [| x |]
(Bindlib.box_apply (Bindlib.box_apply
(fun arg -> (A.EAssert arg, pos_hoist)) (fun arg -> (A.EAssert arg, pos_hoist))
(A.make_var (x, pos_hoist))) (A.make_var (x, pos_hoist)))
pos_hoist [ (D.TAny, pos_hoist) ] pos_hoist) pos_hoist
[ (D.TAny, pos_hoist) ]
pos_hoist)
| _ -> | _ ->
Errors.raise_spanned_error pos_hoist Errors.raise_spanned_error pos_hoist
"Internal Error: An term was found in a position where it should not be" "Internal Error: An term was found in a position where it should \
not be"
in in
(* [ match {{ c' }} with | None -> None | Some {{ v }} -> {{ acc }} end ] *) (* [ match {{ c' }} with | None -> None | Some {{ v }} -> {{ acc }} end
(* Cli.debug_print @@ Format.asprintf "build matchopt using %a" Print.format_var v; *) ] *)
A.make_matchopt pos_hoist v (D.TAny, pos_hoist) c' (A.make_none pos_hoist) acc) (* Cli.debug_print @@ Format.asprintf "build matchopt using %a"
Print.format_var v; *)
A.make_matchopt pos_hoist v (D.TAny, pos_hoist) c' (A.make_none pos_hoist)
acc)
let rec translate_scope_let (ctx : ctx) (lets : scope_lets) = let rec translate_scope_let (ctx : ctx) (lets : scope_lets) =
match lets with match lets with
@ -335,14 +403,18 @@ let rec translate_scope_let (ctx : ctx) (lets : scope_lets) =
scope_let_next = next; scope_let_next = next;
scope_let_pos = pos; scope_let_pos = pos;
} -> } ->
(* special case : the subscope variable is thunked (context i/o). We remove this thunking. *) (* special case : the subscope variable is thunked (context i/o). We
remove this thunking. *)
let _, expr = Bindlib.unmbind binder in let _, expr = Bindlib.unmbind binder in
let var_is_pure = true in let var_is_pure = true in
let var, next = Bindlib.unbind next in let var, next = Bindlib.unbind next in
(* Cli.debug_print @@ Format.asprintf "unbinding %a" Dcalc.Print.format_var var; *) (* Cli.debug_print @@ Format.asprintf "unbinding %a"
Dcalc.Print.format_var var; *)
let ctx' = add_var pos var var_is_pure ctx in let ctx' = add_var pos var var_is_pure ctx in
let new_var = (find ~info:"variable that was just created" var ctx').var in let new_var =
(find ~info:"variable that was just created" var ctx').var
in
A.make_let_in new_var (translate_typ typ) A.make_let_in new_var (translate_typ typ)
(translate_expr ctx ~append_esome:false expr) (translate_expr ctx ~append_esome:false expr)
(translate_scope_let ctx' next) (translate_scope_let ctx' next)
@ -357,17 +429,26 @@ let rec translate_scope_let (ctx : ctx) (lets : scope_lets) =
(* special case: regular input to the subscope *) (* special case: regular input to the subscope *)
let var_is_pure = true in let var_is_pure = true in
let var, next = Bindlib.unbind next in let var, next = Bindlib.unbind next in
(* Cli.debug_print @@ Format.asprintf "unbinding %a" Dcalc.Print.format_var var; *) (* Cli.debug_print @@ Format.asprintf "unbinding %a"
Dcalc.Print.format_var var; *)
let ctx' = add_var pos var var_is_pure ctx in let ctx' = add_var pos var var_is_pure ctx in
let new_var = (find ~info:"variable that was just created" var ctx').var in let new_var =
(find ~info:"variable that was just created" var ctx').var
in
A.make_let_in new_var (translate_typ typ) A.make_let_in new_var (translate_typ typ)
(translate_expr ctx ~append_esome:false expr) (translate_expr ctx ~append_esome:false expr)
(translate_scope_let ctx' next) (translate_scope_let ctx' next)
| ScopeLet | ScopeLet
{ scope_let_kind = SubScopeVarDefinition; scope_let_pos = pos; scope_let_expr = expr; _ } -> {
scope_let_kind = SubScopeVarDefinition;
scope_let_pos = pos;
scope_let_expr = expr;
_;
} ->
Errors.raise_spanned_error pos Errors.raise_spanned_error pos
"Internal Error: found an SubScopeVarDefinition that does not satisfy the invariants when \ "Internal Error: found an SubScopeVarDefinition that does not satisfy \
translating Dcalc to Lcalc without exceptions: @[<hov 2>%a@]" the invariants when translating Dcalc to Lcalc without exceptions: \
@[<hov 2>%a@]"
(Dcalc.Print.format_expr ctx.decl_ctx) (Dcalc.Print.format_expr ctx.decl_ctx)
expr expr
| ScopeLet | ScopeLet
@ -381,18 +462,24 @@ let rec translate_scope_let (ctx : ctx) (lets : scope_lets) =
let var_is_pure = let var_is_pure =
match kind with match kind with
| DestructuringInputStruct -> ( | DestructuringInputStruct -> (
(* Here, we have to distinguish between context and input variables. We can do so by (* Here, we have to distinguish between context and input variables.
looking at the typ of the destructuring: if it's thunked, then the variable is We can do so by looking at the typ of the destructuring: if it's
context. If it's not thunked, it's a regular input. *) thunked, then the variable is context. If it's not thunked, it's
match Pos.unmark typ with D.TArrow ((D.TLit D.TUnit, _), _) -> false | _ -> true) a regular input. *)
match Pos.unmark typ with
| D.TArrow ((D.TLit D.TUnit, _), _) -> false
| _ -> true)
| ScopeVarDefinition | SubScopeVarDefinition | CallingSubScope | ScopeVarDefinition | SubScopeVarDefinition | CallingSubScope
| DestructuringSubScopeResults | Assertion -> | DestructuringSubScopeResults | Assertion ->
true true
in in
let var, next = Bindlib.unbind next in let var, next = Bindlib.unbind next in
(* Cli.debug_print @@ Format.asprintf "unbinding %a" Dcalc.Print.format_var var; *) (* Cli.debug_print @@ Format.asprintf "unbinding %a"
Dcalc.Print.format_var var; *)
let ctx' = add_var pos var var_is_pure ctx in let ctx' = add_var pos var var_is_pure ctx in
let new_var = (find ~info:"variable that was just created" var ctx').var in let new_var =
(find ~info:"variable that was just created" var ctx').var
in
A.make_let_in new_var (translate_typ typ) A.make_let_in new_var (translate_typ typ)
(translate_expr ctx ~append_esome:false expr) (translate_expr ctx ~append_esome:false expr)
(translate_scope_let ctx' next) (translate_scope_let ctx' next)
@ -409,17 +496,22 @@ let translate_scope_body (scope_pos : Pos.t) (ctx : ctx) (body : scope_body) :
let ctx' = add_var scope_pos v true ctx in let ctx' = add_var scope_pos v true ctx in
let v' = (find ~info:"variable that was just created" v ctx').var in let v' = (find ~info:"variable that was just created" v ctx').var in
A.make_abs [| v' |] (translate_scope_let ctx' lets) Pos.no_pos A.make_abs [| v' |]
(translate_scope_let ctx' lets)
Pos.no_pos
[ (D.TTuple ([], Some input_struct), Pos.no_pos) ] [ (D.TTuple ([], Some input_struct), Pos.no_pos) ]
Pos.no_pos Pos.no_pos
let rec translate_scopes (ctx : ctx) (scopes : scopes) : Ast.scope_body list Bindlib.box = let rec translate_scopes (ctx : ctx) (scopes : scopes) :
Ast.scope_body list Bindlib.box =
match scopes with match scopes with
| Nil -> Bindlib.box [] | Nil -> Bindlib.box []
| ScopeDef { scope_name; scope_body; scope_next } -> | ScopeDef { scope_name; scope_body; scope_next } ->
let scope_var, next = Bindlib.unbind scope_next in let scope_var, next = Bindlib.unbind scope_next in
let new_ctx = add_var Pos.no_pos scope_var true ctx in let new_ctx = add_var Pos.no_pos scope_var true ctx in
let new_scope_name = (find ~info:"variable that was just created" scope_var new_ctx).var in let new_scope_name =
(find ~info:"variable that was just created" scope_var new_ctx).var
in
let scope_pos = Pos.get_position (D.ScopeName.get_info scope_name) in let scope_pos = Pos.get_position (D.ScopeName.get_info scope_name) in
@ -445,12 +537,14 @@ let translate_program (prgm : D.program) : A.program =
body.D.scope_body_input_struct :: acc) body.D.scope_body_input_struct :: acc)
in in
(* Cli.debug_print @@ Format.asprintf "List of structs to modify: [%a]" (Format.pp_print_list (* Cli.debug_print @@ Format.asprintf "List of structs to modify: [%a]"
D.StructName.format_t) inputs_structs; *) (Format.pp_print_list D.StructName.format_t) inputs_structs; *)
let decl_ctx = let decl_ctx =
{ {
prgm.decl_ctx with prgm.decl_ctx with
D.ctx_enums = prgm.decl_ctx.ctx_enums |> D.EnumMap.add A.option_enum A.option_enum_config; D.ctx_enums =
prgm.decl_ctx.ctx_enums
|> D.EnumMap.add A.option_enum A.option_enum_config;
} }
in in
let decl_ctx = let decl_ctx =
@ -461,9 +555,11 @@ let translate_program (prgm : D.program) : A.program =
|> D.StructMap.mapi (fun n l -> |> D.StructMap.mapi (fun n l ->
if List.mem n inputs_structs then if List.mem n inputs_structs then
ListLabels.map l ~f:(fun (n, tau) -> ListLabels.map l ~f:(fun (n, tau) ->
(* Cli.debug_print @@ Format.asprintf "Input type: %a" (Dcalc.Print.format_typ (* Cli.debug_print @@ Format.asprintf "Input type: %a"
decl_ctx) tau; Cli.debug_print @@ Format.asprintf "Output type: %a" (Dcalc.Print.format_typ decl_ctx) tau; Cli.debug_print
(Dcalc.Print.format_typ decl_ctx) (translate_typ tau); *) @@ Format.asprintf "Output type: %a"
(Dcalc.Print.format_typ decl_ctx) (translate_typ
tau); *)
(n, translate_typ tau)) (n, translate_typ tau))
else l); else l);
} }

View File

@ -1,19 +1,22 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits (* This file is part of the Catala compiler, a specification language for tax
computation rules. Copyright (C) 2020-2022 Inria, contributor: Alain Delaët-Tixeuil and social benefits computation rules. Copyright (C) 2020-2022 Inria,
<alain.delaet--tixeuil@inria.fr> contributor: Alain Delaët-Tixeuil <alain.delaet--tixeuil@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
(** Translation from the default calculus to the lambda calculus. This translation uses an option (** Translation from the default calculus to the lambda calculus. This
monad to handle empty defaults terms. This transformation is one piece to permit to compile translation uses an option monad to handle empty defaults terms. This
toward legacy languages that does not contains exceptions. *) transformation is one piece to permit to compile toward legacy languages
that does not contains exceptions. *)
val translate_program : Dcalc.Ast.program -> Ast.program val translate_program : Dcalc.Ast.program -> Ast.program

View File

@ -1,27 +1,30 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits (* This file is part of the Catala compiler, a specification language for tax
computation rules. Copyright (C) 2020 Inria, contributor: Denis Merigoux and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
<denis.merigoux@inria.fr> Denis Merigoux <denis.merigoux@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
open Utils open Utils
open Ast open Ast
let ( let+ ) x f = Bindlib.box_apply f x let ( let+ ) x f = Bindlib.box_apply f x
let ( and+ ) x y = Bindlib.box_pair x y let ( and+ ) x y = Bindlib.box_pair x y
let visitor_map (t : 'a -> expr Pos.marked -> expr Pos.marked Bindlib.box) (ctx : 'a) let visitor_map
(t : 'a -> expr Pos.marked -> expr Pos.marked Bindlib.box)
(ctx : 'a)
(e : expr Pos.marked) : expr Pos.marked Bindlib.box = (e : expr Pos.marked) : expr Pos.marked Bindlib.box =
(* calls [t ctx] on every direct childs of [e], then rebuild an abstract syntax tree modified. (* calls [t ctx] on every direct childs of [e], then rebuild an abstract
Used in other transformations. *) syntax tree modified. Used in other transformations. *)
let default_mark e' = Pos.same_pos_as e' e in let default_mark e' = Pos.same_pos_as e' e in
match Pos.unmark e with match Pos.unmark e with
| EVar (v, pos) -> | EVar (v, pos) ->
@ -37,7 +40,8 @@ let visitor_map (t : 'a -> expr Pos.marked -> expr Pos.marked Bindlib.box) (ctx
let+ e1 = t ctx e1 in let+ e1 = t ctx e1 in
default_mark @@ EInj (e1, i, n, ts) default_mark @@ EInj (e1, i, n, ts)
| EMatch (arg, cases, n) -> | EMatch (arg, cases, n) ->
let+ arg = t ctx arg and+ cases = cases |> List.map (t ctx) |> Bindlib.box_list in let+ arg = t ctx arg
and+ cases = cases |> List.map (t ctx) |> Bindlib.box_list in
default_mark @@ EMatch (arg, cases, n) default_mark @@ EMatch (arg, cases, n)
| EArray args -> | EArray args ->
let+ args = args |> List.map (t ctx) |> Bindlib.box_list in let+ args = args |> List.map (t ctx) |> Bindlib.box_list in
@ -48,7 +52,8 @@ let visitor_map (t : 'a -> expr Pos.marked -> expr Pos.marked Bindlib.box) (ctx
let+ binder = Bindlib.bind_mvar vars body in let+ binder = Bindlib.bind_mvar vars body in
default_mark @@ EAbs ((binder, pos_binder), ts) default_mark @@ EAbs ((binder, pos_binder), ts)
| EApp (e1, args) -> | EApp (e1, args) ->
let+ e1 = t ctx e1 and+ args = args |> List.map (t ctx) |> Bindlib.box_list in let+ e1 = t ctx e1
and+ args = args |> List.map (t ctx) |> Bindlib.box_list in
default_mark @@ EApp (e1, args) default_mark @@ EApp (e1, args)
| EAssert e1 -> | EAssert e1 ->
let+ e1 = t ctx e1 in let+ e1 = t ctx e1 in
@ -61,10 +66,12 @@ let visitor_map (t : 'a -> expr Pos.marked -> expr Pos.marked Bindlib.box) (ctx
default_mark @@ ECatch (e1, exn, e2) default_mark @@ ECatch (e1, exn, e2)
| ERaise _ | ELit _ | EOp _ -> Bindlib.box e | ERaise _ | ELit _ | EOp _ -> Bindlib.box e
let rec iota_expr (_ : unit) (e : expr Pos.marked) : expr Pos.marked Bindlib.box = let rec iota_expr (_ : unit) (e : expr Pos.marked) : expr Pos.marked Bindlib.box
=
let default_mark e' = Pos.mark (Pos.get_position e) e' in let default_mark e' = Pos.mark (Pos.get_position e) e' in
match Pos.unmark e with match Pos.unmark e with
| EMatch ((EInj (e1, i, n', _ts), _), cases, n) when Dcalc.Ast.EnumName.compare n n' = 0 -> | EMatch ((EInj (e1, i, n', _ts), _), cases, n)
when Dcalc.Ast.EnumName.compare n n' = 0 ->
let+ e1 = visitor_map iota_expr () e1 let+ e1 = visitor_map iota_expr () e1
and+ case = visitor_map iota_expr () (List.nth cases i) in and+ case = visitor_map iota_expr () (List.nth cases i) in
default_mark @@ EApp (case, [ e1 ]) default_mark @@ EApp (case, [ e1 ])
@ -79,11 +86,13 @@ let rec iota_expr (_ : unit) (e : expr Pos.marked) : expr Pos.marked Bindlib.box
visitor_map iota_expr () e' visitor_map iota_expr () e'
| _ -> visitor_map iota_expr () e | _ -> visitor_map iota_expr () e
let rec beta_expr (_ : unit) (e : expr Pos.marked) : expr Pos.marked Bindlib.box = let rec beta_expr (_ : unit) (e : expr Pos.marked) : expr Pos.marked Bindlib.box
=
let default_mark e' = Pos.same_pos_as e' e in let default_mark e' = Pos.same_pos_as e' e in
match Pos.unmark e with match Pos.unmark e with
| EApp (e1, args) -> ( | EApp (e1, args) -> (
let+ e1 = beta_expr () e1 and+ args = List.map (beta_expr ()) args |> Bindlib.box_list in let+ e1 = beta_expr () e1
and+ args = List.map (beta_expr ()) args |> Bindlib.box_list in
match Pos.unmark e1 with match Pos.unmark e1 with
| EAbs ((binder, _pos_binder), _ts) -> | EAbs ((binder, _pos_binder), _ts) ->
let (_ : (_, _) Bindlib.mbinder) = binder in let (_ : (_, _) Bindlib.mbinder) = binder in
@ -99,14 +108,16 @@ let iota_optimizations (p : program) : program =
(fun scope_body -> (fun scope_body ->
{ {
scope_body with scope_body with
scope_body_expr = Bindlib.unbox (iota_expr () scope_body.scope_body_expr); scope_body_expr =
Bindlib.unbox (iota_expr () scope_body.scope_body_expr);
}) })
p.scopes; p.scopes;
} }
(* TODO: beta optimizations apply inlining of the program. We left the inclusion of (* TODO: beta optimizations apply inlining of the program. We left the inclusion
beta-optimization as future work since its produce code that is harder to read, and can produce of beta-optimization as future work since its produce code that is harder to
exponential blowup of the size of the generated program. *) read, and can produce exponential blowup of the size of the generated
program. *)
let _beta_optimizations (p : program) : program = let _beta_optimizations (p : program) : program =
{ {
p with p with
@ -115,20 +126,28 @@ let _beta_optimizations (p : program) : program =
(fun scope_body -> (fun scope_body ->
{ {
scope_body with scope_body with
scope_body_expr = Bindlib.unbox (beta_expr () scope_body.scope_body_expr); scope_body_expr =
Bindlib.unbox (beta_expr () scope_body.scope_body_expr);
}) })
p.scopes; p.scopes;
} }
let rec peephole_expr (_ : unit) (e : expr Pos.marked) : expr Pos.marked Bindlib.box = let rec peephole_expr (_ : unit) (e : expr Pos.marked) :
expr Pos.marked Bindlib.box =
let default_mark e' = Pos.mark (Pos.get_position e) e' in let default_mark e' = Pos.mark (Pos.get_position e) e' in
match Pos.unmark e with match Pos.unmark e with
| EIfThenElse (e1, e2, e3) -> ( | EIfThenElse (e1, e2, e3) -> (
let+ e1 = peephole_expr () e1 and+ e2 = peephole_expr () e2 and+ e3 = peephole_expr () e3 in let+ e1 = peephole_expr () e1
and+ e2 = peephole_expr () e2
and+ e3 = peephole_expr () e3 in
match Pos.unmark e1 with match Pos.unmark e1 with
| ELit (LBool true) | EApp ((EOp (Unop (Log _)), _), [ (ELit (LBool true), _) ]) -> e2 | ELit (LBool true)
| ELit (LBool false) | EApp ((EOp (Unop (Log _)), _), [ (ELit (LBool false), _) ]) -> e3 | EApp ((EOp (Unop (Log _)), _), [ (ELit (LBool true), _) ]) ->
e2
| ELit (LBool false)
| EApp ((EOp (Unop (Log _)), _), [ (ELit (LBool false), _) ]) ->
e3
| _ -> default_mark @@ EIfThenElse (e1, e2, e3)) | _ -> default_mark @@ EIfThenElse (e1, e2, e3))
| _ -> visitor_map peephole_expr () e | _ -> visitor_map peephole_expr () e
@ -140,9 +159,11 @@ let peephole_optimizations (p : program) : program =
(fun scope_body -> (fun scope_body ->
{ {
scope_body with scope_body with
scope_body_expr = Bindlib.unbox (peephole_expr () scope_body.scope_body_expr); scope_body_expr =
Bindlib.unbox (peephole_expr () scope_body.scope_body_expr);
}) })
p.scopes; p.scopes;
} }
let optimize_program (p : program) : program = p |> iota_optimizations |> peephole_optimizations let optimize_program (p : program) : program =
p |> iota_optimizations |> peephole_optimizations

View File

@ -1,15 +1,17 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits (* This file is part of the Catala compiler, a specification language for tax
computation rules. Copyright (C) 2020 Inria, contributor: Denis Merigoux and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
<denis.merigoux@inria.fr> Denis Merigoux <denis.merigoux@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
open Ast open Ast

View File

@ -1,15 +1,17 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits (* This file is part of the Catala compiler, a specification language for tax
computation rules. Copyright (C) 2020 Inria, contributor: Denis Merigoux and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
<denis.merigoux@inria.fr> Denis Merigoux <denis.merigoux@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
open Utils open Utils
@ -38,12 +40,18 @@ let format_lit (fmt : Format.formatter) (l : lit Pos.marked) : unit =
(Runtime.decimal_to_string ~max_prec_digits:!Utils.Cli.max_prec_digits i) (Runtime.decimal_to_string ~max_prec_digits:!Utils.Cli.max_prec_digits i)
| LMoney e -> ( | LMoney e -> (
match !Utils.Cli.locale_lang with match !Utils.Cli.locale_lang with
| En -> Dcalc.Print.format_lit_style fmt (Format.asprintf "$%s" (Runtime.money_to_string e)) | En ->
| Fr -> Dcalc.Print.format_lit_style fmt (Format.asprintf "%s €" (Runtime.money_to_string e)) Dcalc.Print.format_lit_style fmt
(Format.asprintf "$%s" (Runtime.money_to_string e))
| Fr ->
Dcalc.Print.format_lit_style fmt
(Format.asprintf "%s €" (Runtime.money_to_string e))
| Pl -> | Pl ->
Dcalc.Print.format_lit_style fmt (Format.asprintf "%s PLN" (Runtime.money_to_string e))) Dcalc.Print.format_lit_style fmt
(Format.asprintf "%s PLN" (Runtime.money_to_string e)))
| LDate d -> Dcalc.Print.format_lit_style fmt (Runtime.date_to_string d) | LDate d -> Dcalc.Print.format_lit_style fmt (Runtime.date_to_string d)
| LDuration d -> Dcalc.Print.format_lit_style fmt (Runtime.duration_to_string d) | LDuration d ->
Dcalc.Print.format_lit_style fmt (Runtime.duration_to_string d)
let format_exception (fmt : Format.formatter) (exn : except) : unit = let format_exception (fmt : Format.formatter) (exn : except) : unit =
Dcalc.Print.format_operator fmt Dcalc.Print.format_operator fmt
@ -65,12 +73,16 @@ let needs_parens (e : expr Pos.marked) : bool =
let format_var (fmt : Format.formatter) (v : Var.t) : unit = let format_var (fmt : Format.formatter) (v : Var.t) : unit =
Format.fprintf fmt "%s_%d" (Bindlib.name_of v) (Bindlib.uid_of v) Format.fprintf fmt "%s_%d" (Bindlib.name_of v) (Bindlib.uid_of v)
let rec format_expr (ctx : Dcalc.Ast.decl_ctx) ?(debug : bool = false) (fmt : Format.formatter) let rec format_expr
(ctx : Dcalc.Ast.decl_ctx)
?(debug : bool = false)
(fmt : Format.formatter)
(e : expr Pos.marked) : unit = (e : expr Pos.marked) : unit =
let format_expr = format_expr ctx ~debug in let format_expr = format_expr ctx ~debug in
let format_with_parens (fmt : Format.formatter) (e : expr Pos.marked) = let format_with_parens (fmt : Format.formatter) (e : expr Pos.marked) =
if needs_parens e then if needs_parens e then
Format.fprintf fmt "%a%a%a" format_punctuation "(" format_expr e format_punctuation ")" Format.fprintf fmt "%a%a%a" format_punctuation "(" format_expr e
format_punctuation ")"
else Format.fprintf fmt "%a" format_expr e else Format.fprintf fmt "%a" format_expr e
in in
match Pos.unmark e with match Pos.unmark e with
@ -82,15 +94,16 @@ let rec format_expr (ctx : Dcalc.Ast.decl_ctx) ?(debug : bool = false) (fmt : Fo
(fun fmt e -> Format.fprintf fmt "%a" format_expr e)) (fun fmt e -> Format.fprintf fmt "%a" format_expr e))
es format_punctuation ")" es format_punctuation ")"
| ETuple (es, Some s) -> | ETuple (es, Some s) ->
Format.fprintf fmt "@[<hov 2>%a@ %a%a%a@]" Dcalc.Ast.StructName.format_t s format_punctuation Format.fprintf fmt "@[<hov 2>%a@ %a%a%a@]" Dcalc.Ast.StructName.format_t s
"{" format_punctuation "{"
(Format.pp_print_list (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ") ~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ")
(fun fmt (e, struct_field) -> (fun fmt (e, struct_field) ->
Format.fprintf fmt "%a%a%a%a %a" format_punctuation "\"" Format.fprintf fmt "%a%a%a%a %a" format_punctuation "\""
Dcalc.Ast.StructFieldName.format_t struct_field format_punctuation "\"" Dcalc.Ast.StructFieldName.format_t struct_field
format_punctuation ":" format_expr e)) format_punctuation "\"" format_punctuation ":" format_expr e))
(List.combine es (List.map fst (Dcalc.Ast.StructMap.find s ctx.ctx_structs))) (List.combine es
(List.map fst (Dcalc.Ast.StructMap.find s ctx.ctx_structs)))
format_punctuation "}" format_punctuation "}"
| EArray es -> | EArray es ->
Format.fprintf fmt "@[<hov 2>%a%a%a@]" format_punctuation "[" Format.fprintf fmt "@[<hov 2>%a%a%a@]" format_punctuation "["
@ -100,10 +113,11 @@ let rec format_expr (ctx : Dcalc.Ast.decl_ctx) ?(debug : bool = false) (fmt : Fo
es format_punctuation "]" es format_punctuation "]"
| ETupleAccess (e1, n, s, _ts) -> ( | ETupleAccess (e1, n, s, _ts) -> (
match s with match s with
| None -> Format.fprintf fmt "%a%a%d" format_expr e1 format_punctuation "." n | None ->
Format.fprintf fmt "%a%a%d" format_expr e1 format_punctuation "." n
| Some s -> | Some s ->
Format.fprintf fmt "%a%a%a%a%a" format_expr e1 format_punctuation "." format_punctuation Format.fprintf fmt "%a%a%a%a%a" format_expr e1 format_punctuation "."
"\"" Dcalc.Ast.StructFieldName.format_t format_punctuation "\"" Dcalc.Ast.StructFieldName.format_t
(fst (List.nth (Dcalc.Ast.StructMap.find s ctx.ctx_structs) n)) (fst (List.nth (Dcalc.Ast.StructMap.find s ctx.ctx_structs) n))
format_punctuation "\"") format_punctuation "\"")
| EInj (e, n, en, _ts) -> | EInj (e, n, en, _ts) ->
@ -111,26 +125,31 @@ let rec format_expr (ctx : Dcalc.Ast.decl_ctx) ?(debug : bool = false) (fmt : Fo
(fst (List.nth (Dcalc.Ast.EnumMap.find en ctx.ctx_enums) n)) (fst (List.nth (Dcalc.Ast.EnumMap.find en ctx.ctx_enums) n))
format_expr e format_expr e
| EMatch (e, es, e_name) -> | EMatch (e, es, e_name) ->
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@ %a@]" format_keyword "match" format_expr e Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@ %a@]" format_keyword "match"
format_keyword "with" format_expr e format_keyword "with"
(Format.pp_print_list (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n") ~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n")
(fun fmt (e, c) -> (fun fmt (e, c) ->
Format.fprintf fmt "@[<hov 2>%a %a%a@ %a@]" format_punctuation "|" Format.fprintf fmt "@[<hov 2>%a %a%a@ %a@]" format_punctuation "|"
Dcalc.Print.format_enum_constructor c format_punctuation ":" format_expr e)) Dcalc.Print.format_enum_constructor c format_punctuation ":"
(List.combine es (List.map fst (Dcalc.Ast.EnumMap.find e_name ctx.ctx_enums))) format_expr e))
(List.combine es
(List.map fst (Dcalc.Ast.EnumMap.find e_name ctx.ctx_enums)))
| ELit l -> Format.fprintf fmt "%a" format_lit (Pos.same_pos_as l e) | ELit l -> Format.fprintf fmt "%a" format_lit (Pos.same_pos_as l e)
| EApp ((EAbs ((binder, _), taus), _), args) -> | EApp ((EAbs ((binder, _), taus), _), args) ->
let xs, body = Bindlib.unmbind binder in let xs, body = Bindlib.unmbind binder in
let xs_tau = List.map2 (fun x tau -> (x, tau)) (Array.to_list xs) taus in let xs_tau = List.map2 (fun x tau -> (x, tau)) (Array.to_list xs) taus in
let xs_tau_arg = List.map2 (fun (x, tau) arg -> (x, tau, arg)) xs_tau args in let xs_tau_arg =
List.map2 (fun (x, tau) arg -> (x, tau, arg)) xs_tau args
in
Format.fprintf fmt "%a%a" Format.fprintf fmt "%a%a"
(Format.pp_print_list (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "") ~pp_sep:(fun fmt () -> Format.fprintf fmt "")
(fun fmt (x, tau, arg) -> (fun fmt (x, tau, arg) ->
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@ %a@ %a@ %a@ %a@]@\n" format_keyword "let" Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@ %a@ %a@ %a@ %a@]@\n"
format_var x format_punctuation ":" (Dcalc.Print.format_typ ctx) tau format_keyword "let" format_var x format_punctuation ":"
format_punctuation "=" format_expr arg format_keyword "in")) (Dcalc.Print.format_typ ctx)
tau format_punctuation "=" format_expr arg format_keyword "in"))
xs_tau_arg format_expr body xs_tau_arg format_expr body
| EAbs ((binder, _), taus) -> | EAbs ((binder, _), taus) ->
let xs, body = Bindlib.unmbind binder in let xs, body = Bindlib.unmbind binder in
@ -139,39 +158,57 @@ let rec format_expr (ctx : Dcalc.Ast.decl_ctx) ?(debug : bool = false) (fmt : Fo
(Format.pp_print_list (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ") ~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ")
(fun fmt (x, tau) -> (fun fmt (x, tau) ->
Format.fprintf fmt "%a%a%a %a%a" format_punctuation "(" format_var x format_punctuation Format.fprintf fmt "%a%a%a %a%a" format_punctuation "(" format_var
":" (Dcalc.Print.format_typ ctx) tau format_punctuation ")")) x format_punctuation ":"
(Dcalc.Print.format_typ ctx)
tau format_punctuation ")"))
xs_tau format_punctuation "" format_expr body xs_tau format_punctuation "" format_expr body
| EApp ((EOp (Binop ((Dcalc.Ast.Map | Dcalc.Ast.Filter) as op)), _), [ arg1; arg2 ]) -> | EApp
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@]" Dcalc.Print.format_binop (op, Pos.no_pos) ( (EOp (Binop ((Dcalc.Ast.Map | Dcalc.Ast.Filter) as op)), _),
format_with_parens arg1 format_with_parens arg2 [ arg1; arg2 ] ) ->
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@]" Dcalc.Print.format_binop
(op, Pos.no_pos) format_with_parens arg1 format_with_parens arg2
| EApp ((EOp (Binop op), _), [ arg1; arg2 ]) -> | EApp ((EOp (Binop op), _), [ arg1; arg2 ]) ->
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@]" format_with_parens arg1 Dcalc.Print.format_binop Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@]" format_with_parens arg1
(op, Pos.no_pos) format_with_parens arg2 Dcalc.Print.format_binop (op, Pos.no_pos) format_with_parens arg2
| EApp ((EOp (Unop (Log _)), _), [ arg1 ]) when not debug -> | EApp ((EOp (Unop (Log _)), _), [ arg1 ]) when not debug ->
Format.fprintf fmt "%a" format_with_parens arg1 Format.fprintf fmt "%a" format_with_parens arg1
| EApp ((EOp (Unop op), _), [ arg1 ]) -> | EApp ((EOp (Unop op), _), [ arg1 ]) ->
Format.fprintf fmt "@[<hov 2>%a@ %a@]" Dcalc.Print.format_unop (op, Pos.no_pos) Format.fprintf fmt "@[<hov 2>%a@ %a@]" Dcalc.Print.format_unop
format_with_parens arg1 (op, Pos.no_pos) format_with_parens arg1
| EApp (f, args) -> | EApp (f, args) ->
Format.fprintf fmt "@[<hov 2>%a@ %a@]" format_expr f Format.fprintf fmt "@[<hov 2>%a@ %a@]" format_expr f
(Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ") format_with_parens) (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ")
format_with_parens)
args args
| EIfThenElse (e1, e2, e3) -> | EIfThenElse (e1, e2, e3) ->
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@ %a@ %a@ %a@]" format_keyword "if" format_expr e1 Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@ %a@ %a@ %a@]" format_keyword "if"
format_keyword "then" format_expr e2 format_keyword "else" format_expr e3 format_expr e1 format_keyword "then" format_expr e2 format_keyword
| EOp (Ternop op) -> Format.fprintf fmt "%a" Dcalc.Print.format_ternop (op, Pos.no_pos) "else" format_expr e3
| EOp (Binop op) -> Format.fprintf fmt "%a" Dcalc.Print.format_binop (op, Pos.no_pos) | EOp (Ternop op) ->
| EOp (Unop op) -> Format.fprintf fmt "%a" Dcalc.Print.format_unop (op, Pos.no_pos) Format.fprintf fmt "%a" Dcalc.Print.format_ternop (op, Pos.no_pos)
| EOp (Binop op) ->
Format.fprintf fmt "%a" Dcalc.Print.format_binop (op, Pos.no_pos)
| EOp (Unop op) ->
Format.fprintf fmt "%a" Dcalc.Print.format_unop (op, Pos.no_pos)
| ECatch (e1, exn, e2) -> | ECatch (e1, exn, e2) ->
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@ %a ->@ %a@]" format_keyword "try" format_with_parens Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@ %a ->@ %a@]" format_keyword "try"
e1 format_keyword "with" format_exception exn format_with_parens e2 format_with_parens e1 format_keyword "with" format_exception exn
| ERaise exn -> Format.fprintf fmt "@[<hov 2>%a@ %a@]" format_keyword "raise" format_exception exn format_with_parens e2
| ERaise exn ->
Format.fprintf fmt "@[<hov 2>%a@ %a@]" format_keyword "raise"
format_exception exn
| EAssert e' -> | EAssert e' ->
Format.fprintf fmt "@[<hov 2>%a@ %a%a%a@]" format_keyword "assert" format_punctuation "(" Format.fprintf fmt "@[<hov 2>%a@ %a%a%a@]" format_keyword "assert"
format_expr e' format_punctuation ")" format_punctuation "(" format_expr e' format_punctuation ")"
let format_scope (decl_ctx : Dcalc.Ast.decl_ctx) ?(debug : bool = false) (fmt : Format.formatter) let format_scope
(decl_ctx : Dcalc.Ast.decl_ctx)
?(debug : bool = false)
(fmt : Format.formatter)
(body : scope_body) : unit = (body : scope_body) : unit =
Format.fprintf fmt "@[<hov 2>%a %a %a@ %a@]" format_keyword "let" format_var body.scope_body_var Format.fprintf fmt "@[<hov 2>%a %a %a@ %a@]" format_keyword "let" format_var
format_punctuation "=" (format_expr decl_ctx ~debug) body.scope_body_expr body.scope_body_var format_punctuation "="
(format_expr decl_ctx ~debug)
body.scope_body_expr

View File

@ -1,15 +1,17 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits (* This file is part of the Catala compiler, a specification language for tax
computation rules. Copyright (C) 2020 Inria, contributor: Denis Merigoux and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
<denis.merigoux@inria.fr> Denis Merigoux <denis.merigoux@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
open Utils open Utils
@ -17,18 +19,24 @@ open Utils
(** {1 Helpers} *) (** {1 Helpers} *)
val is_uppercase : CamomileLibraryDefault.Camomile.UChar.t -> bool val is_uppercase : CamomileLibraryDefault.Camomile.UChar.t -> bool
val begins_with_uppercase : string -> bool val begins_with_uppercase : string -> bool
(** {1 Formatters} *) (** {1 Formatters} *)
val format_lit : Format.formatter -> Ast.lit Pos.marked -> unit val format_lit : Format.formatter -> Ast.lit Pos.marked -> unit
val format_var : Format.formatter -> Ast.Var.t -> unit val format_var : Format.formatter -> Ast.Var.t -> unit
val format_exception : Format.formatter -> Ast.except -> unit val format_exception : Format.formatter -> Ast.except -> unit
val format_expr : val format_expr :
Dcalc.Ast.decl_ctx -> ?debug:bool -> Format.formatter -> Ast.expr Pos.marked -> unit Dcalc.Ast.decl_ctx ->
?debug:bool ->
Format.formatter ->
Ast.expr Pos.marked ->
unit
val format_scope : Dcalc.Ast.decl_ctx -> ?debug:bool -> Format.formatter -> Ast.scope_body -> unit val format_scope :
Dcalc.Ast.decl_ctx ->
?debug:bool ->
Format.formatter ->
Ast.scope_body ->
unit

View File

@ -1,15 +1,17 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits (* This file is part of the Catala compiler, a specification language for tax
computation rules. Copyright (C) 2020 Inria, contributor: Denis Merigoux and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
<denis.merigoux@inria.fr> Denis Merigoux <denis.merigoux@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
open Utils open Utils
@ -23,20 +25,25 @@ let find_struct (s : D.StructName.t) (ctx : D.decl_ctx) :
with Not_found -> with Not_found ->
let s_name, pos = D.StructName.get_info s in let s_name, pos = D.StructName.get_info s in
Errors.raise_spanned_error pos Errors.raise_spanned_error pos
"Internal Error: Structure %s was not found in the current environment." s_name "Internal Error: Structure %s was not found in the current environment."
s_name
let find_enum (en : D.EnumName.t) (ctx : D.decl_ctx) : (D.EnumConstructor.t * D.typ Pos.marked) list let find_enum (en : D.EnumName.t) (ctx : D.decl_ctx) :
= (D.EnumConstructor.t * D.typ Pos.marked) list =
try D.EnumMap.find en ctx.D.ctx_enums try D.EnumMap.find en ctx.D.ctx_enums
with Not_found -> with Not_found ->
let en_name, pos = D.EnumName.get_info en in let en_name, pos = D.EnumName.get_info en in
Errors.raise_spanned_error pos Errors.raise_spanned_error pos
"Internal Error: Enumeration %s was not found in the current environment." en_name "Internal Error: Enumeration %s was not found in the current environment."
en_name
let format_lit (fmt : Format.formatter) (l : lit Pos.marked) : unit = let format_lit (fmt : Format.formatter) (l : lit Pos.marked) : unit =
match Pos.unmark l with match Pos.unmark l with
| LBool b -> Dcalc.Print.format_lit fmt (Pos.same_pos_as (Dcalc.Ast.LBool b) l) | LBool b ->
| LInt i -> Format.fprintf fmt "integer_of_string@ \"%s\"" (Runtime.integer_to_string i) Dcalc.Print.format_lit fmt (Pos.same_pos_as (Dcalc.Ast.LBool b) l)
| LInt i ->
Format.fprintf fmt "integer_of_string@ \"%s\""
(Runtime.integer_to_string i)
| LUnit -> Dcalc.Print.format_lit fmt (Pos.same_pos_as Dcalc.Ast.LUnit l) | LUnit -> Dcalc.Print.format_lit fmt (Pos.same_pos_as Dcalc.Ast.LUnit l)
| LRat i -> | LRat i ->
Format.fprintf fmt "decimal_of_string \"%a\"" Dcalc.Print.format_lit Format.fprintf fmt "decimal_of_string \"%a\"" Dcalc.Print.format_lit
@ -55,9 +62,15 @@ let format_lit (fmt : Format.formatter) (l : lit Pos.marked) : unit =
let format_op_kind (fmt : Format.formatter) (k : Dcalc.Ast.op_kind) = let format_op_kind (fmt : Format.formatter) (k : Dcalc.Ast.op_kind) =
Format.fprintf fmt "%s" Format.fprintf fmt "%s"
(match k with KInt -> "!" | KRat -> "&" | KMoney -> "$" | KDate -> "@" | KDuration -> "^") (match k with
| KInt -> "!"
| KRat -> "&"
| KMoney -> "$"
| KDate -> "@"
| KDuration -> "^")
let format_binop (fmt : Format.formatter) (op : Dcalc.Ast.binop Pos.marked) : unit = let format_binop (fmt : Format.formatter) (op : Dcalc.Ast.binop Pos.marked) :
unit =
match Pos.unmark op with match Pos.unmark op with
| Add k -> Format.fprintf fmt "+%a" format_op_kind k | Add k -> Format.fprintf fmt "+%a" format_op_kind k
| Sub k -> Format.fprintf fmt "-%a" format_op_kind k | Sub k -> Format.fprintf fmt "-%a" format_op_kind k
@ -75,14 +88,17 @@ let format_binop (fmt : Format.formatter) (op : Dcalc.Ast.binop Pos.marked) : un
| Map -> Format.fprintf fmt "Array.map" | Map -> Format.fprintf fmt "Array.map"
| Filter -> Format.fprintf fmt "array_filter" | Filter -> Format.fprintf fmt "array_filter"
let format_ternop (fmt : Format.formatter) (op : Dcalc.Ast.ternop Pos.marked) : unit = let format_ternop (fmt : Format.formatter) (op : Dcalc.Ast.ternop Pos.marked) :
unit =
match Pos.unmark op with Fold -> Format.fprintf fmt "Array.fold_left" match Pos.unmark op with Fold -> Format.fprintf fmt "Array.fold_left"
let format_uid_list (fmt : Format.formatter) (uids : Uid.MarkedString.info list) : unit = let format_uid_list (fmt : Format.formatter) (uids : Uid.MarkedString.info list)
: unit =
Format.fprintf fmt "@[<hov 2>[%a]@]" Format.fprintf fmt "@[<hov 2>[%a]@]"
(Format.pp_print_list (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ";@ ") ~pp_sep:(fun fmt () -> Format.fprintf fmt ";@ ")
(fun fmt info -> Format.fprintf fmt "\"%a\"" Utils.Uid.MarkedString.format_info info)) (fun fmt info ->
Format.fprintf fmt "\"%a\"" Utils.Uid.MarkedString.format_info info))
uids uids
let format_string_list (fmt : Format.formatter) (uids : string list) : unit = let format_string_list (fmt : Format.formatter) (uids : string list) : unit =
@ -92,13 +108,15 @@ let format_string_list (fmt : Format.formatter) (uids : string list) : unit =
(fun fmt info -> Format.fprintf fmt "\"%s\"" info)) (fun fmt info -> Format.fprintf fmt "\"%s\"" info))
uids uids
let format_unop (fmt : Format.formatter) (op : Dcalc.Ast.unop Pos.marked) : unit = let format_unop (fmt : Format.formatter) (op : Dcalc.Ast.unop Pos.marked) : unit
=
match Pos.unmark op with match Pos.unmark op with
| Minus k -> Format.fprintf fmt "~-%a" format_op_kind k | Minus k -> Format.fprintf fmt "~-%a" format_op_kind k
| Not -> Format.fprintf fmt "%s" "not" | Not -> Format.fprintf fmt "%s" "not"
| Log (_entry, _infos) -> | Log (_entry, _infos) ->
Errors.raise_spanned_error (Pos.get_position op) Errors.raise_spanned_error (Pos.get_position op)
"Internal error: a log operator has not been caught by the expression match" "Internal error: a log operator has not been caught by the expression \
match"
| Length -> Format.fprintf fmt "%s" "array_length" | Length -> Format.fprintf fmt "%s" "array_length"
| IntToRat -> Format.fprintf fmt "%s" "decimal_of_integer" | IntToRat -> Format.fprintf fmt "%s" "decimal_of_integer"
| GetDay -> Format.fprintf fmt "%s" "day_of_month_of_date" | GetDay -> Format.fprintf fmt "%s" "day_of_month_of_date"
@ -108,36 +126,49 @@ let format_unop (fmt : Format.formatter) (op : Dcalc.Ast.unop Pos.marked) : unit
let avoid_keywords (s : string) : string = let avoid_keywords (s : string) : string =
if if
match s with match s with
(* list taken from http://caml.inria.fr/pub/docs/manual-ocaml/lex.html#sss:keywords *) (* list taken from
| "and" | "as" | "assert" | "asr" | "begin" | "class" | "constraint" | "do" | "done" | "downto" http://caml.inria.fr/pub/docs/manual-ocaml/lex.html#sss:keywords *)
| "else" | "end" | "exception" | "external" | "false" | "for" | "fun" | "function" | "functor" | "and" | "as" | "assert" | "asr" | "begin" | "class" | "constraint" | "do"
| "if" | "in" | "include" | "inherit" | "initializer" | "land" | "lazy" | "let" | "lor" | "lsl" | "done" | "downto" | "else" | "end" | "exception" | "external" | "false"
| "lsr" | "lxor" | "match" | "method" | "mod" | "module" | "mutable" | "new" | "nonrec" | "for" | "fun" | "function" | "functor" | "if" | "in" | "include"
| "object" | "of" | "open" | "or" | "private" | "rec" | "sig" | "struct" | "then" | "to" | "inherit" | "initializer" | "land" | "lazy" | "let" | "lor" | "lsl"
| "true" | "try" | "type" | "val" | "virtual" | "when" | "while" | "with" -> | "lsr" | "lxor" | "match" | "method" | "mod" | "module" | "mutable" | "new"
| "nonrec" | "object" | "of" | "open" | "or" | "private" | "rec" | "sig"
| "struct" | "then" | "to" | "true" | "try" | "type" | "val" | "virtual"
| "when" | "while" | "with" ->
true true
| _ -> false | _ -> false
then s ^ "_" then s ^ "_"
else s else s
let format_struct_name (fmt : Format.formatter) (v : Dcalc.Ast.StructName.t) : unit = let format_struct_name (fmt : Format.formatter) (v : Dcalc.Ast.StructName.t) :
unit =
Format.fprintf fmt "%s" Format.fprintf fmt "%s"
(avoid_keywords (avoid_keywords
(to_lowercase (to_ascii (Format.asprintf "%a" Dcalc.Ast.StructName.format_t v)))) (to_lowercase
(to_ascii (Format.asprintf "%a" Dcalc.Ast.StructName.format_t v))))
let format_struct_field_name (fmt : Format.formatter) (v : Dcalc.Ast.StructFieldName.t) : unit = let format_struct_field_name
(fmt : Format.formatter) (v : Dcalc.Ast.StructFieldName.t) : unit =
Format.fprintf fmt "%s" Format.fprintf fmt "%s"
(avoid_keywords (to_ascii (Format.asprintf "%a" Dcalc.Ast.StructFieldName.format_t v))) (avoid_keywords
(to_ascii (Format.asprintf "%a" Dcalc.Ast.StructFieldName.format_t v)))
let format_enum_name (fmt : Format.formatter) (v : Dcalc.Ast.EnumName.t) : unit = let format_enum_name (fmt : Format.formatter) (v : Dcalc.Ast.EnumName.t) : unit
=
Format.fprintf fmt "%s" Format.fprintf fmt "%s"
(avoid_keywords (to_lowercase (to_ascii (Format.asprintf "%a" Dcalc.Ast.EnumName.format_t v)))) (avoid_keywords
(to_lowercase
(to_ascii (Format.asprintf "%a" Dcalc.Ast.EnumName.format_t v))))
let format_enum_cons_name (fmt : Format.formatter) (v : Dcalc.Ast.EnumConstructor.t) : unit = let format_enum_cons_name
(fmt : Format.formatter) (v : Dcalc.Ast.EnumConstructor.t) : unit =
Format.fprintf fmt "%s" Format.fprintf fmt "%s"
(avoid_keywords (to_ascii (Format.asprintf "%a" Dcalc.Ast.EnumConstructor.format_t v))) (avoid_keywords
(to_ascii (Format.asprintf "%a" Dcalc.Ast.EnumConstructor.format_t v)))
let rec typ_embedding_name (fmt : Format.formatter) (ty : D.typ Pos.marked) : unit = let rec typ_embedding_name (fmt : Format.formatter) (ty : D.typ Pos.marked) :
unit =
match Pos.unmark ty with match Pos.unmark ty with
| D.TLit D.TUnit -> Format.fprintf fmt "embed_unit" | D.TLit D.TUnit -> Format.fprintf fmt "embed_unit"
| D.TLit D.TBool -> Format.fprintf fmt "embed_bool" | D.TLit D.TBool -> Format.fprintf fmt "embed_bool"
@ -146,7 +177,8 @@ let rec typ_embedding_name (fmt : Format.formatter) (ty : D.typ Pos.marked) : un
| D.TLit D.TMoney -> Format.fprintf fmt "embed_money" | D.TLit D.TMoney -> Format.fprintf fmt "embed_money"
| D.TLit D.TDate -> Format.fprintf fmt "embed_date" | D.TLit D.TDate -> Format.fprintf fmt "embed_date"
| D.TLit D.TDuration -> Format.fprintf fmt "embed_duration" | D.TLit D.TDuration -> Format.fprintf fmt "embed_duration"
| D.TTuple (_, Some s_name) -> Format.fprintf fmt "embed_%a" format_struct_name s_name | D.TTuple (_, Some s_name) ->
Format.fprintf fmt "embed_%a" format_struct_name s_name
| D.TEnum (_, e_name) -> Format.fprintf fmt "embed_%a" format_enum_name e_name | D.TEnum (_, e_name) -> Format.fprintf fmt "embed_%a" format_enum_name e_name
| D.TArray ty -> Format.fprintf fmt "embed_array (%a)" typ_embedding_name ty | D.TArray ty -> Format.fprintf fmt "embed_array (%a)" typ_embedding_name ty
| _ -> Format.fprintf fmt "unembeddable" | _ -> Format.fprintf fmt "unembeddable"
@ -154,9 +186,11 @@ let rec typ_embedding_name (fmt : Format.formatter) (ty : D.typ Pos.marked) : un
let typ_needs_parens (e : Dcalc.Ast.typ Pos.marked) : bool = let typ_needs_parens (e : Dcalc.Ast.typ Pos.marked) : bool =
match Pos.unmark e with TArrow _ | TArray _ -> true | _ -> false match Pos.unmark e with TArrow _ | TArray _ -> true | _ -> false
let rec format_typ (fmt : Format.formatter) (typ : Dcalc.Ast.typ Pos.marked) : unit = let rec format_typ (fmt : Format.formatter) (typ : Dcalc.Ast.typ Pos.marked) :
unit =
let format_typ = format_typ in let format_typ = format_typ in
let format_typ_with_parens (fmt : Format.formatter) (t : Dcalc.Ast.typ Pos.marked) = let format_typ_with_parens
(fmt : Format.formatter) (t : Dcalc.Ast.typ Pos.marked) =
if typ_needs_parens t then Format.fprintf fmt "(%a)" format_typ t if typ_needs_parens t then Format.fprintf fmt "(%a)" format_typ t
else Format.fprintf fmt "%a" format_typ t else Format.fprintf fmt "%a" format_typ t
in in
@ -170,20 +204,25 @@ let rec format_typ (fmt : Format.formatter) (typ : Dcalc.Ast.typ Pos.marked) : u
ts ts
| TTuple (_, Some s) -> Format.fprintf fmt "%a" format_struct_name s | TTuple (_, Some s) -> Format.fprintf fmt "%a" format_struct_name s
| TEnum ([ t ], e) when D.EnumName.compare e Ast.option_enum = 0 -> | TEnum ([ t ], e) when D.EnumName.compare e Ast.option_enum = 0 ->
Format.fprintf fmt "@[<hov 2>(%a)@] %a" format_typ_with_parens t format_enum_name e Format.fprintf fmt "@[<hov 2>(%a)@] %a" format_typ_with_parens t
format_enum_name e
| TEnum (_, e) when D.EnumName.compare e Ast.option_enum = 0 -> | TEnum (_, e) when D.EnumName.compare e Ast.option_enum = 0 ->
Errors.raise_spanned_error (Pos.get_position typ) Errors.raise_spanned_error (Pos.get_position typ)
"Internal Error: found an typing parameter for an eoption type of the wrong lenght." "Internal Error: found an typing parameter for an eoption type of the \
wrong lenght."
| TEnum (_ts, e) -> Format.fprintf fmt "%a" format_enum_name e | TEnum (_ts, e) -> Format.fprintf fmt "%a" format_enum_name e
| TArrow (t1, t2) -> | TArrow (t1, t2) ->
Format.fprintf fmt "@[<hov 2>%a ->@ %a@]" format_typ_with_parens t1 format_typ_with_parens t2 Format.fprintf fmt "@[<hov 2>%a ->@ %a@]" format_typ_with_parens t1
format_typ_with_parens t2
| TArray t1 -> Format.fprintf fmt "@[%a@ array@]" format_typ_with_parens t1 | TArray t1 -> Format.fprintf fmt "@[%a@ array@]" format_typ_with_parens t1
| TAny -> Format.fprintf fmt "_" | TAny -> Format.fprintf fmt "_"
let format_var (fmt : Format.formatter) (v : Var.t) : unit = let format_var (fmt : Format.formatter) (v : Var.t) : unit =
let lowercase_name = to_lowercase (to_ascii (Bindlib.name_of v)) in let lowercase_name = to_lowercase (to_ascii (Bindlib.name_of v)) in
let lowercase_name = let lowercase_name =
Re.Pcre.substitute ~rex:(Re.Pcre.regexp "\\.") ~subst:(fun _ -> "_dot_") lowercase_name Re.Pcre.substitute ~rex:(Re.Pcre.regexp "\\.")
~subst:(fun _ -> "_dot_")
lowercase_name
in in
let lowercase_name = avoid_keywords (to_ascii lowercase_name) in let lowercase_name = avoid_keywords (to_ascii lowercase_name) in
if if
@ -195,7 +234,10 @@ let format_var (fmt : Format.formatter) (v : Var.t) : unit =
let needs_parens (e : expr Pos.marked) : bool = let needs_parens (e : expr Pos.marked) : bool =
match Pos.unmark e with match Pos.unmark e with
| EApp ((EAbs (_, _), _), _) | ELit (LBool _ | LUnit) | EVar _ | ETuple _ | EOp _ -> false | EApp ((EAbs (_, _), _), _)
| ELit (LBool _ | LUnit)
| EVar _ | ETuple _ | EOp _ ->
false
| _ -> true | _ -> true
let format_exception (fmt : Format.formatter) (exc : except Pos.marked) : unit = let format_exception (fmt : Format.formatter) (exc : except Pos.marked) : unit =
@ -206,13 +248,15 @@ let format_exception (fmt : Format.formatter) (exc : except Pos.marked) : unit =
| NoValueProvided -> | NoValueProvided ->
let pos = Pos.get_position exc in let pos = Pos.get_position exc in
Format.fprintf fmt Format.fprintf fmt
"(NoValueProvided@ @[<hov 2>{filename = \"%s\";@ start_line=%d;@ start_column=%d;@ \ "(NoValueProvided@ @[<hov 2>{filename = \"%s\";@ start_line=%d;@ \
end_line=%d; end_column=%d;@ law_headings=%a}@])" start_column=%d;@ end_line=%d; end_column=%d;@ law_headings=%a}@])"
(Pos.get_file pos) (Pos.get_start_line pos) (Pos.get_start_column pos) (Pos.get_file pos) (Pos.get_start_line pos) (Pos.get_start_column pos)
(Pos.get_end_line pos) (Pos.get_end_column pos) format_string_list (Pos.get_law_info pos) (Pos.get_end_line pos) (Pos.get_end_column pos) format_string_list
(Pos.get_law_info pos)
let rec format_expr (ctx : Dcalc.Ast.decl_ctx) (fmt : Format.formatter) (e : expr Pos.marked) : unit let rec format_expr
= (ctx : Dcalc.Ast.decl_ctx) (fmt : Format.formatter) (e : expr Pos.marked) :
unit =
let format_expr = format_expr ctx in let format_expr = format_expr ctx in
let format_with_parens (fmt : Format.formatter) (e : expr Pos.marked) = let format_with_parens (fmt : Format.formatter) (e : expr Pos.marked) =
if needs_parens e then Format.fprintf fmt "(%a)" format_expr e if needs_parens e then Format.fprintf fmt "(%a)" format_expr e
@ -233,8 +277,8 @@ let rec format_expr (ctx : Dcalc.Ast.decl_ctx) (fmt : Format.formatter) (e : exp
(Format.pp_print_list (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ";@ ") ~pp_sep:(fun fmt () -> Format.fprintf fmt ";@ ")
(fun fmt (e, struct_field) -> (fun fmt (e, struct_field) ->
Format.fprintf fmt "@[<hov 2>%a =@ %a@]" format_struct_field_name struct_field Format.fprintf fmt "@[<hov 2>%a =@ %a@]" format_struct_field_name
format_with_parens e)) struct_field format_with_parens e))
(List.combine es (List.map fst (find_struct s ctx))) (List.combine es (List.map fst (find_struct s ctx)))
| EArray es -> | EArray es ->
Format.fprintf fmt "@[<hov 2>[|%a|]@]" Format.fprintf fmt "@[<hov 2>[|%a|]@]"
@ -248,11 +292,13 @@ let rec format_expr (ctx : Dcalc.Ast.decl_ctx) (fmt : Format.formatter) (e : exp
Format.fprintf fmt "let@ %a@ = %a@ in@ x" Format.fprintf fmt "let@ %a@ = %a@ in@ x"
(Format.pp_print_list (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ") ~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ")
(fun fmt i -> Format.fprintf fmt "%s" (if i = n then "x" else "_"))) (fun fmt i ->
Format.fprintf fmt "%s" (if i = n then "x" else "_")))
(List.mapi (fun i _ -> i) ts) (List.mapi (fun i _ -> i) ts)
format_with_parens e1 format_with_parens e1
| Some s -> | Some s ->
Format.fprintf fmt "%a.%a" format_with_parens e1 format_struct_field_name Format.fprintf fmt "%a.%a" format_with_parens e1
format_struct_field_name
(fst (List.nth (find_struct s ctx) n))) (fst (List.nth (find_struct s ctx) n)))
| EInj (e, n, en, _ts) -> | EInj (e, n, en, _ts) ->
Format.fprintf fmt "@[<hov 2>%a@ %a@]" format_enum_cons_name Format.fprintf fmt "@[<hov 2>%a@ %a@]" format_enum_cons_name
@ -281,13 +327,15 @@ let rec format_expr (ctx : Dcalc.Ast.decl_ctx) (fmt : Format.formatter) (e : exp
| EApp ((EAbs ((binder, _), taus), _), args) -> | EApp ((EAbs ((binder, _), taus), _), args) ->
let xs, body = Bindlib.unmbind binder in let xs, body = Bindlib.unmbind binder in
let xs_tau = List.map2 (fun x tau -> (x, tau)) (Array.to_list xs) taus in let xs_tau = List.map2 (fun x tau -> (x, tau)) (Array.to_list xs) taus in
let xs_tau_arg = List.map2 (fun (x, tau) arg -> (x, tau, arg)) xs_tau args in let xs_tau_arg =
List.map2 (fun (x, tau) arg -> (x, tau, arg)) xs_tau args
in
Format.fprintf fmt "(%a%a)" Format.fprintf fmt "(%a%a)"
(Format.pp_print_list (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "") ~pp_sep:(fun fmt () -> Format.fprintf fmt "")
(fun fmt (x, tau, arg) -> (fun fmt (x, tau, arg) ->
Format.fprintf fmt "@[<hov 2>let@ %a@ :@ %a@ =@ %a@]@ in@\n" format_var x format_typ Format.fprintf fmt "@[<hov 2>let@ %a@ :@ %a@ =@ %a@]@ in@\n"
tau format_with_parens arg)) format_var x format_typ tau format_with_parens arg))
xs_tau_arg format_with_parens body xs_tau_arg format_with_parens body
| EAbs ((binder, _), taus) -> | EAbs ((binder, _), taus) ->
let xs, body = Bindlib.unmbind binder in let xs, body = Bindlib.unmbind binder in
@ -296,114 +344,145 @@ let rec format_expr (ctx : Dcalc.Ast.decl_ctx) (fmt : Format.formatter) (e : exp
(Format.pp_print_list (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ") ~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ")
(fun fmt (x, tau) -> (fun fmt (x, tau) ->
Format.fprintf fmt "@[<hov 2>(%a:@ %a)@]" format_var x format_typ tau)) Format.fprintf fmt "@[<hov 2>(%a:@ %a)@]" format_var x format_typ
tau))
xs_tau format_expr body xs_tau format_expr body
| EApp ((EOp (Binop ((Dcalc.Ast.Map | Dcalc.Ast.Filter) as op)), _), [ arg1; arg2 ]) -> | EApp
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@]" format_binop (op, Pos.no_pos) format_with_parens ( (EOp (Binop ((Dcalc.Ast.Map | Dcalc.Ast.Filter) as op)), _),
arg1 format_with_parens arg2 [ arg1; arg2 ] ) ->
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@]" format_binop (op, Pos.no_pos)
format_with_parens arg1 format_with_parens arg2
| EApp ((EOp (Binop op), _), [ arg1; arg2 ]) -> | EApp ((EOp (Binop op), _), [ arg1; arg2 ]) ->
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@]" format_with_parens arg1 format_binop Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@]" format_with_parens arg1
(op, Pos.no_pos) format_with_parens arg2 format_binop (op, Pos.no_pos) format_with_parens arg2
| EApp ((EApp ((EOp (Unop (D.Log (D.BeginCall, info))), _), [ f ]), _), [ arg ]) | EApp
((EApp ((EOp (Unop (D.Log (D.BeginCall, info))), _), [ f ]), _), [ arg ])
when !Cli.trace_flag -> when !Cli.trace_flag ->
Format.fprintf fmt "(log_begin_call@ %a@ %a@ %a)" format_uid_list info format_with_parens f Format.fprintf fmt "(log_begin_call@ %a@ %a@ %a)" format_uid_list info
format_with_parens arg format_with_parens f format_with_parens arg
| EApp ((EOp (Unop (D.Log (D.VarDef tau, info))), _), [ arg1 ]) when !Cli.trace_flag -> | EApp ((EOp (Unop (D.Log (D.VarDef tau, info))), _), [ arg1 ])
Format.fprintf fmt "(log_variable_definition@ %a@ (%a)@ %a)" format_uid_list info when !Cli.trace_flag ->
typ_embedding_name (tau, Pos.no_pos) format_with_parens arg1 Format.fprintf fmt "(log_variable_definition@ %a@ (%a)@ %a)"
| EApp ((EOp (Unop (D.Log (D.PosRecordIfTrueBool, _))), pos), [ arg1 ]) when !Cli.trace_flag -> format_uid_list info typ_embedding_name (tau, Pos.no_pos)
Format.fprintf fmt
"(log_decision_taken@ @[<hov 2>{filename = \"%s\";@ start_line=%d;@ start_column=%d;@ \
end_line=%d; end_column=%d;@ law_headings=%a}@]@ %a)"
(Pos.get_file pos) (Pos.get_start_line pos) (Pos.get_start_column pos)
(Pos.get_end_line pos) (Pos.get_end_column pos) format_string_list (Pos.get_law_info pos)
format_with_parens arg1 format_with_parens arg1
| EApp ((EOp (Unop (D.Log (D.EndCall, info))), _), [ arg1 ]) when !Cli.trace_flag -> | EApp ((EOp (Unop (D.Log (D.PosRecordIfTrueBool, _))), pos), [ arg1 ])
Format.fprintf fmt "(log_end_call@ %a@ %a)" format_uid_list info format_with_parens arg1 when !Cli.trace_flag ->
| EApp ((EOp (Unop (D.Log _)), _), [ arg1 ]) -> Format.fprintf fmt "%a" format_with_parens arg1 Format.fprintf fmt
"(log_decision_taken@ @[<hov 2>{filename = \"%s\";@ start_line=%d;@ \
start_column=%d;@ end_line=%d; end_column=%d;@ law_headings=%a}@]@ \
%a)"
(Pos.get_file pos) (Pos.get_start_line pos) (Pos.get_start_column pos)
(Pos.get_end_line pos) (Pos.get_end_column pos) format_string_list
(Pos.get_law_info pos) format_with_parens arg1
| EApp ((EOp (Unop (D.Log (D.EndCall, info))), _), [ arg1 ])
when !Cli.trace_flag ->
Format.fprintf fmt "(log_end_call@ %a@ %a)" format_uid_list info
format_with_parens arg1
| EApp ((EOp (Unop (D.Log _)), _), [ arg1 ]) ->
Format.fprintf fmt "%a" format_with_parens arg1
| EApp ((EOp (Unop op), _), [ arg1 ]) -> | EApp ((EOp (Unop op), _), [ arg1 ]) ->
Format.fprintf fmt "@[<hov 2>%a@ %a@]" format_unop (op, Pos.no_pos) format_with_parens arg1 Format.fprintf fmt "@[<hov 2>%a@ %a@]" format_unop (op, Pos.no_pos)
format_with_parens arg1
| EApp (f, args) -> | EApp (f, args) ->
Format.fprintf fmt "@[<hov 2>%a@ %a@]" format_with_parens f Format.fprintf fmt "@[<hov 2>%a@ %a@]" format_with_parens f
(Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ") format_with_parens) (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ")
format_with_parens)
args args
| EIfThenElse (e1, e2, e3) -> | EIfThenElse (e1, e2, e3) ->
Format.fprintf fmt "@[<hov 2> if@ @[<hov 2>%a@]@ then@ @[<hov 2>%a@]@ else@ @[<hov 2>%a@]@]" Format.fprintf fmt
"@[<hov 2> if@ @[<hov 2>%a@]@ then@ @[<hov 2>%a@]@ else@ @[<hov \
2>%a@]@]"
format_with_parens e1 format_with_parens e2 format_with_parens e3 format_with_parens e1 format_with_parens e2 format_with_parens e3
| EOp (Ternop op) -> Format.fprintf fmt "%a" format_ternop (op, Pos.no_pos) | EOp (Ternop op) -> Format.fprintf fmt "%a" format_ternop (op, Pos.no_pos)
| EOp (Binop op) -> Format.fprintf fmt "%a" format_binop (op, Pos.no_pos) | EOp (Binop op) -> Format.fprintf fmt "%a" format_binop (op, Pos.no_pos)
| EOp (Unop op) -> Format.fprintf fmt "%a" format_unop (op, Pos.no_pos) | EOp (Unop op) -> Format.fprintf fmt "%a" format_unop (op, Pos.no_pos)
| EAssert e' -> | EAssert e' ->
Format.fprintf fmt "@[<hov 2>if @ %a@ then@ ()@ else@ raise AssertionFailed@]" Format.fprintf fmt
"@[<hov 2>if @ %a@ then@ ()@ else@ raise AssertionFailed@]"
format_with_parens e' format_with_parens e'
| ERaise exc -> Format.fprintf fmt "raise@ %a" format_exception (exc, Pos.get_position e) | ERaise exc ->
Format.fprintf fmt "raise@ %a" format_exception (exc, Pos.get_position e)
| ECatch (e1, exc, e2) -> | ECatch (e1, exc, e2) ->
Format.fprintf fmt "@[<hov 2>try@ %a@ with@ %a@ ->@ %a@]" format_with_parens e1 Format.fprintf fmt "@[<hov 2>try@ %a@ with@ %a@ ->@ %a@]"
format_exception format_with_parens e1 format_exception
(exc, Pos.get_position e) (exc, Pos.get_position e)
format_with_parens e2 format_with_parens e2
let format_struct_embedding (fmt : Format.formatter) let format_struct_embedding
((struct_name, struct_fields) : D.StructName.t * (D.StructFieldName.t * D.typ Pos.marked) list) (fmt : Format.formatter)
= ((struct_name, struct_fields) :
D.StructName.t * (D.StructFieldName.t * D.typ Pos.marked) list) =
if List.length struct_fields = 0 then if List.length struct_fields = 0 then
Format.fprintf fmt "let embed_%a (_: %a) : runtime_value = Unit@\n@\n" format_struct_name Format.fprintf fmt "let embed_%a (_: %a) : runtime_value = Unit@\n@\n"
struct_name format_struct_name struct_name format_struct_name struct_name format_struct_name struct_name
else else
Format.fprintf fmt Format.fprintf fmt
"@[<hov 2>let embed_%a (x: %a) : runtime_value =@ Struct([\"%a\"],@ @[<hov 2>[%a]@])@]@\n@\n" "@[<hov 2>let embed_%a (x: %a) : runtime_value =@ Struct([\"%a\"],@ \
format_struct_name struct_name format_struct_name struct_name D.StructName.format_t @[<hov 2>[%a]@])@]@\n\
struct_name @\n"
format_struct_name struct_name format_struct_name struct_name
D.StructName.format_t struct_name
(Format.pp_print_list (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ";@\n") ~pp_sep:(fun fmt () -> Format.fprintf fmt ";@\n")
(fun _fmt (struct_field, struct_field_type) -> (fun _fmt (struct_field, struct_field_type) ->
Format.fprintf fmt "(\"%a\",@ %a@ x.%a)" D.StructFieldName.format_t struct_field Format.fprintf fmt "(\"%a\",@ %a@ x.%a)" D.StructFieldName.format_t
typ_embedding_name struct_field_type format_struct_field_name struct_field)) struct_field typ_embedding_name struct_field_type
format_struct_field_name struct_field))
struct_fields struct_fields
let format_enum_embedding (fmt : Format.formatter) let format_enum_embedding
((enum_name, enum_cases) : D.EnumName.t * (D.EnumConstructor.t * D.typ Pos.marked) list) = (fmt : Format.formatter)
((enum_name, enum_cases) :
D.EnumName.t * (D.EnumConstructor.t * D.typ Pos.marked) list) =
if List.length enum_cases = 0 then if List.length enum_cases = 0 then
Format.fprintf fmt "let embed_%a (_: %a) : runtime_value = Unit@\n@\n" format_enum_name Format.fprintf fmt "let embed_%a (_: %a) : runtime_value = Unit@\n@\n"
enum_name format_enum_name enum_name format_enum_name enum_name format_enum_name enum_name
else else
Format.fprintf fmt Format.fprintf fmt
"@[<hov 2>let embed_%a (x: %a) : runtime_value =@ Enum([\"%a\"],@ @[<hov 2>match x with@ \ "@[<hov 2>let embed_%a (x: %a) : runtime_value =@ Enum([\"%a\"],@ @[<hov \
%a@])@]@\n\ 2>match x with@ %a@])@]@\n\
@\n" @\n"
format_enum_name enum_name format_enum_name enum_name D.EnumName.format_t enum_name format_enum_name enum_name format_enum_name enum_name D.EnumName.format_t
enum_name
(Format.pp_print_list (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n") ~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n")
(fun _fmt (enum_cons, enum_cons_type) -> (fun _fmt (enum_cons, enum_cons_type) ->
Format.fprintf fmt "@[<hov 2>| %a x ->@ (\"%a\", %a x)@]" format_enum_cons_name enum_cons Format.fprintf fmt "@[<hov 2>| %a x ->@ (\"%a\", %a x)@]"
D.EnumConstructor.format_t enum_cons typ_embedding_name enum_cons_type)) format_enum_cons_name enum_cons D.EnumConstructor.format_t
enum_cons typ_embedding_name enum_cons_type))
enum_cases enum_cases
let format_ctx (type_ordering : Scopelang.Dependency.TVertex.t list) (fmt : Format.formatter) let format_ctx
(type_ordering : Scopelang.Dependency.TVertex.t list)
(fmt : Format.formatter)
(ctx : D.decl_ctx) : unit = (ctx : D.decl_ctx) : unit =
let format_struct_decl fmt (struct_name, struct_fields) = let format_struct_decl fmt (struct_name, struct_fields) =
if List.length struct_fields = 0 then if List.length struct_fields = 0 then
Format.fprintf fmt "type %a = unit@\n@\n" format_struct_name struct_name Format.fprintf fmt "type %a = unit@\n@\n" format_struct_name struct_name
else else
Format.fprintf fmt "type %a = {@\n@[<hov 2> %a@]@\n}@\n@\n" format_struct_name struct_name Format.fprintf fmt "type %a = {@\n@[<hov 2> %a@]@\n}@\n@\n"
format_struct_name struct_name
(Format.pp_print_list (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n") ~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n")
(fun _fmt (struct_field, struct_field_type) -> (fun _fmt (struct_field, struct_field_type) ->
Format.fprintf fmt "%a:@ %a;" format_struct_field_name struct_field format_typ Format.fprintf fmt "%a:@ %a;" format_struct_field_name struct_field
struct_field_type)) format_typ struct_field_type))
struct_fields; struct_fields;
if !Cli.trace_flag then format_struct_embedding fmt (struct_name, struct_fields) if !Cli.trace_flag then
format_struct_embedding fmt (struct_name, struct_fields)
in in
let format_enum_decl fmt (enum_name, enum_cons) = let format_enum_decl fmt (enum_name, enum_cons) =
if List.length enum_cons = 0 then if List.length enum_cons = 0 then
Format.fprintf fmt "type %a = unit@\n@\n" format_enum_name enum_name Format.fprintf fmt "type %a = unit@\n@\n" format_enum_name enum_name
else else
Format.fprintf fmt "type %a =@\n@[<hov 2> %a@]@\n@\n" format_enum_name enum_name Format.fprintf fmt "type %a =@\n@[<hov 2> %a@]@\n@\n" format_enum_name
enum_name
(Format.pp_print_list (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n") ~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n")
(fun _fmt (enum_cons, enum_cons_type) -> (fun _fmt (enum_cons, enum_cons_type) ->
Format.fprintf fmt "| %a@ of@ %a" format_enum_cons_name enum_cons format_typ Format.fprintf fmt "| %a@ of@ %a" format_enum_cons_name enum_cons
enum_cons_type)) format_typ enum_cons_type))
enum_cons; enum_cons;
if !Cli.trace_flag then format_enum_embedding fmt (enum_name, enum_cons) if !Cli.trace_flag then format_enum_embedding fmt (enum_name, enum_cons)
in in
@ -419,7 +498,9 @@ let format_ctx (type_ordering : Scopelang.Dependency.TVertex.t list) (fmt : Form
List.map List.map
(fun (s, _) -> Scopelang.Dependency.TVertex.Struct s) (fun (s, _) -> Scopelang.Dependency.TVertex.Struct s)
(Dcalc.Ast.StructMap.bindings (Dcalc.Ast.StructMap.bindings
(Dcalc.Ast.StructMap.filter (fun s _ -> not (is_in_type_ordering s)) ctx.ctx_structs)) (Dcalc.Ast.StructMap.filter
(fun s _ -> not (is_in_type_ordering s))
ctx.ctx_structs))
in in
List.iter List.iter
(fun struct_or_enum -> (fun struct_or_enum ->
@ -430,7 +511,9 @@ let format_ctx (type_ordering : Scopelang.Dependency.TVertex.t list) (fmt : Form
Format.fprintf fmt "%a@\n@\n" format_enum_decl (e, find_enum e ctx)) Format.fprintf fmt "%a@\n@\n" format_enum_decl (e, find_enum e ctx))
(type_ordering @ scope_structs) (type_ordering @ scope_structs)
let format_program (fmt : Format.formatter) (p : Ast.program) let format_program
(fmt : Format.formatter)
(p : Ast.program)
(type_ordering : Scopelang.Dependency.TVertex.t list) : unit = (type_ordering : Scopelang.Dependency.TVertex.t list) : unit =
Cli.style_flag := false; Cli.style_flag := false;
Format.fprintf fmt Format.fprintf fmt
@ -447,6 +530,6 @@ let format_program (fmt : Format.formatter) (p : Ast.program)
(Format.pp_print_list (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n@\n") ~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n@\n")
(fun fmt body -> (fun fmt body ->
Format.fprintf fmt "@[<hov 2>let@ %a@ =@ %a@]" format_var body.scope_body_var Format.fprintf fmt "@[<hov 2>let@ %a@ =@ %a@]" format_var
(format_expr p.decl_ctx) body.scope_body_expr)) body.scope_body_var (format_expr p.decl_ctx) body.scope_body_expr))
p.scopes p.scopes

View File

@ -1,18 +1,21 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits (* This file is part of the Catala compiler, a specification language for tax
computation rules. Copyright (C) 2020 Inria, contributor: Denis Merigoux and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
<denis.merigoux@inria.fr> Denis Merigoux <denis.merigoux@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
(** Formats a lambda calculus program into a valid OCaml program *) (** Formats a lambda calculus program into a valid OCaml program *)
val format_program : Format.formatter -> Ast.program -> Scopelang.Dependency.TVertex.t list -> unit val format_program :
Format.formatter -> Ast.program -> Scopelang.Dependency.TVertex.t list -> unit
(** Usage [format_program fmt p type_dependencies_ordering] *) (** Usage [format_program fmt p type_dependencies_ordering] *)

View File

@ -1,19 +1,22 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits (* This file is part of the Catala compiler, a specification language for tax
computation rules. Copyright (C) 2020 Inria, contributors: Denis Merigoux and social benefits computation rules. Copyright (C) 2020 Inria,
<denis.merigoux@inria.fr>, Emile Rolley <emile.rolley@tuta.io> contributors: Denis Merigoux <denis.merigoux@inria.fr>, Emile Rolley
<emile.rolley@tuta.io>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
(** This modules weaves the source code and the legislative text together into a document that law (** This modules weaves the source code and the legislative text together into a
professionals can understand. *) document that law professionals can understand. *)
open Utils open Utils
open Literate_common open Literate_common
@ -33,29 +36,40 @@ let pre_html (s : string) =
(** Raise an error if pygments cannot be found *) (** Raise an error if pygments cannot be found *)
let raise_failed_pygments (command : string) (error_code : int) : 'a = let raise_failed_pygments (command : string) (error_code : int) : 'a =
Errors.raise_error "Weaving to HTML failed: pygmentize command \"%s\" returned with error code %d" Errors.raise_error
"Weaving to HTML failed: pygmentize command \"%s\" returned with error \
code %d"
command error_code command error_code
(** Partial application allowing to remove first code lines of [<td class="code">] and (** Partial application allowing to remove first code lines of
[<td class="linenos">] generated HTML. Basically, remove all code block first lines. *) [<td class="code">] and [<td class="linenos">] generated HTML. Basically,
remove all code block first lines. *)
let remove_cb_first_lines : string -> string = let remove_cb_first_lines : string -> string =
R.substitute ~rex:(R.regexp "<pre>.*\n") ~subst:(function _ -> "<pre>\n") R.substitute ~rex:(R.regexp "<pre>.*\n") ~subst:(function _ -> "<pre>\n")
(** Partial application allowing to remove last code lines of [<td class="code">] and (** Partial application allowing to remove last code lines of
[<td class="linenos">] generated HTML. Basically, remove all code block last lines. *) [<td class="code">] and [<td class="linenos">] generated HTML. Basically,
remove all code block last lines. *)
let remove_cb_last_lines : string -> string = let remove_cb_last_lines : string -> string =
R.substitute ~rex:(R.regexp "<.*\n*</pre>") ~subst:(function _ -> "</pre>") R.substitute ~rex:(R.regexp "<.*\n*</pre>") ~subst:(function _ -> "</pre>")
(** Usage: [wrap_html source_files custom_pygments language fmt wrapped] (** Usage: [wrap_html source_files custom_pygments language fmt wrapped]
Prints an HTML complete page structure around the [wrapped] content. *) Prints an HTML complete page structure around the [wrapped] content. *)
let wrap_html (source_files : string list) (language : Cli.backend_lang) (fmt : Format.formatter) let wrap_html
(source_files : string list)
(language : Cli.backend_lang)
(fmt : Format.formatter)
(wrapped : Format.formatter -> unit) : unit = (wrapped : Format.formatter -> unit) : unit =
let pygments = "pygmentize" in let pygments = "pygmentize" in
let css_file = Filename.temp_file "catala_css_pygments" "" in let css_file = Filename.temp_file "catala_css_pygments" "" in
let pygments_args = [| "-f"; "html"; "-S"; "colorful"; "-a"; ".catala-code" |] in let pygments_args =
[| "-f"; "html"; "-S"; "colorful"; "-a"; ".catala-code" |]
in
let cmd = let cmd =
Format.sprintf "%s %s > %s" pygments (String.concat " " (Array.to_list pygments_args)) css_file Format.sprintf "%s %s > %s" pygments
(String.concat " " (Array.to_list pygments_args))
css_file
in in
let return_code = Sys.command cmd in let return_code = Sys.command cmd in
if return_code <> 0 then raise_failed_pygments cmd return_code; if return_code <> 0 then raise_failed_pygments cmd return_code;
@ -78,7 +92,9 @@ let wrap_html (source_files : string list) (language : Cli.backend_lang) (fmt :
<ul>\n\ <ul>\n\
%s\n\ %s\n\
</ul>\n" </ul>\n"
css_as_string (literal_title language) (literal_generated_by language) Utils.Cli.version css_as_string (literal_title language)
(literal_generated_by language)
Utils.Cli.version
(literal_source_files language) (literal_source_files language)
(String.concat "\n" (String.concat "\n"
(List.map (List.map
@ -86,8 +102,10 @@ let wrap_html (source_files : string list) (language : Cli.backend_lang) (fmt :
let mtime = (Unix.stat filename).Unix.st_mtime in let mtime = (Unix.stat filename).Unix.st_mtime in
let ltime = Unix.localtime mtime in let ltime = Unix.localtime mtime in
let ftime = let ftime =
Printf.sprintf "%d-%02d-%02d, %d:%02d" (1900 + ltime.Unix.tm_year) Printf.sprintf "%d-%02d-%02d, %d:%02d"
(ltime.Unix.tm_mon + 1) ltime.Unix.tm_mday ltime.Unix.tm_hour ltime.Unix.tm_min (1900 + ltime.Unix.tm_year)
(ltime.Unix.tm_mon + 1) ltime.Unix.tm_mday ltime.Unix.tm_hour
ltime.Unix.tm_min
in in
Printf.sprintf "<li><tt>%s</tt>, %s %s</li>" Printf.sprintf "<li><tt>%s</tt>, %s %s</li>"
(pre_html (Filename.basename filename)) (pre_html (Filename.basename filename))
@ -96,9 +114,12 @@ let wrap_html (source_files : string list) (language : Cli.backend_lang) (fmt :
source_files)); source_files));
wrapped fmt wrapped fmt
(** Performs syntax highlighting on a piece of code by using Pygments and the special Catala lexer. *) (** Performs syntax highlighting on a piece of code by using Pygments and the
let pygmentize_code (c : string Pos.marked) (language : C.backend_lang) : string = special Catala lexer. *)
C.debug_print "Pygmenting the code chunk %s" (Pos.to_string (Pos.get_position c)); let pygmentize_code (c : string Pos.marked) (language : C.backend_lang) : string
=
C.debug_print "Pygmenting the code chunk %s"
(Pos.to_string (Pos.get_position c));
let temp_file_in = Filename.temp_file "catala_html_pygments" "in" in let temp_file_in = Filename.temp_file "catala_html_pygments" "in" in
let temp_file_out = Filename.temp_file "catala_html_pygments" "out" in let temp_file_out = Filename.temp_file "catala_html_pygments" "out" in
let oc = open_out temp_file_in in let oc = open_out temp_file_in in
@ -122,48 +143,66 @@ let pygmentize_code (c : string Pos.marked) (language : C.backend_lang) : string
temp_file_in; temp_file_in;
|] |]
in in
let cmd = Format.asprintf "%s %s" pygments (String.concat " " (Array.to_list pygments_args)) in let cmd =
Format.asprintf "%s %s" pygments
(String.concat " " (Array.to_list pygments_args))
in
let return_code = Sys.command cmd in let return_code = Sys.command cmd in
if return_code <> 0 then raise_failed_pygments cmd return_code; if return_code <> 0 then raise_failed_pygments cmd return_code;
let oc = open_in temp_file_out in let oc = open_in temp_file_out in
let output = really_input_string oc (in_channel_length oc) in let output = really_input_string oc (in_channel_length oc) in
close_in oc; close_in oc;
(* Remove code blocks delimiters needed by [Pygments]. *) (* Remove code blocks delimiters needed by [Pygments]. *)
let trimmed_output = output |> remove_cb_first_lines |> remove_cb_last_lines in let trimmed_output =
output |> remove_cb_first_lines |> remove_cb_last_lines
in
trimmed_output trimmed_output
(** {1 Weaving} *) (** {1 Weaving} *)
let rec law_structure_to_html (language : C.backend_lang) (fmt : Format.formatter) let rec law_structure_to_html
(i : A.law_structure) : unit = (language : C.backend_lang) (fmt : Format.formatter) (i : A.law_structure) :
unit =
match i with match i with
| A.LawText t -> | A.LawText t ->
let t = pre_html t in let t = pre_html t in
if t = "" then () else Format.fprintf fmt "<p class='law-text'>%s</p>" t if t = "" then () else Format.fprintf fmt "<p class='law-text'>%s</p>" t
| A.CodeBlock (_, c, metadata) -> | A.CodeBlock (_, c, metadata) ->
Format.fprintf fmt "<div class='code-wrapper%s'>\n<div class='filename'>%s</div>\n%s\n</div>" Format.fprintf fmt
"<div class='code-wrapper%s'>\n\
<div class='filename'>%s</div>\n\
%s\n\
</div>"
(if metadata then " code-metadata" else "") (if metadata then " code-metadata" else "")
(Pos.get_file (Pos.get_position c)) (Pos.get_file (Pos.get_position c))
(pygmentize_code (Pos.same_pos_as ("```catala\n" ^ Pos.unmark c ^ "```") c) language) (pygmentize_code
(Pos.same_pos_as ("```catala\n" ^ Pos.unmark c ^ "```") c)
language)
| A.LawHeading (heading, children) -> | A.LawHeading (heading, children) ->
let h_number = heading.law_heading_precedence + 1 in let h_number = heading.law_heading_precedence + 1 in
Format.fprintf fmt "<h%d class='law-heading'><a href='%s'>%s</a></h%d>\n" h_number Format.fprintf fmt "<h%d class='law-heading'><a href='%s'>%s</a></h%d>\n"
h_number
(match (heading.law_heading_id, language) with (match (heading.law_heading_id, language) with
| Some id, Fr -> | Some id, Fr ->
let ltime = Unix.localtime (Unix.time ()) in let ltime = Unix.localtime (Unix.time ()) in
P.sprintf "https://legifrance.gouv.fr/codes/id/%s/%d-%02d-%02d" id P.sprintf "https://legifrance.gouv.fr/codes/id/%s/%d-%02d-%02d" id
(1900 + ltime.Unix.tm_year) (ltime.Unix.tm_mon + 1) ltime.Unix.tm_mday (1900 + ltime.Unix.tm_year)
(ltime.Unix.tm_mon + 1) ltime.Unix.tm_mday
| _ -> "#") | _ -> "#")
(pre_html (Pos.unmark heading.law_heading_name)) (pre_html (Pos.unmark heading.law_heading_name))
h_number; h_number;
Format.pp_print_list Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "\n") ~pp_sep:(fun fmt () -> Format.fprintf fmt "\n")
(law_structure_to_html language) fmt children (law_structure_to_html language)
fmt children
| A.LawInclude _ -> () | A.LawInclude _ -> ()
(** {1 API} *) (** {1 API} *)
let ast_to_html (language : C.backend_lang) (fmt : Format.formatter) (program : A.program) : unit = let ast_to_html
(language : C.backend_lang) (fmt : Format.formatter) (program : A.program) :
unit =
Format.pp_print_list Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "\n\n") ~pp_sep:(fun fmt () -> Format.fprintf fmt "\n\n")
(law_structure_to_html language) fmt program.program_items (law_structure_to_html language)
fmt program.program_items

View File

@ -1,30 +1,37 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits (* This file is part of the Catala compiler, a specification language for tax
computation rules. Copyright (C) 2020 Inria, contributor: Denis Merigoux and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
<denis.merigoux@inria.fr> Denis Merigoux <denis.merigoux@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
(** This modules weaves the source code and the legislative text together into a document that law (** This modules weaves the source code and the legislative text together into a
professionals can understand. *) document that law professionals can understand. *)
open Utils open Utils
(** {1 Helpers} *) (** {1 Helpers} *)
val wrap_html : val wrap_html :
string list -> Cli.backend_lang -> Format.formatter -> (Format.formatter -> unit) -> unit string list ->
Cli.backend_lang ->
Format.formatter ->
(Format.formatter -> unit) ->
unit
(** Usage: [wrap_html source_files language fmt wrapped] (** Usage: [wrap_html source_files language fmt wrapped]
Prints an HTML complete page structure around the [wrapped] content. *) Prints an HTML complete page structure around the [wrapped] content. *)
(** {1 API} *) (** {1 API} *)
val ast_to_html : Cli.backend_lang -> Format.formatter -> Surface.Ast.program -> unit val ast_to_html :
Cli.backend_lang -> Format.formatter -> Surface.Ast.program -> unit

View File

@ -1,19 +1,22 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits (* This file is part of the Catala compiler, a specification language for tax
computation rules. Copyright (C) 2020 Inria, contributors: Denis Merigoux and social benefits computation rules. Copyright (C) 2020 Inria,
<denis.merigoux@inria.fr>, Emile Rolley <emile.rolley@tuta.io> contributors: Denis Merigoux <denis.merigoux@inria.fr>, Emile Rolley
<emile.rolley@tuta.io>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
(** This modules weaves the source code and the legislative text together into a document that law (** This modules weaves the source code and the legislative text together into a
professionals can understand. *) document that law professionals can understand. *)
open Utils open Utils
open Literate_common open Literate_common
@ -25,7 +28,9 @@ module C = Cli
(** Espaces various LaTeX-sensitive characters *) (** Espaces various LaTeX-sensitive characters *)
let pre_latexify (s : string) : string = let pre_latexify (s : string) : string =
let substitute s (old_s, new_s) = R.substitute ~rex:(R.regexp old_s) ~subst:(fun _ -> new_s) s in let substitute s (old_s, new_s) =
R.substitute ~rex:(R.regexp old_s) ~subst:(fun _ -> new_s) s
in
[ [
("\\$", "\\$"); ("\\$", "\\$");
("%", "\\%"); ("%", "\\%");
@ -39,7 +44,10 @@ let pre_latexify (s : string) : string =
(** Usage: [wrap_latex source_files custom_pygments language fmt wrapped] (** Usage: [wrap_latex source_files custom_pygments language fmt wrapped]
Prints an LaTeX complete documùent structure around the [wrapped] content. *) Prints an LaTeX complete documùent structure around the [wrapped] content. *)
let wrap_latex (source_files : string list) (language : C.backend_lang) (fmt : Format.formatter) let wrap_latex
(source_files : string list)
(language : C.backend_lang)
(fmt : Format.formatter)
(wrapped : Format.formatter -> unit) = (wrapped : Format.formatter -> unit) =
Format.fprintf fmt Format.fprintf fmt
"\\documentclass[%s, 11pt, a4paper]{article}\n\n\ "\\documentclass[%s, 11pt, a4paper]{article}\n\n\
@ -83,7 +91,9 @@ let wrap_latex (source_files : string list) (language : C.backend_lang) (fmt : F
\\begin{itemize}%s\\end{itemize}\n\n\ \\begin{itemize}%s\\end{itemize}\n\n\
\\[\\star\\star\\star\\]\\\\\n" \\[\\star\\star\\star\\]\\\\\n"
(match language with Fr -> "french" | En -> "english" | Pl -> "polish") (match language with Fr -> "french" | En -> "english" | Pl -> "polish")
(literal_title language) (literal_generated_by language) Utils.Cli.version (literal_title language)
(literal_generated_by language)
Utils.Cli.version
(literal_source_files language) (literal_source_files language)
(String.concat "," (String.concat ","
(List.map (List.map
@ -91,8 +101,10 @@ let wrap_latex (source_files : string list) (language : C.backend_lang) (fmt : F
let mtime = (Unix.stat filename).Unix.st_mtime in let mtime = (Unix.stat filename).Unix.st_mtime in
let ltime = Unix.localtime mtime in let ltime = Unix.localtime mtime in
let ftime = let ftime =
Printf.sprintf "%d-%02d-%02d, %d:%02d" (1900 + ltime.Unix.tm_year) Printf.sprintf "%d-%02d-%02d, %d:%02d"
(ltime.Unix.tm_mon + 1) ltime.Unix.tm_mday ltime.Unix.tm_hour ltime.Unix.tm_min (1900 + ltime.Unix.tm_year)
(ltime.Unix.tm_mon + 1) ltime.Unix.tm_mday ltime.Unix.tm_hour
ltime.Unix.tm_min
in in
Printf.sprintf "\\item\\texttt{%s}, %s %s" Printf.sprintf "\\item\\texttt{%s}, %s %s"
(pre_latexify (Filename.basename filename)) (pre_latexify (Filename.basename filename))
@ -104,8 +116,9 @@ let wrap_latex (source_files : string list) (language : C.backend_lang) (fmt : F
(** {1 Weaving} *) (** {1 Weaving} *)
let rec law_structure_to_latex (language : C.backend_lang) (fmt : Format.formatter) let rec law_structure_to_latex
(i : A.law_structure) : unit = (language : C.backend_lang) (fmt : Format.formatter) (i : A.law_structure) :
unit =
match i with match i with
| A.LawHeading (heading, children) -> | A.LawHeading (heading, children) ->
Format.fprintf fmt "\\%s*{%s}\n\n" Format.fprintf fmt "\\%s*{%s}\n\n"
@ -118,11 +131,16 @@ let rec law_structure_to_latex (language : C.backend_lang) (fmt : Format.formatt
(pre_latexify (Pos.unmark heading.law_heading_name)); (pre_latexify (Pos.unmark heading.law_heading_name));
Format.pp_print_list Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "\n\n") ~pp_sep:(fun fmt () -> Format.fprintf fmt "\n\n")
(law_structure_to_latex language) fmt children (law_structure_to_latex language)
fmt children
| A.LawInclude (A.PdfFile ((file, _), page)) -> | A.LawInclude (A.PdfFile ((file, _), page)) ->
let label = file ^ match page with None -> "" | Some p -> Format.sprintf "_page_%d," p in let label =
file
^ match page with None -> "" | Some p -> Format.sprintf "_page_%d," p
in
Format.fprintf fmt Format.fprintf fmt
"\\begin{center}\\textit{Annexe incluse, retranscrite page \\pageref{%s}}\\end{center} \ "\\begin{center}\\textit{Annexe incluse, retranscrite page \
\\pageref{%s}}\\end{center} \
\\begin{figure}[p]\\begin{center}\\includegraphics[%swidth=\\textwidth]{%s}\\label{%s}\\end{center}\\end{figure}" \\begin{figure}[p]\\begin{center}\\includegraphics[%swidth=\\textwidth]{%s}\\label{%s}\\end{center}\\end{figure}"
label label
(match page with None -> "" | Some p -> Format.sprintf "page=%d," p) (match page with None -> "" | Some p -> Format.sprintf "page=%d," p)
@ -137,16 +155,22 @@ let rec law_structure_to_latex (language : C.backend_lang) (fmt : Format.formatt
\\end{minted}" \\end{minted}"
(pre_latexify (Filename.basename (Pos.get_file (Pos.get_position c)))) (pre_latexify (Filename.basename (Pos.get_file (Pos.get_position c))))
(Pos.get_start_line (Pos.get_position c) - 1) (Pos.get_start_line (Pos.get_position c) - 1)
(get_language_extension language) (Pos.unmark c) (get_language_extension language)
(Pos.unmark c)
| A.CodeBlock (_, c, true) -> | A.CodeBlock (_, c, true) ->
let metadata_title = let metadata_title =
match language with Fr -> "Métadonnées" | En -> "Metadata" | Pl -> "Metadane" match language with
| Fr -> "Métadonnées"
| En -> "Metadata"
| Pl -> "Metadane"
in in
Format.fprintf fmt Format.fprintf fmt
"\\begin{tcolorbox}[colframe=OliveGreen, breakable, \ "\\begin{tcolorbox}[colframe=OliveGreen, breakable, \
title=\\textcolor{black}{\\texttt{%s}},title after \ title=\\textcolor{black}{\\texttt{%s}},title after \
break=\\textcolor{black}{\\texttt{%s}},before skip=1em, after skip=1em]\n\ break=\\textcolor{black}{\\texttt{%s}},before skip=1em, after \
\\begin{minted}[numbersep=9mm, firstnumber=%d, label={\\hspace*{\\fill}\\texttt{%s}}]{%s}\n\ skip=1em]\n\
\\begin{minted}[numbersep=9mm, firstnumber=%d, \
label={\\hspace*{\\fill}\\texttt{%s}}]{%s}\n\
```catala\n\ ```catala\n\
%s```\n\ %s```\n\
\\end{minted}\n\ \\end{minted}\n\
@ -154,11 +178,15 @@ let rec law_structure_to_latex (language : C.backend_lang) (fmt : Format.formatt
metadata_title metadata_title metadata_title metadata_title
(Pos.get_start_line (Pos.get_position c) - 1) (Pos.get_start_line (Pos.get_position c) - 1)
(pre_latexify (Filename.basename (Pos.get_file (Pos.get_position c)))) (pre_latexify (Filename.basename (Pos.get_file (Pos.get_position c))))
(get_language_extension language) (Pos.unmark c) (get_language_extension language)
(Pos.unmark c)
(** {1 API} *) (** {1 API} *)
let ast_to_latex (language : C.backend_lang) (fmt : Format.formatter) (program : A.program) : unit = let ast_to_latex
(language : C.backend_lang) (fmt : Format.formatter) (program : A.program) :
unit =
Format.pp_print_list Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "\n\n") ~pp_sep:(fun fmt () -> Format.fprintf fmt "\n\n")
(law_structure_to_latex language) fmt program.program_items (law_structure_to_latex language)
fmt program.program_items

View File

@ -1,30 +1,37 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits (* This file is part of the Catala compiler, a specification language for tax
computation rules. Copyright (C) 2020 Inria, contributor: Denis Merigoux and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
<denis.merigoux@inria.fr> Denis Merigoux <denis.merigoux@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
(** This modules weaves the source code and the legislative text together into a document that law (** This modules weaves the source code and the legislative text together into a
professionals can understand. *) document that law professionals can understand. *)
open Utils open Utils
(** {1 Helpers} *) (** {1 Helpers} *)
val wrap_latex : val wrap_latex :
string list -> Cli.backend_lang -> Format.formatter -> (Format.formatter -> unit) -> unit string list ->
Cli.backend_lang ->
Format.formatter ->
(Format.formatter -> unit) ->
unit
(** Usage: [wrap_latex source_files language fmt wrapped] (** Usage: [wrap_latex source_files language fmt wrapped]
Prints an LaTeX complete documùent structure around the [wrapped] content. *) Prints an LaTeX complete documùent structure around the [wrapped] content. *)
(** {1 API} *) (** {1 API} *)
val ast_to_latex : Cli.backend_lang -> Format.formatter -> Surface.Ast.program -> unit val ast_to_latex :
Cli.backend_lang -> Format.formatter -> Surface.Ast.program -> unit

View File

@ -1,14 +1,17 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits (* This file is part of the Catala compiler, a specification language for tax
computation rules. Copyright (C) 2020 Inria, contributor: Emile Rolley <emile.rolley@tuta.io> and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
Emile Rolley <emile.rolley@tuta.io>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
open Utils open Utils
@ -34,4 +37,7 @@ let literal_last_modification = function
| Fr -> "dernière modification le" | Fr -> "dernière modification le"
| Pl -> "ostatnia modyfikacja" | Pl -> "ostatnia modyfikacja"
let get_language_extension = function Fr -> "catala_fr" | En -> "catala_en" | Pl -> "catala_pl" let get_language_extension = function
| Fr -> "catala_fr"
| En -> "catala_en"
| Pl -> "catala_pl"

View File

@ -1,29 +1,37 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits (* This file is part of the Catala compiler, a specification language for tax
computation rules. Copyright (C) 2020 Inria, contributor: Emile Rolley <emile.rolley@tuta.io> and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
Emile Rolley <emile.rolley@tuta.io>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
open Utils open Utils
val literal_title : Cli.backend_lang -> string val literal_title : Cli.backend_lang -> string
(** Return the title traduction according the given {!type: Utils.Cli.backend_lang}. *) (** Return the title traduction according the given {!type:
Utils.Cli.backend_lang}. *)
val literal_generated_by : Cli.backend_lang -> string val literal_generated_by : Cli.backend_lang -> string
(** Return the 'generated by' traduction according the given {!type: Utils.Cli.backend_lang}. *) (** Return the 'generated by' traduction according the given {!type:
Utils.Cli.backend_lang}. *)
val literal_source_files : Cli.backend_lang -> string val literal_source_files : Cli.backend_lang -> string
(** Return the 'source files weaved' traduction according the given {!type: Utils.Cli.backend_lang}. *) (** Return the 'source files weaved' traduction according the given {!type:
Utils.Cli.backend_lang}. *)
val literal_last_modification : Cli.backend_lang -> string val literal_last_modification : Cli.backend_lang -> string
(** Return the 'last modification' traduction according the given {!type: Utils.Cli.backend_lang}. *) (** Return the 'last modification' traduction according the given {!type:
Utils.Cli.backend_lang}. *)
val get_language_extension : Cli.backend_lang -> string val get_language_extension : Cli.backend_lang -> string
(** Return the file extension corresponding to the given {!type: Utils.Cli.backend_lang}. *) (** Return the file extension corresponding to the given {!type:
Utils.Cli.backend_lang}. *)

View File

@ -1,25 +1,23 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits (* This file is part of the Catala compiler, a specification language for tax
computation rules. Copyright (C) 2020 Inria, contributor: Denis Merigoux and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
<denis.merigoux@inria.fr>, Emile Rolley <emile.rolley@tuta.io> Denis Merigoux <denis.merigoux@inria.fr>, Emile Rolley <emile.rolley@tuta.io>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
type money = Z.t type money = Z.t
type integer = Z.t type integer = Z.t
type decimal = Q.t type decimal = Q.t
type date = CalendarLib.Date.t type date = CalendarLib.Date.t
type duration = CalendarLib.Date.Period.t type duration = CalendarLib.Date.Period.t
type source_position = { type source_position = {
@ -34,17 +32,11 @@ type source_position = {
type 'a eoption = ENone of unit | ESome of 'a type 'a eoption = ENone of unit | ESome of 'a
exception EmptyError exception EmptyError
exception AssertionFailed exception AssertionFailed
exception ConflictError exception ConflictError
exception UncomparableDurations exception UncomparableDurations
exception IndivisableDurations exception IndivisableDurations
exception ImpossibleDate exception ImpossibleDate
exception NoValueProvided of source_position exception NoValueProvided of source_position
type runtime_value = type runtime_value =
@ -61,21 +53,13 @@ type runtime_value =
| Unembeddable | Unembeddable
let unembeddable _ = Unembeddable let unembeddable _ = Unembeddable
let embed_unit () = Unit let embed_unit () = Unit
let embed_bool x = Bool x let embed_bool x = Bool x
let embed_money x = Money x let embed_money x = Money x
let embed_integer x = Integer x let embed_integer x = Integer x
let embed_decimal x = Decimal x let embed_decimal x = Decimal x
let embed_date x = Date x let embed_date x = Date x
let embed_duration x = Duration x let embed_duration x = Duration x
let embed_array f x = Array (Array.map f x) let embed_array f x = Array (Array.map f x)
type event = type event =
@ -85,9 +69,7 @@ type event =
| DecisionTaken of source_position | DecisionTaken of source_position
let log_ref : event list ref = ref [] let log_ref : event list ref = ref []
let reset_log () = log_ref := [] let reset_log () = log_ref := []
let retrieve_log () = List.rev !log_ref let retrieve_log () = List.rev !log_ref
let log_begin_call info f x = let log_begin_call info f x =
@ -107,24 +89,17 @@ let log_decision_taken pos x =
x x
let money_of_cents_string (cents : string) : money = Z.of_string cents let money_of_cents_string (cents : string) : money = Z.of_string cents
let money_of_units_int (units : int) : money = Z.(of_int units * of_int 100) let money_of_units_int (units : int) : money = Z.(of_int units * of_int 100)
let money_of_cents_integer (cents : integer) : money = cents let money_of_cents_integer (cents : integer) : money = cents
let money_to_float (m : money) : float = Z.to_float m /. 100. let money_to_float (m : money) : float = Z.to_float m /. 100.
let money_to_string (m : money) : string = let money_to_string (m : money) : string =
Format.asprintf "%.2f" Q.(to_float (of_bigint m / of_int 100)) Format.asprintf "%.2f" Q.(to_float (of_bigint m / of_int 100))
let money_to_cents m = m let money_to_cents m = m
let decimal_of_string (d : string) : decimal = Q.of_string d let decimal_of_string (d : string) : decimal = Q.of_string d
let decimal_to_float (d : decimal) : float = Q.to_float d let decimal_to_float (d : decimal) : float = Q.to_float d
let decimal_of_float (d : float) : decimal = Q.of_float d let decimal_of_float (d : float) : decimal = Q.of_float d
let decimal_of_integer (d : integer) : decimal = Q.of_bigint d let decimal_of_integer (d : integer) : decimal = Q.of_bigint d
let decimal_to_string ~(max_prec_digits : int) (i : decimal) : string = let decimal_to_string ~(max_prec_digits : int) (i : decimal) : string =
@ -146,7 +121,10 @@ let decimal_to_string ~(max_prec_digits : int) (i : decimal) : string =
| `End i -> i | `End i -> i
| `Begin i -> i | `Begin i -> i
in in
while !n <> Z.zero && List.length !digits - leading_zeroes !digits < max_prec_digits do while
!n <> Z.zero
&& List.length !digits - leading_zeroes !digits < max_prec_digits
do
n := Z.mul !n (Z.of_int 10); n := Z.mul !n (Z.of_int 10);
digits := Z.ediv !n d :: !digits; digits := Z.ediv !n d :: !digits;
n := Z.erem !n d n := Z.erem !n d
@ -158,26 +136,22 @@ let decimal_to_string ~(max_prec_digits : int) (i : decimal) : string =
~pp_sep:(fun _fmt () -> ()) ~pp_sep:(fun _fmt () -> ())
(fun fmt digit -> Format.fprintf fmt "%a" Z.pp_print digit)) (fun fmt digit -> Format.fprintf fmt "%a" Z.pp_print digit))
(List.rev !digits) (List.rev !digits)
(if List.length !digits - leading_zeroes !digits = max_prec_digits then "" else "") (if List.length !digits - leading_zeroes !digits = max_prec_digits then ""
else "")
let integer_of_string (s : string) : integer = Z.of_string s let integer_of_string (s : string) : integer = Z.of_string s
let integer_to_string (i : integer) : string = Z.to_string i let integer_to_string (i : integer) : string = Z.to_string i
let integer_to_int (i : integer) : int = Z.to_int i let integer_to_int (i : integer) : int = Z.to_int i
let integer_of_int (i : int) : integer = Z.of_int i let integer_of_int (i : int) : integer = Z.of_int i
let integer_exponentiation (i : integer) (e : int) : integer = Z.pow i e let integer_exponentiation (i : integer) (e : int) : integer = Z.pow i e
let integer_log2 = Z.log2 let integer_log2 = Z.log2
let year_of_date (d : date) : integer = Z.of_int (CalendarLib.Date.year d) let year_of_date (d : date) : integer = Z.of_int (CalendarLib.Date.year d)
let month_number_of_date (d : date) : integer = let month_number_of_date (d : date) : integer =
Z.of_int (CalendarLib.Date.int_of_month (CalendarLib.Date.month d)) Z.of_int (CalendarLib.Date.int_of_month (CalendarLib.Date.month d))
let day_of_month_of_date (d : date) : integer = Z.of_int (CalendarLib.Date.day_of_month d) let day_of_month_of_date (d : date) : integer =
Z.of_int (CalendarLib.Date.day_of_month d)
let date_of_numbers (year : int) (month : int) (day : int) : date = let date_of_numbers (year : int) (month : int) (day : int) : date =
try CalendarLib.Date.make year month day with _ -> raise ImpossibleDate try CalendarLib.Date.make year month day with _ -> raise ImpossibleDate
@ -189,7 +163,11 @@ let duration_of_numbers (year : int) (month : int) (day : int) : duration =
let duration_to_string (d : duration) : string = let duration_to_string (d : duration) : string =
let x, y, z = CalendarLib.Date.Period.ymd d in let x, y, z = CalendarLib.Date.Period.ymd d in
let to_print = List.filter (fun (a, _) -> a <> 0) [ (x, "years"); (y, "months"); (z, "days") ] in let to_print =
List.filter
(fun (a, _) -> a <> 0)
[ (x, "years"); (y, "months"); (z, "days") ]
in
match to_print with match to_print with
| [] -> "empty duration" | [] -> "empty duration"
| _ -> | _ ->
@ -199,9 +177,11 @@ let duration_to_string (d : duration) : string =
(fun fmt (d, l) -> Format.fprintf fmt "%d %s" d l)) (fun fmt (d, l) -> Format.fprintf fmt "%d %s" d l))
to_print to_print
let duration_to_years_months_days (d : duration) : int * int * int = CalendarLib.Date.Period.ymd d let duration_to_years_months_days (d : duration) : int * int * int =
CalendarLib.Date.Period.ymd d
let handle_default : 'a. (unit -> 'a) array -> (unit -> bool) -> (unit -> 'a) -> 'a = let handle_default :
'a. (unit -> 'a) array -> (unit -> bool) -> (unit -> 'a) -> 'a =
fun exceptions just cons -> fun exceptions just cons ->
let except = let except =
Array.fold_left Array.fold_left
@ -213,9 +193,12 @@ let handle_default : 'a. (unit -> 'a) array -> (unit -> bool) -> (unit -> 'a) ->
| Some _, Some _ -> raise ConflictError) | Some _, Some _ -> raise ConflictError)
None exceptions None exceptions
in in
match except with Some x -> x | None -> if just () then cons () else raise EmptyError match except with
| Some x -> x
| None -> if just () then cons () else raise EmptyError
let handle_default_opt (exceptions : 'a eoption array) (just : bool eoption) (cons : 'a eoption) : let handle_default_opt
(exceptions : 'a eoption array) (just : bool eoption) (cons : 'a eoption) :
'a eoption = 'a eoption =
let except = let except =
Array.fold_left Array.fold_left
@ -228,58 +211,56 @@ let handle_default_opt (exceptions : 'a eoption array) (just : bool eoption) (co
in in
match except with match except with
| ESome _ -> except | ESome _ -> except
| ENone _ -> ( match just with ESome b -> if b then cons else ENone () | ENone _ -> ENone ()) | ENone _ -> (
match just with
| ESome b -> if b then cons else ENone ()
| ENone _ -> ENone ())
let no_input : unit -> 'a = fun _ -> raise EmptyError let no_input : unit -> 'a = fun _ -> raise EmptyError
let ( *$ ) (i1 : money) (i2 : decimal) : money = let ( *$ ) (i1 : money) (i2 : decimal) : money =
let rat_result = Q.mul (Q.of_bigint i1) i2 in let rat_result = Q.mul (Q.of_bigint i1) i2 in
let res, remainder = Z.div_rem (Q.num rat_result) (Q.den rat_result) in let res, remainder = Z.div_rem (Q.num rat_result) (Q.den rat_result) in
(* we perform nearest rounding when multiplying an amount of money by a decimal !*) (* we perform nearest rounding when multiplying an amount of money by a
if Z.(of_int 2 * remainder >= Q.den rat_result) then Z.add res (Z.of_int 1) else res decimal !*)
if Z.(of_int 2 * remainder >= Q.den rat_result) then Z.add res (Z.of_int 1)
else res
let ( /$ ) (m1 : money) (m2 : money) : decimal = let ( /$ ) (m1 : money) (m2 : money) : decimal =
if Z.zero = m2 then raise Division_by_zero else Q.div (Q.of_bigint m1) (Q.of_bigint m2) if Z.zero = m2 then raise Division_by_zero
else Q.div (Q.of_bigint m1) (Q.of_bigint m2)
let ( +$ ) (m1 : money) (m2 : money) : money = Z.add m1 m2 let ( +$ ) (m1 : money) (m2 : money) : money = Z.add m1 m2
let ( -$ ) (m1 : money) (m2 : money) : money = Z.sub m1 m2 let ( -$ ) (m1 : money) (m2 : money) : money = Z.sub m1 m2
let ( ~-$ ) (m1 : money) : money = Z.sub Z.zero m1 let ( ~-$ ) (m1 : money) : money = Z.sub Z.zero m1
let ( +! ) (i1 : integer) (i2 : integer) : integer = Z.add i1 i2 let ( +! ) (i1 : integer) (i2 : integer) : integer = Z.add i1 i2
let ( -! ) (i1 : integer) (i2 : integer) : integer = Z.sub i1 i2 let ( -! ) (i1 : integer) (i2 : integer) : integer = Z.sub i1 i2
let ( ~-! ) (i1 : integer) : integer = Z.sub Z.zero i1 let ( ~-! ) (i1 : integer) : integer = Z.sub Z.zero i1
let ( *! ) (i1 : integer) (i2 : integer) : integer = Z.mul i1 i2 let ( *! ) (i1 : integer) (i2 : integer) : integer = Z.mul i1 i2
let ( /! ) (i1 : integer) (i2 : integer) : integer = let ( /! ) (i1 : integer) (i2 : integer) : integer =
if Z.zero = i2 then raise Division_by_zero else Z.div i1 i2 if Z.zero = i2 then raise Division_by_zero else Z.div i1 i2
let ( +& ) (i1 : decimal) (i2 : decimal) : decimal = Q.add i1 i2 let ( +& ) (i1 : decimal) (i2 : decimal) : decimal = Q.add i1 i2
let ( -& ) (i1 : decimal) (i2 : decimal) : decimal = Q.sub i1 i2 let ( -& ) (i1 : decimal) (i2 : decimal) : decimal = Q.sub i1 i2
let ( ~-& ) (i1 : decimal) : decimal = Q.sub Q.zero i1 let ( ~-& ) (i1 : decimal) : decimal = Q.sub Q.zero i1
let ( *& ) (i1 : decimal) (i2 : decimal) : decimal = Q.mul i1 i2 let ( *& ) (i1 : decimal) (i2 : decimal) : decimal = Q.mul i1 i2
let ( /& ) (i1 : decimal) (i2 : decimal) : decimal = let ( /& ) (i1 : decimal) (i2 : decimal) : decimal =
if Q.zero = i2 then raise Division_by_zero else Q.div i1 i2 if Q.zero = i2 then raise Division_by_zero else Q.div i1 i2
let ( +@ ) (d1 : date) (d2 : duration) : date = CalendarLib.Date.add d1 d2 let ( +@ ) (d1 : date) (d2 : duration) : date = CalendarLib.Date.add d1 d2
let ( -@ ) (d1 : date) (d2 : date) : duration = CalendarLib.Date.sub d1 d2 let ( -@ ) (d1 : date) (d2 : date) : duration = CalendarLib.Date.sub d1 d2
let ( +^ ) (d1 : duration) (d2 : duration) : duration = CalendarLib.Date.Period.add d1 d2 let ( +^ ) (d1 : duration) (d2 : duration) : duration =
CalendarLib.Date.Period.add d1 d2
let ( -^ ) (d1 : duration) (d2 : duration) : duration = CalendarLib.Date.Period.sub d1 d2 let ( -^ ) (d1 : duration) (d2 : duration) : duration =
CalendarLib.Date.Period.sub d1 d2
(* (EmileRolley) NOTE: {!CalendarLib.Date.Period.nb_days} is deprecated, (* (EmileRolley) NOTE: {!CalendarLib.Date.Period.nb_days} is deprecated,
{!CalendarLib.Date.Period.safe_nb_days} should be used. But the current {!duration} is greater {!CalendarLib.Date.Period.safe_nb_days} should be used. But the current
that the supported polymorphic variants.*) {!duration} is greater that the supported polymorphic variants.*)
let ( /^ ) (d1 : duration) (d2 : duration) : decimal = let ( /^ ) (d1 : duration) (d2 : duration) : decimal =
try try
let nb_day1 = CalendarLib.Date.Period.nb_days d1 in let nb_day1 = CalendarLib.Date.Period.nb_days d1 in
@ -288,46 +269,28 @@ let ( /^ ) (d1 : duration) (d2 : duration) : decimal =
with CalendarLib.Date.Period.Not_computable -> raise IndivisableDurations with CalendarLib.Date.Period.Not_computable -> raise IndivisableDurations
let ( <=$ ) (m1 : money) (m2 : money) : bool = Z.compare m1 m2 <= 0 let ( <=$ ) (m1 : money) (m2 : money) : bool = Z.compare m1 m2 <= 0
let ( >=$ ) (m1 : money) (m2 : money) : bool = Z.compare m1 m2 >= 0 let ( >=$ ) (m1 : money) (m2 : money) : bool = Z.compare m1 m2 >= 0
let ( <$ ) (m1 : money) (m2 : money) : bool = Z.compare m1 m2 < 0 let ( <$ ) (m1 : money) (m2 : money) : bool = Z.compare m1 m2 < 0
let ( >$ ) (m1 : money) (m2 : money) : bool = Z.compare m1 m2 > 0 let ( >$ ) (m1 : money) (m2 : money) : bool = Z.compare m1 m2 > 0
let ( =$ ) (m1 : money) (m2 : money) : bool = Z.compare m1 m2 = 0 let ( =$ ) (m1 : money) (m2 : money) : bool = Z.compare m1 m2 = 0
let ( >=! ) (i1 : integer) (i2 : integer) : bool = Z.compare i1 i2 >= 0 let ( >=! ) (i1 : integer) (i2 : integer) : bool = Z.compare i1 i2 >= 0
let ( <=! ) (i1 : integer) (i2 : integer) : bool = Z.compare i1 i2 <= 0 let ( <=! ) (i1 : integer) (i2 : integer) : bool = Z.compare i1 i2 <= 0
let ( >! ) (i1 : integer) (i2 : integer) : bool = Z.compare i1 i2 > 0 let ( >! ) (i1 : integer) (i2 : integer) : bool = Z.compare i1 i2 > 0
let ( <! ) (i1 : integer) (i2 : integer) : bool = Z.compare i1 i2 < 0 let ( <! ) (i1 : integer) (i2 : integer) : bool = Z.compare i1 i2 < 0
let ( =! ) (i1 : integer) (i2 : integer) : bool = Z.compare i1 i2 = 0 let ( =! ) (i1 : integer) (i2 : integer) : bool = Z.compare i1 i2 = 0
let ( >=& ) (i1 : decimal) (i2 : decimal) : bool = Q.compare i1 i2 >= 0 let ( >=& ) (i1 : decimal) (i2 : decimal) : bool = Q.compare i1 i2 >= 0
let ( <=& ) (i1 : decimal) (i2 : decimal) : bool = Q.compare i1 i2 <= 0 let ( <=& ) (i1 : decimal) (i2 : decimal) : bool = Q.compare i1 i2 <= 0
let ( >& ) (i1 : decimal) (i2 : decimal) : bool = Q.compare i1 i2 > 0 let ( >& ) (i1 : decimal) (i2 : decimal) : bool = Q.compare i1 i2 > 0
let ( <& ) (i1 : decimal) (i2 : decimal) : bool = Q.compare i1 i2 < 0 let ( <& ) (i1 : decimal) (i2 : decimal) : bool = Q.compare i1 i2 < 0
let ( =& ) (i1 : decimal) (i2 : decimal) : bool = Q.compare i1 i2 = 0 let ( =& ) (i1 : decimal) (i2 : decimal) : bool = Q.compare i1 i2 = 0
let ( >=@ ) (d1 : date) (d2 : date) : bool = CalendarLib.Date.compare d1 d2 >= 0 let ( >=@ ) (d1 : date) (d2 : date) : bool = CalendarLib.Date.compare d1 d2 >= 0
let ( <=@ ) (d1 : date) (d2 : date) : bool = CalendarLib.Date.compare d1 d2 <= 0 let ( <=@ ) (d1 : date) (d2 : date) : bool = CalendarLib.Date.compare d1 d2 <= 0
let ( >@ ) (d1 : date) (d2 : date) : bool = CalendarLib.Date.compare d1 d2 > 0 let ( >@ ) (d1 : date) (d2 : date) : bool = CalendarLib.Date.compare d1 d2 > 0
let ( <@ ) (d1 : date) (d2 : date) : bool = CalendarLib.Date.compare d1 d2 < 0 let ( <@ ) (d1 : date) (d2 : date) : bool = CalendarLib.Date.compare d1 d2 < 0
let ( =@ ) (d1 : date) (d2 : date) : bool = CalendarLib.Date.compare d1 d2 = 0 let ( =@ ) (d1 : date) (d2 : date) : bool = CalendarLib.Date.compare d1 d2 = 0
let compare_periods (p1 : CalendarLib.Date.Period.t) (p2 : CalendarLib.Date.Period.t) : int = let compare_periods
(p1 : CalendarLib.Date.Period.t) (p2 : CalendarLib.Date.Period.t) : int =
try try
let p1_days = CalendarLib.Date.Period.nb_days p1 in let p1_days = CalendarLib.Date.Period.nb_days p1 in
let p2_days = CalendarLib.Date.Period.nb_days p2 in let p2_days = CalendarLib.Date.Period.nb_days p2 in
@ -335,15 +298,10 @@ let compare_periods (p1 : CalendarLib.Date.Period.t) (p2 : CalendarLib.Date.Peri
with CalendarLib.Date.Period.Not_computable -> raise UncomparableDurations with CalendarLib.Date.Period.Not_computable -> raise UncomparableDurations
let ( >=^ ) (d1 : duration) (d2 : duration) : bool = compare_periods d1 d2 >= 0 let ( >=^ ) (d1 : duration) (d2 : duration) : bool = compare_periods d1 d2 >= 0
let ( <=^ ) (d1 : duration) (d2 : duration) : bool = compare_periods d1 d2 <= 0 let ( <=^ ) (d1 : duration) (d2 : duration) : bool = compare_periods d1 d2 <= 0
let ( >^ ) (d1 : duration) (d2 : duration) : bool = compare_periods d1 d2 > 0 let ( >^ ) (d1 : duration) (d2 : duration) : bool = compare_periods d1 d2 > 0
let ( <^ ) (d1 : duration) (d2 : duration) : bool = compare_periods d1 d2 < 0 let ( <^ ) (d1 : duration) (d2 : duration) : bool = compare_periods d1 d2 < 0
let ( =^ ) (d1 : duration) (d2 : duration) : bool = compare_periods d1 d2 = 0 let ( =^ ) (d1 : duration) (d2 : duration) : bool = compare_periods d1 d2 = 0
let ( ~-^ ) (d1 : duration) : duration = CalendarLib.Date.Period.opp d1 let ( ~-^ ) (d1 : duration) : duration = CalendarLib.Date.Period.opp d1
let array_filter (f : 'a -> bool) (a : 'a array) : 'a array = let array_filter (f : 'a -> bool) (a : 'a array) : 'a array =

View File

@ -1,27 +1,26 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits (* This file is part of the Catala compiler, a specification language for tax
computation rules. Copyright (C) 2020 Inria, contributors: Denis Merigoux and social benefits computation rules. Copyright (C) 2020 Inria,
<denis.merigoux@inria.fr>, Emile Rolley <emile.rolley@tuta.io> contributors: Denis Merigoux <denis.merigoux@inria.fr>, Emile Rolley
<emile.rolley@tuta.io>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
(** {1 Types} *) (** {1 Types} *)
type money type money
type integer type integer
type decimal type decimal
type date type date
type duration type duration
type source_position = { type source_position = {
@ -38,17 +37,11 @@ type 'a eoption = ENone of unit | ESome of 'a
(** {1 Exceptions} *) (** {1 Exceptions} *)
exception EmptyError exception EmptyError
exception AssertionFailed exception AssertionFailed
exception ConflictError exception ConflictError
exception UncomparableDurations exception UncomparableDurations
exception IndivisableDurations exception IndivisableDurations
exception ImpossibleDate exception ImpossibleDate
exception NoValueProvided of source_position exception NoValueProvided of source_position
(** {1 Value Embedding} *) (** {1 Value Embedding} *)
@ -67,21 +60,13 @@ type runtime_value =
| Unembeddable | Unembeddable
val unembeddable : 'a -> runtime_value val unembeddable : 'a -> runtime_value
val embed_unit : unit -> runtime_value val embed_unit : unit -> runtime_value
val embed_bool : bool -> runtime_value val embed_bool : bool -> runtime_value
val embed_money : money -> runtime_value val embed_money : money -> runtime_value
val embed_integer : integer -> runtime_value val embed_integer : integer -> runtime_value
val embed_decimal : decimal -> runtime_value val embed_decimal : decimal -> runtime_value
val embed_date : date -> runtime_value val embed_date : date -> runtime_value
val embed_duration : duration -> runtime_value val embed_duration : duration -> runtime_value
val embed_array : ('a -> runtime_value) -> 'a Array.t -> runtime_value val embed_array : ('a -> runtime_value) -> 'a Array.t -> runtime_value
(** {1 Logging} *) (** {1 Logging} *)
@ -93,15 +78,10 @@ type event =
| DecisionTaken of source_position | DecisionTaken of source_position
val reset_log : unit -> unit val reset_log : unit -> unit
val retrieve_log : unit -> event list val retrieve_log : unit -> event list
val log_begin_call : string list -> ('a -> 'b) -> 'a -> 'b val log_begin_call : string list -> ('a -> 'b) -> 'a -> 'b
val log_end_call : string list -> 'a -> 'a val log_end_call : string list -> 'a -> 'a
val log_variable_definition : string list -> ('a -> runtime_value) -> 'a -> 'a val log_variable_definition : string list -> ('a -> runtime_value) -> 'a -> 'a
val log_decision_taken : source_position -> bool -> bool val log_decision_taken : source_position -> bool -> bool
(**{1 Constructors and conversions} *) (**{1 Constructors and conversions} *)
@ -109,51 +89,34 @@ val log_decision_taken : source_position -> bool -> bool
(**{2 Money}*) (**{2 Money}*)
val money_of_cents_string : string -> money val money_of_cents_string : string -> money
val money_of_units_int : int -> money val money_of_units_int : int -> money
val money_of_cents_integer : integer -> money val money_of_cents_integer : integer -> money
val money_to_float : money -> float val money_to_float : money -> float
val money_to_string : money -> string val money_to_string : money -> string
val money_to_cents : money -> integer val money_to_cents : money -> integer
(** {2 Decimals} *) (** {2 Decimals} *)
val decimal_of_string : string -> decimal val decimal_of_string : string -> decimal
val decimal_to_string : max_prec_digits:int -> decimal -> string val decimal_to_string : max_prec_digits:int -> decimal -> string
val decimal_of_integer : integer -> decimal val decimal_of_integer : integer -> decimal
val decimal_of_float : float -> decimal val decimal_of_float : float -> decimal
val decimal_to_float : decimal -> float val decimal_to_float : decimal -> float
(**{2 Integers} *) (**{2 Integers} *)
val integer_of_string : string -> integer val integer_of_string : string -> integer
val integer_to_string : integer -> string val integer_to_string : integer -> string
val integer_to_int : integer -> int val integer_to_int : integer -> int
val integer_of_int : int -> integer val integer_of_int : int -> integer
val integer_log2 : integer -> int val integer_log2 : integer -> int
val integer_exponentiation : integer -> int -> integer val integer_exponentiation : integer -> int -> integer
(**{2 Dates} *) (**{2 Dates} *)
val day_of_month_of_date : date -> integer val day_of_month_of_date : date -> integer
val month_number_of_date : date -> integer val month_number_of_date : date -> integer
val year_of_date : date -> integer val year_of_date : date -> integer
val date_to_string : date -> string val date_to_string : date -> string
val date_of_numbers : int -> int -> int -> date val date_of_numbers : int -> int -> int -> date
@ -164,9 +127,7 @@ val date_of_numbers : int -> int -> int -> date
(**{2 Durations} *) (**{2 Durations} *)
val duration_of_numbers : int -> int -> int -> duration val duration_of_numbers : int -> int -> int -> duration
val duration_to_years_months_days : duration -> int * int * int val duration_to_years_months_days : duration -> int * int * int
val duration_to_string : duration -> string val duration_to_string : duration -> string
(**{1 Defaults} *) (**{1 Defaults} *)
@ -175,7 +136,8 @@ val handle_default : (unit -> 'a) array -> (unit -> bool) -> (unit -> 'a) -> 'a
(** @raise EmptyError (** @raise EmptyError
@raise ConflictError *) @raise ConflictError *)
val handle_default_opt : 'a eoption array -> bool eoption -> 'a eoption -> 'a eoption val handle_default_opt :
'a eoption array -> bool eoption -> 'a eoption -> 'a eoption
(** @raise ConflictError *) (** @raise ConflictError *)
val no_input : unit -> 'a val no_input : unit -> 'a
@ -190,87 +152,59 @@ val ( /$ ) : money -> money -> decimal
(** @raise Division_by_zero *) (** @raise Division_by_zero *)
val ( +$ ) : money -> money -> money val ( +$ ) : money -> money -> money
val ( -$ ) : money -> money -> money val ( -$ ) : money -> money -> money
val ( ~-$ ) : money -> money val ( ~-$ ) : money -> money
val ( =$ ) : money -> money -> bool val ( =$ ) : money -> money -> bool
val ( <=$ ) : money -> money -> bool val ( <=$ ) : money -> money -> bool
val ( >=$ ) : money -> money -> bool val ( >=$ ) : money -> money -> bool
val ( <$ ) : money -> money -> bool val ( <$ ) : money -> money -> bool
val ( >$ ) : money -> money -> bool val ( >$ ) : money -> money -> bool
(**{2 Integers} *) (**{2 Integers} *)
val ( +! ) : integer -> integer -> integer val ( +! ) : integer -> integer -> integer
val ( -! ) : integer -> integer -> integer val ( -! ) : integer -> integer -> integer
val ( ~-! ) : integer -> integer val ( ~-! ) : integer -> integer
val ( *! ) : integer -> integer -> integer val ( *! ) : integer -> integer -> integer
val ( /! ) : integer -> integer -> integer val ( /! ) : integer -> integer -> integer
(** @raise Division_by_zero *) (** @raise Division_by_zero *)
val ( =! ) : integer -> integer -> bool val ( =! ) : integer -> integer -> bool
val ( >=! ) : integer -> integer -> bool val ( >=! ) : integer -> integer -> bool
val ( <=! ) : integer -> integer -> bool val ( <=! ) : integer -> integer -> bool
val ( >! ) : integer -> integer -> bool val ( >! ) : integer -> integer -> bool
val ( <! ) : integer -> integer -> bool val ( <! ) : integer -> integer -> bool
(** {2 Decimals} *) (** {2 Decimals} *)
val ( +& ) : decimal -> decimal -> decimal val ( +& ) : decimal -> decimal -> decimal
val ( -& ) : decimal -> decimal -> decimal val ( -& ) : decimal -> decimal -> decimal
val ( ~-& ) : decimal -> decimal val ( ~-& ) : decimal -> decimal
val ( *& ) : decimal -> decimal -> decimal val ( *& ) : decimal -> decimal -> decimal
val ( /& ) : decimal -> decimal -> decimal val ( /& ) : decimal -> decimal -> decimal
(** @raise Division_by_zero *) (** @raise Division_by_zero *)
val ( =& ) : decimal -> decimal -> bool val ( =& ) : decimal -> decimal -> bool
val ( >=& ) : decimal -> decimal -> bool val ( >=& ) : decimal -> decimal -> bool
val ( <=& ) : decimal -> decimal -> bool val ( <=& ) : decimal -> decimal -> bool
val ( >& ) : decimal -> decimal -> bool val ( >& ) : decimal -> decimal -> bool
val ( <& ) : decimal -> decimal -> bool val ( <& ) : decimal -> decimal -> bool
(** {2 Dates} *) (** {2 Dates} *)
val ( +@ ) : date -> duration -> date val ( +@ ) : date -> duration -> date
val ( -@ ) : date -> date -> duration val ( -@ ) : date -> date -> duration
val ( =@ ) : date -> date -> bool val ( =@ ) : date -> date -> bool
val ( >=@ ) : date -> date -> bool val ( >=@ ) : date -> date -> bool
val ( <=@ ) : date -> date -> bool val ( <=@ ) : date -> date -> bool
val ( >@ ) : date -> date -> bool val ( >@ ) : date -> date -> bool
val ( <@ ) : date -> date -> bool val ( <@ ) : date -> date -> bool
(** {2 Durations} *) (** {2 Durations} *)
val ( +^ ) : duration -> duration -> duration val ( +^ ) : duration -> duration -> duration
val ( -^ ) : duration -> duration -> duration val ( -^ ) : duration -> duration -> duration
val ( /^ ) : duration -> duration -> decimal val ( /^ ) : duration -> duration -> decimal
@ -278,7 +212,6 @@ val ( /^ ) : duration -> duration -> decimal
@raise IndivisableDurations *) @raise IndivisableDurations *)
val ( ~-^ ) : duration -> duration val ( ~-^ ) : duration -> duration
val ( =^ ) : duration -> duration -> bool val ( =^ ) : duration -> duration -> bool
val ( >=^ ) : duration -> duration -> bool val ( >=^ ) : duration -> duration -> bool
@ -296,5 +229,4 @@ val ( <^ ) : duration -> duration -> bool
(** {2 Arrays} *) (** {2 Arrays} *)
val array_filter : ('a -> bool) -> 'a array -> 'a array val array_filter : ('a -> bool) -> 'a array -> 'a array
val array_length : 'a array -> integer val array_length : 'a array -> integer

View File

@ -1,23 +1,23 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits (* This file is part of the Catala compiler, a specification language for tax
computation rules. Copyright (C) 2021 Inria, contributor: Denis Merigoux and social benefits computation rules. Copyright (C) 2021 Inria, contributor:
<denis.merigoux@inria.fr> Denis Merigoux <denis.merigoux@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
open Utils open Utils
module D = Dcalc.Ast module D = Dcalc.Ast
module L = Lcalc.Ast module L = Lcalc.Ast
module TopLevelName = Uid.Make (Uid.MarkedString) () module TopLevelName = Uid.Make (Uid.MarkedString) ()
module LocalName = Uid.Make (Uid.MarkedString) () module LocalName = Uid.Make (Uid.MarkedString) ()
type expr = type expr =
@ -49,7 +49,10 @@ type stmt =
and block = stmt Pos.marked list and block = stmt Pos.marked list
and func = { func_params : (LocalName.t Pos.marked * D.typ Pos.marked) list; func_body : block } and func = {
func_params : (LocalName.t Pos.marked * D.typ Pos.marked) list;
func_body : block;
}
type scope_body = { type scope_body = {
scope_body_name : Dcalc.Ast.ScopeName.t; scope_body_name : Dcalc.Ast.ScopeName.t;

View File

@ -1,15 +1,17 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits (* This file is part of the Catala compiler, a specification language for tax
computation rules. Copyright (C) 2021 Inria, contributor: Denis Merigoux and social benefits computation rules. Copyright (C) 2021 Inria, contributor:
<denis.merigoux@inria.fr> Denis Merigoux <denis.merigoux@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
open Utils open Utils
@ -24,14 +26,16 @@ type ctxt = {
inside_definition_of : A.LocalName.t option; inside_definition_of : A.LocalName.t option;
} }
(* Expressions can spill out side effect, hence this function also returns a list of statements to (* Expressions can spill out side effect, hence this function also returns a
be prepended before the expression is evaluated *) list of statements to be prepended before the expression is evaluated *)
let rec translate_expr (ctxt : ctxt) (expr : L.expr Pos.marked) : A.block * A.expr Pos.marked = let rec translate_expr (ctxt : ctxt) (expr : L.expr Pos.marked) :
A.block * A.expr Pos.marked =
match Pos.unmark expr with match Pos.unmark expr with
| L.EVar v -> | L.EVar v ->
let local_var = let local_var =
try A.EVar (L.VarMap.find (Pos.unmark v) ctxt.var_dict) try A.EVar (L.VarMap.find (Pos.unmark v) ctxt.var_dict)
with Not_found -> A.EFunc (L.VarMap.find (Pos.unmark v) ctxt.func_dict) with Not_found ->
A.EFunc (L.VarMap.find (Pos.unmark v) ctxt.func_dict)
in in
([], (local_var, Pos.get_position v)) ([], (local_var, Pos.get_position v))
| L.ETuple (args, Some s_name) -> | L.ETuple (args, Some s_name) ->
@ -45,17 +49,26 @@ let rec translate_expr (ctxt : ctxt) (expr : L.expr Pos.marked) : A.block * A.ex
let new_args = List.rev new_args in let new_args = List.rev new_args in
let args_stmts = List.rev args_stmts in let args_stmts = List.rev args_stmts in
(args_stmts, (A.EStruct (new_args, s_name), Pos.get_position expr)) (args_stmts, (A.EStruct (new_args, s_name), Pos.get_position expr))
| L.ETuple (_, None) -> failwith "Non-struct tuples cannot be compiled to scalc" | L.ETuple (_, None) ->
failwith "Non-struct tuples cannot be compiled to scalc"
| L.ETupleAccess (e1, num_field, Some s_name, _) -> | L.ETupleAccess (e1, num_field, Some s_name, _) ->
let e1_stmts, new_e1 = translate_expr ctxt e1 in let e1_stmts, new_e1 = translate_expr ctxt e1 in
let field_name = let field_name =
fst (List.nth (D.StructMap.find s_name ctxt.decl_ctx.ctx_structs) num_field) fst
(List.nth
(D.StructMap.find s_name ctxt.decl_ctx.ctx_structs)
num_field)
in in
(e1_stmts, (A.EStructFieldAccess (new_e1, field_name, s_name), Pos.get_position expr)) ( e1_stmts,
| L.ETupleAccess (_, _, None, _) -> failwith "Non-struct tuples cannot be compiled to scalc" ( A.EStructFieldAccess (new_e1, field_name, s_name),
Pos.get_position expr ) )
| L.ETupleAccess (_, _, None, _) ->
failwith "Non-struct tuples cannot be compiled to scalc"
| L.EInj (e1, num_cons, e_name, _) -> | L.EInj (e1, num_cons, e_name, _) ->
let e1_stmts, new_e1 = translate_expr ctxt e1 in let e1_stmts, new_e1 = translate_expr ctxt e1 in
let cons_name = fst (List.nth (D.EnumMap.find e_name ctxt.decl_ctx.ctx_enums) num_cons) in let cons_name =
fst (List.nth (D.EnumMap.find e_name ctxt.decl_ctx.ctx_enums) num_cons)
in
(e1_stmts, (A.EInj (new_e1, cons_name, e_name), Pos.get_position expr)) (e1_stmts, (A.EInj (new_e1, cons_name, e_name), Pos.get_position expr))
| L.EApp (f, args) -> | L.EApp (f, args) ->
let f_stmts, new_f = translate_expr ctxt f in let f_stmts, new_f = translate_expr ctxt f in
@ -84,14 +97,18 @@ let rec translate_expr (ctxt : ctxt) (expr : L.expr Pos.marked) : A.block * A.ex
let tmp_var = A.LocalName.fresh ("local_var", Pos.get_position expr) in let tmp_var = A.LocalName.fresh ("local_var", Pos.get_position expr) in
let ctxt = { ctxt with inside_definition_of = Some tmp_var } in let ctxt = { ctxt with inside_definition_of = Some tmp_var } in
let tmp_stmts = translate_statements ctxt expr in let tmp_stmts = translate_statements ctxt expr in
( ( A.SLocalDecl ((tmp_var, Pos.get_position expr), (D.TAny, Pos.get_position expr)), ( ( A.SLocalDecl
((tmp_var, Pos.get_position expr), (D.TAny, Pos.get_position expr)),
Pos.get_position expr ) Pos.get_position expr )
:: tmp_stmts, :: tmp_stmts,
(A.EVar tmp_var, Pos.get_position expr) ) (A.EVar tmp_var, Pos.get_position expr) )
and translate_statements (ctxt : ctxt) (block_expr : L.expr Pos.marked) : A.block = and translate_statements (ctxt : ctxt) (block_expr : L.expr Pos.marked) :
A.block =
match Pos.unmark block_expr with match Pos.unmark block_expr with
| L.EApp ((L.EAbs ((binder, _), [ (D.TLit D.TUnit, _) ]), _), [ (L.EAssert e, _) ]) -> | L.EApp
((L.EAbs ((binder, _), [ (D.TLit D.TUnit, _) ]), _), [ (L.EAssert e, _) ])
->
(* Assertions are always encapsulated in a unit-typed let binding *) (* Assertions are always encapsulated in a unit-typed let binding *)
let _, body = Bindlib.unmbind binder in let _, body = Bindlib.unmbind binder in
let e_stmts, new_e = translate_expr ctxt e in let e_stmts, new_e = translate_expr ctxt e in
@ -101,32 +118,40 @@ and translate_statements (ctxt : ctxt) (block_expr : L.expr Pos.marked) : A.bloc
| L.EApp ((L.EAbs ((binder, binder_pos), taus), eabs_pos), args) -> | L.EApp ((L.EAbs ((binder, binder_pos), taus), eabs_pos), args) ->
(* This defines multiple local variables at the time *) (* This defines multiple local variables at the time *)
let vars, body = Bindlib.unmbind binder in let vars, body = Bindlib.unmbind binder in
let vars_tau = List.map2 (fun x tau -> (x, tau)) (Array.to_list vars) taus in let vars_tau =
List.map2 (fun x tau -> (x, tau)) (Array.to_list vars) taus
in
let ctxt = let ctxt =
{ {
ctxt with ctxt with
var_dict = var_dict =
List.fold_left List.fold_left
(fun var_dict (x, _) -> (fun var_dict (x, _) ->
L.VarMap.add x (A.LocalName.fresh (Bindlib.name_of x, binder_pos)) var_dict) L.VarMap.add x
(A.LocalName.fresh (Bindlib.name_of x, binder_pos))
var_dict)
ctxt.var_dict vars_tau; ctxt.var_dict vars_tau;
} }
in in
let local_decls = let local_decls =
List.map List.map
(fun (x, tau) -> (fun (x, tau) ->
(A.SLocalDecl ((L.VarMap.find x ctxt.var_dict, binder_pos), tau), eabs_pos)) ( A.SLocalDecl ((L.VarMap.find x ctxt.var_dict, binder_pos), tau),
eabs_pos ))
vars_tau vars_tau
in in
let vars_args = let vars_args =
List.map2 List.map2
(fun (x, tau) arg -> ((L.VarMap.find x ctxt.var_dict, binder_pos), tau, arg)) (fun (x, tau) arg ->
((L.VarMap.find x ctxt.var_dict, binder_pos), tau, arg))
vars_tau args vars_tau args
in in
let def_blocks = let def_blocks =
List.map List.map
(fun (x, _tau, arg) -> (fun (x, _tau, arg) ->
let ctxt = { ctxt with inside_definition_of = Some (Pos.unmark x) } in let ctxt =
{ ctxt with inside_definition_of = Some (Pos.unmark x) }
in
let arg_stmts, new_arg = translate_expr ctxt arg in let arg_stmts, new_arg = translate_expr ctxt arg in
arg_stmts @ [ (A.SLocalDef (x, new_arg), binder_pos) ]) arg_stmts @ [ (A.SLocalDef (x, new_arg), binder_pos) ])
vars_args vars_args
@ -135,7 +160,9 @@ and translate_statements (ctxt : ctxt) (block_expr : L.expr Pos.marked) : A.bloc
local_decls @ List.flatten def_blocks @ rest_of_block local_decls @ List.flatten def_blocks @ rest_of_block
| L.EAbs ((binder, binder_pos), taus) -> | L.EAbs ((binder, binder_pos), taus) ->
let vars, body = Bindlib.unmbind binder in let vars, body = Bindlib.unmbind binder in
let vars_tau = List.map2 (fun x tau -> (x, tau)) (Array.to_list vars) taus in let vars_tau =
List.map2 (fun x tau -> (x, tau)) (Array.to_list vars) taus
in
let closure_name = let closure_name =
match ctxt.inside_definition_of with match ctxt.inside_definition_of with
| None -> A.LocalName.fresh ("closure", Pos.get_position block_expr) | None -> A.LocalName.fresh ("closure", Pos.get_position block_expr)
@ -147,7 +174,9 @@ and translate_statements (ctxt : ctxt) (block_expr : L.expr Pos.marked) : A.bloc
var_dict = var_dict =
List.fold_left List.fold_left
(fun var_dict (x, _) -> (fun var_dict (x, _) ->
L.VarMap.add x (A.LocalName.fresh (Bindlib.name_of x, binder_pos)) var_dict) L.VarMap.add x
(A.LocalName.fresh (Bindlib.name_of x, binder_pos))
var_dict)
ctxt.var_dict vars_tau; ctxt.var_dict vars_tau;
inside_definition_of = None; inside_definition_of = None;
} }
@ -159,7 +188,8 @@ and translate_statements (ctxt : ctxt) (block_expr : L.expr Pos.marked) : A.bloc
{ {
func_params = func_params =
List.map List.map
(fun (var, tau) -> ((L.VarMap.find var ctxt.var_dict, binder_pos), tau)) (fun (var, tau) ->
((L.VarMap.find var ctxt.var_dict, binder_pos), tau))
vars_tau; vars_tau;
func_body = new_body; func_body = new_body;
} ), } ),
@ -175,8 +205,15 @@ and translate_statements (ctxt : ctxt) (block_expr : L.expr Pos.marked) : A.bloc
let vars, body = Bindlib.unmbind binder in let vars, body = Bindlib.unmbind binder in
assert (Array.length vars = 1); assert (Array.length vars = 1);
let var = vars.(0) in let var = vars.(0) in
let scalc_var = A.LocalName.fresh (Bindlib.name_of var, pos_binder) in let scalc_var =
let ctxt = { ctxt with var_dict = L.VarMap.add var scalc_var ctxt.var_dict } in A.LocalName.fresh (Bindlib.name_of var, pos_binder)
in
let ctxt =
{
ctxt with
var_dict = L.VarMap.add var scalc_var ctxt.var_dict;
}
in
let new_arg = translate_statements ctxt body in let new_arg = translate_statements ctxt body in
(new_arg, scalc_var) :: new_args (new_arg, scalc_var) :: new_args
| _ -> assert false | _ -> assert false
@ -184,16 +221,23 @@ and translate_statements (ctxt : ctxt) (block_expr : L.expr Pos.marked) : A.bloc
[] args [] args
in in
let new_args = List.rev new_args in let new_args = List.rev new_args in
e1_stmts @ [ (A.SSwitch (new_e1, e_name, new_args), Pos.get_position block_expr) ] e1_stmts
@ [ (A.SSwitch (new_e1, e_name, new_args), Pos.get_position block_expr) ]
| L.EIfThenElse (cond, e_true, e_false) -> | L.EIfThenElse (cond, e_true, e_false) ->
let cond_stmts, s_cond = translate_expr ctxt cond in let cond_stmts, s_cond = translate_expr ctxt cond in
let s_e_true = translate_statements ctxt e_true in let s_e_true = translate_statements ctxt e_true in
let s_e_false = translate_statements ctxt e_false in let s_e_false = translate_statements ctxt e_false in
cond_stmts @ [ (A.SIfThenElse (s_cond, s_e_true, s_e_false), Pos.get_position block_expr) ] cond_stmts
@ [
( A.SIfThenElse (s_cond, s_e_true, s_e_false),
Pos.get_position block_expr );
]
| L.ECatch (e_try, except, e_catch) -> | L.ECatch (e_try, except, e_catch) ->
let s_e_try = translate_statements ctxt e_try in let s_e_try = translate_statements ctxt e_try in
let s_e_catch = translate_statements ctxt e_catch in let s_e_catch = translate_statements ctxt e_catch in
[ (A.STryExcept (s_e_try, except, s_e_catch), Pos.get_position block_expr) ] [
(A.STryExcept (s_e_try, except, s_e_catch), Pos.get_position block_expr);
]
| L.ERaise except -> [ (A.SRaise except, Pos.get_position block_expr) ] | L.ERaise except -> [ (A.SRaise except, Pos.get_position block_expr) ]
| _ -> ( | _ -> (
let e_stmts, new_e = translate_expr ctxt block_expr in let e_stmts, new_e = translate_expr ctxt block_expr in
@ -201,8 +245,9 @@ and translate_statements (ctxt : ctxt) (block_expr : L.expr Pos.marked) : A.bloc
@ @
match e_stmts with match e_stmts with
| (A.SRaise _, _) :: _ -> | (A.SRaise _, _) :: _ ->
(* if the last statement raises an exception, then we don't need to return or to define (* if the last statement raises an exception, then we don't need to
the current variable since this code will be unreachable *) return or to define the current variable since this code will be
unreachable *)
[] []
| _ -> | _ ->
[ [
@ -212,16 +257,20 @@ and translate_statements (ctxt : ctxt) (block_expr : L.expr Pos.marked) : A.bloc
Pos.get_position block_expr ); Pos.get_position block_expr );
]) ])
let translate_scope (decl_ctx : D.decl_ctx) (func_dict : A.TopLevelName.t L.VarMap.t) let translate_scope
(scope_expr : L.expr Pos.marked) : (A.LocalName.t Pos.marked * D.typ Pos.marked) list * A.block (decl_ctx : D.decl_ctx)
= (func_dict : A.TopLevelName.t L.VarMap.t)
(scope_expr : L.expr Pos.marked) :
(A.LocalName.t Pos.marked * D.typ Pos.marked) list * A.block =
match Pos.unmark scope_expr with match Pos.unmark scope_expr with
| L.EAbs ((binder, binder_pos), typs) -> | L.EAbs ((binder, binder_pos), typs) ->
let vars, body = Bindlib.unmbind binder in let vars, body = Bindlib.unmbind binder in
let var_dict = let var_dict =
Array.fold_left Array.fold_left
(fun var_dict var -> (fun var_dict var ->
L.VarMap.add var (A.LocalName.fresh (Bindlib.name_of var, binder_pos)) var_dict) L.VarMap.add var
(A.LocalName.fresh (Bindlib.name_of var, binder_pos))
var_dict)
L.VarMap.empty vars L.VarMap.empty vars
in in
let param_list = let param_list =
@ -230,7 +279,9 @@ let translate_scope (decl_ctx : D.decl_ctx) (func_dict : A.TopLevelName.t L.VarM
(Array.to_list vars) typs (Array.to_list vars) typs
in in
let new_body = let new_body =
translate_statements { decl_ctx; func_dict; var_dict; inside_definition_of = None } body translate_statements
{ decl_ctx; func_dict; var_dict; inside_definition_of = None }
body
in in
(param_list, new_body) (param_list, new_body)
| _ -> assert false | _ -> assert false
@ -244,18 +295,25 @@ let translate_program (p : L.program) : A.program =
List.fold_left List.fold_left
(fun (func_dict, new_scopes) body -> (fun (func_dict, new_scopes) body ->
let new_scope_params, new_scope_body = let new_scope_params, new_scope_body =
translate_scope p.decl_ctx func_dict body.Lcalc.Ast.scope_body_expr translate_scope p.decl_ctx func_dict
body.Lcalc.Ast.scope_body_expr
in in
let func_id = let func_id =
A.TopLevelName.fresh (Bindlib.name_of body.Lcalc.Ast.scope_body_var, Pos.no_pos) A.TopLevelName.fresh
(Bindlib.name_of body.Lcalc.Ast.scope_body_var, Pos.no_pos)
in
let func_dict =
L.VarMap.add body.Lcalc.Ast.scope_body_var func_id func_dict
in in
let func_dict = L.VarMap.add body.Lcalc.Ast.scope_body_var func_id func_dict in
( func_dict, ( func_dict,
{ {
Ast.scope_body_name = body.Lcalc.Ast.scope_body_name; Ast.scope_body_name = body.Lcalc.Ast.scope_body_name;
Ast.scope_body_var = func_id; Ast.scope_body_var = func_id;
scope_body_func = scope_body_func =
{ A.func_params = new_scope_params; A.func_body = new_scope_body }; {
A.func_params = new_scope_params;
A.func_body = new_scope_body;
};
} }
:: new_scopes )) :: new_scopes ))
( (if !Cli.avoid_exceptions_flag then ( (if !Cli.avoid_exceptions_flag then

View File

@ -1,15 +1,17 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits (* This file is part of the Catala compiler, a specification language for tax
computation rules. Copyright (C) 2022 Inria, contributor: Denis Merigoux and social benefits computation rules. Copyright (C) 2022 Inria, contributor:
<denis.merigoux@inria.fr> Denis Merigoux <denis.merigoux@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
open Utils open Utils
@ -18,15 +20,19 @@ open Ast
let needs_parens (_e : expr Pos.marked) : bool = false let needs_parens (_e : expr Pos.marked) : bool = false
let format_local_name (fmt : Format.formatter) (v : LocalName.t) : unit = let format_local_name (fmt : Format.formatter) (v : LocalName.t) : unit =
Format.fprintf fmt "%a_%s" LocalName.format_t v (string_of_int (LocalName.hash v)) Format.fprintf fmt "%a_%s" LocalName.format_t v
(string_of_int (LocalName.hash v))
let rec format_expr (decl_ctx : Dcalc.Ast.decl_ctx) ?(debug : bool = false) (fmt : Format.formatter) let rec format_expr
(decl_ctx : Dcalc.Ast.decl_ctx)
?(debug : bool = false)
(fmt : Format.formatter)
(e : expr Pos.marked) : unit = (e : expr Pos.marked) : unit =
let format_expr = format_expr decl_ctx ~debug in let format_expr = format_expr decl_ctx ~debug in
let format_with_parens (fmt : Format.formatter) (e : expr Pos.marked) = let format_with_parens (fmt : Format.formatter) (e : expr Pos.marked) =
if needs_parens e then if needs_parens e then
Format.fprintf fmt "%a%a%a" Dcalc.Print.format_punctuation "(" format_expr e Format.fprintf fmt "%a%a%a" Dcalc.Print.format_punctuation "(" format_expr
Dcalc.Print.format_punctuation ")" e Dcalc.Print.format_punctuation ")"
else Format.fprintf fmt "%a" format_expr e else Format.fprintf fmt "%a" format_expr e
in in
match Pos.unmark e with match Pos.unmark e with
@ -38,10 +44,12 @@ let rec format_expr (decl_ctx : Dcalc.Ast.decl_ctx) ?(debug : bool = false) (fmt
(Format.pp_print_list (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ") ~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ")
(fun fmt (e, struct_field) -> (fun fmt (e, struct_field) ->
Format.fprintf fmt "%a%a%a%a %a" Dcalc.Print.format_punctuation "\"" Format.fprintf fmt "%a%a%a%a %a" Dcalc.Print.format_punctuation
Dcalc.Ast.StructFieldName.format_t struct_field Dcalc.Print.format_punctuation "\"" "\"" Dcalc.Ast.StructFieldName.format_t struct_field
Dcalc.Print.format_punctuation "\""
Dcalc.Print.format_punctuation ":" format_expr e)) Dcalc.Print.format_punctuation ":" format_expr e))
(List.combine es (List.map fst (Dcalc.Ast.StructMap.find s decl_ctx.ctx_structs))) (List.combine es
(List.map fst (Dcalc.Ast.StructMap.find s decl_ctx.ctx_structs)))
Dcalc.Print.format_punctuation "}" Dcalc.Print.format_punctuation "}"
| EArray es -> | EArray es ->
Format.fprintf fmt "@[<hov 2>%a%a%a@]" Dcalc.Print.format_punctuation "[" Format.fprintf fmt "@[<hov 2>%a%a%a@]" Dcalc.Print.format_punctuation "["
@ -50,76 +58,103 @@ let rec format_expr (decl_ctx : Dcalc.Ast.decl_ctx) ?(debug : bool = false) (fmt
(fun fmt e -> Format.fprintf fmt "%a" format_expr e)) (fun fmt e -> Format.fprintf fmt "%a" format_expr e))
es Dcalc.Print.format_punctuation "]" es Dcalc.Print.format_punctuation "]"
| EStructFieldAccess (e1, field, s) -> | EStructFieldAccess (e1, field, s) ->
Format.fprintf fmt "%a%a%a%a%a" format_expr e1 Dcalc.Print.format_punctuation "." Format.fprintf fmt "%a%a%a%a%a" format_expr e1
Dcalc.Print.format_punctuation "\"" Dcalc.Ast.StructFieldName.format_t Dcalc.Print.format_punctuation "." Dcalc.Print.format_punctuation "\""
Dcalc.Ast.StructFieldName.format_t
(fst (fst
(List.find (List.find
(fun (field', _) -> Dcalc.Ast.StructFieldName.compare field' field = 0) (fun (field', _) ->
Dcalc.Ast.StructFieldName.compare field' field = 0)
(Dcalc.Ast.StructMap.find s decl_ctx.ctx_structs))) (Dcalc.Ast.StructMap.find s decl_ctx.ctx_structs)))
Dcalc.Print.format_punctuation "\"" Dcalc.Print.format_punctuation "\""
| EInj (e, case, enum) -> | EInj (e, case, enum) ->
Format.fprintf fmt "@[<hov 2>%a@ %a@]" Dcalc.Print.format_enum_constructor Format.fprintf fmt "@[<hov 2>%a@ %a@]" Dcalc.Print.format_enum_constructor
(fst (fst
(List.find (List.find
(fun (case', _) -> Dcalc.Ast.EnumConstructor.compare case' case = 0) (fun (case', _) ->
Dcalc.Ast.EnumConstructor.compare case' case = 0)
(Dcalc.Ast.EnumMap.find enum decl_ctx.ctx_enums))) (Dcalc.Ast.EnumMap.find enum decl_ctx.ctx_enums)))
format_expr e format_expr e
| ELit l -> Format.fprintf fmt "%a" Lcalc.Print.format_lit (Pos.same_pos_as l e) | ELit l ->
| EApp ((EOp (Binop ((Dcalc.Ast.Map | Dcalc.Ast.Filter) as op)), _), [ arg1; arg2 ]) -> Format.fprintf fmt "%a" Lcalc.Print.format_lit (Pos.same_pos_as l e)
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@]" Dcalc.Print.format_binop (op, Pos.no_pos) | EApp
format_with_parens arg1 format_with_parens arg2 ( (EOp (Binop ((Dcalc.Ast.Map | Dcalc.Ast.Filter) as op)), _),
[ arg1; arg2 ] ) ->
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@]" Dcalc.Print.format_binop
(op, Pos.no_pos) format_with_parens arg1 format_with_parens arg2
| EApp ((EOp (Binop op), _), [ arg1; arg2 ]) -> | EApp ((EOp (Binop op), _), [ arg1; arg2 ]) ->
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@]" format_with_parens arg1 Dcalc.Print.format_binop Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@]" format_with_parens arg1
(op, Pos.no_pos) format_with_parens arg2 Dcalc.Print.format_binop (op, Pos.no_pos) format_with_parens arg2
| EApp ((EOp (Unop (Log _)), _), [ arg1 ]) when not debug -> | EApp ((EOp (Unop (Log _)), _), [ arg1 ]) when not debug ->
Format.fprintf fmt "%a" format_with_parens arg1 Format.fprintf fmt "%a" format_with_parens arg1
| EApp ((EOp (Unop op), _), [ arg1 ]) -> | EApp ((EOp (Unop op), _), [ arg1 ]) ->
Format.fprintf fmt "@[<hov 2>%a@ %a@]" Dcalc.Print.format_unop (op, Pos.no_pos) Format.fprintf fmt "@[<hov 2>%a@ %a@]" Dcalc.Print.format_unop
format_with_parens arg1 (op, Pos.no_pos) format_with_parens arg1
| EApp (f, args) -> | EApp (f, args) ->
Format.fprintf fmt "@[<hov 2>%a@ %a@]" format_expr f Format.fprintf fmt "@[<hov 2>%a@ %a@]" format_expr f
(Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ") format_with_parens) (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ")
format_with_parens)
args args
| EOp (Ternop op) -> Format.fprintf fmt "%a" Dcalc.Print.format_ternop (op, Pos.no_pos) | EOp (Ternop op) ->
| EOp (Binop op) -> Format.fprintf fmt "%a" Dcalc.Print.format_binop (op, Pos.no_pos) Format.fprintf fmt "%a" Dcalc.Print.format_ternop (op, Pos.no_pos)
| EOp (Unop op) -> Format.fprintf fmt "%a" Dcalc.Print.format_unop (op, Pos.no_pos) | EOp (Binop op) ->
Format.fprintf fmt "%a" Dcalc.Print.format_binop (op, Pos.no_pos)
| EOp (Unop op) ->
Format.fprintf fmt "%a" Dcalc.Print.format_unop (op, Pos.no_pos)
let rec format_statement (decl_ctx : Dcalc.Ast.decl_ctx) ?(debug : bool = false) let rec format_statement
(fmt : Format.formatter) (stmt : stmt Pos.marked) : unit = (decl_ctx : Dcalc.Ast.decl_ctx)
?(debug : bool = false)
(fmt : Format.formatter)
(stmt : stmt Pos.marked) : unit =
if debug then () else (); if debug then () else ();
match Pos.unmark stmt with match Pos.unmark stmt with
| SInnerFuncDef (name, func) -> | SInnerFuncDef (name, func) ->
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@ %a@]@\n@[<v 2> %a@]" Dcalc.Print.format_keyword Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@ %a@]@\n@[<v 2> %a@]"
"let" LocalName.format_t (Pos.unmark name) Dcalc.Print.format_keyword "let" LocalName.format_t (Pos.unmark name)
(Format.pp_print_list (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ") ~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ")
(fun fmt ((name, _), typ) -> (fun fmt ((name, _), typ) ->
Format.fprintf fmt "%a%a %a@ %a%a" Dcalc.Print.format_punctuation "(" Format.fprintf fmt "%a%a %a@ %a%a" Dcalc.Print.format_punctuation
LocalName.format_t name Dcalc.Print.format_punctuation ":" "(" LocalName.format_t name Dcalc.Print.format_punctuation ":"
(Dcalc.Print.format_typ decl_ctx) typ Dcalc.Print.format_punctuation ")")) (Dcalc.Print.format_typ decl_ctx)
func.func_params Dcalc.Print.format_punctuation "=" (format_block decl_ctx ~debug) typ Dcalc.Print.format_punctuation ")"))
func.func_params Dcalc.Print.format_punctuation "="
(format_block decl_ctx ~debug)
func.func_body func.func_body
| SLocalDecl (name, typ) -> | SLocalDecl (name, typ) ->
Format.fprintf fmt "@[<hov 2>%a %a %a@ %a@]" Dcalc.Print.format_keyword "decl" Format.fprintf fmt "@[<hov 2>%a %a %a@ %a@]" Dcalc.Print.format_keyword
LocalName.format_t (Pos.unmark name) Dcalc.Print.format_punctuation ":" "decl" LocalName.format_t (Pos.unmark name)
(Dcalc.Print.format_typ decl_ctx) typ Dcalc.Print.format_punctuation ":"
(Dcalc.Print.format_typ decl_ctx)
typ
| SLocalDef (name, expr) -> | SLocalDef (name, expr) ->
Format.fprintf fmt "@[<hov 2>%a %a@ %a@]" LocalName.format_t (Pos.unmark name) Format.fprintf fmt "@[<hov 2>%a %a@ %a@]" LocalName.format_t
Dcalc.Print.format_punctuation "=" (format_expr decl_ctx ~debug) expr (Pos.unmark name) Dcalc.Print.format_punctuation "="
(format_expr decl_ctx ~debug)
expr
| STryExcept (b_try, except, b_with) -> | STryExcept (b_try, except, b_with) ->
Format.fprintf fmt "@[<v 2>%a%a@ %a@]@\n@[<v 2>%a %a%a@ %a@]" Dcalc.Print.format_keyword "try" Format.fprintf fmt "@[<v 2>%a%a@ %a@]@\n@[<v 2>%a %a%a@ %a@]"
Dcalc.Print.format_punctuation ":" (format_block decl_ctx ~debug) b_try Dcalc.Print.format_keyword "try" Dcalc.Print.format_punctuation ":"
Dcalc.Print.format_keyword "with" Lcalc.Print.format_exception except (format_block decl_ctx ~debug)
Dcalc.Print.format_punctuation ":" (format_block decl_ctx ~debug) b_with b_try Dcalc.Print.format_keyword "with" Lcalc.Print.format_exception
except Dcalc.Print.format_punctuation ":"
(format_block decl_ctx ~debug)
b_with
| SRaise except -> | SRaise except ->
Format.fprintf fmt "@[<hov 2>%a %a@]" Dcalc.Print.format_keyword "raise" Format.fprintf fmt "@[<hov 2>%a %a@]" Dcalc.Print.format_keyword "raise"
Lcalc.Print.format_exception except Lcalc.Print.format_exception except
| SIfThenElse (e_if, b_true, b_false) -> | SIfThenElse (e_if, b_true, b_false) ->
Format.fprintf fmt "@[<v 2>%a @[<hov 2>%a@]%a@ %a@ @]@[<v 2>%a%a@ %a@]" Format.fprintf fmt "@[<v 2>%a @[<hov 2>%a@]%a@ %a@ @]@[<v 2>%a%a@ %a@]"
Dcalc.Print.format_keyword "if" (format_expr decl_ctx ~debug) e_if Dcalc.Print.format_keyword "if"
Dcalc.Print.format_punctuation ":" (format_block decl_ctx ~debug) b_true (format_expr decl_ctx ~debug)
Dcalc.Print.format_keyword "else" Dcalc.Print.format_punctuation ":" e_if Dcalc.Print.format_punctuation ":"
(format_block decl_ctx ~debug) b_false (format_block decl_ctx ~debug)
b_true Dcalc.Print.format_keyword "else" Dcalc.Print.format_punctuation
":"
(format_block decl_ctx ~debug)
b_false
| SReturn ret -> | SReturn ret ->
Format.fprintf fmt "@[<hov 2>%a %a@]" Dcalc.Print.format_keyword "return" Format.fprintf fmt "@[<hov 2>%a %a@]" Dcalc.Print.format_keyword "return"
(format_expr decl_ctx ~debug) (format_expr decl_ctx ~debug)
@ -129,34 +164,48 @@ let rec format_statement (decl_ctx : Dcalc.Ast.decl_ctx) ?(debug : bool = false)
(format_expr decl_ctx ~debug) (format_expr decl_ctx ~debug)
(expr, Pos.get_position stmt) (expr, Pos.get_position stmt)
| SSwitch (e_switch, enum, arms) -> | SSwitch (e_switch, enum, arms) ->
Format.fprintf fmt "@[<v 0>%a @[<hov 2>%a@]%a@]%a" Dcalc.Print.format_keyword "switch" Format.fprintf fmt "@[<v 0>%a @[<hov 2>%a@]%a@]%a"
(format_expr decl_ctx ~debug) e_switch Dcalc.Print.format_punctuation ":" Dcalc.Print.format_keyword "switch"
(format_expr decl_ctx ~debug)
e_switch Dcalc.Print.format_punctuation ":"
(Format.pp_print_list (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n") ~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n")
(fun fmt ((case, _), (arm_block, payload_name)) -> (fun fmt ((case, _), (arm_block, payload_name)) ->
Format.fprintf fmt "%a %a%a@ %a @[<v 2>%a@ %a@]" Dcalc.Print.format_punctuation "|" Format.fprintf fmt "%a %a%a@ %a @[<v 2>%a@ %a@]"
Dcalc.Print.format_enum_constructor case Dcalc.Print.format_punctuation ":" Dcalc.Print.format_punctuation "|"
LocalName.format_t payload_name Dcalc.Print.format_punctuation "" Dcalc.Print.format_enum_constructor case
(format_block decl_ctx ~debug) arm_block)) Dcalc.Print.format_punctuation ":" LocalName.format_t
payload_name Dcalc.Print.format_punctuation ""
(format_block decl_ctx ~debug)
arm_block))
(List.combine (Dcalc.Ast.EnumMap.find enum decl_ctx.ctx_enums) arms) (List.combine (Dcalc.Ast.EnumMap.find enum decl_ctx.ctx_enums) arms)
and format_block (decl_ctx : Dcalc.Ast.decl_ctx) ?(debug : bool = false) (fmt : Format.formatter) and format_block
(decl_ctx : Dcalc.Ast.decl_ctx)
?(debug : bool = false)
(fmt : Format.formatter)
(block : block) : unit = (block : block) : unit =
Format.pp_print_list Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "%a@ " Dcalc.Print.format_punctuation ";") ~pp_sep:(fun fmt () ->
Format.fprintf fmt "%a@ " Dcalc.Print.format_punctuation ";")
(format_statement decl_ctx ~debug) (format_statement decl_ctx ~debug)
fmt block fmt block
let format_scope (decl_ctx : Dcalc.Ast.decl_ctx) ?(debug : bool = false) (fmt : Format.formatter) let format_scope
(decl_ctx : Dcalc.Ast.decl_ctx)
?(debug : bool = false)
(fmt : Format.formatter)
(body : scope_body) : unit = (body : scope_body) : unit =
if debug then () else (); if debug then () else ();
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@ %a@]@\n@[<v 2> %a@]" Dcalc.Print.format_keyword "let" Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@ %a@]@\n@[<v 2> %a@]"
TopLevelName.format_t body.scope_body_var Dcalc.Print.format_keyword "let" TopLevelName.format_t body.scope_body_var
(Format.pp_print_list (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ") ~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ")
(fun fmt ((name, _), typ) -> (fun fmt ((name, _), typ) ->
Format.fprintf fmt "%a%a %a@ %a%a" Dcalc.Print.format_punctuation "(" LocalName.format_t Format.fprintf fmt "%a%a %a@ %a%a" Dcalc.Print.format_punctuation "("
name Dcalc.Print.format_punctuation ":" (Dcalc.Print.format_typ decl_ctx) typ LocalName.format_t name Dcalc.Print.format_punctuation ":"
Dcalc.Print.format_punctuation ")")) (Dcalc.Print.format_typ decl_ctx)
typ Dcalc.Print.format_punctuation ")"))
body.scope_body_func.func_params Dcalc.Print.format_punctuation "=" body.scope_body_func.func_params Dcalc.Print.format_punctuation "="
(format_block decl_ctx ~debug) body.scope_body_func.func_body (format_block decl_ctx ~debug)
body.scope_body_func.func_body

View File

@ -1,15 +1,22 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits (* This file is part of the Catala compiler, a specification language for tax
computation rules. Copyright (C) 2022 Inria, contributor: Denis Merigoux and social benefits computation rules. Copyright (C) 2022 Inria, contributor:
<denis.merigoux@inria.fr> Denis Merigoux <denis.merigoux@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
val format_scope : Dcalc.Ast.decl_ctx -> ?debug:bool -> Format.formatter -> Ast.scope_body -> unit val format_scope :
Dcalc.Ast.decl_ctx ->
?debug:bool ->
Format.formatter ->
Ast.scope_body ->
unit

View File

@ -1,15 +1,17 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits (* This file is part of the Catala compiler, a specification language for tax
computation rules. Copyright (C) 2020 Inria, contributor: Denis Merigoux and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
<denis.merigoux@inria.fr> Denis Merigoux <denis.merigoux@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
[@@@warning "-32-27"] [@@@warning "-32-27"]
@ -23,7 +25,9 @@ let format_lit (fmt : Format.formatter) (l : L.lit Pos.marked) : unit =
match Pos.unmark l with match Pos.unmark l with
| LBool true -> Format.fprintf fmt "True" | LBool true -> Format.fprintf fmt "True"
| LBool false -> Format.fprintf fmt "False" | LBool false -> Format.fprintf fmt "False"
| LInt i -> Format.fprintf fmt "integer_of_string(\"%s\")" (Runtime.integer_to_string i) | LInt i ->
Format.fprintf fmt "integer_of_string(\"%s\")"
(Runtime.integer_to_string i)
| LUnit -> Format.fprintf fmt "Unit()" | LUnit -> Format.fprintf fmt "Unit()"
| LRat i -> | LRat i ->
Format.fprintf fmt "decimal_of_string(\"%a\")" Dcalc.Print.format_lit Format.fprintf fmt "decimal_of_string(\"%a\")" Dcalc.Print.format_lit
@ -40,14 +44,16 @@ let format_lit (fmt : Format.formatter) (l : L.lit Pos.marked) : unit =
let years, months, days = Runtime.duration_to_years_months_days d in let years, months, days = Runtime.duration_to_years_months_days d in
Format.fprintf fmt "duration_of_numbers(%d,%d,%d)" years months days Format.fprintf fmt "duration_of_numbers(%d,%d,%d)" years months days
let format_log_entry (fmt : Format.formatter) (entry : Dcalc.Ast.log_entry) : unit = let format_log_entry (fmt : Format.formatter) (entry : Dcalc.Ast.log_entry) :
unit =
match entry with match entry with
| VarDef _ -> Format.fprintf fmt ":=" | VarDef _ -> Format.fprintf fmt ":="
| BeginCall -> Format.fprintf fmt "" | BeginCall -> Format.fprintf fmt ""
| EndCall -> Format.fprintf fmt "%s" "" | EndCall -> Format.fprintf fmt "%s" ""
| PosRecordIfTrueBool -> Format.fprintf fmt "" | PosRecordIfTrueBool -> Format.fprintf fmt ""
let format_binop (fmt : Format.formatter) (op : Dcalc.Ast.binop Pos.marked) : unit = let format_binop (fmt : Format.formatter) (op : Dcalc.Ast.binop Pos.marked) :
unit =
match Pos.unmark op with match Pos.unmark op with
| Add _ | Concat -> Format.fprintf fmt "+" | Add _ | Concat -> Format.fprintf fmt "+"
| Sub _ -> Format.fprintf fmt "-" | Sub _ -> Format.fprintf fmt "-"
@ -65,14 +71,17 @@ let format_binop (fmt : Format.formatter) (op : Dcalc.Ast.binop Pos.marked) : un
| Map -> Format.fprintf fmt "list_map" | Map -> Format.fprintf fmt "list_map"
| Filter -> Format.fprintf fmt "list_filter" | Filter -> Format.fprintf fmt "list_filter"
let format_ternop (fmt : Format.formatter) (op : Dcalc.Ast.ternop Pos.marked) : unit = let format_ternop (fmt : Format.formatter) (op : Dcalc.Ast.ternop Pos.marked) :
unit =
match Pos.unmark op with Fold -> Format.fprintf fmt "list_fold_left" match Pos.unmark op with Fold -> Format.fprintf fmt "list_fold_left"
let format_uid_list (fmt : Format.formatter) (uids : Uid.MarkedString.info list) : unit = let format_uid_list (fmt : Format.formatter) (uids : Uid.MarkedString.info list)
: unit =
Format.fprintf fmt "[%a]" Format.fprintf fmt "[%a]"
(Format.pp_print_list (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ") ~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ")
(fun fmt info -> Format.fprintf fmt "\"%a\"" Utils.Uid.MarkedString.format_info info)) (fun fmt info ->
Format.fprintf fmt "\"%a\"" Utils.Uid.MarkedString.format_info info))
uids uids
let format_string_list (fmt : Format.formatter) (uids : string list) : unit = let format_string_list (fmt : Format.formatter) (uids : string list) : unit =
@ -82,7 +91,8 @@ let format_string_list (fmt : Format.formatter) (uids : string list) : unit =
(fun fmt info -> Format.fprintf fmt "\"%s\"" info)) (fun fmt info -> Format.fprintf fmt "\"%s\"" info))
uids uids
let format_unop (fmt : Format.formatter) (op : Dcalc.Ast.unop Pos.marked) : unit = let format_unop (fmt : Format.formatter) (op : Dcalc.Ast.unop Pos.marked) : unit
=
match Pos.unmark op with match Pos.unmark op with
| Minus _ -> Format.fprintf fmt "-" | Minus _ -> Format.fprintf fmt "-"
| Not -> Format.fprintf fmt "not" | Not -> Format.fprintf fmt "not"
@ -96,39 +106,52 @@ let format_unop (fmt : Format.formatter) (op : Dcalc.Ast.unop Pos.marked) : unit
let avoid_keywords (s : string) : string = let avoid_keywords (s : string) : string =
if if
match s with match s with
(* list taken from https://www.programiz.com/python-programming/keyword-list *) (* list taken from
| "False" | "None" | "True" | "and" | "as" | "assert" | "async" | "await" | "break" | "class" https://www.programiz.com/python-programming/keyword-list *)
| "continue" | "def" | "del" | "elif" | "else" | "except" | "finally" | "for" | "from" | "False" | "None" | "True" | "and" | "as" | "assert" | "async" | "await"
| "global" | "if" | "import" | "in" | "is" | "lambda" | "nonlocal" | "not" | "or" | "pass" | "break" | "class" | "continue" | "def" | "del" | "elif" | "else"
| "raise" | "return" | "try" | "while" | "with" | "yield" -> | "except" | "finally" | "for" | "from" | "global" | "if" | "import" | "in"
| "is" | "lambda" | "nonlocal" | "not" | "or" | "pass" | "raise" | "return"
| "try" | "while" | "with" | "yield" ->
true true
| _ -> false | _ -> false
then s ^ "_" then s ^ "_"
else s else s
let format_struct_name (fmt : Format.formatter) (v : Dcalc.Ast.StructName.t) : unit = let format_struct_name (fmt : Format.formatter) (v : Dcalc.Ast.StructName.t) :
unit =
Format.fprintf fmt "%s" Format.fprintf fmt "%s"
(avoid_keywords (avoid_keywords
(to_uppercase (to_ascii (Format.asprintf "%a" Dcalc.Ast.StructName.format_t v)))) (to_uppercase
(to_ascii (Format.asprintf "%a" Dcalc.Ast.StructName.format_t v))))
let format_struct_field_name (fmt : Format.formatter) (v : Dcalc.Ast.StructFieldName.t) : unit = let format_struct_field_name
(fmt : Format.formatter) (v : Dcalc.Ast.StructFieldName.t) : unit =
Format.fprintf fmt "%s" Format.fprintf fmt "%s"
(avoid_keywords (to_ascii (Format.asprintf "%a" Dcalc.Ast.StructFieldName.format_t v))) (avoid_keywords
(to_ascii (Format.asprintf "%a" Dcalc.Ast.StructFieldName.format_t v)))
let format_enum_name (fmt : Format.formatter) (v : Dcalc.Ast.EnumName.t) : unit = let format_enum_name (fmt : Format.formatter) (v : Dcalc.Ast.EnumName.t) : unit
=
Format.fprintf fmt "%s" Format.fprintf fmt "%s"
(avoid_keywords (to_uppercase (to_ascii (Format.asprintf "%a" Dcalc.Ast.EnumName.format_t v)))) (avoid_keywords
(to_uppercase
(to_ascii (Format.asprintf "%a" Dcalc.Ast.EnumName.format_t v))))
let format_enum_cons_name (fmt : Format.formatter) (v : Dcalc.Ast.EnumConstructor.t) : unit = let format_enum_cons_name
(fmt : Format.formatter) (v : Dcalc.Ast.EnumConstructor.t) : unit =
Format.fprintf fmt "%s" Format.fprintf fmt "%s"
(avoid_keywords (to_ascii (Format.asprintf "%a" Dcalc.Ast.EnumConstructor.format_t v))) (avoid_keywords
(to_ascii (Format.asprintf "%a" Dcalc.Ast.EnumConstructor.format_t v)))
let typ_needs_parens (e : Dcalc.Ast.typ Pos.marked) : bool = let typ_needs_parens (e : Dcalc.Ast.typ Pos.marked) : bool =
match Pos.unmark e with TArrow _ | TArray _ -> true | _ -> false match Pos.unmark e with TArrow _ | TArray _ -> true | _ -> false
let rec format_typ (fmt : Format.formatter) (typ : Dcalc.Ast.typ Pos.marked) : unit = let rec format_typ (fmt : Format.formatter) (typ : Dcalc.Ast.typ Pos.marked) :
unit =
let format_typ = format_typ in let format_typ = format_typ in
let format_typ_with_parens (fmt : Format.formatter) (t : Dcalc.Ast.typ Pos.marked) = let format_typ_with_parens
(fmt : Format.formatter) (t : Dcalc.Ast.typ Pos.marked) =
if typ_needs_parens t then Format.fprintf fmt "(%a)" format_typ t if typ_needs_parens t then Format.fprintf fmt "(%a)" format_typ t
else Format.fprintf fmt "%a" format_typ t else Format.fprintf fmt "%a" format_typ t
in in
@ -152,14 +175,17 @@ let rec format_typ (fmt : Format.formatter) (typ : Dcalc.Ast.typ Pos.marked) : u
Format.fprintf fmt "Optional[%a]" format_typ some_typ Format.fprintf fmt "Optional[%a]" format_typ some_typ
| TEnum (_, e) -> Format.fprintf fmt "%a" format_enum_name e | TEnum (_, e) -> Format.fprintf fmt "%a" format_enum_name e
| TArrow (t1, t2) -> | TArrow (t1, t2) ->
Format.fprintf fmt "Callable[[%a], %a]" format_typ_with_parens t1 format_typ_with_parens t2 Format.fprintf fmt "Callable[[%a], %a]" format_typ_with_parens t1
format_typ_with_parens t2
| TArray t1 -> Format.fprintf fmt "List[%a]" format_typ_with_parens t1 | TArray t1 -> Format.fprintf fmt "List[%a]" format_typ_with_parens t1
| TAny -> Format.fprintf fmt "Any" | TAny -> Format.fprintf fmt "Any"
let format_name_cleaned (fmt : Format.formatter) (s : string) : unit = let format_name_cleaned (fmt : Format.formatter) (s : string) : unit =
let lowercase_name = to_lowercase (to_ascii s) in let lowercase_name = to_lowercase (to_ascii s) in
let lowercase_name = let lowercase_name =
Re.Pcre.substitute ~rex:(Re.Pcre.regexp "\\.") ~subst:(fun _ -> "_dot_") lowercase_name Re.Pcre.substitute ~rex:(Re.Pcre.regexp "\\.")
~subst:(fun _ -> "_dot_")
lowercase_name
in in
let lowercase_name = avoid_keywords (to_ascii lowercase_name) in let lowercase_name = avoid_keywords (to_ascii lowercase_name) in
Format.fprintf fmt "%s" lowercase_name Format.fprintf fmt "%s" lowercase_name
@ -174,9 +200,12 @@ let format_toplevel_name (fmt : Format.formatter) (v : TopLevelName.t) : unit =
format_name_cleaned fmt v_str format_name_cleaned fmt v_str
let needs_parens (e : expr Pos.marked) : bool = let needs_parens (e : expr Pos.marked) : bool =
match Pos.unmark e with ELit (LBool _ | LUnit) | EVar _ | EOp _ -> false | _ -> true match Pos.unmark e with
| ELit (LBool _ | LUnit) | EVar _ | EOp _ -> false
| _ -> true
let format_exception (fmt : Format.formatter) (exc : L.except Pos.marked) : unit = let format_exception (fmt : Format.formatter) (exc : L.except Pos.marked) : unit
=
match Pos.unmark exc with match Pos.unmark exc with
| ConflictError -> Format.fprintf fmt "ConflictError" | ConflictError -> Format.fprintf fmt "ConflictError"
| EmptyError -> Format.fprintf fmt "EmptyError" | EmptyError -> Format.fprintf fmt "EmptyError"
@ -184,13 +213,16 @@ let format_exception (fmt : Format.formatter) (exc : L.except Pos.marked) : unit
| NoValueProvided -> | NoValueProvided ->
let pos = Pos.get_position exc in let pos = Pos.get_position exc in
Format.fprintf fmt Format.fprintf fmt
"NoValueProvided(@[<hov 0>SourcePosition(@[<hov 0>filename=\"%s\",@ start_line=%d,@ \ "NoValueProvided(@[<hov 0>SourcePosition(@[<hov 0>filename=\"%s\",@ \
start_column=%d,@ end_line=%d,@ end_column=%d,@ law_headings=%a)@])@]" start_line=%d,@ start_column=%d,@ end_line=%d,@ end_column=%d,@ \
law_headings=%a)@])@]"
(Pos.get_file pos) (Pos.get_start_line pos) (Pos.get_start_column pos) (Pos.get_file pos) (Pos.get_start_line pos) (Pos.get_start_column pos)
(Pos.get_end_line pos) (Pos.get_end_column pos) format_string_list (Pos.get_law_info pos) (Pos.get_end_line pos) (Pos.get_end_column pos) format_string_list
(Pos.get_law_info pos)
let rec format_expression (ctx : Dcalc.Ast.decl_ctx) (fmt : Format.formatter) (e : expr Pos.marked) let rec format_expression
: unit = (ctx : Dcalc.Ast.decl_ctx) (fmt : Format.formatter) (e : expr Pos.marked) :
unit =
match Pos.unmark e with match Pos.unmark e with
| EVar v -> format_var fmt v | EVar v -> format_var fmt v
| EFunc f -> format_toplevel_name fmt f | EFunc f -> format_toplevel_name fmt f
@ -201,9 +233,11 @@ let rec format_expression (ctx : Dcalc.Ast.decl_ctx) (fmt : Format.formatter) (e
(fun fmt (e, struct_field) -> (fun fmt (e, struct_field) ->
Format.fprintf fmt "%a = %a" format_struct_field_name struct_field Format.fprintf fmt "%a = %a" format_struct_field_name struct_field
(format_expression ctx) e)) (format_expression ctx) e))
(List.combine es (List.map fst (Dcalc.Ast.StructMap.find s ctx.ctx_structs))) (List.combine es
(List.map fst (Dcalc.Ast.StructMap.find s ctx.ctx_structs)))
| EStructFieldAccess (e1, field, _) -> | EStructFieldAccess (e1, field, _) ->
Format.fprintf fmt "%a.%a" (format_expression ctx) e1 format_struct_field_name field Format.fprintf fmt "%a.%a" (format_expression ctx) e1
format_struct_field_name field
| EInj (_, cons, e_name) | EInj (_, cons, e_name)
when D.EnumName.compare e_name L.option_enum = 0 when D.EnumName.compare e_name L.option_enum = 0
&& D.EnumConstructor.compare cons L.none_constr = 0 -> && D.EnumConstructor.compare cons L.none_constr = 0 ->
@ -215,8 +249,9 @@ let rec format_expression (ctx : Dcalc.Ast.decl_ctx) (fmt : Format.formatter) (e
(* We translate the option type with an overloading by Python's [None] *) (* We translate the option type with an overloading by Python's [None] *)
format_expression ctx fmt e format_expression ctx fmt e
| EInj (e, cons, enum_name) -> | EInj (e, cons, enum_name) ->
Format.fprintf fmt "%a(%a_Code.%a,@ %a)" format_enum_name enum_name format_enum_name enum_name Format.fprintf fmt "%a(%a_Code.%a,@ %a)" format_enum_name enum_name
format_enum_cons_name cons (format_expression ctx) e format_enum_name enum_name format_enum_cons_name cons
(format_expression ctx) e
| EArray es -> | EArray es ->
Format.fprintf fmt "[%a]" Format.fprintf fmt "[%a]"
(Format.pp_print_list (Format.pp_print_list
@ -224,34 +259,43 @@ let rec format_expression (ctx : Dcalc.Ast.decl_ctx) (fmt : Format.formatter) (e
(fun fmt e -> Format.fprintf fmt "%a" (format_expression ctx) e)) (fun fmt e -> Format.fprintf fmt "%a" (format_expression ctx) e))
es es
| ELit l -> Format.fprintf fmt "%a" format_lit (Pos.same_pos_as l e) | ELit l -> Format.fprintf fmt "%a" format_lit (Pos.same_pos_as l e)
| EApp ((EOp (Binop ((Dcalc.Ast.Map | Dcalc.Ast.Filter) as op)), _), [ arg1; arg2 ]) -> | EApp
Format.fprintf fmt "%a(%a,@ %a)" format_binop (op, Pos.no_pos) (format_expression ctx) arg1 ( (EOp (Binop ((Dcalc.Ast.Map | Dcalc.Ast.Filter) as op)), _),
(format_expression ctx) arg2 [ arg1; arg2 ] ) ->
Format.fprintf fmt "%a(%a,@ %a)" format_binop (op, Pos.no_pos)
(format_expression ctx) arg1 (format_expression ctx) arg2
| EApp ((EOp (Binop op), _), [ arg1; arg2 ]) -> | EApp ((EOp (Binop op), _), [ arg1; arg2 ]) ->
Format.fprintf fmt "(%a %a@ %a)" (format_expression ctx) arg1 format_binop (op, Pos.no_pos) Format.fprintf fmt "(%a %a@ %a)" (format_expression ctx) arg1 format_binop
(format_expression ctx) arg2 (op, Pos.no_pos) (format_expression ctx) arg2
| EApp ((EApp ((EOp (Unop (D.Log (D.BeginCall, info))), _), [ f ]), _), [ arg ]) | EApp
((EApp ((EOp (Unop (D.Log (D.BeginCall, info))), _), [ f ]), _), [ arg ])
when !Cli.trace_flag ->
Format.fprintf fmt "log_begin_call(%a,@ %a,@ %a)" format_uid_list info
(format_expression ctx) f (format_expression ctx) arg
| EApp ((EOp (Unop (D.Log (D.VarDef tau, info))), _), [ arg1 ])
when !Cli.trace_flag -> when !Cli.trace_flag ->
Format.fprintf fmt "log_begin_call(%a,@ %a,@ %a)" format_uid_list info (format_expression ctx)
f (format_expression ctx) arg
| EApp ((EOp (Unop (D.Log (D.VarDef tau, info))), _), [ arg1 ]) when !Cli.trace_flag ->
Format.fprintf fmt "log_variable_definition(%a,@ %a)" format_uid_list info Format.fprintf fmt "log_variable_definition(%a,@ %a)" format_uid_list info
(format_expression ctx) arg1 (format_expression ctx) arg1
| EApp ((EOp (Unop (D.Log (D.PosRecordIfTrueBool, _))), pos), [ arg1 ]) when !Cli.trace_flag -> | EApp ((EOp (Unop (D.Log (D.PosRecordIfTrueBool, _))), pos), [ arg1 ])
when !Cli.trace_flag ->
Format.fprintf fmt Format.fprintf fmt
"log_decision_taken(SourcePosition(filename=\"%s\",@ start_line=%d,@ start_column=%d,@ \ "log_decision_taken(SourcePosition(filename=\"%s\",@ start_line=%d,@ \
end_line=%d, end_column=%d,@ law_headings=%a), %a)" start_column=%d,@ end_line=%d, end_column=%d,@ law_headings=%a), %a)"
(Pos.get_file pos) (Pos.get_start_line pos) (Pos.get_start_column pos) (Pos.get_file pos) (Pos.get_start_line pos) (Pos.get_start_column pos)
(Pos.get_end_line pos) (Pos.get_end_column pos) format_string_list (Pos.get_law_info pos) (Pos.get_end_line pos) (Pos.get_end_column pos) format_string_list
(Pos.get_law_info pos) (format_expression ctx) arg1
| EApp ((EOp (Unop (D.Log (D.EndCall, info))), _), [ arg1 ])
when !Cli.trace_flag ->
Format.fprintf fmt "log_end_call(%a,@ %a)" format_uid_list info
(format_expression ctx) arg1 (format_expression ctx) arg1
| EApp ((EOp (Unop (D.Log (D.EndCall, info))), _), [ arg1 ]) when !Cli.trace_flag ->
Format.fprintf fmt "log_end_call(%a,@ %a)" format_uid_list info (format_expression ctx) arg1
| EApp ((EOp (Unop (D.Log _)), _), [ arg1 ]) -> | EApp ((EOp (Unop (D.Log _)), _), [ arg1 ]) ->
Format.fprintf fmt "%a" (format_expression ctx) arg1 Format.fprintf fmt "%a" (format_expression ctx) arg1
| EApp ((EOp (Unop ((Minus _ | Not) as op)), _), [ arg1 ]) -> | EApp ((EOp (Unop ((Minus _ | Not) as op)), _), [ arg1 ]) ->
Format.fprintf fmt "%a %a" format_unop (op, Pos.no_pos) (format_expression ctx) arg1 Format.fprintf fmt "%a %a" format_unop (op, Pos.no_pos)
(format_expression ctx) arg1
| EApp ((EOp (Unop op), _), [ arg1 ]) -> | EApp ((EOp (Unop op), _), [ arg1 ]) ->
Format.fprintf fmt "%a(%a)" format_unop (op, Pos.no_pos) (format_expression ctx) arg1 Format.fprintf fmt "%a(%a)" format_unop (op, Pos.no_pos)
(format_expression ctx) arg1
| EApp (f, args) -> | EApp (f, args) ->
Format.fprintf fmt "%a(@[<hov 0>%a)@]" (format_expression ctx) f Format.fprintf fmt "%a(@[<hov 0>%a)@]" (format_expression ctx) f
(Format.pp_print_list (Format.pp_print_list
@ -262,60 +306,85 @@ let rec format_expression (ctx : Dcalc.Ast.decl_ctx) (fmt : Format.formatter) (e
| EOp (Binop op) -> Format.fprintf fmt "%a" format_binop (op, Pos.no_pos) | EOp (Binop op) -> Format.fprintf fmt "%a" format_binop (op, Pos.no_pos)
| EOp (Unop op) -> Format.fprintf fmt "%a" format_unop (op, Pos.no_pos) | EOp (Unop op) -> Format.fprintf fmt "%a" format_unop (op, Pos.no_pos)
let rec format_statement (ctx : Dcalc.Ast.decl_ctx) (fmt : Format.formatter) (s : stmt Pos.marked) : let rec format_statement
(ctx : Dcalc.Ast.decl_ctx) (fmt : Format.formatter) (s : stmt Pos.marked) :
unit = unit =
match Pos.unmark s with match Pos.unmark s with
| SInnerFuncDef (name, { func_params; func_body }) -> | SInnerFuncDef (name, { func_params; func_body }) ->
Format.fprintf fmt "@[<hov 4>def %a(%a):@\n%a@]" format_var (Pos.unmark name) Format.fprintf fmt "@[<hov 4>def %a(%a):@\n%a@]" format_var
(Pos.unmark name)
(Format.pp_print_list (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ", ") ~pp_sep:(fun fmt () -> Format.fprintf fmt ", ")
(fun fmt (var, typ) -> (fun fmt (var, typ) ->
Format.fprintf fmt "%a:%a" format_var (Pos.unmark var) format_typ typ)) Format.fprintf fmt "%a:%a" format_var (Pos.unmark var) format_typ
typ))
func_params (format_block ctx) func_body func_params (format_block ctx) func_body
| SLocalDecl _ -> assert false (* We don't need to declare variables in Python *) | SLocalDecl _ ->
assert false (* We don't need to declare variables in Python *)
| SLocalDef (v, e) -> | SLocalDef (v, e) ->
Format.fprintf fmt "@[<hov 4>%a = %a@]" format_var (Pos.unmark v) (format_expression ctx) e Format.fprintf fmt "@[<hov 4>%a = %a@]" format_var (Pos.unmark v)
(format_expression ctx) e
| STryExcept (try_b, except, catch_b) -> | STryExcept (try_b, except, catch_b) ->
Format.fprintf fmt "@[<hov 4>try:@\n%a@]@\n@[<hov 4>except %a:@\n%a@]" (format_block ctx) Format.fprintf fmt "@[<hov 4>try:@\n%a@]@\n@[<hov 4>except %a:@\n%a@]"
try_b format_exception (except, Pos.no_pos) (format_block ctx) catch_b (format_block ctx) try_b format_exception (except, Pos.no_pos)
(format_block ctx) catch_b
| SRaise except -> | SRaise except ->
Format.fprintf fmt "@[<hov 4>raise %a@]" format_exception (except, Pos.get_position s) Format.fprintf fmt "@[<hov 4>raise %a@]" format_exception
(except, Pos.get_position s)
| SIfThenElse (cond, b1, b2) -> | SIfThenElse (cond, b1, b2) ->
Format.fprintf fmt "@[<hov 4>if %a:@\n%a@]@\n@[<hov 4>else:@\n%a@]" (format_expression ctx) Format.fprintf fmt "@[<hov 4>if %a:@\n%a@]@\n@[<hov 4>else:@\n%a@]"
cond (format_block ctx) b1 (format_block ctx) b2 (format_expression ctx) cond (format_block ctx) b1 (format_block ctx) b2
| SSwitch (e1, e_name, [ (case_none, _); (case_some, case_some_var) ]) | SSwitch (e1, e_name, [ (case_none, _); (case_some, case_some_var) ])
when D.EnumName.compare e_name L.option_enum = 0 -> when D.EnumName.compare e_name L.option_enum = 0 ->
(* We translate the option type with an overloading by Python's [None] *) (* We translate the option type with an overloading by Python's [None] *)
let tmp_var = LocalName.fresh ("perhaps_none_arg", Pos.no_pos) in let tmp_var = LocalName.fresh ("perhaps_none_arg", Pos.no_pos) in
Format.fprintf fmt Format.fprintf fmt
"%a = %a@\n@[<hov 4>if %a is None:@\n%a@]@\n@[<hov 4>else:@\n%a = %a@\n%a@]" format_var "%a = %a@\n\
tmp_var (format_expression ctx) e1 format_var tmp_var (format_block ctx) case_none @[<hov 4>if %a is None:@\n\
format_var case_some_var format_var tmp_var (format_block ctx) case_some %a@]@\n\
@[<hov 4>else:@\n\
%a = %a@\n\
%a@]"
format_var tmp_var (format_expression ctx) e1 format_var tmp_var
(format_block ctx) case_none format_var case_some_var format_var tmp_var
(format_block ctx) case_some
| SSwitch (e1, e_name, cases) -> | SSwitch (e1, e_name, cases) ->
let cases = let cases =
List.map2 (fun (x, y) (cons, _) -> (x, y, cons)) cases (D.EnumMap.find e_name ctx.ctx_enums) List.map2
(fun (x, y) (cons, _) -> (x, y, cons))
cases
(D.EnumMap.find e_name ctx.ctx_enums)
in in
let tmp_var = LocalName.fresh ("match_arg", Pos.no_pos) in let tmp_var = LocalName.fresh ("match_arg", Pos.no_pos) in
Format.fprintf fmt "%a = %a@\n@[<hov 4>if %a@]" format_var tmp_var (format_expression ctx) e1 Format.fprintf fmt "%a = %a@\n@[<hov 4>if %a@]" format_var tmp_var
(format_expression ctx) e1
(Format.pp_print_list (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@]@\n@[<hov 4>elif ") ~pp_sep:(fun fmt () -> Format.fprintf fmt "@]@\n@[<hov 4>elif ")
(fun fmt (case_block, payload_var, cons_name) -> (fun fmt (case_block, payload_var, cons_name) ->
Format.fprintf fmt "%a.code == %a_Code.%a:@\n%a = %a.value@\n%a" format_var tmp_var Format.fprintf fmt "%a.code == %a_Code.%a:@\n%a = %a.value@\n%a"
format_enum_name e_name format_enum_cons_name cons_name format_var payload_var format_var tmp_var format_enum_name e_name format_enum_cons_name
format_var tmp_var (format_block ctx) case_block)) cons_name format_var payload_var format_var tmp_var
(format_block ctx) case_block))
cases cases
| SReturn e1 -> | SReturn e1 ->
Format.fprintf fmt "@[<hov 4>return %a@]" (format_expression ctx) (e1, Pos.get_position s) Format.fprintf fmt "@[<hov 4>return %a@]" (format_expression ctx)
(e1, Pos.get_position s)
| SAssert e1 -> | SAssert e1 ->
Format.fprintf fmt "@[<hov 4>assert %a@]" (format_expression ctx) (e1, Pos.get_position s) Format.fprintf fmt "@[<hov 4>assert %a@]" (format_expression ctx)
(e1, Pos.get_position s)
and format_block (ctx : Dcalc.Ast.decl_ctx) (fmt : Format.formatter) (b : block) : unit = and format_block (ctx : Dcalc.Ast.decl_ctx) (fmt : Format.formatter) (b : block)
: unit =
Format.pp_print_list Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n") ~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n")
(format_statement ctx) fmt (format_statement ctx) fmt
(List.filter (fun s -> match Pos.unmark s with SLocalDecl _ -> false | _ -> true) b) (List.filter
(fun s -> match Pos.unmark s with SLocalDecl _ -> false | _ -> true)
b)
let format_ctx (type_ordering : Scopelang.Dependency.TVertex.t list) (fmt : Format.formatter) let format_ctx
(type_ordering : Scopelang.Dependency.TVertex.t list)
(fmt : Format.formatter)
(ctx : D.decl_ctx) : unit = (ctx : D.decl_ctx) : unit =
let format_struct_decl fmt (struct_name, struct_fields) = let format_struct_decl fmt (struct_name, struct_fields) =
Format.fprintf fmt Format.fprintf fmt
@ -333,27 +402,29 @@ let format_ctx (type_ordering : Scopelang.Dependency.TVertex.t list) (fmt : Form
\t\treturn not (self == other)@\n\ \t\treturn not (self == other)@\n\
@\n\ @\n\
\tdef __str__(self) -> str:@\n\ \tdef __str__(self) -> str:@\n\
\t\t@[<hov 4>return \"%a(%a)\".format(%a)@]" format_struct_name struct_name \t\t@[<hov 4>return \"%a(%a)\".format(%a)@]" format_struct_name
struct_name
(Format.pp_print_list (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ", ") ~pp_sep:(fun fmt () -> Format.fprintf fmt ", ")
(fun _fmt (struct_field, struct_field_type) -> (fun _fmt (struct_field, struct_field_type) ->
Format.fprintf fmt "%a: %a" format_struct_field_name struct_field format_typ Format.fprintf fmt "%a: %a" format_struct_field_name struct_field
struct_field_type)) format_typ struct_field_type))
struct_fields struct_fields
(if List.length struct_fields = 0 then fun fmt _ -> Format.fprintf fmt "\t\tpass" (if List.length struct_fields = 0 then fun fmt _ ->
Format.fprintf fmt "\t\tpass"
else else
Format.pp_print_list Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n") ~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n")
(fun _fmt (struct_field, _) -> (fun _fmt (struct_field, _) ->
Format.fprintf fmt "\t\tself.%a = %a" format_struct_field_name struct_field Format.fprintf fmt "\t\tself.%a = %a" format_struct_field_name
format_struct_field_name struct_field)) struct_field format_struct_field_name struct_field))
struct_fields format_struct_name struct_name struct_fields format_struct_name struct_name
(if List.length struct_fields > 0 then (if List.length struct_fields > 0 then
Format.pp_print_list Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt " and@ ") ~pp_sep:(fun fmt () -> Format.fprintf fmt " and@ ")
(fun _fmt (struct_field, _) -> (fun _fmt (struct_field, _) ->
Format.fprintf fmt "self.%a == other.%a" format_struct_field_name struct_field Format.fprintf fmt "self.%a == other.%a" format_struct_field_name
format_struct_field_name struct_field) struct_field format_struct_field_name struct_field)
else fun fmt _ -> Format.fprintf fmt "True") else fun fmt _ -> Format.fprintf fmt "True")
struct_fields format_struct_name struct_name struct_fields format_struct_name struct_name
(Format.pp_print_list (Format.pp_print_list
@ -391,13 +462,15 @@ let format_ctx (type_ordering : Scopelang.Dependency.TVertex.t list) (fmt : Form
\t\treturn not (self == other)@\n\ \t\treturn not (self == other)@\n\
@\n\ @\n\
\tdef __str__(self) -> str:@\n\ \tdef __str__(self) -> str:@\n\
\t\t@[<hov 4>return \"{}({})\".format(self.code, self.value)@]" format_enum_name enum_name \t\t@[<hov 4>return \"{}({})\".format(self.code, self.value)@]"
format_enum_name enum_name
(Format.pp_print_list (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n") ~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n")
(fun _fmt (i, enum_cons, enum_cons_type) -> (fun _fmt (i, enum_cons, enum_cons_type) ->
Format.fprintf fmt "%a = %d" format_enum_cons_name enum_cons i)) Format.fprintf fmt "%a = %d" format_enum_cons_name enum_cons i))
(List.mapi (fun i (x, y) -> (i, x, y)) enum_cons) (List.mapi (fun i (x, y) -> (i, x, y)) enum_cons)
format_enum_name enum_name format_enum_name enum_name format_enum_name enum_name format_enum_name enum_name format_enum_name enum_name format_enum_name
enum_name
in in
let is_in_type_ordering s = let is_in_type_ordering s =
@ -412,7 +485,9 @@ let format_ctx (type_ordering : Scopelang.Dependency.TVertex.t list) (fmt : Form
List.map List.map
(fun (s, _) -> Scopelang.Dependency.TVertex.Struct s) (fun (s, _) -> Scopelang.Dependency.TVertex.Struct s)
(Dcalc.Ast.StructMap.bindings (Dcalc.Ast.StructMap.bindings
(Dcalc.Ast.StructMap.filter (fun s _ -> not (is_in_type_ordering s)) ctx.ctx_structs)) (Dcalc.Ast.StructMap.filter
(fun s _ -> not (is_in_type_ordering s))
ctx.ctx_structs))
in in
List.iter List.iter
(fun struct_or_enum -> (fun struct_or_enum ->
@ -425,10 +500,13 @@ let format_ctx (type_ordering : Scopelang.Dependency.TVertex.t list) (fmt : Form
(e, Dcalc.Ast.EnumMap.find e ctx.Dcalc.Ast.ctx_enums)) (e, Dcalc.Ast.EnumMap.find e ctx.Dcalc.Ast.ctx_enums))
(type_ordering @ scope_structs) (type_ordering @ scope_structs)
let format_program (fmt : Format.formatter) (p : Ast.program) let format_program
(fmt : Format.formatter)
(p : Ast.program)
(type_ordering : Scopelang.Dependency.TVertex.t list) : unit = (type_ordering : Scopelang.Dependency.TVertex.t list) : unit =
(* We disable the style flag in order to enjoy formatting from the pretty-printers of Dcalc and (* We disable the style flag in order to enjoy formatting from the
Lcalc but without the color terminal markers. *) pretty-printers of Dcalc and Lcalc but without the color terminal
markers. *)
Cli.style_flag := false; Cli.style_flag := false;
Format.fprintf fmt Format.fprintf fmt
"# This file has been generated by the Catala compiler, do not edit!\n\ "# This file has been generated by the Catala compiler, do not edit!\n\
@ -445,10 +523,12 @@ let format_program (fmt : Format.formatter) (p : Ast.program)
~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n@\n") ~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n@\n")
(fun fmt body -> (fun fmt body ->
let { Ast.func_params; Ast.func_body } = body.scope_body_func in let { Ast.func_params; Ast.func_body } = body.scope_body_func in
Format.fprintf fmt "@[<hov 4>def %a(%a):@\n%a@]" format_toplevel_name body.scope_body_var Format.fprintf fmt "@[<hov 4>def %a(%a):@\n%a@]" format_toplevel_name
body.scope_body_var
(Format.pp_print_list (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ", ") ~pp_sep:(fun fmt () -> Format.fprintf fmt ", ")
(fun fmt (var, typ) -> (fun fmt (var, typ) ->
Format.fprintf fmt "%a:%a" format_var (Pos.unmark var) format_typ typ)) Format.fprintf fmt "%a:%a" format_var (Pos.unmark var)
format_typ typ))
func_params (format_block p.decl_ctx) func_body)) func_params (format_block p.decl_ctx) func_body))
p.scopes p.scopes

View File

@ -1,18 +1,21 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits (* This file is part of the Catala compiler, a specification language for tax
computation rules. Copyright (C) 2021 Inria, contributor: Denis Merigoux and social benefits computation rules. Copyright (C) 2021 Inria, contributor:
<denis.merigoux@inria.fr> Denis Merigoux <denis.merigoux@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
(** Formats a lambda calculus program into a valid Python program *) (** Formats a lambda calculus program into a valid Python program *)
val format_program : Format.formatter -> Ast.program -> Scopelang.Dependency.TVertex.t list -> unit val format_program :
Format.formatter -> Ast.program -> Scopelang.Dependency.TVertex.t list -> unit
(** Usage [format_program fmt p type_dependencies_ordering] *) (** Usage [format_program fmt p type_dependencies_ordering] *)

View File

@ -1,63 +1,69 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits (* This file is part of the Catala compiler, a specification language for tax
computation rules. Copyright (C) 2020 Inria, contributor: Denis Merigoux and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
<denis.merigoux@inria.fr> Denis Merigoux <denis.merigoux@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
open Utils open Utils
module ScopeName = Dcalc.Ast.ScopeName module ScopeName = Dcalc.Ast.ScopeName
module ScopeNameSet : Set.S with type elt = ScopeName.t = Set.Make (ScopeName) module ScopeNameSet : Set.S with type elt = ScopeName.t = Set.Make (ScopeName)
module ScopeMap : Map.S with type key = ScopeName.t = Map.Make (ScopeName) module ScopeMap : Map.S with type key = ScopeName.t = Map.Make (ScopeName)
module SubScopeName : Uid.Id with type info = Uid.MarkedString.info = Uid.Make (Uid.MarkedString) () module SubScopeName : Uid.Id with type info = Uid.MarkedString.info =
Uid.Make (Uid.MarkedString) ()
module SubScopeNameSet : Set.S with type elt = SubScopeName.t = Set.Make (SubScopeName) module SubScopeNameSet : Set.S with type elt = SubScopeName.t =
Set.Make (SubScopeName)
module SubScopeMap : Map.S with type key = SubScopeName.t = Map.Make (SubScopeName) module SubScopeMap : Map.S with type key = SubScopeName.t =
Map.Make (SubScopeName)
module ScopeVar : Uid.Id with type info = Uid.MarkedString.info = Uid.Make (Uid.MarkedString) () module ScopeVar : Uid.Id with type info = Uid.MarkedString.info =
Uid.Make (Uid.MarkedString) ()
module ScopeVarSet : Set.S with type elt = ScopeVar.t = Set.Make (ScopeVar) module ScopeVarSet : Set.S with type elt = ScopeVar.t = Set.Make (ScopeVar)
module ScopeVarMap : Map.S with type key = ScopeVar.t = Map.Make (ScopeVar) module ScopeVarMap : Map.S with type key = ScopeVar.t = Map.Make (ScopeVar)
module StructName = Dcalc.Ast.StructName module StructName = Dcalc.Ast.StructName
module StructMap = Dcalc.Ast.StructMap module StructMap = Dcalc.Ast.StructMap
module StructFieldName = Dcalc.Ast.StructFieldName module StructFieldName = Dcalc.Ast.StructFieldName
module StructFieldMap : Map.S with type key = StructFieldName.t = Map.Make (StructFieldName) module StructFieldMap : Map.S with type key = StructFieldName.t =
Map.Make (StructFieldName)
module StructFieldMapLift = Bindlib.Lift (StructFieldMap) module StructFieldMapLift = Bindlib.Lift (StructFieldMap)
module EnumName = Dcalc.Ast.EnumName module EnumName = Dcalc.Ast.EnumName
module EnumMap = Dcalc.Ast.EnumMap module EnumMap = Dcalc.Ast.EnumMap
module EnumConstructor = Dcalc.Ast.EnumConstructor module EnumConstructor = Dcalc.Ast.EnumConstructor
module EnumConstructorMap : Map.S with type key = EnumConstructor.t = Map.Make (EnumConstructor) module EnumConstructorMap : Map.S with type key = EnumConstructor.t =
Map.Make (EnumConstructor)
module EnumConstructorMapLift = Bindlib.Lift (EnumConstructorMap) module EnumConstructorMapLift = Bindlib.Lift (EnumConstructorMap)
type location = type location =
| ScopeVar of ScopeVar.t Pos.marked | ScopeVar of ScopeVar.t Pos.marked
| SubScopeVar of ScopeName.t * SubScopeName.t Pos.marked * ScopeVar.t Pos.marked | SubScopeVar of
ScopeName.t * SubScopeName.t Pos.marked * ScopeVar.t Pos.marked
module LocationSet : Set.S with type elt = location Pos.marked = Set.Make (struct module LocationSet : Set.S with type elt = location Pos.marked =
Set.Make (struct
type t = location Pos.marked type t = location Pos.marked
let compare x y = let compare x y =
match (Pos.unmark x, Pos.unmark y) with match (Pos.unmark x, Pos.unmark y) with
| ScopeVar (vx, _), ScopeVar (vy, _) -> ScopeVar.compare vx vy | ScopeVar (vx, _), ScopeVar (vy, _) -> ScopeVar.compare vx vy
| SubScopeVar (_, (xsubindex, _), (xsubvar, _)), SubScopeVar (_, (ysubindex, _), (ysubvar, _)) | ( SubScopeVar (_, (xsubindex, _), (xsubvar, _)),
-> SubScopeVar (_, (ysubindex, _), (ysubvar, _)) ) ->
let c = SubScopeName.compare xsubindex ysubindex in let c = SubScopeName.compare xsubindex ysubindex in
if c = 0 then ScopeVar.compare xsubvar ysubvar else c if c = 0 then ScopeVar.compare xsubvar ysubvar else c
| ScopeVar _, SubScopeVar _ -> -1 | ScopeVar _, SubScopeVar _ -> -1
@ -78,9 +84,11 @@ type expr =
| EStruct of StructName.t * expr Pos.marked StructFieldMap.t | EStruct of StructName.t * expr Pos.marked StructFieldMap.t
| EStructAccess of expr Pos.marked * StructFieldName.t * StructName.t | EStructAccess of expr Pos.marked * StructFieldName.t * StructName.t
| EEnumInj of expr Pos.marked * EnumConstructor.t * EnumName.t | EEnumInj of expr Pos.marked * EnumConstructor.t * EnumName.t
| EMatch of expr Pos.marked * EnumName.t * expr Pos.marked EnumConstructorMap.t | EMatch of
expr Pos.marked * EnumName.t * expr Pos.marked EnumConstructorMap.t
| ELit of Dcalc.Ast.lit | ELit of Dcalc.Ast.lit
| EAbs of (expr, expr Pos.marked) Bindlib.mbinder Pos.marked * typ Pos.marked list | EAbs of
(expr, expr Pos.marked) Bindlib.mbinder Pos.marked * typ Pos.marked list
| EApp of expr Pos.marked * expr Pos.marked list | EApp of expr Pos.marked * expr Pos.marked list
| EOp of Dcalc.Ast.operator | EOp of Dcalc.Ast.operator
| EDefault of expr Pos.marked list * expr Pos.marked * expr Pos.marked | EDefault of expr Pos.marked list * expr Pos.marked * expr Pos.marked
@ -118,11 +126,12 @@ let rec locations_used (e : expr Pos.marked) : LocationSet.t =
(LocationSet.union (locations_used just) (locations_used cons)) (LocationSet.union (locations_used just) (locations_used cons))
excepts excepts
| EArray es -> | EArray es ->
List.fold_left (fun acc e' -> LocationSet.union acc (locations_used e')) LocationSet.empty es List.fold_left
(fun acc e' -> LocationSet.union acc (locations_used e'))
LocationSet.empty es
| ErrorOnEmpty e' -> locations_used e' | ErrorOnEmpty e' -> locations_used e'
type io_input = NoInput | OnlyInput | Reentrant type io_input = NoInput | OnlyInput | Reentrant
type io = { io_output : bool Pos.marked; io_input : io_input Pos.marked } type io = { io_output : bool Pos.marked; io_input : io_input Pos.marked }
type rule = type rule =
@ -137,7 +146,6 @@ type scope_decl = {
} }
type struct_ctx = (StructFieldName.t * typ Pos.marked) list StructMap.t type struct_ctx = (StructFieldName.t * typ Pos.marked) list StructMap.t
type enum_ctx = (EnumConstructor.t * typ Pos.marked) list EnumMap.t type enum_ctx = (EnumConstructor.t * typ Pos.marked) list EnumMap.t
type program = { type program = {
@ -162,15 +170,26 @@ type vars = expr Bindlib.mvar
let make_var ((x, pos) : Var.t Pos.marked) : expr Pos.marked Bindlib.box = let make_var ((x, pos) : Var.t Pos.marked) : expr Pos.marked Bindlib.box =
Bindlib.box_apply (fun v -> (v, pos)) (Bindlib.box_var x) Bindlib.box_apply (fun v -> (v, pos)) (Bindlib.box_var x)
let make_abs (xs : vars) (e : expr Pos.marked Bindlib.box) (pos_binder : Pos.t) let make_abs
(taus : typ Pos.marked list) (pos : Pos.t) : expr Pos.marked Bindlib.box = (xs : vars)
Bindlib.box_apply (fun b -> (EAbs ((b, pos_binder), taus), pos)) (Bindlib.bind_mvar xs e) (e : expr Pos.marked Bindlib.box)
(pos_binder : Pos.t)
(taus : typ Pos.marked list)
(pos : Pos.t) : expr Pos.marked Bindlib.box =
Bindlib.box_apply
(fun b -> (EAbs ((b, pos_binder), taus), pos))
(Bindlib.bind_mvar xs e)
let make_app (e : expr Pos.marked Bindlib.box) (u : expr Pos.marked Bindlib.box list) (pos : Pos.t) let make_app
: expr Pos.marked Bindlib.box = (e : expr Pos.marked Bindlib.box)
(u : expr Pos.marked Bindlib.box list)
(pos : Pos.t) : expr Pos.marked Bindlib.box =
Bindlib.box_apply2 (fun e u -> (EApp (e, u), pos)) e (Bindlib.box_list u) Bindlib.box_apply2 (fun e u -> (EApp (e, u), pos)) e (Bindlib.box_list u)
let make_let_in (x : Var.t) (tau : typ Pos.marked) (e1 : expr Pos.marked Bindlib.box) let make_let_in
(x : Var.t)
(tau : typ Pos.marked)
(e1 : expr Pos.marked Bindlib.box)
(e2 : expr Pos.marked Bindlib.box) : expr Pos.marked Bindlib.box = (e2 : expr Pos.marked Bindlib.box) : expr Pos.marked Bindlib.box =
Bindlib.box_apply2 Bindlib.box_apply2
(fun e u -> (EApp (e, u), Pos.get_position (Bindlib.unbox e2))) (fun e u -> (EApp (e, u), Pos.get_position (Bindlib.unbox e2)))

View File

@ -1,15 +1,17 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits (* This file is part of the Catala compiler, a specification language for tax
computation rules. Copyright (C) 2020 Inria, contributor: Denis Merigoux and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
<denis.merigoux@inria.fr> Denis Merigoux <denis.merigoux@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
(** Abstract syntax tree of the scope language *) (** Abstract syntax tree of the scope language *)
@ -19,46 +21,38 @@ open Utils
(** {1 Identifiers} *) (** {1 Identifiers} *)
module ScopeName = Dcalc.Ast.ScopeName module ScopeName = Dcalc.Ast.ScopeName
module ScopeNameSet : Set.S with type elt = ScopeName.t module ScopeNameSet : Set.S with type elt = ScopeName.t
module ScopeMap : Map.S with type key = ScopeName.t module ScopeMap : Map.S with type key = ScopeName.t
module SubScopeName : Uid.Id with type info = Uid.MarkedString.info module SubScopeName : Uid.Id with type info = Uid.MarkedString.info
module SubScopeNameSet : Set.S with type elt = SubScopeName.t module SubScopeNameSet : Set.S with type elt = SubScopeName.t
module SubScopeMap : Map.S with type key = SubScopeName.t module SubScopeMap : Map.S with type key = SubScopeName.t
module ScopeVar : Uid.Id with type info = Uid.MarkedString.info module ScopeVar : Uid.Id with type info = Uid.MarkedString.info
module ScopeVarSet : Set.S with type elt = ScopeVar.t module ScopeVarSet : Set.S with type elt = ScopeVar.t
module ScopeVarMap : Map.S with type key = ScopeVar.t module ScopeVarMap : Map.S with type key = ScopeVar.t
module StructName = Dcalc.Ast.StructName module StructName = Dcalc.Ast.StructName
module StructMap = Dcalc.Ast.StructMap module StructMap = Dcalc.Ast.StructMap
module StructFieldName = Dcalc.Ast.StructFieldName module StructFieldName = Dcalc.Ast.StructFieldName
module StructFieldMap : Map.S with type key = StructFieldName.t module StructFieldMap : Map.S with type key = StructFieldName.t
module StructFieldMapLift : sig module StructFieldMapLift : sig
val lift_box : 'a Bindlib.box StructFieldMap.t -> 'a StructFieldMap.t Bindlib.box val lift_box :
'a Bindlib.box StructFieldMap.t -> 'a StructFieldMap.t Bindlib.box
end end
module EnumName = Dcalc.Ast.EnumName module EnumName = Dcalc.Ast.EnumName
module EnumMap = Dcalc.Ast.EnumMap module EnumMap = Dcalc.Ast.EnumMap
module EnumConstructor = Dcalc.Ast.EnumConstructor module EnumConstructor = Dcalc.Ast.EnumConstructor
module EnumConstructorMap : Map.S with type key = EnumConstructor.t module EnumConstructorMap : Map.S with type key = EnumConstructor.t
module EnumConstructorMapLift : sig module EnumConstructorMapLift : sig
val lift_box : 'a Bindlib.box EnumConstructorMap.t -> 'a EnumConstructorMap.t Bindlib.box val lift_box :
'a Bindlib.box EnumConstructorMap.t -> 'a EnumConstructorMap.t Bindlib.box
end end
type location = type location =
| ScopeVar of ScopeVar.t Pos.marked | ScopeVar of ScopeVar.t Pos.marked
| SubScopeVar of ScopeName.t * SubScopeName.t Pos.marked * ScopeVar.t Pos.marked | SubScopeVar of
ScopeName.t * SubScopeName.t Pos.marked * ScopeVar.t Pos.marked
module LocationSet : Set.S with type elt = location Pos.marked module LocationSet : Set.S with type elt = location Pos.marked
@ -72,17 +66,19 @@ type typ =
| TArray of typ | TArray of typ
| TAny | TAny
(** The expressions use the {{:https://lepigre.fr/ocaml-bindlib/} Bindlib} library, based on (** The expressions use the {{:https://lepigre.fr/ocaml-bindlib/} Bindlib}
higher-order abstract syntax*) library, based on higher-order abstract syntax*)
type expr = type expr =
| ELocation of location | ELocation of location
| EVar of expr Bindlib.var Pos.marked | EVar of expr Bindlib.var Pos.marked
| EStruct of StructName.t * expr Pos.marked StructFieldMap.t | EStruct of StructName.t * expr Pos.marked StructFieldMap.t
| EStructAccess of expr Pos.marked * StructFieldName.t * StructName.t | EStructAccess of expr Pos.marked * StructFieldName.t * StructName.t
| EEnumInj of expr Pos.marked * EnumConstructor.t * EnumName.t | EEnumInj of expr Pos.marked * EnumConstructor.t * EnumName.t
| EMatch of expr Pos.marked * EnumName.t * expr Pos.marked EnumConstructorMap.t | EMatch of
expr Pos.marked * EnumName.t * expr Pos.marked EnumConstructorMap.t
| ELit of Dcalc.Ast.lit | ELit of Dcalc.Ast.lit
| EAbs of (expr, expr Pos.marked) Bindlib.mbinder Pos.marked * typ Pos.marked list | EAbs of
(expr, expr Pos.marked) Bindlib.mbinder Pos.marked * typ Pos.marked list
| EApp of expr Pos.marked * expr Pos.marked list | EApp of expr Pos.marked * expr Pos.marked list
| EOp of Dcalc.Ast.operator | EOp of Dcalc.Ast.operator
| EDefault of expr Pos.marked list * expr Pos.marked * expr Pos.marked | EDefault of expr Pos.marked list * expr Pos.marked * expr Pos.marked
@ -92,19 +88,23 @@ type expr =
val locations_used : expr Pos.marked -> LocationSet.t val locations_used : expr Pos.marked -> LocationSet.t
(** This type characterizes the three levels of visibility for a given scope variable with regards (** This type characterizes the three levels of visibility for a given scope
to the scope's input and possible redefinitions inside the scope.. *) variable with regards to the scope's input and possible redefinitions inside
the scope.. *)
type io_input = type io_input =
| NoInput | NoInput
(** For an internal variable defined only in the scope, and does not appear in the input. *) (** For an internal variable defined only in the scope, and does not
appear in the input. *)
| OnlyInput | OnlyInput
(** For variables that should not be redefined in the scope, because they appear in the input. *) (** For variables that should not be redefined in the scope, because they
appear in the input. *)
| Reentrant | Reentrant
(** For variables defined in the scope that can also be redefined by the caller as they appear (** For variables defined in the scope that can also be redefined by the
in the input. *) caller as they appear in the input. *)
type io = { type io = {
io_output : bool Pos.marked; (** [true] is present in the output of the scope. *) io_output : bool Pos.marked;
(** [true] is present in the output of the scope. *)
io_input : io_input Pos.marked; io_input : io_input Pos.marked;
} }
(** Characterization of the input/output status of a scope variable. *) (** Characterization of the input/output status of a scope variable. *)
@ -121,7 +121,6 @@ type scope_decl = {
} }
type struct_ctx = (StructFieldName.t * typ Pos.marked) list StructMap.t type struct_ctx = (StructFieldName.t * typ Pos.marked) list StructMap.t
type enum_ctx = (EnumConstructor.t * typ Pos.marked) list EnumMap.t type enum_ctx = (EnumConstructor.t * typ Pos.marked) list EnumMap.t
type program = { type program = {
@ -136,7 +135,6 @@ module Var : sig
type t = expr Bindlib.var type t = expr Bindlib.var
val make : string Pos.marked -> t val make : string Pos.marked -> t
val compare : t -> t -> int val compare : t -> t -> int
end end

View File

@ -1,19 +1,21 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits (* This file is part of the Catala compiler, a specification language for tax
computation rules. Copyright (C) 2020 Inria, contributor: Denis Merigoux and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
<denis.merigoux@inria.fr> Denis Merigoux <denis.merigoux@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
(** Graph representation of the dependencies between scopes in the Catala program. Vertices are (** Graph representation of the dependencies between scopes in the Catala
functions, x -> y if x is used in the definition of y. *) program. Vertices are functions, x -> y if x is used in the definition of y. *)
open Utils open Utils
@ -21,22 +23,22 @@ module SVertex = struct
type t = Ast.ScopeName.t type t = Ast.ScopeName.t
let hash x = Ast.ScopeName.hash x let hash x = Ast.ScopeName.hash x
let compare = Ast.ScopeName.compare let compare = Ast.ScopeName.compare
let equal x y = Ast.ScopeName.compare x y = 0 let equal x y = Ast.ScopeName.compare x y = 0
end end
(** On the edges, the label is the expression responsible for the use of the function *) (** On the edges, the label is the expression responsible for the use of the
function *)
module SEdge = struct module SEdge = struct
type t = Pos.t type t = Pos.t
let compare = compare let compare = compare
let default = Pos.no_pos let default = Pos.no_pos
end end
module SDependencies = Graph.Persistent.Digraph.ConcreteBidirectionalLabeled (SVertex) (SEdge) module SDependencies =
Graph.Persistent.Digraph.ConcreteBidirectionalLabeled (SVertex) (SEdge)
module STopologicalTraversal = Graph.Topological.Make (SDependencies) module STopologicalTraversal = Graph.Topological.Make (SDependencies)
module SSCC = Graph.Components.Make (SDependencies) module SSCC = Graph.Components.Make (SDependencies)
@ -44,7 +46,11 @@ module SSCC = Graph.Components.Make (SDependencies)
let build_program_dep_graph (prgm : Ast.program) : SDependencies.t = let build_program_dep_graph (prgm : Ast.program) : SDependencies.t =
let g = SDependencies.empty in let g = SDependencies.empty in
let g = Ast.ScopeMap.fold (fun v _ g -> SDependencies.add_vertex g v) prgm.program_scopes g in let g =
Ast.ScopeMap.fold
(fun v _ g -> SDependencies.add_vertex g v)
prgm.program_scopes g
in
Ast.ScopeMap.fold Ast.ScopeMap.fold
(fun scope_name scope g -> (fun scope_name scope g ->
let subscopes = let subscopes =
@ -55,9 +61,10 @@ let build_program_dep_graph (prgm : Ast.program) : SDependencies.t =
| Ast.Call (subscope, subindex) -> | Ast.Call (subscope, subindex) ->
if subscope = scope_name then if subscope = scope_name then
Errors.raise_spanned_error Errors.raise_spanned_error
(Pos.get_position (Ast.ScopeName.get_info scope.Ast.scope_decl_name)) (Pos.get_position
"The scope %a is calling into itself as a subscope, which is forbidden since \ (Ast.ScopeName.get_info scope.Ast.scope_decl_name))
Catala does not provide recursion" "The scope %a is calling into itself as a subscope, which \
is forbidden since Catala does not provide recursion"
Ast.ScopeName.format_t scope.Ast.scope_decl_name Ast.ScopeName.format_t scope.Ast.scope_decl_name
else else
Ast.ScopeMap.add subscope Ast.ScopeMap.add subscope
@ -73,7 +80,8 @@ let build_program_dep_graph (prgm : Ast.program) : SDependencies.t =
prgm.program_scopes g prgm.program_scopes g
let check_for_cycle_in_scope (g : SDependencies.t) : unit = let check_for_cycle_in_scope (g : SDependencies.t) : unit =
(* if there is a cycle, there will be an strongly connected component of cardinality > 1 *) (* if there is a cycle, there will be an strongly connected component of
cardinality > 1 *)
let sccs = SSCC.scc_list g in let sccs = SSCC.scc_list g in
if List.length sccs < SDependencies.nb_vertex g then if List.length sccs < SDependencies.nb_vertex g then
let scc = List.find (fun scc -> List.length scc > 1) sccs in let scc = List.find (fun scc -> List.length scc > 1) sccs in
@ -82,19 +90,26 @@ let check_for_cycle_in_scope (g : SDependencies.t) : unit =
(List.map (List.map
(fun v -> (fun v ->
let var_str, var_info = let var_str, var_info =
(Format.asprintf "%a" Ast.ScopeName.format_t v, Ast.ScopeName.get_info v) ( Format.asprintf "%a" Ast.ScopeName.format_t v,
Ast.ScopeName.get_info v )
in in
let succs = SDependencies.succ_e g v in let succs = SDependencies.succ_e g v in
let _, edge_pos, succ = List.find (fun (_, _, succ) -> List.mem succ scc) succs in let _, edge_pos, succ =
List.find (fun (_, _, succ) -> List.mem succ scc) succs
in
let succ_str = Format.asprintf "%a" Ast.ScopeName.format_t succ in let succ_str = Format.asprintf "%a" Ast.ScopeName.format_t succ in
[ [
(Some ("Cycle variable " ^ var_str ^ ", declared:"), Pos.get_position var_info); ( Some ("Cycle variable " ^ var_str ^ ", declared:"),
( Some ("Used here in the definition of another cycle variable " ^ succ_str ^ ":"), Pos.get_position var_info );
( Some
("Used here in the definition of another cycle variable "
^ succ_str ^ ":"),
edge_pos ); edge_pos );
]) ])
scc) scc)
in in
Errors.raise_multispanned_error spans "Cyclic dependency detected between scopes!" Errors.raise_multispanned_error spans
"Cyclic dependency detected between scopes!"
let get_scope_ordering (g : SDependencies.t) : Ast.ScopeName.t list = let get_scope_ordering (g : SDependencies.t) : Ast.ScopeName.t list =
List.rev (STopologicalTraversal.fold (fun sd acc -> sd :: acc) g []) List.rev (STopologicalTraversal.fold (fun sd acc -> sd :: acc) g [])
@ -102,7 +117,10 @@ let get_scope_ordering (g : SDependencies.t) : Ast.ScopeName.t list =
module TVertex = struct module TVertex = struct
type t = Struct of Ast.StructName.t | Enum of Ast.EnumName.t type t = Struct of Ast.StructName.t | Enum of Ast.EnumName.t
let hash x = match x with Struct x -> Ast.StructName.hash x | Enum x -> Ast.EnumName.hash x let hash x =
match x with
| Struct x -> Ast.StructName.hash x
| Enum x -> Ast.EnumName.hash x
let compare x y = let compare x y =
match (x, y) with match (x, y) with
@ -118,24 +136,30 @@ module TVertex = struct
| _ -> false | _ -> false
let format_t (fmt : Format.formatter) (x : t) : unit = let format_t (fmt : Format.formatter) (x : t) : unit =
match x with Struct x -> Ast.StructName.format_t fmt x | Enum x -> Ast.EnumName.format_t fmt x match x with
| Struct x -> Ast.StructName.format_t fmt x
| Enum x -> Ast.EnumName.format_t fmt x
let get_info (x : t) = let get_info (x : t) =
match x with Struct x -> Ast.StructName.get_info x | Enum x -> Ast.EnumName.get_info x match x with
| Struct x -> Ast.StructName.get_info x
| Enum x -> Ast.EnumName.get_info x
end end
module TVertexSet = Set.Make (TVertex) module TVertexSet = Set.Make (TVertex)
(** On the edges, the label is the expression responsible for the use of the function *) (** On the edges, the label is the expression responsible for the use of the
function *)
module TEdge = struct module TEdge = struct
type t = Pos.t type t = Pos.t
let compare = compare let compare = compare
let default = Pos.no_pos let default = Pos.no_pos
end end
module TDependencies = Graph.Persistent.Digraph.ConcreteBidirectionalLabeled (TVertex) (TEdge) module TDependencies =
Graph.Persistent.Digraph.ConcreteBidirectionalLabeled (TVertex) (TEdge)
module TTopologicalTraversal = Graph.Topological.Make (TDependencies) module TTopologicalTraversal = Graph.Topological.Make (TDependencies)
module TSCC = Graph.Components.Make (TDependencies) module TSCC = Graph.Components.Make (TDependencies)
@ -146,11 +170,14 @@ let rec get_structs_or_enums_in_type (t : Ast.typ Pos.marked) : TVertexSet.t =
| Ast.TStruct s -> TVertexSet.singleton (TVertex.Struct s) | Ast.TStruct s -> TVertexSet.singleton (TVertex.Struct s)
| Ast.TEnum e -> TVertexSet.singleton (TVertex.Enum e) | Ast.TEnum e -> TVertexSet.singleton (TVertex.Enum e)
| Ast.TArrow (t1, t2) -> | Ast.TArrow (t1, t2) ->
TVertexSet.union (get_structs_or_enums_in_type t1) (get_structs_or_enums_in_type t2) TVertexSet.union
(get_structs_or_enums_in_type t1)
(get_structs_or_enums_in_type t2)
| Ast.TLit _ | Ast.TAny -> TVertexSet.empty | Ast.TLit _ | Ast.TAny -> TVertexSet.empty
| Ast.TArray t1 -> get_structs_or_enums_in_type (Pos.same_pos_as t1 t) | Ast.TArray t1 -> get_structs_or_enums_in_type (Pos.same_pos_as t1 t)
let build_type_graph (structs : Ast.struct_ctx) (enums : Ast.enum_ctx) : TDependencies.t = let build_type_graph (structs : Ast.struct_ctx) (enums : Ast.enum_ctx) :
TDependencies.t =
let g = TDependencies.empty in let g = TDependencies.empty in
let g = let g =
Ast.StructMap.fold Ast.StructMap.fold
@ -164,11 +191,13 @@ let build_type_graph (structs : Ast.struct_ctx) (enums : Ast.enum_ctx) : TDepend
(fun used g -> (fun used g ->
if TVertex.equal used def then if TVertex.equal used def then
Errors.raise_spanned_error (Pos.get_position typ) Errors.raise_spanned_error (Pos.get_position typ)
"The type %a is defined using itself, which is forbidden since Catala does not \ "The type %a is defined using itself, which is forbidden \
provide recursive types" since Catala does not provide recursive types"
TVertex.format_t used TVertex.format_t used
else else
let edge = TDependencies.E.create used (Pos.get_position typ) def in let edge =
TDependencies.E.create used (Pos.get_position typ) def
in
TDependencies.add_edge_e g edge) TDependencies.add_edge_e g edge)
used g) used g)
g fields) g fields)
@ -186,11 +215,13 @@ let build_type_graph (structs : Ast.struct_ctx) (enums : Ast.enum_ctx) : TDepend
(fun used g -> (fun used g ->
if TVertex.equal used def then if TVertex.equal used def then
Errors.raise_spanned_error (Pos.get_position typ) Errors.raise_spanned_error (Pos.get_position typ)
"The type %a is defined using itself, which is forbidden since Catala does not \ "The type %a is defined using itself, which is forbidden \
provide recursive types" since Catala does not provide recursive types"
TVertex.format_t used TVertex.format_t used
else else
let edge = TDependencies.E.create used (Pos.get_position typ) def in let edge =
TDependencies.E.create used (Pos.get_position typ) def
in
TDependencies.add_edge_e g edge) TDependencies.add_edge_e g edge)
used g) used g)
g cases) g cases)
@ -198,9 +229,11 @@ let build_type_graph (structs : Ast.struct_ctx) (enums : Ast.enum_ctx) : TDepend
in in
g g
let check_type_cycles (structs : Ast.struct_ctx) (enums : Ast.enum_ctx) : TVertex.t list = let check_type_cycles (structs : Ast.struct_ctx) (enums : Ast.enum_ctx) :
TVertex.t list =
let g = build_type_graph structs enums in let g = build_type_graph structs enums in
(* if there is a cycle, there will be an strongly connected component of cardinality > 1 *) (* if there is a cycle, there will be an strongly connected component of
cardinality > 1 *)
let sccs = TSCC.scc_list g in let sccs = TSCC.scc_list g in
(if List.length sccs < TDependencies.nb_vertex g then (if List.length sccs < TDependencies.nb_vertex g then
let scc = List.find (fun scc -> List.length scc > 1) sccs in let scc = List.find (fun scc -> List.length scc > 1) sccs in
@ -208,16 +241,24 @@ let check_type_cycles (structs : Ast.struct_ctx) (enums : Ast.enum_ctx) : TVerte
List.flatten List.flatten
(List.map (List.map
(fun v -> (fun v ->
let var_str, var_info = (Format.asprintf "%a" TVertex.format_t v, TVertex.get_info v) in let var_str, var_info =
(Format.asprintf "%a" TVertex.format_t v, TVertex.get_info v)
in
let succs = TDependencies.succ_e g v in let succs = TDependencies.succ_e g v in
let _, edge_pos, succ = List.find (fun (_, _, succ) -> List.mem succ scc) succs in let _, edge_pos, succ =
List.find (fun (_, _, succ) -> List.mem succ scc) succs
in
let succ_str = Format.asprintf "%a" TVertex.format_t succ in let succ_str = Format.asprintf "%a" TVertex.format_t succ in
[ [
(Some ("Cycle type " ^ var_str ^ ", declared:"), Pos.get_position var_info); ( Some ("Cycle type " ^ var_str ^ ", declared:"),
( Some ("Used here in the definition of another cycle type " ^ succ_str ^ ":"), Pos.get_position var_info );
( Some
("Used here in the definition of another cycle type "
^ succ_str ^ ":"),
edge_pos ); edge_pos );
]) ])
scc) scc)
in in
Errors.raise_multispanned_error spans "Cyclic dependency detected between types!"); Errors.raise_multispanned_error spans
"Cyclic dependency detected between types!");
List.rev (TTopologicalTraversal.fold (fun v acc -> v :: acc) g []) List.rev (TTopologicalTraversal.fold (fun v acc -> v :: acc) g [])

View File

@ -1,31 +1,33 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits (* This file is part of the Catala compiler, a specification language for tax
computation rules. Copyright (C) 2020 Inria, contributor: Denis Merigoux and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
<denis.merigoux@inria.fr> Denis Merigoux <denis.merigoux@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
(** Graph representation of the dependencies between scopes in the Catala program. Vertices are (** Graph representation of the dependencies between scopes in the Catala
functions, x -> y if x is used in the definition of y. *) program. Vertices are functions, x -> y if x is used in the definition of y. *)
open Utils open Utils
(** {1 Scope dependencies} *) (** {1 Scope dependencies} *)
(** On the edges, the label is the expression responsible for the use of the function *) (** On the edges, the label is the expression responsible for the use of the
module SDependencies : Graph.Sig.P with type V.t = Ast.ScopeName.t and type E.label = Pos.t function *)
module SDependencies :
Graph.Sig.P with type V.t = Ast.ScopeName.t and type E.label = Pos.t
val build_program_dep_graph : Ast.program -> SDependencies.t val build_program_dep_graph : Ast.program -> SDependencies.t
val check_for_cycle_in_scope : SDependencies.t -> unit val check_for_cycle_in_scope : SDependencies.t -> unit
val get_scope_ordering : SDependencies.t -> Ast.ScopeName.t list val get_scope_ordering : SDependencies.t -> Ast.ScopeName.t list
(** {1 Type dependencies} *) (** {1 Type dependencies} *)
@ -34,7 +36,6 @@ module TVertex : sig
type t = Struct of Ast.StructName.t | Enum of Ast.EnumName.t type t = Struct of Ast.StructName.t | Enum of Ast.EnumName.t
val format_t : Format.formatter -> t -> unit val format_t : Format.formatter -> t -> unit
val get_info : t -> Ast.StructName.info val get_info : t -> Ast.StructName.info
include Graph.Sig.COMPARABLE with type t := t include Graph.Sig.COMPARABLE with type t := t
@ -42,11 +43,11 @@ end
module TVertexSet : Set.S with type elt = TVertex.t module TVertexSet : Set.S with type elt = TVertex.t
(** On the edges, the label is the expression responsible for the use of the function *) (** On the edges, the label is the expression responsible for the use of the
module TDependencies : Graph.Sig.P with type V.t = TVertex.t and type E.label = Pos.t function *)
module TDependencies :
Graph.Sig.P with type V.t = TVertex.t and type E.label = Pos.t
val get_structs_or_enums_in_type : Ast.typ Pos.marked -> TVertexSet.t val get_structs_or_enums_in_type : Ast.typ Pos.marked -> TVertexSet.t
val build_type_graph : Ast.struct_ctx -> Ast.enum_ctx -> TDependencies.t val build_type_graph : Ast.struct_ctx -> Ast.enum_ctx -> TDependencies.t
val check_type_cycles : Ast.struct_ctx -> Ast.enum_ctx -> TVertex.t list val check_type_cycles : Ast.struct_ctx -> Ast.enum_ctx -> TVertex.t list

View File

@ -1,15 +1,17 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits (* This file is part of the Catala compiler, a specification language for tax
computation rules. Copyright (C) 2020 Inria, contributor: Denis Merigoux and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
<denis.merigoux@inria.fr> Denis Merigoux <denis.merigoux@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
open Utils open Utils
@ -25,8 +27,8 @@ let format_location (fmt : Format.formatter) (l : location) : unit =
match l with match l with
| ScopeVar v -> Format.fprintf fmt "%a" ScopeVar.format_t (Pos.unmark v) | ScopeVar v -> Format.fprintf fmt "%a" ScopeVar.format_t (Pos.unmark v)
| SubScopeVar (_, subindex, subvar) -> | SubScopeVar (_, subindex, subvar) ->
Format.fprintf fmt "%a.%a" SubScopeName.format_t (Pos.unmark subindex) ScopeVar.format_t Format.fprintf fmt "%a.%a" SubScopeName.format_t (Pos.unmark subindex)
(Pos.unmark subvar) ScopeVar.format_t (Pos.unmark subvar)
let typ_needs_parens (e : typ Pos.marked) : bool = let typ_needs_parens (e : typ Pos.marked) : bool =
match Pos.unmark e with TArrow _ -> true | _ -> false match Pos.unmark e with TArrow _ -> true | _ -> false
@ -34,8 +36,8 @@ let typ_needs_parens (e : typ Pos.marked) : bool =
let rec format_typ (fmt : Format.formatter) (typ : typ Pos.marked) : unit = let rec format_typ (fmt : Format.formatter) (typ : typ Pos.marked) : unit =
let format_typ_with_parens (fmt : Format.formatter) (t : typ Pos.marked) = let format_typ_with_parens (fmt : Format.formatter) (t : typ Pos.marked) =
if typ_needs_parens t then if typ_needs_parens t then
Format.fprintf fmt "%a%a%a" Dcalc.Print.format_punctuation "(" format_typ t Format.fprintf fmt "%a%a%a" Dcalc.Print.format_punctuation "(" format_typ
Dcalc.Print.format_punctuation ")" t Dcalc.Print.format_punctuation ")"
else Format.fprintf fmt "%a" format_typ t else Format.fprintf fmt "%a" format_typ t
in in
match Pos.unmark typ with match Pos.unmark typ with
@ -58,84 +60,108 @@ let rec format_expr (fmt : Format.formatter) (e : expr Pos.marked) : unit =
match Pos.unmark e with match Pos.unmark e with
| ELocation l -> Format.fprintf fmt "%a" format_location l | ELocation l -> Format.fprintf fmt "%a" format_location l
| EVar v -> Format.fprintf fmt "%a" format_var (Pos.unmark v) | EVar v -> Format.fprintf fmt "%a" format_var (Pos.unmark v)
| ELit l -> Format.fprintf fmt "%a" Dcalc.Print.format_lit (Pos.same_pos_as l e) | ELit l ->
Format.fprintf fmt "%a" Dcalc.Print.format_lit (Pos.same_pos_as l e)
| EStruct (name, fields) -> | EStruct (name, fields) ->
Format.fprintf fmt " @[<hov 2>%a@ %a@ %a@ %a@]" Ast.StructName.format_t name Format.fprintf fmt " @[<hov 2>%a@ %a@ %a@ %a@]" Ast.StructName.format_t
Dcalc.Print.format_punctuation "{" name Dcalc.Print.format_punctuation "{"
(Format.pp_print_list (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "%a@ " Dcalc.Print.format_punctuation ";") ~pp_sep:(fun fmt () ->
Format.fprintf fmt "%a@ " Dcalc.Print.format_punctuation ";")
(fun fmt (field_name, field_expr) -> (fun fmt (field_name, field_expr) ->
Format.fprintf fmt "%a%a%a%a@ %a" Dcalc.Print.format_punctuation "\"" Format.fprintf fmt "%a%a%a%a@ %a" Dcalc.Print.format_punctuation
Ast.StructFieldName.format_t field_name Dcalc.Print.format_punctuation "\"" "\"" Ast.StructFieldName.format_t field_name
Dcalc.Print.format_punctuation "\""
Dcalc.Print.format_punctuation "=" format_expr field_expr)) Dcalc.Print.format_punctuation "=" format_expr field_expr))
(Ast.StructFieldMap.bindings fields) (Ast.StructFieldMap.bindings fields)
Dcalc.Print.format_punctuation "}" Dcalc.Print.format_punctuation "}"
| EStructAccess (e1, field, _) -> | EStructAccess (e1, field, _) ->
Format.fprintf fmt "%a%a%a%a%a" format_expr e1 Dcalc.Print.format_punctuation "." Format.fprintf fmt "%a%a%a%a%a" format_expr e1
Dcalc.Print.format_punctuation "\"" Ast.StructFieldName.format_t field Dcalc.Print.format_punctuation "." Dcalc.Print.format_punctuation "\""
Dcalc.Print.format_punctuation "\"" Ast.StructFieldName.format_t field Dcalc.Print.format_punctuation "\""
| EEnumInj (e1, cons, _) -> | EEnumInj (e1, cons, _) ->
Format.fprintf fmt "%a@ %a" Ast.EnumConstructor.format_t cons format_expr e1 Format.fprintf fmt "%a@ %a" Ast.EnumConstructor.format_t cons format_expr
e1
| EMatch (e1, _, cases) -> | EMatch (e1, _, cases) ->
Format.fprintf fmt "@[<hov 0>%a@ @[<hov 2>%a@]@ %a@ %a@]" Dcalc.Print.format_keyword "match" Format.fprintf fmt "@[<hov 0>%a@ @[<hov 2>%a@]@ %a@ %a@]"
format_expr e1 Dcalc.Print.format_keyword "with" Dcalc.Print.format_keyword "match" format_expr e1
Dcalc.Print.format_keyword "with"
(Format.pp_print_list (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n") ~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n")
(fun fmt (cons_name, case_expr) -> (fun fmt (cons_name, case_expr) ->
Format.fprintf fmt "@[<hov 2>%a %a@ %a@ %a@]" Dcalc.Print.format_punctuation "|" Format.fprintf fmt "@[<hov 2>%a %a@ %a@ %a@]"
Dcalc.Print.format_enum_constructor cons_name Dcalc.Print.format_punctuation "" Dcalc.Print.format_punctuation "|"
format_expr case_expr)) Dcalc.Print.format_enum_constructor cons_name
Dcalc.Print.format_punctuation "" format_expr case_expr))
(Ast.EnumConstructorMap.bindings cases) (Ast.EnumConstructorMap.bindings cases)
| EApp ((EAbs ((binder, _), taus), _), args) -> | EApp ((EAbs ((binder, _), taus), _), args) ->
let xs, body = Bindlib.unmbind binder in let xs, body = Bindlib.unmbind binder in
let xs_tau = List.map2 (fun x tau -> (x, tau)) (Array.to_list xs) taus in let xs_tau = List.map2 (fun x tau -> (x, tau)) (Array.to_list xs) taus in
let xs_tau_arg = List.map2 (fun (x, tau) arg -> (x, tau, arg)) xs_tau args in let xs_tau_arg =
List.map2 (fun (x, tau) arg -> (x, tau, arg)) xs_tau args
in
Format.fprintf fmt "@[%a%a@]" Format.fprintf fmt "@[%a%a@]"
(Format.pp_print_list (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt " ") ~pp_sep:(fun fmt () -> Format.fprintf fmt " ")
(fun fmt (x, tau, arg) -> (fun fmt (x, tau, arg) ->
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@ %a@ %a@ %a@ %a@\n@]" Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@ %a@ %a@ %a@ %a@\n@]"
Dcalc.Print.format_keyword "let" format_var x Dcalc.Print.format_punctuation ":" Dcalc.Print.format_keyword "let" format_var x
format_typ tau Dcalc.Print.format_punctuation "=" format_expr arg Dcalc.Print.format_punctuation ":" format_typ tau
Dcalc.Print.format_punctuation "=" format_expr arg
Dcalc.Print.format_keyword "in")) Dcalc.Print.format_keyword "in"))
xs_tau_arg format_expr body xs_tau_arg format_expr body
| EAbs ((binder, _), taus) -> | EAbs ((binder, _), taus) ->
let xs, body = Bindlib.unmbind binder in let xs, body = Bindlib.unmbind binder in
let xs_tau = List.map2 (fun x tau -> (x, tau)) (Array.to_list xs) taus in let xs_tau = List.map2 (fun x tau -> (x, tau)) (Array.to_list xs) taus in
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@ %a@]" Dcalc.Print.format_punctuation "λ" Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@ %a@]"
Dcalc.Print.format_punctuation "λ"
(Format.pp_print_list (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt " ") ~pp_sep:(fun fmt () -> Format.fprintf fmt " ")
(fun fmt (x, tau) -> (fun fmt (x, tau) ->
Format.fprintf fmt "@[%a%a%a@ %a%a@]" Dcalc.Print.format_punctuation "(" format_var x Format.fprintf fmt "@[%a%a%a@ %a%a@]"
Dcalc.Print.format_punctuation ":" format_typ tau Dcalc.Print.format_punctuation ")")) Dcalc.Print.format_punctuation "(" format_var x
Dcalc.Print.format_punctuation ":" format_typ tau
Dcalc.Print.format_punctuation ")"))
xs_tau Dcalc.Print.format_punctuation "" format_expr body xs_tau Dcalc.Print.format_punctuation "" format_expr body
| EApp ((EOp (Binop op), _), [ arg1; arg2 ]) -> | EApp ((EOp (Binop op), _), [ arg1; arg2 ]) ->
Format.fprintf fmt "@[%a@ %a@ %a@]" format_with_parens arg1 Dcalc.Print.format_binop Format.fprintf fmt "@[%a@ %a@ %a@]" format_with_parens arg1
(op, Pos.no_pos) format_with_parens arg2 Dcalc.Print.format_binop (op, Pos.no_pos) format_with_parens arg2
| EApp ((EOp (Unop op), _), [ arg1 ]) -> | EApp ((EOp (Unop op), _), [ arg1 ]) ->
Format.fprintf fmt "@[%a@ %a@]" Dcalc.Print.format_unop (op, Pos.no_pos) format_with_parens Format.fprintf fmt "@[%a@ %a@]" Dcalc.Print.format_unop (op, Pos.no_pos)
arg1 format_with_parens arg1
| EApp (f, args) -> | EApp (f, args) ->
Format.fprintf fmt "@[%a@ %a@]" format_expr f Format.fprintf fmt "@[%a@ %a@]" format_expr f
(Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ") format_with_parens) (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ")
format_with_parens)
args args
| EIfThenElse (e1, e2, e3) -> | EIfThenElse (e1, e2, e3) ->
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@ %a@ %a@ %a@]" Dcalc.Print.format_keyword "if" Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@ %a@ %a@ %a@]"
format_expr e1 Dcalc.Print.format_keyword "then" format_expr e2 Dcalc.Print.format_keyword Dcalc.Print.format_keyword "if" format_expr e1
"else" format_expr e3 Dcalc.Print.format_keyword "then" format_expr e2
| EOp (Ternop op) -> Format.fprintf fmt "%a" Dcalc.Print.format_ternop (op, Pos.no_pos) Dcalc.Print.format_keyword "else" format_expr e3
| EOp (Binop op) -> Format.fprintf fmt "%a" Dcalc.Print.format_binop (op, Pos.no_pos) | EOp (Ternop op) ->
| EOp (Unop op) -> Format.fprintf fmt "%a" Dcalc.Print.format_unop (op, Pos.no_pos) Format.fprintf fmt "%a" Dcalc.Print.format_ternop (op, Pos.no_pos)
| EOp (Binop op) ->
Format.fprintf fmt "%a" Dcalc.Print.format_binop (op, Pos.no_pos)
| EOp (Unop op) ->
Format.fprintf fmt "%a" Dcalc.Print.format_unop (op, Pos.no_pos)
| EDefault (excepts, just, cons) -> | EDefault (excepts, just, cons) ->
if List.length excepts = 0 then if List.length excepts = 0 then
Format.fprintf fmt "@[%a%a %a@ %a%a@]" Dcalc.Print.format_punctuation "" format_expr just Format.fprintf fmt "@[%a%a %a@ %a%a@]" Dcalc.Print.format_punctuation
Dcalc.Print.format_punctuation "" format_expr cons Dcalc.Print.format_punctuation "" "" format_expr just Dcalc.Print.format_punctuation "" format_expr
cons Dcalc.Print.format_punctuation ""
else else
Format.fprintf fmt "@[<hov 2>%a%a@ %a@ %a %a@ %a%a@]" Dcalc.Print.format_punctuation "" Format.fprintf fmt "@[<hov 2>%a%a@ %a@ %a %a@ %a%a@]"
(Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ") format_expr) Dcalc.Print.format_punctuation ""
excepts Dcalc.Print.format_punctuation "|" format_expr just Dcalc.Print.format_punctuation (Format.pp_print_list
"" format_expr cons Dcalc.Print.format_punctuation "" ~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ")
| ErrorOnEmpty e' -> Format.fprintf fmt "error_empty@ %a" format_with_parens e' format_expr)
excepts Dcalc.Print.format_punctuation "|" format_expr just
Dcalc.Print.format_punctuation "" format_expr cons
Dcalc.Print.format_punctuation ""
| ErrorOnEmpty e' ->
Format.fprintf fmt "error_empty@ %a" format_with_parens e'
| EArray es -> | EArray es ->
Format.fprintf fmt "%a%a%a" Dcalc.Print.format_punctuation "[" Format.fprintf fmt "%a%a%a" Dcalc.Print.format_punctuation "["
(Format.pp_print_list (Format.pp_print_list
@ -143,10 +169,13 @@ let rec format_expr (fmt : Format.formatter) (e : expr Pos.marked) : unit =
(fun fmt e -> Format.fprintf fmt "@[%a@]" format_expr e)) (fun fmt e -> Format.fprintf fmt "@[%a@]" format_expr e))
es Dcalc.Print.format_punctuation "]" es Dcalc.Print.format_punctuation "]"
let format_struct (fmt : Format.formatter) let format_struct
((name, fields) : StructName.t * (StructFieldName.t * typ Pos.marked) list) : unit = (fmt : Format.formatter)
Format.fprintf fmt "%a %a %a %a@\n@[<hov 2> %a@]@\n%a" Dcalc.Print.format_keyword "type" ((name, fields) : StructName.t * (StructFieldName.t * typ Pos.marked) list)
StructName.format_t name Dcalc.Print.format_punctuation "=" Dcalc.Print.format_punctuation "{" : unit =
Format.fprintf fmt "%a %a %a %a@\n@[<hov 2> %a@]@\n%a"
Dcalc.Print.format_keyword "type" StructName.format_t name
Dcalc.Print.format_punctuation "=" Dcalc.Print.format_punctuation "{"
(Format.pp_print_list (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n") ~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n")
(fun fmt (field_name, typ) -> (fun fmt (field_name, typ) ->
@ -154,26 +183,32 @@ let format_struct (fmt : Format.formatter)
Dcalc.Print.format_punctuation ":" format_typ typ)) Dcalc.Print.format_punctuation ":" format_typ typ))
fields Dcalc.Print.format_punctuation "}" fields Dcalc.Print.format_punctuation "}"
let format_enum (fmt : Format.formatter) let format_enum
((name, cases) : EnumName.t * (EnumConstructor.t * typ Pos.marked) list) : unit = (fmt : Format.formatter)
Format.fprintf fmt "%a %a %a @\n@[<hov 2> %a@]" Dcalc.Print.format_keyword "type" ((name, cases) : EnumName.t * (EnumConstructor.t * typ Pos.marked) list) :
EnumName.format_t name Dcalc.Print.format_punctuation "=" unit =
Format.fprintf fmt "%a %a %a @\n@[<hov 2> %a@]" Dcalc.Print.format_keyword
"type" EnumName.format_t name Dcalc.Print.format_punctuation "="
(Format.pp_print_list (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n") ~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n")
(fun fmt (field_name, typ) -> (fun fmt (field_name, typ) ->
Format.fprintf fmt "%a %a%a %a" Dcalc.Print.format_punctuation "|" EnumConstructor.format_t Format.fprintf fmt "%a %a%a %a" Dcalc.Print.format_punctuation "|"
field_name Dcalc.Print.format_punctuation ":" format_typ typ)) EnumConstructor.format_t field_name Dcalc.Print.format_punctuation
":" format_typ typ))
cases cases
let format_scope (fmt : Format.formatter) ((name, decl) : ScopeName.t * scope_decl) : unit = let format_scope
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@ %a@ %a@]@\n@[<v 2> %a@]" Dcalc.Print.format_keyword (fmt : Format.formatter) ((name, decl) : ScopeName.t * scope_decl) : unit =
"let" Dcalc.Print.format_keyword "scope" ScopeName.format_t name Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@ %a@ %a@]@\n@[<v 2> %a@]"
Dcalc.Print.format_keyword "let" Dcalc.Print.format_keyword "scope"
ScopeName.format_t name
(Format.pp_print_list (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ") ~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ")
(fun fmt (scope_var, (typ, vis)) -> (fun fmt (scope_var, (typ, vis)) ->
Format.fprintf fmt "%a%a%a %a%a%a%a%a" Dcalc.Print.format_punctuation "(" ScopeVar.format_t Format.fprintf fmt "%a%a%a %a%a%a%a%a" Dcalc.Print.format_punctuation
scope_var Dcalc.Print.format_punctuation ":" format_typ typ "(" ScopeVar.format_t scope_var Dcalc.Print.format_punctuation ":"
Dcalc.Print.format_punctuation "|" Dcalc.Print.format_keyword format_typ typ Dcalc.Print.format_punctuation "|"
Dcalc.Print.format_keyword
(match Pos.unmark vis.io_input with (match Pos.unmark vis.io_input with
| NoInput -> "internal" | NoInput -> "internal"
| OnlyInput -> "input" | OnlyInput -> "input"
@ -186,19 +221,23 @@ let format_scope (fmt : Format.formatter) ((name, decl) : ScopeName.t * scope_de
(ScopeVarMap.bindings decl.scope_sig) (ScopeVarMap.bindings decl.scope_sig)
Dcalc.Print.format_punctuation "=" Dcalc.Print.format_punctuation "="
(Format.pp_print_list (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "%a@ " Dcalc.Print.format_punctuation ";") ~pp_sep:(fun fmt () ->
Format.fprintf fmt "%a@ " Dcalc.Print.format_punctuation ";")
(fun fmt rule -> (fun fmt rule ->
match rule with match rule with
| Definition (loc, typ, _, e) -> | Definition (loc, typ, _, e) ->
Format.fprintf fmt "@[<hov 2>%a %a %a %a %a@ %a@]" Dcalc.Print.format_keyword "let" Format.fprintf fmt "@[<hov 2>%a %a %a %a %a@ %a@]"
format_location (Pos.unmark loc) Dcalc.Print.format_punctuation ":" format_typ typ Dcalc.Print.format_keyword "let" format_location (Pos.unmark loc)
Dcalc.Print.format_punctuation ":" format_typ typ
Dcalc.Print.format_punctuation "=" Dcalc.Print.format_punctuation "="
(fun fmt e -> (fun fmt e ->
match Pos.unmark loc with match Pos.unmark loc with
| SubScopeVar _ -> format_expr fmt e | SubScopeVar _ -> format_expr fmt e
| ScopeVar v -> ( | ScopeVar v -> (
match match
Pos.unmark (snd (ScopeVarMap.find (Pos.unmark v) decl.scope_sig)).io_input Pos.unmark
(snd (ScopeVarMap.find (Pos.unmark v) decl.scope_sig))
.io_input
with with
| Reentrant -> | Reentrant ->
Format.fprintf fmt "%a@ %a" Dcalc.Print.format_operator Format.fprintf fmt "%a@ %a" Dcalc.Print.format_operator
@ -206,25 +245,34 @@ let format_scope (fmt : Format.formatter) ((name, decl) : ScopeName.t * scope_de
| _ -> Format.fprintf fmt "%a" format_expr e)) | _ -> Format.fprintf fmt "%a" format_expr e))
e e
| Assertion e -> | Assertion e ->
Format.fprintf fmt "%a %a" Dcalc.Print.format_keyword "assert" format_expr e Format.fprintf fmt "%a %a" Dcalc.Print.format_keyword "assert"
format_expr e
| Call (scope_name, subscope_name) -> | Call (scope_name, subscope_name) ->
Format.fprintf fmt "%a %a%a%a%a" Dcalc.Print.format_keyword "call" ScopeName.format_t Format.fprintf fmt "%a %a%a%a%a" Dcalc.Print.format_keyword "call"
scope_name Dcalc.Print.format_punctuation "[" SubScopeName.format_t subscope_name ScopeName.format_t scope_name Dcalc.Print.format_punctuation "["
SubScopeName.format_t subscope_name
Dcalc.Print.format_punctuation "]")) Dcalc.Print.format_punctuation "]"))
decl.scope_decl_rules decl.scope_decl_rules
let format_program (fmt : Format.formatter) (p : program) : unit = let format_program (fmt : Format.formatter) (p : program) : unit =
Format.fprintf fmt "%a%a%a%a%a" Format.fprintf fmt "%a%a%a%a%a"
(Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt "\n\n") format_struct) (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "\n\n")
format_struct)
(StructMap.bindings p.program_structs) (StructMap.bindings p.program_structs)
(fun fmt () -> (fun fmt () ->
if StructMap.is_empty p.program_structs then Format.fprintf fmt "" if StructMap.is_empty p.program_structs then Format.fprintf fmt ""
else Format.fprintf fmt "\n\n") else Format.fprintf fmt "\n\n")
() ()
(Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt "\n\n") format_enum) (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "\n\n")
format_enum)
(EnumMap.bindings p.program_enums) (EnumMap.bindings p.program_enums)
(fun fmt () -> (fun fmt () ->
if EnumMap.is_empty p.program_enums then Format.fprintf fmt "" else Format.fprintf fmt "\n\n") if EnumMap.is_empty p.program_enums then Format.fprintf fmt ""
else Format.fprintf fmt "\n\n")
() ()
(Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt "\n\n") format_scope) (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "\n\n")
format_scope)
(ScopeMap.bindings p.program_scopes) (ScopeMap.bindings p.program_scopes)

View File

@ -1,27 +1,24 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits (* This file is part of the Catala compiler, a specification language for tax
computation rules. Copyright (C) 2020 Inria, contributor: Denis Merigoux and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
<denis.merigoux@inria.fr> Denis Merigoux <denis.merigoux@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
open Utils open Utils
val format_var : Format.formatter -> Ast.Var.t -> unit val format_var : Format.formatter -> Ast.Var.t -> unit
val format_location : Format.formatter -> Ast.location -> unit val format_location : Format.formatter -> Ast.location -> unit
val format_typ : Format.formatter -> Ast.typ Pos.marked -> unit val format_typ : Format.formatter -> Ast.typ Pos.marked -> unit
val format_expr : Format.formatter -> Ast.expr Pos.marked -> unit val format_expr : Format.formatter -> Ast.expr Pos.marked -> unit
val format_scope : Format.formatter -> Ast.ScopeName.t * Ast.scope_decl -> unit val format_scope : Format.formatter -> Ast.ScopeName.t * Ast.scope_decl -> unit
val format_program : Format.formatter -> Ast.program -> unit val format_program : Format.formatter -> Ast.program -> unit

View File

@ -1,15 +1,17 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits (* This file is part of the Catala compiler, a specification language for tax
computation rules. Copyright (C) 2020 Inria, contributor: Denis Merigoux and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
<denis.merigoux@inria.fr> Denis Merigoux <denis.merigoux@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
open Utils open Utils
@ -37,11 +39,16 @@ type ctx = {
scope_name : Ast.ScopeName.t; scope_name : Ast.ScopeName.t;
scopes_parameters : scope_sigs_ctx; scopes_parameters : scope_sigs_ctx;
scope_vars : (Dcalc.Ast.Var.t * Dcalc.Ast.typ * Ast.io) Ast.ScopeVarMap.t; scope_vars : (Dcalc.Ast.Var.t * Dcalc.Ast.typ * Ast.io) Ast.ScopeVarMap.t;
subscope_vars : (Dcalc.Ast.Var.t * Dcalc.Ast.typ * Ast.io) Ast.ScopeVarMap.t Ast.SubScopeMap.t; subscope_vars :
(Dcalc.Ast.Var.t * Dcalc.Ast.typ * Ast.io) Ast.ScopeVarMap.t
Ast.SubScopeMap.t;
local_vars : Dcalc.Ast.Var.t Ast.VarMap.t; local_vars : Dcalc.Ast.Var.t Ast.VarMap.t;
} }
let empty_ctx (struct_ctx : Ast.struct_ctx) (enum_ctx : Ast.enum_ctx) (scopes_ctx : scope_sigs_ctx) let empty_ctx
(struct_ctx : Ast.struct_ctx)
(enum_ctx : Ast.enum_ctx)
(scopes_ctx : scope_sigs_ctx)
(scope_name : Ast.ScopeName.t) = (scope_name : Ast.ScopeName.t) =
{ {
structs = struct_ctx; structs = struct_ctx;
@ -53,23 +60,30 @@ let empty_ctx (struct_ctx : Ast.struct_ctx) (enum_ctx : Ast.enum_ctx) (scopes_ct
local_vars = Ast.VarMap.empty; local_vars = Ast.VarMap.empty;
} }
let rec translate_typ (ctx : ctx) (t : Ast.typ Pos.marked) : Dcalc.Ast.typ Pos.marked = let rec translate_typ (ctx : ctx) (t : Ast.typ Pos.marked) :
Dcalc.Ast.typ Pos.marked =
Pos.same_pos_as Pos.same_pos_as
(match Pos.unmark t with (match Pos.unmark t with
| Ast.TLit l -> Dcalc.Ast.TLit l | Ast.TLit l -> Dcalc.Ast.TLit l
| Ast.TArrow (t1, t2) -> Dcalc.Ast.TArrow (translate_typ ctx t1, translate_typ ctx t2) | Ast.TArrow (t1, t2) ->
Dcalc.Ast.TArrow (translate_typ ctx t1, translate_typ ctx t2)
| Ast.TStruct s_uid -> | Ast.TStruct s_uid ->
let s_fields = Ast.StructMap.find s_uid ctx.structs in let s_fields = Ast.StructMap.find s_uid ctx.structs in
Dcalc.Ast.TTuple (List.map (fun (_, t) -> translate_typ ctx t) s_fields, Some s_uid) Dcalc.Ast.TTuple
(List.map (fun (_, t) -> translate_typ ctx t) s_fields, Some s_uid)
| Ast.TEnum e_uid -> | Ast.TEnum e_uid ->
let e_cases = Ast.EnumMap.find e_uid ctx.enums in let e_cases = Ast.EnumMap.find e_uid ctx.enums in
Dcalc.Ast.TEnum (List.map (fun (_, t) -> translate_typ ctx t) e_cases, e_uid) Dcalc.Ast.TEnum
| Ast.TArray t1 -> Dcalc.Ast.TArray (translate_typ ctx (Pos.same_pos_as t1 t)) (List.map (fun (_, t) -> translate_typ ctx t) e_cases, e_uid)
| Ast.TArray t1 ->
Dcalc.Ast.TArray (translate_typ ctx (Pos.same_pos_as t1 t))
| Ast.TAny -> Dcalc.Ast.TAny) | Ast.TAny -> Dcalc.Ast.TAny)
t t
let merge_defaults (caller : Dcalc.Ast.expr Pos.marked Bindlib.box) let merge_defaults
(callee : Dcalc.Ast.expr Pos.marked Bindlib.box) : Dcalc.Ast.expr Pos.marked Bindlib.box = (caller : Dcalc.Ast.expr Pos.marked Bindlib.box)
(callee : Dcalc.Ast.expr Pos.marked Bindlib.box) :
Dcalc.Ast.expr Pos.marked Bindlib.box =
let caller = let caller =
Dcalc.Ast.make_app caller Dcalc.Ast.make_app caller
[ Bindlib.box (Dcalc.Ast.ELit Dcalc.Ast.LUnit, Pos.no_pos) ] [ Bindlib.box (Dcalc.Ast.ELit Dcalc.Ast.LUnit, Pos.no_pos) ]
@ -79,23 +93,30 @@ let merge_defaults (caller : Dcalc.Ast.expr Pos.marked Bindlib.box)
Bindlib.box_apply2 Bindlib.box_apply2
(fun caller callee -> (fun caller callee ->
( Dcalc.Ast.EDefault ( Dcalc.Ast.EDefault
([ caller ], (Dcalc.Ast.ELit (Dcalc.Ast.LBool true), Pos.no_pos), callee), ( [ caller ],
(Dcalc.Ast.ELit (Dcalc.Ast.LBool true), Pos.no_pos),
callee ),
Pos.no_pos )) Pos.no_pos ))
caller callee caller callee
in in
body body
let tag_with_log_entry (e : Dcalc.Ast.expr Pos.marked Bindlib.box) (l : Dcalc.Ast.log_entry) let tag_with_log_entry
(markings : Utils.Uid.MarkedString.info list) : Dcalc.Ast.expr Pos.marked Bindlib.box = (e : Dcalc.Ast.expr Pos.marked Bindlib.box)
(l : Dcalc.Ast.log_entry)
(markings : Utils.Uid.MarkedString.info list) :
Dcalc.Ast.expr Pos.marked Bindlib.box =
Bindlib.box_apply Bindlib.box_apply
(fun e -> (fun e ->
( Dcalc.Ast.EApp ( Dcalc.Ast.EApp
((Dcalc.Ast.EOp (Dcalc.Ast.Unop (Dcalc.Ast.Log (l, markings))), Pos.get_position e), [ e ]), ( ( Dcalc.Ast.EOp (Dcalc.Ast.Unop (Dcalc.Ast.Log (l, markings))),
Pos.get_position e ),
[ e ] ),
Pos.get_position e )) Pos.get_position e ))
e e
let rec translate_expr (ctx : ctx) (e : Ast.expr Pos.marked) : Dcalc.Ast.expr Pos.marked Bindlib.box let rec translate_expr (ctx : ctx) (e : Ast.expr Pos.marked) :
= Dcalc.Ast.expr Pos.marked Bindlib.box =
Bindlib.box_apply Bindlib.box_apply
(fun (x : Dcalc.Ast.expr) -> Pos.same_pos_as x e) (fun (x : Dcalc.Ast.expr) -> Pos.same_pos_as x e)
(match Pos.unmark e with (match Pos.unmark e with
@ -108,13 +129,14 @@ let rec translate_expr (ctx : ctx) (e : Ast.expr Pos.marked) : Dcalc.Ast.expr Po
(fun (field_name, _) (d_fields, e_fields) -> (fun (field_name, _) (d_fields, e_fields) ->
let field_e = Ast.StructFieldMap.find field_name e_fields in let field_e = Ast.StructFieldMap.find field_name e_fields in
let field_d = translate_expr ctx field_e in let field_d = translate_expr ctx field_e in
(field_d :: d_fields, Ast.StructFieldMap.remove field_name e_fields)) ( field_d :: d_fields,
Ast.StructFieldMap.remove field_name e_fields ))
struct_sig ([], e_fields) struct_sig ([], e_fields)
in in
if Ast.StructFieldMap.cardinal remaining_e_fields > 0 then if Ast.StructFieldMap.cardinal remaining_e_fields > 0 then
Errors.raise_spanned_error (Pos.get_position e) Errors.raise_spanned_error (Pos.get_position e)
"The fields \"%a\" do not belong to the structure %a" Ast.StructName.format_t "The fields \"%a\" do not belong to the structure %a"
struct_name Ast.StructName.format_t struct_name
(Format.pp_print_list (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ", ") ~pp_sep:(fun fmt () -> Format.fprintf fmt ", ")
(fun fmt (field_name, _) -> (fun fmt (field_name, _) ->
@ -127,11 +149,14 @@ let rec translate_expr (ctx : ctx) (e : Ast.expr Pos.marked) : Dcalc.Ast.expr Po
| EStructAccess (e1, field_name, struct_name) -> | EStructAccess (e1, field_name, struct_name) ->
let struct_sig = Ast.StructMap.find struct_name ctx.structs in let struct_sig = Ast.StructMap.find struct_name ctx.structs in
let _, field_index = let _, field_index =
try List.assoc field_name (List.mapi (fun i (x, y) -> (x, (y, i))) struct_sig) try
List.assoc field_name
(List.mapi (fun i (x, y) -> (x, (y, i))) struct_sig)
with Not_found -> with Not_found ->
Errors.raise_spanned_error (Pos.get_position e) Errors.raise_spanned_error (Pos.get_position e)
"The field \"%a\" does not belong to the structure %a" Ast.StructFieldName.format_t "The field \"%a\" does not belong to the structure %a"
field_name Ast.StructName.format_t struct_name Ast.StructFieldName.format_t field_name Ast.StructName.format_t
struct_name
in in
let e1 = translate_expr ctx e1 in let e1 = translate_expr ctx e1 in
Bindlib.box_apply Bindlib.box_apply
@ -145,11 +170,14 @@ let rec translate_expr (ctx : ctx) (e : Ast.expr Pos.marked) : Dcalc.Ast.expr Po
| EEnumInj (e1, constructor, enum_name) -> | EEnumInj (e1, constructor, enum_name) ->
let enum_sig = Ast.EnumMap.find enum_name ctx.enums in let enum_sig = Ast.EnumMap.find enum_name ctx.enums in
let _, constructor_index = let _, constructor_index =
try List.assoc constructor (List.mapi (fun i (x, y) -> (x, (y, i))) enum_sig) try
List.assoc constructor
(List.mapi (fun i (x, y) -> (x, (y, i))) enum_sig)
with Not_found -> with Not_found ->
Errors.raise_spanned_error (Pos.get_position e) Errors.raise_spanned_error (Pos.get_position e)
"The constructor \"%a\" does not belong to the enum %a" Ast.EnumConstructor.format_t "The constructor \"%a\" does not belong to the enum %a"
constructor Ast.EnumName.format_t enum_name Ast.EnumConstructor.format_t constructor Ast.EnumName.format_t
enum_name
in in
let e1 = translate_expr ctx e1 in let e1 = translate_expr ctx e1 in
Bindlib.box_apply Bindlib.box_apply
@ -169,17 +197,20 @@ let rec translate_expr (ctx : ctx) (e : Ast.expr Pos.marked) : Dcalc.Ast.expr Po
try Ast.EnumConstructorMap.find constructor e_cases try Ast.EnumConstructorMap.find constructor e_cases
with Not_found -> with Not_found ->
Errors.raise_spanned_error (Pos.get_position e) Errors.raise_spanned_error (Pos.get_position e)
"The constructor %a of enum %a is missing from this pattern matching" "The constructor %a of enum %a is missing from this \
Ast.EnumConstructor.format_t constructor Ast.EnumName.format_t enum_name pattern matching"
Ast.EnumConstructor.format_t constructor
Ast.EnumName.format_t enum_name
in in
let case_d = translate_expr ctx case_e in let case_d = translate_expr ctx case_e in
(case_d :: d_cases, Ast.EnumConstructorMap.remove constructor e_cases)) ( case_d :: d_cases,
Ast.EnumConstructorMap.remove constructor e_cases ))
enum_sig ([], cases) enum_sig ([], cases)
in in
if Ast.EnumConstructorMap.cardinal remaining_e_cases > 0 then if Ast.EnumConstructorMap.cardinal remaining_e_cases > 0 then
Errors.raise_spanned_error (Pos.get_position e) Errors.raise_spanned_error (Pos.get_position e)
"Patter matching is incomplete for enum %a: missing cases %a" Ast.EnumName.format_t "Patter matching is incomplete for enum %a: missing cases %a"
enum_name Ast.EnumName.format_t enum_name
(Format.pp_print_list (Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ", ") ~pp_sep:(fun fmt () -> Format.fprintf fmt ", ")
(fun fmt (case_name, _) -> (fun fmt (case_name, _) ->
@ -191,18 +222,20 @@ let rec translate_expr (ctx : ctx) (e : Ast.expr Pos.marked) : Dcalc.Ast.expr Po
(fun d_fields e1 -> Dcalc.Ast.EMatch (e1, d_fields, enum_name)) (fun d_fields e1 -> Dcalc.Ast.EMatch (e1, d_fields, enum_name))
(Bindlib.box_list d_cases) e1 (Bindlib.box_list d_cases) e1
| EApp (e1, args) -> | EApp (e1, args) ->
(* We insert various log calls to record arguments and outputs of user-defined functions (* We insert various log calls to record arguments and outputs of
belonging to scopes *) user-defined functions belonging to scopes *)
let e1_func = translate_expr ctx e1 in let e1_func = translate_expr ctx e1 in
let markings l = let markings l =
match l with match l with
| Ast.ScopeVar (v, _) -> | Ast.ScopeVar (v, _) ->
[ Ast.ScopeName.get_info ctx.scope_name; Ast.ScopeVar.get_info v ] [ Ast.ScopeName.get_info ctx.scope_name; Ast.ScopeVar.get_info v ]
| Ast.SubScopeVar (s, _, (v, _)) -> [ Ast.ScopeName.get_info s; Ast.ScopeVar.get_info v ] | Ast.SubScopeVar (s, _, (v, _)) ->
[ Ast.ScopeName.get_info s; Ast.ScopeVar.get_info v ]
in in
let e1_func = let e1_func =
match Pos.unmark e1 with match Pos.unmark e1 with
| ELocation l -> tag_with_log_entry e1_func Dcalc.Ast.BeginCall (markings l) | ELocation l ->
tag_with_log_entry e1_func Dcalc.Ast.BeginCall (markings l)
| _ -> e1_func | _ -> e1_func
in in
let new_args = List.map (translate_expr ctx) args in let new_args = List.map (translate_expr ctx) args in
@ -218,7 +251,8 @@ let rec translate_expr (ctx : ctx) (e : Ast.expr Pos.marked) : Dcalc.Ast.expr Po
let new_e = let new_e =
Bindlib.box_apply2 Bindlib.box_apply2
(fun e' u -> (Dcalc.Ast.EApp (e', u), Pos.get_position e)) (fun e' u -> (Dcalc.Ast.EApp (e', u), Pos.get_position e))
e1_func (Bindlib.box_list new_args) e1_func
(Bindlib.box_list new_args)
in in
let new_e = let new_e =
match Pos.unmark e1 with match Pos.unmark e1 with
@ -232,7 +266,11 @@ let rec translate_expr (ctx : ctx) (e : Ast.expr Pos.marked) : Dcalc.Ast.expr Po
Bindlib.box_apply Pos.unmark new_e Bindlib.box_apply Pos.unmark new_e
| EAbs ((binder, pos_binder), typ) -> | EAbs ((binder, pos_binder), typ) ->
let xs, body = Bindlib.unmbind binder in let xs, body = Bindlib.unmbind binder in
let new_xs = Array.map (fun x -> Dcalc.Ast.Var.make (Bindlib.name_of x, Pos.no_pos)) xs in let new_xs =
Array.map
(fun x -> Dcalc.Ast.Var.make (Bindlib.name_of x, Pos.no_pos))
xs
in
let both_xs = Array.map2 (fun x new_x -> (x, new_x)) xs new_xs in let both_xs = Array.map2 (fun x new_x -> (x, new_x)) xs new_xs in
let body = let body =
translate_expr translate_expr
@ -240,17 +278,22 @@ let rec translate_expr (ctx : ctx) (e : Ast.expr Pos.marked) : Dcalc.Ast.expr Po
ctx with ctx with
local_vars = local_vars =
Array.fold_left Array.fold_left
(fun local_vars (x, new_x) -> Ast.VarMap.add x new_x local_vars) (fun local_vars (x, new_x) ->
Ast.VarMap.add x new_x local_vars)
ctx.local_vars both_xs; ctx.local_vars both_xs;
} }
body body
in in
let binder = Bindlib.bind_mvar new_xs body in let binder = Bindlib.bind_mvar new_xs body in
Bindlib.box_apply Bindlib.box_apply
(fun b -> Dcalc.Ast.EAbs ((b, pos_binder), List.map (translate_typ ctx) typ)) (fun b ->
Dcalc.Ast.EAbs ((b, pos_binder), List.map (translate_typ ctx) typ))
binder binder
| EDefault (excepts, just, cons) -> | EDefault (excepts, just, cons) ->
let just = tag_with_log_entry (translate_expr ctx just) Dcalc.Ast.PosRecordIfTrueBool [] in let just =
tag_with_log_entry (translate_expr ctx just)
Dcalc.Ast.PosRecordIfTrueBool []
in
Bindlib.box_apply3 Bindlib.box_apply3
(fun e j c -> Dcalc.Ast.EDefault (e, j, c)) (fun e j c -> Dcalc.Ast.EDefault (e, j, c))
(Bindlib.box_list (List.map (translate_expr ctx) excepts)) (Bindlib.box_list (List.map (translate_expr ctx) excepts))
@ -274,28 +317,36 @@ let rec translate_expr (ctx : ctx) (e : Ast.expr Pos.marked) : Dcalc.Ast.expr Po
( Some "Incriminated subscope declaration:", ( Some "Incriminated subscope declaration:",
Pos.get_position (Ast.SubScopeName.get_info (Pos.unmark s)) ); Pos.get_position (Ast.SubScopeName.get_info (Pos.unmark s)) );
] ]
"The variable %a.%a cannot be used here, as it is not part subscope %a's results. \ "The variable %a.%a cannot be used here, as it is not part \
Maybe you forgot to qualify it as an output?" subscope %a's results. Maybe you forgot to qualify it as an \
Ast.SubScopeName.format_t (Pos.unmark s) Ast.ScopeVar.format_t (Pos.unmark a) output?"
Ast.SubScopeName.format_t (Pos.unmark s)) Ast.SubScopeName.format_t (Pos.unmark s) Ast.ScopeVar.format_t
(Pos.unmark a) Ast.SubScopeName.format_t (Pos.unmark s))
| EIfThenElse (cond, et, ef) -> | EIfThenElse (cond, et, ef) ->
Bindlib.box_apply3 Bindlib.box_apply3
(fun c t f -> Dcalc.Ast.EIfThenElse (c, t, f)) (fun c t f -> Dcalc.Ast.EIfThenElse (c, t, f))
(translate_expr ctx cond) (translate_expr ctx et) (translate_expr ctx ef) (translate_expr ctx cond) (translate_expr ctx et)
(translate_expr ctx ef)
| EOp op -> Bindlib.box (Dcalc.Ast.EOp op) | EOp op -> Bindlib.box (Dcalc.Ast.EOp op)
| ErrorOnEmpty e' -> | ErrorOnEmpty e' ->
Bindlib.box_apply (fun e' -> Dcalc.Ast.ErrorOnEmpty e') (translate_expr ctx e') Bindlib.box_apply
(fun e' -> Dcalc.Ast.ErrorOnEmpty e')
(translate_expr ctx e')
| EArray es -> | EArray es ->
Bindlib.box_apply Bindlib.box_apply
(fun es -> Dcalc.Ast.EArray es) (fun es -> Dcalc.Ast.EArray es)
(Bindlib.box_list (List.map (translate_expr ctx) es))) (Bindlib.box_list (List.map (translate_expr ctx) es)))
(** The result of a rule translation is a list of assignment, with variables and expressions. We (** The result of a rule translation is a list of assignment, with variables and
also return the new translation context available after the assignment to use in later rule expressions. We also return the new translation context available after the
translations. The list is actually a list of list because we want to group in assignments that assignment to use in later rule translations. The list is actually a list of
are independent of each other to speed up the translation by minimizing Bindlib.bind_mvar *) list because we want to group in assignments that are independent of each
let translate_rule (ctx : ctx) (rule : Ast.rule) other to speed up the translation by minimizing Bindlib.bind_mvar *)
((sigma_name, pos_sigma) : Utils.Uid.MarkedString.info) : Dcalc.Ast.scope_let list * ctx = let translate_rule
(ctx : ctx)
(rule : Ast.rule)
((sigma_name, pos_sigma) : Utils.Uid.MarkedString.info) :
Dcalc.Ast.scope_let list * ctx =
match rule with match rule with
| Definition ((ScopeVar a, var_def_pos), tau, a_io, e) -> | Definition ((ScopeVar a, var_def_pos), tau, a_io, e) ->
let a_name = Ast.ScopeVar.get_info (Pos.unmark a) in let a_name = Ast.ScopeVar.get_info (Pos.unmark a) in
@ -305,11 +356,13 @@ let translate_rule (ctx : ctx) (rule : Ast.rule)
let a_expr = Dcalc.Ast.make_var (a_var, var_def_pos) in let a_expr = Dcalc.Ast.make_var (a_var, var_def_pos) in
let merged_expr = let merged_expr =
Bindlib.box_apply Bindlib.box_apply
(fun merged_expr -> (Dcalc.Ast.ErrorOnEmpty merged_expr, Pos.get_position a_name)) (fun merged_expr ->
(Dcalc.Ast.ErrorOnEmpty merged_expr, Pos.get_position a_name))
(match Pos.unmark a_io.io_input with (match Pos.unmark a_io.io_input with
| OnlyInput -> | OnlyInput ->
failwith "should not happen" failwith "should not happen"
(* scopelang should not contain any definitions of input only variables *) (* scopelang should not contain any definitions of input only
variables *)
| Reentrant -> merge_defaults a_expr new_e | Reentrant -> merge_defaults a_expr new_e
| NoInput -> new_e) | NoInput -> new_e)
in in
@ -329,12 +382,19 @@ let translate_rule (ctx : ctx) (rule : Ast.rule)
{ {
ctx with ctx with
scope_vars = scope_vars =
Ast.ScopeVarMap.add (Pos.unmark a) (a_var, Pos.unmark tau, a_io) ctx.scope_vars; Ast.ScopeVarMap.add (Pos.unmark a)
(a_var, Pos.unmark tau, a_io)
ctx.scope_vars;
} ) } )
| Definition ((SubScopeVar (_subs_name, subs_index, subs_var), var_def_pos), tau, a_io, e) -> | Definition
( (SubScopeVar (_subs_name, subs_index, subs_var), var_def_pos),
tau,
a_io,
e ) ->
let a_name = let a_name =
Pos.map_under_mark Pos.map_under_mark
(fun str -> str ^ "." ^ Pos.unmark (Ast.ScopeVar.get_info (Pos.unmark subs_var))) (fun str ->
str ^ "." ^ Pos.unmark (Ast.ScopeVar.get_info (Pos.unmark subs_var)))
(Ast.SubScopeName.get_info (Pos.unmark subs_index)) (Ast.SubScopeName.get_info (Pos.unmark subs_index))
in in
let a_var = Dcalc.Ast.Var.make a_name in let a_var = Dcalc.Ast.Var.make a_name in
@ -350,7 +410,8 @@ let translate_rule (ctx : ctx) (rule : Ast.rule)
| NoInput -> failwith "should not happen" | NoInput -> failwith "should not happen"
| OnlyInput -> | OnlyInput ->
Bindlib.box_apply Bindlib.box_apply
(fun new_e -> (Dcalc.Ast.ErrorOnEmpty new_e, Pos.get_position subs_var)) (fun new_e ->
(Dcalc.Ast.ErrorOnEmpty new_e, Pos.get_position subs_var))
new_e new_e
| Reentrant -> | Reentrant ->
Dcalc.Ast.make_abs Dcalc.Ast.make_abs
@ -366,7 +427,9 @@ let translate_rule (ctx : ctx) (rule : Ast.rule)
(match Pos.unmark a_io.io_input with (match Pos.unmark a_io.io_input with
| NoInput -> failwith "should not happen" | NoInput -> failwith "should not happen"
| OnlyInput -> tau | OnlyInput -> tau
| Reentrant -> (Dcalc.Ast.TArrow ((TLit TUnit, var_def_pos), tau), var_def_pos)); | Reentrant ->
( Dcalc.Ast.TArrow ((TLit TUnit, var_def_pos), tau),
var_def_pos ));
Dcalc.Ast.scope_let_expr = thunked_or_nonempty_new_e; Dcalc.Ast.scope_let_expr = thunked_or_nonempty_new_e;
Dcalc.Ast.scope_let_kind = Dcalc.Ast.SubScopeVarDefinition; Dcalc.Ast.scope_let_kind = Dcalc.Ast.SubScopeVarDefinition;
}; };
@ -379,7 +442,9 @@ let translate_rule (ctx : ctx) (rule : Ast.rule)
match map with match map with
| Some map -> | Some map ->
Some Some
(Ast.ScopeVarMap.add (Pos.unmark subs_var) (a_var, Pos.unmark tau, a_io) map) (Ast.ScopeVarMap.add (Pos.unmark subs_var)
(a_var, Pos.unmark tau, a_io)
map)
| None -> | None ->
Some Some
(Ast.ScopeVarMap.singleton (Pos.unmark subs_var) (Ast.ScopeVarMap.singleton (Pos.unmark subs_var)
@ -392,11 +457,15 @@ let translate_rule (ctx : ctx) (rule : Ast.rule)
let all_subscope_input_vars = let all_subscope_input_vars =
List.filter List.filter
(fun var_ctx -> (fun var_ctx ->
match Pos.unmark var_ctx.scope_var_io.Ast.io_input with NoInput -> false | _ -> true) match Pos.unmark var_ctx.scope_var_io.Ast.io_input with
| NoInput -> false
| _ -> true)
all_subscope_vars all_subscope_vars
in in
let all_subscope_output_vars = let all_subscope_output_vars =
List.filter (fun var_ctx -> Pos.unmark var_ctx.scope_var_io.Ast.io_output) all_subscope_vars List.filter
(fun var_ctx -> Pos.unmark var_ctx.scope_var_io.Ast.io_output)
all_subscope_vars
in in
let scope_dcalc_var = subscope_sig.scope_sig_scope_var in let scope_dcalc_var = subscope_sig.scope_sig_scope_var in
let called_scope_input_struct = subscope_sig.scope_sig_input_struct in let called_scope_input_struct = subscope_sig.scope_sig_input_struct in
@ -413,19 +482,23 @@ let translate_rule (ctx : ctx) (rule : Ast.rule)
List.map List.map
(fun (subvar : scope_var_ctx) -> (fun (subvar : scope_var_ctx) ->
if subscope_var_not_yet_defined subvar.scope_var_name then if subscope_var_not_yet_defined subvar.scope_var_name then
(* This is a redundant check. Normally, all subscope varaibles should have been (* This is a redundant check. Normally, all subscope varaibles
defined (even an empty definition, if they're not defined by any rule in the source should have been defined (even an empty definition, if they're
code) by the translation from desugared to the scope language. *) not defined by any rule in the source code) by the translation
from desugared to the scope language. *)
Bindlib.box Dcalc.Ast.empty_thunked_term Bindlib.box Dcalc.Ast.empty_thunked_term
else else
let a_var, _, _ = Ast.ScopeVarMap.find subvar.scope_var_name subscope_vars_defined in let a_var, _, _ =
Ast.ScopeVarMap.find subvar.scope_var_name subscope_vars_defined
in
Dcalc.Ast.make_var (a_var, pos_call)) Dcalc.Ast.make_var (a_var, pos_call))
all_subscope_input_vars all_subscope_input_vars
in in
let subscope_struct_arg = let subscope_struct_arg =
Bindlib.box_apply Bindlib.box_apply
(fun subscope_args -> (fun subscope_args ->
(Dcalc.Ast.ETuple (subscope_args, Some called_scope_input_struct), pos_call)) ( Dcalc.Ast.ETuple (subscope_args, Some called_scope_input_struct),
pos_call ))
(Bindlib.box_list subscope_args) (Bindlib.box_list subscope_args)
in in
let all_subscope_output_vars_dcalc = let all_subscope_output_vars_dcalc =
@ -434,7 +507,8 @@ let translate_rule (ctx : ctx) (rule : Ast.rule)
let sub_dcalc_var = let sub_dcalc_var =
Dcalc.Ast.Var.make Dcalc.Ast.Var.make
(Pos.map_under_mark (Pos.map_under_mark
(fun s -> Pos.unmark (Ast.SubScopeName.get_info subindex) ^ "." ^ s) (fun s ->
Pos.unmark (Ast.SubScopeName.get_info subindex) ^ "." ^ s)
(Ast.ScopeVar.get_info subvar.scope_var_name)) (Ast.ScopeVar.get_info subvar.scope_var_name))
in in
(subvar, sub_dcalc_var)) (subvar, sub_dcalc_var))
@ -443,7 +517,8 @@ let translate_rule (ctx : ctx) (rule : Ast.rule)
let subscope_func = let subscope_func =
tag_with_log_entry tag_with_log_entry
(Dcalc.Ast.make_var (Dcalc.Ast.make_var
(scope_dcalc_var, Pos.get_position (Ast.SubScopeName.get_info subindex))) ( scope_dcalc_var,
Pos.get_position (Ast.SubScopeName.get_info subindex) ))
Dcalc.Ast.BeginCall Dcalc.Ast.BeginCall
[ [
(sigma_name, pos_sigma); (sigma_name, pos_sigma);
@ -495,7 +570,8 @@ let translate_rule (ctx : ctx) (rule : Ast.rule)
i, i,
Some called_scope_return_struct, Some called_scope_return_struct,
List.map List.map
(fun (var_ctx, _) -> (var_ctx.scope_var_typ, pos_sigma)) (fun (var_ctx, _) ->
(var_ctx.scope_var_typ, pos_sigma))
all_subscope_output_vars_dcalc ), all_subscope_output_vars_dcalc ),
pos_sigma )) pos_sigma ))
(Dcalc.Ast.make_var (result_tuple_var, pos_sigma)); (Dcalc.Ast.make_var (result_tuple_var, pos_sigma));
@ -523,12 +599,13 @@ let translate_rule (ctx : ctx) (rule : Ast.rule)
(Dcalc.Ast.Var.make ("_", Pos.get_position e), Pos.get_position e); (Dcalc.Ast.Var.make ("_", Pos.get_position e), Pos.get_position e);
Dcalc.Ast.scope_let_typ = (Dcalc.Ast.TLit TUnit, Pos.get_position e); Dcalc.Ast.scope_let_typ = (Dcalc.Ast.TLit TUnit, Pos.get_position e);
Dcalc.Ast.scope_let_expr = Dcalc.Ast.scope_let_expr =
(* To ensure that we throw an error if the value is not defined, we add an check (* To ensure that we throw an error if the value is not defined,
"ErrorOnEmpty" here. *) we add an check "ErrorOnEmpty" here. *)
Bindlib.box_apply Bindlib.box_apply
(fun new_e -> (fun new_e ->
Pos.same_pos_as Pos.same_pos_as
(Dcalc.Ast.EAssert (Dcalc.Ast.ErrorOnEmpty new_e, Pos.get_position e)) (Dcalc.Ast.EAssert
(Dcalc.Ast.ErrorOnEmpty new_e, Pos.get_position e))
e) e)
new_e; new_e;
Dcalc.Ast.scope_let_kind = Dcalc.Ast.Assertion; Dcalc.Ast.scope_let_kind = Dcalc.Ast.Assertion;
@ -536,50 +613,66 @@ let translate_rule (ctx : ctx) (rule : Ast.rule)
], ],
ctx ) ctx )
let translate_rules (ctx : ctx) (rules : Ast.rule list) let translate_rules
(ctx : ctx)
(rules : Ast.rule list)
((sigma_name, pos_sigma) : Utils.Uid.MarkedString.info) ((sigma_name, pos_sigma) : Utils.Uid.MarkedString.info)
(sigma_return_struct_name : Ast.StructName.t) : (sigma_return_struct_name : Ast.StructName.t) :
Dcalc.Ast.scope_let list * Dcalc.Ast.expr Pos.marked Bindlib.box * ctx = Dcalc.Ast.scope_let list * Dcalc.Ast.expr Pos.marked Bindlib.box * ctx =
let scope_lets, new_ctx = let scope_lets, new_ctx =
List.fold_left List.fold_left
(fun (scope_lets, ctx) rule -> (fun (scope_lets, ctx) rule ->
let new_scope_lets, new_ctx = translate_rule ctx rule (sigma_name, pos_sigma) in let new_scope_lets, new_ctx =
translate_rule ctx rule (sigma_name, pos_sigma)
in
(scope_lets @ new_scope_lets, new_ctx)) (scope_lets @ new_scope_lets, new_ctx))
([], ctx) rules ([], ctx) rules
in in
let scope_variables = Ast.ScopeVarMap.bindings new_ctx.scope_vars in let scope_variables = Ast.ScopeVarMap.bindings new_ctx.scope_vars in
let scope_output_variables = let scope_output_variables =
List.filter (fun (_, (_, _, io)) -> Pos.unmark io.Ast.io_output) scope_variables List.filter
(fun (_, (_, _, io)) -> Pos.unmark io.Ast.io_output)
scope_variables
in in
let return_exp = let return_exp =
Bindlib.box_apply Bindlib.box_apply
(fun args -> (Dcalc.Ast.ETuple (args, Some sigma_return_struct_name), pos_sigma)) (fun args ->
(Dcalc.Ast.ETuple (args, Some sigma_return_struct_name), pos_sigma))
(Bindlib.box_list (Bindlib.box_list
(List.map (List.map
(fun (_, (dcalc_var, _, _)) -> Dcalc.Ast.make_var (dcalc_var, pos_sigma)) (fun (_, (dcalc_var, _, _)) ->
Dcalc.Ast.make_var (dcalc_var, pos_sigma))
scope_output_variables)) scope_output_variables))
in in
(scope_lets, return_exp, new_ctx) (scope_lets, return_exp, new_ctx)
let translate_scope_decl (struct_ctx : Ast.struct_ctx) (enum_ctx : Ast.enum_ctx) let translate_scope_decl
(sctx : scope_sigs_ctx) (scope_name : Ast.ScopeName.t) (sigma : Ast.scope_decl) : (struct_ctx : Ast.struct_ctx)
Dcalc.Ast.scope_body * Dcalc.Ast.struct_ctx = (enum_ctx : Ast.enum_ctx)
(sctx : scope_sigs_ctx)
(scope_name : Ast.ScopeName.t)
(sigma : Ast.scope_decl) : Dcalc.Ast.scope_body * Dcalc.Ast.struct_ctx =
let sigma_info = Ast.ScopeName.get_info sigma.scope_decl_name in let sigma_info = Ast.ScopeName.get_info sigma.scope_decl_name in
let scope_sig = Ast.ScopeMap.find sigma.scope_decl_name sctx in let scope_sig = Ast.ScopeMap.find sigma.scope_decl_name sctx in
let scope_variables = scope_sig.scope_sig_local_vars in let scope_variables = scope_sig.scope_sig_local_vars in
let ctx = let ctx =
(* the context must be initialized for fresh variables for all only-input scope variables *) (* the context must be initialized for fresh variables for all only-input
scope variables *)
List.fold_left List.fold_left
(fun ctx scope_var -> (fun ctx scope_var ->
match Pos.unmark scope_var.scope_var_io.io_input with match Pos.unmark scope_var.scope_var_io.io_input with
| OnlyInput -> | OnlyInput ->
let scope_var_name = Ast.ScopeVar.get_info scope_var.scope_var_name in let scope_var_name =
Ast.ScopeVar.get_info scope_var.scope_var_name
in
let scope_var_dcalc = Dcalc.Ast.Var.make scope_var_name in let scope_var_dcalc = Dcalc.Ast.Var.make scope_var_name in
{ {
ctx with ctx with
scope_vars = scope_vars =
Ast.ScopeVarMap.add scope_var.scope_var_name Ast.ScopeVarMap.add scope_var.scope_var_name
(scope_var_dcalc, scope_var.scope_var_typ, scope_var.scope_var_io) ( scope_var_dcalc,
scope_var.scope_var_typ,
scope_var.scope_var_io )
ctx.scope_vars; ctx.scope_vars;
} }
| _ -> ctx) | _ -> ctx)
@ -591,12 +684,15 @@ let translate_scope_decl (struct_ctx : Ast.struct_ctx) (enum_ctx : Ast.enum_ctx)
let scope_return_struct_name = scope_sig.scope_sig_output_struct in let scope_return_struct_name = scope_sig.scope_sig_output_struct in
let pos_sigma = Pos.get_position sigma_info in let pos_sigma = Pos.get_position sigma_info in
let rules, return_exp, ctx = let rules, return_exp, ctx =
translate_rules ctx sigma.scope_decl_rules sigma_info scope_return_struct_name translate_rules ctx sigma.scope_decl_rules sigma_info
scope_return_struct_name
in in
let scope_variables = let scope_variables =
List.map List.map
(fun var_ctx -> (fun var_ctx ->
let dcalc_x, _, _ = Ast.ScopeVarMap.find var_ctx.scope_var_name ctx.scope_vars in let dcalc_x, _, _ =
Ast.ScopeVarMap.find var_ctx.scope_var_name ctx.scope_vars
in
(var_ctx, dcalc_x)) (var_ctx, dcalc_x))
scope_variables scope_variables
in in
@ -604,17 +700,23 @@ let translate_scope_decl (struct_ctx : Ast.struct_ctx) (enum_ctx : Ast.enum_ctx)
let scope_input_variables = let scope_input_variables =
List.filter List.filter
(fun (var_ctx, _) -> (fun (var_ctx, _) ->
match Pos.unmark var_ctx.scope_var_io.io_input with NoInput -> false | _ -> true) match Pos.unmark var_ctx.scope_var_io.io_input with
| NoInput -> false
| _ -> true)
scope_variables scope_variables
in in
let scope_output_variables = let scope_output_variables =
List.filter (fun (var_ctx, _) -> Pos.unmark var_ctx.scope_var_io.io_output) scope_variables List.filter
(fun (var_ctx, _) -> Pos.unmark var_ctx.scope_var_io.io_output)
scope_variables
in in
let input_var_typ (var_ctx : scope_var_ctx) = let input_var_typ (var_ctx : scope_var_ctx) =
match Pos.unmark var_ctx.scope_var_io.io_input with match Pos.unmark var_ctx.scope_var_io.io_input with
| OnlyInput -> (var_ctx.scope_var_typ, pos_sigma) | OnlyInput -> (var_ctx.scope_var_typ, pos_sigma)
| Reentrant -> | Reentrant ->
( Dcalc.Ast.TArrow ((Dcalc.Ast.TLit TUnit, pos_sigma), (var_ctx.scope_var_typ, pos_sigma)), ( Dcalc.Ast.TArrow
( (Dcalc.Ast.TLit TUnit, pos_sigma),
(var_ctx.scope_var_typ, pos_sigma) ),
pos_sigma ) pos_sigma )
| NoInput -> failwith "should not happen" | NoInput -> failwith "should not happen"
in in
@ -632,7 +734,9 @@ let translate_scope_decl (struct_ctx : Ast.struct_ctx) (enum_ctx : Ast.enum_ctx)
( r, ( r,
i, i,
Some scope_input_struct_name, Some scope_input_struct_name,
List.map (fun (var_ctx, _) -> input_var_typ var_ctx) scope_input_variables ), List.map
(fun (var_ctx, _) -> input_var_typ var_ctx)
scope_input_variables ),
pos_sigma )) pos_sigma ))
(Dcalc.Ast.make_var (scope_input_var, pos_sigma)); (Dcalc.Ast.make_var (scope_input_var, pos_sigma));
}) })
@ -658,7 +762,8 @@ let translate_scope_decl (struct_ctx : Ast.struct_ctx) (enum_ctx : Ast.enum_ctx)
in in
let new_struct_ctx = let new_struct_ctx =
Ast.StructMap.add scope_input_struct_name scope_input_struct_fields Ast.StructMap.add scope_input_struct_name scope_input_struct_fields
(Ast.StructMap.singleton scope_return_struct_name scope_return_struct_fields) (Ast.StructMap.singleton scope_return_struct_name
scope_return_struct_fields)
in in
( { ( {
Dcalc.Ast.scope_body_lets = input_destructurings @ rules; Dcalc.Ast.scope_body_lets = input_destructurings @ rules;
@ -669,10 +774,13 @@ let translate_scope_decl (struct_ctx : Ast.struct_ctx) (enum_ctx : Ast.enum_ctx)
}, },
new_struct_ctx ) new_struct_ctx )
let translate_program (prgm : Ast.program) : Dcalc.Ast.program * Dependency.TVertex.t list = let translate_program (prgm : Ast.program) :
Dcalc.Ast.program * Dependency.TVertex.t list =
let scope_dependencies = Dependency.build_program_dep_graph prgm in let scope_dependencies = Dependency.build_program_dep_graph prgm in
Dependency.check_for_cycle_in_scope scope_dependencies; Dependency.check_for_cycle_in_scope scope_dependencies;
let types_ordering = Dependency.check_type_cycles prgm.program_structs prgm.program_enums in let types_ordering =
Dependency.check_type_cycles prgm.program_structs prgm.program_enums
in
let scope_ordering = Dependency.get_scope_ordering scope_dependencies in let scope_ordering = Dependency.get_scope_ordering scope_dependencies in
let struct_ctx = prgm.program_structs in let struct_ctx = prgm.program_structs in
let enum_ctx = prgm.program_enums in let enum_ctx = prgm.program_enums in
@ -684,36 +792,52 @@ let translate_program (prgm : Ast.program) : Dcalc.Ast.program * Dependency.TVer
{ {
Dcalc.Ast.ctx_structs = Dcalc.Ast.ctx_structs =
Ast.StructMap.map Ast.StructMap.map
(List.map (fun (x, y) -> (x, translate_typ (ctx_for_typ_translation dummy_scope) y))) (List.map (fun (x, y) ->
(x, translate_typ (ctx_for_typ_translation dummy_scope) y)))
struct_ctx; struct_ctx;
Dcalc.Ast.ctx_enums = Dcalc.Ast.ctx_enums =
Ast.EnumMap.map Ast.EnumMap.map
(List.map (fun (x, y) -> (x, (translate_typ (ctx_for_typ_translation dummy_scope)) y))) (List.map (fun (x, y) ->
(x, (translate_typ (ctx_for_typ_translation dummy_scope)) y)))
enum_ctx; enum_ctx;
} }
in in
let sctx : scope_sigs_ctx = let sctx : scope_sigs_ctx =
Ast.ScopeMap.mapi Ast.ScopeMap.mapi
(fun scope_name scope -> (fun scope_name scope ->
let scope_dvar = Dcalc.Ast.Var.make (Ast.ScopeName.get_info scope.Ast.scope_decl_name) in let scope_dvar =
Dcalc.Ast.Var.make (Ast.ScopeName.get_info scope.Ast.scope_decl_name)
in
let scope_return_struct_name = let scope_return_struct_name =
Ast.StructName.fresh Ast.StructName.fresh
(Pos.map_under_mark (fun s -> s ^ "_out") (Ast.ScopeName.get_info scope_name)) (Pos.map_under_mark
(fun s -> s ^ "_out")
(Ast.ScopeName.get_info scope_name))
in in
let scope_input_var = let scope_input_var =
Dcalc.Ast.Var.make Dcalc.Ast.Var.make
(Pos.map_under_mark (fun s -> s ^ "_in") (Ast.ScopeName.get_info scope_name)) (Pos.map_under_mark
(fun s -> s ^ "_in")
(Ast.ScopeName.get_info scope_name))
in in
let scope_input_struct_name = let scope_input_struct_name =
Ast.StructName.fresh Ast.StructName.fresh
(Pos.map_under_mark (fun s -> s ^ "_in") (Ast.ScopeName.get_info scope_name)) (Pos.map_under_mark
(fun s -> s ^ "_in")
(Ast.ScopeName.get_info scope_name))
in in
{ {
scope_sig_local_vars = scope_sig_local_vars =
List.map List.map
(fun (scope_var, (tau, vis)) -> (fun (scope_var, (tau, vis)) ->
let tau = translate_typ (ctx_for_typ_translation scope_name) tau in let tau =
{ scope_var_name = scope_var; scope_var_typ = Pos.unmark tau; scope_var_io = vis }) translate_typ (ctx_for_typ_translation scope_name) tau
in
{
scope_var_name = scope_var;
scope_var_typ = Pos.unmark tau;
scope_var_io = vis;
})
(Ast.ScopeVarMap.bindings scope.scope_sig); (Ast.ScopeVarMap.bindings scope.scope_sig);
scope_sig_scope_var = scope_dvar; scope_sig_scope_var = scope_dvar;
scope_sig_input_var = scope_input_var; scope_sig_input_var = scope_input_var;
@ -722,14 +846,20 @@ let translate_program (prgm : Ast.program) : Dcalc.Ast.program * Dependency.TVer
}) })
prgm.program_scopes prgm.program_scopes
in in
(* the resulting expression is the list of definitions of all the scopes, ending with the (* the resulting expression is the list of definitions of all the scopes,
top-level scope. *) ending with the top-level scope. *)
let (scopes, decl_ctx) let (scopes, decl_ctx)
: (Ast.ScopeName.t * Dcalc.Ast.expr Bindlib.var * Dcalc.Ast.scope_body) list * _ = : (Ast.ScopeName.t * Dcalc.Ast.expr Bindlib.var * Dcalc.Ast.scope_body)
list
* _ =
List.fold_right List.fold_right
(fun scope_name (fun scope_name
((scopes, decl_ctx) : ((scopes, decl_ctx) :
(Ast.ScopeName.t * Dcalc.Ast.expr Bindlib.var * Dcalc.Ast.scope_body) list * _) -> (Ast.ScopeName.t
* Dcalc.Ast.expr Bindlib.var
* Dcalc.Ast.scope_body)
list
* _) ->
let scope = Ast.ScopeMap.find scope_name prgm.program_scopes in let scope = Ast.ScopeMap.find scope_name prgm.program_scopes in
let scope_body, scope_out_struct = let scope_body, scope_out_struct =
translate_scope_decl struct_ctx enum_ctx sctx scope_name scope translate_scope_decl struct_ctx enum_ctx sctx scope_name scope

View File

@ -1,20 +1,24 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits (* This file is part of the Catala compiler, a specification language for tax
computation rules. Copyright (C) 2020 Inria, contributor: Denis Merigoux and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
<denis.merigoux@inria.fr> Denis Merigoux <denis.merigoux@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
(** Scope language to default calculus translator *) (** Scope language to default calculus translator *)
val translate_program : Ast.program -> Dcalc.Ast.program * Dependency.TVertex.t list val translate_program :
(** Usage [translate_program p] returns a tuple [(new_program, types_list)] where [new_program] is Ast.program -> Dcalc.Ast.program * Dependency.TVertex.t list
the map of translated scopes. Finally, [types_list] is a list of all types (structs and enums) (** Usage [translate_program p] returns a tuple [(new_program, types_list)]
used in the program, correctly ordered with respect to inter-types dependency. *) where [new_program] is the map of translated scopes. Finally, [types_list]
is a list of all types (structs and enums) used in the program, correctly
ordered with respect to inter-types dependency. *)

View File

@ -1,15 +1,18 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits (* This file is part of the Catala compiler, a specification language for tax
computation rules. Copyright (C) 2020 Inria, contributors: Denis Merigoux and social benefits computation rules. Copyright (C) 2020 Inria,
<denis.merigoux@inria.fr>, Emile Rolley <emile.rolley@tuta.io> contributors: Denis Merigoux <denis.merigoux@inria.fr>, Emile Rolley
<emile.rolley@tuta.io>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
(** Abstract syntax tree built by the Catala parser *) (** Abstract syntax tree built by the Catala parser *)
@ -19,10 +22,11 @@
open Utils open Utils
(** {1 Visitor classes for programs} *) (** {1 Visitor classes for programs} *)
(** To allow for quick traversal and/or modification of this AST structure, we provide a (** To allow for quick traversal and/or modification of this AST structure, we
{{:https://en.wikipedia.org/wiki/Visitor_pattern} visitor design pattern}. This feature is provide a {{:https://en.wikipedia.org/wiki/Visitor_pattern} visitor design
implemented via {{:https://gitlab.inria.fr/fpottier/visitors} François Pottier's OCaml visitors pattern}. This feature is implemented via
library}. *) {{:https://gitlab.inria.fr/fpottier/visitors} François Pottier's OCaml
visitors library}. *)
(** {1 Type definitions} *) (** {1 Type definitions} *)
@ -40,9 +44,18 @@ type ident = (string[@opaque])
type qident = ident Pos.marked list type qident = ident Pos.marked list
[@@deriving [@@deriving
visitors { variety = "map"; ancestors = [ "Pos.marked_map"; "ident_map" ]; name = "qident_map" },
visitors visitors
{ variety = "iter"; ancestors = [ "Pos.marked_iter"; "ident_iter" ]; name = "qident_iter" }] {
variety = "map";
ancestors = [ "Pos.marked_map"; "ident_map" ];
name = "qident_map";
},
visitors
{
variety = "iter";
ancestors = [ "Pos.marked_iter"; "ident_iter" ];
name = "qident_iter";
}]
type primitive_typ = type primitive_typ =
| Integer | Integer
@ -54,10 +67,22 @@ type primitive_typ =
| Date | Date
| Named of constructor | Named of constructor
[@@deriving [@@deriving
visitors { variety = "map"; ancestors = [ "constructor_map" ]; name = "primitive_typ_map" }, visitors
visitors { variety = "iter"; ancestors = [ "constructor_iter" ]; name = "primitive_typ_iter" }] {
variety = "map";
ancestors = [ "constructor_map" ];
name = "primitive_typ_map";
},
visitors
{
variety = "iter";
ancestors = [ "constructor_iter" ];
name = "primitive_typ_iter";
}]
type base_typ_data = Primitive of primitive_typ | Collection of base_typ_data Pos.marked type base_typ_data =
| Primitive of primitive_typ
| Collection of base_typ_data Pos.marked
[@@deriving [@@deriving
visitors visitors
{ {
@ -75,7 +100,12 @@ type base_typ_data = Primitive of primitive_typ | Collection of base_typ_data Po
type base_typ = Condition | Data of base_typ_data type base_typ = Condition | Data of base_typ_data
[@@deriving [@@deriving
visitors visitors
{ variety = "map"; ancestors = [ "base_typ_data_map" ]; name = "base_typ_map"; nude = true }, {
variety = "map";
ancestors = [ "base_typ_data_map" ];
name = "base_typ_map";
nude = true;
},
visitors visitors
{ {
variety = "iter"; variety = "iter";
@ -84,16 +114,42 @@ type base_typ = Condition | Data of base_typ_data
nude = true; nude = true;
}] }]
type func_typ = { arg_typ : base_typ Pos.marked; return_typ : base_typ Pos.marked } type func_typ = {
arg_typ : base_typ Pos.marked;
return_typ : base_typ Pos.marked;
}
[@@deriving [@@deriving
visitors { variety = "map"; ancestors = [ "base_typ_map" ]; name = "func_typ_map"; nude = true },
visitors visitors
{ variety = "iter"; ancestors = [ "base_typ_iter" ]; name = "func_typ_iter"; nude = true }] {
variety = "map";
ancestors = [ "base_typ_map" ];
name = "func_typ_map";
nude = true;
},
visitors
{
variety = "iter";
ancestors = [ "base_typ_iter" ];
name = "func_typ_iter";
nude = true;
}]
type typ = Base of base_typ | Func of func_typ type typ = Base of base_typ | Func of func_typ
[@@deriving [@@deriving
visitors { variety = "map"; ancestors = [ "func_typ_map" ]; name = "typ_map"; nude = true }, visitors
visitors { variety = "iter"; ancestors = [ "func_typ_iter" ]; name = "typ_iter"; nude = true }] {
variety = "map";
ancestors = [ "func_typ_map" ];
name = "typ_map";
nude = true;
},
visitors
{
variety = "iter";
ancestors = [ "func_typ_iter" ];
name = "typ_iter";
nude = true;
}]
type struct_decl_field = { type struct_decl_field = {
struct_decl_field_name : ident Pos.marked; struct_decl_field_name : ident Pos.marked;
@ -101,7 +157,11 @@ type struct_decl_field = {
} }
[@@deriving [@@deriving
visitors visitors
{ variety = "map"; ancestors = [ "typ_map"; "ident_map" ]; name = "struct_decl_field_map" }, {
variety = "map";
ancestors = [ "typ_map"; "ident_map" ];
name = "struct_decl_field_map";
},
visitors visitors
{ {
variety = "iter"; variety = "iter";
@ -114,18 +174,38 @@ type struct_decl = {
struct_decl_fields : struct_decl_field Pos.marked list; struct_decl_fields : struct_decl_field Pos.marked list;
} }
[@@deriving [@@deriving
visitors { variety = "map"; ancestors = [ "struct_decl_field_map" ]; name = "struct_decl_map" },
visitors visitors
{ variety = "iter"; ancestors = [ "struct_decl_field_iter" ]; name = "struct_decl_iter" }] {
variety = "map";
ancestors = [ "struct_decl_field_map" ];
name = "struct_decl_map";
},
visitors
{
variety = "iter";
ancestors = [ "struct_decl_field_iter" ];
name = "struct_decl_iter";
}]
type enum_decl_case = { type enum_decl_case = {
enum_decl_case_name : constructor Pos.marked; enum_decl_case_name : constructor Pos.marked;
enum_decl_case_typ : typ Pos.marked option; enum_decl_case_typ : typ Pos.marked option;
} }
[@@deriving [@@deriving
visitors { variety = "map"; ancestors = [ "typ_map" ]; name = "enum_decl_case_map"; nude = true },
visitors visitors
{ variety = "iter"; ancestors = [ "typ_iter" ]; name = "enum_decl_case_iter"; nude = true }] {
variety = "map";
ancestors = [ "typ_map" ];
name = "enum_decl_case_map";
nude = true;
},
visitors
{
variety = "iter";
ancestors = [ "typ_iter" ];
name = "enum_decl_case_iter";
nude = true;
}]
type enum_decl = { type enum_decl = {
enum_decl_name : constructor Pos.marked; enum_decl_name : constructor Pos.marked;
@ -133,7 +213,12 @@ type enum_decl = {
} }
[@@deriving [@@deriving
visitors visitors
{ variety = "map"; ancestors = [ "enum_decl_case_map" ]; name = "enum_decl_map"; nude = true }, {
variety = "map";
ancestors = [ "enum_decl_case_map" ];
name = "enum_decl_map";
nude = true;
},
visitors visitors
{ {
variety = "iter"; variety = "iter";
@ -143,7 +228,8 @@ type enum_decl = {
}] }]
type match_case_pattern = type match_case_pattern =
(constructor Pos.marked option * constructor Pos.marked) list * ident Pos.marked option (constructor Pos.marked option * constructor Pos.marked) list
* ident Pos.marked option
[@@deriving [@@deriving
visitors visitors
{ {
@ -179,13 +265,37 @@ type binop =
| Neq | Neq
| Concat | Concat
[@@deriving [@@deriving
visitors { variety = "map"; ancestors = [ "op_kind_map" ]; name = "binop_map"; nude = true }, visitors
visitors { variety = "iter"; ancestors = [ "op_kind_iter" ]; name = "binop_iter"; nude = true }] {
variety = "map";
ancestors = [ "op_kind_map" ];
name = "binop_map";
nude = true;
},
visitors
{
variety = "iter";
ancestors = [ "op_kind_iter" ];
name = "binop_iter";
nude = true;
}]
type unop = Not | Minus of op_kind type unop = Not | Minus of op_kind
[@@deriving [@@deriving
visitors { variety = "map"; ancestors = [ "op_kind_map" ]; name = "unop_map"; nude = true }, visitors
visitors { variety = "iter"; ancestors = [ "op_kind_iter" ]; name = "unop_iter"; nude = true }] {
variety = "map";
ancestors = [ "op_kind_map" ];
name = "unop_map";
nude = true;
},
visitors
{
variety = "iter";
ancestors = [ "op_kind_iter" ];
name = "unop_iter";
nude = true;
}]
type builtin_expression = Cardinal | IntToDec | GetDay | GetMonth | GetYear type builtin_expression = Cardinal | IntToDec | GetDay | GetMonth | GetYear
[@@deriving [@@deriving
@ -198,8 +308,18 @@ type literal_date = {
literal_date_year : (int[@opaque]) Pos.marked; literal_date_year : (int[@opaque]) Pos.marked;
} }
[@@deriving [@@deriving
visitors { variety = "map"; ancestors = [ "Pos.marked_map" ]; name = "literal_date_map" }, visitors
visitors { variety = "iter"; ancestors = [ "Pos.marked_iter" ]; name = "literal_date_iter" }] {
variety = "map";
ancestors = [ "Pos.marked_map" ];
name = "literal_date_map";
},
visitors
{
variety = "iter";
ancestors = [ "Pos.marked_iter" ];
name = "literal_date_iter";
}]
type literal_number = type literal_number =
| Int of (Runtime.integer[@opaque]) | Int of (Runtime.integer[@opaque])
@ -231,14 +351,24 @@ type literal =
{ {
variety = "map"; variety = "map";
ancestors = ancestors =
[ "literal_number_map"; "money_amount_map"; "literal_date_map"; "literal_unit_map" ]; [
"literal_number_map";
"money_amount_map";
"literal_date_map";
"literal_unit_map";
];
name = "literal_map"; name = "literal_map";
}, },
visitors visitors
{ {
variety = "iter"; variety = "iter";
ancestors = ancestors =
[ "literal_number_iter"; "money_amount_iter"; "literal_date_iter"; "literal_unit_iter" ]; [
"literal_number_iter";
"money_amount_iter";
"literal_date_iter";
"literal_unit_iter";
];
name = "literal_iter"; name = "literal_iter";
}] }]
@ -248,35 +378,50 @@ type aggregate_func =
| AggregateExtremum of bool * primitive_typ * expression Pos.marked | AggregateExtremum of bool * primitive_typ * expression Pos.marked
| AggregateArgExtremum of bool * primitive_typ * expression Pos.marked | AggregateArgExtremum of bool * primitive_typ * expression Pos.marked
and collection_op = Exists | Forall | Aggregate of aggregate_func | Map | Filter and collection_op =
| Exists
| Forall
| Aggregate of aggregate_func
| Map
| Filter
and explicit_match_case = { and explicit_match_case = {
match_case_pattern : match_case_pattern Pos.marked; match_case_pattern : match_case_pattern Pos.marked;
match_case_expr : expression Pos.marked; match_case_expr : expression Pos.marked;
} }
and match_case = WildCard of expression Pos.marked | MatchCase of explicit_match_case and match_case =
| WildCard of expression Pos.marked
| MatchCase of explicit_match_case
and match_cases = match_case Pos.marked list and match_cases = match_case Pos.marked list
and expression = and expression =
| MatchWith of expression Pos.marked * match_cases Pos.marked | MatchWith of expression Pos.marked * match_cases Pos.marked
| IfThenElse of expression Pos.marked * expression Pos.marked * expression Pos.marked | IfThenElse of
expression Pos.marked * expression Pos.marked * expression Pos.marked
| Binop of binop Pos.marked * expression Pos.marked * expression Pos.marked | Binop of binop Pos.marked * expression Pos.marked * expression Pos.marked
| Unop of unop Pos.marked * expression Pos.marked | Unop of unop Pos.marked * expression Pos.marked
| CollectionOp of | CollectionOp of
collection_op Pos.marked * ident Pos.marked * expression Pos.marked * expression Pos.marked collection_op Pos.marked
* ident Pos.marked
* expression Pos.marked
* expression Pos.marked
| MemCollection of expression Pos.marked * expression Pos.marked | MemCollection of expression Pos.marked * expression Pos.marked
| TestMatchCase of expression Pos.marked * match_case_pattern Pos.marked | TestMatchCase of expression Pos.marked * match_case_pattern Pos.marked
| FunCall of expression Pos.marked * expression Pos.marked | FunCall of expression Pos.marked * expression Pos.marked
| Builtin of builtin_expression | Builtin of builtin_expression
| Literal of literal | Literal of literal
| EnumInject of | EnumInject of
constructor Pos.marked option * constructor Pos.marked * expression Pos.marked option constructor Pos.marked option
| StructLit of constructor Pos.marked * (ident Pos.marked * expression Pos.marked) list * constructor Pos.marked
* expression Pos.marked option
| StructLit of
constructor Pos.marked * (ident Pos.marked * expression Pos.marked) list
| ArrayLit of expression Pos.marked list | ArrayLit of expression Pos.marked list
| Ident of ident | Ident of ident
| Dotted of expression Pos.marked * constructor Pos.marked option * ident Pos.marked | Dotted of
expression Pos.marked * constructor Pos.marked option * ident Pos.marked
(** Dotted is for both struct field projection and sub-scope variables *) (** Dotted is for both struct field projection and sub-scope variables *)
[@@deriving [@@deriving
visitors visitors
@ -308,10 +453,17 @@ and expression =
name = "expression_iter"; name = "expression_iter";
}] }]
type exception_to = NotAnException | UnlabeledException | ExceptionToLabel of ident Pos.marked type exception_to =
| NotAnException
| UnlabeledException
| ExceptionToLabel of ident Pos.marked
[@@deriving [@@deriving
visitors visitors
{ variety = "map"; ancestors = [ "ident_map"; "Pos.marked_map" ]; name = "exception_to_map" }, {
variety = "map";
ancestors = [ "ident_map"; "Pos.marked_map" ];
name = "exception_to_map";
},
visitors visitors
{ {
variety = "iter"; variety = "iter";
@ -374,7 +526,10 @@ type variation_typ = Increasing | Decreasing
type meta_assertion = type meta_assertion =
| FixedBy of qident Pos.marked * ident Pos.marked | FixedBy of qident Pos.marked * ident Pos.marked
| VariesWith of qident Pos.marked * expression Pos.marked * variation_typ Pos.marked option | VariesWith of
qident Pos.marked
* expression Pos.marked
* variation_typ Pos.marked option
[@@deriving [@@deriving
visitors visitors
{ {
@ -394,8 +549,18 @@ type assertion = {
assertion_content : expression Pos.marked; assertion_content : expression Pos.marked;
} }
[@@deriving [@@deriving
visitors { variety = "map"; ancestors = [ "expression_map" ]; name = "assertion_map" }, visitors
visitors { variety = "iter"; ancestors = [ "expression_iter" ]; name = "assertion_iter" }] {
variety = "map";
ancestors = [ "expression_map" ];
name = "assertion_map";
},
visitors
{
variety = "iter";
ancestors = [ "expression_iter" ];
name = "assertion_iter";
}]
type scope_use_item = type scope_use_item =
| Rule of rule | Rule of rule
@ -406,13 +571,20 @@ type scope_use_item =
visitors visitors
{ {
variety = "map"; variety = "map";
ancestors = [ "meta_assertion_map"; "definition_map"; "assertion_map"; "rule_map" ]; ancestors =
[ "meta_assertion_map"; "definition_map"; "assertion_map"; "rule_map" ];
name = "scope_use_item_map"; name = "scope_use_item_map";
}, },
visitors visitors
{ {
variety = "iter"; variety = "iter";
ancestors = [ "meta_assertion_iter"; "definition_iter"; "assertion_iter"; "rule_iter" ]; ancestors =
[
"meta_assertion_iter";
"definition_iter";
"assertion_iter";
"rule_iter";
];
name = "scope_use_item_iter"; name = "scope_use_item_iter";
}] }]
@ -467,14 +639,25 @@ type scope_decl_context_scope = {
visitors visitors
{ {
variety = "map"; variety = "map";
ancestors = [ "ident_map"; "constructor_map"; "scope_decl_context_io_map"; "Pos.marked_map" ]; ancestors =
[
"ident_map";
"constructor_map";
"scope_decl_context_io_map";
"Pos.marked_map";
];
name = "scope_decl_context_scope_map"; name = "scope_decl_context_scope_map";
}, },
visitors visitors
{ {
variety = "iter"; variety = "iter";
ancestors = ancestors =
[ "ident_iter"; "constructor_iter"; "scope_decl_context_io_iter"; "Pos.marked_iter" ]; [
"ident_iter";
"constructor_iter";
"scope_decl_context_io_iter";
"Pos.marked_iter";
];
name = "scope_decl_context_scope_iter"; name = "scope_decl_context_scope_iter";
}] }]
@ -505,13 +688,15 @@ type scope_decl_context_item =
visitors visitors
{ {
variety = "map"; variety = "map";
ancestors = [ "scope_decl_context_data_map"; "scope_decl_context_scope_map" ]; ancestors =
[ "scope_decl_context_data_map"; "scope_decl_context_scope_map" ];
name = "scope_decl_context_item_map"; name = "scope_decl_context_item_map";
}, },
visitors visitors
{ {
variety = "iter"; variety = "iter";
ancestors = [ "scope_decl_context_data_iter"; "scope_decl_context_scope_iter" ]; ancestors =
[ "scope_decl_context_data_iter"; "scope_decl_context_scope_iter" ];
name = "scope_decl_context_item_iter"; name = "scope_decl_context_item_iter";
}] }]
@ -521,9 +706,17 @@ type scope_decl = {
} }
[@@deriving [@@deriving
visitors visitors
{ variety = "map"; ancestors = [ "scope_decl_context_item_map" ]; name = "scope_decl_map" }, {
variety = "map";
ancestors = [ "scope_decl_context_item_map" ];
name = "scope_decl_map";
},
visitors visitors
{ variety = "iter"; ancestors = [ "scope_decl_context_item_iter" ]; name = "scope_decl_iter" }] {
variety = "iter";
ancestors = [ "scope_decl_context_item_iter" ];
name = "scope_decl_iter";
}]
type code_item = type code_item =
| ScopeUse of scope_use | ScopeUse of scope_use
@ -534,25 +727,54 @@ type code_item =
visitors visitors
{ {
variety = "map"; variety = "map";
ancestors = [ "scope_decl_map"; "enum_decl_map"; "struct_decl_map"; "scope_use_map" ]; ancestors =
[
"scope_decl_map"; "enum_decl_map"; "struct_decl_map"; "scope_use_map";
];
name = "code_item_map"; name = "code_item_map";
}, },
visitors visitors
{ {
variety = "iter"; variety = "iter";
ancestors = [ "scope_decl_iter"; "enum_decl_iter"; "struct_decl_iter"; "scope_use_iter" ]; ancestors =
[
"scope_decl_iter";
"enum_decl_iter";
"struct_decl_iter";
"scope_use_iter";
];
name = "code_item_iter"; name = "code_item_iter";
}] }]
type code_block = code_item Pos.marked list type code_block = code_item Pos.marked list
[@@deriving [@@deriving
visitors { variety = "map"; ancestors = [ "code_item_map" ]; name = "code_block_map" }, visitors
visitors { variety = "iter"; ancestors = [ "code_item_iter" ]; name = "code_block_iter" }] {
variety = "map";
ancestors = [ "code_item_map" ];
name = "code_block_map";
},
visitors
{
variety = "iter";
ancestors = [ "code_item_iter" ];
name = "code_block_iter";
}]
type source_repr = (string[@opaque]) Pos.marked type source_repr = (string[@opaque]) Pos.marked
[@@deriving [@@deriving
visitors { variety = "map"; ancestors = [ "Pos.marked_map" ]; name = "source_repr_map" }, visitors
visitors { variety = "iter"; ancestors = [ "Pos.marked_iter" ]; name = "source_repr_iter" }] {
variety = "map";
ancestors = [ "Pos.marked_map" ];
name = "source_repr_map";
},
visitors
{
variety = "iter";
ancestors = [ "Pos.marked_iter" ];
name = "source_repr_iter";
}]
type law_heading = { type law_heading = {
law_heading_name : (string[@opaque]) Pos.marked; law_heading_name : (string[@opaque]) Pos.marked;
@ -561,16 +783,36 @@ type law_heading = {
law_heading_precedence : (int[@opaque]); law_heading_precedence : (int[@opaque]);
} }
[@@deriving [@@deriving
visitors { variety = "map"; ancestors = [ "Pos.marked_map" ]; name = "law_heading_map" }, visitors
visitors { variety = "iter"; ancestors = [ "Pos.marked_iter" ]; name = "law_heading_iter" }] {
variety = "map";
ancestors = [ "Pos.marked_map" ];
name = "law_heading_map";
},
visitors
{
variety = "iter";
ancestors = [ "Pos.marked_iter" ];
name = "law_heading_iter";
}]
type law_include = type law_include =
| PdfFile of (string[@opaque]) Pos.marked * (int[@opaque]) option | PdfFile of (string[@opaque]) Pos.marked * (int[@opaque]) option
| CatalaFile of (string[@opaque]) Pos.marked | CatalaFile of (string[@opaque]) Pos.marked
| LegislativeText of (string[@opaque]) Pos.marked | LegislativeText of (string[@opaque]) Pos.marked
[@@deriving [@@deriving
visitors { variety = "map"; ancestors = [ "Pos.marked_map" ]; name = "law_include_map" }, visitors
visitors { variety = "iter"; ancestors = [ "Pos.marked_iter" ]; name = "law_include_iter" }] {
variety = "map";
ancestors = [ "Pos.marked_map" ];
name = "law_include_map";
},
visitors
{
variety = "iter";
ancestors = [ "Pos.marked_iter" ];
name = "law_include_iter";
}]
type law_structure = type law_structure =
| LawInclude of law_include | LawInclude of law_include
@ -581,21 +823,45 @@ type law_structure =
visitors visitors
{ {
variety = "map"; variety = "map";
ancestors = [ "law_include_map"; "code_block_map"; "source_repr_map"; "law_heading_map" ]; ancestors =
[
"law_include_map";
"code_block_map";
"source_repr_map";
"law_heading_map";
];
name = "law_structure_map"; name = "law_structure_map";
}, },
visitors visitors
{ {
variety = "iter"; variety = "iter";
ancestors = ancestors =
[ "law_include_iter"; "code_block_iter"; "source_repr_iter"; "law_heading_iter" ]; [
"law_include_iter";
"code_block_iter";
"source_repr_iter";
"law_heading_iter";
];
name = "law_structure_iter"; name = "law_structure_iter";
}] }]
type program = { program_items : law_structure list; program_source_files : (string[@opaque]) list } type program = {
program_items : law_structure list;
program_source_files : (string[@opaque]) list;
}
[@@deriving [@@deriving
visitors { variety = "map"; ancestors = [ "law_structure_map" ]; name = "program_map" }, visitors
visitors { variety = "iter"; ancestors = [ "law_structure_iter" ]; name = "program_iter" }] {
variety = "map";
ancestors = [ "law_structure_map" ];
name = "program_map";
},
visitors
{
variety = "iter";
ancestors = [ "law_structure_iter" ];
name = "program_iter";
}]
type source_file = law_structure list type source_file = law_structure list

File diff suppressed because it is too large Load Diff

View File

@ -1,15 +1,18 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits (* This file is part of the Catala compiler, a specification language for tax
computation rules. Copyright (C) 2020 Inria, contributor: Nicolas Chataing and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
<nicolas.chataing@ens.fr> Denis Merigoux <denis.merigoux@inria.fr> Nicolas Chataing <nicolas.chataing@ens.fr> Denis Merigoux
<denis.merigoux@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
(** Translation from {!module: Surface.Ast} to {!module: Desugared.Ast}. (** Translation from {!module: Surface.Ast} to {!module: Desugared.Ast}.
@ -17,5 +20,6 @@
- Removes syntactic sugars - Removes syntactic sugars
- Separate code from legislation *) - Separate code from legislation *)
val desugar_program : Name_resolution.context -> Ast.program -> Desugared.Ast.program val desugar_program :
Name_resolution.context -> Ast.program -> Desugared.Ast.program
(** Main function of this module *) (** Main function of this module *)

View File

@ -1,15 +1,17 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits (* This file is part of the Catala compiler, a specification language for tax
computation rules. Copyright (C) 2020 Inria, contributor: Denis Merigoux and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
<denis.merigoux@inria.fr> Denis Merigoux <denis.merigoux@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
open Utils open Utils
@ -22,12 +24,15 @@ let fill_pos_with_legislative_info (p : Ast.program) : Ast.program =
method! visit_marked f env x = method! visit_marked f env x =
(f env (Pos.unmark x), Pos.overwrite_law_info (Pos.get_position x) env) (f env (Pos.unmark x), Pos.overwrite_law_info (Pos.get_position x) env)
method! visit_LawHeading (env : string list) (heading : Ast.law_heading) method! visit_LawHeading
(env : string list)
(heading : Ast.law_heading)
(children : Ast.law_structure list) = (children : Ast.law_structure list) =
let env = Pos.unmark heading.law_heading_name :: env in let env = Pos.unmark heading.law_heading_name :: env in
Ast.LawHeading Ast.LawHeading
( super#visit_law_heading env heading, ( super#visit_law_heading env heading,
List.map (fun child -> super#visit_law_structure env child) children ) List.map (fun child -> super#visit_law_structure env child) children
)
end end
in in
visitor#visit_program [] p visitor#visit_program [] p

View File

@ -1,18 +1,20 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits (* This file is part of the Catala compiler, a specification language for tax
computation rules. Copyright (C) 2020 Inria, contributor: Denis Merigoux and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
<denis.merigoux@inria.fr> Denis Merigoux <denis.merigoux@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
(** Fills the position tags in the AST with info about the legislative article this position belongs (** Fills the position tags in the AST with info about the legislative article
to. *) this position belongs to. *)
val fill_pos_with_legislative_info : Ast.program -> Ast.program val fill_pos_with_legislative_info : Ast.program -> Ast.program

View File

@ -1,15 +1,18 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits (* This file is part of the Catala compiler, a specification language for tax
computation rules. Copyright (C) 2020 Inria, contributors: Denis Merigoux and social benefits computation rules. Copyright (C) 2020 Inria,
<denis.merigoux@inria.fr>, Emile Rolley <emile.rolley@tuta.io> contributors: Denis Merigoux <denis.merigoux@inria.fr>, Emile Rolley
<emile.rolley@tuta.io>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
open Tokens open Tokens
@ -17,45 +20,57 @@ open Sedlexing
open Utils open Utils
module R = Re.Pcre module R = Re.Pcre
(* Calculates the precedence according a {!val: matched_regex} of the form : '[#]+'. (* Calculates the precedence according a {!val: matched_regex} of the form :
'[#]+'.
@note -2 because [LAW_HEADING] start with at least "#" and the number of '#' remaining @note -2 because [LAW_HEADING] start with at least "#" and the number of '#'
corresponds to the precedence. *) remaining corresponds to the precedence. *)
let calc_precedence (matched_regex : string) : int = String.length matched_regex - 1 let calc_precedence (matched_regex : string) : int =
String.length matched_regex - 1
(* Gets the [LAW_HEADING] token from the current {!val: lexbuf} *) (* Gets the [LAW_HEADING] token from the current {!val: lexbuf} *)
let get_law_heading (lexbuf : lexbuf) : token = let get_law_heading (lexbuf : lexbuf) : token =
let extract_article_title = let extract_article_title =
R.regexp "([#]+)\\s*([^\\|]+)(\\|([^\\|]+)|)(\\|\\s*([0-9]{4}\\-[0-9]{2}\\-[0-9]{2})|)" R.regexp
"([#]+)\\s*([^\\|]+)(\\|([^\\|]+)|)(\\|\\s*([0-9]{4}\\-[0-9]{2}\\-[0-9]{2})|)"
in
let get_substring =
R.get_substring (R.exec ~rex:extract_article_title (Utf8.lexeme lexbuf))
in in
let get_substring = R.get_substring (R.exec ~rex:extract_article_title (Utf8.lexeme lexbuf)) in
let title = String.trim (get_substring 2) in let title = String.trim (get_substring 2) in
let article_id = try Some (String.trim (get_substring 4)) with Not_found -> None in let article_id =
let article_expiration_date = try Some (String.trim (get_substring 6)) with Not_found -> None in try Some (String.trim (get_substring 4)) with Not_found -> None
in
let article_expiration_date =
try Some (String.trim (get_substring 6)) with Not_found -> None
in
let precedence = calc_precedence (String.trim (get_substring 1)) in let precedence = calc_precedence (String.trim (get_substring 1)) in
LAW_HEADING (title, article_id, article_expiration_date, precedence) LAW_HEADING (title, article_id, article_expiration_date, precedence)
type lexing_context = Law | Code | Directive | Directive_args type lexing_context = Law | Code | Directive | Directive_args
(** Boolean reference, used by the lexer as the mutable state to distinguish whether it is lexing (** Boolean reference, used by the lexer as the mutable state to distinguish
code or law. *) whether it is lexing code or law. *)
let context : lexing_context ref = ref Law let context : lexing_context ref = ref Law
(** Mutable string reference that accumulates the string representation of the body of code being (** Mutable string reference that accumulates the string representation of the
lexed. This string representation is used in the literate programming backends to faithfully body of code being lexed. This string representation is used in the literate
capture the spacing pattern of the original program *) programming backends to faithfully capture the spacing pattern of the
original program *)
let code_buffer : Buffer.t = Buffer.create 4000 let code_buffer : Buffer.t = Buffer.create 4000
(** Updates {!val:code_buffer} with the current lexeme *) (** Updates {!val:code_buffer} with the current lexeme *)
let update_acc (lexbuf : lexbuf) : unit = Buffer.add_string code_buffer (Utf8.lexeme lexbuf) let update_acc (lexbuf : lexbuf) : unit =
Buffer.add_string code_buffer (Utf8.lexeme lexbuf)
(** Error-generating helper *) (** Error-generating helper *)
let raise_lexer_error (loc : Pos.t) (token : string) = let raise_lexer_error (loc : Pos.t) (token : string) =
Errors.raise_spanned_error loc "Parsing error after token \"%s\": what comes after is unknown" Errors.raise_spanned_error loc
token "Parsing error after token \"%s\": what comes after is unknown" token
(** Associative list matching each punctuation string part of the Catala syntax with its {!module: (** Associative list matching each punctuation string part of the Catala syntax
Surface.Parser} token. Same for all the input languages (English, French, etc.) *) with its {!module: Surface.Parser} token. Same for all the input languages
(English, French, etc.) *)
let token_list_language_agnostic : (string * token) list = let token_list_language_agnostic : (string * token) list =
[ [
(".", DOT); (".", DOT);
@ -83,7 +98,8 @@ let token_list_language_agnostic : (string * token) list =
module type LocalisedLexer = sig module type LocalisedLexer = sig
val token_list : (string * Tokens.token) list val token_list : (string * Tokens.token) list
(** Same as {!val: token_list_language_agnostic}, but with tokens specialized to a given language. *) (** Same as {!val: token_list_language_agnostic}, but with tokens specialized
to a given language. *)
val lex_builtin : string -> Ast.builtin_expression option val lex_builtin : string -> Ast.builtin_expression option
(** Simple lexer for builtins *) (** Simple lexer for builtins *)
@ -95,6 +111,7 @@ module type LocalisedLexer = sig
(** Main lexing function used outside code blocks *) (** Main lexing function used outside code blocks *)
val lexer : Sedlexing.lexbuf -> Tokens.token val lexer : Sedlexing.lexbuf -> Tokens.token
(** Entry point of the lexer, distributes to {!val: lex_code} or {!val:lex_law} depending of the (** Entry point of the lexer, distributes to {!val: lex_code} or
current {!val: Surface.Lexer_common.context}. *) {!val:lex_law} depending of the current {!val:
Surface.Lexer_common.context}. *)
end end

View File

@ -1,15 +1,18 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits (* This file is part of the Catala compiler, a specification language for tax
computation rules. Copyright (C) 2020 Inria, contributors: Denis Merigoux and social benefits computation rules. Copyright (C) 2020 Inria,
<denis.merigoux@inria.fr>, Emile Rolley <emile.rolley@tuta.io> contributors: Denis Merigoux <denis.merigoux@inria.fr>, Emile Rolley
<emile.rolley@tuta.io>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
(** Auxiliary functions used by all lexers. *) (** Auxiliary functions used by all lexers. *)
@ -17,13 +20,13 @@
type lexing_context = Law | Code | Directive | Directive_args type lexing_context = Law | Code | Directive | Directive_args
val context : lexing_context ref val context : lexing_context ref
(** Reference, used by the lexer as the mutable state to distinguish whether it is lexing code or (** Reference, used by the lexer as the mutable state to distinguish whether it
law. *) is lexing code or law. *)
val code_buffer : Buffer.t val code_buffer : Buffer.t
(** Buffer that accumulates the string representation of the body of code being lexed. This string (** Buffer that accumulates the string representation of the body of code being
representation is used in the literate programming backends to faithfully capture the spacing lexed. This string representation is used in the literate programming
pattern of the original program *) backends to faithfully capture the spacing pattern of the original program *)
val update_acc : Sedlexing.lexbuf -> unit val update_acc : Sedlexing.lexbuf -> unit
(** Updates {!val:code_buffer} with the current lexeme *) (** Updates {!val:code_buffer} with the current lexeme *)
@ -32,8 +35,9 @@ val raise_lexer_error : Utils.Pos.t -> string -> 'a
(** Error-generating helper *) (** Error-generating helper *)
val token_list_language_agnostic : (string * Tokens.token) list val token_list_language_agnostic : (string * Tokens.token) list
(** Associative list matching each punctuation string part of the Catala syntax with its (** Associative list matching each punctuation string part of the Catala syntax
{!Surface.Parser} token. Same for all the input languages (English, French, etc.) *) with its {!Surface.Parser} token. Same for all the input languages (English,
French, etc.) *)
val calc_precedence : string -> int val calc_precedence : string -> int
(** Calculates the precedence according a matched regex of the form : '[#]+' *) (** Calculates the precedence according a matched regex of the form : '[#]+' *)
@ -43,8 +47,8 @@ val get_law_heading : Sedlexing.lexbuf -> Tokens.token
module type LocalisedLexer = sig module type LocalisedLexer = sig
val token_list : (string * Tokens.token) list val token_list : (string * Tokens.token) list
(** Same as {!val: Surface.Lexer_common.token_list_language_agnostic}, but with tokens whose (** Same as {!val: Surface.Lexer_common.token_list_language_agnostic}, but
string varies with the input language. *) with tokens whose string varies with the input language. *)
val lex_builtin : string -> Ast.builtin_expression option val lex_builtin : string -> Ast.builtin_expression option
(** Simple lexer for builtins *) (** Simple lexer for builtins *)
@ -56,6 +60,7 @@ module type LocalisedLexer = sig
(** Main lexing function used outside code blocks *) (** Main lexing function used outside code blocks *)
val lexer : Sedlexing.lexbuf -> Tokens.token val lexer : Sedlexing.lexbuf -> Tokens.token
(** Entry point of the lexer, distributes to {!val: lex_code} or {!val:lex_law} depending of the (** Entry point of the lexer, distributes to {!val: lex_code} or
current {!val: Surface.Lexer_common.context}. *) {!val:lex_law} depending of the current {!val:
Surface.Lexer_common.context}. *)
end end

View File

@ -1,15 +1,18 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits (* This file is part of the Catala compiler, a specification language for tax
computation rules. Copyright (C) 2020 Inria, contributors: Denis Merigoux and social benefits computation rules. Copyright (C) 2020 Inria,
<denis.merigoux@inria.fr>, Emile Rolley <emile.rolley@tuta.io> contributors: Denis Merigoux <denis.merigoux@inria.fr>, Emile Rolley
<emile.rolley@tuta.io>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
include Lexer_common.LocalisedLexer include Lexer_common.LocalisedLexer

View File

@ -1,15 +1,18 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits (* This file is part of the Catala compiler, a specification language for tax
computation rules. Copyright (C) 2020 Inria, contributors: Denis Merigoux and social benefits computation rules. Copyright (C) 2020 Inria,
<denis.merigoux@inria.fr>, Emile Rolley <emile.rolley@tuta.io> contributors: Denis Merigoux <denis.merigoux@inria.fr>, Emile Rolley
<emile.rolley@tuta.io>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
include Lexer_common.LocalisedLexer include Lexer_common.LocalisedLexer

View File

@ -1,15 +1,18 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits (* This file is part of the Catala compiler, a specification language for tax
computation rules. Copyright (C) 2020 Inria, contributors: Denis Merigoux and social benefits computation rules. Copyright (C) 2020 Inria,
<denis.merigoux@inria.fr>, Emile Rolley <emile.rolley@tuta.io> contributors: Denis Merigoux <denis.merigoux@inria.fr>, Emile Rolley
<emile.rolley@tuta.io>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
include Lexer_common.LocalisedLexer include Lexer_common.LocalisedLexer

View File

@ -1,29 +1,33 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits (* This file is part of the Catala compiler, a specification language for tax
computation rules. Copyright (C) 2020 Inria, contributor: Nicolas Chataing and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
<nicolas.chataing@ens.fr> Denis Merigoux <denis.merigoux@inria.fr> Nicolas Chataing <nicolas.chataing@ens.fr> Denis Merigoux
<denis.merigoux@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
(** Builds a context that allows for mapping each name to a precise uid, taking lexical scopes into (** Builds a context that allows for mapping each name to a precise uid, taking
account *) lexical scopes into account *)
open Utils open Utils
(** {1 Name resolution context} *) (** {1 Name resolution context} *)
type ident = string type ident = string
type typ = Scopelang.Ast.typ type typ = Scopelang.Ast.typ
type unique_rulename = Ambiguous of Pos.t list | Unique of Desugared.Ast.RuleName.t Pos.marked type unique_rulename =
| Ambiguous of Pos.t list
| Unique of Desugared.Ast.RuleName.t Pos.marked
type scope_def_context = { type scope_def_context = {
default_exception_rulename : unique_rulename option; default_exception_rulename : unique_rulename option;
@ -32,7 +36,8 @@ type scope_def_context = {
} }
type scope_context = { type scope_context = {
var_idmap : Desugared.Ast.ScopeVar.t Desugared.Ast.IdentMap.t; (** Scope variables *) var_idmap : Desugared.Ast.ScopeVar.t Desugared.Ast.IdentMap.t;
(** Scope variables *)
scope_defs_contexts : scope_def_context Desugared.Ast.ScopeDefMap.t; scope_defs_contexts : scope_def_context Desugared.Ast.ScopeDefMap.t;
(** What is the default rule to refer to for unnamed exceptions, if any *) (** What is the default rule to refer to for unnamed exceptions, if any *)
sub_scopes_idmap : Scopelang.Ast.SubScopeName.t Desugared.Ast.IdentMap.t; sub_scopes_idmap : Scopelang.Ast.SubScopeName.t Desugared.Ast.IdentMap.t;
@ -58,21 +63,30 @@ type var_sig = {
type context = { type context = {
local_var_idmap : Desugared.Ast.Var.t Desugared.Ast.IdentMap.t; local_var_idmap : Desugared.Ast.Var.t Desugared.Ast.IdentMap.t;
(** Inside a definition, local variables can be introduced by functions arguments or pattern (** Inside a definition, local variables can be introduced by functions
matching *) arguments or pattern matching *)
scope_idmap : Scopelang.Ast.ScopeName.t Desugared.Ast.IdentMap.t; (** The names of the scopes *) scope_idmap : Scopelang.Ast.ScopeName.t Desugared.Ast.IdentMap.t;
(** The names of the scopes *)
struct_idmap : Scopelang.Ast.StructName.t Desugared.Ast.IdentMap.t; struct_idmap : Scopelang.Ast.StructName.t Desugared.Ast.IdentMap.t;
(** The names of the structs *) (** The names of the structs *)
field_idmap : Scopelang.Ast.StructFieldName.t Scopelang.Ast.StructMap.t Desugared.Ast.IdentMap.t; field_idmap :
(** The names of the struct fields. Names of fields can be shared between different structs *) Scopelang.Ast.StructFieldName.t Scopelang.Ast.StructMap.t
enum_idmap : Scopelang.Ast.EnumName.t Desugared.Ast.IdentMap.t; (** The names of the enums *) Desugared.Ast.IdentMap.t;
(** The names of the struct fields. Names of fields can be shared between
different structs *)
enum_idmap : Scopelang.Ast.EnumName.t Desugared.Ast.IdentMap.t;
(** The names of the enums *)
constructor_idmap : constructor_idmap :
Scopelang.Ast.EnumConstructor.t Scopelang.Ast.EnumMap.t Desugared.Ast.IdentMap.t; Scopelang.Ast.EnumConstructor.t Scopelang.Ast.EnumMap.t
(** The names of the enum constructors. Constructor names can be shared between different Desugared.Ast.IdentMap.t;
enums *) (** The names of the enum constructors. Constructor names can be shared
scopes : scope_context Scopelang.Ast.ScopeMap.t; (** For each scope, its context *) between different enums *)
structs : struct_context Scopelang.Ast.StructMap.t; (** For each struct, its context *) scopes : scope_context Scopelang.Ast.ScopeMap.t;
enums : enum_context Scopelang.Ast.EnumMap.t; (** For each enum, its context *) (** For each scope, its context *)
structs : struct_context Scopelang.Ast.StructMap.t;
(** For each struct, its context *)
enums : enum_context Scopelang.Ast.EnumMap.t;
(** For each enum, its context *)
var_typs : var_sig Desugared.Ast.ScopeVarMap.t; var_typs : var_sig Desugared.Ast.ScopeVarMap.t;
(** The signatures of each scope variable declared *) (** The signatures of each scope variable declared *)
} }
@ -80,53 +94,67 @@ type context = {
(** {1 Helpers} *) (** {1 Helpers} *)
(** Temporary function raising an error message saying that a feature is not supported yet *) (** Temporary function raising an error message saying that a feature is not
supported yet *)
let raise_unsupported_feature (msg : string) (pos : Pos.t) = let raise_unsupported_feature (msg : string) (pos : Pos.t) =
Errors.raise_spanned_error pos "Unsupported feature: %s" msg Errors.raise_spanned_error pos "Unsupported feature: %s" msg
(** Function to call whenever an identifier used somewhere has not been declared in the program (** Function to call whenever an identifier used somewhere has not been declared
previously *) in the program previously *)
let raise_unknown_identifier (msg : string) (ident : ident Pos.marked) = let raise_unknown_identifier (msg : string) (ident : ident Pos.marked) =
Errors.raise_spanned_error (Pos.get_position ident) "\"%s\": unknown identifier %s" Errors.raise_spanned_error (Pos.get_position ident)
"\"%s\": unknown identifier %s"
(Utils.Cli.with_style [ ANSITerminal.yellow ] "%s" (Pos.unmark ident)) (Utils.Cli.with_style [ ANSITerminal.yellow ] "%s" (Pos.unmark ident))
msg msg
(** Gets the type associated to an uid *) (** Gets the type associated to an uid *)
let get_var_typ (ctxt : context) (uid : Desugared.Ast.ScopeVar.t) : typ Pos.marked = let get_var_typ (ctxt : context) (uid : Desugared.Ast.ScopeVar.t) :
typ Pos.marked =
(Desugared.Ast.ScopeVarMap.find uid ctxt.var_typs).var_sig_typ (Desugared.Ast.ScopeVarMap.find uid ctxt.var_typs).var_sig_typ
let is_var_cond (ctxt : context) (uid : Desugared.Ast.ScopeVar.t) : bool = let is_var_cond (ctxt : context) (uid : Desugared.Ast.ScopeVar.t) : bool =
(Desugared.Ast.ScopeVarMap.find uid ctxt.var_typs).var_sig_is_condition (Desugared.Ast.ScopeVarMap.find uid ctxt.var_typs).var_sig_is_condition
let get_var_io (ctxt : context) (uid : Desugared.Ast.ScopeVar.t) : Ast.scope_decl_context_io = let get_var_io (ctxt : context) (uid : Desugared.Ast.ScopeVar.t) :
Ast.scope_decl_context_io =
(Desugared.Ast.ScopeVarMap.find uid ctxt.var_typs).var_sig_io (Desugared.Ast.ScopeVarMap.find uid ctxt.var_typs).var_sig_io
(** Get the variable uid inside the scope given in argument *) (** Get the variable uid inside the scope given in argument *)
let get_var_uid (scope_uid : Scopelang.Ast.ScopeName.t) (ctxt : context) let get_var_uid
(scope_uid : Scopelang.Ast.ScopeName.t)
(ctxt : context)
((x, pos) : ident Pos.marked) : Desugared.Ast.ScopeVar.t = ((x, pos) : ident Pos.marked) : Desugared.Ast.ScopeVar.t =
let scope = Scopelang.Ast.ScopeMap.find scope_uid ctxt.scopes in let scope = Scopelang.Ast.ScopeMap.find scope_uid ctxt.scopes in
match Desugared.Ast.IdentMap.find_opt x scope.var_idmap with match Desugared.Ast.IdentMap.find_opt x scope.var_idmap with
| None -> | None ->
raise_unknown_identifier raise_unknown_identifier
(Format.asprintf "for a variable of scope %a" Scopelang.Ast.ScopeName.format_t scope_uid) (Format.asprintf "for a variable of scope %a"
Scopelang.Ast.ScopeName.format_t scope_uid)
(x, pos) (x, pos)
| Some uid -> uid | Some uid -> uid
(** Get the subscope uid inside the scope given in argument *) (** Get the subscope uid inside the scope given in argument *)
let get_subscope_uid (scope_uid : Scopelang.Ast.ScopeName.t) (ctxt : context) let get_subscope_uid
(scope_uid : Scopelang.Ast.ScopeName.t)
(ctxt : context)
((y, pos) : ident Pos.marked) : Scopelang.Ast.SubScopeName.t = ((y, pos) : ident Pos.marked) : Scopelang.Ast.SubScopeName.t =
let scope = Scopelang.Ast.ScopeMap.find scope_uid ctxt.scopes in let scope = Scopelang.Ast.ScopeMap.find scope_uid ctxt.scopes in
match Desugared.Ast.IdentMap.find_opt y scope.sub_scopes_idmap with match Desugared.Ast.IdentMap.find_opt y scope.sub_scopes_idmap with
| None -> raise_unknown_identifier "for a subscope of this scope" (y, pos) | None -> raise_unknown_identifier "for a subscope of this scope" (y, pos)
| Some sub_uid -> sub_uid | Some sub_uid -> sub_uid
(** [is_subscope_uid scope_uid ctxt y] returns true if [y] belongs to the subscopes of [scope_uid]. *) (** [is_subscope_uid scope_uid ctxt y] returns true if [y] belongs to the
let is_subscope_uid (scope_uid : Scopelang.Ast.ScopeName.t) (ctxt : context) (y : ident) : bool = subscopes of [scope_uid]. *)
let is_subscope_uid
(scope_uid : Scopelang.Ast.ScopeName.t) (ctxt : context) (y : ident) : bool
=
let scope = Scopelang.Ast.ScopeMap.find scope_uid ctxt.scopes in let scope = Scopelang.Ast.ScopeMap.find scope_uid ctxt.scopes in
Desugared.Ast.IdentMap.mem y scope.sub_scopes_idmap Desugared.Ast.IdentMap.mem y scope.sub_scopes_idmap
(** Checks if the var_uid belongs to the scope scope_uid *) (** Checks if the var_uid belongs to the scope scope_uid *)
let belongs_to (ctxt : context) (uid : Desugared.Ast.ScopeVar.t) let belongs_to
(ctxt : context)
(uid : Desugared.Ast.ScopeVar.t)
(scope_uid : Scopelang.Ast.ScopeName.t) : bool = (scope_uid : Scopelang.Ast.ScopeName.t) : bool =
let scope = Scopelang.Ast.ScopeMap.find scope_uid ctxt.scopes in let scope = Scopelang.Ast.ScopeMap.find scope_uid ctxt.scopes in
Desugared.Ast.IdentMap.exists Desugared.Ast.IdentMap.exists
@ -134,24 +162,28 @@ let belongs_to (ctxt : context) (uid : Desugared.Ast.ScopeVar.t)
scope.var_idmap scope.var_idmap
(** Retrieves the type of a scope definition from the context *) (** Retrieves the type of a scope definition from the context *)
let get_def_typ (ctxt : context) (def : Desugared.Ast.ScopeDef.t) : typ Pos.marked = let get_def_typ (ctxt : context) (def : Desugared.Ast.ScopeDef.t) :
typ Pos.marked =
match def with match def with
| Desugared.Ast.ScopeDef.SubScopeVar (_, x) | Desugared.Ast.ScopeDef.SubScopeVar (_, x)
(* we don't need to look at the subscope prefix because [x] is already the uid referring back to (* we don't need to look at the subscope prefix because [x] is already the uid
the original subscope *) referring back to the original subscope *)
| Desugared.Ast.ScopeDef.Var (x, _) -> | Desugared.Ast.ScopeDef.Var (x, _) ->
get_var_typ ctxt x get_var_typ ctxt x
let is_def_cond (ctxt : context) (def : Desugared.Ast.ScopeDef.t) : bool = let is_def_cond (ctxt : context) (def : Desugared.Ast.ScopeDef.t) : bool =
match def with match def with
| Desugared.Ast.ScopeDef.SubScopeVar (_, x) | Desugared.Ast.ScopeDef.SubScopeVar (_, x)
(* we don't need to look at the subscope prefix because [x] is already the uid referring back to (* we don't need to look at the subscope prefix because [x] is already the uid
the original subscope *) referring back to the original subscope *)
| Desugared.Ast.ScopeDef.Var (x, _) -> | Desugared.Ast.ScopeDef.Var (x, _) ->
is_var_cond ctxt x is_var_cond ctxt x
let label_groups (ctxt : context) (s_uid : Scopelang.Ast.ScopeName.t) let label_groups
(def : Desugared.Ast.ScopeDef.t) : Desugared.Ast.RuleSet.t Desugared.Ast.LabelMap.t = (ctxt : context)
(s_uid : Scopelang.Ast.ScopeName.t)
(def : Desugared.Ast.ScopeDef.t) :
Desugared.Ast.RuleSet.t Desugared.Ast.LabelMap.t =
try try
(Desugared.Ast.ScopeDefMap.find def (Desugared.Ast.ScopeDefMap.find def
(Scopelang.Ast.ScopeMap.find s_uid ctxt.scopes).scope_defs_contexts) (Scopelang.Ast.ScopeMap.find s_uid ctxt.scopes).scope_defs_contexts)
@ -161,16 +193,21 @@ let label_groups (ctxt : context) (s_uid : Scopelang.Ast.ScopeName.t)
(** {1 Declarations pass} *) (** {1 Declarations pass} *)
(** Process a subscope declaration *) (** Process a subscope declaration *)
let process_subscope_decl (scope : Scopelang.Ast.ScopeName.t) (ctxt : context) let process_subscope_decl
(scope : Scopelang.Ast.ScopeName.t)
(ctxt : context)
(decl : Ast.scope_decl_context_scope) : context = (decl : Ast.scope_decl_context_scope) : context =
let name, name_pos = decl.scope_decl_context_scope_name in let name, name_pos = decl.scope_decl_context_scope_name in
let subscope, s_pos = decl.scope_decl_context_scope_sub_scope in let subscope, s_pos = decl.scope_decl_context_scope_sub_scope in
let scope_ctxt = Scopelang.Ast.ScopeMap.find scope ctxt.scopes in let scope_ctxt = Scopelang.Ast.ScopeMap.find scope ctxt.scopes in
match Desugared.Ast.IdentMap.find_opt subscope scope_ctxt.sub_scopes_idmap with match
Desugared.Ast.IdentMap.find_opt subscope scope_ctxt.sub_scopes_idmap
with
| Some use -> | Some use ->
Errors.raise_multispanned_error Errors.raise_multispanned_error
[ [
(Some "first use", Pos.get_position (Scopelang.Ast.SubScopeName.get_info use)); ( Some "first use",
Pos.get_position (Scopelang.Ast.SubScopeName.get_info use) );
(Some "second use", s_pos); (Some "second use", s_pos);
] ]
"Subscope name \"%a\" already used" "Subscope name \"%a\" already used"
@ -187,26 +224,36 @@ let process_subscope_decl (scope : Scopelang.Ast.ScopeName.t) (ctxt : context)
{ {
scope_ctxt with scope_ctxt with
sub_scopes_idmap = sub_scopes_idmap =
Desugared.Ast.IdentMap.add name sub_scope_uid scope_ctxt.sub_scopes_idmap; Desugared.Ast.IdentMap.add name sub_scope_uid
scope_ctxt.sub_scopes_idmap;
sub_scopes = sub_scopes =
Scopelang.Ast.SubScopeMap.add sub_scope_uid original_subscope_uid scope_ctxt.sub_scopes; Scopelang.Ast.SubScopeMap.add sub_scope_uid original_subscope_uid
scope_ctxt.sub_scopes;
} }
in in
{ ctxt with scopes = Scopelang.Ast.ScopeMap.add scope scope_ctxt ctxt.scopes } {
ctxt with
scopes = Scopelang.Ast.ScopeMap.add scope scope_ctxt ctxt.scopes;
}
let is_type_cond ((typ, _) : Ast.typ Pos.marked) = let is_type_cond ((typ, _) : Ast.typ Pos.marked) =
match typ with match typ with
| Ast.Base Ast.Condition | Ast.Func { arg_typ = _; return_typ = Ast.Condition, _ } -> true | Ast.Base Ast.Condition
| Ast.Func { arg_typ = _; return_typ = Ast.Condition, _ } ->
true
| _ -> false | _ -> false
(** Process a basic type (all types except function types) *) (** Process a basic type (all types except function types) *)
let rec process_base_typ (ctxt : context) ((typ, typ_pos) : Ast.base_typ Pos.marked) : let rec process_base_typ
(ctxt : context) ((typ, typ_pos) : Ast.base_typ Pos.marked) :
Scopelang.Ast.typ Pos.marked = Scopelang.Ast.typ Pos.marked =
match typ with match typ with
| Ast.Condition -> (Scopelang.Ast.TLit TBool, typ_pos) | Ast.Condition -> (Scopelang.Ast.TLit TBool, typ_pos)
| Ast.Data (Ast.Collection t) -> | Ast.Data (Ast.Collection t) ->
( Scopelang.Ast.TArray ( Scopelang.Ast.TArray
(Pos.unmark (process_base_typ ctxt (Ast.Data (Pos.unmark t), Pos.get_position t))), (Pos.unmark
(process_base_typ ctxt
(Ast.Data (Pos.unmark t), Pos.get_position t))),
typ_pos ) typ_pos )
| Ast.Data (Ast.Primitive prim) -> ( | Ast.Data (Ast.Primitive prim) -> (
match prim with match prim with
@ -225,7 +272,8 @@ let rec process_base_typ (ctxt : context) ((typ, typ_pos) : Ast.base_typ Pos.mar
| Some e_uid -> (Scopelang.Ast.TEnum e_uid, typ_pos) | Some e_uid -> (Scopelang.Ast.TEnum e_uid, typ_pos)
| None -> | None ->
Errors.raise_spanned_error typ_pos Errors.raise_spanned_error typ_pos
"Unknown type \"%a\", not a struct or enum previously declared" "Unknown type \"%a\", not a struct or enum previously \
declared"
(Utils.Cli.format_with_style [ ANSITerminal.yellow ]) (Utils.Cli.format_with_style [ ANSITerminal.yellow ])
ident))) ident)))
@ -235,11 +283,14 @@ let process_type (ctxt : context) ((typ, typ_pos) : Ast.typ Pos.marked) :
match typ with match typ with
| Ast.Base base_typ -> process_base_typ ctxt (base_typ, typ_pos) | Ast.Base base_typ -> process_base_typ ctxt (base_typ, typ_pos)
| Ast.Func { arg_typ; return_typ } -> | Ast.Func { arg_typ; return_typ } ->
( Scopelang.Ast.TArrow (process_base_typ ctxt arg_typ, process_base_typ ctxt return_typ), ( Scopelang.Ast.TArrow
(process_base_typ ctxt arg_typ, process_base_typ ctxt return_typ),
typ_pos ) typ_pos )
(** Process data declaration *) (** Process data declaration *)
let process_data_decl (scope : Scopelang.Ast.ScopeName.t) (ctxt : context) let process_data_decl
(scope : Scopelang.Ast.ScopeName.t)
(ctxt : context)
(decl : Ast.scope_decl_context_data) : context = (decl : Ast.scope_decl_context_data) : context =
(* First check the type of the context data *) (* First check the type of the context data *)
let data_typ = process_type ctxt decl.scope_decl_context_item_typ in let data_typ = process_type ctxt decl.scope_decl_context_item_typ in
@ -250,7 +301,8 @@ let process_data_decl (scope : Scopelang.Ast.ScopeName.t) (ctxt : context)
| Some use -> | Some use ->
Errors.raise_multispanned_error Errors.raise_multispanned_error
[ [
(Some "first use", Pos.get_position (Desugared.Ast.ScopeVar.get_info use)); ( Some "first use",
Pos.get_position (Desugared.Ast.ScopeVar.get_info use) );
(Some "second use", pos); (Some "second use", pos);
] ]
"var name \"%a\" already used" "var name \"%a\" already used"
@ -259,13 +311,17 @@ let process_data_decl (scope : Scopelang.Ast.ScopeName.t) (ctxt : context)
| None -> | None ->
let uid = Desugared.Ast.ScopeVar.fresh (name, pos) in let uid = Desugared.Ast.ScopeVar.fresh (name, pos) in
let scope_ctxt = let scope_ctxt =
{ scope_ctxt with var_idmap = Desugared.Ast.IdentMap.add name uid scope_ctxt.var_idmap } {
scope_ctxt with
var_idmap = Desugared.Ast.IdentMap.add name uid scope_ctxt.var_idmap;
}
in in
let states_idmap, states_list = let states_idmap, states_list =
List.fold_right List.fold_right
(fun state_id (states_idmap, states_list) -> (fun state_id (states_idmap, states_list) ->
let state_uid = Desugared.Ast.StateName.fresh state_id in let state_uid = Desugared.Ast.StateName.fresh state_id in
( Desugared.Ast.IdentMap.add (Pos.unmark state_id) state_uid states_idmap, ( Desugared.Ast.IdentMap.add (Pos.unmark state_id) state_uid
states_idmap,
state_uid :: states_list )) state_uid :: states_list ))
decl.scope_decl_context_item_states decl.scope_decl_context_item_states
(Desugared.Ast.IdentMap.empty, []) (Desugared.Ast.IdentMap.empty, [])
@ -286,20 +342,24 @@ let process_data_decl (scope : Scopelang.Ast.ScopeName.t) (ctxt : context)
} }
(** Process an item declaration *) (** Process an item declaration *)
let process_item_decl (scope : Scopelang.Ast.ScopeName.t) (ctxt : context) let process_item_decl
(scope : Scopelang.Ast.ScopeName.t)
(ctxt : context)
(decl : Ast.scope_decl_context_item) : context = (decl : Ast.scope_decl_context_item) : context =
match decl with match decl with
| Ast.ContextData data_decl -> process_data_decl scope ctxt data_decl | Ast.ContextData data_decl -> process_data_decl scope ctxt data_decl
| Ast.ContextScope sub_decl -> process_subscope_decl scope ctxt sub_decl | Ast.ContextScope sub_decl -> process_subscope_decl scope ctxt sub_decl
(** Adds a binding to the context *) (** Adds a binding to the context *)
let add_def_local_var (ctxt : context) (name : ident Pos.marked) : context * Desugared.Ast.Var.t = let add_def_local_var (ctxt : context) (name : ident Pos.marked) :
context * Desugared.Ast.Var.t =
let local_var_uid = Desugared.Ast.Var.make name in let local_var_uid = Desugared.Ast.Var.make name in
let ctxt = let ctxt =
{ {
ctxt with ctxt with
local_var_idmap = local_var_idmap =
Desugared.Ast.IdentMap.add (Pos.unmark name) local_var_uid ctxt.local_var_idmap; Desugared.Ast.IdentMap.add (Pos.unmark name) local_var_uid
ctxt.local_var_idmap;
} }
in in
(ctxt, local_var_uid) (ctxt, local_var_uid)
@ -314,10 +374,14 @@ let process_scope_decl (ctxt : context) (decl : Ast.scope_decl) : context =
(** Process a struct declaration *) (** Process a struct declaration *)
let process_struct_decl (ctxt : context) (sdecl : Ast.struct_decl) : context = let process_struct_decl (ctxt : context) (sdecl : Ast.struct_decl) : context =
let s_uid = Desugared.Ast.IdentMap.find (fst sdecl.struct_decl_name) ctxt.struct_idmap in let s_uid =
Desugared.Ast.IdentMap.find (fst sdecl.struct_decl_name) ctxt.struct_idmap
in
List.fold_left List.fold_left
(fun ctxt (fdecl, _) -> (fun ctxt (fdecl, _) ->
let f_uid = Scopelang.Ast.StructFieldName.fresh fdecl.Ast.struct_decl_field_name in let f_uid =
Scopelang.Ast.StructFieldName.fresh fdecl.Ast.struct_decl_field_name
in
let ctxt = let ctxt =
{ {
ctxt with ctxt with
@ -327,7 +391,8 @@ let process_struct_decl (ctxt : context) (sdecl : Ast.struct_decl) : context =
(fun uids -> (fun uids ->
match uids with match uids with
| None -> Some (Scopelang.Ast.StructMap.singleton s_uid f_uid) | None -> Some (Scopelang.Ast.StructMap.singleton s_uid f_uid)
| Some uids -> Some (Scopelang.Ast.StructMap.add s_uid f_uid uids)) | Some uids ->
Some (Scopelang.Ast.StructMap.add s_uid f_uid uids))
ctxt.field_idmap; ctxt.field_idmap;
} }
in in
@ -352,10 +417,14 @@ let process_struct_decl (ctxt : context) (sdecl : Ast.struct_decl) : context =
(** Process an enum declaration *) (** Process an enum declaration *)
let process_enum_decl (ctxt : context) (edecl : Ast.enum_decl) : context = let process_enum_decl (ctxt : context) (edecl : Ast.enum_decl) : context =
let e_uid = Desugared.Ast.IdentMap.find (fst edecl.enum_decl_name) ctxt.enum_idmap in let e_uid =
Desugared.Ast.IdentMap.find (fst edecl.enum_decl_name) ctxt.enum_idmap
in
List.fold_left List.fold_left
(fun ctxt (cdecl, cdecl_pos) -> (fun ctxt (cdecl, cdecl_pos) ->
let c_uid = Scopelang.Ast.EnumConstructor.fresh cdecl.Ast.enum_decl_case_name in let c_uid =
Scopelang.Ast.EnumConstructor.fresh cdecl.Ast.enum_decl_case_name
in
let ctxt = let ctxt =
{ {
ctxt with ctxt with
@ -380,17 +449,23 @@ let process_enum_decl (ctxt : context) (edecl : Ast.enum_decl) : context =
| Some typ -> process_type ctxt typ | Some typ -> process_type ctxt typ
in in
match cases with match cases with
| None -> Some (Scopelang.Ast.EnumConstructorMap.singleton c_uid typ) | None ->
| Some fields -> Some (Scopelang.Ast.EnumConstructorMap.add c_uid typ fields)) Some (Scopelang.Ast.EnumConstructorMap.singleton c_uid typ)
| Some fields ->
Some (Scopelang.Ast.EnumConstructorMap.add c_uid typ fields))
ctxt.enums; ctxt.enums;
}) })
ctxt edecl.enum_decl_cases ctxt edecl.enum_decl_cases
(** Process the names of all declaration items *) (** Process the names of all declaration items *)
let process_name_item (ctxt : context) (item : Ast.code_item Pos.marked) : context = let process_name_item (ctxt : context) (item : Ast.code_item Pos.marked) :
context =
let raise_already_defined_error (use : Uid.MarkedString.info) name pos msg = let raise_already_defined_error (use : Uid.MarkedString.info) name pos msg =
Errors.raise_multispanned_error Errors.raise_multispanned_error
[ (Some "First definition:", Pos.get_position use); (Some "Second definition:", pos) ] [
(Some "First definition:", Pos.get_position use);
(Some "Second definition:", pos);
]
"%s name \"%a\" already defined" msg "%s name \"%a\" already defined" msg
(Utils.Cli.format_with_style [ ANSITerminal.yellow ]) (Utils.Cli.format_with_style [ ANSITerminal.yellow ])
name name
@ -401,12 +476,15 @@ let process_name_item (ctxt : context) (item : Ast.code_item Pos.marked) : conte
(* Checks if the name is already used *) (* Checks if the name is already used *)
match Desugared.Ast.IdentMap.find_opt name ctxt.scope_idmap with match Desugared.Ast.IdentMap.find_opt name ctxt.scope_idmap with
| Some use -> | Some use ->
raise_already_defined_error (Scopelang.Ast.ScopeName.get_info use) name pos "scope" raise_already_defined_error
(Scopelang.Ast.ScopeName.get_info use)
name pos "scope"
| None -> | None ->
let scope_uid = Scopelang.Ast.ScopeName.fresh (name, pos) in let scope_uid = Scopelang.Ast.ScopeName.fresh (name, pos) in
{ {
ctxt with ctxt with
scope_idmap = Desugared.Ast.IdentMap.add name scope_uid ctxt.scope_idmap; scope_idmap =
Desugared.Ast.IdentMap.add name scope_uid ctxt.scope_idmap;
scopes = scopes =
Scopelang.Ast.ScopeMap.add scope_uid Scopelang.Ast.ScopeMap.add scope_uid
{ {
@ -421,31 +499,40 @@ let process_name_item (ctxt : context) (item : Ast.code_item Pos.marked) : conte
let name, pos = sdecl.struct_decl_name in let name, pos = sdecl.struct_decl_name in
match Desugared.Ast.IdentMap.find_opt name ctxt.struct_idmap with match Desugared.Ast.IdentMap.find_opt name ctxt.struct_idmap with
| Some use -> | Some use ->
raise_already_defined_error (Scopelang.Ast.StructName.get_info use) name pos "struct" raise_already_defined_error
(Scopelang.Ast.StructName.get_info use)
name pos "struct"
| None -> | None ->
let s_uid = Scopelang.Ast.StructName.fresh sdecl.struct_decl_name in let s_uid = Scopelang.Ast.StructName.fresh sdecl.struct_decl_name in
{ {
ctxt with ctxt with
struct_idmap = struct_idmap =
Desugared.Ast.IdentMap.add (Pos.unmark sdecl.struct_decl_name) s_uid ctxt.struct_idmap; Desugared.Ast.IdentMap.add
(Pos.unmark sdecl.struct_decl_name)
s_uid ctxt.struct_idmap;
}) })
| EnumDecl edecl -> ( | EnumDecl edecl -> (
let name, pos = edecl.enum_decl_name in let name, pos = edecl.enum_decl_name in
match Desugared.Ast.IdentMap.find_opt name ctxt.enum_idmap with match Desugared.Ast.IdentMap.find_opt name ctxt.enum_idmap with
| Some use -> | Some use ->
raise_already_defined_error (Scopelang.Ast.EnumName.get_info use) name pos "enum" raise_already_defined_error
(Scopelang.Ast.EnumName.get_info use)
name pos "enum"
| None -> | None ->
let e_uid = Scopelang.Ast.EnumName.fresh edecl.enum_decl_name in let e_uid = Scopelang.Ast.EnumName.fresh edecl.enum_decl_name in
{ {
ctxt with ctxt with
enum_idmap = enum_idmap =
Desugared.Ast.IdentMap.add (Pos.unmark edecl.enum_decl_name) e_uid ctxt.enum_idmap; Desugared.Ast.IdentMap.add
(Pos.unmark edecl.enum_decl_name)
e_uid ctxt.enum_idmap;
}) })
| ScopeUse _ -> ctxt | ScopeUse _ -> ctxt
(** Process a code item that is a declaration *) (** Process a code item that is a declaration *)
let process_decl_item (ctxt : context) (item : Ast.code_item Pos.marked) : context = let process_decl_item (ctxt : context) (item : Ast.code_item Pos.marked) :
context =
match Pos.unmark item with match Pos.unmark item with
| ScopeDecl decl -> process_scope_decl ctxt decl | ScopeDecl decl -> process_scope_decl ctxt decl
| StructDecl sdecl -> process_struct_decl ctxt sdecl | StructDecl sdecl -> process_struct_decl ctxt sdecl
@ -453,24 +540,33 @@ let process_decl_item (ctxt : context) (item : Ast.code_item Pos.marked) : conte
| ScopeUse _ -> ctxt | ScopeUse _ -> ctxt
(** Process a code block *) (** Process a code block *)
let process_code_block (ctxt : context) (block : Ast.code_block) let process_code_block
(ctxt : context)
(block : Ast.code_block)
(process_item : context -> Ast.code_item Pos.marked -> context) : context = (process_item : context -> Ast.code_item Pos.marked -> context) : context =
List.fold_left (fun ctxt decl -> process_item ctxt decl) ctxt block List.fold_left (fun ctxt decl -> process_item ctxt decl) ctxt block
(** Process a law structure, only considering the code blocks *) (** Process a law structure, only considering the code blocks *)
let rec process_law_structure (ctxt : context) (s : Ast.law_structure) let rec process_law_structure
(ctxt : context)
(s : Ast.law_structure)
(process_item : context -> Ast.code_item Pos.marked -> context) : context = (process_item : context -> Ast.code_item Pos.marked -> context) : context =
match s with match s with
| Ast.LawHeading (_, children) -> | Ast.LawHeading (_, children) ->
List.fold_left (fun ctxt child -> process_law_structure ctxt child process_item) ctxt children List.fold_left
(fun ctxt child -> process_law_structure ctxt child process_item)
ctxt children
| Ast.CodeBlock (block, _, _) -> process_code_block ctxt block process_item | Ast.CodeBlock (block, _, _) -> process_code_block ctxt block process_item
| Ast.LawInclude _ | Ast.LawText _ -> ctxt | Ast.LawInclude _ | Ast.LawText _ -> ctxt
(** {1 Scope uses pass} *) (** {1 Scope uses pass} *)
let get_def_key (name : Ast.qident) (state : Ast.ident Pos.marked option) let get_def_key
(scope_uid : Scopelang.Ast.ScopeName.t) (ctxt : context) (default_pos : Pos.t) : (name : Ast.qident)
Desugared.Ast.ScopeDef.t = (state : Ast.ident Pos.marked option)
(scope_uid : Scopelang.Ast.ScopeName.t)
(ctxt : context)
(default_pos : Pos.t) : Desugared.Ast.ScopeDef.t =
let scope_ctxt = Scopelang.Ast.ScopeMap.find scope_uid ctxt.scopes in let scope_ctxt = Scopelang.Ast.ScopeMap.find scope_uid ctxt.scopes in
match name with match name with
| [ x ] -> | [ x ] ->
@ -480,30 +576,40 @@ let get_def_key (name : Ast.qident) (state : Ast.ident Pos.marked option)
( x_uid, ( x_uid,
match state with match state with
| Some state -> ( | Some state -> (
try Some (Desugared.Ast.IdentMap.find (Pos.unmark state) var_sig.var_sig_states_idmap) try
Some
(Desugared.Ast.IdentMap.find (Pos.unmark state)
var_sig.var_sig_states_idmap)
with Not_found -> with Not_found ->
Errors.raise_multispanned_error Errors.raise_multispanned_error
[ [
(None, Pos.get_position state); (None, Pos.get_position state);
( Some "Variable declaration:", ( Some "Variable declaration:",
Pos.get_position (Desugared.Ast.ScopeVar.get_info x_uid) ); Pos.get_position (Desugared.Ast.ScopeVar.get_info x_uid)
);
] ]
"This identifier is not a state declared for variable %a." "This identifier is not a state declared for variable %a."
Desugared.Ast.ScopeVar.format_t x_uid) Desugared.Ast.ScopeVar.format_t x_uid)
| None -> | None ->
if not (Desugared.Ast.IdentMap.is_empty var_sig.var_sig_states_idmap) then if
not
(Desugared.Ast.IdentMap.is_empty var_sig.var_sig_states_idmap)
then
Errors.raise_multispanned_error Errors.raise_multispanned_error
[ [
(None, Pos.get_position x); (None, Pos.get_position x);
( Some "Variable declaration:", ( Some "Variable declaration:",
Pos.get_position (Desugared.Ast.ScopeVar.get_info x_uid) ); Pos.get_position (Desugared.Ast.ScopeVar.get_info x_uid)
);
] ]
"This definition does not indicate which state has to be considered for variable \ "This definition does not indicate which state has to be \
%a." considered for variable %a."
Desugared.Ast.ScopeVar.format_t x_uid Desugared.Ast.ScopeVar.format_t x_uid
else None ) else None )
| [ y; x ] -> | [ y; x ] ->
let subscope_uid : Scopelang.Ast.SubScopeName.t = get_subscope_uid scope_uid ctxt y in let subscope_uid : Scopelang.Ast.SubScopeName.t =
get_subscope_uid scope_uid ctxt y
in
let subscope_real_uid : Scopelang.Ast.ScopeName.t = let subscope_real_uid : Scopelang.Ast.ScopeName.t =
Scopelang.Ast.SubScopeMap.find subscope_uid scope_ctxt.sub_scopes Scopelang.Ast.SubScopeMap.find subscope_uid scope_ctxt.sub_scopes
in in
@ -511,7 +617,8 @@ let get_def_key (name : Ast.qident) (state : Ast.ident Pos.marked option)
Desugared.Ast.ScopeDef.SubScopeVar (subscope_uid, x_uid) Desugared.Ast.ScopeDef.SubScopeVar (subscope_uid, x_uid)
| _ -> Errors.raise_spanned_error default_pos "Structs are not handled yet" | _ -> Errors.raise_spanned_error default_pos "Structs are not handled yet"
let process_definition (ctxt : context) (s_name : Scopelang.Ast.ScopeName.t) (d : Ast.definition) : let process_definition
(ctxt : context) (s_name : Scopelang.Ast.ScopeName.t) (d : Ast.definition) :
context = context =
(* We update the definition context inside the big context *) (* We update the definition context inside the big context *)
{ {
@ -520,7 +627,9 @@ let process_definition (ctxt : context) (s_name : Scopelang.Ast.ScopeName.t) (d
Scopelang.Ast.ScopeMap.update s_name Scopelang.Ast.ScopeMap.update s_name
(fun (s_ctxt : scope_context option) -> (fun (s_ctxt : scope_context option) ->
let def_key = let def_key =
get_def_key (Pos.unmark d.definition_name) d.definition_state s_name ctxt get_def_key
(Pos.unmark d.definition_name)
d.definition_state s_name ctxt
(Pos.get_position d.definition_expr) (Pos.get_position d.definition_expr)
in in
match s_ctxt with match s_ctxt with
@ -536,8 +645,8 @@ let process_definition (ctxt : context) (s_name : Scopelang.Ast.ScopeName.t) (d
Option.fold Option.fold
~none: ~none:
{ {
(* Here, this is the first time we encounter a definition for this (* Here, this is the first time we encounter a
definition key *) definition for this definition key *)
default_exception_rulename = None; default_exception_rulename = None;
label_idmap = Desugared.Ast.IdentMap.empty; label_idmap = Desugared.Ast.IdentMap.empty;
label_groups = Desugared.Ast.LabelMap.empty; label_groups = Desugared.Ast.LabelMap.empty;
@ -545,8 +654,8 @@ let process_definition (ctxt : context) (s_name : Scopelang.Ast.ScopeName.t) (d
~some:(fun x -> x) ~some:(fun x -> x)
def_key_ctx def_key_ctx
in in
(* First, we update the def key context with information about the (* First, we update the def key context with information
definition's label*) about the definition's label*)
let def_key_ctx = let def_key_ctx =
match d.Ast.definition_label with match d.Ast.definition_label with
| None -> def_key_ctx | None -> def_key_ctx
@ -556,11 +665,14 @@ let process_definition (ctxt : context) (s_name : Scopelang.Ast.ScopeName.t) (d
(fun existing_label -> (fun existing_label ->
match existing_label with match existing_label with
| Some existing_label -> Some existing_label | Some existing_label -> Some existing_label
| None -> Some (Desugared.Ast.LabelName.fresh label)) | None ->
Some
(Desugared.Ast.LabelName.fresh label))
def_key_ctx.label_idmap def_key_ctx.label_idmap
in in
let label_id = let label_id =
Desugared.Ast.IdentMap.find (Pos.unmark label) new_label_idmap Desugared.Ast.IdentMap.find (Pos.unmark label)
new_label_idmap
in in
{ {
def_key_ctx with def_key_ctx with
@ -570,34 +682,39 @@ let process_definition (ctxt : context) (s_name : Scopelang.Ast.ScopeName.t) (d
(fun group -> (fun group ->
match group with match group with
| None -> | None ->
Some (Desugared.Ast.RuleSet.singleton d.definition_id) Some
(Desugared.Ast.RuleSet.singleton
d.definition_id)
| Some existing_group -> | Some existing_group ->
Some Some
(Desugared.Ast.RuleSet.add d.definition_id (Desugared.Ast.RuleSet.add
existing_group)) d.definition_id existing_group))
def_key_ctx.label_groups; def_key_ctx.label_groups;
} }
in in
(* And second, we update the map of default rulenames for unlabeled (* And second, we update the map of default rulenames
exceptions *) for unlabeled exceptions *)
let def_key_ctx = let def_key_ctx =
match d.Ast.definition_exception_to with match d.Ast.definition_exception_to with
(* If this definition is an exception, it cannot be a default (* If this definition is an exception, it cannot be a
definition *) default definition *)
| UnlabeledException | ExceptionToLabel _ -> def_key_ctx | UnlabeledException | ExceptionToLabel _ ->
(* If it is not an exception, we need to distinguish between several def_key_ctx
cases *) (* If it is not an exception, we need to distinguish
between several cases *)
| NotAnException -> ( | NotAnException -> (
match def_key_ctx.default_exception_rulename with match def_key_ctx.default_exception_rulename with
(* There was already a default definition for this key. If we need it, (* There was already a default definition for this
it is ambiguous *) key. If we need it, it is ambiguous *)
| Some old -> | Some old ->
{ {
def_key_ctx with def_key_ctx with
default_exception_rulename = default_exception_rulename =
Some Some
(Ambiguous (Ambiguous
([ Pos.get_position d.definition_name ] ([
Pos.get_position d.definition_name;
]
@ @
match old with match old with
| Ambiguous old -> old | Ambiguous old -> old
@ -606,23 +723,31 @@ let process_definition (ctxt : context) (s_name : Scopelang.Ast.ScopeName.t) (d
(* No definition has been set yet for this key *) (* No definition has been set yet for this key *)
| None -> ( | None -> (
match d.Ast.definition_label with match d.Ast.definition_label with
(* This default definition has a label. This is not allowed for (* This default definition has a label. This
unlabeled exceptions *) is not allowed for unlabeled exceptions *)
| Some _ -> | Some _ ->
{ {
def_key_ctx with def_key_ctx with
default_exception_rulename = default_exception_rulename =
Some (Ambiguous [ Pos.get_position d.definition_name ]); Some
(Ambiguous
[
Pos.get_position
d.definition_name;
]);
} }
(* This is a possible default definition for this key. We create (* This is a possible default definition for
and store a fresh rulename *) this key. We create and store a fresh
rulename *)
| None -> | None ->
{ {
def_key_ctx with def_key_ctx with
default_exception_rulename = default_exception_rulename =
Some Some
(Unique (Unique
(d.definition_id, Pos.get_position d.definition_name)); ( d.definition_id,
Pos.get_position
d.definition_name ));
})) }))
in in
Some def_key_ctx) Some def_key_ctx)
@ -631,7 +756,9 @@ let process_definition (ctxt : context) (s_name : Scopelang.Ast.ScopeName.t) (d
ctxt.scopes; ctxt.scopes;
} }
let process_scope_use_item (s_name : Scopelang.Ast.ScopeName.t) (ctxt : context) let process_scope_use_item
(s_name : Scopelang.Ast.ScopeName.t)
(ctxt : context)
(sitem : Ast.scope_use_item Pos.marked) : context = (sitem : Ast.scope_use_item Pos.marked) : context =
match Pos.unmark sitem with match Pos.unmark sitem with
| Rule r -> process_definition ctxt s_name (Ast.rule_to_def r) | Rule r -> process_definition ctxt s_name (Ast.rule_to_def r)
@ -640,7 +767,10 @@ let process_scope_use_item (s_name : Scopelang.Ast.ScopeName.t) (ctxt : context)
let process_scope_use (ctxt : context) (suse : Ast.scope_use) : context = let process_scope_use (ctxt : context) (suse : Ast.scope_use) : context =
let s_name = let s_name =
try Desugared.Ast.IdentMap.find (Pos.unmark suse.Ast.scope_use_name) ctxt.scope_idmap try
Desugared.Ast.IdentMap.find
(Pos.unmark suse.Ast.scope_use_name)
ctxt.scope_idmap
with Not_found -> with Not_found ->
Errors.raise_spanned_error Errors.raise_spanned_error
(Pos.get_position suse.Ast.scope_use_name) (Pos.get_position suse.Ast.scope_use_name)
@ -650,7 +780,8 @@ let process_scope_use (ctxt : context) (suse : Ast.scope_use) : context =
in in
List.fold_left (process_scope_use_item s_name) ctxt suse.Ast.scope_use_items List.fold_left (process_scope_use_item s_name) ctxt suse.Ast.scope_use_items
let process_use_item (ctxt : context) (item : Ast.code_item Pos.marked) : context = let process_use_item (ctxt : context) (item : Ast.code_item Pos.marked) :
context =
match Pos.unmark item with match Pos.unmark item with
| ScopeDecl _ | StructDecl _ | EnumDecl _ -> ctxt | ScopeDecl _ | StructDecl _ | EnumDecl _ -> ctxt
| ScopeUse suse -> process_scope_use ctxt suse | ScopeUse suse -> process_scope_use ctxt suse

View File

@ -1,29 +1,33 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits (* This file is part of the Catala compiler, a specification language for tax
computation rules. Copyright (C) 2020 Inria, contributor: Nicolas Chataing and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
<nicolas.chataing@ens.fr> Denis Merigoux <denis.merigoux@inria.fr> Nicolas Chataing <nicolas.chataing@ens.fr> Denis Merigoux
<denis.merigoux@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
(** Builds a context that allows for mapping each name to a precise uid, taking lexical scopes into (** Builds a context that allows for mapping each name to a precise uid, taking
account *) lexical scopes into account *)
open Utils open Utils
(** {1 Name resolution context} *) (** {1 Name resolution context} *)
type ident = string type ident = string
type typ = Scopelang.Ast.typ type typ = Scopelang.Ast.typ
type unique_rulename = Ambiguous of Pos.t list | Unique of Desugared.Ast.RuleName.t Pos.marked type unique_rulename =
| Ambiguous of Pos.t list
| Unique of Desugared.Ast.RuleName.t Pos.marked
type scope_def_context = { type scope_def_context = {
default_exception_rulename : unique_rulename option; default_exception_rulename : unique_rulename option;
@ -32,7 +36,8 @@ type scope_def_context = {
} }
type scope_context = { type scope_context = {
var_idmap : Desugared.Ast.ScopeVar.t Desugared.Ast.IdentMap.t; (** Scope variables *) var_idmap : Desugared.Ast.ScopeVar.t Desugared.Ast.IdentMap.t;
(** Scope variables *)
scope_defs_contexts : scope_def_context Desugared.Ast.ScopeDefMap.t; scope_defs_contexts : scope_def_context Desugared.Ast.ScopeDefMap.t;
(** What is the default rule to refer to for unnamed exceptions, if any *) (** What is the default rule to refer to for unnamed exceptions, if any *)
sub_scopes_idmap : Scopelang.Ast.SubScopeName.t Desugared.Ast.IdentMap.t; sub_scopes_idmap : Scopelang.Ast.SubScopeName.t Desugared.Ast.IdentMap.t;
@ -58,21 +63,30 @@ type var_sig = {
type context = { type context = {
local_var_idmap : Desugared.Ast.Var.t Desugared.Ast.IdentMap.t; local_var_idmap : Desugared.Ast.Var.t Desugared.Ast.IdentMap.t;
(** Inside a definition, local variables can be introduced by functions arguments or pattern (** Inside a definition, local variables can be introduced by functions
matching *) arguments or pattern matching *)
scope_idmap : Scopelang.Ast.ScopeName.t Desugared.Ast.IdentMap.t; (** The names of the scopes *) scope_idmap : Scopelang.Ast.ScopeName.t Desugared.Ast.IdentMap.t;
(** The names of the scopes *)
struct_idmap : Scopelang.Ast.StructName.t Desugared.Ast.IdentMap.t; struct_idmap : Scopelang.Ast.StructName.t Desugared.Ast.IdentMap.t;
(** The names of the structs *) (** The names of the structs *)
field_idmap : Scopelang.Ast.StructFieldName.t Scopelang.Ast.StructMap.t Desugared.Ast.IdentMap.t; field_idmap :
(** The names of the struct fields. Names of fields can be shared between different structs *) Scopelang.Ast.StructFieldName.t Scopelang.Ast.StructMap.t
enum_idmap : Scopelang.Ast.EnumName.t Desugared.Ast.IdentMap.t; (** The names of the enums *) Desugared.Ast.IdentMap.t;
(** The names of the struct fields. Names of fields can be shared between
different structs *)
enum_idmap : Scopelang.Ast.EnumName.t Desugared.Ast.IdentMap.t;
(** The names of the enums *)
constructor_idmap : constructor_idmap :
Scopelang.Ast.EnumConstructor.t Scopelang.Ast.EnumMap.t Desugared.Ast.IdentMap.t; Scopelang.Ast.EnumConstructor.t Scopelang.Ast.EnumMap.t
(** The names of the enum constructors. Constructor names can be shared between different Desugared.Ast.IdentMap.t;
enums *) (** The names of the enum constructors. Constructor names can be shared
scopes : scope_context Scopelang.Ast.ScopeMap.t; (** For each scope, its context *) between different enums *)
structs : struct_context Scopelang.Ast.StructMap.t; (** For each struct, its context *) scopes : scope_context Scopelang.Ast.ScopeMap.t;
enums : enum_context Scopelang.Ast.EnumMap.t; (** For each enum, its context *) (** For each scope, its context *)
structs : struct_context Scopelang.Ast.StructMap.t;
(** For each struct, its context *)
enums : enum_context Scopelang.Ast.EnumMap.t;
(** For each enum, its context *)
var_typs : var_sig Desugared.Ast.ScopeVarMap.t; var_typs : var_sig Desugared.Ast.ScopeVarMap.t;
(** The signatures of each scope variable declared *) (** The signatures of each scope variable declared *)
} }
@ -81,31 +95,41 @@ type context = {
(** {1 Helpers} *) (** {1 Helpers} *)
val raise_unsupported_feature : string -> Pos.t -> 'a val raise_unsupported_feature : string -> Pos.t -> 'a
(** Temporary function raising an error message saying that a feature is not supported yet *) (** Temporary function raising an error message saying that a feature is not
supported yet *)
val raise_unknown_identifier : string -> ident Pos.marked -> 'a val raise_unknown_identifier : string -> ident Pos.marked -> 'a
(** Function to call whenever an identifier used somewhere has not been declared in the program (** Function to call whenever an identifier used somewhere has not been declared
previously *) in the program previously *)
val get_var_typ : context -> Desugared.Ast.ScopeVar.t -> typ Pos.marked val get_var_typ : context -> Desugared.Ast.ScopeVar.t -> typ Pos.marked
(** Gets the type associated to an uid *) (** Gets the type associated to an uid *)
val is_var_cond : context -> Desugared.Ast.ScopeVar.t -> bool val is_var_cond : context -> Desugared.Ast.ScopeVar.t -> bool
val get_var_io : context -> Desugared.Ast.ScopeVar.t -> Ast.scope_decl_context_io val get_var_io :
context -> Desugared.Ast.ScopeVar.t -> Ast.scope_decl_context_io
val get_var_uid : val get_var_uid :
Scopelang.Ast.ScopeName.t -> context -> ident Pos.marked -> Desugared.Ast.ScopeVar.t Scopelang.Ast.ScopeName.t ->
context ->
ident Pos.marked ->
Desugared.Ast.ScopeVar.t
(** Get the variable uid inside the scope given in argument *) (** Get the variable uid inside the scope given in argument *)
val get_subscope_uid : val get_subscope_uid :
Scopelang.Ast.ScopeName.t -> context -> ident Pos.marked -> Scopelang.Ast.SubScopeName.t Scopelang.Ast.ScopeName.t ->
context ->
ident Pos.marked ->
Scopelang.Ast.SubScopeName.t
(** Get the subscope uid inside the scope given in argument *) (** Get the subscope uid inside the scope given in argument *)
val is_subscope_uid : Scopelang.Ast.ScopeName.t -> context -> ident -> bool val is_subscope_uid : Scopelang.Ast.ScopeName.t -> context -> ident -> bool
(** [is_subscope_uid scope_uid ctxt y] returns true if [y] belongs to the subscopes of [scope_uid]. *) (** [is_subscope_uid scope_uid ctxt y] returns true if [y] belongs to the
subscopes of [scope_uid]. *)
val belongs_to : context -> Desugared.Ast.ScopeVar.t -> Scopelang.Ast.ScopeName.t -> bool val belongs_to :
context -> Desugared.Ast.ScopeVar.t -> Scopelang.Ast.ScopeName.t -> bool
(** Checks if the var_uid belongs to the scope scope_uid *) (** Checks if the var_uid belongs to the scope scope_uid *)
val get_def_typ : context -> Desugared.Ast.ScopeDef.t -> typ Pos.marked val get_def_typ : context -> Desugared.Ast.ScopeDef.t -> typ Pos.marked
@ -121,7 +145,8 @@ val label_groups :
val is_type_cond : Ast.typ Pos.marked -> bool val is_type_cond : Ast.typ Pos.marked -> bool
val add_def_local_var : context -> ident Pos.marked -> context * Desugared.Ast.Var.t val add_def_local_var :
context -> ident Pos.marked -> context * Desugared.Ast.Var.t
(** Adds a binding to the context *) (** Adds a binding to the context *)
val get_def_key : val get_def_key :

View File

@ -1,15 +1,17 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits (* This file is part of the Catala compiler, a specification language for tax
computation rules. Copyright (C) 2020 Inria, contributor: Denis Merigoux and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
<denis.merigoux@inria.fr> Denis Merigoux <denis.merigoux@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
(** Helpers for parsing *) (** Helpers for parsing *)

View File

@ -1,15 +1,17 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits (* This file is part of the Catala compiler, a specification language for tax
computation rules. Copyright (C) 2020 Inria, contributor: Denis Merigoux and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
<denis.merigoux@inria.fr> Denis Merigoux <denis.merigoux@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
(** Helpers for parsing *) (** Helpers for parsing *)

View File

@ -1,19 +1,22 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits (* This file is part of the Catala compiler, a specification language for tax
computation rules. Copyright (C) 2020 Inria, contributors: Denis Merigoux and social benefits computation rules. Copyright (C) 2020 Inria,
<denis.merigoux@inria.fr>, Emile Rolley <emile.rolley@tuta.io> contributors: Denis Merigoux <denis.merigoux@inria.fr>, Emile Rolley
<emile.rolley@tuta.io>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
(** Wrapping module around parser and lexer that offers the {!: Parser_driver.parse_source_file} (** Wrapping module around parser and lexer that offers the {!:
API. *) Parser_driver.parse_source_file} API. *)
open Sedlexing open Sedlexing
open Utils open Utils
@ -23,24 +26,27 @@ open Utils
(** Three-way minimum *) (** Three-way minimum *)
let minimum a b c = min a (min b c) let minimum a b c = min a (min b c)
(** Computes the levenshtein distance between two strings, used to provide error messages (** Computes the levenshtein distance between two strings, used to provide error
suggestions *) messages suggestions *)
let levenshtein_distance (s : string) (t : string) : int = let levenshtein_distance (s : string) (t : string) : int =
let m = String.length s and n = String.length t in let m = String.length s and n = String.length t in
(* for all i and j, d.(i).(j) will hold the Levenshtein distance between the first i characters of (* for all i and j, d.(i).(j) will hold the Levenshtein distance between the
s and the first j characters of t *) first i characters of s and the first j characters of t *)
let d = Array.make_matrix (m + 1) (n + 1) 0 in let d = Array.make_matrix (m + 1) (n + 1) 0 in
for i = 0 to m do for i = 0 to m do
d.(i).(0) <- i (* the distance of any first string to an empty second string *) d.(i).(0) <- i
(* the distance of any first string to an empty second string *)
done; done;
for j = 0 to n do for j = 0 to n do
d.(0).(j) <- j (* the distance of any second string to an empty first string *) d.(0).(j) <- j
(* the distance of any second string to an empty first string *)
done; done;
for j = 1 to n do for j = 1 to n do
for i = 1 to m do for i = 1 to m do
if s.[i - 1] = t.[j - 1] then d.(i).(j) <- d.(i - 1).(j - 1) (* no operation required *) if s.[i - 1] = t.[j - 1] then d.(i).(j) <- d.(i - 1).(j - 1)
(* no operation required *)
else else
d.(i).(j) <- d.(i).(j) <-
minimum minimum
@ -52,9 +58,11 @@ let levenshtein_distance (s : string) (t : string) : int =
d.(m).(n) d.(m).(n)
(** After parsing, heading structure is completely flat because of the [source_file_item] rule. We (** After parsing, heading structure is completely flat because of the
need to tree-i-fy the flat structure, by looking at the precedence of the law headings. *) [source_file_item] rule. We need to tree-i-fy the flat structure, by looking
let rec law_struct_list_to_tree (f : Ast.law_structure list) : Ast.law_structure list = at the precedence of the law headings. *)
let rec law_struct_list_to_tree (f : Ast.law_structure list) :
Ast.law_structure list =
match f with match f with
| [] -> [] | [] -> []
| [ item ] -> [ item ] | [ item ] -> [ item ]
@ -65,18 +73,20 @@ let rec law_struct_list_to_tree (f : Ast.law_structure list) : Ast.law_structure
| rest_head :: rest_tail -> ( | rest_head :: rest_tail -> (
match first_item with match first_item with
| CodeBlock _ | LawText _ | LawInclude _ -> | CodeBlock _ | LawText _ | LawInclude _ ->
(* if an article or an include is just before a new heading , then we don't merge it (* if an article or an include is just before a new heading , then
with what comes next *) we don't merge it with what comes next *)
first_item :: rest_head :: rest_tail first_item :: rest_head :: rest_tail
| LawHeading (heading, _) -> | LawHeading (heading, _) ->
(* here we have encountered a heading, which is going to "gobble" everything in the (* here we have encountered a heading, which is going to "gobble"
[rest_tree] until it finds a heading of at least the same precedence *) everything in the [rest_tree] until it finds a heading of at
least the same precedence *)
let rec split_rest_tree (rest_tree : Ast.law_structure list) : let rec split_rest_tree (rest_tree : Ast.law_structure list) :
Ast.law_structure list * Ast.law_structure list = Ast.law_structure list * Ast.law_structure list =
match rest_tree with match rest_tree with
| [] -> ([], []) | [] -> ([], [])
| LawHeading (new_heading, _) :: _ | LawHeading (new_heading, _) :: _
when new_heading.law_heading_precedence <= heading.law_heading_precedence -> when new_heading.law_heading_precedence
<= heading.law_heading_precedence ->
(* we stop gobbling *) (* we stop gobbling *)
([], rest_tree) ([], rest_tree)
| first :: after -> | first :: after ->
@ -92,10 +102,14 @@ let syntax_hints_style = [ ANSITerminal.yellow ]
(** Usage: [raise_parser_error error_loc last_good_loc token msg] (** Usage: [raise_parser_error error_loc last_good_loc token msg]
Raises an error message featuring the [error_loc] position where the parser has failed, the Raises an error message featuring the [error_loc] position where the parser
[token] on which the parser has failed, and the error message [msg]. If available, displays has failed, the [token] on which the parser has failed, and the error
[last_good_loc] the location of the last token correctly parsed. *) message [msg]. If available, displays [last_good_loc] the location of the
let raise_parser_error (error_loc : Pos.t) (last_good_loc : Pos.t option) (token : string) last token correctly parsed. *)
let raise_parser_error
(error_loc : Pos.t)
(last_good_loc : Pos.t option)
(token : string)
(msg : string) : 'a = (msg : string) : 'a =
Errors.raise_multispanned_error Errors.raise_multispanned_error
((Some "Error token:", error_loc) ((Some "Error token:", error_loc)
@ -105,7 +119,8 @@ let raise_parser_error (error_loc : Pos.t) (last_good_loc : Pos.t option) (token
| Some last_good_loc -> [ (Some "Last good token:", last_good_loc) ])) | Some last_good_loc -> [ (Some "Last good token:", last_good_loc) ]))
"Syntax error at token %a\n%s" "Syntax error at token %a\n%s"
(Cli.format_with_style syntax_hints_style) (Cli.format_with_style syntax_hints_style)
(Printf.sprintf "\"%s\"" token) msg (Printf.sprintf "\"%s\"" token)
msg
module ParserAux (LocalisedLexer : Lexer_common.LocalisedLexer) = struct module ParserAux (LocalisedLexer : Lexer_common.LocalisedLexer) = struct
include Parser.Make (LocalisedLexer) include Parser.Make (LocalisedLexer)
@ -119,21 +134,28 @@ module ParserAux (LocalisedLexer : Lexer_common.LocalisedLexer) = struct
(** Usage: [fail lexbuf env token_list last_input_needed] (** Usage: [fail lexbuf env token_list last_input_needed]
Raises an error with meaningful hints about what the parsing error was. [lexbuf] is the lexing Raises an error with meaningful hints about what the parsing error was.
buffer state at the failure point, [env] is the Menhir environment and [last_input_needed] is [lexbuf] is the lexing buffer state at the failure point, [env] is the
the last checkpoint of a valid Menhir state before the parsing error. [token_list] is provided Menhir environment and [last_input_needed] is the last checkpoint of a
by things like {!val: Surface.Lexer_common.token_list_language_agnostic} and is used to valid Menhir state before the parsing error. [token_list] is provided by
provide suggestions of the tokens acceptable at the failure point *) things like {!val: Surface.Lexer_common.token_list_language_agnostic} and
let fail (lexbuf : lexbuf) (env : 'semantic_value I.env) is used to provide suggestions of the tokens acceptable at the failure
(token_list : (string * Tokens.token) list) (last_input_needed : 'semantic_value I.env option) point *)
: 'a = let fail
(lexbuf : lexbuf)
(env : 'semantic_value I.env)
(token_list : (string * Tokens.token) list)
(last_input_needed : 'semantic_value I.env option) : 'a =
let wrong_token = Utf8.lexeme lexbuf in let wrong_token = Utf8.lexeme lexbuf in
let acceptable_tokens, last_positions = let acceptable_tokens, last_positions =
match last_input_needed with match last_input_needed with
| Some last_input_needed -> | Some last_input_needed ->
( List.filter ( List.filter
(fun (_, t) -> (fun (_, t) ->
I.acceptable (I.input_needed last_input_needed) t (fst (lexing_positions lexbuf))) I.acceptable
(I.input_needed last_input_needed)
t
(fst (lexing_positions lexbuf)))
token_list, token_list,
Some (I.positions last_input_needed) ) Some (I.positions last_input_needed) )
| None -> (token_list, None) | None -> (token_list, None)
@ -163,23 +185,27 @@ module ParserAux (LocalisedLexer : Lexer_common.LocalisedLexer) = struct
(Printf.sprintf "did you mean %s?" (Printf.sprintf "did you mean %s?"
(String.concat ", or maybe " (String.concat ", or maybe "
(List.map (List.map
(fun (ts, _) -> Cli.with_style syntax_hints_style "\"%s\"" ts) (fun (ts, _) ->
Cli.with_style syntax_hints_style "\"%s\"" ts)
similar_acceptable_tokens))) similar_acceptable_tokens)))
in in
(* The parser has suspended itself because of a syntax error. Stop. *) (* The parser has suspended itself because of a syntax error. Stop. *)
let custom_menhir_message = let custom_menhir_message =
match Parser_errors.message (state env) with match Parser_errors.message (state env) with
| exception Not_found -> | exception Not_found ->
"Message: " ^ Cli.with_style syntax_hints_style "%s" "unexpected token" "Message: "
^ Cli.with_style syntax_hints_style "%s" "unexpected token"
| msg -> | msg ->
"Message: " "Message: "
^ Cli.with_style syntax_hints_style "%s" (String.trim (String.uncapitalize_ascii msg)) ^ Cli.with_style syntax_hints_style "%s"
(String.trim (String.uncapitalize_ascii msg))
in in
let msg = let msg =
match similar_token_msg with match similar_token_msg with
| None -> custom_menhir_message | None -> custom_menhir_message
| Some similar_token_msg -> | Some similar_token_msg ->
Printf.sprintf "%s\nAutosuggestion: %s" custom_menhir_message similar_token_msg Printf.sprintf "%s\nAutosuggestion: %s" custom_menhir_message
similar_token_msg
in in
raise_parser_error raise_parser_error
(Pos.from_lpos (lexing_positions lexbuf)) (Pos.from_lpos (lexing_positions lexbuf))
@ -187,10 +213,12 @@ module ParserAux (LocalisedLexer : Lexer_common.LocalisedLexer) = struct
(Utf8.lexeme lexbuf) msg (Utf8.lexeme lexbuf) msg
(** Main parsing loop *) (** Main parsing loop *)
let rec loop (next_token : unit -> Tokens.token * Lexing.position * Lexing.position) let rec loop
(token_list : (string * Tokens.token) list) (lexbuf : lexbuf) (next_token : unit -> Tokens.token * Lexing.position * Lexing.position)
(last_input_needed : 'semantic_value I.env option) (checkpoint : 'semantic_value I.checkpoint) (token_list : (string * Tokens.token) list)
: Ast.source_file = (lexbuf : lexbuf)
(last_input_needed : 'semantic_value I.env option)
(checkpoint : 'semantic_value I.checkpoint) : Ast.source_file =
match checkpoint with match checkpoint with
| I.InputNeeded env -> | I.InputNeeded env ->
let token = next_token () in let token = next_token () in
@ -205,21 +233,27 @@ module ParserAux (LocalisedLexer : Lexer_common.LocalisedLexer) = struct
(* Cannot happen as we stop at syntax error immediatly *) (* Cannot happen as we stop at syntax error immediatly *)
assert false assert false
(** Stub that wraps the parsing main loop and handles the Menhir/Sedlex type difference for (** Stub that wraps the parsing main loop and handles the Menhir/Sedlex type
[lexbuf]. *) difference for [lexbuf]. *)
let sedlex_with_menhir (lexer' : lexbuf -> Tokens.token) let sedlex_with_menhir
(lexer' : lexbuf -> Tokens.token)
(token_list : (string * Tokens.token) list) (token_list : (string * Tokens.token) list)
(target_rule : Lexing.position -> 'semantic_value I.checkpoint) (lexbuf : lexbuf) : (target_rule : Lexing.position -> 'semantic_value I.checkpoint)
Ast.source_file = (lexbuf : lexbuf) : Ast.source_file =
let lexer : unit -> Tokens.token * Lexing.position * Lexing.position = let lexer : unit -> Tokens.token * Lexing.position * Lexing.position =
with_tokenizer lexer' lexbuf with_tokenizer lexer' lexbuf
in in
try loop lexer token_list lexbuf None (target_rule (fst @@ Sedlexing.lexing_positions lexbuf)) try
loop lexer token_list lexbuf None
(target_rule (fst @@ Sedlexing.lexing_positions lexbuf))
with Sedlexing.MalFormed | Sedlexing.InvalidCodepoint _ -> with Sedlexing.MalFormed | Sedlexing.InvalidCodepoint _ ->
Lexer_common.raise_lexer_error (Pos.from_lpos (lexing_positions lexbuf)) (Utf8.lexeme lexbuf) Lexer_common.raise_lexer_error
(Pos.from_lpos (lexing_positions lexbuf))
(Utf8.lexeme lexbuf)
let commands_or_includes (lexbuf : lexbuf) : Ast.source_file = let commands_or_includes (lexbuf : lexbuf) : Ast.source_file =
sedlex_with_menhir LocalisedLexer.lexer LocalisedLexer.token_list Incremental.source_file lexbuf sedlex_with_menhir LocalisedLexer.lexer LocalisedLexer.token_list
Incremental.source_file lexbuf
end end
module Parser_En = ParserAux (Lexer_en) module Parser_En = ParserAux (Lexer_en)
@ -234,9 +268,10 @@ let localised_parser : Cli.backend_lang -> lexbuf -> Ast.source_file = function
(** {1 Parsing multiple files} *) (** {1 Parsing multiple files} *)
(** Parses a single source file *) (** Parses a single source file *)
let rec parse_source_file (source_file : Pos.input_file) (language : Cli.backend_lang) : Ast.program let rec parse_source_file
= (source_file : Pos.input_file) (language : Cli.backend_lang) : Ast.program =
Cli.debug_print "Parsing %s" (match source_file with FileName s | Contents s -> s); Cli.debug_print "Parsing %s"
(match source_file with FileName s | Contents s -> s);
let lexbuf, input = let lexbuf, input =
match source_file with match source_file with
| FileName source_file -> ( | FileName source_file -> (
@ -246,7 +281,9 @@ let rec parse_source_file (source_file : Pos.input_file) (language : Cli.backend
with Sys_error msg -> Errors.raise_error "%s" msg) with Sys_error msg -> Errors.raise_error "%s" msg)
| Contents contents -> (Sedlexing.Utf8.from_string contents, None) | Contents contents -> (Sedlexing.Utf8.from_string contents, None)
in in
let source_file_name = match source_file with FileName s -> s | Contents _ -> "stdin" in let source_file_name =
match source_file with FileName s -> s | Contents _ -> "stdin"
in
Sedlexing.set_filename lexbuf source_file_name; Sedlexing.set_filename lexbuf source_file_name;
Parse_utils.current_file := source_file_name; Parse_utils.current_file := source_file_name;
let commands = localised_parser language lexbuf in let commands = localised_parser language lexbuf in
@ -257,8 +294,11 @@ let rec parse_source_file (source_file : Pos.input_file) (language : Cli.backend
program_source_files = source_file_name :: program.Ast.program_source_files; program_source_files = source_file_name :: program.Ast.program_source_files;
} }
(** Expands the include directives in a parsing result, thus parsing new source files *) (** Expands the include directives in a parsing result, thus parsing new source
and expand_includes (source_file : string) (commands : Ast.law_structure list) files *)
and expand_includes
(source_file : string)
(commands : Ast.law_structure list)
(language : Cli.backend_lang) : Ast.program = (language : Cli.backend_lang) : Ast.program =
List.fold_left List.fold_left
(fun acc command -> (fun acc command ->
@ -266,19 +306,27 @@ and expand_includes (source_file : string) (commands : Ast.law_structure list)
| Ast.LawInclude (Ast.CatalaFile sub_source) -> | Ast.LawInclude (Ast.CatalaFile sub_source) ->
let source_dir = Filename.dirname source_file in let source_dir = Filename.dirname source_file in
let sub_source = Filename.concat source_dir (Pos.unmark sub_source) in let sub_source = Filename.concat source_dir (Pos.unmark sub_source) in
let includ_program = parse_source_file (FileName sub_source) language in let includ_program =
parse_source_file (FileName sub_source) language
in
{ {
Ast.program_source_files = Ast.program_source_files =
acc.Ast.program_source_files @ includ_program.program_source_files; acc.Ast.program_source_files @ includ_program.program_source_files;
Ast.program_items = acc.Ast.program_items @ includ_program.program_items; Ast.program_items =
acc.Ast.program_items @ includ_program.program_items;
} }
| Ast.LawHeading (heading, commands') -> | Ast.LawHeading (heading, commands') ->
let { Ast.program_items = commands'; Ast.program_source_files = new_sources } = let {
Ast.program_items = commands';
Ast.program_source_files = new_sources;
} =
expand_includes source_file commands' language expand_includes source_file commands' language
in in
{ {
Ast.program_source_files = acc.Ast.program_source_files @ new_sources; Ast.program_source_files =
Ast.program_items = acc.Ast.program_items @ [ Ast.LawHeading (heading, commands') ]; acc.Ast.program_source_files @ new_sources;
Ast.program_items =
acc.Ast.program_items @ [ Ast.LawHeading (heading, commands') ];
} }
| i -> { acc with Ast.program_items = acc.Ast.program_items @ [ i ] }) | i -> { acc with Ast.program_items = acc.Ast.program_items @ [ i ] })
{ Ast.program_source_files = []; Ast.program_items = [] } { Ast.program_source_files = []; Ast.program_items = [] }
@ -286,7 +334,10 @@ and expand_includes (source_file : string) (commands : Ast.law_structure list)
(** {1 API} *) (** {1 API} *)
let parse_top_level_file (source_file : Pos.input_file) (language : Cli.backend_lang) : Ast.program let parse_top_level_file
= (source_file : Pos.input_file) (language : Cli.backend_lang) : Ast.program =
let program = parse_source_file source_file language in let program = parse_source_file source_file language in
{ program with Ast.program_items = law_struct_list_to_tree program.Ast.program_items } {
program with
Ast.program_items = law_struct_list_to_tree program.Ast.program_items;
}

View File

@ -1,15 +1,17 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits (* This file is part of the Catala compiler, a specification language for tax
computation rules. Copyright (C) 2020 Inria, contributor: Denis Merigoux and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
<denis.merigoux@inria.fr> Denis Merigoux <denis.merigoux@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
(** Wrapping module around parser and lexer that offers the (** Wrapping module around parser and lexer that offers the

View File

@ -1,15 +1,17 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits (* This file is part of the Catala compiler, a specification language for tax
computation rules. Copyright (C) 2020 Inria, contributor: Denis Merigoux and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
<denis.merigoux@inria.fr> Denis Merigoux <denis.merigoux@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
(** Interface of the module auto-generated based on "parser.messages". *) (** Interface of the module auto-generated based on "parser.messages". *)

View File

@ -1,15 +1,17 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits (* This file is part of the Catala compiler, a specification language for tax
computation rules. Copyright (C) 2020 Inria, contributor: Denis Merigoux and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
<denis.merigoux@inria.fr> Denis Merigoux <denis.merigoux@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
open Ast open Ast

View File

@ -1,15 +1,17 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits (* This file is part of the Catala compiler, a specification language for tax
computation rules. Copyright (C) 2020 Inria, contributor: Denis Merigoux and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
<denis.merigoux@inria.fr> Denis Merigoux <denis.merigoux@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
val format_primitive_typ : Format.formatter -> Ast.primitive_typ -> unit val format_primitive_typ : Format.formatter -> Ast.primitive_typ -> unit

View File

@ -1,15 +1,18 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits (* This file is part of the Catala compiler, a specification language for tax
computation rules. Copyright (C) 2020 Inria, contributors: Denis Merigoux and social benefits computation rules. Copyright (C) 2020 Inria,
<denis.merigoux@inria.fr>, Emile Rolley <emile.rolley@tuta.io> contributors: Denis Merigoux <denis.merigoux@inria.fr>, Emile Rolley
<emile.rolley@tuta.io>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
type backend_lang = En | Fr | Pl type backend_lang = En | Fr | Pl
@ -18,7 +21,6 @@ type backend_lang = En | Fr | Pl
let source_files : string list ref = ref [] let source_files : string list ref = ref []
let locale_lang : backend_lang ref = ref En let locale_lang : backend_lang ref = ref En
let contents : string ref = ref "" let contents : string ref = ref ""
(** Prints debug information *) (** Prints debug information *)
@ -29,13 +31,9 @@ let style_flag = ref true
(* Max number of digits to show for decimal results *) (* Max number of digits to show for decimal results *)
let max_prec_digits = ref 20 let max_prec_digits = ref 20
let trace_flag = ref false let trace_flag = ref false
let optimize_flag = ref false let optimize_flag = ref false
let disable_counterexamples = ref false let disable_counterexamples = ref false
let avoid_exceptions_flag = ref false let avoid_exceptions_flag = ref false
open Cmdliner open Cmdliner
@ -46,39 +44,46 @@ let file =
& pos 1 (some file) None & pos 1 (some file) None
& info [] ~docv:"FILE" ~doc:"Catala master file to be compiled.") & info [] ~docv:"FILE" ~doc:"Catala master file to be compiled.")
let debug = Arg.(value & flag & info [ "debug"; "d" ] ~doc:"Prints debug information.") let debug =
Arg.(value & flag & info [ "debug"; "d" ] ~doc:"Prints debug information.")
let unstyled = let unstyled =
Arg.( Arg.(
value & flag value & flag
& info [ "unstyled"; "u" ] ~doc:"Removes styling (colors, etc.) from terminal output.") & info [ "unstyled"; "u" ]
~doc:"Removes styling (colors, etc.) from terminal output.")
let optimize = Arg.(value & flag & info [ "optimize"; "O" ] ~doc:"Run compiler optimizations.") let optimize =
Arg.(
value & flag & info [ "optimize"; "O" ] ~doc:"Run compiler optimizations.")
let trace_opt = let trace_opt =
Arg.( Arg.(
value & flag value & flag
& info [ "trace"; "t" ] & info [ "trace"; "t" ]
~doc: ~doc:
"Displays a trace of the interpreter's computation or generates logging instructions in \ "Displays a trace of the interpreter's computation or generates \
translate programs.") logging instructions in translate programs.")
let avoid_exceptions = let avoid_exceptions =
Arg.( Arg.(
value & flag value & flag
& info [ "avoid_exceptions" ] ~doc:"Compiles the default calculus without exceptions") & info [ "avoid_exceptions" ]
~doc:"Compiles the default calculus without exceptions")
let wrap_weaved_output = let wrap_weaved_output =
Arg.( Arg.(
value & flag value & flag
& info [ "wrap"; "w" ] ~doc:"Wraps literate programming output with a minimal preamble.") & info [ "wrap"; "w" ]
~doc:"Wraps literate programming output with a minimal preamble.")
let backend = let backend =
Arg.( Arg.(
required required
& pos 0 (some string) None & pos 0 (some string) None
& info [] ~docv:"COMMAND" & info [] ~docv:"COMMAND"
~doc:"Backend selection (see the list of commands for available options).") ~doc:
"Backend selection (see the list of commands for available options).")
type backend_option = type backend_option =
| Dcalc | Dcalc
@ -98,23 +103,29 @@ let language =
Arg.( Arg.(
value value
& opt (some string) None & opt (some string) None
& info [ "l"; "language" ] ~docv:"LANG" ~doc:"Input language among: en, fr, pl.") & info [ "l"; "language" ] ~docv:"LANG"
~doc:"Input language among: en, fr, pl.")
let max_prec_digits_opt = let max_prec_digits_opt =
Arg.( Arg.(
value value
& opt (some int) None & opt (some int) None
& info [ "p"; "max_digits_printed" ] ~docv:"DIGITS" & info
~doc:"Maximum number of significant digits printed for decimal results (default 20).") [ "p"; "max_digits_printed" ]
~docv:"DIGITS"
~doc:
"Maximum number of significant digits printed for decimal results \
(default 20).")
let disable_counterexamples_opt = let disable_counterexamples_opt =
Arg.( Arg.(
value & flag value & flag
& info [ "disable_counterexamples" ] & info
[ "disable_counterexamples" ]
~doc: ~doc:
"Disables the search for counterexamples in proof mode. Useful when you want a \ "Disables the search for counterexamples in proof mode. Useful when \
deterministic output from the Catala compiler, since provers can have some randomness \ you want a deterministic output from the Catala compiler, since \
in them.") provers can have some randomness in them.")
let ex_scope = let ex_scope =
Arg.( Arg.(
@ -128,60 +139,74 @@ let output =
& opt (some string) None & opt (some string) None
& info [ "output"; "o" ] ~docv:"OUTPUT" & info [ "output"; "o" ] ~docv:"OUTPUT"
~doc: ~doc:
"$(i, OUTPUT) is the file that will contain the output of the compiler. Defaults to \ "$(i, OUTPUT) is the file that will contain the output of the \
$(i,FILE).$(i,EXT) where $(i,EXT) depends on the chosen backend.") compiler. Defaults to $(i,FILE).$(i,EXT) where $(i,EXT) depends on \
the chosen backend.")
let catala_t f = let catala_t f =
Term.( Term.(
const f $ file $ debug $ unstyled $ wrap_weaved_output $ avoid_exceptions $ backend $ language const f $ file $ debug $ unstyled $ wrap_weaved_output $ avoid_exceptions
$ max_prec_digits_opt $ trace_opt $ disable_counterexamples_opt $ optimize $ ex_scope $ output) $ backend $ language $ max_prec_digits_opt $ trace_opt
$ disable_counterexamples_opt $ optimize $ ex_scope $ output)
let version = "0.5.0" let version = "0.5.0"
let info = let info =
let doc = let doc =
"Compiler for Catala, a specification language for tax and social benefits computation rules." "Compiler for Catala, a specification language for tax and social benefits \
computation rules."
in in
let man = let man =
[ [
`S Manpage.s_description; `S Manpage.s_description;
`P `P
"Catala is a domain-specific language for deriving faithful-by-construction algorithms \ "Catala is a domain-specific language for deriving \
from legislative texts."; faithful-by-construction algorithms from legislative texts.";
`S Manpage.s_commands; `S Manpage.s_commands;
`I `I
( "$(b,Intepret)", ( "$(b,Intepret)",
"Runs the interpreter on the Catala program, executing the scope specified by the \ "Runs the interpreter on the Catala program, executing the scope \
$(b,-s) option assuming no additional external inputs." ); specified by the $(b,-s) option assuming no additional external \
`I ("$(b,Typecheck)", "Parses and typechecks a Catala program, without interpreting it."); inputs." );
`I
( "$(b,Typecheck)",
"Parses and typechecks a Catala program, without interpreting it." );
`I `I
( "$(b,Proof)", ( "$(b,Proof)",
"Generates and proves verification conditions about the well-behaved execution of the \ "Generates and proves verification conditions about the well-behaved \
Catala program." ); execution of the Catala program." );
`I ("$(b,OCaml)", "Generates an OCaml translation of the Catala program."); `I ("$(b,OCaml)", "Generates an OCaml translation of the Catala program.");
`I ("$(b,Python)", "Generates a Python translation of the Catala program."); `I ("$(b,Python)", "Generates a Python translation of the Catala program.");
`I ("$(b,LaTeX)", "Weaves a LaTeX literate programming output of the Catala program."); `I
`I ("$(b,HTML)", "Weaves an HTML literate programming output of the Catala program."); ( "$(b,LaTeX)",
"Weaves a LaTeX literate programming output of the Catala program." );
`I
( "$(b,HTML)",
"Weaves an HTML literate programming output of the Catala program." );
`I `I
( "$(b,Makefile)", ( "$(b,Makefile)",
"Generates a Makefile-compatible list of the file dependencies of a Catala program." ); "Generates a Makefile-compatible list of the file dependencies of a \
Catala program." );
`I `I
( "$(b,Scopelang)", ( "$(b,Scopelang)",
"Prints a debugging verbatim of the scope language intermediate representation of the \ "Prints a debugging verbatim of the scope language intermediate \
Catala program. Use the $(b,-s) option to restrict the output to a particular scope." ); representation of the Catala program. Use the $(b,-s) option to \
restrict the output to a particular scope." );
`I `I
( "$(b,Dcalc)", ( "$(b,Dcalc)",
"Prints a debugging verbatim of the default calculus intermediate representation of the \ "Prints a debugging verbatim of the default calculus intermediate \
Catala program. Use the $(b,-s) option to restrict the output to a particular scope." ); representation of the Catala program. Use the $(b,-s) option to \
restrict the output to a particular scope." );
`I `I
( "$(b,Lcalc)", ( "$(b,Lcalc)",
"Prints a debugging verbatim of the lambda calculus intermediate representation of the \ "Prints a debugging verbatim of the lambda calculus intermediate \
Catala program. Use the $(b,-s) option to restrict the output to a particular scope." ); representation of the Catala program. Use the $(b,-s) option to \
restrict the output to a particular scope." );
`I `I
( "$(b,Scalc)", ( "$(b,Scalc)",
"Prints a debugging verbatim of the statement calculus intermediate representation of \ "Prints a debugging verbatim of the statement calculus intermediate \
the Catala program. Use the $(b,-s) option to restrict the output to a particular \ representation of the Catala program. Use the $(b,-s) option to \
scope." ); restrict the output to a particular scope." );
`S Manpage.s_authors; `S Manpage.s_authors;
`P "The authors are listed by alphabetical order."; `P "The authors are listed by alphabetical order.";
`P "Nicolas Chataing <nicolas.chataing@ens.fr>"; `P "Nicolas Chataing <nicolas.chataing@ens.fr>";
@ -194,7 +219,8 @@ let info =
`Pre "catala Interpret -s Foo file.catala_en"; `Pre "catala Interpret -s Foo file.catala_en";
`Pre "catala Ocaml -o target/file.ml file.catala_en"; `Pre "catala Ocaml -o target/file.ml file.catala_en";
`S Manpage.s_bugs; `S Manpage.s_bugs;
`P "Please file bug reports at https://github.com/CatalaLang/catala/issues"; `P
"Please file bug reports at https://github.com/CatalaLang/catala/issues";
] ]
in in
let exits = Term.default_exits @ [ Term.exit_info ~doc:"on error." 1 ] in let exits = Term.default_exits @ [ Term.exit_info ~doc:"on error." 1 ] in
@ -206,12 +232,14 @@ let info =
let time : float ref = ref (Unix.gettimeofday ()) let time : float ref = ref (Unix.gettimeofday ())
let with_style (styles : ANSITerminal.style list) (str : ('a, unit, string) format) = let with_style
(styles : ANSITerminal.style list) (str : ('a, unit, string) format) =
if !style_flag then ANSITerminal.sprintf styles str else Printf.sprintf str if !style_flag then ANSITerminal.sprintf styles str else Printf.sprintf str
let format_with_style (styles : ANSITerminal.style list) fmt (str : string) = let format_with_style (styles : ANSITerminal.style list) fmt (str : string) =
if !style_flag then if !style_flag then
Format.pp_print_as fmt (String.length str) (ANSITerminal.sprintf styles "%s" str) Format.pp_print_as fmt (String.length str)
(ANSITerminal.sprintf styles "%s" str)
else Format.pp_print_string fmt str else Format.pp_print_string fmt str
let time_marker () = let time_marker () =
@ -221,7 +249,9 @@ let time_marker () =
let delta = (new_time -. old_time) *. 1000. in let delta = (new_time -. old_time) *. 1000. in
if delta > 50. then if delta > 50. then
Printf.printf "%s" Printf.printf "%s"
(with_style [ ANSITerminal.Bold; ANSITerminal.black ] "[TIME] %.0f ms\n" delta) (with_style
[ ANSITerminal.Bold; ANSITerminal.black ]
"[TIME] %.0f ms\n" delta)
(** Prints [\[DEBUG\]] in purple on the terminal standard output *) (** Prints [\[DEBUG\]] in purple on the terminal standard output *)
let debug_marker () = let debug_marker () =
@ -229,29 +259,35 @@ let debug_marker () =
with_style [ ANSITerminal.Bold; ANSITerminal.magenta ] "[DEBUG] " with_style [ ANSITerminal.Bold; ANSITerminal.magenta ] "[DEBUG] "
(** Prints [\[ERROR\]] in red on the terminal error output *) (** Prints [\[ERROR\]] in red on the terminal error output *)
let error_marker () = with_style [ ANSITerminal.Bold; ANSITerminal.red ] "[ERROR] " let error_marker () =
with_style [ ANSITerminal.Bold; ANSITerminal.red ] "[ERROR] "
(** Prints [\[WARNING\]] in yellow on the terminal standard output *) (** Prints [\[WARNING\]] in yellow on the terminal standard output *)
let warning_marker () = with_style [ ANSITerminal.Bold; ANSITerminal.yellow ] "[WARNING] " let warning_marker () =
with_style [ ANSITerminal.Bold; ANSITerminal.yellow ] "[WARNING] "
(** Prints [\[RESULT\]] in green on the terminal standard output *) (** Prints [\[RESULT\]] in green on the terminal standard output *)
let result_marker () = with_style [ ANSITerminal.Bold; ANSITerminal.green ] "[RESULT] " let result_marker () =
with_style [ ANSITerminal.Bold; ANSITerminal.green ] "[RESULT] "
(** Prints [\[LOG\]] in red on the terminal error output *) (** Prints [\[LOG\]] in red on the terminal error output *)
let log_marker () = with_style [ ANSITerminal.Bold; ANSITerminal.black ] "[LOG] " let log_marker () =
with_style [ ANSITerminal.Bold; ANSITerminal.black ] "[LOG] "
(**{2 Printers}*) (**{2 Printers}*)
(** All the printers below print their argument after the correct marker *) (** All the printers below print their argument after the correct marker *)
let concat_with_line_depending_prefix_and_suffix (prefix : int -> string) (suffix : int -> string) let concat_with_line_depending_prefix_and_suffix
(ss : string list) = (prefix : int -> string) (suffix : int -> string) (ss : string list) =
match ss with match ss with
| hd :: rest -> | hd :: rest ->
let out, _ = let out, _ =
List.fold_left List.fold_left
(fun (acc, i) s -> (fun (acc, i) s ->
((acc ^ prefix i ^ s ^ if i = List.length ss - 1 then "" else suffix i), i + 1)) ( (acc ^ prefix i ^ s
^ if i = List.length ss - 1 then "" else suffix i),
i + 1 ))
((prefix 0 ^ hd ^ if 0 = List.length ss - 1 then "" else suffix 0), 1) ((prefix 0 ^ hd ^ if 0 = List.length ss - 1 then "" else suffix 0), 1)
rest rest
in in
@ -270,7 +306,8 @@ let debug_print (format : ('a, out_channel, unit) format) =
else Printf.ifprintf stdout format else Printf.ifprintf stdout format
let debug_format (format : ('a, Format.formatter, unit) format) = let debug_format (format : ('a, Format.formatter, unit) format) =
if !debug_flag then Format.printf ("%s@[<hov>" ^^ format ^^ "@]@.") (debug_marker ()) if !debug_flag then
Format.printf ("%s@[<hov>" ^^ format ^^ "@]@.") (debug_marker ())
else Format.ifprintf Format.std_formatter format else Format.ifprintf Format.std_formatter format
let error_print (format : ('a, out_channel, unit) format) = let error_print (format : ('a, out_channel, unit) format) =

View File

@ -1,15 +1,18 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits (* This file is part of the Catala compiler, a specification language for tax
computation rules. Copyright (C) 2020 Inria, contributors: Denis Merigoux and social benefits computation rules. Copyright (C) 2020 Inria,
<denis.merigoux@inria.fr>, Emile Rolley <emile.rolley@tuta.io> contributors: Denis Merigoux <denis.merigoux@inria.fr>, Emile Rolley
<emile.rolley@tuta.io>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
type backend_lang = En | Fr | Pl type backend_lang = En | Fr | Pl
@ -20,9 +23,7 @@ val source_files : string list ref
(** Source files to be compiled *) (** Source files to be compiled *)
val locale_lang : backend_lang ref val locale_lang : backend_lang ref
val contents : string ref val contents : string ref
val debug_flag : bool ref val debug_flag : bool ref
val style_flag : bool ref val style_flag : bool ref
@ -44,15 +45,10 @@ val avoid_exceptions_flag : bool ref
(** {2 CLI terms} *) (** {2 CLI terms} *)
val file : string Cmdliner.Term.t val file : string Cmdliner.Term.t
val debug : bool Cmdliner.Term.t val debug : bool Cmdliner.Term.t
val unstyled : bool Cmdliner.Term.t val unstyled : bool Cmdliner.Term.t
val trace_opt : bool Cmdliner.Term.t val trace_opt : bool Cmdliner.Term.t
val wrap_weaved_output : bool Cmdliner.Term.t val wrap_weaved_output : bool Cmdliner.Term.t
val backend : string Cmdliner.Term.t val backend : string Cmdliner.Term.t
type backend_option = type backend_option =
@ -70,11 +66,8 @@ type backend_option =
| Typecheck | Typecheck
val language : string option Cmdliner.Term.t val language : string option Cmdliner.Term.t
val max_prec_digits_opt : int option Cmdliner.Term.t val max_prec_digits_opt : int option Cmdliner.Term.t
val ex_scope : string option Cmdliner.Term.t val ex_scope : string option Cmdliner.Term.t
val output : string option Cmdliner.Term.t val output : string option Cmdliner.Term.t
val catala_t : val catala_t :
@ -97,7 +90,6 @@ val catala_t :
[catala_t file debug unstyled wrap_weaved_output avoid_exceptions backend language max_prec_digits_opt trace_opt disable_counterexamples optimize ex_scope output] *) [catala_t file debug unstyled wrap_weaved_output avoid_exceptions backend language max_prec_digits_opt trace_opt disable_counterexamples optimize ex_scope output] *)
val version : string val version : string
val info : Cmdliner.Term.info val info : Cmdliner.Term.info
(**{1 Terminal formatting}*) (**{1 Terminal formatting}*)
@ -106,16 +98,13 @@ val info : Cmdliner.Term.info
val with_style : ANSITerminal.style list -> ('a, unit, string) format -> 'a val with_style : ANSITerminal.style list -> ('a, unit, string) format -> 'a
val format_with_style : ANSITerminal.style list -> Format.formatter -> string -> unit val format_with_style :
ANSITerminal.style list -> Format.formatter -> string -> unit
val debug_marker : unit -> string val debug_marker : unit -> string
val error_marker : unit -> string val error_marker : unit -> string
val warning_marker : unit -> string val warning_marker : unit -> string
val result_marker : unit -> string val result_marker : unit -> string
val log_marker : unit -> string val log_marker : unit -> string
(**{2 Printers}*) (**{2 Printers}*)
@ -129,17 +118,10 @@ val add_prefix_to_each_line : string -> (int -> string) -> string
(** The int argument of the prefix corresponds to the line number, starting at 0 *) (** The int argument of the prefix corresponds to the line number, starting at 0 *)
val debug_print : ('a, out_channel, unit) format -> 'a val debug_print : ('a, out_channel, unit) format -> 'a
val debug_format : ('a, Format.formatter, unit) format -> 'a val debug_format : ('a, Format.formatter, unit) format -> 'a
val error_print : ('a, out_channel, unit) format -> 'a val error_print : ('a, out_channel, unit) format -> 'a
val warning_print : ('a, out_channel, unit) format -> 'a val warning_print : ('a, out_channel, unit) format -> 'a
val result_print : ('a, out_channel, unit) format -> 'a val result_print : ('a, out_channel, unit) format -> 'a
val result_format : ('a, Format.formatter, unit) format -> 'a val result_format : ('a, Format.formatter, unit) format -> 'a
val log_print : ('a, out_channel, unit) format -> 'a val log_print : ('a, out_channel, unit) format -> 'a
val log_format : ('a, Format.formatter, unit) format -> 'a val log_format : ('a, Format.formatter, unit) format -> 'a

View File

@ -1,15 +1,17 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits (* This file is part of the Catala compiler, a specification language for tax
computation rules. Copyright (C) 2020 Inria, contributor: Denis Merigoux and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
<denis.merigoux@inria.fr> Denis Merigoux <denis.merigoux@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
(** Error formatting and helper functions *) (** Error formatting and helper functions *)
@ -17,11 +19,12 @@
(** {1 Error exception and printing} *) (** {1 Error exception and printing} *)
exception StructuredError of (string * (string option * Pos.t) list) exception StructuredError of (string * (string option * Pos.t) list)
(** The payload of the expression is a main error message, with a list of secondary positions (** The payload of the expression is a main error message, with a list of
related to the error, each carrying an optional secondary message to describe what is pointed by secondary positions related to the error, each carrying an optional
the position. *) secondary message to describe what is pointed by the position. *)
let print_structured_error (msg : string) (pos : (string option * Pos.t) list) : string = let print_structured_error (msg : string) (pos : (string option * Pos.t) list) :
string =
Printf.sprintf "%s%s%s" msg Printf.sprintf "%s%s%s" msg
(if List.length pos = 0 then "" else "\n\n") (if List.length pos = 0 then "" else "\n\n")
(String.concat "\n\n" (String.concat "\n\n"
@ -35,17 +38,22 @@ let print_structured_error (msg : string) (pos : (string option * Pos.t) list) :
(** {1 Error exception and printing} *) (** {1 Error exception and printing} *)
let raise_spanned_error ?(span_msg : string option) (span : Pos.t) format = let raise_spanned_error ?(span_msg : string option) (span : Pos.t) format =
Format.kasprintf (fun msg -> raise (StructuredError (msg, [ (span_msg, span) ]))) format Format.kasprintf
(fun msg -> raise (StructuredError (msg, [ (span_msg, span) ])))
format
let raise_multispanned_error (spans : (string option * Pos.t) list) format = let raise_multispanned_error (spans : (string option * Pos.t) list) format =
Format.kasprintf (fun msg -> raise (StructuredError (msg, spans))) format Format.kasprintf (fun msg -> raise (StructuredError (msg, spans))) format
let raise_error format = Format.kasprintf (fun msg -> raise (StructuredError (msg, []))) format let raise_error format =
Format.kasprintf (fun msg -> raise (StructuredError (msg, []))) format
(** {1 Warning printing}*) (** {1 Warning printing}*)
let format_multispanned_warning (pos : (string option * Pos.t) list) format = let format_multispanned_warning (pos : (string option * Pos.t) list) format =
Format.kasprintf (fun msg -> Cli.warning_print "%s" (print_structured_error msg pos)) format Format.kasprintf
(fun msg -> Cli.warning_print "%s" (print_structured_error msg pos))
format
let format_spanned_warning ?(span_msg : string option) (span : Pos.t) format = let format_spanned_warning ?(span_msg : string option) (span : Pos.t) format =
format_multispanned_warning [ (span_msg, span) ] format format_multispanned_warning [ (span_msg, span) ] format

View File

@ -1,15 +1,17 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits (* This file is part of the Catala compiler, a specification language for tax
computation rules. Copyright (C) 2020 Inria, contributor: Denis Merigoux and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
<denis.merigoux@inria.fr> Denis Merigoux <denis.merigoux@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
(** Error formatting and helper functions *) (** Error formatting and helper functions *)
@ -17,9 +19,9 @@
(** {1 Error exception and printing} *) (** {1 Error exception and printing} *)
exception StructuredError of (string * (string option * Pos.t) list) exception StructuredError of (string * (string option * Pos.t) list)
(** The payload of the expression is a main error message, with a list of secondary positions (** The payload of the expression is a main error message, with a list of
related to the error, each carrying an optional secondary message to describe what is pointed by secondary positions related to the error, each carrying an optional
the position. *) secondary message to describe what is pointed by the position. *)
val print_structured_error : string -> (string option * Pos.t) list -> string val print_structured_error : string -> (string option * Pos.t) list -> string
@ -38,6 +40,7 @@ val raise_error : ('a, Format.formatter, unit, 'b) format4 -> 'a
val format_multispanned_warning : val format_multispanned_warning :
(string option * Pos.t) list -> ('a, Format.formatter, unit) format -> 'a (string option * Pos.t) list -> ('a, Format.formatter, unit) format -> 'a
val format_spanned_warning : ?span_msg:string -> Pos.t -> ('a, Format.formatter, unit) format -> 'a val format_spanned_warning :
?span_msg:string -> Pos.t -> ('a, Format.formatter, unit) format -> 'a
val format_warning : ('a, Format.formatter, unit) format -> 'a val format_warning : ('a, Format.formatter, unit) format -> 'a

View File

@ -1,31 +1,46 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits (* This file is part of the Catala compiler, a specification language for tax
computation rules. Copyright (C) 2020 Inria, contributor: Denis Merigoux and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
<denis.merigoux@inria.fr> Denis Merigoux <denis.merigoux@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
type t = { code_pos : Lexing.position * Lexing.position; law_pos : string list } type t = { code_pos : Lexing.position * Lexing.position; law_pos : string list }
let from_lpos (p : Lexing.position * Lexing.position) : t = { code_pos = p; law_pos = [] } let from_lpos (p : Lexing.position * Lexing.position) : t =
{ code_pos = p; law_pos = [] }
let from_info (file : string) (sline : int) (scol : int) (eline : int) (ecol : int) : t = let from_info
(file : string) (sline : int) (scol : int) (eline : int) (ecol : int) : t =
let spos = let spos =
{ Lexing.pos_fname = file; Lexing.pos_lnum = sline; Lexing.pos_cnum = scol; Lexing.pos_bol = 1 } {
Lexing.pos_fname = file;
Lexing.pos_lnum = sline;
Lexing.pos_cnum = scol;
Lexing.pos_bol = 1;
}
in in
let epos = let epos =
{ Lexing.pos_fname = file; Lexing.pos_lnum = eline; Lexing.pos_cnum = ecol; Lexing.pos_bol = 1 } {
Lexing.pos_fname = file;
Lexing.pos_lnum = eline;
Lexing.pos_cnum = ecol;
Lexing.pos_bol = 1;
}
in in
{ code_pos = (spos, epos); law_pos = [] } { code_pos = (spos, epos); law_pos = [] }
let overwrite_law_info (pos : t) (law_pos : string list) : t = { pos with law_pos } let overwrite_law_info (pos : t) (law_pos : string list) : t =
{ pos with law_pos }
let get_law_info (pos : t) : string list = pos.law_pos let get_law_info (pos : t) : string list = pos.law_pos
@ -51,7 +66,8 @@ type input_file = FileName of string | Contents of string
let to_string (pos : t) : string = let to_string (pos : t) : string =
let s, e = pos.code_pos in let s, e = pos.code_pos in
Printf.sprintf "in file %s, from %d:%d to %d:%d" s.Lexing.pos_fname s.Lexing.pos_lnum Printf.sprintf "in file %s, from %d:%d to %d:%d" s.Lexing.pos_fname
s.Lexing.pos_lnum
(s.Lexing.pos_cnum - s.Lexing.pos_bol + 1) (s.Lexing.pos_cnum - s.Lexing.pos_bol + 1)
e.Lexing.pos_lnum e.Lexing.pos_lnum
(e.Lexing.pos_cnum - e.Lexing.pos_bol + 1) (e.Lexing.pos_cnum - e.Lexing.pos_bol + 1)
@ -107,11 +123,15 @@ let retrieve_loc_text (pos : t) : string =
if line_no = sline && line_no = eline then if line_no = sline && line_no = eline then
Cli.with_style error_indicator_style "%*s" Cli.with_style error_indicator_style "%*s"
(get_end_column pos - 1) (get_end_column pos - 1)
(String.make (max (get_end_column pos - get_start_column pos) 0) '^') (String.make
(max (get_end_column pos - get_start_column pos) 0)
'^')
else if line_no = sline && line_no <> eline then else if line_no = sline && line_no <> eline then
Cli.with_style error_indicator_style "%*s" Cli.with_style error_indicator_style "%*s"
(String.length line - 1) (String.length line - 1)
(String.make (max (String.length line - get_start_column pos) 0) '^') (String.make
(max (String.length line - get_start_column pos) 0)
'^')
else if line_no <> sline && line_no <> eline then else if line_no <> sline && line_no <> eline then
Cli.with_style error_indicator_style "%*s%s" line_indent "" Cli.with_style error_indicator_style "%*s%s" line_indent ""
(String.make (max (String.length line - line_indent) 0) '^') (String.make (max (String.length line - line_indent) 0) '^')
@ -127,8 +147,10 @@ let retrieve_loc_text (pos : t) : string =
match input_line_opt () with match input_line_opt () with
| Some line -> | Some line ->
if n < sline - include_extra_count then get_lines (n + 1) if n < sline - include_extra_count then get_lines (n + 1)
else if n >= sline - include_extra_count && n <= eline + include_extra_count then else if
print_matched_line line n :: get_lines (n + 1) n >= sline - include_extra_count
&& n <= eline + include_extra_count
then print_matched_line line n :: get_lines (n + 1)
else [] else []
| None -> [] | None -> []
in in
@ -137,7 +159,10 @@ let retrieve_loc_text (pos : t) : string =
let legal_pos_lines = let legal_pos_lines =
List.rev List.rev
(List.map (List.map
(fun s -> Re.Pcre.substitute ~rex:(Re.Pcre.regexp "\n\\s*") ~subst:(fun _ -> " ") s) (fun s ->
Re.Pcre.substitute ~rex:(Re.Pcre.regexp "\n\\s*")
~subst:(fun _ -> " ")
s)
pos.law_pos) pos.law_pos)
in in
(match oc with None -> () | Some oc -> close_in oc); (match oc with None -> () | Some oc -> close_in oc);
@ -150,20 +175,28 @@ let retrieve_loc_text (pos : t) : string =
cur_line >= sline cur_line >= sline
&& cur_line <= sline + (2 * (eline - sline)) && cur_line <= sline + (2 * (eline - sline))
&& cur_line mod 2 = sline mod 2 && cur_line mod 2 = sline mod 2
then Cli.with_style blue_style "%*d | " spaces (sline + ((cur_line - sline) / 2)) then
else if cur_line >= sline - include_extra_count && cur_line < sline then Cli.with_style blue_style "%*d | " spaces
Cli.with_style blue_style "%*d | " spaces cur_line (sline + ((cur_line - sline) / 2))
else if cur_line >= sline - include_extra_count && cur_line < sline
then Cli.with_style blue_style "%*d | " spaces cur_line
else if else if
cur_line <= sline + (2 * (eline - sline)) + 1 + include_extra_count cur_line
<= sline + (2 * (eline - sline)) + 1 + include_extra_count
&& cur_line > sline + (2 * (eline - sline)) + 1 && cur_line > sline + (2 * (eline - sline)) + 1
then Cli.with_style blue_style "%*d | " spaces (cur_line - (eline - sline + 1)) then
Cli.with_style blue_style "%*d | " spaces
(cur_line - (eline - sline + 1))
else Cli.with_style blue_style "%*s | " spaces "")) else Cli.with_style blue_style "%*s | " spaces ""))
(Cli.add_prefix_to_each_line (Cli.add_prefix_to_each_line
(Printf.sprintf "%s" (Printf.sprintf "%s"
(String.concat "\n" (String.concat "\n"
(List.map (fun l -> Cli.with_style blue_style "%s" l) legal_pos_lines))) (List.map
(fun l -> Cli.with_style blue_style "%s" l)
legal_pos_lines)))
(fun i -> (fun i ->
if i = 0 then Cli.with_style blue_style "%*s + " (spaces + (2 * i)) "" if i = 0 then
Cli.with_style blue_style "%*s + " (spaces + (2 * i)) ""
else Cli.with_style blue_style "%*s+-+ " (spaces + (2 * i) - 1) "")) else Cli.with_style blue_style "%*s+-+ " (spaces + (2 * i) - 1) ""))
with Sys_error _ -> "Location:" ^ to_string pos with Sys_error _ -> "Location:" ^ to_string pos
@ -171,18 +204,19 @@ type 'a marked = 'a * t
let no_pos : t = let no_pos : t =
let zero_pos = let zero_pos =
{ Lexing.pos_fname = ""; Lexing.pos_lnum = 0; Lexing.pos_cnum = 0; Lexing.pos_bol = 0 } {
Lexing.pos_fname = "";
Lexing.pos_lnum = 0;
Lexing.pos_cnum = 0;
Lexing.pos_bol = 0;
}
in in
{ code_pos = (zero_pos, zero_pos); law_pos = [] } { code_pos = (zero_pos, zero_pos); law_pos = [] }
let mark pos e : 'a marked = (e, pos) let mark pos e : 'a marked = (e, pos)
let unmark ((x, _) : 'a marked) : 'a = x let unmark ((x, _) : 'a marked) : 'a = x
let get_position ((_, x) : 'a marked) : t = x let get_position ((_, x) : 'a marked) : t = x
let map_under_mark (f : 'a -> 'b) ((x, y) : 'a marked) : 'b marked = (f x, y) let map_under_mark (f : 'a -> 'b) ((x, y) : 'a marked) : 'b marked = (f x, y)
let same_pos_as (x : 'a) ((_, y) : 'b marked) : 'a marked = (x, y) let same_pos_as (x : 'a) ((_, y) : 'b marked) : 'a marked = (x, y)
let unmark_option (x : 'a marked option) : 'a option = let unmark_option (x : 'a marked option) : 'a option =
@ -191,16 +225,23 @@ let unmark_option (x : 'a marked option) : 'a option =
class ['self] marked_map = class ['self] marked_map =
object (_self : 'self) object (_self : 'self)
constraint constraint
'self = < visit_marked : 'a. ('env -> 'a -> 'a) -> 'env -> 'a marked -> 'a marked ; .. > 'self = < visit_marked :
'a. ('env -> 'a -> 'a) -> 'env -> 'a marked -> 'a marked
; .. >
method visit_marked : 'a. ('env -> 'a -> 'a) -> 'env -> 'a marked -> 'a marked = method visit_marked
: 'a. ('env -> 'a -> 'a) -> 'env -> 'a marked -> 'a marked =
fun f env x -> same_pos_as (f env (unmark x)) x fun f env x -> same_pos_as (f env (unmark x)) x
end end
class ['self] marked_iter = class ['self] marked_iter =
object (_self : 'self) object (_self : 'self)
constraint 'self = < visit_marked : 'a. ('env -> 'a -> unit) -> 'env -> 'a marked -> unit ; .. > constraint
'self = < visit_marked :
'a. ('env -> 'a -> unit) -> 'env -> 'a marked -> unit
; .. >
method visit_marked : 'a. ('env -> 'a -> unit) -> 'env -> 'a marked -> unit = method visit_marked : 'a. ('env -> 'a -> unit) -> 'env -> 'a marked -> unit
=
fun f env x -> f env (unmark x) fun f env x -> f env (unmark x)
end end

View File

@ -1,42 +1,37 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits (* This file is part of the Catala compiler, a specification language for tax
computation rules. Copyright (C) 2020 Inria, contributor: Denis Merigoux and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
<denis.merigoux@inria.fr> Denis Merigoux <denis.merigoux@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
(** Source code position *) (** Source code position *)
type t type t
(** A position in the source code is a file, as well as begin and end location of the form col:line *) (** A position in the source code is a file, as well as begin and end location
of the form col:line *)
(** Custom visitor for the [Pos.marked] type *) (** Custom visitor for the [Pos.marked] type *)
(**{2 Constructor and getters}*) (**{2 Constructor and getters}*)
val from_lpos : Lexing.position * Lexing.position -> t val from_lpos : Lexing.position * Lexing.position -> t
val from_info : string -> int -> int -> int -> int -> t val from_info : string -> int -> int -> int -> int -> t
val overwrite_law_info : t -> string list -> t val overwrite_law_info : t -> string list -> t
val get_law_info : t -> string list val get_law_info : t -> string list
val get_start_line : t -> int val get_start_line : t -> int
val get_start_column : t -> int val get_start_column : t -> int
val get_end_line : t -> int val get_end_line : t -> int
val get_end_column : t -> int val get_end_column : t -> int
val get_file : t -> string val get_file : t -> string
type input_file = FileName of string | Contents of string type input_file = FileName of string | Contents of string
@ -54,26 +49,23 @@ val to_string_short : t -> string
{v <file>;<start_line>:<start_col>--<end_line>:<end_col> v} *) {v <file>;<start_line>:<start_col>--<end_line>:<end_col> v} *)
val retrieve_loc_text : t -> string val retrieve_loc_text : t -> string
(** Open the file corresponding to the position and retrieves the text concerned by the position *) (** Open the file corresponding to the position and retrieves the text concerned
by the position *)
(**{2 AST markings}*) (**{2 AST markings}*)
type 'a marked = 'a * t type 'a marked = 'a * t
(** Everything related to the source code should keep its position stored, to improve error messages *) (** Everything related to the source code should keep its position stored, to
improve error messages *)
val no_pos : t val no_pos : t
(** Placeholder position *) (** Placeholder position *)
val mark : t -> 'a -> 'a marked val mark : t -> 'a -> 'a marked
val unmark : 'a marked -> 'a val unmark : 'a marked -> 'a
val get_position : 'a marked -> t val get_position : 'a marked -> t
val map_under_mark : ('a -> 'b) -> 'a marked -> 'b marked val map_under_mark : ('a -> 'b) -> 'a marked -> 'b marked
val same_pos_as : 'a -> 'b marked -> 'a marked val same_pos_as : 'a -> 'b marked -> 'a marked
val unmark_option : 'a marked option -> 'a option val unmark_option : 'a marked option -> 'a option
(** Visitors *) (** Visitors *)
@ -81,14 +73,20 @@ val unmark_option : 'a marked option -> 'a option
class ['self] marked_map : class ['self] marked_map :
object ('self) object ('self)
constraint constraint
'self = < visit_marked : 'a. ('env -> 'a -> 'a) -> 'env -> 'a marked -> 'a marked ; .. > 'self = < visit_marked :
'a. ('env -> 'a -> 'a) -> 'env -> 'a marked -> 'a marked
; .. >
method visit_marked : 'a. ('env -> 'a -> 'a) -> 'env -> 'a marked -> 'a marked method visit_marked :
'a. ('env -> 'a -> 'a) -> 'env -> 'a marked -> 'a marked
end end
class ['self] marked_iter : class ['self] marked_iter :
object ('self) object ('self)
constraint 'self = < visit_marked : 'a. ('env -> 'a -> unit) -> 'env -> 'a marked -> unit ; .. > constraint
'self = < visit_marked :
'a. ('env -> 'a -> unit) -> 'env -> 'a marked -> unit
; .. >
method visit_marked : 'a. ('env -> 'a -> unit) -> 'env -> 'a marked -> unit method visit_marked : 'a. ('env -> 'a -> unit) -> 'env -> 'a marked -> unit
end end

View File

@ -1,15 +1,17 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits (* This file is part of the Catala compiler, a specification language for tax
computation rules. Copyright (C) 2020 Inria, contributor: Denis Merigoux and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
<denis.merigoux@inria.fr> Denis Merigoux <denis.merigoux@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
module type Info = sig module type Info = sig
@ -20,23 +22,17 @@ end
module type Id = sig module type Id = sig
type t type t
type info type info
val fresh : info -> t val fresh : info -> t
val get_info : t -> info val get_info : t -> info
val compare : t -> t -> int val compare : t -> t -> int
val format_t : Format.formatter -> t -> unit val format_t : Format.formatter -> t -> unit
val hash : t -> int val hash : t -> int
end end
module Make (X : Info) () : Id with type info = X.info = struct module Make (X : Info) () : Id with type info = X.info = struct
type t = { id : int; info : X.info } type t = { id : int; info : X.info }
type info = X.info type info = X.info
let counter = ref 0 let counter = ref 0
@ -46,7 +42,6 @@ module Make (X : Info) () : Id with type info = X.info = struct
{ id = !counter; info } { id = !counter; info }
let get_info (uid : t) : X.info = uid.info let get_info (uid : t) : X.info = uid.info
let compare (x : t) (y : t) : int = compare x.id y.id let compare (x : t) (y : t) : int = compare x.id y.id
let format_t (fmt : Format.formatter) (x : t) : unit = let format_t (fmt : Format.formatter) (x : t) : unit =

View File

@ -1,15 +1,17 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits (* This file is part of the Catala compiler, a specification language for tax
computation rules. Copyright (C) 2020 Inria, contributor: Denis Merigoux and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
<denis.merigoux@inria.fr> Denis Merigoux <denis.merigoux@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
(** Global identifiers factories using a generative functor *) (** Global identifiers factories using a generative functor *)
@ -22,28 +24,25 @@ module type Info = sig
end end
module MarkedString : Info with type info = string Pos.marked module MarkedString : Info with type info = string Pos.marked
(** The only kind of information carried in Catala identifiers is the original string of the (** The only kind of information carried in Catala identifiers is the original
identifier annotated with the position where it is declared or used. *) string of the identifier annotated with the position where it is declared or
used. *)
(** Identifiers have abstract types, but are comparable so they can be used as keys in maps or sets. (** Identifiers have abstract types, but are comparable so they can be used as
Their underlying information can be retrieved at any time. *) keys in maps or sets. Their underlying information can be retrieved at any
time. *)
module type Id = sig module type Id = sig
type t type t
type info type info
val fresh : info -> t val fresh : info -> t
val get_info : t -> info val get_info : t -> info
val compare : t -> t -> int val compare : t -> t -> int
val format_t : Format.formatter -> t -> unit val format_t : Format.formatter -> t -> unit
val hash : t -> int val hash : t -> int
end end
(** This is the generative functor that ensures that two modules resulting from two different calls (** This is the generative functor that ensures that two modules resulting from
to [Make] will be viewed as different types [t] by the OCaml typechecker. Prevents mixing up two different calls to [Make] will be viewed as different types [t] by the
different sorts of identifiers. *) OCaml typechecker. Prevents mixing up different sorts of identifiers. *)
module Make (X : Info) () : Id with type info = X.info module Make (X : Info) () : Id with type info = X.info

View File

@ -1,15 +1,18 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits (* This file is part of the Catala compiler, a specification language for tax
computation rules. Copyright (C) 2022 Inria, contributor: Denis Merigoux and social benefits computation rules. Copyright (C) 2022 Inria, contributor:
<denis.merigoux@inria.fr>, Alain Delaët <alain.delaet--tixeuil@inria.fr> Denis Merigoux <denis.merigoux@inria.fr>, Alain Delaët
<alain.delaet--tixeuil@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
open Utils open Utils
@ -19,19 +22,22 @@ open Ast
(** {1 Helpers and type definitions}*) (** {1 Helpers and type definitions}*)
type vc_return = expr Pos.marked * typ Pos.marked VarMap.t type vc_return = expr Pos.marked * typ Pos.marked VarMap.t
(** The return type of VC generators is the VC expression plus the types of any locally free (** The return type of VC generators is the VC expression plus the types of any
variable inside that expression. *) locally free variable inside that expression. *)
type ctx = { decl : decl_ctx; input_vars : Var.t list } type ctx = { decl : decl_ctx; input_vars : Var.t list }
let conjunction (args : vc_return list) (pos : Pos.t) : vc_return = let conjunction (args : vc_return list) (pos : Pos.t) : vc_return =
let acc, list = let acc, list =
match args with hd :: tl -> (hd, tl) | [] -> (((ELit (LBool true), pos), VarMap.empty), []) match args with
| hd :: tl -> (hd, tl)
| [] -> (((ELit (LBool true), pos), VarMap.empty), [])
in in
List.fold_left List.fold_left
(fun (acc, acc_ty) (arg, arg_ty) -> (fun (acc, acc_ty) (arg, arg_ty) ->
( (EApp ((EOp (Binop And), pos), [ arg; acc ]), pos), ( (EApp ((EOp (Binop And), pos), [ arg; acc ]), pos),
VarMap.union (fun _ _ _ -> failwith "should not happen") acc_ty arg_ty )) VarMap.union (fun _ _ _ -> failwith "should not happen") acc_ty arg_ty
))
acc list acc list
let negation ((arg, arg_ty) : vc_return) (pos : Pos.t) : vc_return = let negation ((arg, arg_ty) : vc_return) (pos : Pos.t) : vc_return =
@ -39,26 +45,31 @@ let negation ((arg, arg_ty) : vc_return) (pos : Pos.t) : vc_return =
let disjunction (args : vc_return list) (pos : Pos.t) : vc_return = let disjunction (args : vc_return list) (pos : Pos.t) : vc_return =
let acc, list = let acc, list =
match args with hd :: tl -> (hd, tl) | [] -> (((ELit (LBool false), pos), VarMap.empty), []) match args with
| hd :: tl -> (hd, tl)
| [] -> (((ELit (LBool false), pos), VarMap.empty), [])
in in
List.fold_left List.fold_left
(fun ((acc, acc_ty) : vc_return) (arg, arg_ty) -> (fun ((acc, acc_ty) : vc_return) (arg, arg_ty) ->
( (EApp ((EOp (Binop Or), pos), [ arg; acc ]), pos), ( (EApp ((EOp (Binop Or), pos), [ arg; acc ]), pos),
VarMap.union (fun _ _ _ -> failwith "should not happen") acc_ty arg_ty )) VarMap.union (fun _ _ _ -> failwith "should not happen") acc_ty arg_ty
))
acc list acc list
(** [half_product \[a1,...,an\] \[b1,...,bm\] returns \[(a1,b1),...(a1,bn),...(an,b1),...(an,bm)\]] *) (** [half_product \[a1,...,an\] \[b1,...,bm\] returns \[(a1,b1),...(a1,bn),...(an,b1),...(an,bm)\]] *)
let half_product (l1 : 'a list) (l2 : 'b list) : ('a * 'b) list = let half_product (l1 : 'a list) (l2 : 'b list) : ('a * 'b) list =
l1 l1
|> List.mapi (fun i ei -> List.filteri (fun j _ -> i < j) l2 |> List.map (fun ej -> (ei, ej))) |> List.mapi (fun i ei ->
List.filteri (fun j _ -> i < j) l2 |> List.map (fun ej -> (ei, ej)))
|> List.concat |> List.concat
(** This code skims through the topmost layers of the terms like this: (** This code skims through the topmost layers of the terms like this:
[log (error_on_empty < reentrant_variable () | true :- e1 >)] for scope variables, or [log (error_on_empty < reentrant_variable () | true :- e1 >)] for scope
[fun () -> e1] for subscope variables. But what we really want to analyze is only [e1], so we variables, or [fun () -> e1] for subscope variables. But what we really want
match this outermost structure explicitely and have a clean verification condition generator to analyze is only [e1], so we match this outermost structure explicitely
that only runs on [e1] *) and have a clean verification condition generator that only runs on [e1] *)
let match_and_ignore_outer_reentrant_default (ctx : ctx) (e : expr Pos.marked) : expr Pos.marked = let match_and_ignore_outer_reentrant_default (ctx : ctx) (e : expr Pos.marked) :
expr Pos.marked =
match Pos.unmark e with match Pos.unmark e with
| EApp | EApp
( (EOp (Unop (Log _)), _), ( (EOp (Unop (Log _)), _),
@ -81,8 +92,8 @@ let match_and_ignore_outer_reentrant_default (ctx : ctx) (e : expr Pos.marked) :
| EApp ((EOp (Unop (Log _)), _), [ arg ]) -> arg | EApp ((EOp (Unop (Log _)), _), [ arg ]) -> arg
| _ -> | _ ->
Errors.raise_spanned_error (Pos.get_position e) Errors.raise_spanned_error (Pos.get_position e)
"Internal error: this expression does not have the structure expected by the VC \ "Internal error: this expression does not have the structure \
generator:\n\ expected by the VC generator:\n\
%a" %a"
(Print.format_expr ~debug:true ctx.decl) (Print.format_expr ~debug:true ctx.decl)
e) e)
@ -91,33 +102,43 @@ let match_and_ignore_outer_reentrant_default (ctx : ctx) (e : expr Pos.marked) :
d (* input subscope variables and non-input scope variable *) d (* input subscope variables and non-input scope variable *)
| _ -> | _ ->
Errors.raise_spanned_error (Pos.get_position e) Errors.raise_spanned_error (Pos.get_position e)
"Internal error: this expression does not have the structure expected by the VC generator:\n\ "Internal error: this expression does not have the structure expected \
by the VC generator:\n\
%a" %a"
(Print.format_expr ~debug:true ctx.decl) (Print.format_expr ~debug:true ctx.decl)
e e
(** {1 Verification conditions generator}*) (** {1 Verification conditions generator}*)
(** [generate_vc_must_not_return_empty e] returns the dcalc boolean expression [b] such that if [b] (** [generate_vc_must_not_return_empty e] returns the dcalc boolean expression
is true, then [e] will never return an empty error. It also returns a map of all the types of [b] such that if [b] is true, then [e] will never return an empty error. It
locally free variables inside the expression. *) also returns a map of all the types of locally free variables inside the
let rec generate_vc_must_not_return_empty (ctx : ctx) (e : expr Pos.marked) : vc_return = expression. *)
let rec generate_vc_must_not_return_empty (ctx : ctx) (e : expr Pos.marked) :
vc_return =
let out = let out =
match Pos.unmark e with match Pos.unmark e with
| ETuple (args, _) | EArray args -> | ETuple (args, _) | EArray args ->
conjunction (List.map (generate_vc_must_not_return_empty ctx) args) (Pos.get_position e) conjunction
(List.map (generate_vc_must_not_return_empty ctx) args)
(Pos.get_position e)
| EMatch (arg, arms, _) -> | EMatch (arg, arms, _) ->
conjunction conjunction
(List.map (generate_vc_must_not_return_empty ctx) (arg :: arms)) (List.map (generate_vc_must_not_return_empty ctx) (arg :: arms))
(Pos.get_position e) (Pos.get_position e)
| ETupleAccess (e1, _, _, _) | EInj (e1, _, _, _) | EAssert e1 | ErrorOnEmpty e1 -> | ETupleAccess (e1, _, _, _)
| EInj (e1, _, _, _)
| EAssert e1
| ErrorOnEmpty e1 ->
(generate_vc_must_not_return_empty ctx) e1 (generate_vc_must_not_return_empty ctx) e1
| EAbs (binder, typs) -> | EAbs (binder, typs) ->
(* Hot take: for a function never to return an empty error when called, it has to do (* Hot take: for a function never to return an empty error when called, it has to do
so whatever its input. So we universally quantify over the variable of the function so whatever its input. So we universally quantify over the variable of the function
when inspecting the body, resulting in simply traversing through in the code here. *) when inspecting the body, resulting in simply traversing through in the code here. *)
let vars, body = Bindlib.unmbind (Pos.unmark binder) in let vars, body = Bindlib.unmbind (Pos.unmark binder) in
let vc_body_expr, vc_body_ty = (generate_vc_must_not_return_empty ctx) body in let vc_body_expr, vc_body_ty =
(generate_vc_must_not_return_empty ctx) body
in
( vc_body_expr, ( vc_body_expr,
List.fold_left List.fold_left
(fun acc (var, ty) -> VarMap.add var ty acc) (fun acc (var, ty) -> VarMap.add var ty acc)
@ -137,7 +158,9 @@ let rec generate_vc_must_not_return_empty (ctx : ctx) (e : expr Pos.marked) : vc
[ [
(e1_vc, vc_typ1); (e1_vc, vc_typ1);
( (EIfThenElse (e1, e2_vc, e3_vc), Pos.get_position e), ( (EIfThenElse (e1, e2_vc, e3_vc), Pos.get_position e),
VarMap.union (fun _ _ _ -> failwith "should not happen") vc_typ2 vc_typ3 ); VarMap.union
(fun _ _ _ -> failwith "should not happen")
vc_typ2 vc_typ3 );
] ]
(Pos.get_position e) (Pos.get_position e)
| ELit LEmptyError -> (Pos.same_pos_as (ELit (LBool false)) e, VarMap.empty) | ELit LEmptyError -> (Pos.same_pos_as (ELit (LBool false)) e, VarMap.empty)
@ -157,7 +180,9 @@ let rec generate_vc_must_not_return_empty (ctx : ctx) (e : expr Pos.marked) : vc
conjunction conjunction
[ [
generate_vc_must_not_return_empty ctx just; generate_vc_must_not_return_empty ctx just;
(let vc_just_expr, vc_just_ty = generate_vc_must_not_return_empty ctx cons in (let vc_just_expr, vc_just_ty =
generate_vc_must_not_return_empty ctx cons
in
( ( EIfThenElse ( ( EIfThenElse
( just, ( just,
(* Comment from Alain: the justification is not checked for holding an default term. (* Comment from Alain: the justification is not checked for holding an default term.
@ -178,25 +203,34 @@ let rec generate_vc_must_not_return_empty (ctx : ctx) (e : expr Pos.marked) : vc
out out
[@@ocamlformat "wrap-comments=false"] [@@ocamlformat "wrap-comments=false"]
(** [generate_vs_must_not_return_confict e] returns the dcalc boolean expression [b] such that if (** [generate_vs_must_not_return_confict e] returns the dcalc boolean expression
[b] is true, then [e] will never return a conflict error. It also returns a map of all the types [b] such that if [b] is true, then [e] will never return a conflict error.
of locally free variables inside the expression. *) It also returns a map of all the types of locally free variables inside the
let rec generate_vs_must_not_return_confict (ctx : ctx) (e : expr Pos.marked) : vc_return = expression. *)
let rec generate_vs_must_not_return_confict (ctx : ctx) (e : expr Pos.marked) :
vc_return =
let out = let out =
(* See the code of [generate_vc_must_not_return_empty] for a list of invariants on which this (* See the code of [generate_vc_must_not_return_empty] for a list of invariants on which this
function relies on. *) function relies on. *)
match Pos.unmark e with match Pos.unmark e with
| ETuple (args, _) | EArray args -> | ETuple (args, _) | EArray args ->
conjunction (List.map (generate_vs_must_not_return_confict ctx) args) (Pos.get_position e) conjunction
(List.map (generate_vs_must_not_return_confict ctx) args)
(Pos.get_position e)
| EMatch (arg, arms, _) -> | EMatch (arg, arms, _) ->
conjunction conjunction
(List.map (generate_vs_must_not_return_confict ctx) (arg :: arms)) (List.map (generate_vs_must_not_return_confict ctx) (arg :: arms))
(Pos.get_position e) (Pos.get_position e)
| ETupleAccess (e1, _, _, _) | EInj (e1, _, _, _) | EAssert e1 | ErrorOnEmpty e1 -> | ETupleAccess (e1, _, _, _)
| EInj (e1, _, _, _)
| EAssert e1
| ErrorOnEmpty e1 ->
generate_vs_must_not_return_confict ctx e1 generate_vs_must_not_return_confict ctx e1
| EAbs (binder, typs) -> | EAbs (binder, typs) ->
let vars, body = Bindlib.unmbind (Pos.unmark binder) in let vars, body = Bindlib.unmbind (Pos.unmark binder) in
let vc_body_expr, vc_body_ty = (generate_vs_must_not_return_confict ctx) body in let vc_body_expr, vc_body_ty =
(generate_vs_must_not_return_confict ctx) body
in
( vc_body_expr, ( vc_body_expr,
List.fold_left List.fold_left
(fun acc (var, ty) -> VarMap.add var ty acc) (fun acc (var, ty) -> VarMap.add var ty acc)
@ -214,10 +248,13 @@ let rec generate_vs_must_not_return_confict (ctx : ctx) (e : expr Pos.marked) :
[ [
(e1_vc, vc_typ1); (e1_vc, vc_typ1);
( (EIfThenElse (e1, e2_vc, e3_vc), Pos.get_position e), ( (EIfThenElse (e1, e2_vc, e3_vc), Pos.get_position e),
VarMap.union (fun _ _ _ -> failwith "should not happen") vc_typ2 vc_typ3 ); VarMap.union
(fun _ _ _ -> failwith "should not happen")
vc_typ2 vc_typ3 );
] ]
(Pos.get_position e) (Pos.get_position e)
| EVar _ | ELit _ | EOp _ -> (Pos.same_pos_as (ELit (LBool true)) e, VarMap.empty) | EVar _ | ELit _ | EOp _ ->
(Pos.same_pos_as (ELit (LBool true)) e, VarMap.empty)
| EDefault (exceptions, just, cons) -> | EDefault (exceptions, just, cons) ->
(* <e1 ... en | ejust :- econs > never returns conflict if and only if: (* <e1 ... en | ejust :- econs > never returns conflict if and only if:
- neither e1 nor ... nor en nor ejust nor econs return conflict - neither e1 nor ... nor en nor ejust nor econs return conflict
@ -238,7 +275,9 @@ let rec generate_vs_must_not_return_confict (ctx : ctx) (e : expr Pos.marked) :
(Pos.get_position e) (Pos.get_position e)
in in
let others = let others =
List.map (generate_vs_must_not_return_confict ctx) (just :: cons :: exceptions) List.map
(generate_vs_must_not_return_confict ctx)
(just :: cons :: exceptions)
in in
let out = conjunction (quadratic :: others) (Pos.get_position e) in let out = conjunction (quadratic :: others) (Pos.get_position e) in
out out
@ -259,7 +298,8 @@ type verification_condition = {
vc_free_vars_typ : typ Pos.marked VarMap.t; vc_free_vars_typ : typ Pos.marked VarMap.t;
} }
let generate_verification_conditions (p : program) : verification_condition list = let generate_verification_conditions (p : program) : verification_condition list
=
List.fold_left List.fold_left
(fun acc (s_name, _s_var, s_body) -> (fun acc (s_name, _s_var, s_body) ->
let ctx = { decl = p.decl_ctx; input_vars = [] } in let ctx = { decl = p.decl_ctx; input_vars = [] } in
@ -268,19 +308,29 @@ let generate_verification_conditions (p : program) : verification_condition list
(fun (acc, ctx) s_let -> (fun (acc, ctx) s_let ->
match s_let.scope_let_kind with match s_let.scope_let_kind with
| DestructuringInputStruct -> | DestructuringInputStruct ->
(acc, { ctx with input_vars = Pos.unmark s_let.scope_let_var :: ctx.input_vars }) ( acc,
{
ctx with
input_vars =
Pos.unmark s_let.scope_let_var :: ctx.input_vars;
} )
| ScopeVarDefinition | SubScopeVarDefinition -> | ScopeVarDefinition | SubScopeVarDefinition ->
(* For scope variables, we should check both that they never evaluate to emptyError (* For scope variables, we should check both that they never
nor conflictError. But for subscope variable definitions, what we're really doing evaluate to emptyError nor conflictError. But for subscope
is adding exceptions to something defined in the subscope so we just ought to variable definitions, what we're really doing is adding
verify only that the exceptions overlap. *) exceptions to something defined in the subscope so we just
ought to verify only that the exceptions overlap. *)
let e = let e =
match_and_ignore_outer_reentrant_default ctx (Bindlib.unbox s_let.scope_let_expr) match_and_ignore_outer_reentrant_default ctx
(Bindlib.unbox s_let.scope_let_expr)
in
let vc_confl, vc_confl_typs =
generate_vs_must_not_return_confict ctx e
in in
let vc_confl, vc_confl_typs = generate_vs_must_not_return_confict ctx e in
let vc_confl = let vc_confl =
if !Cli.optimize_flag then if !Cli.optimize_flag then
Bindlib.unbox (Optimizations.optimize_expr p.decl_ctx vc_confl) Bindlib.unbox
(Optimizations.optimize_expr p.decl_ctx vc_confl)
else vc_confl else vc_confl
in in
let vc_list = let vc_list =
@ -297,10 +347,13 @@ let generate_verification_conditions (p : program) : verification_condition list
let vc_list = let vc_list =
match s_let.scope_let_kind with match s_let.scope_let_kind with
| ScopeVarDefinition -> | ScopeVarDefinition ->
let vc_empty, vc_empty_typs = generate_vc_must_not_return_empty ctx e in let vc_empty, vc_empty_typs =
generate_vc_must_not_return_empty ctx e
in
let vc_empty = let vc_empty =
if !Cli.optimize_flag then if !Cli.optimize_flag then
Bindlib.unbox (Optimizations.optimize_expr p.decl_ctx vc_empty) Bindlib.unbox
(Optimizations.optimize_expr p.decl_ctx vc_empty)
else vc_empty else vc_empty
in in
{ {

View File

@ -1,33 +1,41 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits (* This file is part of the Catala compiler, a specification language for tax
computation rules. Copyright (C) 2022 Inria, contributor: Denis Merigoux and social benefits computation rules. Copyright (C) 2022 Inria, contributor:
<denis.merigoux@inria.fr>, Alain Delaët <alain.delaet--tixeuil@inria.fr> Denis Merigoux <denis.merigoux@inria.fr>, Alain Delaët
<alain.delaet--tixeuil@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
(** Generates verification conditions from scope definitions *) (** Generates verification conditions from scope definitions *)
type verification_condition_kind = type verification_condition_kind =
| NoEmptyError | NoEmptyError
(** This verification condition checks whether a definition never returns an empty error *) (** This verification condition checks whether a definition never returns
an empty error *)
| NoOverlappingExceptions | NoOverlappingExceptions
(** This verification condition checks whether a definition never returns a conflict error *) (** This verification condition checks whether a definition never returns
a conflict error *)
type verification_condition = { type verification_condition = {
vc_guard : Dcalc.Ast.expr Utils.Pos.marked; (** This expression should have type [bool]*) vc_guard : Dcalc.Ast.expr Utils.Pos.marked;
(** This expression should have type [bool]*)
vc_kind : verification_condition_kind; vc_kind : verification_condition_kind;
vc_scope : Dcalc.Ast.ScopeName.t; vc_scope : Dcalc.Ast.ScopeName.t;
vc_variable : Dcalc.Ast.Var.t Utils.Pos.marked; vc_variable : Dcalc.Ast.Var.t Utils.Pos.marked;
vc_free_vars_typ : Dcalc.Ast.typ Utils.Pos.marked Dcalc.Ast.VarMap.t; vc_free_vars_typ : Dcalc.Ast.typ Utils.Pos.marked Dcalc.Ast.VarMap.t;
(** Types of the locally free variables in [vc_guard]. The types of other free variables (** Types of the locally free variables in [vc_guard]. The types of other
linked to scope variables can be obtained with [Dcalc.Ast.variable_types]. *) free variables linked to scope variables can be obtained with
[Dcalc.Ast.variable_types]. *)
} }
val generate_verification_conditions : Dcalc.Ast.program -> verification_condition list val generate_verification_conditions :
Dcalc.Ast.program -> verification_condition list

View File

@ -1,15 +1,18 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits (* This file is part of the Catala compiler, a specification language for tax
computation rules. Copyright (C) 2022 Inria, contributor: Aymeric Fromherz and social benefits computation rules. Copyright (C) 2022 Inria, contributor:
<aymeric.fromherz@inria.fr>, Denis Merigoux <denis.merigoux@inria.fr> Aymeric Fromherz <aymeric.fromherz@inria.fr>, Denis Merigoux
<denis.merigoux@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
open Utils open Utils
@ -27,17 +30,16 @@ module type Backend = sig
val print_encoding : vc_encoding -> string val print_encoding : vc_encoding -> string
type model type model
type solver_result = ProvenTrue | ProvenFalse of model option | Unknown type solver_result = ProvenTrue | ProvenFalse of model option | Unknown
val solve_vc_encoding : backend_context -> vc_encoding -> solver_result val solve_vc_encoding : backend_context -> vc_encoding -> solver_result
val print_model : backend_context -> model -> string val print_model : backend_context -> model -> string
val is_model_empty : model -> bool val is_model_empty : model -> bool
val translate_expr : val translate_expr :
backend_context -> Dcalc.Ast.expr Utils.Pos.marked -> backend_context * vc_encoding backend_context ->
Dcalc.Ast.expr Utils.Pos.marked ->
backend_context * vc_encoding
end end
module type BackendIO = sig module type BackendIO = sig
@ -50,19 +52,28 @@ module type BackendIO = sig
type vc_encoding type vc_encoding
val translate_expr : val translate_expr :
backend_context -> Dcalc.Ast.expr Utils.Pos.marked -> backend_context * vc_encoding backend_context ->
Dcalc.Ast.expr Utils.Pos.marked ->
backend_context * vc_encoding
type model type model
type vc_encoding_result = Success of vc_encoding * backend_context | Fail of string type vc_encoding_result =
| Success of vc_encoding * backend_context
| Fail of string
val print_positive_result : Conditions.verification_condition -> string val print_positive_result : Conditions.verification_condition -> string
val print_negative_result : val print_negative_result :
Conditions.verification_condition -> backend_context -> model option -> string Conditions.verification_condition ->
backend_context ->
model option ->
string
val encode_and_check_vc : val encode_and_check_vc :
Dcalc.Ast.decl_ctx -> Conditions.verification_condition * vc_encoding_result -> unit Dcalc.Ast.decl_ctx ->
Conditions.verification_condition * vc_encoding_result ->
unit
end end
module MakeBackendIO (B : Backend) = struct module MakeBackendIO (B : Backend) = struct
@ -78,7 +89,9 @@ module MakeBackendIO (B : Backend) = struct
type model = B.model type model = B.model
type vc_encoding_result = Success of B.vc_encoding * B.backend_context | Fail of string type vc_encoding_result =
| Success of B.vc_encoding * B.backend_context
| Fail of string
let print_positive_result (vc : Conditions.verification_condition) : string = let print_positive_result (vc : Conditions.verification_condition) : string =
match vc.Conditions.vc_kind with match vc.Conditions.vc_kind with
@ -93,7 +106,9 @@ module MakeBackendIO (B : Backend) = struct
(Format.asprintf "%a" ScopeName.format_t vc.vc_scope) (Format.asprintf "%a" ScopeName.format_t vc.vc_scope)
(Bindlib.name_of (Pos.unmark vc.vc_variable))) (Bindlib.name_of (Pos.unmark vc.vc_variable)))
let print_negative_result (vc : Conditions.verification_condition) (ctx : B.backend_context) let print_negative_result
(vc : Conditions.verification_condition)
(ctx : B.backend_context)
(model : B.model option) : string = (model : B.model option) : string =
let var_and_pos = let var_and_pos =
match vc.Conditions.vc_kind with match vc.Conditions.vc_kind with
@ -104,7 +119,8 @@ module MakeBackendIO (B : Backend) = struct
(Bindlib.name_of (Pos.unmark vc.vc_variable))) (Bindlib.name_of (Pos.unmark vc.vc_variable)))
(Pos.retrieve_loc_text (Pos.get_position vc.vc_variable)) (Pos.retrieve_loc_text (Pos.get_position vc.vc_variable))
| Conditions.NoOverlappingExceptions -> | Conditions.NoOverlappingExceptions ->
Format.asprintf "%s At least two exceptions overlap for this variable:\n%s" Format.asprintf
"%s At least two exceptions overlap for this variable:\n%s"
(Cli.with_style [ ANSITerminal.yellow ] "[%s.%s]" (Cli.with_style [ ANSITerminal.yellow ] "[%s.%s]"
(Format.asprintf "%a" ScopeName.format_t vc.vc_scope) (Format.asprintf "%a" ScopeName.format_t vc.vc_scope)
(Bindlib.name_of (Pos.unmark vc.vc_variable))) (Bindlib.name_of (Pos.unmark vc.vc_variable)))
@ -117,23 +133,28 @@ module MakeBackendIO (B : Backend) = struct
match model with match model with
| None -> | None ->
Some Some
"The solver did not manage to generate a counterexample to explain the faulty \ "The solver did not manage to generate a counterexample to \
behavior." explain the faulty behavior."
| Some model -> | Some model ->
if B.is_model_empty model then None if B.is_model_empty model then None
else else
Some Some
(Format.asprintf (Format.asprintf
"The solver generated the following counterexample to explain the faulty \ "The solver generated the following counterexample to \
behavior:\n\ explain the faulty behavior:\n\
%s" %s"
(B.print_model ctx model)) (B.print_model ctx model))
in in
var_and_pos var_and_pos
^ match counterexample with None -> "" | Some counterexample -> "\n" ^ counterexample ^
match counterexample with
| None -> ""
| Some counterexample -> "\n" ^ counterexample
(** [encode_and_check_vc] spawns a new Z3 solver and tries to solve the expression [vc] **) (** [encode_and_check_vc] spawns a new Z3 solver and tries to solve the
let encode_and_check_vc (decl_ctx : decl_ctx) expression [vc] **)
let encode_and_check_vc
(decl_ctx : decl_ctx)
(vc : Conditions.verification_condition * vc_encoding_result) : unit = (vc : Conditions.verification_condition * vc_encoding_result) : unit =
let vc, z3_vc = vc in let vc, z3_vc = vc in
@ -142,17 +163,21 @@ module MakeBackendIO (B : Backend) = struct
Cli.debug_format "This verification condition was generated for %a:@\n%a" Cli.debug_format "This verification condition was generated for %a:@\n%a"
(Cli.format_with_style [ ANSITerminal.yellow ]) (Cli.format_with_style [ ANSITerminal.yellow ])
(match vc.vc_kind with (match vc.vc_kind with
| Conditions.NoEmptyError -> "the variable definition never to return an empty error" | Conditions.NoEmptyError ->
"the variable definition never to return an empty error"
| NoOverlappingExceptions -> "no two exceptions to ever overlap") | NoOverlappingExceptions -> "no two exceptions to ever overlap")
(Dcalc.Print.format_expr decl_ctx) (Dcalc.Print.format_expr decl_ctx)
vc.vc_guard; vc.vc_guard;
match z3_vc with match z3_vc with
| Success (encoding, backend_ctx) -> ( | Success (encoding, backend_ctx) -> (
Cli.debug_print "The translation to Z3 is the following:@\n%s" (B.print_encoding encoding); Cli.debug_print "The translation to Z3 is the following:@\n%s"
(B.print_encoding encoding);
match B.solve_vc_encoding backend_ctx encoding with match B.solve_vc_encoding backend_ctx encoding with
| ProvenTrue -> Cli.result_print "%s" (print_positive_result vc) | ProvenTrue -> Cli.result_print "%s" (print_positive_result vc)
| ProvenFalse model -> Cli.error_print "%s" (print_negative_result vc backend_ctx model) | ProvenFalse model ->
| Unknown -> failwith "The solver failed at proving or disproving the VC") Cli.error_print "%s" (print_negative_result vc backend_ctx model)
| Unknown ->
failwith "The solver failed at proving or disproving the VC")
| Fail msg -> Cli.error_print "The translation to Z3 failed:@\n%s" msg | Fail msg -> Cli.error_print "The translation to Z3 failed:@\n%s" msg
end end

View File

@ -1,15 +1,18 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits (* This file is part of the Catala compiler, a specification language for tax
computation rules. Copyright (C) 2022 Inria, contributor: Aymeric Fromherz and social benefits computation rules. Copyright (C) 2022 Inria, contributor:
<aymeric.fromherz@inria.fr>, Denis Merigoux <denis.merigoux@inria.fr> Aymeric Fromherz <aymeric.fromherz@inria.fr>, Denis Merigoux
<denis.merigoux@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
(** Common code for handling the IO of all proof backends supported *) (** Common code for handling the IO of all proof backends supported *)
@ -20,24 +23,25 @@ module type Backend = sig
type backend_context type backend_context
val make_context : val make_context :
Dcalc.Ast.decl_ctx -> Dcalc.Ast.typ Utils.Pos.marked Dcalc.Ast.VarMap.t -> backend_context Dcalc.Ast.decl_ctx ->
Dcalc.Ast.typ Utils.Pos.marked Dcalc.Ast.VarMap.t ->
backend_context
type vc_encoding type vc_encoding
val print_encoding : vc_encoding -> string val print_encoding : vc_encoding -> string
type model type model
type solver_result = ProvenTrue | ProvenFalse of model option | Unknown type solver_result = ProvenTrue | ProvenFalse of model option | Unknown
val solve_vc_encoding : backend_context -> vc_encoding -> solver_result val solve_vc_encoding : backend_context -> vc_encoding -> solver_result
val print_model : backend_context -> model -> string val print_model : backend_context -> model -> string
val is_model_empty : model -> bool val is_model_empty : model -> bool
val translate_expr : val translate_expr :
backend_context -> Dcalc.Ast.expr Utils.Pos.marked -> backend_context * vc_encoding backend_context ->
Dcalc.Ast.expr Utils.Pos.marked ->
backend_context * vc_encoding
end end
module type BackendIO = sig module type BackendIO = sig
@ -46,24 +50,35 @@ module type BackendIO = sig
type backend_context type backend_context
val make_context : val make_context :
Dcalc.Ast.decl_ctx -> Dcalc.Ast.typ Utils.Pos.marked Dcalc.Ast.VarMap.t -> backend_context Dcalc.Ast.decl_ctx ->
Dcalc.Ast.typ Utils.Pos.marked Dcalc.Ast.VarMap.t ->
backend_context
type vc_encoding type vc_encoding
val translate_expr : val translate_expr :
backend_context -> Dcalc.Ast.expr Utils.Pos.marked -> backend_context * vc_encoding backend_context ->
Dcalc.Ast.expr Utils.Pos.marked ->
backend_context * vc_encoding
type model type model
type vc_encoding_result = Success of vc_encoding * backend_context | Fail of string type vc_encoding_result =
| Success of vc_encoding * backend_context
| Fail of string
val print_positive_result : Conditions.verification_condition -> string val print_positive_result : Conditions.verification_condition -> string
val print_negative_result : val print_negative_result :
Conditions.verification_condition -> backend_context -> model option -> string Conditions.verification_condition ->
backend_context ->
model option ->
string
val encode_and_check_vc : val encode_and_check_vc :
Dcalc.Ast.decl_ctx -> Conditions.verification_condition * vc_encoding_result -> unit Dcalc.Ast.decl_ctx ->
Conditions.verification_condition * vc_encoding_result ->
unit
end end
module MakeBackendIO : functor (B : Backend) -> module MakeBackendIO : functor (B : Backend) ->

View File

@ -1,26 +1,30 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits (* This file is part of the Catala compiler, a specification language for tax
computation rules. Copyright (C) 2022 Inria, contributor: Aymeric Fromherz and social benefits computation rules. Copyright (C) 2022 Inria, contributor:
<aymeric.fromherz@inria.fr> Aymeric Fromherz <aymeric.fromherz@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
open Dcalc.Ast open Dcalc.Ast
(** [solve_vc] is the main entry point of this module. It takes a list of expressions [vcs] (** [solve_vc] is the main entry point of this module. It takes a list of
corresponding to verification conditions that must be discharged by Z3, and attempts to solve expressions [vcs] corresponding to verification conditions that must be
them **) discharged by Z3, and attempts to solve them **)
let solve_vc (prgm : program) (decl_ctx : decl_ctx) (vcs : Conditions.verification_condition list) : let solve_vc
unit = (prgm : program)
(* Right now we only use the Z3 backend but the functorial interface should make it easy to mix (decl_ctx : decl_ctx)
and match different proof backends. *) (vcs : Conditions.verification_condition list) : unit =
(* Right now we only use the Z3 backend but the functorial interface should
make it easy to mix and match different proof backends. *)
Z3backend.Io.init_backend (); Z3backend.Io.init_backend ();
let z3_vcs = let z3_vcs =
List.map List.map
@ -32,9 +36,12 @@ let solve_vc (prgm : program) (decl_ctx : decl_ctx) (vcs : Conditions.verificati
(Z3backend.Io.make_context decl_ctx (Z3backend.Io.make_context decl_ctx
(VarMap.union (VarMap.union
(fun _ _ _ -> (fun _ _ _ ->
failwith "[Proof encoding]: A Variable cannot be both free and bound") failwith
"[Proof encoding]: A Variable cannot be both free \
and bound")
(variable_types prgm) vc.Conditions.vc_free_vars_typ)) (variable_types prgm) vc.Conditions.vc_free_vars_typ))
(Bindlib.unbox (Dcalc.Optimizations.remove_all_logs vc.Conditions.vc_guard)) (Bindlib.unbox
(Dcalc.Optimizations.remove_all_logs vc.Conditions.vc_guard))
in in
Z3backend.Io.Success (z3_vc, ctx) Z3backend.Io.Success (z3_vc, ctx)
with Failure msg -> Fail msg )) with Failure msg -> Fail msg ))

View File

@ -1,18 +1,23 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits (* This file is part of the Catala compiler, a specification language for tax
computation rules. Copyright (C) 2022 Inria, contributor: Aymeric Fromherz and social benefits computation rules. Copyright (C) 2022 Inria, contributor:
<aymeric.fromherz@inria.fr> Aymeric Fromherz <aymeric.fromherz@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
(** Solves verification conditions using various proof backends *) (** Solves verification conditions using various proof backends *)
val solve_vc : val solve_vc :
Dcalc.Ast.program -> Dcalc.Ast.decl_ctx -> Conditions.verification_condition list -> unit Dcalc.Ast.program ->
Dcalc.Ast.decl_ctx ->
Conditions.verification_condition list ->
unit

View File

@ -1,135 +1,154 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits (* This file is part of the Catala compiler, a specification language for tax
computation rules. Copyright (C) 2022 Inria, contributor: Aymeric Fromherz and social benefits computation rules. Copyright (C) 2022 Inria, contributor:
<aymeric.fromherz@inria.fr> Aymeric Fromherz <aymeric.fromherz@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
open Utils open Utils
open Dcalc open Dcalc
open Ast open Ast
open Z3 open Z3
module StringMap : Map.S with type key = String.t = Map.Make (String) module StringMap : Map.S with type key = String.t = Map.Make (String)
type context = { type context = {
ctx_z3 : Z3.context; ctx_z3 : Z3.context;
(* The Z3 context, used to create symbols and expressions *) (* The Z3 context, used to create symbols and expressions *)
ctx_decl : decl_ctx; ctx_decl : decl_ctx;
(* The declaration context from the Catala program, containing information to precisely pretty (* The declaration context from the Catala program, containing information to
print Catala expressions *) precisely pretty print Catala expressions *)
ctx_var : typ Pos.marked VarMap.t; ctx_var : typ Pos.marked VarMap.t;
(* A map from Catala variables to their types, needed to create Z3 expressions of the right (* A map from Catala variables to their types, needed to create Z3 expressions
sort *) of the right sort *)
ctx_funcdecl : FuncDecl.func_decl VarMap.t; ctx_funcdecl : FuncDecl.func_decl VarMap.t;
(* A map from Catala function names (represented as variables) to Z3 function declarations, used (* A map from Catala function names (represented as variables) to Z3 function
to only define once functions in Z3 queries *) declarations, used to only define once functions in Z3 queries *)
ctx_z3vars : Var.t StringMap.t; ctx_z3vars : Var.t StringMap.t;
(* A map from strings, corresponding to Z3 symbol names, to the Catala variable they represent. (* A map from strings, corresponding to Z3 symbol names, to the Catala
Used when to pretty-print Z3 models when a counterexample is generated *) variable they represent. Used when to pretty-print Z3 models when a
counterexample is generated *)
ctx_z3datatypes : Sort.sort EnumMap.t; ctx_z3datatypes : Sort.sort EnumMap.t;
(* A map from Catala enumeration names to the corresponding Z3 sort, from which we can retrieve (* A map from Catala enumeration names to the corresponding Z3 sort, from
constructors and accessors *) which we can retrieve constructors and accessors *)
ctx_z3matchsubsts : Expr.expr VarMap.t; ctx_z3matchsubsts : Expr.expr VarMap.t;
(* A map from Catala temporary variables, generated when translating a match, to the corresponding (* A map from Catala temporary variables, generated when translating a match,
enum accessor call as a Z3 expression *) to the corresponding enum accessor call as a Z3 expression *)
ctx_z3structs : Sort.sort StructMap.t; ctx_z3structs : Sort.sort StructMap.t;
(* A map from Catala struct names to the corresponding Z3 sort, from which we can retrieve the (* A map from Catala struct names to the corresponding Z3 sort, from which we
constructor and the accessors *) can retrieve the constructor and the accessors *)
ctx_z3unit : Sort.sort * Expr.expr; ctx_z3unit : Sort.sort * Expr.expr;
(* A pair containing the Z3 encodings of the unit type, encoded as a tuple of 0 elements, and (* A pair containing the Z3 encodings of the unit type, encoded as a tuple
the unit value *) of 0 elements, and the unit value *)
} }
(** The context contains all the required information to encode a VC represented as a Catala term to (** The context contains all the required information to encode a VC represented
Z3. The fields [ctx_decl] and [ctx_var] are computed before starting the translation to Z3, and as a Catala term to Z3. The fields [ctx_decl] and [ctx_var] are computed
are thus unmodified throughout the translation. The [ctx_z3] context is an OCaml abstraction on before starting the translation to Z3, and are thus unmodified throughout
top of an underlying C++ imperative implementation, it is therefore only created once. the translation. The [ctx_z3] context is an OCaml abstraction on top of an
Unfortunately, the maps [ctx_funcdecl], [ctx_z3vars], and [ctx_z3datatypes] are computed underlying C++ imperative implementation, it is therefore only created once.
dynamically during the translation requiring us to pass the context around in a functional way **) Unfortunately, the maps [ctx_funcdecl], [ctx_z3vars], and [ctx_z3datatypes]
are computed dynamically during the translation requiring us to pass the
context around in a functional way **)
(** [add_funcdecl] adds the mapping between the Catala variable [v] and the Z3 function declaration (** [add_funcdecl] adds the mapping between the Catala variable [v] and the Z3
[fd] to the context **) function declaration [fd] to the context **)
let add_funcdecl (v : Var.t) (fd : FuncDecl.func_decl) (ctx : context) : context = let add_funcdecl (v : Var.t) (fd : FuncDecl.func_decl) (ctx : context) : context
=
{ ctx with ctx_funcdecl = VarMap.add v fd ctx.ctx_funcdecl } { ctx with ctx_funcdecl = VarMap.add v fd ctx.ctx_funcdecl }
(** [add_z3var] adds the mapping between [name] and the Catala variable [v] to the context **) (** [add_z3var] adds the mapping between [name] and the Catala variable [v] to
the context **)
let add_z3var (name : string) (v : Var.t) (ctx : context) : context = let add_z3var (name : string) (v : Var.t) (ctx : context) : context =
{ ctx with ctx_z3vars = StringMap.add name v ctx.ctx_z3vars } { ctx with ctx_z3vars = StringMap.add name v ctx.ctx_z3vars }
(** [add_z3enum] adds the mapping between the Catala enumeration [enum] and the corresponding Z3 (** [add_z3enum] adds the mapping between the Catala enumeration [enum] and the
datatype [sort] to the context **) corresponding Z3 datatype [sort] to the context **)
let add_z3enum (enum : EnumName.t) (sort : Sort.sort) (ctx : context) : context = let add_z3enum (enum : EnumName.t) (sort : Sort.sort) (ctx : context) : context
=
{ ctx with ctx_z3datatypes = EnumMap.add enum sort ctx.ctx_z3datatypes } { ctx with ctx_z3datatypes = EnumMap.add enum sort ctx.ctx_z3datatypes }
(** [add_z3var] adds the mapping between temporary variable [v] and the Z3 expression [e] (** [add_z3var] adds the mapping between temporary variable [v] and the Z3
representing an accessor application to the context **) expression [e] representing an accessor application to the context **)
let add_z3matchsubst (v : Var.t) (e : Expr.expr) (ctx : context) : context = let add_z3matchsubst (v : Var.t) (e : Expr.expr) (ctx : context) : context =
{ ctx with ctx_z3matchsubsts = VarMap.add v e ctx.ctx_z3matchsubsts } { ctx with ctx_z3matchsubsts = VarMap.add v e ctx.ctx_z3matchsubsts }
(** [add_z3struct] adds the mapping between the Catala struct [s] and the corresponding Z3 datatype (** [add_z3struct] adds the mapping between the Catala struct [s] and the
[sort] to the context **) corresponding Z3 datatype [sort] to the context **)
let add_z3struct (s : StructName.t) (sort : Sort.sort) (ctx : context) : context = let add_z3struct (s : StructName.t) (sort : Sort.sort) (ctx : context) : context
=
{ ctx with ctx_z3structs = StructMap.add s sort ctx.ctx_z3structs } { ctx with ctx_z3structs = StructMap.add s sort ctx.ctx_z3structs }
(** For the Z3 encoding of Catala programs, we define the "day 0" as Jan 1, 1900 **) (** For the Z3 encoding of Catala programs, we define the "day 0" as Jan 1, 1900
**)
let base_day = CalendarLib.Date.make 1900 1 1 let base_day = CalendarLib.Date.make 1900 1 1
(** [unique_name] returns the full, unique name corresponding to variable [v], as given by Bindlib **) (** [unique_name] returns the full, unique name corresponding to variable [v],
as given by Bindlib **)
let unique_name (v : Var.t) : string = let unique_name (v : Var.t) : string =
Format.asprintf "%s_%d" (Bindlib.name_of v) (Bindlib.uid_of v) Format.asprintf "%s_%d" (Bindlib.name_of v) (Bindlib.uid_of v)
(** [date_to_int] translates [date] to an integer corresponding to the number of days since Jan 1, (** [date_to_int] translates [date] to an integer corresponding to the number of
1900 **) days since Jan 1, 1900 **)
let date_to_int (d : Runtime.date) : int = let date_to_int (d : Runtime.date) : int =
(* Alternatively, could expose this from Runtime as a (noop) coercion, but would allow to break (* Alternatively, could expose this from Runtime as a (noop) coercion, but
abstraction more easily elsewhere *) would allow to break abstraction more easily elsewhere *)
let date : CalendarLib.Date.t = CalendarLib.Printer.Date.from_string (Runtime.date_to_string d) in let date : CalendarLib.Date.t =
CalendarLib.Printer.Date.from_string (Runtime.date_to_string d)
in
let period = CalendarLib.Date.sub date base_day in let period = CalendarLib.Date.sub date base_day in
CalendarLib.Date.Period.nb_days period CalendarLib.Date.Period.nb_days period
(** [date_of_year] translates a [year], represented as an integer into an OCaml date corresponding (** [date_of_year] translates a [year], represented as an integer into an OCaml
to Jan 1st of the same year *) date corresponding to Jan 1st of the same year *)
let date_of_year (year : int) = Runtime.date_of_numbers year 1 1 let date_of_year (year : int) = Runtime.date_of_numbers year 1 1
(** Returns the date (as a string) corresponding to nb days after the base day, defined here as Jan (** Returns the date (as a string) corresponding to nb days after the base day,
1, 1900 **) defined here as Jan 1, 1900 **)
let nb_days_to_date (nb : int) : string = let nb_days_to_date (nb : int) : string =
CalendarLib.Printer.Date.to_string CalendarLib.Printer.Date.to_string
(CalendarLib.Date.add base_day (CalendarLib.Date.Period.day nb)) (CalendarLib.Date.add base_day (CalendarLib.Date.Period.day nb))
(** [print_z3model_expr] pretty-prints the value [e] given by a Z3 model according to the Catala (** [print_z3model_expr] pretty-prints the value [e] given by a Z3 model
type [ty], corresponding to [e] **) according to the Catala type [ty], corresponding to [e] **)
let rec print_z3model_expr (ctx : context) (ty : typ Pos.marked) (e : Expr.expr) : string = let rec print_z3model_expr (ctx : context) (ty : typ Pos.marked) (e : Expr.expr)
: string =
let print_lit (ty : typ_lit) = let print_lit (ty : typ_lit) =
match ty with match ty with
(* TODO: Print boolean according to current language *) (* TODO: Print boolean according to current language *)
| TBool -> Expr.to_string e | TBool -> Expr.to_string e
(* TUnit is only used for the absence of an enum constructor argument. Hence, when (* TUnit is only used for the absence of an enum constructor argument.
pretty-printing, we print nothing to remain closer from Catala sources *) Hence, when pretty-printing, we print nothing to remain closer from
Catala sources *)
| TUnit -> "" | TUnit -> ""
| TInt -> Expr.to_string e | TInt -> Expr.to_string e
| TRat -> Arithmetic.Real.to_decimal_string e !Cli.max_prec_digits | TRat -> Arithmetic.Real.to_decimal_string e !Cli.max_prec_digits
(* TODO: Print the right money symbol according to language *) (* TODO: Print the right money symbol according to language *)
| TMoney -> | TMoney ->
let z3_str = Expr.to_string e in let z3_str = Expr.to_string e in
(* The Z3 model returns an integer corresponding to the amount of cents. We reformat it as (* The Z3 model returns an integer corresponding to the amount of cents.
dollars *) We reformat it as dollars *)
let to_dollars s = Runtime.money_to_string (Runtime.money_of_cents_string s) in let to_dollars s =
Runtime.money_to_string (Runtime.money_of_cents_string s)
in
if String.contains z3_str '-' then if String.contains z3_str '-' then
Format.asprintf "-%s $" (to_dollars (String.sub z3_str 3 (String.length z3_str - 4))) Format.asprintf "-%s $"
(to_dollars (String.sub z3_str 3 (String.length z3_str - 4)))
else Format.asprintf "%s $" (to_dollars z3_str) else Format.asprintf "%s $" (to_dollars z3_str)
(* The Z3 date representation corresponds to the number of days since Jan 1, 1900. We (* The Z3 date representation corresponds to the number of days since Jan 1,
pretty-print it as the actual date *) 1900. We pretty-print it as the actual date *)
(* TODO: Use differnt dates conventions depending on the language ? *) (* TODO: Use differnt dates conventions depending on the language ? *)
| TDate -> nb_days_to_date (int_of_string (Expr.to_string e)) | TDate -> nb_days_to_date (int_of_string (Expr.to_string e))
| TDuration -> failwith "[Z3 model]: Pretty-printing of duration literals not supported" | TDuration ->
failwith
"[Z3 model]: Pretty-printing of duration literals not supported"
in in
match Pos.unmark ty with match Pos.unmark ty with
@ -142,14 +161,18 @@ let rec print_z3model_expr (ctx : context) (ty : typ Pos.marked) (e : Expr.expr)
let fields = let fields =
List.map2 List.map2
(fun (fn, ty) e -> (fun (fn, ty) e ->
Format.asprintf "-- %s : %s" (get_fieldname fn) (print_z3model_expr ctx ty e)) Format.asprintf "-- %s : %s" (get_fieldname fn)
(print_z3model_expr ctx ty e))
s (Expr.get_args e) s (Expr.get_args e)
in in
let fields_str = String.concat " " fields in let fields_str = String.concat " " fields in
Format.asprintf "%s { %s }" (Pos.unmark (StructName.get_info name)) fields_str Format.asprintf "%s { %s }"
| TTuple (_, None) -> failwith "[Z3 model]: Pretty-printing of unnamed structs not supported" (Pos.unmark (StructName.get_info name))
fields_str
| TTuple (_, None) ->
failwith "[Z3 model]: Pretty-printing of unnamed structs not supported"
| TEnum (_tys, name) -> | TEnum (_tys, name) ->
(* The value associated to the enum is a single argument *) (* The value associated to the enum is a single argument *)
let e' = List.hd (Expr.get_args e) in let e' = List.hd (Expr.get_args e) in
@ -159,7 +182,8 @@ let rec print_z3model_expr (ctx : context) (ty : typ Pos.marked) (e : Expr.expr)
let enum_ctrs = EnumMap.find name ctx.ctx_decl.ctx_enums in let enum_ctrs = EnumMap.find name ctx.ctx_decl.ctx_enums in
let case = let case =
List.find List.find
(fun (ctr, _) -> String.equal fd_name (Pos.unmark (EnumConstructor.get_info ctr))) (fun (ctr, _) ->
String.equal fd_name (Pos.unmark (EnumConstructor.get_info ctr)))
enum_ctrs enum_ctrs
in in
@ -168,10 +192,11 @@ let rec print_z3model_expr (ctx : context) (ty : typ Pos.marked) (e : Expr.expr)
| TArray _ -> failwith "[Z3 model]: Pretty-printing of arrays not supported" | TArray _ -> failwith "[Z3 model]: Pretty-printing of arrays not supported"
| TAny -> failwith "[Z3 model]: Pretty-printing of Any not supported" | TAny -> failwith "[Z3 model]: Pretty-printing of Any not supported"
(** [print_model] pretty prints a Z3 model, used to exhibit counter examples where verification (** [print_model] pretty prints a Z3 model, used to exhibit counter examples
conditions are not satisfied. The context [ctx] is useful to retrieve the mapping between Z3 where verification conditions are not satisfied. The context [ctx] is useful
variables and Catala variables, and to retrieve type information about the variables that was to retrieve the mapping between Z3 variables and Catala variables, and to
lost during the translation (e.g., by translating a date to an integer) **) retrieve type information about the variables that was lost during the
translation (e.g., by translating a date to an integer) **)
let print_model (ctx : context) (model : Model.model) : string = let print_model (ctx : context) (model : Model.model) : string =
let decls = Model.get_decls model in let decls = Model.get_decls model in
Format.asprintf "%a" Format.asprintf "%a"
@ -182,32 +207,41 @@ let print_model (ctx : context) (model : Model.model) : string =
(* Constant case *) (* Constant case *)
match Model.get_const_interp model d with match Model.get_const_interp model d with
(* TODO: Better handling of this case *) (* TODO: Better handling of this case *)
| None -> failwith "[Z3 model]: A variable does not have an associated Z3 solution" | None ->
failwith
"[Z3 model]: A variable does not have an associated Z3 \
solution"
(* Print "name : value\n" *) (* Print "name : value\n" *)
| Some e -> | Some e ->
let symbol_name = Symbol.to_string (FuncDecl.get_name d) in let symbol_name = Symbol.to_string (FuncDecl.get_name d) in
let v = StringMap.find symbol_name ctx.ctx_z3vars in let v = StringMap.find symbol_name ctx.ctx_z3vars in
Format.fprintf fmt "%s %s : %s" Format.fprintf fmt "%s %s : %s"
(Cli.with_style [ ANSITerminal.blue ] "%s" "-->") (Cli.with_style [ ANSITerminal.blue ] "%s" "-->")
(Cli.with_style [ ANSITerminal.yellow ] "%s" (Bindlib.name_of v)) (Cli.with_style [ ANSITerminal.yellow ] "%s"
(Bindlib.name_of v))
(print_z3model_expr ctx (VarMap.find v ctx.ctx_var) e) (print_z3model_expr ctx (VarMap.find v ctx.ctx_var) e)
else else
(* Declaration d is a function *) (* Declaration d is a function *)
match Model.get_func_interp model d with match Model.get_func_interp model d with
(* TODO: Better handling of this case *) (* TODO: Better handling of this case *)
| None -> failwith "[Z3 model]: A variable does not have an associated Z3 solution" | None ->
failwith
"[Z3 model]: A variable does not have an associated Z3 \
solution"
(* Print "name : value\n" *) (* Print "name : value\n" *)
| Some f -> | Some f ->
let symbol_name = Symbol.to_string (FuncDecl.get_name d) in let symbol_name = Symbol.to_string (FuncDecl.get_name d) in
let v = StringMap.find symbol_name ctx.ctx_z3vars in let v = StringMap.find symbol_name ctx.ctx_z3vars in
Format.fprintf fmt "%s %s : %s" Format.fprintf fmt "%s %s : %s"
(Cli.with_style [ ANSITerminal.blue ] "%s" "-->") (Cli.with_style [ ANSITerminal.blue ] "%s" "-->")
(Cli.with_style [ ANSITerminal.yellow ] "%s" (Bindlib.name_of v)) (Cli.with_style [ ANSITerminal.yellow ] "%s"
(Bindlib.name_of v))
(* TODO: Model of a Z3 function should be pretty-printed *) (* TODO: Model of a Z3 function should be pretty-printed *)
(Model.FuncInterp.to_string f))) (Model.FuncInterp.to_string f)))
decls decls
(** [translate_typ_lit] returns the Z3 sort corresponding to the Catala literal type [t] **) (** [translate_typ_lit] returns the Z3 sort corresponding to the Catala literal
type [t] **)
let translate_typ_lit (ctx : context) (t : typ_lit) : Sort.sort = let translate_typ_lit (ctx : context) (t : typ_lit) : Sort.sort =
match t with match t with
| TBool -> Boolean.mk_sort ctx.ctx_z3 | TBool -> Boolean.mk_sort ctx.ctx_z3
@ -215,7 +249,8 @@ let translate_typ_lit (ctx : context) (t : typ_lit) : Sort.sort =
| TInt -> Arithmetic.Integer.mk_sort ctx.ctx_z3 | TInt -> Arithmetic.Integer.mk_sort ctx.ctx_z3
| TRat -> Arithmetic.Real.mk_sort ctx.ctx_z3 | TRat -> Arithmetic.Real.mk_sort ctx.ctx_z3
| TMoney -> Arithmetic.Integer.mk_sort ctx.ctx_z3 | TMoney -> Arithmetic.Integer.mk_sort ctx.ctx_z3
(* Dates are encoded as integers, corresponding to the number of days since Jan 1, 1900 *) (* Dates are encoded as integers, corresponding to the number of days since
Jan 1, 1900 *)
| TDate -> Arithmetic.Integer.mk_sort ctx.ctx_z3 | TDate -> Arithmetic.Integer.mk_sort ctx.ctx_z3
| TDuration -> failwith "[Z3 encoding] TDuration type not supported" | TDuration -> failwith "[Z3 encoding] TDuration type not supported"
@ -224,34 +259,40 @@ let rec translate_typ (ctx : context) (t : typ) : context * Sort.sort =
match t with match t with
| TLit t -> (ctx, translate_typ_lit ctx t) | TLit t -> (ctx, translate_typ_lit ctx t)
| TTuple (_, Some name) -> find_or_create_struct ctx name | TTuple (_, Some name) -> find_or_create_struct ctx name
| TTuple (_, None) -> failwith "[Z3 encoding] TTuple type of unnamed struct not supported" | TTuple (_, None) ->
failwith "[Z3 encoding] TTuple type of unnamed struct not supported"
| TEnum (_, e) -> find_or_create_enum ctx e | TEnum (_, e) -> find_or_create_enum ctx e
| TArrow _ -> failwith "[Z3 encoding] TArrow type not supported" | TArrow _ -> failwith "[Z3 encoding] TArrow type not supported"
| TArray _ -> failwith "[Z3 encoding] TArray type not supported" | TArray _ -> failwith "[Z3 encoding] TArray type not supported"
| TAny -> failwith "[Z3 encoding] TAny type not supported" | TAny -> failwith "[Z3 encoding] TAny type not supported"
(** [find_or_create_enum] attempts to retrieve the Z3 sort corresponding to the Catala enumeration (** [find_or_create_enum] attempts to retrieve the Z3 sort corresponding to the
[enum]. If no such sort exists yet, it constructs it by creating a Z3 constructor for each Catala enumeration [enum]. If no such sort exists yet, it constructs it by
Catala constructor of [enum], and adds it to the context *) creating a Z3 constructor for each Catala constructor of [enum], and adds it
and find_or_create_enum (ctx : context) (enum : EnumName.t) : context * Sort.sort = to the context *)
and find_or_create_enum (ctx : context) (enum : EnumName.t) :
context * Sort.sort =
(* Creates a Z3 constructor corresponding to the Catala constructor [c] *) (* Creates a Z3 constructor corresponding to the Catala constructor [c] *)
let create_constructor (ctx : context) (c : EnumConstructor.t * typ Pos.marked) : let create_constructor
(ctx : context) (c : EnumConstructor.t * typ Pos.marked) :
context * Datatype.Constructor.constructor = context * Datatype.Constructor.constructor =
let name, ty = c in let name, ty = c in
let name = Pos.unmark (EnumConstructor.get_info name) in let name = Pos.unmark (EnumConstructor.get_info name) in
let ctx, arg_z3_ty = translate_typ ctx (Pos.unmark ty) in let ctx, arg_z3_ty = translate_typ ctx (Pos.unmark ty) in
(* The mk_constructor_s Z3 function is not so well documented. From my understanding, its (* The mk_constructor_s Z3 function is not so well documented. From my
argument are: - a string corresponding to the name of the constructor - a recognizer as a understanding, its argument are: - a string corresponding to the name of
symbol corresponding to the name (unsure why) - a list of symbols corresponding to the the constructor - a recognizer as a symbol corresponding to the name
arguments of the constructor - a list of types, that must be of the same length as the list (unsure why) - a list of symbols corresponding to the arguments of the
of arguments - a list of sort_refs, of the same length as the list of arguments. I'm unsure constructor - a list of types, that must be of the same length as the
what this corresponds to *) list of arguments - a list of sort_refs, of the same length as the list
of arguments. I'm unsure what this corresponds to *)
( ctx, ( ctx,
Datatype.mk_constructor_s ctx.ctx_z3 name Datatype.mk_constructor_s ctx.ctx_z3 name
(Symbol.mk_string ctx.ctx_z3 name) (Symbol.mk_string ctx.ctx_z3 name)
(* We need a name for the argument of the constructor, we arbitrary pick the name of the (* We need a name for the argument of the constructor, we arbitrary pick
constructor to which we append the special character "!" and the integer 0 *) the name of the constructor to which we append the special character
"!" and the integer 0 *)
[ Symbol.mk_string ctx.ctx_z3 (name ^ "!0") ] [ Symbol.mk_string ctx.ctx_z3 (name ^ "!0") ]
(* The type of the argument, translated to a Z3 sort *) (* The type of the argument, translated to a Z3 sort *)
[ Some arg_z3_ty ] [ Some arg_z3_ty ]
@ -263,13 +304,19 @@ and find_or_create_enum (ctx : context) (enum : EnumName.t) : context * Sort.sor
| None -> | None ->
let ctrs = EnumMap.find enum ctx.ctx_decl.ctx_enums in let ctrs = EnumMap.find enum ctx.ctx_decl.ctx_enums in
let ctx, z3_ctrs = List.fold_left_map create_constructor ctx ctrs in let ctx, z3_ctrs = List.fold_left_map create_constructor ctx ctrs in
let z3_enum = Datatype.mk_sort_s ctx.ctx_z3 (Pos.unmark (EnumName.get_info enum)) z3_ctrs in let z3_enum =
Datatype.mk_sort_s ctx.ctx_z3
(Pos.unmark (EnumName.get_info enum))
z3_ctrs
in
(add_z3enum enum z3_enum ctx, z3_enum) (add_z3enum enum z3_enum ctx, z3_enum)
(** [find_or_create_struct] attemps to retrieve the Z3 sort corresponding to the struct [s]. If no (** [find_or_create_struct] attemps to retrieve the Z3 sort corresponding to the
such sort exists yet, we construct it as a datatype with one constructor taking all the fields struct [s]. If no such sort exists yet, we construct it as a datatype with
as arguments, and add it to the context *) one constructor taking all the fields as arguments, and add it to the
and find_or_create_struct (ctx : context) (s : StructName.t) : context * Sort.sort = context *)
and find_or_create_struct (ctx : context) (s : StructName.t) :
context * Sort.sort =
match StructMap.find_opt s ctx.ctx_z3structs with match StructMap.find_opt s ctx.ctx_z3structs with
| Some s -> (ctx, s) | Some s -> (ctx, s)
| None -> | None ->
@ -277,11 +324,15 @@ and find_or_create_struct (ctx : context) (s : StructName.t) : context * Sort.so
let fields = StructMap.find s ctx.ctx_decl.ctx_structs in let fields = StructMap.find s ctx.ctx_decl.ctx_structs in
let z3_fieldnames = let z3_fieldnames =
List.map List.map
(fun f -> Pos.unmark (StructFieldName.get_info (fst f)) |> Symbol.mk_string ctx.ctx_z3) (fun f ->
Pos.unmark (StructFieldName.get_info (fst f))
|> Symbol.mk_string ctx.ctx_z3)
fields fields
in in
let ctx, z3_fieldtypes = let ctx, z3_fieldtypes =
List.fold_left_map (fun ctx f -> Pos.unmark (snd f) |> translate_typ ctx) ctx fields List.fold_left_map
(fun ctx f -> Pos.unmark (snd f) |> translate_typ ctx)
ctx fields
in in
let z3_sortrefs = List.map Sort.get_id z3_fieldtypes in let z3_sortrefs = List.map Sort.get_id z3_fieldtypes in
let mk_struct_s = "mk!" ^ s_name in let mk_struct_s = "mk!" ^ s_name in
@ -296,25 +347,33 @@ and find_or_create_struct (ctx : context) (s : StructName.t) : context * Sort.so
let z3_struct = Datatype.mk_sort_s ctx.ctx_z3 s_name [ z3_mk_struct ] in let z3_struct = Datatype.mk_sort_s ctx.ctx_z3 s_name [ z3_mk_struct ] in
(add_z3struct s z3_struct ctx, z3_struct) (add_z3struct s z3_struct ctx, z3_struct)
(** [translate_lit] returns the Z3 expression as a literal corresponding to [lit] **) (** [translate_lit] returns the Z3 expression as a literal corresponding to
[lit] **)
let translate_lit (ctx : context) (l : lit) : Expr.expr = let translate_lit (ctx : context) (l : lit) : Expr.expr =
match l with match l with
| LBool b -> if b then Boolean.mk_true ctx.ctx_z3 else Boolean.mk_false ctx.ctx_z3 | LBool b ->
if b then Boolean.mk_true ctx.ctx_z3 else Boolean.mk_false ctx.ctx_z3
| LEmptyError -> failwith "[Z3 encoding] LEmptyError literals not supported" | LEmptyError -> failwith "[Z3 encoding] LEmptyError literals not supported"
| LInt n -> Arithmetic.Integer.mk_numeral_i ctx.ctx_z3 (Runtime.integer_to_int n) | LInt n ->
| LRat r -> Arithmetic.Real.mk_numeral_s ctx.ctx_z3 (string_of_float (Runtime.decimal_to_float r)) Arithmetic.Integer.mk_numeral_i ctx.ctx_z3 (Runtime.integer_to_int n)
| LRat r ->
Arithmetic.Real.mk_numeral_s ctx.ctx_z3
(string_of_float (Runtime.decimal_to_float r))
| LMoney m -> | LMoney m ->
let z3_m = Runtime.integer_to_int (Runtime.money_to_cents m) in let z3_m = Runtime.integer_to_int (Runtime.money_to_cents m) in
Arithmetic.Integer.mk_numeral_i ctx.ctx_z3 z3_m Arithmetic.Integer.mk_numeral_i ctx.ctx_z3 z3_m
| LUnit -> failwith "[Z3 encoding] LUnit literals not supported" | LUnit -> failwith "[Z3 encoding] LUnit literals not supported"
(* Encoding a date as an integer corresponding to the number of days since Jan 1, 1900 *) (* Encoding a date as an integer corresponding to the number of days since Jan
1, 1900 *)
| LDate d -> Arithmetic.Integer.mk_numeral_i ctx.ctx_z3 (date_to_int d) | LDate d -> Arithmetic.Integer.mk_numeral_i ctx.ctx_z3 (date_to_int d)
| LDuration _ -> failwith "[Z3 encoding] LDuration literals not supported" | LDuration _ -> failwith "[Z3 encoding] LDuration literals not supported"
(** [find_or_create_funcdecl] attempts to retrieve the Z3 function declaration corresponding to the (** [find_or_create_funcdecl] attempts to retrieve the Z3 function declaration
variable [v]. If no such function declaration exists yet, we construct it and add it to the corresponding to the variable [v]. If no such function declaration exists
context, thus requiring to return a new context *) yet, we construct it and add it to the context, thus requiring to return a
let find_or_create_funcdecl (ctx : context) (v : Var.t) : context * FuncDecl.func_decl = new context *)
let find_or_create_funcdecl (ctx : context) (v : Var.t) :
context * FuncDecl.func_decl =
match VarMap.find_opt v ctx.ctx_funcdecl with match VarMap.find_opt v ctx.ctx_funcdecl with
| Some fd -> (ctx, fd) | Some fd -> (ctx, fd)
| None -> ( | None -> (
@ -331,14 +390,17 @@ let find_or_create_funcdecl (ctx : context) (v : Var.t) : context * FuncDecl.fun
(ctx, fd) (ctx, fd)
| TAny -> | TAny ->
failwith failwith
"[Z3 Encoding] A function being applied has type TAny, the type was not fully inferred" "[Z3 Encoding] A function being applied has type TAny, the type \
was not fully inferred"
| _ -> | _ ->
failwith failwith
"[Z3 Encoding] Ill-formed VC, a function application does not have a function type") "[Z3 Encoding] Ill-formed VC, a function application does not have \
a function type")
(** [translate_op] returns the Z3 expression corresponding to the application of [op] to the (** [translate_op] returns the Z3 expression corresponding to the application of
arguments [args] **) [op] to the arguments [args] **)
let rec translate_op (ctx : context) (op : operator) (args : expr Pos.marked list) : let rec translate_op
(ctx : context) (op : operator) (args : expr Pos.marked list) :
context * Expr.expr = context * Expr.expr =
match op with match op with
| Ternop _top -> | Ternop _top ->
@ -347,7 +409,8 @@ let rec translate_op (ctx : context) (op : operator) (args : expr Pos.marked lis
| [ e1; e2; e3 ] -> (e1, e2, e3) | [ e1; e2; e3 ] -> (e1, e2, e3)
| _ -> | _ ->
failwith failwith
(Format.asprintf "[Z3 encoding] Ill-formed ternary operator application: %a" (Format.asprintf
"[Z3 encoding] Ill-formed ternary operator application: %a"
(Print.format_expr ctx.ctx_decl) (Print.format_expr ctx.ctx_decl)
(EApp ((EOp op, Pos.no_pos), args), Pos.no_pos)) (EApp ((EOp op, Pos.no_pos), args), Pos.no_pos))
in in
@ -356,41 +419,61 @@ let rec translate_op (ctx : context) (op : operator) (args : expr Pos.marked lis
| Binop bop -> ( | Binop bop -> (
(* Special case for GetYear comparisons *) (* Special case for GetYear comparisons *)
match (bop, args) with match (bop, args) with
| Lt KInt, [ (EApp ((EOp (Unop GetYear), _), [ e1 ]), _); (ELit (LInt n), _) ] -> | ( Lt KInt,
[ (EApp ((EOp (Unop GetYear), _), [ e1 ]), _); (ELit (LInt n), _) ] )
->
let n = Runtime.integer_to_int n in let n = Runtime.integer_to_int n in
let ctx, e1 = translate_expr ctx e1 in let ctx, e1 = translate_expr ctx e1 in
let e2 = Arithmetic.Integer.mk_numeral_i ctx.ctx_z3 (date_to_int (date_of_year n)) in let e2 =
(* e2 corresponds to the first day of the year n. GetYear e1 < e2 can thus be directly Arithmetic.Integer.mk_numeral_i ctx.ctx_z3
translated as < in the Z3 encoding using the number of days *) (date_to_int (date_of_year n))
in
(* e2 corresponds to the first day of the year n. GetYear e1 < e2 can
thus be directly translated as < in the Z3 encoding using the
number of days *)
(ctx, Arithmetic.mk_lt ctx.ctx_z3 e1 e2) (ctx, Arithmetic.mk_lt ctx.ctx_z3 e1 e2)
| Lte KInt, [ (EApp ((EOp (Unop GetYear), _), [ e1 ]), _); (ELit (LInt n), _) ] -> | ( Lte KInt,
[ (EApp ((EOp (Unop GetYear), _), [ e1 ]), _); (ELit (LInt n), _) ] )
->
let n = Runtime.integer_to_int n in let n = Runtime.integer_to_int n in
let ctx, e1 = translate_expr ctx e1 in let ctx, e1 = translate_expr ctx e1 in
let nb_days = if CalendarLib.Date.is_leap_year n then 365 else 364 in let nb_days = if CalendarLib.Date.is_leap_year n then 365 else 364 in
(* We want that the year corresponding to e1 is smaller or equal to n. We encode this as (* We want that the year corresponding to e1 is smaller or equal to n.
the day corresponding to e1 is smaller or equal than the last day of the year [n], We encode this as the day corresponding to e1 is smaller or equal
which is Jan 1st + 365 days if [n] is a leap year, Jan 1st + 364 else *) than the last day of the year [n], which is Jan 1st + 365 days if
[n] is a leap year, Jan 1st + 364 else *)
let e2 = let e2 =
Arithmetic.Integer.mk_numeral_i ctx.ctx_z3 (date_to_int (date_of_year n) + nb_days) Arithmetic.Integer.mk_numeral_i ctx.ctx_z3
(date_to_int (date_of_year n) + nb_days)
in in
(ctx, Arithmetic.mk_le ctx.ctx_z3 e1 e2) (ctx, Arithmetic.mk_le ctx.ctx_z3 e1 e2)
| Gt KInt, [ (EApp ((EOp (Unop GetYear), _), [ e1 ]), _); (ELit (LInt n), _) ] -> | ( Gt KInt,
[ (EApp ((EOp (Unop GetYear), _), [ e1 ]), _); (ELit (LInt n), _) ] )
->
let n = Runtime.integer_to_int n in let n = Runtime.integer_to_int n in
let ctx, e1 = translate_expr ctx e1 in let ctx, e1 = translate_expr ctx e1 in
let nb_days = if CalendarLib.Date.is_leap_year n then 365 else 364 in let nb_days = if CalendarLib.Date.is_leap_year n then 365 else 364 in
(* We want that the year corresponding to e1 is greater to n. We encode this as the day (* We want that the year corresponding to e1 is greater to n. We
corresponding to e1 is greater than the last day of the year [n], which is Jan 1st + encode this as the day corresponding to e1 is greater than the last
365 days if [n] is a leap year, Jan 1st + 364 else *) day of the year [n], which is Jan 1st + 365 days if [n] is a leap
year, Jan 1st + 364 else *)
let e2 = let e2 =
Arithmetic.Integer.mk_numeral_i ctx.ctx_z3 (date_to_int (date_of_year n) + nb_days) Arithmetic.Integer.mk_numeral_i ctx.ctx_z3
(date_to_int (date_of_year n) + nb_days)
in in
(ctx, Arithmetic.mk_gt ctx.ctx_z3 e1 e2) (ctx, Arithmetic.mk_gt ctx.ctx_z3 e1 e2)
| Gte KInt, [ (EApp ((EOp (Unop GetYear), _), [ e1 ]), _); (ELit (LInt n), _) ] -> | ( Gte KInt,
[ (EApp ((EOp (Unop GetYear), _), [ e1 ]), _); (ELit (LInt n), _) ] )
->
let n = Runtime.integer_to_int n in let n = Runtime.integer_to_int n in
let ctx, e1 = translate_expr ctx e1 in let ctx, e1 = translate_expr ctx e1 in
let e2 = Arithmetic.Integer.mk_numeral_i ctx.ctx_z3 (date_to_int (date_of_year n)) in let e2 =
(* e2 corresponds to the first day of the year n. GetYear e1 >= e2 can thus be directly Arithmetic.Integer.mk_numeral_i ctx.ctx_z3
translated as >= in the Z3 encoding using the number of days *) (date_to_int (date_of_year n))
in
(* e2 corresponds to the first day of the year n. GetYear e1 >= e2 can
thus be directly translated as >= in the Z3 encoding using the
number of days *)
(ctx, Arithmetic.mk_ge ctx.ctx_z3 e1 e2) (ctx, Arithmetic.mk_ge ctx.ctx_z3 e1 e2)
| _ -> ( | _ -> (
let ctx, e1, e2 = let ctx, e1, e2 =
@ -401,7 +484,8 @@ let rec translate_op (ctx : context) (op : operator) (args : expr Pos.marked lis
(ctx, e1, e2) (ctx, e1, e2)
| _ -> | _ ->
failwith failwith
(Format.asprintf "[Z3 encoding] Ill-formed binary operator application: %a" (Format.asprintf
"[Z3 encoding] Ill-formed binary operator application: %a"
(Print.format_expr ctx.ctx_decl) (Print.format_expr ctx.ctx_decl)
(EApp ((EOp op, Pos.no_pos), args), Pos.no_pos)) (EApp ((EOp op, Pos.no_pos), args), Pos.no_pos))
in in
@ -410,67 +494,111 @@ let rec translate_op (ctx : context) (op : operator) (args : expr Pos.marked lis
| And -> (ctx, Boolean.mk_and ctx.ctx_z3 [ e1; e2 ]) | And -> (ctx, Boolean.mk_and ctx.ctx_z3 [ e1; e2 ])
| Or -> (ctx, Boolean.mk_or ctx.ctx_z3 [ e1; e2 ]) | Or -> (ctx, Boolean.mk_or ctx.ctx_z3 [ e1; e2 ])
| Xor -> (ctx, Boolean.mk_xor ctx.ctx_z3 e1 e2) | Xor -> (ctx, Boolean.mk_xor ctx.ctx_z3 e1 e2)
| Add KInt | Add KRat | Add KMoney -> (ctx, Arithmetic.mk_add ctx.ctx_z3 [ e1; e2 ]) | Add KInt | Add KRat | Add KMoney ->
(ctx, Arithmetic.mk_add ctx.ctx_z3 [ e1; e2 ])
| Add _ -> | Add _ ->
failwith "[Z3 encoding] application of non-integer binary operator Add not supported" failwith
| Sub KInt | Sub KRat | Sub KMoney -> (ctx, Arithmetic.mk_sub ctx.ctx_z3 [ e1; e2 ]) "[Z3 encoding] application of non-integer binary operator Add \
not supported"
| Sub KInt | Sub KRat | Sub KMoney ->
(ctx, Arithmetic.mk_sub ctx.ctx_z3 [ e1; e2 ])
| Sub _ -> | Sub _ ->
failwith "[Z3 encoding] application of non-integer binary operator Sub not supported" failwith
| Mult KInt | Mult KRat | Mult KMoney -> (ctx, Arithmetic.mk_mul ctx.ctx_z3 [ e1; e2 ]) "[Z3 encoding] application of non-integer binary operator Sub \
not supported"
| Mult KInt | Mult KRat | Mult KMoney ->
(ctx, Arithmetic.mk_mul ctx.ctx_z3 [ e1; e2 ])
| Mult _ -> | Mult _ ->
failwith "[Z3 encoding] application of non-integer binary operator Mult not supported" failwith
| Div KInt | Div KRat | Div KMoney -> (ctx, Arithmetic.mk_div ctx.ctx_z3 e1 e2) "[Z3 encoding] application of non-integer binary operator Mult \
not supported"
| Div KInt | Div KRat | Div KMoney ->
(ctx, Arithmetic.mk_div ctx.ctx_z3 e1 e2)
| Div _ -> | Div _ ->
failwith "[Z3 encoding] application of non-integer binary operator Div not supported" failwith
| Lt KInt | Lt KRat | Lt KMoney | Lt KDate -> (ctx, Arithmetic.mk_lt ctx.ctx_z3 e1 e2) "[Z3 encoding] application of non-integer binary operator Div \
not supported"
| Lt KInt | Lt KRat | Lt KMoney | Lt KDate ->
(ctx, Arithmetic.mk_lt ctx.ctx_z3 e1 e2)
| Lt _ -> | Lt _ ->
failwith failwith
"[Z3 encoding] application of non-integer or money binary operator Lt not supported" "[Z3 encoding] application of non-integer or money binary \
| Lte KInt | Lte KRat | Lte KMoney | Lte KDate -> (ctx, Arithmetic.mk_le ctx.ctx_z3 e1 e2) operator Lt not supported"
| Lte KInt | Lte KRat | Lte KMoney | Lte KDate ->
(ctx, Arithmetic.mk_le ctx.ctx_z3 e1 e2)
| Lte _ -> | Lte _ ->
failwith failwith
"[Z3 encoding] application of non-integer or money binary operator Lte not \ "[Z3 encoding] application of non-integer or money binary \
supported" operator Lte not supported"
| Gt KInt | Gt KRat | Gt KMoney | Gt KDate -> (ctx, Arithmetic.mk_gt ctx.ctx_z3 e1 e2) | Gt KInt | Gt KRat | Gt KMoney | Gt KDate ->
(ctx, Arithmetic.mk_gt ctx.ctx_z3 e1 e2)
| Gt _ -> | Gt _ ->
failwith failwith
"[Z3 encoding] application of non-integer or money binary operator Gt not supported" "[Z3 encoding] application of non-integer or money binary \
| Gte KInt | Gte KRat | Gte KMoney | Gte KDate -> (ctx, Arithmetic.mk_ge ctx.ctx_z3 e1 e2) operator Gt not supported"
| Gte KInt | Gte KRat | Gte KMoney | Gte KDate ->
(ctx, Arithmetic.mk_ge ctx.ctx_z3 e1 e2)
| Gte _ -> | Gte _ ->
failwith failwith
"[Z3 encoding] application of non-integer or money binary operator Gte not \ "[Z3 encoding] application of non-integer or money binary \
supported" operator Gte not supported"
| Eq -> (ctx, Boolean.mk_eq ctx.ctx_z3 e1 e2) | Eq -> (ctx, Boolean.mk_eq ctx.ctx_z3 e1 e2)
| Neq -> (ctx, Boolean.mk_not ctx.ctx_z3 (Boolean.mk_eq ctx.ctx_z3 e1 e2)) | Neq ->
| Map -> failwith "[Z3 encoding] application of binary operator Map not supported" (ctx, Boolean.mk_not ctx.ctx_z3 (Boolean.mk_eq ctx.ctx_z3 e1 e2))
| Concat -> failwith "[Z3 encoding] application of binary operator Concat not supported" | Map ->
| Filter -> failwith "[Z3 encoding] application of binary operator Filter not supported")) failwith
"[Z3 encoding] application of binary operator Map not supported"
| Concat ->
failwith
"[Z3 encoding] application of binary operator Concat not \
supported"
| Filter ->
failwith
"[Z3 encoding] application of binary operator Filter not \
supported"))
| Unop uop -> ( | Unop uop -> (
let ctx, e1 = let ctx, e1 =
match args with match args with
| [ e1 ] -> translate_expr ctx e1 | [ e1 ] -> translate_expr ctx e1
| _ -> | _ ->
failwith failwith
(Format.asprintf "[Z3 encoding] Ill-formed unary operator application: %a" (Format.asprintf
"[Z3 encoding] Ill-formed unary operator application: %a"
(Print.format_expr ctx.ctx_decl) (Print.format_expr ctx.ctx_decl)
(EApp ((EOp op, Pos.no_pos), args), Pos.no_pos)) (EApp ((EOp op, Pos.no_pos), args), Pos.no_pos))
in in
match uop with match uop with
| Not -> (ctx, Boolean.mk_not ctx.ctx_z3 e1) | Not -> (ctx, Boolean.mk_not ctx.ctx_z3 e1)
| Minus _ -> failwith "[Z3 encoding] application of unary operator Minus not supported" | Minus _ ->
failwith
"[Z3 encoding] application of unary operator Minus not supported"
(* Omitting the log from the VC *) (* Omitting the log from the VC *)
| Log _ -> (ctx, e1) | Log _ -> (ctx, e1)
| Length -> failwith "[Z3 encoding] application of unary operator Length not supported" | Length ->
| IntToRat -> failwith "[Z3 encoding] application of unary operator IntToRat not supported" failwith
| GetDay -> failwith "[Z3 encoding] application of unary operator GetDay not supported" "[Z3 encoding] application of unary operator Length not supported"
| GetMonth -> failwith "[Z3 encoding] application of unary operator GetMonth not supported" | IntToRat ->
failwith
"[Z3 encoding] application of unary operator IntToRat not supported"
| GetDay ->
failwith
"[Z3 encoding] application of unary operator GetDay not supported"
| GetMonth ->
failwith
"[Z3 encoding] application of unary operator GetMonth not supported"
| GetYear -> | GetYear ->
failwith "[Z3 encoding] GetYear operator only supported in comparisons with literal") failwith
"[Z3 encoding] GetYear operator only supported in comparisons with \
literal")
(** [translate_expr] translate the expression [vc] to its corresponding Z3 expression **) (** [translate_expr] translate the expression [vc] to its corresponding Z3
and translate_expr (ctx : context) (vc : expr Pos.marked) : context * Expr.expr = expression **)
let translate_match_arm (head : Expr.expr) (ctx : context) and translate_expr (ctx : context) (vc : expr Pos.marked) : context * Expr.expr
=
let translate_match_arm
(head : Expr.expr)
(ctx : context)
(e : expr Pos.marked * FuncDecl.func_decl list) : context * Expr.expr = (e : expr Pos.marked * FuncDecl.func_decl list) : context * Expr.expr =
let e, accessors = e in let e, accessors = e in
match Pos.unmark e with match Pos.unmark e with
@ -482,8 +610,8 @@ and translate_expr (ctx : context) (vc : expr Pos.marked) : context * Expr.expr
(* Invariant: Catala enums always have exactly one argument *) (* Invariant: Catala enums always have exactly one argument *)
let accessor = List.hd accessors in let accessor = List.hd accessors in
let proj = Expr.mk_app ctx.ctx_z3 accessor [ head ] in let proj = Expr.mk_app ctx.ctx_z3 accessor [ head ] in
(* The fresh variable should be substituted by a projection into the enum in the body, we (* The fresh variable should be substituted by a projection into the
add this to the context *) enum in the body, we add this to the context *)
let ctx = add_z3matchsubst fresh_v proj ctx in let ctx = add_z3matchsubst fresh_v proj ctx in
let body = Bindlib.msubst (Pos.unmark e) [| fresh_e |] in let body = Bindlib.msubst (Pos.unmark e) [| fresh_e |] in
@ -496,7 +624,8 @@ and translate_expr (ctx : context) (vc : expr Pos.marked) : context * Expr.expr
| EVar v -> ( | EVar v -> (
match VarMap.find_opt (Pos.unmark v) ctx.ctx_z3matchsubsts with match VarMap.find_opt (Pos.unmark v) ctx.ctx_z3matchsubsts with
| None -> | None ->
(* We are in the standard case, where this is a true Catala variable *) (* We are in the standard case, where this is a true Catala
variable *)
let v = Pos.unmark v in let v = Pos.unmark v in
let t = VarMap.find v ctx.ctx_var in let t = VarMap.find v ctx.ctx_var in
let name = unique_name v in let name = unique_name v in
@ -504,20 +633,23 @@ and translate_expr (ctx : context) (vc : expr Pos.marked) : context * Expr.expr
let ctx, ty = translate_typ ctx (Pos.unmark t) in let ctx, ty = translate_typ ctx (Pos.unmark t) in
(ctx, Expr.mk_const_s ctx.ctx_z3 name ty) (ctx, Expr.mk_const_s ctx.ctx_z3 name ty)
| Some e -> | Some e ->
(* This variable is a temporary variable generated during VC translation of a match. It (* This variable is a temporary variable generated during VC
actually corresponds to applying an accessor to an enum, the corresponding Z3 translation of a match. It actually corresponds to applying an
expression was previously stored in the context *) accessor to an enum, the corresponding Z3 expression was previously
stored in the context *)
(ctx, e)) (ctx, e))
| ETuple _ -> failwith "[Z3 encoding] ETuple unsupported" | ETuple _ -> failwith "[Z3 encoding] ETuple unsupported"
| ETupleAccess (s, idx, oname, _tys) -> | ETupleAccess (s, idx, oname, _tys) ->
let name = let name =
match oname with match oname with
| None -> failwith "[Z3 encoding]: ETupleAccess of unnamed struct unsupported" | None ->
failwith "[Z3 encoding]: ETupleAccess of unnamed struct unsupported"
| Some n -> n | Some n -> n
in in
let ctx, z3_struct = find_or_create_struct ctx name in let ctx, z3_struct = find_or_create_struct ctx name in
(* This datatype should have only one constructor, corresponding to mk_struct. The accessors (* This datatype should have only one constructor, corresponding to
of this constructor correspond to the field accesses *) mk_struct. The accessors of this constructor correspond to the field
accesses *)
let accessors = List.hd (Datatype.get_accessors z3_struct) in let accessors = List.hd (Datatype.get_accessors z3_struct) in
let accessor = List.nth accessors idx in let accessor = List.nth accessors idx in
let ctx, s = translate_expr ctx s in let ctx, s = translate_expr ctx s in
@ -527,7 +659,9 @@ and translate_expr (ctx : context) (vc : expr Pos.marked) : context * Expr.expr
let ctx, z3_enum = find_or_create_enum ctx enum in let ctx, z3_enum = find_or_create_enum ctx enum in
let ctx, z3_arg = translate_expr ctx arg in let ctx, z3_arg = translate_expr ctx arg in
let _ctx, z3_arms = let _ctx, z3_arms =
List.fold_left_map (translate_match_arm z3_arg) ctx List.fold_left_map
(translate_match_arm z3_arg)
ctx
(List.combine arms (Datatype.get_accessors z3_enum)) (List.combine arms (Datatype.get_accessors z3_enum))
in in
let z3_arms = let z3_arms =
@ -548,8 +682,8 @@ and translate_expr (ctx : context) (vc : expr Pos.marked) : context * Expr.expr
| EOp op -> translate_op ctx op args | EOp op -> translate_op ctx op args
| EVar v -> | EVar v ->
let ctx, fd = find_or_create_funcdecl ctx (Pos.unmark v) in let ctx, fd = find_or_create_funcdecl ctx (Pos.unmark v) in
(* Fold_right to preserve the order of the arguments: The head argument is appended at the (* Fold_right to preserve the order of the arguments: The head
head *) argument is appended at the head *)
let ctx, z3_args = let ctx, z3_args =
List.fold_right List.fold_right
(fun arg (ctx, acc) -> (fun arg (ctx, acc) ->
@ -560,8 +694,8 @@ and translate_expr (ctx : context) (vc : expr Pos.marked) : context * Expr.expr
(ctx, Expr.mk_app ctx.ctx_z3 fd z3_args) (ctx, Expr.mk_app ctx.ctx_z3 fd z3_args)
| _ -> | _ ->
failwith failwith
"[Z3 encoding] EApp node: Catala function calls should only include operators or \ "[Z3 encoding] EApp node: Catala function calls should only \
function names") include operators or function names")
| EAssert _ -> failwith "[Z3 encoding] EAssert unsupported" | EAssert _ -> failwith "[Z3 encoding] EAssert unsupported"
| EOp _ -> failwith "[Z3 encoding] EOp unsupported" | EOp _ -> failwith "[Z3 encoding] EOp unsupported"
| EDefault _ -> failwith "[Z3 encoding] EDefault unsupported" | EDefault _ -> failwith "[Z3 encoding] EDefault unsupported"
@ -574,12 +708,15 @@ and translate_expr (ctx : context) (vc : expr Pos.marked) : context * Expr.expr
Boolean.mk_and ctx.ctx_z3 Boolean.mk_and ctx.ctx_z3
[ [
Boolean.mk_implies ctx.ctx_z3 z3_if z3_then; Boolean.mk_implies ctx.ctx_z3 z3_if z3_then;
Boolean.mk_implies ctx.ctx_z3 (Boolean.mk_not ctx.ctx_z3 z3_if) z3_else; Boolean.mk_implies ctx.ctx_z3
(Boolean.mk_not ctx.ctx_z3 z3_if)
z3_else;
] ) ] )
| ErrorOnEmpty _ -> failwith "[Z3 encoding] ErrorOnEmpty unsupported" | ErrorOnEmpty _ -> failwith "[Z3 encoding] ErrorOnEmpty unsupported"
(** [create_z3unit] creates a Z3 sort and expression corresponding to the unit type and value (** [create_z3unit] creates a Z3 sort and expression corresponding to the unit
respectively. Concretely, we represent unit as a tuple with 0 elements **) type and value respectively. Concretely, we represent unit as a tuple with 0
elements **)
let create_z3unit (ctx : Z3.context) : Z3.context * (Sort.sort * Expr.expr) = let create_z3unit (ctx : Z3.context) : Z3.context * (Sort.sort * Expr.expr) =
let unit_sort = Tuple.mk_sort ctx (Symbol.mk_string ctx "unit") [] [] in let unit_sort = Tuple.mk_sort ctx (Symbol.mk_string ctx "unit") [] [] in
let mk_unit = Tuple.get_mk_decl unit_sort in let mk_unit = Tuple.get_mk_decl unit_sort in
@ -588,16 +725,15 @@ let create_z3unit (ctx : Z3.context) : Z3.context * (Sort.sort * Expr.expr) =
module Backend = struct module Backend = struct
type backend_context = context type backend_context = context
type vc_encoding = Z3.Expr.expr type vc_encoding = Z3.Expr.expr
let print_encoding (vc : vc_encoding) : string = Expr.to_string vc let print_encoding (vc : vc_encoding) : string = Expr.to_string vc
type model = Z3.Model.model type model = Z3.Model.model
type solver_result = ProvenTrue | ProvenFalse of model option | Unknown type solver_result = ProvenTrue | ProvenFalse of model option | Unknown
let solve_vc_encoding (ctx : backend_context) (encoding : vc_encoding) : solver_result = let solve_vc_encoding (ctx : backend_context) (encoding : vc_encoding) :
solver_result =
let solver = Z3.Solver.mk_solver ctx.ctx_z3 None in let solver = Z3.Solver.mk_solver ctx.ctx_z3 None in
Z3.Solver.add solver [ Boolean.mk_not ctx.ctx_z3 encoding ]; Z3.Solver.add solver [ Boolean.mk_not ctx.ctx_z3 encoding ];
match Z3.Solver.check solver [] with match Z3.Solver.check solver [] with
@ -605,18 +741,23 @@ module Backend = struct
| SATISFIABLE -> ProvenFalse (Z3.Solver.get_model solver) | SATISFIABLE -> ProvenFalse (Z3.Solver.get_model solver)
| UNKNOWN -> Unknown | UNKNOWN -> Unknown
let print_model (ctx : backend_context) (m : model) : string = print_model ctx m let print_model (ctx : backend_context) (m : model) : string =
print_model ctx m
let is_model_empty (m : model) : bool = List.length (Z3.Model.get_decls m) = 0 let is_model_empty (m : model) : bool = List.length (Z3.Model.get_decls m) = 0
let translate_expr (ctx : backend_context) (e : Dcalc.Ast.expr Pos.marked) = translate_expr ctx e let translate_expr (ctx : backend_context) (e : Dcalc.Ast.expr Pos.marked) =
translate_expr ctx e
let init_backend () = Cli.debug_print "Running Z3 version %s" Version.to_string let init_backend () =
Cli.debug_print "Running Z3 version %s" Version.to_string
let make_context (decl_ctx : decl_ctx) (free_vars_typ : typ Pos.marked VarMap.t) : backend_context let make_context
= (decl_ctx : decl_ctx) (free_vars_typ : typ Pos.marked VarMap.t) :
backend_context =
let cfg = let cfg =
(if !Cli.disable_counterexamples then [] else [ ("model", "true") ]) @ [ ("proof", "false") ] (if !Cli.disable_counterexamples then [] else [ ("model", "true") ])
@ [ ("proof", "false") ]
in in
let z3_ctx = mk_context cfg in let z3_ctx = mk_context cfg in
let z3_ctx, z3unit = create_z3unit z3_ctx in let z3_ctx, z3unit = create_z3unit z3_ctx in

View File

@ -1,15 +1,18 @@
(* This file is part of the Catala compiler, a specification language for tax and social benefits (* This file is part of the Catala compiler, a specification language for tax
computation rules. Copyright (C) 2022 Inria, contributor: Aymeric Fromherz and social benefits computation rules. Copyright (C) 2022 Inria, contributor:
<aymeric.fromherz@inria.fr>, Denis Merigoux <denis.merigoux@inria.fr> Aymeric Fromherz <aymeric.fromherz@inria.fr>, Denis Merigoux
<denis.merigoux@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
(** Interfacing with the Z3 SMT solver *) (** Interfacing with the Z3 SMT solver *)

View File

@ -1,24 +1,31 @@
(* This file is part of the French law library, a collection of functions for computing French taxes (* This file is part of the French law library, a collection of functions for
and benefits derived from Catala programs. Copyright (C) 2021 Inria, contributor: Denis Merigoux computing French taxes and benefits derived from Catala programs. Copyright
<denis.merigoux@inria.fr> (C) 2021 Inria, contributor: Denis Merigoux <denis.merigoux@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
module Allocations_familiales = Law_source.Allocations_familiales module Allocations_familiales = Law_source.Allocations_familiales
module AF = Allocations_familiales module AF = Allocations_familiales
open Runtime open Runtime
let compute_allocations_familiales ~(current_date : Runtime.date) let compute_allocations_familiales
~(children : AF.enfant_entree array) ~(income : int) ~(residence : AF.collectivite) ~(current_date : Runtime.date)
~(is_parent : bool) ~(fills_title_I : bool) ~(had_rights_open_before_2012 : bool) : float = ~(children : AF.enfant_entree array)
~(income : int)
~(residence : AF.collectivite)
~(is_parent : bool)
~(fills_title_I : bool)
~(had_rights_open_before_2012 : bool) : float =
let result = let result =
AF.interface_allocations_familiales AF.interface_allocations_familiales
{ {
@ -27,8 +34,10 @@ let compute_allocations_familiales ~(current_date : Runtime.date)
AF.i_ressources_menage_in = money_of_units_int income; AF.i_ressources_menage_in = money_of_units_int income;
AF.i_residence_in = residence; AF.i_residence_in = residence;
AF.i_personne_charge_effective_permanente_est_parent_in = is_parent; AF.i_personne_charge_effective_permanente_est_parent_in = is_parent;
AF.i_personne_charge_effective_permanente_remplit_titre_I_in = fills_title_I; AF.i_personne_charge_effective_permanente_remplit_titre_I_in =
AF.i_avait_enfant_a_charge_avant_1er_janvier_2012_in = had_rights_open_before_2012; fills_title_I;
AF.i_avait_enfant_a_charge_avant_1er_janvier_2012_in =
had_rights_open_before_2012;
} }
in in
money_to_float result.AF.i_montant_verse_out money_to_float result.AF.i_montant_verse_out

View File

@ -1,15 +1,17 @@
(* This file is part of the French law library, a collection of functions for computing French taxes (* This file is part of the French law library, a collection of functions for
and benefits derived from Catala programs. Copyright (C) 2021 Inria, contributor: Denis Merigoux computing French taxes and benefits derived from Catala programs. Copyright
<denis.merigoux@inria.fr> (C) 2021 Inria, contributor: Denis Merigoux <denis.merigoux@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
module Allocations_familiales = Law_source.Allocations_familiales module Allocations_familiales = Law_source.Allocations_familiales

View File

@ -1,15 +1,17 @@
(* This file is part of the French law library, a collection of functions for computing French taxes (* This file is part of the French law library, a collection of functions for
and benefits derived from Catala programs. Copyright (C) 2021 Inria, contributor: Denis Merigoux computing French taxes and benefits derived from Catala programs. Copyright
<denis.merigoux@inria.fr> (C) 2021 Inria, contributor: Denis Merigoux <denis.merigoux@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
module Allocations_familiales = Law_source.Allocations_familiales module Allocations_familiales = Law_source.Allocations_familiales
@ -20,13 +22,9 @@ open Js_of_ocaml
class type enfant_entree = class type enfant_entree =
object object
method id : int Js.readonly_prop method id : int Js.readonly_prop
method remunerationMensuelle : int Js.readonly_prop method remunerationMensuelle : int Js.readonly_prop
method dateNaissance : Js.date Js.t Js.readonly_prop method dateNaissance : Js.date Js.t Js.readonly_prop
method gardeAlternee : bool Js.t Js.readonly_prop method gardeAlternee : bool Js.t Js.readonly_prop
method gardeAlterneePartageAllocation : bool Js.t Js.readonly_prop method gardeAlterneePartageAllocation : bool Js.t Js.readonly_prop
method priseEnCharge : Js.js_string Js.t Js.readonly_prop method priseEnCharge : Js.js_string Js.t Js.readonly_prop
@ -44,42 +42,33 @@ class type enfant_entree =
class type allocations_familiales_input = class type allocations_familiales_input =
object object
method currentDate : Js.date Js.t Js.readonly_prop method currentDate : Js.date Js.t Js.readonly_prop
method children : enfant_entree Js.t Js.js_array Js.t Js.readonly_prop method children : enfant_entree Js.t Js.js_array Js.t Js.readonly_prop
method income : int Js.readonly_prop method income : int Js.readonly_prop
method residence : Js.js_string Js.t Js.readonly_prop method residence : Js.js_string Js.t Js.readonly_prop
method personneQuiAssumeLaChargeEffectivePermanenteEstParent : bool Js.t Js.readonly_prop method personneQuiAssumeLaChargeEffectivePermanenteEstParent :
bool Js.t Js.readonly_prop
method personneQuiAssumeLaChargeEffectivePermanenteRemplitConditionsTitreISecuriteSociale : method
personneQuiAssumeLaChargeEffectivePermanenteRemplitConditionsTitreISecuriteSociale :
bool Js.t Js.readonly_prop bool Js.t Js.readonly_prop
end end
class type source_position = class type source_position =
object object
method fileName : Js.js_string Js.t Js.prop method fileName : Js.js_string Js.t Js.prop
method startLine : int Js.prop method startLine : int Js.prop
method endLine : int Js.prop method endLine : int Js.prop
method startColumn : int Js.prop method startColumn : int Js.prop
method endColumn : int Js.prop method endColumn : int Js.prop
method lawHeadings : Js.js_string Js.t Js.js_array Js.t Js.prop method lawHeadings : Js.js_string Js.t Js.js_array Js.t Js.prop
end end
class type log_event = class type log_event =
object object
method eventType : Js.js_string Js.t Js.prop method eventType : Js.js_string Js.t Js.prop
method information : Js.js_string Js.t Js.js_array Js.t Js.prop method information : Js.js_string Js.t Js.js_array Js.t Js.prop
method sourcePosition : source_position Js.t Js.optdef Js.prop method sourcePosition : source_position Js.t Js.optdef Js.prop
method loggedValue : Js.Unsafe.any Js.prop method loggedValue : Js.Unsafe.any Js.prop
end end
@ -102,13 +91,17 @@ let rec embed_to_js (v : runtime_value) : Js.Unsafe.any =
Js.Unsafe.inject date Js.Unsafe.inject date
| Duration d -> | Duration d ->
let days, months, years = duration_to_years_months_days d in let days, months, years = duration_to_years_months_days d in
Js.Unsafe.inject (Js.string (Printf.sprintf "%dD%dM%dY" days months years)) Js.Unsafe.inject
(Js.string (Printf.sprintf "%dD%dM%dY" days months years))
| Struct (name, fields) -> | Struct (name, fields) ->
Js.Unsafe.inject Js.Unsafe.inject
(object%js (object%js
val mutable structName = val mutable structName =
if List.length name = 1 then Js.Unsafe.inject (Js.string (List.hd name)) if List.length name = 1 then
else Js.Unsafe.inject (Js.array (Array.of_list (List.map Js.string name))) Js.Unsafe.inject (Js.string (List.hd name))
else
Js.Unsafe.inject
(Js.array (Array.of_list (List.map Js.string name)))
val mutable structFields = val mutable structFields =
Js.Unsafe.inject Js.Unsafe.inject
@ -117,9 +110,11 @@ let rec embed_to_js (v : runtime_value) : Js.Unsafe.any =
(List.map (List.map
(fun (name, v) -> (fun (name, v) ->
object%js object%js
val mutable fieldName = Js.Unsafe.inject (Js.string name) val mutable fieldName =
Js.Unsafe.inject (Js.string name)
val mutable fieldValue = Js.Unsafe.inject (embed_to_js v) val mutable fieldValue =
Js.Unsafe.inject (embed_to_js v)
end) end)
fields))) fields)))
end) end)
@ -127,11 +122,13 @@ let rec embed_to_js (v : runtime_value) : Js.Unsafe.any =
Js.Unsafe.inject Js.Unsafe.inject
(object%js (object%js
val mutable enumName = val mutable enumName =
if List.length name = 1 then Js.Unsafe.inject (Js.string (List.hd name)) if List.length name = 1 then
else Js.Unsafe.inject (Js.array (Array.of_list (List.map Js.string name))) Js.Unsafe.inject (Js.string (List.hd name))
else
Js.Unsafe.inject
(Js.array (Array.of_list (List.map Js.string name)))
val mutable enumCase = Js.Unsafe.inject (Js.string case) val mutable enumCase = Js.Unsafe.inject (Js.string case)
val mutable enumPayload = Js.Unsafe.inject (embed_to_js v) val mutable enumPayload = Js.Unsafe.inject (embed_to_js v)
end) end)
| Array vs -> Js.Unsafe.inject (Js.array (Array.map embed_to_js vs)) | Array vs -> Js.Unsafe.inject (Js.array (Array.map embed_to_js vs))
@ -142,7 +139,8 @@ let _ =
(object%js (object%js
method resetLog : (unit -> unit) Js.callback = Js.wrap_callback reset_log method resetLog : (unit -> unit) Js.callback = Js.wrap_callback reset_log
method retrieveLog : (unit -> log_event Js.t Js.js_array Js.t) Js.callback = method retrieveLog
: (unit -> log_event Js.t Js.js_array Js.t) Js.callback =
Js.wrap_callback (fun () -> Js.wrap_callback (fun () ->
Js.array Js.array
(Array.of_list (Array.of_list
@ -161,7 +159,9 @@ let _ =
Js.array Js.array
(Array.of_list (Array.of_list
(match evt with (match evt with
| BeginCall info | EndCall info | VariableDefinition (info, _) -> | BeginCall info
| EndCall info
| VariableDefinition (info, _) ->
List.map Js.string info List.map Js.string info
| DecisionTaken _ -> [])) | DecisionTaken _ -> []))
@ -176,18 +176,18 @@ let _ =
| DecisionTaken pos -> | DecisionTaken pos ->
Js.def Js.def
(object%js (object%js
val mutable fileName = Js.string pos.filename val mutable fileName =
Js.string pos.filename
val mutable startLine = pos.start_line val mutable startLine = pos.start_line
val mutable endLine = pos.end_line val mutable endLine = pos.end_line
val mutable startColumn = pos.start_column val mutable startColumn = pos.start_column
val mutable endColumn = pos.end_column val mutable endColumn = pos.end_column
val mutable lawHeadings = val mutable lawHeadings =
Js.array (Array.of_list (List.map Js.string pos.law_headings)) Js.array
(Array.of_list
(List.map Js.string pos.law_headings))
end) end)
| _ -> Js.undefined | _ -> Js.undefined
end) end)
@ -200,7 +200,8 @@ let _ =
AF.interface_allocations_familiales AF.interface_allocations_familiales
{ {
AF.i_personne_charge_effective_permanente_est_parent_in = AF.i_personne_charge_effective_permanente_est_parent_in =
Js.to_bool input##.personneQuiAssumeLaChargeEffectivePermanenteEstParent; Js.to_bool
input##.personneQuiAssumeLaChargeEffectivePermanenteEstParent;
AF.i_personne_charge_effective_permanente_remplit_titre_I_in = AF.i_personne_charge_effective_permanente_remplit_titre_I_in =
Js.to_bool Js.to_bool
input##.personneQuiAssumeLaChargeEffectivePermanenteRemplitConditionsTitreISecuriteSociale; input##.personneQuiAssumeLaChargeEffectivePermanenteRemplitConditionsTitreISecuriteSociale;
@ -214,7 +215,8 @@ let _ =
(fun (child : enfant_entree Js.t) -> (fun (child : enfant_entree Js.t) ->
{ {
AF.d_a_deja_ouvert_droit_aux_allocations_familiales = AF.d_a_deja_ouvert_droit_aux_allocations_familiales =
Js.to_bool child##.aDejaOuvertDroitAuxAllocationsFamiliales; Js.to_bool
child##.aDejaOuvertDroitAuxAllocationsFamiliales;
AF.d_identifiant = integer_of_int child##.id; AF.d_identifiant = integer_of_int child##.id;
AF.d_date_de_naissance = AF.d_date_de_naissance =
date_of_numbers date_of_numbers
@ -223,16 +225,19 @@ let _ =
child##.dateNaissance##getUTCDate; child##.dateNaissance##getUTCDate;
AF.d_prise_en_charge = AF.d_prise_en_charge =
(match Js.to_string child##.priseEnCharge with (match Js.to_string child##.priseEnCharge with
| "Effective et permanente" -> EffectiveEtPermanente () | "Effective et permanente" ->
EffectiveEtPermanente ()
| "Garde alternée, allocataire unique" -> | "Garde alternée, allocataire unique" ->
GardeAlterneeAllocataireUnique () GardeAlterneeAllocataireUnique ()
| "Garde alternée, partage des allocations" -> | "Garde alternée, partage des allocations" ->
GardeAlterneePartageAllocations () GardeAlterneePartageAllocations ()
| "Confié aux service sociaux, allocation versée à la famille" -> | "Confié aux service sociaux, allocation versée \
à la famille" ->
ServicesSociauxAllocationVerseeALaFamille () ServicesSociauxAllocationVerseeALaFamille ()
| "Confié aux service sociaux, allocation versée aux services sociaux" | "Confié aux service sociaux, allocation versée \
-> aux services sociaux" ->
ServicesSociauxAllocationVerseeAuxServicesSociaux () ServicesSociauxAllocationVerseeAuxServicesSociaux
()
| _ -> failwith "Unknown prise en charge"); | _ -> failwith "Unknown prise en charge");
AF.d_remuneration_mensuelle = AF.d_remuneration_mensuelle =
money_of_units_int child##.remunerationMensuelle; money_of_units_int child##.remunerationMensuelle;

View File

@ -1,15 +1,17 @@
(* This file is part of the French law library, a collection of functions for computing French taxes (* This file is part of the French law library, a collection of functions for
and benefits derived from Catala programs. Copyright (C) 2021 Inria, contributor: Denis Merigoux computing French taxes and benefits derived from Catala programs. Copyright
<denis.merigoux@inria.fr> (C) 2021 Inria, contributor: Denis Merigoux <denis.merigoux@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except Licensed under the Apache License, Version 2.0 (the "License"); you may not
in compliance with the License. You may obtain a copy of the License at use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License Unless required by applicable law or agreed to in writing, software
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
or implied. See the License for the specific language governing permissions and limitations under WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *) the License. *)
module AF = Api.Allocations_familiales module AF = Api.Allocations_familiales
@ -20,7 +22,10 @@ let random_children (id : int) =
AF.d_identifiant = integer_of_int id; AF.d_identifiant = integer_of_int id;
d_remuneration_mensuelle = money_of_units_int (Random.int 2000); d_remuneration_mensuelle = money_of_units_int (Random.int 2000);
d_date_de_naissance = d_date_de_naissance =
date_of_numbers (2020 - Random.int 22) (1 + Random.int 12) (1 + Random.int 28); date_of_numbers
(2020 - Random.int 22)
(1 + Random.int 12)
(1 + Random.int 28);
d_prise_en_charge = d_prise_en_charge =
(match Random.int 5 with (match Random.int 5 with
| 0 -> AF.EffectiveEtPermanente () | 0 -> AF.EffectiveEtPermanente ()
@ -44,18 +49,21 @@ let format_residence (fmt : Format.formatter) (r : AF.collectivite) : unit =
| AF.SaintMartin _ -> "Saint Martin" | AF.SaintMartin _ -> "Saint Martin"
| AF.Mayotte _ -> "Mayotte") | AF.Mayotte _ -> "Mayotte")
let format_prise_en_charge (fmt : Format.formatter) (g : AF.prise_en_charge) : unit = let format_prise_en_charge (fmt : Format.formatter) (g : AF.prise_en_charge) :
unit =
Format.fprintf fmt "%s" Format.fprintf fmt "%s"
(match g with (match g with
| AF.EffectiveEtPermanente _ -> "Effective et permanente" | AF.EffectiveEtPermanente _ -> "Effective et permanente"
| AF.GardeAlterneePartageAllocations _ -> "Garde alternée, allocations partagée" | AF.GardeAlterneePartageAllocations _ ->
| AF.GardeAlterneeAllocataireUnique _ -> "Garde alternée, allocataire unique" "Garde alternée, allocations partagée"
| AF.ServicesSociauxAllocationVerseeALaFamille _ -> "Oui, allocations versée à la famille" | AF.GardeAlterneeAllocataireUnique _ ->
"Garde alternée, allocataire unique"
| AF.ServicesSociauxAllocationVerseeALaFamille _ ->
"Oui, allocations versée à la famille"
| AF.ServicesSociauxAllocationVerseeAuxServicesSociaux _ -> | AF.ServicesSociauxAllocationVerseeAuxServicesSociaux _ ->
"Oui, allocations versée aux services sociaux") "Oui, allocations versée aux services sociaux")
let num_successful = ref 0 let num_successful = ref 0
let total_amount = ref 0. let total_amount = ref 0.
let run_test () = let run_test () =
@ -65,12 +73,16 @@ let run_test () =
let current_date = Runtime.date_of_numbers 2020 05 01 in let current_date = Runtime.date_of_numbers 2020 05 01 in
let residence = let residence =
let x = Random.int 2 in let x = Random.int 2 in
match x with 0 -> AF.Metropole () | 1 -> AF.Guadeloupe () | _ -> AF.Mayotte () match x with
| 0 -> AF.Metropole ()
| 1 -> AF.Guadeloupe ()
| _ -> AF.Mayotte ()
in in
try try
let amount = let amount =
Api.compute_allocations_familiales ~current_date ~income ~residence ~children ~is_parent:true Api.compute_allocations_familiales ~current_date ~income ~residence
~fills_title_I:true ~had_rights_open_before_2012:(Random.bool ()) ~children ~is_parent:true ~fills_title_I:true
~had_rights_open_before_2012:(Random.bool ())
in in
incr num_successful; incr num_successful;
total_amount := Float.add !total_amount amount total_amount := Float.add !total_amount amount
@ -82,7 +94,11 @@ let run_test () =
| ConflictError -> "Conflict error!" | ConflictError -> "Conflict error!"
| _ -> failwith "impossible") | _ -> failwith "impossible")
(Format.pp_print_list (fun fmt child -> (Format.pp_print_list (fun fmt child ->
Format.fprintf fmt "Child %d:\n income: %.2f\n birth date: %s\n prise en charge: %a" Format.fprintf fmt
"Child %d:\n\
\ income: %.2f\n\
\ birth date: %s\n\
\ prise en charge: %a"
(integer_to_int child.AF.d_identifiant) (integer_to_int child.AF.d_identifiant)
(money_to_float child.AF.d_remuneration_mensuelle) (money_to_float child.AF.d_remuneration_mensuelle)
(Runtime.date_to_string child.AF.d_date_de_naissance) (Runtime.date_to_string child.AF.d_date_de_naissance)
@ -97,11 +113,15 @@ let bench =
Random.init (int_of_float (Unix.time ())); Random.init (int_of_float (Unix.time ()));
let num_iter = 10000 in let num_iter = 10000 in
let _ = let _ =
Benchmark.latency1 ~style:Auto ~name:"Allocations familiales" (Int64.of_int num_iter) run_test Benchmark.latency1 ~style:Auto ~name:"Allocations familiales"
() (Int64.of_int num_iter) run_test ()
in in
Printf.printf "Successful computations: %d (%.2f%%)\nTotal benefits awarded: %.2f€ (mean %.2f€)\n" Printf.printf
"Successful computations: %d (%.2f%%)\n\
Total benefits awarded: %.2f (mean %.2f)\n"
!num_successful !num_successful
(Float.mul (Float.div (float_of_int !num_successful) (float_of_int num_iter)) 100.) (Float.mul
(Float.div (float_of_int !num_successful) (float_of_int num_iter))
100.)
!total_amount !total_amount
(Float.div !total_amount (float_of_int !num_successful)) (Float.div !total_amount (float_of_int !num_successful))

Some files were not shown because too many files have changed in this diff Show More