diff --git a/.ocamlformat b/.ocamlformat index 484d5d0d..e8450b9a 100644 --- a/.ocamlformat +++ b/.ocamlformat @@ -1,7 +1,11 @@ profile = default margin = 80 exp-grouping = preserve -break-fun-decl = smart +break-fun-decl = fit-or-vertical wrap-comments parse-docstrings -version=0.20.1 +version=0.21.0 +cases-exp-indent=2 +indicate-multiline-delimiters=no +parens-tuple=multi-line-only +space-around-lists=false diff --git a/build_system/clerk_driver.ml b/build_system/clerk_driver.ml index ce84e484..768cc80c 100644 --- a/build_system/clerk_driver.ml +++ b/build_system/clerk_driver.ml @@ -34,12 +34,12 @@ let command = & info [] ~docv:"COMMAND" ~doc:"Command selection among: test, run") let debug = - Arg.(value & flag & info [ "debug"; "d" ] ~doc:"Prints debug information") + Arg.(value & flag & info ["debug"; "d"] ~doc:"Prints debug information") let reset_test_outputs = Arg.( value & flag - & info [ "r"; "reset" ] + & info ["r"; "reset"] ~doc: "Used with the `test` command, resets the test output to whatever is \ output by the Catala compiler.") @@ -48,14 +48,14 @@ let catalac = Arg.( value & opt (some string) None - & info [ "e"; "exe" ] ~docv:"EXE" + & info ["e"; "exe"] ~docv:"EXE" ~doc:"Catala compiler executable, defaults to `catala`") let ninja_output = Arg.( value & opt (some string) None - & info [ "o"; "output" ] ~docv:"OUTPUT" + & info ["o"; "output"] ~docv:"OUTPUT" ~doc: "$(i, OUTPUT) is the file that will contain the build.ninja file \ output. If not specified, the build.ninja file will be outputed in \ @@ -65,7 +65,7 @@ let scope = Arg.( value & opt (some string) None - & info [ "s"; "scope" ] ~docv:"SCOPE" + & info ["s"; "scope"] ~docv:"SCOPE" ~doc: "Used with the `run` command, selects which scope of a given Catala \ file to run.") @@ -74,7 +74,7 @@ let makeflags = Arg.( value & opt (some string) None - & info [ "makeflags" ] ~docv:"LANG" + & info ["makeflags"] ~docv:"LANG" ~doc: "Provides the contents of a $(i, MAKEFLAGS) variable to pass on to \ Ninja. Currently recognizes the -i and -j options.") @@ -83,7 +83,7 @@ let catala_opts = Arg.( value & opt (some string) None - & info [ "c"; "catala-opts" ] ~docv:"LANG" + & info ["c"; "catala-opts"] ~docv:"LANG" ~doc:"Options to pass to the Catala compiler") let clerk_t f = @@ -134,7 +134,7 @@ let info = "Please file bug reports at https://github.com/CatalaLang/catala/issues"; ] in - let exits = Term.default_exits @ [ Term.exit_info ~doc:"on error." 1 ] in + let exits = Term.default_exits @ [Term.exit_info ~doc:"on error." 1] in Term.info "clerk" ~version ~doc ~exits ~man (**{1 Testing}*) @@ -173,19 +173,19 @@ let filename_to_expected_output_descr (output_dir : string) (filename : string) match backend with | None -> None | Some backend -> - let second_extension = Filename.extension filename in - let base_filename, scope = - if Re.Pcre.pmatch ~rex:catala_suffix_regex second_extension then - (filename, None) - else - let scope_name_regex = Re.Pcre.regexp "\\.(.+)" in - let scope_name = - try (Re.Pcre.extract ~rex:scope_name_regex second_extension).(1) - with Not_found -> "" - in - (Filename.remove_extension filename, Some scope_name) - in - Some { output_dir; complete_filename; base_filename; backend; scope } + let second_extension = Filename.extension filename in + let base_filename, scope = + if Re.Pcre.pmatch ~rex:catala_suffix_regex second_extension then + filename, None + else + let scope_name_regex = Re.Pcre.regexp "\\.(.+)" in + let scope_name = + try (Re.Pcre.extract ~rex:scope_name_regex second_extension).(1) + with Not_found -> "" + in + Filename.remove_extension filename, Some scope_name + in + Some { output_dir; complete_filename; base_filename; backend; scope } (** [readdir_sort dirname] returns the sorted subdirectories of [dirname] in an array or an empty array if the [dirname] doesn't exist. *) @@ -206,9 +206,9 @@ let search_for_expected_outputs (file : string) : expected_output_descr list = match filename_to_expected_output_descr output_dir output_file with | None -> None | Some expected_output -> - if expected_output.base_filename = Filename.basename file then - Some expected_output - else None) + if expected_output.base_filename = Filename.basename file then + Some expected_output + else None) (Array.to_list output_files) let add_reset_rules_aux @@ -235,7 +235,7 @@ let add_reset_rules_aux ~command: Nj.Expr.( Seq - ([ Lit catala_exe_opts; Lit "-s"; Var "scope" ] + ([Lit catala_exe_opts; Lit "-s"; Var "scope"] @ reset_common_cmd_exprs)) ~description: Nj.Expr.( @@ -281,7 +281,7 @@ let add_test_rules_aux ~command: Nj.Expr.( Seq - ([ Lit catala_exe_opts; Lit "-s"; Var "scope" ] + ([Lit catala_exe_opts; Lit "-s"; Var "scope"] @ test_common_cmd_exprs)) ~description: Nj.Expr.( @@ -346,8 +346,8 @@ let add_test_rules (catala_exe_opts : string) (rules : Rule.t Nj.RuleMap.t) : (** [add_reset_with_ouput_rules catala_exe_opts rules] adds ninja rules used to reset test files using an output flag into [rules] and returns it.*) let add_reset_with_output_rules - (catala_exe_opts : string) (rules : Rule.t Nj.RuleMap.t) : - Rule.t Nj.RuleMap.t = + (catala_exe_opts : string) + (rules : Rule.t Nj.RuleMap.t) : Rule.t Nj.RuleMap.t = add_reset_rules_aux ~with_scope_output_rule:"reset_with_scope_and_output" ~without_scope_output_rule:"reset_without_scope_and_output" ~redirect:"-o" catala_exe_opts rules @@ -355,8 +355,8 @@ let add_reset_with_output_rules (** [add_test_with_output_rules catala_exe_opts rules] adds ninja rules used to test files using an output flag into [rules] and returns it.*) let add_test_with_output_rules - (catala_exe_opts : string) (rules : Rule.t Nj.RuleMap.t) : - Rule.t Nj.RuleMap.t = + (catala_exe_opts : string) + (rules : Rule.t Nj.RuleMap.t) : Rule.t Nj.RuleMap.t = let test_common_cmd_exprs = Nj.Expr. [ @@ -382,10 +382,10 @@ let ninja_start (catala_exe : string) (catala_opts : string) : ninja = let catala_exe_opts = catala_exe ^ " " ^ catala_opts in let run_and_display_final_message = Nj.Rule.make "run_and_display_final_message" - ~command:Nj.Expr.(Seq [ Lit ":" ]) + ~command:Nj.Expr.(Seq [Lit ":"]) ~description: Nj.Expr.( - Seq [ Lit "All tests"; Var "test_file_or_folder"; Lit "passed!" ]) + Seq [Lit "All tests"; Var "test_file_or_folder"; Lit "passed!"]) in { rules = @@ -403,8 +403,9 @@ let ninja_start (catala_exe : string) (catala_opts : string) : ninja = creates and returns all ninja build statements needed to test the [tested_file]. *) let collect_all_ninja_build - (ninja : ninja) (tested_file : string) (reset_test_outputs : bool) : - (string * ninja) option = + (ninja : ninja) + (tested_file : string) + (reset_test_outputs : bool) : (string * ninja) option = let expected_outputs = search_for_expected_outputs tested_file in if List.length expected_outputs = 0 then ( Cli.debug_print "No expected outputs were found for test file %s" @@ -420,7 +421,7 @@ let collect_all_ninja_build Nj.Expr.Lit (Cli.catala_backend_option_to_string expected_output.backend) ); - ("tested_file", Nj.Expr.Lit tested_file); + "tested_file", Nj.Expr.Lit tested_file; ( "expected_output", Nj.Expr.Lit (expected_output.output_dir @@ -435,17 +436,17 @@ let collect_all_ninja_build string option -> string * string * (string * Nj.Expr.t) list = function | Some scope -> - ( Printf.sprintf "%s_%s_%s_%s" output_build_kind scope - catala_backend tested_file - |> Nj.Build.unpath, - output_build_kind ^ "_with_scope" ^ rule_postfix, - ("scope", Nj.Expr.Lit scope) :: vars ) + ( Printf.sprintf "%s_%s_%s_%s" output_build_kind scope + catala_backend tested_file + |> Nj.Build.unpath, + output_build_kind ^ "_with_scope" ^ rule_postfix, + ("scope", Nj.Expr.Lit scope) :: vars ) | None -> - ( Printf.sprintf "%s_%s_%s" output_build_kind catala_backend - tested_file - |> Nj.Build.unpath, - output_build_kind ^ "_without_scope" ^ rule_postfix, - vars ) + ( Printf.sprintf "%s_%s_%s" output_build_kind catala_backend + tested_file + |> Nj.Build.unpath, + output_build_kind ^ "_without_scope" ^ rule_postfix, + vars ) in let ninja_add_new_rule @@ -457,8 +458,7 @@ let collect_all_ninja_build ninja with builds = Nj.BuildMap.add rule_output - (Nj.Build.make_with_vars - ~outputs:[ Nj.Expr.Lit rule_output ] + (Nj.Build.make_with_vars ~outputs:[Nj.Expr.Lit rule_output] ~rule ~vars) ninja.builds; } @@ -467,33 +467,30 @@ let collect_all_ninja_build match expected_output.backend with | Cli.Interpret | Cli.Proof | Cli.Typecheck | Cli.Dcalc | Cli.Scopelang | Cli.Scalc | Cli.Lcalc -> - let rule_output, rule_name, rule_vars = - get_rule_infos expected_output.scope - in - let rule_vars = - match expected_output.backend with - | Cli.Proof -> - ("extra_flags", Nj.Expr.Lit "--disable_counterexamples") - :: rule_vars - (* Counterexamples can be different at each call because of - the randomness inside SMT solver, so we can't expect - their value to remain constant. Hence we disable the - counterexamples when testing the replication of failed - proofs. *) - | _ -> rule_vars - in - ( ninja_add_new_rule rule_output rule_name rule_vars ninja, - test_names ^ " $\n " ^ rule_output ) + let rule_output, rule_name, rule_vars = + get_rule_infos expected_output.scope + in + let rule_vars = + match expected_output.backend with + | Cli.Proof -> + ("extra_flags", Nj.Expr.Lit "--disable_counterexamples") + :: rule_vars + (* Counterexamples can be different at each call because of the + randomness inside SMT solver, so we can't expect their value + to remain constant. Hence we disable the counterexamples when + testing the replication of failed proofs. *) + | _ -> rule_vars + in + ( ninja_add_new_rule rule_output rule_name rule_vars ninja, + test_names ^ " $\n " ^ rule_output ) | Cli.Python | Cli.OCaml | Cli.Latex | Cli.Html | Cli.Makefile -> - let tmp_file = - Filename.temp_file "clerk_" ("_" ^ catala_backend) - in - let rule_output, rule_name, rule_vars = - get_rule_infos ~rule_postfix:"_and_output" expected_output.scope - in - let rule_vars = ("tmp_file", Nj.Expr.Lit tmp_file) :: rule_vars in - ( ninja_add_new_rule rule_output rule_name rule_vars ninja, - test_names ^ " $\n " ^ rule_output )) + let tmp_file = Filename.temp_file "clerk_" ("_" ^ catala_backend) in + let rule_output, rule_name, rule_vars = + get_rule_infos ~rule_postfix:"_and_output" expected_output.scope + in + let rule_vars = ("tmp_file", Nj.Expr.Lit tmp_file) :: rule_vars in + ( ninja_add_new_rule rule_output rule_name rule_vars ninja, + test_names ^ " $\n " ^ rule_output )) (ninja, "") expected_outputs in let test_name = @@ -508,8 +505,8 @@ let collect_all_ninja_build ninja with builds = Nj.BuildMap.add test_name - (Nj.Build.make_with_inputs ~outputs:[ Nj.Expr.Lit test_name ] - ~rule:"phony" ~inputs:[ Nj.Expr.Lit test_names ]) + (Nj.Build.make_with_inputs ~outputs:[Nj.Expr.Lit test_name] + ~rule:"phony" ~inputs:[Nj.Expr.Lit test_names]) ninja.builds; } ) @@ -517,8 +514,9 @@ let collect_all_ninja_build ninja build declaration calling the rule 'run_and_display_final_message' for [all_test_builds] which correspond to [all_file_names]. *) let add_root_test_build - (ninja : ninja) (all_file_names : string list) (all_test_builds : string) : - ninja = + (ninja : ninja) + (all_file_names : string list) + (all_test_builds : string) : ninja = let file_names_str = List.hd all_file_names ^ "" ^ List.fold_left @@ -529,9 +527,9 @@ let add_root_test_build ninja with builds = Nj.BuildMap.add "test" - (Nj.Build.make_with_vars_and_inputs ~outputs:[ Nj.Expr.Lit "test" ] + (Nj.Build.make_with_vars_and_inputs ~outputs:[Nj.Expr.Lit "test"] ~rule:"run_and_display_final_message" - ~inputs:[ Nj.Expr.Lit all_test_builds ] + ~inputs:[Nj.Expr.Lit all_test_builds] ~vars: [ ( "test_file_or_folder", @@ -551,7 +549,7 @@ let run_file String.concat " " (List.filter (fun s -> s <> "") - [ catala_exe; catala_opts; "-s " ^ scope; "Interpret"; file ]) + [catala_exe; catala_opts; "-s " ^ scope; "Interpret"; file]) in Cli.debug_print "Running: %s" command; Sys.command command @@ -561,20 +559,20 @@ let run_file let get_catala_files_in_folder (dir : string) : string list = let rec loop result = function | f :: fs -> - let f_is_dir = - try Sys.is_directory f - with Sys_error e -> - Cli.warning_print "skipping %s" e; - false - in - if f_is_dir then - readdir_sort f |> Array.to_list - |> List.map (Filename.concat f) - |> List.append fs |> loop result - else loop (f :: result) fs + let f_is_dir = + try Sys.is_directory f + with Sys_error e -> + Cli.warning_print "skipping %s" e; + false + in + if f_is_dir then + readdir_sort f |> Array.to_list + |> List.map (Filename.concat f) + |> List.append fs |> loop result + else loop (f :: result) fs | [] -> result in - let all_files_in_folder = loop [] [ dir ] in + let all_files_in_folder = loop [] [dir] in List.filter (Re.Pcre.pmatch ~rex:catala_suffix_regex) all_files_in_folder type ninja_building_context = { @@ -612,10 +610,10 @@ let collect_in_folder (fun (ninja, test_file_names) file -> match collect_all_ninja_build ninja file reset_test_outputs with | None -> - (* Skips none Catala file. *) - (ninja, test_file_names) + (* Skips none Catala file. *) + ninja, test_file_names | Some (test_file_name, ninja) -> - (ninja, test_file_names ^ " $\n " ^ test_file_name)) + ninja, test_file_names ^ " $\n " ^ test_file_name) (ninja_start, "") (get_catala_files_in_folder folder) in @@ -631,9 +629,9 @@ let collect_in_folder builds = Nj.BuildMap.add test_dir_name (Nj.Build.make_with_vars_and_inputs - ~outputs:[ Nj.Expr.Lit test_dir_name ] + ~outputs:[Nj.Expr.Lit test_dir_name] ~rule:"run_and_display_final_message" - ~inputs:[ Nj.Expr.Lit test_file_names ] + ~inputs:[Nj.Expr.Lit test_file_names] ~vars: [ ( "test_file_or_folder", @@ -668,20 +666,20 @@ let collect_in_file (reset_test_outputs : bool) : ninja_building_context = match collect_all_ninja_build ninja_start tested_file reset_test_outputs with | Some (test_file_name, ninja) -> - { - ctx with - last_valid_ninja = ninja; - curr_ninja = Some ninja; - all_file_names = tested_file :: ctx.all_file_names; - all_test_builds = ctx.all_test_builds ^ " $\n " ^ test_file_name; - } + { + ctx with + last_valid_ninja = ninja; + curr_ninja = Some ninja; + all_file_names = tested_file :: ctx.all_file_names; + all_test_builds = ctx.all_test_builds ^ " $\n " ^ test_file_name; + } | None -> - { - ctx with - last_valid_ninja = ninja_start; - curr_ninja = None; - all_failed_names = tested_file :: ctx.all_failed_names; - } + { + ctx with + last_valid_ninja = ninja_start; + curr_ninja = None; + all_failed_names = tested_file :: ctx.all_failed_names; + } (** {1 Return code values} *) @@ -714,18 +712,15 @@ let makeflags_to_ninja_flags (makeflags : string option) = match makeflags with | None -> "" | Some makeflags -> - let ignore_rex = Re.(compile @@ word (char 'i')) in - let has_ignore = Re.execp ignore_rex makeflags in - let jobs_rex = Re.(compile @@ seq [ str "-j"; group (rep digit) ]) in - let number_of_jobs = - try int_of_string (Re.Group.get (Re.exec jobs_rex makeflags) 1) - with _ -> 0 - in - String.concat " " - [ - (if has_ignore then "-k0" else ""); - "-j" ^ string_of_int number_of_jobs; - ] + let ignore_rex = Re.(compile @@ word (char 'i')) in + let has_ignore = Re.execp ignore_rex makeflags in + let jobs_rex = Re.(compile @@ seq [str "-j"; group (rep digit)]) in + let number_of_jobs = + try int_of_string (Re.Group.get (Re.exec jobs_rex makeflags) 1) + with _ -> 0 + in + String.concat " " + [(if has_ignore then "-k0" else ""); "-j" ^ string_of_int number_of_jobs] let driver (files_or_folders : string list) @@ -749,56 +744,56 @@ let driver in match String.lowercase_ascii command with | "test" -> ( - Cli.debug_print "building ninja rules..."; - let ctx = - add_test_builds - (ninja_building_context_init (ninja_start catala_exe catala_opts)) - files_or_folders reset_test_outputs - in - let there_is_some_fails = 0 <> List.length ctx.all_failed_names in - let ninja = - match ctx.curr_ninja with - | Some ninja -> ninja - | None -> ctx.last_valid_ninja - in - if there_is_some_fails then - List.iter - (fun f -> - f - |> Cli.with_style [ ANSITerminal.magenta ] "%s" - |> Cli.warning_print "No test case found for %s") - ctx.all_failed_names; - if 0 = List.compare_lengths ctx.all_failed_names files_or_folders then - return_ok - else - try - let out = open_out ninja_output in - Cli.debug_print "writing %s..." ninja_output; - Nj.format - (Format.formatter_of_out_channel out) - (add_root_test_build ninja ctx.all_file_names ctx.all_test_builds); - close_out out; - let ninja_cmd = "ninja " ^ ninja_flags ^ " test -f " ^ ninja_output in - Cli.debug_print "executing '%s'..." ninja_cmd; - Sys.command ninja_cmd - with Sys_error e -> - Cli.error_print "can not write in %s" e; - return_err) + Cli.debug_print "building ninja rules..."; + let ctx = + add_test_builds + (ninja_building_context_init (ninja_start catala_exe catala_opts)) + files_or_folders reset_test_outputs + in + let there_is_some_fails = 0 <> List.length ctx.all_failed_names in + let ninja = + match ctx.curr_ninja with + | Some ninja -> ninja + | None -> ctx.last_valid_ninja + in + if there_is_some_fails then + List.iter + (fun f -> + f + |> Cli.with_style [ANSITerminal.magenta] "%s" + |> Cli.warning_print "No test case found for %s") + ctx.all_failed_names; + if 0 = List.compare_lengths ctx.all_failed_names files_or_folders then + return_ok + else + try + let out = open_out ninja_output in + Cli.debug_print "writing %s..." ninja_output; + Nj.format + (Format.formatter_of_out_channel out) + (add_root_test_build ninja ctx.all_file_names ctx.all_test_builds); + close_out out; + let ninja_cmd = "ninja " ^ ninja_flags ^ " test -f " ^ ninja_output in + Cli.debug_print "executing '%s'..." ninja_cmd; + Sys.command ninja_cmd + with Sys_error e -> + Cli.error_print "can not write in %s" e; + return_err) | "run" -> ( - match scope with - | Some scope -> - let res = - List.fold_left - (fun ret f -> ret + run_file f catala_exe catala_opts scope) - 0 files_or_folders - in - if 0 <> res then return_err else return_ok - | None -> - Cli.error_print "Please provide a scope to run with the -s option"; - 1) + match scope with + | Some scope -> + let res = + List.fold_left + (fun ret f -> ret + run_file f catala_exe catala_opts scope) + 0 files_or_folders + in + if 0 <> res then return_err else return_ok + | None -> + Cli.error_print "Please provide a scope to run with the -s option"; + 1) | _ -> - Cli.error_print "The command \"%s\" is unknown to clerk." command; - 1 + Cli.error_print "The command \"%s\" is unknown to clerk." command; + 1 let main () = match Cmdliner.Term.eval (clerk_t driver, info) with diff --git a/build_system/ninja_utils.ml b/build_system/ninja_utils.ml index 83d2f365..d93958fa 100644 --- a/build_system/ninja_utils.ml +++ b/build_system/ninja_utils.ml @@ -24,10 +24,10 @@ module Expr = struct and format_list fmt = function | hd :: tl -> - Format.fprintf fmt "%a%a" format hd - (fun fmt tl -> - tl |> List.iter (fun s -> Format.fprintf fmt " %a" format s)) - tl + Format.fprintf fmt "%a%a" format hd + (fun fmt tl -> + tl |> List.iter (fun s -> Format.fprintf fmt " %a" format s)) + tl | [] -> () end @@ -65,7 +65,7 @@ module Build = struct let make_with_vars_and_inputs ~outputs ~rule ~inputs ~vars = { outputs; rule; inputs = Option.some inputs; vars } - let empty = make ~outputs:[ Expr.Lit "empty" ] ~rule:"phony" + let empty = make ~outputs:[Expr.Lit "empty"] ~rule:"phony" let unpath ?(sep = "-") path = Re.Pcre.(substitute ~rex:(regexp "/") ~subst:(fun _ -> sep)) path diff --git a/build_system/tests/test_clerk_driver.ml b/build_system/tests/test_clerk_driver.ml index 4a3ec496..7bce44d2 100644 --- a/build_system/tests/test_clerk_driver.ml +++ b/build_system/tests/test_clerk_driver.ml @@ -25,7 +25,7 @@ let test_ninja_start () = let test_add_test_builds_for_folder () = let ctx = D.ninja_building_context_init ninja_start in let nj_building_ctx = - To_test.add_test_builds ctx [ test_files_dir ^ "folder" ] false + To_test.add_test_builds ctx [test_files_dir ^ "folder"] false in al_assert "a test case should be found" (Option.is_some nj_building_ctx.curr_ninja); @@ -47,7 +47,7 @@ let test_add_test_builds_for_folder () = let test_add_test_builds_for_untested_file () = let untested_file = test_files_dir ^ "untested_file.catala_en" in let ctx = D.ninja_building_context_init Nj.empty in - let nj_building_ctx = To_test.add_test_builds ctx [ untested_file ] false in + let nj_building_ctx = To_test.add_test_builds ctx [untested_file] false in al_assert "no test cases should be found" (Option.is_none nj_building_ctx.curr_ninja); @@ -61,7 +61,7 @@ let test_add_test_builds_for_simple_interpret_scope_file () = in let ctx = D.ninja_building_context_init ninja_start in let nj_building_ctx = - To_test.add_test_builds ctx [ simple_interpret_scope_file ] false + To_test.add_test_builds ctx [simple_interpret_scope_file] false in al_assert "a test case should be found" (Option.is_some nj_building_ctx.curr_ninja); @@ -76,13 +76,13 @@ let test_add_test_builds_for_simple_interpret_scope_file () = in let test_A_file = Build.make_with_vars - ~outputs:[ Expr.Lit test_A_file_output ] + ~outputs:[Expr.Lit test_A_file_output] ~rule:"test_with_scope" ~vars: [ - ("scope", Lit "A"); - ("catala_cmd", Lit "Interpret"); - ("tested_file", Lit simple_interpret_scope_file); + "scope", Lit "A"; + "catala_cmd", Lit "Interpret"; + "tested_file", Lit simple_interpret_scope_file; ( "expected_output", Lit (test_files_dir @@ -91,9 +91,9 @@ let test_add_test_builds_for_simple_interpret_scope_file () = in let test_file = Build.make_with_inputs - ~outputs:[ Expr.Lit test_file_output ] + ~outputs:[Expr.Lit test_file_output] ~rule:"phony" - ~inputs:[ Expr.Lit (" $\n " ^ test_A_file_output) ] + ~inputs:[Expr.Lit (" $\n " ^ test_A_file_output)] in BuildMap.empty |> BuildMap.add test_file_output test_file diff --git a/compiler/dcalc/ast.ml b/compiler/dcalc/ast.ml index 281096f7..c50958a9 100644 --- a/compiler/dcalc/ast.ml +++ b/compiler/dcalc/ast.ml @@ -163,15 +163,13 @@ and 'expr scopes = Nil | ScopeDef of 'expr scope_def type program = { decl_ctx : decl_ctx; scopes : expr scopes } let evar (v : expr Bindlib.var) (pos : Pos.t) : expr Pos.marked Bindlib.box = - Bindlib.box_apply (fun v' -> (v', pos)) (Bindlib.box_var v) + Bindlib.box_apply (fun v' -> v', pos) (Bindlib.box_var v) let etuple (args : expr Pos.marked Bindlib.box list) (s : StructName.t option) (pos : Pos.t) : expr Pos.marked Bindlib.box = - Bindlib.box_apply - (fun args -> (ETuple (args, s), pos)) - (Bindlib.box_list args) + Bindlib.box_apply (fun args -> ETuple (args, s), pos) (Bindlib.box_list args) let etupleaccess (e1 : expr Pos.marked Bindlib.box) @@ -179,7 +177,7 @@ let etupleaccess (s : StructName.t option) (typs : typ Pos.marked list) (pos : Pos.t) : expr Pos.marked Bindlib.box = - Bindlib.box_apply (fun e1 -> (ETupleAccess (e1, i, s, typs), pos)) e1 + Bindlib.box_apply (fun e1 -> ETupleAccess (e1, i, s, typs), pos) e1 let einj (e1 : expr Pos.marked Bindlib.box) @@ -187,7 +185,7 @@ let einj (e_name : EnumName.t) (typs : typ Pos.marked list) (pos : Pos.t) : expr Pos.marked Bindlib.box = - Bindlib.box_apply (fun e1 -> (EInj (e1, i, e_name, typs), pos)) e1 + Bindlib.box_apply (fun e1 -> EInj (e1, i, e_name, typs), pos) e1 let ematch (arg : expr Pos.marked Bindlib.box) @@ -195,12 +193,12 @@ let ematch (e_name : EnumName.t) (pos : Pos.t) : expr Pos.marked Bindlib.box = Bindlib.box_apply2 - (fun arg arms -> (EMatch (arg, arms, e_name), pos)) + (fun arg arms -> EMatch (arg, arms, e_name), pos) arg (Bindlib.box_list arms) let earray (args : expr Pos.marked Bindlib.box list) (pos : Pos.t) : expr Pos.marked Bindlib.box = - Bindlib.box_apply (fun args -> (EArray args, pos)) (Bindlib.box_list args) + Bindlib.box_apply (fun args -> EArray args, pos) (Bindlib.box_list args) let elit (l : lit) (pos : Pos.t) : expr Pos.marked Bindlib.box = Bindlib.box (ELit l, pos) @@ -211,7 +209,7 @@ let eabs (typs : typ Pos.marked list) (pos : Pos.t) : expr Pos.marked Bindlib.box = Bindlib.box_apply - (fun binder -> (EAbs ((binder, pos_binder), typs), pos)) + (fun binder -> EAbs ((binder, pos_binder), typs), pos) binder let eapp @@ -219,12 +217,12 @@ let eapp (args : expr Pos.marked Bindlib.box list) (pos : Pos.t) : expr Pos.marked Bindlib.box = Bindlib.box_apply2 - (fun e1 args -> (EApp (e1, args), pos)) + (fun e1 args -> EApp (e1, args), pos) e1 (Bindlib.box_list args) let eassert (e1 : expr Pos.marked Bindlib.box) (pos : Pos.t) : expr Pos.marked Bindlib.box = - Bindlib.box_apply (fun e1 -> (EAssert e1, pos)) e1 + Bindlib.box_apply (fun e1 -> EAssert e1, pos) e1 let eop (op : operator) (pos : Pos.t) : expr Pos.marked Bindlib.box = Bindlib.box (EOp op, pos) @@ -235,7 +233,7 @@ let edefault (cons : expr Pos.marked Bindlib.box) (pos : Pos.t) : expr Pos.marked Bindlib.box = Bindlib.box_apply3 - (fun excepts just cons -> (EDefault (excepts, just, cons), pos)) + (fun excepts just cons -> EDefault (excepts, just, cons), pos) (Bindlib.box_list excepts) just cons let eifthenelse @@ -243,11 +241,11 @@ let eifthenelse (e2 : expr Pos.marked Bindlib.box) (e3 : expr Pos.marked Bindlib.box) (pos : Pos.t) : expr Pos.marked Bindlib.box = - Bindlib.box_apply3 (fun e1 e2 e3 -> (EIfThenElse (e1, e2, e3), pos)) e1 e2 e3 + Bindlib.box_apply3 (fun e1 e2 e3 -> EIfThenElse (e1, e2, e3), pos) e1 e2 e3 let eerroronempty (e1 : expr Pos.marked Bindlib.box) (pos : Pos.t) : expr Pos.marked Bindlib.box = - Bindlib.box_apply (fun e1 -> (ErrorOnEmpty e1, pos)) e1 + Bindlib.box_apply (fun e1 -> ErrorOnEmpty e1, pos) e1 let map_expr (ctx : 'a) @@ -256,30 +254,30 @@ let map_expr match Pos.unmark e with | EVar (v, _pos) -> evar v (Pos.get_position e) | EApp (e1, args) -> - eapp (f ctx e1) (List.map (f ctx) args) (Pos.get_position e) + eapp (f ctx e1) (List.map (f ctx) args) (Pos.get_position e) | EAbs ((binder, binder_pos), typs) -> - eabs - (Bindlib.box_mbinder (f ctx) binder) - binder_pos typs (Pos.get_position e) + eabs + (Bindlib.box_mbinder (f ctx) binder) + binder_pos typs (Pos.get_position e) | ETuple (args, s) -> etuple (List.map (f ctx) args) s (Pos.get_position e) | ETupleAccess (e1, n, s_name, typs) -> - etupleaccess ((f ctx) e1) n s_name typs (Pos.get_position e) + etupleaccess ((f ctx) e1) n s_name typs (Pos.get_position e) | EInj (e1, i, e_name, typs) -> - einj ((f ctx) e1) i e_name typs (Pos.get_position e) + einj ((f ctx) e1) i e_name typs (Pos.get_position e) | EMatch (arg, arms, e_name) -> - ematch ((f ctx) arg) (List.map (f ctx) arms) e_name (Pos.get_position e) + ematch ((f ctx) arg) (List.map (f ctx) arms) e_name (Pos.get_position e) | EArray args -> earray (List.map (f ctx) args) (Pos.get_position e) | ELit l -> elit l (Pos.get_position e) | EAssert e1 -> eassert ((f ctx) e1) (Pos.get_position e) | EOp op -> Bindlib.box (EOp op, Pos.get_position e) | EDefault (excepts, just, cons) -> - edefault - (List.map (f ctx) excepts) - ((f ctx) just) - ((f ctx) cons) - (Pos.get_position e) + edefault + (List.map (f ctx) excepts) + ((f ctx) just) + ((f ctx) cons) + (Pos.get_position e) | EIfThenElse (e1, e2, e3) -> - eifthenelse ((f ctx) e1) ((f ctx) e2) ((f ctx) e3) (Pos.get_position e) + eifthenelse ((f ctx) e1) ((f ctx) e2) ((f ctx) e3) (Pos.get_position e) | ErrorOnEmpty e1 -> eerroronempty ((f ctx) e1) (Pos.get_position e) (** See [Bindlib.box_term] documentation for why we are doing that. *) @@ -296,8 +294,8 @@ let rec fold_left_scope_lets match scope_body_expr with | Result _ -> init | ScopeLet scope_let -> - let var, next = Bindlib.unbind scope_let.scope_let_next in - fold_left_scope_lets ~f ~init:(f init scope_let var) next + let var, next = Bindlib.unbind scope_let.scope_let_next in + fold_left_scope_lets ~f ~init:(f init scope_let var) next let rec fold_right_scope_lets ~(f : 'expr scope_let -> 'expr Bindlib.var -> 'a -> 'a) @@ -306,9 +304,9 @@ let rec fold_right_scope_lets match scope_body_expr with | Result result -> init result | ScopeLet scope_let -> - let var, next = Bindlib.unbind scope_let.scope_let_next in - let next_result = fold_right_scope_lets ~f ~init next in - f scope_let var next_result + let var, next = Bindlib.unbind scope_let.scope_let_next in + let next_result = fold_right_scope_lets ~f ~init next in + f scope_let var next_result let map_exprs_in_scope_lets ~(f : 'expr Pos.marked -> 'expr Pos.marked Bindlib.box) @@ -336,8 +334,8 @@ let rec fold_left_scope_defs match scopes with | Nil -> init | ScopeDef scope_def -> - let var, next = Bindlib.unbind scope_def.scope_next in - fold_left_scope_defs ~f ~init:(f init scope_def var) next + let var, next = Bindlib.unbind scope_def.scope_next in + fold_left_scope_defs ~f ~init:(f init scope_def var) next let rec fold_right_scope_defs ~(f : 'expr scope_def -> 'expr Bindlib.var -> 'a -> 'a) @@ -346,9 +344,9 @@ let rec fold_right_scope_defs match scopes with | Nil -> init | ScopeDef scope_def -> - let var_next, next = Bindlib.unbind scope_def.scope_next in - let result_next = fold_right_scope_defs ~f ~init next in - f scope_def var_next result_next + let var_next, next = Bindlib.unbind scope_def.scope_next in + let result_next = fold_right_scope_defs ~f ~init next in + f scope_def var_next result_next let map_scope_defs ~(f : 'expr scope_def -> 'expr scope_def Bindlib.box) @@ -406,34 +404,34 @@ let rec free_vars_expr (e : expr Pos.marked) : VarSet.t = match Pos.unmark e with | EVar (v, _) -> VarSet.singleton v | ETuple (es, _) | EArray es -> - es |> List.map free_vars_expr |> List.fold_left VarSet.union VarSet.empty + es |> List.map free_vars_expr |> List.fold_left VarSet.union VarSet.empty | ETupleAccess (e1, _, _, _) | EAssert e1 | ErrorOnEmpty e1 | EInj (e1, _, _, _) -> - free_vars_expr e1 + free_vars_expr e1 | EApp (e1, es) | EMatch (e1, es, _) -> - e1 :: es |> List.map free_vars_expr - |> List.fold_left VarSet.union VarSet.empty + e1 :: es |> List.map free_vars_expr + |> List.fold_left VarSet.union VarSet.empty | EDefault (es, ejust, econs) -> - ejust :: econs :: es |> List.map free_vars_expr - |> List.fold_left VarSet.union VarSet.empty + ejust :: econs :: es |> List.map free_vars_expr + |> List.fold_left VarSet.union VarSet.empty | EOp _ | ELit _ -> VarSet.empty | EIfThenElse (e1, e2, e3) -> - [ e1; e2; e3 ] |> List.map free_vars_expr - |> List.fold_left VarSet.union VarSet.empty + [e1; e2; e3] |> List.map free_vars_expr + |> List.fold_left VarSet.union VarSet.empty | EAbs ((binder, _), _) -> - let vs, body = Bindlib.unmbind binder in - Array.fold_right VarSet.remove vs (free_vars_expr body) + let vs, body = Bindlib.unmbind binder in + Array.fold_right VarSet.remove vs (free_vars_expr body) let rec free_vars_scope_body_expr (scope_lets : expr scope_body_expr) : VarSet.t = match scope_lets with | Result e -> free_vars_expr e | ScopeLet { scope_let_expr = e; scope_let_next = next; _ } -> - let v, body = Bindlib.unbind next in - VarSet.union (free_vars_expr e) - (VarSet.remove v (free_vars_scope_body_expr body)) + let v, body = Bindlib.unbind next in + VarSet.union (free_vars_expr e) + (VarSet.remove v (free_vars_scope_body_expr body)) let free_vars_scope_body (scope_body : expr scope_body) : VarSet.t = let { scope_body_expr = binder; _ } = scope_body in @@ -444,15 +442,15 @@ let rec free_vars_scopes (scopes : expr scopes) : VarSet.t = match scopes with | Nil -> VarSet.empty | ScopeDef { scope_body = body; scope_next = next; _ } -> - let v, next = Bindlib.unbind next in - VarSet.union - (VarSet.remove v (free_vars_scopes next)) - (free_vars_scope_body body) + let v, next = Bindlib.unbind next in + VarSet.union + (VarSet.remove v (free_vars_scopes next)) + (free_vars_scope_body body) type vars = expr Bindlib.mvar let make_var ((x, pos) : Var.t Pos.marked) : expr Pos.marked Bindlib.box = - Bindlib.box_apply (fun x -> (x, pos)) (Bindlib.box_var x) + Bindlib.box_apply (fun x -> x, pos) (Bindlib.box_var x) let make_abs (xs : vars) @@ -461,14 +459,14 @@ let make_abs (taus : typ Pos.marked list) (pos : Pos.t) : expr Pos.marked Bindlib.box = Bindlib.box_apply - (fun b -> (EAbs ((b, pos_binder), taus), pos)) + (fun b -> EAbs ((b, pos_binder), taus), pos) (Bindlib.bind_mvar xs e) let make_app (e : expr Pos.marked Bindlib.box) (u : expr Pos.marked Bindlib.box list) (pos : Pos.t) : expr Pos.marked Bindlib.box = - Bindlib.box_apply2 (fun e u -> (EApp (e, u), pos)) e (Bindlib.box_list u) + Bindlib.box_apply2 (fun e u -> EApp (e, u), pos) e (Bindlib.box_list u) let make_let_in (x : Var.t) @@ -476,23 +474,22 @@ let make_let_in (e1 : expr Pos.marked Bindlib.box) (e2 : expr Pos.marked Bindlib.box) (pos : Pos.t) : expr Pos.marked Bindlib.box = - make_app (make_abs (Array.of_list [ x ]) e2 pos [ tau ] pos) [ e1 ] pos + make_app (make_abs (Array.of_list [x]) e2 pos [tau] pos) [e1] pos let empty_thunked_term : expr Pos.marked = let silent = Var.make ("_", Pos.no_pos) in Bindlib.unbox - (make_abs - (Array.of_list [ silent ]) + (make_abs (Array.of_list [silent]) (Bindlib.box (ELit LEmptyError, Pos.no_pos)) Pos.no_pos - [ (TLit TUnit, Pos.no_pos) ] + [TLit TUnit, Pos.no_pos] Pos.no_pos) let is_value (e : expr Pos.marked) : bool = match Pos.unmark e with ELit _ | EAbs _ | EOp _ -> true | _ -> false let rec equal_typs (ty1 : typ Pos.marked) (ty2 : typ Pos.marked) : bool = - match (Pos.unmark ty1, Pos.unmark ty2) with + match Pos.unmark ty1, Pos.unmark ty2 with | TLit l1, TLit l2 -> l1 = l2 | TTuple (tys1, n1), TTuple (tys2, n2) -> n1 = n2 && equal_typs_list tys1 tys2 | TEnum (tys1, n1), TEnum (tys2, n2) -> n1 = n2 && equal_typs_list tys1 tys2 @@ -509,12 +506,12 @@ and equal_typs_list (tys1 : typ Pos.marked list) (tys2 : typ Pos.marked list) : List.for_all (fun (x, y) -> equal_typs x y) (List.combine tys1 tys2) let equal_log_entries (l1 : log_entry) (l2 : log_entry) : bool = - match (l1, l2) with + match l1, l2 with | VarDef t1, VarDef t2 -> equal_typs (t1, Pos.no_pos) (t2, Pos.no_pos) | x, y -> x = y let equal_unops (op1 : unop) (op2 : unop) : bool = - match (op1, op2) with + match op1, op2 with (* Log entries contain a typ which contain position information, we thus need to descend into them *) | Log (l1, info1), Log (l2, info2) -> equal_log_entries l1 l2 && info1 = info2 @@ -522,40 +519,40 @@ let equal_unops (op1 : unop) (op2 : unop) : bool = | _ -> op1 = op2 let equal_ops (op1 : operator) (op2 : operator) : bool = - match (op1, op2) with + match op1, op2 with | Ternop op1, Ternop op2 -> op1 = op2 | Binop op1, Binop op2 -> op1 = op2 | Unop op1, Unop op2 -> equal_unops op1 op2 | _, _ -> false let rec equal_exprs (e1 : expr Pos.marked) (e2 : expr Pos.marked) : bool = - match (Pos.unmark e1, Pos.unmark e2) with + match Pos.unmark e1, Pos.unmark e2 with | EVar v1, EVar v2 -> Pos.unmark v1 = Pos.unmark v2 | ETuple (es1, n1), ETuple (es2, n2) -> n1 = n2 && equal_exprs_list es1 es2 | ETupleAccess (e1, id1, n1, tys1), ETupleAccess (e2, id2, n2, tys2) -> - equal_exprs e1 e2 && id1 = id2 && n1 = n2 && equal_typs_list tys1 tys2 + equal_exprs e1 e2 && id1 = id2 && n1 = n2 && equal_typs_list tys1 tys2 | EInj (e1, id1, n1, tys1), EInj (e2, id2, n2, tys2) -> - equal_exprs e1 e2 && id1 = id2 && n1 = n2 && equal_typs_list tys1 tys2 + equal_exprs e1 e2 && id1 = id2 && n1 = n2 && equal_typs_list tys1 tys2 | EMatch (e1, cases1, n1), EMatch (e2, cases2, n2) -> - n1 = n2 && equal_exprs e1 e2 && equal_exprs_list cases1 cases2 + n1 = n2 && equal_exprs e1 e2 && equal_exprs_list cases1 cases2 | EArray es1, EArray es2 -> equal_exprs_list es1 es2 | ELit l1, ELit l2 -> l1 = l2 | EAbs (b1, tys1), EAbs (b2, tys2) -> - equal_typs_list tys1 tys2 - && - let vars1, body1 = Bindlib.unmbind (Pos.unmark b1) in - let body2 = - Bindlib.msubst (Pos.unmark b2) - (Array.map (fun x -> EVar (x, Pos.no_pos)) vars1) - in - equal_exprs body1 body2 + equal_typs_list tys1 tys2 + && + let vars1, body1 = Bindlib.unmbind (Pos.unmark b1) in + let body2 = + Bindlib.msubst (Pos.unmark b2) + (Array.map (fun x -> EVar (x, Pos.no_pos)) vars1) + in + equal_exprs body1 body2 | EAssert e1, EAssert e2 -> equal_exprs e1 e2 | EOp op1, EOp op2 -> equal_ops op1 op2 | EDefault (exc1, def1, cons1), EDefault (exc2, def2, cons2) -> - equal_exprs def1 def2 && equal_exprs cons1 cons2 - && equal_exprs_list exc1 exc2 + equal_exprs def1 def2 && equal_exprs cons1 cons2 + && equal_exprs_list exc1 exc2 | EIfThenElse (if1, then1, else1), EIfThenElse (if2, then2, else2) -> - equal_exprs if1 if2 && equal_exprs then1 then2 && equal_exprs else1 else2 + equal_exprs if1 if2 && equal_exprs then1 then2 && equal_exprs else1 else2 | ErrorOnEmpty e1, ErrorOnEmpty e2 -> equal_exprs e1 e2 | _, _ -> false @@ -597,10 +594,10 @@ let rec unfold_scope_body_expr scope_let_next; scope_let_pos; } -> - let var, next = Bindlib.unbind scope_let_next in - make_let_in var scope_let_typ (box_expr scope_let_expr) - (unfold_scope_body_expr ~box_expr ~make_let_in ctx next) - scope_let_pos + let var, next = Bindlib.unbind scope_let_next in + make_let_in var scope_let_typ (box_expr scope_let_expr) + (unfold_scope_body_expr ~box_expr ~make_let_in ctx next) + scope_let_pos let build_whole_scope_expr ~(box_expr : 'expr box_expr_sig) @@ -611,9 +608,7 @@ let build_whole_scope_expr (pos_scope : Pos.t) : 'expr Pos.marked Bindlib.box = let var, body_expr = Bindlib.unbind body.scope_body_expr in let body_expr = unfold_scope_body_expr ~box_expr ~make_let_in ctx body_expr in - make_abs - (Array.of_list [ var ]) - body_expr pos_scope + make_abs (Array.of_list [var]) body_expr pos_scope [ ( TTuple ( List.map snd @@ -633,12 +628,12 @@ let build_scope_typ_from_sig StructMap.find scope_return_struct_name ctx.ctx_structs in let result_typ = - (TTuple (List.map snd scope_return_typ, Some scope_return_struct_name), pos) + TTuple (List.map snd scope_return_typ, Some scope_return_struct_name), pos in let input_typ = - (TTuple (List.map snd scope_sig, Some scope_input_struct_name), pos) + TTuple (List.map snd scope_sig, Some scope_input_struct_name), pos in - (TArrow (input_typ, result_typ), pos) + TArrow (input_typ, result_typ), pos type 'expr scope_name_or_var = | ScopeName of ScopeName.t @@ -653,28 +648,27 @@ let rec unfold_scopes (main_scope : 'expr scope_name_or_var) : 'expr Pos.marked Bindlib.box = match s with | Nil -> ( - match main_scope with - | ScopeVar v -> - Bindlib.box_apply (fun v -> (v, Pos.no_pos)) (Bindlib.box_var v) - | ScopeName _ -> failwith "should not happen") + match main_scope with + | ScopeVar v -> + Bindlib.box_apply (fun v -> v, Pos.no_pos) (Bindlib.box_var v) + | ScopeName _ -> failwith "should not happen") | ScopeDef { scope_name; scope_body; scope_next } -> - let scope_var, scope_next = Bindlib.unbind scope_next in - let scope_pos = Pos.get_position (ScopeName.get_info scope_name) in - let main_scope = - match main_scope with - | ScopeVar v -> ScopeVar v - | ScopeName n -> - if ScopeName.compare n scope_name = 0 then ScopeVar scope_var - else ScopeName n - in - make_let_in scope_var - (build_scope_typ_from_sig ctx scope_body.scope_body_input_struct - scope_body.scope_body_output_struct scope_pos) - (build_whole_scope_expr ~box_expr ~make_abs ~make_let_in ctx scope_body - scope_pos) - (unfold_scopes ~box_expr ~make_abs ~make_let_in ctx scope_next - main_scope) - scope_pos + let scope_var, scope_next = Bindlib.unbind scope_next in + let scope_pos = Pos.get_position (ScopeName.get_info scope_name) in + let main_scope = + match main_scope with + | ScopeVar v -> ScopeVar v + | ScopeName n -> + if ScopeName.compare n scope_name = 0 then ScopeVar scope_var + else ScopeName n + in + make_let_in scope_var + (build_scope_typ_from_sig ctx scope_body.scope_body_input_struct + scope_body.scope_body_output_struct scope_pos) + (build_whole_scope_expr ~box_expr ~make_abs ~make_let_in ctx scope_body + scope_pos) + (unfold_scopes ~box_expr ~make_abs ~make_let_in ctx scope_next main_scope) + scope_pos let build_whole_program_expr (p : program) (main_scope : ScopeName.t) = unfold_scopes ~box_expr ~make_abs ~make_let_in p.decl_ctx p.scopes @@ -684,31 +678,28 @@ let rec expr_size (e : expr Pos.marked) : int = match Pos.unmark e with | EVar _ | ELit _ | EOp _ -> 1 | ETuple (args, _) | EArray args -> - List.fold_left (fun acc arg -> acc + expr_size arg) 1 args + List.fold_left (fun acc arg -> acc + expr_size arg) 1 args | ETupleAccess (e1, _, _, _) | EInj (e1, _, _, _) | EAssert e1 | ErrorOnEmpty e1 -> - expr_size e1 + 1 + expr_size e1 + 1 | EMatch (arg, args, _) | EApp (arg, args) -> - List.fold_left - (fun acc arg -> acc + expr_size arg) - (1 + expr_size arg) - args + List.fold_left (fun acc arg -> acc + expr_size arg) (1 + expr_size arg) args | EAbs ((binder, _), _) -> - let _, body = Bindlib.unmbind binder in - 1 + expr_size body + let _, body = Bindlib.unmbind binder in + 1 + expr_size body | EIfThenElse (e1, e2, e3) -> 1 + expr_size e1 + expr_size e2 + expr_size e3 | EDefault (exceptions, just, cons) -> - List.fold_left - (fun acc except -> acc + expr_size except) - (1 + expr_size just + expr_size cons) - exceptions + List.fold_left + (fun acc except -> acc + expr_size except) + (1 + expr_size just + expr_size cons) + exceptions let remove_logging_calls (e : expr Pos.marked) : expr Pos.marked Bindlib.box = let rec f () e = match Pos.unmark e with - | EApp ((EOp (Unop (Log _)), _), [ arg ]) -> map_expr () ~f arg + | EApp ((EOp (Unop (Log _)), _), [arg]) -> map_expr () ~f arg | _ -> map_expr () ~f e in f () e diff --git a/compiler/dcalc/interpreter.ml b/compiler/dcalc/interpreter.ml index a5da8bb8..87b473b9 100644 --- a/compiler/dcalc/interpreter.ml +++ b/compiler/dcalc/interpreter.ml @@ -40,22 +40,23 @@ let rec evaluate_operator with Division_by_zero -> Errors.raise_multispanned_error [ - (Some "The division operator:", Pos.get_position op); - (Some "The null denominator:", Pos.get_position (List.nth args 1)); + Some "The division operator:", Pos.get_position op; + Some "The null denominator:", Pos.get_position (List.nth args 1); ] "division by zero at runtime" in let get_binop_args_pos (args : (A.expr * Pos.t) list) : (string option * Pos.t) list = [ - (None, Pos.get_position (List.nth args 0)); - (None, Pos.get_position (List.nth args 1)); + None, Pos.get_position (List.nth args 0); + None, Pos.get_position (List.nth args 1); ] in (* Try to apply [cmp] and if a [UncomparableDurations] exceptions is catched, use [args] to raise multispanned errors. *) let apply_cmp_or_raise_err - (cmp : unit -> A.expr) (args : (A.expr * Pos.t) list) : A.expr = + (cmp : unit -> A.expr) + (args : (A.expr * Pos.t) list) : A.expr = try cmp () with Runtime.UncomparableDurations -> Errors.raise_multispanned_error (get_binop_args_pos args) @@ -63,469 +64,461 @@ let rec evaluate_operator precise number of days" in Pos.same_pos_as - (match (Pos.unmark op, List.map Pos.unmark args) with - | A.Ternop A.Fold, [ _f; _init; EArray es ] -> + (match Pos.unmark op, List.map Pos.unmark args with + | A.Ternop A.Fold, [_f; _init; EArray es] -> + Pos.unmark + (List.fold_left + (fun acc e' -> + evaluate_expr ctx + (Pos.same_pos_as (A.EApp (List.nth args 0, [acc; e'])) e')) + (List.nth args 1) es) + | A.Binop A.And, [ELit (LBool b1); ELit (LBool b2)] -> + A.ELit (LBool (b1 && b2)) + | A.Binop A.Or, [ELit (LBool b1); ELit (LBool b2)] -> + A.ELit (LBool (b1 || b2)) + | A.Binop A.Xor, [ELit (LBool b1); ELit (LBool b2)] -> + A.ELit (LBool (b1 <> b2)) + | A.Binop (A.Add KInt), [ELit (LInt i1); ELit (LInt i2)] -> + A.ELit (LInt Runtime.(i1 +! i2)) + | A.Binop (A.Sub KInt), [ELit (LInt i1); ELit (LInt i2)] -> + A.ELit (LInt Runtime.(i1 -! i2)) + | A.Binop (A.Mult KInt), [ELit (LInt i1); ELit (LInt i2)] -> + A.ELit (LInt Runtime.(i1 *! i2)) + | A.Binop (A.Div KInt), [ELit (LInt i1); ELit (LInt i2)] -> + apply_div_or_raise_err (fun _ -> A.ELit (LInt Runtime.(i1 /! i2))) op + | A.Binop (A.Add KRat), [ELit (LRat i1); ELit (LRat i2)] -> + A.ELit (LRat Runtime.(i1 +& i2)) + | A.Binop (A.Sub KRat), [ELit (LRat i1); ELit (LRat i2)] -> + A.ELit (LRat Runtime.(i1 -& i2)) + | A.Binop (A.Mult KRat), [ELit (LRat i1); ELit (LRat i2)] -> + A.ELit (LRat Runtime.(i1 *& i2)) + | A.Binop (A.Div KRat), [ELit (LRat i1); ELit (LRat i2)] -> + apply_div_or_raise_err (fun _ -> A.ELit (LRat Runtime.(i1 /& i2))) op + | A.Binop (A.Add KMoney), [ELit (LMoney m1); ELit (LMoney m2)] -> + A.ELit (LMoney Runtime.(m1 +$ m2)) + | A.Binop (A.Sub KMoney), [ELit (LMoney m1); ELit (LMoney m2)] -> + A.ELit (LMoney Runtime.(m1 -$ m2)) + | A.Binop (A.Mult KMoney), [ELit (LMoney m1); ELit (LRat m2)] -> + A.ELit (LMoney Runtime.(m1 *$ m2)) + | A.Binop (A.Div KMoney), [ELit (LMoney m1); ELit (LMoney m2)] -> + apply_div_or_raise_err (fun _ -> A.ELit (LRat Runtime.(m1 /$ m2))) op + | A.Binop (A.Add KDuration), [ELit (LDuration d1); ELit (LDuration d2)] -> + A.ELit (LDuration Runtime.(d1 +^ d2)) + | A.Binop (A.Sub KDuration), [ELit (LDuration d1); ELit (LDuration d2)] -> + A.ELit (LDuration Runtime.(d1 -^ d2)) + | A.Binop (A.Sub KDate), [ELit (LDate d1); ELit (LDate d2)] -> + A.ELit (LDuration Runtime.(d1 -@ d2)) + | A.Binop (A.Add KDate), [ELit (LDate d1); ELit (LDuration d2)] -> + A.ELit (LDate Runtime.(d1 +@ d2)) + | A.Binop (A.Div KDuration), [ELit (LDuration d1); ELit (LDuration d2)] -> + apply_div_or_raise_err + (fun _ -> + try A.ELit (LRat Runtime.(d1 /^ d2)) + with Runtime.IndivisableDurations -> + Errors.raise_multispanned_error (get_binop_args_pos args) + "Cannot divide durations that cannot be converted to a precise \ + number of days") + op + | A.Binop (A.Lt KInt), [ELit (LInt i1); ELit (LInt i2)] -> + A.ELit (LBool Runtime.(i1 + A.ELit (LBool Runtime.(i1 <=! i2)) + | A.Binop (A.Gt KInt), [ELit (LInt i1); ELit (LInt i2)] -> + A.ELit (LBool Runtime.(i1 >! i2)) + | A.Binop (A.Gte KInt), [ELit (LInt i1); ELit (LInt i2)] -> + A.ELit (LBool Runtime.(i1 >=! i2)) + | A.Binop (A.Lt KRat), [ELit (LRat i1); ELit (LRat i2)] -> + A.ELit (LBool Runtime.(i1 <& i2)) + | A.Binop (A.Lte KRat), [ELit (LRat i1); ELit (LRat i2)] -> + A.ELit (LBool Runtime.(i1 <=& i2)) + | A.Binop (A.Gt KRat), [ELit (LRat i1); ELit (LRat i2)] -> + A.ELit (LBool Runtime.(i1 >& i2)) + | A.Binop (A.Gte KRat), [ELit (LRat i1); ELit (LRat i2)] -> + A.ELit (LBool Runtime.(i1 >=& i2)) + | A.Binop (A.Lt KMoney), [ELit (LMoney m1); ELit (LMoney m2)] -> + A.ELit (LBool Runtime.(m1 <$ m2)) + | A.Binop (A.Lte KMoney), [ELit (LMoney m1); ELit (LMoney m2)] -> + A.ELit (LBool Runtime.(m1 <=$ m2)) + | A.Binop (A.Gt KMoney), [ELit (LMoney m1); ELit (LMoney m2)] -> + A.ELit (LBool Runtime.(m1 >$ m2)) + | A.Binop (A.Gte KMoney), [ELit (LMoney m1); ELit (LMoney m2)] -> + A.ELit (LBool Runtime.(m1 >=$ m2)) + | A.Binop (A.Lt KDuration), [ELit (LDuration d1); ELit (LDuration d2)] -> + apply_cmp_or_raise_err (fun _ -> A.ELit (LBool Runtime.(d1 <^ d2))) args + | A.Binop (A.Lte KDuration), [ELit (LDuration d1); ELit (LDuration d2)] -> + apply_cmp_or_raise_err (fun _ -> A.ELit (LBool Runtime.(d1 <=^ d2))) args + | A.Binop (A.Gt KDuration), [ELit (LDuration d1); ELit (LDuration d2)] -> + apply_cmp_or_raise_err (fun _ -> A.ELit (LBool Runtime.(d1 >^ d2))) args + | A.Binop (A.Gte KDuration), [ELit (LDuration d1); ELit (LDuration d2)] -> + apply_cmp_or_raise_err (fun _ -> A.ELit (LBool Runtime.(d1 >=^ d2))) args + | A.Binop (A.Lt KDate), [ELit (LDate d1); ELit (LDate d2)] -> + A.ELit (LBool Runtime.(d1 <@ d2)) + | A.Binop (A.Lte KDate), [ELit (LDate d1); ELit (LDate d2)] -> + A.ELit (LBool Runtime.(d1 <=@ d2)) + | A.Binop (A.Gt KDate), [ELit (LDate d1); ELit (LDate d2)] -> + A.ELit (LBool Runtime.(d1 >@ d2)) + | A.Binop (A.Gte KDate), [ELit (LDate d1); ELit (LDate d2)] -> + A.ELit (LBool Runtime.(d1 >=@ d2)) + | A.Binop A.Eq, [ELit LUnit; ELit LUnit] -> A.ELit (LBool true) + | A.Binop A.Eq, [ELit (LDuration d1); ELit (LDuration d2)] -> + A.ELit (LBool Runtime.(d1 =^ d2)) + | A.Binop A.Eq, [ELit (LDate d1); ELit (LDate d2)] -> + A.ELit (LBool Runtime.(d1 =@ d2)) + | A.Binop A.Eq, [ELit (LMoney m1); ELit (LMoney m2)] -> + A.ELit (LBool Runtime.(m1 =$ m2)) + | A.Binop A.Eq, [ELit (LRat i1); ELit (LRat i2)] -> + A.ELit (LBool Runtime.(i1 =& i2)) + | A.Binop A.Eq, [ELit (LInt i1); ELit (LInt i2)] -> + A.ELit (LBool Runtime.(i1 =! i2)) + | A.Binop A.Eq, [ELit (LBool b1); ELit (LBool b2)] -> + A.ELit (LBool (b1 = b2)) + | A.Binop A.Eq, [EArray es1; EArray es2] -> + A.ELit + (LBool + (try + List.for_all2 + (fun e1 e2 -> + match Pos.unmark (evaluate_operator ctx op [e1; e2]) with + | A.ELit (LBool b) -> b + | _ -> assert false + (* should not happen *)) + es1 es2 + with Invalid_argument _ -> false)) + | A.Binop A.Eq, [ETuple (es1, s1); ETuple (es2, s2)] -> + A.ELit + (LBool + (try + s1 = s2 + && List.for_all2 + (fun e1 e2 -> + match Pos.unmark (evaluate_operator ctx op [e1; e2]) with + | A.ELit (LBool b) -> b + | _ -> assert false + (* should not happen *)) + es1 es2 + with Invalid_argument _ -> false)) + | A.Binop A.Eq, [EInj (e1, i1, en1, _ts1); EInj (e2, i2, en2, _ts2)] -> + A.ELit + (LBool + (try + en1 = en2 && i1 = i2 + && + match Pos.unmark (evaluate_operator ctx op [e1; e2]) with + | A.ELit (LBool b) -> b + | _ -> assert false + (* should not happen *) + with Invalid_argument _ -> false)) + | A.Binop A.Eq, [_; _] -> + A.ELit (LBool false) (* comparing anything else return false *) + | A.Binop A.Neq, [_; _] -> ( + match Pos.unmark - (List.fold_left - (fun acc e' -> + (evaluate_operator ctx (Pos.same_pos_as (A.Binop A.Eq) op) args) + with + | A.ELit (A.LBool b) -> A.ELit (A.LBool (not b)) + | _ -> assert false (*should not happen *)) + | A.Binop A.Concat, [A.EArray es1; A.EArray es2] -> A.EArray (es1 @ es2) + | A.Binop A.Map, [_; A.EArray es] -> + A.EArray + (List.map + (fun e' -> + evaluate_expr ctx + (Pos.same_pos_as (A.EApp (List.nth args 0, [e'])) e')) + es) + | A.Binop A.Filter, [_; A.EArray es] -> + A.EArray + (List.filter + (fun e' -> + match evaluate_expr ctx - (Pos.same_pos_as (A.EApp (List.nth args 0, [ acc; e' ])) e')) - (List.nth args 1) es) - | A.Binop A.And, [ ELit (LBool b1); ELit (LBool b2) ] -> - A.ELit (LBool (b1 && b2)) - | A.Binop A.Or, [ ELit (LBool b1); ELit (LBool b2) ] -> - A.ELit (LBool (b1 || b2)) - | A.Binop A.Xor, [ ELit (LBool b1); ELit (LBool b2) ] -> - A.ELit (LBool (b1 <> b2)) - | A.Binop (A.Add KInt), [ ELit (LInt i1); ELit (LInt i2) ] -> - A.ELit (LInt Runtime.(i1 +! i2)) - | A.Binop (A.Sub KInt), [ ELit (LInt i1); ELit (LInt i2) ] -> - A.ELit (LInt Runtime.(i1 -! i2)) - | A.Binop (A.Mult KInt), [ ELit (LInt i1); ELit (LInt i2) ] -> - A.ELit (LInt Runtime.(i1 *! i2)) - | A.Binop (A.Div KInt), [ ELit (LInt i1); ELit (LInt i2) ] -> - apply_div_or_raise_err (fun _ -> A.ELit (LInt Runtime.(i1 /! i2))) op - | A.Binop (A.Add KRat), [ ELit (LRat i1); ELit (LRat i2) ] -> - A.ELit (LRat Runtime.(i1 +& i2)) - | A.Binop (A.Sub KRat), [ ELit (LRat i1); ELit (LRat i2) ] -> - A.ELit (LRat Runtime.(i1 -& i2)) - | A.Binop (A.Mult KRat), [ ELit (LRat i1); ELit (LRat i2) ] -> - A.ELit (LRat Runtime.(i1 *& i2)) - | A.Binop (A.Div KRat), [ ELit (LRat i1); ELit (LRat i2) ] -> - apply_div_or_raise_err (fun _ -> A.ELit (LRat Runtime.(i1 /& i2))) op - | A.Binop (A.Add KMoney), [ ELit (LMoney m1); ELit (LMoney m2) ] -> - A.ELit (LMoney Runtime.(m1 +$ m2)) - | A.Binop (A.Sub KMoney), [ ELit (LMoney m1); ELit (LMoney m2) ] -> - A.ELit (LMoney Runtime.(m1 -$ m2)) - | A.Binop (A.Mult KMoney), [ ELit (LMoney m1); ELit (LRat m2) ] -> - A.ELit (LMoney Runtime.(m1 *$ m2)) - | A.Binop (A.Div KMoney), [ ELit (LMoney m1); ELit (LMoney m2) ] -> - apply_div_or_raise_err (fun _ -> A.ELit (LRat Runtime.(m1 /$ m2))) op - | A.Binop (A.Add KDuration), [ ELit (LDuration d1); ELit (LDuration d2) ] -> - A.ELit (LDuration Runtime.(d1 +^ d2)) - | A.Binop (A.Sub KDuration), [ ELit (LDuration d1); ELit (LDuration d2) ] -> - A.ELit (LDuration Runtime.(d1 -^ d2)) - | A.Binop (A.Sub KDate), [ ELit (LDate d1); ELit (LDate d2) ] -> - A.ELit (LDuration Runtime.(d1 -@ d2)) - | A.Binop (A.Add KDate), [ ELit (LDate d1); ELit (LDuration d2) ] -> - A.ELit (LDate Runtime.(d1 +@ d2)) - | A.Binop (A.Div KDuration), [ ELit (LDuration d1); ELit (LDuration d2) ] -> - apply_div_or_raise_err - (fun _ -> - try A.ELit (LRat Runtime.(d1 /^ d2)) - with Runtime.IndivisableDurations -> - Errors.raise_multispanned_error (get_binop_args_pos args) - "Cannot divide durations that cannot be converted to a precise \ - number of days") - op - | A.Binop (A.Lt KInt), [ ELit (LInt i1); ELit (LInt i2) ] -> - A.ELit (LBool Runtime.(i1 - A.ELit (LBool Runtime.(i1 <=! i2)) - | A.Binop (A.Gt KInt), [ ELit (LInt i1); ELit (LInt i2) ] -> - A.ELit (LBool Runtime.(i1 >! i2)) - | A.Binop (A.Gte KInt), [ ELit (LInt i1); ELit (LInt i2) ] -> - A.ELit (LBool Runtime.(i1 >=! i2)) - | A.Binop (A.Lt KRat), [ ELit (LRat i1); ELit (LRat i2) ] -> - A.ELit (LBool Runtime.(i1 <& i2)) - | A.Binop (A.Lte KRat), [ ELit (LRat i1); ELit (LRat i2) ] -> - A.ELit (LBool Runtime.(i1 <=& i2)) - | A.Binop (A.Gt KRat), [ ELit (LRat i1); ELit (LRat i2) ] -> - A.ELit (LBool Runtime.(i1 >& i2)) - | A.Binop (A.Gte KRat), [ ELit (LRat i1); ELit (LRat i2) ] -> - A.ELit (LBool Runtime.(i1 >=& i2)) - | A.Binop (A.Lt KMoney), [ ELit (LMoney m1); ELit (LMoney m2) ] -> - A.ELit (LBool Runtime.(m1 <$ m2)) - | A.Binop (A.Lte KMoney), [ ELit (LMoney m1); ELit (LMoney m2) ] -> - A.ELit (LBool Runtime.(m1 <=$ m2)) - | A.Binop (A.Gt KMoney), [ ELit (LMoney m1); ELit (LMoney m2) ] -> - A.ELit (LBool Runtime.(m1 >$ m2)) - | A.Binop (A.Gte KMoney), [ ELit (LMoney m1); ELit (LMoney m2) ] -> - A.ELit (LBool Runtime.(m1 >=$ m2)) - | A.Binop (A.Lt KDuration), [ ELit (LDuration d1); ELit (LDuration d2) ] -> - apply_cmp_or_raise_err (fun _ -> A.ELit (LBool Runtime.(d1 <^ d2))) args - | A.Binop (A.Lte KDuration), [ ELit (LDuration d1); ELit (LDuration d2) ] -> - apply_cmp_or_raise_err - (fun _ -> A.ELit (LBool Runtime.(d1 <=^ d2))) - args - | A.Binop (A.Gt KDuration), [ ELit (LDuration d1); ELit (LDuration d2) ] -> - apply_cmp_or_raise_err (fun _ -> A.ELit (LBool Runtime.(d1 >^ d2))) args - | A.Binop (A.Gte KDuration), [ ELit (LDuration d1); ELit (LDuration d2) ] -> - apply_cmp_or_raise_err - (fun _ -> A.ELit (LBool Runtime.(d1 >=^ d2))) - args - | A.Binop (A.Lt KDate), [ ELit (LDate d1); ELit (LDate d2) ] -> - A.ELit (LBool Runtime.(d1 <@ d2)) - | A.Binop (A.Lte KDate), [ ELit (LDate d1); ELit (LDate d2) ] -> - A.ELit (LBool Runtime.(d1 <=@ d2)) - | A.Binop (A.Gt KDate), [ ELit (LDate d1); ELit (LDate d2) ] -> - A.ELit (LBool Runtime.(d1 >@ d2)) - | A.Binop (A.Gte KDate), [ ELit (LDate d1); ELit (LDate d2) ] -> - A.ELit (LBool Runtime.(d1 >=@ d2)) - | A.Binop A.Eq, [ ELit LUnit; ELit LUnit ] -> A.ELit (LBool true) - | A.Binop A.Eq, [ ELit (LDuration d1); ELit (LDuration d2) ] -> - A.ELit (LBool Runtime.(d1 =^ d2)) - | A.Binop A.Eq, [ ELit (LDate d1); ELit (LDate d2) ] -> - A.ELit (LBool Runtime.(d1 =@ d2)) - | A.Binop A.Eq, [ ELit (LMoney m1); ELit (LMoney m2) ] -> - A.ELit (LBool Runtime.(m1 =$ m2)) - | A.Binop A.Eq, [ ELit (LRat i1); ELit (LRat i2) ] -> - A.ELit (LBool Runtime.(i1 =& i2)) - | A.Binop A.Eq, [ ELit (LInt i1); ELit (LInt i2) ] -> - A.ELit (LBool Runtime.(i1 =! i2)) - | A.Binop A.Eq, [ ELit (LBool b1); ELit (LBool b2) ] -> - A.ELit (LBool (b1 = b2)) - | A.Binop A.Eq, [ EArray es1; EArray es2 ] -> - A.ELit - (LBool - (try - List.for_all2 - (fun e1 e2 -> - match Pos.unmark (evaluate_operator ctx op [ e1; e2 ]) with - | A.ELit (LBool b) -> b - | _ -> assert false - (* should not happen *)) - es1 es2 - with Invalid_argument _ -> false)) - | A.Binop A.Eq, [ ETuple (es1, s1); ETuple (es2, s2) ] -> - A.ELit - (LBool - (try - s1 = s2 - && List.for_all2 - (fun e1 e2 -> - match - Pos.unmark (evaluate_operator ctx op [ e1; e2 ]) - with - | A.ELit (LBool b) -> b - | _ -> assert false - (* should not happen *)) - es1 es2 - with Invalid_argument _ -> false)) - | A.Binop A.Eq, [ EInj (e1, i1, en1, _ts1); EInj (e2, i2, en2, _ts2) ] -> - A.ELit - (LBool - (try - en1 = en2 && i1 = i2 - && - match Pos.unmark (evaluate_operator ctx op [ e1; e2 ]) with - | A.ELit (LBool b) -> b - | _ -> assert false - (* should not happen *) - with Invalid_argument _ -> false)) - | A.Binop A.Eq, [ _; _ ] -> - A.ELit (LBool false) (* comparing anything else return false *) - | A.Binop A.Neq, [ _; _ ] -> ( - match - Pos.unmark - (evaluate_operator ctx (Pos.same_pos_as (A.Binop A.Eq) op) args) - with - | A.ELit (A.LBool b) -> A.ELit (A.LBool (not b)) - | _ -> assert false (*should not happen *)) - | A.Binop A.Concat, [ A.EArray es1; A.EArray es2 ] -> A.EArray (es1 @ es2) - | A.Binop A.Map, [ _; A.EArray es ] -> - A.EArray - (List.map - (fun e' -> - evaluate_expr ctx - (Pos.same_pos_as (A.EApp (List.nth args 0, [ e' ])) e')) - es) - | A.Binop A.Filter, [ _; A.EArray es ] -> - A.EArray - (List.filter - (fun e' -> - match - evaluate_expr ctx - (Pos.same_pos_as (A.EApp (List.nth args 0, [ e' ])) e') - with - | A.ELit (A.LBool b), _ -> b - | _ -> - Errors.raise_spanned_error - (Pos.get_position (List.nth args 0)) - "This predicate evaluated to something else than a \ - boolean (should not happen if the term was well-typed)") - es) - | A.Binop _, ([ ELit LEmptyError; _ ] | [ _; ELit LEmptyError ]) -> - A.ELit LEmptyError - | A.Unop (A.Minus KInt), [ ELit (LInt i) ] -> - A.ELit (LInt Runtime.(integer_of_int 0 -! i)) - | A.Unop (A.Minus KRat), [ ELit (LRat i) ] -> - A.ELit (LRat Runtime.(decimal_of_string "0" -& i)) - | A.Unop (A.Minus KMoney), [ ELit (LMoney i) ] -> - A.ELit (LMoney Runtime.(money_of_units_int 0 -$ i)) - | A.Unop (A.Minus KDuration), [ ELit (LDuration i) ] -> - A.ELit (LDuration Runtime.(~-^i)) - | A.Unop A.Not, [ ELit (LBool b) ] -> A.ELit (LBool (not b)) - | A.Unop A.Length, [ EArray es ] -> - A.ELit (LInt (Runtime.integer_of_int (List.length es))) - | A.Unop A.GetDay, [ ELit (LDate d) ] -> - A.ELit (LInt Runtime.(day_of_month_of_date d)) - | A.Unop A.GetMonth, [ ELit (LDate d) ] -> - A.ELit (LInt Runtime.(month_number_of_date d)) - | A.Unop A.GetYear, [ ELit (LDate d) ] -> - A.ELit (LInt Runtime.(year_of_date d)) - | A.Unop A.IntToRat, [ ELit (LInt i) ] -> - A.ELit (LRat Runtime.(decimal_of_integer i)) - | A.Unop A.RoundMoney, [ ELit (LMoney m) ] -> - A.ELit (LMoney Runtime.(money_round m)) - | A.Unop A.RoundDecimal, [ ELit (LRat m) ] -> - A.ELit (LRat Runtime.(decimal_round m)) - | A.Unop (A.Log (entry, infos)), [ e' ] -> - if !Cli.trace_flag then ( - match entry with - | VarDef _ -> - (* TODO: this usage of Format is broken, Formatting requires that - all is formatted in one pass, without going through - intermediate "%s" *) - Cli.log_format "%*s%a %a: %s" (!log_indent * 2) "" - Print.format_log_entry entry Print.format_uid_list infos - (match e' with - (* | Ast.EAbs _ -> Cli.with_style [ ANSITerminal.green ] - "" *) - | _ -> - let expr_str = - Format.asprintf "%a" - (Print.format_expr ctx ~debug:false) - (e', Pos.no_pos) - in - let expr_str = - Re.Pcre.substitute ~rex:(Re.Pcre.regexp "\n\\s*") - ~subst:(fun _ -> " ") - expr_str - in - Cli.with_style [ ANSITerminal.green ] "%s" expr_str) - | PosRecordIfTrueBool -> ( - let pos = Pos.get_position op in - match (pos <> Pos.no_pos, e') with - | true, ELit (LBool true) -> - Cli.log_format "%*s%a%s:\n%s" (!log_indent * 2) "" - Print.format_log_entry entry - (Cli.with_style [ ANSITerminal.green ] "Definition applied") - (Cli.add_prefix_to_each_line (Pos.retrieve_loc_text pos) - (fun _ -> Format.asprintf "%*s" (!log_indent * 2) "")) - | _ -> ()) - | BeginCall -> - Cli.log_format "%*s%a %a" (!log_indent * 2) "" - Print.format_log_entry entry Print.format_uid_list infos; - log_indent := !log_indent + 1 - | EndCall -> - log_indent := !log_indent - 1; - Cli.log_format "%*s%a %a" (!log_indent * 2) "" - Print.format_log_entry entry Print.format_uid_list infos) - else (); - e' - | A.Unop _, [ ELit LEmptyError ] -> A.ELit LEmptyError + (Pos.same_pos_as (A.EApp (List.nth args 0, [e'])) e') + with + | A.ELit (A.LBool b), _ -> b + | _ -> + Errors.raise_spanned_error + (Pos.get_position (List.nth args 0)) + "This predicate evaluated to something else than a boolean \ + (should not happen if the term was well-typed)") + es) + | A.Binop _, ([ELit LEmptyError; _] | [_; ELit LEmptyError]) -> + A.ELit LEmptyError + | A.Unop (A.Minus KInt), [ELit (LInt i)] -> + A.ELit (LInt Runtime.(integer_of_int 0 -! i)) + | A.Unop (A.Minus KRat), [ELit (LRat i)] -> + A.ELit (LRat Runtime.(decimal_of_string "0" -& i)) + | A.Unop (A.Minus KMoney), [ELit (LMoney i)] -> + A.ELit (LMoney Runtime.(money_of_units_int 0 -$ i)) + | A.Unop (A.Minus KDuration), [ELit (LDuration i)] -> + A.ELit (LDuration Runtime.(~-^i)) + | A.Unop A.Not, [ELit (LBool b)] -> A.ELit (LBool (not b)) + | A.Unop A.Length, [EArray es] -> + A.ELit (LInt (Runtime.integer_of_int (List.length es))) + | A.Unop A.GetDay, [ELit (LDate d)] -> + A.ELit (LInt Runtime.(day_of_month_of_date d)) + | A.Unop A.GetMonth, [ELit (LDate d)] -> + A.ELit (LInt Runtime.(month_number_of_date d)) + | A.Unop A.GetYear, [ELit (LDate d)] -> + A.ELit (LInt Runtime.(year_of_date d)) + | A.Unop A.IntToRat, [ELit (LInt i)] -> + A.ELit (LRat Runtime.(decimal_of_integer i)) + | A.Unop A.RoundMoney, [ELit (LMoney m)] -> + A.ELit (LMoney Runtime.(money_round m)) + | A.Unop A.RoundDecimal, [ELit (LRat m)] -> + A.ELit (LRat Runtime.(decimal_round m)) + | A.Unop (A.Log (entry, infos)), [e'] -> + if !Cli.trace_flag then ( + match entry with + | VarDef _ -> + (* TODO: this usage of Format is broken, Formatting requires that all + is formatted in one pass, without going through intermediate + "%s" *) + Cli.log_format "%*s%a %a: %s" (!log_indent * 2) "" + Print.format_log_entry entry Print.format_uid_list infos + (match e' with + (* | Ast.EAbs _ -> Cli.with_style [ ANSITerminal.green ] + "" *) + | _ -> + let expr_str = + Format.asprintf "%a" + (Print.format_expr ctx ~debug:false) + (e', Pos.no_pos) + in + let expr_str = + Re.Pcre.substitute ~rex:(Re.Pcre.regexp "\n\\s*") + ~subst:(fun _ -> " ") + expr_str + in + Cli.with_style [ANSITerminal.green] "%s" expr_str) + | PosRecordIfTrueBool -> ( + let pos = Pos.get_position op in + match pos <> Pos.no_pos, e' with + | true, ELit (LBool true) -> + Cli.log_format "%*s%a%s:\n%s" (!log_indent * 2) "" + Print.format_log_entry entry + (Cli.with_style [ANSITerminal.green] "Definition applied") + (Cli.add_prefix_to_each_line (Pos.retrieve_loc_text pos) (fun _ -> + Format.asprintf "%*s" (!log_indent * 2) "")) + | _ -> ()) + | BeginCall -> + Cli.log_format "%*s%a %a" (!log_indent * 2) "" Print.format_log_entry + entry Print.format_uid_list infos; + log_indent := !log_indent + 1 + | EndCall -> + log_indent := !log_indent - 1; + Cli.log_format "%*s%a %a" (!log_indent * 2) "" Print.format_log_entry + entry Print.format_uid_list infos) + else (); + e' + | A.Unop _, [ELit LEmptyError] -> A.ELit LEmptyError | _ -> - Errors.raise_multispanned_error - ([ (Some "Operator:", Pos.get_position op) ] - @ List.mapi - (fun i arg -> - ( Some - (Format.asprintf "Argument n°%d, value %a" (i + 1) - (Print.format_expr ctx ~debug:true) - arg), - Pos.get_position arg )) - args) - "Operator applied to the wrong arguments\n\ - (should not happen if the term was well-typed)") + Errors.raise_multispanned_error + ([Some "Operator:", Pos.get_position op] + @ List.mapi + (fun i arg -> + ( Some + (Format.asprintf "Argument n°%d, value %a" (i + 1) + (Print.format_expr ctx ~debug:true) + arg), + Pos.get_position arg )) + args) + "Operator applied to the wrong arguments\n\ + (should not happen if the term was well-typed)") op and evaluate_expr (ctx : Ast.decl_ctx) (e : A.expr Pos.marked) : A.expr Pos.marked = match Pos.unmark e with | EVar _ -> - Errors.raise_spanned_error (Pos.get_position e) - "free variable found at evaluation (should not happen if term was \ - well-typed" + Errors.raise_spanned_error (Pos.get_position e) + "free variable found at evaluation (should not happen if term was \ + well-typed" | EApp (e1, args) -> ( - let e1 = evaluate_expr ctx e1 in - let args = List.map (evaluate_expr ctx) args in - match Pos.unmark e1 with - | EAbs ((binder, _), _) -> - if Bindlib.mbinder_arity binder = List.length args then - evaluate_expr ctx - (Bindlib.msubst binder (Array.of_list (List.map Pos.unmark args))) - else - Errors.raise_spanned_error (Pos.get_position e) - "wrong function call, expected %d arguments, got %d" - (Bindlib.mbinder_arity binder) - (List.length args) - | EOp op -> - Pos.same_pos_as - (Pos.unmark (evaluate_operator ctx (Pos.same_pos_as op e1) args)) - e - | ELit LEmptyError -> Pos.same_pos_as (A.ELit LEmptyError) e - | _ -> - Errors.raise_spanned_error (Pos.get_position e) - "function has not been reduced to a lambda at evaluation (should \ - not happen if the term was well-typed") + let e1 = evaluate_expr ctx e1 in + let args = List.map (evaluate_expr ctx) args in + match Pos.unmark e1 with + | EAbs ((binder, _), _) -> + if Bindlib.mbinder_arity binder = List.length args then + evaluate_expr ctx + (Bindlib.msubst binder (Array.of_list (List.map Pos.unmark args))) + else + Errors.raise_spanned_error (Pos.get_position e) + "wrong function call, expected %d arguments, got %d" + (Bindlib.mbinder_arity binder) + (List.length args) + | EOp op -> + Pos.same_pos_as + (Pos.unmark (evaluate_operator ctx (Pos.same_pos_as op e1) args)) + e + | ELit LEmptyError -> Pos.same_pos_as (A.ELit LEmptyError) e + | _ -> + Errors.raise_spanned_error (Pos.get_position e) + "function has not been reduced to a lambda at evaluation (should not \ + happen if the term was well-typed") | EAbs _ | ELit _ | EOp _ -> e (* these are values *) | ETuple (es, s) -> - let new_es = List.map (evaluate_expr ctx) es in - if List.exists is_empty_error new_es then - Pos.same_pos_as (A.ELit LEmptyError) e - else Pos.same_pos_as (A.ETuple (new_es, s)) e + let new_es = List.map (evaluate_expr ctx) es in + if List.exists is_empty_error new_es then + Pos.same_pos_as (A.ELit LEmptyError) e + else Pos.same_pos_as (A.ETuple (new_es, s)) e | ETupleAccess (e1, n, s, _) -> ( - let e1 = evaluate_expr ctx e1 in - match Pos.unmark e1 with - | ETuple (es, s') -> ( - (match (s, s') with - | None, None -> () - | Some s, Some s' when s = s' -> () - | _ -> - Errors.raise_multispanned_error - [ (None, Pos.get_position e); (None, Pos.get_position e1) ] - "Error during tuple access: not the same structs (should not \ - happen if the term was well-typed)"); - match List.nth_opt es n with - | Some e' -> e' - | None -> - Errors.raise_spanned_error (Pos.get_position e1) - "The tuple has %d components but the %i-th element was \ - requested (should not happen if the term was well-type)" - (List.length es) n) - | ELit LEmptyError -> Pos.same_pos_as (A.ELit LEmptyError) e + let e1 = evaluate_expr ctx e1 in + match Pos.unmark e1 with + | ETuple (es, s') -> ( + (match s, s' with + | None, None -> () + | Some s, Some s' when s = s' -> () | _ -> - Errors.raise_spanned_error (Pos.get_position e1) - "The expression %a should be a tuple with %d components but is not \ - (should not happen if the term was well-typed)" - (Print.format_expr ctx ~debug:true) - e n) + Errors.raise_multispanned_error + [None, Pos.get_position e; None, Pos.get_position e1] + "Error during tuple access: not the same structs (should not happen \ + if the term was well-typed)"); + match List.nth_opt es n with + | Some e' -> e' + | None -> + Errors.raise_spanned_error (Pos.get_position e1) + "The tuple has %d components but the %i-th element was requested \ + (should not happen if the term was well-type)" + (List.length es) n) + | ELit LEmptyError -> Pos.same_pos_as (A.ELit LEmptyError) e + | _ -> + Errors.raise_spanned_error (Pos.get_position e1) + "The expression %a should be a tuple with %d components but is not \ + (should not happen if the term was well-typed)" + (Print.format_expr ctx ~debug:true) + e n) | EInj (e1, n, en, ts) -> - let e1' = evaluate_expr ctx e1 in - if is_empty_error e1' then Pos.same_pos_as (A.ELit LEmptyError) e - else Pos.same_pos_as (A.EInj (e1', n, en, ts)) e + let e1' = evaluate_expr ctx e1 in + if is_empty_error e1' then Pos.same_pos_as (A.ELit LEmptyError) e + else Pos.same_pos_as (A.EInj (e1', n, en, ts)) e | EMatch (e1, es, e_name) -> ( - let e1 = evaluate_expr ctx e1 in - match Pos.unmark e1 with - | A.EInj (e1, n, e_name', _) -> - if e_name <> e_name' then - Errors.raise_multispanned_error - [ (None, Pos.get_position e); (None, Pos.get_position e1) ] - "Error during match: two different enums found (should not \ - happend if the term was well-typed)"; - let es_n = - match List.nth_opt es n with - | Some es_n -> es_n - | None -> - Errors.raise_spanned_error (Pos.get_position e) - "sum type index error (should not happend if the term was \ - well-typed)" - in - let new_e = Pos.same_pos_as (A.EApp (es_n, [ e1 ])) e in - evaluate_expr ctx new_e - | A.ELit A.LEmptyError -> Pos.same_pos_as (A.ELit A.LEmptyError) e - | _ -> - Errors.raise_spanned_error (Pos.get_position e1) - "Expected a term having a sum type as an argument to a match \ - (should not happend if the term was well-typed") + let e1 = evaluate_expr ctx e1 in + match Pos.unmark e1 with + | A.EInj (e1, n, e_name', _) -> + if e_name <> e_name' then + Errors.raise_multispanned_error + [None, Pos.get_position e; None, Pos.get_position e1] + "Error during match: two different enums found (should not happend \ + if the term was well-typed)"; + let es_n = + match List.nth_opt es n with + | Some es_n -> es_n + | None -> + Errors.raise_spanned_error (Pos.get_position e) + "sum type index error (should not happend if the term was \ + well-typed)" + in + let new_e = Pos.same_pos_as (A.EApp (es_n, [e1])) e in + evaluate_expr ctx new_e + | A.ELit A.LEmptyError -> Pos.same_pos_as (A.ELit A.LEmptyError) e + | _ -> + Errors.raise_spanned_error (Pos.get_position e1) + "Expected a term having a sum type as an argument to a match (should \ + not happend if the term was well-typed") | EDefault (exceptions, just, cons) -> ( - let exceptions = List.map (evaluate_expr ctx) exceptions in - let empty_count = List.length (List.filter is_empty_error exceptions) in - match List.length exceptions - empty_count with - | 0 -> ( - let just = evaluate_expr ctx just in - match Pos.unmark just with - | ELit LEmptyError -> Pos.same_pos_as (A.ELit LEmptyError) e - | ELit (LBool true) -> evaluate_expr ctx cons - | ELit (LBool false) -> Pos.same_pos_as (A.ELit LEmptyError) e - | _ -> - Errors.raise_spanned_error (Pos.get_position e) - "Default justification has not been reduced to a boolean at \ - evaluation (should not happen if the term was well-typed") - | 1 -> List.find (fun sub -> not (is_empty_error sub)) exceptions + let exceptions = List.map (evaluate_expr ctx) exceptions in + let empty_count = List.length (List.filter is_empty_error exceptions) in + match List.length exceptions - empty_count with + | 0 -> ( + let just = evaluate_expr ctx just in + match Pos.unmark just with + | ELit LEmptyError -> Pos.same_pos_as (A.ELit LEmptyError) e + | ELit (LBool true) -> evaluate_expr ctx cons + | ELit (LBool false) -> Pos.same_pos_as (A.ELit LEmptyError) e | _ -> - Errors.raise_multispanned_error - (List.map - (fun except -> - ( Some "This consequence has a valid justification:", - Pos.get_position except )) - (List.filter (fun sub -> not (is_empty_error sub)) exceptions)) - "There is a conflict between multiple valid consequences for \ - assigning the same variable.") + Errors.raise_spanned_error (Pos.get_position e) + "Default justification has not been reduced to a boolean at \ + evaluation (should not happen if the term was well-typed") + | 1 -> List.find (fun sub -> not (is_empty_error sub)) exceptions + | _ -> + Errors.raise_multispanned_error + (List.map + (fun except -> + ( Some "This consequence has a valid justification:", + Pos.get_position except )) + (List.filter (fun sub -> not (is_empty_error sub)) exceptions)) + "There is a conflict between multiple valid consequences for assigning \ + the same variable.") | EIfThenElse (cond, et, ef) -> ( - match Pos.unmark (evaluate_expr ctx cond) with - | ELit (LBool true) -> evaluate_expr ctx et - | ELit (LBool false) -> evaluate_expr ctx ef - | ELit LEmptyError -> Pos.same_pos_as (A.ELit LEmptyError) e - | _ -> - Errors.raise_spanned_error (Pos.get_position cond) - "Expected a boolean literal for the result of this condition \ - (should not happen if the term was well-typed)") + match Pos.unmark (evaluate_expr ctx cond) with + | ELit (LBool true) -> evaluate_expr ctx et + | ELit (LBool false) -> evaluate_expr ctx ef + | ELit LEmptyError -> Pos.same_pos_as (A.ELit LEmptyError) e + | _ -> + Errors.raise_spanned_error (Pos.get_position cond) + "Expected a boolean literal for the result of this condition (should \ + not happen if the term was well-typed)") | EArray es -> - let new_es = List.map (evaluate_expr ctx) es in - if List.exists is_empty_error new_es then - Pos.same_pos_as (A.ELit LEmptyError) e - else Pos.same_pos_as (A.EArray new_es) e + let new_es = List.map (evaluate_expr ctx) es in + if List.exists is_empty_error new_es then + Pos.same_pos_as (A.ELit LEmptyError) e + else Pos.same_pos_as (A.EArray new_es) e | ErrorOnEmpty e' -> - let e' = evaluate_expr ctx e' in - if Pos.unmark e' = A.ELit LEmptyError then - Errors.raise_spanned_error (Pos.get_position e') - "This variable evaluated to an empty term (no rule that defined it \ - applied in this situation)" - else e' + let e' = evaluate_expr ctx e' in + if Pos.unmark e' = A.ELit LEmptyError then + Errors.raise_spanned_error (Pos.get_position e') + "This variable evaluated to an empty term (no rule that defined it \ + applied in this situation)" + else e' | EAssert e' -> ( - match Pos.unmark (evaluate_expr ctx e') with - | ELit (LBool true) -> Pos.same_pos_as (Ast.ELit LUnit) e' - | ELit (LBool false) -> ( - match Pos.unmark e' with - | Ast.ErrorOnEmpty - ( EApp - ( (Ast.EOp (Binop op), pos_op), - [ ((ELit _, _) as e1); ((ELit _, _) as e2) ] ), - _ ) - | EApp - ( (Ast.EOp (Ast.Unop (Ast.Log _)), _), - [ - ( Ast.EApp - ( (Ast.EOp (Binop op), pos_op), - [ ((ELit _, _) as e1); ((ELit _, _) as e2) ] ), - _ ); - ] ) - | EApp + match Pos.unmark (evaluate_expr ctx e') with + | ELit (LBool true) -> Pos.same_pos_as (Ast.ELit LUnit) e' + | ELit (LBool false) -> ( + match Pos.unmark e' with + | Ast.ErrorOnEmpty + ( EApp ( (Ast.EOp (Binop op), pos_op), - [ ((ELit _, _) as e1); ((ELit _, _) as e2) ] ) -> - Errors.raise_spanned_error (Pos.get_position e') - "Assertion failed: %a %a %a" - (Print.format_expr ctx ~debug:false) - e1 Print.format_binop (op, pos_op) - (Print.format_expr ctx ~debug:false) - e2 - | _ -> - Cli.debug_format "%a" (Print.format_expr ctx) e'; - Errors.raise_spanned_error (Pos.get_position e') - "Assertion failed") - | ELit LEmptyError -> Pos.same_pos_as (A.ELit LEmptyError) e + [((ELit _, _) as e1); ((ELit _, _) as e2)] ), + _ ) + | EApp + ( (Ast.EOp (Ast.Unop (Ast.Log _)), _), + [ + ( Ast.EApp + ( (Ast.EOp (Binop op), pos_op), + [((ELit _, _) as e1); ((ELit _, _) as e2)] ), + _ ); + ] ) + | EApp + ( (Ast.EOp (Binop op), pos_op), + [((ELit _, _) as e1); ((ELit _, _) as e2)] ) -> + Errors.raise_spanned_error (Pos.get_position e') + "Assertion failed: %a %a %a" + (Print.format_expr ctx ~debug:false) + e1 Print.format_binop (op, pos_op) + (Print.format_expr ctx ~debug:false) + e2 | _ -> - Errors.raise_spanned_error (Pos.get_position e') - "Expected a boolean literal for the result of this assertion \ - (should not happen if the term was well-typed)") + Cli.debug_format "%a" (Print.format_expr ctx) e'; + Errors.raise_spanned_error (Pos.get_position e') "Assertion failed") + | ELit LEmptyError -> Pos.same_pos_as (A.ELit LEmptyError) e + | _ -> + Errors.raise_spanned_error (Pos.get_position e') + "Expected a boolean literal for the result of this assertion (should \ + not happen if the term was well-typed)") (** {1 API} *) let interpret_program (ctx : Ast.decl_ctx) (e : Ast.expr Pos.marked) : (Uid.MarkedString.info * Ast.expr Pos.marked) list = match Pos.unmark (evaluate_expr ctx e) with - | Ast.EAbs (_, [ (Ast.TTuple (taus, Some s_in), _) ]) -> ( - let application_term = List.map (fun _ -> Ast.empty_thunked_term) taus in - let to_interpret = - ( Ast.EApp - (e, [ (Ast.ETuple (application_term, Some s_in), Pos.no_pos) ]), - Pos.no_pos ) + | Ast.EAbs (_, [(Ast.TTuple (taus, Some s_in), _)]) -> ( + let application_term = List.map (fun _ -> Ast.empty_thunked_term) taus in + let to_interpret = + ( Ast.EApp (e, [Ast.ETuple (application_term, Some s_in), Pos.no_pos]), + Pos.no_pos ) + in + match Pos.unmark (evaluate_expr ctx to_interpret) with + | Ast.ETuple (args, Some s_out) -> + let s_out_fields = + List.map + (fun (f, _) -> Ast.StructFieldName.get_info f) + (Ast.StructMap.find s_out ctx.ctx_structs) in - match Pos.unmark (evaluate_expr ctx to_interpret) with - | Ast.ETuple (args, Some s_out) -> - let s_out_fields = - List.map - (fun (f, _) -> Ast.StructFieldName.get_info f) - (Ast.StructMap.find s_out ctx.ctx_structs) - in - List.map2 (fun arg var -> (var, arg)) args s_out_fields - | _ -> - Errors.raise_spanned_error (Pos.get_position e) - "The interpretation of a program should always yield a struct \ - corresponding to the scope variables") - | _ -> + List.map2 (fun arg var -> var, arg) args s_out_fields + | _ -> Errors.raise_spanned_error (Pos.get_position e) - "The interpreter can only interpret terms starting with functions \ - having thunked arguments" + "The interpretation of a program should always yield a struct \ + corresponding to the scope variables") + | _ -> + Errors.raise_spanned_error (Pos.get_position e) + "The interpreter can only interpret terms starting with functions having \ + thunked arguments" diff --git a/compiler/dcalc/optimizations.ml b/compiler/dcalc/optimizations.ml index 15cb1923..ea93cb36 100644 --- a/compiler/dcalc/optimizations.ml +++ b/compiler/dcalc/optimizations.ml @@ -29,163 +29,161 @@ let rec partial_evaluation (ctx : partial_evaluation_ctx) (e : expr Pos.marked) match Pos.unmark e with | EApp ( (( EOp (Unop Not), _ - | EApp ((EOp (Unop (Log _)), _), [ (EOp (Unop Not), _) ]), _ ) as op), - [ e1 ] ) -> - (* reduction of logical not *) - (Bindlib.box_apply (fun e1 -> - match e1 with - | ELit (LBool false), _ -> (ELit (LBool true), pos) - | ELit (LBool true), _ -> (ELit (LBool false), pos) - | _ -> (EApp (op, [ e1 ]), pos))) - (rec_helper e1) + | EApp ((EOp (Unop (Log _)), _), [(EOp (Unop Not), _)]), _ ) as op), + [e1] ) -> + (* reduction of logical not *) + (Bindlib.box_apply (fun e1 -> + match e1 with + | ELit (LBool false), _ -> ELit (LBool true), pos + | ELit (LBool true), _ -> ELit (LBool false), pos + | _ -> EApp (op, [e1]), pos)) + (rec_helper e1) | EApp ( (( EOp (Binop Or), _ - | EApp ((EOp (Unop (Log _)), _), [ (EOp (Binop Or), _) ]), _ ) as op), - [ e1; e2 ] ) -> - (* reduction of logical or *) - (Bindlib.box_apply2 (fun e1 e2 -> - match (e1, e2) with - | (ELit (LBool false), _), new_e | new_e, (ELit (LBool false), _) -> - new_e - | (ELit (LBool true), _), _ | _, (ELit (LBool true), _) -> - (ELit (LBool true), pos) - | _ -> (EApp (op, [ e1; e2 ]), pos))) - (rec_helper e1) (rec_helper e2) + | EApp ((EOp (Unop (Log _)), _), [(EOp (Binop Or), _)]), _ ) as op), + [e1; e2] ) -> + (* reduction of logical or *) + (Bindlib.box_apply2 (fun e1 e2 -> + match e1, e2 with + | (ELit (LBool false), _), new_e | new_e, (ELit (LBool false), _) -> + new_e + | (ELit (LBool true), _), _ | _, (ELit (LBool true), _) -> + ELit (LBool true), pos + | _ -> EApp (op, [e1; e2]), pos)) + (rec_helper e1) (rec_helper e2) | EApp ( (( EOp (Binop And), _ - | EApp ((EOp (Unop (Log _)), _), [ (EOp (Binop And), _) ]), _ ) as op), - [ e1; e2 ] ) -> - (* reduction of logical and *) - (Bindlib.box_apply2 (fun e1 e2 -> - match (e1, e2) with - | (ELit (LBool true), _), new_e | new_e, (ELit (LBool true), _) -> - new_e - | (ELit (LBool false), _), _ | _, (ELit (LBool false), _) -> - (ELit (LBool false), pos) - | _ -> (EApp (op, [ e1; e2 ]), pos))) - (rec_helper e1) (rec_helper e2) - | EVar (x, _) -> Bindlib.box_apply (fun x -> (x, pos)) (Bindlib.box_var x) + | EApp ((EOp (Unop (Log _)), _), [(EOp (Binop And), _)]), _ ) as op), + [e1; e2] ) -> + (* reduction of logical and *) + (Bindlib.box_apply2 (fun e1 e2 -> + match e1, e2 with + | (ELit (LBool true), _), new_e | new_e, (ELit (LBool true), _) -> + new_e + | (ELit (LBool false), _), _ | _, (ELit (LBool false), _) -> + ELit (LBool false), pos + | _ -> EApp (op, [e1; e2]), pos)) + (rec_helper e1) (rec_helper e2) + | EVar (x, _) -> Bindlib.box_apply (fun x -> x, pos) (Bindlib.box_var x) | ETuple (args, s_name) -> - Bindlib.box_apply - (fun args -> (ETuple (args, s_name), pos)) - (List.map rec_helper args |> Bindlib.box_list) + Bindlib.box_apply + (fun args -> ETuple (args, s_name), pos) + (List.map rec_helper args |> Bindlib.box_list) | ETupleAccess (arg, i, s_name, typs) -> - Bindlib.box_apply - (fun arg -> (ETupleAccess (arg, i, s_name, typs), pos)) - (rec_helper arg) + Bindlib.box_apply + (fun arg -> ETupleAccess (arg, i, s_name, typs), pos) + (rec_helper arg) | EInj (arg, i, e_name, typs) -> - Bindlib.box_apply - (fun arg -> (EInj (arg, i, e_name, typs), pos)) - (rec_helper arg) + Bindlib.box_apply + (fun arg -> EInj (arg, i, e_name, typs), pos) + (rec_helper arg) | EMatch (arg, arms, e_name) -> - Bindlib.box_apply2 - (fun arg arms -> - match (arg, arms) with - | (EInj (e1, i, e_name', _ts), _), _ - when Ast.EnumName.compare e_name e_name' = 0 -> - (* iota reduction *) - (EApp (List.nth arms i, [ e1 ]), pos) - | _ -> (EMatch (arg, arms, e_name), pos)) - (rec_helper arg) - (List.map rec_helper arms |> Bindlib.box_list) + Bindlib.box_apply2 + (fun arg arms -> + match arg, arms with + | (EInj (e1, i, e_name', _ts), _), _ + when Ast.EnumName.compare e_name e_name' = 0 -> + (* iota reduction *) + EApp (List.nth arms i, [e1]), pos + | _ -> EMatch (arg, arms, e_name), pos) + (rec_helper arg) + (List.map rec_helper arms |> Bindlib.box_list) | EArray args -> - Bindlib.box_apply - (fun args -> (EArray args, pos)) - (List.map rec_helper args |> Bindlib.box_list) + Bindlib.box_apply + (fun args -> EArray args, pos) + (List.map rec_helper args |> Bindlib.box_list) | ELit l -> Bindlib.box (ELit l, pos) | EAbs ((binder, binder_pos), typs) -> - let vars, body = Bindlib.unmbind binder in - let new_body = rec_helper body in - let new_binder = Bindlib.bind_mvar vars new_body in - Bindlib.box_apply - (fun binder -> (EAbs ((binder, binder_pos), typs), pos)) - new_binder + let vars, body = Bindlib.unmbind binder in + let new_body = rec_helper body in + let new_binder = Bindlib.bind_mvar vars new_body in + Bindlib.box_apply + (fun binder -> EAbs ((binder, binder_pos), typs), pos) + new_binder | EApp (f, args) -> - Bindlib.box_apply2 - (fun f args -> - match Pos.unmark f with - | EAbs ((binder, _pos_binder), _ts) -> - (* beta reduction *) - Bindlib.msubst binder (List.map fst args |> Array.of_list) - | _ -> (EApp (f, args), pos)) - (rec_helper f) - (List.map rec_helper args |> Bindlib.box_list) - | EAssert e1 -> - Bindlib.box_apply (fun e1 -> (EAssert e1, pos)) (rec_helper e1) + Bindlib.box_apply2 + (fun f args -> + match Pos.unmark f with + | EAbs ((binder, _pos_binder), _ts) -> + (* beta reduction *) + Bindlib.msubst binder (List.map fst args |> Array.of_list) + | _ -> EApp (f, args), pos) + (rec_helper f) + (List.map rec_helper args |> Bindlib.box_list) + | EAssert e1 -> Bindlib.box_apply (fun e1 -> EAssert e1, pos) (rec_helper e1) | EOp op -> Bindlib.box (EOp op, pos) | EDefault (exceptions, just, cons) -> - Bindlib.box_apply3 - (fun exceptions just cons -> - (* TODO: mechanically prove each of these optimizations correct :) *) - match - ( List.filter - (fun except -> - match Pos.unmark except with - | ELit LEmptyError -> false - | _ -> true) - exceptions - (* we can discard the exceptions that are always empty error *), - just, - cons ) - with - | exceptions, just, cons - when List.fold_left - (fun nb except -> if is_value except then nb + 1 else nb) - 0 exceptions - > 1 -> - (* at this point we know a conflict error will be triggered so we - just feed the expression to the interpreter that will print the - beautiful right error message *) - Interpreter.evaluate_expr ctx.decl_ctx - (EDefault (exceptions, just, cons), pos) - | [ except ], _, _ when is_value except -> - (* if there is only one exception and it is a non-empty value it - is always chosen *) - except - | ( [], - ( ( ELit (LBool true) - | EApp ((EOp (Unop (Log _)), _), [ (ELit (LBool true), _) ]) ), - _ ), - cons ) -> - cons - | ( [], - ( ( ELit (LBool false) - | EApp ((EOp (Unop (Log _)), _), [ (ELit (LBool false), _) ]) ), - _ ), - _ ) -> - (ELit LEmptyError, pos) - | [], just, cons when not !Cli.avoid_exceptions_flag -> - (* without exceptions, a default is just an [if then else] raising - an error in the else case. This exception is only valid in the - context of compilation_with_exceptions, so we desactivate with - a global flag to know if we will be compiling using exceptions - or the option monad. *) - (EIfThenElse (just, cons, (ELit LEmptyError, pos)), pos) - | exceptions, just, cons -> (EDefault (exceptions, just, cons), pos)) - (List.map rec_helper exceptions |> Bindlib.box_list) - (rec_helper just) (rec_helper cons) + Bindlib.box_apply3 + (fun exceptions just cons -> + (* TODO: mechanically prove each of these optimizations correct :) *) + match + ( List.filter + (fun except -> + match Pos.unmark except with + | ELit LEmptyError -> false + | _ -> true) + exceptions + (* we can discard the exceptions that are always empty error *), + just, + cons ) + with + | exceptions, just, cons + when List.fold_left + (fun nb except -> if is_value except then nb + 1 else nb) + 0 exceptions + > 1 -> + (* at this point we know a conflict error will be triggered so we just + feed the expression to the interpreter that will print the + beautiful right error message *) + Interpreter.evaluate_expr ctx.decl_ctx + (EDefault (exceptions, just, cons), pos) + | [except], _, _ when is_value except -> + (* if there is only one exception and it is a non-empty value it is + always chosen *) + except + | ( [], + ( ( ELit (LBool true) + | EApp ((EOp (Unop (Log _)), _), [(ELit (LBool true), _)]) ), + _ ), + cons ) -> + cons + | ( [], + ( ( ELit (LBool false) + | EApp ((EOp (Unop (Log _)), _), [(ELit (LBool false), _)]) ), + _ ), + _ ) -> + ELit LEmptyError, pos + | [], just, cons when not !Cli.avoid_exceptions_flag -> + (* without exceptions, a default is just an [if then else] raising an + error in the else case. This exception is only valid in the context + of compilation_with_exceptions, so we desactivate with a global + flag to know if we will be compiling using exceptions or the option + monad. *) + EIfThenElse (just, cons, (ELit LEmptyError, pos)), pos + | exceptions, just, cons -> EDefault (exceptions, just, cons), pos) + (List.map rec_helper exceptions |> Bindlib.box_list) + (rec_helper just) (rec_helper cons) | EIfThenElse (e1, e2, e3) -> - Bindlib.box_apply3 - (fun e1 e2 e3 -> - match (Pos.unmark e1, Pos.unmark e2, Pos.unmark e3) with - | ELit (LBool true), _, _ - | EApp ((EOp (Unop (Log _)), _), [ (ELit (LBool true), _) ]), _, _ -> - e2 - | ELit (LBool false), _, _ - | EApp ((EOp (Unop (Log _)), _), [ (ELit (LBool false), _) ]), _, _ -> - e3 - | ( _, - ( ELit (LBool true) - | EApp ((EOp (Unop (Log _)), _), [ (ELit (LBool true), _) ]) ), - ( ELit (LBool false) - | EApp ((EOp (Unop (Log _)), _), [ (ELit (LBool false), _) ]) ) ) - -> - e1 - | _ when equal_exprs e2 e3 -> e2 - | _ -> (EIfThenElse (e1, e2, e3), pos)) - (rec_helper e1) (rec_helper e2) (rec_helper e3) + Bindlib.box_apply3 + (fun e1 e2 e3 -> + match Pos.unmark e1, Pos.unmark e2, Pos.unmark e3 with + | ELit (LBool true), _, _ + | EApp ((EOp (Unop (Log _)), _), [(ELit (LBool true), _)]), _, _ -> + e2 + | ELit (LBool false), _, _ + | EApp ((EOp (Unop (Log _)), _), [(ELit (LBool false), _)]), _, _ -> + e3 + | ( _, + ( ELit (LBool true) + | EApp ((EOp (Unop (Log _)), _), [(ELit (LBool true), _)]) ), + ( ELit (LBool false) + | EApp ((EOp (Unop (Log _)), _), [(ELit (LBool false), _)]) ) ) -> + e1 + | _ when equal_exprs e2 e3 -> e2 + | _ -> EIfThenElse (e1, e2, e3), pos) + (rec_helper e1) (rec_helper e2) (rec_helper e3) | ErrorOnEmpty e1 -> - Bindlib.box_apply (fun e1 -> (ErrorOnEmpty e1, pos)) (rec_helper e1) + Bindlib.box_apply (fun e1 -> ErrorOnEmpty e1, pos) (rec_helper e1) let optimize_expr (decl_ctx : decl_ctx) (e : expr Pos.marked) = partial_evaluation { var_values = VarMap.empty; decl_ctx } e @@ -198,19 +196,19 @@ let rec scope_lets_map match scope_body_expr with | Result e -> Bindlib.box_apply (fun e' -> Result e') (t ctx e) | ScopeLet scope_let -> - let var, next = Bindlib.unbind scope_let.scope_let_next in - let new_scope_let_expr = t ctx scope_let.scope_let_expr in - let new_next = scope_lets_map t ctx next in - let new_next = Bindlib.bind_var var new_next in - Bindlib.box_apply2 - (fun new_scope_let_expr new_next -> - ScopeLet - { - scope_let with - scope_let_expr = new_scope_let_expr; - scope_let_next = new_next; - }) - new_scope_let_expr new_next + let var, next = Bindlib.unbind scope_let.scope_let_next in + let new_scope_let_expr = t ctx scope_let.scope_let_expr in + let new_next = scope_lets_map t ctx next in + let new_next = Bindlib.bind_var var new_next in + Bindlib.box_apply2 + (fun new_scope_let_expr new_next -> + ScopeLet + { + scope_let with + scope_let_expr = new_scope_let_expr; + scope_let_next = new_next; + }) + new_scope_let_expr new_next let rec scopes_map (t : 'a -> expr Pos.marked -> expr Pos.marked Bindlib.box) @@ -219,29 +217,29 @@ let rec scopes_map match scopes with | Nil -> Bindlib.box Nil | ScopeDef scope_def -> - let scope_var, scope_next = Bindlib.unbind scope_def.scope_next in - let scope_arg_var, scope_body_expr = - Bindlib.unbind scope_def.scope_body.scope_body_expr - in - let new_scope_body_expr = scope_lets_map t ctx scope_body_expr in - let new_scope_body_expr = - Bindlib.bind_var scope_arg_var new_scope_body_expr - in - let new_scope_next = scopes_map t ctx scope_next in - let new_scope_next = Bindlib.bind_var scope_var new_scope_next in - Bindlib.box_apply2 - (fun new_scope_body_expr new_scope_next -> - ScopeDef - { - scope_def with - scope_next = new_scope_next; - scope_body = - { - scope_def.scope_body with - scope_body_expr = new_scope_body_expr; - }; - }) - new_scope_body_expr new_scope_next + let scope_var, scope_next = Bindlib.unbind scope_def.scope_next in + let scope_arg_var, scope_body_expr = + Bindlib.unbind scope_def.scope_body.scope_body_expr + in + let new_scope_body_expr = scope_lets_map t ctx scope_body_expr in + let new_scope_body_expr = + Bindlib.bind_var scope_arg_var new_scope_body_expr + in + let new_scope_next = scopes_map t ctx scope_next in + let new_scope_next = Bindlib.bind_var scope_var new_scope_next in + Bindlib.box_apply2 + (fun new_scope_body_expr new_scope_next -> + ScopeDef + { + scope_def with + scope_next = new_scope_next; + scope_body = + { + scope_def.scope_body with + scope_body_expr = new_scope_body_expr; + }; + }) + new_scope_body_expr new_scope_next let program_map (t : 'a -> expr Pos.marked -> expr Pos.marked Bindlib.box) diff --git a/compiler/dcalc/print.ml b/compiler/dcalc/print.ml index 98435351..b1162f33 100644 --- a/compiler/dcalc/print.ml +++ b/compiler/dcalc/print.ml @@ -33,7 +33,8 @@ let begins_with_uppercase (s : string) : bool = is_uppercase first_letter let format_uid_list - (fmt : Format.formatter) (infos : Uid.MarkedString.info list) : unit = + (fmt : Format.formatter) + (infos : Uid.MarkedString.info list) : unit = Format.fprintf fmt "%a" (Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt ".") @@ -41,29 +42,25 @@ let format_uid_list Format.fprintf fmt "%a" (Utils.Cli.format_with_style (if begins_with_uppercase (Pos.unmark info) then - [ ANSITerminal.red ] + [ANSITerminal.red] else [])) (Format.asprintf "%a" Utils.Uid.MarkedString.format_info info))) infos let format_keyword (fmt : Format.formatter) (s : string) : unit = - Format.fprintf fmt "%a" (Utils.Cli.format_with_style [ ANSITerminal.red ]) s + Format.fprintf fmt "%a" (Utils.Cli.format_with_style [ANSITerminal.red]) s let format_base_type (fmt : Format.formatter) (s : string) : unit = - Format.fprintf fmt "%a" - (Utils.Cli.format_with_style [ ANSITerminal.yellow ]) - s + Format.fprintf fmt "%a" (Utils.Cli.format_with_style [ANSITerminal.yellow]) s let format_punctuation (fmt : Format.formatter) (s : string) : unit = - Format.fprintf fmt "%a" (Utils.Cli.format_with_style [ ANSITerminal.cyan ]) s + Format.fprintf fmt "%a" (Utils.Cli.format_with_style [ANSITerminal.cyan]) s let format_operator (fmt : Format.formatter) (s : string) : unit = - Format.fprintf fmt "%a" (Utils.Cli.format_with_style [ ANSITerminal.green ]) s + Format.fprintf fmt "%a" (Utils.Cli.format_with_style [ANSITerminal.green]) s let format_lit_style (fmt : Format.formatter) (s : string) : unit = - Format.fprintf fmt "%a" - (Utils.Cli.format_with_style [ ANSITerminal.yellow ]) - s + Format.fprintf fmt "%a" (Utils.Cli.format_with_style [ANSITerminal.yellow]) s let format_tlit (fmt : Format.formatter) (l : typ_lit) : unit = format_base_type fmt @@ -79,12 +76,13 @@ let format_tlit (fmt : Format.formatter) (l : typ_lit) : unit = let format_enum_constructor (fmt : Format.formatter) (c : EnumConstructor.t) : unit = Format.fprintf fmt "%a" - (Utils.Cli.format_with_style [ ANSITerminal.magenta ]) + (Utils.Cli.format_with_style [ANSITerminal.magenta]) (Format.asprintf "%a" EnumConstructor.format_t c) let rec format_typ - (ctx : Ast.decl_ctx) (fmt : Format.formatter) (typ : typ Pos.marked) : unit - = + (ctx : Ast.decl_ctx) + (fmt : Format.formatter) + (typ : typ Pos.marked) : unit = let format_typ = format_typ ctx in let format_typ_with_parens (fmt : Format.formatter) (t : typ Pos.marked) = if typ_needs_parens t then Format.fprintf fmt "(%a)" format_typ t @@ -93,41 +91,40 @@ let rec format_typ match Pos.unmark typ with | TLit l -> Format.fprintf fmt "%a" format_tlit l | TTuple (ts, None) -> - Format.fprintf fmt "@[(%a)@]" - (Format.pp_print_list - ~pp_sep:(fun fmt () -> - Format.fprintf fmt "@ %a@ " format_operator "*") - (fun fmt t -> Format.fprintf fmt "%a" format_typ t)) - ts + Format.fprintf fmt "@[(%a)@]" + (Format.pp_print_list + ~pp_sep:(fun fmt () -> Format.fprintf fmt "@ %a@ " format_operator "*") + (fun fmt t -> Format.fprintf fmt "%a" format_typ t)) + ts | TTuple (_args, Some s) -> - Format.fprintf fmt "@[%a%a%a%a@]" Ast.StructName.format_t s - format_punctuation "{" - (Format.pp_print_list - ~pp_sep:(fun fmt () -> - Format.fprintf fmt "%a@ " format_punctuation ";") - (fun fmt (field, typ) -> - Format.fprintf fmt "%a%a%a%a@ %a" format_punctuation "\"" - StructFieldName.format_t field format_punctuation "\"" - format_punctuation ":" format_typ typ)) - (StructMap.find s ctx.ctx_structs) - format_punctuation "}" + Format.fprintf fmt "@[%a%a%a%a@]" Ast.StructName.format_t s + format_punctuation "{" + (Format.pp_print_list + ~pp_sep:(fun fmt () -> + Format.fprintf fmt "%a@ " format_punctuation ";") + (fun fmt (field, typ) -> + Format.fprintf fmt "%a%a%a%a@ %a" format_punctuation "\"" + StructFieldName.format_t field format_punctuation "\"" + format_punctuation ":" format_typ typ)) + (StructMap.find s ctx.ctx_structs) + format_punctuation "}" | TEnum (_, e) -> - Format.fprintf fmt "@[%a%a%a%a@]" Ast.EnumName.format_t e - format_punctuation "[" - (Format.pp_print_list - ~pp_sep:(fun fmt () -> - Format.fprintf fmt "@ %a@ " format_punctuation "|") - (fun fmt (case, typ) -> - Format.fprintf fmt "%a%a@ %a" format_enum_constructor case - format_punctuation ":" format_typ typ)) - (EnumMap.find e ctx.ctx_enums) - format_punctuation "]" + Format.fprintf fmt "@[%a%a%a%a@]" Ast.EnumName.format_t e + format_punctuation "[" + (Format.pp_print_list + ~pp_sep:(fun fmt () -> + Format.fprintf fmt "@ %a@ " format_punctuation "|") + (fun fmt (case, typ) -> + Format.fprintf fmt "%a%a@ %a" format_enum_constructor case + format_punctuation ":" format_typ typ)) + (EnumMap.find e ctx.ctx_enums) + format_punctuation "]" | TArrow (t1, t2) -> - Format.fprintf fmt "@[%a %a@ %a@]" format_typ_with_parens t1 - format_operator "→" format_typ t2 + Format.fprintf fmt "@[%a %a@ %a@]" format_typ_with_parens t1 + format_operator "→" format_typ t2 | TArray t1 -> - Format.fprintf fmt "@[%a@ %a@]" format_base_type "array" format_typ - t1 + Format.fprintf fmt "@[%a@ %a@]" format_base_type "array" format_typ + t1 | TAny -> format_base_type fmt "any" (* (EmileRolley) NOTE: seems to be factorizable with Lcalc.Print.format_lit. *) @@ -138,19 +135,17 @@ let format_lit (fmt : Format.formatter) (l : lit Pos.marked) : unit = | LEmptyError -> format_lit_style fmt "∅ " | LUnit -> format_lit_style fmt "()" | LRat i -> - format_lit_style fmt - (Runtime.decimal_to_string ~max_prec_digits:!Utils.Cli.max_prec_digits i) + format_lit_style fmt + (Runtime.decimal_to_string ~max_prec_digits:!Utils.Cli.max_prec_digits i) | LMoney e -> ( - match !Utils.Cli.locale_lang with - | En -> - format_lit_style fmt - (Format.asprintf "$%s" (Runtime.money_to_string e)) - | Fr -> - format_lit_style fmt - (Format.asprintf "%s €" (Runtime.money_to_string e)) - | Pl -> - format_lit_style fmt - (Format.asprintf "%s PLN" (Runtime.money_to_string e))) + match !Utils.Cli.locale_lang with + | En -> + format_lit_style fmt (Format.asprintf "$%s" (Runtime.money_to_string e)) + | Fr -> + format_lit_style fmt (Format.asprintf "%s €" (Runtime.money_to_string e)) + | Pl -> + format_lit_style fmt + (Format.asprintf "%s PLN" (Runtime.money_to_string e))) | LDate d -> format_lit_style fmt (Runtime.date_to_string d) | LDuration d -> format_lit_style fmt (Runtime.duration_to_string d) @@ -189,10 +184,10 @@ let format_ternop (fmt : Format.formatter) (op : ternop Pos.marked) : unit = let format_log_entry (fmt : Format.formatter) (entry : log_entry) : unit = Format.fprintf fmt "@<2>%s" (match entry with - | VarDef _ -> Utils.Cli.with_style [ ANSITerminal.blue ] "≔ " - | BeginCall -> Utils.Cli.with_style [ ANSITerminal.yellow ] "→ " - | EndCall -> Utils.Cli.with_style [ ANSITerminal.yellow ] "← " - | PosRecordIfTrueBool -> Utils.Cli.with_style [ ANSITerminal.green ] "☛ ") + | VarDef _ -> Utils.Cli.with_style [ANSITerminal.blue] "≔ " + | BeginCall -> Utils.Cli.with_style [ANSITerminal.yellow] "→ " + | EndCall -> Utils.Cli.with_style [ANSITerminal.yellow] "← " + | PosRecordIfTrueBool -> Utils.Cli.with_style [ANSITerminal.green] "☛ ") let format_unop (fmt : Format.formatter) (op : unop Pos.marked) : unit = Format.fprintf fmt "%s" @@ -200,11 +195,11 @@ let format_unop (fmt : Format.formatter) (op : unop Pos.marked) : unit = | Minus _ -> "-" | Not -> "~" | Log (entry, infos) -> - Format.asprintf "log@[[%a|%a]@]" format_log_entry entry - (Format.pp_print_list - ~pp_sep:(fun fmt () -> Format.fprintf fmt ".") - (fun fmt info -> Utils.Uid.MarkedString.format_info fmt info)) - infos + Format.asprintf "log@[[%a|%a]@]" format_log_entry entry + (Format.pp_print_list + ~pp_sep:(fun fmt () -> Format.fprintf fmt ".") + (fun fmt info -> Utils.Uid.MarkedString.format_info fmt info)) + infos | Length -> "length" | IntToRat -> "int_to_rat" | GetDay -> "get_day" @@ -234,123 +229,120 @@ let rec format_expr match Pos.unmark e with | EVar v -> Format.fprintf fmt "%a" format_var (Pos.unmark v) | ETuple (es, None) -> - Format.fprintf fmt "@[%a%a%a@]" format_punctuation "(" - (Format.pp_print_list - ~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ") - (fun fmt e -> Format.fprintf fmt "%a" format_expr e)) - es format_punctuation ")" + Format.fprintf fmt "@[%a%a%a@]" format_punctuation "(" + (Format.pp_print_list + ~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ") + (fun fmt e -> Format.fprintf fmt "%a" format_expr e)) + es format_punctuation ")" | ETuple (es, Some s) -> - Format.fprintf fmt "@[%a@ @[%a%a%a@]@]" - Ast.StructName.format_t s format_punctuation "{" - (Format.pp_print_list - ~pp_sep:(fun fmt () -> - Format.fprintf fmt "%a@ " format_punctuation ";") - (fun fmt (e, struct_field) -> - Format.fprintf fmt "%a%a%a%a@ %a" format_punctuation "\"" - Ast.StructFieldName.format_t struct_field format_punctuation "\"" - format_punctuation "=" format_expr e)) - (List.combine es (List.map fst (Ast.StructMap.find s ctx.ctx_structs))) - format_punctuation "}" + Format.fprintf fmt "@[%a@ @[%a%a%a@]@]" + Ast.StructName.format_t s format_punctuation "{" + (Format.pp_print_list + ~pp_sep:(fun fmt () -> + Format.fprintf fmt "%a@ " format_punctuation ";") + (fun fmt (e, struct_field) -> + Format.fprintf fmt "%a%a%a%a@ %a" format_punctuation "\"" + Ast.StructFieldName.format_t struct_field format_punctuation "\"" + format_punctuation "=" format_expr e)) + (List.combine es (List.map fst (Ast.StructMap.find s ctx.ctx_structs))) + format_punctuation "}" | EArray es -> - Format.fprintf fmt "@[%a%a%a@]" format_punctuation "[" - (Format.pp_print_list - ~pp_sep:(fun fmt () -> Format.fprintf fmt ";@ ") - (fun fmt e -> Format.fprintf fmt "%a" format_expr e)) - es format_punctuation "]" + Format.fprintf fmt "@[%a%a%a@]" format_punctuation "[" + (Format.pp_print_list + ~pp_sep:(fun fmt () -> Format.fprintf fmt ";@ ") + (fun fmt e -> Format.fprintf fmt "%a" format_expr e)) + es format_punctuation "]" | ETupleAccess (e1, n, s, _ts) -> ( - match s with - | None -> - Format.fprintf fmt "%a%a%d" format_expr e1 format_punctuation "." n - | Some s -> - Format.fprintf fmt "%a%a%a%a%a" format_expr e1 format_operator "." - format_punctuation "\"" Ast.StructFieldName.format_t - (fst (List.nth (Ast.StructMap.find s ctx.ctx_structs) n)) - format_punctuation "\"") + match s with + | None -> + Format.fprintf fmt "%a%a%d" format_expr e1 format_punctuation "." n + | Some s -> + Format.fprintf fmt "%a%a%a%a%a" format_expr e1 format_operator "." + format_punctuation "\"" Ast.StructFieldName.format_t + (fst (List.nth (Ast.StructMap.find s ctx.ctx_structs) n)) + format_punctuation "\"") | EInj (e, n, en, _ts) -> - Format.fprintf fmt "@[%a@ %a@]" format_enum_constructor - (fst (List.nth (Ast.EnumMap.find en ctx.ctx_enums) n)) - format_expr e + Format.fprintf fmt "@[%a@ %a@]" format_enum_constructor + (fst (List.nth (Ast.EnumMap.find en ctx.ctx_enums) n)) + format_expr e | EMatch (e, es, e_name) -> - Format.fprintf fmt "@[%a@ @[%a@]@ %a@ %a@]" format_keyword - "match" format_expr e format_keyword "with" - (Format.pp_print_list - ~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n") - (fun fmt (e, c) -> - Format.fprintf fmt "@[%a %a%a@ %a@]" format_punctuation "|" - format_enum_constructor c format_punctuation ":" format_expr e)) - (List.combine es (List.map fst (Ast.EnumMap.find e_name ctx.ctx_enums))) + Format.fprintf fmt "@[%a@ @[%a@]@ %a@ %a@]" format_keyword + "match" format_expr e format_keyword "with" + (Format.pp_print_list + ~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n") + (fun fmt (e, c) -> + Format.fprintf fmt "@[%a %a%a@ %a@]" format_punctuation "|" + format_enum_constructor c format_punctuation ":" format_expr e)) + (List.combine es (List.map fst (Ast.EnumMap.find e_name ctx.ctx_enums))) | ELit l -> format_lit fmt (Pos.same_pos_as l e) | EApp ((EAbs ((binder, _), taus), _), args) -> - let xs, body = Bindlib.unmbind binder in - let xs_tau = List.map2 (fun x tau -> (x, tau)) (Array.to_list xs) taus in - let xs_tau_arg = - List.map2 (fun (x, tau) arg -> (x, tau, arg)) xs_tau args - in - Format.fprintf fmt "%a%a" - (Format.pp_print_list - ~pp_sep:(fun fmt () -> Format.fprintf fmt "") - (fun fmt (x, tau, arg) -> - Format.fprintf fmt "@[%a@ %a@ %a@ %a@ %a@ %a@ %a@]@\n" - format_keyword "let" format_var x format_punctuation ":" - (format_typ ctx) tau format_punctuation "=" format_expr arg - format_keyword "in")) - xs_tau_arg format_expr body + let xs, body = Bindlib.unmbind binder in + let xs_tau = List.map2 (fun x tau -> x, tau) (Array.to_list xs) taus in + let xs_tau_arg = List.map2 (fun (x, tau) arg -> x, tau, arg) xs_tau args in + Format.fprintf fmt "%a%a" + (Format.pp_print_list + ~pp_sep:(fun fmt () -> Format.fprintf fmt "") + (fun fmt (x, tau, arg) -> + Format.fprintf fmt "@[%a@ %a@ %a@ %a@ %a@ %a@ %a@]@\n" + format_keyword "let" format_var x format_punctuation ":" + (format_typ ctx) tau format_punctuation "=" format_expr arg + format_keyword "in")) + xs_tau_arg format_expr body | EAbs ((binder, _), taus) -> - let xs, body = Bindlib.unmbind binder in - let xs_tau = List.map2 (fun x tau -> (x, tau)) (Array.to_list xs) taus in - Format.fprintf fmt "@[%a @[%a@] %a@ %a@]" format_punctuation - "λ" - (Format.pp_print_list - ~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ") - (fun fmt (x, tau) -> - Format.fprintf fmt "%a%a%a %a%a" format_punctuation "(" format_var - x format_punctuation ":" (format_typ ctx) tau format_punctuation - ")")) - xs_tau format_punctuation "→" format_expr body - | EApp ((EOp (Binop ((Ast.Map | Ast.Filter) as op)), _), [ arg1; arg2 ]) -> - Format.fprintf fmt "@[%a@ %a@ %a@]" format_binop (op, Pos.no_pos) - format_with_parens arg1 format_with_parens arg2 - | EApp ((EOp (Binop op), _), [ arg1; arg2 ]) -> - Format.fprintf fmt "@[%a@ %a@ %a@]" format_with_parens arg1 - format_binop (op, Pos.no_pos) format_with_parens arg2 - | EApp ((EOp (Unop (Log _)), _), [ arg1 ]) when not debug -> - format_expr fmt arg1 - | EApp ((EOp (Unop op), _), [ arg1 ]) -> - Format.fprintf fmt "@[%a@ %a@]" format_unop (op, Pos.no_pos) - format_with_parens arg1 + let xs, body = Bindlib.unmbind binder in + let xs_tau = List.map2 (fun x tau -> x, tau) (Array.to_list xs) taus in + Format.fprintf fmt "@[%a @[%a@] %a@ %a@]" format_punctuation + "λ" + (Format.pp_print_list + ~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ") + (fun fmt (x, tau) -> + Format.fprintf fmt "%a%a%a %a%a" format_punctuation "(" format_var x + format_punctuation ":" (format_typ ctx) tau format_punctuation ")")) + xs_tau format_punctuation "→" format_expr body + | EApp ((EOp (Binop ((Ast.Map | Ast.Filter) as op)), _), [arg1; arg2]) -> + Format.fprintf fmt "@[%a@ %a@ %a@]" format_binop (op, Pos.no_pos) + format_with_parens arg1 format_with_parens arg2 + | EApp ((EOp (Binop op), _), [arg1; arg2]) -> + Format.fprintf fmt "@[%a@ %a@ %a@]" format_with_parens arg1 + format_binop (op, Pos.no_pos) format_with_parens arg2 + | EApp ((EOp (Unop (Log _)), _), [arg1]) when not debug -> + format_expr fmt arg1 + | EApp ((EOp (Unop op), _), [arg1]) -> + Format.fprintf fmt "@[%a@ %a@]" format_unop (op, Pos.no_pos) + format_with_parens arg1 | EApp (f, args) -> - Format.fprintf fmt "@[%a@ %a@]" format_expr f - (Format.pp_print_list - ~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ") - format_with_parens) - args + Format.fprintf fmt "@[%a@ %a@]" format_expr f + (Format.pp_print_list + ~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ") + format_with_parens) + args | EIfThenElse (e1, e2, e3) -> - Format.fprintf fmt "@[%a@ %a@ %a@ %a@ %a@ %a@]" format_keyword "if" - format_expr e1 format_keyword "then" format_expr e2 format_keyword - "else" format_expr e3 + Format.fprintf fmt "@[%a@ %a@ %a@ %a@ %a@ %a@]" format_keyword "if" + format_expr e1 format_keyword "then" format_expr e2 format_keyword "else" + format_expr e3 | EOp (Ternop op) -> Format.fprintf fmt "%a" format_ternop (op, Pos.no_pos) | EOp (Binop op) -> Format.fprintf fmt "%a" format_binop (op, Pos.no_pos) | EOp (Unop op) -> Format.fprintf fmt "%a" format_unop (op, Pos.no_pos) | EDefault (exceptions, just, cons) -> - if List.length exceptions = 0 then - Format.fprintf fmt "@[%a%a@ %a@ %a%a@]" format_punctuation "⟨" - format_expr just format_punctuation "⊢" format_expr cons - format_punctuation "⟩" - else - Format.fprintf fmt "@[%a%a@ %a@ %a@ %a@ %a%a@]" - format_punctuation "⟨" - (Format.pp_print_list - ~pp_sep:(fun fmt () -> - Format.fprintf fmt "%a@ " format_punctuation ",") - format_expr) - exceptions format_punctuation "|" format_expr just format_punctuation - "⊢" format_expr cons format_punctuation "⟩" + if List.length exceptions = 0 then + Format.fprintf fmt "@[%a%a@ %a@ %a%a@]" format_punctuation "⟨" + format_expr just format_punctuation "⊢" format_expr cons + format_punctuation "⟩" + else + Format.fprintf fmt "@[%a%a@ %a@ %a@ %a@ %a%a@]" format_punctuation + "⟨" + (Format.pp_print_list + ~pp_sep:(fun fmt () -> + Format.fprintf fmt "%a@ " format_punctuation ",") + format_expr) + exceptions format_punctuation "|" format_expr just format_punctuation + "⊢" format_expr cons format_punctuation "⟩" | ErrorOnEmpty e' -> - Format.fprintf fmt "%a@ %a" format_operator "error_empty" - format_with_parens e' + Format.fprintf fmt "%a@ %a" format_operator "error_empty" format_with_parens + e' | EAssert e' -> - Format.fprintf fmt "@[%a@ %a%a%a@]" format_keyword "assert" - format_punctuation "(" format_expr e' format_punctuation ")" + Format.fprintf fmt "@[%a@ %a%a%a@]" format_keyword "assert" + format_punctuation "(" format_expr e' format_punctuation ")" let format_scope ?(debug : bool = false) diff --git a/compiler/dcalc/typing.ml b/compiler/dcalc/typing.ml index b4cf79c1..a6097c48 100644 --- a/compiler/dcalc/typing.ml +++ b/compiler/dcalc/typing.ml @@ -52,7 +52,8 @@ let rec format_typ (typ : typ Pos.marked UnionFind.elem) : unit = let format_typ = format_typ ctx in let format_typ_with_parens - (fmt : Format.formatter) (t : typ Pos.marked UnionFind.elem) = + (fmt : Format.formatter) + (t : typ Pos.marked UnionFind.elem) = if typ_needs_parens t then Format.fprintf fmt "(%a)" format_typ t else Format.fprintf fmt "%a" format_typ t in @@ -60,16 +61,16 @@ let rec format_typ match Pos.unmark typ with | TLit l -> Format.fprintf fmt "%a" Print.format_tlit l | TTuple (ts, None) -> - Format.fprintf fmt "@[(%a)]" - (Format.pp_print_list - ~pp_sep:(fun fmt () -> Format.fprintf fmt "@ *@ ") - (fun fmt t -> Format.fprintf fmt "%a" format_typ t)) - ts + Format.fprintf fmt "@[(%a)]" + (Format.pp_print_list + ~pp_sep:(fun fmt () -> Format.fprintf fmt "@ *@ ") + (fun fmt t -> Format.fprintf fmt "%a" format_typ t)) + ts | TTuple (_ts, Some s) -> Format.fprintf fmt "%a" Ast.StructName.format_t s | TEnum (_ts, e) -> Format.fprintf fmt "%a" Ast.EnumName.format_t e | TArrow (t1, t2) -> - Format.fprintf fmt "@[%a →@ %a@]" format_typ_with_parens t1 - format_typ t2 + Format.fprintf fmt "@[%a →@ %a@]" format_typ_with_parens t1 + format_typ t2 | TArray t1 -> Format.fprintf fmt "@[%a@ array@]" format_typ t1 | TAny d -> Format.fprintf fmt "any[%d]" (Any.hash d) @@ -87,50 +88,50 @@ let rec unify (* TODO: if we get weird error messages, then it means that we should use the persistent version of the union-find data structure. *) let t1_s = - Cli.with_style [ ANSITerminal.yellow ] "%s" + Cli.with_style [ANSITerminal.yellow] "%s" (Re.Pcre.substitute ~rex:(Re.Pcre.regexp "\n\\s*") ~subst:(fun _ -> " ") (Format.asprintf "%a" (format_typ ctx) t1)) in let t2_s = - Cli.with_style [ ANSITerminal.yellow ] "%s" + Cli.with_style [ANSITerminal.yellow] "%s" (Re.Pcre.substitute ~rex:(Re.Pcre.regexp "\n\\s*") ~subst:(fun _ -> " ") (Format.asprintf "%a" (format_typ ctx) t2)) in Errors.raise_multispanned_error [ - (Some (Format.asprintf "Type %s coming from expression:" t1_s), t1_pos); - (Some (Format.asprintf "Type %s coming from expression:" t2_s), t2_pos); + Some (Format.asprintf "Type %s coming from expression:" t1_s), t1_pos; + Some (Format.asprintf "Type %s coming from expression:" t2_s), t2_pos; ] "Error during typechecking, incompatible types:\n%a %s\n%a %s" - (Cli.format_with_style [ ANSITerminal.blue; ANSITerminal.Bold ]) + (Cli.format_with_style [ANSITerminal.blue; ANSITerminal.Bold]) "-->" t1_s - (Cli.format_with_style [ ANSITerminal.blue; ANSITerminal.Bold ]) + (Cli.format_with_style [ANSITerminal.blue; ANSITerminal.Bold]) "-->" t2_s in let repr = - match (t1_repr, t2_repr) with + match t1_repr, t2_repr with | (TLit tl1, _), (TLit tl2, _) when tl1 = tl2 -> None | (TArrow (t11, t12), _), (TArrow (t21, t22), _) -> - unify t11 t21; - unify t12 t22; - None + unify t11 t21; + unify t12 t22; + None | (TTuple (ts1, s1), t1_pos), (TTuple (ts2, s2), t2_pos) -> - if s1 = s2 && List.length ts1 = List.length ts2 then begin - List.iter2 unify ts1 ts2; - None - end - else raise_type_error t1_pos t2_pos - | (TEnum (ts1, e1), t1_pos), (TEnum (ts2, e2), t2_pos) -> - if e1 = e2 && List.length ts1 = List.length ts2 then begin - List.iter2 unify ts1 ts2; - None - end - else raise_type_error t1_pos t2_pos - | (TArray t1', _), (TArray t2', _) -> - unify t1' t2'; + if s1 = s2 && List.length ts1 = List.length ts2 then begin + List.iter2 unify ts1 ts2; None + end + else raise_type_error t1_pos t2_pos + | (TEnum (ts1, e1), t1_pos), (TEnum (ts2, e2), t2_pos) -> + if e1 = e2 && List.length ts1 = List.length ts2 then begin + List.iter2 unify ts1 ts2; + None + end + else raise_type_error t1_pos t2_pos + | (TArray t1', _), (TArray t2', _) -> + unify t1' t2'; + None | (TAny _, _), (TAny _, _) -> None | (TAny _, _), t_repr | t_repr, (TAny _, _) -> Some t_repr | (_, t1_pos), (_, t2_pos) -> raise_type_error t1_pos t2_pos @@ -157,12 +158,12 @@ let op_type (op : A.operator Pos.marked) : typ Pos.marked UnionFind.elem = let arr x y = UnionFind.make (TArrow (x, y), pos) in match Pos.unmark op with | A.Ternop A.Fold -> - arr (arr any2 (arr any any2)) (arr any2 (arr array_any any2)) + arr (arr any2 (arr any any2)) (arr any2 (arr array_any any2)) | A.Binop (A.And | A.Or | A.Xor) -> arr bt (arr bt bt) | A.Binop (A.Add KInt | A.Sub KInt | A.Mult KInt | A.Div KInt) -> - arr it (arr it it) + arr it (arr it it) | A.Binop (A.Add KRat | A.Sub KRat | A.Mult KRat | A.Div KRat) -> - arr rt (arr rt rt) + arr rt (arr rt rt) | A.Binop (A.Add KMoney | A.Sub KMoney) -> arr mt (arr mt mt) | A.Binop (A.Add KDuration | A.Sub KDuration) -> arr dut (arr dut dut) | A.Binop (A.Sub KDate) -> arr dat (arr dat dut) @@ -171,16 +172,16 @@ let op_type (op : A.operator Pos.marked) : typ Pos.marked UnionFind.elem = | A.Binop (A.Div KMoney) -> arr mt (arr mt rt) | A.Binop (A.Mult KMoney) -> arr mt (arr rt mt) | A.Binop (A.Lt KInt | A.Lte KInt | A.Gt KInt | A.Gte KInt) -> - arr it (arr it bt) + arr it (arr it bt) | A.Binop (A.Lt KRat | A.Lte KRat | A.Gt KRat | A.Gte KRat) -> - arr rt (arr rt bt) + arr rt (arr rt bt) | A.Binop (A.Lt KMoney | A.Lte KMoney | A.Gt KMoney | A.Gte KMoney) -> - arr mt (arr mt bt) + arr mt (arr mt bt) | A.Binop (A.Lt KDate | A.Lte KDate | A.Gt KDate | A.Gte KDate) -> - arr dat (arr dat bt) + arr dat (arr dat bt) | A.Binop (A.Lt KDuration | A.Lte KDuration | A.Gt KDuration | A.Gte KDuration) -> - arr dut (arr dut bt) + arr dut (arr dut bt) | A.Binop (A.Eq | A.Neq) -> arr any (arr any bt) | A.Binop A.Map -> arr (arr any any2) (arr array_any array_any2) | A.Binop A.Filter -> arr (arr any bt) (arr array_any array_any) @@ -200,23 +201,23 @@ let op_type (op : A.operator Pos.marked) : typ Pos.marked UnionFind.elem = | A.Unop A.RoundDecimal -> arr rt rt | A.Unop A.IntToRat -> arr it rt | Binop (Mult (KDate | KDuration)) | Binop (Div KDate) | Unop (Minus KDate) -> - Errors.raise_spanned_error pos "This operator is not available!" + Errors.raise_spanned_error pos "This operator is not available!" let rec ast_to_typ (ty : A.typ) : typ = match ty with | A.TLit l -> TLit l | A.TArrow (t1, t2) -> - TArrow - ( UnionFind.make (Pos.map_under_mark ast_to_typ t1), - UnionFind.make (Pos.map_under_mark ast_to_typ t2) ) + TArrow + ( UnionFind.make (Pos.map_under_mark ast_to_typ t1), + UnionFind.make (Pos.map_under_mark ast_to_typ t2) ) | A.TTuple (ts, s) -> - TTuple - ( List.map (fun t -> UnionFind.make (Pos.map_under_mark ast_to_typ t)) ts, - s ) + TTuple + ( List.map (fun t -> UnionFind.make (Pos.map_under_mark ast_to_typ t)) ts, + s ) | A.TEnum (ts, e) -> - TEnum - ( List.map (fun t -> UnionFind.make (Pos.map_under_mark ast_to_typ t)) ts, - e ) + TEnum + ( List.map (fun t -> UnionFind.make (Pos.map_under_mark ast_to_typ t)) ts, + e ) | A.TArray t -> TArray (UnionFind.make (Pos.map_under_mark ast_to_typ t)) | A.TAny -> TAny (Any.fresh ()) @@ -238,155 +239,152 @@ type env = typ Pos.marked UnionFind.elem A.VarMap.t (** Infers the most permissive type from an expression *) let rec typecheck_expr_bottom_up - (ctx : Ast.decl_ctx) (env : env) (e : A.expr Pos.marked) : - typ Pos.marked UnionFind.elem = + (ctx : Ast.decl_ctx) + (env : env) + (e : A.expr Pos.marked) : typ Pos.marked UnionFind.elem = (* Cli.debug_print (Format.asprintf "Looking for type of %a" (Print.format_expr ctx) e); *) try let out = match Pos.unmark e with | EVar v -> ( - match A.VarMap.find_opt (Pos.unmark v) env with - | Some t -> t - | None -> - Errors.raise_spanned_error (Pos.get_position e) - "Variable not found in the current context") + match A.VarMap.find_opt (Pos.unmark v) env with + | Some t -> t + | None -> + Errors.raise_spanned_error (Pos.get_position e) + "Variable not found in the current context") | ELit (LBool _) -> UnionFind.make (Pos.same_pos_as (TLit TBool) e) | ELit (LInt _) -> UnionFind.make (Pos.same_pos_as (TLit TInt) e) | ELit (LRat _) -> UnionFind.make (Pos.same_pos_as (TLit TRat) e) | ELit (LMoney _) -> UnionFind.make (Pos.same_pos_as (TLit TMoney) e) | ELit (LDate _) -> UnionFind.make (Pos.same_pos_as (TLit TDate) e) | ELit (LDuration _) -> - UnionFind.make (Pos.same_pos_as (TLit TDuration) e) + UnionFind.make (Pos.same_pos_as (TLit TDuration) e) | ELit LUnit -> UnionFind.make (Pos.same_pos_as (TLit TUnit) e) | ELit LEmptyError -> - UnionFind.make (Pos.same_pos_as (TAny (Any.fresh ())) e) + UnionFind.make (Pos.same_pos_as (TAny (Any.fresh ())) e) | ETuple (es, s) -> - let ts = List.map (typecheck_expr_bottom_up ctx env) es in - UnionFind.make (Pos.same_pos_as (TTuple (ts, s)) e) + let ts = List.map (typecheck_expr_bottom_up ctx env) es in + UnionFind.make (Pos.same_pos_as (TTuple (ts, s)) e) | ETupleAccess (e1, n, s, typs) -> ( - let typs = - List.map - (fun typ -> UnionFind.make (Pos.map_under_mark ast_to_typ typ)) - typs - in - typecheck_expr_top_down ctx env e1 - (UnionFind.make (TTuple (typs, s), Pos.get_position e)); - match List.nth_opt typs n with - | Some t' -> t' - | None -> - Errors.raise_spanned_error (Pos.get_position e1) - "Expression should have a tuple type with at least %d elements \ - but only has %d" - n (List.length typs)) + let typs = + List.map + (fun typ -> UnionFind.make (Pos.map_under_mark ast_to_typ typ)) + typs + in + typecheck_expr_top_down ctx env e1 + (UnionFind.make (TTuple (typs, s), Pos.get_position e)); + match List.nth_opt typs n with + | Some t' -> t' + | None -> + Errors.raise_spanned_error (Pos.get_position e1) + "Expression should have a tuple type with at least %d elements but \ + only has %d" + n (List.length typs)) | EInj (e1, n, e_name, ts) -> - let ts = - List.map - (fun t -> UnionFind.make (Pos.map_under_mark ast_to_typ t)) - ts - in - let ts_n = - match List.nth_opt ts n with - | Some ts_n -> ts_n - | None -> - Errors.raise_spanned_error (Pos.get_position e) - "Expression should have a sum type with at least %d cases \ - but only has %d" - n (List.length ts) - in - typecheck_expr_top_down ctx env e1 ts_n; - UnionFind.make (Pos.same_pos_as (TEnum (ts, e_name)) e) + let ts = + List.map + (fun t -> UnionFind.make (Pos.map_under_mark ast_to_typ t)) + ts + in + let ts_n = + match List.nth_opt ts n with + | Some ts_n -> ts_n + | None -> + Errors.raise_spanned_error (Pos.get_position e) + "Expression should have a sum type with at least %d cases but \ + only has %d" + n (List.length ts) + in + typecheck_expr_top_down ctx env e1 ts_n; + UnionFind.make (Pos.same_pos_as (TEnum (ts, e_name)) e) | EMatch (e1, es, e_name) -> - let enum_cases = - List.map - (fun e' -> - UnionFind.make (Pos.same_pos_as (TAny (Any.fresh ())) e')) - es - in - let t_e1 = - UnionFind.make (Pos.same_pos_as (TEnum (enum_cases, e_name)) e1) - in - typecheck_expr_top_down ctx env e1 t_e1; - let t_ret = - UnionFind.make (Pos.same_pos_as (TAny (Any.fresh ())) e) - in - List.iteri - (fun i es' -> - let enum_t = List.nth enum_cases i in - let t_es' = - UnionFind.make (Pos.same_pos_as (TArrow (enum_t, t_ret)) es') - in - typecheck_expr_top_down ctx env es' t_es') - es; - t_ret + let enum_cases = + List.map + (fun e' -> + UnionFind.make (Pos.same_pos_as (TAny (Any.fresh ())) e')) + es + in + let t_e1 = + UnionFind.make (Pos.same_pos_as (TEnum (enum_cases, e_name)) e1) + in + typecheck_expr_top_down ctx env e1 t_e1; + let t_ret = UnionFind.make (Pos.same_pos_as (TAny (Any.fresh ())) e) in + List.iteri + (fun i es' -> + let enum_t = List.nth enum_cases i in + let t_es' = + UnionFind.make (Pos.same_pos_as (TArrow (enum_t, t_ret)) es') + in + typecheck_expr_top_down ctx env es' t_es') + es; + t_ret | EAbs ((binder, pos_binder), taus) -> - let xs, body = Bindlib.unmbind binder in - if Array.length xs = List.length taus then - let xstaus = - List.map2 - (fun x tau -> - ( x, - UnionFind.make - (ast_to_typ (Pos.unmark tau), Pos.get_position tau) )) - (Array.to_list xs) taus - in - let env = - List.fold_left - (fun env (x, tau) -> A.VarMap.add x tau env) - env xstaus - in - List.fold_right - (fun (_, t_arg) (acc : typ Pos.marked UnionFind.elem) -> - UnionFind.make (TArrow (t_arg, acc), pos_binder)) - xstaus - (typecheck_expr_bottom_up ctx env body) - else - Errors.raise_spanned_error pos_binder - "function has %d variables but was supplied %d types" - (Array.length xs) (List.length taus) + let xs, body = Bindlib.unmbind binder in + if Array.length xs = List.length taus then + let xstaus = + List.map2 + (fun x tau -> + ( x, + UnionFind.make + (ast_to_typ (Pos.unmark tau), Pos.get_position tau) )) + (Array.to_list xs) taus + in + let env = + List.fold_left + (fun env (x, tau) -> A.VarMap.add x tau env) + env xstaus + in + List.fold_right + (fun (_, t_arg) (acc : typ Pos.marked UnionFind.elem) -> + UnionFind.make (TArrow (t_arg, acc), pos_binder)) + xstaus + (typecheck_expr_bottom_up ctx env body) + else + Errors.raise_spanned_error pos_binder + "function has %d variables but was supplied %d types" + (Array.length xs) (List.length taus) | EApp (e1, args) -> - let t_args = List.map (typecheck_expr_bottom_up ctx env) args in - let t_ret = - UnionFind.make (Pos.same_pos_as (TAny (Any.fresh ())) e) - in - let t_app = - List.fold_right - (fun t_arg acc -> - UnionFind.make (Pos.same_pos_as (TArrow (t_arg, acc)) e)) - t_args t_ret - in - typecheck_expr_top_down ctx env e1 t_app; - t_ret + let t_args = List.map (typecheck_expr_bottom_up ctx env) args in + let t_ret = UnionFind.make (Pos.same_pos_as (TAny (Any.fresh ())) e) in + let t_app = + List.fold_right + (fun t_arg acc -> + UnionFind.make (Pos.same_pos_as (TArrow (t_arg, acc)) e)) + t_args t_ret + in + typecheck_expr_top_down ctx env e1 t_app; + t_ret | EOp op -> op_type (Pos.same_pos_as op e) | EDefault (excepts, just, cons) -> - typecheck_expr_top_down ctx env just - (UnionFind.make (Pos.same_pos_as (TLit TBool) just)); - let tcons = typecheck_expr_bottom_up ctx env cons in - List.iter - (fun except -> typecheck_expr_top_down ctx env except tcons) - excepts; - tcons + typecheck_expr_top_down ctx env just + (UnionFind.make (Pos.same_pos_as (TLit TBool) just)); + let tcons = typecheck_expr_bottom_up ctx env cons in + List.iter + (fun except -> typecheck_expr_top_down ctx env except tcons) + excepts; + tcons | EIfThenElse (cond, et, ef) -> - typecheck_expr_top_down ctx env cond - (UnionFind.make (Pos.same_pos_as (TLit TBool) cond)); - let tt = typecheck_expr_bottom_up ctx env et in - typecheck_expr_top_down ctx env ef tt; - tt + typecheck_expr_top_down ctx env cond + (UnionFind.make (Pos.same_pos_as (TLit TBool) cond)); + let tt = typecheck_expr_bottom_up ctx env et in + typecheck_expr_top_down ctx env ef tt; + tt | EAssert e' -> - typecheck_expr_top_down ctx env e' - (UnionFind.make (Pos.same_pos_as (TLit TBool) e')); - UnionFind.make (Pos.same_pos_as (TLit TUnit) e') + typecheck_expr_top_down ctx env e' + (UnionFind.make (Pos.same_pos_as (TLit TBool) e')); + UnionFind.make (Pos.same_pos_as (TLit TUnit) e') | ErrorOnEmpty e' -> typecheck_expr_bottom_up ctx env e' | EArray es -> - let cell_type = - UnionFind.make (Pos.same_pos_as (TAny (Any.fresh ())) e) - in - List.iter - (fun e' -> - let t_e' = typecheck_expr_bottom_up ctx env e' in - unify ctx cell_type t_e') - es; - UnionFind.make (Pos.same_pos_as (TArray cell_type) e) + let cell_type = + UnionFind.make (Pos.same_pos_as (TAny (Any.fresh ())) e) + in + List.iter + (fun e' -> + let t_e' = typecheck_expr_bottom_up ctx env e' in + unify ctx cell_type t_e') + es; + UnionFind.make (Pos.same_pos_as (TArray cell_type) e) in (* Cli.debug_print (Format.asprintf "Found type of %a: %a" (Print.format_expr ctx) e (format_typ ctx) out); *) @@ -410,154 +408,151 @@ and typecheck_expr_top_down try match Pos.unmark e with | EVar v -> ( - match A.VarMap.find_opt (Pos.unmark v) env with - | Some tau' -> ignore (unify ctx tau tau') - | None -> - Errors.raise_spanned_error (Pos.get_position e) - "Variable not found in the current context") + match A.VarMap.find_opt (Pos.unmark v) env with + | Some tau' -> ignore (unify ctx tau tau') + | None -> + Errors.raise_spanned_error (Pos.get_position e) + "Variable not found in the current context") | ELit (LBool _) -> - unify ctx tau (UnionFind.make (Pos.same_pos_as (TLit TBool) e)) + unify ctx tau (UnionFind.make (Pos.same_pos_as (TLit TBool) e)) | ELit (LInt _) -> - unify ctx tau (UnionFind.make (Pos.same_pos_as (TLit TInt) e)) + unify ctx tau (UnionFind.make (Pos.same_pos_as (TLit TInt) e)) | ELit (LRat _) -> - unify ctx tau (UnionFind.make (Pos.same_pos_as (TLit TRat) e)) + unify ctx tau (UnionFind.make (Pos.same_pos_as (TLit TRat) e)) | ELit (LMoney _) -> - unify ctx tau (UnionFind.make (Pos.same_pos_as (TLit TMoney) e)) + unify ctx tau (UnionFind.make (Pos.same_pos_as (TLit TMoney) e)) | ELit (LDate _) -> - unify ctx tau (UnionFind.make (Pos.same_pos_as (TLit TDate) e)) + unify ctx tau (UnionFind.make (Pos.same_pos_as (TLit TDate) e)) | ELit (LDuration _) -> - unify ctx tau (UnionFind.make (Pos.same_pos_as (TLit TDuration) e)) + unify ctx tau (UnionFind.make (Pos.same_pos_as (TLit TDuration) e)) | ELit LUnit -> - unify ctx tau (UnionFind.make (Pos.same_pos_as (TLit TUnit) e)) + unify ctx tau (UnionFind.make (Pos.same_pos_as (TLit TUnit) e)) | ELit LEmptyError -> - unify ctx tau (UnionFind.make (Pos.same_pos_as (TAny (Any.fresh ())) e)) + unify ctx tau (UnionFind.make (Pos.same_pos_as (TAny (Any.fresh ())) e)) | ETuple (es, s) -> - let t_es = - UnionFind.make - (Pos.same_pos_as - (TTuple (List.map (typecheck_expr_bottom_up ctx env) es, s)) - e) - in - unify ctx tau t_es + let t_es = + UnionFind.make + (Pos.same_pos_as + (TTuple (List.map (typecheck_expr_bottom_up ctx env) es, s)) + e) + in + unify ctx tau t_es | ETupleAccess (e1, n, s, typs) -> ( - let typs = - List.map - (fun typ -> UnionFind.make (Pos.map_under_mark ast_to_typ typ)) - typs - in - typecheck_expr_top_down ctx env e1 - (UnionFind.make (TTuple (typs, s), Pos.get_position e)); - match List.nth_opt typs n with - | Some t1n -> unify ctx t1n tau - | None -> - Errors.raise_spanned_error (Pos.get_position e1) - "Expression should have a tuple type with at least %d elements \ - but only has %d" - n (List.length typs)) + let typs = + List.map + (fun typ -> UnionFind.make (Pos.map_under_mark ast_to_typ typ)) + typs + in + typecheck_expr_top_down ctx env e1 + (UnionFind.make (TTuple (typs, s), Pos.get_position e)); + match List.nth_opt typs n with + | Some t1n -> unify ctx t1n tau + | None -> + Errors.raise_spanned_error (Pos.get_position e1) + "Expression should have a tuple type with at least %d elements but \ + only has %d" + n (List.length typs)) | EInj (e1, n, e_name, ts) -> - let ts = - List.map - (fun t -> UnionFind.make (Pos.map_under_mark ast_to_typ t)) - ts - in - let ts_n = - match List.nth_opt ts n with - | Some ts_n -> ts_n - | None -> - Errors.raise_spanned_error (Pos.get_position e) - "Expression should have a sum type with at least %d cases but \ - only has %d" - n (List.length ts) - in - typecheck_expr_top_down ctx env e1 ts_n; - unify ctx (UnionFind.make (Pos.same_pos_as (TEnum (ts, e_name)) e)) tau + let ts = + List.map (fun t -> UnionFind.make (Pos.map_under_mark ast_to_typ t)) ts + in + let ts_n = + match List.nth_opt ts n with + | Some ts_n -> ts_n + | None -> + Errors.raise_spanned_error (Pos.get_position e) + "Expression should have a sum type with at least %d cases but only \ + has %d" + n (List.length ts) + in + typecheck_expr_top_down ctx env e1 ts_n; + unify ctx (UnionFind.make (Pos.same_pos_as (TEnum (ts, e_name)) e)) tau | EMatch (e1, es, e_name) -> - let enum_cases = - List.map - (fun e' -> - UnionFind.make (Pos.same_pos_as (TAny (Any.fresh ())) e')) - es - in - let t_e1 = - UnionFind.make (Pos.same_pos_as (TEnum (enum_cases, e_name)) e1) - in - typecheck_expr_top_down ctx env e1 t_e1; - let t_ret = UnionFind.make (Pos.same_pos_as (TAny (Any.fresh ())) e) in - List.iteri - (fun i es' -> - let enum_t = List.nth enum_cases i in - let t_es' = - UnionFind.make (Pos.same_pos_as (TArrow (enum_t, t_ret)) es') - in - typecheck_expr_top_down ctx env es' t_es') - es; - unify ctx tau t_ret + let enum_cases = + List.map + (fun e' -> UnionFind.make (Pos.same_pos_as (TAny (Any.fresh ())) e')) + es + in + let t_e1 = + UnionFind.make (Pos.same_pos_as (TEnum (enum_cases, e_name)) e1) + in + typecheck_expr_top_down ctx env e1 t_e1; + let t_ret = UnionFind.make (Pos.same_pos_as (TAny (Any.fresh ())) e) in + List.iteri + (fun i es' -> + let enum_t = List.nth enum_cases i in + let t_es' = + UnionFind.make (Pos.same_pos_as (TArrow (enum_t, t_ret)) es') + in + typecheck_expr_top_down ctx env es' t_es') + es; + unify ctx tau t_ret | EAbs ((binder, pos_binder), t_args) -> - let xs, body = Bindlib.unmbind binder in - if Array.length xs = List.length t_args then - let xstaus = - List.map2 - (fun x t_arg -> - (x, UnionFind.make (Pos.map_under_mark ast_to_typ t_arg))) - (Array.to_list xs) t_args - in - let env = - List.fold_left - (fun env (x, t_arg) -> A.VarMap.add x t_arg env) - env xstaus - in - let t_out = typecheck_expr_bottom_up ctx env body in - let t_func = - List.fold_right - (fun (_, t_arg) acc -> - UnionFind.make (Pos.same_pos_as (TArrow (t_arg, acc)) e)) - xstaus t_out - in - unify ctx t_func tau - else - Errors.raise_spanned_error pos_binder - "function has %d variables but was supplied %d types" - (Array.length xs) (List.length t_args) - | EApp (e1, args) -> - let t_args = List.map (typecheck_expr_bottom_up ctx env) args in - let te1 = typecheck_expr_bottom_up ctx env e1 in + let xs, body = Bindlib.unmbind binder in + if Array.length xs = List.length t_args then + let xstaus = + List.map2 + (fun x t_arg -> + x, UnionFind.make (Pos.map_under_mark ast_to_typ t_arg)) + (Array.to_list xs) t_args + in + let env = + List.fold_left + (fun env (x, t_arg) -> A.VarMap.add x t_arg env) + env xstaus + in + let t_out = typecheck_expr_bottom_up ctx env body in let t_func = List.fold_right - (fun t_arg acc -> + (fun (_, t_arg) acc -> UnionFind.make (Pos.same_pos_as (TArrow (t_arg, acc)) e)) - t_args tau + xstaus t_out in - unify ctx te1 t_func + unify ctx t_func tau + else + Errors.raise_spanned_error pos_binder + "function has %d variables but was supplied %d types" + (Array.length xs) (List.length t_args) + | EApp (e1, args) -> + let t_args = List.map (typecheck_expr_bottom_up ctx env) args in + let te1 = typecheck_expr_bottom_up ctx env e1 in + let t_func = + List.fold_right + (fun t_arg acc -> + UnionFind.make (Pos.same_pos_as (TArrow (t_arg, acc)) e)) + t_args tau + in + unify ctx te1 t_func | EOp op -> - let op_typ = op_type (Pos.same_pos_as op e) in - unify ctx op_typ tau + let op_typ = op_type (Pos.same_pos_as op e) in + unify ctx op_typ tau | EDefault (excepts, just, cons) -> - typecheck_expr_top_down ctx env just - (UnionFind.make (Pos.same_pos_as (TLit TBool) just)); - typecheck_expr_top_down ctx env cons tau; - List.iter - (fun except -> typecheck_expr_top_down ctx env except tau) - excepts + typecheck_expr_top_down ctx env just + (UnionFind.make (Pos.same_pos_as (TLit TBool) just)); + typecheck_expr_top_down ctx env cons tau; + List.iter + (fun except -> typecheck_expr_top_down ctx env except tau) + excepts | EIfThenElse (cond, et, ef) -> - typecheck_expr_top_down ctx env cond - (UnionFind.make (Pos.same_pos_as (TLit TBool) cond)); - typecheck_expr_top_down ctx env et tau; - typecheck_expr_top_down ctx env ef tau + typecheck_expr_top_down ctx env cond + (UnionFind.make (Pos.same_pos_as (TLit TBool) cond)); + typecheck_expr_top_down ctx env et tau; + typecheck_expr_top_down ctx env ef tau | EAssert e' -> - typecheck_expr_top_down ctx env e' - (UnionFind.make (Pos.same_pos_as (TLit TBool) e')); - unify ctx tau (UnionFind.make (Pos.same_pos_as (TLit TUnit) e')) + typecheck_expr_top_down ctx env e' + (UnionFind.make (Pos.same_pos_as (TLit TBool) e')); + unify ctx tau (UnionFind.make (Pos.same_pos_as (TLit TUnit) e')) | ErrorOnEmpty e' -> typecheck_expr_top_down ctx env e' tau | EArray es -> - let cell_type = - UnionFind.make (Pos.same_pos_as (TAny (Any.fresh ())) e) - in - List.iter - (fun e' -> - let t_e' = typecheck_expr_bottom_up ctx env e' in - unify ctx cell_type t_e') - es; - unify ctx tau (UnionFind.make (Pos.same_pos_as (TArray cell_type) e)) + let cell_type = + UnionFind.make (Pos.same_pos_as (TAny (Any.fresh ())) e) + in + List.iter + (fun e' -> + let t_e' = typecheck_expr_bottom_up ctx env e' in + unify ctx cell_type t_e') + es; + unify ctx tau (UnionFind.make (Pos.same_pos_as (TArray cell_type) e)) with Errors.StructuredError (msg, err_pos) when List.length err_pos = 2 -> raise (Errors.StructuredError @@ -575,6 +570,8 @@ let infer_type (ctx : Ast.decl_ctx) (e : A.expr Pos.marked) : A.typ Pos.marked = (** Typechecks an expression given an expected type *) let check_type - (ctx : Ast.decl_ctx) (e : A.expr Pos.marked) (tau : A.typ Pos.marked) = + (ctx : Ast.decl_ctx) + (e : A.expr Pos.marked) + (tau : A.typ Pos.marked) = typecheck_expr_top_down ctx A.VarMap.empty e (UnionFind.make (Pos.map_under_mark ast_to_typ tau)) diff --git a/compiler/desugared/ast.ml b/compiler/desugared/ast.ml index 3e0398e7..d03df46c 100644 --- a/compiler/desugared/ast.ml +++ b/compiler/desugared/ast.ml @@ -53,42 +53,42 @@ module ScopeDef = struct subscope's original declaration *) let compare x y = - match (x, y) with + match x, y with | Var (x, None), Var (y, None) | Var (x, Some _), Var (y, None) | Var (x, None), Var (y, Some _) | Var (x, _), SubScopeVar (_, y) | SubScopeVar (_, x), Var (y, _) -> - ScopeVar.compare x y + ScopeVar.compare x y | Var (x, Some sx), Var (y, Some sy) -> - let cmp = ScopeVar.compare x y in - if cmp = 0 then StateName.compare sx sy else cmp + let cmp = ScopeVar.compare x y in + if cmp = 0 then StateName.compare sx sy else cmp | SubScopeVar (x', x), SubScopeVar (y', y) -> - let cmp = Scopelang.Ast.SubScopeName.compare x' y' in - if cmp = 0 then ScopeVar.compare x y else cmp + let cmp = Scopelang.Ast.SubScopeName.compare x' y' in + if cmp = 0 then ScopeVar.compare x y else cmp let get_position x = match x with | Var (x, None) -> Pos.get_position (ScopeVar.get_info x) | Var (_, Some sx) -> Pos.get_position (StateName.get_info sx) | SubScopeVar (x, _) -> - Pos.get_position (Scopelang.Ast.SubScopeName.get_info x) + Pos.get_position (Scopelang.Ast.SubScopeName.get_info x) let format_t fmt x = match x with | Var (v, None) -> ScopeVar.format_t fmt v | Var (v, Some sv) -> - Format.fprintf fmt "%a.%a" ScopeVar.format_t v StateName.format_t sv + Format.fprintf fmt "%a.%a" ScopeVar.format_t v StateName.format_t sv | SubScopeVar (s, v) -> - Format.fprintf fmt "%a.%a" Scopelang.Ast.SubScopeName.format_t s - ScopeVar.format_t v + Format.fprintf fmt "%a.%a" Scopelang.Ast.SubScopeName.format_t s + ScopeVar.format_t v let hash x = match x with | Var (v, None) -> ScopeVar.hash v | Var (v, Some sv) -> Int.logxor (ScopeVar.hash v) (StateName.hash sv) | SubScopeVar (w, v) -> - Int.logxor (Scopelang.Ast.SubScopeName.hash w) (ScopeVar.hash v) + Int.logxor (Scopelang.Ast.SubScopeName.hash w) (ScopeVar.hash v) end module ScopeDefMap : Map.S with type key = ScopeDef.t = Map.Make (ScopeDef) @@ -108,18 +108,18 @@ Set.Make (struct type t = location Pos.marked let compare x y = - match (Pos.unmark x, Pos.unmark y) with + match Pos.unmark x, Pos.unmark y with | ScopeVar (vx, None), ScopeVar (vy, None) | ScopeVar (vx, Some _), ScopeVar (vy, None) | ScopeVar (vx, None), ScopeVar (vy, Some _) -> - ScopeVar.compare (Pos.unmark vx) (Pos.unmark vy) + ScopeVar.compare (Pos.unmark vx) (Pos.unmark vy) | ScopeVar ((x, _), Some sx), ScopeVar ((y, _), Some sy) -> - let cmp = ScopeVar.compare x y in - if cmp = 0 then StateName.compare sx sy else cmp + let cmp = ScopeVar.compare x y in + if cmp = 0 then StateName.compare sx sy else cmp | ( SubScopeVar (_, (xsubindex, _), (xsubvar, _)), SubScopeVar (_, (ysubindex, _), (ysubvar, _)) ) -> - let c = Scopelang.Ast.SubScopeName.compare xsubindex ysubindex in - if c = 0 then ScopeVar.compare xsubvar ysubvar else c + let c = Scopelang.Ast.SubScopeName.compare xsubindex ysubindex in + if c = 0 then ScopeVar.compare xsubvar ysubvar else c | ScopeVar _, SubScopeVar _ -> -1 | SubScopeVar _, ScopeVar _ -> 1 end) @@ -177,8 +177,8 @@ type rule = { } let empty_rule - (pos : Pos.t) (have_parameter : Scopelang.Ast.typ Pos.marked option) : rule - = + (pos : Pos.t) + (have_parameter : Scopelang.Ast.typ Pos.marked option) : rule = { rule_just = Bindlib.box (ELit (Dcalc.Ast.LBool false), pos); rule_cons = Bindlib.box (ELit Dcalc.Ast.LEmptyError, pos); @@ -186,13 +186,13 @@ let empty_rule (match have_parameter with | Some typ -> Some (Var.make ("dummy", pos), typ) | None -> None); - rule_exception_to_rules = (RuleSet.empty, pos); + rule_exception_to_rules = RuleSet.empty, pos; rule_id = RuleName.fresh ("empty", pos); } let always_false_rule - (pos : Pos.t) (have_parameter : Scopelang.Ast.typ Pos.marked option) : rule - = + (pos : Pos.t) + (have_parameter : Scopelang.Ast.typ Pos.marked option) : rule = { rule_just = Bindlib.box (ELit (Dcalc.Ast.LBool true), pos); rule_cons = Bindlib.box (ELit (Dcalc.Ast.LBool false), pos); @@ -200,7 +200,7 @@ let always_false_rule (match have_parameter with | Some typ -> Some (Var.make ("dummy", pos), typ) | None -> None); - rule_exception_to_rules = (RuleSet.empty, pos); + rule_exception_to_rules = RuleSet.empty, pos; rule_id = RuleName.fresh ("always_false", pos); } @@ -242,34 +242,34 @@ let rec locations_used (e : expr Pos.marked) : LocationSet.t = | ELocation l -> LocationSet.singleton (l, Pos.get_position e) | EVar _ | ELit _ | EOp _ -> LocationSet.empty | EAbs ((binder, _), _) -> - let _, body = Bindlib.unmbind binder in - locations_used body + let _, body = Bindlib.unmbind binder in + locations_used body | EStruct (_, es) -> - Scopelang.Ast.StructFieldMap.fold - (fun _ e' acc -> LocationSet.union acc (locations_used e')) - es LocationSet.empty + Scopelang.Ast.StructFieldMap.fold + (fun _ e' acc -> LocationSet.union acc (locations_used e')) + es LocationSet.empty | EStructAccess (e1, _, _) -> locations_used e1 | EEnumInj (e1, _, _) -> locations_used e1 | EMatch (e1, _, es) -> - Scopelang.Ast.EnumConstructorMap.fold - (fun _ e' acc -> LocationSet.union acc (locations_used e')) - es (locations_used e1) + Scopelang.Ast.EnumConstructorMap.fold + (fun _ e' acc -> LocationSet.union acc (locations_used e')) + es (locations_used e1) | EApp (e1, args) -> - List.fold_left - (fun acc arg -> LocationSet.union (locations_used arg) acc) - (locations_used e1) args + List.fold_left + (fun acc arg -> LocationSet.union (locations_used arg) acc) + (locations_used e1) args | EIfThenElse (e1, e2, e3) -> - LocationSet.union (locations_used e1) - (LocationSet.union (locations_used e2) (locations_used e3)) + LocationSet.union (locations_used e1) + (LocationSet.union (locations_used e2) (locations_used e3)) | EDefault (excepts, just, cons) -> - List.fold_left - (fun acc except -> LocationSet.union (locations_used except) acc) - (LocationSet.union (locations_used just) (locations_used cons)) - excepts + List.fold_left + (fun acc except -> LocationSet.union (locations_used except) acc) + (LocationSet.union (locations_used just) (locations_used cons)) + excepts | EArray es -> - List.fold_left - (fun acc e' -> LocationSet.union acc (locations_used e')) - LocationSet.empty es + List.fold_left + (fun acc e' -> LocationSet.union acc (locations_used e')) + LocationSet.empty es | ErrorOnEmpty e' -> locations_used e' let free_variables (def : rule RuleMap.t) : Pos.t ScopeDefMap.t = @@ -281,7 +281,7 @@ let free_variables (def : rule RuleMap.t) : Pos.t ScopeDefMap.t = (match loc with | ScopeVar (v, st) -> ScopeDef.Var (Pos.unmark v, st) | SubScopeVar (_, sub_index, sub_var) -> - ScopeDef.SubScopeVar (Pos.unmark sub_index, Pos.unmark sub_var)) + ScopeDef.SubScopeVar (Pos.unmark sub_index, Pos.unmark sub_var)) loc_pos acc) locs acc in @@ -296,7 +296,7 @@ let free_variables (def : rule RuleMap.t) : Pos.t ScopeDefMap.t = def ScopeDefMap.empty let make_var ((x, pos) : Var.t Pos.marked) : expr Pos.marked Bindlib.box = - Bindlib.box_apply (fun v -> (v, pos)) (Bindlib.box_var x) + Bindlib.box_apply (fun v -> v, pos) (Bindlib.box_var x) let make_abs (xs : vars) @@ -305,14 +305,14 @@ let make_abs (taus : Scopelang.Ast.typ Pos.marked list) (pos : Pos.t) : expr Pos.marked Bindlib.box = Bindlib.box_apply - (fun b -> (EAbs ((b, pos_binder), taus), pos)) + (fun b -> EAbs ((b, pos_binder), taus), pos) (Bindlib.bind_mvar xs e) let make_app (e : expr Pos.marked Bindlib.box) (u : expr Pos.marked Bindlib.box list) (pos : Pos.t) : expr Pos.marked Bindlib.box = - Bindlib.box_apply2 (fun e u -> (EApp (e, u), pos)) e (Bindlib.box_list u) + Bindlib.box_apply2 (fun e u -> EApp (e, u), pos) e (Bindlib.box_list u) let make_let_in (x : Var.t) @@ -320,13 +320,11 @@ let make_let_in (e1 : expr Pos.marked Bindlib.box) (e2 : expr Pos.marked Bindlib.box) : expr Pos.marked Bindlib.box = Bindlib.box_apply2 - (fun e u -> (EApp (e, u), Pos.get_position (Bindlib.unbox e2))) - (make_abs - (Array.of_list [ x ]) - e2 + (fun e u -> EApp (e, u), Pos.get_position (Bindlib.unbox e2)) + (make_abs (Array.of_list [x]) e2 (Pos.get_position (Bindlib.unbox e2)) - [ tau ] + [tau] (Pos.get_position (Bindlib.unbox e2))) - (Bindlib.box_list [ e1 ]) + (Bindlib.box_list [e1]) module VarMap = Map.Make (Var) diff --git a/compiler/desugared/dependency.ml b/compiler/desugared/dependency.ml index b6f5f905..1dd2d4c4 100644 --- a/compiler/desugared/dependency.ml +++ b/compiler/desugared/dependency.ml @@ -40,16 +40,16 @@ module Vertex = struct match x with | Var (x, None) -> Ast.ScopeVar.hash x | Var (x, Some sx) -> - Int.logxor (Ast.ScopeVar.hash x) (Ast.StateName.hash sx) + Int.logxor (Ast.ScopeVar.hash x) (Ast.StateName.hash sx) | SubScope x -> Scopelang.Ast.SubScopeName.hash x let compare = compare let equal x y = - match (x, y) with + match x, y with | Var (x, None), Var (y, None) -> Ast.ScopeVar.compare x y = 0 | Var (x, Some sx), Var (y, Some sy) -> - Ast.ScopeVar.compare x y = 0 && Ast.StateName.compare sx sy = 0 + Ast.ScopeVar.compare x y = 0 && Ast.StateName.compare sx sy = 0 | SubScope x, SubScope y -> Scopelang.Ast.SubScopeName.compare x y = 0 | _ -> false @@ -57,8 +57,8 @@ module Vertex = struct match x with | Var (v, None) -> Ast.ScopeVar.format_t fmt v | Var (v, Some sv) -> - Format.fprintf fmt "%a.%a" Ast.ScopeVar.format_t v - Ast.StateName.format_t sv + Format.fprintf fmt "%a.%a" Ast.ScopeVar.format_t v Ast.StateName.format_t + sv | SubScope v -> Scopelang.Ast.SubScopeName.format_t fmt v end @@ -103,15 +103,15 @@ let check_for_cycle (scope : Ast.scope) (g : ScopeDependencies.t) : unit = let var_str, var_info = match v with | Vertex.Var (v, None) -> - ( Format.asprintf "%a" Ast.ScopeVar.format_t v, - Ast.ScopeVar.get_info v ) + ( Format.asprintf "%a" Ast.ScopeVar.format_t v, + Ast.ScopeVar.get_info v ) | Vertex.Var (v, Some sv) -> - ( Format.asprintf "%a.%a" Ast.ScopeVar.format_t v - Ast.StateName.format_t sv, - Ast.StateName.get_info sv ) + ( Format.asprintf "%a.%a" Ast.ScopeVar.format_t v + Ast.StateName.format_t sv, + Ast.StateName.get_info sv ) | Vertex.SubScope v -> - ( Format.asprintf "%a" Scopelang.Ast.SubScopeName.format_t v, - Scopelang.Ast.SubScopeName.get_info v ) + ( Format.asprintf "%a" Scopelang.Ast.SubScopeName.format_t v, + Scopelang.Ast.SubScopeName.get_info v ) in let succs = ScopeDependencies.succ_e g v in let _, edge_pos, succ = @@ -120,12 +120,12 @@ let check_for_cycle (scope : Ast.scope) (g : ScopeDependencies.t) : unit = let succ_str = match succ with | Vertex.Var (v, None) -> - Format.asprintf "%a" Ast.ScopeVar.format_t v + Format.asprintf "%a" Ast.ScopeVar.format_t v | Vertex.Var (v, Some sv) -> - Format.asprintf "%a.%a" Ast.ScopeVar.format_t v - Ast.StateName.format_t sv + Format.asprintf "%a.%a" Ast.ScopeVar.format_t v + Ast.StateName.format_t sv | Vertex.SubScope v -> - Format.asprintf "%a" Scopelang.Ast.SubScopeName.format_t v + Format.asprintf "%a" Scopelang.Ast.SubScopeName.format_t v in [ ( Some ("Cycle variable " ^ var_str ^ ", declared:"), @@ -151,10 +151,10 @@ let build_scope_dependencies (scope : Ast.scope) : ScopeDependencies.t = match var_or_state with | Ast.WholeVar -> ScopeDependencies.add_vertex g (Vertex.Var (v, None)) | Ast.States states -> - List.fold_left - (fun g state -> - ScopeDependencies.add_vertex g (Vertex.Var (v, Some state))) - g states) + List.fold_left + (fun g state -> + ScopeDependencies.add_vertex g (Vertex.Var (v, Some state))) + g states) scope.scope_vars g in let g = @@ -170,59 +170,58 @@ let build_scope_dependencies (scope : Ast.scope) : ScopeDependencies.t = let fv = Ast.free_variables def in Ast.ScopeDefMap.fold (fun fv_def fv_def_pos g -> - match (def_key, fv_def) with + match def_key, fv_def with | ( Ast.ScopeDef.Var (v_defined, s_defined), Ast.ScopeDef.Var (v_used, s_used) ) -> - (* simple case *) - if v_used = v_defined && s_used = s_defined then - (* variable definitions cannot be recursive *) - Errors.raise_spanned_error fv_def_pos - "The variable %a is used in one of its definitions, but \ - recursion is forbidden in Catala" - Ast.ScopeDef.format_t def_key - else - let edge = - ScopeDependencies.E.create - (Vertex.Var (v_used, s_used)) - fv_def_pos - (Vertex.Var (v_defined, s_defined)) - in - ScopeDependencies.add_edge_e g edge - | ( Ast.ScopeDef.SubScopeVar (defined, _), - Ast.ScopeDef.Var (v_used, s_used) ) -> - (* here we are defining the input of a subscope using a var of - the scope *) + (* simple case *) + if v_used = v_defined && s_used = s_defined then + (* variable definitions cannot be recursive *) + Errors.raise_spanned_error fv_def_pos + "The variable %a is used in one of its definitions, but \ + recursion is forbidden in Catala" + Ast.ScopeDef.format_t def_key + else let edge = ScopeDependencies.E.create (Vertex.Var (v_used, s_used)) - fv_def_pos (Vertex.SubScope defined) + fv_def_pos + (Vertex.Var (v_defined, s_defined)) in ScopeDependencies.add_edge_e g edge + | ( Ast.ScopeDef.SubScopeVar (defined, _), + Ast.ScopeDef.Var (v_used, s_used) ) -> + (* here we are defining the input of a subscope using a var of the + scope *) + let edge = + ScopeDependencies.E.create + (Vertex.Var (v_used, s_used)) + fv_def_pos (Vertex.SubScope defined) + in + ScopeDependencies.add_edge_e g edge | ( Ast.ScopeDef.SubScopeVar (defined, _), Ast.ScopeDef.SubScopeVar (used, _) ) -> - (* here we are defining the input of a scope with the output of - another subscope *) - if used = defined then - (* subscopes are not recursive functions *) - Errors.raise_spanned_error fv_def_pos - "The subscope %a is used when defining one of its inputs, \ - but recursion is forbidden in Catala" - Scopelang.Ast.SubScopeName.format_t defined - else - let edge = - ScopeDependencies.E.create (Vertex.SubScope used) fv_def_pos - (Vertex.SubScope defined) - in - ScopeDependencies.add_edge_e g edge - | ( Ast.ScopeDef.Var (v_defined, s_defined), - Ast.ScopeDef.SubScopeVar (used, _) ) -> - (* finally we define a scope var with the output of a - subscope *) + (* here we are defining the input of a scope with the output of + another subscope *) + if used = defined then + (* subscopes are not recursive functions *) + Errors.raise_spanned_error fv_def_pos + "The subscope %a is used when defining one of its inputs, \ + but recursion is forbidden in Catala" + Scopelang.Ast.SubScopeName.format_t defined + else let edge = ScopeDependencies.E.create (Vertex.SubScope used) fv_def_pos - (Vertex.Var (v_defined, s_defined)) + (Vertex.SubScope defined) in - ScopeDependencies.add_edge_e g edge) + ScopeDependencies.add_edge_e g edge + | ( Ast.ScopeDef.Var (v_defined, s_defined), + Ast.ScopeDef.SubScopeVar (used, _) ) -> + (* finally we define a scope var with the output of a subscope *) + let edge = + ScopeDependencies.E.create (Vertex.SubScope used) fv_def_pos + (Vertex.Var (v_defined, s_defined)) + in + ScopeDependencies.add_edge_e g edge) fv g) scope.scope_defs g in @@ -252,8 +251,8 @@ module ExceptionsSCC = Graph.Components.Make (ExceptionsDependencies) (** {2 Graph computations} *) let build_exceptions_graph - (def : Ast.rule Ast.RuleMap.t) (def_info : Ast.ScopeDef.t) : - ExceptionsDependencies.t = + (def : Ast.rule Ast.RuleMap.t) + (def_info : Ast.ScopeDef.t) : ExceptionsDependencies.t = (* first we collect all the rule sets referred by exceptions *) let all_rule_sets_pointed_to_by_exceptions : Ast.RuleSet.t list = Ast.RuleMap.fold diff --git a/compiler/desugared/desugared_to_scope.ml b/compiler/desugared/desugared_to_scope.ml index 6987a479..6fc8d3d2 100644 --- a/compiler/desugared/desugared_to_scope.ml +++ b/compiler/desugared/desugared_to_scope.ml @@ -37,123 +37,115 @@ let tag_with_log_entry ( Scopelang.Ast.EApp ( ( Scopelang.Ast.EOp (Dcalc.Ast.Unop (Dcalc.Ast.Log (l, markings))), Pos.get_position e ), - [ e ] ), + [e] ), Pos.get_position e ) let rec translate_expr (ctx : ctx) (e : Ast.expr Pos.marked) : Scopelang.Ast.expr Pos.marked Bindlib.box = match Pos.unmark e with | Ast.ELocation (SubScopeVar (s_name, ss_name, s_var)) -> - (* When referring to a subscope variable in an expression, we are - referring to the output, hence we take the last state. *) - let new_s_var = - match Ast.ScopeVarMap.find (Pos.unmark s_var) ctx.scope_var_mapping with - | WholeVar new_s_var -> Pos.same_pos_as new_s_var s_var - | States states -> - Pos.same_pos_as (snd (List.hd (List.rev states))) s_var - in - Bindlib.box - ( Scopelang.Ast.ELocation (SubScopeVar (s_name, ss_name, new_s_var)), - Pos.get_position e ) + (* When referring to a subscope variable in an expression, we are referring + to the output, hence we take the last state. *) + let new_s_var = + match Ast.ScopeVarMap.find (Pos.unmark s_var) ctx.scope_var_mapping with + | WholeVar new_s_var -> Pos.same_pos_as new_s_var s_var + | States states -> Pos.same_pos_as (snd (List.hd (List.rev states))) s_var + in + Bindlib.box + ( Scopelang.Ast.ELocation (SubScopeVar (s_name, ss_name, new_s_var)), + Pos.get_position e ) | Ast.ELocation (ScopeVar (s_var, None)) -> - Bindlib.box - ( Scopelang.Ast.ELocation - (ScopeVar - (match - Ast.ScopeVarMap.find (Pos.unmark s_var) ctx.scope_var_mapping - with - | WholeVar new_s_var -> Pos.same_pos_as new_s_var s_var - | States _ -> failwith "should not happen")), - Pos.get_position e ) + Bindlib.box + ( Scopelang.Ast.ELocation + (ScopeVar + (match + Ast.ScopeVarMap.find (Pos.unmark s_var) ctx.scope_var_mapping + with + | WholeVar new_s_var -> Pos.same_pos_as new_s_var s_var + | States _ -> failwith "should not happen")), + Pos.get_position e ) | Ast.ELocation (ScopeVar (s_var, Some state)) -> - Bindlib.box - ( Scopelang.Ast.ELocation - (ScopeVar - (match - Ast.ScopeVarMap.find (Pos.unmark s_var) ctx.scope_var_mapping - with - | WholeVar _ -> failwith "should not happen" - | States states -> - Pos.same_pos_as (List.assoc state states) s_var)), - Pos.get_position e ) + Bindlib.box + ( Scopelang.Ast.ELocation + (ScopeVar + (match + Ast.ScopeVarMap.find (Pos.unmark s_var) ctx.scope_var_mapping + with + | WholeVar _ -> failwith "should not happen" + | States states -> Pos.same_pos_as (List.assoc state states) s_var)), + Pos.get_position e ) | Ast.EVar v -> - Bindlib.box_apply - (fun v -> Pos.same_pos_as v e) - (Bindlib.box_var (Ast.VarMap.find (Pos.unmark v) ctx.var_mapping)) + Bindlib.box_apply + (fun v -> Pos.same_pos_as v e) + (Bindlib.box_var (Ast.VarMap.find (Pos.unmark v) ctx.var_mapping)) | EStruct (s_name, fields) -> - Bindlib.box_apply - (fun new_fields -> - (Scopelang.Ast.EStruct (s_name, new_fields), Pos.get_position e)) - (Scopelang.Ast.StructFieldMapLift.lift_box - (Scopelang.Ast.StructFieldMap.map (translate_expr ctx) fields)) + Bindlib.box_apply + (fun new_fields -> + Scopelang.Ast.EStruct (s_name, new_fields), Pos.get_position e) + (Scopelang.Ast.StructFieldMapLift.lift_box + (Scopelang.Ast.StructFieldMap.map (translate_expr ctx) fields)) | EStructAccess (e1, s_name, f_name) -> - Bindlib.box_apply - (fun new_e1 -> - ( Scopelang.Ast.EStructAccess (new_e1, s_name, f_name), - Pos.get_position e )) - (translate_expr ctx e1) + Bindlib.box_apply + (fun new_e1 -> + Scopelang.Ast.EStructAccess (new_e1, s_name, f_name), Pos.get_position e) + (translate_expr ctx e1) | EEnumInj (e1, cons, e_name) -> - Bindlib.box_apply - (fun new_e1 -> - (Scopelang.Ast.EEnumInj (new_e1, cons, e_name), Pos.get_position e)) - (translate_expr ctx e1) + Bindlib.box_apply + (fun new_e1 -> + Scopelang.Ast.EEnumInj (new_e1, cons, e_name), Pos.get_position e) + (translate_expr ctx e1) | EMatch (e1, e_name, arms) -> - Bindlib.box_apply2 - (fun new_e1 new_arms -> - (Scopelang.Ast.EMatch (new_e1, e_name, new_arms), Pos.get_position e)) - (translate_expr ctx e1) - (Scopelang.Ast.EnumConstructorMapLift.lift_box - (Scopelang.Ast.EnumConstructorMap.map (translate_expr ctx) arms)) + Bindlib.box_apply2 + (fun new_e1 new_arms -> + Scopelang.Ast.EMatch (new_e1, e_name, new_arms), Pos.get_position e) + (translate_expr ctx e1) + (Scopelang.Ast.EnumConstructorMapLift.lift_box + (Scopelang.Ast.EnumConstructorMap.map (translate_expr ctx) arms)) | ELit l -> Bindlib.box (Scopelang.Ast.ELit l, Pos.get_position e) | EAbs ((binder, binder_pos), typs) -> - let vars, body = Bindlib.unmbind binder in - let new_vars = - Array.map - (fun var -> Scopelang.Ast.Var.make (Bindlib.name_of var, binder_pos)) - vars - in - let ctx = - List.fold_left2 - (fun ctx var new_var -> - { - ctx with - var_mapping = Ast.VarMap.add var new_var ctx.var_mapping; - }) - ctx (Array.to_list vars) (Array.to_list new_vars) - in - Bindlib.box_apply - (fun new_binder -> - ( Scopelang.Ast.EAbs ((new_binder, binder_pos), typs), - Pos.get_position e )) - (Bindlib.bind_mvar new_vars (translate_expr ctx body)) + let vars, body = Bindlib.unmbind binder in + let new_vars = + Array.map + (fun var -> Scopelang.Ast.Var.make (Bindlib.name_of var, binder_pos)) + vars + in + let ctx = + List.fold_left2 + (fun ctx var new_var -> + { ctx with var_mapping = Ast.VarMap.add var new_var ctx.var_mapping }) + ctx (Array.to_list vars) (Array.to_list new_vars) + in + Bindlib.box_apply + (fun new_binder -> + Scopelang.Ast.EAbs ((new_binder, binder_pos), typs), Pos.get_position e) + (Bindlib.bind_mvar new_vars (translate_expr ctx body)) | EApp (e1, args) -> - Bindlib.box_apply2 - (fun new_e1 new_args -> - (Scopelang.Ast.EApp (new_e1, new_args), Pos.get_position e)) - (translate_expr ctx e1) - (Bindlib.box_list (List.map (translate_expr ctx) args)) + Bindlib.box_apply2 + (fun new_e1 new_args -> + Scopelang.Ast.EApp (new_e1, new_args), Pos.get_position e) + (translate_expr ctx e1) + (Bindlib.box_list (List.map (translate_expr ctx) args)) | EOp op -> Bindlib.box (Scopelang.Ast.EOp op, Pos.get_position e) | EDefault (excepts, just, cons) -> - Bindlib.box_apply3 - (fun new_excepts new_just new_cons -> - ( Scopelang.Ast.EDefault (new_excepts, new_just, new_cons), - Pos.get_position e )) - (Bindlib.box_list (List.map (translate_expr ctx) excepts)) - (translate_expr ctx just) (translate_expr ctx cons) + Bindlib.box_apply3 + (fun new_excepts new_just new_cons -> + ( Scopelang.Ast.EDefault (new_excepts, new_just, new_cons), + Pos.get_position e )) + (Bindlib.box_list (List.map (translate_expr ctx) excepts)) + (translate_expr ctx just) (translate_expr ctx cons) | EIfThenElse (e1, e2, e3) -> - Bindlib.box_apply3 - (fun new_e1 new_e2 new_e3 -> - ( Scopelang.Ast.EIfThenElse (new_e1, new_e2, new_e3), - Pos.get_position e )) - (translate_expr ctx e1) (translate_expr ctx e2) (translate_expr ctx e3) + Bindlib.box_apply3 + (fun new_e1 new_e2 new_e3 -> + Scopelang.Ast.EIfThenElse (new_e1, new_e2, new_e3), Pos.get_position e) + (translate_expr ctx e1) (translate_expr ctx e2) (translate_expr ctx e3) | EArray args -> - Bindlib.box_apply - (fun new_args -> (Scopelang.Ast.EArray new_args, Pos.get_position e)) - (Bindlib.box_list (List.map (translate_expr ctx) args)) + Bindlib.box_apply + (fun new_args -> Scopelang.Ast.EArray new_args, Pos.get_position e) + (Bindlib.box_list (List.map (translate_expr ctx) args)) | ErrorOnEmpty e1 -> - Bindlib.box_apply - (fun new_e1 -> (Scopelang.Ast.ErrorOnEmpty new_e1, Pos.get_position e)) - (translate_expr ctx e1) + Bindlib.box_apply + (fun new_e1 -> Scopelang.Ast.ErrorOnEmpty new_e1, Pos.get_position e) + (translate_expr ctx e1) (** {1 Rule tree construction} *) @@ -207,21 +199,21 @@ let rec rule_tree_to_expr (is_func : Ast.Var.t option) (tree : rule_tree) : Scopelang.Ast.expr Pos.marked Bindlib.box = let exceptions, base_rules = - match tree with Leaf r -> ([], r) | Node (exceptions, r) -> (exceptions, r) + match tree with Leaf r -> [], r | Node (exceptions, r) -> exceptions, r in (* because each rule has its own variable parameter and we want to convert the whole rule tree into a function, we need to perform some alpha-renaming of all the expressions *) let substitute_parameter - (e : Ast.expr Pos.marked Bindlib.box) (rule : Ast.rule) : - Ast.expr Pos.marked Bindlib.box = - match (is_func, rule.Ast.rule_parameter) with + (e : Ast.expr Pos.marked Bindlib.box) + (rule : Ast.rule) : Ast.expr Pos.marked Bindlib.box = + match is_func, rule.Ast.rule_parameter with | Some new_param, Some (old_param, _) -> - let binder = Bindlib.bind_var old_param e in - Bindlib.box_apply2 - (fun binder new_param -> Bindlib.subst binder new_param) - binder - (Bindlib.box_var new_param) + let binder = Bindlib.bind_var old_param e in + Bindlib.box_apply2 + (fun binder new_param -> Bindlib.subst binder new_param) + binder + (Bindlib.box_var new_param) | None, None -> e | _ -> assert false (* should not happen *) @@ -230,22 +222,21 @@ let rec rule_tree_to_expr match is_func with | None -> ctx | Some new_param -> ( - match Ast.VarMap.find_opt new_param ctx.var_mapping with - | None -> - let new_param_scope = - Scopelang.Ast.Var.make (Bindlib.name_of new_param, def_pos) - in - { - ctx with - var_mapping = - Ast.VarMap.add new_param new_param_scope ctx.var_mapping; - } - | Some _ -> - (* We only create a mapping if none exists because - [rule_tree_to_expr] is called recursively on the exceptions of - the tree and we don't want to create a new Scopelang variable for - the parameter at each tree level. *) - ctx) + match Ast.VarMap.find_opt new_param ctx.var_mapping with + | None -> + let new_param_scope = + Scopelang.Ast.Var.make (Bindlib.name_of new_param, def_pos) + in + { + ctx with + var_mapping = Ast.VarMap.add new_param new_param_scope ctx.var_mapping; + } + | Some _ -> + (* We only create a mapping if none exists because [rule_tree_to_expr] + is called recursively on the exceptions of the tree and we don't want + to create a new Scopelang variable for the parameter at each tree + level. *) + ctx) in let base_just_list = List.map @@ -304,22 +295,22 @@ let rec rule_tree_to_expr def_pos )) exceptions default_containing_base_cases in - match (is_func, (List.hd base_rules).Ast.rule_parameter) with + match is_func, (List.hd base_rules).Ast.rule_parameter with | None, None -> default | Some new_param, Some (_, typ) -> - if toplevel then - (* When we're creating a function from multiple defaults, we must check - that the result returned by the function is not empty *) - let default = - Bindlib.box_apply - (fun (default : Scopelang.Ast.expr * Pos.t) -> - (Scopelang.Ast.ErrorOnEmpty default, def_pos)) - default - in - Scopelang.Ast.make_abs - (Array.of_list [ Ast.VarMap.find new_param ctx.var_mapping ]) - default def_pos [ typ ] def_pos - else default + if toplevel then + (* When we're creating a function from multiple defaults, we must check + that the result returned by the function is not empty *) + let default = + Bindlib.box_apply + (fun (default : Scopelang.Ast.expr * Pos.t) -> + Scopelang.Ast.ErrorOnEmpty default, def_pos) + default + in + Scopelang.Ast.make_abs + (Array.of_list [Ast.VarMap.find new_param ctx.var_mapping]) + default def_pos [typ] def_pos + else default | _ -> (* should not happen *) assert false (** {1 AST translation} *) @@ -350,10 +341,10 @@ let translate_def match Pos.unmark typ with | Scopelang.Ast.TArrow (t_param, _) -> Some t_param | _ -> - Errors.raise_spanned_error (Pos.get_position typ) - "The definitions of %a are function but its type, %a, is not a \ - function type" - Ast.ScopeDef.format_t def_info Scopelang.Print.format_typ typ + Errors.raise_spanned_error (Pos.get_position typ) + "The definitions of %a are function but its type, %a, is not a \ + function type" + Ast.ScopeDef.format_t def_info Scopelang.Print.format_typ typ else if (not is_def_func) && all_rules_not_func then None else let spans = @@ -408,7 +399,7 @@ let translate_def defined as an OnlyInput to a subscope, since the [false] default value will not be provided by the calee scope, it has to be placed in the caller. *) - then (ELit LEmptyError, Pos.no_pos) + then ELit LEmptyError, Pos.no_pos else Bindlib.unbox (rule_tree_to_expr ~toplevel:true ctx @@ -419,9 +410,9 @@ let translate_def is_def_func_param_typ) (match top_list with | [] -> - (* In this case, there are no rules to define the expression *) - Leaf [ top_value ] - | _ -> Node (top_list, [ top_value ]))) + (* In this case, there are no rules to define the expression *) + Leaf [top_value] + | _ -> Node (top_list, [top_value]))) (** Translates a scope *) let translate_scope (ctx : ctx) (scope : Ast.scope) : Scopelang.Ast.scope_decl = @@ -436,177 +427,166 @@ let translate_scope (ctx : ctx) (scope : Ast.scope) : Scopelang.Ast.scope_decl = (fun vertex -> match vertex with | Dependency.Vertex.Var (var, state) -> ( - let scope_def = - Ast.ScopeDefMap.find + let scope_def = + Ast.ScopeDefMap.find + (Ast.ScopeDef.Var (var, state)) + scope.scope_defs + in + let var_def = scope_def.scope_def_rules in + let var_typ = scope_def.scope_def_typ in + let is_cond = scope_def.scope_def_is_condition in + match Pos.unmark scope_def.Ast.scope_def_io.io_input with + | OnlyInput when not (Ast.RuleMap.is_empty var_def) -> + (* If the variable is tagged as input, then it shall not be + redefined. *) + Errors.raise_multispanned_error + (( Some "Incriminated variable:", + Pos.get_position (Ast.ScopeVar.get_info var) ) + :: List.map + (fun (rule, _) -> + ( Some "Incriminated variable definition:", + Pos.get_position (Ast.RuleName.get_info rule) )) + (Ast.RuleMap.bindings var_def)) + "It is impossible to give a definition to a scope variable \ + tagged as input." + | OnlyInput -> + [] + (* we do not provide any definition for an input-only variable *) + | _ -> + let expr_def = + translate_def ctx (Ast.ScopeDef.Var (var, state)) - scope.scope_defs + var_def var_typ scope_def.Ast.scope_def_io ~is_cond + ~is_subscope_var:false in - let var_def = scope_def.scope_def_rules in - let var_typ = scope_def.scope_def_typ in - let is_cond = scope_def.scope_def_is_condition in - match Pos.unmark scope_def.Ast.scope_def_io.io_input with - | OnlyInput when not (Ast.RuleMap.is_empty var_def) -> - (* If the variable is tagged as input, then it shall not be - redefined. *) - Errors.raise_multispanned_error - (( Some "Incriminated variable:", - Pos.get_position (Ast.ScopeVar.get_info var) ) - :: List.map - (fun (rule, _) -> - ( Some "Incriminated variable definition:", - Pos.get_position (Ast.RuleName.get_info rule) )) - (Ast.RuleMap.bindings var_def)) - "It is impossible to give a definition to a scope \ - variable tagged as input." - | OnlyInput -> - [] - (* we do not provide any definition for an input-only - variable *) - | _ -> - let expr_def = - translate_def ctx - (Ast.ScopeDef.Var (var, state)) - var_def var_typ scope_def.Ast.scope_def_io ~is_cond - ~is_subscope_var:false - in - let scope_var = - match - (Ast.ScopeVarMap.find var ctx.scope_var_mapping, state) - with - | WholeVar v, None -> v - | States states, Some state -> List.assoc state states - | _ -> failwith "should not happen" - in - [ - Scopelang.Ast.Definition - ( ( Scopelang.Ast.ScopeVar - ( scope_var, - Pos.get_position - (Scopelang.Ast.ScopeVar.get_info scope_var) ), + let scope_var = + match + Ast.ScopeVarMap.find var ctx.scope_var_mapping, state + with + | WholeVar v, None -> v + | States states, Some state -> List.assoc state states + | _ -> failwith "should not happen" + in + [ + Scopelang.Ast.Definition + ( ( Scopelang.Ast.ScopeVar + ( scope_var, Pos.get_position (Scopelang.Ast.ScopeVar.get_info scope_var) ), - var_typ, - scope_def.Ast.scope_def_io, - expr_def ); - ]) + Pos.get_position + (Scopelang.Ast.ScopeVar.get_info scope_var) ), + var_typ, + scope_def.Ast.scope_def_io, + expr_def ); + ]) | Dependency.Vertex.SubScope sub_scope_index -> - (* Before calling the sub_scope, we need to include all the - re-definitions of subscope parameters*) - let sub_scope = - Scopelang.Ast.SubScopeMap.find sub_scope_index - scope.scope_sub_scopes - in - let sub_scope_vars_redefs_candidates = - Ast.ScopeDefMap.filter - (fun def_key scope_def -> - match def_key with - | Ast.ScopeDef.Var _ -> false - | Ast.ScopeDef.SubScopeVar (sub_scope_index', _) -> - sub_scope_index = sub_scope_index' - (* We exclude subscope variables that have 0 - re-definitions and are not visible in the input of - the subscope *) - && not - ((match - Pos.unmark scope_def.Ast.scope_def_io.io_input - with - | Scopelang.Ast.NoInput -> true - | _ -> false) - && Ast.RuleMap.is_empty scope_def.scope_def_rules - )) - scope.scope_defs - in - let sub_scope_vars_redefs = - Ast.ScopeDefMap.mapi - (fun def_key scope_def -> - let def = scope_def.Ast.scope_def_rules in - let def_typ = scope_def.scope_def_typ in - let is_cond = scope_def.scope_def_is_condition in - match def_key with - | Ast.ScopeDef.Var _ -> - assert false (* should not happen *) - | Ast.ScopeDef.SubScopeVar (_, sub_scope_var) -> - (* This definition redefines a variable of the correct - subscope. But we have to check that this - redefinition is allowed with respect to the io - parameters of that subscope variable. *) - (match - Pos.unmark scope_def.Ast.scope_def_io.io_input - with - | Scopelang.Ast.NoInput -> - Errors.raise_multispanned_error - (( Some "Incriminated subscope:", - Ast.ScopeDef.get_position def_key ) - :: ( Some "Incriminated variable:", - Pos.get_position - (Ast.ScopeVar.get_info sub_scope_var) ) - :: List.map - (fun (rule, _) -> - ( Some - "Incriminated subscope variable \ - definition:", - Pos.get_position - (Ast.RuleName.get_info rule) )) - (Ast.RuleMap.bindings def)) - "It is impossible to give a definition to a \ - subscope variable not tagged as input or \ - context." - | OnlyInput - when Ast.RuleMap.is_empty def && not is_cond -> - (* If the subscope variable is tagged as input, - then it shall be defined. *) - Errors.raise_multispanned_error - [ - ( Some "Incriminated subscope:", - Ast.ScopeDef.get_position def_key ); - ( Some "Incriminated variable:", - Pos.get_position - (Ast.ScopeVar.get_info sub_scope_var) ); - ] - "This subscope variable is a mandatory input \ - but no definition was provided." - | _ -> ()); - (* Now that all is good, we can proceed with - translating this redefinition to a proper Scopelang - term. *) - let expr_def = - translate_def ctx def_key def def_typ - scope_def.Ast.scope_def_io ~is_cond - ~is_subscope_var:true - in - let subscop_real_name = - Scopelang.Ast.SubScopeMap.find sub_scope_index - scope.scope_sub_scopes - in - let var_pos = - Pos.get_position - (Ast.ScopeVar.get_info sub_scope_var) - in - Scopelang.Ast.Definition - ( ( Scopelang.Ast.SubScopeVar - ( subscop_real_name, - (sub_scope_index, var_pos), - match - Ast.ScopeVarMap.find sub_scope_var - ctx.scope_var_mapping - with - | WholeVar v -> (v, var_pos) - | States states -> - (* When defining a sub-scope variable, we - always define its first state in the - sub-scope. *) - (snd (List.hd states), var_pos) ), - var_pos ), - def_typ, - scope_def.Ast.scope_def_io, - expr_def )) - sub_scope_vars_redefs_candidates - in - let sub_scope_vars_redefs = - List.map snd (Ast.ScopeDefMap.bindings sub_scope_vars_redefs) - in - sub_scope_vars_redefs - @ [ Scopelang.Ast.Call (sub_scope, sub_scope_index) ]) + (* Before calling the sub_scope, we need to include all the + re-definitions of subscope parameters*) + let sub_scope = + Scopelang.Ast.SubScopeMap.find sub_scope_index + scope.scope_sub_scopes + in + let sub_scope_vars_redefs_candidates = + Ast.ScopeDefMap.filter + (fun def_key scope_def -> + match def_key with + | Ast.ScopeDef.Var _ -> false + | Ast.ScopeDef.SubScopeVar (sub_scope_index', _) -> + sub_scope_index = sub_scope_index' + (* We exclude subscope variables that have 0 re-definitions + and are not visible in the input of the subscope *) + && not + ((match + Pos.unmark scope_def.Ast.scope_def_io.io_input + with + | Scopelang.Ast.NoInput -> true + | _ -> false) + && Ast.RuleMap.is_empty scope_def.scope_def_rules)) + scope.scope_defs + in + let sub_scope_vars_redefs = + Ast.ScopeDefMap.mapi + (fun def_key scope_def -> + let def = scope_def.Ast.scope_def_rules in + let def_typ = scope_def.scope_def_typ in + let is_cond = scope_def.scope_def_is_condition in + match def_key with + | Ast.ScopeDef.Var _ -> assert false (* should not happen *) + | Ast.ScopeDef.SubScopeVar (_, sub_scope_var) -> + (* This definition redefines a variable of the correct + subscope. But we have to check that this redefinition is + allowed with respect to the io parameters of that + subscope variable. *) + (match Pos.unmark scope_def.Ast.scope_def_io.io_input with + | Scopelang.Ast.NoInput -> + Errors.raise_multispanned_error + (( Some "Incriminated subscope:", + Ast.ScopeDef.get_position def_key ) + :: ( Some "Incriminated variable:", + Pos.get_position + (Ast.ScopeVar.get_info sub_scope_var) ) + :: List.map + (fun (rule, _) -> + ( Some + "Incriminated subscope variable definition:", + Pos.get_position (Ast.RuleName.get_info rule) + )) + (Ast.RuleMap.bindings def)) + "It is impossible to give a definition to a subscope \ + variable not tagged as input or context." + | OnlyInput when Ast.RuleMap.is_empty def && not is_cond -> + (* If the subscope variable is tagged as input, then it + shall be defined. *) + Errors.raise_multispanned_error + [ + ( Some "Incriminated subscope:", + Ast.ScopeDef.get_position def_key ); + ( Some "Incriminated variable:", + Pos.get_position + (Ast.ScopeVar.get_info sub_scope_var) ); + ] + "This subscope variable is a mandatory input but no \ + definition was provided." + | _ -> ()); + (* Now that all is good, we can proceed with translating + this redefinition to a proper Scopelang term. *) + let expr_def = + translate_def ctx def_key def def_typ + scope_def.Ast.scope_def_io ~is_cond + ~is_subscope_var:true + in + let subscop_real_name = + Scopelang.Ast.SubScopeMap.find sub_scope_index + scope.scope_sub_scopes + in + let var_pos = + Pos.get_position (Ast.ScopeVar.get_info sub_scope_var) + in + Scopelang.Ast.Definition + ( ( Scopelang.Ast.SubScopeVar + ( subscop_real_name, + (sub_scope_index, var_pos), + match + Ast.ScopeVarMap.find sub_scope_var + ctx.scope_var_mapping + with + | WholeVar v -> v, var_pos + | States states -> + (* When defining a sub-scope variable, we + always define its first state in the + sub-scope. *) + snd (List.hd states), var_pos ), + var_pos ), + def_typ, + scope_def.Ast.scope_def_io, + expr_def )) + sub_scope_vars_redefs_candidates + in + let sub_scope_vars_redefs = + List.map snd (Ast.ScopeDefMap.bindings sub_scope_vars_redefs) + in + sub_scope_vars_redefs + @ [Scopelang.Ast.Call (sub_scope, sub_scope_index)]) scope_ordering) in (* Then, after having computed all the scopes variables, we add the @@ -628,36 +608,34 @@ let translate_scope (ctx : ctx) (scope : Ast.scope) : Scopelang.Ast.scope_decl = (fun var (states : Ast.var_or_states) acc -> match states with | WholeVar -> - let scope_def = - Ast.ScopeDefMap.find - (Ast.ScopeDef.Var (var, None)) - scope.scope_defs - in - let typ = scope_def.scope_def_typ in - Scopelang.Ast.ScopeVarMap.add - (match Ast.ScopeVarMap.find var ctx.scope_var_mapping with - | WholeVar v -> v - | States _ -> failwith "should not happen") - (typ, scope_def.scope_def_io) - acc + let scope_def = + Ast.ScopeDefMap.find (Ast.ScopeDef.Var (var, None)) scope.scope_defs + in + let typ = scope_def.scope_def_typ in + Scopelang.Ast.ScopeVarMap.add + (match Ast.ScopeVarMap.find var ctx.scope_var_mapping with + | WholeVar v -> v + | States _ -> failwith "should not happen") + (typ, scope_def.scope_def_io) + acc | States states -> - (* What happens in the case of variables with multiple states is - interesting. We need to create as many Scopelang.Var entries in - the scope signature as there are states. *) - List.fold_left - (fun acc (state : Ast.StateName.t) -> - let scope_def = - Ast.ScopeDefMap.find - (Ast.ScopeDef.Var (var, Some state)) - scope.scope_defs - in - Scopelang.Ast.ScopeVarMap.add - (match Ast.ScopeVarMap.find var ctx.scope_var_mapping with - | WholeVar _ -> failwith "should not happen" - | States states' -> List.assoc state states') - (scope_def.scope_def_typ, scope_def.scope_def_io) - acc) - acc states) + (* What happens in the case of variables with multiple states is + interesting. We need to create as many Scopelang.Var entries in the + scope signature as there are states. *) + List.fold_left + (fun acc (state : Ast.StateName.t) -> + let scope_def = + Ast.ScopeDefMap.find + (Ast.ScopeDef.Var (var, Some state)) + scope.scope_defs + in + Scopelang.Ast.ScopeVarMap.add + (match Ast.ScopeVarMap.find var ctx.scope_var_mapping with + | WholeVar _ -> failwith "should not happen" + | States states' -> List.assoc state states') + (scope_def.scope_def_typ, scope_def.scope_def_io) + acc) + acc states) scope.scope_vars Scopelang.Ast.ScopeVarMap.empty in { @@ -679,34 +657,34 @@ let translate_program (pgrm : Ast.program) : Scopelang.Ast.program = (fun scope_var (states : Ast.var_or_states) ctx -> match states with | Ast.WholeVar -> - { - ctx with - scope_var_mapping = - Ast.ScopeVarMap.add scope_var - (WholeVar - (Scopelang.Ast.ScopeVar.fresh - (Ast.ScopeVar.get_info scope_var))) - ctx.scope_var_mapping; - } + { + ctx with + scope_var_mapping = + Ast.ScopeVarMap.add scope_var + (WholeVar + (Scopelang.Ast.ScopeVar.fresh + (Ast.ScopeVar.get_info scope_var))) + ctx.scope_var_mapping; + } | States states -> - { - ctx with - scope_var_mapping = - Ast.ScopeVarMap.add scope_var - (States - (List.map - (fun state -> - ( state, - Scopelang.Ast.ScopeVar.fresh - (let state_name, state_pos = - Ast.StateName.get_info state - in - ( Pos.unmark (Ast.ScopeVar.get_info scope_var) - ^ "_" ^ state_name, - state_pos )) )) - states)) - ctx.scope_var_mapping; - }) + { + ctx with + scope_var_mapping = + Ast.ScopeVarMap.add scope_var + (States + (List.map + (fun state -> + ( state, + Scopelang.Ast.ScopeVar.fresh + (let state_name, state_pos = + Ast.StateName.get_info state + in + ( Pos.unmark (Ast.ScopeVar.get_info scope_var) + ^ "_" ^ state_name, + state_pos )) )) + states)) + ctx.scope_var_mapping; + }) scope_decl.Ast.scope_vars ctx) pgrm.Ast.program_scopes { diff --git a/compiler/driver.ml b/compiler/driver.ml index 662faf30..d1a19a52 100644 --- a/compiler/driver.ml +++ b/compiler/driver.ml @@ -20,12 +20,11 @@ module Errors = Utils.Errors module Pos = Utils.Pos (** Associates a {!type: Cli.backend_lang} with its string represtation. *) -let languages = [ ("en", Cli.En); ("fr", Cli.Fr); ("pl", Cli.Pl) ] +let languages = ["en", Cli.En; "fr", Cli.Fr; "pl", Cli.Pl] (** Associates a file extension with its corresponding {!type: Cli.backend_lang} string representation. *) -let extensions = - [ (".catala_fr", "fr"); (".catala_en", "en"); (".catala_pl", "pl") ] +let extensions = [".catala_fr", "fr"; ".catala_en", "en"; ".catala_pl", "pl"] (** Entry function for the executable. Returns a negative number in case of error. Usage: [driver source_file options]*) @@ -44,14 +43,14 @@ let driver source_file (options : Cli.options) : int = match options.language with | Some l -> l | None -> ( - (* Try to infer the language from the intput file extension. *) - let ext = Filename.extension !filename in - if ext = "" then - Errors.raise_error - "No file extension found for the file '%s'. (Try to add one or \ - to specify the -l flag)" - !filename; - try List.assoc ext extensions with Not_found -> ext) + (* Try to infer the language from the intput file extension. *) + let ext = Filename.extension !filename in + if ext = "" then + Errors.raise_error + "No file extension found for the file '%s'. (Try to add one or to \ + specify the -l flag)" + !filename; + try List.assoc ext extensions with Not_found -> ext) in let language = try List.assoc l languages @@ -65,8 +64,8 @@ let driver source_file (options : Cli.options) : int = match Cli.catala_backend_option_of_string backend with | Some b -> b | None -> - Errors.raise_error - "The selected backend (%s) is not supported by Catala" backend + Errors.raise_error + "The selected backend (%s) is not supported by Catala" backend in let prgm = Surface.Parser_driver.parse_top_level_file source_file language @@ -74,147 +73,245 @@ let driver source_file (options : Cli.options) : int = let prgm = Surface.Fill_positions.fill_pos_with_legislative_info prgm in match backend with | Cli.Makefile -> - let backend_extensions_list = [ ".tex" ] in - let source_file = - match source_file with - | FileName f -> f - | Contents _ -> - Errors.raise_error - "The Makefile backend does not work if the input is not a file" - in - let output_file = - match options.output_file with - | Some f -> f - | None -> Filename.remove_extension source_file ^ ".d" - in - Cli.debug_print "Writing list of dependencies to %s..." output_file; - let oc = open_out output_file in - Printf.fprintf oc "%s:\\\n%s\n%s:" - (String.concat "\\\n" - (output_file - :: List.map - (fun ext -> Filename.remove_extension source_file ^ ext) - backend_extensions_list)) - (String.concat "\\\n" prgm.program_source_files) - (String.concat "\\\n" prgm.program_source_files); - 0 + let backend_extensions_list = [".tex"] in + let source_file = + match source_file with + | FileName f -> f + | Contents _ -> + Errors.raise_error + "The Makefile backend does not work if the input is not a file" + in + let output_file = + match options.output_file with + | Some f -> f + | None -> Filename.remove_extension source_file ^ ".d" + in + Cli.debug_print "Writing list of dependencies to %s..." output_file; + let oc = open_out output_file in + Printf.fprintf oc "%s:\\\n%s\n%s:" + (String.concat "\\\n" + (output_file + :: List.map + (fun ext -> Filename.remove_extension source_file ^ ext) + backend_extensions_list)) + (String.concat "\\\n" prgm.program_source_files) + (String.concat "\\\n" prgm.program_source_files); + 0 | Cli.Latex | Cli.Html -> - let source_file = - match source_file with - | FileName f -> f - | Contents _ -> - Errors.raise_error - "The literate programming backends do not work if the input is \ - not a file" - in - Cli.debug_print "Weaving literate program into %s" - (match backend with - | Cli.Latex -> "LaTeX" - | Cli.Html -> "HTML" - | _ -> assert false (* should not happen *)); - let output_file = - match options.output_file with - | Some f -> f - | None -> ( - Filename.remove_extension source_file - ^ - match backend with - | Cli.Latex -> ".tex" - | Cli.Html -> ".html" - | _ -> assert false - (* should not happen *)) - in - let oc = open_out output_file in - let weave_output = + let source_file = + match source_file with + | FileName f -> f + | Contents _ -> + Errors.raise_error + "The literate programming backends do not work if the input is not \ + a file" + in + Cli.debug_print "Weaving literate program into %s" + (match backend with + | Cli.Latex -> "LaTeX" + | Cli.Html -> "HTML" + | _ -> assert false (* should not happen *)); + let output_file = + match options.output_file with + | Some f -> f + | None -> ( + Filename.remove_extension source_file + ^ match backend with - | Cli.Latex -> Literate.Latex.ast_to_latex language - | Cli.Html -> Literate.Html.ast_to_html language + | Cli.Latex -> ".tex" + | Cli.Html -> ".html" | _ -> assert false - (* should not happen *) - in - Cli.debug_print "Writing to %s" output_file; - let fmt = Format.formatter_of_out_channel oc in - if options.wrap_weaved_output then - match backend with - | Cli.Latex -> - Literate.Latex.wrap_latex prgm.Surface.Ast.program_source_files - language fmt (fun fmt -> weave_output fmt prgm) - | Cli.Html -> - Literate.Html.wrap_html prgm.Surface.Ast.program_source_files - language fmt (fun fmt -> weave_output fmt prgm) - | _ -> assert false (* should not happen *) - else weave_output fmt prgm; - close_out oc; - 0 + (* should not happen *)) + in + let oc = open_out output_file in + let weave_output = + match backend with + | Cli.Latex -> Literate.Latex.ast_to_latex language + | Cli.Html -> Literate.Html.ast_to_html language + | _ -> assert false + (* should not happen *) + in + Cli.debug_print "Writing to %s" output_file; + let fmt = Format.formatter_of_out_channel oc in + if options.wrap_weaved_output then + match backend with + | Cli.Latex -> + Literate.Latex.wrap_latex prgm.Surface.Ast.program_source_files + language fmt (fun fmt -> weave_output fmt prgm) + | Cli.Html -> + Literate.Html.wrap_html prgm.Surface.Ast.program_source_files language + fmt (fun fmt -> weave_output fmt prgm) + | _ -> assert false (* should not happen *) + else weave_output fmt prgm; + close_out oc; + 0 | _ -> ( - Cli.debug_print "Name resolution..."; - let ctxt = Surface.Name_resolution.form_context prgm in - let scope_uid = - match (options.ex_scope, backend) with - | None, Cli.Interpret -> - Errors.raise_error "No scope was provided for execution." - | None, _ -> - snd - (try Desugared.Ast.IdentMap.choose ctxt.scope_idmap - with Not_found -> - Errors.raise_error - "There isn't any scope inside the program.") - | Some name, _ -> ( - match Desugared.Ast.IdentMap.find_opt name ctxt.scope_idmap with - | None -> - Errors.raise_error - "There is no scope \"%s\" inside the program." name - | Some uid -> uid) + Cli.debug_print "Name resolution..."; + let ctxt = Surface.Name_resolution.form_context prgm in + let scope_uid = + match options.ex_scope, backend with + | None, Cli.Interpret -> + Errors.raise_error "No scope was provided for execution." + | None, _ -> + snd + (try Desugared.Ast.IdentMap.choose ctxt.scope_idmap + with Not_found -> + Errors.raise_error "There isn't any scope inside the program.") + | Some name, _ -> ( + match Desugared.Ast.IdentMap.find_opt name ctxt.scope_idmap with + | None -> + Errors.raise_error "There is no scope \"%s\" inside the program." + name + | Some uid -> uid) + in + Cli.debug_print "Desugaring..."; + let prgm = Surface.Desugaring.desugar_program ctxt prgm in + Cli.debug_print "Collecting rules..."; + let prgm = Desugared.Desugared_to_scope.translate_program prgm in + if backend = Cli.Scopelang then begin + let fmt, at_end = + match options.output_file with + | Some f -> + let oc = open_out f in + Format.formatter_of_out_channel oc, fun _ -> close_out oc + | None -> Format.std_formatter, fun _ -> () in - Cli.debug_print "Desugaring..."; - let prgm = Surface.Desugaring.desugar_program ctxt prgm in - Cli.debug_print "Collecting rules..."; - let prgm = Desugared.Desugared_to_scope.translate_program prgm in - if backend = Cli.Scopelang then begin - let fmt, at_end = - match options.output_file with - | Some f -> - let oc = open_out f in - (Format.formatter_of_out_channel oc, fun _ -> close_out oc) - | None -> (Format.std_formatter, fun _ -> ()) - in - if Option.is_some options.ex_scope then - Format.fprintf fmt "%a\n" - (Scopelang.Print.format_scope ~debug:options.debug) - ( scope_uid, - Scopelang.Ast.ScopeMap.find scope_uid prgm.program_scopes ) - else - Format.fprintf fmt "%a\n" - (Scopelang.Print.format_program ~debug:options.debug) - prgm; - at_end (); - exit 0 - end; - Cli.debug_print "Translating to default calculus..."; - let prgm, type_ordering = - Scopelang.Scope_to_dcalc.translate_program prgm + if Option.is_some options.ex_scope then + Format.fprintf fmt "%a\n" + (Scopelang.Print.format_scope ~debug:options.debug) + ( scope_uid, + Scopelang.Ast.ScopeMap.find scope_uid prgm.program_scopes ) + else + Format.fprintf fmt "%a\n" + (Scopelang.Print.format_program ~debug:options.debug) + prgm; + at_end (); + exit 0 + end; + Cli.debug_print "Translating to default calculus..."; + let prgm, type_ordering = + Scopelang.Scope_to_dcalc.translate_program prgm + in + let prgm = + if options.optimize then begin + Cli.debug_print "Optimizing default calculus..."; + Dcalc.Optimizations.optimize_program prgm + end + else prgm + in + let prgrm_dcalc_expr = + Bindlib.unbox (Dcalc.Ast.build_whole_program_expr prgm scope_uid) + in + if backend = Cli.Dcalc then begin + let fmt, at_end = + match options.output_file with + | Some f -> + let oc = open_out f in + Format.formatter_of_out_channel oc, fun _ -> close_out oc + | None -> Format.std_formatter, fun _ -> () + in + if Option.is_some options.ex_scope then + Format.fprintf fmt "%a\n" + (Dcalc.Print.format_scope ~debug:options.debug prgm.decl_ctx) + ( scope_uid, + Option.get + (Dcalc.Ast.fold_left_scope_defs ~init:None + ~f:(fun acc scope_def _ -> + if + Dcalc.Ast.ScopeName.compare scope_def.scope_name + scope_uid + = 0 + then Some scope_def.scope_body + else acc) + prgm.scopes) ) + else + Format.fprintf fmt "%a\n" + (Dcalc.Print.format_expr prgm.decl_ctx) + prgrm_dcalc_expr; + at_end (); + exit 0 + end; + Cli.debug_print "Typechecking..."; + let _typ = Dcalc.Typing.infer_type prgm.decl_ctx prgrm_dcalc_expr in + (* Cli.debug_format "Typechecking results :@\n%a" (Dcalc.Print.format_typ + prgm.decl_ctx) typ; *) + match backend with + | Cli.Typecheck -> + (* That's it! *) + Cli.result_print "Typechecking successful!"; + 0 + | Cli.Proof -> + let vcs = + Verification.Conditions.generate_verification_conditions prgm + (match options.ex_scope with + | None -> None + | Some _ -> Some scope_uid) + in + Verification.Solver.solve_vc prgm.decl_ctx vcs; + 0 + | Cli.Interpret -> + Cli.debug_print "Starting interpretation..."; + let results = + Dcalc.Interpreter.interpret_program prgm.decl_ctx prgrm_dcalc_expr + in + let out_regex = Re.Pcre.regexp "\\_out$" in + let results = + List.map + (fun ((v1, v1_pos), e1) -> + let v1 = + Re.Pcre.substitute ~rex:out_regex ~subst:(fun _ -> "") v1 + in + (v1, v1_pos), e1) + results + in + let results = + List.sort + (fun ((v1, _), _) ((v2, _), _) -> String.compare v1 v2) + results + in + Cli.debug_print "End of interpretation"; + Cli.result_print "Computation successful!%s" + (if List.length results > 0 then " Results:" else ""); + List.iter + (fun ((var, _), result) -> + Cli.result_format "@[%s@ =@ %a@]" var + (Dcalc.Print.format_expr prgm.decl_ctx) + result) + results; + 0 + | Cli.OCaml | Cli.Python | Cli.Lcalc | Cli.Scalc -> + Cli.debug_print "Compiling program into lambda calculus..."; + let prgm = + if options.avoid_exceptions then + Lcalc.Compile_without_exceptions.translate_program prgm + else Lcalc.Compile_with_exceptions.translate_program prgm in let prgm = if options.optimize then begin - Cli.debug_print "Optimizing default calculus..."; - Dcalc.Optimizations.optimize_program prgm + Cli.debug_print "Optimizing lambda calculus..."; + Lcalc.Optimizations.optimize_program prgm end else prgm in - let prgrm_dcalc_expr = - Bindlib.unbox (Dcalc.Ast.build_whole_program_expr prgm scope_uid) + let prgm = + if options.closure_conversion then ( + Cli.debug_print "Performing closure conversion..."; + let prgm = Lcalc.Closure_conversion.closure_conversion prgm in + let prgm = Bindlib.unbox prgm in + prgm) + else prgm in - if backend = Cli.Dcalc then begin + if backend = Cli.Lcalc then begin let fmt, at_end = match options.output_file with | Some f -> - let oc = open_out f in - (Format.formatter_of_out_channel oc, fun _ -> close_out oc) - | None -> (Format.std_formatter, fun _ -> ()) + let oc = open_out f in + Format.formatter_of_out_channel oc, fun _ -> close_out oc + | None -> Format.std_formatter, fun _ -> () in if Option.is_some options.ex_scope then Format.fprintf fmt "%a\n" - (Dcalc.Print.format_scope ~debug:options.debug prgm.decl_ctx) + (Lcalc.Print.format_scope ~debug:options.debug prgm.decl_ctx) ( scope_uid, Option.get (Dcalc.Ast.fold_left_scope_defs ~init:None @@ -227,188 +324,86 @@ let driver source_file (options : Cli.options) : int = else acc) prgm.scopes) ) else - Format.fprintf fmt "%a\n" - (Dcalc.Print.format_expr prgm.decl_ctx) - prgrm_dcalc_expr; + ignore + (Dcalc.Ast.fold_left_scope_defs ~init:0 + ~f:(fun i scope_def _ -> + Format.fprintf fmt "%s%a" + (if i = 0 then "" else "\n") + (Lcalc.Print.format_scope prgm.decl_ctx) + (scope_uid, scope_def.scope_body); + i + 1) + prgm.scopes); at_end (); exit 0 end; - Cli.debug_print "Typechecking..."; - let _typ = Dcalc.Typing.infer_type prgm.decl_ctx prgrm_dcalc_expr in - (* Cli.debug_format "Typechecking results :@\n%a" - (Dcalc.Print.format_typ prgm.decl_ctx) typ; *) - match backend with - | Cli.Typecheck -> - (* That's it! *) - Cli.result_print "Typechecking successful!"; - 0 - | Cli.Proof -> - let vcs = - Verification.Conditions.generate_verification_conditions prgm - (match options.ex_scope with - | None -> None - | Some _ -> Some scope_uid) - in - Verification.Solver.solve_vc prgm.decl_ctx vcs; - 0 - | Cli.Interpret -> - Cli.debug_print "Starting interpretation..."; - let results = - Dcalc.Interpreter.interpret_program prgm.decl_ctx prgrm_dcalc_expr - in - let out_regex = Re.Pcre.regexp "\\_out$" in - let results = - List.map - (fun ((v1, v1_pos), e1) -> - let v1 = - Re.Pcre.substitute ~rex:out_regex ~subst:(fun _ -> "") v1 - in - ((v1, v1_pos), e1)) - results - in - let results = - List.sort - (fun ((v1, _), _) ((v2, _), _) -> String.compare v1 v2) - results - in - Cli.debug_print "End of interpretation"; - Cli.result_print "Computation successful!%s" - (if List.length results > 0 then " Results:" else ""); - List.iter - (fun ((var, _), result) -> - Cli.result_format "@[%s@ =@ %a@]" var - (Dcalc.Print.format_expr prgm.decl_ctx) - result) - results; - 0 - | Cli.OCaml | Cli.Python | Cli.Lcalc | Cli.Scalc -> - Cli.debug_print "Compiling program into lambda calculus..."; - let prgm = - if options.avoid_exceptions then - Lcalc.Compile_without_exceptions.translate_program prgm - else Lcalc.Compile_with_exceptions.translate_program prgm - in - let prgm = - if options.optimize then begin - Cli.debug_print "Optimizing lambda calculus..."; - Lcalc.Optimizations.optimize_program prgm - end - else prgm - in - let prgm = - if options.closure_conversion then ( - Cli.debug_print "Performing closure conversion..."; - let prgm = Lcalc.Closure_conversion.closure_conversion prgm in - let prgm = Bindlib.unbox prgm in - prgm) - else prgm - in - if backend = Cli.Lcalc then begin - let fmt, at_end = - match options.output_file with - | Some f -> - let oc = open_out f in - (Format.formatter_of_out_channel oc, fun _ -> close_out oc) - | None -> (Format.std_formatter, fun _ -> ()) - in - if Option.is_some options.ex_scope then - Format.fprintf fmt "%a\n" - (Lcalc.Print.format_scope ~debug:options.debug prgm.decl_ctx) - ( scope_uid, - Option.get - (Dcalc.Ast.fold_left_scope_defs ~init:None - ~f:(fun acc scope_def _ -> - if - Dcalc.Ast.ScopeName.compare scope_def.scope_name - scope_uid - = 0 - then Some scope_def.scope_body - else acc) - prgm.scopes) ) - else - ignore - (Dcalc.Ast.fold_left_scope_defs ~init:0 - ~f:(fun i scope_def _ -> - Format.fprintf fmt "%s%a" - (if i = 0 then "" else "\n") - (Lcalc.Print.format_scope prgm.decl_ctx) - (scope_uid, scope_def.scope_body); - i + 1) - prgm.scopes); - at_end (); - exit 0 - end; - let source_file = - match source_file with - | FileName f -> f - | Contents _ -> - Errors.raise_error - "This backend does not work if the input is not a file" - in - let new_output_file (extension : string) : string = + let source_file = + match source_file with + | FileName f -> f + | Contents _ -> + Errors.raise_error + "This backend does not work if the input is not a file" + in + let new_output_file (extension : string) : string = + match options.output_file with + | Some f -> f + | None -> Filename.remove_extension source_file ^ extension + in + (match backend with + | Cli.OCaml -> + let output_file = new_output_file ".ml" in + Cli.debug_print "Writing to %s..." output_file; + let oc = open_out output_file in + let fmt = Format.formatter_of_out_channel oc in + Cli.debug_print "Compiling program into OCaml..."; + Lcalc.To_ocaml.format_program fmt prgm type_ordering; + close_out oc + | Cli.Python | Cli.Scalc -> + let prgm = Scalc.Compile_from_lambda.translate_program prgm in + if backend = Cli.Scalc then begin + let fmt, at_end = match options.output_file with - | Some f -> f - | None -> Filename.remove_extension source_file ^ extension + | Some f -> + let oc = open_out f in + Format.formatter_of_out_channel oc, fun _ -> close_out oc + | None -> Format.std_formatter, fun _ -> () in - (match backend with - | Cli.OCaml -> - let output_file = new_output_file ".ml" in - Cli.debug_print "Writing to %s..." output_file; - let oc = open_out output_file in - let fmt = Format.formatter_of_out_channel oc in - Cli.debug_print "Compiling program into OCaml..."; - Lcalc.To_ocaml.format_program fmt prgm type_ordering; - close_out oc - | Cli.Python | Cli.Scalc -> - let prgm = Scalc.Compile_from_lambda.translate_program prgm in - if backend = Cli.Scalc then begin - let fmt, at_end = - match options.output_file with - | Some f -> - let oc = open_out f in - ( Format.formatter_of_out_channel oc, - fun _ -> close_out oc ) - | None -> (Format.std_formatter, fun _ -> ()) - in - if Option.is_some options.ex_scope then - Format.fprintf fmt "%a\n" - (Scalc.Print.format_scope ~debug:options.debug - prgm.decl_ctx) - (let body = - List.find - (fun body -> - body.Scalc.Ast.scope_body_name = scope_uid) - prgm.scopes - in - body) - else - Format.fprintf fmt "%a\n" - (Format.pp_print_list - ~pp_sep:(fun fmt () -> Format.fprintf fmt "\n\n") - (fun fmt scope -> - (Scalc.Print.format_scope prgm.decl_ctx) fmt scope)) - prgm.scopes; - at_end (); - exit 0 - end; - let output_file = new_output_file ".py" in - Cli.debug_print "Compiling program into Python..."; - Cli.debug_print "Writing to %s..." output_file; - let oc = open_out output_file in - let fmt = Format.formatter_of_out_channel oc in - Scalc.To_python.format_program fmt prgm type_ordering; - close_out oc - | _ -> assert false (* should not happen *)); - 0 - | _ -> assert false - (* should not happen *)) + if Option.is_some options.ex_scope then + Format.fprintf fmt "%a\n" + (Scalc.Print.format_scope ~debug:options.debug prgm.decl_ctx) + (let body = + List.find + (fun body -> body.Scalc.Ast.scope_body_name = scope_uid) + prgm.scopes + in + body) + else + Format.fprintf fmt "%a\n" + (Format.pp_print_list + ~pp_sep:(fun fmt () -> Format.fprintf fmt "\n\n") + (fun fmt scope -> + (Scalc.Print.format_scope prgm.decl_ctx) fmt scope)) + prgm.scopes; + at_end (); + exit 0 + end; + let output_file = new_output_file ".py" in + Cli.debug_print "Compiling program into Python..."; + Cli.debug_print "Writing to %s..." output_file; + let oc = open_out output_file in + let fmt = Format.formatter_of_out_channel oc in + Scalc.To_python.format_program fmt prgm type_ordering; + close_out oc + | _ -> assert false (* should not happen *)); + 0 + | _ -> assert false + (* should not happen *)) with | Errors.StructuredError (msg, pos) -> - Cli.error_print "%s" (Errors.print_structured_error msg pos); - -1 + Cli.error_print "%s" (Errors.print_structured_error msg pos); + -1 | Sys_error msg -> - Cli.error_print "System error: %s" msg; - -1 + Cli.error_print "System error: %s" msg; + -1 let main () = let return_code = diff --git a/compiler/lcalc/backends.ml b/compiler/lcalc/backends.ml index c55c2f2a..cfb51330 100644 --- a/compiler/lcalc/backends.ml +++ b/compiler/lcalc/backends.ml @@ -36,8 +36,8 @@ let to_ascii (s : string) : string = | c when c >= 0xd9 && c <= 0xdc -> "U" | c when c >= 0xf9 && c <= 0xfc -> "u" | _ -> - if code > 128 then "_" - else String.make 1 (CamomileLibraryDefault.Camomile.UChar.char_of c)) + if code > 128 then "_" + else String.make 1 (CamomileLibraryDefault.Camomile.UChar.char_of c)) s; !out diff --git a/compiler/lcalc/closure_conversion.ml b/compiler/lcalc/closure_conversion.ml index e2af2d00..39d29673 100644 --- a/compiler/lcalc/closure_conversion.ml +++ b/compiler/lcalc/closure_conversion.ml @@ -31,250 +31,241 @@ let rec closure_conversion_expr (ctx : ctx) (e : expr Pos.marked) : expr Pos.marked Bindlib.box * VarSet.t = match Pos.unmark e with | EVar v -> - ( Bindlib.box_apply - (fun new_v -> (new_v, Pos.get_position v)) - (Bindlib.box_var (Pos.unmark v)), - VarSet.diff (VarSet.singleton (Pos.unmark v)) ctx.globally_bound_vars ) + ( Bindlib.box_apply + (fun new_v -> new_v, Pos.get_position v) + (Bindlib.box_var (Pos.unmark v)), + VarSet.diff (VarSet.singleton (Pos.unmark v)) ctx.globally_bound_vars ) | ETuple (args, s) -> - let new_args, free_vars = - List.fold_left - (fun (new_args, free_vars) arg -> - let new_arg, new_free_vars = closure_conversion_expr ctx arg in - (new_arg :: new_args, VarSet.union new_free_vars free_vars)) - ([], VarSet.empty) args - in - ( Bindlib.box_apply - (fun new_args -> (ETuple (List.rev new_args, s), Pos.get_position e)) - (Bindlib.box_list new_args), - free_vars ) + let new_args, free_vars = + List.fold_left + (fun (new_args, free_vars) arg -> + let new_arg, new_free_vars = closure_conversion_expr ctx arg in + new_arg :: new_args, VarSet.union new_free_vars free_vars) + ([], VarSet.empty) args + in + ( Bindlib.box_apply + (fun new_args -> ETuple (List.rev new_args, s), Pos.get_position e) + (Bindlib.box_list new_args), + free_vars ) | ETupleAccess (e1, n, s, typs) -> - let new_e1, free_vars = closure_conversion_expr ctx e1 in - ( Bindlib.box_apply - (fun new_e1 -> - (ETupleAccess (new_e1, n, s, typs), Pos.get_position e)) - new_e1, - free_vars ) + let new_e1, free_vars = closure_conversion_expr ctx e1 in + ( Bindlib.box_apply + (fun new_e1 -> ETupleAccess (new_e1, n, s, typs), Pos.get_position e) + new_e1, + free_vars ) | EInj (e1, n, e_name, typs) -> - let new_e1, free_vars = closure_conversion_expr ctx e1 in - ( Bindlib.box_apply - (fun new_e1 -> (EInj (new_e1, n, e_name, typs), Pos.get_position e)) - new_e1, - free_vars ) + let new_e1, free_vars = closure_conversion_expr ctx e1 in + ( Bindlib.box_apply + (fun new_e1 -> EInj (new_e1, n, e_name, typs), Pos.get_position e) + new_e1, + free_vars ) | EMatch (e1, arms, e_name) -> - let new_e1, free_vars = closure_conversion_expr ctx e1 in - (* We do not close the clotures inside the arms of the match expression, - since they get a special treatment at compilation to Scalc. *) - let new_arms, free_vars = - List.fold_right - (fun arm (new_arms, free_vars) -> - match Pos.unmark arm with - | EAbs ((binder, binder_pos), typs) -> - let vars, body = Bindlib.unmbind binder in - let new_body, new_free_vars = - closure_conversion_expr ctx body - in - let new_binder = Bindlib.bind_mvar vars new_body in - ( Bindlib.box_apply - (fun new_binder -> - ( EAbs ((new_binder, binder_pos), typs), - Pos.get_position arm )) - new_binder - :: new_arms, - VarSet.union free_vars new_free_vars ) - | _ -> failwith "should not happen") - arms ([], free_vars) - in - ( Bindlib.box_apply2 - (fun new_e1 new_arms -> - (EMatch (new_e1, new_arms, e_name), Pos.get_position e)) - new_e1 - (Bindlib.box_list new_arms), - free_vars ) + let new_e1, free_vars = closure_conversion_expr ctx e1 in + (* We do not close the clotures inside the arms of the match expression, + since they get a special treatment at compilation to Scalc. *) + let new_arms, free_vars = + List.fold_right + (fun arm (new_arms, free_vars) -> + match Pos.unmark arm with + | EAbs ((binder, binder_pos), typs) -> + let vars, body = Bindlib.unmbind binder in + let new_body, new_free_vars = closure_conversion_expr ctx body in + let new_binder = Bindlib.bind_mvar vars new_body in + ( Bindlib.box_apply + (fun new_binder -> + EAbs ((new_binder, binder_pos), typs), Pos.get_position arm) + new_binder + :: new_arms, + VarSet.union free_vars new_free_vars ) + | _ -> failwith "should not happen") + arms ([], free_vars) + in + ( Bindlib.box_apply2 + (fun new_e1 new_arms -> + EMatch (new_e1, new_arms, e_name), Pos.get_position e) + new_e1 + (Bindlib.box_list new_arms), + free_vars ) | EArray args -> - let new_args, free_vars = - List.fold_right - (fun arg (new_args, free_vars) -> - let new_arg, new_free_vars = closure_conversion_expr ctx arg in - (new_arg :: new_args, VarSet.union free_vars new_free_vars)) - args ([], VarSet.empty) - in - ( Bindlib.box_apply - (fun new_args -> (EArray new_args, Pos.get_position e)) - (Bindlib.box_list new_args), - free_vars ) - | ELit l -> (Bindlib.box (ELit l, Pos.get_position e), VarSet.empty) + let new_args, free_vars = + List.fold_right + (fun arg (new_args, free_vars) -> + let new_arg, new_free_vars = closure_conversion_expr ctx arg in + new_arg :: new_args, VarSet.union free_vars new_free_vars) + args ([], VarSet.empty) + in + ( Bindlib.box_apply + (fun new_args -> EArray new_args, Pos.get_position e) + (Bindlib.box_list new_args), + free_vars ) + | ELit l -> Bindlib.box (ELit l, Pos.get_position e), VarSet.empty | EApp ((EAbs ((binder, binder_pos), typs_abs), e1_pos), args) -> - (* let-binding, we should not close these *) - let vars, body = Bindlib.unmbind binder in - let new_body, free_vars = closure_conversion_expr ctx body in - let new_binder = Bindlib.bind_mvar vars new_body in - let new_args, free_vars = - List.fold_right - (fun arg (new_args, free_vars) -> - let new_arg, new_free_vars = closure_conversion_expr ctx arg in - (new_arg :: new_args, VarSet.union free_vars new_free_vars)) - args ([], free_vars) - in - ( Bindlib.box_apply2 - (fun new_binder new_args -> - ( EApp - ((EAbs ((new_binder, binder_pos), typs_abs), e1_pos), new_args), - Pos.get_position e )) - new_binder - (Bindlib.box_list new_args), - free_vars ) + (* let-binding, we should not close these *) + let vars, body = Bindlib.unmbind binder in + let new_body, free_vars = closure_conversion_expr ctx body in + let new_binder = Bindlib.bind_mvar vars new_body in + let new_args, free_vars = + List.fold_right + (fun arg (new_args, free_vars) -> + let new_arg, new_free_vars = closure_conversion_expr ctx arg in + new_arg :: new_args, VarSet.union free_vars new_free_vars) + args ([], free_vars) + in + ( Bindlib.box_apply2 + (fun new_binder new_args -> + ( EApp ((EAbs ((new_binder, binder_pos), typs_abs), e1_pos), new_args), + Pos.get_position e )) + new_binder + (Bindlib.box_list new_args), + free_vars ) | EAbs ((binder, binder_pos), typs) -> - (* λ x.t *) - (* Converting the closure. *) - let vars, body = Bindlib.unmbind binder in - (* t *) - let new_body, body_vars = closure_conversion_expr ctx body in - (* [[t]] *) - let extra_vars = - VarSet.diff body_vars (VarSet.of_list (Array.to_list vars)) - in - let extra_vars_list = VarSet.elements extra_vars in - (* x1, ..., xn *) - let code_var = Var.make (ctx.name_context, binder_pos) in - (* code *) - let inner_c_var = Var.make ("env", binder_pos) in - let new_closure_body = - make_multiple_let_in - (Array.of_list extra_vars_list) - (List.init (List.length extra_vars_list) (fun _ -> - (Dcalc.Ast.TAny, binder_pos))) - (List.mapi - (fun i _ -> - Bindlib.box_apply - (fun inner_c_var -> - ( ETupleAccess - ( (inner_c_var, binder_pos), - i + 1, - None, - List.init - (List.length extra_vars_list + 1) - (fun _ -> (Dcalc.Ast.TAny, binder_pos)) ), - binder_pos )) - (Bindlib.box_var inner_c_var)) - extra_vars_list) - new_body binder_pos - in - let new_closure = - make_abs - (Array.concat [ Array.make 1 inner_c_var; vars ]) - new_closure_body binder_pos - ((Dcalc.Ast.TAny, binder_pos) :: typs) - (Pos.get_position e) - in - ( make_let_in code_var - (Dcalc.Ast.TAny, Pos.get_position e) - new_closure - (Bindlib.box_apply2 - (fun code_var extra_vars -> - ( ETuple - ( (code_var, binder_pos) - :: List.map - (fun extra_var -> (extra_var, binder_pos)) - extra_vars, - None ), - Pos.get_position e )) - (Bindlib.box_var code_var) - (Bindlib.box_list - (List.map - (fun extra_var -> Bindlib.box_var extra_var) - extra_vars_list))) - (Pos.get_position e), - extra_vars ) + (* λ x.t *) + (* Converting the closure. *) + let vars, body = Bindlib.unmbind binder in + (* t *) + let new_body, body_vars = closure_conversion_expr ctx body in + (* [[t]] *) + let extra_vars = + VarSet.diff body_vars (VarSet.of_list (Array.to_list vars)) + in + let extra_vars_list = VarSet.elements extra_vars in + (* x1, ..., xn *) + let code_var = Var.make (ctx.name_context, binder_pos) in + (* code *) + let inner_c_var = Var.make ("env", binder_pos) in + let new_closure_body = + make_multiple_let_in + (Array.of_list extra_vars_list) + (List.init (List.length extra_vars_list) (fun _ -> + Dcalc.Ast.TAny, binder_pos)) + (List.mapi + (fun i _ -> + Bindlib.box_apply + (fun inner_c_var -> + ( ETupleAccess + ( (inner_c_var, binder_pos), + i + 1, + None, + List.init + (List.length extra_vars_list + 1) + (fun _ -> Dcalc.Ast.TAny, binder_pos) ), + binder_pos )) + (Bindlib.box_var inner_c_var)) + extra_vars_list) + new_body binder_pos + in + let new_closure = + make_abs + (Array.concat [Array.make 1 inner_c_var; vars]) + new_closure_body binder_pos + ((Dcalc.Ast.TAny, binder_pos) :: typs) + (Pos.get_position e) + in + ( make_let_in code_var + (Dcalc.Ast.TAny, Pos.get_position e) + new_closure + (Bindlib.box_apply2 + (fun code_var extra_vars -> + ( ETuple + ( (code_var, binder_pos) + :: List.map + (fun extra_var -> extra_var, binder_pos) + extra_vars, + None ), + Pos.get_position e )) + (Bindlib.box_var code_var) + (Bindlib.box_list + (List.map + (fun extra_var -> Bindlib.box_var extra_var) + extra_vars_list))) + (Pos.get_position e), + extra_vars ) | EApp ((EOp op, pos_op), args) -> - (* This corresponds to an operator call, which we don't want to - transform*) - let new_args, free_vars = - List.fold_right - (fun arg (new_args, free_vars) -> - let new_arg, new_free_vars = closure_conversion_expr ctx arg in - (new_arg :: new_args, VarSet.union free_vars new_free_vars)) - args ([], VarSet.empty) - in - ( Bindlib.box_apply - (fun new_e2 -> (EApp ((EOp op, pos_op), new_e2), Pos.get_position e)) - (Bindlib.box_list new_args), - free_vars ) + (* This corresponds to an operator call, which we don't want to transform*) + let new_args, free_vars = + List.fold_right + (fun arg (new_args, free_vars) -> + let new_arg, new_free_vars = closure_conversion_expr ctx arg in + new_arg :: new_args, VarSet.union free_vars new_free_vars) + args ([], VarSet.empty) + in + ( Bindlib.box_apply + (fun new_e2 -> EApp ((EOp op, pos_op), new_e2), Pos.get_position e) + (Bindlib.box_list new_args), + free_vars ) | EApp ((EVar (v, _), v_pos), args) when VarSet.mem v ctx.globally_bound_vars -> - (* This corresponds to a scope call, which we don't want to transform*) - let new_args, free_vars = - List.fold_right - (fun arg (new_args, free_vars) -> - let new_arg, new_free_vars = closure_conversion_expr ctx arg in - (new_arg :: new_args, VarSet.union free_vars new_free_vars)) - args ([], VarSet.empty) - in - ( Bindlib.box_apply2 - (fun new_v new_e2 -> - (EApp ((new_v, v_pos), new_e2), Pos.get_position e)) - (Bindlib.box_var v) - (Bindlib.box_list new_args), - free_vars ) + (* This corresponds to a scope call, which we don't want to transform*) + let new_args, free_vars = + List.fold_right + (fun arg (new_args, free_vars) -> + let new_arg, new_free_vars = closure_conversion_expr ctx arg in + new_arg :: new_args, VarSet.union free_vars new_free_vars) + args ([], VarSet.empty) + in + ( Bindlib.box_apply2 + (fun new_v new_e2 -> EApp ((new_v, v_pos), new_e2), Pos.get_position e) + (Bindlib.box_var v) + (Bindlib.box_list new_args), + free_vars ) | EApp (e1, args) -> - let new_e1, free_vars = closure_conversion_expr ctx e1 in - let env_var = Var.make ("env", Pos.get_position e1) in - let code_var = Var.make ("code", Pos.get_position e1) in - let new_args, free_vars = - List.fold_right - (fun arg (new_args, free_vars) -> - let new_arg, new_free_vars = closure_conversion_expr ctx arg in - (new_arg :: new_args, VarSet.union free_vars new_free_vars)) - args ([], free_vars) - in - let call_expr = - make_let_in code_var - (Dcalc.Ast.TAny, Pos.get_position e) - (Bindlib.box_apply - (fun env_var -> - ( ETupleAccess - ((env_var, Pos.get_position e1), 0, None, [ (*TODO: fill?*) ]), - Pos.get_position e )) - (Bindlib.box_var env_var)) - (Bindlib.box_apply3 - (fun code_var env_var new_args -> - ( EApp - ( (code_var, Pos.get_position e1), - (env_var, Pos.get_position e1) :: new_args ), - Pos.get_position e )) - (Bindlib.box_var code_var) (Bindlib.box_var env_var) - (Bindlib.box_list new_args)) - (Pos.get_position e) - in - ( make_let_in env_var - (Dcalc.Ast.TAny, Pos.get_position e) - new_e1 call_expr (Pos.get_position e), - free_vars ) + let new_e1, free_vars = closure_conversion_expr ctx e1 in + let env_var = Var.make ("env", Pos.get_position e1) in + let code_var = Var.make ("code", Pos.get_position e1) in + let new_args, free_vars = + List.fold_right + (fun arg (new_args, free_vars) -> + let new_arg, new_free_vars = closure_conversion_expr ctx arg in + new_arg :: new_args, VarSet.union free_vars new_free_vars) + args ([], free_vars) + in + let call_expr = + make_let_in code_var + (Dcalc.Ast.TAny, Pos.get_position e) + (Bindlib.box_apply + (fun env_var -> + ( ETupleAccess + ((env_var, Pos.get_position e1), 0, None, [ (*TODO: fill?*) ]), + Pos.get_position e )) + (Bindlib.box_var env_var)) + (Bindlib.box_apply3 + (fun code_var env_var new_args -> + ( EApp + ( (code_var, Pos.get_position e1), + (env_var, Pos.get_position e1) :: new_args ), + Pos.get_position e )) + (Bindlib.box_var code_var) (Bindlib.box_var env_var) + (Bindlib.box_list new_args)) + (Pos.get_position e) + in + ( make_let_in env_var + (Dcalc.Ast.TAny, Pos.get_position e) + new_e1 call_expr (Pos.get_position e), + free_vars ) | EAssert e1 -> - let new_e1, free_vars = closure_conversion_expr ctx e1 in - ( Bindlib.box_apply - (fun new_e1 -> (EAssert new_e1, Pos.get_position e)) - new_e1, - free_vars ) - | EOp op -> (Bindlib.box (EOp op, Pos.get_position e), VarSet.empty) + let new_e1, free_vars = closure_conversion_expr ctx e1 in + ( Bindlib.box_apply (fun new_e1 -> EAssert new_e1, Pos.get_position e) new_e1, + free_vars ) + | EOp op -> Bindlib.box (EOp op, Pos.get_position e), VarSet.empty | EIfThenElse (e1, e2, e3) -> - let new_e1, free_vars1 = closure_conversion_expr ctx e1 in - let new_e2, free_vars2 = closure_conversion_expr ctx e2 in - let new_e3, free_vars3 = closure_conversion_expr ctx e3 in - ( Bindlib.box_apply3 - (fun new_e1 new_e2 new_e3 -> - (EIfThenElse (new_e1, new_e2, new_e3), Pos.get_position e)) - new_e1 new_e2 new_e3, - VarSet.union (VarSet.union free_vars1 free_vars2) free_vars3 ) + let new_e1, free_vars1 = closure_conversion_expr ctx e1 in + let new_e2, free_vars2 = closure_conversion_expr ctx e2 in + let new_e3, free_vars3 = closure_conversion_expr ctx e3 in + ( Bindlib.box_apply3 + (fun new_e1 new_e2 new_e3 -> + EIfThenElse (new_e1, new_e2, new_e3), Pos.get_position e) + new_e1 new_e2 new_e3, + VarSet.union (VarSet.union free_vars1 free_vars2) free_vars3 ) | ERaise except -> - (Bindlib.box (ERaise except, Pos.get_position e), VarSet.empty) + Bindlib.box (ERaise except, Pos.get_position e), VarSet.empty | ECatch (e1, except, e2) -> - let new_e1, free_vars1 = closure_conversion_expr ctx e1 in - let new_e2, free_vars2 = closure_conversion_expr ctx e2 in - ( Bindlib.box_apply2 - (fun new_e1 new_e2 -> - (ECatch (new_e1, except, new_e2), Pos.get_position e)) - new_e1 new_e2, - VarSet.union free_vars1 free_vars2 ) + let new_e1, free_vars1 = closure_conversion_expr ctx e1 in + let new_e2, free_vars2 = closure_conversion_expr ctx e2 in + ( Bindlib.box_apply2 + (fun new_e1 new_e2 -> + ECatch (new_e1, except, new_e2), Pos.get_position e) + new_e1 new_e2, + VarSet.union free_vars1 free_vars2 ) let closure_conversion (p : program) : program Bindlib.box = let new_scopes, _ = @@ -321,7 +312,7 @@ let closure_conversion (p : program) : program Bindlib.box = new_scope_body_expr (Bindlib.bind_var scope_var next))), global_vars )) - ~init:(Fun.id, VarSet.of_list [ handle_default; handle_default_opt ]) + ~init:(Fun.id, VarSet.of_list [handle_default; handle_default_opt]) p.scopes in Bindlib.box_apply diff --git a/compiler/lcalc/compile_with_exceptions.ml b/compiler/lcalc/compile_with_exceptions.ml index 28a297cf..f0af48a4 100644 --- a/compiler/lcalc/compile_with_exceptions.ml +++ b/compiler/lcalc/compile_with_exceptions.ml @@ -36,7 +36,7 @@ let translate_lit (l : D.lit) : A.expr = let thunk_expr (e : A.expr Pos.marked Bindlib.box) (pos : Pos.t) : A.expr Pos.marked Bindlib.box = let dummy_var = A.Var.make ("_", pos) in - A.make_abs [| dummy_var |] e pos [ (D.TAny, pos) ] pos + A.make_abs [| dummy_var |] e pos [D.TAny, pos] pos let rec translate_default (ctx : ctx) @@ -66,55 +66,55 @@ and translate_expr (ctx : ctx) (e : D.expr Pos.marked) : match Pos.unmark e with | D.EVar v -> A.make_var (D.VarMap.find (Pos.unmark v) ctx, Pos.get_position e) | D.ETuple (args, s) -> - A.etuple (List.map (translate_expr ctx) args) s (Pos.get_position e) + A.etuple (List.map (translate_expr ctx) args) s (Pos.get_position e) | D.ETupleAccess (e1, i, s, ts) -> - A.etupleaccess (translate_expr ctx e1) i s ts (Pos.get_position e) + A.etupleaccess (translate_expr ctx e1) i s ts (Pos.get_position e) | D.EInj (e1, i, en, ts) -> - A.einj (translate_expr ctx e1) i en ts (Pos.get_position e) + A.einj (translate_expr ctx e1) i en ts (Pos.get_position e) | D.EMatch (e1, cases, en) -> - A.ematch (translate_expr ctx e1) - (List.map (translate_expr ctx) cases) - en (Pos.get_position e) + A.ematch (translate_expr ctx e1) + (List.map (translate_expr ctx) cases) + en (Pos.get_position e) | D.EArray es -> - A.earray (List.map (translate_expr ctx) es) (Pos.get_position e) + A.earray (List.map (translate_expr ctx) es) (Pos.get_position e) | D.ELit l -> Bindlib.box (Pos.same_pos_as (translate_lit l) e) | D.EOp op -> A.eop op (Pos.get_position e) | D.EIfThenElse (e1, e2, e3) -> - A.eifthenelse (translate_expr ctx e1) (translate_expr ctx e2) - (translate_expr ctx e3) (Pos.get_position e) + A.eifthenelse (translate_expr ctx e1) (translate_expr ctx e2) + (translate_expr ctx e3) (Pos.get_position e) | D.EAssert e1 -> A.eassert (translate_expr ctx e1) (Pos.get_position e) | D.ErrorOnEmpty arg -> - A.ecatch (translate_expr ctx arg) A.EmptyError - (Bindlib.box (Pos.same_pos_as (A.ERaise A.NoValueProvided) e)) - (Pos.get_position e) + A.ecatch (translate_expr ctx arg) A.EmptyError + (Bindlib.box (Pos.same_pos_as (A.ERaise A.NoValueProvided) e)) + (Pos.get_position e) | D.EApp (e1, args) -> - A.eapp (translate_expr ctx e1) - (List.map (translate_expr ctx) args) - (Pos.get_position e) + A.eapp (translate_expr ctx e1) + (List.map (translate_expr ctx) args) + (Pos.get_position e) | D.EAbs ((binder, pos_binder), ts) -> - let vars, body = Bindlib.unmbind binder in - let ctx, lc_vars = - Array.fold_right - (fun var (ctx, lc_vars) -> - let lc_var = A.Var.make (Bindlib.name_of var, pos_binder) in - (D.VarMap.add var lc_var ctx, lc_var :: lc_vars)) - vars (ctx, []) - in - let lc_vars = Array.of_list lc_vars in - let new_body = translate_expr ctx body in - let new_binder = Bindlib.bind_mvar lc_vars new_body in - Bindlib.box_apply - (fun new_binder -> - Pos.same_pos_as (A.EAbs ((new_binder, pos_binder), ts)) e) - new_binder - | D.EDefault ([ exn ], just, cons) when !Cli.optimize_flag -> - A.ecatch (translate_expr ctx exn) A.EmptyError - (A.eifthenelse (translate_expr ctx just) (translate_expr ctx cons) - (Bindlib.box (Pos.same_pos_as (A.ERaise A.EmptyError) e)) - (Pos.get_position e)) - (Pos.get_position e) + let vars, body = Bindlib.unmbind binder in + let ctx, lc_vars = + Array.fold_right + (fun var (ctx, lc_vars) -> + let lc_var = A.Var.make (Bindlib.name_of var, pos_binder) in + D.VarMap.add var lc_var ctx, lc_var :: lc_vars) + vars (ctx, []) + in + let lc_vars = Array.of_list lc_vars in + let new_body = translate_expr ctx body in + let new_binder = Bindlib.bind_mvar lc_vars new_body in + Bindlib.box_apply + (fun new_binder -> + Pos.same_pos_as (A.EAbs ((new_binder, pos_binder), ts)) e) + new_binder + | D.EDefault ([exn], just, cons) when !Cli.optimize_flag -> + A.ecatch (translate_expr ctx exn) A.EmptyError + (A.eifthenelse (translate_expr ctx just) (translate_expr ctx cons) + (Bindlib.box (Pos.same_pos_as (A.ERaise A.EmptyError) e)) + (Pos.get_position e)) + (Pos.get_position e) | D.EDefault (exceptions, just, cons) -> - translate_default ctx exceptions just cons (Pos.get_position e) + translate_default ctx exceptions just cons (Pos.get_position e) let rec translate_scope_lets (decl_ctx : D.decl_ctx) @@ -124,29 +124,27 @@ let rec translate_scope_lets match scope_lets with | Result e -> Bindlib.box_apply (fun e -> D.Result e) (translate_expr ctx e) | ScopeLet scope_let -> - let old_scope_let_var, scope_let_next = - Bindlib.unbind scope_let.scope_let_next - in - let new_scope_let_var = - A.Var.make (Bindlib.name_of old_scope_let_var, scope_let.scope_let_pos) - in - let new_scope_let_expr = translate_expr ctx scope_let.scope_let_expr in - let new_ctx = D.VarMap.add old_scope_let_var new_scope_let_var ctx in - let new_scope_next = - translate_scope_lets decl_ctx new_ctx scope_let_next - in - let new_scope_next = Bindlib.bind_var new_scope_let_var new_scope_next in - Bindlib.box_apply2 - (fun new_scope_next new_scope_let_expr -> - D.ScopeLet - { - scope_let_typ = scope_let.D.scope_let_typ; - scope_let_kind = scope_let.D.scope_let_kind; - scope_let_pos = scope_let.D.scope_let_pos; - scope_let_next = new_scope_next; - scope_let_expr = new_scope_let_expr; - }) - new_scope_next new_scope_let_expr + let old_scope_let_var, scope_let_next = + Bindlib.unbind scope_let.scope_let_next + in + let new_scope_let_var = + A.Var.make (Bindlib.name_of old_scope_let_var, scope_let.scope_let_pos) + in + let new_scope_let_expr = translate_expr ctx scope_let.scope_let_expr in + let new_ctx = D.VarMap.add old_scope_let_var new_scope_let_var ctx in + let new_scope_next = translate_scope_lets decl_ctx new_ctx scope_let_next in + let new_scope_next = Bindlib.bind_var new_scope_let_var new_scope_next in + Bindlib.box_apply2 + (fun new_scope_next new_scope_let_expr -> + D.ScopeLet + { + scope_let_typ = scope_let.D.scope_let_typ; + scope_let_kind = scope_let.D.scope_let_kind; + scope_let_pos = scope_let.D.scope_let_pos; + scope_let_next = new_scope_next; + scope_let_expr = new_scope_let_expr; + }) + new_scope_next new_scope_let_expr let rec translate_scopes (decl_ctx : D.decl_ctx) @@ -155,51 +153,51 @@ let rec translate_scopes match scopes with | Nil -> Bindlib.box D.Nil | ScopeDef scope_def -> - let old_scope_var, scope_next = Bindlib.unbind scope_def.scope_next in - let new_scope_var = - A.Var.make (D.ScopeName.get_info scope_def.scope_name) - in - let old_scope_input_var, scope_body_expr = - Bindlib.unbind scope_def.scope_body.scope_body_expr - in - let new_scope_input_var = - A.Var.make - ( Bindlib.name_of old_scope_input_var, - Pos.get_position (D.ScopeName.get_info scope_def.scope_name) ) - in - let new_ctx = D.VarMap.add old_scope_input_var new_scope_input_var ctx in - let new_scope_body_expr = - translate_scope_lets decl_ctx new_ctx scope_body_expr - in - let new_scope_body_expr = - Bindlib.bind_var new_scope_input_var new_scope_body_expr - in - let new_scope : A.expr D.scope_body Bindlib.box = - Bindlib.box_apply - (fun new_scope_body_expr -> - { - D.scope_body_input_struct = - scope_def.scope_body.scope_body_input_struct; - scope_body_output_struct = - scope_def.scope_body.scope_body_output_struct; - scope_body_expr = new_scope_body_expr; - }) - new_scope_body_expr - in - let new_ctx = D.VarMap.add old_scope_var new_scope_var new_ctx in - let scope_next = - Bindlib.bind_var new_scope_var - (translate_scopes decl_ctx new_ctx scope_next) - in - Bindlib.box_apply2 - (fun new_scope scope_next -> - D.ScopeDef - { - scope_name = scope_def.scope_name; - scope_body = new_scope; - scope_next; - }) - new_scope scope_next + let old_scope_var, scope_next = Bindlib.unbind scope_def.scope_next in + let new_scope_var = + A.Var.make (D.ScopeName.get_info scope_def.scope_name) + in + let old_scope_input_var, scope_body_expr = + Bindlib.unbind scope_def.scope_body.scope_body_expr + in + let new_scope_input_var = + A.Var.make + ( Bindlib.name_of old_scope_input_var, + Pos.get_position (D.ScopeName.get_info scope_def.scope_name) ) + in + let new_ctx = D.VarMap.add old_scope_input_var new_scope_input_var ctx in + let new_scope_body_expr = + translate_scope_lets decl_ctx new_ctx scope_body_expr + in + let new_scope_body_expr = + Bindlib.bind_var new_scope_input_var new_scope_body_expr + in + let new_scope : A.expr D.scope_body Bindlib.box = + Bindlib.box_apply + (fun new_scope_body_expr -> + { + D.scope_body_input_struct = + scope_def.scope_body.scope_body_input_struct; + scope_body_output_struct = + scope_def.scope_body.scope_body_output_struct; + scope_body_expr = new_scope_body_expr; + }) + new_scope_body_expr + in + let new_ctx = D.VarMap.add old_scope_var new_scope_var new_ctx in + let scope_next = + Bindlib.bind_var new_scope_var + (translate_scopes decl_ctx new_ctx scope_next) + in + Bindlib.box_apply2 + (fun new_scope scope_next -> + D.ScopeDef + { + scope_name = scope_def.scope_name; + scope_body = new_scope; + scope_next; + }) + new_scope scope_next let translate_program (prgm : D.program) : A.program = { diff --git a/compiler/lcalc/compile_without_exceptions.ml b/compiler/lcalc/compile_without_exceptions.ml index 3d903142..9bca337b 100644 --- a/compiler/lcalc/compile_without_exceptions.ml +++ b/compiler/lcalc/compile_without_exceptions.ml @@ -123,9 +123,8 @@ let rec translate_typ (tau : D.typ Pos.marked) : D.typ Pos.marked = | D.TArray ts -> D.TArray (translate_typ ts) (* catala is not polymorphic *) | D.TArrow ((D.TLit D.TUnit, pos_unit), t2) -> - D.TEnum - ([ (D.TLit D.TUnit, pos_unit); translate_typ t2 ], A.option_enum) - (* D.TAny *) + D.TEnum ([D.TLit D.TUnit, pos_unit; translate_typ t2], A.option_enum) + (* D.TAny *) | D.TArrow (t1, t2) -> D.TArrow (translate_typ t1, translate_typ t2) end @@ -139,9 +138,9 @@ let translate_lit (l : D.lit) (pos : Pos.t) : A.lit = | D.LDate d -> A.LDate d | D.LDuration d -> A.LDuration d | D.LEmptyError -> - Errors.raise_spanned_error pos - "Internal Error: An empty error was found in a place that shouldn't be \ - possible." + Errors.raise_spanned_error pos + "Internal Error: An empty error was found in a place that shouldn't be \ + possible." (** [c = disjoint_union_maps cs] Compute the disjoint union of multiple maps. Raises an internal error if there is two identicals keys in differnts parts. *) @@ -169,136 +168,132 @@ let rec translate_and_hoist (ctx : ctx) (e : D.expr Pos.marked) : EApp(D.EVar _, [ELit LUnit]), EDefault _, ELit LEmptyDefault) I'm unsure about assert. *) | D.EVar v -> - (* todo: for now, every unpure (such that [is_pure] is [false] in the - current context) is thunked, hence matched in the next case. This - assumption can change in the future, and this case is here for this - reason. *) - let v, pos_v = v in - if not (find ~info:"search for a variable" v ctx).is_pure then - let v' = A.Var.make (Bindlib.name_of v, pos_v) in - (* Cli.debug_print @@ Format.asprintf "Found an unpure variable %a, - created a variable %a to replace it" Dcalc.Print.format_var v - Print.format_var v'; *) - (A.make_var (v', pos), A.VarMap.singleton v' e) - else ((find ~info:"should never happend" v ctx).expr, A.VarMap.empty) - | D.EApp ((D.EVar (v, pos_v), p), [ (D.ELit D.LUnit, _) ]) -> - if not (find ~info:"search for a variable" v ctx).is_pure then - let v' = A.Var.make (Bindlib.name_of v, pos_v) in - (* Cli.debug_print @@ Format.asprintf "Found an unpure variable %a, - created a variable %a to replace it" Dcalc.Print.format_var v - Print.format_var v'; *) - (A.make_var (v', pos), A.VarMap.singleton v' (D.EVar (v, pos_v), p)) - else - Errors.raise_spanned_error pos - "Internal error: an pure variable was found in an unpure environment." + (* todo: for now, every unpure (such that [is_pure] is [false] in the + current context) is thunked, hence matched in the next case. This + assumption can change in the future, and this case is here for this + reason. *) + let v, pos_v = v in + if not (find ~info:"search for a variable" v ctx).is_pure then + let v' = A.Var.make (Bindlib.name_of v, pos_v) in + (* Cli.debug_print @@ Format.asprintf "Found an unpure variable %a, + created a variable %a to replace it" Dcalc.Print.format_var v + Print.format_var v'; *) + A.make_var (v', pos), A.VarMap.singleton v' e + else (find ~info:"should never happend" v ctx).expr, A.VarMap.empty + | D.EApp ((D.EVar (v, pos_v), p), [(D.ELit D.LUnit, _)]) -> + if not (find ~info:"search for a variable" v ctx).is_pure then + let v' = A.Var.make (Bindlib.name_of v, pos_v) in + (* Cli.debug_print @@ Format.asprintf "Found an unpure variable %a, + created a variable %a to replace it" Dcalc.Print.format_var v + Print.format_var v'; *) + A.make_var (v', pos), A.VarMap.singleton v' (D.EVar (v, pos_v), p) + else + Errors.raise_spanned_error pos + "Internal error: an pure variable was found in an unpure environment." | D.EDefault (_exceptions, _just, _cons) -> - let v' = A.Var.make ("default_term", pos) in - (A.make_var (v', pos), A.VarMap.singleton v' e) + let v' = A.Var.make ("default_term", pos) in + A.make_var (v', pos), A.VarMap.singleton v' e | D.ELit D.LEmptyError -> - let v' = A.Var.make ("empty_litteral", pos) in - (A.make_var (v', pos), A.VarMap.singleton v' e) + let v' = A.Var.make ("empty_litteral", pos) in + A.make_var (v', pos), A.VarMap.singleton v' e (* This one is a very special case. It transform an unpure expression environement to a pure expression. *) | ErrorOnEmpty arg -> - (* [ match arg with | None -> raise NoValueProvided | Some v -> {{ v }} - ] *) - let silent_var = A.Var.make ("_", pos) in - let x = A.Var.make ("non_empty_argument", pos) in + (* [ match arg with | None -> raise NoValueProvided | Some v -> {{ v }} ] *) + let silent_var = A.Var.make ("_", pos) in + let x = A.Var.make ("non_empty_argument", pos) in - let arg' = translate_expr ctx arg in + let arg' = translate_expr ctx arg in - ( A.make_matchopt_with_abs_arms arg' - (A.make_abs [| silent_var |] - (Bindlib.box (A.ERaise A.NoValueProvided, pos)) - pos - [ (D.TAny, pos) ] - pos) - (A.make_abs [| x |] (A.make_var (x, pos)) pos [ (D.TAny, pos) ] pos), - A.VarMap.empty ) + ( A.make_matchopt_with_abs_arms arg' + (A.make_abs [| silent_var |] + (Bindlib.box (A.ERaise A.NoValueProvided, pos)) + pos + [D.TAny, pos] + pos) + (A.make_abs [| x |] (A.make_var (x, pos)) pos [D.TAny, pos] pos), + A.VarMap.empty ) (* pure terms *) - | D.ELit l -> (A.elit (translate_lit l pos) pos, A.VarMap.empty) + | D.ELit l -> A.elit (translate_lit l pos) pos, A.VarMap.empty | D.EIfThenElse (e1, e2, e3) -> - let e1', h1 = translate_and_hoist ctx e1 in - let e2', h2 = translate_and_hoist ctx e2 in - let e3', h3 = translate_and_hoist ctx e3 in + let e1', h1 = translate_and_hoist ctx e1 in + let e2', h2 = translate_and_hoist ctx e2 in + let e3', h3 = translate_and_hoist ctx e3 in - let e' = A.eifthenelse e1' e2' e3' pos in + let e' = A.eifthenelse e1' e2' e3' pos in - (*(* equivalent code : *) let e' = let+ e1' = e1' and+ e2' = e2' and+ e3' - = e3' in (A.EIfThenElse (e1', e2', e3'), pos) in *) - (e', disjoint_union_maps pos [ h1; h2; h3 ]) + (*(* equivalent code : *) let e' = let+ e1' = e1' and+ e2' = e2' and+ e3' = + e3' in (A.EIfThenElse (e1', e2', e3'), pos) in *) + e', disjoint_union_maps pos [h1; h2; h3] | D.EAssert e1 -> - (* same behavior as in the ICFP paper: if e1 is empty, then no error is - raised. *) - let e1', h1 = translate_and_hoist ctx e1 in - (A.eassert e1' pos, h1) + (* same behavior as in the ICFP paper: if e1 is empty, then no error is + raised. *) + let e1', h1 = translate_and_hoist ctx e1 in + A.eassert e1' pos, h1 | D.EAbs ((binder, pos_binder), ts) -> - let vars, body = Bindlib.unmbind binder in - let ctx, lc_vars = - ArrayLabels.fold_right vars ~init:(ctx, []) - ~f:(fun var (ctx, lc_vars) -> - (* we suppose the invariant that when applying a function, its - arguments cannot be of the type "option". + let vars, body = Bindlib.unmbind binder in + let ctx, lc_vars = + ArrayLabels.fold_right vars ~init:(ctx, []) ~f:(fun var (ctx, lc_vars) -> + (* we suppose the invariant that when applying a function, its + arguments cannot be of the type "option". - The code should behave correctly in the without this assumption - if we put here an is_pure=false, but the types are more - compilcated. (unimplemented for now) *) - let ctx = add_var pos var true ctx in - let lc_var = (find var ctx).var in - (ctx, lc_var :: lc_vars)) - in - let lc_vars = Array.of_list lc_vars in + The code should behave correctly in the without this assumption if + we put here an is_pure=false, but the types are more compilcated. + (unimplemented for now) *) + let ctx = add_var pos var true ctx in + let lc_var = (find var ctx).var in + ctx, lc_var :: lc_vars) + in + let lc_vars = Array.of_list lc_vars in - (* here we take the guess that if we cannot build the closure because one - of the variable is empty, then we cannot build the function. *) - let new_body, hoists = translate_and_hoist ctx body in - let new_binder = Bindlib.bind_mvar lc_vars new_body in + (* here we take the guess that if we cannot build the closure because one of + the variable is empty, then we cannot build the function. *) + let new_body, hoists = translate_and_hoist ctx body in + let new_binder = Bindlib.bind_mvar lc_vars new_body in - ( Bindlib.box_apply - (fun new_binder -> - (A.EAbs ((new_binder, pos_binder), List.map translate_typ ts), pos)) - new_binder, - hoists ) + ( Bindlib.box_apply + (fun new_binder -> + A.EAbs ((new_binder, pos_binder), List.map translate_typ ts), pos) + new_binder, + hoists ) | EApp (e1, args) -> - let e1', h1 = translate_and_hoist ctx e1 in - let args', h_args = - args |> List.map (translate_and_hoist ctx) |> List.split - in + let e1', h1 = translate_and_hoist ctx e1 in + let args', h_args = + args |> List.map (translate_and_hoist ctx) |> List.split + in - let hoists = disjoint_union_maps pos (h1 :: h_args) in - let e' = A.eapp e1' args' pos in - (e', hoists) + let hoists = disjoint_union_maps pos (h1 :: h_args) in + let e' = A.eapp e1' args' pos in + e', hoists | ETuple (args, s) -> - let args', h_args = - args |> List.map (translate_and_hoist ctx) |> List.split - in + let args', h_args = + args |> List.map (translate_and_hoist ctx) |> List.split + in - let hoists = disjoint_union_maps pos h_args in - (A.etuple args' s pos, hoists) + let hoists = disjoint_union_maps pos h_args in + A.etuple args' s pos, hoists | ETupleAccess (e1, i, s, ts) -> - let e1', hoists = translate_and_hoist ctx e1 in - let e1' = A.etupleaccess e1' i s ts pos in - (e1', hoists) + let e1', hoists = translate_and_hoist ctx e1 in + let e1' = A.etupleaccess e1' i s ts pos in + e1', hoists | EInj (e1, i, en, ts) -> - let e1', hoists = translate_and_hoist ctx e1 in - let e1' = A.einj e1' i en ts pos in - (e1', hoists) + let e1', hoists = translate_and_hoist ctx e1 in + let e1' = A.einj e1' i en ts pos in + e1', hoists | EMatch (e1, cases, en) -> - let e1', h1 = translate_and_hoist ctx e1 in - let cases', h_cases = - cases |> List.map (translate_and_hoist ctx) |> List.split - in + let e1', h1 = translate_and_hoist ctx e1 in + let cases', h_cases = + cases |> List.map (translate_and_hoist ctx) |> List.split + in - let hoists = disjoint_union_maps pos (h1 :: h_cases) in - let e' = A.ematch e1' cases' en pos in - (e', hoists) + let hoists = disjoint_union_maps pos (h1 :: h_cases) in + let e' = A.ematch e1' cases' en pos in + e', hoists | EArray es -> - let es', hoists = - es |> List.map (translate_and_hoist ctx) |> List.split - in + let es', hoists = es |> List.map (translate_and_hoist ctx) |> List.split in - (A.earray es' pos, disjoint_union_maps pos hoists) - | EOp op -> (Bindlib.box (A.EOp op, pos), A.VarMap.empty) + A.earray es' pos, disjoint_union_maps pos hoists + | EOp op -> Bindlib.box (A.EOp op, pos), A.VarMap.empty and translate_expr ?(append_esome = true) (ctx : ctx) (e : D.expr Pos.marked) : A.expr Pos.marked Bindlib.box = @@ -320,48 +315,48 @@ and translate_expr ?(append_esome = true) (ctx : ctx) (e : D.expr Pos.marked) : (* Here we have to handle only the cases appearing in hoists, as defined the [translate_and_hoist] function. *) | D.EVar v -> - (find ~info:"should never happend" (Pos.unmark v) ctx).expr + (find ~info:"should never happend" (Pos.unmark v) ctx).expr | D.EDefault (excep, just, cons) -> - let excep' = List.map (translate_expr ctx) excep in - let just' = translate_expr ctx just in - let cons' = translate_expr ctx cons in - (* calls handle_option. *) - A.make_app - (A.make_var (A.handle_default_opt, pos_hoist)) - [ - Bindlib.box_apply - (fun excep' -> (A.EArray excep', pos_hoist)) - (Bindlib.box_list excep'); - just'; - cons'; - ] - pos_hoist + let excep' = List.map (translate_expr ctx) excep in + let just' = translate_expr ctx just in + let cons' = translate_expr ctx cons in + (* calls handle_option. *) + A.make_app + (A.make_var (A.handle_default_opt, pos_hoist)) + [ + Bindlib.box_apply + (fun excep' -> A.EArray excep', pos_hoist) + (Bindlib.box_list excep'); + just'; + cons'; + ] + pos_hoist | D.ELit D.LEmptyError -> A.make_none pos_hoist | D.EAssert arg -> - let arg' = translate_expr ctx arg in + let arg' = translate_expr ctx arg in - (* [ match arg with | None -> raise NoValueProvided | Some v -> - assert {{ v }} ] *) - let silent_var = A.Var.make ("_", pos_hoist) in - let x = A.Var.make ("assertion_argument", pos_hoist) in + (* [ match arg with | None -> raise NoValueProvided | Some v -> assert + {{ v }} ] *) + let silent_var = A.Var.make ("_", pos_hoist) in + let x = A.Var.make ("assertion_argument", pos_hoist) in - A.make_matchopt_with_abs_arms arg' - (A.make_abs [| silent_var |] - (Bindlib.box (A.ERaise A.NoValueProvided, pos_hoist)) - pos_hoist - [ (D.TAny, pos_hoist) ] - pos_hoist) - (A.make_abs [| x |] - (Bindlib.box_apply - (fun arg -> (A.EAssert arg, pos_hoist)) - (A.make_var (x, pos_hoist))) - pos_hoist - [ (D.TAny, pos_hoist) ] - pos_hoist) + A.make_matchopt_with_abs_arms arg' + (A.make_abs [| silent_var |] + (Bindlib.box (A.ERaise A.NoValueProvided, pos_hoist)) + pos_hoist + [D.TAny, pos_hoist] + pos_hoist) + (A.make_abs [| x |] + (Bindlib.box_apply + (fun arg -> A.EAssert arg, pos_hoist) + (A.make_var (x, pos_hoist))) + pos_hoist + [D.TAny, pos_hoist] + pos_hoist) | _ -> - Errors.raise_spanned_error pos_hoist - "Internal Error: An term was found in a position where it should \ - not be" + Errors.raise_spanned_error pos_hoist + "Internal Error: An term was found in a position where it should \ + not be" in (* [ match {{ c' }} with | None -> None | Some {{ v }} -> {{ acc }} end @@ -375,9 +370,9 @@ let rec translate_scope_let (ctx : ctx) (lets : D.expr D.scope_body_expr) : A.expr D.scope_body_expr Bindlib.box = match lets with | Result e -> - Bindlib.box_apply - (fun e -> D.Result e) - (translate_expr ~append_esome:false ctx e) + Bindlib.box_apply + (fun e -> D.Result e) + (translate_expr ~append_esome:false ctx e) | ScopeLet { scope_let_kind = SubScopeVarDefinition; @@ -386,31 +381,29 @@ let rec translate_scope_let (ctx : ctx) (lets : D.expr D.scope_body_expr) : scope_let_next = next; scope_let_pos = pos; } -> - (* special case : the subscope variable is thunked (context i/o). We - remove this thunking. *) - let _, expr = Bindlib.unmbind binder in + (* special case : the subscope variable is thunked (context i/o). We remove + this thunking. *) + let _, expr = Bindlib.unmbind binder in - let var_is_pure = true in - let var, next = Bindlib.unbind next in - (* Cli.debug_print @@ Format.asprintf "unbinding %a" - Dcalc.Print.format_var var; *) - let ctx' = add_var pos var var_is_pure ctx in - let new_var = - (find ~info:"variable that was just created" var ctx').var - in - let new_next = translate_scope_let ctx' next in - Bindlib.box_apply2 - (fun new_expr new_next -> - D.ScopeLet - { - scope_let_kind = SubScopeVarDefinition; - scope_let_typ = translate_typ typ; - scope_let_expr = new_expr; - scope_let_next = new_next; - scope_let_pos = pos; - }) - (translate_expr ctx ~append_esome:false expr) - (Bindlib.bind_var new_var new_next) + let var_is_pure = true in + let var, next = Bindlib.unbind next in + (* Cli.debug_print @@ Format.asprintf "unbinding %a" Dcalc.Print.format_var + var; *) + let ctx' = add_var pos var var_is_pure ctx in + let new_var = (find ~info:"variable that was just created" var ctx').var in + let new_next = translate_scope_let ctx' next in + Bindlib.box_apply2 + (fun new_expr new_next -> + D.ScopeLet + { + scope_let_kind = SubScopeVarDefinition; + scope_let_typ = translate_typ typ; + scope_let_expr = new_expr; + scope_let_next = new_next; + scope_let_pos = pos; + }) + (translate_expr ctx ~append_esome:false expr) + (Bindlib.bind_var new_var new_next) | ScopeLet { scope_let_kind = SubScopeVarDefinition; @@ -419,27 +412,25 @@ let rec translate_scope_let (ctx : ctx) (lets : D.expr D.scope_body_expr) : scope_let_next = next; scope_let_pos = pos; } -> - (* special case: regular input to the subscope *) - let var_is_pure = true in - let var, next = Bindlib.unbind next in - (* Cli.debug_print @@ Format.asprintf "unbinding %a" - Dcalc.Print.format_var var; *) - let ctx' = add_var pos var var_is_pure ctx in - let new_var = - (find ~info:"variable that was just created" var ctx').var - in - Bindlib.box_apply2 - (fun new_expr new_next -> - D.ScopeLet - { - scope_let_kind = SubScopeVarDefinition; - scope_let_typ = translate_typ typ; - scope_let_expr = new_expr; - scope_let_next = new_next; - scope_let_pos = pos; - }) - (translate_expr ctx ~append_esome:false expr) - (Bindlib.bind_var new_var (translate_scope_let ctx' next)) + (* special case: regular input to the subscope *) + let var_is_pure = true in + let var, next = Bindlib.unbind next in + (* Cli.debug_print @@ Format.asprintf "unbinding %a" Dcalc.Print.format_var + var; *) + let ctx' = add_var pos var var_is_pure ctx in + let new_var = (find ~info:"variable that was just created" var ctx').var in + Bindlib.box_apply2 + (fun new_expr new_next -> + D.ScopeLet + { + scope_let_kind = SubScopeVarDefinition; + scope_let_typ = translate_typ typ; + scope_let_expr = new_expr; + scope_let_next = new_next; + scope_let_pos = pos; + }) + (translate_expr ctx ~append_esome:false expr) + (Bindlib.bind_var new_var (translate_scope_let ctx' next)) | ScopeLet { scope_let_kind = SubScopeVarDefinition; @@ -447,12 +438,12 @@ let rec translate_scope_let (ctx : ctx) (lets : D.expr D.scope_body_expr) : scope_let_expr = expr; _; } -> - Errors.raise_spanned_error pos - "Internal Error: found an SubScopeVarDefinition that does not satisfy \ - the invariants when translating Dcalc to Lcalc without exceptions: \ - @[%a@]" - (Dcalc.Print.format_expr ctx.decl_ctx) - expr + Errors.raise_spanned_error pos + "Internal Error: found an SubScopeVarDefinition that does not satisfy \ + the invariants when translating Dcalc to Lcalc without exceptions: \ + @[%a@]" + (Dcalc.Print.format_expr ctx.decl_ctx) + expr | ScopeLet { scope_let_kind = kind; @@ -461,82 +452,81 @@ let rec translate_scope_let (ctx : ctx) (lets : D.expr D.scope_body_expr) : scope_let_next = next; scope_let_pos = pos; } -> - let var_is_pure = - match kind with - | DestructuringInputStruct -> ( - (* Here, we have to distinguish between context and input variables. - We can do so by looking at the typ of the destructuring: if it's - thunked, then the variable is context. If it's not thunked, it's - a regular input. *) - match Pos.unmark typ with - | D.TArrow ((D.TLit D.TUnit, _), _) -> false - | _ -> true) - | ScopeVarDefinition | SubScopeVarDefinition | CallingSubScope - | DestructuringSubScopeResults | Assertion -> - true - in - let var, next = Bindlib.unbind next in - (* Cli.debug_print @@ Format.asprintf "unbinding %a" - Dcalc.Print.format_var var; *) - let ctx' = add_var pos var var_is_pure ctx in - let new_var = - (find ~info:"variable that was just created" var ctx').var - in - Bindlib.box_apply2 - (fun new_expr new_next -> - D.ScopeLet - { - scope_let_kind = kind; - scope_let_typ = translate_typ typ; - scope_let_expr = new_expr; - scope_let_next = new_next; - scope_let_pos = pos; - }) - (translate_expr ctx ~append_esome:false expr) - (Bindlib.bind_var new_var (translate_scope_let ctx' next)) + let var_is_pure = + match kind with + | DestructuringInputStruct -> ( + (* Here, we have to distinguish between context and input variables. We + can do so by looking at the typ of the destructuring: if it's + thunked, then the variable is context. If it's not thunked, it's a + regular input. *) + match Pos.unmark typ with + | D.TArrow ((D.TLit D.TUnit, _), _) -> false + | _ -> true) + | ScopeVarDefinition | SubScopeVarDefinition | CallingSubScope + | DestructuringSubScopeResults | Assertion -> + true + in + let var, next = Bindlib.unbind next in + (* Cli.debug_print @@ Format.asprintf "unbinding %a" Dcalc.Print.format_var + var; *) + let ctx' = add_var pos var var_is_pure ctx in + let new_var = (find ~info:"variable that was just created" var ctx').var in + Bindlib.box_apply2 + (fun new_expr new_next -> + D.ScopeLet + { + scope_let_kind = kind; + scope_let_typ = translate_typ typ; + scope_let_expr = new_expr; + scope_let_next = new_next; + scope_let_pos = pos; + }) + (translate_expr ctx ~append_esome:false expr) + (Bindlib.bind_var new_var (translate_scope_let ctx' next)) let translate_scope_body - (scope_pos : Pos.t) (ctx : ctx) (body : D.expr D.scope_body) : - A.expr D.scope_body Bindlib.box = + (scope_pos : Pos.t) + (ctx : ctx) + (body : D.expr D.scope_body) : A.expr D.scope_body Bindlib.box = match body with | { scope_body_expr = result; scope_body_input_struct = input_struct; scope_body_output_struct = output_struct; } -> - let v, lets = Bindlib.unbind result in - let ctx' = add_var scope_pos v true ctx in - let v' = (find ~info:"variable that was just created" v ctx').var in - Bindlib.box_apply - (fun new_expr -> - { - D.scope_body_expr = new_expr; - scope_body_input_struct = input_struct; - scope_body_output_struct = output_struct; - }) - (Bindlib.bind_var v' (translate_scope_let ctx' lets)) + let v, lets = Bindlib.unbind result in + let ctx' = add_var scope_pos v true ctx in + let v' = (find ~info:"variable that was just created" v ctx').var in + Bindlib.box_apply + (fun new_expr -> + { + D.scope_body_expr = new_expr; + scope_body_input_struct = input_struct; + scope_body_output_struct = output_struct; + }) + (Bindlib.bind_var v' (translate_scope_let ctx' lets)) let rec translate_scopes (ctx : ctx) (scopes : D.expr D.scopes) : A.expr D.scopes Bindlib.box = match scopes with | Nil -> Bindlib.box D.Nil | ScopeDef { scope_name; scope_body; scope_next } -> - let scope_var, next = Bindlib.unbind scope_next in - let new_ctx = add_var Pos.no_pos scope_var true ctx in - let new_scope_name = - (find ~info:"variable that was just created" scope_var new_ctx).var - in + let scope_var, next = Bindlib.unbind scope_next in + let new_ctx = add_var Pos.no_pos scope_var true ctx in + let new_scope_name = + (find ~info:"variable that was just created" scope_var new_ctx).var + in - let scope_pos = Pos.get_position (D.ScopeName.get_info scope_name) in + let scope_pos = Pos.get_position (D.ScopeName.get_info scope_name) in - let new_body = translate_scope_body scope_pos ctx scope_body in - let tail = translate_scopes new_ctx next in + let new_body = translate_scope_body scope_pos ctx scope_body in + let tail = translate_scopes new_ctx next in - Bindlib.box_apply2 - (fun body tail -> - D.ScopeDef { scope_name; scope_body = body; scope_next = tail }) - new_body - (Bindlib.bind_var new_scope_name tail) + Bindlib.box_apply2 + (fun body tail -> + D.ScopeDef { scope_name; scope_body = body; scope_next = tail }) + new_body + (Bindlib.bind_var new_scope_name tail) let translate_program (prgm : D.program) : A.program = let inputs_structs = @@ -567,7 +557,7 @@ let translate_program (prgm : D.program) : A.program = @@ Format.asprintf "Output type: %a" (Dcalc.Print.format_typ decl_ctx) (translate_typ tau); *) - (n, translate_typ tau)) + n, translate_typ tau) else l); } in diff --git a/compiler/lcalc/optimizations.ml b/compiler/lcalc/optimizations.ml index a7a7545f..3c319848 100644 --- a/compiler/lcalc/optimizations.ml +++ b/compiler/lcalc/optimizations.ml @@ -28,42 +28,42 @@ let visitor_map let default_mark e' = Pos.same_pos_as e' e in match Pos.unmark e with | EVar (v, _pos) -> - let+ v = Bindlib.box_var v in - default_mark @@ v + let+ v = Bindlib.box_var v in + default_mark @@ v | ETuple (args, n) -> - let+ args = args |> List.map (t ctx) |> Bindlib.box_list in - default_mark @@ ETuple (args, n) + let+ args = args |> List.map (t ctx) |> Bindlib.box_list in + default_mark @@ ETuple (args, n) | ETupleAccess (e1, i, n, ts) -> - let+ e1 = t ctx e1 in - default_mark @@ ETupleAccess (e1, i, n, ts) + let+ e1 = t ctx e1 in + default_mark @@ ETupleAccess (e1, i, n, ts) | EInj (e1, i, n, ts) -> - let+ e1 = t ctx e1 in - default_mark @@ EInj (e1, i, n, ts) + let+ e1 = t ctx e1 in + default_mark @@ EInj (e1, i, n, ts) | EMatch (arg, cases, n) -> - let+ arg = t ctx arg - and+ cases = cases |> List.map (t ctx) |> Bindlib.box_list in - default_mark @@ EMatch (arg, cases, n) + let+ arg = t ctx arg + and+ cases = cases |> List.map (t ctx) |> Bindlib.box_list in + default_mark @@ EMatch (arg, cases, n) | EArray args -> - let+ args = args |> List.map (t ctx) |> Bindlib.box_list in - default_mark @@ EArray args + let+ args = args |> List.map (t ctx) |> Bindlib.box_list in + default_mark @@ EArray args | EAbs ((binder, pos_binder), ts) -> - let vars, body = Bindlib.unmbind binder in - let body = t ctx body in - let+ binder = Bindlib.bind_mvar vars body in - default_mark @@ EAbs ((binder, pos_binder), ts) + let vars, body = Bindlib.unmbind binder in + let body = t ctx body in + let+ binder = Bindlib.bind_mvar vars body in + default_mark @@ EAbs ((binder, pos_binder), ts) | EApp (e1, args) -> - let+ e1 = t ctx e1 - and+ args = args |> List.map (t ctx) |> Bindlib.box_list in - default_mark @@ EApp (e1, args) + let+ e1 = t ctx e1 + and+ args = args |> List.map (t ctx) |> Bindlib.box_list in + default_mark @@ EApp (e1, args) | EAssert e1 -> - let+ e1 = t ctx e1 in - default_mark @@ EAssert e1 + let+ e1 = t ctx e1 in + default_mark @@ EAssert e1 | EIfThenElse (e1, e2, e3) -> - let+ e1 = t ctx e1 and+ e2 = t ctx e2 and+ e3 = t ctx e3 in - default_mark @@ EIfThenElse (e1, e2, e3) + let+ e1 = t ctx e1 and+ e2 = t ctx e2 and+ e3 = t ctx e3 in + default_mark @@ EIfThenElse (e1, e2, e3) | ECatch (e1, exn, e2) -> - let+ e1 = t ctx e1 and+ e2 = t ctx e2 in - default_mark @@ ECatch (e1, exn, e2) + let+ e1 = t ctx e1 and+ e2 = t ctx e2 in + default_mark @@ ECatch (e1, exn, e2) | ERaise _ | ELit _ | EOp _ -> Bindlib.box e let rec iota_expr (_ : unit) (e : expr Pos.marked) : expr Pos.marked Bindlib.box @@ -72,18 +72,18 @@ let rec iota_expr (_ : unit) (e : expr Pos.marked) : expr Pos.marked Bindlib.box match Pos.unmark e with | EMatch ((EInj (e1, i, n', _ts), _), cases, n) when Dcalc.Ast.EnumName.compare n n' = 0 -> - let+ e1 = visitor_map iota_expr () e1 - and+ case = visitor_map iota_expr () (List.nth cases i) in - default_mark @@ EApp (case, [ e1 ]) + let+ e1 = visitor_map iota_expr () e1 + and+ case = visitor_map iota_expr () (List.nth cases i) in + default_mark @@ EApp (case, [e1]) | EMatch (e', cases, n) when cases |> List.mapi (fun i (case, _pos) -> match case with | EInj (_ei, i', n', _ts') -> - i = i' && (* n = n' *) Dcalc.Ast.EnumName.compare n n' = 0 + i = i' && (* n = n' *) Dcalc.Ast.EnumName.compare n n' = 0 | _ -> false) |> List.for_all Fun.id -> - visitor_map iota_expr () e' + visitor_map iota_expr () e' | _ -> visitor_map iota_expr () e let rec beta_expr (_ : unit) (e : expr Pos.marked) : expr Pos.marked Bindlib.box @@ -91,13 +91,13 @@ let rec beta_expr (_ : unit) (e : expr Pos.marked) : expr Pos.marked Bindlib.box let default_mark e' = Pos.same_pos_as e' e in match Pos.unmark e with | EApp (e1, args) -> ( - let+ e1 = beta_expr () e1 - and+ args = List.map (beta_expr ()) args |> Bindlib.box_list in - match Pos.unmark e1 with - | EAbs ((binder, _pos_binder), _ts) -> - let (_ : (_, _) Bindlib.mbinder) = binder in - Bindlib.msubst binder (List.map fst args |> Array.of_list) - | _ -> default_mark @@ EApp (e1, args)) + let+ e1 = beta_expr () e1 + and+ args = List.map (beta_expr ()) args |> Bindlib.box_list in + match Pos.unmark e1 with + | EAbs ((binder, _pos_binder), _ts) -> + let (_ : (_, _) Bindlib.mbinder) = binder in + Bindlib.msubst binder (List.map fst args |> Array.of_list) + | _ -> default_mark @@ EApp (e1, args)) | _ -> visitor_map beta_expr () e let iota_optimizations (p : program) : program = @@ -118,26 +118,26 @@ let rec peephole_expr (_ : unit) (e : expr Pos.marked) : match Pos.unmark e with | EIfThenElse (e1, e2, e3) -> ( - let+ e1 = peephole_expr () e1 - and+ e2 = peephole_expr () e2 - and+ e3 = peephole_expr () e3 in - match Pos.unmark e1 with - | ELit (LBool true) - | EApp ((EOp (Unop (Log _)), _), [ (ELit (LBool true), _) ]) -> - e2 - | ELit (LBool false) - | EApp ((EOp (Unop (Log _)), _), [ (ELit (LBool false), _) ]) -> - e3 - | _ -> default_mark @@ EIfThenElse (e1, e2, e3)) + let+ e1 = peephole_expr () e1 + and+ e2 = peephole_expr () e2 + and+ e3 = peephole_expr () e3 in + match Pos.unmark e1 with + | ELit (LBool true) + | EApp ((EOp (Unop (Log _)), _), [(ELit (LBool true), _)]) -> + e2 + | ELit (LBool false) + | EApp ((EOp (Unop (Log _)), _), [(ELit (LBool false), _)]) -> + e3 + | _ -> default_mark @@ EIfThenElse (e1, e2, e3)) | ECatch (e1, except, e2) -> ( - let+ e1 = peephole_expr () e1 and+ e2 = peephole_expr () e2 in - match (Pos.unmark e1, Pos.unmark e2) with - | ERaise except', ERaise except'' - when except' = except && except = except'' -> - default_mark @@ ERaise except - | ERaise except', _ when except' = except -> e2 - | _, ERaise except' when except' = except -> e1 - | _ -> default_mark @@ ECatch (e1, except, e2)) + let+ e1 = peephole_expr () e1 and+ e2 = peephole_expr () e2 in + match Pos.unmark e1, Pos.unmark e2 with + | ERaise except', ERaise except'' when except' = except && except = except'' + -> + default_mark @@ ERaise except + | ERaise except', _ when except' = except -> e2 + | _, ERaise except' when except' = except -> e1 + | _ -> default_mark @@ ECatch (e1, except, e2)) | _ -> visitor_map peephole_expr () e let peephole_optimizations (p : program) : program = diff --git a/compiler/lcalc/print.ml b/compiler/lcalc/print.ml index 20572a00..a07e5ca7 100644 --- a/compiler/lcalc/print.ml +++ b/compiler/lcalc/print.ml @@ -36,22 +36,22 @@ let format_lit (fmt : Format.formatter) (l : lit Pos.marked) : unit = | LInt i -> Dcalc.Print.format_lit_style fmt (Runtime.integer_to_string i) | LUnit -> Dcalc.Print.format_lit_style fmt "()" | LRat i -> - Dcalc.Print.format_lit_style fmt - (Runtime.decimal_to_string ~max_prec_digits:!Utils.Cli.max_prec_digits i) + Dcalc.Print.format_lit_style fmt + (Runtime.decimal_to_string ~max_prec_digits:!Utils.Cli.max_prec_digits i) | LMoney e -> ( - match !Utils.Cli.locale_lang with - | En -> - Dcalc.Print.format_lit_style fmt - (Format.asprintf "$%s" (Runtime.money_to_string e)) - | Fr -> - Dcalc.Print.format_lit_style fmt - (Format.asprintf "%s €" (Runtime.money_to_string e)) - | Pl -> - Dcalc.Print.format_lit_style fmt - (Format.asprintf "%s PLN" (Runtime.money_to_string e))) + match !Utils.Cli.locale_lang with + | En -> + Dcalc.Print.format_lit_style fmt + (Format.asprintf "$%s" (Runtime.money_to_string e)) + | Fr -> + Dcalc.Print.format_lit_style fmt + (Format.asprintf "%s €" (Runtime.money_to_string e)) + | Pl -> + Dcalc.Print.format_lit_style fmt + (Format.asprintf "%s PLN" (Runtime.money_to_string e))) | LDate d -> Dcalc.Print.format_lit_style fmt (Runtime.date_to_string d) | LDuration d -> - Dcalc.Print.format_lit_style fmt (Runtime.duration_to_string d) + Dcalc.Print.format_lit_style fmt (Runtime.duration_to_string d) let format_exception (fmt : Format.formatter) (exn : except) : unit = Dcalc.Print.format_operator fmt @@ -62,10 +62,10 @@ let format_exception (fmt : Format.formatter) (exn : except) : unit = | NoValueProvided -> "NoValueProvided") let format_keyword (fmt : Format.formatter) (s : string) : unit = - Format.fprintf fmt "%a" (Utils.Cli.format_with_style [ ANSITerminal.red ]) s + Format.fprintf fmt "%a" (Utils.Cli.format_with_style [ANSITerminal.red]) s let format_punctuation (fmt : Format.formatter) (s : string) : unit = - Format.fprintf fmt "%a" (Utils.Cli.format_with_style [ ANSITerminal.cyan ]) s + Format.fprintf fmt "%a" (Utils.Cli.format_with_style [ANSITerminal.cyan]) s let needs_parens (e : expr Pos.marked) : bool = match Pos.unmark e with EAbs _ | ETuple (_, Some _) -> true | _ -> false @@ -88,120 +88,118 @@ let rec format_expr match Pos.unmark e with | EVar v -> Format.fprintf fmt "%a" format_var (Pos.unmark v) | ETuple (es, None) -> - Format.fprintf fmt "@[%a%a%a@]" format_punctuation "(" - (Format.pp_print_list - ~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ") - (fun fmt e -> Format.fprintf fmt "%a" format_expr e)) - es format_punctuation ")" + Format.fprintf fmt "@[%a%a%a@]" format_punctuation "(" + (Format.pp_print_list + ~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ") + (fun fmt e -> Format.fprintf fmt "%a" format_expr e)) + es format_punctuation ")" | ETuple (es, Some s) -> - Format.fprintf fmt "@[%a@ %a%a%a@]" Dcalc.Ast.StructName.format_t s - format_punctuation "{" - (Format.pp_print_list - ~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ") - (fun fmt (e, struct_field) -> - Format.fprintf fmt "%a%a%a%a %a" format_punctuation "\"" - Dcalc.Ast.StructFieldName.format_t struct_field - format_punctuation "\"" format_punctuation ":" format_expr e)) - (List.combine es - (List.map fst (Dcalc.Ast.StructMap.find s ctx.ctx_structs))) - format_punctuation "}" + Format.fprintf fmt "@[%a@ %a%a%a@]" Dcalc.Ast.StructName.format_t s + format_punctuation "{" + (Format.pp_print_list + ~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ") + (fun fmt (e, struct_field) -> + Format.fprintf fmt "%a%a%a%a %a" format_punctuation "\"" + Dcalc.Ast.StructFieldName.format_t struct_field format_punctuation + "\"" format_punctuation ":" format_expr e)) + (List.combine es + (List.map fst (Dcalc.Ast.StructMap.find s ctx.ctx_structs))) + format_punctuation "}" | EArray es -> - Format.fprintf fmt "@[%a%a%a@]" format_punctuation "[" - (Format.pp_print_list - ~pp_sep:(fun fmt () -> Format.fprintf fmt ";@ ") - (fun fmt e -> Format.fprintf fmt "%a" format_expr e)) - es format_punctuation "]" + Format.fprintf fmt "@[%a%a%a@]" format_punctuation "[" + (Format.pp_print_list + ~pp_sep:(fun fmt () -> Format.fprintf fmt ";@ ") + (fun fmt e -> Format.fprintf fmt "%a" format_expr e)) + es format_punctuation "]" | ETupleAccess (e1, n, s, _ts) -> ( - match s with - | None -> - Format.fprintf fmt "%a%a%d" format_expr e1 format_punctuation "." n - | Some s -> - Format.fprintf fmt "%a%a%a%a%a" format_expr e1 format_punctuation "." - format_punctuation "\"" Dcalc.Ast.StructFieldName.format_t - (fst (List.nth (Dcalc.Ast.StructMap.find s ctx.ctx_structs) n)) - format_punctuation "\"") + match s with + | None -> + Format.fprintf fmt "%a%a%d" format_expr e1 format_punctuation "." n + | Some s -> + Format.fprintf fmt "%a%a%a%a%a" format_expr e1 format_punctuation "." + format_punctuation "\"" Dcalc.Ast.StructFieldName.format_t + (fst (List.nth (Dcalc.Ast.StructMap.find s ctx.ctx_structs) n)) + format_punctuation "\"") | EInj (e, n, en, _ts) -> - Format.fprintf fmt "@[%a@ %a@]" Dcalc.Print.format_enum_constructor - (fst (List.nth (Dcalc.Ast.EnumMap.find en ctx.ctx_enums) n)) - format_expr e + Format.fprintf fmt "@[%a@ %a@]" Dcalc.Print.format_enum_constructor + (fst (List.nth (Dcalc.Ast.EnumMap.find en ctx.ctx_enums) n)) + format_expr e | EMatch (e, es, e_name) -> - Format.fprintf fmt "@[%a@ %a@ %a@ %a@]" format_keyword "match" - format_expr e format_keyword "with" - (Format.pp_print_list - ~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n") - (fun fmt (e, c) -> - Format.fprintf fmt "@[%a %a%a@ %a@]" format_punctuation "|" - Dcalc.Print.format_enum_constructor c format_punctuation ":" - format_expr e)) - (List.combine es - (List.map fst (Dcalc.Ast.EnumMap.find e_name ctx.ctx_enums))) + Format.fprintf fmt "@[%a@ %a@ %a@ %a@]" format_keyword "match" + format_expr e format_keyword "with" + (Format.pp_print_list + ~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n") + (fun fmt (e, c) -> + Format.fprintf fmt "@[%a %a%a@ %a@]" format_punctuation "|" + Dcalc.Print.format_enum_constructor c format_punctuation ":" + format_expr e)) + (List.combine es + (List.map fst (Dcalc.Ast.EnumMap.find e_name ctx.ctx_enums))) | ELit l -> Format.fprintf fmt "%a" format_lit (Pos.same_pos_as l e) | EApp ((EAbs ((binder, _), taus), _), args) -> - let xs, body = Bindlib.unmbind binder in - let xs_tau = List.map2 (fun x tau -> (x, tau)) (Array.to_list xs) taus in - let xs_tau_arg = - List.map2 (fun (x, tau) arg -> (x, tau, arg)) xs_tau args - in - Format.fprintf fmt "%a%a" - (Format.pp_print_list - ~pp_sep:(fun fmt () -> Format.fprintf fmt "") - (fun fmt (x, tau, arg) -> - Format.fprintf fmt "@[%a@ %a@ %a@ %a@ %a@ %a@ %a@]@\n" - format_keyword "let" format_var x format_punctuation ":" - (Dcalc.Print.format_typ ctx) - tau format_punctuation "=" format_expr arg format_keyword "in")) - xs_tau_arg format_expr body + let xs, body = Bindlib.unmbind binder in + let xs_tau = List.map2 (fun x tau -> x, tau) (Array.to_list xs) taus in + let xs_tau_arg = List.map2 (fun (x, tau) arg -> x, tau, arg) xs_tau args in + Format.fprintf fmt "%a%a" + (Format.pp_print_list + ~pp_sep:(fun fmt () -> Format.fprintf fmt "") + (fun fmt (x, tau, arg) -> + Format.fprintf fmt "@[%a@ %a@ %a@ %a@ %a@ %a@ %a@]@\n" + format_keyword "let" format_var x format_punctuation ":" + (Dcalc.Print.format_typ ctx) + tau format_punctuation "=" format_expr arg format_keyword "in")) + xs_tau_arg format_expr body | EAbs ((binder, _), taus) -> - let xs, body = Bindlib.unmbind binder in - let xs_tau = List.map2 (fun x tau -> (x, tau)) (Array.to_list xs) taus in - Format.fprintf fmt "@[%a %a %a@ %a@]" format_punctuation "λ" - (Format.pp_print_list - ~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ") - (fun fmt (x, tau) -> - Format.fprintf fmt "%a%a%a %a%a" format_punctuation "(" format_var - x format_punctuation ":" - (Dcalc.Print.format_typ ctx) - tau format_punctuation ")")) - xs_tau format_punctuation "→" format_expr body + let xs, body = Bindlib.unmbind binder in + let xs_tau = List.map2 (fun x tau -> x, tau) (Array.to_list xs) taus in + Format.fprintf fmt "@[%a %a %a@ %a@]" format_punctuation "λ" + (Format.pp_print_list + ~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ") + (fun fmt (x, tau) -> + Format.fprintf fmt "%a%a%a %a%a" format_punctuation "(" format_var x + format_punctuation ":" + (Dcalc.Print.format_typ ctx) + tau format_punctuation ")")) + xs_tau format_punctuation "→" format_expr body | EApp - ( (EOp (Binop ((Dcalc.Ast.Map | Dcalc.Ast.Filter) as op)), _), - [ arg1; arg2 ] ) -> - Format.fprintf fmt "@[%a@ %a@ %a@]" Dcalc.Print.format_binop - (op, Pos.no_pos) format_with_parens arg1 format_with_parens arg2 - | EApp ((EOp (Binop op), _), [ arg1; arg2 ]) -> - Format.fprintf fmt "@[%a@ %a@ %a@]" format_with_parens arg1 - Dcalc.Print.format_binop (op, Pos.no_pos) format_with_parens arg2 - | EApp ((EOp (Unop (Log _)), _), [ arg1 ]) when not debug -> - Format.fprintf fmt "%a" format_with_parens arg1 - | EApp ((EOp (Unop op), _), [ arg1 ]) -> - Format.fprintf fmt "@[%a@ %a@]" Dcalc.Print.format_unop - (op, Pos.no_pos) format_with_parens arg1 + ((EOp (Binop ((Dcalc.Ast.Map | Dcalc.Ast.Filter) as op)), _), [arg1; arg2]) + -> + Format.fprintf fmt "@[%a@ %a@ %a@]" Dcalc.Print.format_binop + (op, Pos.no_pos) format_with_parens arg1 format_with_parens arg2 + | EApp ((EOp (Binop op), _), [arg1; arg2]) -> + Format.fprintf fmt "@[%a@ %a@ %a@]" format_with_parens arg1 + Dcalc.Print.format_binop (op, Pos.no_pos) format_with_parens arg2 + | EApp ((EOp (Unop (Log _)), _), [arg1]) when not debug -> + Format.fprintf fmt "%a" format_with_parens arg1 + | EApp ((EOp (Unop op), _), [arg1]) -> + Format.fprintf fmt "@[%a@ %a@]" Dcalc.Print.format_unop + (op, Pos.no_pos) format_with_parens arg1 | EApp (f, args) -> - Format.fprintf fmt "@[%a@ %a@]" format_expr f - (Format.pp_print_list - ~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ") - format_with_parens) - args + Format.fprintf fmt "@[%a@ %a@]" format_expr f + (Format.pp_print_list + ~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ") + format_with_parens) + args | EIfThenElse (e1, e2, e3) -> - Format.fprintf fmt "@[%a@ %a@ %a@ %a@ %a@ %a@]" format_keyword "if" - format_expr e1 format_keyword "then" format_expr e2 format_keyword - "else" format_expr e3 + Format.fprintf fmt "@[%a@ %a@ %a@ %a@ %a@ %a@]" format_keyword "if" + format_expr e1 format_keyword "then" format_expr e2 format_keyword "else" + format_expr e3 | EOp (Ternop op) -> - Format.fprintf fmt "%a" Dcalc.Print.format_ternop (op, Pos.no_pos) + Format.fprintf fmt "%a" Dcalc.Print.format_ternop (op, Pos.no_pos) | EOp (Binop op) -> - Format.fprintf fmt "%a" Dcalc.Print.format_binop (op, Pos.no_pos) + Format.fprintf fmt "%a" Dcalc.Print.format_binop (op, Pos.no_pos) | EOp (Unop op) -> - Format.fprintf fmt "%a" Dcalc.Print.format_unop (op, Pos.no_pos) + Format.fprintf fmt "%a" Dcalc.Print.format_unop (op, Pos.no_pos) | ECatch (e1, exn, e2) -> - Format.fprintf fmt "@[%a@ %a@ %a@ %a ->@ %a@]" format_keyword "try" - format_with_parens e1 format_keyword "with" format_exception exn - format_with_parens e2 + Format.fprintf fmt "@[%a@ %a@ %a@ %a ->@ %a@]" format_keyword "try" + format_with_parens e1 format_keyword "with" format_exception exn + format_with_parens e2 | ERaise exn -> - Format.fprintf fmt "@[%a@ %a@]" format_keyword "raise" - format_exception exn + Format.fprintf fmt "@[%a@ %a@]" format_keyword "raise" + format_exception exn | EAssert e' -> - Format.fprintf fmt "@[%a@ %a%a%a@]" format_keyword "assert" - format_punctuation "(" format_expr e' format_punctuation ")" + Format.fprintf fmt "@[%a@ %a%a%a@]" format_keyword "assert" + format_punctuation "(" format_expr e' format_punctuation ")" let format_scope ?(debug : bool = false) diff --git a/compiler/lcalc/to_ocaml.ml b/compiler/lcalc/to_ocaml.ml index 92a5e650..87eb1f55 100644 --- a/compiler/lcalc/to_ocaml.ml +++ b/compiler/lcalc/to_ocaml.ml @@ -40,25 +40,24 @@ let find_enum (en : D.EnumName.t) (ctx : D.decl_ctx) : let format_lit (fmt : Format.formatter) (l : lit Pos.marked) : unit = match Pos.unmark l with | LBool b -> - Dcalc.Print.format_lit fmt (Pos.same_pos_as (Dcalc.Ast.LBool b) l) + Dcalc.Print.format_lit fmt (Pos.same_pos_as (Dcalc.Ast.LBool b) l) | LInt i -> - Format.fprintf fmt "integer_of_string@ \"%s\"" - (Runtime.integer_to_string i) + Format.fprintf fmt "integer_of_string@ \"%s\"" (Runtime.integer_to_string i) | LUnit -> Dcalc.Print.format_lit fmt (Pos.same_pos_as Dcalc.Ast.LUnit l) | LRat i -> - Format.fprintf fmt "decimal_of_string \"%a\"" Dcalc.Print.format_lit - (Pos.same_pos_as (Dcalc.Ast.LRat i) l) + Format.fprintf fmt "decimal_of_string \"%a\"" Dcalc.Print.format_lit + (Pos.same_pos_as (Dcalc.Ast.LRat i) l) | LMoney e -> - Format.fprintf fmt "money_of_cents_string@ \"%s\"" - (Runtime.integer_to_string (Runtime.money_to_cents e)) + Format.fprintf fmt "money_of_cents_string@ \"%s\"" + (Runtime.integer_to_string (Runtime.money_to_cents e)) | LDate d -> - Format.fprintf fmt "date_of_numbers %d %d %d" - (Runtime.integer_to_int (Runtime.year_of_date d)) - (Runtime.integer_to_int (Runtime.month_number_of_date d)) - (Runtime.integer_to_int (Runtime.day_of_month_of_date d)) + Format.fprintf fmt "date_of_numbers %d %d %d" + (Runtime.integer_to_int (Runtime.year_of_date d)) + (Runtime.integer_to_int (Runtime.month_number_of_date d)) + (Runtime.integer_to_int (Runtime.day_of_month_of_date d)) | LDuration d -> - let years, months, days = Runtime.duration_to_years_months_days d in - Format.fprintf fmt "duration_of_numbers %d %d %d" years months days + let years, months, days = Runtime.duration_to_years_months_days d in + Format.fprintf fmt "duration_of_numbers %d %d %d" years months days let format_op_kind (fmt : Format.formatter) (k : Dcalc.Ast.op_kind) = Format.fprintf fmt "%s" @@ -114,9 +113,9 @@ let format_unop (fmt : Format.formatter) (op : Dcalc.Ast.unop Pos.marked) : unit | Minus k -> Format.fprintf fmt "~-%a" format_op_kind k | Not -> Format.fprintf fmt "%s" "not" | Log (_entry, _infos) -> - Errors.raise_spanned_error (Pos.get_position op) - "Internal error: a log operator has not been caught by the expression \ - match" + Errors.raise_spanned_error (Pos.get_position op) + "Internal error: a log operator has not been caught by the expression \ + match" | Length -> Format.fprintf fmt "%s" "array_length" | IntToRat -> Format.fprintf fmt "%s" "decimal_of_integer" | GetDay -> Format.fprintf fmt "%s" "day_of_month_of_date" @@ -138,7 +137,7 @@ let avoid_keywords (s : string) : string = | "nonrec" | "object" | "of" | "open" | "or" | "private" | "rec" | "sig" | "struct" | "then" | "to" | "true" | "try" | "type" | "val" | "virtual" | "when" | "while" | "with" -> - true + true | _ -> false then s ^ "_" else s @@ -151,7 +150,8 @@ let format_struct_name (fmt : Format.formatter) (v : Dcalc.Ast.StructName.t) : (to_ascii (Format.asprintf "%a" Dcalc.Ast.StructName.format_t v)))) let format_struct_field_name - (fmt : Format.formatter) (v : Dcalc.Ast.StructFieldName.t) : unit = + (fmt : Format.formatter) + (v : Dcalc.Ast.StructFieldName.t) : unit = Format.fprintf fmt "%s" (avoid_keywords (to_ascii (Format.asprintf "%a" Dcalc.Ast.StructFieldName.format_t v))) @@ -164,7 +164,8 @@ let format_enum_name (fmt : Format.formatter) (v : Dcalc.Ast.EnumName.t) : unit (to_ascii (Format.asprintf "%a" Dcalc.Ast.EnumName.format_t v)))) let format_enum_cons_name - (fmt : Format.formatter) (v : Dcalc.Ast.EnumConstructor.t) : unit = + (fmt : Format.formatter) + (v : Dcalc.Ast.EnumConstructor.t) : unit = Format.fprintf fmt "%s" (avoid_keywords (to_ascii (Format.asprintf "%a" Dcalc.Ast.EnumConstructor.format_t v))) @@ -180,7 +181,7 @@ let rec typ_embedding_name (fmt : Format.formatter) (ty : D.typ Pos.marked) : | D.TLit D.TDate -> Format.fprintf fmt "embed_date" | D.TLit D.TDuration -> Format.fprintf fmt "embed_duration" | D.TTuple (_, Some s_name) -> - Format.fprintf fmt "embed_%a" format_struct_name s_name + Format.fprintf fmt "embed_%a" format_struct_name s_name | D.TEnum (_, e_name) -> Format.fprintf fmt "embed_%a" format_enum_name e_name | D.TArray ty -> Format.fprintf fmt "embed_array (%a)" typ_embedding_name ty | _ -> Format.fprintf fmt "unembeddable" @@ -192,30 +193,31 @@ let rec format_typ (fmt : Format.formatter) (typ : Dcalc.Ast.typ Pos.marked) : unit = let format_typ = format_typ in let format_typ_with_parens - (fmt : Format.formatter) (t : Dcalc.Ast.typ Pos.marked) = + (fmt : Format.formatter) + (t : Dcalc.Ast.typ Pos.marked) = if typ_needs_parens t then Format.fprintf fmt "(%a)" format_typ t else Format.fprintf fmt "%a" format_typ t in match Pos.unmark typ with | TLit l -> Format.fprintf fmt "%a" Dcalc.Print.format_tlit l | TTuple (ts, None) -> - Format.fprintf fmt "@[(%a)@]" - (Format.pp_print_list - ~pp_sep:(fun fmt () -> Format.fprintf fmt "@ *@ ") - format_typ_with_parens) - ts + Format.fprintf fmt "@[(%a)@]" + (Format.pp_print_list + ~pp_sep:(fun fmt () -> Format.fprintf fmt "@ *@ ") + format_typ_with_parens) + ts | TTuple (_, Some s) -> Format.fprintf fmt "%a" format_struct_name s - | TEnum ([ t ], e) when D.EnumName.compare e Ast.option_enum = 0 -> - Format.fprintf fmt "@[(%a)@] %a" format_typ_with_parens t - format_enum_name e + | TEnum ([t], e) when D.EnumName.compare e Ast.option_enum = 0 -> + Format.fprintf fmt "@[(%a)@] %a" format_typ_with_parens t + format_enum_name e | TEnum (_, e) when D.EnumName.compare e Ast.option_enum = 0 -> - Errors.raise_spanned_error (Pos.get_position typ) - "Internal Error: found an typing parameter for an eoption type of the \ - wrong lenght." + Errors.raise_spanned_error (Pos.get_position typ) + "Internal Error: found an typing parameter for an eoption type of the \ + wrong lenght." | TEnum (_ts, e) -> Format.fprintf fmt "%a" format_enum_name e | TArrow (t1, t2) -> - Format.fprintf fmt "@[%a ->@ %a@]" format_typ_with_parens t1 - format_typ_with_parens t2 + Format.fprintf fmt "@[%a ->@ %a@]" format_typ_with_parens t1 + format_typ_with_parens t2 | TArray t1 -> Format.fprintf fmt "@[%a@ array@]" format_typ_with_parens t1 | TAny -> Format.fprintf fmt "_" @@ -228,7 +230,7 @@ let format_var (fmt : Format.formatter) (v : Var.t) : unit = in let lowercase_name = avoid_keywords (to_ascii lowercase_name) in if - List.mem lowercase_name [ "handle_default"; "handle_default_opt" ] + List.mem lowercase_name ["handle_default"; "handle_default_opt"] || Dcalc.Print.begins_with_uppercase (Bindlib.name_of v) then Format.fprintf fmt "%s" lowercase_name else if lowercase_name = "_" then Format.fprintf fmt "%s" lowercase_name @@ -239,7 +241,7 @@ let needs_parens (e : expr Pos.marked) : bool = | EApp ((EAbs (_, _), _), _) | ELit (LBool _ | LUnit) | EVar _ | ETuple _ | EOp _ -> - false + false | _ -> true let format_exception (fmt : Format.formatter) (exc : except Pos.marked) : unit = @@ -248,17 +250,18 @@ let format_exception (fmt : Format.formatter) (exc : except Pos.marked) : unit = | EmptyError -> Format.fprintf fmt "EmptyError" | Crash -> Format.fprintf fmt "Crash" | NoValueProvided -> - let pos = Pos.get_position exc in - Format.fprintf fmt - "(NoValueProvided@ @[{filename = \"%s\";@ start_line=%d;@ \ - start_column=%d;@ end_line=%d; end_column=%d;@ law_headings=%a}@])" - (Pos.get_file pos) (Pos.get_start_line pos) (Pos.get_start_column pos) - (Pos.get_end_line pos) (Pos.get_end_column pos) format_string_list - (Pos.get_law_info pos) + let pos = Pos.get_position exc in + Format.fprintf fmt + "(NoValueProvided@ @[{filename = \"%s\";@ start_line=%d;@ \ + start_column=%d;@ end_line=%d; end_column=%d;@ law_headings=%a}@])" + (Pos.get_file pos) (Pos.get_start_line pos) (Pos.get_start_column pos) + (Pos.get_end_line pos) (Pos.get_end_column pos) format_string_list + (Pos.get_law_info pos) let rec format_expr - (ctx : Dcalc.Ast.decl_ctx) (fmt : Format.formatter) (e : expr Pos.marked) : - unit = + (ctx : Dcalc.Ast.decl_ctx) + (fmt : Format.formatter) + (e : expr Pos.marked) : unit = let format_expr = format_expr ctx in let format_with_parens (fmt : Format.formatter) (e : expr Pos.marked) = if needs_parens e then Format.fprintf fmt "(%a)" format_expr e @@ -267,149 +270,140 @@ let rec format_expr match Pos.unmark e with | EVar v -> Format.fprintf fmt "%a" format_var (Pos.unmark v) | ETuple (es, None) -> - Format.fprintf fmt "@[(%a)@]" - (Format.pp_print_list - ~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ") - (fun fmt e -> Format.fprintf fmt "%a" format_with_parens e)) - es + Format.fprintf fmt "@[(%a)@]" + (Format.pp_print_list + ~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ") + (fun fmt e -> Format.fprintf fmt "%a" format_with_parens e)) + es | ETuple (es, Some s) -> - if List.length es = 0 then Format.fprintf fmt "()" - else - Format.fprintf fmt "{@[%a@]}" - (Format.pp_print_list - ~pp_sep:(fun fmt () -> Format.fprintf fmt ";@ ") - (fun fmt (e, struct_field) -> - Format.fprintf fmt "@[%a =@ %a@]" format_struct_field_name - struct_field format_with_parens e)) - (List.combine es (List.map fst (find_struct s ctx))) - | EArray es -> - Format.fprintf fmt "@[[|%a|]@]" + if List.length es = 0 then Format.fprintf fmt "()" + else + Format.fprintf fmt "{@[%a@]}" (Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt ";@ ") - (fun fmt e -> Format.fprintf fmt "%a" format_with_parens e)) - es + (fun fmt (e, struct_field) -> + Format.fprintf fmt "@[%a =@ %a@]" format_struct_field_name + struct_field format_with_parens e)) + (List.combine es (List.map fst (find_struct s ctx))) + | EArray es -> + Format.fprintf fmt "@[[|%a|]@]" + (Format.pp_print_list + ~pp_sep:(fun fmt () -> Format.fprintf fmt ";@ ") + (fun fmt e -> Format.fprintf fmt "%a" format_with_parens e)) + es | ETupleAccess (e1, n, s, ts) -> ( - match s with - | None -> - Format.fprintf fmt "let@ %a@ = %a@ in@ x" - (Format.pp_print_list - ~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ") - (fun fmt i -> - Format.fprintf fmt "%s" (if i = n then "x" else "_"))) - (List.mapi (fun i _ -> i) ts) - format_with_parens e1 - | Some s -> - Format.fprintf fmt "%a.%a" format_with_parens e1 - format_struct_field_name - (fst (List.nth (find_struct s ctx) n))) - | EInj (e, n, en, _ts) -> - Format.fprintf fmt "@[%a@ %a@]" format_enum_cons_name - (fst (List.nth (find_enum en ctx) n)) - format_with_parens e - | EMatch (e, es, e_name) -> - Format.fprintf fmt "@[match@ %a@]@ with@\n%a" format_with_parens e + match s with + | None -> + Format.fprintf fmt "let@ %a@ = %a@ in@ x" (Format.pp_print_list - ~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n| ") - (fun fmt (e, c) -> - Format.fprintf fmt "%a %a" format_enum_cons_name c - (fun fmt e -> - match Pos.unmark e with - | EAbs ((binder, _), _) -> - let xs, body = Bindlib.unmbind binder in - Format.fprintf fmt "%a ->@[@ %a@]" - (Format.pp_print_list - ~pp_sep:(fun fmt () -> Format.fprintf fmt "@,") - (fun fmt x -> Format.fprintf fmt "%a" format_var x)) - (Array.to_list xs) format_with_parens body - | _ -> assert false - (* should not happen *)) - e)) - (List.combine es (List.map fst (find_enum e_name ctx))) + ~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ") + (fun fmt i -> Format.fprintf fmt "%s" (if i = n then "x" else "_"))) + (List.mapi (fun i _ -> i) ts) + format_with_parens e1 + | Some s -> + Format.fprintf fmt "%a.%a" format_with_parens e1 format_struct_field_name + (fst (List.nth (find_struct s ctx) n))) + | EInj (e, n, en, _ts) -> + Format.fprintf fmt "@[%a@ %a@]" format_enum_cons_name + (fst (List.nth (find_enum en ctx) n)) + format_with_parens e + | EMatch (e, es, e_name) -> + Format.fprintf fmt "@[match@ %a@]@ with@\n%a" format_with_parens e + (Format.pp_print_list + ~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n| ") + (fun fmt (e, c) -> + Format.fprintf fmt "%a %a" format_enum_cons_name c + (fun fmt e -> + match Pos.unmark e with + | EAbs ((binder, _), _) -> + let xs, body = Bindlib.unmbind binder in + Format.fprintf fmt "%a ->@[@ %a@]" + (Format.pp_print_list + ~pp_sep:(fun fmt () -> Format.fprintf fmt "@,") + (fun fmt x -> Format.fprintf fmt "%a" format_var x)) + (Array.to_list xs) format_with_parens body + | _ -> assert false + (* should not happen *)) + e)) + (List.combine es (List.map fst (find_enum e_name ctx))) | ELit l -> Format.fprintf fmt "%a" format_lit (Pos.same_pos_as l e) | EApp ((EAbs ((binder, _), taus), _), args) -> - let xs, body = Bindlib.unmbind binder in - let xs_tau = List.map2 (fun x tau -> (x, tau)) (Array.to_list xs) taus in - let xs_tau_arg = - List.map2 (fun (x, tau) arg -> (x, tau, arg)) xs_tau args - in - Format.fprintf fmt "(%a%a)" - (Format.pp_print_list - ~pp_sep:(fun fmt () -> Format.fprintf fmt "") - (fun fmt (x, tau, arg) -> - Format.fprintf fmt "@[let@ %a@ :@ %a@ =@ %a@]@ in@\n" - format_var x format_typ tau format_with_parens arg)) - xs_tau_arg format_with_parens body + let xs, body = Bindlib.unmbind binder in + let xs_tau = List.map2 (fun x tau -> x, tau) (Array.to_list xs) taus in + let xs_tau_arg = List.map2 (fun (x, tau) arg -> x, tau, arg) xs_tau args in + Format.fprintf fmt "(%a%a)" + (Format.pp_print_list + ~pp_sep:(fun fmt () -> Format.fprintf fmt "") + (fun fmt (x, tau, arg) -> + Format.fprintf fmt "@[let@ %a@ :@ %a@ =@ %a@]@ in@\n" + format_var x format_typ tau format_with_parens arg)) + xs_tau_arg format_with_parens body | EAbs ((binder, _), taus) -> - let xs, body = Bindlib.unmbind binder in - let xs_tau = List.map2 (fun x tau -> (x, tau)) (Array.to_list xs) taus in - Format.fprintf fmt "@[fun@ %a ->@ %a@]" - (Format.pp_print_list - ~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ") - (fun fmt (x, tau) -> - Format.fprintf fmt "@[(%a:@ %a)@]" format_var x format_typ - tau)) - xs_tau format_expr body + let xs, body = Bindlib.unmbind binder in + let xs_tau = List.map2 (fun x tau -> x, tau) (Array.to_list xs) taus in + Format.fprintf fmt "@[fun@ %a ->@ %a@]" + (Format.pp_print_list + ~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ") + (fun fmt (x, tau) -> + Format.fprintf fmt "@[(%a:@ %a)@]" format_var x format_typ tau)) + xs_tau format_expr body | EApp - ( (EOp (Binop ((Dcalc.Ast.Map | Dcalc.Ast.Filter) as op)), _), - [ arg1; arg2 ] ) -> - Format.fprintf fmt "@[%a@ %a@ %a@]" format_binop (op, Pos.no_pos) - format_with_parens arg1 format_with_parens arg2 - | EApp ((EOp (Binop op), _), [ arg1; arg2 ]) -> - Format.fprintf fmt "@[%a@ %a@ %a@]" format_with_parens arg1 - format_binop (op, Pos.no_pos) format_with_parens arg2 - | EApp - ((EApp ((EOp (Unop (D.Log (D.BeginCall, info))), _), [ f ]), _), [ arg ]) + ((EOp (Binop ((Dcalc.Ast.Map | Dcalc.Ast.Filter) as op)), _), [arg1; arg2]) + -> + Format.fprintf fmt "@[%a@ %a@ %a@]" format_binop (op, Pos.no_pos) + format_with_parens arg1 format_with_parens arg2 + | EApp ((EOp (Binop op), _), [arg1; arg2]) -> + Format.fprintf fmt "@[%a@ %a@ %a@]" format_with_parens arg1 + format_binop (op, Pos.no_pos) format_with_parens arg2 + | EApp ((EApp ((EOp (Unop (D.Log (D.BeginCall, info))), _), [f]), _), [arg]) when !Cli.trace_flag -> - Format.fprintf fmt "(log_begin_call@ %a@ %a@ %a)" format_uid_list info - format_with_parens f format_with_parens arg - | EApp ((EOp (Unop (D.Log (D.VarDef tau, info))), _), [ arg1 ]) + Format.fprintf fmt "(log_begin_call@ %a@ %a@ %a)" format_uid_list info + format_with_parens f format_with_parens arg + | EApp ((EOp (Unop (D.Log (D.VarDef tau, info))), _), [arg1]) when !Cli.trace_flag -> - Format.fprintf fmt "(log_variable_definition@ %a@ (%a)@ %a)" - format_uid_list info typ_embedding_name (tau, Pos.no_pos) - format_with_parens arg1 - | EApp ((EOp (Unop (D.Log (D.PosRecordIfTrueBool, _))), pos), [ arg1 ]) + Format.fprintf fmt "(log_variable_definition@ %a@ (%a)@ %a)" format_uid_list + info typ_embedding_name (tau, Pos.no_pos) format_with_parens arg1 + | EApp ((EOp (Unop (D.Log (D.PosRecordIfTrueBool, _))), pos), [arg1]) when !Cli.trace_flag -> - Format.fprintf fmt - "(log_decision_taken@ @[{filename = \"%s\";@ start_line=%d;@ \ - start_column=%d;@ end_line=%d; end_column=%d;@ law_headings=%a}@]@ \ - %a)" - (Pos.get_file pos) (Pos.get_start_line pos) (Pos.get_start_column pos) - (Pos.get_end_line pos) (Pos.get_end_column pos) format_string_list - (Pos.get_law_info pos) format_with_parens arg1 - | EApp ((EOp (Unop (D.Log (D.EndCall, info))), _), [ arg1 ]) + Format.fprintf fmt + "(log_decision_taken@ @[{filename = \"%s\";@ start_line=%d;@ \ + start_column=%d;@ end_line=%d; end_column=%d;@ law_headings=%a}@]@ %a)" + (Pos.get_file pos) (Pos.get_start_line pos) (Pos.get_start_column pos) + (Pos.get_end_line pos) (Pos.get_end_column pos) format_string_list + (Pos.get_law_info pos) format_with_parens arg1 + | EApp ((EOp (Unop (D.Log (D.EndCall, info))), _), [arg1]) when !Cli.trace_flag -> - Format.fprintf fmt "(log_end_call@ %a@ %a)" format_uid_list info - format_with_parens arg1 - | EApp ((EOp (Unop (D.Log _)), _), [ arg1 ]) -> - Format.fprintf fmt "%a" format_with_parens arg1 - | EApp ((EOp (Unop op), _), [ arg1 ]) -> - Format.fprintf fmt "@[%a@ %a@]" format_unop (op, Pos.no_pos) - format_with_parens arg1 + Format.fprintf fmt "(log_end_call@ %a@ %a)" format_uid_list info + format_with_parens arg1 + | EApp ((EOp (Unop (D.Log _)), _), [arg1]) -> + Format.fprintf fmt "%a" format_with_parens arg1 + | EApp ((EOp (Unop op), _), [arg1]) -> + Format.fprintf fmt "@[%a@ %a@]" format_unop (op, Pos.no_pos) + format_with_parens arg1 | EApp (f, args) -> - Format.fprintf fmt "@[%a@ %a@]" format_with_parens f - (Format.pp_print_list - ~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ") - format_with_parens) - args + Format.fprintf fmt "@[%a@ %a@]" format_with_parens f + (Format.pp_print_list + ~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ") + format_with_parens) + args | EIfThenElse (e1, e2, e3) -> - Format.fprintf fmt - "@[ if@ @[%a@]@ then@ @[%a@]@ else@ @[%a@]@]" - format_with_parens e1 format_with_parens e2 format_with_parens e3 + Format.fprintf fmt + "@[ if@ @[%a@]@ then@ @[%a@]@ else@ @[%a@]@]" + format_with_parens e1 format_with_parens e2 format_with_parens e3 | EOp (Ternop op) -> Format.fprintf fmt "%a" format_ternop (op, Pos.no_pos) | EOp (Binop op) -> Format.fprintf fmt "%a" format_binop (op, Pos.no_pos) | EOp (Unop op) -> Format.fprintf fmt "%a" format_unop (op, Pos.no_pos) | EAssert e' -> - Format.fprintf fmt - "@[if @ %a@ then@ ()@ else@ raise AssertionFailed@]" - format_with_parens e' + Format.fprintf fmt + "@[if @ %a@ then@ ()@ else@ raise AssertionFailed@]" + format_with_parens e' | ERaise exc -> - Format.fprintf fmt "raise@ %a" format_exception (exc, Pos.get_position e) + Format.fprintf fmt "raise@ %a" format_exception (exc, Pos.get_position e) | ECatch (e1, exc, e2) -> - Format.fprintf fmt "@[try@ %a@ with@ %a@ ->@ %a@]" - format_with_parens e1 format_exception - (exc, Pos.get_position e) - format_with_parens e2 + Format.fprintf fmt "@[try@ %a@ with@ %a@ ->@ %a@]" format_with_parens + e1 format_exception + (exc, Pos.get_position e) + format_with_parens e2 let format_struct_embedding (fmt : Format.formatter) @@ -508,9 +502,9 @@ let format_ctx (fun struct_or_enum -> match struct_or_enum with | Scopelang.Dependency.TVertex.Struct s -> - Format.fprintf fmt "%a@\n@\n" format_struct_decl (s, find_struct s ctx) + Format.fprintf fmt "%a@\n@\n" format_struct_decl (s, find_struct s ctx) | Scopelang.Dependency.TVertex.Enum e -> - Format.fprintf fmt "%a@\n@\n" format_enum_decl (e, find_enum e ctx)) + Format.fprintf fmt "%a@\n@\n" format_enum_decl (e, find_enum e ctx)) (type_ordering @ scope_structs) let rec format_scope_body_expr @@ -520,14 +514,14 @@ let rec format_scope_body_expr match scope_lets with | Dcalc.Ast.Result e -> format_expr ctx fmt e | Dcalc.Ast.ScopeLet scope_let -> - let scope_let_var, scope_let_next = - Bindlib.unbind scope_let.scope_let_next - in - Format.fprintf fmt "@[let %a: %a = %a in@]@\n%a" format_var - scope_let_var format_typ scope_let.scope_let_typ (format_expr ctx) - scope_let.scope_let_expr - (format_scope_body_expr ctx) - scope_let_next + let scope_let_var, scope_let_next = + Bindlib.unbind scope_let.scope_let_next + in + Format.fprintf fmt "@[let %a: %a = %a in@]@\n%a" format_var + scope_let_var format_typ scope_let.scope_let_typ (format_expr ctx) + scope_let.scope_let_expr + (format_scope_body_expr ctx) + scope_let_next let rec format_scopes (ctx : Dcalc.Ast.decl_ctx) @@ -536,16 +530,16 @@ let rec format_scopes match scopes with | Dcalc.Ast.Nil -> () | Dcalc.Ast.ScopeDef scope_def -> - let scope_input_var, scope_body_expr = - Bindlib.unbind scope_def.scope_body.scope_body_expr - in - let scope_var, scope_next = Bindlib.unbind scope_def.scope_next in - Format.fprintf fmt "@\n@\n@[let %a (%a: %a) : %a =@\n%a@]%a" - format_var scope_var format_var scope_input_var format_struct_name - scope_def.scope_body.scope_body_input_struct format_struct_name - scope_def.scope_body.scope_body_output_struct - (format_scope_body_expr ctx) - scope_body_expr (format_scopes ctx) scope_next + let scope_input_var, scope_body_expr = + Bindlib.unbind scope_def.scope_body.scope_body_expr + in + let scope_var, scope_next = Bindlib.unbind scope_def.scope_next in + Format.fprintf fmt "@\n@\n@[let %a (%a: %a) : %a =@\n%a@]%a" + format_var scope_var format_var scope_input_var format_struct_name + scope_def.scope_body.scope_body_input_struct format_struct_name + scope_def.scope_body.scope_body_output_struct + (format_scope_body_expr ctx) + scope_body_expr (format_scopes ctx) scope_next let format_program (fmt : Format.formatter) diff --git a/compiler/literate/html.ml b/compiler/literate/html.ml index f242de17..312a5124 100644 --- a/compiler/literate/html.ml +++ b/compiler/literate/html.ml @@ -165,48 +165,46 @@ let pygmentize_code (c : string Pos.marked) (language : C.backend_lang) : string (** {1 Weaving} *) let rec law_structure_to_html - (language : C.backend_lang) (fmt : Format.formatter) (i : A.law_structure) : - unit = + (language : C.backend_lang) + (fmt : Format.formatter) + (i : A.law_structure) : unit = match i with | A.LawText t -> - let t = pre_html t in - if t = "" then () - else Format.fprintf fmt "
%s
" t + let t = pre_html t in + if t = "" then () else Format.fprintf fmt "
%s
" t | A.CodeBlock (_, c, metadata) -> - Format.fprintf fmt - "
\n\ -
%s
\n\ - %s\n\ -
" - (if metadata then " code-metadata" else "") - (Pos.get_file (Pos.get_position c)) - (pygmentize_code - (Pos.same_pos_as ("```catala\n" ^ Pos.unmark c ^ "```") c) - language) + Format.fprintf fmt + "
\n
%s
\n%s\n
" + (if metadata then " code-metadata" else "") + (Pos.get_file (Pos.get_position c)) + (pygmentize_code + (Pos.same_pos_as ("```catala\n" ^ Pos.unmark c ^ "```") c) + language) | A.LawHeading (heading, children) -> - let h_number = heading.law_heading_precedence + 1 in - Format.fprintf fmt "%s\n" - h_number - (match (heading.law_heading_id, language) with - | Some id, Fr -> - let ltime = Unix.localtime (Unix.time ()) in - P.sprintf "https://legifrance.gouv.fr/codes/id/%s/%d-%02d-%02d" id - (1900 + ltime.Unix.tm_year) - (ltime.Unix.tm_mon + 1) ltime.Unix.tm_mday - | _ -> "#") - (pre_html (Pos.unmark heading.law_heading_name)) - h_number; - Format.pp_print_list - ~pp_sep:(fun fmt () -> Format.fprintf fmt "\n") - (law_structure_to_html language) - fmt children + let h_number = heading.law_heading_precedence + 1 in + Format.fprintf fmt "%s\n" + h_number + (match heading.law_heading_id, language with + | Some id, Fr -> + let ltime = Unix.localtime (Unix.time ()) in + P.sprintf "https://legifrance.gouv.fr/codes/id/%s/%d-%02d-%02d" id + (1900 + ltime.Unix.tm_year) + (ltime.Unix.tm_mon + 1) ltime.Unix.tm_mday + | _ -> "#") + (pre_html (Pos.unmark heading.law_heading_name)) + h_number; + Format.pp_print_list + ~pp_sep:(fun fmt () -> Format.fprintf fmt "\n") + (law_structure_to_html language) + fmt children | A.LawInclude _ -> () (** {1 API} *) let ast_to_html - (language : C.backend_lang) (fmt : Format.formatter) (program : A.program) : - unit = + (language : C.backend_lang) + (fmt : Format.formatter) + (program : A.program) : unit = Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt "\n\n") (law_structure_to_html language) diff --git a/compiler/literate/latex.ml b/compiler/literate/latex.ml index 05174f94..09f1604a 100644 --- a/compiler/literate/latex.ml +++ b/compiler/literate/latex.ml @@ -178,99 +178,103 @@ let wrap_latex (** [check_exceeding_lines max_len start_line filename content] prints a warning message for each lines of [content] exceeding [max_len] characters. *) let check_exceeding_lines - ?(max_len = 80) (start_line : int) (filename : string) (content : string) = + ?(max_len = 80) + (start_line : int) + (filename : string) + (content : string) = content |> String.split_on_char '\n' |> List.iteri (fun i s -> if CamomileLibrary.UTF8.length s > max_len then ( Cli.warning_print "The line %s in %s is exceeding %s characters:" (Cli.with_style - ANSITerminal.[ Bold; yellow ] + ANSITerminal.[Bold; yellow] "%d" (start_line + i + 1)) - (Cli.with_style ANSITerminal.[ Bold; magenta ] "%s" filename) - (Cli.with_style ANSITerminal.[ Bold; red ] "%d" max_len); + (Cli.with_style ANSITerminal.[Bold; magenta] "%s" filename) + (Cli.with_style ANSITerminal.[Bold; red] "%d" max_len); Cli.warning_print "%s%s" (String.sub s 0 max_len) (Cli.with_style - ANSITerminal.[ red ] + ANSITerminal.[red] "%s" String.(sub s max_len (length s - max_len))))) let rec law_structure_to_latex - (language : C.backend_lang) (fmt : Format.formatter) (i : A.law_structure) : - unit = + (language : C.backend_lang) + (fmt : Format.formatter) + (i : A.law_structure) : unit = match i with | A.LawHeading (heading, children) -> - Format.fprintf fmt "\\%s{%s}\n\n" - (match heading.law_heading_precedence with - | 0 -> "section" - | 1 -> "subsection" - | 2 -> "subsubsection" - | 3 -> "subsubsubsection" - | 4 -> "subsubsubsubsection" - | 5 -> "subsubsubsubsubsection" - | 6 -> "subsubsubsubsubsubsection" - | 7 -> "paragraph" - | _ -> "subparagraph") - (pre_latexify (Pos.unmark heading.law_heading_name)); - Format.pp_print_list - ~pp_sep:(fun fmt () -> Format.fprintf fmt "\n\n") - (law_structure_to_latex language) - fmt children + Format.fprintf fmt "\\%s{%s}\n\n" + (match heading.law_heading_precedence with + | 0 -> "section" + | 1 -> "subsection" + | 2 -> "subsubsection" + | 3 -> "subsubsubsection" + | 4 -> "subsubsubsubsection" + | 5 -> "subsubsubsubsubsection" + | 6 -> "subsubsubsubsubsubsection" + | 7 -> "paragraph" + | _ -> "subparagraph") + (pre_latexify (Pos.unmark heading.law_heading_name)); + Format.pp_print_list + ~pp_sep:(fun fmt () -> Format.fprintf fmt "\n\n") + (law_structure_to_latex language) + fmt children | A.LawInclude (A.PdfFile ((file, _), page)) -> - let label = - file - ^ match page with None -> "" | Some p -> Format.sprintf "_page_%d," p - in - Format.fprintf fmt - "\\begin{center}\\textit{Annexe incluse, retranscrite page \ - \\pageref{%s}}\\end{center} \ - \\begin{figure}[p]\\begin{center}\\includegraphics[%swidth=\\textwidth]{%s}\\label{%s}\\end{center}\\end{figure}" - label - (match page with None -> "" | Some p -> Format.sprintf "page=%d," p) - file label + let label = + file + ^ match page with None -> "" | Some p -> Format.sprintf "_page_%d," p + in + Format.fprintf fmt + "\\begin{center}\\textit{Annexe incluse, retranscrite page \ + \\pageref{%s}}\\end{center} \ + \\begin{figure}[p]\\begin{center}\\includegraphics[%swidth=\\textwidth]{%s}\\label{%s}\\end{center}\\end{figure}" + label + (match page with None -> "" | Some p -> Format.sprintf "page=%d," p) + file label | A.LawInclude (A.CatalaFile _ | A.LegislativeText _) -> () | A.LawText t -> Format.fprintf fmt "%s" (pre_latexify t) | A.CodeBlock (_, c, false) -> - Format.fprintf fmt - "\\begin{minted}[label={\\hspace*{\\fill}\\texttt{%s}},firstnumber=%d]{%s}\n\ - ```catala\n\ - %s```\n\ - \\end{minted}" - (pre_latexify (Filename.basename (Pos.get_file (Pos.get_position c)))) - (Pos.get_start_line (Pos.get_position c) - 1) - (get_language_extension language) - (Pos.unmark c) + Format.fprintf fmt + "\\begin{minted}[label={\\hspace*{\\fill}\\texttt{%s}},firstnumber=%d]{%s}\n\ + ```catala\n\ + %s```\n\ + \\end{minted}" + (pre_latexify (Filename.basename (Pos.get_file (Pos.get_position c)))) + (Pos.get_start_line (Pos.get_position c) - 1) + (get_language_extension language) + (Pos.unmark c) | A.CodeBlock (_, c, true) -> - let metadata_title = - match language with - | Fr -> "Métadonnées" - | En -> "Metadata" - | Pl -> "Metadane" - in - let start_line = Pos.get_start_line (Pos.get_position c) - 1 in - let filename = Filename.basename (Pos.get_file (Pos.get_position c)) in - let block_content = Pos.unmark c in - check_exceeding_lines start_line filename block_content; - Format.fprintf fmt - "\\begin{tcolorbox}[colframe=OliveGreen, breakable, \ - title=\\textcolor{black}{\\texttt{%s}},title after \ - break=\\textcolor{black}{\\texttt{%s}},before skip=1em, after \ - skip=1em]\n\ - \\begin{minted}[numbersep=9mm, firstnumber=%d, \ - label={\\hspace*{\\fill}\\texttt{%s}}]{%s}\n\ - ```catala\n\ - %s```\n\ - \\end{minted}\n\ - \\end{tcolorbox}" - metadata_title metadata_title start_line (pre_latexify filename) - (get_language_extension language) - block_content + let metadata_title = + match language with + | Fr -> "Métadonnées" + | En -> "Metadata" + | Pl -> "Metadane" + in + let start_line = Pos.get_start_line (Pos.get_position c) - 1 in + let filename = Filename.basename (Pos.get_file (Pos.get_position c)) in + let block_content = Pos.unmark c in + check_exceeding_lines start_line filename block_content; + Format.fprintf fmt + "\\begin{tcolorbox}[colframe=OliveGreen, breakable, \ + title=\\textcolor{black}{\\texttt{%s}},title after \ + break=\\textcolor{black}{\\texttt{%s}},before skip=1em, after skip=1em]\n\ + \\begin{minted}[numbersep=9mm, firstnumber=%d, \ + label={\\hspace*{\\fill}\\texttt{%s}}]{%s}\n\ + ```catala\n\ + %s```\n\ + \\end{minted}\n\ + \\end{tcolorbox}" + metadata_title metadata_title start_line (pre_latexify filename) + (get_language_extension language) + block_content (** {1 API} *) let ast_to_latex - (language : C.backend_lang) (fmt : Format.formatter) (program : A.program) : - unit = + (language : C.backend_lang) + (fmt : Format.formatter) + (program : A.program) : unit = Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt "\n\n") (law_structure_to_latex language) diff --git a/compiler/literate/literate_common.ml b/compiler/literate/literate_common.ml index 0f5826b8..a28ab964 100644 --- a/compiler/literate/literate_common.ml +++ b/compiler/literate/literate_common.ml @@ -34,24 +34,24 @@ let literal_source_files = function let literal_disclaimer_and_link = function | En -> - "This document was produced from a set of source files written in the \ - Catala programming language, mixing together the legislative text and \ - the computer code that translates it. For more information about the \ - methodology and how to read the code, please visit \ - [https://catala-lang.org](https://catala-lang.org)." + "This document was produced from a set of source files written in the \ + Catala programming language, mixing together the legislative text and the \ + computer code that translates it. For more information about the \ + methodology and how to read the code, please visit \ + [https://catala-lang.org](https://catala-lang.org)." | Fr -> - "Ce document a été produit à partir d'un ensemble de fichiers sources \ - écrits dans le langage de programmation Catala, mêlant le texte \ - législatif et le code informatique qui le traduit. Pour plus \ - d'informations sur la méthodologie et sur la façon de lire le code, \ - veuillez consulter le site \ - [https://catala-lang.org](https://catala-lang.org)." + "Ce document a été produit à partir d'un ensemble de fichiers sources \ + écrits dans le langage de programmation Catala, mêlant le texte \ + législatif et le code informatique qui le traduit. Pour plus \ + d'informations sur la méthodologie et sur la façon de lire le code, \ + veuillez consulter le site \ + [https://catala-lang.org](https://catala-lang.org)." | Pl -> - "Niniejszy dokument został opracowany na podstawie zestawu plików \ - źródłowych napisanych w języku programowania Catala, łączących tekst \ - legislacyjny z kodem komputerowym, który go tłumaczy. Więcej informacji \ - na temat metodologii i sposobu odczytywania kodu można znaleźć na \ - stronie [https://catala-lang.org](https://catala-lang.org)" + "Niniejszy dokument został opracowany na podstawie zestawu plików \ + źródłowych napisanych w języku programowania Catala, łączących tekst \ + legislacyjny z kodem komputerowym, który go tłumaczy. Więcej informacji \ + na temat metodologii i sposobu odczytywania kodu można znaleźć na stronie \ + [https://catala-lang.org](https://catala-lang.org)" let literal_last_modification = function | En -> "last modification" diff --git a/compiler/runtime.ml b/compiler/runtime.ml index 095c2354..a61a2ab9 100644 --- a/compiler/runtime.ml +++ b/compiler/runtime.ml @@ -178,18 +178,16 @@ let duration_of_numbers (year : int) (month : int) (day : int) : duration = let duration_to_string (d : duration) : string = let x, y, z = CalendarLib.Date.Period.ymd d in let to_print = - List.filter - (fun (a, _) -> a <> 0) - [ (x, "years"); (y, "months"); (z, "days") ] + List.filter (fun (a, _) -> a <> 0) [x, "years"; y, "months"; z, "days"] in match to_print with | [] -> "empty duration" | _ -> - Format.asprintf "%a" - (Format.pp_print_list - ~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ") - (fun fmt (d, l) -> Format.fprintf fmt "%d %s" d l)) - to_print + Format.asprintf "%a" + (Format.pp_print_list + ~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ") + (fun fmt (d, l) -> Format.fprintf fmt "%d %s" d l)) + to_print let duration_to_years_months_days (d : duration) : int * int * int = CalendarLib.Date.Period.ymd d @@ -201,7 +199,7 @@ let handle_default : Array.fold_left (fun acc except -> let new_val = try Some (except ()) with EmptyError -> None in - match (acc, new_val) with + match acc, new_val with | None, _ -> new_val | Some _, None -> acc | Some _, Some _ -> raise ConflictError) @@ -212,12 +210,13 @@ let handle_default : | None -> if just () then cons () else raise EmptyError let handle_default_opt - (exceptions : 'a eoption array) (just : bool eoption) (cons : 'a eoption) : - 'a eoption = + (exceptions : 'a eoption array) + (just : bool eoption) + (cons : 'a eoption) : 'a eoption = let except = Array.fold_left (fun acc except -> - match (acc, except) with + match acc, except with | ENone _, _ -> except | ESome _, ENone _ -> acc | ESome _, ESome _ -> raise ConflictError) @@ -226,9 +225,9 @@ let handle_default_opt match except with | ESome _ -> except | ENone _ -> ( - match just with - | ESome b -> if b then cons else ENone () - | ENone _ -> ENone ()) + match just with + | ESome b -> if b then cons else ENone () + | ENone _ -> ENone ()) let no_input : unit -> 'a = fun _ -> raise EmptyError @@ -308,7 +307,8 @@ let ( <@ ) (d1 : date) (d2 : date) : bool = CalendarLib.Date.compare d1 d2 < 0 let ( =@ ) (d1 : date) (d2 : date) : bool = CalendarLib.Date.compare d1 d2 = 0 let compare_periods - (p1 : CalendarLib.Date.Period.t) (p2 : CalendarLib.Date.Period.t) : int = + (p1 : CalendarLib.Date.Period.t) + (p2 : CalendarLib.Date.Period.t) : int = try let p1_days = CalendarLib.Date.Period.nb_days p1 in let p2_days = CalendarLib.Date.Period.nb_days p2 in diff --git a/compiler/scalc/compile_from_lambda.ml b/compiler/scalc/compile_from_lambda.ml index 5494f49a..293f6812 100644 --- a/compiler/scalc/compile_from_lambda.ml +++ b/compiler/scalc/compile_from_lambda.ml @@ -33,248 +33,233 @@ let rec translate_expr (ctxt : ctxt) (expr : L.expr Pos.marked) : A.block * A.expr Pos.marked = match Pos.unmark expr with | L.EVar v -> - let local_var = - try A.EVar (L.VarMap.find (Pos.unmark v) ctxt.var_dict) - with Not_found -> - A.EFunc (L.VarMap.find (Pos.unmark v) ctxt.func_dict) - in - ([], (local_var, Pos.get_position v)) + let local_var = + try A.EVar (L.VarMap.find (Pos.unmark v) ctxt.var_dict) + with Not_found -> A.EFunc (L.VarMap.find (Pos.unmark v) ctxt.func_dict) + in + [], (local_var, Pos.get_position v) | L.ETuple (args, Some s_name) -> - let args_stmts, new_args = - List.fold_left - (fun (args_stmts, new_args) arg -> - let arg_stmts, new_arg = translate_expr ctxt arg in - (arg_stmts @ args_stmts, new_arg :: new_args)) - ([], []) args - in - let new_args = List.rev new_args in - let args_stmts = List.rev args_stmts in - (args_stmts, (A.EStruct (new_args, s_name), Pos.get_position expr)) + let args_stmts, new_args = + List.fold_left + (fun (args_stmts, new_args) arg -> + let arg_stmts, new_arg = translate_expr ctxt arg in + arg_stmts @ args_stmts, new_arg :: new_args) + ([], []) args + in + let new_args = List.rev new_args in + let args_stmts = List.rev args_stmts in + args_stmts, (A.EStruct (new_args, s_name), Pos.get_position expr) | L.ETuple (_, None) -> - failwith "Non-struct tuples cannot be compiled to scalc" + failwith "Non-struct tuples cannot be compiled to scalc" | L.ETupleAccess (e1, num_field, Some s_name, _) -> - let e1_stmts, new_e1 = translate_expr ctxt e1 in - let field_name = - fst - (List.nth - (D.StructMap.find s_name ctxt.decl_ctx.ctx_structs) - num_field) - in - ( e1_stmts, - ( A.EStructFieldAccess (new_e1, field_name, s_name), - Pos.get_position expr ) ) + let e1_stmts, new_e1 = translate_expr ctxt e1 in + let field_name = + fst + (List.nth (D.StructMap.find s_name ctxt.decl_ctx.ctx_structs) num_field) + in + ( e1_stmts, + (A.EStructFieldAccess (new_e1, field_name, s_name), Pos.get_position expr) + ) | L.ETupleAccess (_, _, None, _) -> - failwith "Non-struct tuples cannot be compiled to scalc" + failwith "Non-struct tuples cannot be compiled to scalc" | L.EInj (e1, num_cons, e_name, _) -> - let e1_stmts, new_e1 = translate_expr ctxt e1 in - let cons_name = - fst (List.nth (D.EnumMap.find e_name ctxt.decl_ctx.ctx_enums) num_cons) - in - (e1_stmts, (A.EInj (new_e1, cons_name, e_name), Pos.get_position expr)) + let e1_stmts, new_e1 = translate_expr ctxt e1 in + let cons_name = + fst (List.nth (D.EnumMap.find e_name ctxt.decl_ctx.ctx_enums) num_cons) + in + e1_stmts, (A.EInj (new_e1, cons_name, e_name), Pos.get_position expr) | L.EApp (f, args) -> - let f_stmts, new_f = translate_expr ctxt f in - let args_stmts, new_args = - List.fold_left - (fun (args_stmts, new_args) arg -> - let arg_stmts, new_arg = translate_expr ctxt arg in - (arg_stmts @ args_stmts, new_arg :: new_args)) - ([], []) args - in - let new_args = List.rev new_args in - (f_stmts @ args_stmts, (A.EApp (new_f, new_args), Pos.get_position expr)) + let f_stmts, new_f = translate_expr ctxt f in + let args_stmts, new_args = + List.fold_left + (fun (args_stmts, new_args) arg -> + let arg_stmts, new_arg = translate_expr ctxt arg in + arg_stmts @ args_stmts, new_arg :: new_args) + ([], []) args + in + let new_args = List.rev new_args in + f_stmts @ args_stmts, (A.EApp (new_f, new_args), Pos.get_position expr) | L.EArray args -> - let args_stmts, new_args = - List.fold_left - (fun (args_stmts, new_args) arg -> - let arg_stmts, new_arg = translate_expr ctxt arg in - (arg_stmts @ args_stmts, new_arg :: new_args)) - ([], []) args - in - let new_args = List.rev new_args in - (args_stmts, (A.EArray new_args, Pos.get_position expr)) - | L.EOp op -> ([], (A.EOp op, Pos.get_position expr)) - | L.ELit l -> ([], (A.ELit l, Pos.get_position expr)) + let args_stmts, new_args = + List.fold_left + (fun (args_stmts, new_args) arg -> + let arg_stmts, new_arg = translate_expr ctxt arg in + arg_stmts @ args_stmts, new_arg :: new_args) + ([], []) args + in + let new_args = List.rev new_args in + args_stmts, (A.EArray new_args, Pos.get_position expr) + | L.EOp op -> [], (A.EOp op, Pos.get_position expr) + | L.ELit l -> [], (A.ELit l, Pos.get_position expr) | _ -> - let tmp_var = - A.LocalName.fresh - ( (*This piece of logic is used to make the code more readable. TODO: - should be removed when - https://github.com/CatalaLang/catala/issues/240 is fixed. *) - (match ctxt.inside_definition_of with - | None -> ctxt.context_name - | Some v -> - let v = Pos.unmark (A.LocalName.get_info v) in - let tmp_rex = Re.Pcre.regexp "^temp_" in - if Re.Pcre.pmatch ~rex:tmp_rex v then v else "temp_" ^ v), - Pos.get_position expr ) - in - let ctxt = - { - ctxt with - inside_definition_of = Some tmp_var; - context_name = Pos.unmark (A.LocalName.get_info tmp_var); - } - in - let tmp_stmts = translate_statements ctxt expr in - ( ( A.SLocalDecl - ((tmp_var, Pos.get_position expr), (D.TAny, Pos.get_position expr)), + let tmp_var = + A.LocalName.fresh + ( (*This piece of logic is used to make the code more readable. TODO: + should be removed when + https://github.com/CatalaLang/catala/issues/240 is fixed. *) + (match ctxt.inside_definition_of with + | None -> ctxt.context_name + | Some v -> + let v = Pos.unmark (A.LocalName.get_info v) in + let tmp_rex = Re.Pcre.regexp "^temp_" in + if Re.Pcre.pmatch ~rex:tmp_rex v then v else "temp_" ^ v), Pos.get_position expr ) - :: tmp_stmts, - (A.EVar tmp_var, Pos.get_position expr) ) + in + let ctxt = + { + ctxt with + inside_definition_of = Some tmp_var; + context_name = Pos.unmark (A.LocalName.get_info tmp_var); + } + in + let tmp_stmts = translate_statements ctxt expr in + ( ( A.SLocalDecl + ((tmp_var, Pos.get_position expr), (D.TAny, Pos.get_position expr)), + Pos.get_position expr ) + :: tmp_stmts, + (A.EVar tmp_var, Pos.get_position expr) ) and translate_statements (ctxt : ctxt) (block_expr : L.expr Pos.marked) : A.block = match Pos.unmark block_expr with | L.EAssert e -> - (* Assertions are always encapsulated in a unit-typed let binding *) - let e_stmts, new_e = translate_expr ctxt e in - e_stmts @ [ (A.SAssert (Pos.unmark new_e), Pos.get_position block_expr) ] + (* Assertions are always encapsulated in a unit-typed let binding *) + let e_stmts, new_e = translate_expr ctxt e in + e_stmts @ [A.SAssert (Pos.unmark new_e), Pos.get_position block_expr] | L.EApp ((L.EAbs ((binder, binder_pos), taus), eabs_pos), args) -> - (* This defines multiple local variables at the time *) - let vars, body = Bindlib.unmbind binder in - let vars_tau = - List.map2 (fun x tau -> (x, tau)) (Array.to_list vars) taus - in - let ctxt = - { - ctxt with - var_dict = - List.fold_left - (fun var_dict (x, _) -> - L.VarMap.add x - (A.LocalName.fresh (Bindlib.name_of x, binder_pos)) - var_dict) - ctxt.var_dict vars_tau; - } - in - let local_decls = - List.map - (fun (x, tau) -> - ( A.SLocalDecl ((L.VarMap.find x ctxt.var_dict, binder_pos), tau), - eabs_pos )) - vars_tau - in - let vars_args = - List.map2 - (fun (x, tau) arg -> - ((L.VarMap.find x ctxt.var_dict, binder_pos), tau, arg)) - vars_tau args - in - let def_blocks = - List.map - (fun (x, _tau, arg) -> - let ctxt = - { - ctxt with - inside_definition_of = Some (Pos.unmark x); - context_name = Pos.unmark (A.LocalName.get_info (Pos.unmark x)); - } - in - let arg_stmts, new_arg = translate_expr ctxt arg in - arg_stmts @ [ (A.SLocalDef (x, new_arg), binder_pos) ]) - vars_args - in - let rest_of_block = translate_statements ctxt body in - local_decls @ List.flatten def_blocks @ rest_of_block + (* This defines multiple local variables at the time *) + let vars, body = Bindlib.unmbind binder in + let vars_tau = List.map2 (fun x tau -> x, tau) (Array.to_list vars) taus in + let ctxt = + { + ctxt with + var_dict = + List.fold_left + (fun var_dict (x, _) -> + L.VarMap.add x + (A.LocalName.fresh (Bindlib.name_of x, binder_pos)) + var_dict) + ctxt.var_dict vars_tau; + } + in + let local_decls = + List.map + (fun (x, tau) -> + ( A.SLocalDecl ((L.VarMap.find x ctxt.var_dict, binder_pos), tau), + eabs_pos )) + vars_tau + in + let vars_args = + List.map2 + (fun (x, tau) arg -> + (L.VarMap.find x ctxt.var_dict, binder_pos), tau, arg) + vars_tau args + in + let def_blocks = + List.map + (fun (x, _tau, arg) -> + let ctxt = + { + ctxt with + inside_definition_of = Some (Pos.unmark x); + context_name = Pos.unmark (A.LocalName.get_info (Pos.unmark x)); + } + in + let arg_stmts, new_arg = translate_expr ctxt arg in + arg_stmts @ [A.SLocalDef (x, new_arg), binder_pos]) + vars_args + in + let rest_of_block = translate_statements ctxt body in + local_decls @ List.flatten def_blocks @ rest_of_block | L.EAbs ((binder, binder_pos), taus) -> - let vars, body = Bindlib.unmbind binder in - let vars_tau = - List.map2 (fun x tau -> (x, tau)) (Array.to_list vars) taus - in - let closure_name = - match ctxt.inside_definition_of with - | None -> - A.LocalName.fresh (ctxt.context_name, Pos.get_position block_expr) - | Some x -> x - in - let ctxt = - { - ctxt with - var_dict = - List.fold_left - (fun var_dict (x, _) -> - L.VarMap.add x - (A.LocalName.fresh (Bindlib.name_of x, binder_pos)) - var_dict) - ctxt.var_dict vars_tau; - inside_definition_of = None; - } - in - let new_body = translate_statements ctxt body in - [ - ( A.SInnerFuncDef - ( (closure_name, binder_pos), - { - func_params = - List.map - (fun (var, tau) -> - ((L.VarMap.find var ctxt.var_dict, binder_pos), tau)) - vars_tau; - func_body = new_body; - } ), - binder_pos ); - ] + let vars, body = Bindlib.unmbind binder in + let vars_tau = List.map2 (fun x tau -> x, tau) (Array.to_list vars) taus in + let closure_name = + match ctxt.inside_definition_of with + | None -> + A.LocalName.fresh (ctxt.context_name, Pos.get_position block_expr) + | Some x -> x + in + let ctxt = + { + ctxt with + var_dict = + List.fold_left + (fun var_dict (x, _) -> + L.VarMap.add x + (A.LocalName.fresh (Bindlib.name_of x, binder_pos)) + var_dict) + ctxt.var_dict vars_tau; + inside_definition_of = None; + } + in + let new_body = translate_statements ctxt body in + [ + ( A.SInnerFuncDef + ( (closure_name, binder_pos), + { + func_params = + List.map + (fun (var, tau) -> + (L.VarMap.find var ctxt.var_dict, binder_pos), tau) + vars_tau; + func_body = new_body; + } ), + binder_pos ); + ] | L.EMatch (e1, args, e_name) -> - let e1_stmts, new_e1 = translate_expr ctxt e1 in - let new_args = - List.fold_left - (fun new_args arg -> - match Pos.unmark arg with - | L.EAbs ((binder, pos_binder), _) -> - let vars, body = Bindlib.unmbind binder in - assert (Array.length vars = 1); - let var = vars.(0) in - let scalc_var = - A.LocalName.fresh (Bindlib.name_of var, pos_binder) - in - let ctxt = - { - ctxt with - var_dict = L.VarMap.add var scalc_var ctxt.var_dict; - } - in - let new_arg = translate_statements ctxt body in - (new_arg, scalc_var) :: new_args - | _ -> assert false - (* should not happen *)) - [] args - in - let new_args = List.rev new_args in - e1_stmts - @ [ (A.SSwitch (new_e1, e_name, new_args), Pos.get_position block_expr) ] + let e1_stmts, new_e1 = translate_expr ctxt e1 in + let new_args = + List.fold_left + (fun new_args arg -> + match Pos.unmark arg with + | L.EAbs ((binder, pos_binder), _) -> + let vars, body = Bindlib.unmbind binder in + assert (Array.length vars = 1); + let var = vars.(0) in + let scalc_var = + A.LocalName.fresh (Bindlib.name_of var, pos_binder) + in + let ctxt = + { ctxt with var_dict = L.VarMap.add var scalc_var ctxt.var_dict } + in + let new_arg = translate_statements ctxt body in + (new_arg, scalc_var) :: new_args + | _ -> assert false + (* should not happen *)) + [] args + in + let new_args = List.rev new_args in + e1_stmts + @ [A.SSwitch (new_e1, e_name, new_args), Pos.get_position block_expr] | L.EIfThenElse (cond, e_true, e_false) -> - let cond_stmts, s_cond = translate_expr ctxt cond in - let s_e_true = translate_statements ctxt e_true in - let s_e_false = translate_statements ctxt e_false in - cond_stmts - @ [ - ( A.SIfThenElse (s_cond, s_e_true, s_e_false), - Pos.get_position block_expr ); - ] + let cond_stmts, s_cond = translate_expr ctxt cond in + let s_e_true = translate_statements ctxt e_true in + let s_e_false = translate_statements ctxt e_false in + cond_stmts + @ [A.SIfThenElse (s_cond, s_e_true, s_e_false), Pos.get_position block_expr] | L.ECatch (e_try, except, e_catch) -> - let s_e_try = translate_statements ctxt e_try in - let s_e_catch = translate_statements ctxt e_catch in - [ - (A.STryExcept (s_e_try, except, s_e_catch), Pos.get_position block_expr); - ] - | L.ERaise except -> [ (A.SRaise except, Pos.get_position block_expr) ] + let s_e_try = translate_statements ctxt e_try in + let s_e_catch = translate_statements ctxt e_catch in + [A.STryExcept (s_e_try, except, s_e_catch), Pos.get_position block_expr] + | L.ERaise except -> [A.SRaise except, Pos.get_position block_expr] | _ -> ( - let e_stmts, new_e = translate_expr ctxt block_expr in - e_stmts - @ - match e_stmts with - | (A.SRaise _, _) :: _ -> - (* if the last statement raises an exception, then we don't need to - return or to define the current variable since this code will be - unreachable *) - [] - | _ -> - [ - ( (match ctxt.inside_definition_of with - | None -> A.SReturn (Pos.unmark new_e) - | Some x -> A.SLocalDef (Pos.same_pos_as x new_e, new_e)), - Pos.get_position block_expr ); - ]) + let e_stmts, new_e = translate_expr ctxt block_expr in + e_stmts + @ + match e_stmts with + | (A.SRaise _, _) :: _ -> + (* if the last statement raises an exception, then we don't need to return + or to define the current variable since this code will be + unreachable *) + [] + | _ -> + [ + ( (match ctxt.inside_definition_of with + | None -> A.SReturn (Pos.unmark new_e) + | Some x -> A.SLocalDef (Pos.same_pos_as x new_e, new_e)), + Pos.get_position block_expr ); + ]) let rec translate_scope_body_expr (scope_name : D.ScopeName.t) @@ -284,58 +269,57 @@ let rec translate_scope_body_expr (scope_expr : L.expr D.scope_body_expr) : A.block = match scope_expr with | Result e -> - let block, new_e = + let block, new_e = + translate_expr + { + decl_ctx; + func_dict; + var_dict; + inside_definition_of = None; + context_name = Pos.unmark (D.ScopeName.get_info scope_name); + } + e + in + block @ [A.SReturn (Pos.unmark new_e), Pos.get_position new_e] + | ScopeLet scope_let -> + let let_var, scope_let_next = Bindlib.unbind scope_let.scope_let_next in + let let_var_id = + A.LocalName.fresh (Bindlib.name_of let_var, scope_let.scope_let_pos) + in + let new_var_dict = L.VarMap.add let_var let_var_id var_dict in + (match scope_let.scope_let_kind with + | D.Assertion -> + translate_statements + { + decl_ctx; + func_dict; + var_dict; + inside_definition_of = Some let_var_id; + context_name = Pos.unmark (D.ScopeName.get_info scope_name); + } + scope_let.scope_let_expr + | _ -> + let let_expr_stmts, new_let_expr = translate_expr { decl_ctx; func_dict; var_dict; - inside_definition_of = None; + inside_definition_of = Some let_var_id; context_name = Pos.unmark (D.ScopeName.get_info scope_name); } - e + scope_let.scope_let_expr in - block @ [ (A.SReturn (Pos.unmark new_e), Pos.get_position new_e) ] - | ScopeLet scope_let -> - let let_var, scope_let_next = Bindlib.unbind scope_let.scope_let_next in - let let_var_id = - A.LocalName.fresh (Bindlib.name_of let_var, scope_let.scope_let_pos) - in - let new_var_dict = L.VarMap.add let_var let_var_id var_dict in - (match scope_let.scope_let_kind with - | D.Assertion -> - translate_statements - { - decl_ctx; - func_dict; - var_dict; - inside_definition_of = Some let_var_id; - context_name = Pos.unmark (D.ScopeName.get_info scope_name); - } - scope_let.scope_let_expr - | _ -> - let let_expr_stmts, new_let_expr = - translate_expr - { - decl_ctx; - func_dict; - var_dict; - inside_definition_of = Some let_var_id; - context_name = Pos.unmark (D.ScopeName.get_info scope_name); - } - scope_let.scope_let_expr - in - let_expr_stmts - @ [ - ( A.SLocalDecl - ( (let_var_id, scope_let.scope_let_pos), - scope_let.scope_let_typ ), - scope_let.scope_let_pos ); - ( A.SLocalDef ((let_var_id, scope_let.scope_let_pos), new_let_expr), - scope_let.scope_let_pos ); - ]) - @ translate_scope_body_expr scope_name decl_ctx new_var_dict func_dict - scope_let_next + let_expr_stmts + @ [ + ( A.SLocalDecl + ((let_var_id, scope_let.scope_let_pos), scope_let.scope_let_typ), + scope_let.scope_let_pos ); + ( A.SLocalDef ((let_var_id, scope_let.scope_let_pos), new_let_expr), + scope_let.scope_let_pos ); + ]) + @ translate_scope_body_expr scope_name decl_ctx new_var_dict func_dict + scope_let_next let translate_program (p : L.program) : A.program = { diff --git a/compiler/scalc/print.ml b/compiler/scalc/print.ml index e5390757..5010615d 100644 --- a/compiler/scalc/print.ml +++ b/compiler/scalc/print.ml @@ -39,69 +39,68 @@ let rec format_expr | EVar v -> Format.fprintf fmt "%a" format_local_name v | EFunc v -> Format.fprintf fmt "%a" TopLevelName.format_t v | EStruct (es, s) -> - Format.fprintf fmt "@[%a@ %a%a%a@]" Dcalc.Ast.StructName.format_t s - Dcalc.Print.format_punctuation "{" - (Format.pp_print_list - ~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ") - (fun fmt (e, struct_field) -> - Format.fprintf fmt "%a%a%a%a %a" Dcalc.Print.format_punctuation - "\"" Dcalc.Ast.StructFieldName.format_t struct_field - Dcalc.Print.format_punctuation "\"" - Dcalc.Print.format_punctuation ":" format_expr e)) - (List.combine es - (List.map fst (Dcalc.Ast.StructMap.find s decl_ctx.ctx_structs))) - Dcalc.Print.format_punctuation "}" + Format.fprintf fmt "@[%a@ %a%a%a@]" Dcalc.Ast.StructName.format_t s + Dcalc.Print.format_punctuation "{" + (Format.pp_print_list + ~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ") + (fun fmt (e, struct_field) -> + Format.fprintf fmt "%a%a%a%a %a" Dcalc.Print.format_punctuation "\"" + Dcalc.Ast.StructFieldName.format_t struct_field + Dcalc.Print.format_punctuation "\"" Dcalc.Print.format_punctuation + ":" format_expr e)) + (List.combine es + (List.map fst (Dcalc.Ast.StructMap.find s decl_ctx.ctx_structs))) + Dcalc.Print.format_punctuation "}" | EArray es -> - Format.fprintf fmt "@[%a%a%a@]" Dcalc.Print.format_punctuation "[" - (Format.pp_print_list - ~pp_sep:(fun fmt () -> Format.fprintf fmt ";@ ") - (fun fmt e -> Format.fprintf fmt "%a" format_expr e)) - es Dcalc.Print.format_punctuation "]" + Format.fprintf fmt "@[%a%a%a@]" Dcalc.Print.format_punctuation "[" + (Format.pp_print_list + ~pp_sep:(fun fmt () -> Format.fprintf fmt ";@ ") + (fun fmt e -> Format.fprintf fmt "%a" format_expr e)) + es Dcalc.Print.format_punctuation "]" | EStructFieldAccess (e1, field, s) -> - Format.fprintf fmt "%a%a%a%a%a" format_expr e1 - Dcalc.Print.format_punctuation "." Dcalc.Print.format_punctuation "\"" - Dcalc.Ast.StructFieldName.format_t - (fst - (List.find - (fun (field', _) -> - Dcalc.Ast.StructFieldName.compare field' field = 0) - (Dcalc.Ast.StructMap.find s decl_ctx.ctx_structs))) - Dcalc.Print.format_punctuation "\"" + Format.fprintf fmt "%a%a%a%a%a" format_expr e1 + Dcalc.Print.format_punctuation "." Dcalc.Print.format_punctuation "\"" + Dcalc.Ast.StructFieldName.format_t + (fst + (List.find + (fun (field', _) -> + Dcalc.Ast.StructFieldName.compare field' field = 0) + (Dcalc.Ast.StructMap.find s decl_ctx.ctx_structs))) + Dcalc.Print.format_punctuation "\"" | EInj (e, case, enum) -> - Format.fprintf fmt "@[%a@ %a@]" Dcalc.Print.format_enum_constructor - (fst - (List.find - (fun (case', _) -> - Dcalc.Ast.EnumConstructor.compare case' case = 0) - (Dcalc.Ast.EnumMap.find enum decl_ctx.ctx_enums))) - format_expr e + Format.fprintf fmt "@[%a@ %a@]" Dcalc.Print.format_enum_constructor + (fst + (List.find + (fun (case', _) -> Dcalc.Ast.EnumConstructor.compare case' case = 0) + (Dcalc.Ast.EnumMap.find enum decl_ctx.ctx_enums))) + format_expr e | ELit l -> - Format.fprintf fmt "%a" Lcalc.Print.format_lit (Pos.same_pos_as l e) + Format.fprintf fmt "%a" Lcalc.Print.format_lit (Pos.same_pos_as l e) | EApp - ( (EOp (Binop ((Dcalc.Ast.Map | Dcalc.Ast.Filter) as op)), _), - [ arg1; arg2 ] ) -> - Format.fprintf fmt "@[%a@ %a@ %a@]" Dcalc.Print.format_binop - (op, Pos.no_pos) format_with_parens arg1 format_with_parens arg2 - | EApp ((EOp (Binop op), _), [ arg1; arg2 ]) -> - Format.fprintf fmt "@[%a@ %a@ %a@]" format_with_parens arg1 - Dcalc.Print.format_binop (op, Pos.no_pos) format_with_parens arg2 - | EApp ((EOp (Unop (Log _)), _), [ arg1 ]) when not debug -> - Format.fprintf fmt "%a" format_with_parens arg1 - | EApp ((EOp (Unop op), _), [ arg1 ]) -> - Format.fprintf fmt "@[%a@ %a@]" Dcalc.Print.format_unop - (op, Pos.no_pos) format_with_parens arg1 + ((EOp (Binop ((Dcalc.Ast.Map | Dcalc.Ast.Filter) as op)), _), [arg1; arg2]) + -> + Format.fprintf fmt "@[%a@ %a@ %a@]" Dcalc.Print.format_binop + (op, Pos.no_pos) format_with_parens arg1 format_with_parens arg2 + | EApp ((EOp (Binop op), _), [arg1; arg2]) -> + Format.fprintf fmt "@[%a@ %a@ %a@]" format_with_parens arg1 + Dcalc.Print.format_binop (op, Pos.no_pos) format_with_parens arg2 + | EApp ((EOp (Unop (Log _)), _), [arg1]) when not debug -> + Format.fprintf fmt "%a" format_with_parens arg1 + | EApp ((EOp (Unop op), _), [arg1]) -> + Format.fprintf fmt "@[%a@ %a@]" Dcalc.Print.format_unop + (op, Pos.no_pos) format_with_parens arg1 | EApp (f, args) -> - Format.fprintf fmt "@[%a@ %a@]" format_expr f - (Format.pp_print_list - ~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ") - format_with_parens) - args + Format.fprintf fmt "@[%a@ %a@]" format_expr f + (Format.pp_print_list + ~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ") + format_with_parens) + args | EOp (Ternop op) -> - Format.fprintf fmt "%a" Dcalc.Print.format_ternop (op, Pos.no_pos) + Format.fprintf fmt "%a" Dcalc.Print.format_ternop (op, Pos.no_pos) | EOp (Binop op) -> - Format.fprintf fmt "%a" Dcalc.Print.format_binop (op, Pos.no_pos) + Format.fprintf fmt "%a" Dcalc.Print.format_binop (op, Pos.no_pos) | EOp (Unop op) -> - Format.fprintf fmt "%a" Dcalc.Print.format_unop (op, Pos.no_pos) + Format.fprintf fmt "%a" Dcalc.Print.format_unop (op, Pos.no_pos) let rec format_statement (decl_ctx : Dcalc.Ast.decl_ctx) @@ -111,74 +110,74 @@ let rec format_statement if debug then () else (); match Pos.unmark stmt with | SInnerFuncDef (name, func) -> - Format.fprintf fmt "@[%a@ %a@ %a@ %a@]@\n@[ %a@]" - Dcalc.Print.format_keyword "let" LocalName.format_t (Pos.unmark name) - (Format.pp_print_list - ~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ") - (fun fmt ((name, _), typ) -> - Format.fprintf fmt "%a%a %a@ %a%a" Dcalc.Print.format_punctuation - "(" LocalName.format_t name Dcalc.Print.format_punctuation ":" - (Dcalc.Print.format_typ decl_ctx) - typ Dcalc.Print.format_punctuation ")")) - func.func_params Dcalc.Print.format_punctuation "=" - (format_block decl_ctx ~debug) - func.func_body + Format.fprintf fmt "@[%a@ %a@ %a@ %a@]@\n@[ %a@]" + Dcalc.Print.format_keyword "let" LocalName.format_t (Pos.unmark name) + (Format.pp_print_list + ~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ") + (fun fmt ((name, _), typ) -> + Format.fprintf fmt "%a%a %a@ %a%a" Dcalc.Print.format_punctuation "(" + LocalName.format_t name Dcalc.Print.format_punctuation ":" + (Dcalc.Print.format_typ decl_ctx) + typ Dcalc.Print.format_punctuation ")")) + func.func_params Dcalc.Print.format_punctuation "=" + (format_block decl_ctx ~debug) + func.func_body | SLocalDecl (name, typ) -> - Format.fprintf fmt "@[%a %a %a@ %a@]" Dcalc.Print.format_keyword - "decl" LocalName.format_t (Pos.unmark name) - Dcalc.Print.format_punctuation ":" - (Dcalc.Print.format_typ decl_ctx) - typ + Format.fprintf fmt "@[%a %a %a@ %a@]" Dcalc.Print.format_keyword + "decl" LocalName.format_t (Pos.unmark name) Dcalc.Print.format_punctuation + ":" + (Dcalc.Print.format_typ decl_ctx) + typ | SLocalDef (name, expr) -> - Format.fprintf fmt "@[%a %a@ %a@]" LocalName.format_t - (Pos.unmark name) Dcalc.Print.format_punctuation "=" - (format_expr decl_ctx ~debug) - expr + Format.fprintf fmt "@[%a %a@ %a@]" LocalName.format_t + (Pos.unmark name) Dcalc.Print.format_punctuation "=" + (format_expr decl_ctx ~debug) + expr | STryExcept (b_try, except, b_with) -> - Format.fprintf fmt "@[%a%a@ %a@]@\n@[%a %a%a@ %a@]" - Dcalc.Print.format_keyword "try" Dcalc.Print.format_punctuation ":" - (format_block decl_ctx ~debug) - b_try Dcalc.Print.format_keyword "with" Lcalc.Print.format_exception - except Dcalc.Print.format_punctuation ":" - (format_block decl_ctx ~debug) - b_with + Format.fprintf fmt "@[%a%a@ %a@]@\n@[%a %a%a@ %a@]" + Dcalc.Print.format_keyword "try" Dcalc.Print.format_punctuation ":" + (format_block decl_ctx ~debug) + b_try Dcalc.Print.format_keyword "with" Lcalc.Print.format_exception + except Dcalc.Print.format_punctuation ":" + (format_block decl_ctx ~debug) + b_with | SRaise except -> - Format.fprintf fmt "@[%a %a@]" Dcalc.Print.format_keyword "raise" - Lcalc.Print.format_exception except + Format.fprintf fmt "@[%a %a@]" Dcalc.Print.format_keyword "raise" + Lcalc.Print.format_exception except | SIfThenElse (e_if, b_true, b_false) -> - Format.fprintf fmt "@[%a @[%a@]%a@ %a@ @]@[%a%a@ %a@]" - Dcalc.Print.format_keyword "if" - (format_expr decl_ctx ~debug) - e_if Dcalc.Print.format_punctuation ":" - (format_block decl_ctx ~debug) - b_true Dcalc.Print.format_keyword "else" Dcalc.Print.format_punctuation - ":" - (format_block decl_ctx ~debug) - b_false + Format.fprintf fmt "@[%a @[%a@]%a@ %a@ @]@[%a%a@ %a@]" + Dcalc.Print.format_keyword "if" + (format_expr decl_ctx ~debug) + e_if Dcalc.Print.format_punctuation ":" + (format_block decl_ctx ~debug) + b_true Dcalc.Print.format_keyword "else" Dcalc.Print.format_punctuation + ":" + (format_block decl_ctx ~debug) + b_false | SReturn ret -> - Format.fprintf fmt "@[%a %a@]" Dcalc.Print.format_keyword "return" - (format_expr decl_ctx ~debug) - (ret, Pos.get_position stmt) + Format.fprintf fmt "@[%a %a@]" Dcalc.Print.format_keyword "return" + (format_expr decl_ctx ~debug) + (ret, Pos.get_position stmt) | SAssert expr -> - Format.fprintf fmt "@[%a %a@]" Dcalc.Print.format_keyword "assert" - (format_expr decl_ctx ~debug) - (expr, Pos.get_position stmt) + Format.fprintf fmt "@[%a %a@]" Dcalc.Print.format_keyword "assert" + (format_expr decl_ctx ~debug) + (expr, Pos.get_position stmt) | SSwitch (e_switch, enum, arms) -> - Format.fprintf fmt "@[%a @[%a@]%a@]%a" - Dcalc.Print.format_keyword "switch" - (format_expr decl_ctx ~debug) - e_switch Dcalc.Print.format_punctuation ":" - (Format.pp_print_list - ~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n") - (fun fmt ((case, _), (arm_block, payload_name)) -> - Format.fprintf fmt "%a %a%a@ %a @[%a@ %a@]" - Dcalc.Print.format_punctuation "|" - Dcalc.Print.format_enum_constructor case - Dcalc.Print.format_punctuation ":" LocalName.format_t - payload_name Dcalc.Print.format_punctuation "→" - (format_block decl_ctx ~debug) - arm_block)) - (List.combine (Dcalc.Ast.EnumMap.find enum decl_ctx.ctx_enums) arms) + Format.fprintf fmt "@[%a @[%a@]%a@]%a" + Dcalc.Print.format_keyword "switch" + (format_expr decl_ctx ~debug) + e_switch Dcalc.Print.format_punctuation ":" + (Format.pp_print_list + ~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n") + (fun fmt ((case, _), (arm_block, payload_name)) -> + Format.fprintf fmt "%a %a%a@ %a @[%a@ %a@]" + Dcalc.Print.format_punctuation "|" + Dcalc.Print.format_enum_constructor case + Dcalc.Print.format_punctuation ":" LocalName.format_t payload_name + Dcalc.Print.format_punctuation "→" + (format_block decl_ctx ~debug) + arm_block)) + (List.combine (Dcalc.Ast.EnumMap.find enum decl_ctx.ctx_enums) arms) and format_block (decl_ctx : Dcalc.Ast.decl_ctx) diff --git a/compiler/scalc/to_python.ml b/compiler/scalc/to_python.ml index 3860bdf1..5adad7a3 100644 --- a/compiler/scalc/to_python.ml +++ b/compiler/scalc/to_python.ml @@ -26,23 +26,22 @@ let format_lit (fmt : Format.formatter) (l : L.lit Pos.marked) : unit = | LBool true -> Format.fprintf fmt "True" | LBool false -> Format.fprintf fmt "False" | LInt i -> - Format.fprintf fmt "integer_of_string(\"%s\")" - (Runtime.integer_to_string i) + Format.fprintf fmt "integer_of_string(\"%s\")" (Runtime.integer_to_string i) | LUnit -> Format.fprintf fmt "Unit()" | LRat i -> - Format.fprintf fmt "decimal_of_string(\"%a\")" Dcalc.Print.format_lit - (Pos.same_pos_as (Dcalc.Ast.LRat i) l) + Format.fprintf fmt "decimal_of_string(\"%a\")" Dcalc.Print.format_lit + (Pos.same_pos_as (Dcalc.Ast.LRat i) l) | LMoney e -> - Format.fprintf fmt "money_of_cents_string(\"%s\")" - (Runtime.integer_to_string (Runtime.money_to_cents e)) + Format.fprintf fmt "money_of_cents_string(\"%s\")" + (Runtime.integer_to_string (Runtime.money_to_cents e)) | LDate d -> - Format.fprintf fmt "date_of_numbers(%d,%d,%d)" - (Runtime.integer_to_int (Runtime.year_of_date d)) - (Runtime.integer_to_int (Runtime.month_number_of_date d)) - (Runtime.integer_to_int (Runtime.day_of_month_of_date d)) + Format.fprintf fmt "date_of_numbers(%d,%d,%d)" + (Runtime.integer_to_int (Runtime.year_of_date d)) + (Runtime.integer_to_int (Runtime.month_number_of_date d)) + (Runtime.integer_to_int (Runtime.day_of_month_of_date d)) | LDuration d -> - let years, months, days = Runtime.duration_to_years_months_days d in - Format.fprintf fmt "duration_of_numbers(%d,%d,%d)" years months days + let years, months, days = Runtime.duration_to_years_months_days d in + Format.fprintf fmt "duration_of_numbers(%d,%d,%d)" years months days let format_log_entry (fmt : Format.formatter) (entry : Dcalc.Ast.log_entry) : unit = @@ -115,7 +114,7 @@ let avoid_keywords (s : string) : string = | "except" | "finally" | "for" | "from" | "global" | "if" | "import" | "in" | "is" | "lambda" | "nonlocal" | "not" | "or" | "pass" | "raise" | "return" | "try" | "while" | "with" | "yield" -> - true + true | _ -> false then s ^ "_" else s @@ -128,7 +127,8 @@ let format_struct_name (fmt : Format.formatter) (v : Dcalc.Ast.StructName.t) : (to_ascii (Format.asprintf "%a" Dcalc.Ast.StructName.format_t v)))) let format_struct_field_name - (fmt : Format.formatter) (v : Dcalc.Ast.StructFieldName.t) : unit = + (fmt : Format.formatter) + (v : Dcalc.Ast.StructFieldName.t) : unit = Format.fprintf fmt "%s" (avoid_keywords (to_ascii (Format.asprintf "%a" Dcalc.Ast.StructFieldName.format_t v))) @@ -141,7 +141,8 @@ let format_enum_name (fmt : Format.formatter) (v : Dcalc.Ast.EnumName.t) : unit (to_ascii (Format.asprintf "%a" Dcalc.Ast.EnumName.format_t v)))) let format_enum_cons_name - (fmt : Format.formatter) (v : Dcalc.Ast.EnumConstructor.t) : unit = + (fmt : Format.formatter) + (v : Dcalc.Ast.EnumConstructor.t) : unit = Format.fprintf fmt "%s" (avoid_keywords (to_ascii (Format.asprintf "%a" Dcalc.Ast.EnumConstructor.format_t v))) @@ -153,7 +154,8 @@ let rec format_typ (fmt : Format.formatter) (typ : Dcalc.Ast.typ Pos.marked) : unit = let format_typ = format_typ in let format_typ_with_parens - (fmt : Format.formatter) (t : Dcalc.Ast.typ Pos.marked) = + (fmt : Format.formatter) + (t : Dcalc.Ast.typ Pos.marked) = if typ_needs_parens t then Format.fprintf fmt "(%a)" format_typ t else Format.fprintf fmt "%a" format_typ t in @@ -166,19 +168,19 @@ let rec format_typ (fmt : Format.formatter) (typ : Dcalc.Ast.typ Pos.marked) : | TLit TDuration -> Format.fprintf fmt "Duration" | TLit TBool -> Format.fprintf fmt "bool" | TTuple (ts, None) -> - Format.fprintf fmt "Tuple[%a]" - (Format.pp_print_list - ~pp_sep:(fun fmt () -> Format.fprintf fmt ", ") - (fun fmt t -> Format.fprintf fmt "%a" format_typ_with_parens t)) - ts + Format.fprintf fmt "Tuple[%a]" + (Format.pp_print_list + ~pp_sep:(fun fmt () -> Format.fprintf fmt ", ") + (fun fmt t -> Format.fprintf fmt "%a" format_typ_with_parens t)) + ts | TTuple (_, Some s) -> Format.fprintf fmt "%a" format_struct_name s - | TEnum ([ _; some_typ ], e) when D.EnumName.compare e L.option_enum = 0 -> - (* We translate the option type with an overloading by Python's [None] *) - Format.fprintf fmt "Optional[%a]" format_typ some_typ + | TEnum ([_; some_typ], e) when D.EnumName.compare e L.option_enum = 0 -> + (* We translate the option type with an overloading by Python's [None] *) + Format.fprintf fmt "Optional[%a]" format_typ some_typ | TEnum (_, e) -> Format.fprintf fmt "%a" format_enum_name e | TArrow (t1, t2) -> - Format.fprintf fmt "Callable[[%a], %a]" format_typ_with_parens t1 - format_typ_with_parens t2 + Format.fprintf fmt "Callable[[%a], %a]" format_typ_with_parens t1 + format_typ_with_parens t2 | TArray t1 -> Format.fprintf fmt "List[%a]" format_typ_with_parens t1 | TAny -> Format.fprintf fmt "Any" @@ -208,25 +210,25 @@ let format_var (fmt : Format.formatter) (v : LocalName.t) : unit = let local_id = match StringMap.find_opt v_str !string_counter_map with | Some ids -> ( - match IntMap.find_opt hash ids with - | None -> - let max_id = - snd - (List.hd - (List.fast_sort - (fun (_, x) (_, y) -> Int.compare y x) - (IntMap.bindings ids))) - in - string_counter_map := - StringMap.add v_str - (IntMap.add hash (max_id + 1) ids) - !string_counter_map; - max_id + 1 - | Some local_id -> local_id) - | None -> + match IntMap.find_opt hash ids with + | None -> + let max_id = + snd + (List.hd + (List.fast_sort + (fun (_, x) (_, y) -> Int.compare y x) + (IntMap.bindings ids))) + in string_counter_map := - StringMap.add v_str (IntMap.singleton hash 0) !string_counter_map; - 0 + StringMap.add v_str + (IntMap.add hash (max_id + 1) ids) + !string_counter_map; + max_id + 1 + | Some local_id -> local_id) + | None -> + string_counter_map := + StringMap.add v_str (IntMap.singleton hash 0) !string_counter_map; + 0 in if v_str = "_" then Format.fprintf fmt "_" (* special case for the unit pattern *) @@ -249,167 +251,167 @@ let format_exception (fmt : Format.formatter) (exc : L.except Pos.marked) : unit | EmptyError -> Format.fprintf fmt "EmptyError" | Crash -> Format.fprintf fmt "Crash" | NoValueProvided -> - let pos = Pos.get_position exc in - Format.fprintf fmt - "NoValueProvided(@[SourcePosition(@[filename=\"%s\",@ \ - start_line=%d,@ start_column=%d,@ end_line=%d,@ end_column=%d,@ \ - law_headings=%a)@])@]" - (Pos.get_file pos) (Pos.get_start_line pos) (Pos.get_start_column pos) - (Pos.get_end_line pos) (Pos.get_end_column pos) format_string_list - (Pos.get_law_info pos) + let pos = Pos.get_position exc in + Format.fprintf fmt + "NoValueProvided(@[SourcePosition(@[filename=\"%s\",@ \ + start_line=%d,@ start_column=%d,@ end_line=%d,@ end_column=%d,@ \ + law_headings=%a)@])@]" + (Pos.get_file pos) (Pos.get_start_line pos) (Pos.get_start_column pos) + (Pos.get_end_line pos) (Pos.get_end_column pos) format_string_list + (Pos.get_law_info pos) let rec format_expression - (ctx : Dcalc.Ast.decl_ctx) (fmt : Format.formatter) (e : expr Pos.marked) : - unit = + (ctx : Dcalc.Ast.decl_ctx) + (fmt : Format.formatter) + (e : expr Pos.marked) : unit = match Pos.unmark e with | EVar v -> format_var fmt v | EFunc f -> format_toplevel_name fmt f | EStruct (es, s) -> - Format.fprintf fmt "%a(%a)" format_struct_name s - (Format.pp_print_list - ~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ") - (fun fmt (e, struct_field) -> - Format.fprintf fmt "%a = %a" format_struct_field_name struct_field - (format_expression ctx) e)) - (List.combine es - (List.map fst (Dcalc.Ast.StructMap.find s ctx.ctx_structs))) + Format.fprintf fmt "%a(%a)" format_struct_name s + (Format.pp_print_list + ~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ") + (fun fmt (e, struct_field) -> + Format.fprintf fmt "%a = %a" format_struct_field_name struct_field + (format_expression ctx) e)) + (List.combine es + (List.map fst (Dcalc.Ast.StructMap.find s ctx.ctx_structs))) | EStructFieldAccess (e1, field, _) -> - Format.fprintf fmt "%a.%a" (format_expression ctx) e1 - format_struct_field_name field + Format.fprintf fmt "%a.%a" (format_expression ctx) e1 + format_struct_field_name field | EInj (_, cons, e_name) when D.EnumName.compare e_name L.option_enum = 0 && D.EnumConstructor.compare cons L.none_constr = 0 -> - (* We translate the option type with an overloading by Python's [None] *) - Format.fprintf fmt "None" + (* We translate the option type with an overloading by Python's [None] *) + Format.fprintf fmt "None" | EInj (e, cons, e_name) when D.EnumName.compare e_name L.option_enum = 0 && D.EnumConstructor.compare cons L.some_constr = 0 -> - (* We translate the option type with an overloading by Python's [None] *) - format_expression ctx fmt e + (* We translate the option type with an overloading by Python's [None] *) + format_expression ctx fmt e | EInj (e, cons, enum_name) -> - Format.fprintf fmt "%a(%a_Code.%a,@ %a)" format_enum_name enum_name - format_enum_name enum_name format_enum_cons_name cons - (format_expression ctx) e + Format.fprintf fmt "%a(%a_Code.%a,@ %a)" format_enum_name enum_name + format_enum_name enum_name format_enum_cons_name cons + (format_expression ctx) e | EArray es -> - Format.fprintf fmt "[%a]" - (Format.pp_print_list - ~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ") - (fun fmt e -> Format.fprintf fmt "%a" (format_expression ctx) e)) - es + Format.fprintf fmt "[%a]" + (Format.pp_print_list + ~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ") + (fun fmt e -> Format.fprintf fmt "%a" (format_expression ctx) e)) + es | ELit l -> Format.fprintf fmt "%a" format_lit (Pos.same_pos_as l e) | EApp - ( (EOp (Binop ((Dcalc.Ast.Map | Dcalc.Ast.Filter) as op)), _), - [ arg1; arg2 ] ) -> - Format.fprintf fmt "%a(%a,@ %a)" format_binop (op, Pos.no_pos) - (format_expression ctx) arg1 (format_expression ctx) arg2 - | EApp ((EOp (Binop op), _), [ arg1; arg2 ]) -> - Format.fprintf fmt "(%a %a@ %a)" (format_expression ctx) arg1 format_binop - (op, Pos.no_pos) (format_expression ctx) arg2 - | EApp - ((EApp ((EOp (Unop (D.Log (D.BeginCall, info))), _), [ f ]), _), [ arg ]) + ((EOp (Binop ((Dcalc.Ast.Map | Dcalc.Ast.Filter) as op)), _), [arg1; arg2]) + -> + Format.fprintf fmt "%a(%a,@ %a)" format_binop (op, Pos.no_pos) + (format_expression ctx) arg1 (format_expression ctx) arg2 + | EApp ((EOp (Binop op), _), [arg1; arg2]) -> + Format.fprintf fmt "(%a %a@ %a)" (format_expression ctx) arg1 format_binop + (op, Pos.no_pos) (format_expression ctx) arg2 + | EApp ((EApp ((EOp (Unop (D.Log (D.BeginCall, info))), _), [f]), _), [arg]) when !Cli.trace_flag -> - Format.fprintf fmt "log_begin_call(%a,@ %a,@ %a)" format_uid_list info - (format_expression ctx) f (format_expression ctx) arg - | EApp ((EOp (Unop (D.Log (D.VarDef tau, info))), _), [ arg1 ]) + Format.fprintf fmt "log_begin_call(%a,@ %a,@ %a)" format_uid_list info + (format_expression ctx) f (format_expression ctx) arg + | EApp ((EOp (Unop (D.Log (D.VarDef tau, info))), _), [arg1]) when !Cli.trace_flag -> - Format.fprintf fmt "log_variable_definition(%a,@ %a)" format_uid_list info - (format_expression ctx) arg1 - | EApp ((EOp (Unop (D.Log (D.PosRecordIfTrueBool, _))), pos), [ arg1 ]) + Format.fprintf fmt "log_variable_definition(%a,@ %a)" format_uid_list info + (format_expression ctx) arg1 + | EApp ((EOp (Unop (D.Log (D.PosRecordIfTrueBool, _))), pos), [arg1]) when !Cli.trace_flag -> - Format.fprintf fmt - "log_decision_taken(SourcePosition(filename=\"%s\",@ start_line=%d,@ \ - start_column=%d,@ end_line=%d, end_column=%d,@ law_headings=%a), %a)" - (Pos.get_file pos) (Pos.get_start_line pos) (Pos.get_start_column pos) - (Pos.get_end_line pos) (Pos.get_end_column pos) format_string_list - (Pos.get_law_info pos) (format_expression ctx) arg1 - | EApp ((EOp (Unop (D.Log (D.EndCall, info))), _), [ arg1 ]) + Format.fprintf fmt + "log_decision_taken(SourcePosition(filename=\"%s\",@ start_line=%d,@ \ + start_column=%d,@ end_line=%d, end_column=%d,@ law_headings=%a), %a)" + (Pos.get_file pos) (Pos.get_start_line pos) (Pos.get_start_column pos) + (Pos.get_end_line pos) (Pos.get_end_column pos) format_string_list + (Pos.get_law_info pos) (format_expression ctx) arg1 + | EApp ((EOp (Unop (D.Log (D.EndCall, info))), _), [arg1]) when !Cli.trace_flag -> - Format.fprintf fmt "log_end_call(%a,@ %a)" format_uid_list info - (format_expression ctx) arg1 - | EApp ((EOp (Unop (D.Log _)), _), [ arg1 ]) -> - Format.fprintf fmt "%a" (format_expression ctx) arg1 - | EApp ((EOp (Unop ((Minus _ | Not) as op)), _), [ arg1 ]) -> - Format.fprintf fmt "%a %a" format_unop (op, Pos.no_pos) - (format_expression ctx) arg1 - | EApp ((EOp (Unop op), _), [ arg1 ]) -> - Format.fprintf fmt "%a(%a)" format_unop (op, Pos.no_pos) - (format_expression ctx) arg1 + Format.fprintf fmt "log_end_call(%a,@ %a)" format_uid_list info + (format_expression ctx) arg1 + | EApp ((EOp (Unop (D.Log _)), _), [arg1]) -> + Format.fprintf fmt "%a" (format_expression ctx) arg1 + | EApp ((EOp (Unop ((Minus _ | Not) as op)), _), [arg1]) -> + Format.fprintf fmt "%a %a" format_unop (op, Pos.no_pos) + (format_expression ctx) arg1 + | EApp ((EOp (Unop op), _), [arg1]) -> + Format.fprintf fmt "%a(%a)" format_unop (op, Pos.no_pos) + (format_expression ctx) arg1 | EApp (f, args) -> - Format.fprintf fmt "%a(@[%a)@]" (format_expression ctx) f - (Format.pp_print_list - ~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ") - (format_expression ctx)) - args + Format.fprintf fmt "%a(@[%a)@]" (format_expression ctx) f + (Format.pp_print_list + ~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ") + (format_expression ctx)) + args | EOp (Ternop op) -> Format.fprintf fmt "%a" format_ternop (op, Pos.no_pos) | EOp (Binop op) -> Format.fprintf fmt "%a" format_binop (op, Pos.no_pos) | EOp (Unop op) -> Format.fprintf fmt "%a" format_unop (op, Pos.no_pos) let rec format_statement - (ctx : Dcalc.Ast.decl_ctx) (fmt : Format.formatter) (s : stmt Pos.marked) : - unit = + (ctx : Dcalc.Ast.decl_ctx) + (fmt : Format.formatter) + (s : stmt Pos.marked) : unit = match Pos.unmark s with | SInnerFuncDef (name, { func_params; func_body }) -> - Format.fprintf fmt "@[def %a(%a):@\n%a@]" format_var - (Pos.unmark name) - (Format.pp_print_list - ~pp_sep:(fun fmt () -> Format.fprintf fmt ", ") - (fun fmt (var, typ) -> - Format.fprintf fmt "%a:%a" format_var (Pos.unmark var) format_typ - typ)) - func_params (format_block ctx) func_body + Format.fprintf fmt "@[def %a(%a):@\n%a@]" format_var + (Pos.unmark name) + (Format.pp_print_list + ~pp_sep:(fun fmt () -> Format.fprintf fmt ", ") + (fun fmt (var, typ) -> + Format.fprintf fmt "%a:%a" format_var (Pos.unmark var) format_typ typ)) + func_params (format_block ctx) func_body | SLocalDecl _ -> - assert false (* We don't need to declare variables in Python *) + assert false (* We don't need to declare variables in Python *) | SLocalDef (v, e) -> - Format.fprintf fmt "@[%a = %a@]" format_var (Pos.unmark v) - (format_expression ctx) e + Format.fprintf fmt "@[%a = %a@]" format_var (Pos.unmark v) + (format_expression ctx) e | STryExcept (try_b, except, catch_b) -> - Format.fprintf fmt "@[try:@\n%a@]@\n@[except %a:@\n%a@]" - (format_block ctx) try_b format_exception (except, Pos.no_pos) - (format_block ctx) catch_b + Format.fprintf fmt "@[try:@\n%a@]@\n@[except %a:@\n%a@]" + (format_block ctx) try_b format_exception (except, Pos.no_pos) + (format_block ctx) catch_b | SRaise except -> - Format.fprintf fmt "@[raise %a@]" format_exception - (except, Pos.get_position s) + Format.fprintf fmt "@[raise %a@]" format_exception + (except, Pos.get_position s) | SIfThenElse (cond, b1, b2) -> - Format.fprintf fmt "@[if %a:@\n%a@]@\n@[else:@\n%a@]" - (format_expression ctx) cond (format_block ctx) b1 (format_block ctx) b2 - | SSwitch (e1, e_name, [ (case_none, _); (case_some, case_some_var) ]) + Format.fprintf fmt "@[if %a:@\n%a@]@\n@[else:@\n%a@]" + (format_expression ctx) cond (format_block ctx) b1 (format_block ctx) b2 + | SSwitch (e1, e_name, [(case_none, _); (case_some, case_some_var)]) when D.EnumName.compare e_name L.option_enum = 0 -> - (* We translate the option type with an overloading by Python's [None] *) - let tmp_var = LocalName.fresh ("perhaps_none_arg", Pos.no_pos) in - Format.fprintf fmt - "%a = %a@\n\ - @[if %a is None:@\n\ - %a@]@\n\ - @[else:@\n\ - %a = %a@\n\ - %a@]" - format_var tmp_var (format_expression ctx) e1 format_var tmp_var - (format_block ctx) case_none format_var case_some_var format_var tmp_var - (format_block ctx) case_some + (* We translate the option type with an overloading by Python's [None] *) + let tmp_var = LocalName.fresh ("perhaps_none_arg", Pos.no_pos) in + Format.fprintf fmt + "%a = %a@\n\ + @[if %a is None:@\n\ + %a@]@\n\ + @[else:@\n\ + %a = %a@\n\ + %a@]" + format_var tmp_var (format_expression ctx) e1 format_var tmp_var + (format_block ctx) case_none format_var case_some_var format_var tmp_var + (format_block ctx) case_some | SSwitch (e1, e_name, cases) -> - let cases = - List.map2 - (fun (x, y) (cons, _) -> (x, y, cons)) - cases - (D.EnumMap.find e_name ctx.ctx_enums) - in - let tmp_var = LocalName.fresh ("match_arg", Pos.no_pos) in - Format.fprintf fmt "%a = %a@\n@[if %a@]" format_var tmp_var - (format_expression ctx) e1 - (Format.pp_print_list - ~pp_sep:(fun fmt () -> Format.fprintf fmt "@]@\n@[elif ") - (fun fmt (case_block, payload_var, cons_name) -> - Format.fprintf fmt "%a.code == %a_Code.%a:@\n%a = %a.value@\n%a" - format_var tmp_var format_enum_name e_name format_enum_cons_name - cons_name format_var payload_var format_var tmp_var - (format_block ctx) case_block)) + let cases = + List.map2 + (fun (x, y) (cons, _) -> x, y, cons) cases + (D.EnumMap.find e_name ctx.ctx_enums) + in + let tmp_var = LocalName.fresh ("match_arg", Pos.no_pos) in + Format.fprintf fmt "%a = %a@\n@[if %a@]" format_var tmp_var + (format_expression ctx) e1 + (Format.pp_print_list + ~pp_sep:(fun fmt () -> Format.fprintf fmt "@]@\n@[elif ") + (fun fmt (case_block, payload_var, cons_name) -> + Format.fprintf fmt "%a.code == %a_Code.%a:@\n%a = %a.value@\n%a" + format_var tmp_var format_enum_name e_name format_enum_cons_name + cons_name format_var payload_var format_var tmp_var + (format_block ctx) case_block)) + cases | SReturn e1 -> - Format.fprintf fmt "@[return %a@]" (format_expression ctx) - (e1, Pos.get_position s) + Format.fprintf fmt "@[return %a@]" (format_expression ctx) + (e1, Pos.get_position s) | SAssert e1 -> - Format.fprintf fmt "@[assert %a@]" (format_expression ctx) - (e1, Pos.get_position s) + Format.fprintf fmt "@[assert %a@]" (format_expression ctx) + (e1, Pos.get_position s) and format_block (ctx : Dcalc.Ast.decl_ctx) (fmt : Format.formatter) (b : block) : unit = @@ -506,7 +508,7 @@ let format_ctx ~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n") (fun _fmt (i, enum_cons, enum_cons_type) -> Format.fprintf fmt "%a = %d" format_enum_cons_name enum_cons i)) - (List.mapi (fun i (x, y) -> (i, x, y)) enum_cons) + (List.mapi (fun i (x, y) -> i, x, y) enum_cons) format_enum_name enum_name format_enum_name enum_name format_enum_name enum_name in @@ -531,11 +533,11 @@ let format_ctx (fun struct_or_enum -> match struct_or_enum with | Scopelang.Dependency.TVertex.Struct s -> - Format.fprintf fmt "%a@\n@\n" format_struct_decl - (s, Dcalc.Ast.StructMap.find s ctx.Dcalc.Ast.ctx_structs) + Format.fprintf fmt "%a@\n@\n" format_struct_decl + (s, Dcalc.Ast.StructMap.find s ctx.Dcalc.Ast.ctx_structs) | Scopelang.Dependency.TVertex.Enum e -> - Format.fprintf fmt "%a@\n@\n" format_enum_decl - (e, Dcalc.Ast.EnumMap.find e ctx.Dcalc.Ast.ctx_enums)) + Format.fprintf fmt "%a@\n@\n" format_enum_decl + (e, Dcalc.Ast.EnumMap.find e ctx.Dcalc.Ast.ctx_enums)) (type_ordering @ scope_structs) let format_program diff --git a/compiler/scopelang/ast.ml b/compiler/scopelang/ast.ml index 5793112f..13dce2d3 100644 --- a/compiler/scopelang/ast.ml +++ b/compiler/scopelang/ast.ml @@ -60,12 +60,12 @@ Set.Make (struct type t = location Pos.marked let compare x y = - match (Pos.unmark x, Pos.unmark y) with + match Pos.unmark x, Pos.unmark y with | ScopeVar (vx, _), ScopeVar (vy, _) -> ScopeVar.compare vx vy | ( SubScopeVar (_, (xsubindex, _), (xsubvar, _)), SubScopeVar (_, (ysubindex, _), (ysubvar, _)) ) -> - let c = SubScopeName.compare xsubindex ysubindex in - if c = 0 then ScopeVar.compare xsubvar ysubvar else c + let c = SubScopeName.compare xsubindex ysubindex in + if c = 0 then ScopeVar.compare xsubvar ysubvar else c | ScopeVar _, SubScopeVar _ -> -1 | SubScopeVar _, ScopeVar _ -> 1 end) @@ -101,34 +101,34 @@ let rec locations_used (e : expr Pos.marked) : LocationSet.t = | ELocation l -> LocationSet.singleton (l, Pos.get_position e) | EVar _ | ELit _ | EOp _ -> LocationSet.empty | EAbs ((binder, _), _) -> - let _, body = Bindlib.unmbind binder in - locations_used body + let _, body = Bindlib.unmbind binder in + locations_used body | EStruct (_, es) -> - StructFieldMap.fold - (fun _ e' acc -> LocationSet.union acc (locations_used e')) - es LocationSet.empty + StructFieldMap.fold + (fun _ e' acc -> LocationSet.union acc (locations_used e')) + es LocationSet.empty | EStructAccess (e1, _, _) -> locations_used e1 | EEnumInj (e1, _, _) -> locations_used e1 | EMatch (e1, _, es) -> - EnumConstructorMap.fold - (fun _ e' acc -> LocationSet.union acc (locations_used e')) - es (locations_used e1) + EnumConstructorMap.fold + (fun _ e' acc -> LocationSet.union acc (locations_used e')) + es (locations_used e1) | EApp (e1, args) -> - List.fold_left - (fun acc arg -> LocationSet.union (locations_used arg) acc) - (locations_used e1) args + List.fold_left + (fun acc arg -> LocationSet.union (locations_used arg) acc) + (locations_used e1) args | EIfThenElse (e1, e2, e3) -> - LocationSet.union (locations_used e1) - (LocationSet.union (locations_used e2) (locations_used e3)) + LocationSet.union (locations_used e1) + (LocationSet.union (locations_used e2) (locations_used e3)) | EDefault (excepts, just, cons) -> - List.fold_left - (fun acc except -> LocationSet.union (locations_used except) acc) - (LocationSet.union (locations_used just) (locations_used cons)) - excepts + List.fold_left + (fun acc except -> LocationSet.union (locations_used except) acc) + (LocationSet.union (locations_used just) (locations_used cons)) + excepts | EArray es -> - List.fold_left - (fun acc e' -> LocationSet.union acc (locations_used e')) - LocationSet.empty es + List.fold_left + (fun acc e' -> LocationSet.union acc (locations_used e')) + LocationSet.empty es | ErrorOnEmpty e' -> locations_used e' type io_input = NoInput | OnlyInput | Reentrant @@ -168,7 +168,7 @@ end type vars = expr Bindlib.mvar let make_var ((x, pos) : Var.t Pos.marked) : expr Pos.marked Bindlib.box = - Bindlib.box_apply (fun v -> (v, pos)) (Bindlib.box_var x) + Bindlib.box_apply (fun v -> v, pos) (Bindlib.box_var x) let make_abs (xs : vars) @@ -177,14 +177,14 @@ let make_abs (taus : typ Pos.marked list) (pos : Pos.t) : expr Pos.marked Bindlib.box = Bindlib.box_apply - (fun b -> (EAbs ((b, pos_binder), taus), pos)) + (fun b -> EAbs ((b, pos_binder), taus), pos) (Bindlib.bind_mvar xs e) let make_app (e : expr Pos.marked Bindlib.box) (u : expr Pos.marked Bindlib.box list) (pos : Pos.t) : expr Pos.marked Bindlib.box = - Bindlib.box_apply2 (fun e u -> (EApp (e, u), pos)) e (Bindlib.box_list u) + Bindlib.box_apply2 (fun e u -> EApp (e, u), pos) e (Bindlib.box_list u) let make_let_in (x : Var.t) @@ -192,13 +192,11 @@ let make_let_in (e1 : expr Pos.marked Bindlib.box) (e2 : expr Pos.marked Bindlib.box) : expr Pos.marked Bindlib.box = Bindlib.box_apply2 - (fun e u -> (EApp (e, u), Pos.get_position (Bindlib.unbox e2))) - (make_abs - (Array.of_list [ x ]) - e2 + (fun e u -> EApp (e, u), Pos.get_position (Bindlib.unbox e2)) + (make_abs (Array.of_list [x]) e2 (Pos.get_position (Bindlib.unbox e2)) - [ tau ] + [tau] (Pos.get_position (Bindlib.unbox e2))) - (Bindlib.box_list [ e1 ]) + (Bindlib.box_list [e1]) module VarMap = Map.Make (Var) diff --git a/compiler/scopelang/dependency.ml b/compiler/scopelang/dependency.ml index 37d68123..99b6d0ac 100644 --- a/compiler/scopelang/dependency.ml +++ b/compiler/scopelang/dependency.ml @@ -59,17 +59,17 @@ let build_program_dep_graph (prgm : Ast.program) : SDependencies.t = match r with | Ast.Definition _ | Ast.Assertion _ -> acc | Ast.Call (subscope, subindex) -> - if subscope = scope_name then - Errors.raise_spanned_error - (Pos.get_position - (Ast.ScopeName.get_info scope.Ast.scope_decl_name)) - "The scope %a is calling into itself as a subscope, which \ - is forbidden since Catala does not provide recursion" - Ast.ScopeName.format_t scope.Ast.scope_decl_name - else - Ast.ScopeMap.add subscope - (Pos.get_position (Ast.SubScopeName.get_info subindex)) - acc) + if subscope = scope_name then + Errors.raise_spanned_error + (Pos.get_position + (Ast.ScopeName.get_info scope.Ast.scope_decl_name)) + "The scope %a is calling into itself as a subscope, which is \ + forbidden since Catala does not provide recursion" + Ast.ScopeName.format_t scope.Ast.scope_decl_name + else + Ast.ScopeMap.add subscope + (Pos.get_position (Ast.SubScopeName.get_info subindex)) + acc) Ast.ScopeMap.empty scope.Ast.scope_decl_rules in Ast.ScopeMap.fold @@ -123,14 +123,14 @@ module TVertex = struct | Enum x -> Ast.EnumName.hash x let compare x y = - match (x, y) with + match x, y with | Struct x, Struct y -> Ast.StructName.compare x y | Enum x, Enum y -> Ast.EnumName.compare x y | Struct _, Enum _ -> 1 | Enum _, Struct _ -> -1 let equal x y = - match (x, y) with + match x, y with | Struct x, Struct y -> Ast.StructName.compare x y = 0 | Enum x, Enum y -> Ast.EnumName.compare x y = 0 | _ -> false @@ -170,9 +170,9 @@ let rec get_structs_or_enums_in_type (t : Ast.typ Pos.marked) : TVertexSet.t = | Ast.TStruct s -> TVertexSet.singleton (TVertex.Struct s) | Ast.TEnum e -> TVertexSet.singleton (TVertex.Enum e) | Ast.TArrow (t1, t2) -> - TVertexSet.union - (get_structs_or_enums_in_type t1) - (get_structs_or_enums_in_type t2) + TVertexSet.union + (get_structs_or_enums_in_type t1) + (get_structs_or_enums_in_type t2) | Ast.TLit _ | Ast.TAny -> TVertexSet.empty | Ast.TArray t1 -> get_structs_or_enums_in_type (Pos.same_pos_as t1 t) @@ -242,7 +242,7 @@ let check_type_cycles (structs : Ast.struct_ctx) (enums : Ast.enum_ctx) : (List.map (fun v -> let var_str, var_info = - (Format.asprintf "%a" TVertex.format_t v, TVertex.get_info v) + Format.asprintf "%a" TVertex.format_t v, TVertex.get_info v in let succs = TDependencies.succ_e g v in let _, edge_pos, succ = diff --git a/compiler/scopelang/print.ml b/compiler/scopelang/print.ml index f1b9b35b..ede11881 100644 --- a/compiler/scopelang/print.ml +++ b/compiler/scopelang/print.ml @@ -27,8 +27,8 @@ let format_location (fmt : Format.formatter) (l : location) : unit = match l with | ScopeVar v -> Format.fprintf fmt "%a" ScopeVar.format_t (Pos.unmark v) | SubScopeVar (_, subindex, subvar) -> - Format.fprintf fmt "%a.%a" SubScopeName.format_t (Pos.unmark subindex) - ScopeVar.format_t (Pos.unmark subvar) + Format.fprintf fmt "%a.%a" SubScopeName.format_t (Pos.unmark subindex) + ScopeVar.format_t (Pos.unmark subvar) let typ_needs_parens (e : typ Pos.marked) : bool = match Pos.unmark e with TArrow _ -> true | _ -> false @@ -45,16 +45,17 @@ let rec format_typ (fmt : Format.formatter) (typ : typ Pos.marked) : unit = | TStruct s -> Format.fprintf fmt "%a" Ast.StructName.format_t s | TEnum e -> Format.fprintf fmt "%a" Ast.EnumName.format_t e | TArrow (t1, t2) -> - Format.fprintf fmt "@[%a %a@ %a@]" format_typ_with_parens t1 - Dcalc.Print.format_operator "→" format_typ t2 + Format.fprintf fmt "@[%a %a@ %a@]" format_typ_with_parens t1 + Dcalc.Print.format_operator "→" format_typ t2 | TArray t1 -> - Format.fprintf fmt "@[%a@ %a@]" format_typ (Pos.same_pos_as t1 typ) - Dcalc.Print.format_base_type "array" + Format.fprintf fmt "@[%a@ %a@]" format_typ (Pos.same_pos_as t1 typ) + Dcalc.Print.format_base_type "array" | TAny -> Format.fprintf fmt "any" let rec format_expr - ?(debug : bool = false) (fmt : Format.formatter) (e : expr Pos.marked) : - unit = + ?(debug : bool = false) + (fmt : Format.formatter) + (e : expr Pos.marked) : unit = let format_expr = format_expr ~debug in let format_with_parens (fmt : Format.formatter) (e : expr Pos.marked) = if needs_parens e then Format.fprintf fmt "(%a)" format_expr e @@ -64,115 +65,110 @@ let rec format_expr | ELocation l -> Format.fprintf fmt "%a" format_location l | EVar v -> Format.fprintf fmt "%a" format_var (Pos.unmark v) | ELit l -> - Format.fprintf fmt "%a" Dcalc.Print.format_lit (Pos.same_pos_as l e) + Format.fprintf fmt "%a" Dcalc.Print.format_lit (Pos.same_pos_as l e) | EStruct (name, fields) -> - Format.fprintf fmt " @[%a@ %a@ %a@ %a@]" Ast.StructName.format_t - name Dcalc.Print.format_punctuation "{" - (Format.pp_print_list - ~pp_sep:(fun fmt () -> - Format.fprintf fmt "%a@ " Dcalc.Print.format_punctuation ";") - (fun fmt (field_name, field_expr) -> - Format.fprintf fmt "%a%a%a%a@ %a" Dcalc.Print.format_punctuation - "\"" Ast.StructFieldName.format_t field_name - Dcalc.Print.format_punctuation "\"" - Dcalc.Print.format_punctuation "=" format_expr field_expr)) - (Ast.StructFieldMap.bindings fields) - Dcalc.Print.format_punctuation "}" + Format.fprintf fmt " @[%a@ %a@ %a@ %a@]" Ast.StructName.format_t name + Dcalc.Print.format_punctuation "{" + (Format.pp_print_list + ~pp_sep:(fun fmt () -> + Format.fprintf fmt "%a@ " Dcalc.Print.format_punctuation ";") + (fun fmt (field_name, field_expr) -> + Format.fprintf fmt "%a%a%a%a@ %a" Dcalc.Print.format_punctuation "\"" + Ast.StructFieldName.format_t field_name + Dcalc.Print.format_punctuation "\"" Dcalc.Print.format_punctuation + "=" format_expr field_expr)) + (Ast.StructFieldMap.bindings fields) + Dcalc.Print.format_punctuation "}" | EStructAccess (e1, field, _) -> - Format.fprintf fmt "%a%a%a%a%a" format_expr e1 - Dcalc.Print.format_punctuation "." Dcalc.Print.format_punctuation "\"" - Ast.StructFieldName.format_t field Dcalc.Print.format_punctuation "\"" + Format.fprintf fmt "%a%a%a%a%a" format_expr e1 + Dcalc.Print.format_punctuation "." Dcalc.Print.format_punctuation "\"" + Ast.StructFieldName.format_t field Dcalc.Print.format_punctuation "\"" | EEnumInj (e1, cons, _) -> - Format.fprintf fmt "%a@ %a" Ast.EnumConstructor.format_t cons format_expr - e1 + Format.fprintf fmt "%a@ %a" Ast.EnumConstructor.format_t cons format_expr e1 | EMatch (e1, _, cases) -> - Format.fprintf fmt "@[%a@ @[%a@]@ %a@ %a@]" - Dcalc.Print.format_keyword "match" format_expr e1 - Dcalc.Print.format_keyword "with" - (Format.pp_print_list - ~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n") - (fun fmt (cons_name, case_expr) -> - Format.fprintf fmt "@[%a %a@ %a@ %a@]" - Dcalc.Print.format_punctuation "|" - Dcalc.Print.format_enum_constructor cons_name - Dcalc.Print.format_punctuation "→" format_expr case_expr)) - (Ast.EnumConstructorMap.bindings cases) + Format.fprintf fmt "@[%a@ @[%a@]@ %a@ %a@]" + Dcalc.Print.format_keyword "match" format_expr e1 + Dcalc.Print.format_keyword "with" + (Format.pp_print_list + ~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n") + (fun fmt (cons_name, case_expr) -> + Format.fprintf fmt "@[%a %a@ %a@ %a@]" + Dcalc.Print.format_punctuation "|" + Dcalc.Print.format_enum_constructor cons_name + Dcalc.Print.format_punctuation "→" format_expr case_expr)) + (Ast.EnumConstructorMap.bindings cases) | EApp ((EAbs ((binder, _), taus), _), args) -> - let xs, body = Bindlib.unmbind binder in - let xs_tau = List.map2 (fun x tau -> (x, tau)) (Array.to_list xs) taus in - let xs_tau_arg = - List.map2 (fun (x, tau) arg -> (x, tau, arg)) xs_tau args - in - Format.fprintf fmt "@[%a%a@]" - (Format.pp_print_list - ~pp_sep:(fun fmt () -> Format.fprintf fmt " ") - (fun fmt (x, tau, arg) -> - Format.fprintf fmt "@[%a@ %a@ %a@ %a@ %a@ %a@ %a@\n@]" - Dcalc.Print.format_keyword "let" format_var x - Dcalc.Print.format_punctuation ":" format_typ tau - Dcalc.Print.format_punctuation "=" format_expr arg - Dcalc.Print.format_keyword "in")) - xs_tau_arg format_expr body + let xs, body = Bindlib.unmbind binder in + let xs_tau = List.map2 (fun x tau -> x, tau) (Array.to_list xs) taus in + let xs_tau_arg = List.map2 (fun (x, tau) arg -> x, tau, arg) xs_tau args in + Format.fprintf fmt "@[%a%a@]" + (Format.pp_print_list + ~pp_sep:(fun fmt () -> Format.fprintf fmt " ") + (fun fmt (x, tau, arg) -> + Format.fprintf fmt "@[%a@ %a@ %a@ %a@ %a@ %a@ %a@\n@]" + Dcalc.Print.format_keyword "let" format_var x + Dcalc.Print.format_punctuation ":" format_typ tau + Dcalc.Print.format_punctuation "=" format_expr arg + Dcalc.Print.format_keyword "in")) + xs_tau_arg format_expr body | EAbs ((binder, _), taus) -> - let xs, body = Bindlib.unmbind binder in - let xs_tau = List.map2 (fun x tau -> (x, tau)) (Array.to_list xs) taus in - Format.fprintf fmt "@[%a@ %a@ %a@ %a@]" - Dcalc.Print.format_punctuation "λ" - (Format.pp_print_list - ~pp_sep:(fun fmt () -> Format.fprintf fmt " ") - (fun fmt (x, tau) -> - Format.fprintf fmt "@[%a%a%a@ %a%a@]" - Dcalc.Print.format_punctuation "(" format_var x - Dcalc.Print.format_punctuation ":" format_typ tau - Dcalc.Print.format_punctuation ")")) - xs_tau Dcalc.Print.format_punctuation "→" format_expr body - | EApp ((EOp (Binop op), _), [ arg1; arg2 ]) -> - Format.fprintf fmt "@[%a@ %a@ %a@]" format_with_parens arg1 - Dcalc.Print.format_binop (op, Pos.no_pos) format_with_parens arg2 - | EApp ((EOp (Unop (Log _)), _), [ arg1 ]) when not debug -> - format_expr fmt arg1 - | EApp ((EOp (Unop op), _), [ arg1 ]) -> - Format.fprintf fmt "@[%a@ %a@]" Dcalc.Print.format_unop (op, Pos.no_pos) - format_with_parens arg1 + let xs, body = Bindlib.unmbind binder in + let xs_tau = List.map2 (fun x tau -> x, tau) (Array.to_list xs) taus in + Format.fprintf fmt "@[%a@ %a@ %a@ %a@]" + Dcalc.Print.format_punctuation "λ" + (Format.pp_print_list + ~pp_sep:(fun fmt () -> Format.fprintf fmt " ") + (fun fmt (x, tau) -> + Format.fprintf fmt "@[%a%a%a@ %a%a@]" Dcalc.Print.format_punctuation + "(" format_var x Dcalc.Print.format_punctuation ":" format_typ tau + Dcalc.Print.format_punctuation ")")) + xs_tau Dcalc.Print.format_punctuation "→" format_expr body + | EApp ((EOp (Binop op), _), [arg1; arg2]) -> + Format.fprintf fmt "@[%a@ %a@ %a@]" format_with_parens arg1 + Dcalc.Print.format_binop (op, Pos.no_pos) format_with_parens arg2 + | EApp ((EOp (Unop (Log _)), _), [arg1]) when not debug -> + format_expr fmt arg1 + | EApp ((EOp (Unop op), _), [arg1]) -> + Format.fprintf fmt "@[%a@ %a@]" Dcalc.Print.format_unop (op, Pos.no_pos) + format_with_parens arg1 | EApp (f, args) -> - Format.fprintf fmt "@[%a@ %a@]" format_expr f - (Format.pp_print_list - ~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ") - format_with_parens) - args + Format.fprintf fmt "@[%a@ %a@]" format_expr f + (Format.pp_print_list + ~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ") + format_with_parens) + args | EIfThenElse (e1, e2, e3) -> - Format.fprintf fmt "@[%a@ %a@ %a@ %a@ %a@ %a@]" - Dcalc.Print.format_keyword "if" format_expr e1 - Dcalc.Print.format_keyword "then" format_expr e2 - Dcalc.Print.format_keyword "else" format_expr e3 + Format.fprintf fmt "@[%a@ %a@ %a@ %a@ %a@ %a@]" + Dcalc.Print.format_keyword "if" format_expr e1 Dcalc.Print.format_keyword + "then" format_expr e2 Dcalc.Print.format_keyword "else" format_expr e3 | EOp (Ternop op) -> - Format.fprintf fmt "%a" Dcalc.Print.format_ternop (op, Pos.no_pos) + Format.fprintf fmt "%a" Dcalc.Print.format_ternop (op, Pos.no_pos) | EOp (Binop op) -> - Format.fprintf fmt "%a" Dcalc.Print.format_binop (op, Pos.no_pos) + Format.fprintf fmt "%a" Dcalc.Print.format_binop (op, Pos.no_pos) | EOp (Unop op) -> - Format.fprintf fmt "%a" Dcalc.Print.format_unop (op, Pos.no_pos) + Format.fprintf fmt "%a" Dcalc.Print.format_unop (op, Pos.no_pos) | EDefault (excepts, just, cons) -> - if List.length excepts = 0 then - Format.fprintf fmt "@[%a%a %a@ %a%a@]" Dcalc.Print.format_punctuation - "⟨" format_expr just Dcalc.Print.format_punctuation "⊢" format_expr - cons Dcalc.Print.format_punctuation "⟩" - else - Format.fprintf fmt "@[%a%a@ %a@ %a %a@ %a%a@]" - Dcalc.Print.format_punctuation "⟨" - (Format.pp_print_list - ~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ") - format_expr) - excepts Dcalc.Print.format_punctuation "|" format_expr just - Dcalc.Print.format_punctuation "⊢" format_expr cons - Dcalc.Print.format_punctuation "⟩" - | ErrorOnEmpty e' -> - Format.fprintf fmt "error_empty@ %a" format_with_parens e' - | EArray es -> - Format.fprintf fmt "%a%a%a" Dcalc.Print.format_punctuation "[" + if List.length excepts = 0 then + Format.fprintf fmt "@[%a%a %a@ %a%a@]" Dcalc.Print.format_punctuation "⟨" + format_expr just Dcalc.Print.format_punctuation "⊢" format_expr cons + Dcalc.Print.format_punctuation "⟩" + else + Format.fprintf fmt "@[%a%a@ %a@ %a %a@ %a%a@]" + Dcalc.Print.format_punctuation "⟨" (Format.pp_print_list - ~pp_sep:(fun fmt () -> Dcalc.Print.format_punctuation fmt ";") - (fun fmt e -> Format.fprintf fmt "@[%a@]" format_expr e)) - es Dcalc.Print.format_punctuation "]" + ~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ") + format_expr) + excepts Dcalc.Print.format_punctuation "|" format_expr just + Dcalc.Print.format_punctuation "⊢" format_expr cons + Dcalc.Print.format_punctuation "⟩" + | ErrorOnEmpty e' -> + Format.fprintf fmt "error_empty@ %a" format_with_parens e' + | EArray es -> + Format.fprintf fmt "%a%a%a" Dcalc.Print.format_punctuation "[" + (Format.pp_print_list + ~pp_sep:(fun fmt () -> Dcalc.Print.format_punctuation fmt ";") + (fun fmt e -> Format.fprintf fmt "@[%a@]" format_expr e)) + es Dcalc.Print.format_punctuation "]" let format_struct (fmt : Format.formatter) @@ -233,36 +229,38 @@ let format_scope (fun fmt rule -> match rule with | Definition (loc, typ, _, e) -> - Format.fprintf fmt "@[%a %a %a %a %a@ %a@]" - Dcalc.Print.format_keyword "let" format_location (Pos.unmark loc) - Dcalc.Print.format_punctuation ":" format_typ typ - Dcalc.Print.format_punctuation "=" - (fun fmt e -> - match Pos.unmark loc with - | SubScopeVar _ -> format_expr fmt e - | ScopeVar v -> ( - match - Pos.unmark - (snd (ScopeVarMap.find (Pos.unmark v) decl.scope_sig)) - .io_input - with - | Reentrant -> - Format.fprintf fmt "%a@ %a" Dcalc.Print.format_operator - "reentrant or by default" (format_expr ~debug) e - | _ -> Format.fprintf fmt "%a" (format_expr ~debug) e)) - e + Format.fprintf fmt "@[%a %a %a %a %a@ %a@]" + Dcalc.Print.format_keyword "let" format_location (Pos.unmark loc) + Dcalc.Print.format_punctuation ":" format_typ typ + Dcalc.Print.format_punctuation "=" + (fun fmt e -> + match Pos.unmark loc with + | SubScopeVar _ -> format_expr fmt e + | ScopeVar v -> ( + match + Pos.unmark + (snd (ScopeVarMap.find (Pos.unmark v) decl.scope_sig)) + .io_input + with + | Reentrant -> + Format.fprintf fmt "%a@ %a" Dcalc.Print.format_operator + "reentrant or by default" (format_expr ~debug) e + | _ -> Format.fprintf fmt "%a" (format_expr ~debug) e)) + e | Assertion e -> - Format.fprintf fmt "%a %a" Dcalc.Print.format_keyword "assert" - (format_expr ~debug) e + Format.fprintf fmt "%a %a" Dcalc.Print.format_keyword "assert" + (format_expr ~debug) e | Call (scope_name, subscope_name) -> - Format.fprintf fmt "%a %a%a%a%a" Dcalc.Print.format_keyword "call" - ScopeName.format_t scope_name Dcalc.Print.format_punctuation "[" - SubScopeName.format_t subscope_name - Dcalc.Print.format_punctuation "]")) + Format.fprintf fmt "%a %a%a%a%a" Dcalc.Print.format_keyword "call" + ScopeName.format_t scope_name Dcalc.Print.format_punctuation "[" + SubScopeName.format_t subscope_name Dcalc.Print.format_punctuation + "]")) decl.scope_decl_rules let format_program - ?(debug : bool = false) (fmt : Format.formatter) (p : program) : unit = + ?(debug : bool = false) + (fmt : Format.formatter) + (p : program) : unit = Format.fprintf fmt "%a%a%a%a%a" (Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt "\n\n") diff --git a/compiler/scopelang/scope_to_dcalc.ml b/compiler/scopelang/scope_to_dcalc.ml index 1076638f..e9596dba 100644 --- a/compiler/scopelang/scope_to_dcalc.ml +++ b/compiler/scopelang/scope_to_dcalc.ml @@ -66,17 +66,17 @@ let rec translate_typ (ctx : ctx) (t : Ast.typ Pos.marked) : (match Pos.unmark t with | Ast.TLit l -> Dcalc.Ast.TLit l | Ast.TArrow (t1, t2) -> - Dcalc.Ast.TArrow (translate_typ ctx t1, translate_typ ctx t2) + Dcalc.Ast.TArrow (translate_typ ctx t1, translate_typ ctx t2) | Ast.TStruct s_uid -> - let s_fields = Ast.StructMap.find s_uid ctx.structs in - Dcalc.Ast.TTuple - (List.map (fun (_, t) -> translate_typ ctx t) s_fields, Some s_uid) + let s_fields = Ast.StructMap.find s_uid ctx.structs in + Dcalc.Ast.TTuple + (List.map (fun (_, t) -> translate_typ ctx t) s_fields, Some s_uid) | Ast.TEnum e_uid -> - let e_cases = Ast.EnumMap.find e_uid ctx.enums in - Dcalc.Ast.TEnum - (List.map (fun (_, t) -> translate_typ ctx t) e_cases, e_uid) + let e_cases = Ast.EnumMap.find e_uid ctx.enums in + Dcalc.Ast.TEnum + (List.map (fun (_, t) -> translate_typ ctx t) e_cases, e_uid) | Ast.TArray t1 -> - Dcalc.Ast.TArray (translate_typ ctx (Pos.same_pos_as t1 t)) + Dcalc.Ast.TArray (translate_typ ctx (Pos.same_pos_as t1 t)) | Ast.TAny -> Dcalc.Ast.TAny) t @@ -86,14 +86,14 @@ let merge_defaults Dcalc.Ast.expr Pos.marked Bindlib.box = let caller = Dcalc.Ast.make_app caller - [ Bindlib.box (Dcalc.Ast.ELit Dcalc.Ast.LUnit, Pos.no_pos) ] + [Bindlib.box (Dcalc.Ast.ELit Dcalc.Ast.LUnit, Pos.no_pos)] Pos.no_pos in let body = Bindlib.box_apply2 (fun caller callee -> ( Dcalc.Ast.EDefault - ( [ caller ], + ( [caller], (Dcalc.Ast.ELit (Dcalc.Ast.LBool true), Pos.no_pos), callee ), Pos.no_pos )) @@ -111,7 +111,7 @@ let tag_with_log_entry ( Dcalc.Ast.EApp ( ( Dcalc.Ast.EOp (Dcalc.Ast.Unop (Dcalc.Ast.Log (l, markings))), Pos.get_position e ), - [ e ] ), + [e] ), Pos.get_position e )) e @@ -123,215 +123,211 @@ let rec translate_expr (ctx : ctx) (e : Ast.expr Pos.marked) : | EVar v -> Bindlib.box_var (Ast.VarMap.find (Pos.unmark v) ctx.local_vars) | ELit l -> Bindlib.box (Dcalc.Ast.ELit l) | EStruct (struct_name, e_fields) -> - let struct_sig = Ast.StructMap.find struct_name ctx.structs in - let d_fields, remaining_e_fields = - List.fold_right - (fun (field_name, _) (d_fields, e_fields) -> - let field_e = Ast.StructFieldMap.find field_name e_fields in - let field_d = translate_expr ctx field_e in - ( field_d :: d_fields, - Ast.StructFieldMap.remove field_name e_fields )) - struct_sig ([], e_fields) - in - if Ast.StructFieldMap.cardinal remaining_e_fields > 0 then - Errors.raise_spanned_error (Pos.get_position e) - "The fields \"%a\" do not belong to the structure %a" - Ast.StructName.format_t struct_name - (Format.pp_print_list - ~pp_sep:(fun fmt () -> Format.fprintf fmt ", ") - (fun fmt (field_name, _) -> - Format.fprintf fmt "%a" Ast.StructFieldName.format_t field_name)) - (Ast.StructFieldMap.bindings remaining_e_fields) - else - Bindlib.box_apply - (fun d_fields -> Dcalc.Ast.ETuple (d_fields, Some struct_name)) - (Bindlib.box_list d_fields) + let struct_sig = Ast.StructMap.find struct_name ctx.structs in + let d_fields, remaining_e_fields = + List.fold_right + (fun (field_name, _) (d_fields, e_fields) -> + let field_e = Ast.StructFieldMap.find field_name e_fields in + let field_d = translate_expr ctx field_e in + field_d :: d_fields, Ast.StructFieldMap.remove field_name e_fields) + struct_sig ([], e_fields) + in + if Ast.StructFieldMap.cardinal remaining_e_fields > 0 then + Errors.raise_spanned_error (Pos.get_position e) + "The fields \"%a\" do not belong to the structure %a" + Ast.StructName.format_t struct_name + (Format.pp_print_list + ~pp_sep:(fun fmt () -> Format.fprintf fmt ", ") + (fun fmt (field_name, _) -> + Format.fprintf fmt "%a" Ast.StructFieldName.format_t field_name)) + (Ast.StructFieldMap.bindings remaining_e_fields) + else + Bindlib.box_apply + (fun d_fields -> Dcalc.Ast.ETuple (d_fields, Some struct_name)) + (Bindlib.box_list d_fields) | EStructAccess (e1, field_name, struct_name) -> - let struct_sig = Ast.StructMap.find struct_name ctx.structs in - let _, field_index = - try - List.assoc field_name - (List.mapi (fun i (x, y) -> (x, (y, i))) struct_sig) - with Not_found -> - Errors.raise_spanned_error (Pos.get_position e) - "The field \"%a\" does not belong to the structure %a" - Ast.StructFieldName.format_t field_name Ast.StructName.format_t - struct_name - in - let e1 = translate_expr ctx e1 in - Bindlib.box_apply - (fun e1 -> - Dcalc.Ast.ETupleAccess - ( e1, - field_index, - Some struct_name, - List.map (fun (_, t) -> translate_typ ctx t) struct_sig )) - e1 - | EEnumInj (e1, constructor, enum_name) -> - let enum_sig = Ast.EnumMap.find enum_name ctx.enums in - let _, constructor_index = - try - List.assoc constructor - (List.mapi (fun i (x, y) -> (x, (y, i))) enum_sig) - with Not_found -> - Errors.raise_spanned_error (Pos.get_position e) - "The constructor \"%a\" does not belong to the enum %a" - Ast.EnumConstructor.format_t constructor Ast.EnumName.format_t - enum_name - in - let e1 = translate_expr ctx e1 in - Bindlib.box_apply - (fun e1 -> - Dcalc.Ast.EInj - ( e1, - constructor_index, - enum_name, - List.map (fun (_, t) -> translate_typ ctx t) enum_sig )) - e1 - | EMatch (e1, enum_name, cases) -> - let enum_sig = Ast.EnumMap.find enum_name ctx.enums in - let d_cases, remaining_e_cases = - List.fold_right - (fun (constructor, _) (d_cases, e_cases) -> - let case_e = - try Ast.EnumConstructorMap.find constructor e_cases - with Not_found -> - Errors.raise_spanned_error (Pos.get_position e) - "The constructor %a of enum %a is missing from this \ - pattern matching" - Ast.EnumConstructor.format_t constructor - Ast.EnumName.format_t enum_name - in - let case_d = translate_expr ctx case_e in - ( case_d :: d_cases, - Ast.EnumConstructorMap.remove constructor e_cases )) - enum_sig ([], cases) - in - if Ast.EnumConstructorMap.cardinal remaining_e_cases > 0 then - Errors.raise_spanned_error (Pos.get_position e) - "Patter matching is incomplete for enum %a: missing cases %a" - Ast.EnumName.format_t enum_name - (Format.pp_print_list - ~pp_sep:(fun fmt () -> Format.fprintf fmt ", ") - (fun fmt (case_name, _) -> - Format.fprintf fmt "%a" Ast.EnumConstructor.format_t case_name)) - (Ast.EnumConstructorMap.bindings remaining_e_cases) - else - let e1 = translate_expr ctx e1 in - Bindlib.box_apply2 - (fun d_fields e1 -> Dcalc.Ast.EMatch (e1, d_fields, enum_name)) - (Bindlib.box_list d_cases) e1 - | EApp (e1, args) -> - (* We insert various log calls to record arguments and outputs of - user-defined functions belonging to scopes *) - let e1_func = translate_expr ctx e1 in - let markings l = - match l with - | Ast.ScopeVar (v, _) -> - [ Ast.ScopeName.get_info ctx.scope_name; Ast.ScopeVar.get_info v ] - | Ast.SubScopeVar (s, _, (v, _)) -> - [ Ast.ScopeName.get_info s; Ast.ScopeVar.get_info v ] - in - let e1_func = - match Pos.unmark e1 with - | ELocation l -> - tag_with_log_entry e1_func Dcalc.Ast.BeginCall (markings l) - | _ -> e1_func - in - let new_args = List.map (translate_expr ctx) args in - let new_args = - match (Pos.unmark e1, new_args) with - | ELocation l, [ new_arg ] -> - [ - tag_with_log_entry new_arg (Dcalc.Ast.VarDef Dcalc.Ast.TAny) - (markings l @ [ Pos.same_pos_as "input" e ]); - ] - | _ -> new_args - in - let new_e = - Bindlib.box_apply2 - (fun e' u -> (Dcalc.Ast.EApp (e', u), Pos.get_position e)) - e1_func - (Bindlib.box_list new_args) - in - let new_e = - match Pos.unmark e1 with - | ELocation l -> - tag_with_log_entry - (tag_with_log_entry new_e (Dcalc.Ast.VarDef Dcalc.Ast.TAny) - (markings l @ [ Pos.same_pos_as "output" e ])) - Dcalc.Ast.EndCall (markings l) - | _ -> new_e - in - Bindlib.box_apply Pos.unmark new_e - | EAbs ((binder, pos_binder), typ) -> - let xs, body = Bindlib.unmbind binder in - let new_xs = - Array.map - (fun x -> Dcalc.Ast.Var.make (Bindlib.name_of x, Pos.no_pos)) - xs - in - let both_xs = Array.map2 (fun x new_x -> (x, new_x)) xs new_xs in - let body = - translate_expr - { - ctx with - local_vars = - Array.fold_left - (fun local_vars (x, new_x) -> - Ast.VarMap.add x new_x local_vars) - ctx.local_vars both_xs; - } - body - in - let binder = Bindlib.bind_mvar new_xs body in - Bindlib.box_apply - (fun b -> - Dcalc.Ast.EAbs ((b, pos_binder), List.map (translate_typ ctx) typ)) - binder - | EDefault (excepts, just, cons) -> - Bindlib.box_apply3 - (fun e j c -> Dcalc.Ast.EDefault (e, j, c)) - (Bindlib.box_list (List.map (translate_expr ctx) excepts)) - (translate_expr ctx just) (translate_expr ctx cons) - | ELocation (ScopeVar a) -> - let v, _, _ = Ast.ScopeVarMap.find (Pos.unmark a) ctx.scope_vars in - Bindlib.box_var v - | ELocation (SubScopeVar (_, s, a)) -> ( + let struct_sig = Ast.StructMap.find struct_name ctx.structs in + let _, field_index = try - let v, _, _ = - Ast.ScopeVarMap.find (Pos.unmark a) - (Ast.SubScopeMap.find (Pos.unmark s) ctx.subscope_vars) - in - Bindlib.box_var v + List.assoc field_name + (List.mapi (fun i (x, y) -> x, (y, i)) struct_sig) with Not_found -> - Errors.raise_multispanned_error - [ - (Some "Incriminated variable usage:", Pos.get_position e); - ( Some "Incriminated subscope variable declaration:", - Pos.get_position (Ast.ScopeVar.get_info (Pos.unmark a)) ); - ( Some "Incriminated subscope declaration:", - Pos.get_position (Ast.SubScopeName.get_info (Pos.unmark s)) ); - ] - "The variable %a.%a cannot be used here, as it is not part \ - subscope %a's results. Maybe you forgot to qualify it as an \ - output?" - Ast.SubScopeName.format_t (Pos.unmark s) Ast.ScopeVar.format_t - (Pos.unmark a) Ast.SubScopeName.format_t (Pos.unmark s)) + Errors.raise_spanned_error (Pos.get_position e) + "The field \"%a\" does not belong to the structure %a" + Ast.StructFieldName.format_t field_name Ast.StructName.format_t + struct_name + in + let e1 = translate_expr ctx e1 in + Bindlib.box_apply + (fun e1 -> + Dcalc.Ast.ETupleAccess + ( e1, + field_index, + Some struct_name, + List.map (fun (_, t) -> translate_typ ctx t) struct_sig )) + e1 + | EEnumInj (e1, constructor, enum_name) -> + let enum_sig = Ast.EnumMap.find enum_name ctx.enums in + let _, constructor_index = + try + List.assoc constructor + (List.mapi (fun i (x, y) -> x, (y, i)) enum_sig) + with Not_found -> + Errors.raise_spanned_error (Pos.get_position e) + "The constructor \"%a\" does not belong to the enum %a" + Ast.EnumConstructor.format_t constructor Ast.EnumName.format_t + enum_name + in + let e1 = translate_expr ctx e1 in + Bindlib.box_apply + (fun e1 -> + Dcalc.Ast.EInj + ( e1, + constructor_index, + enum_name, + List.map (fun (_, t) -> translate_typ ctx t) enum_sig )) + e1 + | EMatch (e1, enum_name, cases) -> + let enum_sig = Ast.EnumMap.find enum_name ctx.enums in + let d_cases, remaining_e_cases = + List.fold_right + (fun (constructor, _) (d_cases, e_cases) -> + let case_e = + try Ast.EnumConstructorMap.find constructor e_cases + with Not_found -> + Errors.raise_spanned_error (Pos.get_position e) + "The constructor %a of enum %a is missing from this pattern \ + matching" + Ast.EnumConstructor.format_t constructor Ast.EnumName.format_t + enum_name + in + let case_d = translate_expr ctx case_e in + case_d :: d_cases, Ast.EnumConstructorMap.remove constructor e_cases) + enum_sig ([], cases) + in + if Ast.EnumConstructorMap.cardinal remaining_e_cases > 0 then + Errors.raise_spanned_error (Pos.get_position e) + "Patter matching is incomplete for enum %a: missing cases %a" + Ast.EnumName.format_t enum_name + (Format.pp_print_list + ~pp_sep:(fun fmt () -> Format.fprintf fmt ", ") + (fun fmt (case_name, _) -> + Format.fprintf fmt "%a" Ast.EnumConstructor.format_t case_name)) + (Ast.EnumConstructorMap.bindings remaining_e_cases) + else + let e1 = translate_expr ctx e1 in + Bindlib.box_apply2 + (fun d_fields e1 -> Dcalc.Ast.EMatch (e1, d_fields, enum_name)) + (Bindlib.box_list d_cases) e1 + | EApp (e1, args) -> + (* We insert various log calls to record arguments and outputs of + user-defined functions belonging to scopes *) + let e1_func = translate_expr ctx e1 in + let markings l = + match l with + | Ast.ScopeVar (v, _) -> + [Ast.ScopeName.get_info ctx.scope_name; Ast.ScopeVar.get_info v] + | Ast.SubScopeVar (s, _, (v, _)) -> + [Ast.ScopeName.get_info s; Ast.ScopeVar.get_info v] + in + let e1_func = + match Pos.unmark e1 with + | ELocation l -> + tag_with_log_entry e1_func Dcalc.Ast.BeginCall (markings l) + | _ -> e1_func + in + let new_args = List.map (translate_expr ctx) args in + let new_args = + match Pos.unmark e1, new_args with + | ELocation l, [new_arg] -> + [ + tag_with_log_entry new_arg (Dcalc.Ast.VarDef Dcalc.Ast.TAny) + (markings l @ [Pos.same_pos_as "input" e]); + ] + | _ -> new_args + in + let new_e = + Bindlib.box_apply2 + (fun e' u -> Dcalc.Ast.EApp (e', u), Pos.get_position e) + e1_func + (Bindlib.box_list new_args) + in + let new_e = + match Pos.unmark e1 with + | ELocation l -> + tag_with_log_entry + (tag_with_log_entry new_e (Dcalc.Ast.VarDef Dcalc.Ast.TAny) + (markings l @ [Pos.same_pos_as "output" e])) + Dcalc.Ast.EndCall (markings l) + | _ -> new_e + in + Bindlib.box_apply Pos.unmark new_e + | EAbs ((binder, pos_binder), typ) -> + let xs, body = Bindlib.unmbind binder in + let new_xs = + Array.map + (fun x -> Dcalc.Ast.Var.make (Bindlib.name_of x, Pos.no_pos)) + xs + in + let both_xs = Array.map2 (fun x new_x -> x, new_x) xs new_xs in + let body = + translate_expr + { + ctx with + local_vars = + Array.fold_left + (fun local_vars (x, new_x) -> Ast.VarMap.add x new_x local_vars) + ctx.local_vars both_xs; + } + body + in + let binder = Bindlib.bind_mvar new_xs body in + Bindlib.box_apply + (fun b -> + Dcalc.Ast.EAbs ((b, pos_binder), List.map (translate_typ ctx) typ)) + binder + | EDefault (excepts, just, cons) -> + Bindlib.box_apply3 + (fun e j c -> Dcalc.Ast.EDefault (e, j, c)) + (Bindlib.box_list (List.map (translate_expr ctx) excepts)) + (translate_expr ctx just) (translate_expr ctx cons) + | ELocation (ScopeVar a) -> + let v, _, _ = Ast.ScopeVarMap.find (Pos.unmark a) ctx.scope_vars in + Bindlib.box_var v + | ELocation (SubScopeVar (_, s, a)) -> ( + try + let v, _, _ = + Ast.ScopeVarMap.find (Pos.unmark a) + (Ast.SubScopeMap.find (Pos.unmark s) ctx.subscope_vars) + in + Bindlib.box_var v + with Not_found -> + Errors.raise_multispanned_error + [ + Some "Incriminated variable usage:", Pos.get_position e; + ( Some "Incriminated subscope variable declaration:", + Pos.get_position (Ast.ScopeVar.get_info (Pos.unmark a)) ); + ( Some "Incriminated subscope declaration:", + Pos.get_position (Ast.SubScopeName.get_info (Pos.unmark s)) ); + ] + "The variable %a.%a cannot be used here, as it is not part subscope \ + %a's results. Maybe you forgot to qualify it as an output?" + Ast.SubScopeName.format_t (Pos.unmark s) Ast.ScopeVar.format_t + (Pos.unmark a) Ast.SubScopeName.format_t (Pos.unmark s)) | EIfThenElse (cond, et, ef) -> - Bindlib.box_apply3 - (fun c t f -> Dcalc.Ast.EIfThenElse (c, t, f)) - (translate_expr ctx cond) (translate_expr ctx et) - (translate_expr ctx ef) + Bindlib.box_apply3 + (fun c t f -> Dcalc.Ast.EIfThenElse (c, t, f)) + (translate_expr ctx cond) (translate_expr ctx et) + (translate_expr ctx ef) | EOp op -> Bindlib.box (Dcalc.Ast.EOp op) | ErrorOnEmpty e' -> - Bindlib.box_apply - (fun e' -> Dcalc.Ast.ErrorOnEmpty e') - (translate_expr ctx e') + Bindlib.box_apply + (fun e' -> Dcalc.Ast.ErrorOnEmpty e') + (translate_expr ctx e') | EArray es -> - Bindlib.box_apply - (fun es -> Dcalc.Ast.EArray es) - (Bindlib.box_list (List.map (translate_expr ctx) es))) + Bindlib.box_apply + (fun es -> Dcalc.Ast.EArray es) + (Bindlib.box_list (List.map (translate_expr ctx) es))) (** The result of a rule translation is a list of assignment, with variables and expressions. We also return the new translation context available after the @@ -347,298 +343,294 @@ let translate_rule * ctx = match rule with | Definition ((ScopeVar a, var_def_pos), tau, a_io, e) -> - let a_name = Ast.ScopeVar.get_info (Pos.unmark a) in - let a_var = Dcalc.Ast.Var.make a_name in - let tau = translate_typ ctx tau in - let new_e = translate_expr ctx e in - let a_expr = Dcalc.Ast.make_var (a_var, var_def_pos) in - let merged_expr = - Bindlib.box_apply - (fun merged_expr -> - (Dcalc.Ast.ErrorOnEmpty merged_expr, Pos.get_position a_name)) - (match Pos.unmark a_io.io_input with - | OnlyInput -> - failwith "should not happen" - (* scopelang should not contain any definitions of input only - variables *) - | Reentrant -> merge_defaults a_expr new_e - | NoInput -> new_e) - in - let merged_expr = - tag_with_log_entry merged_expr - (Dcalc.Ast.VarDef (Pos.unmark tau)) - [ (sigma_name, pos_sigma); a_name ] - in - ( (fun next -> - Bindlib.box_apply2 - (fun next merged_expr -> - Dcalc.Ast.ScopeLet - { - Dcalc.Ast.scope_let_next = next; - Dcalc.Ast.scope_let_typ = tau; - Dcalc.Ast.scope_let_expr = merged_expr; - Dcalc.Ast.scope_let_kind = Dcalc.Ast.ScopeVarDefinition; - Dcalc.Ast.scope_let_pos = Pos.get_position a; - }) - (Bindlib.bind_var a_var next) - merged_expr), - { - ctx with - scope_vars = - Ast.ScopeVarMap.add (Pos.unmark a) - (a_var, Pos.unmark tau, a_io) - ctx.scope_vars; - } ) + let a_name = Ast.ScopeVar.get_info (Pos.unmark a) in + let a_var = Dcalc.Ast.Var.make a_name in + let tau = translate_typ ctx tau in + let new_e = translate_expr ctx e in + let a_expr = Dcalc.Ast.make_var (a_var, var_def_pos) in + let merged_expr = + Bindlib.box_apply + (fun merged_expr -> + Dcalc.Ast.ErrorOnEmpty merged_expr, Pos.get_position a_name) + (match Pos.unmark a_io.io_input with + | OnlyInput -> + failwith "should not happen" + (* scopelang should not contain any definitions of input only + variables *) + | Reentrant -> merge_defaults a_expr new_e + | NoInput -> new_e) + in + let merged_expr = + tag_with_log_entry merged_expr + (Dcalc.Ast.VarDef (Pos.unmark tau)) + [sigma_name, pos_sigma; a_name] + in + ( (fun next -> + Bindlib.box_apply2 + (fun next merged_expr -> + Dcalc.Ast.ScopeLet + { + Dcalc.Ast.scope_let_next = next; + Dcalc.Ast.scope_let_typ = tau; + Dcalc.Ast.scope_let_expr = merged_expr; + Dcalc.Ast.scope_let_kind = Dcalc.Ast.ScopeVarDefinition; + Dcalc.Ast.scope_let_pos = Pos.get_position a; + }) + (Bindlib.bind_var a_var next) + merged_expr), + { + ctx with + scope_vars = + Ast.ScopeVarMap.add (Pos.unmark a) + (a_var, Pos.unmark tau, a_io) + ctx.scope_vars; + } ) | Definition ( (SubScopeVar (_subs_name, subs_index, subs_var), var_def_pos), tau, a_io, e ) -> - let a_name = - Pos.map_under_mark - (fun str -> - str ^ "." ^ Pos.unmark (Ast.ScopeVar.get_info (Pos.unmark subs_var))) - (Ast.SubScopeName.get_info (Pos.unmark subs_index)) - in - let a_var = Dcalc.Ast.Var.make a_name in - let tau = translate_typ ctx tau in - let new_e = - tag_with_log_entry (translate_expr ctx e) - (Dcalc.Ast.VarDef (Pos.unmark tau)) - [ (sigma_name, pos_sigma); a_name ] - in - let silent_var = Dcalc.Ast.Var.make ("_", Pos.no_pos) in - let thunked_or_nonempty_new_e = - match Pos.unmark a_io.io_input with - | NoInput -> failwith "should not happen" - | OnlyInput -> - Bindlib.box_apply - (fun new_e -> - (Dcalc.Ast.ErrorOnEmpty new_e, Pos.get_position subs_var)) - new_e - | Reentrant -> - Dcalc.Ast.make_abs - (Array.of_list [ silent_var ]) - new_e var_def_pos - [ (Dcalc.Ast.TLit TUnit, var_def_pos) ] - var_def_pos - in - ( (fun next -> - Bindlib.box_apply2 - (fun next thunked_or_nonempty_new_e -> - Dcalc.Ast.ScopeLet - { - Dcalc.Ast.scope_let_next = next; - Dcalc.Ast.scope_let_pos = Pos.get_position a_name; - Dcalc.Ast.scope_let_typ = - (match Pos.unmark a_io.io_input with - | NoInput -> failwith "should not happen" - | OnlyInput -> tau - | Reentrant -> - ( Dcalc.Ast.TArrow ((TLit TUnit, var_def_pos), tau), - var_def_pos )); - Dcalc.Ast.scope_let_expr = thunked_or_nonempty_new_e; - Dcalc.Ast.scope_let_kind = Dcalc.Ast.SubScopeVarDefinition; - }) - (Bindlib.bind_var a_var next) - thunked_or_nonempty_new_e), - { - ctx with - subscope_vars = - Ast.SubScopeMap.update (Pos.unmark subs_index) - (fun map -> - match map with - | Some map -> - Some - (Ast.ScopeVarMap.add (Pos.unmark subs_var) - (a_var, Pos.unmark tau, a_io) - map) - | None -> - Some - (Ast.ScopeVarMap.singleton (Pos.unmark subs_var) - (a_var, Pos.unmark tau, a_io))) - ctx.subscope_vars; - } ) - | Call (subname, subindex) -> - let subscope_sig = Ast.ScopeMap.find subname ctx.scopes_parameters in - let all_subscope_vars = subscope_sig.scope_sig_local_vars in - let all_subscope_input_vars = - List.filter - (fun var_ctx -> - match Pos.unmark var_ctx.scope_var_io.Ast.io_input with - | NoInput -> false - | _ -> true) - all_subscope_vars - in - let all_subscope_output_vars = - List.filter - (fun var_ctx -> Pos.unmark var_ctx.scope_var_io.Ast.io_output) - all_subscope_vars - in - let scope_dcalc_var = subscope_sig.scope_sig_scope_var in - let called_scope_input_struct = subscope_sig.scope_sig_input_struct in - let called_scope_return_struct = subscope_sig.scope_sig_output_struct in - let subscope_vars_defined = - try Ast.SubScopeMap.find subindex ctx.subscope_vars - with Not_found -> Ast.ScopeVarMap.empty - in - let subscope_var_not_yet_defined subvar = - not (Ast.ScopeVarMap.mem subvar subscope_vars_defined) - in - let pos_call = Pos.get_position (Ast.SubScopeName.get_info subindex) in - let subscope_args = - List.map - (fun (subvar : scope_var_ctx) -> - if subscope_var_not_yet_defined subvar.scope_var_name then - (* This is a redundant check. Normally, all subscope varaibles - should have been defined (even an empty definition, if they're - not defined by any rule in the source code) by the translation - from desugared to the scope language. *) - Bindlib.box Dcalc.Ast.empty_thunked_term - else - let a_var, _, _ = - Ast.ScopeVarMap.find subvar.scope_var_name subscope_vars_defined - in - Dcalc.Ast.make_var (a_var, pos_call)) - all_subscope_input_vars - in - let subscope_struct_arg = + let a_name = + Pos.map_under_mark + (fun str -> + str ^ "." ^ Pos.unmark (Ast.ScopeVar.get_info (Pos.unmark subs_var))) + (Ast.SubScopeName.get_info (Pos.unmark subs_index)) + in + let a_var = Dcalc.Ast.Var.make a_name in + let tau = translate_typ ctx tau in + let new_e = + tag_with_log_entry (translate_expr ctx e) + (Dcalc.Ast.VarDef (Pos.unmark tau)) + [sigma_name, pos_sigma; a_name] + in + let silent_var = Dcalc.Ast.Var.make ("_", Pos.no_pos) in + let thunked_or_nonempty_new_e = + match Pos.unmark a_io.io_input with + | NoInput -> failwith "should not happen" + | OnlyInput -> Bindlib.box_apply - (fun subscope_args -> - ( Dcalc.Ast.ETuple (subscope_args, Some called_scope_input_struct), - pos_call )) - (Bindlib.box_list subscope_args) - in - let all_subscope_output_vars_dcalc = - List.map - (fun (subvar : scope_var_ctx) -> - let sub_dcalc_var = - Dcalc.Ast.Var.make - (Pos.map_under_mark - (fun s -> - Pos.unmark (Ast.SubScopeName.get_info subindex) ^ "." ^ s) - (Ast.ScopeVar.get_info subvar.scope_var_name)) - in - (subvar, sub_dcalc_var)) - all_subscope_output_vars - in - let subscope_func = - tag_with_log_entry - (Dcalc.Ast.make_var - ( scope_dcalc_var, - Pos.get_position (Ast.SubScopeName.get_info subindex) )) - Dcalc.Ast.BeginCall - [ - (sigma_name, pos_sigma); - Ast.SubScopeName.get_info subindex; - Ast.ScopeName.get_info subname; - ] - in - let call_expr = - tag_with_log_entry - (Bindlib.box_apply2 - (fun e u -> (Dcalc.Ast.EApp (e, [ u ]), Pos.no_pos)) - subscope_func subscope_struct_arg) - Dcalc.Ast.EndCall - [ - (sigma_name, pos_sigma); - Ast.SubScopeName.get_info subindex; - Ast.ScopeName.get_info subname; - ] - in - let result_tuple_var = Dcalc.Ast.Var.make ("result", pos_sigma) in - let result_tuple_typ = - ( Dcalc.Ast.TTuple - ( List.map - (fun (subvar, _) -> (subvar.scope_var_typ, pos_sigma)) - all_subscope_output_vars_dcalc, - Some called_scope_return_struct ), - pos_sigma ) - in - let call_scope_let - (next : Dcalc.Ast.expr Dcalc.Ast.scope_body_expr Bindlib.box) = + (fun new_e -> Dcalc.Ast.ErrorOnEmpty new_e, Pos.get_position subs_var) + new_e + | Reentrant -> + Dcalc.Ast.make_abs + (Array.of_list [silent_var]) + new_e var_def_pos + [Dcalc.Ast.TLit TUnit, var_def_pos] + var_def_pos + in + ( (fun next -> Bindlib.box_apply2 - (fun next call_expr -> + (fun next thunked_or_nonempty_new_e -> Dcalc.Ast.ScopeLet { Dcalc.Ast.scope_let_next = next; - Dcalc.Ast.scope_let_pos = pos_sigma; - Dcalc.Ast.scope_let_kind = Dcalc.Ast.CallingSubScope; - Dcalc.Ast.scope_let_typ = result_tuple_typ; - Dcalc.Ast.scope_let_expr = call_expr; + Dcalc.Ast.scope_let_pos = Pos.get_position a_name; + Dcalc.Ast.scope_let_typ = + (match Pos.unmark a_io.io_input with + | NoInput -> failwith "should not happen" + | OnlyInput -> tau + | Reentrant -> + ( Dcalc.Ast.TArrow ((TLit TUnit, var_def_pos), tau), + var_def_pos )); + Dcalc.Ast.scope_let_expr = thunked_or_nonempty_new_e; + Dcalc.Ast.scope_let_kind = Dcalc.Ast.SubScopeVarDefinition; }) - (Bindlib.bind_var result_tuple_var next) - call_expr - in - let result_bindings_lets - (next : Dcalc.Ast.expr Dcalc.Ast.scope_body_expr Bindlib.box) = - List.fold_right - (fun (var_ctx, v) (next, i) -> - ( Bindlib.box_apply2 - (fun next r -> - Dcalc.Ast.ScopeLet - { - Dcalc.Ast.scope_let_next = next; - Dcalc.Ast.scope_let_pos = pos_sigma; - Dcalc.Ast.scope_let_typ = - (var_ctx.scope_var_typ, pos_sigma); - Dcalc.Ast.scope_let_kind = - Dcalc.Ast.DestructuringSubScopeResults; - Dcalc.Ast.scope_let_expr = - ( Dcalc.Ast.ETupleAccess - ( r, - i, - Some called_scope_return_struct, - List.map - (fun (var_ctx, _) -> - (var_ctx.scope_var_typ, pos_sigma)) - all_subscope_output_vars_dcalc ), - pos_sigma ); - }) - (Bindlib.bind_var v next) - (Dcalc.Ast.make_var (result_tuple_var, pos_sigma)), - i - 1 )) - all_subscope_output_vars_dcalc - (next, List.length all_subscope_output_vars_dcalc - 1) - in - ( (fun next -> call_scope_let (fst (result_bindings_lets next))), - { - ctx with - subscope_vars = - Ast.SubScopeMap.add subindex - (List.fold_left - (fun acc (var_ctx, dvar) -> - Ast.ScopeVarMap.add var_ctx.scope_var_name - (dvar, var_ctx.scope_var_typ, var_ctx.scope_var_io) - acc) - Ast.ScopeVarMap.empty all_subscope_output_vars_dcalc) - ctx.subscope_vars; - } ) + (Bindlib.bind_var a_var next) + thunked_or_nonempty_new_e), + { + ctx with + subscope_vars = + Ast.SubScopeMap.update (Pos.unmark subs_index) + (fun map -> + match map with + | Some map -> + Some + (Ast.ScopeVarMap.add (Pos.unmark subs_var) + (a_var, Pos.unmark tau, a_io) + map) + | None -> + Some + (Ast.ScopeVarMap.singleton (Pos.unmark subs_var) + (a_var, Pos.unmark tau, a_io))) + ctx.subscope_vars; + } ) + | Call (subname, subindex) -> + let subscope_sig = Ast.ScopeMap.find subname ctx.scopes_parameters in + let all_subscope_vars = subscope_sig.scope_sig_local_vars in + let all_subscope_input_vars = + List.filter + (fun var_ctx -> + match Pos.unmark var_ctx.scope_var_io.Ast.io_input with + | NoInput -> false + | _ -> true) + all_subscope_vars + in + let all_subscope_output_vars = + List.filter + (fun var_ctx -> Pos.unmark var_ctx.scope_var_io.Ast.io_output) + all_subscope_vars + in + let scope_dcalc_var = subscope_sig.scope_sig_scope_var in + let called_scope_input_struct = subscope_sig.scope_sig_input_struct in + let called_scope_return_struct = subscope_sig.scope_sig_output_struct in + let subscope_vars_defined = + try Ast.SubScopeMap.find subindex ctx.subscope_vars + with Not_found -> Ast.ScopeVarMap.empty + in + let subscope_var_not_yet_defined subvar = + not (Ast.ScopeVarMap.mem subvar subscope_vars_defined) + in + let pos_call = Pos.get_position (Ast.SubScopeName.get_info subindex) in + let subscope_args = + List.map + (fun (subvar : scope_var_ctx) -> + if subscope_var_not_yet_defined subvar.scope_var_name then + (* This is a redundant check. Normally, all subscope varaibles + should have been defined (even an empty definition, if they're + not defined by any rule in the source code) by the translation + from desugared to the scope language. *) + Bindlib.box Dcalc.Ast.empty_thunked_term + else + let a_var, _, _ = + Ast.ScopeVarMap.find subvar.scope_var_name subscope_vars_defined + in + Dcalc.Ast.make_var (a_var, pos_call)) + all_subscope_input_vars + in + let subscope_struct_arg = + Bindlib.box_apply + (fun subscope_args -> + ( Dcalc.Ast.ETuple (subscope_args, Some called_scope_input_struct), + pos_call )) + (Bindlib.box_list subscope_args) + in + let all_subscope_output_vars_dcalc = + List.map + (fun (subvar : scope_var_ctx) -> + let sub_dcalc_var = + Dcalc.Ast.Var.make + (Pos.map_under_mark + (fun s -> + Pos.unmark (Ast.SubScopeName.get_info subindex) ^ "." ^ s) + (Ast.ScopeVar.get_info subvar.scope_var_name)) + in + subvar, sub_dcalc_var) + all_subscope_output_vars + in + let subscope_func = + tag_with_log_entry + (Dcalc.Ast.make_var + ( scope_dcalc_var, + Pos.get_position (Ast.SubScopeName.get_info subindex) )) + Dcalc.Ast.BeginCall + [ + sigma_name, pos_sigma; + Ast.SubScopeName.get_info subindex; + Ast.ScopeName.get_info subname; + ] + in + let call_expr = + tag_with_log_entry + (Bindlib.box_apply2 + (fun e u -> Dcalc.Ast.EApp (e, [u]), Pos.no_pos) + subscope_func subscope_struct_arg) + Dcalc.Ast.EndCall + [ + sigma_name, pos_sigma; + Ast.SubScopeName.get_info subindex; + Ast.ScopeName.get_info subname; + ] + in + let result_tuple_var = Dcalc.Ast.Var.make ("result", pos_sigma) in + let result_tuple_typ = + ( Dcalc.Ast.TTuple + ( List.map + (fun (subvar, _) -> subvar.scope_var_typ, pos_sigma) + all_subscope_output_vars_dcalc, + Some called_scope_return_struct ), + pos_sigma ) + in + let call_scope_let + (next : Dcalc.Ast.expr Dcalc.Ast.scope_body_expr Bindlib.box) = + Bindlib.box_apply2 + (fun next call_expr -> + Dcalc.Ast.ScopeLet + { + Dcalc.Ast.scope_let_next = next; + Dcalc.Ast.scope_let_pos = pos_sigma; + Dcalc.Ast.scope_let_kind = Dcalc.Ast.CallingSubScope; + Dcalc.Ast.scope_let_typ = result_tuple_typ; + Dcalc.Ast.scope_let_expr = call_expr; + }) + (Bindlib.bind_var result_tuple_var next) + call_expr + in + let result_bindings_lets + (next : Dcalc.Ast.expr Dcalc.Ast.scope_body_expr Bindlib.box) = + List.fold_right + (fun (var_ctx, v) (next, i) -> + ( Bindlib.box_apply2 + (fun next r -> + Dcalc.Ast.ScopeLet + { + Dcalc.Ast.scope_let_next = next; + Dcalc.Ast.scope_let_pos = pos_sigma; + Dcalc.Ast.scope_let_typ = var_ctx.scope_var_typ, pos_sigma; + Dcalc.Ast.scope_let_kind = + Dcalc.Ast.DestructuringSubScopeResults; + Dcalc.Ast.scope_let_expr = + ( Dcalc.Ast.ETupleAccess + ( r, + i, + Some called_scope_return_struct, + List.map + (fun (var_ctx, _) -> + var_ctx.scope_var_typ, pos_sigma) + all_subscope_output_vars_dcalc ), + pos_sigma ); + }) + (Bindlib.bind_var v next) + (Dcalc.Ast.make_var (result_tuple_var, pos_sigma)), + i - 1 )) + all_subscope_output_vars_dcalc + (next, List.length all_subscope_output_vars_dcalc - 1) + in + ( (fun next -> call_scope_let (fst (result_bindings_lets next))), + { + ctx with + subscope_vars = + Ast.SubScopeMap.add subindex + (List.fold_left + (fun acc (var_ctx, dvar) -> + Ast.ScopeVarMap.add var_ctx.scope_var_name + (dvar, var_ctx.scope_var_typ, var_ctx.scope_var_io) + acc) + Ast.ScopeVarMap.empty all_subscope_output_vars_dcalc) + ctx.subscope_vars; + } ) | Assertion e -> - let new_e = translate_expr ctx e in - ( (fun next -> - Bindlib.box_apply2 - (fun next new_e -> - Dcalc.Ast.ScopeLet - { - Dcalc.Ast.scope_let_next = next; - Dcalc.Ast.scope_let_pos = Pos.get_position e; - Dcalc.Ast.scope_let_typ = - (Dcalc.Ast.TLit TUnit, Pos.get_position e); - Dcalc.Ast.scope_let_expr = - (* To ensure that we throw an error if the value is not - defined, we add an check "ErrorOnEmpty" here. *) - Pos.same_pos_as - (Dcalc.Ast.EAssert - (Dcalc.Ast.ErrorOnEmpty new_e, Pos.get_position e)) - new_e; - Dcalc.Ast.scope_let_kind = Dcalc.Ast.Assertion; - }) - (Bindlib.bind_var - (Dcalc.Ast.Var.make ("_", Pos.get_position e)) - next) - new_e), - ctx ) + let new_e = translate_expr ctx e in + ( (fun next -> + Bindlib.box_apply2 + (fun next new_e -> + Dcalc.Ast.ScopeLet + { + Dcalc.Ast.scope_let_next = next; + Dcalc.Ast.scope_let_pos = Pos.get_position e; + Dcalc.Ast.scope_let_typ = + Dcalc.Ast.TLit TUnit, Pos.get_position e; + Dcalc.Ast.scope_let_expr = + (* To ensure that we throw an error if the value is not + defined, we add an check "ErrorOnEmpty" here. *) + Pos.same_pos_as + (Dcalc.Ast.EAssert + (Dcalc.Ast.ErrorOnEmpty new_e, Pos.get_position e)) + new_e; + Dcalc.Ast.scope_let_kind = Dcalc.Ast.Assertion; + }) + (Bindlib.bind_var (Dcalc.Ast.Var.make ("_", Pos.get_position e)) next) + new_e), + ctx ) let translate_rules (ctx : ctx) @@ -652,7 +644,7 @@ let translate_rules let new_scope_lets, new_ctx = translate_rule ctx rule (sigma_name, pos_sigma) in - ((fun next -> scope_lets (new_scope_lets next)), new_ctx)) + (fun next -> scope_lets (new_scope_lets next)), new_ctx) ((fun next -> next), ctx) rules in @@ -665,7 +657,7 @@ let translate_rules let return_exp = Bindlib.box_apply (fun args -> - (Dcalc.Ast.ETuple (args, Some sigma_return_struct_name), pos_sigma)) + Dcalc.Ast.ETuple (args, Some sigma_return_struct_name), pos_sigma) (Bindlib.box_list (List.map (fun (_, (dcalc_var, _, _)) -> @@ -695,19 +687,17 @@ let translate_scope_decl (fun ctx scope_var -> match Pos.unmark scope_var.scope_var_io.io_input with | OnlyInput -> - let scope_var_name = - Ast.ScopeVar.get_info scope_var.scope_var_name - in - let scope_var_dcalc = Dcalc.Ast.Var.make scope_var_name in - { - ctx with - scope_vars = - Ast.ScopeVarMap.add scope_var.scope_var_name - ( scope_var_dcalc, - scope_var.scope_var_typ, - scope_var.scope_var_io ) - ctx.scope_vars; - } + let scope_var_name = Ast.ScopeVar.get_info scope_var.scope_var_name in + let scope_var_dcalc = Dcalc.Ast.Var.make scope_var_name in + { + ctx with + scope_vars = + Ast.ScopeVarMap.add scope_var.scope_var_name + ( scope_var_dcalc, + scope_var.scope_var_typ, + scope_var.scope_var_io ) + ctx.scope_vars; + } | _ -> ctx) (empty_ctx struct_ctx enum_ctx sctx scope_name) scope_variables @@ -726,7 +716,7 @@ let translate_scope_decl let dcalc_x, _, _ = Ast.ScopeVarMap.find var_ctx.scope_var_name ctx.scope_vars in - (var_ctx, dcalc_x)) + var_ctx, dcalc_x) scope_variables in (* first we create variables from the fields of the input struct *) @@ -745,12 +735,11 @@ let translate_scope_decl in let input_var_typ (var_ctx : scope_var_ctx) = match Pos.unmark var_ctx.scope_var_io.io_input with - | OnlyInput -> (var_ctx.scope_var_typ, pos_sigma) + | OnlyInput -> var_ctx.scope_var_typ, pos_sigma | Reentrant -> - ( Dcalc.Ast.TArrow - ( (Dcalc.Ast.TLit TUnit, pos_sigma), - (var_ctx.scope_var_typ, pos_sigma) ), - pos_sigma ) + ( Dcalc.Ast.TArrow + ((Dcalc.Ast.TLit TUnit, pos_sigma), (var_ctx.scope_var_typ, pos_sigma)), + pos_sigma ) | NoInput -> failwith "should not happen" in let input_destructurings @@ -789,7 +778,7 @@ let translate_scope_decl let struct_field_name = Ast.StructFieldName.fresh (Bindlib.name_of dvar ^ "_out", pos_sigma) in - (struct_field_name, (var_ctx.scope_var_typ, pos_sigma))) + struct_field_name, (var_ctx.scope_var_typ, pos_sigma)) scope_output_variables in let scope_input_struct_fields = @@ -798,7 +787,7 @@ let translate_scope_decl let struct_field_name = Ast.StructFieldName.fresh (Bindlib.name_of dvar ^ "_in", pos_sigma) in - (struct_field_name, input_var_typ var_ctx)) + struct_field_name, input_var_typ var_ctx) scope_input_variables in let new_struct_ctx = @@ -836,12 +825,12 @@ let translate_program (prgm : Ast.program) : Dcalc.Ast.ctx_structs = Ast.StructMap.map (List.map (fun (x, y) -> - (x, translate_typ (ctx_for_typ_translation dummy_scope) y))) + x, translate_typ (ctx_for_typ_translation dummy_scope) y)) struct_ctx; Dcalc.Ast.ctx_enums = Ast.EnumMap.map (List.map (fun (x, y) -> - (x, (translate_typ (ctx_for_typ_translation dummy_scope)) y))) + x, (translate_typ (ctx_for_typ_translation dummy_scope)) y)) enum_ctx; } in @@ -915,8 +904,8 @@ let translate_program (prgm : Ast.program) : Dcalc.Ast.ScopeDef { scope_name; scope_body; scope_next }) scope_body scope_next in - (new_scopes, decl_ctx)) + new_scopes, decl_ctx) scope_ordering (Bindlib.box Dcalc.Ast.Nil, decl_ctx) in - ({ scopes = Bindlib.unbox scopes; decl_ctx }, types_ordering) + { scopes = Bindlib.unbox scopes; decl_ctx }, types_ordering diff --git a/compiler/surface/ast.ml b/compiler/surface/ast.ml index 356434a6..26afa54d 100644 --- a/compiler/surface/ast.ml +++ b/compiler/surface/ast.ml @@ -47,13 +47,13 @@ type qident = ident Pos.marked list visitors { variety = "map"; - ancestors = [ "Pos.marked_map"; "ident_map" ]; + ancestors = ["Pos.marked_map"; "ident_map"]; name = "qident_map"; }, visitors { variety = "iter"; - ancestors = [ "Pos.marked_iter"; "ident_iter" ]; + ancestors = ["Pos.marked_iter"; "ident_iter"]; name = "qident_iter"; }] @@ -70,13 +70,13 @@ type primitive_typ = visitors { variety = "map"; - ancestors = [ "constructor_map" ]; + ancestors = ["constructor_map"]; name = "primitive_typ_map"; }, visitors { variety = "iter"; - ancestors = [ "constructor_iter" ]; + ancestors = ["constructor_iter"]; name = "primitive_typ_iter"; }] @@ -87,13 +87,13 @@ type base_typ_data = visitors { variety = "map"; - ancestors = [ "Pos.marked_map"; "primitive_typ_map" ]; + ancestors = ["Pos.marked_map"; "primitive_typ_map"]; name = "base_typ_data_map"; }, visitors { variety = "iter"; - ancestors = [ "Pos.marked_iter"; "primitive_typ_iter" ]; + ancestors = ["Pos.marked_iter"; "primitive_typ_iter"]; name = "base_typ_data_iter"; }] @@ -102,14 +102,14 @@ type base_typ = Condition | Data of base_typ_data visitors { variety = "map"; - ancestors = [ "base_typ_data_map" ]; + ancestors = ["base_typ_data_map"]; name = "base_typ_map"; nude = true; }, visitors { variety = "iter"; - ancestors = [ "base_typ_data_iter" ]; + ancestors = ["base_typ_data_iter"]; name = "base_typ_iter"; nude = true; }] @@ -122,14 +122,14 @@ type func_typ = { visitors { variety = "map"; - ancestors = [ "base_typ_map" ]; + ancestors = ["base_typ_map"]; name = "func_typ_map"; nude = true; }, visitors { variety = "iter"; - ancestors = [ "base_typ_iter" ]; + ancestors = ["base_typ_iter"]; name = "func_typ_iter"; nude = true; }] @@ -139,14 +139,14 @@ type typ = Base of base_typ | Func of func_typ visitors { variety = "map"; - ancestors = [ "func_typ_map" ]; + ancestors = ["func_typ_map"]; name = "typ_map"; nude = true; }, visitors { variety = "iter"; - ancestors = [ "func_typ_iter" ]; + ancestors = ["func_typ_iter"]; name = "typ_iter"; nude = true; }] @@ -159,13 +159,13 @@ type struct_decl_field = { visitors { variety = "map"; - ancestors = [ "typ_map"; "ident_map" ]; + ancestors = ["typ_map"; "ident_map"]; name = "struct_decl_field_map"; }, visitors { variety = "iter"; - ancestors = [ "typ_iter"; "ident_iter" ]; + ancestors = ["typ_iter"; "ident_iter"]; name = "struct_decl_field_iter"; }] @@ -177,13 +177,13 @@ type struct_decl = { visitors { variety = "map"; - ancestors = [ "struct_decl_field_map" ]; + ancestors = ["struct_decl_field_map"]; name = "struct_decl_map"; }, visitors { variety = "iter"; - ancestors = [ "struct_decl_field_iter" ]; + ancestors = ["struct_decl_field_iter"]; name = "struct_decl_iter"; }] @@ -195,14 +195,14 @@ type enum_decl_case = { visitors { variety = "map"; - ancestors = [ "typ_map" ]; + ancestors = ["typ_map"]; name = "enum_decl_case_map"; nude = true; }, visitors { variety = "iter"; - ancestors = [ "typ_iter" ]; + ancestors = ["typ_iter"]; name = "enum_decl_case_iter"; nude = true; }] @@ -215,14 +215,14 @@ type enum_decl = { visitors { variety = "map"; - ancestors = [ "enum_decl_case_map" ]; + ancestors = ["enum_decl_case_map"]; name = "enum_decl_map"; nude = true; }, visitors { variety = "iter"; - ancestors = [ "enum_decl_case_iter" ]; + ancestors = ["enum_decl_case_iter"]; name = "enum_decl_iter"; nude = true; }] @@ -234,13 +234,13 @@ type match_case_pattern = visitors { variety = "map"; - ancestors = [ "ident_map"; "constructor_map"; "Pos.marked_map" ]; + ancestors = ["ident_map"; "constructor_map"; "Pos.marked_map"]; name = "match_case_pattern_map"; }, visitors { variety = "iter"; - ancestors = [ "ident_iter"; "constructor_iter"; "Pos.marked_iter" ]; + ancestors = ["ident_iter"; "constructor_iter"; "Pos.marked_iter"]; name = "match_case_pattern_iter"; }] @@ -268,14 +268,14 @@ type binop = visitors { variety = "map"; - ancestors = [ "op_kind_map" ]; + ancestors = ["op_kind_map"]; name = "binop_map"; nude = true; }, visitors { variety = "iter"; - ancestors = [ "op_kind_iter" ]; + ancestors = ["op_kind_iter"]; name = "binop_iter"; nude = true; }] @@ -285,14 +285,14 @@ type unop = Not | Minus of op_kind visitors { variety = "map"; - ancestors = [ "op_kind_map" ]; + ancestors = ["op_kind_map"]; name = "unop_map"; nude = true; }, visitors { variety = "iter"; - ancestors = [ "op_kind_iter" ]; + ancestors = ["op_kind_iter"]; name = "unop_iter"; nude = true; }] @@ -318,13 +318,13 @@ type literal_date = { visitors { variety = "map"; - ancestors = [ "Pos.marked_map" ]; + ancestors = ["Pos.marked_map"]; name = "literal_date_map"; }, visitors { variety = "iter"; - ancestors = [ "Pos.marked_iter" ]; + ancestors = ["Pos.marked_iter"]; name = "literal_date_iter"; }] @@ -468,13 +468,13 @@ type exception_to = visitors { variety = "map"; - ancestors = [ "ident_map"; "Pos.marked_map" ]; + ancestors = ["ident_map"; "Pos.marked_map"]; name = "exception_to_map"; }, visitors { variety = "iter"; - ancestors = [ "ident_iter"; "Pos.marked_iter" ]; + ancestors = ["ident_iter"; "Pos.marked_iter"]; name = "exception_to_iter"; }] @@ -492,13 +492,13 @@ type rule = { visitors { variety = "map"; - ancestors = [ "expression_map"; "qident_map"; "exception_to_map" ]; + ancestors = ["expression_map"; "qident_map"; "exception_to_map"]; name = "rule_map"; }, visitors { variety = "iter"; - ancestors = [ "expression_iter"; "qident_iter"; "exception_to_iter" ]; + ancestors = ["expression_iter"; "qident_iter"; "exception_to_iter"]; name = "rule_iter"; }] @@ -516,13 +516,13 @@ type definition = { visitors { variety = "map"; - ancestors = [ "expression_map"; "qident_map"; "exception_to_map" ]; + ancestors = ["expression_map"; "qident_map"; "exception_to_map"]; name = "definition_map"; }, visitors { variety = "iter"; - ancestors = [ "expression_iter"; "qident_iter"; "exception_to_iter" ]; + ancestors = ["expression_iter"; "qident_iter"; "exception_to_iter"]; name = "definition_iter"; }] @@ -541,13 +541,13 @@ type meta_assertion = visitors { variety = "map"; - ancestors = [ "variation_typ_map"; "qident_map"; "expression_map" ]; + ancestors = ["variation_typ_map"; "qident_map"; "expression_map"]; name = "meta_assertion_map"; }, visitors { variety = "iter"; - ancestors = [ "variation_typ_iter"; "qident_iter"; "expression_iter" ]; + ancestors = ["variation_typ_iter"; "qident_iter"; "expression_iter"]; name = "meta_assertion_iter"; }] @@ -557,15 +557,11 @@ type assertion = { } [@@deriving visitors - { - variety = "map"; - ancestors = [ "expression_map" ]; - name = "assertion_map"; - }, + { variety = "map"; ancestors = ["expression_map"]; name = "assertion_map" }, visitors { variety = "iter"; - ancestors = [ "expression_iter" ]; + ancestors = ["expression_iter"]; name = "assertion_iter"; }] @@ -579,7 +575,7 @@ type scope_use_item = { variety = "map"; ancestors = - [ "meta_assertion_map"; "definition_map"; "assertion_map"; "rule_map" ]; + ["meta_assertion_map"; "definition_map"; "assertion_map"; "rule_map"]; name = "scope_use_item_map"; }, visitors @@ -604,13 +600,13 @@ type scope_use = { visitors { variety = "map"; - ancestors = [ "expression_map"; "scope_use_item_map" ]; + ancestors = ["expression_map"; "scope_use_item_map"]; name = "scope_use_map"; }, visitors { variety = "iter"; - ancestors = [ "expression_iter"; "scope_use_item_iter" ]; + ancestors = ["expression_iter"; "scope_use_item_iter"]; name = "scope_use_iter"; }] @@ -627,13 +623,13 @@ type scope_decl_context_io = { visitors { variety = "map"; - ancestors = [ "io_input_map"; "Pos.marked_map" ]; + ancestors = ["io_input_map"; "Pos.marked_map"]; name = "scope_decl_context_io_map"; }, visitors { variety = "iter"; - ancestors = [ "io_input_iter"; "Pos.marked_iter" ]; + ancestors = ["io_input_iter"; "Pos.marked_iter"]; name = "scope_decl_context_io_iter"; }] @@ -678,13 +674,13 @@ type scope_decl_context_data = { visitors { variety = "map"; - ancestors = [ "typ_map"; "scope_decl_context_io_map"; "ident_map" ]; + ancestors = ["typ_map"; "scope_decl_context_io_map"; "ident_map"]; name = "scope_decl_context_data_map"; }, visitors { variety = "iter"; - ancestors = [ "typ_iter"; "scope_decl_context_io_iter"; "ident_iter" ]; + ancestors = ["typ_iter"; "scope_decl_context_io_iter"; "ident_iter"]; name = "scope_decl_context_data_iter"; }] @@ -696,14 +692,14 @@ type scope_decl_context_item = { variety = "map"; ancestors = - [ "scope_decl_context_data_map"; "scope_decl_context_scope_map" ]; + ["scope_decl_context_data_map"; "scope_decl_context_scope_map"]; name = "scope_decl_context_item_map"; }, visitors { variety = "iter"; ancestors = - [ "scope_decl_context_data_iter"; "scope_decl_context_scope_iter" ]; + ["scope_decl_context_data_iter"; "scope_decl_context_scope_iter"]; name = "scope_decl_context_item_iter"; }] @@ -715,13 +711,13 @@ type scope_decl = { visitors { variety = "map"; - ancestors = [ "scope_decl_context_item_map" ]; + ancestors = ["scope_decl_context_item_map"]; name = "scope_decl_map"; }, visitors { variety = "iter"; - ancestors = [ "scope_decl_context_item_iter" ]; + ancestors = ["scope_decl_context_item_iter"]; name = "scope_decl_iter"; }] @@ -735,9 +731,7 @@ type code_item = { variety = "map"; ancestors = - [ - "scope_decl_map"; "enum_decl_map"; "struct_decl_map"; "scope_use_map"; - ]; + ["scope_decl_map"; "enum_decl_map"; "struct_decl_map"; "scope_use_map"]; name = "code_item_map"; }, visitors @@ -756,15 +750,11 @@ type code_item = type code_block = code_item Pos.marked list [@@deriving visitors - { - variety = "map"; - ancestors = [ "code_item_map" ]; - name = "code_block_map"; - }, + { variety = "map"; ancestors = ["code_item_map"]; name = "code_block_map" }, visitors { variety = "iter"; - ancestors = [ "code_item_iter" ]; + ancestors = ["code_item_iter"]; name = "code_block_iter"; }] @@ -773,13 +763,13 @@ type source_repr = (string[@opaque]) Pos.marked visitors { variety = "map"; - ancestors = [ "Pos.marked_map" ]; + ancestors = ["Pos.marked_map"]; name = "source_repr_map"; }, visitors { variety = "iter"; - ancestors = [ "Pos.marked_iter" ]; + ancestors = ["Pos.marked_iter"]; name = "source_repr_iter"; }] @@ -793,13 +783,13 @@ type law_heading = { visitors { variety = "map"; - ancestors = [ "Pos.marked_map" ]; + ancestors = ["Pos.marked_map"]; name = "law_heading_map"; }, visitors { variety = "iter"; - ancestors = [ "Pos.marked_iter" ]; + ancestors = ["Pos.marked_iter"]; name = "law_heading_iter"; }] @@ -811,13 +801,13 @@ type law_include = visitors { variety = "map"; - ancestors = [ "Pos.marked_map" ]; + ancestors = ["Pos.marked_map"]; name = "law_include_map"; }, visitors { variety = "iter"; - ancestors = [ "Pos.marked_iter" ]; + ancestors = ["Pos.marked_iter"]; name = "law_include_iter"; }] @@ -858,15 +848,11 @@ type program = { } [@@deriving visitors - { - variety = "map"; - ancestors = [ "law_structure_map" ]; - name = "program_map"; - }, + { variety = "map"; ancestors = ["law_structure_map"]; name = "program_map" }, visitors { variety = "iter"; - ancestors = [ "law_structure_iter" ]; + ancestors = ["law_structure_iter"]; name = "program_iter"; }] @@ -884,6 +870,6 @@ let rule_to_def (rule : rule) : definition = definition_parameter = rule.rule_parameter; definition_condition = rule.rule_condition; definition_id = rule.rule_id; - definition_expr = (consequence_expr, Pos.get_position rule.rule_consequence); + definition_expr = consequence_expr, Pos.get_position rule.rule_consequence; definition_state = rule.rule_state; } diff --git a/compiler/surface/desugaring.ml b/compiler/surface/desugaring.ml index cd19696c..f6f2d519 100644 --- a/compiler/surface/desugaring.ml +++ b/compiler/surface/desugaring.ml @@ -66,10 +66,10 @@ let disambiguate_constructor (pos : Pos.t) : Scopelang.Ast.EnumName.t * Scopelang.Ast.EnumConstructor.t = let enum, constructor = match constructor with - | [ c ] -> c + | [c] -> c | _ -> - Errors.raise_spanned_error pos - "The deep pattern matching syntactic sugar is not yet supported" + Errors.raise_spanned_error pos + "The deep pattern matching syntactic sugar is not yet supported" in let possible_c_uids = try @@ -83,32 +83,32 @@ let disambiguate_constructor in match enum with | None -> - if Scopelang.Ast.EnumMap.cardinal possible_c_uids > 1 then - Errors.raise_spanned_error - (Pos.get_position constructor) - "This constructor name is ambiguous, it can belong to %a. \ - Disambiguate it by prefixing it with the enum name." - (Format.pp_print_list - ~pp_sep:(fun fmt () -> Format.fprintf fmt " or ") - (fun fmt (s_name, _) -> - Format.fprintf fmt "%a" Scopelang.Ast.EnumName.format_t s_name)) - (Scopelang.Ast.EnumMap.bindings possible_c_uids); - Scopelang.Ast.EnumMap.choose possible_c_uids + if Scopelang.Ast.EnumMap.cardinal possible_c_uids > 1 then + Errors.raise_spanned_error + (Pos.get_position constructor) + "This constructor name is ambiguous, it can belong to %a. Disambiguate \ + it by prefixing it with the enum name." + (Format.pp_print_list + ~pp_sep:(fun fmt () -> Format.fprintf fmt " or ") + (fun fmt (s_name, _) -> + Format.fprintf fmt "%a" Scopelang.Ast.EnumName.format_t s_name)) + (Scopelang.Ast.EnumMap.bindings possible_c_uids); + Scopelang.Ast.EnumMap.choose possible_c_uids | Some enum -> ( + try + (* The path is fully qualified *) + let e_uid = + Desugared.Ast.IdentMap.find (Pos.unmark enum) ctxt.enum_idmap + in try - (* The path is fully qualified *) - let e_uid = - Desugared.Ast.IdentMap.find (Pos.unmark enum) ctxt.enum_idmap - in - try - let c_uid = Scopelang.Ast.EnumMap.find e_uid possible_c_uids in - (e_uid, c_uid) - with Not_found -> - Errors.raise_spanned_error pos "Enum %s does not contain case %s" - (Pos.unmark enum) (Pos.unmark constructor) + let c_uid = Scopelang.Ast.EnumMap.find e_uid possible_c_uids in + e_uid, c_uid with Not_found -> - Errors.raise_spanned_error (Pos.get_position enum) - "Enum %s has not been defined before" (Pos.unmark enum)) + Errors.raise_spanned_error pos "Enum %s does not contain case %s" + (Pos.unmark enum) (Pos.unmark constructor) + with Not_found -> + Errors.raise_spanned_error (Pos.get_position enum) + "Enum %s has not been defined before" (Pos.unmark enum)) (** Usage: [translate_expr scope ctxt expr] @@ -128,770 +128,743 @@ let rec translate_expr ( TestMatchCase (e1_sub, ((constructors, Some binding), pos_pattern)), _pos_e1 ), e2 ) -> - (* This sugar corresponds to [e is P x && e'] and should desugar to [match - e with P x -> e' | _ -> false] *) - let enum_uid, c_uid = - disambiguate_constructor ctxt constructors pos_pattern - in - let cases = - Scopelang.Ast.EnumConstructorMap.mapi - (fun c_uid' tau -> - if Scopelang.Ast.EnumConstructor.compare c_uid c_uid' <> 0 then - let nop_var = Desugared.Ast.Var.make ("_", pos) in - Bindlib.unbox - (Desugared.Ast.make_abs [| nop_var |] - (Bindlib.box - (Desugared.Ast.ELit (Dcalc.Ast.LBool false), pos)) - pos [ tau ] pos) - else - let ctxt, binding_var = - Name_resolution.add_def_local_var ctxt binding - in - let e2 = translate_expr scope inside_definition_of ctxt e2 in - Bindlib.unbox - (Desugared.Ast.make_abs [| binding_var |] e2 pos [ tau ] pos)) - (Scopelang.Ast.EnumMap.find enum_uid ctxt.enums) - in - Bindlib.box_apply - (fun e1_sub -> (Desugared.Ast.EMatch (e1_sub, enum_uid, cases), pos)) - (translate_expr scope inside_definition_of ctxt e1_sub) - | IfThenElse (e_if, e_then, e_else) -> - Bindlib.box_apply3 - (fun e_if e_then e_else -> - (Desugared.Ast.EIfThenElse (e_if, e_then, e_else), pos)) - (rec_helper e_if) (rec_helper e_then) (rec_helper e_else) - | Binop (op, e1, e2) -> - let op_term = - Pos.same_pos_as - (Desugared.Ast.EOp (Dcalc.Ast.Binop (translate_binop (Pos.unmark op)))) - op - in - Bindlib.box_apply2 - (fun e1 e2 -> (Desugared.Ast.EApp (op_term, [ e1; e2 ]), pos)) - (rec_helper e1) (rec_helper e2) - | Unop (op, e) -> - let op_term = - Pos.same_pos_as - (Desugared.Ast.EOp (Dcalc.Ast.Unop (translate_unop (Pos.unmark op)))) - op - in - Bindlib.box_apply - (fun e -> (Desugared.Ast.EApp (op_term, [ e ]), pos)) - (rec_helper e) - | Literal l -> - let untyped_term = - match l with - | LNumber ((Int i, _), None) -> - Desugared.Ast.ELit (Dcalc.Ast.LInt (Runtime.integer_of_string i)) - | LNumber ((Int i, _), Some (Percent, _)) -> - Desugared.Ast.ELit - (Dcalc.Ast.LRat - Runtime.(decimal_of_string i /& decimal_of_string "100")) - | LNumber ((Dec (i, f), _), None) -> - Desugared.Ast.ELit - (Dcalc.Ast.LRat Runtime.(decimal_of_string (i ^ "." ^ f))) - | LNumber ((Dec (i, f), _), Some (Percent, _)) -> - Desugared.Ast.ELit - (Dcalc.Ast.LRat - Runtime.( - decimal_of_string (i ^ "." ^ f) /& decimal_of_string "100")) - | LBool b -> Desugared.Ast.ELit (Dcalc.Ast.LBool b) - | LMoneyAmount i -> - Desugared.Ast.ELit - (Dcalc.Ast.LMoney - Runtime.( - money_of_cents_integer - (integer_of_string i.money_amount_units - *! integer_of_int 100 - +! integer_of_string i.money_amount_cents))) - | LNumber ((Int i, _), Some (Year, _)) -> - Desugared.Ast.ELit - (Dcalc.Ast.LDuration - (Runtime.duration_of_numbers (int_of_string i) 0 0)) - | LNumber ((Int i, _), Some (Month, _)) -> - Desugared.Ast.ELit - (Dcalc.Ast.LDuration - (Runtime.duration_of_numbers 0 (int_of_string i) 0)) - | LNumber ((Int i, _), Some (Day, _)) -> - Desugared.Ast.ELit - (Dcalc.Ast.LDuration - (Runtime.duration_of_numbers 0 0 (int_of_string i))) - | LNumber ((Dec (_, _), _), Some ((Year | Month | Day), _)) -> - Errors.raise_spanned_error pos - "Impossible to specify decimal amounts of days, months or years" - | LDate date -> - if Pos.unmark date.literal_date_month > 12 then - Errors.raise_spanned_error - (Pos.get_position date.literal_date_month) - "There is an error in this date: the month number is bigger \ - than 12"; - if Pos.unmark date.literal_date_day > 31 then - Errors.raise_spanned_error - (Pos.get_position date.literal_date_day) - "There is an error in this date: the day number is bigger than \ - 31"; - Desugared.Ast.ELit - (Dcalc.Ast.LDate - (try - Runtime.date_of_numbers - (Pos.unmark date.literal_date_year) - (Pos.unmark date.literal_date_month) - (Pos.unmark date.literal_date_day) - with Runtime.ImpossibleDate -> - Errors.raise_spanned_error pos - "There is an error in this date, it does not correspond \ - to a correct calendar day")) - in - Bindlib.box (untyped_term, pos) - | Ident x -> ( - (* first we check whether this is a local var, then we resort to - scope-wide variables *) - match Desugared.Ast.IdentMap.find_opt x ctxt.local_var_idmap with - | None -> ( - match Desugared.Ast.IdentMap.find_opt x scope_ctxt.var_idmap with - | Some uid -> - (* If the referenced variable has states, then here are the rules - to desambiguate. In general, only the last state can be - referenced. Except if defining a state of the same variable, - then it references the previous state in the chain. *) - let x_sig = Desugared.Ast.ScopeVarMap.find uid ctxt.var_typs in - let x_state = - match x_sig.var_sig_states_list with - | [] -> None - | states -> ( - match inside_definition_of with - | Some (Desugared.Ast.ScopeDef.Var (x'_uid, sx'), _) - when Desugared.Ast.ScopeVar.compare uid x'_uid = 0 -> ( - match sx' with - | None -> - failwith - "inconsistent state: inside a definition of a \ - variable with no state but variable has states" - | Some inside_def_state -> - if - Desugared.Ast.StateName.compare inside_def_state - (List.hd states) - = 0 - then - Errors.raise_spanned_error pos - "It is impossible to refer to the variable you \ - are defining when defining its first state." - else - (* Tricky: we have to retrieve in the list the - previous state with respect to the state that - we are defining. *) - let correct_state = ref None in - ignore - (List.fold_left - (fun previous_state state -> - if - Desugared.Ast.StateName.compare - inside_def_state state - = 0 - then correct_state := previous_state; - Some state) - None states); - !correct_state) - | _ -> - (* we take the last state in the chain *) - Some (List.hd (List.rev states))) - in - Bindlib.box - (Desugared.Ast.ELocation (ScopeVar ((uid, pos), x_state)), pos) - | None -> - Name_resolution.raise_unknown_identifier - "for a local or scope-wide variable" (x, pos)) - | Some uid -> - Desugared.Ast.make_var (uid, pos) - (* the whole box thing is to accomodate for this case *)) - | Dotted (e, c, x) -> ( - match Pos.unmark e with - | Ident y when Name_resolution.is_subscope_uid scope ctxt y -> - (* In this case, y.x is a subscope variable *) - let subscope_uid : Scopelang.Ast.SubScopeName.t = - Name_resolution.get_subscope_uid scope ctxt (Pos.same_pos_as y e) - in - let subscope_real_uid : Scopelang.Ast.ScopeName.t = - Scopelang.Ast.SubScopeMap.find subscope_uid scope_ctxt.sub_scopes - in - let subscope_var_uid = - Name_resolution.get_var_uid subscope_real_uid ctxt x - in - Bindlib.box - ( Desugared.Ast.ELocation - (SubScopeVar - ( subscope_real_uid, - (subscope_uid, pos), - (subscope_var_uid, pos) )), - pos ) - | _ -> ( - (* In this case e.x is the struct field x access of expression e *) - let e = translate_expr scope inside_definition_of ctxt e in - let x_possible_structs = - try Desugared.Ast.IdentMap.find (Pos.unmark x) ctxt.field_idmap - with Not_found -> - Errors.raise_spanned_error (Pos.get_position x) - "Unknown subscope or struct field name" - in - match c with - | None -> - (* No constructor name was specified *) - if Scopelang.Ast.StructMap.cardinal x_possible_structs > 1 then - Errors.raise_spanned_error (Pos.get_position x) - "This struct field name is ambiguous, it can belong to %a. \ - Disambiguate it by prefixing it with the struct name." - (Format.pp_print_list - ~pp_sep:(fun fmt () -> Format.fprintf fmt " or ") - (fun fmt (s_name, _) -> - Format.fprintf fmt "%a" Scopelang.Ast.StructName.format_t - s_name)) - (Scopelang.Ast.StructMap.bindings x_possible_structs) - else - let s_uid, f_uid = - Scopelang.Ast.StructMap.choose x_possible_structs - in - Bindlib.box_apply - (fun e -> - (Desugared.Ast.EStructAccess (e, f_uid, s_uid), pos)) - e - | Some c_name -> ( - try - let c_uid = - Desugared.Ast.IdentMap.find (Pos.unmark c_name) - ctxt.struct_idmap - in - try - let f_uid = - Scopelang.Ast.StructMap.find c_uid x_possible_structs - in - Bindlib.box_apply - (fun e -> - (Desugared.Ast.EStructAccess (e, f_uid, c_uid), pos)) - e - with Not_found -> - Errors.raise_spanned_error pos - "Struct %s does not contain field %s" (Pos.unmark c_name) - (Pos.unmark x) - with Not_found -> - Errors.raise_spanned_error (Pos.get_position c_name) - "Struct %s has not been defined before" (Pos.unmark c_name)))) - | FunCall (f, arg) -> - Bindlib.box_apply2 - (fun f arg -> (Desugared.Ast.EApp (f, [ arg ]), pos)) - (rec_helper f) (rec_helper arg) - | StructLit (s_name, fields) -> - let s_uid = - try Desugared.Ast.IdentMap.find (Pos.unmark s_name) ctxt.struct_idmap - with Not_found -> - Errors.raise_spanned_error (Pos.get_position s_name) - "This identifier should refer to a struct name" - in - - let s_fields = - List.fold_left - (fun s_fields (f_name, f_e) -> - let f_uid = - try - Scopelang.Ast.StructMap.find s_uid - (Desugared.Ast.IdentMap.find (Pos.unmark f_name) - ctxt.field_idmap) - with Not_found -> - Errors.raise_spanned_error (Pos.get_position f_name) - "This identifier should refer to a field of struct %s" - (Pos.unmark s_name) - in - (match Scopelang.Ast.StructFieldMap.find_opt f_uid s_fields with - | None -> () - | Some e_field -> - Errors.raise_multispanned_error - [ - (None, Pos.get_position f_e); - (None, Pos.get_position (Bindlib.unbox e_field)); - ] - "The field %a has been defined twice:" - Scopelang.Ast.StructFieldName.format_t f_uid); - let f_e = translate_expr scope inside_definition_of ctxt f_e in - Scopelang.Ast.StructFieldMap.add f_uid f_e s_fields) - Scopelang.Ast.StructFieldMap.empty fields - in - let expected_s_fields = Scopelang.Ast.StructMap.find s_uid ctxt.structs in - Scopelang.Ast.StructFieldMap.iter - (fun expected_f _ -> - if not (Scopelang.Ast.StructFieldMap.mem expected_f s_fields) then - Errors.raise_spanned_error pos - "Missing field for structure %a: \"%a\"" - Scopelang.Ast.StructName.format_t s_uid - Scopelang.Ast.StructFieldName.format_t expected_f) - expected_s_fields; - - Bindlib.box_apply - (fun s_fields -> (Desugared.Ast.EStruct (s_uid, s_fields), pos)) - (LiftStructFieldMap.lift_box s_fields) - | EnumInject (enum, constructor, payload) -> ( - let possible_c_uids = - try - Desugared.Ast.IdentMap.find (Pos.unmark constructor) - ctxt.constructor_idmap - with Not_found -> - Errors.raise_spanned_error - (Pos.get_position constructor) - "The name of this constructor has not been defined before, maybe \ - it is a typo?" - in - - match enum with - | None -> - if - (* No constructor name was specified *) - Scopelang.Ast.EnumMap.cardinal possible_c_uids > 1 - then - Errors.raise_spanned_error - (Pos.get_position constructor) - "This constructor name is ambiguous, it can belong to %a. \ - Desambiguate it by prefixing it with the enum name." - (Format.pp_print_list - ~pp_sep:(fun fmt () -> Format.fprintf fmt " or ") - (fun fmt (s_name, _) -> - Format.fprintf fmt "%a" Scopelang.Ast.EnumName.format_t - s_name)) - (Scopelang.Ast.EnumMap.bindings possible_c_uids) - else - let e_uid, c_uid = Scopelang.Ast.EnumMap.choose possible_c_uids in - let payload = - Option.map - (translate_expr scope inside_definition_of ctxt) - payload - in - Bindlib.box_apply - (fun payload -> - ( Desugared.Ast.EEnumInj - ( (match payload with - | Some e' -> e' - | None -> - ( Desugared.Ast.ELit Dcalc.Ast.LUnit, - Pos.get_position constructor )), - c_uid, - e_uid ), - pos )) - (Bindlib.box_opt payload) - | Some enum -> ( - try - (* The path has been fully qualified *) - let e_uid = - Desugared.Ast.IdentMap.find (Pos.unmark enum) ctxt.enum_idmap - in - try - let c_uid = Scopelang.Ast.EnumMap.find e_uid possible_c_uids in - let payload = - Option.map - (translate_expr scope inside_definition_of ctxt) - payload - in - Bindlib.box_apply - (fun payload -> - ( Desugared.Ast.EEnumInj - ( (match payload with - | Some e' -> e' - | None -> - ( Desugared.Ast.ELit Dcalc.Ast.LUnit, - Pos.get_position constructor )), - c_uid, - e_uid ), - pos )) - (Bindlib.box_opt payload) - with Not_found -> - Errors.raise_spanned_error pos "Enum %s does not contain case %s" - (Pos.unmark enum) (Pos.unmark constructor) - with Not_found -> - Errors.raise_spanned_error (Pos.get_position enum) - "Enum %s has not been defined before" (Pos.unmark enum))) - | MatchWith (e1, (cases, _cases_pos)) -> - let e1 = translate_expr scope inside_definition_of ctxt e1 in - let cases_d, e_uid = - disambiguate_match_and_build_expression scope inside_definition_of ctxt - cases - in - Bindlib.box_apply2 - (fun e1 cases_d -> (Desugared.Ast.EMatch (e1, e_uid, cases_d), pos)) - e1 - (LiftEnumConstructorMap.lift_box cases_d) - | TestMatchCase (e1, pattern) -> - (match snd (Pos.unmark pattern) with - | None -> () - | Some binding -> - Errors.format_spanned_warning (Pos.get_position binding) - "This binding will be ignored (remove it to suppress warning)"); - let enum_uid, c_uid = - disambiguate_constructor ctxt - (fst (Pos.unmark pattern)) - (Pos.get_position pattern) - in - let cases = - Scopelang.Ast.EnumConstructorMap.mapi - (fun c_uid' tau -> + (* This sugar corresponds to [e is P x && e'] and should desugar to [match e + with P x -> e' | _ -> false] *) + let enum_uid, c_uid = + disambiguate_constructor ctxt constructors pos_pattern + in + let cases = + Scopelang.Ast.EnumConstructorMap.mapi + (fun c_uid' tau -> + if Scopelang.Ast.EnumConstructor.compare c_uid c_uid' <> 0 then let nop_var = Desugared.Ast.Var.make ("_", pos) in Bindlib.unbox (Desugared.Ast.make_abs [| nop_var |] - (Bindlib.box - ( Desugared.Ast.ELit - (Dcalc.Ast.LBool - (Scopelang.Ast.EnumConstructor.compare c_uid c_uid' - = 0)), - pos )) - pos [ tau ] pos)) - (Scopelang.Ast.EnumMap.find enum_uid ctxt.enums) + (Bindlib.box (Desugared.Ast.ELit (Dcalc.Ast.LBool false), pos)) + pos [tau] pos) + else + let ctxt, binding_var = + Name_resolution.add_def_local_var ctxt binding + in + let e2 = translate_expr scope inside_definition_of ctxt e2 in + Bindlib.unbox + (Desugared.Ast.make_abs [| binding_var |] e2 pos [tau] pos)) + (Scopelang.Ast.EnumMap.find enum_uid ctxt.enums) + in + Bindlib.box_apply + (fun e1_sub -> Desugared.Ast.EMatch (e1_sub, enum_uid, cases), pos) + (translate_expr scope inside_definition_of ctxt e1_sub) + | IfThenElse (e_if, e_then, e_else) -> + Bindlib.box_apply3 + (fun e_if e_then e_else -> + Desugared.Ast.EIfThenElse (e_if, e_then, e_else), pos) + (rec_helper e_if) (rec_helper e_then) (rec_helper e_else) + | Binop (op, e1, e2) -> + let op_term = + Pos.same_pos_as + (Desugared.Ast.EOp (Dcalc.Ast.Binop (translate_binop (Pos.unmark op)))) + op + in + Bindlib.box_apply2 + (fun e1 e2 -> Desugared.Ast.EApp (op_term, [e1; e2]), pos) + (rec_helper e1) (rec_helper e2) + | Unop (op, e) -> + let op_term = + Pos.same_pos_as + (Desugared.Ast.EOp (Dcalc.Ast.Unop (translate_unop (Pos.unmark op)))) + op + in + Bindlib.box_apply + (fun e -> Desugared.Ast.EApp (op_term, [e]), pos) + (rec_helper e) + | Literal l -> + let untyped_term = + match l with + | LNumber ((Int i, _), None) -> + Desugared.Ast.ELit (Dcalc.Ast.LInt (Runtime.integer_of_string i)) + | LNumber ((Int i, _), Some (Percent, _)) -> + Desugared.Ast.ELit + (Dcalc.Ast.LRat + Runtime.(decimal_of_string i /& decimal_of_string "100")) + | LNumber ((Dec (i, f), _), None) -> + Desugared.Ast.ELit + (Dcalc.Ast.LRat Runtime.(decimal_of_string (i ^ "." ^ f))) + | LNumber ((Dec (i, f), _), Some (Percent, _)) -> + Desugared.Ast.ELit + (Dcalc.Ast.LRat + Runtime.( + decimal_of_string (i ^ "." ^ f) /& decimal_of_string "100")) + | LBool b -> Desugared.Ast.ELit (Dcalc.Ast.LBool b) + | LMoneyAmount i -> + Desugared.Ast.ELit + (Dcalc.Ast.LMoney + Runtime.( + money_of_cents_integer + ((integer_of_string i.money_amount_units *! integer_of_int 100) + +! integer_of_string i.money_amount_cents))) + | LNumber ((Int i, _), Some (Year, _)) -> + Desugared.Ast.ELit + (Dcalc.Ast.LDuration + (Runtime.duration_of_numbers (int_of_string i) 0 0)) + | LNumber ((Int i, _), Some (Month, _)) -> + Desugared.Ast.ELit + (Dcalc.Ast.LDuration + (Runtime.duration_of_numbers 0 (int_of_string i) 0)) + | LNumber ((Int i, _), Some (Day, _)) -> + Desugared.Ast.ELit + (Dcalc.Ast.LDuration + (Runtime.duration_of_numbers 0 0 (int_of_string i))) + | LNumber ((Dec (_, _), _), Some ((Year | Month | Day), _)) -> + Errors.raise_spanned_error pos + "Impossible to specify decimal amounts of days, months or years" + | LDate date -> + if Pos.unmark date.literal_date_month > 12 then + Errors.raise_spanned_error + (Pos.get_position date.literal_date_month) + "There is an error in this date: the month number is bigger than 12"; + if Pos.unmark date.literal_date_day > 31 then + Errors.raise_spanned_error + (Pos.get_position date.literal_date_day) + "There is an error in this date: the day number is bigger than 31"; + Desugared.Ast.ELit + (Dcalc.Ast.LDate + (try + Runtime.date_of_numbers + (Pos.unmark date.literal_date_year) + (Pos.unmark date.literal_date_month) + (Pos.unmark date.literal_date_day) + with Runtime.ImpossibleDate -> + Errors.raise_spanned_error pos + "There is an error in this date, it does not correspond to a \ + correct calendar day")) + in + Bindlib.box (untyped_term, pos) + | Ident x -> ( + (* first we check whether this is a local var, then we resort to scope-wide + variables *) + match Desugared.Ast.IdentMap.find_opt x ctxt.local_var_idmap with + | None -> ( + match Desugared.Ast.IdentMap.find_opt x scope_ctxt.var_idmap with + | Some uid -> + (* If the referenced variable has states, then here are the rules to + desambiguate. In general, only the last state can be referenced. + Except if defining a state of the same variable, then it references + the previous state in the chain. *) + let x_sig = Desugared.Ast.ScopeVarMap.find uid ctxt.var_typs in + let x_state = + match x_sig.var_sig_states_list with + | [] -> None + | states -> ( + match inside_definition_of with + | Some (Desugared.Ast.ScopeDef.Var (x'_uid, sx'), _) + when Desugared.Ast.ScopeVar.compare uid x'_uid = 0 -> ( + match sx' with + | None -> + failwith + "inconsistent state: inside a definition of a variable with \ + no state but variable has states" + | Some inside_def_state -> + if + Desugared.Ast.StateName.compare inside_def_state + (List.hd states) + = 0 + then + Errors.raise_spanned_error pos + "It is impossible to refer to the variable you are \ + defining when defining its first state." + else + (* Tricky: we have to retrieve in the list the previous state + with respect to the state that we are defining. *) + let correct_state = ref None in + ignore + (List.fold_left + (fun previous_state state -> + if + Desugared.Ast.StateName.compare inside_def_state + state + = 0 + then correct_state := previous_state; + Some state) + None states); + !correct_state) + | _ -> + (* we take the last state in the chain *) + Some (List.hd (List.rev states))) + in + Bindlib.box + (Desugared.Ast.ELocation (ScopeVar ((uid, pos), x_state)), pos) + | None -> + Name_resolution.raise_unknown_identifier + "for a local or scope-wide variable" (x, pos)) + | Some uid -> + Desugared.Ast.make_var (uid, pos) + (* the whole box thing is to accomodate for this case *)) + | Dotted (e, c, x) -> ( + match Pos.unmark e with + | Ident y when Name_resolution.is_subscope_uid scope ctxt y -> + (* In this case, y.x is a subscope variable *) + let subscope_uid : Scopelang.Ast.SubScopeName.t = + Name_resolution.get_subscope_uid scope ctxt (Pos.same_pos_as y e) in - Bindlib.box_apply - (fun e -> (Desugared.Ast.EMatch (e, enum_uid, cases), pos)) - (translate_expr scope inside_definition_of ctxt e1) + let subscope_real_uid : Scopelang.Ast.ScopeName.t = + Scopelang.Ast.SubScopeMap.find subscope_uid scope_ctxt.sub_scopes + in + let subscope_var_uid = + Name_resolution.get_var_uid subscope_real_uid ctxt x + in + Bindlib.box + ( Desugared.Ast.ELocation + (SubScopeVar + (subscope_real_uid, (subscope_uid, pos), (subscope_var_uid, pos))), + pos ) + | _ -> ( + (* In this case e.x is the struct field x access of expression e *) + let e = translate_expr scope inside_definition_of ctxt e in + let x_possible_structs = + try Desugared.Ast.IdentMap.find (Pos.unmark x) ctxt.field_idmap + with Not_found -> + Errors.raise_spanned_error (Pos.get_position x) + "Unknown subscope or struct field name" + in + match c with + | None -> + (* No constructor name was specified *) + if Scopelang.Ast.StructMap.cardinal x_possible_structs > 1 then + Errors.raise_spanned_error (Pos.get_position x) + "This struct field name is ambiguous, it can belong to %a. \ + Disambiguate it by prefixing it with the struct name." + (Format.pp_print_list + ~pp_sep:(fun fmt () -> Format.fprintf fmt " or ") + (fun fmt (s_name, _) -> + Format.fprintf fmt "%a" Scopelang.Ast.StructName.format_t + s_name)) + (Scopelang.Ast.StructMap.bindings x_possible_structs) + else + let s_uid, f_uid = + Scopelang.Ast.StructMap.choose x_possible_structs + in + Bindlib.box_apply + (fun e -> Desugared.Ast.EStructAccess (e, f_uid, s_uid), pos) + e + | Some c_name -> ( + try + let c_uid = + Desugared.Ast.IdentMap.find (Pos.unmark c_name) ctxt.struct_idmap + in + try + let f_uid = Scopelang.Ast.StructMap.find c_uid x_possible_structs in + Bindlib.box_apply + (fun e -> Desugared.Ast.EStructAccess (e, f_uid, c_uid), pos) + e + with Not_found -> + Errors.raise_spanned_error pos "Struct %s does not contain field %s" + (Pos.unmark c_name) (Pos.unmark x) + with Not_found -> + Errors.raise_spanned_error (Pos.get_position c_name) + "Struct %s has not been defined before" (Pos.unmark c_name)))) + | FunCall (f, arg) -> + Bindlib.box_apply2 + (fun f arg -> Desugared.Ast.EApp (f, [arg]), pos) + (rec_helper f) (rec_helper arg) + | StructLit (s_name, fields) -> + let s_uid = + try Desugared.Ast.IdentMap.find (Pos.unmark s_name) ctxt.struct_idmap + with Not_found -> + Errors.raise_spanned_error (Pos.get_position s_name) + "This identifier should refer to a struct name" + in + + let s_fields = + List.fold_left + (fun s_fields (f_name, f_e) -> + let f_uid = + try + Scopelang.Ast.StructMap.find s_uid + (Desugared.Ast.IdentMap.find (Pos.unmark f_name) + ctxt.field_idmap) + with Not_found -> + Errors.raise_spanned_error (Pos.get_position f_name) + "This identifier should refer to a field of struct %s" + (Pos.unmark s_name) + in + (match Scopelang.Ast.StructFieldMap.find_opt f_uid s_fields with + | None -> () + | Some e_field -> + Errors.raise_multispanned_error + [ + None, Pos.get_position f_e; + None, Pos.get_position (Bindlib.unbox e_field); + ] + "The field %a has been defined twice:" + Scopelang.Ast.StructFieldName.format_t f_uid); + let f_e = translate_expr scope inside_definition_of ctxt f_e in + Scopelang.Ast.StructFieldMap.add f_uid f_e s_fields) + Scopelang.Ast.StructFieldMap.empty fields + in + let expected_s_fields = Scopelang.Ast.StructMap.find s_uid ctxt.structs in + Scopelang.Ast.StructFieldMap.iter + (fun expected_f _ -> + if not (Scopelang.Ast.StructFieldMap.mem expected_f s_fields) then + Errors.raise_spanned_error pos + "Missing field for structure %a: \"%a\"" + Scopelang.Ast.StructName.format_t s_uid + Scopelang.Ast.StructFieldName.format_t expected_f) + expected_s_fields; + + Bindlib.box_apply + (fun s_fields -> Desugared.Ast.EStruct (s_uid, s_fields), pos) + (LiftStructFieldMap.lift_box s_fields) + | EnumInject (enum, constructor, payload) -> ( + let possible_c_uids = + try + Desugared.Ast.IdentMap.find (Pos.unmark constructor) + ctxt.constructor_idmap + with Not_found -> + Errors.raise_spanned_error + (Pos.get_position constructor) + "The name of this constructor has not been defined before, maybe it \ + is a typo?" + in + + match enum with + | None -> + if + (* No constructor name was specified *) + Scopelang.Ast.EnumMap.cardinal possible_c_uids > 1 + then + Errors.raise_spanned_error + (Pos.get_position constructor) + "This constructor name is ambiguous, it can belong to %a. \ + Desambiguate it by prefixing it with the enum name." + (Format.pp_print_list + ~pp_sep:(fun fmt () -> Format.fprintf fmt " or ") + (fun fmt (s_name, _) -> + Format.fprintf fmt "%a" Scopelang.Ast.EnumName.format_t s_name)) + (Scopelang.Ast.EnumMap.bindings possible_c_uids) + else + let e_uid, c_uid = Scopelang.Ast.EnumMap.choose possible_c_uids in + let payload = + Option.map (translate_expr scope inside_definition_of ctxt) payload + in + Bindlib.box_apply + (fun payload -> + ( Desugared.Ast.EEnumInj + ( (match payload with + | Some e' -> e' + | None -> + ( Desugared.Ast.ELit Dcalc.Ast.LUnit, + Pos.get_position constructor )), + c_uid, + e_uid ), + pos )) + (Bindlib.box_opt payload) + | Some enum -> ( + try + (* The path has been fully qualified *) + let e_uid = + Desugared.Ast.IdentMap.find (Pos.unmark enum) ctxt.enum_idmap + in + try + let c_uid = Scopelang.Ast.EnumMap.find e_uid possible_c_uids in + let payload = + Option.map (translate_expr scope inside_definition_of ctxt) payload + in + Bindlib.box_apply + (fun payload -> + ( Desugared.Ast.EEnumInj + ( (match payload with + | Some e' -> e' + | None -> + ( Desugared.Ast.ELit Dcalc.Ast.LUnit, + Pos.get_position constructor )), + c_uid, + e_uid ), + pos )) + (Bindlib.box_opt payload) + with Not_found -> + Errors.raise_spanned_error pos "Enum %s does not contain case %s" + (Pos.unmark enum) (Pos.unmark constructor) + with Not_found -> + Errors.raise_spanned_error (Pos.get_position enum) + "Enum %s has not been defined before" (Pos.unmark enum))) + | MatchWith (e1, (cases, _cases_pos)) -> + let e1 = translate_expr scope inside_definition_of ctxt e1 in + let cases_d, e_uid = + disambiguate_match_and_build_expression scope inside_definition_of ctxt + cases + in + Bindlib.box_apply2 + (fun e1 cases_d -> Desugared.Ast.EMatch (e1, e_uid, cases_d), pos) + e1 + (LiftEnumConstructorMap.lift_box cases_d) + | TestMatchCase (e1, pattern) -> + (match snd (Pos.unmark pattern) with + | None -> () + | Some binding -> + Errors.format_spanned_warning (Pos.get_position binding) + "This binding will be ignored (remove it to suppress warning)"); + let enum_uid, c_uid = + disambiguate_constructor ctxt + (fst (Pos.unmark pattern)) + (Pos.get_position pattern) + in + let cases = + Scopelang.Ast.EnumConstructorMap.mapi + (fun c_uid' tau -> + let nop_var = Desugared.Ast.Var.make ("_", pos) in + Bindlib.unbox + (Desugared.Ast.make_abs [| nop_var |] + (Bindlib.box + ( Desugared.Ast.ELit + (Dcalc.Ast.LBool + (Scopelang.Ast.EnumConstructor.compare c_uid c_uid' = 0)), + pos )) + pos [tau] pos)) + (Scopelang.Ast.EnumMap.find enum_uid ctxt.enums) + in + Bindlib.box_apply + (fun e -> Desugared.Ast.EMatch (e, enum_uid, cases), pos) + (translate_expr scope inside_definition_of ctxt e1) | ArrayLit es -> - Bindlib.box_apply - (fun es -> (Desugared.Ast.EArray es, pos)) - (Bindlib.box_list (List.map rec_helper es)) + Bindlib.box_apply + (fun es -> Desugared.Ast.EArray es, pos) + (Bindlib.box_list (List.map rec_helper es)) | CollectionOp ( (((Ast.Filter | Ast.Map) as op'), _pos_op'), param', collection, predicate ) -> - let collection = rec_helper collection in - let ctxt, param = Name_resolution.add_def_local_var ctxt param' in - let f_pred = - Desugared.Ast.make_abs [| param |] - (translate_expr scope inside_definition_of ctxt predicate) - pos - [ (Scopelang.Ast.TAny, pos) ] - pos - in - Bindlib.box_apply2 - (fun f_pred collection -> - ( Desugared.Ast.EApp - ( ( Desugared.Ast.EOp - (match op' with - | Ast.Map -> Dcalc.Ast.Binop Dcalc.Ast.Map - | Ast.Filter -> Dcalc.Ast.Binop Dcalc.Ast.Filter - | _ -> assert false (* should not happen *)), - pos ), - [ f_pred; collection ] ), - pos )) - f_pred collection + let collection = rec_helper collection in + let ctxt, param = Name_resolution.add_def_local_var ctxt param' in + let f_pred = + Desugared.Ast.make_abs [| param |] + (translate_expr scope inside_definition_of ctxt predicate) + pos + [Scopelang.Ast.TAny, pos] + pos + in + Bindlib.box_apply2 + (fun f_pred collection -> + ( Desugared.Ast.EApp + ( ( Desugared.Ast.EOp + (match op' with + | Ast.Map -> Dcalc.Ast.Binop Dcalc.Ast.Map + | Ast.Filter -> Dcalc.Ast.Binop Dcalc.Ast.Filter + | _ -> assert false (* should not happen *)), + pos ), + [f_pred; collection] ), + pos )) + f_pred collection | CollectionOp ( ( Ast.Aggregate (Ast.AggregateArgExtremum (max_or_min, pred_typ, init)), pos_op' ), param', collection, predicate ) -> - let init = rec_helper init in - let collection = rec_helper collection in - let ctxt, param = Name_resolution.add_def_local_var ctxt param' in - let op_kind = - match pred_typ with - | Ast.Integer -> Dcalc.Ast.KInt - | Ast.Decimal -> Dcalc.Ast.KRat - | Ast.Money -> Dcalc.Ast.KMoney - | Ast.Duration -> Dcalc.Ast.KDuration - | Ast.Date -> Dcalc.Ast.KDate - | _ -> - Errors.raise_spanned_error pos - "It is impossible to compute the arg-%s of two values of type %a" - (if max_or_min then "max" else "min") - Print.format_primitive_typ pred_typ - in - let cmp_op = - if max_or_min then Dcalc.Ast.Gt op_kind else Dcalc.Ast.Lt op_kind - in - let f_pred = - Desugared.Ast.make_abs [| param |] - (translate_expr scope inside_definition_of ctxt predicate) - pos - [ (Scopelang.Ast.TAny, pos) ] - pos - in - let f_pred_var = - Desugared.Ast.Var.make ("predicate", Pos.get_position predicate) - in - let f_pred_var_e = - Desugared.Ast.make_var (f_pred_var, Pos.get_position predicate) - in - let acc_var = Desugared.Ast.Var.make ("acc", pos) in - let acc_var_e = Desugared.Ast.make_var (acc_var, pos) in - let item_var = - Desugared.Ast.Var.make - ("item", Pos.get_position (Bindlib.unbox collection)) - in - let item_var_e = - Desugared.Ast.make_var - (item_var, Pos.get_position (Bindlib.unbox collection)) - in - let fold_body = - Bindlib.box_apply3 - (fun acc_var_e item_var_e f_pred_var_e -> - ( Desugared.Ast.EIfThenElse - ( ( Desugared.Ast.EApp - ( (Desugared.Ast.EOp (Dcalc.Ast.Binop cmp_op), pos_op'), - [ - (Desugared.Ast.EApp (f_pred_var_e, [ acc_var_e ]), pos); - ( Desugared.Ast.EApp (f_pred_var_e, [ item_var_e ]), - pos ); - ] ), - pos ), - acc_var_e, - item_var_e ), - pos )) - acc_var_e item_var_e f_pred_var_e - in - let fold_f = - Desugared.Ast.make_abs [| acc_var; item_var |] fold_body pos - [ (Scopelang.Ast.TAny, pos); (Scopelang.Ast.TAny, pos) ] - pos - in - let fold = - Bindlib.box_apply3 - (fun fold_f collection init -> - ( Desugared.Ast.EApp - ( (Desugared.Ast.EOp (Dcalc.Ast.Ternop Dcalc.Ast.Fold), pos), - [ fold_f; init; collection ] ), - pos )) - fold_f collection init - in - Desugared.Ast.make_let_in f_pred_var (Scopelang.Ast.TAny, pos) f_pred fold - | CollectionOp (op', param', collection, predicate) -> - let ctxt, param = Name_resolution.add_def_local_var ctxt param' in - let collection = rec_helper collection in - let init = - match Pos.unmark op' with - | Ast.Map | Ast.Filter | Ast.Aggregate (Ast.AggregateArgExtremum _) -> - assert false (* should not happen *) - | Ast.Exists -> - Bindlib.box - (Desugared.Ast.ELit (Dcalc.Ast.LBool false), Pos.get_position op') - | Ast.Forall -> - Bindlib.box - (Desugared.Ast.ELit (Dcalc.Ast.LBool true), Pos.get_position op') - | Ast.Aggregate (Ast.AggregateSum Ast.Integer) -> - Bindlib.box - ( Desugared.Ast.ELit (Dcalc.Ast.LInt (Runtime.integer_of_int 0)), - Pos.get_position op' ) - | Ast.Aggregate (Ast.AggregateSum Ast.Decimal) -> - Bindlib.box - ( Desugared.Ast.ELit - (Dcalc.Ast.LRat (Runtime.decimal_of_string "0")), - Pos.get_position op' ) - | Ast.Aggregate (Ast.AggregateSum Ast.Money) -> - Bindlib.box - ( Desugared.Ast.ELit - (Dcalc.Ast.LMoney - (Runtime.money_of_cents_integer (Runtime.integer_of_int 0))), - Pos.get_position op' ) - | Ast.Aggregate (Ast.AggregateSum Ast.Duration) -> - Bindlib.box - ( Desugared.Ast.ELit - (Dcalc.Ast.LDuration (Runtime.duration_of_numbers 0 0 0)), - Pos.get_position op' ) - | Ast.Aggregate (Ast.AggregateSum t) -> - Errors.raise_spanned_error pos - "It is impossible to sum two values of type %a together" - Print.format_primitive_typ t - | Ast.Aggregate (Ast.AggregateExtremum (_, _, init)) -> rec_helper init - | Ast.Aggregate Ast.AggregateCount -> - Bindlib.box - ( Desugared.Ast.ELit (Dcalc.Ast.LInt (Runtime.integer_of_int 0)), - Pos.get_position op' ) - in - let acc_var = Desugared.Ast.Var.make ("acc", Pos.get_position param') in - let acc = Desugared.Ast.make_var (acc_var, Pos.get_position param') in - let f_body = - let make_body (op : Dcalc.Ast.binop) = - Bindlib.box_apply2 - (fun predicate acc -> - ( Desugared.Ast.EApp - ( ( Desugared.Ast.EOp (Dcalc.Ast.Binop op), - Pos.get_position op' ), - [ acc; predicate ] ), - pos )) - (translate_expr scope inside_definition_of ctxt predicate) - acc - in - let make_extr_body - (cmp_op : Dcalc.Ast.binop) (t : Scopelang.Ast.typ Pos.marked) = - let tmp_var = - Desugared.Ast.Var.make ("tmp", Pos.get_position param') - in - let tmp = Desugared.Ast.make_var (tmp_var, Pos.get_position param') in - Desugared.Ast.make_let_in tmp_var t - (translate_expr scope inside_definition_of ctxt predicate) - (Bindlib.box_apply2 - (fun acc tmp -> - ( Desugared.Ast.EIfThenElse - ( ( Desugared.Ast.EApp - ( ( Desugared.Ast.EOp (Dcalc.Ast.Binop cmp_op), - Pos.get_position op' ), - [ acc; tmp ] ), - pos ), - acc, - tmp ), - pos )) - acc tmp) - in - match Pos.unmark op' with - | Ast.Map | Ast.Filter | Ast.Aggregate (Ast.AggregateArgExtremum _) -> - assert false (* should not happen *) - | Ast.Exists -> make_body Dcalc.Ast.Or - | Ast.Forall -> make_body Dcalc.Ast.And - | Ast.Aggregate (Ast.AggregateSum Ast.Integer) -> - make_body (Dcalc.Ast.Add Dcalc.Ast.KInt) - | Ast.Aggregate (Ast.AggregateSum Ast.Decimal) -> - make_body (Dcalc.Ast.Add Dcalc.Ast.KRat) - | Ast.Aggregate (Ast.AggregateSum Ast.Money) -> - make_body (Dcalc.Ast.Add Dcalc.Ast.KMoney) - | Ast.Aggregate (Ast.AggregateSum Ast.Duration) -> - make_body (Dcalc.Ast.Add Dcalc.Ast.KDuration) - | Ast.Aggregate (Ast.AggregateSum _) -> - assert false (* should not happen *) - | Ast.Aggregate (Ast.AggregateExtremum (max_or_min, t, _)) -> - let op_kind, typ = - match t with - | Ast.Integer -> (Dcalc.Ast.KInt, (Scopelang.Ast.TLit TInt, pos)) - | Ast.Decimal -> (Dcalc.Ast.KRat, (Scopelang.Ast.TLit TRat, pos)) - | Ast.Money -> (Dcalc.Ast.KMoney, (Scopelang.Ast.TLit TMoney, pos)) - | Ast.Duration -> - (Dcalc.Ast.KDuration, (Scopelang.Ast.TLit TDuration, pos)) - | Ast.Date -> (Dcalc.Ast.KDate, (Scopelang.Ast.TLit TDate, pos)) - | _ -> - Errors.raise_spanned_error pos - "It is impossible to compute the %s of two values of type \ - %a" - (if max_or_min then "max" else "min") - Print.format_primitive_typ t - in - let cmp_op = - if max_or_min then Dcalc.Ast.Gt op_kind else Dcalc.Ast.Lt op_kind - in - make_extr_body cmp_op typ - | Ast.Aggregate Ast.AggregateCount -> - Bindlib.box_apply2 - (fun predicate acc -> - ( Desugared.Ast.EIfThenElse - ( predicate, - ( Desugared.Ast.EApp - ( ( Desugared.Ast.EOp - (Dcalc.Ast.Binop (Dcalc.Ast.Add Dcalc.Ast.KInt)), - Pos.get_position op' ), - [ - acc; - ( Desugared.Ast.ELit - (Dcalc.Ast.LInt (Runtime.integer_of_int 1)), - Pos.get_position predicate ); - ] ), - pos ), - acc ), - pos )) - (translate_expr scope inside_definition_of ctxt predicate) - acc - in - let f = - let make_f (t : Dcalc.Ast.typ_lit) = - Bindlib.box_apply - (fun binder -> - ( Desugared.Ast.EAbs - ( (binder, pos), - [ - (Scopelang.Ast.TLit t, Pos.get_position op'); - (Scopelang.Ast.TAny, pos) - (* we put any here because the type of the elements of the - arrays is not always the type of the accumulator; for - instance in AggregateCount. *); - ] ), - pos )) - (Bindlib.bind_mvar [| acc_var; param |] f_body) - in - match Pos.unmark op' with - | Ast.Map | Ast.Filter | Ast.Aggregate (Ast.AggregateArgExtremum _) -> - assert false (* should not happen *) - | Ast.Exists -> make_f Dcalc.Ast.TBool - | Ast.Forall -> make_f Dcalc.Ast.TBool - | Ast.Aggregate (Ast.AggregateSum Ast.Integer) - | Ast.Aggregate (Ast.AggregateExtremum (_, Ast.Integer, _)) -> - make_f Dcalc.Ast.TInt - | Ast.Aggregate (Ast.AggregateSum Ast.Decimal) - | Ast.Aggregate (Ast.AggregateExtremum (_, Ast.Decimal, _)) -> - make_f Dcalc.Ast.TRat - | Ast.Aggregate (Ast.AggregateSum Ast.Money) - | Ast.Aggregate (Ast.AggregateExtremum (_, Ast.Money, _)) -> - make_f Dcalc.Ast.TMoney - | Ast.Aggregate (Ast.AggregateSum Ast.Duration) - | Ast.Aggregate (Ast.AggregateExtremum (_, Ast.Duration, _)) -> - make_f Dcalc.Ast.TDuration - | Ast.Aggregate (Ast.AggregateSum _) - | Ast.Aggregate (Ast.AggregateExtremum _) -> - assert false (* should not happen *) - | Ast.Aggregate Ast.AggregateCount -> make_f Dcalc.Ast.TInt - in + let init = rec_helper init in + let collection = rec_helper collection in + let ctxt, param = Name_resolution.add_def_local_var ctxt param' in + let op_kind = + match pred_typ with + | Ast.Integer -> Dcalc.Ast.KInt + | Ast.Decimal -> Dcalc.Ast.KRat + | Ast.Money -> Dcalc.Ast.KMoney + | Ast.Duration -> Dcalc.Ast.KDuration + | Ast.Date -> Dcalc.Ast.KDate + | _ -> + Errors.raise_spanned_error pos + "It is impossible to compute the arg-%s of two values of type %a" + (if max_or_min then "max" else "min") + Print.format_primitive_typ pred_typ + in + let cmp_op = + if max_or_min then Dcalc.Ast.Gt op_kind else Dcalc.Ast.Lt op_kind + in + let f_pred = + Desugared.Ast.make_abs [| param |] + (translate_expr scope inside_definition_of ctxt predicate) + pos + [Scopelang.Ast.TAny, pos] + pos + in + let f_pred_var = + Desugared.Ast.Var.make ("predicate", Pos.get_position predicate) + in + let f_pred_var_e = + Desugared.Ast.make_var (f_pred_var, Pos.get_position predicate) + in + let acc_var = Desugared.Ast.Var.make ("acc", pos) in + let acc_var_e = Desugared.Ast.make_var (acc_var, pos) in + let item_var = + Desugared.Ast.Var.make + ("item", Pos.get_position (Bindlib.unbox collection)) + in + let item_var_e = + Desugared.Ast.make_var + (item_var, Pos.get_position (Bindlib.unbox collection)) + in + let fold_body = Bindlib.box_apply3 - (fun f collection init -> + (fun acc_var_e item_var_e f_pred_var_e -> + ( Desugared.Ast.EIfThenElse + ( ( Desugared.Ast.EApp + ( (Desugared.Ast.EOp (Dcalc.Ast.Binop cmp_op), pos_op'), + [ + Desugared.Ast.EApp (f_pred_var_e, [acc_var_e]), pos; + Desugared.Ast.EApp (f_pred_var_e, [item_var_e]), pos; + ] ), + pos ), + acc_var_e, + item_var_e ), + pos )) + acc_var_e item_var_e f_pred_var_e + in + let fold_f = + Desugared.Ast.make_abs [| acc_var; item_var |] fold_body pos + [Scopelang.Ast.TAny, pos; Scopelang.Ast.TAny, pos] + pos + in + let fold = + Bindlib.box_apply3 + (fun fold_f collection init -> ( Desugared.Ast.EApp ( (Desugared.Ast.EOp (Dcalc.Ast.Ternop Dcalc.Ast.Fold), pos), - [ f; init; collection ] ), + [fold_f; init; collection] ), pos )) - f collection init - | MemCollection (member, collection) -> - let param_var = Desugared.Ast.Var.make ("collection_member", pos) in - let param = Desugared.Ast.make_var (param_var, pos) in - let collection = rec_helper collection in - let init = - Bindlib.box (Desugared.Ast.ELit (Dcalc.Ast.LBool false), pos) - in - let acc_var = Desugared.Ast.Var.make ("acc", pos) in - let acc = Desugared.Ast.make_var (acc_var, pos) in - let f_body = - Bindlib.box_apply3 - (fun member acc param -> + fold_f collection init + in + Desugared.Ast.make_let_in f_pred_var (Scopelang.Ast.TAny, pos) f_pred fold + | CollectionOp (op', param', collection, predicate) -> + let ctxt, param = Name_resolution.add_def_local_var ctxt param' in + let collection = rec_helper collection in + let init = + match Pos.unmark op' with + | Ast.Map | Ast.Filter | Ast.Aggregate (Ast.AggregateArgExtremum _) -> + assert false (* should not happen *) + | Ast.Exists -> + Bindlib.box + (Desugared.Ast.ELit (Dcalc.Ast.LBool false), Pos.get_position op') + | Ast.Forall -> + Bindlib.box + (Desugared.Ast.ELit (Dcalc.Ast.LBool true), Pos.get_position op') + | Ast.Aggregate (Ast.AggregateSum Ast.Integer) -> + Bindlib.box + ( Desugared.Ast.ELit (Dcalc.Ast.LInt (Runtime.integer_of_int 0)), + Pos.get_position op' ) + | Ast.Aggregate (Ast.AggregateSum Ast.Decimal) -> + Bindlib.box + ( Desugared.Ast.ELit (Dcalc.Ast.LRat (Runtime.decimal_of_string "0")), + Pos.get_position op' ) + | Ast.Aggregate (Ast.AggregateSum Ast.Money) -> + Bindlib.box + ( Desugared.Ast.ELit + (Dcalc.Ast.LMoney + (Runtime.money_of_cents_integer (Runtime.integer_of_int 0))), + Pos.get_position op' ) + | Ast.Aggregate (Ast.AggregateSum Ast.Duration) -> + Bindlib.box + ( Desugared.Ast.ELit + (Dcalc.Ast.LDuration (Runtime.duration_of_numbers 0 0 0)), + Pos.get_position op' ) + | Ast.Aggregate (Ast.AggregateSum t) -> + Errors.raise_spanned_error pos + "It is impossible to sum two values of type %a together" + Print.format_primitive_typ t + | Ast.Aggregate (Ast.AggregateExtremum (_, _, init)) -> rec_helper init + | Ast.Aggregate Ast.AggregateCount -> + Bindlib.box + ( Desugared.Ast.ELit (Dcalc.Ast.LInt (Runtime.integer_of_int 0)), + Pos.get_position op' ) + in + let acc_var = Desugared.Ast.Var.make ("acc", Pos.get_position param') in + let acc = Desugared.Ast.make_var (acc_var, Pos.get_position param') in + let f_body = + let make_body (op : Dcalc.Ast.binop) = + Bindlib.box_apply2 + (fun predicate acc -> ( Desugared.Ast.EApp - ( (Desugared.Ast.EOp (Dcalc.Ast.Binop Dcalc.Ast.Or), pos), - [ - ( Desugared.Ast.EApp - ( (Desugared.Ast.EOp (Dcalc.Ast.Binop Dcalc.Ast.Eq), pos), - [ member; param ] ), - pos ); - acc; - ] ), + ( (Desugared.Ast.EOp (Dcalc.Ast.Binop op), Pos.get_position op'), + [acc; predicate] ), pos )) - (translate_expr scope inside_definition_of ctxt member) - acc param + (translate_expr scope inside_definition_of ctxt predicate) + acc in - let f = + let make_extr_body + (cmp_op : Dcalc.Ast.binop) + (t : Scopelang.Ast.typ Pos.marked) = + let tmp_var = Desugared.Ast.Var.make ("tmp", Pos.get_position param') in + let tmp = Desugared.Ast.make_var (tmp_var, Pos.get_position param') in + Desugared.Ast.make_let_in tmp_var t + (translate_expr scope inside_definition_of ctxt predicate) + (Bindlib.box_apply2 + (fun acc tmp -> + ( Desugared.Ast.EIfThenElse + ( ( Desugared.Ast.EApp + ( ( Desugared.Ast.EOp (Dcalc.Ast.Binop cmp_op), + Pos.get_position op' ), + [acc; tmp] ), + pos ), + acc, + tmp ), + pos )) + acc tmp) + in + match Pos.unmark op' with + | Ast.Map | Ast.Filter | Ast.Aggregate (Ast.AggregateArgExtremum _) -> + assert false (* should not happen *) + | Ast.Exists -> make_body Dcalc.Ast.Or + | Ast.Forall -> make_body Dcalc.Ast.And + | Ast.Aggregate (Ast.AggregateSum Ast.Integer) -> + make_body (Dcalc.Ast.Add Dcalc.Ast.KInt) + | Ast.Aggregate (Ast.AggregateSum Ast.Decimal) -> + make_body (Dcalc.Ast.Add Dcalc.Ast.KRat) + | Ast.Aggregate (Ast.AggregateSum Ast.Money) -> + make_body (Dcalc.Ast.Add Dcalc.Ast.KMoney) + | Ast.Aggregate (Ast.AggregateSum Ast.Duration) -> + make_body (Dcalc.Ast.Add Dcalc.Ast.KDuration) + | Ast.Aggregate (Ast.AggregateSum _) -> + assert false (* should not happen *) + | Ast.Aggregate (Ast.AggregateExtremum (max_or_min, t, _)) -> + let op_kind, typ = + match t with + | Ast.Integer -> Dcalc.Ast.KInt, (Scopelang.Ast.TLit TInt, pos) + | Ast.Decimal -> Dcalc.Ast.KRat, (Scopelang.Ast.TLit TRat, pos) + | Ast.Money -> Dcalc.Ast.KMoney, (Scopelang.Ast.TLit TMoney, pos) + | Ast.Duration -> + Dcalc.Ast.KDuration, (Scopelang.Ast.TLit TDuration, pos) + | Ast.Date -> Dcalc.Ast.KDate, (Scopelang.Ast.TLit TDate, pos) + | _ -> + Errors.raise_spanned_error pos + "It is impossible to compute the %s of two values of type %a" + (if max_or_min then "max" else "min") + Print.format_primitive_typ t + in + let cmp_op = + if max_or_min then Dcalc.Ast.Gt op_kind else Dcalc.Ast.Lt op_kind + in + make_extr_body cmp_op typ + | Ast.Aggregate Ast.AggregateCount -> + Bindlib.box_apply2 + (fun predicate acc -> + ( Desugared.Ast.EIfThenElse + ( predicate, + ( Desugared.Ast.EApp + ( ( Desugared.Ast.EOp + (Dcalc.Ast.Binop (Dcalc.Ast.Add Dcalc.Ast.KInt)), + Pos.get_position op' ), + [ + acc; + ( Desugared.Ast.ELit + (Dcalc.Ast.LInt (Runtime.integer_of_int 1)), + Pos.get_position predicate ); + ] ), + pos ), + acc ), + pos )) + (translate_expr scope inside_definition_of ctxt predicate) + acc + in + let f = + let make_f (t : Dcalc.Ast.typ_lit) = Bindlib.box_apply (fun binder -> ( Desugared.Ast.EAbs ( (binder, pos), [ - (Scopelang.Ast.TLit Dcalc.Ast.TBool, pos); - (Scopelang.Ast.TAny, pos); + Scopelang.Ast.TLit t, Pos.get_position op'; + Scopelang.Ast.TAny, pos + (* we put any here because the type of the elements of the + arrays is not always the type of the accumulator; for + instance in AggregateCount. *); ] ), pos )) - (Bindlib.bind_mvar [| acc_var; param_var |] f_body) + (Bindlib.bind_mvar [| acc_var; param |] f_body) in + match Pos.unmark op' with + | Ast.Map | Ast.Filter | Ast.Aggregate (Ast.AggregateArgExtremum _) -> + assert false (* should not happen *) + | Ast.Exists -> make_f Dcalc.Ast.TBool + | Ast.Forall -> make_f Dcalc.Ast.TBool + | Ast.Aggregate (Ast.AggregateSum Ast.Integer) + | Ast.Aggregate (Ast.AggregateExtremum (_, Ast.Integer, _)) -> + make_f Dcalc.Ast.TInt + | Ast.Aggregate (Ast.AggregateSum Ast.Decimal) + | Ast.Aggregate (Ast.AggregateExtremum (_, Ast.Decimal, _)) -> + make_f Dcalc.Ast.TRat + | Ast.Aggregate (Ast.AggregateSum Ast.Money) + | Ast.Aggregate (Ast.AggregateExtremum (_, Ast.Money, _)) -> + make_f Dcalc.Ast.TMoney + | Ast.Aggregate (Ast.AggregateSum Ast.Duration) + | Ast.Aggregate (Ast.AggregateExtremum (_, Ast.Duration, _)) -> + make_f Dcalc.Ast.TDuration + | Ast.Aggregate (Ast.AggregateSum _) + | Ast.Aggregate (Ast.AggregateExtremum _) -> + assert false (* should not happen *) + | Ast.Aggregate Ast.AggregateCount -> make_f Dcalc.Ast.TInt + in + Bindlib.box_apply3 + (fun f collection init -> + ( Desugared.Ast.EApp + ( (Desugared.Ast.EOp (Dcalc.Ast.Ternop Dcalc.Ast.Fold), pos), + [f; init; collection] ), + pos )) + f collection init + | MemCollection (member, collection) -> + let param_var = Desugared.Ast.Var.make ("collection_member", pos) in + let param = Desugared.Ast.make_var (param_var, pos) in + let collection = rec_helper collection in + let init = Bindlib.box (Desugared.Ast.ELit (Dcalc.Ast.LBool false), pos) in + let acc_var = Desugared.Ast.Var.make ("acc", pos) in + let acc = Desugared.Ast.make_var (acc_var, pos) in + let f_body = Bindlib.box_apply3 - (fun f collection init -> + (fun member acc param -> ( Desugared.Ast.EApp - ( (Desugared.Ast.EOp (Dcalc.Ast.Ternop Dcalc.Ast.Fold), pos), - [ f; init; collection ] ), + ( (Desugared.Ast.EOp (Dcalc.Ast.Binop Dcalc.Ast.Or), pos), + [ + ( Desugared.Ast.EApp + ( (Desugared.Ast.EOp (Dcalc.Ast.Binop Dcalc.Ast.Eq), pos), + [member; param] ), + pos ); + acc; + ] ), pos )) - f collection init + (translate_expr scope inside_definition_of ctxt member) + acc param + in + let f = + Bindlib.box_apply + (fun binder -> + ( Desugared.Ast.EAbs + ( (binder, pos), + [ + Scopelang.Ast.TLit Dcalc.Ast.TBool, pos; + Scopelang.Ast.TAny, pos; + ] ), + pos )) + (Bindlib.bind_mvar [| acc_var; param_var |] f_body) + in + Bindlib.box_apply3 + (fun f collection init -> + ( Desugared.Ast.EApp + ( (Desugared.Ast.EOp (Dcalc.Ast.Ternop Dcalc.Ast.Fold), pos), + [f; init; collection] ), + pos )) + f collection init | Builtin IntToDec -> - Bindlib.box (Desugared.Ast.EOp (Dcalc.Ast.Unop Dcalc.Ast.IntToRat), pos) + Bindlib.box (Desugared.Ast.EOp (Dcalc.Ast.Unop Dcalc.Ast.IntToRat), pos) | Builtin Cardinal -> - Bindlib.box (Desugared.Ast.EOp (Dcalc.Ast.Unop Dcalc.Ast.Length), pos) + Bindlib.box (Desugared.Ast.EOp (Dcalc.Ast.Unop Dcalc.Ast.Length), pos) | Builtin GetDay -> - Bindlib.box (Desugared.Ast.EOp (Dcalc.Ast.Unop Dcalc.Ast.GetDay), pos) + Bindlib.box (Desugared.Ast.EOp (Dcalc.Ast.Unop Dcalc.Ast.GetDay), pos) | Builtin GetMonth -> - Bindlib.box (Desugared.Ast.EOp (Dcalc.Ast.Unop Dcalc.Ast.GetMonth), pos) + Bindlib.box (Desugared.Ast.EOp (Dcalc.Ast.Unop Dcalc.Ast.GetMonth), pos) | Builtin GetYear -> - Bindlib.box (Desugared.Ast.EOp (Dcalc.Ast.Unop Dcalc.Ast.GetYear), pos) + Bindlib.box (Desugared.Ast.EOp (Dcalc.Ast.Unop Dcalc.Ast.GetYear), pos) | Builtin RoundMoney -> - Bindlib.box (Desugared.Ast.EOp (Dcalc.Ast.Unop Dcalc.Ast.RoundMoney), pos) + Bindlib.box (Desugared.Ast.EOp (Dcalc.Ast.Unop Dcalc.Ast.RoundMoney), pos) | Builtin RoundDecimal -> - Bindlib.box - (Desugared.Ast.EOp (Dcalc.Ast.Unop Dcalc.Ast.RoundDecimal), pos) + Bindlib.box (Desugared.Ast.EOp (Dcalc.Ast.Unop Dcalc.Ast.RoundDecimal), pos) and disambiguate_match_and_build_expression (scope : Scopelang.Ast.ScopeName.t) @@ -901,10 +874,10 @@ and disambiguate_match_and_build_expression Desugared.Ast.expr Pos.marked Bindlib.box Scopelang.Ast.EnumConstructorMap.t * Scopelang.Ast.EnumName.t = let create_var = function - | None -> (ctxt, (Desugared.Ast.Var.make ("_", Pos.no_pos), Pos.no_pos)) + | None -> ctxt, (Desugared.Ast.Var.make ("_", Pos.no_pos), Pos.no_pos) | Some param -> - let ctxt, param_var = Name_resolution.add_def_local_var ctxt param in - (ctxt, (param_var, Pos.get_position param)) + let ctxt, param_var = Name_resolution.add_def_local_var ctxt param in + ctxt, (param_var, Pos.get_position param) in let bind_case_body (c_uid : Dcalc.Ast.EnumConstructor.t) @@ -930,122 +903,117 @@ and disambiguate_match_and_build_expression let bind_match_cases (cases_d, e_uid, curr_index) (case, case_pos) = match case with | Ast.MatchCase case -> - let constructor, binding = Pos.unmark case.Ast.match_case_pattern in - let e_uid', c_uid = - disambiguate_constructor ctxt constructor - (Pos.get_position case.Ast.match_case_pattern) + let constructor, binding = Pos.unmark case.Ast.match_case_pattern in + let e_uid', c_uid = + disambiguate_constructor ctxt constructor + (Pos.get_position case.Ast.match_case_pattern) + in + let e_uid = + match e_uid with + | None -> e_uid' + | Some e_uid -> + if e_uid = e_uid' then e_uid + else + Errors.raise_spanned_error + (Pos.get_position case.Ast.match_case_pattern) + "This case matches a constructor of enumeration %a but previous \ + case were matching constructors of enumeration %a" + Scopelang.Ast.EnumName.format_t e_uid + Scopelang.Ast.EnumName.format_t e_uid' + in + (match Scopelang.Ast.EnumConstructorMap.find_opt c_uid cases_d with + | None -> () + | Some e_case -> + Errors.raise_multispanned_error + [ + None, Pos.get_position case.match_case_expr; + None, Pos.get_position (Bindlib.unbox e_case); + ] + "The constructor %a has been matched twice:" + Scopelang.Ast.EnumConstructor.format_t c_uid); + let ctxt, (param_var, param_pos) = create_var binding in + let case_body = + translate_expr scope inside_definition_of ctxt case.Ast.match_case_expr + in + let e_binder = Bindlib.bind_mvar (Array.of_list [param_var]) case_body in + let case_expr = + bind_case_body c_uid e_uid ctxt param_pos case_body e_binder + in + ( Scopelang.Ast.EnumConstructorMap.add c_uid case_expr cases_d, + Some e_uid, + curr_index + 1 ) + | Ast.WildCard match_case_expr -> ( + let nb_cases = List.length cases in + let raise_wildcard_not_last_case_err () = + Errors.raise_multispanned_error + [ + Some "Not ending wildcard:", case_pos; + ( Some "Next reachable case:", + curr_index + 1 |> List.nth cases |> Pos.get_position ); + ] + "Wildcard must be the last match case" + in + match e_uid with + | None -> + if 1 = nb_cases then + Errors.raise_spanned_error case_pos + "Couldn't infer the enumeration name from lonely wildcard \ + (wildcard cannot be used as single match case)" + else raise_wildcard_not_last_case_err () + | Some e_uid -> + if curr_index < nb_cases - 1 then raise_wildcard_not_last_case_err (); + let missing_constructors = + Scopelang.Ast.EnumMap.find e_uid ctxt.Name_resolution.enums + |> Scopelang.Ast.EnumConstructorMap.filter_map (fun c_uid _ -> + match + Scopelang.Ast.EnumConstructorMap.find_opt c_uid cases_d + with + | Some _ -> None + | None -> Some c_uid) in - let e_uid = - match e_uid with - | None -> e_uid' - | Some e_uid -> - if e_uid = e_uid' then e_uid - else - Errors.raise_spanned_error - (Pos.get_position case.Ast.match_case_pattern) - "This case matches a constructor of enumeration %a but \ - previous case were matching constructors of enumeration %a" - Scopelang.Ast.EnumName.format_t e_uid - Scopelang.Ast.EnumName.format_t e_uid' - in - (match Scopelang.Ast.EnumConstructorMap.find_opt c_uid cases_d with - | None -> () - | Some e_case -> - Errors.raise_multispanned_error - [ - (None, Pos.get_position case.match_case_expr); - (None, Pos.get_position (Bindlib.unbox e_case)); - ] - "The constructor %a has been matched twice:" - Scopelang.Ast.EnumConstructor.format_t c_uid); - let ctxt, (param_var, param_pos) = create_var binding in + if Scopelang.Ast.EnumConstructorMap.is_empty missing_constructors then + Errors.format_spanned_warning case_pos + "Unreachable match case, all constructors of the enumeration %a \ + are already specified" + Scopelang.Ast.EnumName.format_t e_uid; + (* The current used strategy is to replace the wildcard branch: + match foo with + | Case1 x -> x + | _ -> 1 + with: + let wildcard_payload = 1 in + match foo with + | Case1 x -> x + | Case2 -> wildcard_payload + ... + | CaseN -> wildcard_payload *) + (* Creates the wildcard payload *) + let ctxt, (payload_var, var_pos) = create_var None in let case_body = - translate_expr scope inside_definition_of ctxt - case.Ast.match_case_expr + translate_expr scope inside_definition_of ctxt match_case_expr in let e_binder = - Bindlib.bind_mvar (Array.of_list [ param_var ]) case_body + Bindlib.bind_mvar (Array.of_list [payload_var]) case_body in - let case_expr = - bind_case_body c_uid e_uid ctxt param_pos case_body e_binder - in - ( Scopelang.Ast.EnumConstructorMap.add c_uid case_expr cases_d, - Some e_uid, - curr_index + 1 ) - | Ast.WildCard match_case_expr -> ( - let nb_cases = List.length cases in - let raise_wildcard_not_last_case_err () = - Errors.raise_multispanned_error - [ - (Some "Not ending wildcard:", case_pos); - ( Some "Next reachable case:", - curr_index + 1 |> List.nth cases |> Pos.get_position ); - ] - "Wildcard must be the last match case" - in - match e_uid with - | None -> - if 1 = nb_cases then - Errors.raise_spanned_error case_pos - "Couldn't infer the enumeration name from lonely wildcard \ - (wildcard cannot be used as single match case)" - else raise_wildcard_not_last_case_err () - | Some e_uid -> - if curr_index < nb_cases - 1 then - raise_wildcard_not_last_case_err (); - let missing_constructors = - Scopelang.Ast.EnumMap.find e_uid ctxt.Name_resolution.enums - |> Scopelang.Ast.EnumConstructorMap.filter_map (fun c_uid _ -> - match - Scopelang.Ast.EnumConstructorMap.find_opt c_uid cases_d - with - | Some _ -> None - | None -> Some c_uid) - in - if Scopelang.Ast.EnumConstructorMap.is_empty missing_constructors - then - Errors.format_spanned_warning case_pos - "Unreachable match case, all constructors of the enumeration \ - %a are already specified" - Scopelang.Ast.EnumName.format_t e_uid; - (* The current used strategy is to replace the wildcard branch: - match foo with - | Case1 x -> x - | _ -> 1 - with: - let wildcard_payload = 1 in - match foo with - | Case1 x -> x - | Case2 -> wildcard_payload - ... - | CaseN -> wildcard_payload *) - (* Creates the wildcard payload *) - let ctxt, (payload_var, var_pos) = create_var None in - let case_body = - translate_expr scope inside_definition_of ctxt match_case_expr - in - let e_binder = - Bindlib.bind_mvar (Array.of_list [ payload_var ]) case_body - in - (* For each missing cases, binds the wildcard payload. *) - Scopelang.Ast.EnumConstructorMap.fold - (fun c_uid _ (cases_d, e_uid_opt, curr_index) -> - let case_expr = - bind_case_body c_uid e_uid ctxt var_pos case_body e_binder - in - ( Scopelang.Ast.EnumConstructorMap.add c_uid case_expr cases_d, - e_uid_opt, - curr_index + 1 )) - missing_constructors - (cases_d, Some e_uid, curr_index)) + (* For each missing cases, binds the wildcard payload. *) + Scopelang.Ast.EnumConstructorMap.fold + (fun c_uid _ (cases_d, e_uid_opt, curr_index) -> + let case_expr = + bind_case_body c_uid e_uid ctxt var_pos case_body e_binder + in + ( Scopelang.Ast.EnumConstructorMap.add c_uid case_expr cases_d, + e_uid_opt, + curr_index + 1 )) + missing_constructors + (cases_d, Some e_uid, curr_index)) in let expr, e_name, _ = List.fold_left bind_match_cases (Scopelang.Ast.EnumConstructorMap.empty, None, 0) cases in - (expr, Option.get e_name) + expr, Option.get e_name [@@ocamlformat "wrap-comments=false"] (** {1 Translating scope definitions} *) @@ -1057,24 +1025,21 @@ let merge_conditions (precond : Desugared.Ast.expr Pos.marked Bindlib.box option) (cond : Desugared.Ast.expr Pos.marked Bindlib.box option) (default_pos : Pos.t) : Desugared.Ast.expr Pos.marked Bindlib.box = - match (precond, cond) with + match precond, cond with | Some precond, Some cond -> - let op_term = - ( Desugared.Ast.EOp (Dcalc.Ast.Binop Dcalc.Ast.And), - Pos.get_position (Bindlib.unbox cond) ) - in - Bindlib.box_apply2 - (fun precond cond -> - ( Desugared.Ast.EApp (op_term, [ precond; cond ]), - Pos.get_position cond )) - precond cond + let op_term = + ( Desugared.Ast.EOp (Dcalc.Ast.Binop Dcalc.Ast.And), + Pos.get_position (Bindlib.unbox cond) ) + in + Bindlib.box_apply2 + (fun precond cond -> + Desugared.Ast.EApp (op_term, [precond; cond]), Pos.get_position cond) + precond cond | Some precond, None -> - Bindlib.box_apply - (fun precond -> (Pos.unmark precond, default_pos)) - precond + Bindlib.box_apply (fun precond -> Pos.unmark precond, default_pos) precond | None, Some cond -> cond | None, None -> - Bindlib.box (Desugared.Ast.ELit (Dcalc.Ast.LBool true), default_pos) + Bindlib.box (Desugared.Ast.ELit (Dcalc.Ast.LBool true), default_pos) (** Translates a surface definition into condition into a desugared {!type: Desugared.Ast.rule} *) @@ -1102,17 +1067,17 @@ let process_default (let def_key_typ = Name_resolution.get_def_typ ctxt (Pos.unmark def_key) in - match (Pos.unmark def_key_typ, param_uid) with + match Pos.unmark def_key_typ, param_uid with | Scopelang.Ast.TArrow (t_in, _), Some param_uid -> - Some (Pos.unmark param_uid, t_in) + Some (Pos.unmark param_uid, t_in) | Scopelang.Ast.TArrow _, None -> - Errors.raise_spanned_error - (Pos.get_position (Bindlib.unbox cons)) - "This definition has a function type but the parameter is missing" + Errors.raise_spanned_error + (Pos.get_position (Bindlib.unbox cons)) + "This definition has a function type but the parameter is missing" | _, Some _ -> - Errors.raise_spanned_error - (Pos.get_position (Bindlib.unbox cons)) - "This definition has a parameter but its type is not a function" + Errors.raise_spanned_error + (Pos.get_position (Bindlib.unbox cons)) + "This definition has a parameter but its type is not a function" | _ -> None); rule_exception_to_rules = exception_to_rules; rule_id; @@ -1142,10 +1107,10 @@ let process_def (* We add to the name resolution context the name of the parameter variable *) let param_uid, new_ctxt = match def.definition_parameter with - | None -> (None, ctxt) + | None -> None, ctxt | Some param -> - let ctxt, param_var = Name_resolution.add_def_local_var ctxt param in - (Some (Pos.same_pos_as param_var param), ctxt) + let ctxt, param_var = Name_resolution.add_def_local_var ctxt param in + Some (Pos.same_pos_as param_var param), ctxt in let scope_updated = let scope_def = Desugared.Ast.ScopeDefMap.find def_key scope.scope_defs in @@ -1153,28 +1118,27 @@ let process_def let parent_rules = match def.Ast.definition_exception_to with | NotAnException -> - (Desugared.Ast.RuleSet.empty, Pos.get_position def.Ast.definition_name) + Desugared.Ast.RuleSet.empty, Pos.get_position def.Ast.definition_name | UnlabeledException -> ( - match scope_def_ctxt.default_exception_rulename with - (* This should have been caught previously by - check_unlabeled_exception *) - | None | Some (Name_resolution.Ambiguous _) -> - assert false (* should not happen *) - | Some (Name_resolution.Unique (name, pos)) -> - (Desugared.Ast.RuleSet.singleton name, pos)) + match scope_def_ctxt.default_exception_rulename with + (* This should have been caught previously by + check_unlabeled_exception *) + | None | Some (Name_resolution.Ambiguous _) -> + assert false (* should not happen *) + | Some (Name_resolution.Unique (name, pos)) -> + Desugared.Ast.RuleSet.singleton name, pos) | ExceptionToLabel label -> ( - try - let label_id = - Desugared.Ast.IdentMap.find (Pos.unmark label) - scope_def_ctxt.label_idmap - in - ( Desugared.Ast.LabelMap.find label_id - scope_def.scope_def_label_groups, - Pos.get_position def.Ast.definition_name ) - with Not_found -> - Errors.raise_spanned_error (Pos.get_position label) - "Unknown label for the scope variable %a: \"%s\"" - Desugared.Ast.ScopeDef.format_t def_key (Pos.unmark label)) + try + let label_id = + Desugared.Ast.IdentMap.find (Pos.unmark label) + scope_def_ctxt.label_idmap + in + ( Desugared.Ast.LabelMap.find label_id scope_def.scope_def_label_groups, + Pos.get_position def.Ast.definition_name ) + with Not_found -> + Errors.raise_spanned_error (Pos.get_position label) + "Unknown label for the scope variable %a: \"%s\"" + Desugared.Ast.ScopeDef.format_t def_key (Pos.unmark label)) in let scope_def = { @@ -1225,24 +1189,24 @@ let process_assert (match ass.Ast.assertion_condition with | None -> ass.Ast.assertion_content | Some cond -> - ( Ast.IfThenElse - ( cond, - ass.Ast.assertion_content, - Pos.same_pos_as (Ast.Literal (Ast.LBool true)) cond ), - Pos.get_position cond )) + ( Ast.IfThenElse + ( cond, + ass.Ast.assertion_content, + Pos.same_pos_as (Ast.Literal (Ast.LBool true)) cond ), + Pos.get_position cond )) in let ass = match precond with | Some precond -> - Bindlib.box_apply2 - (fun precond ass -> - ( Desugared.Ast.EIfThenElse - ( precond, - ass, - Pos.same_pos_as (Desugared.Ast.ELit (Dcalc.Ast.LBool true)) - precond ), - Pos.get_position precond )) - precond ass + Bindlib.box_apply2 + (fun precond ass -> + ( Desugared.Ast.EIfThenElse + ( precond, + ass, + Pos.same_pos_as (Desugared.Ast.ELit (Dcalc.Ast.LBool true)) + precond ), + Pos.get_position precond )) + precond ass | None -> ass in let new_scope = @@ -1279,42 +1243,42 @@ let check_unlabeled_exception let scope_ctxt = Scopelang.Ast.ScopeMap.find scope ctxt.scopes in match Pos.unmark item with | Ast.Rule _ | Ast.Definition _ -> ( - let def_key, exception_to = - match Pos.unmark item with - | Ast.Rule rule -> - ( Name_resolution.get_def_key - (Pos.unmark rule.rule_name) - rule.rule_state scope ctxt - (Pos.get_position rule.rule_name), - rule.rule_exception_to ) - | Ast.Definition def -> - ( Name_resolution.get_def_key - (Pos.unmark def.definition_name) - def.definition_state scope ctxt - (Pos.get_position def.definition_name), - def.definition_exception_to ) - | _ -> assert false - (* should not happen *) - in - let scope_def_ctxt = - Desugared.Ast.ScopeDefMap.find def_key scope_ctxt.scope_defs_contexts - in - match exception_to with - | Ast.NotAnException | Ast.ExceptionToLabel _ -> () - (* If this is an unlabeled exception, we check that it has a unique - default definition *) - | Ast.UnlabeledException -> ( - match scope_def_ctxt.default_exception_rulename with - | None -> - Errors.raise_spanned_error (Pos.get_position item) - "This exception does not have a corresponding definition" - | Some (Ambiguous pos) -> - Errors.raise_multispanned_error - ([ (Some "Ambiguous exception", Pos.get_position item) ] - @ List.map (fun p -> (Some "Candidate definition", p)) pos) - "This exception can refer to several definitions. Try using \ - labels to disambiguate" - | Some (Unique _) -> ())) + let def_key, exception_to = + match Pos.unmark item with + | Ast.Rule rule -> + ( Name_resolution.get_def_key + (Pos.unmark rule.rule_name) + rule.rule_state scope ctxt + (Pos.get_position rule.rule_name), + rule.rule_exception_to ) + | Ast.Definition def -> + ( Name_resolution.get_def_key + (Pos.unmark def.definition_name) + def.definition_state scope ctxt + (Pos.get_position def.definition_name), + def.definition_exception_to ) + | _ -> assert false + (* should not happen *) + in + let scope_def_ctxt = + Desugared.Ast.ScopeDefMap.find def_key scope_ctxt.scope_defs_contexts + in + match exception_to with + | Ast.NotAnException | Ast.ExceptionToLabel _ -> () + (* If this is an unlabeled exception, we check that it has a unique default + definition *) + | Ast.UnlabeledException -> ( + match scope_def_ctxt.default_exception_rulename with + | None -> + Errors.raise_spanned_error (Pos.get_position item) + "This exception does not have a corresponding definition" + | Some (Ambiguous pos) -> + Errors.raise_multispanned_error + ([Some "Ambiguous exception", Pos.get_position item] + @ List.map (fun p -> Some "Candidate definition", p) pos) + "This exception can refer to several definitions. Try using labels \ + to disambiguate" + | Some (Unique _) -> ())) | _ -> () (** Translates a surface scope use, which is a bunch of definitions *) @@ -1373,11 +1337,10 @@ let desugar_program (ctxt : Name_resolution.context) (prgm : Ast.program) : in match v_sig.var_sig_states_list with | [] -> - Desugared.Ast.ScopeVarMap.add v Desugared.Ast.WholeVar - acc + Desugared.Ast.ScopeVarMap.add v Desugared.Ast.WholeVar acc | states -> - Desugared.Ast.ScopeVarMap.add v - (Desugared.Ast.States states) acc) + Desugared.Ast.ScopeVarMap.add v + (Desugared.Ast.States states) acc) s_context.Name_resolution.var_idmap Desugared.Ast.ScopeVarMap.empty; Desugared.Ast.scope_sub_scopes = @@ -1394,70 +1357,70 @@ let desugar_program (ctxt : Name_resolution.context) (prgm : Ast.program) : in match v_sig.var_sig_states_list with | [] -> - let def_key = Desugared.Ast.ScopeDef.Var (v, None) in - Desugared.Ast.ScopeDefMap.add def_key - { - Desugared.Ast.scope_def_rules = - Desugared.Ast.RuleMap.empty; - Desugared.Ast.scope_def_typ = v_sig.var_sig_typ; - Desugared.Ast.scope_def_label_groups = - Name_resolution.label_groups ctxt s_uid def_key; - Desugared.Ast.scope_def_is_condition = - v_sig.var_sig_is_condition; - Desugared.Ast.scope_def_io = - attribute_to_io v_sig.var_sig_io; - } - acc + let def_key = Desugared.Ast.ScopeDef.Var (v, None) in + Desugared.Ast.ScopeDefMap.add def_key + { + Desugared.Ast.scope_def_rules = + Desugared.Ast.RuleMap.empty; + Desugared.Ast.scope_def_typ = v_sig.var_sig_typ; + Desugared.Ast.scope_def_label_groups = + Name_resolution.label_groups ctxt s_uid def_key; + Desugared.Ast.scope_def_is_condition = + v_sig.var_sig_is_condition; + Desugared.Ast.scope_def_io = + attribute_to_io v_sig.var_sig_io; + } + acc | states -> - fst - (List.fold_left - (fun (acc, i) state -> - let def_key = - Desugared.Ast.ScopeDef.Var (v, Some state) - in - ( Desugared.Ast.ScopeDefMap.add def_key - { - Desugared.Ast.scope_def_rules = - Desugared.Ast.RuleMap.empty; - Desugared.Ast.scope_def_typ = - v_sig.var_sig_typ; - Desugared.Ast.scope_def_label_groups = - Name_resolution.label_groups ctxt - s_uid def_key; - Desugared.Ast.scope_def_is_condition = - v_sig.var_sig_is_condition; - Desugared.Ast.scope_def_io = - (* The first state should have the - input I/O of the original variable, - and the last state should have the - output I/O of the original - variable. All intermediate states - shall have "internal" I/O.*) - (let original_io = - attribute_to_io v_sig.var_sig_io - in - let io_input = - if i = 0 then original_io.io_input - else - ( Scopelang.Ast.NoInput, - Pos.get_position - (Desugared.Ast.StateName - .get_info state) ) - in - let io_output = - if i = List.length states - 1 then - original_io.io_output - else - ( false, - Pos.get_position - (Desugared.Ast.StateName - .get_info state) ) - in - { io_input; io_output }); - } - acc, - i + 1 )) - (acc, 0) states)) + fst + (List.fold_left + (fun (acc, i) state -> + let def_key = + Desugared.Ast.ScopeDef.Var (v, Some state) + in + ( Desugared.Ast.ScopeDefMap.add def_key + { + Desugared.Ast.scope_def_rules = + Desugared.Ast.RuleMap.empty; + Desugared.Ast.scope_def_typ = + v_sig.var_sig_typ; + Desugared.Ast.scope_def_label_groups = + Name_resolution.label_groups ctxt s_uid + def_key; + Desugared.Ast.scope_def_is_condition = + v_sig.var_sig_is_condition; + Desugared.Ast.scope_def_io = + (* The first state should have the input + I/O of the original variable, and the + last state should have the output I/O + of the original variable. All + intermediate states shall have + "internal" I/O.*) + (let original_io = + attribute_to_io v_sig.var_sig_io + in + let io_input = + if i = 0 then original_io.io_input + else + ( Scopelang.Ast.NoInput, + Pos.get_position + (Desugared.Ast.StateName + .get_info state) ) + in + let io_output = + if i = List.length states - 1 then + original_io.io_output + else + ( false, + Pos.get_position + (Desugared.Ast.StateName + .get_info state) ) + in + { io_input; io_output }); + } + acc, + i + 1 )) + (acc, 0) states)) s_context.Name_resolution.var_idmap Desugared.Ast.ScopeDefMap.empty in @@ -1502,20 +1465,20 @@ let desugar_program (ctxt : Name_resolution.context) (prgm : Ast.program) : } in let rec processer_structure - (prgm : Desugared.Ast.program) (item : Ast.law_structure) : - Desugared.Ast.program = + (prgm : Desugared.Ast.program) + (item : Ast.law_structure) : Desugared.Ast.program = match item with | LawHeading (_, children) -> - List.fold_left - (fun prgm child -> processer_structure prgm child) - prgm children + List.fold_left + (fun prgm child -> processer_structure prgm child) + prgm children | CodeBlock (block, _, _) -> - List.fold_left - (fun prgm item -> - match Pos.unmark item with - | Ast.ScopeUse use -> process_scope_use ctxt prgm use - | _ -> prgm) - prgm block + List.fold_left + (fun prgm item -> + match Pos.unmark item with + | Ast.ScopeUse use -> process_scope_use ctxt prgm use + | _ -> prgm) + prgm block | LawInclude _ | LawText _ -> prgm in List.fold_left processer_structure empty_prgm prgm.program_items diff --git a/compiler/surface/fill_positions.ml b/compiler/surface/fill_positions.ml index f1caec81..f22e7273 100644 --- a/compiler/surface/fill_positions.ml +++ b/compiler/surface/fill_positions.ml @@ -22,7 +22,7 @@ let fill_pos_with_legislative_info (p : Ast.program) : Ast.program = inherit [_] Ast.program_map as super method! visit_marked f env x = - (f env (Pos.unmark x), Pos.overwrite_law_info (Pos.get_position x) env) + f env (Pos.unmark x), Pos.overwrite_law_info (Pos.get_position x) env method! visit_LawHeading (env : string list) diff --git a/compiler/surface/lexer_common.ml b/compiler/surface/lexer_common.ml index c081d192..2b8f1705 100644 --- a/compiler/surface/lexer_common.ml +++ b/compiler/surface/lexer_common.ml @@ -73,27 +73,27 @@ let raise_lexer_error (loc : Pos.t) (token : string) = (English, French, etc.) *) let token_list_language_agnostic : (string * token) list = [ - (".", DOT); - ("<=", LESSER_EQUAL); - (">=", GREATER_EQUAL); - (">", GREATER); - ("!=", NOT_EQUAL); - ("=", EQUAL); - ("(", LPAREN); - (")", RPAREN); - ("{", LBRACKET); - ("}", RBRACKET); - ("{", LSQUARE); - ("}", RSQUARE); - ("+", PLUS); - ("-", MINUS); - ("*", MULT); - ("/", DIV); - ("|", VERTICAL); - (":", COLON); - (";", SEMICOLON); - ("--", ALT); - ("++", PLUSPLUS); + ".", DOT; + "<=", LESSER_EQUAL; + ">=", GREATER_EQUAL; + ">", GREATER; + "!=", NOT_EQUAL; + "=", EQUAL; + "(", LPAREN; + ")", RPAREN; + "{", LBRACKET; + "}", RBRACKET; + "{", LSQUARE; + "}", RSQUARE; + "+", PLUS; + "-", MINUS; + "*", MULT; + "/", DIV; + "|", VERTICAL; + ":", COLON; + ";", SEMICOLON; + "--", ALT; + "++", PLUSPLUS; ] module type LocalisedLexer = sig diff --git a/compiler/surface/name_resolution.ml b/compiler/surface/name_resolution.ml index 4ef0b63d..db403841 100644 --- a/compiler/surface/name_resolution.ml +++ b/compiler/surface/name_resolution.ml @@ -104,7 +104,7 @@ let raise_unsupported_feature (msg : string) (pos : Pos.t) = let raise_unknown_identifier (msg : string) (ident : ident Pos.marked) = Errors.raise_spanned_error (Pos.get_position ident) "\"%s\": unknown identifier %s" - (Utils.Cli.with_style [ ANSITerminal.yellow ] "%s" (Pos.unmark ident)) + (Utils.Cli.with_style [ANSITerminal.yellow] "%s" (Pos.unmark ident)) msg (** Gets the type associated to an uid *) @@ -127,10 +127,10 @@ let get_var_uid let scope = Scopelang.Ast.ScopeMap.find scope_uid ctxt.scopes in match Desugared.Ast.IdentMap.find_opt x scope.var_idmap with | None -> - raise_unknown_identifier - (Format.asprintf "for a variable of scope %a" - Scopelang.Ast.ScopeName.format_t scope_uid) - (x, pos) + raise_unknown_identifier + (Format.asprintf "for a variable of scope %a" + Scopelang.Ast.ScopeName.format_t scope_uid) + (x, pos) | Some uid -> uid (** Get the subscope uid inside the scope given in argument *) @@ -146,8 +146,9 @@ let get_subscope_uid (** [is_subscope_uid scope_uid ctxt y] returns true if [y] belongs to the subscopes of [scope_uid]. *) let is_subscope_uid - (scope_uid : Scopelang.Ast.ScopeName.t) (ctxt : context) (y : ident) : bool - = + (scope_uid : Scopelang.Ast.ScopeName.t) + (ctxt : context) + (y : ident) : bool = let scope = Scopelang.Ast.ScopeMap.find scope_uid ctxt.scopes in Desugared.Ast.IdentMap.mem y scope.sub_scopes_idmap @@ -169,7 +170,7 @@ let get_def_typ (ctxt : context) (def : Desugared.Ast.ScopeDef.t) : (* we don't need to look at the subscope prefix because [x] is already the uid referring back to the original subscope *) | Desugared.Ast.ScopeDef.Var (x, _) -> - get_var_typ ctxt x + get_var_typ ctxt x let is_def_cond (ctxt : context) (def : Desugared.Ast.ScopeDef.t) : bool = match def with @@ -177,7 +178,7 @@ let is_def_cond (ctxt : context) (def : Desugared.Ast.ScopeDef.t) : bool = (* we don't need to look at the subscope prefix because [x] is already the uid referring back to the original subscope *) | Desugared.Ast.ScopeDef.Var (x, _) -> - is_var_cond ctxt x + is_var_cond ctxt x let label_groups (ctxt : context) @@ -204,78 +205,76 @@ let process_subscope_decl Desugared.Ast.IdentMap.find_opt subscope scope_ctxt.sub_scopes_idmap with | Some use -> - Errors.raise_multispanned_error - [ - ( Some "first use", - Pos.get_position (Scopelang.Ast.SubScopeName.get_info use) ); - (Some "second use", s_pos); - ] - "Subscope name \"%a\" already used" - (Utils.Cli.format_with_style [ ANSITerminal.yellow ]) - subscope + Errors.raise_multispanned_error + [ + ( Some "first use", + Pos.get_position (Scopelang.Ast.SubScopeName.get_info use) ); + Some "second use", s_pos; + ] + "Subscope name \"%a\" already used" + (Utils.Cli.format_with_style [ANSITerminal.yellow]) + subscope | None -> - let sub_scope_uid = Scopelang.Ast.SubScopeName.fresh (name, name_pos) in - let original_subscope_uid = - match Desugared.Ast.IdentMap.find_opt subscope ctxt.scope_idmap with - | None -> raise_unknown_identifier "for a scope" (subscope, s_pos) - | Some id -> id - in - let scope_ctxt = - { - scope_ctxt with - sub_scopes_idmap = - Desugared.Ast.IdentMap.add name sub_scope_uid - scope_ctxt.sub_scopes_idmap; - sub_scopes = - Scopelang.Ast.SubScopeMap.add sub_scope_uid original_subscope_uid - scope_ctxt.sub_scopes; - } - in + let sub_scope_uid = Scopelang.Ast.SubScopeName.fresh (name, name_pos) in + let original_subscope_uid = + match Desugared.Ast.IdentMap.find_opt subscope ctxt.scope_idmap with + | None -> raise_unknown_identifier "for a scope" (subscope, s_pos) + | Some id -> id + in + let scope_ctxt = { - ctxt with - scopes = Scopelang.Ast.ScopeMap.add scope scope_ctxt ctxt.scopes; + scope_ctxt with + sub_scopes_idmap = + Desugared.Ast.IdentMap.add name sub_scope_uid + scope_ctxt.sub_scopes_idmap; + sub_scopes = + Scopelang.Ast.SubScopeMap.add sub_scope_uid original_subscope_uid + scope_ctxt.sub_scopes; } + in + { + ctxt with + scopes = Scopelang.Ast.ScopeMap.add scope scope_ctxt ctxt.scopes; + } let is_type_cond ((typ, _) : Ast.typ Pos.marked) = match typ with | Ast.Base Ast.Condition | Ast.Func { arg_typ = _; return_typ = Ast.Condition, _ } -> - true + true | _ -> false (** Process a basic type (all types except function types) *) let rec process_base_typ - (ctxt : context) ((typ, typ_pos) : Ast.base_typ Pos.marked) : - Scopelang.Ast.typ Pos.marked = + (ctxt : context) + ((typ, typ_pos) : Ast.base_typ Pos.marked) : Scopelang.Ast.typ Pos.marked = match typ with - | Ast.Condition -> (Scopelang.Ast.TLit TBool, typ_pos) + | Ast.Condition -> Scopelang.Ast.TLit TBool, typ_pos | Ast.Data (Ast.Collection t) -> - ( Scopelang.Ast.TArray - (Pos.unmark - (process_base_typ ctxt - (Ast.Data (Pos.unmark t), Pos.get_position t))), - typ_pos ) + ( Scopelang.Ast.TArray + (Pos.unmark + (process_base_typ ctxt (Ast.Data (Pos.unmark t), Pos.get_position t))), + typ_pos ) | Ast.Data (Ast.Primitive prim) -> ( - match prim with - | Ast.Integer -> (Scopelang.Ast.TLit TInt, typ_pos) - | Ast.Decimal -> (Scopelang.Ast.TLit TRat, typ_pos) - | Ast.Money -> (Scopelang.Ast.TLit TMoney, typ_pos) - | Ast.Duration -> (Scopelang.Ast.TLit TDuration, typ_pos) - | Ast.Date -> (Scopelang.Ast.TLit TDate, typ_pos) - | Ast.Boolean -> (Scopelang.Ast.TLit TBool, typ_pos) - | Ast.Text -> raise_unsupported_feature "text type" typ_pos - | Ast.Named ident -> ( - match Desugared.Ast.IdentMap.find_opt ident ctxt.struct_idmap with - | Some s_uid -> (Scopelang.Ast.TStruct s_uid, typ_pos) - | None -> ( - match Desugared.Ast.IdentMap.find_opt ident ctxt.enum_idmap with - | Some e_uid -> (Scopelang.Ast.TEnum e_uid, typ_pos) - | None -> - Errors.raise_spanned_error typ_pos - "Unknown type \"%a\", not a struct or enum previously \ - declared" - (Utils.Cli.format_with_style [ ANSITerminal.yellow ]) - ident))) + match prim with + | Ast.Integer -> Scopelang.Ast.TLit TInt, typ_pos + | Ast.Decimal -> Scopelang.Ast.TLit TRat, typ_pos + | Ast.Money -> Scopelang.Ast.TLit TMoney, typ_pos + | Ast.Duration -> Scopelang.Ast.TLit TDuration, typ_pos + | Ast.Date -> Scopelang.Ast.TLit TDate, typ_pos + | Ast.Boolean -> Scopelang.Ast.TLit TBool, typ_pos + | Ast.Text -> raise_unsupported_feature "text type" typ_pos + | Ast.Named ident -> ( + match Desugared.Ast.IdentMap.find_opt ident ctxt.struct_idmap with + | Some s_uid -> Scopelang.Ast.TStruct s_uid, typ_pos + | None -> ( + match Desugared.Ast.IdentMap.find_opt ident ctxt.enum_idmap with + | Some e_uid -> Scopelang.Ast.TEnum e_uid, typ_pos + | None -> + Errors.raise_spanned_error typ_pos + "Unknown type \"%a\", not a struct or enum previously declared" + (Utils.Cli.format_with_style [ANSITerminal.yellow]) + ident))) (** Process a type (function or not) *) let process_type (ctxt : context) ((typ, typ_pos) : Ast.typ Pos.marked) : @@ -283,9 +282,9 @@ let process_type (ctxt : context) ((typ, typ_pos) : Ast.typ Pos.marked) : match typ with | Ast.Base base_typ -> process_base_typ ctxt (base_typ, typ_pos) | Ast.Func { arg_typ; return_typ } -> - ( Scopelang.Ast.TArrow - (process_base_typ ctxt arg_typ, process_base_typ ctxt return_typ), - typ_pos ) + ( Scopelang.Ast.TArrow + (process_base_typ ctxt arg_typ, process_base_typ ctxt return_typ), + typ_pos ) (** Process data declaration *) let process_data_decl @@ -299,47 +298,46 @@ let process_data_decl let scope_ctxt = Scopelang.Ast.ScopeMap.find scope ctxt.scopes in match Desugared.Ast.IdentMap.find_opt name scope_ctxt.var_idmap with | Some use -> - Errors.raise_multispanned_error - [ - ( Some "First use:", - Pos.get_position (Desugared.Ast.ScopeVar.get_info use) ); - (Some "Second use:", pos); - ] - "Variable name \"%a\" already used" - (Utils.Cli.format_with_style [ ANSITerminal.yellow ]) - name + Errors.raise_multispanned_error + [ + Some "First use:", Pos.get_position (Desugared.Ast.ScopeVar.get_info use); + Some "Second use:", pos; + ] + "Variable name \"%a\" already used" + (Utils.Cli.format_with_style [ANSITerminal.yellow]) + name | None -> - let uid = Desugared.Ast.ScopeVar.fresh (name, pos) in - let scope_ctxt = - { - scope_ctxt with - var_idmap = Desugared.Ast.IdentMap.add name uid scope_ctxt.var_idmap; - } - in - let states_idmap, states_list = - List.fold_right - (fun state_id (states_idmap, states_list) -> - let state_uid = Desugared.Ast.StateName.fresh state_id in - ( Desugared.Ast.IdentMap.add (Pos.unmark state_id) state_uid - states_idmap, - state_uid :: states_list )) - decl.scope_decl_context_item_states - (Desugared.Ast.IdentMap.empty, []) - in + let uid = Desugared.Ast.ScopeVar.fresh (name, pos) in + let scope_ctxt = { - ctxt with - scopes = Scopelang.Ast.ScopeMap.add scope scope_ctxt ctxt.scopes; - var_typs = - Desugared.Ast.ScopeVarMap.add uid - { - var_sig_typ = data_typ; - var_sig_is_condition = is_cond; - var_sig_io = decl.scope_decl_context_item_attribute; - var_sig_states_idmap = states_idmap; - var_sig_states_list = states_list; - } - ctxt.var_typs; + scope_ctxt with + var_idmap = Desugared.Ast.IdentMap.add name uid scope_ctxt.var_idmap; } + in + let states_idmap, states_list = + List.fold_right + (fun state_id (states_idmap, states_list) -> + let state_uid = Desugared.Ast.StateName.fresh state_id in + ( Desugared.Ast.IdentMap.add (Pos.unmark state_id) state_uid + states_idmap, + state_uid :: states_list )) + decl.scope_decl_context_item_states + (Desugared.Ast.IdentMap.empty, []) + in + { + ctxt with + scopes = Scopelang.Ast.ScopeMap.add scope scope_ctxt ctxt.scopes; + var_typs = + Desugared.Ast.ScopeVarMap.add uid + { + var_sig_typ = data_typ; + var_sig_is_condition = is_cond; + var_sig_io = decl.scope_decl_context_item_attribute; + var_sig_states_idmap = states_idmap; + var_sig_states_list = states_list; + } + ctxt.var_typs; + } (** Process an item declaration *) let process_item_decl @@ -362,7 +360,7 @@ let add_def_local_var (ctxt : context) (name : ident Pos.marked) : ctxt.local_var_idmap; } in - (ctxt, local_var_uid) + ctxt, local_var_uid (** Process a scope declaration *) let process_scope_decl (ctxt : context) (decl : Ast.scope_decl) : context = @@ -398,7 +396,7 @@ let process_struct_decl (ctxt : context) (sdecl : Ast.struct_decl) : context = match uids with | None -> Some (Scopelang.Ast.StructMap.singleton s_uid f_uid) | Some uids -> - Some (Scopelang.Ast.StructMap.add s_uid f_uid uids)) + Some (Scopelang.Ast.StructMap.add s_uid f_uid uids)) ctxt.field_idmap; } in @@ -409,14 +407,14 @@ let process_struct_decl (ctxt : context) (sdecl : Ast.struct_decl) : context = (fun fields -> match fields with | None -> - Some - (Scopelang.Ast.StructFieldMap.singleton f_uid - (process_type ctxt fdecl.Ast.struct_decl_field_typ)) + Some + (Scopelang.Ast.StructFieldMap.singleton f_uid + (process_type ctxt fdecl.Ast.struct_decl_field_typ)) | Some fields -> - Some - (Scopelang.Ast.StructFieldMap.add f_uid - (process_type ctxt fdecl.Ast.struct_decl_field_typ) - fields)) + Some + (Scopelang.Ast.StructFieldMap.add f_uid + (process_type ctxt fdecl.Ast.struct_decl_field_typ) + fields)) ctxt.structs; }) ctxt sdecl.struct_decl_fields @@ -457,14 +455,14 @@ let process_enum_decl (ctxt : context) (edecl : Ast.enum_decl) : context = (fun cases -> let typ = match cdecl.Ast.enum_decl_case_typ with - | None -> (Scopelang.Ast.TLit TUnit, cdecl_pos) + | None -> Scopelang.Ast.TLit TUnit, cdecl_pos | Some typ -> process_type ctxt typ in match cases with | None -> - Some (Scopelang.Ast.EnumConstructorMap.singleton c_uid typ) + Some (Scopelang.Ast.EnumConstructorMap.singleton c_uid typ) | Some fields -> - Some (Scopelang.Ast.EnumConstructorMap.add c_uid typ fields)) + Some (Scopelang.Ast.EnumConstructorMap.add c_uid typ fields)) ctxt.enums; }) ctxt edecl.enum_decl_cases @@ -475,71 +473,70 @@ let process_name_item (ctxt : context) (item : Ast.code_item Pos.marked) : let raise_already_defined_error (use : Uid.MarkedString.info) name pos msg = Errors.raise_multispanned_error [ - (Some "First definition:", Pos.get_position use); - (Some "Second definition:", pos); + Some "First definition:", Pos.get_position use; + Some "Second definition:", pos; ] "%s name \"%a\" already defined" msg - (Utils.Cli.format_with_style [ ANSITerminal.yellow ]) + (Utils.Cli.format_with_style [ANSITerminal.yellow]) name in match Pos.unmark item with | ScopeDecl decl -> ( - let name, pos = decl.scope_decl_name in - (* Checks if the name is already used *) - match Desugared.Ast.IdentMap.find_opt name ctxt.scope_idmap with - | Some use -> - raise_already_defined_error - (Scopelang.Ast.ScopeName.get_info use) - name pos "scope" - | None -> - let scope_uid = Scopelang.Ast.ScopeName.fresh (name, pos) in - { - ctxt with - scope_idmap = - Desugared.Ast.IdentMap.add name scope_uid ctxt.scope_idmap; - scopes = - Scopelang.Ast.ScopeMap.add scope_uid - { - var_idmap = Desugared.Ast.IdentMap.empty; - scope_defs_contexts = Desugared.Ast.ScopeDefMap.empty; - sub_scopes_idmap = Desugared.Ast.IdentMap.empty; - sub_scopes = Scopelang.Ast.SubScopeMap.empty; - } - ctxt.scopes; - }) + let name, pos = decl.scope_decl_name in + (* Checks if the name is already used *) + match Desugared.Ast.IdentMap.find_opt name ctxt.scope_idmap with + | Some use -> + raise_already_defined_error + (Scopelang.Ast.ScopeName.get_info use) + name pos "scope" + | None -> + let scope_uid = Scopelang.Ast.ScopeName.fresh (name, pos) in + { + ctxt with + scope_idmap = Desugared.Ast.IdentMap.add name scope_uid ctxt.scope_idmap; + scopes = + Scopelang.Ast.ScopeMap.add scope_uid + { + var_idmap = Desugared.Ast.IdentMap.empty; + scope_defs_contexts = Desugared.Ast.ScopeDefMap.empty; + sub_scopes_idmap = Desugared.Ast.IdentMap.empty; + sub_scopes = Scopelang.Ast.SubScopeMap.empty; + } + ctxt.scopes; + }) | StructDecl sdecl -> ( - let name, pos = sdecl.struct_decl_name in - match Desugared.Ast.IdentMap.find_opt name ctxt.struct_idmap with - | Some use -> - raise_already_defined_error - (Scopelang.Ast.StructName.get_info use) - name pos "struct" - | None -> - let s_uid = Scopelang.Ast.StructName.fresh sdecl.struct_decl_name in - { - ctxt with - struct_idmap = - Desugared.Ast.IdentMap.add - (Pos.unmark sdecl.struct_decl_name) - s_uid ctxt.struct_idmap; - }) + let name, pos = sdecl.struct_decl_name in + match Desugared.Ast.IdentMap.find_opt name ctxt.struct_idmap with + | Some use -> + raise_already_defined_error + (Scopelang.Ast.StructName.get_info use) + name pos "struct" + | None -> + let s_uid = Scopelang.Ast.StructName.fresh sdecl.struct_decl_name in + { + ctxt with + struct_idmap = + Desugared.Ast.IdentMap.add + (Pos.unmark sdecl.struct_decl_name) + s_uid ctxt.struct_idmap; + }) | EnumDecl edecl -> ( - let name, pos = edecl.enum_decl_name in - match Desugared.Ast.IdentMap.find_opt name ctxt.enum_idmap with - | Some use -> - raise_already_defined_error - (Scopelang.Ast.EnumName.get_info use) - name pos "enum" - | None -> - let e_uid = Scopelang.Ast.EnumName.fresh edecl.enum_decl_name in + let name, pos = edecl.enum_decl_name in + match Desugared.Ast.IdentMap.find_opt name ctxt.enum_idmap with + | Some use -> + raise_already_defined_error + (Scopelang.Ast.EnumName.get_info use) + name pos "enum" + | None -> + let e_uid = Scopelang.Ast.EnumName.fresh edecl.enum_decl_name in - { - ctxt with - enum_idmap = - Desugared.Ast.IdentMap.add - (Pos.unmark edecl.enum_decl_name) - e_uid ctxt.enum_idmap; - }) + { + ctxt with + enum_idmap = + Desugared.Ast.IdentMap.add + (Pos.unmark edecl.enum_decl_name) + e_uid ctxt.enum_idmap; + }) | ScopeUse _ -> ctxt (** Process a code item that is a declaration *) @@ -565,9 +562,9 @@ let rec process_law_structure (process_item : context -> Ast.code_item Pos.marked -> context) : context = match s with | Ast.LawHeading (_, children) -> - List.fold_left - (fun ctxt child -> process_law_structure ctxt child process_item) - ctxt children + List.fold_left + (fun ctxt child -> process_law_structure ctxt child process_item) + ctxt children | Ast.CodeBlock (block, _, _) -> process_code_block ctxt block process_item | Ast.LawInclude _ | Ast.LawText _ -> ctxt @@ -581,57 +578,54 @@ let get_def_key (default_pos : Pos.t) : Desugared.Ast.ScopeDef.t = let scope_ctxt = Scopelang.Ast.ScopeMap.find scope_uid ctxt.scopes in match name with - | [ x ] -> - let x_uid = get_var_uid scope_uid ctxt x in - let var_sig = Desugared.Ast.ScopeVarMap.find x_uid ctxt.var_typs in - Desugared.Ast.ScopeDef.Var - ( x_uid, - match state with - | Some state -> ( - try - Some - (Desugared.Ast.IdentMap.find (Pos.unmark state) - var_sig.var_sig_states_idmap) - with Not_found -> - Errors.raise_multispanned_error - [ - (None, Pos.get_position state); - ( Some "Variable declaration:", - Pos.get_position (Desugared.Ast.ScopeVar.get_info x_uid) - ); - ] - "This identifier is not a state declared for variable %a." - Desugared.Ast.ScopeVar.format_t x_uid) - | None -> - if - not - (Desugared.Ast.IdentMap.is_empty var_sig.var_sig_states_idmap) - then - Errors.raise_multispanned_error - [ - (None, Pos.get_position x); - ( Some "Variable declaration:", - Pos.get_position (Desugared.Ast.ScopeVar.get_info x_uid) - ); - ] - "This definition does not indicate which state has to be \ - considered for variable %a." - Desugared.Ast.ScopeVar.format_t x_uid - else None ) - | [ y; x ] -> - let subscope_uid : Scopelang.Ast.SubScopeName.t = - get_subscope_uid scope_uid ctxt y - in - let subscope_real_uid : Scopelang.Ast.ScopeName.t = - Scopelang.Ast.SubScopeMap.find subscope_uid scope_ctxt.sub_scopes - in - let x_uid = get_var_uid subscope_real_uid ctxt x in - Desugared.Ast.ScopeDef.SubScopeVar (subscope_uid, x_uid) + | [x] -> + let x_uid = get_var_uid scope_uid ctxt x in + let var_sig = Desugared.Ast.ScopeVarMap.find x_uid ctxt.var_typs in + Desugared.Ast.ScopeDef.Var + ( x_uid, + match state with + | Some state -> ( + try + Some + (Desugared.Ast.IdentMap.find (Pos.unmark state) + var_sig.var_sig_states_idmap) + with Not_found -> + Errors.raise_multispanned_error + [ + None, Pos.get_position state; + ( Some "Variable declaration:", + Pos.get_position (Desugared.Ast.ScopeVar.get_info x_uid) ); + ] + "This identifier is not a state declared for variable %a." + Desugared.Ast.ScopeVar.format_t x_uid) + | None -> + if not (Desugared.Ast.IdentMap.is_empty var_sig.var_sig_states_idmap) + then + Errors.raise_multispanned_error + [ + None, Pos.get_position x; + ( Some "Variable declaration:", + Pos.get_position (Desugared.Ast.ScopeVar.get_info x_uid) ); + ] + "This definition does not indicate which state has to be \ + considered for variable %a." + Desugared.Ast.ScopeVar.format_t x_uid + else None ) + | [y; x] -> + let subscope_uid : Scopelang.Ast.SubScopeName.t = + get_subscope_uid scope_uid ctxt y + in + let subscope_real_uid : Scopelang.Ast.ScopeName.t = + Scopelang.Ast.SubScopeMap.find subscope_uid scope_ctxt.sub_scopes + in + let x_uid = get_var_uid subscope_real_uid ctxt x in + Desugared.Ast.ScopeDef.SubScopeVar (subscope_uid, x_uid) | _ -> Errors.raise_spanned_error default_pos "Structs are not handled yet" let process_definition - (ctxt : context) (s_name : Scopelang.Ast.ScopeName.t) (d : Ast.definition) : - context = + (ctxt : context) + (s_name : Scopelang.Ast.ScopeName.t) + (d : Ast.definition) : context = (* We update the definition context inside the big context *) { ctxt with @@ -647,124 +641,115 @@ let process_definition match s_ctxt with | None -> assert false (* should not happen *) | Some s_ctxt -> - Some - { - s_ctxt with - scope_defs_contexts = - Desugared.Ast.ScopeDefMap.update def_key - (fun def_key_ctx -> - let def_key_ctx : scope_def_context = - Option.fold - ~none: - { - (* Here, this is the first time we encounter a - definition for this definition key *) - default_exception_rulename = None; - label_idmap = Desugared.Ast.IdentMap.empty; - label_groups = Desugared.Ast.LabelMap.empty; - } - ~some:(fun x -> x) - def_key_ctx - in - (* First, we update the def key context with information - about the definition's label*) - let def_key_ctx = - match d.Ast.definition_label with - | None -> def_key_ctx - | Some label -> - let new_label_idmap = - Desugared.Ast.IdentMap.update (Pos.unmark label) - (fun existing_label -> - match existing_label with - | Some existing_label -> Some existing_label - | None -> - Some - (Desugared.Ast.LabelName.fresh label)) - def_key_ctx.label_idmap - in - let label_id = - Desugared.Ast.IdentMap.find (Pos.unmark label) - new_label_idmap - in + Some + { + s_ctxt with + scope_defs_contexts = + Desugared.Ast.ScopeDefMap.update def_key + (fun def_key_ctx -> + let def_key_ctx : scope_def_context = + Option.fold + ~none: + { + (* Here, this is the first time we encounter a + definition for this definition key *) + default_exception_rulename = None; + label_idmap = Desugared.Ast.IdentMap.empty; + label_groups = Desugared.Ast.LabelMap.empty; + } + ~some:(fun x -> x) + def_key_ctx + in + (* First, we update the def key context with information + about the definition's label*) + let def_key_ctx = + match d.Ast.definition_label with + | None -> def_key_ctx + | Some label -> + let new_label_idmap = + Desugared.Ast.IdentMap.update (Pos.unmark label) + (fun existing_label -> + match existing_label with + | Some existing_label -> Some existing_label + | None -> + Some (Desugared.Ast.LabelName.fresh label)) + def_key_ctx.label_idmap + in + let label_id = + Desugared.Ast.IdentMap.find (Pos.unmark label) + new_label_idmap + in + { + def_key_ctx with + label_idmap = new_label_idmap; + label_groups = + Desugared.Ast.LabelMap.update label_id + (fun group -> + match group with + | None -> + Some + (Desugared.Ast.RuleSet.singleton + d.definition_id) + | Some existing_group -> + Some + (Desugared.Ast.RuleSet.add d.definition_id + existing_group)) + def_key_ctx.label_groups; + } + in + (* And second, we update the map of default rulenames for + unlabeled exceptions *) + let def_key_ctx = + match d.Ast.definition_exception_to with + (* If this definition is an exception, it cannot be a + default definition *) + | UnlabeledException | ExceptionToLabel _ -> def_key_ctx + (* If it is not an exception, we need to distinguish + between several cases *) + | NotAnException -> ( + match def_key_ctx.default_exception_rulename with + (* There was already a default definition for this + key. If we need it, it is ambiguous *) + | Some old -> + { + def_key_ctx with + default_exception_rulename = + Some + (Ambiguous + ([Pos.get_position d.definition_name] + @ + match old with + | Ambiguous old -> old + | Unique (_, pos) -> [pos])); + } + (* No definition has been set yet for this key *) + | None -> ( + match d.Ast.definition_label with + (* This default definition has a label. This is not + allowed for unlabeled exceptions *) + | Some _ -> { def_key_ctx with - label_idmap = new_label_idmap; - label_groups = - Desugared.Ast.LabelMap.update label_id - (fun group -> - match group with - | None -> - Some - (Desugared.Ast.RuleSet.singleton - d.definition_id) - | Some existing_group -> - Some - (Desugared.Ast.RuleSet.add - d.definition_id existing_group)) - def_key_ctx.label_groups; + default_exception_rulename = + Some + (Ambiguous + [Pos.get_position d.definition_name]); } - in - (* And second, we update the map of default rulenames - for unlabeled exceptions *) - let def_key_ctx = - match d.Ast.definition_exception_to with - (* If this definition is an exception, it cannot be a - default definition *) - | UnlabeledException | ExceptionToLabel _ -> - def_key_ctx - (* If it is not an exception, we need to distinguish - between several cases *) - | NotAnException -> ( - match def_key_ctx.default_exception_rulename with - (* There was already a default definition for this - key. If we need it, it is ambiguous *) - | Some old -> - { - def_key_ctx with - default_exception_rulename = - Some - (Ambiguous - ([ - Pos.get_position d.definition_name; - ] - @ - match old with - | Ambiguous old -> old - | Unique (_, pos) -> [ pos ])); - } - (* No definition has been set yet for this key *) - | None -> ( - match d.Ast.definition_label with - (* This default definition has a label. This - is not allowed for unlabeled exceptions *) - | Some _ -> - { - def_key_ctx with - default_exception_rulename = - Some - (Ambiguous - [ - Pos.get_position - d.definition_name; - ]); - } - (* This is a possible default definition for - this key. We create and store a fresh - rulename *) - | None -> - { - def_key_ctx with - default_exception_rulename = - Some - (Unique - ( d.definition_id, - Pos.get_position - d.definition_name )); - })) - in - Some def_key_ctx) - s_ctxt.scope_defs_contexts; - }) + (* This is a possible default definition for this + key. We create and store a fresh rulename *) + | None -> + { + def_key_ctx with + default_exception_rulename = + Some + (Unique + ( d.definition_id, + Pos.get_position d.definition_name )); + })) + in + Some def_key_ctx) + s_ctxt.scope_defs_contexts; + }) ctxt.scopes; } @@ -787,7 +772,7 @@ let process_scope_use (ctxt : context) (suse : Ast.scope_use) : context = Errors.raise_spanned_error (Pos.get_position suse.Ast.scope_use_name) "\"%a\": this scope has not been declared anywhere, is it a typo?" - (Utils.Cli.format_with_style [ ANSITerminal.yellow ]) + (Utils.Cli.format_with_style [ANSITerminal.yellow]) (Pos.unmark suse.Ast.scope_use_name) in List.fold_left (process_scope_use_item s_name) ctxt suse.Ast.scope_use_items diff --git a/compiler/surface/parser_driver.ml b/compiler/surface/parser_driver.ml index 7a96857a..9e319de9 100644 --- a/compiler/surface/parser_driver.ml +++ b/compiler/surface/parser_driver.ml @@ -65,40 +65,40 @@ let rec law_struct_list_to_tree (f : Ast.law_structure list) : Ast.law_structure list = match f with | [] -> [] - | [ item ] -> [ item ] + | [item] -> [item] | first_item :: rest -> ( - let rest_tree = law_struct_list_to_tree rest in - match rest_tree with - | [] -> assert false (* there should be at least one rest element *) - | rest_head :: rest_tail -> ( - match first_item with - | CodeBlock _ | LawText _ | LawInclude _ -> - (* if an article or an include is just before a new heading , then - we don't merge it with what comes next *) - first_item :: rest_head :: rest_tail - | LawHeading (heading, _) -> - (* here we have encountered a heading, which is going to "gobble" - everything in the [rest_tree] until it finds a heading of at - least the same precedence *) - let rec split_rest_tree (rest_tree : Ast.law_structure list) : - Ast.law_structure list * Ast.law_structure list = - match rest_tree with - | [] -> ([], []) - | LawHeading (new_heading, _) :: _ - when new_heading.law_heading_precedence - <= heading.law_heading_precedence -> - (* we stop gobbling *) - ([], rest_tree) - | first :: after -> - (* we continue gobbling *) - let after_gobbled, after_out = split_rest_tree after in - (first :: after_gobbled, after_out) - in - let gobbled, rest_out = split_rest_tree rest_tree in - LawHeading (heading, gobbled) :: rest_out)) + let rest_tree = law_struct_list_to_tree rest in + match rest_tree with + | [] -> assert false (* there should be at least one rest element *) + | rest_head :: rest_tail -> ( + match first_item with + | CodeBlock _ | LawText _ | LawInclude _ -> + (* if an article or an include is just before a new heading , then we + don't merge it with what comes next *) + first_item :: rest_head :: rest_tail + | LawHeading (heading, _) -> + (* here we have encountered a heading, which is going to "gobble" + everything in the [rest_tree] until it finds a heading of at least + the same precedence *) + let rec split_rest_tree (rest_tree : Ast.law_structure list) : + Ast.law_structure list * Ast.law_structure list = + match rest_tree with + | [] -> [], [] + | LawHeading (new_heading, _) :: _ + when new_heading.law_heading_precedence + <= heading.law_heading_precedence -> + (* we stop gobbling *) + [], rest_tree + | first :: after -> + (* we continue gobbling *) + let after_gobbled, after_out = split_rest_tree after in + first :: after_gobbled, after_out + in + let gobbled, rest_out = split_rest_tree rest_tree in + LawHeading (heading, gobbled) :: rest_out)) (** Style with which to display syntax hints in the terminal output *) -let syntax_hints_style = [ ANSITerminal.yellow ] +let syntax_hints_style = [ANSITerminal.yellow] (** Usage: [raise_parser_error error_loc last_good_loc token msg] @@ -116,7 +116,7 @@ let raise_parser_error :: (match last_good_loc with | None -> [] - | Some last_good_loc -> [ (Some "Last good token:", last_good_loc) ])) + | Some last_good_loc -> [Some "Last good token:", last_good_loc])) "Syntax error at token %a\n%s" (Cli.format_with_style syntax_hints_style) (Printf.sprintf "\"%s\"" token) @@ -150,15 +150,15 @@ module ParserAux (LocalisedLexer : Lexer_common.LocalisedLexer) = struct let acceptable_tokens, last_positions = match last_input_needed with | Some last_input_needed -> - ( List.filter - (fun (_, t) -> - I.acceptable - (I.input_needed last_input_needed) - t - (fst (lexing_positions lexbuf))) - token_list, - Some (I.positions last_input_needed) ) - | None -> (token_list, None) + ( List.filter + (fun (_, t) -> + I.acceptable + (I.input_needed last_input_needed) + t + (fst (lexing_positions lexbuf))) + token_list, + Some (I.positions last_input_needed) ) + | None -> token_list, None in let similar_acceptable_tokens = List.sort @@ -193,19 +193,18 @@ module ParserAux (LocalisedLexer : Lexer_common.LocalisedLexer) = struct let custom_menhir_message = match Parser_errors.message (state env) with | exception Not_found -> - "Message: " - ^ Cli.with_style syntax_hints_style "%s" "unexpected token" + "Message: " ^ Cli.with_style syntax_hints_style "%s" "unexpected token" | msg -> - "Message: " - ^ Cli.with_style syntax_hints_style "%s" - (String.trim (String.uncapitalize_ascii msg)) + "Message: " + ^ Cli.with_style syntax_hints_style "%s" + (String.trim (String.uncapitalize_ascii msg)) in let msg = match similar_token_msg with | None -> custom_menhir_message | Some similar_token_msg -> - Printf.sprintf "%s\nAutosuggestion: %s" custom_menhir_message - similar_token_msg + Printf.sprintf "%s\nAutosuggestion: %s" custom_menhir_message + similar_token_msg in raise_parser_error (Pos.from_lpos (lexing_positions lexbuf)) @@ -221,17 +220,17 @@ module ParserAux (LocalisedLexer : Lexer_common.LocalisedLexer) = struct (checkpoint : 'semantic_value I.checkpoint) : Ast.source_file = match checkpoint with | I.InputNeeded env -> - let token = next_token () in - let checkpoint = I.offer checkpoint token in - loop next_token token_list lexbuf (Some env) checkpoint + let token = next_token () in + let checkpoint = I.offer checkpoint token in + loop next_token token_list lexbuf (Some env) checkpoint | I.Shifting _ | I.AboutToReduce _ -> - let checkpoint = I.resume checkpoint in - loop next_token token_list lexbuf last_input_needed checkpoint + let checkpoint = I.resume checkpoint in + loop next_token token_list lexbuf last_input_needed checkpoint | I.HandlingError env -> fail lexbuf env token_list last_input_needed | I.Accepted v -> v | I.Rejected -> - (* Cannot happen as we stop at syntax error immediatly *) - assert false + (* Cannot happen as we stop at syntax error immediatly *) + assert false (** Stub that wraps the parsing main loop and handles the Menhir/Sedlex type difference for [lexbuf]. *) @@ -269,17 +268,18 @@ let localised_parser : Cli.backend_lang -> lexbuf -> Ast.source_file = function (** Parses a single source file *) let rec parse_source_file - (source_file : Pos.input_file) (language : Cli.backend_lang) : Ast.program = + (source_file : Pos.input_file) + (language : Cli.backend_lang) : Ast.program = Cli.debug_print "Parsing %s" (match source_file with FileName s | Contents s -> s); let lexbuf, input = match source_file with | FileName source_file -> ( - try - let input = open_in source_file in - (Sedlexing.Utf8.from_channel input, Some input) - with Sys_error msg -> Errors.raise_error "System error: %s" msg) - | Contents contents -> (Sedlexing.Utf8.from_string contents, None) + try + let input = open_in source_file in + Sedlexing.Utf8.from_channel input, Some input + with Sys_error msg -> Errors.raise_error "System error: %s" msg) + | Contents contents -> Sedlexing.Utf8.from_string contents, None in let source_file_name = match source_file with FileName s -> s | Contents _ -> "stdin" @@ -304,38 +304,36 @@ and expand_includes (fun acc command -> match command with | Ast.LawInclude (Ast.CatalaFile sub_source) -> - let source_dir = Filename.dirname source_file in - let sub_source = Filename.concat source_dir (Pos.unmark sub_source) in - let includ_program = - parse_source_file (FileName sub_source) language - in - { - Ast.program_source_files = - acc.Ast.program_source_files @ includ_program.program_source_files; - Ast.program_items = - acc.Ast.program_items @ includ_program.program_items; - } + let source_dir = Filename.dirname source_file in + let sub_source = Filename.concat source_dir (Pos.unmark sub_source) in + let includ_program = parse_source_file (FileName sub_source) language in + { + Ast.program_source_files = + acc.Ast.program_source_files @ includ_program.program_source_files; + Ast.program_items = + acc.Ast.program_items @ includ_program.program_items; + } | Ast.LawHeading (heading, commands') -> - let { - Ast.program_items = commands'; - Ast.program_source_files = new_sources; - } = - expand_includes source_file commands' language - in - { - Ast.program_source_files = - acc.Ast.program_source_files @ new_sources; - Ast.program_items = - acc.Ast.program_items @ [ Ast.LawHeading (heading, commands') ]; - } - | i -> { acc with Ast.program_items = acc.Ast.program_items @ [ i ] }) + let { + Ast.program_items = commands'; + Ast.program_source_files = new_sources; + } = + expand_includes source_file commands' language + in + { + Ast.program_source_files = acc.Ast.program_source_files @ new_sources; + Ast.program_items = + acc.Ast.program_items @ [Ast.LawHeading (heading, commands')]; + } + | i -> { acc with Ast.program_items = acc.Ast.program_items @ [i] }) { Ast.program_source_files = []; Ast.program_items = [] } commands (** {1 API} *) let parse_top_level_file - (source_file : Pos.input_file) (language : Cli.backend_lang) : Ast.program = + (source_file : Pos.input_file) + (language : Cli.backend_lang) : Ast.program = let program = parse_source_file source_file language in { program with diff --git a/compiler/utils/cli.ml b/compiler/utils/cli.ml index dd1641ca..061ec00c 100644 --- a/compiler/utils/cli.ml +++ b/compiler/utils/cli.ml @@ -89,22 +89,21 @@ let file = & info [] ~docv:"FILE" ~doc:"Catala master file to be compiled.") let debug = - Arg.(value & flag & info [ "debug"; "d" ] ~doc:"Prints debug information.") + Arg.(value & flag & info ["debug"; "d"] ~doc:"Prints debug information.") let unstyled = Arg.( value & flag - & info [ "unstyled"; "u" ] + & info ["unstyled"; "u"] ~doc:"Removes styling (colors, etc.) from terminal output.") let optimize = - Arg.( - value & flag & info [ "optimize"; "O" ] ~doc:"Run compiler optimizations.") + Arg.(value & flag & info ["optimize"; "O"] ~doc:"Run compiler optimizations.") let trace_opt = Arg.( value & flag - & info [ "trace"; "t" ] + & info ["trace"; "t"] ~doc: "Displays a trace of the interpreter's computation or generates \ logging instructions in translate programs.") @@ -112,19 +111,19 @@ let trace_opt = let avoid_exceptions = Arg.( value & flag - & info [ "avoid_exceptions" ] + & info ["avoid_exceptions"] ~doc:"Compiles the default calculus without exceptions") let closure_conversion = Arg.( value & flag - & info [ "closure_conversion" ] + & info ["closure_conversion"] ~doc:"Performs closure conversion on the lambda calculus") let wrap_weaved_output = Arg.( value & flag - & info [ "wrap"; "w" ] + & info ["wrap"; "w"] ~doc:"Wraps literate programming output with a minimal preamble.") let backend = @@ -139,7 +138,7 @@ let language = Arg.( value & opt (some string) None - & info [ "l"; "language" ] ~docv:"LANG" + & info ["l"; "language"] ~docv:"LANG" ~doc:"Input language among: en, fr, pl.") let max_prec_digits_opt = @@ -147,7 +146,7 @@ let max_prec_digits_opt = value & opt (some int) None & info - [ "p"; "max_digits_printed" ] + ["p"; "max_digits_printed"] ~docv:"DIGITS" ~doc: "Maximum number of significant digits printed for decimal results \ @@ -157,7 +156,7 @@ let disable_counterexamples_opt = Arg.( value & flag & info - [ "disable_counterexamples" ] + ["disable_counterexamples"] ~doc: "Disables the search for counterexamples in proof mode. Useful when \ you want a deterministic output from the Catala compiler, since \ @@ -167,13 +166,13 @@ let ex_scope = Arg.( value & opt (some string) None - & info [ "s"; "scope" ] ~docv:"SCOPE" ~doc:"Scope to be focused on.") + & info ["s"; "scope"] ~docv:"SCOPE" ~doc:"Scope to be focused on.") let output = Arg.( value & opt (some string) None - & info [ "output"; "o" ] ~docv:"OUTPUT" + & info ["output"; "o"] ~docv:"OUTPUT" ~doc: "$(i, OUTPUT) is the file that will contain the output of the \ compiler. Defaults to $(i,FILE).$(i,EXT) where $(i,EXT) depends on \ @@ -315,7 +314,7 @@ let info = "Please file bug reports at https://github.com/CatalaLang/catala/issues"; ] in - let exits = Term.default_exits @ [ Term.exit_info ~doc:"on error." 1 ] in + let exits = Term.default_exits @ [Term.exit_info ~doc:"on error." 1] in Term.info "catala" ~version ~doc ~exits ~man (**{1 Terminal formatting}*) @@ -325,7 +324,8 @@ let info = let time : float ref = ref (Unix.gettimeofday ()) let with_style - (styles : ANSITerminal.style list) (str : ('a, unit, string) format) = + (styles : ANSITerminal.style list) + (str : ('a, unit, string) format) = if !style_flag then ANSITerminal.sprintf styles str else Printf.sprintf str let format_with_style (styles : ANSITerminal.style list) fmt (str : string) = @@ -342,48 +342,49 @@ let time_marker () = if delta > 50. then Printf.printf "%s" (with_style - [ ANSITerminal.Bold; ANSITerminal.black ] + [ANSITerminal.Bold; ANSITerminal.black] "[TIME] %.0f ms\n" delta) (** Prints [\[DEBUG\]] in purple on the terminal standard output *) let debug_marker () = time_marker (); - with_style [ ANSITerminal.Bold; ANSITerminal.magenta ] "[DEBUG] " + with_style [ANSITerminal.Bold; ANSITerminal.magenta] "[DEBUG] " (** Prints [\[ERROR\]] in red on the terminal error output *) let error_marker () = - with_style [ ANSITerminal.Bold; ANSITerminal.red ] "[ERROR] " + with_style [ANSITerminal.Bold; ANSITerminal.red] "[ERROR] " (** Prints [\[WARNING\]] in yellow on the terminal standard output *) let warning_marker () = - with_style [ ANSITerminal.Bold; ANSITerminal.yellow ] "[WARNING] " + with_style [ANSITerminal.Bold; ANSITerminal.yellow] "[WARNING] " (** Prints [\[RESULT\]] in green on the terminal standard output *) let result_marker () = - with_style [ ANSITerminal.Bold; ANSITerminal.green ] "[RESULT] " + with_style [ANSITerminal.Bold; ANSITerminal.green] "[RESULT] " (** Prints [\[LOG\]] in red on the terminal error output *) -let log_marker () = - with_style [ ANSITerminal.Bold; ANSITerminal.black ] "[LOG] " +let log_marker () = with_style [ANSITerminal.Bold; ANSITerminal.black] "[LOG] " (**{2 Printers}*) (** All the printers below print their argument after the correct marker *) let concat_with_line_depending_prefix_and_suffix - (prefix : int -> string) (suffix : int -> string) (ss : string list) = + (prefix : int -> string) + (suffix : int -> string) + (ss : string list) = match ss with | hd :: rest -> - let out, _ = - List.fold_left - (fun (acc, i) s -> - ( (acc ^ prefix i ^ s - ^ if i = List.length ss - 1 then "" else suffix i), - i + 1 )) - ((prefix 0 ^ hd ^ if 0 = List.length ss - 1 then "" else suffix 0), 1) - rest - in - out + let out, _ = + List.fold_left + (fun (acc, i) s -> + ( (acc ^ prefix i ^ s + ^ if i = List.length ss - 1 then "" else suffix i), + i + 1 )) + ((prefix 0 ^ hd ^ if 0 = List.length ss - 1 then "" else suffix 0), 1) + rest + in + out | [] -> prefix 0 (** The int argument of the prefix corresponds to the line number, starting at 0 *) diff --git a/compiler/utils/errors.ml b/compiler/utils/errors.ml index 106165dd..3147a094 100644 --- a/compiler/utils/errors.ml +++ b/compiler/utils/errors.ml @@ -39,7 +39,7 @@ let print_structured_error (msg : string) (pos : (string option * Pos.t) list) : let raise_spanned_error ?(span_msg : string option) (span : Pos.t) format = Format.kasprintf - (fun msg -> raise (StructuredError (msg, [ (span_msg, span) ]))) + (fun msg -> raise (StructuredError (msg, [span_msg, span]))) format let raise_multispanned_error (spans : (string option * Pos.t) list) format = @@ -56,6 +56,6 @@ let format_multispanned_warning (pos : (string option * Pos.t) list) format = format let format_spanned_warning ?(span_msg : string option) (span : Pos.t) format = - format_multispanned_warning [ (span_msg, span) ] format + format_multispanned_warning [span_msg, span] format let format_warning format = format_multispanned_warning [] format diff --git a/compiler/utils/pos.ml b/compiler/utils/pos.ml index 6ef3c34e..fcf86e98 100644 --- a/compiler/utils/pos.ml +++ b/compiler/utils/pos.ml @@ -20,7 +20,11 @@ let from_lpos (p : Lexing.position * Lexing.position) : t = { code_pos = p; law_pos = [] } let from_info - (file : string) (sline : int) (scol : int) (eline : int) (ecol : int) : t = + (file : string) + (sline : int) + (scol : int) + (eline : int) + (ecol : int) : t = let spos = { Lexing.pos_fname = file; @@ -37,7 +41,7 @@ let from_info Lexing.pos_bol = 1; } in - { code_pos = (spos, epos); law_pos = [] } + { code_pos = spos, epos; law_pos = [] } let overwrite_law_info (pos : t) (law_pos : string list) : t = { pos with law_pos } @@ -88,7 +92,7 @@ let indent_number (s : string) : int = let retrieve_loc_text (pos : t) : string = try let filename = get_file pos in - let blue_style = [ ANSITerminal.Bold; ANSITerminal.blue ] in + let blue_style = [ANSITerminal.Bold; ANSITerminal.blue] in if filename = "" then "No position information" else let sline = get_start_line pos in @@ -100,21 +104,21 @@ let retrieve_loc_text (pos : t) : string = let input_line_opt () : string option = match List.nth_opt lines !line_index with | Some l -> - line_index := !line_index + 1; - Some l + line_index := !line_index + 1; + Some l | None -> None in - (None, input_line_opt) + None, input_line_opt else let oc = open_in filename in let input_line_opt () : string option = try Some (input_line oc) with End_of_file -> None in - (Some oc, input_line_opt) + Some oc, input_line_opt in let print_matched_line (line : string) (line_no : int) : string = let line_indent = indent_number line in - let error_indicator_style = [ ANSITerminal.red; ANSITerminal.Bold ] in + let error_indicator_style = [ANSITerminal.red; ANSITerminal.Bold] in line ^ if line_no >= sline && line_no <= eline then @@ -146,12 +150,11 @@ let retrieve_loc_text (pos : t) : string = let rec get_lines (n : int) : string list = match input_line_opt () with | Some line -> - if n < sline - include_extra_count then get_lines (n + 1) - else if - n >= sline - include_extra_count - && n <= eline + include_extra_count - then print_matched_line line n :: get_lines (n + 1) - else [] + if n < sline - include_extra_count then get_lines (n + 1) + else if + n >= sline - include_extra_count && n <= eline + include_extra_count + then print_matched_line line n :: get_lines (n + 1) + else [] | None -> [] in let pos_lines = get_lines 1 in @@ -211,13 +214,13 @@ let no_pos : t = Lexing.pos_bol = 0; } in - { code_pos = (zero_pos, zero_pos); law_pos = [] } + { code_pos = zero_pos, zero_pos; law_pos = [] } -let mark pos e : 'a marked = (e, pos) +let mark pos e : 'a marked = e, pos let unmark ((x, _) : 'a marked) : 'a = x let get_position ((_, x) : 'a marked) : t = x -let map_under_mark (f : 'a -> 'b) ((x, y) : 'a marked) : 'b marked = (f x, y) -let same_pos_as (x : 'a) ((_, y) : 'b marked) : 'a marked = (x, y) +let map_under_mark (f : 'a -> 'b) ((x, y) : 'a marked) : 'b marked = f x, y +let same_pos_as (x : 'a) ((_, y) : 'b marked) : 'a marked = x, y let unmark_option (x : 'a marked option) : 'a option = match x with Some x -> Some (unmark x) | None -> None diff --git a/compiler/verification/conditions.ml b/compiler/verification/conditions.ml index e64c2aaa..870bce15 100644 --- a/compiler/verification/conditions.ml +++ b/compiler/verification/conditions.ml @@ -35,37 +35,35 @@ type ctx = { let conjunction (args : vc_return list) (pos : Pos.t) : vc_return = let acc, list = match args with - | hd :: tl -> (hd, tl) - | [] -> (((ELit (LBool true), pos), VarMap.empty), []) + | hd :: tl -> hd, tl + | [] -> ((ELit (LBool true), pos), VarMap.empty), [] in List.fold_left (fun (acc, acc_ty) (arg, arg_ty) -> - ( (EApp ((EOp (Binop And), pos), [ arg; acc ]), pos), - VarMap.union (fun _ _ _ -> failwith "should not happen") acc_ty arg_ty - )) + ( (EApp ((EOp (Binop And), pos), [arg; acc]), pos), + VarMap.union (fun _ _ _ -> failwith "should not happen") acc_ty arg_ty )) acc list let negation ((arg, arg_ty) : vc_return) (pos : Pos.t) : vc_return = - ((EApp ((EOp (Unop Not), pos), [ arg ]), pos), arg_ty) + (EApp ((EOp (Unop Not), pos), [arg]), pos), arg_ty let disjunction (args : vc_return list) (pos : Pos.t) : vc_return = let acc, list = match args with - | hd :: tl -> (hd, tl) - | [] -> (((ELit (LBool false), pos), VarMap.empty), []) + | hd :: tl -> hd, tl + | [] -> ((ELit (LBool false), pos), VarMap.empty), [] in List.fold_left (fun ((acc, acc_ty) : vc_return) (arg, arg_ty) -> - ( (EApp ((EOp (Binop Or), pos), [ arg; acc ]), pos), - VarMap.union (fun _ _ _ -> failwith "should not happen") acc_ty arg_ty - )) + ( (EApp ((EOp (Binop Or), pos), [arg; acc]), pos), + VarMap.union (fun _ _ _ -> failwith "should not happen") acc_ty arg_ty )) acc list (** [half_product \[a1,...,an\] \[b1,...,bm\] returns \[(a1,b1),...(a1,bn),...(an,b1),...(an,bm)\]] *) let half_product (l1 : 'a list) (l2 : 'b list) : ('a * 'b) list = l1 |> List.mapi (fun i ei -> - List.filteri (fun j _ -> i < j) l2 |> List.map (fun ej -> (ei, ej))) + List.filteri (fun j _ -> i < j) l2 |> List.map (fun ej -> ei, ej)) |> List.concat (** This code skims through the topmost layers of the terms like this: @@ -78,26 +76,26 @@ let match_and_ignore_outer_reentrant_default (ctx : ctx) (e : expr Pos.marked) : match Pos.unmark e with | ErrorOnEmpty ( EDefault - ( [ (EApp ((EVar (x, _), _), [ (ELit LUnit, _) ]), _) ], + ( [(EApp ((EVar (x, _), _), [(ELit LUnit, _)]), _)], (ELit (LBool true), _), cons ), _ ) when List.exists (fun x' -> Bindlib.eq_vars x x') ctx.input_vars -> - (* scope variables*) - cons - | EAbs ((binder, _), [ (TLit TUnit, _) ]) -> - (* context sub-scope variables *) - let _, body = Bindlib.unmbind binder in - body + (* scope variables*) + cons + | EAbs ((binder, _), [(TLit TUnit, _)]) -> + (* context sub-scope variables *) + let _, body = Bindlib.unmbind binder in + body | ErrorOnEmpty d -> - d (* input subscope variables and non-input scope variable *) + d (* input subscope variables and non-input scope variable *) | _ -> - Errors.raise_spanned_error (Pos.get_position e) - "Internal error: this expression does not have the structure expected \ - by the VC generator:\n\ - %a" - (Print.format_expr ~debug:true ctx.decl) - e + Errors.raise_spanned_error (Pos.get_position e) + "Internal error: this expression does not have the structure expected by \ + the VC generator:\n\ + %a" + (Print.format_expr ~debug:true ctx.decl) + e (** {1 Verification conditions generator}*) @@ -110,86 +108,86 @@ let rec generate_vc_must_not_return_empty (ctx : ctx) (e : expr Pos.marked) : let out = match Pos.unmark e with | ETuple (args, _) | EArray args -> - conjunction - (List.map (generate_vc_must_not_return_empty ctx) args) - (Pos.get_position e) + conjunction + (List.map (generate_vc_must_not_return_empty ctx) args) + (Pos.get_position e) | EMatch (arg, arms, _) -> - conjunction - (List.map (generate_vc_must_not_return_empty ctx) (arg :: arms)) - (Pos.get_position e) + conjunction + (List.map (generate_vc_must_not_return_empty ctx) (arg :: arms)) + (Pos.get_position e) | ETupleAccess (e1, _, _, _) | EInj (e1, _, _, _) | EAssert e1 | ErrorOnEmpty e1 -> - (generate_vc_must_not_return_empty ctx) e1 + (generate_vc_must_not_return_empty ctx) e1 | EAbs (binder, typs) -> - (* Hot take: for a function never to return an empty error when called, it has to do - so whatever its input. So we universally quantify over the variable of the function - when inspecting the body, resulting in simply traversing through in the code here. *) - let vars, body = Bindlib.unmbind (Pos.unmark binder) in - let vc_body_expr, vc_body_ty = - (generate_vc_must_not_return_empty ctx) body - in - ( vc_body_expr, - List.fold_left - (fun acc (var, ty) -> VarMap.add var ty acc) - vc_body_ty - (List.map2 (fun x y -> (x, y)) (Array.to_list vars) typs) ) + (* Hot take: for a function never to return an empty error when called, it has to do + so whatever its input. So we universally quantify over the variable of the function + when inspecting the body, resulting in simply traversing through in the code here. *) + let vars, body = Bindlib.unmbind (Pos.unmark binder) in + let vc_body_expr, vc_body_ty = + (generate_vc_must_not_return_empty ctx) body + in + ( vc_body_expr, + List.fold_left + (fun acc (var, ty) -> VarMap.add var ty acc) + vc_body_ty + (List.map2 (fun x y -> x, y) (Array.to_list vars) typs) ) | EApp (f, args) -> - (* We assume here that function calls never return empty error, which implies - all functions have been checked never to return empty errors. *) - conjunction - (List.map (generate_vc_must_not_return_empty ctx) (f :: args)) - (Pos.get_position e) + (* We assume here that function calls never return empty error, which implies + all functions have been checked never to return empty errors. *) + conjunction + (List.map (generate_vc_must_not_return_empty ctx) (f :: args)) + (Pos.get_position e) | EIfThenElse (e1, e2, e3) -> - let e1_vc, vc_typ1 = generate_vc_must_not_return_empty ctx e1 in - let e2_vc, vc_typ2 = generate_vc_must_not_return_empty ctx e2 in - let e3_vc, vc_typ3 = generate_vc_must_not_return_empty ctx e3 in - conjunction - [ - (e1_vc, vc_typ1); - ( (EIfThenElse (e1, e2_vc, e3_vc), Pos.get_position e), - VarMap.union - (fun _ _ _ -> failwith "should not happen") - vc_typ2 vc_typ3 ); - ] - (Pos.get_position e) - | ELit LEmptyError -> (Pos.same_pos_as (ELit (LBool false)) e, VarMap.empty) + let e1_vc, vc_typ1 = generate_vc_must_not_return_empty ctx e1 in + let e2_vc, vc_typ2 = generate_vc_must_not_return_empty ctx e2 in + let e3_vc, vc_typ3 = generate_vc_must_not_return_empty ctx e3 in + conjunction + [ + e1_vc, vc_typ1; + ( (EIfThenElse (e1, e2_vc, e3_vc), Pos.get_position e), + VarMap.union + (fun _ _ _ -> failwith "should not happen") + vc_typ2 vc_typ3 ); + ] + (Pos.get_position e) + | ELit LEmptyError -> Pos.same_pos_as (ELit (LBool false)) e, VarMap.empty | EVar _ (* Per default calculus semantics, you cannot call a function with an argument that evaluates to the empty error. Thus, all variable evaluate to non-empty-error terms. *) | ELit _ | EOp _ -> - (Pos.same_pos_as (ELit (LBool true)) e, VarMap.empty) + Pos.same_pos_as (ELit (LBool true)) e, VarMap.empty | EDefault (exceptions, just, cons) -> - (* never returns empty if and only if: - - first we look if e1 .. en ejust can return empty; - - if no, we check that if ejust is true, whether econs can return empty. - *) - disjunction - (List.map (generate_vc_must_not_return_empty ctx) exceptions - @ [ - conjunction - [ - generate_vc_must_not_return_empty ctx just; - (let vc_just_expr, vc_just_ty = - generate_vc_must_not_return_empty ctx cons - in - ( ( EIfThenElse - ( just, - (* Comment from Alain: the justification is not checked for holding an default term. - In such cases, we need to encode the logic of the default terms within - the generation of the verification condition (Z3encoding.translate_expr). - Answer from Denis: Normally, there is a structural invariant from the - surface language to intermediate representation translation preventing - any default terms to appear in justifications.*) - vc_just_expr, - (ELit (LBool false), Pos.get_position e) ), - Pos.get_position e ), - vc_just_ty )); - ] - (Pos.get_position e); - ]) - (Pos.get_position e) + (* never returns empty if and only if: + - first we look if e1 .. en ejust can return empty; + - if no, we check that if ejust is true, whether econs can return empty. + *) + disjunction + (List.map (generate_vc_must_not_return_empty ctx) exceptions + @ [ + conjunction + [ + generate_vc_must_not_return_empty ctx just; + (let vc_just_expr, vc_just_ty = + generate_vc_must_not_return_empty ctx cons + in + ( ( EIfThenElse + ( just, + (* Comment from Alain: the justification is not checked for holding an default term. + In such cases, we need to encode the logic of the default terms within + the generation of the verification condition (Z3encoding.translate_expr). + Answer from Denis: Normally, there is a structural invariant from the + surface language to intermediate representation translation preventing + any default terms to appear in justifications.*) + vc_just_expr, + (ELit (LBool false), Pos.get_position e) ), + Pos.get_position e ), + vc_just_ty )); + ] + (Pos.get_position e); + ]) + (Pos.get_position e) in out [@@ocamlformat "wrap-comments=false"] @@ -205,73 +203,73 @@ let rec generate_vs_must_not_return_confict (ctx : ctx) (e : expr Pos.marked) : function relies on. *) match Pos.unmark e with | ETuple (args, _) | EArray args -> - conjunction - (List.map (generate_vs_must_not_return_confict ctx) args) - (Pos.get_position e) + conjunction + (List.map (generate_vs_must_not_return_confict ctx) args) + (Pos.get_position e) | EMatch (arg, arms, _) -> - conjunction - (List.map (generate_vs_must_not_return_confict ctx) (arg :: arms)) - (Pos.get_position e) + conjunction + (List.map (generate_vs_must_not_return_confict ctx) (arg :: arms)) + (Pos.get_position e) | ETupleAccess (e1, _, _, _) | EInj (e1, _, _, _) | EAssert e1 | ErrorOnEmpty e1 -> - generate_vs_must_not_return_confict ctx e1 + generate_vs_must_not_return_confict ctx e1 | EAbs (binder, typs) -> - let vars, body = Bindlib.unmbind (Pos.unmark binder) in - let vc_body_expr, vc_body_ty = - (generate_vs_must_not_return_confict ctx) body - in - ( vc_body_expr, - List.fold_left - (fun acc (var, ty) -> VarMap.add var ty acc) - vc_body_ty - (List.map2 (fun x y -> (x, y)) (Array.to_list vars) typs) ) + let vars, body = Bindlib.unmbind (Pos.unmark binder) in + let vc_body_expr, vc_body_ty = + (generate_vs_must_not_return_confict ctx) body + in + ( vc_body_expr, + List.fold_left + (fun acc (var, ty) -> VarMap.add var ty acc) + vc_body_ty + (List.map2 (fun x y -> x, y) (Array.to_list vars) typs) ) | EApp (f, args) -> - conjunction - (List.map (generate_vs_must_not_return_confict ctx) (f :: args)) - (Pos.get_position e) + conjunction + (List.map (generate_vs_must_not_return_confict ctx) (f :: args)) + (Pos.get_position e) | EIfThenElse (e1, e2, e3) -> - let e1_vc, vc_typ1 = generate_vs_must_not_return_confict ctx e1 in - let e2_vc, vc_typ2 = generate_vs_must_not_return_confict ctx e2 in - let e3_vc, vc_typ3 = generate_vs_must_not_return_confict ctx e3 in - conjunction - [ - (e1_vc, vc_typ1); - ( (EIfThenElse (e1, e2_vc, e3_vc), Pos.get_position e), - VarMap.union - (fun _ _ _ -> failwith "should not happen") - vc_typ2 vc_typ3 ); - ] - (Pos.get_position e) + let e1_vc, vc_typ1 = generate_vs_must_not_return_confict ctx e1 in + let e2_vc, vc_typ2 = generate_vs_must_not_return_confict ctx e2 in + let e3_vc, vc_typ3 = generate_vs_must_not_return_confict ctx e3 in + conjunction + [ + e1_vc, vc_typ1; + ( (EIfThenElse (e1, e2_vc, e3_vc), Pos.get_position e), + VarMap.union + (fun _ _ _ -> failwith "should not happen") + vc_typ2 vc_typ3 ); + ] + (Pos.get_position e) | EVar _ | ELit _ | EOp _ -> - (Pos.same_pos_as (ELit (LBool true)) e, VarMap.empty) + Pos.same_pos_as (ELit (LBool true)) e, VarMap.empty | EDefault (exceptions, just, cons) -> - (* never returns conflict if and only if: - - neither e1 nor ... nor en nor ejust nor econs return conflict - - there is no two differents ei ej that are not empty. *) - let quadratic = - negation - (disjunction - (List.map - (fun (e1, e2) -> - conjunction - [ - generate_vc_must_not_return_empty ctx e1; - generate_vc_must_not_return_empty ctx e2; - ] - (Pos.get_position e)) - (half_product exceptions exceptions)) - (Pos.get_position e)) - (Pos.get_position e) - in - let others = - List.map - (generate_vs_must_not_return_confict ctx) - (just :: cons :: exceptions) - in - let out = conjunction (quadratic :: others) (Pos.get_position e) in - out + (* never returns conflict if and only if: + - neither e1 nor ... nor en nor ejust nor econs return conflict + - there is no two differents ei ej that are not empty. *) + let quadratic = + negation + (disjunction + (List.map + (fun (e1, e2) -> + conjunction + [ + generate_vc_must_not_return_empty ctx e1; + generate_vc_must_not_return_empty ctx e2; + ] + (Pos.get_position e)) + (half_product exceptions exceptions)) + (Pos.get_position e)) + (Pos.get_position e) + in + let others = + List.map + (generate_vs_must_not_return_confict ctx) + (just :: cons :: exceptions) + in + let out = conjunction (quadratic :: others) (Pos.get_position e) in + out in out [@@ocamlformat "wrap-comments=false"] @@ -290,133 +288,132 @@ type verification_condition = { } let rec generate_verification_conditions_scope_body_expr - (ctx : ctx) (scope_body_expr : expr scope_body_expr) : - ctx * verification_condition list = + (ctx : ctx) + (scope_body_expr : expr scope_body_expr) : ctx * verification_condition list + = match scope_body_expr with - | Result _ -> (ctx, []) + | Result _ -> ctx, [] | ScopeLet scope_let -> - let scope_let_var, scope_let_next = - Bindlib.unbind scope_let.scope_let_next - in - let new_ctx, vc_list = - match scope_let.scope_let_kind with - | DestructuringInputStruct -> - ({ ctx with input_vars = scope_let_var :: ctx.input_vars }, []) - | ScopeVarDefinition | SubScopeVarDefinition -> - (* For scope variables, we should check both that they never - evaluate to emptyError nor conflictError. But for subscope - variable definitions, what we're really doing is adding - exceptions to something defined in the subscope so we just ought - to verify only that the exceptions overlap. *) - let e = - Bindlib.unbox (remove_logging_calls scope_let.scope_let_expr) + let scope_let_var, scope_let_next = + Bindlib.unbind scope_let.scope_let_next + in + let new_ctx, vc_list = + match scope_let.scope_let_kind with + | DestructuringInputStruct -> + { ctx with input_vars = scope_let_var :: ctx.input_vars }, [] + | ScopeVarDefinition | SubScopeVarDefinition -> + (* For scope variables, we should check both that they never evaluate to + emptyError nor conflictError. But for subscope variable definitions, + what we're really doing is adding exceptions to something defined in + the subscope so we just ought to verify only that the exceptions + overlap. *) + let e = Bindlib.unbox (remove_logging_calls scope_let.scope_let_expr) in + let e = match_and_ignore_outer_reentrant_default ctx e in + let vc_confl, vc_confl_typs = + generate_vs_must_not_return_confict ctx e + in + let vc_confl = + if !Cli.optimize_flag then + Bindlib.unbox (Optimizations.optimize_expr ctx.decl vc_confl) + else vc_confl + in + let vc_list = + [ + { + vc_guard = Pos.same_pos_as (Pos.unmark vc_confl) e; + vc_kind = NoOverlappingExceptions; + vc_free_vars_typ = + VarMap.union + (fun _ _ -> failwith "should not happen") + ctx.scope_variables_typs vc_confl_typs; + vc_scope = ctx.current_scope_name; + vc_variable = scope_let_var, scope_let.scope_let_pos; + }; + ] + in + let vc_list = + match scope_let.scope_let_kind with + | ScopeVarDefinition -> + let vc_empty, vc_empty_typs = + generate_vc_must_not_return_empty ctx e in - let e = match_and_ignore_outer_reentrant_default ctx e in - let vc_confl, vc_confl_typs = - generate_vs_must_not_return_confict ctx e - in - let vc_confl = + let vc_empty = if !Cli.optimize_flag then - Bindlib.unbox (Optimizations.optimize_expr ctx.decl vc_confl) - else vc_confl + Bindlib.unbox (Optimizations.optimize_expr ctx.decl vc_empty) + else vc_empty in - let vc_list = - [ - { - vc_guard = Pos.same_pos_as (Pos.unmark vc_confl) e; - vc_kind = NoOverlappingExceptions; - vc_free_vars_typ = - VarMap.union - (fun _ _ -> failwith "should not happen") - ctx.scope_variables_typs vc_confl_typs; - vc_scope = ctx.current_scope_name; - vc_variable = (scope_let_var, scope_let.scope_let_pos); - }; - ] - in - let vc_list = - match scope_let.scope_let_kind with - | ScopeVarDefinition -> - let vc_empty, vc_empty_typs = - generate_vc_must_not_return_empty ctx e - in - let vc_empty = - if !Cli.optimize_flag then - Bindlib.unbox - (Optimizations.optimize_expr ctx.decl vc_empty) - else vc_empty - in - { - vc_guard = Pos.same_pos_as (Pos.unmark vc_empty) e; - vc_kind = NoEmptyError; - vc_free_vars_typ = - VarMap.union - (fun _ _ -> failwith "should not happen") - ctx.scope_variables_typs vc_empty_typs; - vc_scope = ctx.current_scope_name; - vc_variable = (scope_let_var, scope_let.scope_let_pos); - } - :: vc_list - | _ -> vc_list - in - (ctx, vc_list) - | _ -> (ctx, []) - in - let new_ctx, new_vcs = - generate_verification_conditions_scope_body_expr - { - new_ctx with - scope_variables_typs = - VarMap.add scope_let_var scope_let.scope_let_typ - new_ctx.scope_variables_typs; - } - scope_let_next - in - (new_ctx, vc_list @ new_vcs) + { + vc_guard = Pos.same_pos_as (Pos.unmark vc_empty) e; + vc_kind = NoEmptyError; + vc_free_vars_typ = + VarMap.union + (fun _ _ -> failwith "should not happen") + ctx.scope_variables_typs vc_empty_typs; + vc_scope = ctx.current_scope_name; + vc_variable = scope_let_var, scope_let.scope_let_pos; + } + :: vc_list + | _ -> vc_list + in + ctx, vc_list + | _ -> ctx, [] + in + let new_ctx, new_vcs = + generate_verification_conditions_scope_body_expr + { + new_ctx with + scope_variables_typs = + VarMap.add scope_let_var scope_let.scope_let_typ + new_ctx.scope_variables_typs; + } + scope_let_next + in + new_ctx, vc_list @ new_vcs let rec generate_verification_conditions_scopes - (decl_ctx : decl_ctx) (scopes : expr scopes) (s : ScopeName.t option) : - verification_condition list = + (decl_ctx : decl_ctx) + (scopes : expr scopes) + (s : ScopeName.t option) : verification_condition list = match scopes with | Nil -> [] | ScopeDef scope_def -> - let is_selected_scope = - match s with - | Some s when Dcalc.Ast.ScopeName.compare s scope_def.scope_name = 0 -> - true - | None -> true - | _ -> false - in - let vcs = - if is_selected_scope then - let _scope_input_var, scope_body_expr = - Bindlib.unbind scope_def.scope_body.scope_body_expr - in - let ctx = - { - current_scope_name = scope_def.scope_name; - decl = decl_ctx; - input_vars = []; - scope_variables_typs = - VarMap.empty - (* We don't need to add the typ of the scope input var here - because it will never appear in an expression for which we - generate a verification conditions (the big struct is - destructured with a series of let bindings just after. )*); - } - in - let _, vcs = - generate_verification_conditions_scope_body_expr ctx scope_body_expr - in - vcs - else [] - in - let _scope_var, next = Bindlib.unbind scope_def.scope_next in - generate_verification_conditions_scopes decl_ctx next s @ vcs + let is_selected_scope = + match s with + | Some s when Dcalc.Ast.ScopeName.compare s scope_def.scope_name = 0 -> + true + | None -> true + | _ -> false + in + let vcs = + if is_selected_scope then + let _scope_input_var, scope_body_expr = + Bindlib.unbind scope_def.scope_body.scope_body_expr + in + let ctx = + { + current_scope_name = scope_def.scope_name; + decl = decl_ctx; + input_vars = []; + scope_variables_typs = + VarMap.empty + (* We don't need to add the typ of the scope input var here + because it will never appear in an expression for which we + generate a verification conditions (the big struct is + destructured with a series of let bindings just after. )*); + } + in + let _, vcs = + generate_verification_conditions_scope_body_expr ctx scope_body_expr + in + vcs + else [] + in + let _scope_var, next = Bindlib.unbind scope_def.scope_next in + generate_verification_conditions_scopes decl_ctx next s @ vcs let generate_verification_conditions - (p : program) (s : Dcalc.Ast.ScopeName.t option) : - verification_condition list = + (p : program) + (s : Dcalc.Ast.ScopeName.t option) : verification_condition list = let vcs = generate_verification_conditions_scopes p.decl_ctx p.scopes s in (* We sort this list by scope name and then variable name to ensure consistent output for testing*) diff --git a/compiler/verification/io.ml b/compiler/verification/io.ml index ad32e528..ca845d79 100644 --- a/compiler/verification/io.ml +++ b/compiler/verification/io.ml @@ -96,15 +96,15 @@ module MakeBackendIO (B : Backend) = struct let print_positive_result (vc : Conditions.verification_condition) : string = match vc.Conditions.vc_kind with | Conditions.NoEmptyError -> - Format.asprintf "%s This variable never returns an empty error" - (Cli.with_style [ ANSITerminal.yellow ] "[%s.%s]" - (Format.asprintf "%a" ScopeName.format_t vc.vc_scope) - (Bindlib.name_of (Pos.unmark vc.vc_variable))) + Format.asprintf "%s This variable never returns an empty error" + (Cli.with_style [ANSITerminal.yellow] "[%s.%s]" + (Format.asprintf "%a" ScopeName.format_t vc.vc_scope) + (Bindlib.name_of (Pos.unmark vc.vc_variable))) | Conditions.NoOverlappingExceptions -> - Format.asprintf "%s No two exceptions to ever overlap for this variable" - (Cli.with_style [ ANSITerminal.yellow ] "[%s.%s]" - (Format.asprintf "%a" ScopeName.format_t vc.vc_scope) - (Bindlib.name_of (Pos.unmark vc.vc_variable))) + Format.asprintf "%s No two exceptions to ever overlap for this variable" + (Cli.with_style [ANSITerminal.yellow] "[%s.%s]" + (Format.asprintf "%a" ScopeName.format_t vc.vc_scope) + (Bindlib.name_of (Pos.unmark vc.vc_variable))) let print_negative_result (vc : Conditions.verification_condition) @@ -113,18 +113,18 @@ module MakeBackendIO (B : Backend) = struct let var_and_pos = match vc.Conditions.vc_kind with | Conditions.NoEmptyError -> - Format.asprintf "%s This variable might return an empty error:\n%s" - (Cli.with_style [ ANSITerminal.yellow ] "[%s.%s]" - (Format.asprintf "%a" ScopeName.format_t vc.vc_scope) - (Bindlib.name_of (Pos.unmark vc.vc_variable))) - (Pos.retrieve_loc_text (Pos.get_position vc.vc_variable)) + Format.asprintf "%s This variable might return an empty error:\n%s" + (Cli.with_style [ANSITerminal.yellow] "[%s.%s]" + (Format.asprintf "%a" ScopeName.format_t vc.vc_scope) + (Bindlib.name_of (Pos.unmark vc.vc_variable))) + (Pos.retrieve_loc_text (Pos.get_position vc.vc_variable)) | Conditions.NoOverlappingExceptions -> - Format.asprintf - "%s At least two exceptions overlap for this variable:\n%s" - (Cli.with_style [ ANSITerminal.yellow ] "[%s.%s]" - (Format.asprintf "%a" ScopeName.format_t vc.vc_scope) - (Bindlib.name_of (Pos.unmark vc.vc_variable))) - (Pos.retrieve_loc_text (Pos.get_position vc.vc_variable)) + Format.asprintf + "%s At least two exceptions overlap for this variable:\n%s" + (Cli.with_style [ANSITerminal.yellow] "[%s.%s]" + (Format.asprintf "%a" ScopeName.format_t vc.vc_scope) + (Bindlib.name_of (Pos.unmark vc.vc_variable))) + (Pos.retrieve_loc_text (Pos.get_position vc.vc_variable)) in let counterexample : string option = if !Cli.disable_counterexamples then @@ -132,18 +132,18 @@ module MakeBackendIO (B : Backend) = struct else match model with | None -> - Some - "The solver did not manage to generate a counterexample to \ - explain the faulty behavior." + Some + "The solver did not manage to generate a counterexample to explain \ + the faulty behavior." | Some model -> - if B.is_model_empty model then None - else - Some - (Format.asprintf - "The solver generated the following counterexample to \ - explain the faulty behavior:\n\ - %s" - (B.print_model ctx model)) + if B.is_model_empty model then None + else + Some + (Format.asprintf + "The solver generated the following counterexample to explain \ + the faulty behavior:\n\ + %s" + (B.print_model ctx model)) in var_and_pos ^ @@ -161,28 +161,27 @@ module MakeBackendIO (B : Backend) = struct Cli.debug_print "For this variable:\n%s\n" (Pos.retrieve_loc_text (Pos.get_position vc.Conditions.vc_guard)); Cli.debug_format "This verification condition was generated for %a:@\n%a" - (Cli.format_with_style [ ANSITerminal.yellow ]) + (Cli.format_with_style [ANSITerminal.yellow]) (match vc.vc_kind with | Conditions.NoEmptyError -> - "the variable definition never to return an empty error" + "the variable definition never to return an empty error" | NoOverlappingExceptions -> "no two exceptions to ever overlap") (Dcalc.Print.format_expr decl_ctx) vc.vc_guard; match z3_vc with | Success (encoding, backend_ctx) -> ( - Cli.debug_print "The translation to Z3 is the following:\n%s" - (B.print_encoding encoding); - match B.solve_vc_encoding backend_ctx encoding with - | ProvenTrue -> Cli.result_print "%s" (print_positive_result vc) - | ProvenFalse model -> - Cli.error_print "%s" (print_negative_result vc backend_ctx model) - | Unknown -> - failwith "The solver failed at proving or disproving the VC") + Cli.debug_print "The translation to Z3 is the following:\n%s" + (B.print_encoding encoding); + match B.solve_vc_encoding backend_ctx encoding with + | ProvenTrue -> Cli.result_print "%s" (print_positive_result vc) + | ProvenFalse model -> + Cli.error_print "%s" (print_negative_result vc backend_ctx model) + | Unknown -> failwith "The solver failed at proving or disproving the VC") | Fail msg -> - Cli.error_print "%s The translation to Z3 failed:\n%s" - (Cli.with_style [ ANSITerminal.yellow ] "[%s.%s]" - (Format.asprintf "%a" ScopeName.format_t vc.vc_scope) - (Bindlib.name_of (Pos.unmark vc.vc_variable))) - msg + Cli.error_print "%s The translation to Z3 failed:\n%s" + (Cli.with_style [ANSITerminal.yellow] "[%s.%s]" + (Format.asprintf "%a" ScopeName.format_t vc.vc_scope) + (Bindlib.name_of (Pos.unmark vc.vc_variable))) + msg end diff --git a/compiler/verification/solver.ml b/compiler/verification/solver.ml index 0db1ebb2..c47bf8d8 100644 --- a/compiler/verification/solver.ml +++ b/compiler/verification/solver.ml @@ -20,8 +20,8 @@ open Dcalc.Ast expressions [vcs] corresponding to verification conditions that must be discharged by Z3, and attempts to solve them **) let solve_vc - (decl_ctx : decl_ctx) (vcs : Conditions.verification_condition list) : unit - = + (decl_ctx : decl_ctx) + (vcs : Conditions.verification_condition list) : unit = (* Right now we only use the Z3 backend but the functorial interface should make it easy to mix and match different proof backends. *) Z3backend.Io.init_backend (); diff --git a/compiler/verification/z3backend.real.ml b/compiler/verification/z3backend.real.ml index 2cc83705..42e09354 100644 --- a/compiler/verification/z3backend.real.ml +++ b/compiler/verification/z3backend.real.ml @@ -139,16 +139,16 @@ let rec print_z3model_expr (ctx : context) (ty : typ Pos.marked) (e : Expr.expr) | TRat -> Arithmetic.Real.to_decimal_string e !Cli.max_prec_digits (* TODO: Print the right money symbol according to language *) | TMoney -> - let z3_str = Expr.to_string e in - (* The Z3 model returns an integer corresponding to the amount of cents. - We reformat it as dollars *) - let to_dollars s = - Runtime.money_to_string (Runtime.money_of_cents_string s) - in - if String.contains z3_str '-' then - Format.asprintf "-%s $" - (to_dollars (String.sub z3_str 3 (String.length z3_str - 4))) - else Format.asprintf "%s $" (to_dollars z3_str) + let z3_str = Expr.to_string e in + (* The Z3 model returns an integer corresponding to the amount of cents. + We reformat it as dollars *) + let to_dollars s = + Runtime.money_to_string (Runtime.money_of_cents_string s) + in + if String.contains z3_str '-' then + Format.asprintf "-%s $" + (to_dollars (String.sub z3_str 3 (String.length z3_str - 4))) + else Format.asprintf "%s $" (to_dollars z3_str) (* The Z3 date representation corresponds to the number of days since Jan 1, 1900. We pretty-print it as the actual date *) (* TODO: Use differnt dates conventions depending on the language ? *) @@ -159,44 +159,44 @@ let rec print_z3model_expr (ctx : context) (ty : typ Pos.marked) (e : Expr.expr) match Pos.unmark ty with | TLit ty -> print_lit ty | TTuple (_, Some name) -> - let s = StructMap.find name ctx.ctx_decl.ctx_structs in - let get_fieldname (fn : StructFieldName.t) : string = - Pos.unmark (StructFieldName.get_info fn) - in - let fields = - List.map2 - (fun (fn, ty) e -> - Format.asprintf "-- %s : %s" (get_fieldname fn) - (print_z3model_expr ctx ty e)) - s (Expr.get_args e) - in + let s = StructMap.find name ctx.ctx_decl.ctx_structs in + let get_fieldname (fn : StructFieldName.t) : string = + Pos.unmark (StructFieldName.get_info fn) + in + let fields = + List.map2 + (fun (fn, ty) e -> + Format.asprintf "-- %s : %s" (get_fieldname fn) + (print_z3model_expr ctx ty e)) + s (Expr.get_args e) + in - let fields_str = String.concat " " fields in + let fields_str = String.concat " " fields in - Format.asprintf "%s { %s }" - (Pos.unmark (StructName.get_info name)) - fields_str + Format.asprintf "%s { %s }" + (Pos.unmark (StructName.get_info name)) + fields_str | TTuple (_, None) -> - failwith "[Z3 model]: Pretty-printing of unnamed structs not supported" + failwith "[Z3 model]: Pretty-printing of unnamed structs not supported" | TEnum (_tys, name) -> - (* The value associated to the enum is a single argument *) - let e' = List.hd (Expr.get_args e) in - let fd = Expr.get_func_decl e in - let fd_name = Symbol.to_string (FuncDecl.get_name fd) in + (* The value associated to the enum is a single argument *) + let e' = List.hd (Expr.get_args e) in + let fd = Expr.get_func_decl e in + let fd_name = Symbol.to_string (FuncDecl.get_name fd) in - let enum_ctrs = EnumMap.find name ctx.ctx_decl.ctx_enums in - let case = - List.find - (fun (ctr, _) -> - String.equal fd_name (Pos.unmark (EnumConstructor.get_info ctr))) - enum_ctrs - in + let enum_ctrs = EnumMap.find name ctx.ctx_decl.ctx_enums in + let case = + List.find + (fun (ctr, _) -> + String.equal fd_name (Pos.unmark (EnumConstructor.get_info ctr))) + enum_ctrs + in - Format.asprintf "%s (%s)" fd_name (print_z3model_expr ctx (snd case) e') + Format.asprintf "%s (%s)" fd_name (print_z3model_expr ctx (snd case) e') | TArrow _ -> failwith "[Z3 model]: Pretty-printing of arrows not supported" | TArray _ -> - (* For now, only the length of arrays is modeled *) - Format.asprintf "(length = %s)" (Expr.to_string e) + (* For now, only the length of arrays is modeled *) + Format.asprintf "(length = %s)" (Expr.to_string e) | TAny -> failwith "[Z3 model]: Pretty-printing of Any not supported" (** [print_model] pretty prints a Z3 model, used to exhibit counter examples @@ -215,36 +215,32 @@ let print_model (ctx : context) (model : Model.model) : string = match Model.get_const_interp model d with (* TODO: Better handling of this case *) | None -> - failwith - "[Z3 model]: A variable does not have an associated Z3 \ - solution" + failwith + "[Z3 model]: A variable does not have an associated Z3 solution" (* Print "name : value\n" *) | Some e -> - let symbol_name = Symbol.to_string (FuncDecl.get_name d) in - let v = StringMap.find symbol_name ctx.ctx_z3vars in - Format.fprintf fmt "%s %s : %s" - (Cli.with_style [ ANSITerminal.blue ] "%s" "-->") - (Cli.with_style [ ANSITerminal.yellow ] "%s" - (Bindlib.name_of v)) - (print_z3model_expr ctx (VarMap.find v ctx.ctx_var) e) + let symbol_name = Symbol.to_string (FuncDecl.get_name d) in + let v = StringMap.find symbol_name ctx.ctx_z3vars in + Format.fprintf fmt "%s %s : %s" + (Cli.with_style [ANSITerminal.blue] "%s" "-->") + (Cli.with_style [ANSITerminal.yellow] "%s" (Bindlib.name_of v)) + (print_z3model_expr ctx (VarMap.find v ctx.ctx_var) e) else (* Declaration d is a function *) match Model.get_func_interp model d with (* TODO: Better handling of this case *) | None -> - failwith - "[Z3 model]: A variable does not have an associated Z3 \ - solution" + failwith + "[Z3 model]: A variable does not have an associated Z3 solution" (* Print "name : value\n" *) | Some f -> - let symbol_name = Symbol.to_string (FuncDecl.get_name d) in - let v = StringMap.find symbol_name ctx.ctx_z3vars in - Format.fprintf fmt "%s %s : %s" - (Cli.with_style [ ANSITerminal.blue ] "%s" "-->") - (Cli.with_style [ ANSITerminal.yellow ] "%s" - (Bindlib.name_of v)) - (* TODO: Model of a Z3 function should be pretty-printed *) - (Model.FuncInterp.to_string f))) + let symbol_name = Symbol.to_string (FuncDecl.get_name d) in + let v = StringMap.find symbol_name ctx.ctx_z3vars in + Format.fprintf fmt "%s %s : %s" + (Cli.with_style [ANSITerminal.blue] "%s" "-->") + (Cli.with_style [ANSITerminal.yellow] "%s" (Bindlib.name_of v)) + (* TODO: Model of a Z3 function should be pretty-printed *) + (Model.FuncInterp.to_string f))) decls (** [translate_typ_lit] returns the Z3 sort corresponding to the Catala literal @@ -264,16 +260,16 @@ let translate_typ_lit (ctx : context) (t : typ_lit) : Sort.sort = (** [translate_typ] returns the Z3 sort correponding to the Catala type [t] **) let rec translate_typ (ctx : context) (t : typ) : context * Sort.sort = match t with - | TLit t -> (ctx, translate_typ_lit ctx t) + | TLit t -> ctx, translate_typ_lit ctx t | TTuple (_, Some name) -> find_or_create_struct ctx name | TTuple (_, None) -> - failwith "[Z3 encoding] TTuple type of unnamed struct not supported" + failwith "[Z3 encoding] TTuple type of unnamed struct not supported" | TEnum (_, e) -> find_or_create_enum ctx e | TArrow _ -> failwith "[Z3 encoding] TArrow type not supported" | TArray _ -> - (* For now, we are only encoding the (symbolic) length of an array. - Ultimately, the type of an array should also contain its elements *) - (ctx, Arithmetic.Integer.mk_sort ctx.ctx_z3) + (* For now, we are only encoding the (symbolic) length of an array. + Ultimately, the type of an array should also contain its elements *) + ctx, Arithmetic.Integer.mk_sort ctx.ctx_z3 | TAny -> failwith "[Z3 encoding] TAny type not supported" (** [find_or_create_enum] attempts to retrieve the Z3 sort corresponding to the @@ -284,7 +280,8 @@ and find_or_create_enum (ctx : context) (enum : EnumName.t) : context * Sort.sort = (* Creates a Z3 constructor corresponding to the Catala constructor [c] *) let create_constructor - (ctx : context) (c : EnumConstructor.t * typ Pos.marked) : + (ctx : context) + (c : EnumConstructor.t * typ Pos.marked) : context * Datatype.Constructor.constructor = let name, ty = c in let name = Pos.unmark (EnumConstructor.get_info name) in @@ -303,23 +300,23 @@ and find_or_create_enum (ctx : context) (enum : EnumName.t) : (* We need a name for the argument of the constructor, we arbitrary pick the name of the constructor to which we append the special character "!" and the integer 0 *) - [ Symbol.mk_string ctx.ctx_z3 (name ^ "!0") ] + [Symbol.mk_string ctx.ctx_z3 (name ^ "!0")] (* The type of the argument, translated to a Z3 sort *) - [ Some arg_z3_ty ] - [ Sort.get_id arg_z3_ty ] ) + [Some arg_z3_ty] + [Sort.get_id arg_z3_ty] ) in match EnumMap.find_opt enum ctx.ctx_z3datatypes with - | Some e -> (ctx, e) + | Some e -> ctx, e | None -> - let ctrs = EnumMap.find enum ctx.ctx_decl.ctx_enums in - let ctx, z3_ctrs = List.fold_left_map create_constructor ctx ctrs in - let z3_enum = - Datatype.mk_sort_s ctx.ctx_z3 - (Pos.unmark (EnumName.get_info enum)) - z3_ctrs - in - (add_z3enum enum z3_enum ctx, z3_enum) + let ctrs = EnumMap.find enum ctx.ctx_decl.ctx_enums in + let ctx, z3_ctrs = List.fold_left_map create_constructor ctx ctrs in + let z3_enum = + Datatype.mk_sort_s ctx.ctx_z3 + (Pos.unmark (EnumName.get_info enum)) + z3_ctrs + in + add_z3enum enum z3_enum ctx, z3_enum (** [find_or_create_struct] attemps to retrieve the Z3 sort corresponding to the struct [s]. If no such sort exists yet, we construct it as a datatype with @@ -328,61 +325,61 @@ and find_or_create_enum (ctx : context) (enum : EnumName.t) : and find_or_create_struct (ctx : context) (s : StructName.t) : context * Sort.sort = match StructMap.find_opt s ctx.ctx_z3structs with - | Some s -> (ctx, s) + | Some s -> ctx, s | None -> - let s_name = Pos.unmark (StructName.get_info s) in - let fields = StructMap.find s ctx.ctx_decl.ctx_structs in - let z3_fieldnames = - List.map - (fun f -> - Pos.unmark (StructFieldName.get_info (fst f)) - |> Symbol.mk_string ctx.ctx_z3) - fields - in - let ctx, z3_fieldtypes = - List.fold_left_map - (fun ctx f -> Pos.unmark (snd f) |> translate_typ ctx) - ctx fields - in - let z3_sortrefs = List.map Sort.get_id z3_fieldtypes in - let mk_struct_s = "mk!" ^ s_name in - let z3_mk_struct = - Datatype.mk_constructor_s ctx.ctx_z3 mk_struct_s - (Symbol.mk_string ctx.ctx_z3 mk_struct_s) - z3_fieldnames - (List.map (fun x -> Some x) z3_fieldtypes) - z3_sortrefs - in + let s_name = Pos.unmark (StructName.get_info s) in + let fields = StructMap.find s ctx.ctx_decl.ctx_structs in + let z3_fieldnames = + List.map + (fun f -> + Pos.unmark (StructFieldName.get_info (fst f)) + |> Symbol.mk_string ctx.ctx_z3) + fields + in + let ctx, z3_fieldtypes = + List.fold_left_map + (fun ctx f -> Pos.unmark (snd f) |> translate_typ ctx) + ctx fields + in + let z3_sortrefs = List.map Sort.get_id z3_fieldtypes in + let mk_struct_s = "mk!" ^ s_name in + let z3_mk_struct = + Datatype.mk_constructor_s ctx.ctx_z3 mk_struct_s + (Symbol.mk_string ctx.ctx_z3 mk_struct_s) + z3_fieldnames + (List.map (fun x -> Some x) z3_fieldtypes) + z3_sortrefs + in - let z3_struct = Datatype.mk_sort_s ctx.ctx_z3 s_name [ z3_mk_struct ] in - (add_z3struct s z3_struct ctx, z3_struct) + let z3_struct = Datatype.mk_sort_s ctx.ctx_z3 s_name [z3_mk_struct] in + add_z3struct s z3_struct ctx, z3_struct (** [translate_lit] returns the Z3 expression as a literal corresponding to [lit] **) let translate_lit (ctx : context) (l : lit) : Expr.expr = match l with | LBool b -> - if b then Boolean.mk_true ctx.ctx_z3 else Boolean.mk_false ctx.ctx_z3 + if b then Boolean.mk_true ctx.ctx_z3 else Boolean.mk_false ctx.ctx_z3 | LEmptyError -> failwith "[Z3 encoding] LEmptyError literals not supported" | LInt n -> - Arithmetic.Integer.mk_numeral_i ctx.ctx_z3 (Runtime.integer_to_int n) + Arithmetic.Integer.mk_numeral_i ctx.ctx_z3 (Runtime.integer_to_int n) | LRat r -> - Arithmetic.Real.mk_numeral_s ctx.ctx_z3 - (string_of_float (Runtime.decimal_to_float r)) + Arithmetic.Real.mk_numeral_s ctx.ctx_z3 + (string_of_float (Runtime.decimal_to_float r)) | LMoney m -> - let z3_m = Runtime.integer_to_int (Runtime.money_to_cents m) in - Arithmetic.Integer.mk_numeral_i ctx.ctx_z3 z3_m + let z3_m = Runtime.integer_to_int (Runtime.money_to_cents m) in + Arithmetic.Integer.mk_numeral_i ctx.ctx_z3 z3_m | LUnit -> snd ctx.ctx_z3unit (* Encoding a date as an integer corresponding to the number of days since Jan 1, 1900 *) | LDate d -> Arithmetic.Integer.mk_numeral_i ctx.ctx_z3 (date_to_int d) | LDuration d -> - let y, m, d = Runtime.duration_to_years_months_days d in - if y <> 0 || m <> 0 then - failwith - "[Z3 encoding]: Duration literals containing years or months not \ - supported"; - Arithmetic.Integer.mk_numeral_i ctx.ctx_z3 d + let y, m, d = Runtime.duration_to_years_months_days d in + if y <> 0 || m <> 0 then + failwith + "[Z3 encoding]: Duration literals containing years or months not \ + supported"; + Arithmetic.Integer.mk_numeral_i ctx.ctx_z3 d (** [find_or_create_funcdecl] attempts to retrieve the Z3 function declaration corresponding to the variable [v]. If no such function declaration exists @@ -391,217 +388,208 @@ let translate_lit (ctx : context) (l : lit) : Expr.expr = let find_or_create_funcdecl (ctx : context) (v : Var.t) : context * FuncDecl.func_decl = match VarMap.find_opt v ctx.ctx_funcdecl with - | Some fd -> (ctx, fd) + | Some fd -> ctx, fd | None -> ( - (* Retrieves the Catala type of the function [v] *) - let f_ty = VarMap.find v ctx.ctx_var in - match Pos.unmark f_ty with - | TArrow (t1, t2) -> - let ctx, z3_t1 = translate_typ ctx (Pos.unmark t1) in - let ctx, z3_t2 = translate_typ ctx (Pos.unmark t2) in - let name = unique_name v in - let fd = FuncDecl.mk_func_decl_s ctx.ctx_z3 name [ z3_t1 ] z3_t2 in - let ctx = add_funcdecl v fd ctx in - let ctx = add_z3var name v ctx in - (ctx, fd) - | TAny -> - failwith - "[Z3 Encoding] A function being applied has type TAny, the type \ - was not fully inferred" - | _ -> - failwith - "[Z3 Encoding] Ill-formed VC, a function application does not have \ - a function type") + (* Retrieves the Catala type of the function [v] *) + let f_ty = VarMap.find v ctx.ctx_var in + match Pos.unmark f_ty with + | TArrow (t1, t2) -> + let ctx, z3_t1 = translate_typ ctx (Pos.unmark t1) in + let ctx, z3_t2 = translate_typ ctx (Pos.unmark t2) in + let name = unique_name v in + let fd = FuncDecl.mk_func_decl_s ctx.ctx_z3 name [z3_t1] z3_t2 in + let ctx = add_funcdecl v fd ctx in + let ctx = add_z3var name v ctx in + ctx, fd + | TAny -> + failwith + "[Z3 Encoding] A function being applied has type TAny, the type was \ + not fully inferred" + | _ -> + failwith + "[Z3 Encoding] Ill-formed VC, a function application does not have a \ + function type") (** [translate_op] returns the Z3 expression corresponding to the application of [op] to the arguments [args] **) let rec translate_op - (ctx : context) (op : operator) (args : expr Pos.marked list) : - context * Expr.expr = + (ctx : context) + (op : operator) + (args : expr Pos.marked list) : context * Expr.expr = match op with | Ternop _top -> - let _e1, _e2, _e3 = - match args with - | [ e1; e2; e3 ] -> (e1, e2, e3) - | _ -> - failwith - (Format.asprintf - "[Z3 encoding] Ill-formed ternary operator application: %a" - (Print.format_expr ctx.ctx_decl) - (EApp ((EOp op, Pos.no_pos), args), Pos.no_pos)) - in + let _e1, _e2, _e3 = + match args with + | [e1; e2; e3] -> e1, e2, e3 + | _ -> + failwith + (Format.asprintf + "[Z3 encoding] Ill-formed ternary operator application: %a" + (Print.format_expr ctx.ctx_decl) + (EApp ((EOp op, Pos.no_pos), args), Pos.no_pos)) + in - failwith "[Z3 encoding] ternary operator application not supported" + failwith "[Z3 encoding] ternary operator application not supported" | Binop bop -> ( - (* Special case for GetYear comparisons *) - match (bop, args) with - | ( Lt KInt, - [ (EApp ((EOp (Unop GetYear), _), [ e1 ]), _); (ELit (LInt n), _) ] ) - -> - let n = Runtime.integer_to_int n in - let ctx, e1 = translate_expr ctx e1 in - let e2 = - Arithmetic.Integer.mk_numeral_i ctx.ctx_z3 - (date_to_int (date_of_year n)) - in - (* e2 corresponds to the first day of the year n. GetYear e1 < e2 can - thus be directly translated as < in the Z3 encoding using the - number of days *) - (ctx, Arithmetic.mk_lt ctx.ctx_z3 e1 e2) - | ( Lte KInt, - [ (EApp ((EOp (Unop GetYear), _), [ e1 ]), _); (ELit (LInt n), _) ] ) - -> - let n = Runtime.integer_to_int n in - let ctx, e1 = translate_expr ctx e1 in - let nb_days = if CalendarLib.Date.is_leap_year n then 365 else 364 in - (* We want that the year corresponding to e1 is smaller or equal to n. - We encode this as the day corresponding to e1 is smaller or equal - than the last day of the year [n], which is Jan 1st + 365 days if - [n] is a leap year, Jan 1st + 364 else *) - let e2 = - Arithmetic.Integer.mk_numeral_i ctx.ctx_z3 - (date_to_int (date_of_year n) + nb_days) - in - (ctx, Arithmetic.mk_le ctx.ctx_z3 e1 e2) - | ( Gt KInt, - [ (EApp ((EOp (Unop GetYear), _), [ e1 ]), _); (ELit (LInt n), _) ] ) - -> - let n = Runtime.integer_to_int n in - let ctx, e1 = translate_expr ctx e1 in - let nb_days = if CalendarLib.Date.is_leap_year n then 365 else 364 in - (* We want that the year corresponding to e1 is greater to n. We - encode this as the day corresponding to e1 is greater than the last - day of the year [n], which is Jan 1st + 365 days if [n] is a leap - year, Jan 1st + 364 else *) - let e2 = - Arithmetic.Integer.mk_numeral_i ctx.ctx_z3 - (date_to_int (date_of_year n) + nb_days) - in - (ctx, Arithmetic.mk_gt ctx.ctx_z3 e1 e2) - | ( Gte KInt, - [ (EApp ((EOp (Unop GetYear), _), [ e1 ]), _); (ELit (LInt n), _) ] ) - -> - let n = Runtime.integer_to_int n in - let ctx, e1 = translate_expr ctx e1 in - let e2 = - Arithmetic.Integer.mk_numeral_i ctx.ctx_z3 - (date_to_int (date_of_year n)) - in - (* e2 corresponds to the first day of the year n. GetYear e1 >= e2 can - thus be directly translated as >= in the Z3 encoding using the - number of days *) - (ctx, Arithmetic.mk_ge ctx.ctx_z3 e1 e2) - | Eq, [ (EApp ((EOp (Unop GetYear), _), [ e1 ]), _); (ELit (LInt n), _) ] - -> - let n = Runtime.integer_to_int n in - let ctx, e1 = translate_expr ctx e1 in - let min_date = - Arithmetic.Integer.mk_numeral_i ctx.ctx_z3 - (date_to_int (date_of_year n)) - in - let max_date = - Arithmetic.Integer.mk_numeral_i ctx.ctx_z3 - (date_to_int (date_of_year (n + 1))) - in - ( ctx, - Boolean.mk_and ctx.ctx_z3 - [ - Arithmetic.mk_ge ctx.ctx_z3 e1 min_date; - Arithmetic.mk_lt ctx.ctx_z3 e1 max_date; - ] ) - | _ -> ( - let ctx, e1, e2 = - match args with - | [ e1; e2 ] -> - let ctx, e1 = translate_expr ctx e1 in - let ctx, e2 = translate_expr ctx e2 in - (ctx, e1, e2) - | _ -> - failwith - (Format.asprintf - "[Z3 encoding] Ill-formed binary operator application: %a" - (Print.format_expr ctx.ctx_decl) - (EApp ((EOp op, Pos.no_pos), args), Pos.no_pos)) - in - - match bop with - | And -> (ctx, Boolean.mk_and ctx.ctx_z3 [ e1; e2 ]) - | Or -> (ctx, Boolean.mk_or ctx.ctx_z3 [ e1; e2 ]) - | Xor -> (ctx, Boolean.mk_xor ctx.ctx_z3 e1 e2) - | Add KInt | Add KRat | Add KMoney | Add KDate | Add KDuration -> - (ctx, Arithmetic.mk_add ctx.ctx_z3 [ e1; e2 ]) - | Sub KInt | Sub KRat | Sub KMoney | Sub KDate | Sub KDuration -> - (ctx, Arithmetic.mk_sub ctx.ctx_z3 [ e1; e2 ]) - | Mult KInt | Mult KRat | Mult KMoney | Mult KDate | Mult KDuration -> - (ctx, Arithmetic.mk_mul ctx.ctx_z3 [ e1; e2 ]) - | Div KInt | Div KRat | Div KMoney -> - (ctx, Arithmetic.mk_div ctx.ctx_z3 e1 e2) - | Div _ -> - failwith - "[Z3 encoding] application of non-integer binary operator Div \ - not supported" - | Lt KInt | Lt KRat | Lt KMoney | Lt KDate | Lt KDuration -> - (ctx, Arithmetic.mk_lt ctx.ctx_z3 e1 e2) - | Lte KInt | Lte KRat | Lte KMoney | Lte KDate | Lte KDuration -> - (ctx, Arithmetic.mk_le ctx.ctx_z3 e1 e2) - | Gt KInt | Gt KRat | Gt KMoney | Gt KDate | Gt KDuration -> - (ctx, Arithmetic.mk_gt ctx.ctx_z3 e1 e2) - | Gte KInt | Gte KRat | Gte KMoney | Gte KDate | Gte KDuration -> - (ctx, Arithmetic.mk_ge ctx.ctx_z3 e1 e2) - | Eq -> (ctx, Boolean.mk_eq ctx.ctx_z3 e1 e2) - | Neq -> - (ctx, Boolean.mk_not ctx.ctx_z3 (Boolean.mk_eq ctx.ctx_z3 e1 e2)) - | Map -> - failwith - "[Z3 encoding] application of binary operator Map not supported" - | Concat -> - failwith - "[Z3 encoding] application of binary operator Concat not \ - supported" - | Filter -> - failwith - "[Z3 encoding] application of binary operator Filter not \ - supported")) - | Unop uop -> ( - let ctx, e1 = + (* Special case for GetYear comparisons *) + match bop, args with + | Lt KInt, [(EApp ((EOp (Unop GetYear), _), [e1]), _); (ELit (LInt n), _)] + -> + let n = Runtime.integer_to_int n in + let ctx, e1 = translate_expr ctx e1 in + let e2 = + Arithmetic.Integer.mk_numeral_i ctx.ctx_z3 + (date_to_int (date_of_year n)) + in + (* e2 corresponds to the first day of the year n. GetYear e1 < e2 can thus + be directly translated as < in the Z3 encoding using the number of + days *) + ctx, Arithmetic.mk_lt ctx.ctx_z3 e1 e2 + | Lte KInt, [(EApp ((EOp (Unop GetYear), _), [e1]), _); (ELit (LInt n), _)] + -> + let n = Runtime.integer_to_int n in + let ctx, e1 = translate_expr ctx e1 in + let nb_days = if CalendarLib.Date.is_leap_year n then 365 else 364 in + (* We want that the year corresponding to e1 is smaller or equal to n. We + encode this as the day corresponding to e1 is smaller or equal than the + last day of the year [n], which is Jan 1st + 365 days if [n] is a leap + year, Jan 1st + 364 else *) + let e2 = + Arithmetic.Integer.mk_numeral_i ctx.ctx_z3 + (date_to_int (date_of_year n) + nb_days) + in + ctx, Arithmetic.mk_le ctx.ctx_z3 e1 e2 + | Gt KInt, [(EApp ((EOp (Unop GetYear), _), [e1]), _); (ELit (LInt n), _)] + -> + let n = Runtime.integer_to_int n in + let ctx, e1 = translate_expr ctx e1 in + let nb_days = if CalendarLib.Date.is_leap_year n then 365 else 364 in + (* We want that the year corresponding to e1 is greater to n. We encode + this as the day corresponding to e1 is greater than the last day of the + year [n], which is Jan 1st + 365 days if [n] is a leap year, Jan 1st + + 364 else *) + let e2 = + Arithmetic.Integer.mk_numeral_i ctx.ctx_z3 + (date_to_int (date_of_year n) + nb_days) + in + ctx, Arithmetic.mk_gt ctx.ctx_z3 e1 e2 + | Gte KInt, [(EApp ((EOp (Unop GetYear), _), [e1]), _); (ELit (LInt n), _)] + -> + let n = Runtime.integer_to_int n in + let ctx, e1 = translate_expr ctx e1 in + let e2 = + Arithmetic.Integer.mk_numeral_i ctx.ctx_z3 + (date_to_int (date_of_year n)) + in + (* e2 corresponds to the first day of the year n. GetYear e1 >= e2 can + thus be directly translated as >= in the Z3 encoding using the number + of days *) + ctx, Arithmetic.mk_ge ctx.ctx_z3 e1 e2 + | Eq, [(EApp ((EOp (Unop GetYear), _), [e1]), _); (ELit (LInt n), _)] -> + let n = Runtime.integer_to_int n in + let ctx, e1 = translate_expr ctx e1 in + let min_date = + Arithmetic.Integer.mk_numeral_i ctx.ctx_z3 + (date_to_int (date_of_year n)) + in + let max_date = + Arithmetic.Integer.mk_numeral_i ctx.ctx_z3 + (date_to_int (date_of_year (n + 1))) + in + ( ctx, + Boolean.mk_and ctx.ctx_z3 + [ + Arithmetic.mk_ge ctx.ctx_z3 e1 min_date; + Arithmetic.mk_lt ctx.ctx_z3 e1 max_date; + ] ) + | _ -> ( + let ctx, e1, e2 = match args with - | [ e1 ] -> translate_expr ctx e1 + | [e1; e2] -> + let ctx, e1 = translate_expr ctx e1 in + let ctx, e2 = translate_expr ctx e2 in + ctx, e1, e2 | _ -> - failwith - (Format.asprintf - "[Z3 encoding] Ill-formed unary operator application: %a" - (Print.format_expr ctx.ctx_decl) - (EApp ((EOp op, Pos.no_pos), args), Pos.no_pos)) + failwith + (Format.asprintf + "[Z3 encoding] Ill-formed binary operator application: %a" + (Print.format_expr ctx.ctx_decl) + (EApp ((EOp op, Pos.no_pos), args), Pos.no_pos)) in - match uop with - | Not -> (ctx, Boolean.mk_not ctx.ctx_z3 e1) - | Minus _ -> - failwith - "[Z3 encoding] application of unary operator Minus not supported" - (* Omitting the log from the VC *) - | Log _ -> (ctx, e1) - | Length -> - (* For now, an array is only its symbolic length. We simply return - it *) - (ctx, e1) - | IntToRat -> - failwith - "[Z3 encoding] application of unary operator IntToRat not supported" - | GetDay -> - failwith - "[Z3 encoding] application of unary operator GetDay not supported" - | GetMonth -> - failwith - "[Z3 encoding] application of unary operator GetMonth not supported" - | GetYear -> - failwith - "[Z3 encoding] GetYear operator only supported in comparisons with \ - literal" - | RoundDecimal -> - failwith "[Z3 encoding] RoundDecimal operator not implemented yet" - | RoundMoney -> - failwith "[Z3 encoding] RoundMoney operator not implemented yet") + match bop with + | And -> ctx, Boolean.mk_and ctx.ctx_z3 [e1; e2] + | Or -> ctx, Boolean.mk_or ctx.ctx_z3 [e1; e2] + | Xor -> ctx, Boolean.mk_xor ctx.ctx_z3 e1 e2 + | Add KInt | Add KRat | Add KMoney | Add KDate | Add KDuration -> + ctx, Arithmetic.mk_add ctx.ctx_z3 [e1; e2] + | Sub KInt | Sub KRat | Sub KMoney | Sub KDate | Sub KDuration -> + ctx, Arithmetic.mk_sub ctx.ctx_z3 [e1; e2] + | Mult KInt | Mult KRat | Mult KMoney | Mult KDate | Mult KDuration -> + ctx, Arithmetic.mk_mul ctx.ctx_z3 [e1; e2] + | Div KInt | Div KRat | Div KMoney -> + ctx, Arithmetic.mk_div ctx.ctx_z3 e1 e2 + | Div _ -> + failwith + "[Z3 encoding] application of non-integer binary operator Div not \ + supported" + | Lt KInt | Lt KRat | Lt KMoney | Lt KDate | Lt KDuration -> + ctx, Arithmetic.mk_lt ctx.ctx_z3 e1 e2 + | Lte KInt | Lte KRat | Lte KMoney | Lte KDate | Lte KDuration -> + ctx, Arithmetic.mk_le ctx.ctx_z3 e1 e2 + | Gt KInt | Gt KRat | Gt KMoney | Gt KDate | Gt KDuration -> + ctx, Arithmetic.mk_gt ctx.ctx_z3 e1 e2 + | Gte KInt | Gte KRat | Gte KMoney | Gte KDate | Gte KDuration -> + ctx, Arithmetic.mk_ge ctx.ctx_z3 e1 e2 + | Eq -> ctx, Boolean.mk_eq ctx.ctx_z3 e1 e2 + | Neq -> ctx, Boolean.mk_not ctx.ctx_z3 (Boolean.mk_eq ctx.ctx_z3 e1 e2) + | Map -> + failwith + "[Z3 encoding] application of binary operator Map not supported" + | Concat -> + failwith + "[Z3 encoding] application of binary operator Concat not supported" + | Filter -> + failwith + "[Z3 encoding] application of binary operator Filter not supported")) + | Unop uop -> ( + let ctx, e1 = + match args with + | [e1] -> translate_expr ctx e1 + | _ -> + failwith + (Format.asprintf + "[Z3 encoding] Ill-formed unary operator application: %a" + (Print.format_expr ctx.ctx_decl) + (EApp ((EOp op, Pos.no_pos), args), Pos.no_pos)) + in + + match uop with + | Not -> ctx, Boolean.mk_not ctx.ctx_z3 e1 + | Minus _ -> + failwith "[Z3 encoding] application of unary operator Minus not supported" + (* Omitting the log from the VC *) + | Log _ -> ctx, e1 + | Length -> + (* For now, an array is only its symbolic length. We simply return it *) + ctx, e1 + | IntToRat -> + failwith + "[Z3 encoding] application of unary operator IntToRat not supported" + | GetDay -> + failwith + "[Z3 encoding] application of unary operator GetDay not supported" + | GetMonth -> + failwith + "[Z3 encoding] application of unary operator GetMonth not supported" + | GetYear -> + failwith + "[Z3 encoding] GetYear operator only supported in comparisons with \ + literal" + | RoundDecimal -> + failwith "[Z3 encoding] RoundDecimal operator not implemented yet" + | RoundMoney -> + failwith "[Z3 encoding] RoundMoney operator not implemented yet") (** [translate_expr] translate the expression [vc] to its corresponding Z3 expression **) @@ -614,136 +602,134 @@ and translate_expr (ctx : context) (vc : expr Pos.marked) : context * Expr.expr let e, accessors = e in match Pos.unmark e with | EAbs (e, _) -> - (* Create a fresh Catala variable to substitue and obtain the body *) - let fresh_v = Var.make ("arm!tmp", Pos.no_pos) in - let fresh_e = EVar (fresh_v, Pos.no_pos) in + (* Create a fresh Catala variable to substitue and obtain the body *) + let fresh_v = Var.make ("arm!tmp", Pos.no_pos) in + let fresh_e = EVar (fresh_v, Pos.no_pos) in - (* Invariant: Catala enums always have exactly one argument *) - let accessor = List.hd accessors in - let proj = Expr.mk_app ctx.ctx_z3 accessor [ head ] in - (* The fresh variable should be substituted by a projection into the - enum in the body, we add this to the context *) - let ctx = add_z3matchsubst fresh_v proj ctx in + (* Invariant: Catala enums always have exactly one argument *) + let accessor = List.hd accessors in + let proj = Expr.mk_app ctx.ctx_z3 accessor [head] in + (* The fresh variable should be substituted by a projection into the enum + in the body, we add this to the context *) + let ctx = add_z3matchsubst fresh_v proj ctx in - let body = Bindlib.msubst (Pos.unmark e) [| fresh_e |] in - translate_expr ctx body + let body = Bindlib.msubst (Pos.unmark e) [| fresh_e |] in + translate_expr ctx body (* Invariant: Catala match arms are always lambda*) | _ -> failwith "[Z3 encoding] : Arms branches inside VCs should be lambdas" in match Pos.unmark vc with | EVar v -> ( - match VarMap.find_opt (Pos.unmark v) ctx.ctx_z3matchsubsts with - | None -> - (* We are in the standard case, where this is a true Catala - variable *) - let v = Pos.unmark v in - let t = VarMap.find v ctx.ctx_var in - let name = unique_name v in - let ctx = add_z3var name v ctx in - let ctx, ty = translate_typ ctx (Pos.unmark t) in - let z3_var = Expr.mk_const_s ctx.ctx_z3 name ty in - let ctx = - match Pos.unmark t with - (* If we are creating a new array, we need to log that its length is - greater than 0 *) - | TArray _ -> - add_z3constraint - (Arithmetic.mk_ge ctx.ctx_z3 z3_var - (Arithmetic.Integer.mk_numeral_i ctx.ctx_z3 0)) - ctx - | _ -> ctx - in + match VarMap.find_opt (Pos.unmark v) ctx.ctx_z3matchsubsts with + | None -> + (* We are in the standard case, where this is a true Catala variable *) + let v = Pos.unmark v in + let t = VarMap.find v ctx.ctx_var in + let name = unique_name v in + let ctx = add_z3var name v ctx in + let ctx, ty = translate_typ ctx (Pos.unmark t) in + let z3_var = Expr.mk_const_s ctx.ctx_z3 name ty in + let ctx = + match Pos.unmark t with + (* If we are creating a new array, we need to log that its length is + greater than 0 *) + | TArray _ -> + add_z3constraint + (Arithmetic.mk_ge ctx.ctx_z3 z3_var + (Arithmetic.Integer.mk_numeral_i ctx.ctx_z3 0)) + ctx + | _ -> ctx + in - (ctx, z3_var) - | Some e -> - (* This variable is a temporary variable generated during VC - translation of a match. It actually corresponds to applying an - accessor to an enum, the corresponding Z3 expression was previously - stored in the context *) - (ctx, e)) + ctx, z3_var + | Some e -> + (* This variable is a temporary variable generated during VC translation + of a match. It actually corresponds to applying an accessor to an enum, + the corresponding Z3 expression was previously stored in the context *) + ctx, e) | ETuple _ -> failwith "[Z3 encoding] ETuple unsupported" | ETupleAccess (s, idx, oname, _tys) -> - let name = - match oname with - | None -> - failwith "[Z3 encoding]: ETupleAccess of unnamed struct unsupported" - | Some n -> n - in - let ctx, z3_struct = find_or_create_struct ctx name in - (* This datatype should have only one constructor, corresponding to - mk_struct. The accessors of this constructor correspond to the field - accesses *) - let accessors = List.hd (Datatype.get_accessors z3_struct) in - let accessor = List.nth accessors idx in - let ctx, s = translate_expr ctx s in - (ctx, Expr.mk_app ctx.ctx_z3 accessor [ s ]) + let name = + match oname with + | None -> + failwith "[Z3 encoding]: ETupleAccess of unnamed struct unsupported" + | Some n -> n + in + let ctx, z3_struct = find_or_create_struct ctx name in + (* This datatype should have only one constructor, corresponding to + mk_struct. The accessors of this constructor correspond to the field + accesses *) + let accessors = List.hd (Datatype.get_accessors z3_struct) in + let accessor = List.nth accessors idx in + let ctx, s = translate_expr ctx s in + ctx, Expr.mk_app ctx.ctx_z3 accessor [s] | EInj (e, idx, en, _tys) -> - (* This node corresponds to creating a value for the enumeration [en], by - calling the [idx]-th constructor of enum [en], with argument [e] *) - let ctx, z3_enum = find_or_create_enum ctx en in - let ctx, z3_arg = translate_expr ctx e in - let ctrs = Datatype.get_constructors z3_enum in - (* This should always succeed if the expression is well-typed in dcalc *) - let ctr = List.nth ctrs idx in - (ctx, Expr.mk_app ctx.ctx_z3 ctr [ z3_arg ]) + (* This node corresponds to creating a value for the enumeration [en], by + calling the [idx]-th constructor of enum [en], with argument [e] *) + let ctx, z3_enum = find_or_create_enum ctx en in + let ctx, z3_arg = translate_expr ctx e in + let ctrs = Datatype.get_constructors z3_enum in + (* This should always succeed if the expression is well-typed in dcalc *) + let ctr = List.nth ctrs idx in + ctx, Expr.mk_app ctx.ctx_z3 ctr [z3_arg] | EMatch (arg, arms, enum) -> - let ctx, z3_enum = find_or_create_enum ctx enum in - let ctx, z3_arg = translate_expr ctx arg in - let _ctx, z3_arms = - List.fold_left_map - (translate_match_arm z3_arg) - ctx - (List.combine arms (Datatype.get_accessors z3_enum)) - in - let z3_arms = - List.map2 - (fun r arm -> - (* Encodes A? arg ==> body *) - let is_r = Expr.mk_app ctx.ctx_z3 r [ z3_arg ] in - Boolean.mk_implies ctx.ctx_z3 is_r arm) - (Datatype.get_recognizers z3_enum) - z3_arms - in - (ctx, Boolean.mk_and ctx.ctx_z3 z3_arms) + let ctx, z3_enum = find_or_create_enum ctx enum in + let ctx, z3_arg = translate_expr ctx arg in + let _ctx, z3_arms = + List.fold_left_map + (translate_match_arm z3_arg) + ctx + (List.combine arms (Datatype.get_accessors z3_enum)) + in + let z3_arms = + List.map2 + (fun r arm -> + (* Encodes A? arg ==> body *) + let is_r = Expr.mk_app ctx.ctx_z3 r [z3_arg] in + Boolean.mk_implies ctx.ctx_z3 is_r arm) + (Datatype.get_recognizers z3_enum) + z3_arms + in + ctx, Boolean.mk_and ctx.ctx_z3 z3_arms | EArray _ -> failwith "[Z3 encoding] EArray unsupported" - | ELit l -> (ctx, translate_lit ctx l) + | ELit l -> ctx, translate_lit ctx l | EAbs _ -> failwith "[Z3 encoding] EAbs unsupported" | EApp (head, args) -> ( - match Pos.unmark head with - | EOp op -> translate_op ctx op args - | EVar v -> - let ctx, fd = find_or_create_funcdecl ctx (Pos.unmark v) in - (* Fold_right to preserve the order of the arguments: The head - argument is appended at the head *) - let ctx, z3_args = - List.fold_right - (fun arg (ctx, acc) -> - let ctx, z3_arg = translate_expr ctx arg in - (ctx, z3_arg :: acc)) - args (ctx, []) - in - (ctx, Expr.mk_app ctx.ctx_z3 fd z3_args) - | _ -> - failwith - "[Z3 encoding] EApp node: Catala function calls should only \ - include operators or function names") + match Pos.unmark head with + | EOp op -> translate_op ctx op args + | EVar v -> + let ctx, fd = find_or_create_funcdecl ctx (Pos.unmark v) in + (* Fold_right to preserve the order of the arguments: The head argument is + appended at the head *) + let ctx, z3_args = + List.fold_right + (fun arg (ctx, acc) -> + let ctx, z3_arg = translate_expr ctx arg in + ctx, z3_arg :: acc) + args (ctx, []) + in + ctx, Expr.mk_app ctx.ctx_z3 fd z3_args + | _ -> + failwith + "[Z3 encoding] EApp node: Catala function calls should only include \ + operators or function names") | EAssert _ -> failwith "[Z3 encoding] EAssert unsupported" | EOp _ -> failwith "[Z3 encoding] EOp unsupported" | EDefault _ -> failwith "[Z3 encoding] EDefault unsupported" | EIfThenElse (e_if, e_then, e_else) -> - (* Encode this as (e_if ==> e_then) /\ (not e_if ==> e_else) *) - let ctx, z3_if = translate_expr ctx e_if in - let ctx, z3_then = translate_expr ctx e_then in - let ctx, z3_else = translate_expr ctx e_else in - ( ctx, - Boolean.mk_and ctx.ctx_z3 - [ - Boolean.mk_implies ctx.ctx_z3 z3_if z3_then; - Boolean.mk_implies ctx.ctx_z3 - (Boolean.mk_not ctx.ctx_z3 z3_if) - z3_else; - ] ) + (* Encode this as (e_if ==> e_then) /\ (not e_if ==> e_else) *) + let ctx, z3_if = translate_expr ctx e_if in + let ctx, z3_then = translate_expr ctx e_then in + let ctx, z3_else = translate_expr ctx e_else in + ( ctx, + Boolean.mk_and ctx.ctx_z3 + [ + Boolean.mk_implies ctx.ctx_z3 z3_if z3_then; + Boolean.mk_implies ctx.ctx_z3 + (Boolean.mk_not ctx.ctx_z3 z3_if) + z3_else; + ] ) | ErrorOnEmpty _ -> failwith "[Z3 encoding] ErrorOnEmpty unsupported" (** [create_z3unit] creates a Z3 sort and expression corresponding to the unit @@ -753,7 +739,7 @@ let create_z3unit (ctx : Z3.context) : Z3.context * (Sort.sort * Expr.expr) = let unit_sort = Tuple.mk_sort ctx (Symbol.mk_string ctx "unit") [] [] in let mk_unit = Tuple.get_mk_decl unit_sort in let unit_val = Expr.mk_app ctx mk_unit [] in - (ctx, (unit_sort, unit_val)) + ctx, (unit_sort, unit_val) module Backend = struct type backend_context = context @@ -790,11 +776,11 @@ module Backend = struct Cli.debug_print "Running Z3 version %s" Version.to_string let make_context - (decl_ctx : decl_ctx) (free_vars_typ : typ Pos.marked VarMap.t) : - backend_context = + (decl_ctx : decl_ctx) + (free_vars_typ : typ Pos.marked VarMap.t) : backend_context = let cfg = - (if !Cli.disable_counterexamples then [] else [ ("model", "true") ]) - @ [ ("proof", "false") ] + (if !Cli.disable_counterexamples then [] else ["model", "true"]) + @ ["proof", "false"] in let z3_ctx = mk_context cfg in let z3_ctx, z3unit = create_z3unit z3_ctx in diff --git a/french_law/ocaml/api_web.ml b/french_law/ocaml/api_web.ml index 413f9cc6..9a8a36d2 100644 --- a/french_law/ocaml/api_web.ml +++ b/french_law/ocaml/api_web.ml @@ -83,57 +83,56 @@ let rec embed_to_js (v : runtime_value) : Js.Unsafe.any = | Decimal d -> Js.Unsafe.inject (decimal_to_float d) | Money m -> Js.Unsafe.inject (money_to_float m) | Date d -> - let date = new%js Js.date_now in - ignore (date##setUTCFullYear (integer_to_int @@ year_of_date d)); - ignore (date##setUTCMonth (integer_to_int @@ month_number_of_date d)); - ignore (date##setUTCDate (integer_to_int @@ day_of_month_of_date d)); - ignore (date##setUTCHours 0); - ignore (date##setUTCMinutes 0); - ignore (date##setUTCSeconds 0); - ignore (date##setUTCMilliseconds 0); - Js.Unsafe.inject date + let date = new%js Js.date_now in + ignore (date##setUTCFullYear (integer_to_int @@ year_of_date d)); + ignore (date##setUTCMonth (integer_to_int @@ month_number_of_date d)); + ignore (date##setUTCDate (integer_to_int @@ day_of_month_of_date d)); + ignore (date##setUTCHours 0); + ignore (date##setUTCMinutes 0); + ignore (date##setUTCSeconds 0); + ignore (date##setUTCMilliseconds 0); + Js.Unsafe.inject date | Duration d -> - let days, months, years = duration_to_years_months_days d in - Js.Unsafe.inject - (Js.string (Printf.sprintf "%dD%dM%dY" days months years)) + let days, months, years = duration_to_years_months_days d in + Js.Unsafe.inject (Js.string (Printf.sprintf "%dD%dM%dY" days months years)) | Struct (name, fields) -> - Js.Unsafe.inject - (object%js - val mutable structName = - if List.length name = 1 then - Js.Unsafe.inject (Js.string (List.hd name)) - else - Js.Unsafe.inject - (Js.array (Array.of_list (List.map Js.string name))) - - val mutable structFields = + Js.Unsafe.inject + (object%js + val mutable structName = + if List.length name = 1 then + Js.Unsafe.inject (Js.string (List.hd name)) + else Js.Unsafe.inject - (Js.array - (Array.of_list - (List.map - (fun (name, v) -> - object%js - val mutable fieldName = - Js.Unsafe.inject (Js.string name) + (Js.array (Array.of_list (List.map Js.string name))) - val mutable fieldValue = - Js.Unsafe.inject (embed_to_js v) - end) - fields))) - end) + val mutable structFields = + Js.Unsafe.inject + (Js.array + (Array.of_list + (List.map + (fun (name, v) -> + object%js + val mutable fieldName = + Js.Unsafe.inject (Js.string name) + + val mutable fieldValue = + Js.Unsafe.inject (embed_to_js v) + end) + fields))) + end) | Enum (name, (case, v)) -> - Js.Unsafe.inject - (object%js - val mutable enumName = - if List.length name = 1 then - Js.Unsafe.inject (Js.string (List.hd name)) - else - Js.Unsafe.inject - (Js.array (Array.of_list (List.map Js.string name))) + Js.Unsafe.inject + (object%js + val mutable enumName = + if List.length name = 1 then + Js.Unsafe.inject (Js.string (List.hd name)) + else + Js.Unsafe.inject + (Js.array (Array.of_list (List.map Js.string name))) - val mutable enumCase = Js.Unsafe.inject (Js.string case) - val mutable enumPayload = Js.Unsafe.inject (embed_to_js v) - end) + val mutable enumCase = Js.Unsafe.inject (Js.string case) + val mutable enumPayload = Js.Unsafe.inject (embed_to_js v) + end) | Array vs -> Js.Unsafe.inject (Js.array (Array.map embed_to_js vs)) | Unembeddable -> Js.Unsafe.inject Js.null @@ -165,33 +164,31 @@ let _ = | BeginCall info | EndCall info | VariableDefinition (info, _) -> - List.map Js.string info + List.map Js.string info | DecisionTaken _ -> [])) val mutable loggedValue = match evt with | VariableDefinition (_, v) -> embed_to_js v | EndCall _ | BeginCall _ | DecisionTaken _ -> - Js.Unsafe.inject Js.undefined + Js.Unsafe.inject Js.undefined val mutable sourcePosition = match evt with | DecisionTaken pos -> - Js.def - (object%js - val mutable fileName = - Js.string pos.filename + Js.def + (object%js + val mutable fileName = Js.string pos.filename + val mutable startLine = pos.start_line + val mutable endLine = pos.end_line + val mutable startColumn = pos.start_column + val mutable endColumn = pos.end_column - val mutable startLine = pos.start_line - val mutable endLine = pos.end_line - val mutable startColumn = pos.start_column - val mutable endColumn = pos.end_column - - val mutable lawHeadings = - Js.array - (Array.of_list - (List.map Js.string pos.law_headings)) - end) + val mutable lawHeadings = + Js.array + (Array.of_list + (List.map Js.string pos.law_headings)) + end) | _ -> Js.undefined end) (retrieve_log ())))) @@ -229,18 +226,18 @@ let _ = AF.d_prise_en_charge = (match Js.to_string child##.priseEnCharge with | "Effective et permanente" -> - EffectiveEtPermanente () + EffectiveEtPermanente () | "Garde alternée, allocataire unique" -> - GardeAlterneeAllocataireUnique () + GardeAlterneeAllocataireUnique () | "Garde alternée, partage des allocations" -> - GardeAlterneePartageAllocations () + GardeAlterneePartageAllocations () | "Confié aux service sociaux, allocation versée \ à la famille" -> - ServicesSociauxAllocationVerseeALaFamille () + ServicesSociauxAllocationVerseeALaFamille () | "Confié aux service sociaux, allocation versée \ aux services sociaux" -> - ServicesSociauxAllocationVerseeAuxServicesSociaux - () + ServicesSociauxAllocationVerseeAuxServicesSociaux + () | _ -> failwith "Unknown prise en charge"); AF.d_remuneration_mensuelle = money_of_units_int child##.remunerationMensuelle; diff --git a/french_law/ocaml/bench.ml b/french_law/ocaml/bench.ml index cf2f0054..8886e7f1 100644 --- a/french_law/ocaml/bench.ml +++ b/french_law/ocaml/bench.ml @@ -56,13 +56,13 @@ let format_prise_en_charge (fmt : Format.formatter) (g : AF.prise_en_charge) : (match g with | AF.EffectiveEtPermanente _ -> "Effective et permanente" | AF.GardeAlterneePartageAllocations _ -> - "Garde alternée, allocations partagée" + "Garde alternée, allocations partagée" | AF.GardeAlterneeAllocataireUnique _ -> - "Garde alternée, allocataire unique" + "Garde alternée, allocataire unique" | AF.ServicesSociauxAllocationVerseeALaFamille _ -> - "Oui, allocations versée à la famille" + "Oui, allocations versée à la famille" | AF.ServicesSociauxAllocationVerseeAuxServicesSociaux _ -> - "Oui, allocations versée aux services sociaux") + "Oui, allocations versée aux services sociaux") let num_successful = ref 0 let total_amount = ref 0. @@ -89,25 +89,25 @@ let run_test () = total_amount := Float.add !total_amount amount with | (NoValueProvided _ | ConflictError) as err -> - Format.printf "%s\n%a\nincome: %d\ncurrent_date: %s\nresidence: %a\n" - (match err with - | NoValueProvided _ -> "No value provided somewhere!" - | ConflictError -> "Conflict error!" - | _ -> failwith "impossible") - (Format.pp_print_list (fun fmt child -> - Format.fprintf fmt - "Child %d:\n\ - \ income: %.2f\n\ - \ birth date: %s\n\ - \ prise en charge: %a" - (integer_to_int child.AF.d_identifiant) - (money_to_float child.AF.d_remuneration_mensuelle) - (Runtime.date_to_string child.AF.d_date_de_naissance) - format_prise_en_charge child.AF.d_prise_en_charge)) - (Array.to_list children) income - (Runtime.date_to_string current_date) - format_residence residence; - exit (-1) + Format.printf "%s\n%a\nincome: %d\ncurrent_date: %s\nresidence: %a\n" + (match err with + | NoValueProvided _ -> "No value provided somewhere!" + | ConflictError -> "Conflict error!" + | _ -> failwith "impossible") + (Format.pp_print_list (fun fmt child -> + Format.fprintf fmt + "Child %d:\n\ + \ income: %.2f\n\ + \ birth date: %s\n\ + \ prise en charge: %a" + (integer_to_int child.AF.d_identifiant) + (money_to_float child.AF.d_remuneration_mensuelle) + (Runtime.date_to_string child.AF.d_date_de_naissance) + format_prise_en_charge child.AF.d_prise_en_charge)) + (Array.to_list children) income + (Runtime.date_to_string current_date) + format_residence residence; + exit (-1) | AssertionFailed -> () let bench = diff --git a/french_law/ocaml/law_source/allocations_familiales.ml b/french_law/ocaml/law_source/allocations_familiales.ml index f361e01c..c5ec3ffa 100644 --- a/french_law/ocaml/law_source/allocations_familiales.ml +++ b/french_law/ocaml/law_source/allocations_familiales.ml @@ -13,17 +13,17 @@ type prise_en_charge = let embed_prise_en_charge (x : prise_en_charge) : runtime_value = Enum - ( [ "PriseEnCharge" ], + ( ["PriseEnCharge"], match x with | GardeAlterneePartageAllocations x -> - ("GardeAlternéePartageAllocations", embed_unit x) + "GardeAlternéePartageAllocations", embed_unit x | GardeAlterneeAllocataireUnique x -> - ("GardeAlternéeAllocataireUnique", embed_unit x) - | EffectiveEtPermanente x -> ("EffectiveEtPermanente", embed_unit x) + "GardeAlternéeAllocataireUnique", embed_unit x + | EffectiveEtPermanente x -> "EffectiveEtPermanente", embed_unit x | ServicesSociauxAllocationVerseeALaFamille x -> - ("ServicesSociauxAllocationVerséeÀLaFamille", embed_unit x) + "ServicesSociauxAllocationVerséeÀLaFamille", embed_unit x | ServicesSociauxAllocationVerseeAuxServicesSociaux x -> - ("ServicesSociauxAllocationVerséeAuxServicesSociaux", embed_unit x) ) + "ServicesSociauxAllocationVerséeAuxServicesSociaux", embed_unit x ) type situation_obligation_scolaire = | Avant of unit @@ -33,21 +33,21 @@ type situation_obligation_scolaire = let embed_situation_obligation_scolaire (x : situation_obligation_scolaire) : runtime_value = Enum - ( [ "SituationObligationScolaire" ], + ( ["SituationObligationScolaire"], match x with - | Avant x -> ("Avant", embed_unit x) - | Pendant x -> ("Pendant", embed_unit x) - | Apres x -> ("Après", embed_unit x) ) + | Avant x -> "Avant", embed_unit x + | Pendant x -> "Pendant", embed_unit x + | Apres x -> "Après", embed_unit x ) type prise_en_compte = Complete of unit | Partagee of unit | Zero of unit let embed_prise_en_compte (x : prise_en_compte) : runtime_value = Enum - ( [ "PriseEnCompte" ], + ( ["PriseEnCompte"], match x with - | Complete x -> ("Complète", embed_unit x) - | Partagee x -> ("Partagée", embed_unit x) - | Zero x -> ("Zéro", embed_unit x) ) + | Complete x -> "Complète", embed_unit x + | Partagee x -> "Partagée", embed_unit x + | Zero x -> "Zéro", embed_unit x ) type versement_allocations = | Normal of unit @@ -55,11 +55,11 @@ type versement_allocations = let embed_versement_allocations (x : versement_allocations) : runtime_value = Enum - ( [ "VersementAllocations" ], + ( ["VersementAllocations"], match x with - | Normal x -> ("Normal", embed_unit x) + | Normal x -> "Normal", embed_unit x | AllocationVerseeAuxServicesSociaux x -> - ("AllocationVerséeAuxServicesSociaux", embed_unit x) ) + "AllocationVerséeAuxServicesSociaux", embed_unit x ) type element_prestations_familiales = | PrestationAccueilJeuneEnfant of unit @@ -74,21 +74,19 @@ type element_prestations_familiales = let embed_element_prestations_familiales (x : element_prestations_familiales) : runtime_value = Enum - ( [ "ÉlémentPrestationsFamiliales" ], + ( ["ÉlémentPrestationsFamiliales"], match x with | PrestationAccueilJeuneEnfant x -> - ("PrestationAccueilJeuneEnfant", embed_unit x) - | AllocationsFamiliales x -> ("AllocationsFamiliales", embed_unit x) - | ComplementFamilial x -> ("ComplémentFamilial", embed_unit x) - | AllocationLogement x -> ("AllocationLogement", embed_unit x) + "PrestationAccueilJeuneEnfant", embed_unit x + | AllocationsFamiliales x -> "AllocationsFamiliales", embed_unit x + | ComplementFamilial x -> "ComplémentFamilial", embed_unit x + | AllocationLogement x -> "AllocationLogement", embed_unit x | AllocationEducationEnfantHandicape x -> - ("AllocationÉducationEnfantHandicapé", embed_unit x) - | AllocationSoutienFamilial x -> - ("AllocationSoutienFamilial", embed_unit x) - | AllocationRentreeScolaire x -> - ("AllocationRentréeScolaire", embed_unit x) + "AllocationÉducationEnfantHandicapé", embed_unit x + | AllocationSoutienFamilial x -> "AllocationSoutienFamilial", embed_unit x + | AllocationRentreeScolaire x -> "AllocationRentréeScolaire", embed_unit x | AllocationJournalierePresenceParentale x -> - ("AllocationJournalièrePresenceParentale", embed_unit x) ) + "AllocationJournalièrePresenceParentale", embed_unit x ) type collectivite = | Guadeloupe of unit @@ -103,17 +101,17 @@ type collectivite = let embed_collectivite (x : collectivite) : runtime_value = Enum - ( [ "Collectivité" ], + ( ["Collectivité"], match x with - | Guadeloupe x -> ("Guadeloupe", embed_unit x) - | Guyane x -> ("Guyane", embed_unit x) - | Martinique x -> ("Martinique", embed_unit x) - | LaReunion x -> ("LaRéunion", embed_unit x) - | SaintBarthelemy x -> ("SaintBarthélemy", embed_unit x) - | SaintMartin x -> ("SaintMartin", embed_unit x) - | Metropole x -> ("Métropole", embed_unit x) - | SaintPierreEtMiquelon x -> ("SaintPierreEtMiquelon", embed_unit x) - | Mayotte x -> ("Mayotte", embed_unit x) ) + | Guadeloupe x -> "Guadeloupe", embed_unit x + | Guyane x -> "Guyane", embed_unit x + | Martinique x -> "Martinique", embed_unit x + | LaReunion x -> "LaRéunion", embed_unit x + | SaintBarthelemy x -> "SaintBarthélemy", embed_unit x + | SaintMartin x -> "SaintMartin", embed_unit x + | Metropole x -> "Métropole", embed_unit x + | SaintPierreEtMiquelon x -> "SaintPierreEtMiquelon", embed_unit x + | Mayotte x -> "Mayotte", embed_unit x ) type enfant_entree = { d_identifiant : integer; @@ -126,12 +124,12 @@ type enfant_entree = { let embed_enfant_entree (x : enfant_entree) : runtime_value = Struct - ( [ "EnfantEntrée" ], + ( ["EnfantEntrée"], [ - ("d_identifiant", embed_integer x.d_identifiant); - ("d_rémuneration_mensuelle", embed_money x.d_remuneration_mensuelle); - ("d_date_de_naissance", embed_date x.d_date_de_naissance); - ("d_prise_en_charge", embed_prise_en_charge x.d_prise_en_charge); + "d_identifiant", embed_integer x.d_identifiant; + "d_rémuneration_mensuelle", embed_money x.d_remuneration_mensuelle; + "d_date_de_naissance", embed_date x.d_date_de_naissance; + "d_prise_en_charge", embed_prise_en_charge x.d_prise_en_charge; ( "d_a_déjà_ouvert_droit_aux_allocations_familiales", embed_bool x.d_a_deja_ouvert_droit_aux_allocations_familiales ); ( "d_bénéficie_titre_personnel_aide_personnelle_logement", @@ -151,15 +149,15 @@ type enfant = { let embed_enfant (x : enfant) : runtime_value = Struct - ( [ "Enfant" ], + ( ["Enfant"], [ - ("identifiant", embed_integer x.identifiant); + "identifiant", embed_integer x.identifiant; ( "obligation_scolaire", embed_situation_obligation_scolaire x.obligation_scolaire ); - ("rémuneration_mensuelle", embed_money x.remuneration_mensuelle); - ("date_de_naissance", embed_date x.date_de_naissance); - ("âge", embed_integer x.age); - ("prise_en_charge", embed_prise_en_charge x.prise_en_charge); + "rémuneration_mensuelle", embed_money x.remuneration_mensuelle; + "date_de_naissance", embed_date x.date_de_naissance; + "âge", embed_integer x.age; + "prise_en_charge", embed_prise_en_charge x.prise_en_charge; ( "a_déjà_ouvert_droit_aux_allocations_familiales", embed_bool x.a_deja_ouvert_droit_aux_allocations_familiales ); ( "bénéficie_titre_personnel_aide_personnelle_logement", @@ -176,12 +174,12 @@ type prestations_familiales_out = { let embed_prestations_familiales_out (x : prestations_familiales_out) : runtime_value = Struct - ( [ "PrestationsFamiliales_out" ], + ( ["PrestationsFamiliales_out"], [ - ("droit_ouvert_out", unembeddable x.droit_ouvert_out); - ("conditions_hors_âge_out", unembeddable x.conditions_hors_age_out); - ("âge_l512_3_2_out", embed_integer x.age_l512_3_2_out); - ("régime_outre_mer_l751_1_out", embed_bool x.regime_outre_mer_l751_1_out); + "droit_ouvert_out", unembeddable x.droit_ouvert_out; + "conditions_hors_âge_out", unembeddable x.conditions_hors_age_out; + "âge_l512_3_2_out", embed_integer x.age_l512_3_2_out; + "régime_outre_mer_l751_1_out", embed_bool x.regime_outre_mer_l751_1_out; ] ) type prestations_familiales_in = { @@ -193,12 +191,12 @@ type prestations_familiales_in = { let embed_prestations_familiales_in (x : prestations_familiales_in) : runtime_value = Struct - ( [ "PrestationsFamiliales_in" ], + ( ["PrestationsFamiliales_in"], [ - ("date_courante_in", embed_date x.date_courante_in); + "date_courante_in", embed_date x.date_courante_in; ( "prestation_courante_in", embed_element_prestations_familiales x.prestation_courante_in ); - ("résidence_in", embed_collectivite x.residence_in); + "résidence_in", embed_collectivite x.residence_in; ] ) type allocation_familiales_avril2008_out = { @@ -208,7 +206,7 @@ type allocation_familiales_avril2008_out = { let embed_allocation_familiales_avril2008_out (x : allocation_familiales_avril2008_out) : runtime_value = Struct - ( [ "AllocationFamilialesAvril2008_out" ], + ( ["AllocationFamilialesAvril2008_out"], [ ( "âge_minimum_alinéa_1_l521_3_out", embed_integer x.age_minimum_alinea_1_l521_3_out ); @@ -224,23 +222,23 @@ type enfant_le_plus_age_out = { le_plus_age_out : enfant } let embed_enfant_le_plus_age_out (x : enfant_le_plus_age_out) : runtime_value = Struct - ( [ "EnfantLePlusÂgé_out" ], - [ ("le_plus_âgé_out", embed_enfant x.le_plus_age_out) ] ) + ( ["EnfantLePlusÂgé_out"], + ["le_plus_âgé_out", embed_enfant x.le_plus_age_out] ) type enfant_le_plus_age_in = { enfants_in : enfant array } let embed_enfant_le_plus_age_in (x : enfant_le_plus_age_in) : runtime_value = Struct - ( [ "EnfantLePlusÂgé_in" ], - [ ("enfants_in", embed_array embed_enfant x.enfants_in) ] ) + ( ["EnfantLePlusÂgé_in"], + ["enfants_in", embed_array embed_enfant x.enfants_in] ) type allocations_familiales_out = { montant_verse_out : money } let embed_allocations_familiales_out (x : allocations_familiales_out) : runtime_value = Struct - ( [ "AllocationsFamiliales_out" ], - [ ("montant_versé_out", embed_money x.montant_verse_out) ] ) + ( ["AllocationsFamiliales_out"], + ["montant_versé_out", embed_money x.montant_verse_out] ) type allocations_familiales_in = { personne_charge_effective_permanente_est_parent_in : bool; @@ -255,17 +253,16 @@ type allocations_familiales_in = { let embed_allocations_familiales_in (x : allocations_familiales_in) : runtime_value = Struct - ( [ "AllocationsFamiliales_in" ], + ( ["AllocationsFamiliales_in"], [ ( "personne_charge_effective_permanente_est_parent_in", embed_bool x.personne_charge_effective_permanente_est_parent_in ); ( "personne_charge_effective_permanente_remplit_titre_I_in", - embed_bool x.personne_charge_effective_permanente_remplit_titre_I_in - ); - ("ressources_ménage_in", embed_money x.ressources_menage_in); - ("résidence_in", embed_collectivite x.residence_in); - ("date_courante_in", embed_date x.date_courante_in); - ("enfants_à_charge_in", embed_array embed_enfant x.enfants_a_charge_in); + embed_bool x.personne_charge_effective_permanente_remplit_titre_I_in ); + "ressources_ménage_in", embed_money x.ressources_menage_in; + "résidence_in", embed_collectivite x.residence_in; + "date_courante_in", embed_date x.date_courante_in; + "enfants_à_charge_in", embed_array embed_enfant x.enfants_a_charge_in; ( "avait_enfant_à_charge_avant_1er_janvier_2012_in", embed_bool x.avait_enfant_a_charge_avant_1er_janvier_2012_in ); ] ) @@ -273,17 +270,16 @@ let embed_allocations_familiales_in (x : allocations_familiales_in) : type smic_out = { brut_horaire_out : money } let embed_smic_out (x : smic_out) : runtime_value = - Struct - ([ "Smic_out" ], [ ("brut_horaire_out", embed_money x.brut_horaire_out) ]) + Struct (["Smic_out"], ["brut_horaire_out", embed_money x.brut_horaire_out]) type smic_in = { date_courante_in : date; residence_in : collectivite } let embed_smic_in (x : smic_in) : runtime_value = Struct - ( [ "Smic_in" ], + ( ["Smic_in"], [ - ("date_courante_in", embed_date x.date_courante_in); - ("résidence_in", embed_collectivite x.residence_in); + "date_courante_in", embed_date x.date_courante_in; + "résidence_in", embed_collectivite x.residence_in; ] ) type base_mensuelle_allocations_familiales_out = { montant_out : money } @@ -291,24 +287,24 @@ type base_mensuelle_allocations_familiales_out = { montant_out : money } let embed_base_mensuelle_allocations_familiales_out (x : base_mensuelle_allocations_familiales_out) : runtime_value = Struct - ( [ "BaseMensuelleAllocationsFamiliales_out" ], - [ ("montant_out", embed_money x.montant_out) ] ) + ( ["BaseMensuelleAllocationsFamiliales_out"], + ["montant_out", embed_money x.montant_out] ) type base_mensuelle_allocations_familiales_in = { date_courante_in : date } let embed_base_mensuelle_allocations_familiales_in (x : base_mensuelle_allocations_familiales_in) : runtime_value = Struct - ( [ "BaseMensuelleAllocationsFamiliales_in" ], - [ ("date_courante_in", embed_date x.date_courante_in) ] ) + ( ["BaseMensuelleAllocationsFamiliales_in"], + ["date_courante_in", embed_date x.date_courante_in] ) type interface_allocations_familiales_out = { i_montant_verse_out : money } let embed_interface_allocations_familiales_out (x : interface_allocations_familiales_out) : runtime_value = Struct - ( [ "InterfaceAllocationsFamiliales_out" ], - [ ("i_montant_versé_out", embed_money x.i_montant_verse_out) ] ) + ( ["InterfaceAllocationsFamiliales_out"], + ["i_montant_versé_out", embed_money x.i_montant_verse_out] ) type interface_allocations_familiales_in = { i_date_courante_in : date; @@ -323,12 +319,12 @@ type interface_allocations_familiales_in = { let embed_interface_allocations_familiales_in (x : interface_allocations_familiales_in) : runtime_value = Struct - ( [ "InterfaceAllocationsFamiliales_in" ], + ( ["InterfaceAllocationsFamiliales_in"], [ - ("i_date_courante_in", embed_date x.i_date_courante_in); - ("i_enfants_in", embed_array embed_enfant_entree x.i_enfants_in); - ("i_ressources_ménage_in", embed_money x.i_ressources_menage_in); - ("i_résidence_in", embed_collectivite x.i_residence_in); + "i_date_courante_in", embed_date x.i_date_courante_in; + "i_enfants_in", embed_array embed_enfant_entree x.i_enfants_in; + "i_ressources_ménage_in", embed_money x.i_ressources_menage_in; + "i_résidence_in", embed_collectivite x.i_residence_in; ( "i_personne_charge_effective_permanente_est_parent_in", embed_bool x.i_personne_charge_effective_permanente_est_parent_in ); ( "i_personne_charge_effective_permanente_remplit_titre_I_in", @@ -343,7 +339,7 @@ let allocation_familiales_avril2008 allocation_familiales_avril2008_out = let age_minimum_alinea_1_l521_3_ : integer = log_variable_definition - [ "AllocationFamilialesAvril2008"; "âge_minimum_alinéa_1_l521_3" ] + ["AllocationFamilialesAvril2008"; "âge_minimum_alinéa_1_l521_3"] embed_integer (try integer_of_string "16" with EmptyError -> @@ -357,9 +353,7 @@ let allocation_familiales_avril2008 end_column = 37; law_headings = [ - "Allocations familiales"; - "Champs d'applications"; - "Prologue"; + "Allocations familiales"; "Champs d'applications"; "Prologue"; ]; })) in @@ -370,7 +364,7 @@ let enfant_le_plus_age (enfant_le_plus_age_in : enfant_le_plus_age_in) : let enfants_ : enfant array = enfant_le_plus_age_in.enfants_in in let le_plus_age_ : enfant = log_variable_definition - [ "EnfantLePlusÂgé"; "le_plus_âgé" ] + ["EnfantLePlusÂgé"; "le_plus_âgé"] embed_enfant (try Array.fold_left @@ -398,9 +392,7 @@ let enfant_le_plus_age (enfant_le_plus_age_in : enfant_le_plus_age_in) : end_column = 21; law_headings = [ - "Allocations familiales"; - "Champs d'applications"; - "Prologue"; + "Allocations familiales"; "Champs d'applications"; "Prologue"; ]; })) in @@ -410,7 +402,7 @@ let smic (smic_in : smic_in) : smic_out = let date_courante_ : date = smic_in.date_courante_in in let residence_ : collectivite = smic_in.residence_in in let brut_horaire_ : money = - log_variable_definition [ "Smic"; "brut_horaire" ] embed_money + log_variable_definition ["Smic"; "brut_horaire"] embed_money (try handle_default [| @@ -574,7 +566,7 @@ let smic (smic_in : smic_in) : smic_out = end_line = 11; end_column = 22; law_headings = - [ "Prologue"; "Montant du salaire minimum de croissance" ]; + ["Prologue"; "Montant du salaire minimum de croissance"]; })) in { brut_horaire_out = brut_horaire_ } @@ -588,7 +580,7 @@ let base_mensuelle_allocations_familiales in let montant_ : money = log_variable_definition - [ "BaseMensuelleAllocationsFamiliales"; "montant" ] + ["BaseMensuelleAllocationsFamiliales"; "montant"] embed_money (try handle_default @@ -708,7 +700,7 @@ let base_mensuelle_allocations_familiales end_line = 6; end_column = 17; law_headings = - [ "Montant de la base mensuelle des allocations familiales" ]; + ["Montant de la base mensuelle des allocations familiales"]; })) in { montant_out = montant_ } @@ -723,7 +715,7 @@ let prestations_familiales let residence_ : collectivite = prestations_familiales_in.residence_in in let age_l512_3_2_ : integer = log_variable_definition - [ "PrestationsFamiliales"; "âge_l512_3_2" ] + ["PrestationsFamiliales"; "âge_l512_3_2"] embed_integer (try integer_of_string "20" with EmptyError -> @@ -737,16 +729,14 @@ let prestations_familiales end_column = 22; law_headings = [ - "Prestations familiales"; - "Champs d'applications"; - "Prologue"; + "Prestations familiales"; "Champs d'applications"; "Prologue"; ]; })) in let smic_dot_date_courante_ : date = try log_variable_definition - [ "PrestationsFamiliales"; "smic.date_courante" ] + ["PrestationsFamiliales"; "smic.date_courante"] embed_date date_courante_ with EmptyError -> raise @@ -758,13 +748,13 @@ let prestations_familiales end_line = 9; end_column = 23; law_headings = - [ "Prologue"; "Montant du salaire minimum de croissance" ]; + ["Prologue"; "Montant du salaire minimum de croissance"]; }) in let smic_dot_residence_ : collectivite = try log_variable_definition - [ "PrestationsFamiliales"; "smic.résidence" ] + ["PrestationsFamiliales"; "smic.résidence"] embed_collectivite residence_ with EmptyError -> raise @@ -776,14 +766,14 @@ let prestations_familiales end_line = 10; end_column = 19; law_headings = - [ "Prologue"; "Montant du salaire minimum de croissance" ]; + ["Prologue"; "Montant du salaire minimum de croissance"]; }) in let result_ : smic_out = log_end_call - [ "PrestationsFamiliales"; "smic"; "Smic" ] + ["PrestationsFamiliales"; "smic"; "Smic"] (log_begin_call - [ "PrestationsFamiliales"; "smic"; "Smic" ] + ["PrestationsFamiliales"; "smic"; "Smic"] smic { date_courante_in = smic_dot_date_courante_; @@ -793,7 +783,7 @@ let prestations_familiales let smic_dot_brut_horaire_ : money = result_.brut_horaire_out in let regime_outre_mer_l751_1_ : bool = log_variable_definition - [ "PrestationsFamiliales"; "régime_outre_mer_l751_1" ] + ["PrestationsFamiliales"; "régime_outre_mer_l751_1"] embed_bool (try try @@ -835,15 +825,13 @@ let prestations_familiales end_column = 33; law_headings = [ - "Prestations familiales"; - "Champs d'applications"; - "Prologue"; + "Prestations familiales"; "Champs d'applications"; "Prologue"; ]; })) in let plafond_l512_3_2_ : money = log_variable_definition - [ "PrestationsFamiliales"; "plafond_l512_3_2" ] + ["PrestationsFamiliales"; "plafond_l512_3_2"] embed_money (try try @@ -885,15 +873,13 @@ let prestations_familiales end_column = 27; law_headings = [ - "Prestations familiales"; - "Champs d'applications"; - "Prologue"; + "Prestations familiales"; "Champs d'applications"; "Prologue"; ]; })) in let conditions_hors_age_ : enfant -> bool = log_variable_definition - [ "PrestationsFamiliales"; "conditions_hors_âge" ] + ["PrestationsFamiliales"; "conditions_hors_âge"] unembeddable (try fun (param_ : enfant) -> @@ -962,15 +948,13 @@ let prestations_familiales end_column = 29; law_headings = [ - "Prestations familiales"; - "Champs d'applications"; - "Prologue"; + "Prestations familiales"; "Champs d'applications"; "Prologue"; ]; })) in let droit_ouvert_ : enfant -> bool = log_variable_definition - [ "PrestationsFamiliales"; "droit_ouvert" ] + ["PrestationsFamiliales"; "droit_ouvert"] unembeddable (try fun (param_ : enfant) -> @@ -1095,9 +1079,7 @@ let prestations_familiales end_column = 22; law_headings = [ - "Prestations familiales"; - "Champs d'applications"; - "Prologue"; + "Prestations familiales"; "Champs d'applications"; "Prologue"; ]; })) in @@ -1131,7 +1113,7 @@ let allocations_familiales in let prise_en_compte_ : enfant -> prise_en_compte = log_variable_definition - [ "AllocationsFamiliales"; "prise_en_compte" ] + ["AllocationsFamiliales"; "prise_en_compte"] unembeddable (try fun (param_ : enfant) -> @@ -1164,7 +1146,7 @@ let allocations_familiales | EffectiveEtPermanente _ -> false | ServicesSociauxAllocationVerseeALaFamille _ -> true | ServicesSociauxAllocationVerseeAuxServicesSociaux _ -> - false) + false) then Complete () else raise EmptyError); (fun (_ : _) -> @@ -1193,7 +1175,7 @@ let allocations_familiales | EffectiveEtPermanente _ -> false | ServicesSociauxAllocationVerseeALaFamille _ -> false | ServicesSociauxAllocationVerseeAuxServicesSociaux _ -> - true) + true) then Zero () else raise EmptyError); (fun (_ : _) -> @@ -1222,7 +1204,7 @@ let allocations_familiales | EffectiveEtPermanente _ -> false | ServicesSociauxAllocationVerseeALaFamille _ -> false | ServicesSociauxAllocationVerseeAuxServicesSociaux _ -> - false) + false) then Partagee () else raise EmptyError); (fun (_ : _) -> @@ -1251,7 +1233,7 @@ let allocations_familiales | EffectiveEtPermanente _ -> false | ServicesSociauxAllocationVerseeALaFamille _ -> false | ServicesSociauxAllocationVerseeAuxServicesSociaux _ -> - false) + false) then Complete () else raise EmptyError); (fun (_ : _) -> @@ -1280,7 +1262,7 @@ let allocations_familiales | EffectiveEtPermanente _ -> true | ServicesSociauxAllocationVerseeALaFamille _ -> false | ServicesSociauxAllocationVerseeAuxServicesSociaux _ -> - false) + false) then Complete () else raise EmptyError); |] @@ -1313,15 +1295,13 @@ let allocations_familiales end_column = 26; law_headings = [ - "Allocations familiales"; - "Champs d'applications"; - "Prologue"; + "Allocations familiales"; "Champs d'applications"; "Prologue"; ]; })) in let versement_ : enfant -> versement_allocations = log_variable_definition - [ "AllocationsFamiliales"; "versement" ] + ["AllocationsFamiliales"; "versement"] unembeddable (try fun (param_ : enfant) -> @@ -1354,7 +1334,7 @@ let allocations_familiales | EffectiveEtPermanente _ -> false | ServicesSociauxAllocationVerseeALaFamille _ -> true | ServicesSociauxAllocationVerseeAuxServicesSociaux _ -> - false) + false) then Normal () else raise EmptyError); (fun (_ : _) -> @@ -1383,7 +1363,7 @@ let allocations_familiales | EffectiveEtPermanente _ -> false | ServicesSociauxAllocationVerseeALaFamille _ -> false | ServicesSociauxAllocationVerseeAuxServicesSociaux _ -> - true) + true) then AllocationVerseeAuxServicesSociaux () else raise EmptyError); (fun (_ : _) -> @@ -1412,7 +1392,7 @@ let allocations_familiales | EffectiveEtPermanente _ -> false | ServicesSociauxAllocationVerseeALaFamille _ -> false | ServicesSociauxAllocationVerseeAuxServicesSociaux _ -> - false) + false) then Normal () else raise EmptyError); (fun (_ : _) -> @@ -1441,7 +1421,7 @@ let allocations_familiales | EffectiveEtPermanente _ -> false | ServicesSociauxAllocationVerseeALaFamille _ -> false | ServicesSociauxAllocationVerseeAuxServicesSociaux _ -> - false) + false) then Normal () else raise EmptyError); (fun (_ : _) -> @@ -1470,7 +1450,7 @@ let allocations_familiales | EffectiveEtPermanente _ -> true | ServicesSociauxAllocationVerseeALaFamille _ -> false | ServicesSociauxAllocationVerseeAuxServicesSociaux _ -> - false) + false) then Normal () else raise EmptyError); |] @@ -1503,15 +1483,13 @@ let allocations_familiales end_column = 20; law_headings = [ - "Allocations familiales"; - "Champs d'applications"; - "Prologue"; + "Allocations familiales"; "Champs d'applications"; "Prologue"; ]; })) in let nombre_enfants_l521_1_ : integer = log_variable_definition - [ "AllocationsFamiliales"; "nombre_enfants_l521_1" ] + ["AllocationsFamiliales"; "nombre_enfants_l521_1"] embed_integer (try integer_of_string "3" with EmptyError -> @@ -1525,15 +1503,13 @@ let allocations_familiales end_column = 32; law_headings = [ - "Allocations familiales"; - "Champs d'applications"; - "Prologue"; + "Allocations familiales"; "Champs d'applications"; "Prologue"; ]; })) in let nombre_enfants_alinea_2_l521_3_ : integer = log_variable_definition - [ "AllocationsFamiliales"; "nombre_enfants_alinéa_2_l521_3" ] + ["AllocationsFamiliales"; "nombre_enfants_alinéa_2_l521_3"] embed_integer (try integer_of_string "3" with EmptyError -> @@ -1547,9 +1523,7 @@ let allocations_familiales end_column = 41; law_headings = [ - "Allocations familiales"; - "Champs d'applications"; - "Prologue"; + "Allocations familiales"; "Champs d'applications"; "Prologue"; ]; })) in @@ -1574,7 +1548,7 @@ let allocations_familiales let bmaf_dot_date_courante_ : date = try log_variable_definition - [ "AllocationsFamiliales"; "bmaf.date_courante" ] + ["AllocationsFamiliales"; "bmaf.date_courante"] embed_date date_courante_ with EmptyError -> raise @@ -1587,16 +1561,14 @@ let allocations_familiales end_line = 5; end_column = 23; law_headings = - [ "Montant de la base mensuelle des allocations familiales" ]; + ["Montant de la base mensuelle des allocations familiales"]; }) in let result_ : base_mensuelle_allocations_familiales_out = log_end_call - [ "AllocationsFamiliales"; "bmaf"; "BaseMensuelleAllocationsFamiliales" ] + ["AllocationsFamiliales"; "bmaf"; "BaseMensuelleAllocationsFamiliales"] (log_begin_call - [ - "AllocationsFamiliales"; "bmaf"; "BaseMensuelleAllocationsFamiliales"; - ] + ["AllocationsFamiliales"; "bmaf"; "BaseMensuelleAllocationsFamiliales"] base_mensuelle_allocations_familiales { date_courante_in = bmaf_dot_date_courante_ }) in @@ -1604,7 +1576,7 @@ let allocations_familiales let prestations_familiales_dot_date_courante_ : date = try log_variable_definition - [ "AllocationsFamiliales"; "prestations_familiales.date_courante" ] + ["AllocationsFamiliales"; "prestations_familiales.date_courante"] embed_date date_courante_ with EmptyError -> raise @@ -1616,16 +1588,14 @@ let allocations_familiales end_line = 63; end_column = 23; law_headings = - [ "Prestations familiales"; "Champs d'applications"; "Prologue" ]; + ["Prestations familiales"; "Champs d'applications"; "Prologue"]; }) in let prestations_familiales_dot_prestation_courante_ : element_prestations_familiales = try log_variable_definition - [ - "AllocationsFamiliales"; "prestations_familiales.prestation_courante"; - ] + ["AllocationsFamiliales"; "prestations_familiales.prestation_courante"] embed_element_prestations_familiales (AllocationsFamiliales ()) with EmptyError -> raise @@ -1637,13 +1607,13 @@ let allocations_familiales end_line = 64; end_column = 29; law_headings = - [ "Prestations familiales"; "Champs d'applications"; "Prologue" ]; + ["Prestations familiales"; "Champs d'applications"; "Prologue"]; }) in let prestations_familiales_dot_residence_ : collectivite = try log_variable_definition - [ "AllocationsFamiliales"; "prestations_familiales.résidence" ] + ["AllocationsFamiliales"; "prestations_familiales.résidence"] embed_collectivite residence_ with EmptyError -> raise @@ -1655,7 +1625,7 @@ let allocations_familiales end_line = 65; end_column = 19; law_headings = - [ "Prestations familiales"; "Champs d'applications"; "Prologue" ]; + ["Prestations familiales"; "Champs d'applications"; "Prologue"]; }) in let result_ : prestations_familiales_out = @@ -1694,7 +1664,7 @@ let allocations_familiales let enfant_le_plus_age_dot_enfants_ : enfant array = try log_variable_definition - [ "AllocationsFamiliales"; "enfant_le_plus_âgé.enfants" ] + ["AllocationsFamiliales"; "enfant_le_plus_âgé.enfants"] (embed_array embed_enfant) enfants_a_charge_ with EmptyError -> raise @@ -1706,21 +1676,21 @@ let allocations_familiales end_line = 80; end_column = 17; law_headings = - [ "Allocations familiales"; "Champs d'applications"; "Prologue" ]; + ["Allocations familiales"; "Champs d'applications"; "Prologue"]; }) in let result_ : enfant_le_plus_age_out = log_end_call - [ "AllocationsFamiliales"; "enfant_le_plus_âgé"; "EnfantLePlusÂgé" ] + ["AllocationsFamiliales"; "enfant_le_plus_âgé"; "EnfantLePlusÂgé"] (log_begin_call - [ "AllocationsFamiliales"; "enfant_le_plus_âgé"; "EnfantLePlusÂgé" ] + ["AllocationsFamiliales"; "enfant_le_plus_âgé"; "EnfantLePlusÂgé"] enfant_le_plus_age { enfants_in = enfant_le_plus_age_dot_enfants_ }) in let enfant_le_plus_age_dot_le_plus_age_ : enfant = result_.le_plus_age_out in let age_minimum_alinea_1_l521_3_ : enfant -> integer = log_variable_definition - [ "AllocationsFamiliales"; "âge_minimum_alinéa_1_l521_3" ] + ["AllocationsFamiliales"; "âge_minimum_alinéa_1_l521_3"] unembeddable (try fun (param_ : enfant) -> @@ -1777,9 +1747,7 @@ let allocations_familiales end_column = 38; law_headings = [ - "Allocations familiales"; - "Champs d'applications"; - "Prologue"; + "Allocations familiales"; "Champs d'applications"; "Prologue"; ]; })) in @@ -1794,15 +1762,15 @@ let allocations_familiales array_filter (fun (enfant_ : _) -> log_end_call - [ "PrestationsFamiliales"; "droit_ouvert" ] + ["PrestationsFamiliales"; "droit_ouvert"] (log_variable_definition - [ "PrestationsFamiliales"; "droit_ouvert"; "output" ] + ["PrestationsFamiliales"; "droit_ouvert"; "output"] unembeddable (log_begin_call - [ "PrestationsFamiliales"; "droit_ouvert" ] + ["PrestationsFamiliales"; "droit_ouvert"] prestations_familiales_dot_droit_ouvert_ (log_variable_definition - [ "PrestationsFamiliales"; "droit_ouvert"; "input" ] + ["PrestationsFamiliales"; "droit_ouvert"; "input"] unembeddable enfant_)))) enfants_a_charge_ with EmptyError -> @@ -1816,15 +1784,13 @@ let allocations_familiales end_column = 61; law_headings = [ - "Allocations familiales"; - "Champs d'applications"; - "Prologue"; + "Allocations familiales"; "Champs d'applications"; "Prologue"; ]; })) in let est_enfant_le_plus_age_ : enfant -> bool = log_variable_definition - [ "AllocationsFamiliales"; "est_enfant_le_plus_âgé" ] + ["AllocationsFamiliales"; "est_enfant_le_plus_âgé"] unembeddable (try fun (param_ : enfant) -> @@ -1856,15 +1822,13 @@ let allocations_familiales end_column = 33; law_headings = [ - "Allocations familiales"; - "Champs d'applications"; - "Prologue"; + "Allocations familiales"; "Champs d'applications"; "Prologue"; ]; })) in let plafond__i_i_d521_3_ : money = log_variable_definition - [ "AllocationsFamiliales"; "plafond_II_d521_3" ] + ["AllocationsFamiliales"; "plafond_II_d521_3"] embed_money (try handle_default @@ -2007,15 +1971,13 @@ let allocations_familiales end_column = 28; law_headings = [ - "Allocations familiales"; - "Champs d'applications"; - "Prologue"; + "Allocations familiales"; "Champs d'applications"; "Prologue"; ]; })) in let plafond__i_d521_3_ : money = log_variable_definition - [ "AllocationsFamiliales"; "plafond_I_d521_3" ] + ["AllocationsFamiliales"; "plafond_I_d521_3"] embed_money (try handle_default @@ -2158,15 +2120,13 @@ let allocations_familiales end_column = 27; law_headings = [ - "Allocations familiales"; - "Champs d'applications"; - "Prologue"; + "Allocations familiales"; "Champs d'applications"; "Prologue"; ]; })) in let droit_ouvert_complement_ : bool = log_variable_definition - [ "AllocationsFamiliales"; "droit_ouvert_complément" ] + ["AllocationsFamiliales"; "droit_ouvert_complément"] embed_bool (try try @@ -2211,15 +2171,13 @@ let allocations_familiales end_column = 34; law_headings = [ - "Allocations familiales"; - "Champs d'applications"; - "Prologue"; + "Allocations familiales"; "Champs d'applications"; "Prologue"; ]; })) in let droit_ouvert_forfaitaire_ : enfant -> bool = log_variable_definition - [ "AllocationsFamiliales"; "droit_ouvert_forfaitaire" ] + ["AllocationsFamiliales"; "droit_ouvert_forfaitaire"] unembeddable (try fun (param_ : enfant) -> @@ -2278,7 +2236,7 @@ let allocations_familiales && param_.age = prestations_familiales_dot_age_l512_3_2_ && param_.a_deja_ouvert_droit_aux_allocations_familiales && log_end_call - [ "PrestationsFamiliales"; "conditions_hors_âge" ] + ["PrestationsFamiliales"; "conditions_hors_âge"] (log_variable_definition [ "PrestationsFamiliales"; @@ -2287,9 +2245,7 @@ let allocations_familiales ] unembeddable (log_begin_call - [ - "PrestationsFamiliales"; "conditions_hors_âge"; - ] + ["PrestationsFamiliales"; "conditions_hors_âge"] prestations_familiales_dot_conditions_hors_age_ (log_variable_definition [ @@ -2328,9 +2284,7 @@ let allocations_familiales end_column = 35; law_headings = [ - "Allocations familiales"; - "Champs d'applications"; - "Prologue"; + "Allocations familiales"; "Champs d'applications"; "Prologue"; ]; })) in @@ -2363,17 +2317,13 @@ let allocations_familiales end_column = 64; law_headings = [ - "Allocations familiales"; - "Champs d'applications"; - "Prologue"; + "Allocations familiales"; "Champs d'applications"; "Prologue"; ]; })) in let montant_initial_base_troisieme_enfant_mayotte_ : money = log_variable_definition - [ - "AllocationsFamiliales"; "montant_initial_base_troisième_enfant_mayotte"; - ] + ["AllocationsFamiliales"; "montant_initial_base_troisième_enfant_mayotte"] embed_money (try handle_default @@ -2677,15 +2627,13 @@ let allocations_familiales end_column = 56; law_headings = [ - "Allocations familiales"; - "Champs d'applications"; - "Prologue"; + "Allocations familiales"; "Champs d'applications"; "Prologue"; ]; })) in let nombre_total_enfants_ : decimal = log_variable_definition - [ "AllocationsFamiliales"; "nombre_total_enfants" ] + ["AllocationsFamiliales"; "nombre_total_enfants"] embed_decimal (try decimal_of_integer @@ -2701,15 +2649,13 @@ let allocations_familiales end_column = 31; law_headings = [ - "Allocations familiales"; - "Champs d'applications"; - "Prologue"; + "Allocations familiales"; "Champs d'applications"; "Prologue"; ]; })) in let nombre_moyen_enfants_ : decimal = log_variable_definition - [ "AllocationsFamiliales"; "nombre_moyen_enfants" ] + ["AllocationsFamiliales"; "nombre_moyen_enfants"] embed_decimal (try Array.fold_left @@ -2718,17 +2664,15 @@ let allocations_familiales +& match log_end_call - [ "AllocationsFamiliales"; "prise_en_compte" ] + ["AllocationsFamiliales"; "prise_en_compte"] (log_variable_definition - [ "AllocationsFamiliales"; "prise_en_compte"; "output" ] + ["AllocationsFamiliales"; "prise_en_compte"; "output"] unembeddable (log_begin_call - [ "AllocationsFamiliales"; "prise_en_compte" ] + ["AllocationsFamiliales"; "prise_en_compte"] prise_en_compte_ (log_variable_definition - [ - "AllocationsFamiliales"; "prise_en_compte"; "input"; - ] + ["AllocationsFamiliales"; "prise_en_compte"; "input"] unembeddable enfant_))) with | Complete _ -> decimal_of_string "1." @@ -2747,15 +2691,13 @@ let allocations_familiales end_column = 31; law_headings = [ - "Allocations familiales"; - "Champs d'applications"; - "Prologue"; + "Allocations familiales"; "Champs d'applications"; "Prologue"; ]; })) in let montant_initial_base_premier_enfant_ : money = log_variable_definition - [ "AllocationsFamiliales"; "montant_initial_base_premier_enfant" ] + ["AllocationsFamiliales"; "montant_initial_base_premier_enfant"] embed_money (try handle_default @@ -3148,15 +3090,13 @@ let allocations_familiales end_column = 46; law_headings = [ - "Allocations familiales"; - "Champs d'applications"; - "Prologue"; + "Allocations familiales"; "Champs d'applications"; "Prologue"; ]; })) in let droit_ouvert_base_ : bool = log_variable_definition - [ "AllocationsFamiliales"; "droit_ouvert_base" ] + ["AllocationsFamiliales"; "droit_ouvert_base"] embed_bool (try try @@ -3252,15 +3192,13 @@ let allocations_familiales end_column = 28; law_headings = [ - "Allocations familiales"; - "Champs d'applications"; - "Prologue"; + "Allocations familiales"; "Champs d'applications"; "Prologue"; ]; })) in let droit_ouvert_majoration_ : enfant -> bool = log_variable_definition - [ "AllocationsFamiliales"; "droit_ouvert_majoration" ] + ["AllocationsFamiliales"; "droit_ouvert_majoration"] unembeddable (try fun (param_ : enfant) -> @@ -3339,9 +3277,7 @@ let allocations_familiales } ((not (log_end_call - [ - "AllocationsFamiliales"; "est_enfant_le_plus_âgé"; - ] + ["AllocationsFamiliales"; "est_enfant_le_plus_âgé"] (log_variable_definition [ "AllocationsFamiliales"; @@ -3418,15 +3354,13 @@ let allocations_familiales end_column = 34; law_headings = [ - "Allocations familiales"; - "Champs d'applications"; - "Prologue"; + "Allocations familiales"; "Champs d'applications"; "Prologue"; ]; })) in let complement_degressif_ : money -> money = log_variable_definition - [ "AllocationsFamiliales"; "complément_dégressif" ] + ["AllocationsFamiliales"; "complément_dégressif"] unembeddable (try fun (param_ : money) -> @@ -3523,15 +3457,13 @@ let allocations_familiales end_column = 31; law_headings = [ - "Allocations familiales"; - "Champs d'applications"; - "Prologue"; + "Allocations familiales"; "Champs d'applications"; "Prologue"; ]; })) in let montant_verse_forfaitaire_par_enfant_ : money = log_variable_definition - [ "AllocationsFamiliales"; "montant_versé_forfaitaire_par_enfant" ] + ["AllocationsFamiliales"; "montant_versé_forfaitaire_par_enfant"] embed_money (try handle_default @@ -3620,17 +3552,13 @@ let allocations_familiales end_column = 47; law_headings = [ - "Allocations familiales"; - "Champs d'applications"; - "Prologue"; + "Allocations familiales"; "Champs d'applications"; "Prologue"; ]; })) in let montant_initial_base_troisieme_enfant_et_plus_ : money = log_variable_definition - [ - "AllocationsFamiliales"; "montant_initial_base_troisième_enfant_et_plus"; - ] + ["AllocationsFamiliales"; "montant_initial_base_troisième_enfant_et_plus"] embed_money (try handle_default @@ -3752,15 +3680,13 @@ let allocations_familiales end_column = 56; law_headings = [ - "Allocations familiales"; - "Champs d'applications"; - "Prologue"; + "Allocations familiales"; "Champs d'applications"; "Prologue"; ]; })) in let montant_initial_base_deuxieme_enfant_ : money = log_variable_definition - [ "AllocationsFamiliales"; "montant_initial_base_deuxième_enfant" ] + ["AllocationsFamiliales"; "montant_initial_base_deuxième_enfant"] embed_money (try try @@ -4179,15 +4105,13 @@ let allocations_familiales end_column = 47; law_headings = [ - "Allocations familiales"; - "Champs d'applications"; - "Prologue"; + "Allocations familiales"; "Champs d'applications"; "Prologue"; ]; })) in let rapport_enfants_total_moyen_ : decimal = log_variable_definition - [ "AllocationsFamiliales"; "rapport_enfants_total_moyen" ] + ["AllocationsFamiliales"; "rapport_enfants_total_moyen"] embed_decimal (try if nombre_total_enfants_ = decimal_of_string "0." then @@ -4204,15 +4128,13 @@ let allocations_familiales end_column = 38; law_headings = [ - "Allocations familiales"; - "Champs d'applications"; - "Prologue"; + "Allocations familiales"; "Champs d'applications"; "Prologue"; ]; })) in let montant_initial_metropole_majoration_ : enfant -> money = log_variable_definition - [ "AllocationsFamiliales"; "montant_initial_métropole_majoration" ] + ["AllocationsFamiliales"; "montant_initial_métropole_majoration"] unembeddable (try fun (param_ : enfant) -> @@ -4238,8 +4160,7 @@ let allocations_familiales (not (log_end_call [ - "AllocationsFamiliales"; - "droit_ouvert_majoration"; + "AllocationsFamiliales"; "droit_ouvert_majoration"; ] (log_variable_definition [ @@ -4285,9 +4206,7 @@ let allocations_familiales } (ressources_menage_ >$ plafond__i_i_d521_3_ && log_end_call - [ - "AllocationsFamiliales"; "droit_ouvert_majoration"; - ] + ["AllocationsFamiliales"; "droit_ouvert_majoration"] (log_variable_definition [ "AllocationsFamiliales"; @@ -4333,9 +4252,7 @@ let allocations_familiales ((ressources_menage_ >$ plafond__i_d521_3_ && ressources_menage_ <=$ plafond__i_i_d521_3_) && log_end_call - [ - "AllocationsFamiliales"; "droit_ouvert_majoration"; - ] + ["AllocationsFamiliales"; "droit_ouvert_majoration"] (log_variable_definition [ "AllocationsFamiliales"; @@ -4380,9 +4297,7 @@ let allocations_familiales } (ressources_menage_ <=$ plafond__i_d521_3_ && log_end_call - [ - "AllocationsFamiliales"; "droit_ouvert_majoration"; - ] + ["AllocationsFamiliales"; "droit_ouvert_majoration"] (log_variable_definition [ "AllocationsFamiliales"; @@ -4435,15 +4350,13 @@ let allocations_familiales end_column = 47; law_headings = [ - "Allocations familiales"; - "Champs d'applications"; - "Prologue"; + "Allocations familiales"; "Champs d'applications"; "Prologue"; ]; })) in let montant_verse_forfaitaire_ : money = log_variable_definition - [ "AllocationsFamiliales"; "montant_versé_forfaitaire" ] + ["AllocationsFamiliales"; "montant_versé_forfaitaire"] embed_money (try montant_verse_forfaitaire_par_enfant_ @@ -4452,7 +4365,7 @@ let allocations_familiales (fun (acc_ : integer) (enfant_ : _) -> if log_end_call - [ "AllocationsFamiliales"; "droit_ouvert_forfaitaire" ] + ["AllocationsFamiliales"; "droit_ouvert_forfaitaire"] (log_variable_definition [ "AllocationsFamiliales"; @@ -4487,15 +4400,13 @@ let allocations_familiales end_column = 36; law_headings = [ - "Allocations familiales"; - "Champs d'applications"; - "Prologue"; + "Allocations familiales"; "Champs d'applications"; "Prologue"; ]; })) in let montant_initial_base_ : money = log_variable_definition - [ "AllocationsFamiliales"; "montant_initial_base" ] + ["AllocationsFamiliales"; "montant_initial_base"] embed_money (try handle_default @@ -4567,15 +4478,13 @@ let allocations_familiales end_column = 31; law_headings = [ - "Allocations familiales"; - "Champs d'applications"; - "Prologue"; + "Allocations familiales"; "Champs d'applications"; "Prologue"; ]; })) in let montant_initial_majoration_ : enfant -> money = log_variable_definition - [ "AllocationsFamiliales"; "montant_initial_majoration" ] + ["AllocationsFamiliales"; "montant_initial_majoration"] unembeddable (try fun (param_ : enfant) -> @@ -4603,7 +4512,7 @@ let allocations_familiales ]; } (log_end_call - [ "AllocationsFamiliales"; "droit_ouvert_majoration" ] + ["AllocationsFamiliales"; "droit_ouvert_majoration"] (log_variable_definition [ "AllocationsFamiliales"; @@ -4652,7 +4561,7 @@ let allocations_familiales ]; } (log_end_call - [ "AllocationsFamiliales"; "droit_ouvert_majoration" ] + ["AllocationsFamiliales"; "droit_ouvert_majoration"] (log_variable_definition [ "AllocationsFamiliales"; @@ -4736,15 +4645,13 @@ let allocations_familiales end_column = 37; law_headings = [ - "Allocations familiales"; - "Champs d'applications"; - "Prologue"; + "Allocations familiales"; "Champs d'applications"; "Prologue"; ]; })) in let montant_verse_complement_pour_forfaitaire_ : money = log_variable_definition - [ "AllocationsFamiliales"; "montant_versé_complément_pour_forfaitaire" ] + ["AllocationsFamiliales"; "montant_versé_complément_pour_forfaitaire"] embed_money (try handle_default @@ -4825,15 +4732,13 @@ let allocations_familiales end_column = 52; law_headings = [ - "Allocations familiales"; - "Champs d'applications"; - "Prologue"; + "Allocations familiales"; "Champs d'applications"; "Prologue"; ]; })) in let montant_avec_garde_alternee_base_ : money = log_variable_definition - [ "AllocationsFamiliales"; "montant_avec_garde_alternée_base" ] + ["AllocationsFamiliales"; "montant_avec_garde_alternée_base"] embed_money (try montant_initial_base_ *$ rapport_enfants_total_moyen_ with EmptyError -> @@ -4847,21 +4752,19 @@ let allocations_familiales end_column = 43; law_headings = [ - "Allocations familiales"; - "Champs d'applications"; - "Prologue"; + "Allocations familiales"; "Champs d'applications"; "Prologue"; ]; })) in let montant_avec_garde_alternee_majoration_ : enfant -> money = log_variable_definition - [ "AllocationsFamiliales"; "montant_avec_garde_alternée_majoration" ] + ["AllocationsFamiliales"; "montant_avec_garde_alternée_majoration"] unembeddable (try fun (param_ : enfant) -> try log_end_call - [ "AllocationsFamiliales"; "montant_initial_majoration" ] + ["AllocationsFamiliales"; "montant_initial_majoration"] (log_variable_definition [ "AllocationsFamiliales"; @@ -4870,7 +4773,7 @@ let allocations_familiales ] unembeddable (log_begin_call - [ "AllocationsFamiliales"; "montant_initial_majoration" ] + ["AllocationsFamiliales"; "montant_initial_majoration"] montant_initial_majoration_ (log_variable_definition [ @@ -4882,17 +4785,15 @@ let allocations_familiales *$ match log_end_call - [ "AllocationsFamiliales"; "prise_en_compte" ] + ["AllocationsFamiliales"; "prise_en_compte"] (log_variable_definition - [ "AllocationsFamiliales"; "prise_en_compte"; "output" ] + ["AllocationsFamiliales"; "prise_en_compte"; "output"] unembeddable (log_begin_call - [ "AllocationsFamiliales"; "prise_en_compte" ] + ["AllocationsFamiliales"; "prise_en_compte"] prise_en_compte_ (log_variable_definition - [ - "AllocationsFamiliales"; "prise_en_compte"; "input"; - ] + ["AllocationsFamiliales"; "prise_en_compte"; "input"] unembeddable param_))) with | Complete _ -> decimal_of_string "1." @@ -4925,15 +4826,13 @@ let allocations_familiales end_column = 49; law_headings = [ - "Allocations familiales"; - "Champs d'applications"; - "Prologue"; + "Allocations familiales"; "Champs d'applications"; "Prologue"; ]; })) in let montant_verse_base_ : money = log_variable_definition - [ "AllocationsFamiliales"; "montant_versé_base" ] + ["AllocationsFamiliales"; "montant_versé_base"] embed_money (try if droit_ouvert_base_ then montant_avec_garde_alternee_base_ @@ -4949,15 +4848,13 @@ let allocations_familiales end_column = 29; law_headings = [ - "Allocations familiales"; - "Champs d'applications"; - "Prologue"; + "Allocations familiales"; "Champs d'applications"; "Prologue"; ]; })) in let montant_verse_majoration_ : money = log_variable_definition - [ "AllocationsFamiliales"; "montant_versé_majoration" ] + ["AllocationsFamiliales"; "montant_versé_majoration"] embed_money (try if droit_ouvert_base_ then @@ -5003,9 +4900,7 @@ let allocations_familiales end_column = 35; law_headings = [ - "Allocations familiales"; - "Champs d'applications"; - "Prologue"; + "Allocations familiales"; "Champs d'applications"; "Prologue"; ]; })) in @@ -5028,9 +4923,7 @@ let allocations_familiales end_column = 58; law_headings = [ - "Allocations familiales"; - "Champs d'applications"; - "Prologue"; + "Allocations familiales"; "Champs d'applications"; "Prologue"; ]; })) in @@ -5044,17 +4937,15 @@ let allocations_familiales (try if droit_ouvert_complement_ then log_end_call - [ "AllocationsFamiliales"; "complément_dégressif" ] + ["AllocationsFamiliales"; "complément_dégressif"] (log_variable_definition - [ "AllocationsFamiliales"; "complément_dégressif"; "output" ] + ["AllocationsFamiliales"; "complément_dégressif"; "output"] unembeddable (log_begin_call - [ "AllocationsFamiliales"; "complément_dégressif" ] + ["AllocationsFamiliales"; "complément_dégressif"] complement_degressif_ (log_variable_definition - [ - "AllocationsFamiliales"; "complément_dégressif"; "input"; - ] + ["AllocationsFamiliales"; "complément_dégressif"; "input"] unembeddable montant_base_complement_pour_base_et_majoration_))) else money_of_cents_string "0" @@ -5069,15 +4960,13 @@ let allocations_familiales end_column = 59; law_headings = [ - "Allocations familiales"; - "Champs d'applications"; - "Prologue"; + "Allocations familiales"; "Champs d'applications"; "Prologue"; ]; })) in let montant_verse_ : money = log_variable_definition - [ "AllocationsFamiliales"; "montant_versé" ] + ["AllocationsFamiliales"; "montant_versé"] embed_money (try if droit_ouvert_base_ then @@ -5098,9 +4987,7 @@ let allocations_familiales end_column = 23; law_headings = [ - "Allocations familiales"; - "Champs d'applications"; - "Prologue"; + "Allocations familiales"; "Champs d'applications"; "Prologue"; ]; })) in @@ -5163,7 +5050,7 @@ let interface_allocations_familiales in let enfants_a_charge_ : enfant array = log_variable_definition - [ "InterfaceAllocationsFamiliales"; "enfants_à_charge" ] + ["InterfaceAllocationsFamiliales"; "enfants_à_charge"] (embed_array embed_enfant) (try Array.map @@ -5249,7 +5136,7 @@ let interface_allocations_familiales end_line = 86; end_column = 57; law_headings = - [ "Allocations familiales"; "Champs d'applications"; "Prologue" ]; + ["Allocations familiales"; "Champs d'applications"; "Prologue"]; }) in let allocations_familiales_dot_personne_charge_effective_permanente_remplit_titre__i_ @@ -5291,7 +5178,7 @@ let interface_allocations_familiales end_line = 87; end_column = 62; law_headings = - [ "Allocations familiales"; "Champs d'applications"; "Prologue" ]; + ["Allocations familiales"; "Champs d'applications"; "Prologue"]; }) in let allocations_familiales_dot_ressources_menage_ : money = @@ -5312,13 +5199,13 @@ let interface_allocations_familiales end_line = 88; end_column = 27; law_headings = - [ "Allocations familiales"; "Champs d'applications"; "Prologue" ]; + ["Allocations familiales"; "Champs d'applications"; "Prologue"]; }) in let allocations_familiales_dot_residence_ : collectivite = try log_variable_definition - [ "InterfaceAllocationsFamiliales"; "allocations_familiales.résidence" ] + ["InterfaceAllocationsFamiliales"; "allocations_familiales.résidence"] embed_collectivite i_residence_ with EmptyError -> raise @@ -5330,7 +5217,7 @@ let interface_allocations_familiales end_line = 89; end_column = 19; law_headings = - [ "Allocations familiales"; "Champs d'applications"; "Prologue" ]; + ["Allocations familiales"; "Champs d'applications"; "Prologue"]; }) in let allocations_familiales_dot_date_courante_ : date = @@ -5351,7 +5238,7 @@ let interface_allocations_familiales end_line = 92; end_column = 23; law_headings = - [ "Allocations familiales"; "Champs d'applications"; "Prologue" ]; + ["Allocations familiales"; "Champs d'applications"; "Prologue"]; }) in let allocations_familiales_dot_enfants_a_charge_ : enfant array = @@ -5372,7 +5259,7 @@ let interface_allocations_familiales end_line = 95; end_column = 26; law_headings = - [ "Allocations familiales"; "Champs d'applications"; "Prologue" ]; + ["Allocations familiales"; "Champs d'applications"; "Prologue"]; }) in let allocations_familiales_dot_avait_enfant_a_charge_avant_1er_janvier_2012_ : @@ -5414,7 +5301,7 @@ let interface_allocations_familiales end_line = 116; end_column = 54; law_headings = - [ "Allocations familiales"; "Champs d'applications"; "Prologue" ]; + ["Allocations familiales"; "Champs d'applications"; "Prologue"]; }) in let result_ : allocations_familiales_out = @@ -5449,7 +5336,7 @@ let interface_allocations_familiales in let i_montant_verse_ : money = log_variable_definition - [ "InterfaceAllocationsFamiliales"; "i_montant_versé" ] + ["InterfaceAllocationsFamiliales"; "i_montant_versé"] embed_money (try allocations_familiales_dot_montant_verse_ with EmptyError -> diff --git a/french_law/ocaml/law_source/unit_tests/run_tests.ml b/french_law/ocaml/law_source/unit_tests/run_tests.ml index cffe2f08..5d9a211e 100644 --- a/french_law/ocaml/law_source/unit_tests/run_tests.ml +++ b/french_law/ocaml/law_source/unit_tests/run_tests.ml @@ -4,13 +4,13 @@ let try_test msg test = try test (); Format.printf "%s %s\n" - (ANSITerminal.sprintf [ ANSITerminal.green ] "PASS") - (ANSITerminal.sprintf [ ANSITerminal.magenta ] msg) + (ANSITerminal.sprintf [ANSITerminal.green] "PASS") + (ANSITerminal.sprintf [ANSITerminal.magenta] msg) with Runtime.AssertionFailed -> failure := true; Format.printf "%s %s\n" - (ANSITerminal.sprintf [ ANSITerminal.red ] "FAIL") - (ANSITerminal.sprintf [ ANSITerminal.magenta ] msg) + (ANSITerminal.sprintf [ANSITerminal.red] "FAIL") + (ANSITerminal.sprintf [ANSITerminal.magenta] msg) let _ = try_test "Allocations familiales #1" Tests_allocations_familiales.test1;