From feeee4016e29924a0f9f46ff7f021a851a201d43 Mon Sep 17 00:00:00 2001 From: Louis Gesbert Date: Thu, 13 Apr 2023 21:49:16 +0200 Subject: [PATCH 01/10] Add support for dcalc plugins previously only lcalc and scalc where available --- compiler/driver.ml | 10 +++++++++- compiler/plugin.ml | 8 +++++++- compiler/plugin.mli | 7 +++++++ 3 files changed, 23 insertions(+), 2 deletions(-) diff --git a/compiler/driver.ml b/compiler/driver.ml index 2a1c1561..3459b16c 100644 --- a/compiler/driver.ml +++ b/compiler/driver.ml @@ -289,6 +289,13 @@ let driver source_file (options : Cli.options) : int = (Shared_ast.Expr.format ~debug:options.debug prgm.decl_ctx) result) results + | `Plugin (Plugin.Dcalc p) -> + let output_file, _ = get_output_format ~ext:p.Plugin.extension () in + Cli.debug_print "Compiling program through backend \"%s\"..." + p.Plugin.name; + p.Plugin.apply ~source_file ~output_file ~scope:options.ex_scope + (Shared_ast.Program.untype prgm) + type_ordering | (`OCaml | `Interpret_Lcalc | `Python | `Lcalc | `Scalc | `Plugin _) as backend -> ( Cli.debug_print "Compiling program into lambda calculus..."; @@ -375,6 +382,7 @@ let driver source_file (options : Cli.options) : int = Cli.debug_print "Writing to %s..." (Option.value ~default:"stdout" output_file); Lcalc.To_ocaml.format_program fmt prgm type_ordering + | `Plugin (Plugin.Dcalc _) -> assert false | `Plugin (Plugin.Lcalc p) -> let output_file, _ = get_output_format ~ext:p.Plugin.extension () @@ -411,7 +419,7 @@ let driver source_file (options : Cli.options) : int = with_output @@ fun fmt -> Scalc.To_python.format_program fmt prgm type_ordering - | `Plugin (Plugin.Lcalc _) -> assert false + | `Plugin (Plugin.Dcalc _ | Plugin.Lcalc _) -> assert false | `Plugin (Plugin.Scalc p) -> let output_file, _ = get_output ~ext:p.Plugin.extension () in Cli.debug_print "Compiling program through backend \"%s\"..." diff --git a/compiler/plugin.ml b/compiler/plugin.ml index e7f9437a..83a149a4 100644 --- a/compiler/plugin.ml +++ b/compiler/plugin.ml @@ -31,16 +31,22 @@ type 'ast gen = { } type t = + | Dcalc of Shared_ast.untyped Dcalc.Ast.program gen | Lcalc of Shared_ast.untyped Lcalc.Ast.program gen | Scalc of Scalc.Ast.program gen -let name = function Lcalc { name; _ } | Scalc { name; _ } -> name +let name = function + | Dcalc { name; _ } | Lcalc { name; _ } | Scalc { name; _ } -> name + let backend_plugins : (string, t) Hashtbl.t = Hashtbl.create 17 let register t = Hashtbl.replace backend_plugins (String.lowercase_ascii (name t)) t module PluginAPI = struct + let register_dcalc ~name ~extension apply = + register (Dcalc { name; extension; apply }) + let register_lcalc ~name ~extension apply = register (Lcalc { name; extension; apply }) diff --git a/compiler/plugin.mli b/compiler/plugin.mli index acfdf56a..0d69561a 100644 --- a/compiler/plugin.mli +++ b/compiler/plugin.mli @@ -33,6 +33,7 @@ type 'ast gen = { } type t = + | Dcalc of Shared_ast.untyped Dcalc.Ast.program gen | Lcalc of Shared_ast.untyped Lcalc.Ast.program gen | Scalc of Scalc.Ast.program gen @@ -48,6 +49,12 @@ val load_dir : string -> unit (** {2 plugin-facing API} *) module PluginAPI : sig + val register_dcalc : + name:string -> + extension:string -> + Shared_ast.untyped Dcalc.Ast.program plugin_apply_fun_typ -> + unit + val register_lcalc : name:string -> extension:string -> From b4a68fa39261eb1e9645d083f91e72a4244b6f48 Mon Sep 17 00:00:00 2001 From: Louis Gesbert Date: Fri, 14 Apr 2023 16:56:57 +0200 Subject: [PATCH 02/10] Add experimental lazy interpreter as a plugin To try it (without installing Catala): ```shell-session $ make plugins $ export CATALA_PLUGINS=_build/default/compiler/plugins $ dune exec -- catala lazy examples/aides_logement/tests/tests_calcul_apl_locatif.catala_fr -s Exemple2 ``` Keep in mind that this is a work-in-progress prototype :) --- compiler/driver.ml | 18 +- compiler/plugin.ml | 2 +- compiler/plugin.mli | 2 +- compiler/plugins/dune | 8 + compiler/plugins/json_schema.ml | 52 ++---- compiler/plugins/lazy_interp.ml | 273 ++++++++++++++++++++++++++++ compiler/shared_ast/interpreter.ml | 13 +- compiler/shared_ast/interpreter.mli | 12 ++ 8 files changed, 334 insertions(+), 46 deletions(-) create mode 100644 compiler/plugins/lazy_interp.ml diff --git a/compiler/driver.ml b/compiler/driver.ml index 3459b16c..a2260256 100644 --- a/compiler/driver.ml +++ b/compiler/driver.ml @@ -293,7 +293,11 @@ let driver source_file (options : Cli.options) : int = let output_file, _ = get_output_format ~ext:p.Plugin.extension () in Cli.debug_print "Compiling program through backend \"%s\"..." p.Plugin.name; - p.Plugin.apply ~source_file ~output_file ~scope:options.ex_scope + p.Plugin.apply ~source_file ~output_file + ~scope: + (match options.ex_scope with + | None -> None + | Some _ -> Some scope_uid) (Shared_ast.Program.untype prgm) type_ordering | (`OCaml | `Interpret_Lcalc | `Python | `Lcalc | `Scalc | `Plugin _) @@ -389,7 +393,11 @@ let driver source_file (options : Cli.options) : int = in Cli.debug_print "Compiling program through backend \"%s\"..." p.Plugin.name; - p.Plugin.apply ~source_file ~output_file ~scope:options.ex_scope + p.Plugin.apply ~source_file ~output_file + ~scope: + (match options.ex_scope with + | None -> None + | Some _ -> Some scope_uid) prgm type_ordering | (`Python | `Scalc | `Plugin (Plugin.Scalc _)) as backend -> ( let prgm = Scalc.From_lcalc.translate_program prgm in @@ -427,7 +435,11 @@ let driver source_file (options : Cli.options) : int = Cli.debug_print "Writing to %s..." (Option.value ~default:"stdout" output_file); p.Plugin.apply ~source_file ~output_file - ~scope:options.ex_scope prgm type_ordering))))))); + ~scope: + (match options.ex_scope with + | None -> None + | Some _ -> Some scope_uid) + prgm type_ordering))))))); 0 with | Errors.StructuredError (msg, pos) -> diff --git a/compiler/plugin.ml b/compiler/plugin.ml index 83a149a4..8c5540ac 100644 --- a/compiler/plugin.ml +++ b/compiler/plugin.ml @@ -19,7 +19,7 @@ open Catala_utils type 'ast plugin_apply_fun_typ = source_file:Pos.input_file -> output_file:string option -> - scope:string option -> + scope:Shared_ast.ScopeName.t option -> 'ast -> Scopelang.Dependency.TVertex.t list -> unit diff --git a/compiler/plugin.mli b/compiler/plugin.mli index 0d69561a..c8046201 100644 --- a/compiler/plugin.mli +++ b/compiler/plugin.mli @@ -21,7 +21,7 @@ open Catala_utils type 'ast plugin_apply_fun_typ = source_file:Pos.input_file -> output_file:string option -> - scope:string option -> + scope:Shared_ast.ScopeName.t option -> 'ast -> Scopelang.Dependency.TVertex.t list -> unit diff --git a/compiler/plugins/dune b/compiler/plugins/dune index 531a170f..0d768020 100644 --- a/compiler/plugins/dune +++ b/compiler/plugins/dune @@ -20,6 +20,14 @@ (modules json_schema) (libraries catala.driver)) +(library + (name lazy_interpreter) + (public_name catala.plugins.lazy-interpreter) + (synopsis + "Catala plugin that implements a different, experimental interpreter, featuring lazy and partial evaluation") + (modules lazy_interp) + (libraries shared_ast catala.driver)) + (documentation (package catala) (mld_files plugins)) diff --git a/compiler/plugins/json_schema.ml b/compiler/plugins/json_schema.ml index e0f8ce02..bd7b783f 100644 --- a/compiler/plugins/json_schema.ml +++ b/compiler/plugins/json_schema.ml @@ -22,7 +22,6 @@ let extension = "_schema.json" open Catala_utils open Shared_ast -open Lcalc.Ast open Lcalc.To_ocaml module D = Dcalc.Ast @@ -47,17 +46,6 @@ module To_json = struct in Format.fprintf fmt "%s" s - let rec find_scope_def (target_name : string) : - 'm expr code_item_list -> (ScopeName.t * 'm expr scope_body) option = - function - | Nil -> None - | Cons (ScopeDef (name, body), _) - when String.equal target_name (Marked.unmark (ScopeName.get_info name)) -> - Some (name, body) - | Cons (_, next_bind) -> - let _, next_scope = Bindlib.unbind next_bind in - find_scope_def target_name next_scope - let fmt_tlit fmt (tlit : typ_lit) = match tlit with | TUnit -> Format.fprintf fmt "\"type\": \"null\",@\n\"default\": null" @@ -203,31 +191,29 @@ module To_json = struct let format_program (fmt : Format.formatter) - (scope : string) + (scope : ScopeName.t) (prgm : 'm Lcalc.Ast.program) = - match find_scope_def scope prgm.code_items with - | None -> Cli.error_print "Internal error: scope '%s' not found." scope - | Some scope_def -> - Cli.call_unstyled (fun _ -> - Format.fprintf fmt - "{@[@\n\ - \"type\": \"object\",@\n\ - \"@[definitions\": {%a@]@\n\ - },@\n\ - \"@[properties\": {@\n\ - %a@]@\n\ - }@]@\n\ - }" - (fmt_definitions prgm.decl_ctx) - scope_def - (fmt_struct_properties prgm.decl_ctx) - (snd scope_def).scope_body_input_struct) + let scope_body = Program.get_scope_body prgm scope in + Cli.call_unstyled (fun _ -> + Format.fprintf fmt + "{@[@\n\ + \"type\": \"object\",@\n\ + \"@[definitions\": {%a@]@\n\ + },@\n\ + \"@[properties\": {@\n\ + %a@]@\n\ + }@]@\n\ + }" + (fmt_definitions prgm.decl_ctx) + (scope, scope_body) + (fmt_struct_properties prgm.decl_ctx) + scope_body.scope_body_input_struct) end let apply ~(source_file : Pos.input_file) ~(output_file : string option) - ~(scope : string option) + ~(scope : Shared_ast.ScopeName.t option) (prgm : 'm Lcalc.Ast.program) (type_ordering : Scopelang.Dependency.TVertex.t list) = ignore source_file; @@ -236,9 +222,9 @@ let apply | Some s -> File.with_formatter_of_opt_file output_file (fun fmt -> Cli.debug_print - "Writing JSON schema corresponding to the scope '%s' to the file \ + "Writing JSON schema corresponding to the scope '%a' to the file \ %s..." - s + ScopeName.format_t s (Option.value ~default:"stdout" output_file); To_json.format_program fmt s prgm) | None -> Cli.error_print "A scope must be specified for the plugin: %s" name diff --git a/compiler/plugins/lazy_interp.ml b/compiler/plugins/lazy_interp.ml new file mode 100644 index 00000000..1d854242 --- /dev/null +++ b/compiler/plugins/lazy_interp.ml @@ -0,0 +1,273 @@ +(* This file is part of the Catala compiler, a specification language for tax + and social benefits computation rules. Copyright (C) 2023 Inria, contributor: + Louis Gesbert . + + Licensed under the Apache License, Version 2.0 (the "License"); you may not + use this file except in compliance with the License. You may obtain a copy of + the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + License for the specific language governing permissions and limitations under + the License. *) + +open Catala_utils +open Shared_ast + +(* -- Definition of the lazy interpreter -- *) + +let log fmt = Format.ifprintf Format.err_formatter (fmt ^^ "@\n") +let error e = Errors.raise_spanned_error (Expr.pos e) +let noassert = true + +type laziness_level = { + eval_struct : bool; + (* if true, evaluate members of structures, tuples, etc. *) + eval_op : bool; + (* if false, evaluate the operands but keep e.g. `3 + 4` as is *) + eval_default : bool; + (* if false, stop evaluating as soon as you can discriminate with + `EEmptyError` *) +} + +let value_level = { eval_struct = false; eval_op = true; eval_default = true } + +module Env = struct + type 'm t = + | Env of + ((dcalc, 'm mark) gexpr, ((dcalc, 'm mark) gexpr * 'm t) ref) Var.Map.t + + let find v (Env t) = Var.Map.find v t + let add v e e_env (Env t) = Env (Var.Map.add v (ref (e, e_env)) t) + let empty = Env Var.Map.empty + + let join (Env t1) (Env t2) = + Env + (Var.Map.union + (fun _ x1 x2 -> + assert (x1 == x2); + Some x1) + t1 t2) + + let print ppf (Env t) = + Format.pp_print_list ~pp_sep:Format.pp_print_space + (fun ppf (v, { contents = _e, _env }) -> Print.var_debug ppf v) + ppf (Var.Map.bindings t) +end + +let rec lazy_eval : + decl_ctx -> + 'm Env.t -> + laziness_level -> + (dcalc, 'm mark) gexpr -> + (dcalc, 'm mark) gexpr * 'm Env.t = + fun ctx env llevel e0 -> + let eval_to_value ?(eval_default = true) env e = + lazy_eval ctx env { value_level with eval_default } e + in + match e0 with + | EVar v, _ -> + if not llevel.eval_default then e0, env + else + (* Variables reducing to EEmpty should not propagate to parent EDefault + (?) *) + let v_env = + try Env.find v env + with Not_found -> + error e0 "Variable %a undefined [@[%a@]]" Print.var_debug v + Env.print env + in + let e, env1 = !v_env in + let r, env1 = lazy_eval ctx env1 llevel e in + if not (Expr.equal e r) then ( + log "@[{{%a =@ [%a]@ ==> [%a]}}@]" Print.var_debug v + (Print.expr ~debug:true ctx) + e + (Print.expr ~debug:true ctx) + r; + v_env := r, env1); + r, Env.join env env1 + | EApp { f; args }, m -> ( + if + (not llevel.eval_default) + && not (List.equal Expr.equal args [ELit LUnit, m]) + (* Applications to () encode thunked default terms *) + then e0, env + else + match eval_to_value env f with + | (EAbs { binder; _ }, _), env -> + let vars, body = Bindlib.unmbind binder in + log "@[@[{"; + let env = + Seq.fold_left2 + (fun env1 var e -> + log "@[LET %a = %a@]@ " Print.var_debug var + (Print.expr ~debug:true ctx) + e; + Env.add var e env env1) + env (Array.to_seq vars) (List.to_seq args) + in + log "@]@[IN [%a]@]" (Print.expr ~debug:true ctx) body; + let e, env = lazy_eval ctx env llevel body in + log "@]}"; + e, env + | ((EOp { op; _ }, m) as f), env -> + let env, args = + List.fold_left_map + (fun env e -> + let e, env = lazy_eval ctx env llevel e in + env, e) + env args + in + if not llevel.eval_op then (EApp { f; args }, m), env + else + let renv = ref env in + (* Dirty workaround returning env from evaluate_operator *) + let eval e = + let e, env = lazy_eval ctx !renv llevel e in + renv := env; + e + in + Interpreter.evaluate_operator eval ctx op m args, !renv + (* fixme: this forwards eempty *) + | e, _ -> error e "Invalid apply on %a" (Print.expr ctx) e) + | (EAbs _ | ELit _ | EOp _ | EEmptyError), _ -> e0, env (* these are values *) + | (EStruct _ | ETuple _ | EInj _ | EArray _), _ -> + if not llevel.eval_struct then e0, env + else + let env, e = + Expr.map_gather ~acc:env ~join:Env.join + ~f:(fun e -> + let e, env = lazy_eval ctx env llevel e in + env, Expr.box e) + e0 + in + Expr.unbox e, env + | EStructAccess { e; name; field }, _ -> ( + if not llevel.eval_default then e0, env + else + match eval_to_value env e with + | (EStruct { name = n; fields }, _), env when StructName.equal name n -> + lazy_eval ctx env llevel (StructField.Map.find field fields) + | e, _ -> error e "Invalid field access on %a" (Print.expr ctx) e) + | ETupleAccess { e; index; size }, _ -> ( + if not llevel.eval_default then e0, env + else + match eval_to_value env e with + | (ETuple es, _), env when List.length es = size -> + lazy_eval ctx env llevel (List.nth es index) + | e, _ -> error e "Invalid tuple access on %a" (Print.expr ctx) e) + | EMatch { e; name; cases }, _ -> ( + if not llevel.eval_default then e0, env + else + match eval_to_value env e with + | (EInj { name = n; cons; e }, m), env when EnumName.equal name n -> + lazy_eval ctx env llevel + (EApp { f = EnumConstructor.Map.find cons cases; args = [e] }, m) + | e, _ -> error e "Invalid match argument %a" (Print.expr ctx) e) + | EDefault { excepts; just; cons }, m -> ( + let excs = + List.filter_map + (fun e -> + match eval_to_value env e ~eval_default:false with + | (EEmptyError, _), _ -> None + | e -> Some e) + excepts + in + match excs with + | [] -> ( + match eval_to_value env just with + | (ELit (LBool true), _), _ -> lazy_eval ctx env llevel cons + | (ELit (LBool false), _), _ -> (EEmptyError, m), env + | e, _ -> error e "Invalid exception justification %a" (Print.expr ctx) e) + | [(e, env)] -> + log "@[EVAL %a@]" (Print.expr ctx) e; + lazy_eval ctx env llevel e + | _ :: _ :: _ -> + Errors.raise_multispanned_error + ((None, Expr.mark_pos m) + :: List.map (fun (e, _) -> None, Expr.pos e) excs) + "Conflicting exceptions") + | EIfThenElse { cond; etrue; efalse }, _ -> ( + match eval_to_value env cond with + | (ELit (LBool true), _), _ -> lazy_eval ctx env llevel etrue + | (ELit (LBool false), _), _ -> lazy_eval ctx env llevel efalse + | e, _ -> error e "Invalid condition %a" (Print.expr ctx) e) + | EErrorOnEmpty e, _ -> ( + match eval_to_value env e ~eval_default:false with + | ((EEmptyError, _) as e'), _ -> + (* This does _not_ match the eager semantics ! *) + error e' "This value is undefined %a" (Print.expr ctx) e + | e, env -> lazy_eval ctx env llevel e) + | EAssert e, m -> ( + if noassert then (ELit LUnit, m), env + else + match eval_to_value env e with + | (ELit (LBool true), m), env -> (ELit LUnit, m), env + | (ELit (LBool false), _), _ -> + error e "Assert failure (%a)" (Print.expr ctx) e + | _ -> error e "Invalid assertion condition %a" (Print.expr ctx) e) + | _ -> . + +let interpret_program + (prg : ('dcalc, 'm mark) gexpr program) + (scope : ScopeName.t) : ('t, 'm mark) gexpr * 'm Env.t = + let ctx = prg.decl_ctx in + let all_env, scopes = + Scope.fold_left prg.code_items ~init:(Env.empty, ScopeName.Map.empty) + ~f:(fun (env, scopes) item v -> + match item with + | ScopeDef (name, body) -> + let e = Scope.to_expr ctx body (Scope.get_body_mark body) in + ( Env.add v (Expr.unbox e) env env, + ScopeName.Map.add name (v, body.scope_body_input_struct) scopes ) + | Topdef (_, _, e) -> Env.add v e env env, scopes) + in + let scope_v, scope_arg_struct = ScopeName.Map.find scope scopes in + let { contents = e, env } = Env.find scope_v all_env in + let e = Expr.unbox (Expr.remove_logging_calls e) in + log "====================="; + log "%a" (Print.expr ~debug:true ctx) e; + log "====================="; + let m = Marked.get_mark e in + let application_arg = + Expr.estruct scope_arg_struct + (StructField.Map.map + (function + | TArrow (ty_in, ty_out), _ -> + Expr.make_abs + [| Var.make "_" |] + (Bindlib.box EEmptyError, Expr.with_ty m ty_out) + ty_in (Expr.mark_pos m) + | ty -> Expr.evar (Var.make "undefined_input") (Expr.with_ty m ty)) + (StructName.Map.find scope_arg_struct ctx.ctx_structs)) + m + in + let e_app = Expr.eapp (Expr.box e) [application_arg] m in + lazy_eval ctx env + { value_level with eval_struct = true; eval_op = false } + (Expr.unbox e_app) + +(* -- Plugin registration -- *) + +let name = "lazy" +let extension = ".out" (* unused *) + +let apply ~source_file ~output_file ~scope prg _type_ordering = + let scope = + match scope with + | None -> Errors.raise_error "A scope must be specified" + | Some s -> s + in + ignore source_file; + (* File.with_formatter_of_opt_file output_file + * @@ fun fmt -> *) + ignore output_file; + let fmt = Format.std_formatter in + let result_expr, _env = interpret_program prg scope in + Print.expr prg.decl_ctx fmt result_expr + +let () = Driver.Plugin.register_dcalc ~name ~extension apply diff --git a/compiler/shared_ast/interpreter.ml b/compiler/shared_ast/interpreter.ml index e2e480a6..5296e951 100644 --- a/compiler/shared_ast/interpreter.ml +++ b/compiler/shared_ast/interpreter.ml @@ -191,15 +191,14 @@ let rec evaluate_operator EArray (List.map (fun e' -> - evaluate_expr ctx (Marked.same_mark_as (EApp { f; args = [e'] }) e')) + evaluate_expr (Marked.same_mark_as (EApp { f; args = [e'] }) e')) es) | Reduce, [_; default; (EArray [], _)] -> Marked.unmark default | Reduce, [f; _; (EArray (x0 :: xn), _)] -> Marked.unmark (List.fold_left (fun acc x -> - evaluate_expr ctx - (Marked.same_mark_as (EApp { f; args = [acc; x] }) f)) + evaluate_expr (Marked.same_mark_as (EApp { f; args = [acc; x] }) f)) x0 xn) | Concat, [(EArray es1, _); (EArray es2, _)] -> EArray (es1 @ es2) | Filter, [f; (EArray es, _)] -> @@ -207,8 +206,7 @@ let rec evaluate_operator (List.filter (fun e' -> match - evaluate_expr ctx - (Marked.same_mark_as (EApp { f; args = [e'] }) e') + evaluate_expr (Marked.same_mark_as (EApp { f; args = [e'] }) e') with | ELit (LBool b), _ -> b | _ -> @@ -221,8 +219,7 @@ let rec evaluate_operator Marked.unmark (List.fold_left (fun acc e' -> - evaluate_expr ctx - (Marked.same_mark_as (EApp { f; args = [acc; e'] }) e')) + evaluate_expr (Marked.same_mark_as (EApp { f; args = [acc; e'] }) e')) init es) | (Length | Log _ | Eq | Map | Concat | Filter | Fold | Reduce), _ -> err () | Not, [(ELit (LBool b), _)] -> ELit (LBool (o_not b)) @@ -409,7 +406,7 @@ let rec evaluate_expr : | Eq_dur_dur ) as op; _; } -> - evaluate_operator evaluate_expr ctx op m args + evaluate_operator (evaluate_expr ctx) ctx op m args | _ -> Errors.raise_spanned_error pos "function has not been reduced to a lambda at evaluation (should not \ diff --git a/compiler/shared_ast/interpreter.mli b/compiler/shared_ast/interpreter.mli index 4e5911bd..b311d837 100644 --- a/compiler/shared_ast/interpreter.mli +++ b/compiler/shared_ast/interpreter.mli @@ -20,6 +20,18 @@ open Catala_utils open Definitions +val evaluate_operator : + ((([< all ] as 'a), 'm mark) gexpr -> ('a, 'm mark) gexpr) -> + decl_ctx -> + [< dcalc | lcalc > `Monomorphic `Polymorphic `Resolved ] operator -> + 'm mark -> + ('a, 'm mark) gexpr list -> + ('a, 'm mark) gexpr +(** Evaluates the result of applying the given operator to the given arguments, + which are expected to be already reduced to values. The first argument is + used to evaluate expressions and called when reducing e.g. the [map] + operator. *) + val evaluate_expr : decl_ctx -> (([< dcalc | lcalc ] as 'a), 'm mark) gexpr -> ('a, 'm mark) gexpr (** Evaluates an expression according to the semantics of the default calculus. *) From b757b828a0decc95a10eb24302dff802c00a3ec3 Mon Sep 17 00:00:00 2001 From: adelaett <90894311+adelaett@users.noreply.github.com> Date: Mon, 17 Apr 2023 00:22:21 +0000 Subject: [PATCH 03/10] Update lock files --- flake.lock | 30 ++++++++++++++++++++++++------ 1 file changed, 24 insertions(+), 6 deletions(-) diff --git a/flake.lock b/flake.lock index 343996da..0792da3c 100644 --- a/flake.lock +++ b/flake.lock @@ -1,12 +1,15 @@ { "nodes": { "flake-utils": { + "inputs": { + "systems": "systems" + }, "locked": { - "lastModified": 1676283394, - "narHash": "sha256-XX2f9c3iySLCw54rJ/CZs+ZK6IQy7GXNY4nSOyu2QG4=", + "lastModified": 1681202837, + "narHash": "sha256-H+Rh19JDwRtpVPAWp64F+rlEtxUWBAQW28eAi3SRSzg=", "owner": "numtide", "repo": "flake-utils", - "rev": "3db36a8b464d0c4532ba1c7dda728f4576d6d073", + "rev": "cfacdce06f30d2b68473a46042957675eebb3401", "type": "github" }, "original": { @@ -17,11 +20,11 @@ }, "nixpkgs": { "locked": { - "lastModified": 1678470307, - "narHash": "sha256-OEeMUr3ueLIXyW/OaFUX5jUdimyQwMg/7e+/Q0gC/QE=", + "lastModified": 1681648924, + "narHash": "sha256-pzi3HISK8+7mpEtv08Yr80wswyHKsz+RP1CROG1Qf6s=", "owner": "NixOS", "repo": "nixpkgs", - "rev": "0c4800d579af4ed98ecc47d464a5e7b0870c4b1f", + "rev": "f294325aed382b66c7a188482101b0f336d1d7db", "type": "github" }, "original": { @@ -36,6 +39,21 @@ "flake-utils": "flake-utils", "nixpkgs": "nixpkgs" } + }, + "systems": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } } }, "root": "root", From 2afb6fc20c049a7d484fce957445c07cfb69ad25 Mon Sep 17 00:00:00 2001 From: Denis Merigoux Date: Fri, 7 Apr 2023 12:39:26 +0200 Subject: [PATCH 04/10] I/O plumbing necessary for this feature, missing main implem --- compiler/catala_utils/cli.ml | 24 ++- compiler/catala_utils/cli.mli | 2 + compiler/desugared/print.ml | 22 +++ compiler/desugared/print.mli | 19 +++ compiler/driver.ml | 118 +++++++++++++- compiler/scopelang/from_desugared.ml | 225 ++++++++++++++++---------- compiler/scopelang/from_desugared.mli | 8 +- compiler/shared_ast/definitions.ml | 52 ++++++ 8 files changed, 375 insertions(+), 95 deletions(-) create mode 100644 compiler/desugared/print.ml create mode 100644 compiler/desugared/print.mli diff --git a/compiler/catala_utils/cli.ml b/compiler/catala_utils/cli.ml index 4b3e2549..ec16e7e2 100644 --- a/compiler/catala_utils/cli.ml +++ b/compiler/catala_utils/cli.ml @@ -22,6 +22,7 @@ type backend_option_builtin = | `Makefile | `Html | `Interpret + | `Interpret_Lcalc | `Typecheck | `OCaml | `Python @@ -29,8 +30,8 @@ type backend_option_builtin = | `Lcalc | `Dcalc | `Scopelang - | `Proof - | `Interpret_Lcalc ] + | `Exceptions + | `Proof ] type 'a backend_option = [ backend_option_builtin | `Plugin of 'a ] @@ -55,6 +56,7 @@ let backend_option_to_string = function | `Typecheck -> "Typecheck" | `Scalc -> "Scalc" | `Lcalc -> "Lcalc" + | `Exceptions -> "Exceptions" | `Plugin s -> s let backend_option_of_string backend = @@ -72,6 +74,7 @@ let backend_option_of_string backend = | "typecheck" -> `Typecheck | "scalc" -> `Scalc | "lcalc" -> `Lcalc + | "exceptions" -> `Exceptions | s -> `Plugin s (** Source files to be compiled *) @@ -234,6 +237,12 @@ let ex_scope = & opt (some string) None & info ["s"; "scope"] ~docv:"SCOPE" ~doc:"Scope to be focused on.") +let ex_variable = + Arg.( + value + & opt (some string) None + & info ["v"; "variable"] ~docv:"VARIABLE" ~doc:"Variable to be focused on.") + let output = Arg.( value @@ -258,6 +267,7 @@ type options = { disable_counterexamples : bool; optimize : bool; ex_scope : string option; + ex_variable : string option; output_file : string option; closure_conversion : bool; print_only_law : bool; @@ -280,6 +290,7 @@ let options = disable_counterexamples optimize ex_scope + ex_variable output_file print_only_law : options = { @@ -296,6 +307,7 @@ let options = disable_counterexamples; optimize; ex_scope; + ex_variable; output_file; closure_conversion; print_only_law; @@ -318,6 +330,7 @@ let options = $ disable_counterexamples_opt $ optimize $ ex_scope + $ ex_variable $ output $ print_only_law) @@ -402,6 +415,13 @@ let info = "Prints a debugging verbatim of the statement calculus intermediate \ representation of the Catala program. Use the $(b,-s) option to \ restrict the output to a particular scope." ); + `I + ( "$(b,Exceptions)", + "Prints the exception tree for the definitions of a particular \ + variable, for debugging purposes. Use the $(b,-s) option to select \ + the scope and the $(b,-v) option to select the variable. Use \ + foo.bar to access state bar of variable foo or variable bar of \ + subscope foo." ); `I ( "$(b,pygmentize)", "This special command is a wrapper around the $(b,pygmentize) \ diff --git a/compiler/catala_utils/cli.mli b/compiler/catala_utils/cli.mli index 648cc311..ef48ac0f 100644 --- a/compiler/catala_utils/cli.mli +++ b/compiler/catala_utils/cli.mli @@ -30,6 +30,7 @@ type backend_option_builtin = | `Lcalc | `Dcalc | `Scopelang + | `Exceptions | `Proof ] type 'a backend_option = [ backend_option_builtin | `Plugin of 'a ] @@ -105,6 +106,7 @@ type options = { disable_counterexamples : bool; optimize : bool; ex_scope : string option; + ex_variable : string option; output_file : string option; closure_conversion : bool; print_only_law : bool; diff --git a/compiler/desugared/print.ml b/compiler/desugared/print.ml new file mode 100644 index 00000000..ab84ac20 --- /dev/null +++ b/compiler/desugared/print.ml @@ -0,0 +1,22 @@ +(* This file is part of the Catala compiler, a specification language for tax + and social benefits computation rules. Copyright (C) 2023 Inria, contributor: + Denis Merigoux + + Licensed under the Apache License, Version 2.0 (the "License"); you may not + use this file except in compliance with the License. You may obtain a copy of + the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + License for the specific language governing permissions and limitations under + the License. *) + +open Shared_ast + +let print_exceptions_graph + (var : DesugaredVarName.t) + (g : Dependency.ExceptionsDependencies.t) = + assert false diff --git a/compiler/desugared/print.mli b/compiler/desugared/print.mli new file mode 100644 index 00000000..739e2b07 --- /dev/null +++ b/compiler/desugared/print.mli @@ -0,0 +1,19 @@ +(* This file is part of the Catala compiler, a specification language for tax + and social benefits computation rules. Copyright (C) 2023 Inria, contributor: + Denis Merigoux + + Licensed under the Apache License, Version 2.0 (the "License"); you may not + use this file except in compliance with the License. You may obtain a copy of + the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + License for the specific language governing permissions and limitations under + the License. *) + +val print_exceptions_graph : + Shared_ast.DesugaredVarName.t -> Dependency.ExceptionsDependencies.t -> unit +(** Prints the exception graph of a variable to the terminal *) diff --git a/compiler/driver.ml b/compiler/driver.ml index 2a1c1561..b4c4540f 100644 --- a/compiler/driver.ml +++ b/compiler/driver.ml @@ -140,7 +140,8 @@ let driver source_file (options : Cli.options) : int = language fmt (fun fmt -> weave_output fmt prgm) else weave_output fmt prgm) | ( `Interpret | `Interpret_Lcalc | `Typecheck | `OCaml | `Python | `Scalc - | `Lcalc | `Dcalc | `Scopelang | `Proof | `Plugin _ ) as backend -> ( + | `Lcalc | `Dcalc | `Scopelang | `Exceptions | `Proof | `Plugin _ ) as + backend -> ( Cli.debug_print "Name resolution..."; let ctxt = Desugared.Name_resolution.form_context prgm in let scope_uid = @@ -164,9 +165,108 @@ let driver source_file (options : Cli.options) : int = match Shared_ast.IdentName.Map.find_opt name ctxt.typedefs with | Some (Desugared.Name_resolution.TScope (uid, _)) -> uid | _ -> - Errors.raise_error "There is no scope \"%s\" inside the program." + Errors.raise_error "There is no scope \"%a\" inside the program." + (Cli.format_with_style [ANSITerminal.yellow]) name) in + (* This uid is a Desugared identifier *) + let variable_uid = + match options.ex_variable, backend with + | None, `Exceptions -> + Errors.raise_error + "Please specify a variable with the -v option to print its \ + exception tree." + | None, _ -> None + | Some name, _ -> ( + (* Sometimes the variable selected is of the form [a.b]*) + let first_part, second_part = + match + Re.( + exec_opt + (compile + @@ whole_string + @@ seq + [ + group (rep1 (compl [char '.'])); + char '.'; + group (rep1 any); + ]) + name) + with + | None -> name, None + | Some groups -> Re.Group.get groups 1, Some (Re.Group.get groups 2) + in + match + Shared_ast.IdentName.Map.find_opt first_part + (Shared_ast.ScopeName.Map.find scope_uid ctxt.scopes).var_idmap + with + | None -> + Errors.raise_error "Variable \"%a\" not found inside scope \"%a\"" + (Cli.format_with_style [ANSITerminal.yellow]) + name + (Cli.format_with_style [ANSITerminal.yellow]) + (Format.asprintf "%a" Shared_ast.ScopeName.format_t scope_uid) + | Some + (Desugared.Name_resolution.SubScope + (subscope_var_name, subscope_name)) -> ( + match second_part with + | None -> + Errors.raise_error + "Subscope \"%a\" of scope \"%a\" cannot be selected by itself, \ + please add \".\" where is a subscope variable." + (Cli.format_with_style [ANSITerminal.yellow]) + (Format.asprintf "%a" Shared_ast.SubScopeName.format_t + subscope_var_name) + (Cli.format_with_style [ANSITerminal.yellow]) + (Format.asprintf "%a" Shared_ast.ScopeName.format_t scope_uid) + | Some second_part -> ( + match + Shared_ast.IdentName.Map.find_opt second_part + (Shared_ast.ScopeName.Map.find subscope_name ctxt.scopes) + .var_idmap + with + | Some (Desugared.Name_resolution.ScopeVar v) -> + Some + (Shared_ast.DesugaredVarName.SubScopeVar (subscope_var_name, v)) + | _ -> + Errors.raise_error + "Var \"%a\" of subscope \"%a\" in scope \"%a\" does not \ + exist, please check your command line arguments." + (Cli.format_with_style [ANSITerminal.yellow]) + second_part + (Cli.format_with_style [ANSITerminal.yellow]) + (Format.asprintf "%a" Shared_ast.SubScopeName.format_t + subscope_var_name) + (Cli.format_with_style [ANSITerminal.yellow]) + (Format.asprintf "%a" Shared_ast.ScopeName.format_t scope_uid) + )) + | Some (Desugared.Name_resolution.ScopeVar v) -> + Some + (Shared_ast.DesugaredVarName.ScopeVar + ( v, + Option.map + (fun second_part -> + let var_sig = + Shared_ast.ScopeVar.Map.find v ctxt.var_typs + in + match + Shared_ast.IdentName.Map.find_opt second_part + var_sig.var_sig_states_idmap + with + | Some state -> state + | None -> + Errors.raise_error + "State \"%a\" is not found for variable \"%a\" of \ + scope \"%a\"" + (Cli.format_with_style [ANSITerminal.yellow]) + second_part + (Cli.format_with_style [ANSITerminal.yellow]) + first_part + (Cli.format_with_style [ANSITerminal.yellow]) + (Format.asprintf "%a" Shared_ast.ScopeName.format_t + scope_uid)) + second_part ))) + in Cli.debug_print "Desugaring..."; let prgm = Desugared.From_surface.translate_program ctxt prgm in Cli.debug_print "Disambiguating..."; @@ -174,8 +274,20 @@ let driver source_file (options : Cli.options) : int = Cli.debug_print "Linting..."; Desugared.Linting.lint_program prgm; Cli.debug_print "Collecting rules..."; - let prgm = Scopelang.From_desugared.translate_program prgm in + let prgm, exceptions_graphs = + Scopelang.From_desugared.translate_program prgm + in match backend with + | `Exceptions -> + let variable_uid = + match variable_uid with + | Some variable_uid -> variable_uid + | None -> + Errors.raise_error + "Please provide a scope variable to analyze with the -v option." + in + Desugared.Print.print_exceptions_graph variable_uid + (Shared_ast.DesugaredVarName.Map.find variable_uid exceptions_graphs) | `Scopelang -> let _output_file, with_output = get_output_format () in with_output diff --git a/compiler/scopelang/from_desugared.ml b/compiler/scopelang/from_desugared.ml index 28d7b3cf..5e8f48b3 100644 --- a/compiler/scopelang/from_desugared.ml +++ b/compiler/scopelang/from_desugared.ml @@ -180,7 +180,8 @@ type rule_tree = priorities declared between rules *) let def_map_to_tree (def_info : Desugared.Ast.ScopeDef.t) - (def : Desugared.Ast.rule RuleName.Map.t) : rule_tree list = + (def : Desugared.Ast.rule RuleName.Map.t) : + rule_tree list * Desugared.Dependency.ExceptionsDependencies.t = let exc_graph = Desugared.Dependency.build_exceptions_graph def def_info in Desugared.Dependency.check_for_exception_cycle def exc_graph; (* we start by the base cases: they are the vertices which have no @@ -207,7 +208,7 @@ let def_map_to_tree | [] -> Leaf base_case_as_rule_list | _ -> Node (List.map build_tree exceptions, base_case_as_rule_list) in - List.map build_tree base_cases + List.map build_tree base_cases, exc_graph (** From the {!type: rule_tree}, builds an {!constructor: Dcalc.EDefault} expression in the scope language. The [~toplevel] parameter is used to know @@ -342,9 +343,10 @@ let translate_def (typ : typ) (io : Desugared.Ast.io) ~(is_cond : bool) - ~(is_subscope_var : bool) : untyped Ast.expr boxed = + ~(is_subscope_var : bool) : + untyped Ast.expr boxed * Desugared.Dependency.ExceptionsDependencies.t = (* Here, we have to transform this list of rules into a default tree. *) - let top_list = def_map_to_tree def_info def in + let top_list, exc_graph = def_map_to_tree def_info def in let is_input = match Marked.unmark io.Desugared.Ast.io_input with | OnlyInput -> true @@ -397,34 +399,37 @@ let translate_def match params with | Some (ps, _) -> let labels, tys = List.split ps in - Expr.make_abs - (Array.of_list - (List.map (fun lbl -> Var.make (Marked.unmark lbl)) labels)) - empty_error tys (Expr.mark_pos m) - | _ -> empty_error + ( Expr.make_abs + (Array.of_list + (List.map (fun lbl -> Var.make (Marked.unmark lbl)) labels)) + empty_error tys (Expr.mark_pos m), + exc_graph ) + | _ -> empty_error, exc_graph else - rule_tree_to_expr ~toplevel:true ~is_reentrant_var:is_reentrant ctx - (Desugared.Ast.ScopeDef.get_position def_info) - (Option.map - (fun (ps, _) -> - (List.map (fun (lbl, _) -> Var.make (Marked.unmark lbl))) ps) - params) - (match top_list, top_value with - | [], None -> - (* In this case, there are no rules to define the expression and no - default value so we put an empty rule. *) - Leaf [Desugared.Ast.empty_rule (Marked.get_mark typ) params] - | [], Some top_value -> - (* In this case, there are no rules to define the expression but a - default value so we put it. *) - Leaf [top_value] - | _, Some top_value -> - (* When there are rules + a default value, we put the rules as - exceptions to the default value *) - Node (top_list, [top_value]) - | [top_tree], None -> top_tree - | _, None -> - Node (top_list, [Desugared.Ast.empty_rule (Marked.get_mark typ) params])) + ( rule_tree_to_expr ~toplevel:true ~is_reentrant_var:is_reentrant ctx + (Desugared.Ast.ScopeDef.get_position def_info) + (Option.map + (fun (ps, _) -> + (List.map (fun (lbl, _) -> Var.make (Marked.unmark lbl))) ps) + params) + (match top_list, top_value with + | [], None -> + (* In this case, there are no rules to define the expression and no + default value so we put an empty rule. *) + Leaf [Desugared.Ast.empty_rule (Marked.get_mark typ) params] + | [], Some top_value -> + (* In this case, there are no rules to define the expression but a + default value so we put it. *) + Leaf [top_value] + | _, Some top_value -> + (* When there are rules + a default value, we put the rules as + exceptions to the default value *) + Node (top_list, [top_value]) + | [top_tree], None -> top_tree + | _, None -> + Node + (top_list, [Desugared.Ast.empty_rule (Marked.get_mark typ) params])), + exc_graph ) let translate_rule ctx (scope : Desugared.Ast.scope) = function | Desugared.Dependency.Vertex.Var (var, state) -> ( @@ -449,10 +454,10 @@ let translate_rule ctx (scope : Desugared.Ast.scope) = function (RuleName.Map.bindings var_def)) "It is impossible to give a definition to a scope variable tagged as \ input." - | OnlyInput -> [] + | OnlyInput -> [], DesugaredVarName.Map.empty (* we do not provide any definition for an input-only variable *) | _ -> - let expr_def = + let expr_def, exc_graph = translate_def ctx (Desugared.Ast.ScopeDef.Var (var, state)) var_def var_params var_typ scope_def.Desugared.Ast.scope_def_io @@ -464,15 +469,18 @@ let translate_rule ctx (scope : Desugared.Ast.scope) = function | States states, Some state -> List.assoc state states | _ -> failwith "should not happen" in - [ - Ast.Definition - ( ( ScopelangScopeVar - (scope_var, Marked.get_mark (ScopeVar.get_info scope_var)), - Marked.get_mark (ScopeVar.get_info scope_var) ), - var_typ, - scope_def.Desugared.Ast.scope_def_io, - Expr.unbox expr_def ); - ]) + ( [ + Ast.Definition + ( ( ScopelangScopeVar + (scope_var, Marked.get_mark (ScopeVar.get_info scope_var)), + Marked.get_mark (ScopeVar.get_info scope_var) ), + var_typ, + scope_def.Desugared.Ast.scope_def_io, + Expr.unbox expr_def ); + ], + DesugaredVarName.Map.singleton + (DesugaredVarName.ScopeVar (var, state)) + exc_graph )) | Desugared.Dependency.Vertex.SubScope sub_scope_index -> (* Before calling the sub_scope, we need to include all the re-definitions of subscope parameters*) @@ -539,7 +547,7 @@ let translate_rule ctx (scope : Desugared.Ast.scope) = function | _ -> ()); (* Now that all is good, we can proceed with translating this redefinition to a proper Scopelang term. *) - let expr_def = + let expr_def, exc_graph = translate_def ctx def_key def scope_def.D.scope_def_parameters def_typ scope_def.Desugared.Ast.scope_def_io ~is_cond ~is_subscope_var:true @@ -548,40 +556,53 @@ let translate_rule ctx (scope : Desugared.Ast.scope) = function SubScopeName.Map.find sub_scope_index scope.scope_sub_scopes in let var_pos = Desugared.Ast.ScopeDef.get_position def_key in - Ast.Definition - ( ( SubScopeVar - ( subscop_real_name, - (sub_scope_index, var_pos), - match - ScopeVar.Map.find sub_scope_var ctx.scope_var_mapping - with - | WholeVar v -> v, var_pos - | States states -> - (* When defining a sub-scope variable, we always define - its first state in the sub-scope. *) - snd (List.hd states), var_pos ), - var_pos ), - def_typ, - scope_def.Desugared.Ast.scope_def_io, - Expr.unbox expr_def )) + ( Ast.Definition + ( ( SubScopeVar + ( subscop_real_name, + (sub_scope_index, var_pos), + match + ScopeVar.Map.find sub_scope_var ctx.scope_var_mapping + with + | WholeVar v -> v, var_pos + | States states -> + (* When defining a sub-scope variable, we always + define its first state in the sub-scope. *) + snd (List.hd states), var_pos ), + var_pos ), + def_typ, + scope_def.Desugared.Ast.scope_def_io, + Expr.unbox expr_def ), + (exc_graph, sub_scope_var) )) sub_scope_vars_redefs_candidates in - let sub_scope_vars_redefs = + let sub_scope_vars_redefs_and_exc_graphs = List.map snd (Desugared.Ast.ScopeDefMap.bindings sub_scope_vars_redefs) in - sub_scope_vars_redefs - @ [ - Ast.Call - ( sub_scope, - sub_scope_index, - Untyped - { pos = Marked.get_mark (SubScopeName.get_info sub_scope_index) } - ); - ] + let sub_scope_vars_redefs = + List.map fst sub_scope_vars_redefs_and_exc_graphs + in + ( sub_scope_vars_redefs + @ [ + Ast.Call + ( sub_scope, + sub_scope_index, + Untyped + { + pos = Marked.get_mark (SubScopeName.get_info sub_scope_index); + } ); + ], + List.fold_left + (fun exc_graphs (new_exc_graph, subscope_var) -> + DesugaredVarName.Map.add + (DesugaredVarName.SubScopeVar (sub_scope_index, subscope_var)) + new_exc_graph exc_graphs) + DesugaredVarName.Map.empty + (List.map snd sub_scope_vars_redefs_and_exc_graphs) ) (** Translates a scope *) let translate_scope (ctx : ctx) (scope : Desugared.Ast.scope) : - untyped Ast.scope_decl = + untyped Ast.scope_decl + * Desugared.Dependency.ExceptionsDependencies.t DesugaredVarName.Map.t = let scope_dependencies = Desugared.Dependency.build_scope_dependencies scope in @@ -589,8 +610,18 @@ let translate_scope (ctx : ctx) (scope : Desugared.Ast.scope) : let scope_ordering = Desugared.Dependency.correct_computation_ordering scope_dependencies in - let scope_decl_rules = - List.flatten (List.map (translate_rule ctx scope) scope_ordering) + let scope_decl_rules, exceptions_graphs = + List.fold_left + (fun (scope_decl_rules, exceptions_graphs) scope_def_key -> + let new_rules, new_exceptions_graphs = + translate_rule ctx scope scope_def_key + in + ( scope_decl_rules @ new_rules, + DesugaredVarName.Map.union + (fun _ _ _ -> assert false (* there should not be key conflicts *)) + new_exceptions_graphs exceptions_graphs )) + ([], DesugaredVarName.Map.empty) + scope_ordering in (* Then, after having computed all the scopes variables, we add the assertions. TODO: the assertions should be interleaved with the @@ -641,17 +672,20 @@ let translate_scope (ctx : ctx) (scope : Desugared.Ast.scope) : scope.scope_vars ScopeVar.Map.empty in let pos = Marked.get_mark (ScopeName.get_info scope.scope_uid) in - { - Ast.scope_decl_name = scope.scope_uid; - Ast.scope_decl_rules; - Ast.scope_sig; - Ast.scope_mark = Untyped { pos }; - Ast.scope_options = scope.scope_options; - } + ( { + Ast.scope_decl_name = scope.scope_uid; + Ast.scope_decl_rules; + Ast.scope_sig; + Ast.scope_mark = Untyped { pos }; + Ast.scope_options = scope.scope_options; + }, + exceptions_graphs ) (** {1 API} *) -let translate_program (pgrm : Desugared.Ast.program) : untyped Ast.program = +let translate_program (pgrm : Desugared.Ast.program) : + untyped Ast.program + * Desugared.Dependency.ExceptionsDependencies.t DesugaredVarName.Map.t = (* First we give mappings to all the locations between Desugared and This involves creating a new Scopelang scope variable for every state of a Desugared variable. *) @@ -706,12 +740,25 @@ let translate_program (pgrm : Desugared.Ast.program) : untyped Ast.program = { out_str with out_struct_fields }) pgrm.Desugared.Ast.program_ctx.ctx_scopes in - { - Ast.program_topdefs = - TopdefName.Map.map - (fun (e, ty) -> Expr.unbox (translate_expr ctx e), ty) - pgrm.program_topdefs; - Ast.program_scopes = - ScopeName.Map.map (translate_scope ctx) pgrm.program_scopes; - program_ctx = { pgrm.program_ctx with ctx_scopes }; - } + let new_program_scopes, exceptions_graphs = + ScopeName.Map.fold + (fun scope_name scope (new_program_scopes, exceptions_graph) -> + let new_program_scope, new_exceptions_graphs = + translate_scope ctx scope + in + ( ScopeName.Map.add scope_name new_program_scope new_program_scopes, + DesugaredVarName.Map.union + (fun _ _ _ -> assert false (* key conflicts should not happen*)) + new_exceptions_graphs exceptions_graph )) + pgrm.program_scopes + (ScopeName.Map.empty, DesugaredVarName.Map.empty) + in + ( { + Ast.program_topdefs = + TopdefName.Map.map + (fun (e, ty) -> Expr.unbox (translate_expr ctx e), ty) + pgrm.program_topdefs; + Ast.program_scopes = new_program_scopes; + program_ctx = { pgrm.program_ctx with ctx_scopes }; + }, + exceptions_graphs ) diff --git a/compiler/scopelang/from_desugared.mli b/compiler/scopelang/from_desugared.mli index 8f2dae8c..742d6eae 100644 --- a/compiler/scopelang/from_desugared.mli +++ b/compiler/scopelang/from_desugared.mli @@ -16,4 +16,10 @@ (** Translation from {!module: Desugared.Ast} to {!module: Scopelang.Ast} *) -val translate_program : Desugared.Ast.program -> Shared_ast.untyped Ast.program +val translate_program : + Desugared.Ast.program -> + Shared_ast.untyped Ast.program + * Desugared.Dependency.ExceptionsDependencies.t + Shared_ast.DesugaredVarName.Map.t +(** This functions returns the translated program as well as all the graphs of + exceptions inferred for each scope variable of the program. *) diff --git a/compiler/shared_ast/definitions.ml b/compiler/shared_ast/definitions.ml index a9a0d37a..115bfd5c 100644 --- a/compiler/shared_ast/definitions.ml +++ b/compiler/shared_ast/definitions.ml @@ -45,6 +45,58 @@ module SubScopeName = Uid.Gen () module StateName = Uid.Gen () (** {1 Abstract syntax tree} *) +module DesugaredVarName : sig + type t = + | ScopeVar of ScopeVar.t * StateName.t option + | SubScopeVar of SubScopeName.t * ScopeVar.t + + val hash : t -> int + val compare : t -> t -> int + val equal : t -> t -> bool + + module Map : Map.S with type key = t + module Set : Set.S with type elt = t +end = struct + module Ordering = struct + type t = + | ScopeVar of ScopeVar.t * StateName.t option + | SubScopeVar of SubScopeName.t * ScopeVar.t + + let hash x = + match x with + | ScopeVar (x, None) -> ScopeVar.hash x + | ScopeVar (x, Some sx) -> + Int.logxor (ScopeVar.hash x) (StateName.hash sx) + | SubScopeVar (x, y) -> Int.logxor (SubScopeName.hash x) (ScopeVar.hash y) + + let compare x y = + match x, y with + | ScopeVar (x, xst), ScopeVar (y, yst) -> ( + match ScopeVar.compare x y with + | 0 -> Option.compare StateName.compare xst yst + | n -> n) + | SubScopeVar (x, xv), SubScopeVar (y, yv) -> ( + match SubScopeName.compare x y with + | 0 -> ScopeVar.compare xv yv + | n -> n) + | ScopeVar _, _ -> -1 + | _, ScopeVar _ -> 1 + | SubScopeVar _, _ -> . + | _, SubScopeVar _ -> . + + let equal x y = + match x, y with + | ScopeVar (x, sx), ScopeVar (y, sy) -> + ScopeVar.equal x y && Option.equal StateName.equal sx sy + | SubScopeVar (x, xv), SubScopeVar (y, yv) -> + SubScopeName.equal x y && ScopeVar.equal xv yv + | (ScopeVar _ | SubScopeVar _), _ -> false + end + + include Ordering + module Map = Map.Make (Ordering) + module Set = Set.Make (Ordering) +end (** Define a common base type for the expressions in most passes of the compiler *) From 6479c3c10bcc2bf052ae9eac54958aff0001b92f Mon Sep 17 00:00:00 2001 From: Denis Merigoux Date: Fri, 7 Apr 2023 16:35:09 +0200 Subject: [PATCH 05/10] Print exception tree --- compiler/desugared/dependency.ml | 45 +++++++++---- compiler/desugared/dependency.mli | 8 ++- compiler/desugared/print.ml | 95 +++++++++++++++++++++++++++- compiler/desugared/print.mli | 5 +- compiler/driver.ml | 44 ++++++------- compiler/scopelang/from_desugared.ml | 7 +- compiler/shared_ast/definitions.ml | 9 +++ 7 files changed, 174 insertions(+), 39 deletions(-) diff --git a/compiler/desugared/dependency.ml b/compiler/desugared/dependency.ml index 97ab4935..461619e7 100644 --- a/compiler/desugared/dependency.ml +++ b/compiler/desugared/dependency.ml @@ -244,10 +244,14 @@ let build_scope_dependencies (scope : Ast.scope) : ScopeDependencies.t = (** {2 Graph declaration} *) module ExceptionVertex = struct - include RuleName.Set + type t = { rules : Pos.t RuleName.Map.t; label : LabelName.t } + + let compare x y = RuleName.Map.compare compare x.rules y.rules let hash (x : t) : int = - RuleName.Set.fold (fun r acc -> Int.logxor (RuleName.hash r) acc) x 0 + RuleName.Map.fold + (fun r _ acc -> Int.logxor (RuleName.hash r) acc) + x.rules 0 let equal x y = compare x y = 0 end @@ -353,9 +357,19 @@ let build_exceptions_graph in LabelName.Map.update label_of_rule (fun rule_set -> + let pos = + (* We have to overwrite the law info on tis position because the + pass at the surface AST level that fills the law info on + positions only does it for positions inside expressions, the + visitor in [surface/fill_positions.ml] does not go into the + info of [RuleName.t], etc.*) + Pos.overwrite_law_info + (snd (RuleName.get_info rule.rule_id)) + (Pos.get_law_info (Expr.pos rule.rule_just)) + in match rule_set with - | None -> Some (RuleName.Set.singleton rule_name) - | Some rule_set -> Some (RuleName.Set.add rule_name rule_set)) + | None -> Some (RuleName.Map.singleton rule_name pos) + | Some rule_set -> Some (RuleName.Map.add rule_name pos rule_set)) rule_sets) def LabelName.Map.empty in @@ -363,7 +377,7 @@ let build_exceptions_graph fst (LabelName.Map.choose (LabelName.Map.filter - (fun _ rule_set -> RuleName.Set.mem r rule_set) + (fun _ rule_set -> RuleName.Map.mem r rule_set) label_to_rule_sets)) in (* Next, we collect the exception edges between those groups of rules referred @@ -431,7 +445,8 @@ let build_exceptions_graph (* We've got the vertices and the edges, let's build the graph! *) let g = LabelName.Map.fold - (fun _label rule_set g -> ExceptionsDependencies.add_vertex g rule_set) + (fun label rule_set g -> + ExceptionsDependencies.add_vertex g { rules = rule_set; label }) label_to_rule_sets ExceptionsDependencies.empty in (* then we add the edges *) @@ -439,10 +454,18 @@ let build_exceptions_graph List.fold_left (fun g edge -> let rule_group_from = - LabelName.Map.find edge.label_from label_to_rule_sets + { + ExceptionVertex.rules = + LabelName.Map.find edge.label_from label_to_rule_sets; + label = edge.label_from; + } in let rule_group_to = - LabelName.Map.find edge.label_to label_to_rule_sets + { + ExceptionVertex.rules = + LabelName.Map.find edge.label_to label_to_rule_sets; + label = edge.label_to; + } in let edge = ExceptionsDependencies.E.create rule_group_from edge.edge_positions @@ -464,14 +487,14 @@ let check_for_exception_cycle let scc = List.find (fun scc -> List.length scc > 1) sccs in let spans = List.rev_map - (fun (vs : RuleName.Set.t) -> - let v = RuleName.Set.choose vs in + (fun (vs : ExceptionVertex.t) -> + let v, _ = RuleName.Map.choose vs.rules in let rule = RuleName.Map.find v def in let pos = Marked.get_mark (RuleName.get_info rule.Ast.rule_id) in None, pos) scc in - let v = RuleName.Set.choose (List.hd scc) in + let v, _ = RuleName.Map.choose (List.hd scc).rules in Errors.raise_multispanned_error spans "Exception cycle detected when defining %a: each of these %d exceptions \ applies over the previous one, and the first applies over the last" diff --git a/compiler/desugared/dependency.mli b/compiler/desugared/dependency.mli index 588243dc..278e72b6 100644 --- a/compiler/desugared/dependency.mli +++ b/compiler/desugared/dependency.mli @@ -72,8 +72,14 @@ val build_scope_dependencies : Ast.scope -> ScopeDependencies.t module EdgeExceptions : Graph.Sig.ORDERED_TYPE_DFT with type t = Pos.t list +module ExceptionVertex : sig + type t = { rules : Pos.t RuleName.Map.t; label : LabelName.t } +end + module ExceptionsDependencies : - Graph.Sig.P with type V.t = RuleName.Set.t and type E.label = EdgeExceptions.t + Graph.Sig.P + with type V.t = ExceptionVertex.t + and type E.label = EdgeExceptions.t val build_exceptions_graph : Ast.rule RuleName.Map.t -> Ast.ScopeDef.t -> ExceptionsDependencies.t diff --git a/compiler/desugared/print.ml b/compiler/desugared/print.ml index ab84ac20..86253fb1 100644 --- a/compiler/desugared/print.ml +++ b/compiler/desugared/print.ml @@ -15,8 +15,101 @@ the License. *) open Shared_ast +open Catala_utils + +type exception_tree = + | Leaf of Dependency.ExceptionVertex.t + | Node of exception_tree list * Dependency.ExceptionVertex.t + +open Format + +(* Credits for this printing code: Jean-Christophe Filiâtre, *) +let format_exception_tree (fmt : Format.formatter) (t : exception_tree) = + let blue s = + Format.asprintf "%a" (Cli.format_with_style [ANSITerminal.blue]) s + in + let rec print_node pref (t : exception_tree) = + let (s, w), sons = + let print_s s = + ( Format.asprintf "%a" + (Cli.format_with_style [ANSITerminal.yellow]) + (Format.asprintf "\"%a\"" LabelName.format_t + s.Dependency.ExceptionVertex.label), + String.length + (Format.asprintf "\"%a\"" LabelName.format_t + s.Dependency.ExceptionVertex.label) ) + in + match t with Leaf s -> print_s s, [] | Node (sons, s) -> print_s s, sons + in + pp_print_string fmt s; + if sons != [] then + let pref' = pref ^ String.make (w + 1) ' ' in + match sons with + | [t'] -> + pp_print_string fmt (blue "───"); + print_node (pref' ^ " ") t' + | _ -> + pp_print_string fmt (blue "──"); + print_sons pref' "┬──" sons + and print_sons pref start = function + | [] -> assert false + | [s] -> + pp_print_string fmt (blue " └──"); + print_node (pref ^ " ") s + | s :: sons -> + pp_print_string fmt (blue start); + print_node (pref ^ "| ") s; + pp_force_newline fmt (); + pp_print_string fmt (blue pref); + print_sons pref "├──" sons + in + print_node "" t + +let build_exception_tree exc_graph = + let base_cases = + Dependency.ExceptionsDependencies.fold_vertex + (fun v base_cases -> + if Dependency.ExceptionsDependencies.out_degree exc_graph v = 0 then + v :: base_cases + else base_cases) + exc_graph [] + in + let rec build_tree (base_cases : Dependency.ExceptionVertex.t) = + let exceptions = + Dependency.ExceptionsDependencies.pred exc_graph base_cases + in + match exceptions with + | [] -> Leaf base_cases + | _ -> Node (List.map build_tree exceptions, base_cases) + in + List.map build_tree base_cases let print_exceptions_graph + (scope : ScopeName.t) (var : DesugaredVarName.t) (g : Dependency.ExceptionsDependencies.t) = - assert false + Cli.result_format + "Printing the tree of exceptions for the definitions of variable %a of \ + scope %a." + (Cli.format_with_style [ANSITerminal.yellow]) + (Format.asprintf "\"%a\"" DesugaredVarName.format var) + (Cli.format_with_style [ANSITerminal.yellow]) + (Format.asprintf "\"%a\"" ScopeName.format_t scope); + Dependency.ExceptionsDependencies.iter_vertex + (fun ex -> + Cli.result_format "Group of definitions with label %a:\n%a" + (Cli.format_with_style [ANSITerminal.yellow]) + (Format.asprintf "\"%a\"" LabelName.format_t + ex.Dependency.ExceptionVertex.label) + (Format.pp_print_list + ~pp_sep:(fun fmt () -> Format.fprintf fmt "\n") + (fun fmt (_, pos) -> + Format.fprintf fmt "%s" (Pos.retrieve_loc_text pos))) + (RuleName.Map.bindings ex.Dependency.ExceptionVertex.rules)) + g; + let tree = build_exception_tree g in + Cli.result_format "The exception tree structure is as follows:\n\n%a" + (Format.pp_print_list + ~pp_sep:(fun fmt () -> Format.fprintf fmt "\n") + (fun fmt tree -> format_exception_tree fmt tree)) + tree diff --git a/compiler/desugared/print.mli b/compiler/desugared/print.mli index 739e2b07..2d9f3eaa 100644 --- a/compiler/desugared/print.mli +++ b/compiler/desugared/print.mli @@ -15,5 +15,8 @@ the License. *) val print_exceptions_graph : - Shared_ast.DesugaredVarName.t -> Dependency.ExceptionsDependencies.t -> unit + Shared_ast.ScopeName.t -> + Shared_ast.DesugaredVarName.t -> + Dependency.ExceptionsDependencies.t -> + unit (** Prints the exception graph of a variable to the terminal *) diff --git a/compiler/driver.ml b/compiler/driver.ml index b4c4540f..a071105a 100644 --- a/compiler/driver.ml +++ b/compiler/driver.ml @@ -165,9 +165,9 @@ let driver source_file (options : Cli.options) : int = match Shared_ast.IdentName.Map.find_opt name ctxt.typedefs with | Some (Desugared.Name_resolution.TScope (uid, _)) -> uid | _ -> - Errors.raise_error "There is no scope \"%a\" inside the program." + Errors.raise_error "There is no scope %a inside the program." (Cli.format_with_style [ANSITerminal.yellow]) - name) + ("\"" ^ name ^ "\"")) in (* This uid is a Desugared identifier *) let variable_uid = @@ -201,24 +201,25 @@ let driver source_file (options : Cli.options) : int = (Shared_ast.ScopeName.Map.find scope_uid ctxt.scopes).var_idmap with | None -> - Errors.raise_error "Variable \"%a\" not found inside scope \"%a\"" + Errors.raise_error "Variable %a not found inside scope %a" (Cli.format_with_style [ANSITerminal.yellow]) - name + ("\"" ^ name ^ "\"") (Cli.format_with_style [ANSITerminal.yellow]) - (Format.asprintf "%a" Shared_ast.ScopeName.format_t scope_uid) + (Format.asprintf "\"%a\"" Shared_ast.ScopeName.format_t scope_uid) | Some (Desugared.Name_resolution.SubScope (subscope_var_name, subscope_name)) -> ( match second_part with | None -> Errors.raise_error - "Subscope \"%a\" of scope \"%a\" cannot be selected by itself, \ - please add \".\" where is a subscope variable." + "Subscope %a of scope %a cannot be selected by itself, please \ + add \".\" where is a subscope variable." (Cli.format_with_style [ANSITerminal.yellow]) - (Format.asprintf "%a" Shared_ast.SubScopeName.format_t + (Format.asprintf "\"%a\"" Shared_ast.SubScopeName.format_t subscope_var_name) (Cli.format_with_style [ANSITerminal.yellow]) - (Format.asprintf "%a" Shared_ast.ScopeName.format_t scope_uid) + (Format.asprintf "\"%a\"" Shared_ast.ScopeName.format_t + scope_uid) | Some second_part -> ( match Shared_ast.IdentName.Map.find_opt second_part @@ -230,16 +231,16 @@ let driver source_file (options : Cli.options) : int = (Shared_ast.DesugaredVarName.SubScopeVar (subscope_var_name, v)) | _ -> Errors.raise_error - "Var \"%a\" of subscope \"%a\" in scope \"%a\" does not \ - exist, please check your command line arguments." + "Var %a of subscope %a in scope %a does not exist, please \ + check your command line arguments." (Cli.format_with_style [ANSITerminal.yellow]) - second_part + ("\"" ^ second_part ^ "\"") (Cli.format_with_style [ANSITerminal.yellow]) - (Format.asprintf "%a" Shared_ast.SubScopeName.format_t + (Format.asprintf "\"%a\"" Shared_ast.SubScopeName.format_t subscope_var_name) (Cli.format_with_style [ANSITerminal.yellow]) - (Format.asprintf "%a" Shared_ast.ScopeName.format_t scope_uid) - )) + (Format.asprintf "\"%a\"" Shared_ast.ScopeName.format_t + scope_uid))) | Some (Desugared.Name_resolution.ScopeVar v) -> Some (Shared_ast.DesugaredVarName.ScopeVar @@ -256,15 +257,14 @@ let driver source_file (options : Cli.options) : int = | Some state -> state | None -> Errors.raise_error - "State \"%a\" is not found for variable \"%a\" of \ - scope \"%a\"" + "State %a is not found for variable %a of scope %a" (Cli.format_with_style [ANSITerminal.yellow]) - second_part + ("\"" ^ second_part ^ "\"") (Cli.format_with_style [ANSITerminal.yellow]) - first_part + ("\"" ^ first_part ^ "\"") (Cli.format_with_style [ANSITerminal.yellow]) - (Format.asprintf "%a" Shared_ast.ScopeName.format_t - scope_uid)) + (Format.asprintf "\"%a\"" + Shared_ast.ScopeName.format_t scope_uid)) second_part ))) in Cli.debug_print "Desugaring..."; @@ -286,7 +286,7 @@ let driver source_file (options : Cli.options) : int = Errors.raise_error "Please provide a scope variable to analyze with the -v option." in - Desugared.Print.print_exceptions_graph variable_uid + Desugared.Print.print_exceptions_graph scope_uid variable_uid (Shared_ast.DesugaredVarName.Map.find variable_uid exceptions_graphs) | `Scopelang -> let _output_file, with_output = get_output_format () in diff --git a/compiler/scopelang/from_desugared.ml b/compiler/scopelang/from_desugared.ml index 5e8f48b3..50633d5f 100644 --- a/compiler/scopelang/from_desugared.ml +++ b/compiler/scopelang/from_desugared.ml @@ -195,14 +195,15 @@ let def_map_to_tree else base_cases) exc_graph [] in - let rec build_tree (base_cases : RuleName.Set.t) : rule_tree = + let rec build_tree (base_cases : Desugared.Dependency.ExceptionVertex.t) : + rule_tree = let exceptions = Desugared.Dependency.ExceptionsDependencies.pred exc_graph base_cases in let base_case_as_rule_list = List.map - (fun r -> RuleName.Map.find r def) - (RuleName.Set.elements base_cases) + (fun (r, _) -> RuleName.Map.find r def) + (RuleName.Map.bindings base_cases.rules) in match exceptions with | [] -> Leaf base_case_as_rule_list diff --git a/compiler/shared_ast/definitions.ml b/compiler/shared_ast/definitions.ml index 115bfd5c..bc52cf8d 100644 --- a/compiler/shared_ast/definitions.ml +++ b/compiler/shared_ast/definitions.ml @@ -53,6 +53,7 @@ module DesugaredVarName : sig val hash : t -> int val compare : t -> t -> int val equal : t -> t -> bool + val format : Format.formatter -> t -> unit module Map : Map.S with type key = t module Set : Set.S with type elt = t @@ -91,6 +92,14 @@ end = struct | SubScopeVar (x, xv), SubScopeVar (y, yv) -> SubScopeName.equal x y && ScopeVar.equal xv yv | (ScopeVar _ | SubScopeVar _), _ -> false + + let format fmt x = + match x with + | ScopeVar (v, None) -> ScopeVar.format_t fmt v + | ScopeVar (v, Some st) -> + Format.fprintf fmt "%a.%a" ScopeVar.format_t v StateName.format_t st + | SubScopeVar (ss, v) -> + Format.fprintf fmt "%a.%a" SubScopeName.format_t ss ScopeVar.format_t v end include Ordering From ecccb5fb91a2e2b25cb93f89274fa2102a4937f6 Mon Sep 17 00:00:00 2001 From: Denis Merigoux Date: Fri, 7 Apr 2023 16:45:45 +0200 Subject: [PATCH 06/10] Last changes --- compiler/desugared/print.ml | 4 +- .../good/groups_of_exceptions.catala_en | 41 +++++++++++++++++++ 2 files changed, 44 insertions(+), 1 deletion(-) diff --git a/compiler/desugared/print.ml b/compiler/desugared/print.ml index 86253fb1..c16cda56 100644 --- a/compiler/desugared/print.ml +++ b/compiler/desugared/print.ml @@ -23,7 +23,7 @@ type exception_tree = open Format -(* Credits for this printing code: Jean-Christophe Filiâtre, *) +(* Original credits for this printing code: Jean-Christophe Filiâtre *) let format_exception_tree (fmt : Format.formatter) (t : exception_tree) = let blue s = Format.asprintf "%a" (Cli.format_with_style [ANSITerminal.blue]) s @@ -60,6 +60,8 @@ let format_exception_tree (fmt : Format.formatter) (t : exception_tree) = pp_print_string fmt (blue start); print_node (pref ^ "| ") s; pp_force_newline fmt (); + pp_print_string fmt (blue (pref ^ " │")); + pp_force_newline fmt (); pp_print_string fmt (blue pref); print_sons pref "├──" sons in diff --git a/tests/test_exception/good/groups_of_exceptions.catala_en b/tests/test_exception/good/groups_of_exceptions.catala_en index 6571f177..ab38ac93 100644 --- a/tests/test_exception/good/groups_of_exceptions.catala_en +++ b/tests/test_exception/good/groups_of_exceptions.catala_en @@ -44,3 +44,44 @@ let scope Foo (y: integer|input) (x: integer|internal|output) = ⊢ ⟨ ⟨y = 2 ⊢ 2⟩, ⟨y = 3 ⊢ 3⟩ | false ⊢ ∅ ⟩ ⟩ | true ⊢ ⟨ ⟨y = 0 ⊢ 0⟩, ⟨y = 1 ⊢ 1⟩ | false ⊢ ∅ ⟩ ⟩ ``` + +```catala-test-inline +$ catala Exceptions -s Foo -v x +[RESULT] Printing the tree of exceptions for the definitions of variable "x" of scope "Foo". +[RESULT] Group of definitions with label "base": +┌─⯈ tests/test_exception/good/groups_of_exceptions.catala_en:9.2-25: +└─┐ +9 │ label base definition x under condition + │ ‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾ + └─ Test +┌─⯈ tests/test_exception/good/groups_of_exceptions.catala_en:13.2-25: +└──┐ +13 │ label base definition x under condition + │ ‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾ + └─ Test +[RESULT] Group of definitions with label "intermediate": +┌─⯈ tests/test_exception/good/groups_of_exceptions.catala_en:17.2-48: +└──┐ +17 │ label intermediate exception base definition x under condition + │ ‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾ + └─ Test +┌─⯈ tests/test_exception/good/groups_of_exceptions.catala_en:21.2-48: +└──┐ +21 │ label intermediate exception base definition x under condition + │ ‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾ + └─ Test +[RESULT] Group of definitions with label "exception_to_intermediate": +┌─⯈ tests/test_exception/good/groups_of_exceptions.catala_en:25.2-37: +└──┐ +25 │ exception intermediate definition x under condition + │ ‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾ + └─ Test +┌─⯈ tests/test_exception/good/groups_of_exceptions.catala_en:29.2-37: +└──┐ +29 │ exception intermediate definition x under condition + │ ‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾ + └─ Test +[RESULT] The exception tree structure is as follows: + +"base"───"intermediate"───"exception_to_intermediate" +``` From 39f1704d7618571c9bb83724e40c128ed7bd0b30 Mon Sep 17 00:00:00 2001 From: Denis Merigoux Date: Fri, 7 Apr 2023 16:56:43 +0200 Subject: [PATCH 07/10] Last fixes --- compiler/desugared/print.ml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/compiler/desugared/print.ml b/compiler/desugared/print.ml index c16cda56..cda5f6b9 100644 --- a/compiler/desugared/print.ml +++ b/compiler/desugared/print.ml @@ -99,7 +99,7 @@ let print_exceptions_graph (Format.asprintf "\"%a\"" ScopeName.format_t scope); Dependency.ExceptionsDependencies.iter_vertex (fun ex -> - Cli.result_format "Group of definitions with label %a:\n%a" + Cli.result_format "Definitions with label %a:\n%a" (Cli.format_with_style [ANSITerminal.yellow]) (Format.asprintf "\"%a\"" LabelName.format_t ex.Dependency.ExceptionVertex.label) @@ -112,6 +112,6 @@ let print_exceptions_graph let tree = build_exception_tree g in Cli.result_format "The exception tree structure is as follows:\n\n%a" (Format.pp_print_list - ~pp_sep:(fun fmt () -> Format.fprintf fmt "\n") + ~pp_sep:(fun fmt () -> Format.fprintf fmt "\n\n") (fun fmt tree -> format_exception_tree fmt tree)) tree From c5ba3e72fe8ce684432344eda05ba4daaff39216 Mon Sep 17 00:00:00 2001 From: Denis Merigoux Date: Fri, 7 Apr 2023 17:10:02 +0200 Subject: [PATCH 08/10] Restore CI --- compiler/catala_web_interpreter.ml | 1 + examples/allocations_familiales/prologue.catala_fr | 3 ++- french_law/python/main.py | 2 +- tests/test_exception/good/groups_of_exceptions.catala_en | 6 +++--- 4 files changed, 7 insertions(+), 5 deletions(-) diff --git a/compiler/catala_web_interpreter.ml b/compiler/catala_web_interpreter.ml index e565181f..a693b783 100644 --- a/compiler/catala_web_interpreter.ml +++ b/compiler/catala_web_interpreter.ml @@ -27,6 +27,7 @@ let _ = disable_counterexamples = false; optimize = false; ex_scope = Some (Js.to_string scope); + ex_variable = None; output_file = None; print_only_law = false; } diff --git a/examples/allocations_familiales/prologue.catala_fr b/examples/allocations_familiales/prologue.catala_fr index 41555e02..fc5a1481 100644 --- a/examples/allocations_familiales/prologue.catala_fr +++ b/examples/allocations_familiales/prologue.catala_fr @@ -85,7 +85,8 @@ déclaration champ d'application AllocationsFamiliales: interne enfants_à_charge_droit_ouvert_prestation_familiale contenu collection Enfant interne prise_en_compte contenu PriseEnCompte dépend de enfant contenu Enfant - résultat versement contenu VersementAllocations dépend de enfant contenu Enfant + résultat versement contenu VersementAllocations + dépend de enfant contenu Enfant résultat montant_versé contenu argent diff --git a/french_law/python/main.py b/french_law/python/main.py index 72d67132..1bf1d332 100755 --- a/french_law/python/main.py +++ b/french_law/python/main.py @@ -143,7 +143,7 @@ if __name__ == '__main__': print(timeit.timeit(benchmark_iteration_family, number=iterations)) elif action == "bench_housing": iterations = 1000 - print("Iterating {} iterations of the family benefits computation. Total time (s):".format( + print("Iterating {} iterations of the housing benefits computation. Total time (s):".format( iterations)) print(timeit.timeit(benchmark_iteration_housing, number=iterations)) elif action == "show_log": diff --git a/tests/test_exception/good/groups_of_exceptions.catala_en b/tests/test_exception/good/groups_of_exceptions.catala_en index ab38ac93..37174e41 100644 --- a/tests/test_exception/good/groups_of_exceptions.catala_en +++ b/tests/test_exception/good/groups_of_exceptions.catala_en @@ -48,7 +48,7 @@ let scope Foo (y: integer|input) (x: integer|internal|output) = ```catala-test-inline $ catala Exceptions -s Foo -v x [RESULT] Printing the tree of exceptions for the definitions of variable "x" of scope "Foo". -[RESULT] Group of definitions with label "base": +[RESULT] Definitions with label "base": ┌─⯈ tests/test_exception/good/groups_of_exceptions.catala_en:9.2-25: └─┐ 9 │ label base definition x under condition @@ -59,7 +59,7 @@ $ catala Exceptions -s Foo -v x 13 │ label base definition x under condition │ ‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾ └─ Test -[RESULT] Group of definitions with label "intermediate": +[RESULT] Definitions with label "intermediate": ┌─⯈ tests/test_exception/good/groups_of_exceptions.catala_en:17.2-48: └──┐ 17 │ label intermediate exception base definition x under condition @@ -70,7 +70,7 @@ $ catala Exceptions -s Foo -v x 21 │ label intermediate exception base definition x under condition │ ‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾ └─ Test -[RESULT] Group of definitions with label "exception_to_intermediate": +[RESULT] Definitions with label "exception_to_intermediate": ┌─⯈ tests/test_exception/good/groups_of_exceptions.catala_en:25.2-37: └──┐ 25 │ exception intermediate definition x under condition From 57da6225678e80bd74aba0384878eb586dbf03bf Mon Sep 17 00:00:00 2001 From: Denis Merigoux Date: Tue, 18 Apr 2023 10:31:44 +0200 Subject: [PATCH 09/10] Refactoring changes after @altgr's suggestions --- compiler/desugared/ast.ml | 93 +++++----- compiler/desugared/ast.mli | 10 +- compiler/desugared/dependency.ml | 9 +- compiler/desugared/disambiguate.ml | 4 +- compiler/desugared/from_surface.ml | 18 +- compiler/desugared/linting.ml | 4 +- compiler/desugared/name_resolution.ml | 6 +- compiler/desugared/name_resolution.mli | 2 +- compiler/desugared/print.ml | 4 +- compiler/desugared/print.mli | 2 +- compiler/driver.ml | 247 +++++++++++++------------ compiler/scopelang/from_desugared.ml | 43 +++-- compiler/scopelang/from_desugared.mli | 3 +- compiler/shared_ast/definitions.ml | 61 ------ 14 files changed, 229 insertions(+), 277 deletions(-) diff --git a/compiler/desugared/ast.ml b/compiler/desugared/ast.ml index 8c72e215..9adf4208 100644 --- a/compiler/desugared/ast.ml +++ b/compiler/desugared/ast.ml @@ -24,48 +24,53 @@ open Shared_ast (** Inside a scope, a definition can refer either to a scope def, or a subscope def *) module ScopeDef = struct - type t = - | Var of ScopeVar.t * StateName.t option - | SubScopeVar of SubScopeName.t * ScopeVar.t * Pos.t - (** In this case, the [ScopeVar.t] lives inside the context of the - subscope's original declaration *) + module Base = struct + type t = + | Var of ScopeVar.t * StateName.t option + | SubScopeVar of SubScopeName.t * ScopeVar.t * Pos.t + (** In this case, the [ScopeVar.t] lives inside the context of the + subscope's original declaration *) - let compare x y = - match x, y with - | Var (x, stx), Var (y, sty) -> ( - match ScopeVar.compare x y with - | 0 -> Option.compare StateName.compare stx sty - | n -> n) - | SubScopeVar (x', x, _), SubScopeVar (y', y, _) -> ( - match SubScopeName.compare x' y' with 0 -> ScopeVar.compare x y | n -> n) - | Var _, _ -> -1 - | _, Var _ -> 1 + let compare x y = + match x, y with + | Var (x, stx), Var (y, sty) -> ( + match ScopeVar.compare x y with + | 0 -> Option.compare StateName.compare stx sty + | n -> n) + | SubScopeVar (x', x, _), SubScopeVar (y', y, _) -> ( + match SubScopeName.compare x' y' with + | 0 -> ScopeVar.compare x y + | n -> n) + | Var _, _ -> -1 + | _, Var _ -> 1 - let get_position x = - match x with - | Var (x, None) -> Marked.get_mark (ScopeVar.get_info x) - | Var (_, Some sx) -> Marked.get_mark (StateName.get_info sx) - | SubScopeVar (_, _, pos) -> pos + let get_position x = + match x with + | Var (x, None) -> Marked.get_mark (ScopeVar.get_info x) + | Var (_, Some sx) -> Marked.get_mark (StateName.get_info sx) + | SubScopeVar (_, _, pos) -> pos - let format_t fmt x = - match x with - | Var (v, None) -> ScopeVar.format_t fmt v - | Var (v, Some sv) -> - Format.fprintf fmt "%a.%a" ScopeVar.format_t v StateName.format_t sv - | SubScopeVar (s, v, _) -> - Format.fprintf fmt "%a.%a" SubScopeName.format_t s ScopeVar.format_t v + let format_t fmt x = + match x with + | Var (v, None) -> ScopeVar.format_t fmt v + | Var (v, Some sv) -> + Format.fprintf fmt "%a.%a" ScopeVar.format_t v StateName.format_t sv + | SubScopeVar (s, v, _) -> + Format.fprintf fmt "%a.%a" SubScopeName.format_t s ScopeVar.format_t v - let hash x = - match x with - | Var (v, None) -> ScopeVar.hash v - | Var (v, Some sv) -> Int.logxor (ScopeVar.hash v) (StateName.hash sv) - | SubScopeVar (w, v, _) -> - Int.logxor (SubScopeName.hash w) (ScopeVar.hash v) + let hash x = + match x with + | Var (v, None) -> ScopeVar.hash v + | Var (v, Some sv) -> Int.logxor (ScopeVar.hash v) (StateName.hash sv) + | SubScopeVar (w, v, _) -> + Int.logxor (SubScopeName.hash w) (ScopeVar.hash v) + end + + include Base + module Map = Map.Make (Base) + module Set = Set.Make (Base) end -module ScopeDefMap : Map.S with type key = ScopeDef.t = Map.Make (ScopeDef) -module ScopeDefSet : Set.S with type elt = ScopeDef.t = Set.Make (ScopeDef) - (** {1 AST} *) type location = desugared glocation @@ -195,7 +200,7 @@ type scope = { scope_vars : var_or_states ScopeVar.Map.t; scope_sub_scopes : ScopeName.t SubScopeName.Map.t; scope_uid : ScopeName.t; - scope_defs : scope_def ScopeDefMap.t; + scope_defs : scope_def ScopeDef.Map.t; scope_assertions : assertion list; scope_options : catala_option Marked.pos list; scope_meta_assertions : meta_assertion list; @@ -218,9 +223,9 @@ let rec locations_used e : LocationSet.t = (fun e -> LocationSet.union (locations_used e)) e LocationSet.empty -let free_variables (def : rule RuleName.Map.t) : Pos.t ScopeDefMap.t = - let add_locs (acc : Pos.t ScopeDefMap.t) (locs : LocationSet.t) : - Pos.t ScopeDefMap.t = +let free_variables (def : rule RuleName.Map.t) : Pos.t ScopeDef.Map.t = + let add_locs (acc : Pos.t ScopeDef.Map.t) (locs : LocationSet.t) : + Pos.t ScopeDef.Map.t = LocationSet.fold (fun (loc, loc_pos) acc -> let usage = @@ -235,7 +240,9 @@ let free_variables (def : rule RuleName.Map.t) : Pos.t ScopeDefMap.t = Marked.get_mark sub_index )) | ToplevelVar _ -> None in - match usage with Some u -> ScopeDefMap.add u loc_pos acc | None -> acc) + match usage with + | Some u -> ScopeDef.Map.add u loc_pos acc + | None -> acc) locs acc in RuleName.Map.fold @@ -246,14 +253,14 @@ let free_variables (def : rule RuleName.Map.t) : Pos.t ScopeDefMap.t = (locations_used (Expr.unbox rule.rule_cons)) in add_locs acc locs) - def ScopeDefMap.empty + def ScopeDef.Map.empty let fold_exprs ~(f : 'a -> expr -> 'a) ~(init : 'a) (p : program) : 'a = let acc = ScopeName.Map.fold (fun _ scope acc -> let acc = - ScopeDefMap.fold + ScopeDef.Map.fold (fun _ scope_def acc -> RuleName.Map.fold (fun _ rule acc -> diff --git a/compiler/desugared/ast.mli b/compiler/desugared/ast.mli index 0357a6d4..8e8987e9 100644 --- a/compiler/desugared/ast.mli +++ b/compiler/desugared/ast.mli @@ -30,10 +30,10 @@ module ScopeDef : sig val get_position : t -> Pos.t val format_t : Format.formatter -> t -> unit val hash : t -> int -end -module ScopeDefMap : Map.S with type key = ScopeDef.t -module ScopeDefSet : Set.S with type elt = ScopeDef.t + module Map : Map.S with type key = t + module Set : Set.S with type elt = t +end (** {1 AST} *) @@ -118,7 +118,7 @@ type scope = { scope_vars : var_or_states ScopeVar.Map.t; scope_sub_scopes : ScopeName.t SubScopeName.Map.t; scope_uid : ScopeName.t; - scope_defs : scope_def ScopeDefMap.t; + scope_defs : scope_def ScopeDef.Map.t; scope_assertions : assertion list; scope_options : catala_option Marked.pos list; scope_meta_assertions : meta_assertion list; @@ -133,7 +133,7 @@ type program = { (** {1 Helpers} *) val locations_used : expr -> LocationSet.t -val free_variables : rule RuleName.Map.t -> Pos.t ScopeDefMap.t +val free_variables : rule RuleName.Map.t -> Pos.t ScopeDef.Map.t val fold_exprs : f:('a -> expr -> 'a) -> init:'a -> program -> 'a (** Usage: [fold_exprs ~f ~init program] applies ~f to all the expressions diff --git a/compiler/desugared/dependency.ml b/compiler/desugared/dependency.ml index 461619e7..742e8897 100644 --- a/compiler/desugared/dependency.ml +++ b/compiler/desugared/dependency.ml @@ -173,11 +173,11 @@ let build_scope_dependencies (scope : Ast.scope) : ScopeDependencies.t = scope.scope_sub_scopes g in let g = - Ast.ScopeDefMap.fold + Ast.ScopeDef.Map.fold (fun def_key scope_def g -> let def = scope_def.Ast.scope_def_rules in let fv = Ast.free_variables def in - Ast.ScopeDefMap.fold + Ast.ScopeDef.Map.fold (fun fv_def fv_def_pos g -> match def_key, fv_def with | ( Ast.ScopeDef.Var (v_defined, s_defined), @@ -246,7 +246,10 @@ let build_scope_dependencies (scope : Ast.scope) : ScopeDependencies.t = module ExceptionVertex = struct type t = { rules : Pos.t RuleName.Map.t; label : LabelName.t } - let compare x y = RuleName.Map.compare compare x.rules y.rules + let compare x y = + RuleName.Map.compare + (fun _ _ -> 0 (* we don't care about positions here*)) + x.rules y.rules let hash (x : t) : int = RuleName.Map.fold diff --git a/compiler/desugared/disambiguate.ml b/compiler/desugared/disambiguate.ml index a1ce7fc8..3bb7534b 100644 --- a/compiler/desugared/disambiguate.ml +++ b/compiler/desugared/disambiguate.ml @@ -45,7 +45,7 @@ let rule ctx env rule = let scope ctx env scope = let env = Typing.Env.open_scope scope.scope_uid env in let scope_defs = - ScopeDefMap.map + ScopeDef.Map.map (fun def -> let scope_def_rules = (* Note: ordering in file order might be better for error reporting ? @@ -75,7 +75,7 @@ let program prg = ScopeName.Map.fold (fun scope_name scope env -> let vars = - ScopeDefMap.fold + ScopeDef.Map.fold (fun var def vars -> match var with | Var (v, _states) -> ScopeVar.Map.add v def.scope_def_typ vars diff --git a/compiler/desugared/from_surface.ml b/compiler/desugared/from_surface.ml index 6449aa49..e6604626 100644 --- a/compiler/desugared/from_surface.ml +++ b/compiler/desugared/from_surface.ml @@ -1019,7 +1019,7 @@ let process_def (Marked.get_mark def.definition_name) in let scope_def_ctxt = - Ast.ScopeDefMap.find def_key scope_ctxt.scope_defs_contexts + Ast.ScopeDef.Map.find def_key scope_ctxt.scope_defs_contexts in (* We add to the name resolution context the name of the parameter variable *) let new_ctxt, param_uids = @@ -1028,7 +1028,7 @@ let process_def def in let scope_updated = - let scope_def = Ast.ScopeDefMap.find def_key scope.scope_defs in + let scope_def = Ast.ScopeDef.Map.find def_key scope.scope_defs in let rule_name = def.definition_id in let label_situation = match def.definition_label with @@ -1075,7 +1075,7 @@ let process_def in { scope with - scope_defs = Ast.ScopeDefMap.add def_key scope_def scope.scope_defs; + scope_defs = Ast.ScopeDef.Map.add def_key scope_def scope.scope_defs; } in { @@ -1204,7 +1204,7 @@ let check_unlabeled_exception (* should not happen *) in let scope_def_ctxt = - Ast.ScopeDefMap.find def_key scope_ctxt.scope_defs_contexts + Ast.ScopeDef.Map.find def_key scope_ctxt.scope_defs_contexts in match exception_to with | Surface.Ast.NotAnException | Surface.Ast.ExceptionToLabel _ -> () @@ -1296,7 +1296,7 @@ let attribute_to_io (attr : Surface.Ast.scope_decl_context_io) : Ast.io = let init_scope_defs (ctxt : Name_resolution.context) (scope_idmap : Name_resolution.scope_var_or_subscope IdentName.Map.t) : - Ast.scope_def Ast.ScopeDefMap.t = + Ast.scope_def Ast.ScopeDef.Map.t = (* Initializing the definitions of all scopes and subscope vars, with no rules yet inside *) let add_def _ v scope_def_map = @@ -1306,7 +1306,7 @@ let init_scope_defs match v_sig.var_sig_states_list with | [] -> let def_key = Ast.ScopeDef.Var (v, None) in - Ast.ScopeDefMap.add def_key + Ast.ScopeDef.Map.add def_key { Ast.scope_def_rules = RuleName.Map.empty; Ast.scope_def_typ = v_sig.var_sig_typ; @@ -1344,7 +1344,7 @@ let init_scope_defs { io_input; io_output }); } in - Ast.ScopeDefMap.add def_key def acc, i + 1) + Ast.ScopeDef.Map.add def_key def acc, i + 1) (scope_def_map, 0) states in scope_def) @@ -1364,7 +1364,7 @@ let init_scope_defs Ast.ScopeDef.SubScopeVar (v0, v, Marked.get_mark (ScopeVar.get_info v)) in - Ast.ScopeDefMap.add def_key + Ast.ScopeDef.Map.add def_key { Ast.scope_def_rules = RuleName.Map.empty; Ast.scope_def_typ = v_sig.var_sig_typ; @@ -1375,7 +1375,7 @@ let init_scope_defs scope_def_map) sub_scope_def.Name_resolution.var_idmap scope_def_map in - IdentName.Map.fold add_def scope_idmap Ast.ScopeDefMap.empty + IdentName.Map.fold add_def scope_idmap Ast.ScopeDef.Map.empty (** Main function of this module *) let translate_program diff --git a/compiler/desugared/linting.ml b/compiler/desugared/linting.ml index 176f2b41..599751b0 100644 --- a/compiler/desugared/linting.ml +++ b/compiler/desugared/linting.ml @@ -22,7 +22,7 @@ open Catala_utils let detect_empty_definitions (p : program) : unit = ScopeName.Map.iter (fun (scope_name : ScopeName.t) scope -> - ScopeDefMap.iter + ScopeDef.Map.iter (fun scope_def_key scope_def -> if (match scope_def_key with ScopeDef.Var _ -> true | _ -> false) @@ -59,7 +59,7 @@ let detect_unused_scope_vars (p : program) : unit = in ScopeName.Map.iter (fun (scope_name : ScopeName.t) scope -> - ScopeDefMap.iter + ScopeDef.Map.iter (fun scope_def_key scope_def -> match scope_def_key with | ScopeDef.Var (v, _) diff --git a/compiler/desugared/name_resolution.ml b/compiler/desugared/name_resolution.ml index af5a0235..245c9805 100644 --- a/compiler/desugared/name_resolution.ml +++ b/compiler/desugared/name_resolution.ml @@ -39,7 +39,7 @@ type scope_var_or_subscope = type scope_context = { var_idmap : scope_var_or_subscope IdentName.Map.t; (** All variables, including scope variables and subscopes *) - scope_defs_contexts : scope_def_context Ast.ScopeDefMap.t; + scope_defs_contexts : scope_def_context Ast.ScopeDef.Map.t; (** What is the default rule to refer to for unnamed exceptions, if any *) sub_scopes : ScopeName.Set.t; (** Other scopes referred to by this scope. Used for dependency analysis *) @@ -624,7 +624,7 @@ let process_name_item (ctxt : context) (item : Surface.Ast.code_item Marked.pos) ScopeName.Map.add scope_uid { var_idmap = IdentName.Map.empty; - scope_defs_contexts = Ast.ScopeDefMap.empty; + scope_defs_contexts = Ast.ScopeDef.Map.empty; sub_scopes = ScopeName.Set.empty; } ctxt.scopes; @@ -853,7 +853,7 @@ let process_definition { s_ctxt with scope_defs_contexts = - Ast.ScopeDefMap.update def_key + Ast.ScopeDef.Map.update def_key (fun def_key_ctx -> Some (update_def_key_ctx d diff --git a/compiler/desugared/name_resolution.mli b/compiler/desugared/name_resolution.mli index 7a3245d3..eab8c1a9 100644 --- a/compiler/desugared/name_resolution.mli +++ b/compiler/desugared/name_resolution.mli @@ -39,7 +39,7 @@ type scope_var_or_subscope = type scope_context = { var_idmap : scope_var_or_subscope IdentName.Map.t; (** All variables, including scope variables and subscopes *) - scope_defs_contexts : scope_def_context Ast.ScopeDefMap.t; + scope_defs_contexts : scope_def_context Ast.ScopeDef.Map.t; (** What is the default rule to refer to for unnamed exceptions, if any *) sub_scopes : ScopeName.Set.t; (** Other scopes referred to by this scope. Used for dependency analysis *) diff --git a/compiler/desugared/print.ml b/compiler/desugared/print.ml index cda5f6b9..06a88cf8 100644 --- a/compiler/desugared/print.ml +++ b/compiler/desugared/print.ml @@ -88,13 +88,13 @@ let build_exception_tree exc_graph = let print_exceptions_graph (scope : ScopeName.t) - (var : DesugaredVarName.t) + (var : Ast.ScopeDef.t) (g : Dependency.ExceptionsDependencies.t) = Cli.result_format "Printing the tree of exceptions for the definitions of variable %a of \ scope %a." (Cli.format_with_style [ANSITerminal.yellow]) - (Format.asprintf "\"%a\"" DesugaredVarName.format var) + (Format.asprintf "\"%a\"" Ast.ScopeDef.format_t var) (Cli.format_with_style [ANSITerminal.yellow]) (Format.asprintf "\"%a\"" ScopeName.format_t scope); Dependency.ExceptionsDependencies.iter_vertex diff --git a/compiler/desugared/print.mli b/compiler/desugared/print.mli index 2d9f3eaa..ee23dae0 100644 --- a/compiler/desugared/print.mli +++ b/compiler/desugared/print.mli @@ -16,7 +16,7 @@ val print_exceptions_graph : Shared_ast.ScopeName.t -> - Shared_ast.DesugaredVarName.t -> + Ast.ScopeDef.t -> Dependency.ExceptionsDependencies.t -> unit (** Prints the exception graph of a variable to the terminal *) diff --git a/compiler/driver.ml b/compiler/driver.ml index a071105a..080d2dc1 100644 --- a/compiler/driver.ml +++ b/compiler/driver.ml @@ -21,6 +21,127 @@ open Catala_utils string representation. *) let extensions = [".catala_fr", "fr"; ".catala_en", "en"; ".catala_pl", "pl"] +let get_scope_uid + (options : Cli.options) + (backend : Plugin.t Cli.backend_option) + (ctxt : Desugared.Name_resolution.context) = + match options.ex_scope, backend with + | None, `Interpret -> + Errors.raise_error "No scope was provided for execution." + | None, _ -> + let _, scope = + try + Shared_ast.IdentName.Map.filter_map + (fun _ -> function + | Desugared.Name_resolution.TScope (uid, _) -> Some uid + | _ -> None) + ctxt.typedefs + |> Shared_ast.IdentName.Map.choose + with Not_found -> + Errors.raise_error "There isn't any scope inside the program." + in + scope + | Some name, _ -> ( + match Shared_ast.IdentName.Map.find_opt name ctxt.typedefs with + | Some (Desugared.Name_resolution.TScope (uid, _)) -> uid + | _ -> + Errors.raise_error "There is no scope %a inside the program." + (Cli.format_with_style [ANSITerminal.yellow]) + ("\"" ^ name ^ "\"")) + +let get_variable_uid + (options : Cli.options) + (backend : Plugin.t Cli.backend_option) + (ctxt : Desugared.Name_resolution.context) + (scope_uid : Shared_ast.ScopeName.t) = + match options.ex_variable, backend with + | None, `Exceptions -> + Errors.raise_error + "Please specify a variable with the -v option to print its exception \ + tree." + | None, _ -> None + | Some name, _ -> ( + (* Sometimes the variable selected is of the form [a.b]*) + let first_part, second_part = + match + Re.( + exec_opt + (compile + @@ whole_string + @@ seq [group (rep1 (compl [char '.'])); char '.'; group (rep1 any)] + ) + name) + with + | None -> name, None + | Some groups -> Re.Group.get groups 1, Some (Re.Group.get groups 2) + in + match + Shared_ast.IdentName.Map.find_opt first_part + (Shared_ast.ScopeName.Map.find scope_uid ctxt.scopes).var_idmap + with + | None -> + Errors.raise_error "Variable %a not found inside scope %a" + (Cli.format_with_style [ANSITerminal.yellow]) + ("\"" ^ name ^ "\"") + (Cli.format_with_style [ANSITerminal.yellow]) + (Format.asprintf "\"%a\"" Shared_ast.ScopeName.format_t scope_uid) + | Some + (Desugared.Name_resolution.SubScope (subscope_var_name, subscope_name)) + -> ( + match second_part with + | None -> + Errors.raise_error + "Subscope %a of scope %a cannot be selected by itself, please add \ + \".\" where is a subscope variable." + (Cli.format_with_style [ANSITerminal.yellow]) + (Format.asprintf "\"%a\"" Shared_ast.SubScopeName.format_t + subscope_var_name) + (Cli.format_with_style [ANSITerminal.yellow]) + (Format.asprintf "\"%a\"" Shared_ast.ScopeName.format_t scope_uid) + | Some second_part -> ( + match + Shared_ast.IdentName.Map.find_opt second_part + (Shared_ast.ScopeName.Map.find subscope_name ctxt.scopes).var_idmap + with + | Some (Desugared.Name_resolution.ScopeVar v) -> + Some + (Desugared.Ast.ScopeDef.SubScopeVar + (subscope_var_name, v, Pos.no_pos)) + | _ -> + Errors.raise_error + "Var %a of subscope %a in scope %a does not exist, please check \ + your command line arguments." + (Cli.format_with_style [ANSITerminal.yellow]) + ("\"" ^ second_part ^ "\"") + (Cli.format_with_style [ANSITerminal.yellow]) + (Format.asprintf "\"%a\"" Shared_ast.SubScopeName.format_t + subscope_var_name) + (Cli.format_with_style [ANSITerminal.yellow]) + (Format.asprintf "\"%a\"" Shared_ast.ScopeName.format_t scope_uid))) + | Some (Desugared.Name_resolution.ScopeVar v) -> + Some + (Desugared.Ast.ScopeDef.Var + ( v, + Option.map + (fun second_part -> + let var_sig = Shared_ast.ScopeVar.Map.find v ctxt.var_typs in + match + Shared_ast.IdentName.Map.find_opt second_part + var_sig.var_sig_states_idmap + with + | Some state -> state + | None -> + Errors.raise_error + "State %a is not found for variable %a of scope %a" + (Cli.format_with_style [ANSITerminal.yellow]) + ("\"" ^ second_part ^ "\"") + (Cli.format_with_style [ANSITerminal.yellow]) + ("\"" ^ first_part ^ "\"") + (Cli.format_with_style [ANSITerminal.yellow]) + (Format.asprintf "\"%a\"" Shared_ast.ScopeName.format_t + scope_uid)) + second_part ))) + (** Entry function for the executable. Returns a negative number in case of error. Usage: [driver source_file options]*) let driver source_file (options : Cli.options) : int = @@ -144,129 +265,9 @@ let driver source_file (options : Cli.options) : int = backend -> ( Cli.debug_print "Name resolution..."; let ctxt = Desugared.Name_resolution.form_context prgm in - let scope_uid = - match options.ex_scope, backend with - | None, `Interpret -> - Errors.raise_error "No scope was provided for execution." - | None, _ -> - let _, scope = - try - Shared_ast.IdentName.Map.filter_map - (fun _ -> function - | Desugared.Name_resolution.TScope (uid, _) -> Some uid - | _ -> None) - ctxt.typedefs - |> Shared_ast.IdentName.Map.choose - with Not_found -> - Errors.raise_error "There isn't any scope inside the program." - in - scope - | Some name, _ -> ( - match Shared_ast.IdentName.Map.find_opt name ctxt.typedefs with - | Some (Desugared.Name_resolution.TScope (uid, _)) -> uid - | _ -> - Errors.raise_error "There is no scope %a inside the program." - (Cli.format_with_style [ANSITerminal.yellow]) - ("\"" ^ name ^ "\"")) - in + let scope_uid = get_scope_uid options backend ctxt in (* This uid is a Desugared identifier *) - let variable_uid = - match options.ex_variable, backend with - | None, `Exceptions -> - Errors.raise_error - "Please specify a variable with the -v option to print its \ - exception tree." - | None, _ -> None - | Some name, _ -> ( - (* Sometimes the variable selected is of the form [a.b]*) - let first_part, second_part = - match - Re.( - exec_opt - (compile - @@ whole_string - @@ seq - [ - group (rep1 (compl [char '.'])); - char '.'; - group (rep1 any); - ]) - name) - with - | None -> name, None - | Some groups -> Re.Group.get groups 1, Some (Re.Group.get groups 2) - in - match - Shared_ast.IdentName.Map.find_opt first_part - (Shared_ast.ScopeName.Map.find scope_uid ctxt.scopes).var_idmap - with - | None -> - Errors.raise_error "Variable %a not found inside scope %a" - (Cli.format_with_style [ANSITerminal.yellow]) - ("\"" ^ name ^ "\"") - (Cli.format_with_style [ANSITerminal.yellow]) - (Format.asprintf "\"%a\"" Shared_ast.ScopeName.format_t scope_uid) - | Some - (Desugared.Name_resolution.SubScope - (subscope_var_name, subscope_name)) -> ( - match second_part with - | None -> - Errors.raise_error - "Subscope %a of scope %a cannot be selected by itself, please \ - add \".\" where is a subscope variable." - (Cli.format_with_style [ANSITerminal.yellow]) - (Format.asprintf "\"%a\"" Shared_ast.SubScopeName.format_t - subscope_var_name) - (Cli.format_with_style [ANSITerminal.yellow]) - (Format.asprintf "\"%a\"" Shared_ast.ScopeName.format_t - scope_uid) - | Some second_part -> ( - match - Shared_ast.IdentName.Map.find_opt second_part - (Shared_ast.ScopeName.Map.find subscope_name ctxt.scopes) - .var_idmap - with - | Some (Desugared.Name_resolution.ScopeVar v) -> - Some - (Shared_ast.DesugaredVarName.SubScopeVar (subscope_var_name, v)) - | _ -> - Errors.raise_error - "Var %a of subscope %a in scope %a does not exist, please \ - check your command line arguments." - (Cli.format_with_style [ANSITerminal.yellow]) - ("\"" ^ second_part ^ "\"") - (Cli.format_with_style [ANSITerminal.yellow]) - (Format.asprintf "\"%a\"" Shared_ast.SubScopeName.format_t - subscope_var_name) - (Cli.format_with_style [ANSITerminal.yellow]) - (Format.asprintf "\"%a\"" Shared_ast.ScopeName.format_t - scope_uid))) - | Some (Desugared.Name_resolution.ScopeVar v) -> - Some - (Shared_ast.DesugaredVarName.ScopeVar - ( v, - Option.map - (fun second_part -> - let var_sig = - Shared_ast.ScopeVar.Map.find v ctxt.var_typs - in - match - Shared_ast.IdentName.Map.find_opt second_part - var_sig.var_sig_states_idmap - with - | Some state -> state - | None -> - Errors.raise_error - "State %a is not found for variable %a of scope %a" - (Cli.format_with_style [ANSITerminal.yellow]) - ("\"" ^ second_part ^ "\"") - (Cli.format_with_style [ANSITerminal.yellow]) - ("\"" ^ first_part ^ "\"") - (Cli.format_with_style [ANSITerminal.yellow]) - (Format.asprintf "\"%a\"" - Shared_ast.ScopeName.format_t scope_uid)) - second_part ))) - in + let variable_uid = get_variable_uid options backend ctxt scope_uid in Cli.debug_print "Desugaring..."; let prgm = Desugared.From_surface.translate_program ctxt prgm in Cli.debug_print "Disambiguating..."; @@ -287,7 +288,7 @@ let driver source_file (options : Cli.options) : int = "Please provide a scope variable to analyze with the -v option." in Desugared.Print.print_exceptions_graph scope_uid variable_uid - (Shared_ast.DesugaredVarName.Map.find variable_uid exceptions_graphs) + (Desugared.Ast.ScopeDef.Map.find variable_uid exceptions_graphs) | `Scopelang -> let _output_file, with_output = get_output_format () in with_output diff --git a/compiler/scopelang/from_desugared.ml b/compiler/scopelang/from_desugared.ml index 50633d5f..e23853d9 100644 --- a/compiler/scopelang/from_desugared.ml +++ b/compiler/scopelang/from_desugared.ml @@ -435,7 +435,7 @@ let translate_def let translate_rule ctx (scope : Desugared.Ast.scope) = function | Desugared.Dependency.Vertex.Var (var, state) -> ( let scope_def = - Desugared.Ast.ScopeDefMap.find + Desugared.Ast.ScopeDef.Map.find (Desugared.Ast.ScopeDef.Var (var, state)) scope.scope_defs in @@ -455,7 +455,7 @@ let translate_rule ctx (scope : Desugared.Ast.scope) = function (RuleName.Map.bindings var_def)) "It is impossible to give a definition to a scope variable tagged as \ input." - | OnlyInput -> [], DesugaredVarName.Map.empty + | OnlyInput -> [], Desugared.Ast.ScopeDef.Map.empty (* we do not provide any definition for an input-only variable *) | _ -> let expr_def, exc_graph = @@ -479,8 +479,8 @@ let translate_rule ctx (scope : Desugared.Ast.scope) = function scope_def.Desugared.Ast.scope_def_io, Expr.unbox expr_def ); ], - DesugaredVarName.Map.singleton - (DesugaredVarName.ScopeVar (var, state)) + Desugared.Ast.ScopeDef.Map.singleton + (Desugared.Ast.ScopeDef.Var (var, state)) exc_graph )) | Desugared.Dependency.Vertex.SubScope sub_scope_index -> (* Before calling the sub_scope, we need to include all the re-definitions @@ -489,7 +489,7 @@ let translate_rule ctx (scope : Desugared.Ast.scope) = function SubScopeName.Map.find sub_scope_index scope.scope_sub_scopes in let sub_scope_vars_redefs_candidates = - Desugared.Ast.ScopeDefMap.filter + Desugared.Ast.ScopeDef.Map.filter (fun def_key scope_def -> match def_key with | Desugared.Ast.ScopeDef.Var _ -> false @@ -507,7 +507,7 @@ let translate_rule ctx (scope : Desugared.Ast.scope) = function scope.scope_defs in let sub_scope_vars_redefs = - Desugared.Ast.ScopeDefMap.mapi + Desugared.Ast.ScopeDef.Map.mapi (fun def_key scope_def -> let def = scope_def.Desugared.Ast.scope_def_rules in let def_typ = scope_def.scope_def_typ in @@ -573,11 +573,11 @@ let translate_rule ctx (scope : Desugared.Ast.scope) = function def_typ, scope_def.Desugared.Ast.scope_def_io, Expr.unbox expr_def ), - (exc_graph, sub_scope_var) )) + (exc_graph, sub_scope_var, var_pos) )) sub_scope_vars_redefs_candidates in let sub_scope_vars_redefs_and_exc_graphs = - List.map snd (Desugared.Ast.ScopeDefMap.bindings sub_scope_vars_redefs) + List.map snd (Desugared.Ast.ScopeDef.Map.bindings sub_scope_vars_redefs) in let sub_scope_vars_redefs = List.map fst sub_scope_vars_redefs_and_exc_graphs @@ -593,17 +593,19 @@ let translate_rule ctx (scope : Desugared.Ast.scope) = function } ); ], List.fold_left - (fun exc_graphs (new_exc_graph, subscope_var) -> - DesugaredVarName.Map.add - (DesugaredVarName.SubScopeVar (sub_scope_index, subscope_var)) + (fun exc_graphs (new_exc_graph, subscope_var, var_pos) -> + Desugared.Ast.ScopeDef.Map.add + (Desugared.Ast.ScopeDef.SubScopeVar + (sub_scope_index, subscope_var, var_pos)) new_exc_graph exc_graphs) - DesugaredVarName.Map.empty + Desugared.Ast.ScopeDef.Map.empty (List.map snd sub_scope_vars_redefs_and_exc_graphs) ) (** Translates a scope *) let translate_scope (ctx : ctx) (scope : Desugared.Ast.scope) : untyped Ast.scope_decl - * Desugared.Dependency.ExceptionsDependencies.t DesugaredVarName.Map.t = + * Desugared.Dependency.ExceptionsDependencies.t Desugared.Ast.ScopeDef.Map.t + = let scope_dependencies = Desugared.Dependency.build_scope_dependencies scope in @@ -618,10 +620,10 @@ let translate_scope (ctx : ctx) (scope : Desugared.Ast.scope) : translate_rule ctx scope scope_def_key in ( scope_decl_rules @ new_rules, - DesugaredVarName.Map.union + Desugared.Ast.ScopeDef.Map.union (fun _ _ _ -> assert false (* there should not be key conflicts *)) new_exceptions_graphs exceptions_graphs )) - ([], DesugaredVarName.Map.empty) + ([], Desugared.Ast.ScopeDef.Map.empty) scope_ordering in (* Then, after having computed all the scopes variables, we add the @@ -641,7 +643,7 @@ let translate_scope (ctx : ctx) (scope : Desugared.Ast.scope) : match states with | WholeVar -> let scope_def = - Desugared.Ast.ScopeDefMap.find + Desugared.Ast.ScopeDef.Map.find (Desugared.Ast.ScopeDef.Var (var, None)) scope.scope_defs in @@ -659,7 +661,7 @@ let translate_scope (ctx : ctx) (scope : Desugared.Ast.scope) : List.fold_left (fun acc (state : StateName.t) -> let scope_def = - Desugared.Ast.ScopeDefMap.find + Desugared.Ast.ScopeDef.Map.find (Desugared.Ast.ScopeDef.Var (var, Some state)) scope.scope_defs in @@ -686,7 +688,8 @@ let translate_scope (ctx : ctx) (scope : Desugared.Ast.scope) : let translate_program (pgrm : Desugared.Ast.program) : untyped Ast.program - * Desugared.Dependency.ExceptionsDependencies.t DesugaredVarName.Map.t = + * Desugared.Dependency.ExceptionsDependencies.t Desugared.Ast.ScopeDef.Map.t + = (* First we give mappings to all the locations between Desugared and This involves creating a new Scopelang scope variable for every state of a Desugared variable. *) @@ -748,11 +751,11 @@ let translate_program (pgrm : Desugared.Ast.program) : translate_scope ctx scope in ( ScopeName.Map.add scope_name new_program_scope new_program_scopes, - DesugaredVarName.Map.union + Desugared.Ast.ScopeDef.Map.union (fun _ _ _ -> assert false (* key conflicts should not happen*)) new_exceptions_graphs exceptions_graph )) pgrm.program_scopes - (ScopeName.Map.empty, DesugaredVarName.Map.empty) + (ScopeName.Map.empty, Desugared.Ast.ScopeDef.Map.empty) in ( { Ast.program_topdefs = diff --git a/compiler/scopelang/from_desugared.mli b/compiler/scopelang/from_desugared.mli index 742d6eae..445e0d80 100644 --- a/compiler/scopelang/from_desugared.mli +++ b/compiler/scopelang/from_desugared.mli @@ -19,7 +19,6 @@ val translate_program : Desugared.Ast.program -> Shared_ast.untyped Ast.program - * Desugared.Dependency.ExceptionsDependencies.t - Shared_ast.DesugaredVarName.Map.t + * Desugared.Dependency.ExceptionsDependencies.t Desugared.Ast.ScopeDef.Map.t (** This functions returns the translated program as well as all the graphs of exceptions inferred for each scope variable of the program. *) diff --git a/compiler/shared_ast/definitions.ml b/compiler/shared_ast/definitions.ml index bc52cf8d..a9a0d37a 100644 --- a/compiler/shared_ast/definitions.ml +++ b/compiler/shared_ast/definitions.ml @@ -45,67 +45,6 @@ module SubScopeName = Uid.Gen () module StateName = Uid.Gen () (** {1 Abstract syntax tree} *) -module DesugaredVarName : sig - type t = - | ScopeVar of ScopeVar.t * StateName.t option - | SubScopeVar of SubScopeName.t * ScopeVar.t - - val hash : t -> int - val compare : t -> t -> int - val equal : t -> t -> bool - val format : Format.formatter -> t -> unit - - module Map : Map.S with type key = t - module Set : Set.S with type elt = t -end = struct - module Ordering = struct - type t = - | ScopeVar of ScopeVar.t * StateName.t option - | SubScopeVar of SubScopeName.t * ScopeVar.t - - let hash x = - match x with - | ScopeVar (x, None) -> ScopeVar.hash x - | ScopeVar (x, Some sx) -> - Int.logxor (ScopeVar.hash x) (StateName.hash sx) - | SubScopeVar (x, y) -> Int.logxor (SubScopeName.hash x) (ScopeVar.hash y) - - let compare x y = - match x, y with - | ScopeVar (x, xst), ScopeVar (y, yst) -> ( - match ScopeVar.compare x y with - | 0 -> Option.compare StateName.compare xst yst - | n -> n) - | SubScopeVar (x, xv), SubScopeVar (y, yv) -> ( - match SubScopeName.compare x y with - | 0 -> ScopeVar.compare xv yv - | n -> n) - | ScopeVar _, _ -> -1 - | _, ScopeVar _ -> 1 - | SubScopeVar _, _ -> . - | _, SubScopeVar _ -> . - - let equal x y = - match x, y with - | ScopeVar (x, sx), ScopeVar (y, sy) -> - ScopeVar.equal x y && Option.equal StateName.equal sx sy - | SubScopeVar (x, xv), SubScopeVar (y, yv) -> - SubScopeName.equal x y && ScopeVar.equal xv yv - | (ScopeVar _ | SubScopeVar _), _ -> false - - let format fmt x = - match x with - | ScopeVar (v, None) -> ScopeVar.format_t fmt v - | ScopeVar (v, Some st) -> - Format.fprintf fmt "%a.%a" ScopeVar.format_t v StateName.format_t st - | SubScopeVar (ss, v) -> - Format.fprintf fmt "%a.%a" SubScopeName.format_t ss ScopeVar.format_t v - end - - include Ordering - module Map = Map.Make (Ordering) - module Set = Set.Make (Ordering) -end (** Define a common base type for the expressions in most passes of the compiler *) From 0266252854c8b8cf188de21b3cab12ec9735c54b Mon Sep 17 00:00:00 2001 From: Denis Merigoux Date: Tue, 18 Apr 2023 11:06:58 +0200 Subject: [PATCH 10/10] Refactoring for cleaner exception graph building --- compiler/desugared/dependency.ml | 3 +- compiler/driver.ml | 7 +- compiler/scopelang/from_desugared.ml | 442 +++++++++++++++----------- compiler/scopelang/from_desugared.mli | 8 +- 4 files changed, 277 insertions(+), 183 deletions(-) diff --git a/compiler/desugared/dependency.ml b/compiler/desugared/dependency.ml index 742e8897..541afef1 100644 --- a/compiler/desugared/dependency.ml +++ b/compiler/desugared/dependency.ml @@ -365,7 +365,8 @@ let build_exceptions_graph pass at the surface AST level that fills the law info on positions only does it for positions inside expressions, the visitor in [surface/fill_positions.ml] does not go into the - info of [RuleName.t], etc.*) + info of [RuleName.t], etc. Related issue: + https://github.com/CatalaLang/catala/issues/194 *) Pos.overwrite_law_info (snd (RuleName.get_info rule.rule_id)) (Pos.get_law_info (Expr.pos rule.rule_just)) diff --git a/compiler/driver.ml b/compiler/driver.ml index 080d2dc1..215ceb7b 100644 --- a/compiler/driver.ml +++ b/compiler/driver.ml @@ -275,8 +275,11 @@ let driver source_file (options : Cli.options) : int = Cli.debug_print "Linting..."; Desugared.Linting.lint_program prgm; Cli.debug_print "Collecting rules..."; - let prgm, exceptions_graphs = - Scopelang.From_desugared.translate_program prgm + let exceptions_graphs = + Scopelang.From_desugared.build_exceptions_graph prgm + in + let prgm = + Scopelang.From_desugared.translate_program prgm exceptions_graphs in match backend with | `Exceptions -> diff --git a/compiler/scopelang/from_desugared.ml b/compiler/scopelang/from_desugared.ml index e23853d9..c69f1b43 100644 --- a/compiler/scopelang/from_desugared.ml +++ b/compiler/scopelang/from_desugared.ml @@ -178,12 +178,149 @@ type rule_tree = (** Transforms a flat list of rules into a tree, taking into account the priorities declared between rules *) -let def_map_to_tree +let def_to_exception_graph (def_info : Desugared.Ast.ScopeDef.t) (def : Desugared.Ast.rule RuleName.Map.t) : - rule_tree list * Desugared.Dependency.ExceptionsDependencies.t = + Desugared.Dependency.ExceptionsDependencies.t = let exc_graph = Desugared.Dependency.build_exceptions_graph def def_info in Desugared.Dependency.check_for_exception_cycle def exc_graph; + exc_graph + +let rule_to_exception_graph (scope : Desugared.Ast.scope) = function + | Desugared.Dependency.Vertex.Var (var, state) -> ( + let scope_def = + Desugared.Ast.ScopeDef.Map.find + (Desugared.Ast.ScopeDef.Var (var, state)) + scope.scope_defs + in + let var_def = scope_def.D.scope_def_rules in + match Marked.unmark scope_def.Desugared.Ast.scope_def_io.io_input with + | OnlyInput when not (RuleName.Map.is_empty var_def) -> + (* If the variable is tagged as input, then it shall not be redefined. *) + Errors.raise_multispanned_error + ((Some "Incriminated variable:", Marked.get_mark (ScopeVar.get_info var)) + :: List.map + (fun (rule, _) -> + ( Some "Incriminated variable definition:", + Marked.get_mark (RuleName.get_info rule) )) + (RuleName.Map.bindings var_def)) + "It is impossible to give a definition to a scope variable tagged as \ + input." + | OnlyInput -> Desugared.Ast.ScopeDef.Map.empty + (* we do not provide any definition for an input-only variable *) + | _ -> + Desugared.Ast.ScopeDef.Map.singleton + (Desugared.Ast.ScopeDef.Var (var, state)) + (def_to_exception_graph + (Desugared.Ast.ScopeDef.Var (var, state)) + var_def)) + | Desugared.Dependency.Vertex.SubScope sub_scope_index -> + (* Before calling the sub_scope, we need to include all the re-definitions + of subscope parameters*) + let sub_scope_vars_redefs_candidates = + Desugared.Ast.ScopeDef.Map.filter + (fun def_key scope_def -> + match def_key with + | Desugared.Ast.ScopeDef.Var _ -> false + | Desugared.Ast.ScopeDef.SubScopeVar (sub_scope_index', _, _) -> + sub_scope_index = sub_scope_index' + (* We exclude subscope variables that have 0 re-definitions and are + not visible in the input of the subscope *) + && not + ((match + Marked.unmark scope_def.Desugared.Ast.scope_def_io.io_input + with + | Desugared.Ast.NoInput -> true + | _ -> false) + && RuleName.Map.is_empty scope_def.scope_def_rules)) + scope.scope_defs + in + let sub_scope_vars_redefs = + Desugared.Ast.ScopeDef.Map.mapi + (fun def_key scope_def -> + let def = scope_def.Desugared.Ast.scope_def_rules in + let is_cond = scope_def.scope_def_is_condition in + match def_key with + | Desugared.Ast.ScopeDef.Var _ -> assert false (* should not happen *) + | Desugared.Ast.ScopeDef.SubScopeVar (sscope, sub_scope_var, pos) -> + (* This definition redefines a variable of the correct subscope. But + we have to check that this redefinition is allowed with respect + to the io parameters of that subscope variable. *) + (match + Marked.unmark scope_def.Desugared.Ast.scope_def_io.io_input + with + | Desugared.Ast.NoInput -> + Errors.raise_multispanned_error + (( Some "Incriminated subscope:", + Marked.get_mark (SubScopeName.get_info sscope) ) + :: ( Some "Incriminated variable:", + Marked.get_mark (ScopeVar.get_info sub_scope_var) ) + :: List.map + (fun (rule, _) -> + ( Some "Incriminated subscope variable definition:", + Marked.get_mark (RuleName.get_info rule) )) + (RuleName.Map.bindings def)) + "It is impossible to give a definition to a subscope variable \ + not tagged as input or context." + | OnlyInput when RuleName.Map.is_empty def && not is_cond -> + (* If the subscope variable is tagged as input, then it shall be + defined. *) + Errors.raise_multispanned_error + [ + ( Some "Incriminated subscope:", + Marked.get_mark (SubScopeName.get_info sscope) ); + Some "Incriminated variable:", pos; + ] + "This subscope variable is a mandatory input but no definition \ + was provided." + | _ -> ()); + let exc_graph = def_to_exception_graph def_key def in + let var_pos = Desugared.Ast.ScopeDef.get_position def_key in + exc_graph, sub_scope_var, var_pos) + sub_scope_vars_redefs_candidates + in + List.fold_left + (fun exc_graphs (new_exc_graph, subscope_var, var_pos) -> + Desugared.Ast.ScopeDef.Map.add + (Desugared.Ast.ScopeDef.SubScopeVar + (sub_scope_index, subscope_var, var_pos)) + new_exc_graph exc_graphs) + Desugared.Ast.ScopeDef.Map.empty + (List.map snd (Desugared.Ast.ScopeDef.Map.bindings sub_scope_vars_redefs)) + +let scope_to_exception_graphs (scope : Desugared.Ast.scope) : + Desugared.Dependency.ExceptionsDependencies.t Desugared.Ast.ScopeDef.Map.t = + let scope_dependencies = + Desugared.Dependency.build_scope_dependencies scope + in + Desugared.Dependency.check_for_cycle scope scope_dependencies; + let scope_ordering = + Desugared.Dependency.correct_computation_ordering scope_dependencies + in + List.fold_left + (fun exceptions_graphs scope_def_key -> + let new_exceptions_graphs = rule_to_exception_graph scope scope_def_key in + Desugared.Ast.ScopeDef.Map.union + (fun _ _ _ -> assert false (* there should not be key conflicts *)) + new_exceptions_graphs exceptions_graphs) + Desugared.Ast.ScopeDef.Map.empty scope_ordering + +let build_exceptions_graph (pgrm : Desugared.Ast.program) : + Desugared.Dependency.ExceptionsDependencies.t Desugared.Ast.ScopeDef.Map.t = + ScopeName.Map.fold + (fun _ scope exceptions_graph -> + let new_exceptions_graphs = scope_to_exception_graphs scope in + Desugared.Ast.ScopeDef.Map.union + (fun _ _ _ -> assert false (* key conflicts should not happen*)) + new_exceptions_graphs exceptions_graph) + pgrm.program_scopes Desugared.Ast.ScopeDef.Map.empty + +(** Transforms a flat list of rules into a tree, taking into account the + priorities declared between rules *) +let def_map_to_tree + (def : Desugared.Ast.rule RuleName.Map.t) + (exc_graph : Desugared.Dependency.ExceptionsDependencies.t) : rule_tree list + = (* we start by the base cases: they are the vertices which have no successors *) let base_cases = @@ -209,7 +346,7 @@ let def_map_to_tree | [] -> Leaf base_case_as_rule_list | _ -> Node (List.map build_tree exceptions, base_case_as_rule_list) in - List.map build_tree base_cases, exc_graph + List.map build_tree base_cases (** From the {!type: rule_tree}, builds an {!constructor: Dcalc.EDefault} expression in the scope language. The [~toplevel] parameter is used to know @@ -337,17 +474,18 @@ let rec rule_tree_to_expr (** Translates a definition inside a scope, the resulting expression should be an {!constructor: Dcalc.EDefault} *) let translate_def + ~(is_cond : bool) + ~(is_subscope_var : bool) (ctx : ctx) (def_info : Desugared.Ast.ScopeDef.t) (def : Desugared.Ast.rule RuleName.Map.t) (params : (Uid.MarkedString.info * typ) list Marked.pos option) (typ : typ) (io : Desugared.Ast.io) - ~(is_cond : bool) - ~(is_subscope_var : bool) : - untyped Ast.expr boxed * Desugared.Dependency.ExceptionsDependencies.t = + (exc_graph : Desugared.Dependency.ExceptionsDependencies.t) : + untyped Ast.expr boxed = (* Here, we have to transform this list of rules into a default tree. *) - let top_list, exc_graph = def_map_to_tree def_info def in + let top_list = def_map_to_tree def exc_graph in let is_input = match Marked.unmark io.Desugared.Ast.io_input with | OnlyInput -> true @@ -400,39 +538,41 @@ let translate_def match params with | Some (ps, _) -> let labels, tys = List.split ps in - ( Expr.make_abs - (Array.of_list - (List.map (fun lbl -> Var.make (Marked.unmark lbl)) labels)) - empty_error tys (Expr.mark_pos m), - exc_graph ) - | _ -> empty_error, exc_graph + Expr.make_abs + (Array.of_list + (List.map (fun lbl -> Var.make (Marked.unmark lbl)) labels)) + empty_error tys (Expr.mark_pos m) + | _ -> empty_error else - ( rule_tree_to_expr ~toplevel:true ~is_reentrant_var:is_reentrant ctx - (Desugared.Ast.ScopeDef.get_position def_info) - (Option.map - (fun (ps, _) -> - (List.map (fun (lbl, _) -> Var.make (Marked.unmark lbl))) ps) - params) - (match top_list, top_value with - | [], None -> - (* In this case, there are no rules to define the expression and no - default value so we put an empty rule. *) - Leaf [Desugared.Ast.empty_rule (Marked.get_mark typ) params] - | [], Some top_value -> - (* In this case, there are no rules to define the expression but a - default value so we put it. *) - Leaf [top_value] - | _, Some top_value -> - (* When there are rules + a default value, we put the rules as - exceptions to the default value *) - Node (top_list, [top_value]) - | [top_tree], None -> top_tree - | _, None -> - Node - (top_list, [Desugared.Ast.empty_rule (Marked.get_mark typ) params])), - exc_graph ) + rule_tree_to_expr ~toplevel:true ~is_reentrant_var:is_reentrant ctx + (Desugared.Ast.ScopeDef.get_position def_info) + (Option.map + (fun (ps, _) -> + (List.map (fun (lbl, _) -> Var.make (Marked.unmark lbl))) ps) + params) + (match top_list, top_value with + | [], None -> + (* In this case, there are no rules to define the expression and no + default value so we put an empty rule. *) + Leaf [Desugared.Ast.empty_rule (Marked.get_mark typ) params] + | [], Some top_value -> + (* In this case, there are no rules to define the expression but a + default value so we put it. *) + Leaf [top_value] + | _, Some top_value -> + (* When there are rules + a default value, we put the rules as + exceptions to the default value *) + Node (top_list, [top_value]) + | [top_tree], None -> top_tree + | _, None -> + Node (top_list, [Desugared.Ast.empty_rule (Marked.get_mark typ) params])) -let translate_rule ctx (scope : Desugared.Ast.scope) = function +let translate_rule + ctx + (scope : Desugared.Ast.scope) + (exc_graphs : + Desugared.Dependency.ExceptionsDependencies.t Desugared.Ast.ScopeDef.Map.t) + = function | Desugared.Dependency.Vertex.Var (var, state) -> ( let scope_def = Desugared.Ast.ScopeDef.Map.find @@ -445,23 +585,15 @@ let translate_rule ctx (scope : Desugared.Ast.scope) = function let is_cond = scope_def.D.scope_def_is_condition in match Marked.unmark scope_def.Desugared.Ast.scope_def_io.io_input with | OnlyInput when not (RuleName.Map.is_empty var_def) -> - (* If the variable is tagged as input, then it shall not be redefined. *) - Errors.raise_multispanned_error - ((Some "Incriminated variable:", Marked.get_mark (ScopeVar.get_info var)) - :: List.map - (fun (rule, _) -> - ( Some "Incriminated variable definition:", - Marked.get_mark (RuleName.get_info rule) )) - (RuleName.Map.bindings var_def)) - "It is impossible to give a definition to a scope variable tagged as \ - input." - | OnlyInput -> [], Desugared.Ast.ScopeDef.Map.empty + assert false (* error already raised *) + | OnlyInput -> [] (* we do not provide any definition for an input-only variable *) | _ -> - let expr_def, exc_graph = - translate_def ctx - (Desugared.Ast.ScopeDef.Var (var, state)) - var_def var_params var_typ scope_def.Desugared.Ast.scope_def_io + let scope_def_key = Desugared.Ast.ScopeDef.Var (var, state) in + let expr_def = + translate_def ctx scope_def_key var_def var_params var_typ + scope_def.Desugared.Ast.scope_def_io + (Desugared.Ast.ScopeDef.Map.find scope_def_key exc_graphs) ~is_cond ~is_subscope_var:false in let scope_var = @@ -470,18 +602,15 @@ let translate_rule ctx (scope : Desugared.Ast.scope) = function | States states, Some state -> List.assoc state states | _ -> failwith "should not happen" in - ( [ - Ast.Definition - ( ( ScopelangScopeVar - (scope_var, Marked.get_mark (ScopeVar.get_info scope_var)), - Marked.get_mark (ScopeVar.get_info scope_var) ), - var_typ, - scope_def.Desugared.Ast.scope_def_io, - Expr.unbox expr_def ); - ], - Desugared.Ast.ScopeDef.Map.singleton - (Desugared.Ast.ScopeDef.Var (var, state)) - exc_graph )) + [ + Ast.Definition + ( ( ScopelangScopeVar + (scope_var, Marked.get_mark (ScopeVar.get_info scope_var)), + Marked.get_mark (ScopeVar.get_info scope_var) ), + var_typ, + scope_def.Desugared.Ast.scope_def_io, + Expr.unbox expr_def ); + ]) | Desugared.Dependency.Vertex.SubScope sub_scope_index -> (* Before calling the sub_scope, we need to include all the re-definitions of subscope parameters*) @@ -514,98 +643,66 @@ let translate_rule ctx (scope : Desugared.Ast.scope) = function let is_cond = scope_def.scope_def_is_condition in match def_key with | Desugared.Ast.ScopeDef.Var _ -> assert false (* should not happen *) - | Desugared.Ast.ScopeDef.SubScopeVar (sscope, sub_scope_var, pos) -> + | Desugared.Ast.ScopeDef.SubScopeVar (_, sub_scope_var, var_pos) -> (* This definition redefines a variable of the correct subscope. But we have to check that this redefinition is allowed with respect to the io parameters of that subscope variable. *) (match Marked.unmark scope_def.Desugared.Ast.scope_def_io.io_input with - | Desugared.Ast.NoInput -> - Errors.raise_multispanned_error - (( Some "Incriminated subscope:", - Marked.get_mark (SubScopeName.get_info sscope) ) - :: ( Some "Incriminated variable:", - Marked.get_mark (ScopeVar.get_info sub_scope_var) ) - :: List.map - (fun (rule, _) -> - ( Some "Incriminated subscope variable definition:", - Marked.get_mark (RuleName.get_info rule) )) - (RuleName.Map.bindings def)) - "It is impossible to give a definition to a subscope variable \ - not tagged as input or context." + | Desugared.Ast.NoInput -> assert false (* error already raised *) | OnlyInput when RuleName.Map.is_empty def && not is_cond -> - (* If the subscope variable is tagged as input, then it shall be - defined. *) - Errors.raise_multispanned_error - [ - ( Some "Incriminated subscope:", - Marked.get_mark (SubScopeName.get_info sscope) ); - Some "Incriminated variable:", pos; - ] - "This subscope variable is a mandatory input but no definition \ - was provided." + assert false (* error already raised *) | _ -> ()); (* Now that all is good, we can proceed with translating this redefinition to a proper Scopelang term. *) - let expr_def, exc_graph = + let expr_def = translate_def ctx def_key def scope_def.D.scope_def_parameters - def_typ scope_def.Desugared.Ast.scope_def_io ~is_cond - ~is_subscope_var:true + def_typ scope_def.Desugared.Ast.scope_def_io + (Desugared.Ast.ScopeDef.Map.find def_key exc_graphs) + ~is_cond ~is_subscope_var:true in let subscop_real_name = SubScopeName.Map.find sub_scope_index scope.scope_sub_scopes in - let var_pos = Desugared.Ast.ScopeDef.get_position def_key in - ( Ast.Definition - ( ( SubScopeVar - ( subscop_real_name, - (sub_scope_index, var_pos), - match - ScopeVar.Map.find sub_scope_var ctx.scope_var_mapping - with - | WholeVar v -> v, var_pos - | States states -> - (* When defining a sub-scope variable, we always - define its first state in the sub-scope. *) - snd (List.hd states), var_pos ), - var_pos ), - def_typ, - scope_def.Desugared.Ast.scope_def_io, - Expr.unbox expr_def ), - (exc_graph, sub_scope_var, var_pos) )) + Ast.Definition + ( ( SubScopeVar + ( subscop_real_name, + (sub_scope_index, var_pos), + match + ScopeVar.Map.find sub_scope_var ctx.scope_var_mapping + with + | WholeVar v -> v, var_pos + | States states -> + (* When defining a sub-scope variable, we always define + its first state in the sub-scope. *) + snd (List.hd states), var_pos ), + var_pos ), + def_typ, + scope_def.Desugared.Ast.scope_def_io, + Expr.unbox expr_def )) sub_scope_vars_redefs_candidates in - let sub_scope_vars_redefs_and_exc_graphs = + let sub_scope_vars_redefs = List.map snd (Desugared.Ast.ScopeDef.Map.bindings sub_scope_vars_redefs) in - let sub_scope_vars_redefs = - List.map fst sub_scope_vars_redefs_and_exc_graphs - in - ( sub_scope_vars_redefs - @ [ - Ast.Call - ( sub_scope, - sub_scope_index, - Untyped - { - pos = Marked.get_mark (SubScopeName.get_info sub_scope_index); - } ); - ], - List.fold_left - (fun exc_graphs (new_exc_graph, subscope_var, var_pos) -> - Desugared.Ast.ScopeDef.Map.add - (Desugared.Ast.ScopeDef.SubScopeVar - (sub_scope_index, subscope_var, var_pos)) - new_exc_graph exc_graphs) - Desugared.Ast.ScopeDef.Map.empty - (List.map snd sub_scope_vars_redefs_and_exc_graphs) ) + sub_scope_vars_redefs + @ [ + Ast.Call + ( sub_scope, + sub_scope_index, + Untyped + { pos = Marked.get_mark (SubScopeName.get_info sub_scope_index) } + ); + ] (** Translates a scope *) -let translate_scope (ctx : ctx) (scope : Desugared.Ast.scope) : - untyped Ast.scope_decl - * Desugared.Dependency.ExceptionsDependencies.t Desugared.Ast.ScopeDef.Map.t - = +let translate_scope + (ctx : ctx) + (scope : Desugared.Ast.scope) + (exc_graphs : + Desugared.Dependency.ExceptionsDependencies.t Desugared.Ast.ScopeDef.Map.t) + : untyped Ast.scope_decl = let scope_dependencies = Desugared.Dependency.build_scope_dependencies scope in @@ -613,18 +710,12 @@ let translate_scope (ctx : ctx) (scope : Desugared.Ast.scope) : let scope_ordering = Desugared.Dependency.correct_computation_ordering scope_dependencies in - let scope_decl_rules, exceptions_graphs = + let scope_decl_rules = List.fold_left - (fun (scope_decl_rules, exceptions_graphs) scope_def_key -> - let new_rules, new_exceptions_graphs = - translate_rule ctx scope scope_def_key - in - ( scope_decl_rules @ new_rules, - Desugared.Ast.ScopeDef.Map.union - (fun _ _ _ -> assert false (* there should not be key conflicts *)) - new_exceptions_graphs exceptions_graphs )) - ([], Desugared.Ast.ScopeDef.Map.empty) - scope_ordering + (fun scope_decl_rules scope_def_key -> + let new_rules = translate_rule ctx scope exc_graphs scope_def_key in + scope_decl_rules @ new_rules) + [] scope_ordering in (* Then, after having computed all the scopes variables, we add the assertions. TODO: the assertions should be interleaved with the @@ -675,21 +766,21 @@ let translate_scope (ctx : ctx) (scope : Desugared.Ast.scope) : scope.scope_vars ScopeVar.Map.empty in let pos = Marked.get_mark (ScopeName.get_info scope.scope_uid) in - ( { - Ast.scope_decl_name = scope.scope_uid; - Ast.scope_decl_rules; - Ast.scope_sig; - Ast.scope_mark = Untyped { pos }; - Ast.scope_options = scope.scope_options; - }, - exceptions_graphs ) + { + Ast.scope_decl_name = scope.scope_uid; + Ast.scope_decl_rules; + Ast.scope_sig; + Ast.scope_mark = Untyped { pos }; + Ast.scope_options = scope.scope_options; + } (** {1 API} *) -let translate_program (pgrm : Desugared.Ast.program) : - untyped Ast.program - * Desugared.Dependency.ExceptionsDependencies.t Desugared.Ast.ScopeDef.Map.t - = +let translate_program + (pgrm : Desugared.Ast.program) + (exc_graphs : + Desugared.Dependency.ExceptionsDependencies.t Desugared.Ast.ScopeDef.Map.t) + : untyped Ast.program = (* First we give mappings to all the locations between Desugared and This involves creating a new Scopelang scope variable for every state of a Desugared variable. *) @@ -744,25 +835,18 @@ let translate_program (pgrm : Desugared.Ast.program) : { out_str with out_struct_fields }) pgrm.Desugared.Ast.program_ctx.ctx_scopes in - let new_program_scopes, exceptions_graphs = + let new_program_scopes = ScopeName.Map.fold - (fun scope_name scope (new_program_scopes, exceptions_graph) -> - let new_program_scope, new_exceptions_graphs = - translate_scope ctx scope - in - ( ScopeName.Map.add scope_name new_program_scope new_program_scopes, - Desugared.Ast.ScopeDef.Map.union - (fun _ _ _ -> assert false (* key conflicts should not happen*)) - new_exceptions_graphs exceptions_graph )) - pgrm.program_scopes - (ScopeName.Map.empty, Desugared.Ast.ScopeDef.Map.empty) + (fun scope_name scope new_program_scopes -> + let new_program_scope = translate_scope ctx scope exc_graphs in + ScopeName.Map.add scope_name new_program_scope new_program_scopes) + pgrm.program_scopes ScopeName.Map.empty in - ( { - Ast.program_topdefs = - TopdefName.Map.map - (fun (e, ty) -> Expr.unbox (translate_expr ctx e), ty) - pgrm.program_topdefs; - Ast.program_scopes = new_program_scopes; - program_ctx = { pgrm.program_ctx with ctx_scopes }; - }, - exceptions_graphs ) + { + Ast.program_topdefs = + TopdefName.Map.map + (fun (e, ty) -> Expr.unbox (translate_expr ctx e), ty) + pgrm.program_topdefs; + Ast.program_scopes = new_program_scopes; + program_ctx = { pgrm.program_ctx with ctx_scopes }; + } diff --git a/compiler/scopelang/from_desugared.mli b/compiler/scopelang/from_desugared.mli index 445e0d80..601f11a7 100644 --- a/compiler/scopelang/from_desugared.mli +++ b/compiler/scopelang/from_desugared.mli @@ -16,9 +16,15 @@ (** Translation from {!module: Desugared.Ast} to {!module: Scopelang.Ast} *) +val build_exceptions_graph : + Desugared.Ast.program -> + Desugared.Dependency.ExceptionsDependencies.t Desugared.Ast.ScopeDef.Map.t +(** This function builds all the exceptions dependency graphs for all variables + of all scopes. *) + val translate_program : Desugared.Ast.program -> + Desugared.Dependency.ExceptionsDependencies.t Desugared.Ast.ScopeDef.Map.t -> Shared_ast.untyped Ast.program - * Desugared.Dependency.ExceptionsDependencies.t Desugared.Ast.ScopeDef.Map.t (** This functions returns the translated program as well as all the graphs of exceptions inferred for each scope variable of the program. *)