diff --git a/compiler/dcalc/dune b/compiler/dcalc/dune index 9b105b24..c0875ed2 100644 --- a/compiler/dcalc/dune +++ b/compiler/dcalc/dune @@ -1,7 +1,15 @@ (library (name dcalc) (public_name catala.dcalc) - (libraries bindlib unionFind utils re ubase catala.runtime_ocaml shared_ast) + (libraries + bindlib + unionFind + utils + re + ubase + catala.runtime_ocaml + shared_ast + scopelang) (preprocess (pps visitors.ppx))) diff --git a/compiler/scopelang/scope_to_dcalc.ml b/compiler/dcalc/from_scopelang.ml similarity index 93% rename from compiler/scopelang/scope_to_dcalc.ml rename to compiler/dcalc/from_scopelang.ml index b2184694..7ddffc5c 100644 --- a/compiler/scopelang/scope_to_dcalc.ml +++ b/compiler/dcalc/from_scopelang.ml @@ -20,19 +20,18 @@ open Shared_ast type scope_var_ctx = { scope_var_name : ScopeVar.t; scope_var_typ : naked_typ; - scope_var_io : Ast.io; + scope_var_io : Desugared.Ast.io; } type 'm scope_sig_ctx = { scope_sig_local_vars : scope_var_ctx list; (** List of scope variables *) - scope_sig_scope_var : 'm Dcalc.Ast.expr Var.t; - (** Var representing the scope *) - scope_sig_input_var : 'm Dcalc.Ast.expr Var.t; + scope_sig_scope_var : 'm Ast.expr Var.t; (** Var representing the scope *) + scope_sig_input_var : 'm Ast.expr Var.t; (** Var representing the scope input inside the scope func *) scope_sig_input_struct : StructName.t; (** Scope input *) scope_sig_output_struct : StructName.t; (** Scope output *) scope_sig_in_fields : - (StructFieldName.t * Ast.io_input Marked.pos) ScopeVarMap.t; + (StructFieldName.t * Desugared.Ast.io_input Marked.pos) ScopeVarMap.t; (** Mapping between the input scope variables and the input struct fields. *) scope_sig_out_fields : StructFieldName.t ScopeVarMap.t; (** Mapping between the output scope variables and the output struct @@ -47,10 +46,11 @@ type 'm ctx = { enums : enum_ctx; scope_name : ScopeName.t; scopes_parameters : 'm scope_sigs_ctx; - scope_vars : ('m Dcalc.Ast.expr Var.t * naked_typ * Ast.io) ScopeVarMap.t; + scope_vars : ('m Ast.expr Var.t * naked_typ * Desugared.Ast.io) ScopeVarMap.t; subscope_vars : - ('m Dcalc.Ast.expr Var.t * naked_typ * Ast.io) ScopeVarMap.t SubScopeMap.t; - local_vars : ('m Ast.expr, 'm Dcalc.Ast.expr Var.t) Var.Map.t; + ('m Ast.expr Var.t * naked_typ * Desugared.Ast.io) ScopeVarMap.t + SubScopeMap.t; + local_vars : ('m Scopelang.Ast.expr, 'm Ast.expr Var.t) Var.Map.t; } let empty_ctx @@ -99,9 +99,9 @@ let merge_defaults caller callee = body let tag_with_log_entry - (e : 'm Dcalc.Ast.expr boxed) + (e : 'm Ast.expr boxed) (l : log_entry) - (markings : Utils.Uid.MarkedString.info list) : 'm Dcalc.Ast.expr boxed = + (markings : Utils.Uid.MarkedString.info list) : 'm Ast.expr boxed = let m = mark_tany (Marked.get_mark e) (Expr.pos e) in Expr.eapp (Expr.eop (Unop (Log (l, markings))) m) [e] m @@ -112,10 +112,10 @@ let tag_with_log_entry NOTE: the choice of the exception that will be triggered and show in the trace is arbitrary (but deterministic). *) -let collapse_similar_outcomes (type m) (excepts : m Ast.expr list) : - m Ast.expr list = +let collapse_similar_outcomes (type m) (excepts : m Scopelang.Ast.expr list) : + m Scopelang.Ast.expr list = let module ExprMap = Map.Make (struct - type t = m Ast.expr + type t = m Scopelang.Ast.expr let compare = Expr.compare end) in @@ -156,12 +156,13 @@ let thunk_scope_arg io_in e = let silent_var = Var.make "_" in let pos = Marked.get_mark io_in in match Marked.unmark io_in with - | Ast.NoInput -> invalid_arg "thunk_scope_arg" - | Ast.OnlyInput -> Expr.eerroronempty e (Marked.get_mark e) - | Ast.Reentrant -> Expr.make_abs [| silent_var |] e [TLit TUnit, pos] pos + | Desugared.Ast.NoInput -> invalid_arg "thunk_scope_arg" + | Desugared.Ast.OnlyInput -> Expr.eerroronempty e (Marked.get_mark e) + | Desugared.Ast.Reentrant -> + Expr.make_abs [| silent_var |] e [TLit TUnit, pos] pos -let rec translate_expr (ctx : 'm ctx) (e : 'm Ast.expr) : - 'm Dcalc.Ast.expr boxed = +let rec translate_expr (ctx : 'm ctx) (e : 'm Scopelang.Ast.expr) : + 'm Ast.expr boxed = let m = Marked.get_mark e in match Marked.unmark e with | EVar v -> Expr.evar (Var.Map.find v ctx.local_vars) m @@ -215,7 +216,7 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm Ast.expr) : (fun var_name str_field expr -> let expr = match str_field, expr with - | Some (_, (Ast.Reentrant, _)), None -> + | Some (_, (Desugared.Ast.Reentrant, _)), None -> Some (Expr.unbox (Expr.elit LEmptyError (mark_tany m pos))) | _ -> expr in @@ -372,10 +373,10 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm Ast.expr) : come later in the chain of let-bindings. *) let translate_rule (ctx : 'm ctx) - (rule : 'm Ast.rule) + (rule : 'm Scopelang.Ast.rule) ((sigma_name, pos_sigma) : Utils.Uid.MarkedString.info) : - ('m Dcalc.Ast.expr scope_body_expr Bindlib.box -> - 'm Dcalc.Ast.expr scope_body_expr Bindlib.box) + ('m Ast.expr scope_body_expr Bindlib.box -> + 'm Ast.expr scope_body_expr Bindlib.box) * 'm ctx = match rule with | Definition ((ScopelangScopeVar a, var_def_pos), tau, a_io, e) -> @@ -436,7 +437,9 @@ let translate_rule (VarDef (Marked.unmark tau)) [sigma_name, pos_sigma; a_name] in - let thunked_or_nonempty_new_e = thunk_scope_arg a_io.Ast.io_input new_e in + let thunked_or_nonempty_new_e = + thunk_scope_arg a_io.Desugared.Ast.io_input new_e + in ( (fun next -> Bindlib.box_apply2 (fun next thunked_or_nonempty_new_e -> @@ -478,14 +481,15 @@ let translate_rule let all_subscope_input_vars = List.filter (fun var_ctx -> - match Marked.unmark var_ctx.scope_var_io.Ast.io_input with + match Marked.unmark var_ctx.scope_var_io.Desugared.Ast.io_input with | NoInput -> false | _ -> true) all_subscope_vars in let all_subscope_output_vars = List.filter - (fun var_ctx -> Marked.unmark var_ctx.scope_var_io.Ast.io_output) + (fun var_ctx -> + Marked.unmark var_ctx.scope_var_io.Desugared.Ast.io_output) all_subscope_vars in let scope_dcalc_var = subscope_sig.scope_sig_scope_var in @@ -639,11 +643,11 @@ let translate_rule let translate_rules (ctx : 'm ctx) - (rules : 'm Ast.rule list) + (rules : 'm Scopelang.Ast.rule list) ((sigma_name, pos_sigma) : Utils.Uid.MarkedString.info) (mark : 'm mark) (scope_sig : 'm scope_sig_ctx) : - 'm Dcalc.Ast.expr scope_body_expr Bindlib.box * 'm ctx = + 'm Ast.expr scope_body_expr Bindlib.box * 'm ctx = let scope_lets, new_ctx = List.fold_left (fun (scope_lets, ctx) rule -> @@ -658,7 +662,7 @@ let translate_rules Expr.estruct scope_sig.scope_sig_output_struct (ScopeVarMap.fold (fun var (dcalc_var, _, io) acc -> - if Marked.unmark io.Ast.io_output then + if Marked.unmark io.Desugared.Ast.io_output then let field = ScopeVarMap.find var scope_sig.scope_sig_out_fields in StructFieldMap.add field (Expr.make_var dcalc_var (mark_tany mark pos_sigma)) @@ -678,8 +682,8 @@ let translate_scope_decl (enum_ctx : enum_ctx) (sctx : 'm scope_sigs_ctx) (scope_name : ScopeName.t) - (sigma : 'm Ast.scope_decl) : - 'm Dcalc.Ast.expr scope_body Bindlib.box * struct_ctx = + (sigma : 'm Scopelang.Ast.scope_decl) : + 'm Ast.expr scope_body Bindlib.box * struct_ctx = let sigma_info = ScopeName.get_info sigma.scope_decl_name in let scope_sig = ScopeMap.find sigma.scope_decl_name sctx in let scope_variables = scope_sig.scope_sig_local_vars in @@ -786,17 +790,20 @@ let translate_scope_decl (input_destructurings rules_with_return_expr)), new_struct_ctx ) -let translate_program (prgm : 'm Ast.program) : 'm Dcalc.Ast.program = - let scope_dependencies = Dependency.build_program_dep_graph prgm in - Dependency.check_for_cycle_in_scope scope_dependencies; - let scope_ordering = Dependency.get_scope_ordering scope_dependencies in +let translate_program (prgm : 'm Scopelang.Ast.program) : 'm Ast.program = + let scope_dependencies = Scopelang.Dependency.build_program_dep_graph prgm in + Scopelang.Dependency.check_for_cycle_in_scope scope_dependencies; + let scope_ordering = + Scopelang.Dependency.get_scope_ordering scope_dependencies + in let decl_ctx = prgm.program_ctx in let sctx : 'm scope_sigs_ctx = ScopeMap.mapi (fun scope_name scope -> let scope_dvar = Var.make - (Marked.unmark (ScopeName.get_info scope.Ast.scope_decl_name)) + (Marked.unmark + (ScopeName.get_info scope.Scopelang.Ast.scope_decl_name)) in let scope_return = ScopeMap.find scope_name decl_ctx.ctx_scopes in let scope_input_var = @@ -811,14 +818,14 @@ let translate_program (prgm : 'm Ast.program) : 'm Dcalc.Ast.program = let scope_sig_in_fields = ScopeVarMap.filter_map (fun dvar (_, vis) -> - match Marked.unmark vis.Ast.io_input with + match Marked.unmark vis.Desugared.Ast.io_input with | NoInput -> None | OnlyInput | Reentrant -> let info = ScopeVar.get_info dvar in let s = Marked.unmark info ^ "_in" in Some ( StructFieldName.fresh (s, Marked.get_mark info), - vis.Ast.io_input )) + vis.Desugared.Ast.io_input )) scope.scope_sig in { diff --git a/compiler/scopelang/scope_to_dcalc.mli b/compiler/dcalc/from_scopelang.mli similarity index 92% rename from compiler/scopelang/scope_to_dcalc.mli rename to compiler/dcalc/from_scopelang.mli index af8eb11a..1d60ac9a 100644 --- a/compiler/scopelang/scope_to_dcalc.mli +++ b/compiler/dcalc/from_scopelang.mli @@ -16,4 +16,4 @@ (** Scope language to default calculus translator *) -val translate_program : 'm Ast.program -> 'm Dcalc.Ast.program +val translate_program : 'm Scopelang.Ast.program -> 'm Ast.program diff --git a/compiler/desugared/ast.ml b/compiler/desugared/ast.ml index 253f136b..1a69fa9f 100644 --- a/compiler/desugared/ast.ml +++ b/compiler/desugared/ast.ml @@ -21,20 +21,6 @@ open Shared_ast (** {1 Names, Maps and Keys} *) -module IdentMap : Map.S with type key = String.t = Map.Make (String) - -module RuleName : Uid.Id with type info = Uid.MarkedString.info = - Uid.Make (Uid.MarkedString) () - -module RuleMap : Map.S with type key = RuleName.t = Map.Make (RuleName) -module RuleSet : Set.S with type elt = RuleName.t = Set.Make (RuleName) - -module LabelName : Uid.Id with type info = Uid.MarkedString.info = - Uid.Make (Uid.MarkedString) () - -module LabelMap : Map.S with type key = LabelName.t = Map.Make (LabelName) -module LabelSet : Set.S with type elt = LabelName.t = Set.Make (LabelName) - (** Inside a scope, a definition can refer either to a scope def, or a subscope def *) module ScopeDef = struct @@ -103,6 +89,9 @@ module ExprMap = Map.Make (struct let compare = Expr.compare end) +type io_input = NoInput | OnlyInput | Reentrant +type io = { io_output : bool Marked.pos; io_input : io_input Marked.pos } + type exception_situation = | BaseCase | ExceptionToLabel of LabelName.t Marked.pos @@ -192,7 +181,7 @@ type scope_def = { scope_def_rules : rule RuleMap.t; scope_def_typ : typ; scope_def_is_condition : bool; - scope_def_io : Scopelang.Ast.io; + scope_def_io : io; } type var_or_states = WholeVar | States of StateName.t list diff --git a/compiler/desugared/ast.mli b/compiler/desugared/ast.mli index 713cff92..f14aef54 100644 --- a/compiler/desugared/ast.mli +++ b/compiler/desugared/ast.mli @@ -19,16 +19,6 @@ open Utils open Shared_ast -(** {1 Names, Maps and Keys} *) - -module IdentMap : Map.S with type key = String.t -module RuleName : Uid.Id with type info = Uid.MarkedString.info -module RuleMap : Map.S with type key = RuleName.t -module RuleSet : Set.S with type elt = RuleName.t -module LabelName : Uid.Id with type info = Uid.MarkedString.info -module LabelMap : Map.S with type key = LabelName.t -module LabelSet : Set.S with type elt = LabelName.t - (** Inside a scope, a definition can refer either to a scope def, or a subscope def *) module ScopeDef : sig @@ -88,11 +78,32 @@ type meta_assertion = | FixedBy of reference_typ Marked.pos | VariesWith of unit * variation_typ Marked.pos option +(** This type characterizes the three levels of visibility for a given scope + variable with regards to the scope's input and possible redefinitions inside + the scope.. *) +type io_input = + | NoInput + (** For an internal variable defined only in the scope, and does not + appear in the input. *) + | OnlyInput + (** For variables that should not be redefined in the scope, because they + appear in the input. *) + | Reentrant + (** For variables defined in the scope that can also be redefined by the + caller as they appear in the input. *) + +type io = { + io_output : bool Marked.pos; + (** [true] is present in the output of the scope. *) + io_input : io_input Marked.pos; +} +(** Characterization of the input/output status of a scope variable. *) + type scope_def = { scope_def_rules : rule RuleMap.t; scope_def_typ : typ; scope_def_is_condition : bool; - scope_def_io : Scopelang.Ast.io; + scope_def_io : io; } type var_or_states = WholeVar | States of StateName.t list diff --git a/compiler/desugared/dependency.ml b/compiler/desugared/dependency.ml index 521ddb4d..097bb349 100644 --- a/compiler/desugared/dependency.ml +++ b/compiler/desugared/dependency.ml @@ -229,10 +229,10 @@ let build_scope_dependencies (scope : Ast.scope) : ScopeDependencies.t = (** {2 Graph declaration} *) module ExceptionVertex = struct - include Ast.RuleSet + include RuleSet let hash (x : t) : int = - Ast.RuleSet.fold (fun r acc -> Int.logxor (Ast.RuleName.hash r) acc) x 0 + RuleSet.fold (fun r acc -> Int.logxor (RuleName.hash r) acc) x 0 let equal x y = compare x y = 0 end @@ -257,13 +257,13 @@ module ExceptionsSCC = Graph.Components.Make (ExceptionsDependencies) (** {2 Graph computations} *) type exception_edge = { - label_from : Ast.LabelName.t; - label_to : Ast.LabelName.t; + label_from : LabelName.t; + label_to : LabelName.t; edge_positions : Pos.t list; } let build_exceptions_graph - (def : Ast.rule Ast.RuleMap.t) + (def : Ast.rule RuleMap.t) (def_info : Ast.ScopeDef.t) : ExceptionsDependencies.t = (* First we partition the definitions into groups bearing the same label. To handle the rules that were not labeled by the user, we create implicit @@ -271,63 +271,57 @@ let build_exceptions_graph (* All the rules of the form [definition x ...] are base case with no explicit label, so they should share this implicit label. *) - let base_case_implicit_label = - Ast.LabelName.fresh ("base_case", Pos.no_pos) - in + let base_case_implicit_label = LabelName.fresh ("base_case", Pos.no_pos) in (* When declaring [exception definition x ...], it means there is a unique rule [R] to which this can be an exception to. So we give a unique label to all the rules that are implicitly exceptions to rule [R]. *) - let exception_to_rule_implicit_labels : Ast.LabelName.t Ast.RuleMap.t = - Ast.RuleMap.fold + let exception_to_rule_implicit_labels : LabelName.t RuleMap.t = + RuleMap.fold (fun _ rule_from exception_to_rule_implicit_labels -> match rule_from.Ast.rule_exception with | Ast.ExceptionToRule (rule_to, _) -> ( - match - Ast.RuleMap.find_opt rule_to exception_to_rule_implicit_labels - with + match RuleMap.find_opt rule_to exception_to_rule_implicit_labels with | Some _ -> (* we already created the label *) exception_to_rule_implicit_labels | None -> - Ast.RuleMap.add rule_to - (Ast.LabelName.fresh - ( "exception_to_" - ^ Marked.unmark (Ast.RuleName.get_info rule_to), + RuleMap.add rule_to + (LabelName.fresh + ( "exception_to_" ^ Marked.unmark (RuleName.get_info rule_to), Pos.no_pos )) exception_to_rule_implicit_labels) | _ -> exception_to_rule_implicit_labels) - def Ast.RuleMap.empty + def RuleMap.empty in (* When declaring [exception foo_l definition x ...], the rule is exception to all the rules sharing label [foo_l]. So we give a unique label to all the rules that are implicitly exceptions to rule [foo_l]. *) - let exception_to_label_implicit_labels : Ast.LabelName.t Ast.LabelMap.t = - Ast.RuleMap.fold + let exception_to_label_implicit_labels : LabelName.t LabelMap.t = + RuleMap.fold (fun _ rule_from - (exception_to_label_implicit_labels : Ast.LabelName.t Ast.LabelMap.t) -> + (exception_to_label_implicit_labels : LabelName.t LabelMap.t) -> match rule_from.Ast.rule_exception with | Ast.ExceptionToLabel (label_to, _) -> ( match - Ast.LabelMap.find_opt label_to exception_to_label_implicit_labels + LabelMap.find_opt label_to exception_to_label_implicit_labels with | Some _ -> (* we already created the label *) exception_to_label_implicit_labels | None -> - Ast.LabelMap.add label_to - (Ast.LabelName.fresh - ( "exception_to_" - ^ Marked.unmark (Ast.LabelName.get_info label_to), + LabelMap.add label_to + (LabelName.fresh + ( "exception_to_" ^ Marked.unmark (LabelName.get_info label_to), Pos.no_pos )) exception_to_label_implicit_labels) | _ -> exception_to_label_implicit_labels) - def Ast.LabelMap.empty + def LabelMap.empty in (* Now we have all the labels necessary to partition our rules into sets, each one corresponding to a label relating to the structure of the exception DAG. *) let label_to_rule_sets = - Ast.RuleMap.fold + RuleMap.fold (fun rule_name rule rule_sets -> let label_of_rule = match rule.Ast.rule_label with @@ -336,23 +330,23 @@ let build_exceptions_graph match rule.Ast.rule_exception with | BaseCase -> base_case_implicit_label | ExceptionToRule (r, _) -> - Ast.RuleMap.find r exception_to_rule_implicit_labels + RuleMap.find r exception_to_rule_implicit_labels | ExceptionToLabel (l', _) -> - Ast.LabelMap.find l' exception_to_label_implicit_labels) + LabelMap.find l' exception_to_label_implicit_labels) in - Ast.LabelMap.update label_of_rule + LabelMap.update label_of_rule (fun rule_set -> match rule_set with - | None -> Some (Ast.RuleSet.singleton rule_name) - | Some rule_set -> Some (Ast.RuleSet.add rule_name rule_set)) + | None -> Some (RuleSet.singleton rule_name) + | Some rule_set -> Some (RuleSet.add rule_name rule_set)) rule_sets) - def Ast.LabelMap.empty + def LabelMap.empty in - let find_label_of_rule (r : Ast.RuleName.t) : Ast.LabelName.t = + let find_label_of_rule (r : RuleName.t) : LabelName.t = fst - (Ast.LabelMap.choose - (Ast.LabelMap.filter - (fun _ rule_set -> Ast.RuleSet.mem r rule_set) + (LabelMap.choose + (LabelMap.filter + (fun _ rule_set -> RuleSet.mem r rule_set) label_to_rule_sets)) in (* Next, we collect the exception edges between those groups of rules referred @@ -360,7 +354,7 @@ let build_exceptions_graph edges as they are declared at each rule but should be the same for all the rules of the same group. *) let exception_edges : exception_edge list = - Ast.RuleMap.fold + RuleMap.fold (fun rule_name rule exception_edges -> let label_from = find_label_of_rule rule_name in let label_to_and_pos = @@ -374,16 +368,16 @@ let build_exceptions_graph | Some (label_to, edge_pos) -> ( let other_edges_originating_from_same_label = List.filter - (fun edge -> Ast.LabelName.compare edge.label_from label_from = 0) + (fun edge -> LabelName.compare edge.label_from label_from = 0) exception_edges in (* We check the consistency*) - if Ast.LabelName.compare label_from label_to = 0 then + if LabelName.compare label_from label_to = 0 then Errors.raise_spanned_error edge_pos "Cannot define rule as an exception to itself"; List.iter (fun edge -> - if Ast.LabelName.compare edge.label_to label_to <> 0 then + if LabelName.compare edge.label_to label_to <> 0 then Errors.raise_multispanned_error (( Some "This declaration contradicts another exception \ @@ -401,8 +395,8 @@ let build_exceptions_graph let existing_edge = List.find_opt (fun edge -> - Ast.LabelName.compare edge.label_from label_from = 0 - && Ast.LabelName.compare edge.label_to label_to = 0) + LabelName.compare edge.label_from label_from = 0 + && LabelName.compare edge.label_to label_to = 0) exception_edges in match existing_edge with @@ -420,7 +414,7 @@ let build_exceptions_graph in (* We've got the vertices and the edges, let's build the graph! *) let g = - Ast.LabelMap.fold + LabelMap.fold (fun _label rule_set g -> ExceptionsDependencies.add_vertex g rule_set) label_to_rule_sets ExceptionsDependencies.empty in @@ -429,11 +423,9 @@ let build_exceptions_graph List.fold_left (fun g edge -> let rule_group_from = - Ast.LabelMap.find edge.label_from label_to_rule_sets - in - let rule_group_to = - Ast.LabelMap.find edge.label_to label_to_rule_sets + LabelMap.find edge.label_from label_to_rule_sets in + let rule_group_to = LabelMap.find edge.label_to label_to_rule_sets in let edge = ExceptionsDependencies.E.create rule_group_from edge.edge_positions rule_group_to @@ -453,11 +445,10 @@ let check_for_exception_cycle (g : ExceptionsDependencies.t) : unit = let spans = List.flatten (List.map - (fun (vs : Ast.RuleSet.t) -> - let v = Ast.RuleSet.choose vs in + (fun (vs : RuleSet.t) -> + let v = RuleSet.choose vs in let var_str, var_info = - ( Format.asprintf "%a" Ast.RuleName.format_t v, - Ast.RuleName.get_info v ) + Format.asprintf "%a" RuleName.format_t v, RuleName.get_info v in let succs = ExceptionsDependencies.succ_e g vs in let _, edge_pos, _ = diff --git a/compiler/desugared/dependency.mli b/compiler/desugared/dependency.mli index 6e18b56a..5487b786 100644 --- a/compiler/desugared/dependency.mli +++ b/compiler/desugared/dependency.mli @@ -18,6 +18,7 @@ OCamlgraph} *) open Utils +open Shared_ast (** {1 Scope variables dependency graph} *) @@ -71,9 +72,9 @@ val build_scope_dependencies : Ast.scope -> ScopeDependencies.t module EdgeExceptions : Graph.Sig.ORDERED_TYPE_DFT with type t = Pos.t list module ExceptionsDependencies : - Graph.Sig.P with type V.t = Ast.RuleSet.t and type E.label = EdgeExceptions.t + Graph.Sig.P with type V.t = RuleSet.t and type E.label = EdgeExceptions.t val build_exceptions_graph : - Ast.rule Ast.RuleMap.t -> Ast.ScopeDef.t -> ExceptionsDependencies.t + Ast.rule RuleMap.t -> Ast.ScopeDef.t -> ExceptionsDependencies.t val check_for_exception_cycle : ExceptionsDependencies.t -> unit diff --git a/compiler/desugared/dune b/compiler/desugared/dune index 9d5fb83f..0b066eb9 100644 --- a/compiler/desugared/dune +++ b/compiler/desugared/dune @@ -1,7 +1,7 @@ (library (name desugared) (public_name catala.desugared) - (libraries utils dcalc scopelang ocamlgraph)) + (libraries ocamlgraph utils shared_ast surface)) (documentation (package catala) diff --git a/compiler/surface/desugaring.ml b/compiler/desugared/from_surface.ml similarity index 80% rename from compiler/surface/desugaring.ml rename to compiler/desugared/from_surface.ml index 62b5c2f8..0a98c54d 100644 --- a/compiler/surface/desugaring.ml +++ b/compiler/desugared/from_surface.ml @@ -16,7 +16,7 @@ the License. *) open Utils -module SurfacePrint = Print +module SurfacePrint = Surface.Print open Shared_ast module Runtime = Runtime_ocaml.Runtime @@ -27,7 +27,7 @@ module Runtime = Runtime_ocaml.Runtime (** {1 Translating expressions} *) -let translate_op_kind (k : Ast.op_kind) : op_kind = +let translate_op_kind (k : Surface.Ast.op_kind) : op_kind = match k with | KInt -> KInt | KDec -> KRat @@ -35,7 +35,7 @@ let translate_op_kind (k : Ast.op_kind) : op_kind = | KDate -> KDate | KDuration -> KDuration -let translate_binop (op : Ast.binop) : binop = +let translate_binop (op : Surface.Ast.binop) : binop = match op with | And -> And | Or -> Or @@ -52,7 +52,7 @@ let translate_binop (op : Ast.binop) : binop = | Neq -> Neq | Concat -> Concat -let translate_unop (op : Ast.unop) : unop = +let translate_unop (op : Surface.Ast.unop) : unop = match op with Not -> Not | Minus l -> Minus (translate_op_kind l) let disambiguate_constructor @@ -68,7 +68,7 @@ let disambiguate_constructor in let possible_c_uids = try - Desugared.Ast.IdentMap.find + Name_resolution.IdentMap.find (Marked.unmark constructor) ctxt.constructor_idmap with Not_found -> @@ -111,16 +111,16 @@ let disambiguate_constructor disambiguate the scope and subscopes variables than occur in the expression *) let rec translate_expr (scope : ScopeName.t) - (inside_definition_of : Desugared.Ast.ScopeDef.t Marked.pos option) + (inside_definition_of : Ast.ScopeDef.t Marked.pos option) (ctxt : Name_resolution.context) - (expr : Ast.expression Marked.pos) : Desugared.Ast.expr boxed = + (expr : Surface.Ast.expression Marked.pos) : Ast.expr boxed = let scope_ctxt = ScopeMap.find scope ctxt.scopes in let rec_helper = translate_expr scope inside_definition_of ctxt in let pos = Marked.get_mark expr in let emark = Untyped { pos } in match Marked.unmark expr with | Binop - ( (Ast.And, _pos_op), + ( (Surface.Ast.And, _pos_op), ( TestMatchCase (e1_sub, ((constructors, Some binding), pos_pattern)), _pos_e1 ), e2 ) -> @@ -204,9 +204,9 @@ let rec translate_expr | Ident x -> ( (* first we check whether this is a local var, then we resort to scope-wide variables *) - match Desugared.Ast.IdentMap.find_opt x ctxt.local_var_idmap with + match Name_resolution.IdentMap.find_opt x ctxt.local_var_idmap with | None -> ( - match Desugared.Ast.IdentMap.find_opt x scope_ctxt.var_idmap with + match Name_resolution.IdentMap.find_opt x scope_ctxt.var_idmap with | Some (ScopeVar uid) -> (* If the referenced variable has states, then here are the rules to desambiguate. In general, only the last state can be referenced. @@ -258,7 +258,7 @@ let rec translate_expr | Ident y when Name_resolution.is_subscope_uid scope ctxt y -> (* In this case, y.x is a subscope variable *) let subscope_uid, subscope_real_uid = - match Desugared.Ast.IdentMap.find y scope_ctxt.var_idmap with + match Name_resolution.IdentMap.find y scope_ctxt.var_idmap with | SubScope (sub, sc) -> sub, sc | ScopeVar _ -> assert false in @@ -273,7 +273,7 @@ let rec translate_expr (* In this case e.x is the struct field x access of expression e *) let e = translate_expr scope inside_definition_of ctxt e in let x_possible_structs = - try Desugared.Ast.IdentMap.find (Marked.unmark x) ctxt.field_idmap + try Name_resolution.IdentMap.find (Marked.unmark x) ctxt.field_idmap with Not_found -> Errors.raise_spanned_error (Marked.get_mark x) "Unknown subscope or struct field name" @@ -314,7 +314,7 @@ let rec translate_expr (fun acc (fld_id, e) -> let var = match - Desugared.Ast.IdentMap.find_opt (Marked.unmark fld_id) + Name_resolution.IdentMap.find_opt (Marked.unmark fld_id) scope_def.var_idmap with | Some (ScopeVar v) -> v @@ -353,7 +353,7 @@ let rec translate_expr | StructLit (s_name, fields) -> let s_uid = match - Desugared.Ast.IdentMap.find_opt (Marked.unmark s_name) ctxt.typedefs + Name_resolution.IdentMap.find_opt (Marked.unmark s_name) ctxt.typedefs with | Some (Name_resolution.TStruct s_uid) -> s_uid | _ -> @@ -367,7 +367,7 @@ let rec translate_expr let f_uid = try StructMap.find s_uid - (Desugared.Ast.IdentMap.find (Marked.unmark f_name) + (Name_resolution.IdentMap.find (Marked.unmark f_name) ctxt.field_idmap) with Not_found -> Errors.raise_spanned_error (Marked.get_mark f_name) @@ -397,7 +397,7 @@ let rec translate_expr Expr.estruct s_uid s_fields emark | EnumInject (enum, (constructor, pos_constructor), payload) -> ( let possible_c_uids = - try Desugared.Ast.IdentMap.find constructor ctxt.constructor_idmap + try Name_resolution.IdentMap.find constructor ctxt.constructor_idmap with Not_found -> Errors.raise_spanned_error pos_constructor "The name of this constructor has not been defined before, maybe it \ @@ -481,7 +481,7 @@ let rec translate_expr enum_uid cases emark | ArrayLit es -> Expr.earray (List.map rec_helper es) emark | CollectionOp - ( (((Ast.Filter | Ast.Map) as op'), _pos_op'), + ( (((Surface.Ast.Filter | Surface.Ast.Map) as op'), _pos_op'), param', collection, predicate ) -> @@ -498,13 +498,14 @@ let rec translate_expr Expr.eapp (Expr.eop (match op' with - | Ast.Map -> Binop Map - | Ast.Filter -> Binop Filter + | Surface.Ast.Map -> Binop Map + | Surface.Ast.Filter -> Binop Filter | _ -> assert false (* should not happen *)) emark) [f_pred; collection] emark | CollectionOp - ( ( Ast.Aggregate (Ast.AggregateArgExtremum (max_or_min, pred_typ, init)), + ( ( Surface.Ast.Aggregate + (Surface.Ast.AggregateArgExtremum (max_or_min, pred_typ, init)), pos_op' ), param', collection, @@ -516,11 +517,11 @@ let rec translate_expr in let op_kind = match pred_typ with - | Ast.Integer -> KInt - | Ast.Decimal -> KRat - | Ast.Money -> KMoney - | Ast.Duration -> KDuration - | Ast.Date -> KDate + | Surface.Ast.Integer -> KInt + | Surface.Ast.Decimal -> KRat + | Surface.Ast.Money -> KMoney + | Surface.Ast.Duration -> KDuration + | Surface.Ast.Date -> KDate | _ -> Errors.raise_spanned_error pos "It is impossible to compute the arg-%s of two values of type %a" @@ -568,26 +569,28 @@ let rec translate_expr let mark = Untyped { pos = Marked.get_mark op' } in let init = match Marked.unmark op' with - | Ast.Map | Ast.Filter | Ast.Aggregate (Ast.AggregateArgExtremum _) -> + | Surface.Ast.Map | Surface.Ast.Filter + | Surface.Ast.Aggregate (Surface.Ast.AggregateArgExtremum _) -> assert false (* should not happen *) - | Ast.Exists -> Expr.elit (LBool false) mark - | Ast.Forall -> Expr.elit (LBool true) mark - | Ast.Aggregate (Ast.AggregateSum Ast.Integer) -> + | Surface.Ast.Exists -> Expr.elit (LBool false) mark + | Surface.Ast.Forall -> Expr.elit (LBool true) mark + | Surface.Ast.Aggregate (Surface.Ast.AggregateSum Surface.Ast.Integer) -> Expr.elit (LInt (Runtime.integer_of_int 0)) mark - | Ast.Aggregate (Ast.AggregateSum Ast.Decimal) -> + | Surface.Ast.Aggregate (Surface.Ast.AggregateSum Surface.Ast.Decimal) -> Expr.elit (LRat (Runtime.decimal_of_string "0")) mark - | Ast.Aggregate (Ast.AggregateSum Ast.Money) -> + | Surface.Ast.Aggregate (Surface.Ast.AggregateSum Surface.Ast.Money) -> Expr.elit (LMoney (Runtime.money_of_cents_integer (Runtime.integer_of_int 0))) mark - | Ast.Aggregate (Ast.AggregateSum Ast.Duration) -> + | Surface.Ast.Aggregate (Surface.Ast.AggregateSum Surface.Ast.Duration) -> Expr.elit (LDuration (Runtime.duration_of_numbers 0 0 0)) mark - | Ast.Aggregate (Ast.AggregateSum t) -> + | Surface.Ast.Aggregate (Surface.Ast.AggregateSum t) -> Errors.raise_spanned_error pos "It is impossible to sum two values of type %a together" SurfacePrint.format_primitive_typ t - | Ast.Aggregate (Ast.AggregateExtremum (_, _, init)) -> rec_helper init - | Ast.Aggregate Ast.AggregateCount -> + | Surface.Ast.Aggregate (Surface.Ast.AggregateExtremum (_, _, init)) -> + rec_helper init + | Surface.Ast.Aggregate Surface.Ast.AggregateCount -> Expr.elit (LInt (Runtime.integer_of_int 0)) mark in let acc_var = Var.make "acc" in @@ -613,25 +616,30 @@ let rec translate_expr pos in match Marked.unmark op' with - | Ast.Map | Ast.Filter | Ast.Aggregate (Ast.AggregateArgExtremum _) -> + | Surface.Ast.Map | Surface.Ast.Filter + | Surface.Ast.Aggregate (Surface.Ast.AggregateArgExtremum _) -> assert false (* should not happen *) - | Ast.Exists -> make_body Or - | Ast.Forall -> make_body And - | Ast.Aggregate (Ast.AggregateSum Ast.Integer) -> make_body (Add KInt) - | Ast.Aggregate (Ast.AggregateSum Ast.Decimal) -> make_body (Add KRat) - | Ast.Aggregate (Ast.AggregateSum Ast.Money) -> make_body (Add KMoney) - | Ast.Aggregate (Ast.AggregateSum Ast.Duration) -> + | Surface.Ast.Exists -> make_body Or + | Surface.Ast.Forall -> make_body And + | Surface.Ast.Aggregate (Surface.Ast.AggregateSum Surface.Ast.Integer) -> + make_body (Add KInt) + | Surface.Ast.Aggregate (Surface.Ast.AggregateSum Surface.Ast.Decimal) -> + make_body (Add KRat) + | Surface.Ast.Aggregate (Surface.Ast.AggregateSum Surface.Ast.Money) -> + make_body (Add KMoney) + | Surface.Ast.Aggregate (Surface.Ast.AggregateSum Surface.Ast.Duration) -> make_body (Add KDuration) - | Ast.Aggregate (Ast.AggregateSum _) -> + | Surface.Ast.Aggregate (Surface.Ast.AggregateSum _) -> assert false (* should not happen *) - | Ast.Aggregate (Ast.AggregateExtremum (max_or_min, t, _)) -> + | Surface.Ast.Aggregate (Surface.Ast.AggregateExtremum (max_or_min, t, _)) + -> let op_kind, typ = match t with - | Ast.Integer -> KInt, (TLit TInt, pos) - | Ast.Decimal -> KRat, (TLit TRat, pos) - | Ast.Money -> KMoney, (TLit TMoney, pos) - | Ast.Duration -> KDuration, (TLit TDuration, pos) - | Ast.Date -> KDate, (TLit TDate, pos) + | Surface.Ast.Integer -> KInt, (TLit TInt, pos) + | Surface.Ast.Decimal -> KRat, (TLit TRat, pos) + | Surface.Ast.Money -> KMoney, (TLit TMoney, pos) + | Surface.Ast.Duration -> KDuration, (TLit TDuration, pos) + | Surface.Ast.Date -> KDate, (TLit TDate, pos) | _ -> Errors.raise_spanned_error pos "It is impossible to compute the %s of two values of type %a" @@ -640,7 +648,7 @@ let rec translate_expr in let cmp_op = if max_or_min then Gt op_kind else Lt op_kind in make_extr_body cmp_op typ - | Ast.Aggregate Ast.AggregateCount -> + | Surface.Ast.Aggregate Surface.Ast.AggregateCount -> let predicate = translate_expr scope inside_definition_of ctxt predicate in @@ -670,26 +678,31 @@ let rec translate_expr emark in match Marked.unmark op' with - | Ast.Map | Ast.Filter | Ast.Aggregate (Ast.AggregateArgExtremum _) -> + | Surface.Ast.Map | Surface.Ast.Filter + | Surface.Ast.Aggregate (Surface.Ast.AggregateArgExtremum _) -> assert false (* should not happen *) - | Ast.Exists -> make_f TBool - | Ast.Forall -> make_f TBool - | Ast.Aggregate (Ast.AggregateSum Ast.Integer) - | Ast.Aggregate (Ast.AggregateExtremum (_, Ast.Integer, _)) -> + | Surface.Ast.Exists -> make_f TBool + | Surface.Ast.Forall -> make_f TBool + | Surface.Ast.Aggregate (Surface.Ast.AggregateSum Surface.Ast.Integer) + | Surface.Ast.Aggregate + (Surface.Ast.AggregateExtremum (_, Surface.Ast.Integer, _)) -> make_f TInt - | Ast.Aggregate (Ast.AggregateSum Ast.Decimal) - | Ast.Aggregate (Ast.AggregateExtremum (_, Ast.Decimal, _)) -> + | Surface.Ast.Aggregate (Surface.Ast.AggregateSum Surface.Ast.Decimal) + | Surface.Ast.Aggregate + (Surface.Ast.AggregateExtremum (_, Surface.Ast.Decimal, _)) -> make_f TRat - | Ast.Aggregate (Ast.AggregateSum Ast.Money) - | Ast.Aggregate (Ast.AggregateExtremum (_, Ast.Money, _)) -> + | Surface.Ast.Aggregate (Surface.Ast.AggregateSum Surface.Ast.Money) + | Surface.Ast.Aggregate + (Surface.Ast.AggregateExtremum (_, Surface.Ast.Money, _)) -> make_f TMoney - | Ast.Aggregate (Ast.AggregateSum Ast.Duration) - | Ast.Aggregate (Ast.AggregateExtremum (_, Ast.Duration, _)) -> + | Surface.Ast.Aggregate (Surface.Ast.AggregateSum Surface.Ast.Duration) + | Surface.Ast.Aggregate + (Surface.Ast.AggregateExtremum (_, Surface.Ast.Duration, _)) -> make_f TDuration - | Ast.Aggregate (Ast.AggregateSum _) - | Ast.Aggregate (Ast.AggregateExtremum _) -> + | Surface.Ast.Aggregate (Surface.Ast.AggregateSum _) + | Surface.Ast.Aggregate (Surface.Ast.AggregateExtremum _) -> assert false (* should not happen *) - | Ast.Aggregate Ast.AggregateCount -> make_f TInt + | Surface.Ast.Aggregate Surface.Ast.AggregateCount -> make_f TInt in Expr.eapp (Expr.eop (Ternop Fold) emark) [f; init; collection] emark | MemCollection (member, collection) -> @@ -727,10 +740,10 @@ let rec translate_expr and disambiguate_match_and_build_expression (scope : ScopeName.t) - (inside_definition_of : Desugared.Ast.ScopeDef.t Marked.pos option) + (inside_definition_of : Ast.ScopeDef.t Marked.pos option) (ctxt : Name_resolution.context) - (cases : Ast.match_case Marked.pos list) : - Desugared.Ast.expr boxed EnumConstructorMap.t * EnumName.t = + (cases : Surface.Ast.match_case Marked.pos list) : + Ast.expr boxed EnumConstructorMap.t * EnumName.t = let create_var = function | None -> ctxt, Var.make "_" | Some param -> @@ -752,11 +765,13 @@ and disambiguate_match_and_build_expression in let bind_match_cases (cases_d, e_uid, curr_index) (case, case_pos) = match case with - | Ast.MatchCase case -> - let constructor, binding = Marked.unmark case.Ast.match_case_pattern in + | Surface.Ast.MatchCase case -> + let constructor, binding = + Marked.unmark case.Surface.Ast.match_case_pattern + in let e_uid', c_uid = disambiguate_constructor ctxt constructor - (Marked.get_mark case.Ast.match_case_pattern) + (Marked.get_mark case.Surface.Ast.match_case_pattern) in let e_uid = match e_uid with @@ -765,7 +780,7 @@ and disambiguate_match_and_build_expression if e_uid = e_uid' then e_uid else Errors.raise_spanned_error - (Marked.get_mark case.Ast.match_case_pattern) + (Marked.get_mark case.Surface.Ast.match_case_pattern) "This case matches a constructor of enumeration %a but previous \ case were matching constructors of enumeration %a" EnumName.format_t e_uid EnumName.format_t e_uid' @@ -779,12 +794,13 @@ and disambiguate_match_and_build_expression c_uid); let ctxt, param_var = create_var (Option.map Marked.unmark binding) in let case_body = - translate_expr scope inside_definition_of ctxt case.Ast.match_case_expr + translate_expr scope inside_definition_of ctxt + case.Surface.Ast.match_case_expr in let e_binder = Expr.bind [| param_var |] case_body in let case_expr = bind_case_body c_uid e_uid ctxt case_body e_binder in EnumConstructorMap.add c_uid case_expr cases_d, Some e_uid, curr_index + 1 - | Ast.WildCard match_case_expr -> ( + | Surface.Ast.WildCard match_case_expr -> ( let nb_cases = List.length cases in let raise_wildcard_not_last_case_err () = Errors.raise_multispanned_error @@ -858,9 +874,9 @@ and disambiguate_match_and_build_expression this precondition has to be appended to the justifications of each definition in the subscope use. This is what this function does. *) let merge_conditions - (precond : Desugared.Ast.expr boxed option) - (cond : Desugared.Ast.expr boxed option) - (default_pos : Pos.t) : Desugared.Ast.expr boxed = + (precond : Ast.expr boxed option) + (cond : Ast.expr boxed option) + (default_pos : Pos.t) : Ast.expr boxed = match precond, cond with | Some precond, Some cond -> let op_term = Expr.eop (Binop And) (Marked.get_mark cond) in @@ -870,18 +886,18 @@ let merge_conditions | None, None -> Expr.elit (LBool true) (Untyped { pos = default_pos }) (** Translates a surface definition into condition into a desugared {!type: - Desugared.Ast.rule} *) + Ast.rule} *) let process_default (ctxt : Name_resolution.context) (scope : ScopeName.t) - (def_key : Desugared.Ast.ScopeDef.t Marked.pos) - (rule_id : Desugared.Ast.RuleName.t) - (param_uid : Desugared.Ast.expr Var.t Marked.pos option) - (precond : Desugared.Ast.expr boxed option) - (exception_situation : Desugared.Ast.exception_situation) - (label_situation : Desugared.Ast.label_situation) - (just : Ast.expression Marked.pos option) - (cons : Ast.expression Marked.pos) : Desugared.Ast.rule = + (def_key : Ast.ScopeDef.t Marked.pos) + (rule_id : RuleName.t) + (param_uid : Ast.expr Var.t Marked.pos option) + (precond : Ast.expr boxed option) + (exception_situation : Ast.exception_situation) + (label_situation : Ast.label_situation) + (just : Surface.Ast.expression Marked.pos option) + (cons : Surface.Ast.expression Marked.pos) : Ast.rule = let just = match just with | Some just -> Some (translate_expr scope (Some def_key) ctxt just) @@ -913,14 +929,12 @@ let process_default (** Wrapper around {!val: process_default} that performs some name disambiguation *) let process_def - (precond : Desugared.Ast.expr boxed option) + (precond : Ast.expr boxed option) (scope_uid : ScopeName.t) (ctxt : Name_resolution.context) - (prgm : Desugared.Ast.program) - (def : Ast.definition) : Desugared.Ast.program = - let scope : Desugared.Ast.scope = - ScopeMap.find scope_uid prgm.program_scopes - in + (prgm : Ast.program) + (def : Surface.Ast.definition) : Ast.program = + let scope : Ast.scope = ScopeMap.find scope_uid prgm.program_scopes in let scope_ctxt = ScopeMap.find scope_uid ctxt.scopes in let def_key = Name_resolution.get_def_key @@ -929,7 +943,7 @@ let process_def (Marked.get_mark def.definition_name) in let scope_def_ctxt = - Desugared.Ast.ScopeDefMap.find def_key scope_ctxt.scope_defs_contexts + Ast.ScopeDefMap.find def_key scope_ctxt.scope_defs_contexts in (* We add to the name resolution context the name of the parameter variable *) let param_uid, new_ctxt = @@ -942,19 +956,19 @@ let process_def Some (Marked.same_mark_as param_var param), ctxt in let scope_updated = - let scope_def = Desugared.Ast.ScopeDefMap.find def_key scope.scope_defs in + let scope_def = Ast.ScopeDefMap.find def_key scope.scope_defs in let rule_name = def.definition_id in let label_situation = match def.definition_label with | Some (label_str, label_pos) -> - Desugared.Ast.ExplicitlyLabeled - ( Desugared.Ast.IdentMap.find label_str scope_def_ctxt.label_idmap, + Ast.ExplicitlyLabeled + ( Name_resolution.IdentMap.find label_str scope_def_ctxt.label_idmap, label_pos ) - | None -> Desugared.Ast.Unlabeled + | None -> Ast.Unlabeled in let exception_situation = - match def.Ast.definition_exception_to with - | NotAnException -> Desugared.Ast.BaseCase + match def.Surface.Ast.definition_exception_to with + | NotAnException -> Ast.BaseCase | UnlabeledException -> ( match scope_def_ctxt.default_exception_rulename with | None | Some (Name_resolution.Ambiguous _) -> @@ -966,7 +980,7 @@ let process_def | ExceptionToLabel label_str -> ( try let label_id = - Desugared.Ast.IdentMap.find (Marked.unmark label_str) + Name_resolution.IdentMap.find (Marked.unmark label_str) scope_def_ctxt.label_idmap in ExceptionToLabel (label_id, Marked.get_mark label_str) @@ -974,13 +988,13 @@ let process_def Errors.raise_spanned_error (Marked.get_mark label_str) "Unknown label for the scope variable %a: \"%s\"" - Desugared.Ast.ScopeDef.format_t def_key (Marked.unmark label_str)) + Ast.ScopeDef.format_t def_key (Marked.unmark label_str)) in let scope_def = { scope_def with scope_def_rules = - Desugared.Ast.RuleMap.add rule_name + RuleMap.add rule_name (process_default new_ctxt scope_uid (def_key, Marked.get_mark def.definition_name) rule_name param_uid precond exception_situation label_situation @@ -990,8 +1004,7 @@ let process_def in { scope with - scope_defs = - Desugared.Ast.ScopeDefMap.add def_key scope_def scope.scope_defs; + scope_defs = Ast.ScopeDefMap.add def_key scope_def scope.scope_defs; } in { @@ -1001,33 +1014,32 @@ let process_def (** Translates a {!type: Surface.Ast.rule} from the surface language *) let process_rule - (precond : Desugared.Ast.expr boxed option) + (precond : Ast.expr boxed option) (scope : ScopeName.t) (ctxt : Name_resolution.context) - (prgm : Desugared.Ast.program) - (rule : Ast.rule) : Desugared.Ast.program = - let def = Ast.rule_to_def rule in + (prgm : Ast.program) + (rule : Surface.Ast.rule) : Ast.program = + let def = Surface.Ast.rule_to_def rule in process_def precond scope ctxt prgm def (** Translates assertions *) let process_assert - (precond : Desugared.Ast.expr boxed option) + (precond : Ast.expr boxed option) (scope_uid : ScopeName.t) (ctxt : Name_resolution.context) - (prgm : Desugared.Ast.program) - (ass : Ast.assertion) : Desugared.Ast.program = - let scope : Desugared.Ast.scope = - ScopeMap.find scope_uid prgm.program_scopes - in + (prgm : Ast.program) + (ass : Surface.Ast.assertion) : Ast.program = + let scope : Ast.scope = ScopeMap.find scope_uid prgm.program_scopes in let ass = translate_expr scope_uid None ctxt - (match ass.Ast.assertion_condition with - | None -> ass.Ast.assertion_content + (match ass.Surface.Ast.assertion_condition with + | None -> ass.Surface.Ast.assertion_content | Some cond -> - ( Ast.IfThenElse + ( Surface.Ast.IfThenElse ( cond, - ass.Ast.assertion_content, - Marked.same_mark_as (Ast.Literal (Ast.LBool true)) cond ), + ass.Surface.Ast.assertion_content, + Marked.same_mark_as (Surface.Ast.Literal (Surface.Ast.LBool true)) + cond ), Marked.get_mark cond )) in let ass = @@ -1048,16 +1060,16 @@ let process_assert (** Translates a surface definition, rule or assertion *) let process_scope_use_item - (precond : Ast.expression Marked.pos option) + (precond : Surface.Ast.expression Marked.pos option) (scope : ScopeName.t) (ctxt : Name_resolution.context) - (prgm : Desugared.Ast.program) - (item : Ast.scope_use_item Marked.pos) : Desugared.Ast.program = + (prgm : Ast.program) + (item : Surface.Ast.scope_use_item Marked.pos) : Ast.program = let precond = Option.map (translate_expr scope None ctxt) precond in match Marked.unmark item with - | Ast.Rule rule -> process_rule precond scope ctxt prgm rule - | Ast.Definition def -> process_def precond scope ctxt prgm def - | Ast.Assertion ass -> process_assert precond scope ctxt prgm ass + | Surface.Ast.Rule rule -> process_rule precond scope ctxt prgm rule + | Surface.Ast.Definition def -> process_def precond scope ctxt prgm def + | Surface.Ast.Assertion ass -> process_assert precond scope ctxt prgm ass | _ -> prgm (** {1 Translating top-level items} *) @@ -1067,19 +1079,19 @@ let process_scope_use_item let check_unlabeled_exception (scope : ScopeName.t) (ctxt : Name_resolution.context) - (item : Ast.scope_use_item Marked.pos) : unit = + (item : Surface.Ast.scope_use_item Marked.pos) : unit = let scope_ctxt = ScopeMap.find scope ctxt.scopes in match Marked.unmark item with - | Ast.Rule _ | Ast.Definition _ -> ( + | Surface.Ast.Rule _ | Surface.Ast.Definition _ -> ( let def_key, exception_to = match Marked.unmark item with - | Ast.Rule rule -> + | Surface.Ast.Rule rule -> ( Name_resolution.get_def_key (Marked.unmark rule.rule_name) rule.rule_state scope ctxt (Marked.get_mark rule.rule_name), rule.rule_exception_to ) - | Ast.Definition def -> + | Surface.Ast.Definition def -> ( Name_resolution.get_def_key (Marked.unmark def.definition_name) def.definition_state scope ctxt @@ -1089,13 +1101,13 @@ let check_unlabeled_exception (* should not happen *) in let scope_def_ctxt = - Desugared.Ast.ScopeDefMap.find def_key scope_ctxt.scope_defs_contexts + Ast.ScopeDefMap.find def_key scope_ctxt.scope_defs_contexts in match exception_to with - | Ast.NotAnException | Ast.ExceptionToLabel _ -> () + | Surface.Ast.NotAnException | Surface.Ast.ExceptionToLabel _ -> () (* If this is an unlabeled exception, we check that it has a unique default definition *) - | Ast.UnlabeledException -> ( + | Surface.Ast.UnlabeledException -> ( match scope_def_ctxt.default_exception_rulename with | None -> Errors.raise_spanned_error (Marked.get_mark item) @@ -1112,8 +1124,8 @@ let check_unlabeled_exception (** Translates a surface scope use, which is a bunch of definitions *) let process_scope_use (ctxt : Name_resolution.context) - (prgm : Desugared.Ast.program) - (use : Ast.scope_use) : Desugared.Ast.program = + (prgm : Ast.program) + (use : Surface.Ast.scope_use) : Ast.program = let scope_uid = Name_resolution.get_scope ctxt use.scope_use_name in (* Make sure the scope exists *) let prgm = @@ -1128,24 +1140,24 @@ let process_scope_use (process_scope_use_item precond scope_uid ctxt) prgm use.scope_use_items -let attribute_to_io (attr : Ast.scope_decl_context_io) : Scopelang.Ast.io = +let attribute_to_io (attr : Surface.Ast.scope_decl_context_io) : Ast.io = { - Scopelang.Ast.io_output = attr.scope_decl_context_io_output; - Scopelang.Ast.io_input = + Ast.io_output = attr.scope_decl_context_io_output; + Ast.io_input = Marked.map_under_mark (fun io -> match io with - | Ast.Input -> Scopelang.Ast.OnlyInput - | Ast.Internal -> Scopelang.Ast.NoInput - | Ast.Context -> Scopelang.Ast.Reentrant) + | Surface.Ast.Input -> Ast.OnlyInput + | Surface.Ast.Internal -> Ast.NoInput + | Surface.Ast.Context -> Ast.Reentrant) attr.scope_decl_context_io_input; } let init_scope_defs (ctxt : Name_resolution.context) (scope_idmap : - Name_resolution.scope_var_or_subscope Desugared.Ast.IdentMap.t) : - Desugared.Ast.scope_def Desugared.Ast.ScopeDefMap.t = + Name_resolution.scope_var_or_subscope Name_resolution.IdentMap.t) : + Ast.scope_def Ast.ScopeDefMap.t = (* Initializing the definitions of all scopes and subscope vars, with no rules yet inside *) let add_def _ v scope_def_map = @@ -1154,27 +1166,26 @@ let init_scope_defs let v_sig = ScopeVarMap.find v ctxt.Name_resolution.var_typs in match v_sig.var_sig_states_list with | [] -> - let def_key = Desugared.Ast.ScopeDef.Var (v, None) in - Desugared.Ast.ScopeDefMap.add def_key + let def_key = Ast.ScopeDef.Var (v, None) in + Ast.ScopeDefMap.add def_key { - Desugared.Ast.scope_def_rules = Desugared.Ast.RuleMap.empty; - Desugared.Ast.scope_def_typ = v_sig.var_sig_typ; - Desugared.Ast.scope_def_is_condition = v_sig.var_sig_is_condition; - Desugared.Ast.scope_def_io = attribute_to_io v_sig.var_sig_io; + Ast.scope_def_rules = RuleMap.empty; + Ast.scope_def_typ = v_sig.var_sig_typ; + Ast.scope_def_is_condition = v_sig.var_sig_is_condition; + Ast.scope_def_io = attribute_to_io v_sig.var_sig_io; } scope_def_map | states -> let scope_def, _ = List.fold_left (fun (acc, i) state -> - let def_key = Desugared.Ast.ScopeDef.Var (v, Some state) in + let def_key = Ast.ScopeDef.Var (v, Some state) in let def = { - Desugared.Ast.scope_def_rules = Desugared.Ast.RuleMap.empty; - Desugared.Ast.scope_def_typ = v_sig.var_sig_typ; - Desugared.Ast.scope_def_is_condition = - v_sig.var_sig_is_condition; - Desugared.Ast.scope_def_io = + Ast.scope_def_rules = RuleMap.empty; + Ast.scope_def_typ = v_sig.var_sig_typ; + Ast.scope_def_is_condition = v_sig.var_sig_is_condition; + Ast.scope_def_io = (* The first state should have the input I/O of the original variable, and the last state should have the output I/O of the original variable. All intermediate states shall @@ -1183,8 +1194,7 @@ let init_scope_defs let io_input = if i = 0 then original_io.io_input else - ( Scopelang.Ast.NoInput, - Marked.get_mark (StateName.get_info state) ) + Ast.NoInput, Marked.get_mark (StateName.get_info state) in let io_output = if i = List.length states - 1 then original_io.io_output @@ -1193,7 +1203,7 @@ let init_scope_defs { io_input; io_output }); } in - Desugared.Ast.ScopeDefMap.add def_key def acc, i + 1) + Ast.ScopeDefMap.add def_key def acc, i + 1) (scope_def_map, 0) states in scope_def) @@ -1201,7 +1211,7 @@ let init_scope_defs let sub_scope_def = ScopeMap.find subscope_uid ctxt.Name_resolution.scopes in - Desugared.Ast.IdentMap.fold + Name_resolution.IdentMap.fold (fun _ v scope_def_map -> match v with | Name_resolution.SubScope _ -> scope_def_map @@ -1210,45 +1220,43 @@ let init_scope_defs ? *) let v_sig = ScopeVarMap.find v ctxt.Name_resolution.var_typs in let def_key = - Desugared.Ast.ScopeDef.SubScopeVar + Ast.ScopeDef.SubScopeVar (v0, v, Marked.get_mark (ScopeVar.get_info v)) in - Desugared.Ast.ScopeDefMap.add def_key + Ast.ScopeDefMap.add def_key { - Desugared.Ast.scope_def_rules = Desugared.Ast.RuleMap.empty; - Desugared.Ast.scope_def_typ = v_sig.var_sig_typ; - Desugared.Ast.scope_def_is_condition = - v_sig.var_sig_is_condition; - Desugared.Ast.scope_def_io = attribute_to_io v_sig.var_sig_io; + Ast.scope_def_rules = RuleMap.empty; + Ast.scope_def_typ = v_sig.var_sig_typ; + Ast.scope_def_is_condition = v_sig.var_sig_is_condition; + Ast.scope_def_io = attribute_to_io v_sig.var_sig_io; } scope_def_map) sub_scope_def.Name_resolution.var_idmap scope_def_map in - Desugared.Ast.IdentMap.fold add_def scope_idmap - Desugared.Ast.ScopeDefMap.empty + Name_resolution.IdentMap.fold add_def scope_idmap Ast.ScopeDefMap.empty (** Main function of this module *) -let desugar_program (ctxt : Name_resolution.context) (prgm : Ast.program) : - Desugared.Ast.program = +let translate_program + (ctxt : Name_resolution.context) + (prgm : Surface.Ast.program) : Ast.program = let empty_prgm = let program_scopes = ScopeMap.mapi (fun s_uid s_context -> let scope_vars = - Desugared.Ast.IdentMap.fold + Name_resolution.IdentMap.fold (fun _ v acc -> match v with | Name_resolution.SubScope _ -> acc | Name_resolution.ScopeVar v -> ( let v_sig = ScopeVarMap.find v ctxt.var_typs in match v_sig.var_sig_states_list with - | [] -> ScopeVarMap.add v Desugared.Ast.WholeVar acc - | states -> - ScopeVarMap.add v (Desugared.Ast.States states) acc)) + | [] -> ScopeVarMap.add v Ast.WholeVar acc + | states -> ScopeVarMap.add v (Ast.States states) acc)) s_context.Name_resolution.var_idmap ScopeVarMap.empty in let scope_sub_scopes = - Desugared.Ast.IdentMap.fold + Name_resolution.IdentMap.fold (fun _ v acc -> match v with | Name_resolution.ScopeVar _ -> acc @@ -1257,7 +1265,7 @@ let desugar_program (ctxt : Name_resolution.context) (prgm : Ast.program) : s_context.Name_resolution.var_idmap SubScopeMap.empty in { - Desugared.Ast.scope_vars; + Ast.scope_vars; scope_sub_scopes; scope_defs = init_scope_defs ctxt s_context.var_idmap; scope_assertions = []; @@ -1267,12 +1275,12 @@ let desugar_program (ctxt : Name_resolution.context) (prgm : Ast.program) : ctxt.Name_resolution.scopes in { - Desugared.Ast.program_ctx = + Ast.program_ctx = { ctx_structs = ctxt.Name_resolution.structs; ctx_enums = ctxt.Name_resolution.enums; ctx_scopes = - Desugared.Ast.IdentMap.fold + Name_resolution.IdentMap.fold (fun _ def acc -> match def with | Name_resolution.TScope (scope, scope_out_struct) -> @@ -1280,12 +1288,12 @@ let desugar_program (ctxt : Name_resolution.context) (prgm : Ast.program) : | _ -> acc) ctxt.Name_resolution.typedefs ScopeMap.empty; }; - Desugared.Ast.program_scopes; + Ast.program_scopes; } in let rec processer_structure - (prgm : Desugared.Ast.program) - (item : Ast.law_structure) : Desugared.Ast.program = + (prgm : Ast.program) + (item : Surface.Ast.law_structure) : Ast.program = match item with | LawHeading (_, children) -> List.fold_left @@ -1295,7 +1303,7 @@ let desugar_program (ctxt : Name_resolution.context) (prgm : Ast.program) : List.fold_left (fun prgm item -> match Marked.unmark item with - | Ast.ScopeUse use -> process_scope_use ctxt prgm use + | Surface.Ast.ScopeUse use -> process_scope_use ctxt prgm use | _ -> prgm) prgm block | LawInclude _ | LawText _ -> prgm diff --git a/compiler/surface/desugaring.mli b/compiler/desugared/from_surface.mli similarity index 91% rename from compiler/surface/desugaring.mli rename to compiler/desugared/from_surface.mli index 28f81ce7..e6d6f050 100644 --- a/compiler/surface/desugaring.mli +++ b/compiler/desugared/from_surface.mli @@ -20,6 +20,6 @@ - Removes syntactic sugars - Separate code from legislation *) -val desugar_program : - Name_resolution.context -> Ast.program -> Desugared.Ast.program +val translate_program : + Name_resolution.context -> Surface.Ast.program -> Ast.program (** Main function of this module *) diff --git a/compiler/surface/name_resolution.ml b/compiler/desugared/name_resolution.ml similarity index 78% rename from compiler/surface/name_resolution.ml rename to compiler/desugared/name_resolution.ml index fdd464a6..284e8de9 100644 --- a/compiler/surface/name_resolution.ml +++ b/compiler/desugared/name_resolution.ml @@ -25,13 +25,15 @@ open Shared_ast type ident = string +module IdentMap : Map.S with type key = String.t = Map.Make (String) + type unique_rulename = | Ambiguous of Pos.t list - | Unique of Desugared.Ast.RuleName.t Marked.pos + | Unique of RuleName.t Marked.pos type scope_def_context = { default_exception_rulename : unique_rulename option; - label_idmap : Desugared.Ast.LabelName.t Desugared.Ast.IdentMap.t; + label_idmap : LabelName.t IdentMap.t; } type scope_var_or_subscope = @@ -39,9 +41,9 @@ type scope_var_or_subscope = | SubScope of SubScopeName.t * ScopeName.t type scope_context = { - var_idmap : scope_var_or_subscope Desugared.Ast.IdentMap.t; + var_idmap : scope_var_or_subscope IdentMap.t; (** All variables, including scope variables and subscopes *) - scope_defs_contexts : scope_def_context Desugared.Ast.ScopeDefMap.t; + scope_defs_contexts : scope_def_context Ast.ScopeDefMap.t; (** What is the default rule to refer to for unnamed exceptions, if any *) sub_scopes : ScopeSet.t; (** Other scopes referred to by this scope. Used for dependency analysis *) @@ -57,8 +59,8 @@ type enum_context = typ EnumConstructorMap.t type var_sig = { var_sig_typ : typ; var_sig_is_condition : bool; - var_sig_io : Ast.scope_decl_context_io; - var_sig_states_idmap : StateName.t Desugared.Ast.IdentMap.t; + var_sig_io : Surface.Ast.scope_decl_context_io; + var_sig_states_idmap : StateName.t IdentMap.t; var_sig_states_list : StateName.t list; } @@ -71,15 +73,15 @@ type typedef = (** Implicitly defined output struct *) type context = { - local_var_idmap : Desugared.Ast.expr Var.t Desugared.Ast.IdentMap.t; + local_var_idmap : Ast.expr Var.t IdentMap.t; (** Inside a definition, local variables can be introduced by functions arguments or pattern matching *) - typedefs : typedef Desugared.Ast.IdentMap.t; + typedefs : typedef IdentMap.t; (** Gathers the names of the scopes, structs and enums *) - field_idmap : StructFieldName.t StructMap.t Desugared.Ast.IdentMap.t; + field_idmap : StructFieldName.t StructMap.t IdentMap.t; (** The names of the struct fields. Names of fields can be shared between different structs *) - constructor_idmap : EnumConstructor.t EnumMap.t Desugared.Ast.IdentMap.t; + constructor_idmap : EnumConstructor.t EnumMap.t IdentMap.t; (** The names of the enum constructors. Constructor names can be shared between different enums *) scopes : scope_context ScopeMap.t; (** For each scope, its context *) @@ -112,7 +114,8 @@ let get_var_typ (ctxt : context) (uid : ScopeVar.t) : typ = let is_var_cond (ctxt : context) (uid : ScopeVar.t) : bool = (ScopeVarMap.find uid ctxt.var_typs).var_sig_is_condition -let get_var_io (ctxt : context) (uid : ScopeVar.t) : Ast.scope_decl_context_io = +let get_var_io (ctxt : context) (uid : ScopeVar.t) : + Surface.Ast.scope_decl_context_io = (ScopeVarMap.find uid ctxt.var_typs).var_sig_io (** Get the variable uid inside the scope given in argument *) @@ -121,7 +124,7 @@ let get_var_uid (ctxt : context) ((x, pos) : ident Marked.pos) : ScopeVar.t = let scope = ScopeMap.find scope_uid ctxt.scopes in - match Desugared.Ast.IdentMap.find_opt x scope.var_idmap with + match IdentMap.find_opt x scope.var_idmap with | Some (ScopeVar uid) -> uid | _ -> raise_unknown_identifier @@ -134,7 +137,7 @@ let get_subscope_uid (ctxt : context) ((y, pos) : ident Marked.pos) : SubScopeName.t = let scope = ScopeMap.find scope_uid ctxt.scopes in - match Desugared.Ast.IdentMap.find_opt y scope.var_idmap with + match IdentMap.find_opt y scope.var_idmap with | Some (SubScope (sub_uid, _sub_id)) -> sub_uid | _ -> raise_unknown_identifier "for a subscope of this scope" (y, pos) @@ -143,7 +146,7 @@ let get_subscope_uid let is_subscope_uid (scope_uid : ScopeName.t) (ctxt : context) (y : ident) : bool = let scope = ScopeMap.find scope_uid ctxt.scopes in - match Desugared.Ast.IdentMap.find_opt y scope.var_idmap with + match IdentMap.find_opt y scope.var_idmap with | Some (SubScope _) -> true | _ -> false @@ -151,31 +154,31 @@ let is_subscope_uid (scope_uid : ScopeName.t) (ctxt : context) (y : ident) : let belongs_to (ctxt : context) (uid : ScopeVar.t) (scope_uid : ScopeName.t) : bool = let scope = ScopeMap.find scope_uid ctxt.scopes in - Desugared.Ast.IdentMap.exists + IdentMap.exists (fun _ -> function | ScopeVar var_uid -> ScopeVar.equal uid var_uid | _ -> false) scope.var_idmap (** Retrieves the type of a scope definition from the context *) -let get_def_typ (ctxt : context) (def : Desugared.Ast.ScopeDef.t) : typ = +let get_def_typ (ctxt : context) (def : Ast.ScopeDef.t) : typ = match def with - | Desugared.Ast.ScopeDef.SubScopeVar (_, x, _) + | Ast.ScopeDef.SubScopeVar (_, x, _) (* we don't need to look at the subscope prefix because [x] is already the uid referring back to the original subscope *) - | Desugared.Ast.ScopeDef.Var (x, _) -> + | Ast.ScopeDef.Var (x, _) -> get_var_typ ctxt x -let is_def_cond (ctxt : context) (def : Desugared.Ast.ScopeDef.t) : bool = +let is_def_cond (ctxt : context) (def : Ast.ScopeDef.t) : bool = match def with - | Desugared.Ast.ScopeDef.SubScopeVar (_, x, _) + | Ast.ScopeDef.SubScopeVar (_, x, _) (* we don't need to look at the subscope prefix because [x] is already the uid referring back to the original subscope *) - | Desugared.Ast.ScopeDef.Var (x, _) -> + | Ast.ScopeDef.Var (x, _) -> is_var_cond ctxt x let get_enum ctxt id = - match Desugared.Ast.IdentMap.find (Marked.unmark id) ctxt.typedefs with + match IdentMap.find (Marked.unmark id) ctxt.typedefs with | TEnum id -> id | TStruct sid -> Errors.raise_multispanned_error @@ -196,7 +199,7 @@ let get_enum ctxt id = (Marked.unmark id) let get_struct ctxt id = - match Desugared.Ast.IdentMap.find (Marked.unmark id) ctxt.typedefs with + match IdentMap.find (Marked.unmark id) ctxt.typedefs with | TStruct id | TScope (_, { out_struct_name = id; _ }) -> id | TEnum eid -> Errors.raise_multispanned_error @@ -210,7 +213,7 @@ let get_struct ctxt id = (Marked.unmark id) let get_scope ctxt id = - match Desugared.Ast.IdentMap.find (Marked.unmark id) ctxt.typedefs with + match IdentMap.find (Marked.unmark id) ctxt.typedefs with | TScope (id, _) -> id | TEnum eid -> Errors.raise_multispanned_error @@ -236,11 +239,11 @@ let get_scope ctxt id = let process_subscope_decl (scope : ScopeName.t) (ctxt : context) - (decl : Ast.scope_decl_context_scope) : context = + (decl : Surface.Ast.scope_decl_context_scope) : context = let name, name_pos = decl.scope_decl_context_scope_name in let subscope, s_pos = decl.scope_decl_context_scope_sub_scope in let scope_ctxt = ScopeMap.find scope ctxt.scopes in - match Desugared.Ast.IdentMap.find_opt subscope scope_ctxt.var_idmap with + match IdentMap.find_opt subscope scope_ctxt.var_idmap with | Some use -> let info = match use with @@ -261,7 +264,7 @@ let process_subscope_decl { scope_ctxt with var_idmap = - Desugared.Ast.IdentMap.add name + IdentMap.add name (SubScope (sub_scope_uid, original_subscope_uid)) scope_ctxt.var_idmap; sub_scopes = ScopeSet.add original_subscope_uid scope_ctxt.sub_scopes; @@ -269,34 +272,35 @@ let process_subscope_decl in { ctxt with scopes = ScopeMap.add scope scope_ctxt ctxt.scopes } -let is_type_cond ((typ, _) : Ast.typ) = +let is_type_cond ((typ, _) : Surface.Ast.typ) = match typ with - | Ast.Base Ast.Condition - | Ast.Func { arg_typ = _; return_typ = Ast.Condition, _ } -> + | Surface.Ast.Base Surface.Ast.Condition + | Surface.Ast.Func { arg_typ = _; return_typ = Surface.Ast.Condition, _ } -> true | _ -> false (** Process a basic type (all types except function types) *) let rec process_base_typ (ctxt : context) - ((typ, typ_pos) : Ast.base_typ Marked.pos) : typ = + ((typ, typ_pos) : Surface.Ast.base_typ Marked.pos) : typ = match typ with - | Ast.Condition -> TLit TBool, typ_pos - | Ast.Data (Ast.Collection t) -> + | Surface.Ast.Condition -> TLit TBool, typ_pos + | Surface.Ast.Data (Surface.Ast.Collection t) -> ( TArray - (process_base_typ ctxt (Ast.Data (Marked.unmark t), Marked.get_mark t)), + (process_base_typ ctxt + (Surface.Ast.Data (Marked.unmark t), Marked.get_mark t)), typ_pos ) - | Ast.Data (Ast.Primitive prim) -> ( + | Surface.Ast.Data (Surface.Ast.Primitive prim) -> ( match prim with - | Ast.Integer -> TLit TInt, typ_pos - | Ast.Decimal -> TLit TRat, typ_pos - | Ast.Money -> TLit TMoney, typ_pos - | Ast.Duration -> TLit TDuration, typ_pos - | Ast.Date -> TLit TDate, typ_pos - | Ast.Boolean -> TLit TBool, typ_pos - | Ast.Text -> raise_unsupported_feature "text type" typ_pos - | Ast.Named ident -> ( - match Desugared.Ast.IdentMap.find_opt ident ctxt.typedefs with + | Surface.Ast.Integer -> TLit TInt, typ_pos + | Surface.Ast.Decimal -> TLit TRat, typ_pos + | Surface.Ast.Money -> TLit TMoney, typ_pos + | Surface.Ast.Duration -> TLit TDuration, typ_pos + | Surface.Ast.Date -> TLit TDate, typ_pos + | Surface.Ast.Boolean -> TLit TBool, typ_pos + | Surface.Ast.Text -> raise_unsupported_feature "text type" typ_pos + | Surface.Ast.Named ident -> ( + match IdentMap.find_opt ident ctxt.typedefs with | Some (TStruct s_uid) -> TStruct s_uid, typ_pos | Some (TEnum e_uid) -> TEnum e_uid, typ_pos | Some (TScope (_, scope_str)) -> @@ -308,10 +312,11 @@ let rec process_base_typ ident)) (** Process a type (function or not) *) -let process_type (ctxt : context) ((naked_typ, typ_pos) : Ast.typ) : typ = +let process_type (ctxt : context) ((naked_typ, typ_pos) : Surface.Ast.typ) : typ + = match naked_typ with - | Ast.Base base_typ -> process_base_typ ctxt (base_typ, typ_pos) - | Ast.Func { arg_typ; return_typ } -> + | Surface.Ast.Base base_typ -> process_base_typ ctxt (base_typ, typ_pos) + | Surface.Ast.Func { arg_typ; return_typ } -> ( TArrow (process_base_typ ctxt arg_typ, process_base_typ ctxt return_typ), typ_pos ) @@ -319,13 +324,13 @@ let process_type (ctxt : context) ((naked_typ, typ_pos) : Ast.typ) : typ = let process_data_decl (scope : ScopeName.t) (ctxt : context) - (decl : Ast.scope_decl_context_data) : context = + (decl : Surface.Ast.scope_decl_context_data) : context = (* First check the type of the context data *) let data_typ = process_type ctxt decl.scope_decl_context_item_typ in let is_cond = is_type_cond decl.scope_decl_context_item_typ in let name, pos = decl.scope_decl_context_item_name in let scope_ctxt = ScopeMap.find scope ctxt.scopes in - match Desugared.Ast.IdentMap.find_opt name scope_ctxt.var_idmap with + match IdentMap.find_opt name scope_ctxt.var_idmap with | Some use -> let info = match use with @@ -342,19 +347,16 @@ let process_data_decl let scope_ctxt = { scope_ctxt with - var_idmap = - Desugared.Ast.IdentMap.add name (ScopeVar uid) scope_ctxt.var_idmap; + var_idmap = IdentMap.add name (ScopeVar uid) scope_ctxt.var_idmap; } in let states_idmap, states_list = List.fold_right (fun state_id (states_idmap, states_list) -> let state_uid = StateName.fresh state_id in - ( Desugared.Ast.IdentMap.add (Marked.unmark state_id) state_uid - states_idmap, + ( IdentMap.add (Marked.unmark state_id) state_uid states_idmap, state_uid :: states_list )) - decl.scope_decl_context_item_states - (Desugared.Ast.IdentMap.empty, []) + decl.scope_decl_context_item_states (IdentMap.empty, []) in { ctxt with @@ -372,20 +374,20 @@ let process_data_decl } (** Adds a binding to the context *) -let add_def_local_var (ctxt : context) (name : ident) : - context * Desugared.Ast.expr Var.t = +let add_def_local_var (ctxt : context) (name : ident) : context * Ast.expr Var.t + = let local_var_uid = Var.make name in let ctxt = { ctxt with - local_var_idmap = - Desugared.Ast.IdentMap.add name local_var_uid ctxt.local_var_idmap; + local_var_idmap = IdentMap.add name local_var_uid ctxt.local_var_idmap; } in ctxt, local_var_uid (** Process a struct declaration *) -let process_struct_decl (ctxt : context) (sdecl : Ast.struct_decl) : context = +let process_struct_decl (ctxt : context) (sdecl : Surface.Ast.struct_decl) : + context = let s_uid = get_struct ctxt sdecl.struct_decl_name in if sdecl.struct_decl_fields = [] then Errors.raise_spanned_error @@ -395,13 +397,15 @@ let process_struct_decl (ctxt : context) (sdecl : Ast.struct_decl) : context = (Marked.unmark sdecl.struct_decl_name); List.fold_left (fun ctxt (fdecl, _) -> - let f_uid = StructFieldName.fresh fdecl.Ast.struct_decl_field_name in + let f_uid = + StructFieldName.fresh fdecl.Surface.Ast.struct_decl_field_name + in let ctxt = { ctxt with field_idmap = - Desugared.Ast.IdentMap.update - (Marked.unmark fdecl.Ast.struct_decl_field_name) + IdentMap.update + (Marked.unmark fdecl.Surface.Ast.struct_decl_field_name) (fun uids -> match uids with | None -> Some (StructMap.singleton s_uid f_uid) @@ -418,18 +422,19 @@ let process_struct_decl (ctxt : context) (sdecl : Ast.struct_decl) : context = | None -> Some (StructFieldMap.singleton f_uid - (process_type ctxt fdecl.Ast.struct_decl_field_typ)) + (process_type ctxt fdecl.Surface.Ast.struct_decl_field_typ)) | Some fields -> Some (StructFieldMap.add f_uid - (process_type ctxt fdecl.Ast.struct_decl_field_typ) + (process_type ctxt fdecl.Surface.Ast.struct_decl_field_typ) fields)) ctxt.structs; }) ctxt sdecl.struct_decl_fields (** Process an enum declaration *) -let process_enum_decl (ctxt : context) (edecl : Ast.enum_decl) : context = +let process_enum_decl (ctxt : context) (edecl : Surface.Ast.enum_decl) : context + = let e_uid = get_enum ctxt edecl.enum_decl_name in if List.length edecl.enum_decl_cases = 0 then Errors.raise_spanned_error @@ -439,13 +444,13 @@ let process_enum_decl (ctxt : context) (edecl : Ast.enum_decl) : context = (Marked.unmark edecl.enum_decl_name); List.fold_left (fun ctxt (cdecl, cdecl_pos) -> - let c_uid = EnumConstructor.fresh cdecl.Ast.enum_decl_case_name in + let c_uid = EnumConstructor.fresh cdecl.Surface.Ast.enum_decl_case_name in let ctxt = { ctxt with constructor_idmap = - Desugared.Ast.IdentMap.update - (Marked.unmark cdecl.Ast.enum_decl_case_name) + IdentMap.update + (Marked.unmark cdecl.Surface.Ast.enum_decl_case_name) (fun uids -> match uids with | None -> Some (EnumMap.singleton e_uid c_uid) @@ -459,7 +464,7 @@ let process_enum_decl (ctxt : context) (edecl : Ast.enum_decl) : context = EnumMap.update e_uid (fun cases -> let typ = - match cdecl.Ast.enum_decl_case_typ with + match cdecl.Surface.Ast.enum_decl_case_typ with | None -> TLit TUnit, cdecl_pos | Some typ -> process_type ctxt typ in @@ -474,13 +479,15 @@ let process_enum_decl (ctxt : context) (edecl : Ast.enum_decl) : context = let process_item_decl (scope : ScopeName.t) (ctxt : context) - (decl : Ast.scope_decl_context_item) : context = + (decl : Surface.Ast.scope_decl_context_item) : context = match decl with - | Ast.ContextData data_decl -> process_data_decl scope ctxt data_decl - | Ast.ContextScope sub_decl -> process_subscope_decl scope ctxt sub_decl + | Surface.Ast.ContextData data_decl -> process_data_decl scope ctxt data_decl + | Surface.Ast.ContextScope sub_decl -> + process_subscope_decl scope ctxt sub_decl (** Process a scope declaration *) -let process_scope_decl (ctxt : context) (decl : Ast.scope_decl) : context = +let process_scope_decl (ctxt : context) (decl : Surface.Ast.scope_decl) : + context = let scope_uid = get_scope ctxt decl.scope_decl_name in let ctxt = List.fold_left @@ -492,7 +499,7 @@ let process_scope_decl (ctxt : context) (decl : Ast.scope_decl) : context = List.fold_right (fun item acc -> match Marked.unmark item with - | Ast.ContextData + | Surface.Ast.ContextData ({ scope_decl_context_item_attribute = { scope_decl_context_io_output = true, _; _ }; @@ -500,8 +507,10 @@ let process_scope_decl (ctxt : context) (decl : Ast.scope_decl) : context = } as data) -> Marked.mark (Marked.get_mark item) { - Ast.struct_decl_field_name = data.scope_decl_context_item_name; - Ast.struct_decl_field_typ = data.scope_decl_context_item_typ; + Surface.Ast.struct_decl_field_name = + data.scope_decl_context_item_name; + Surface.Ast.struct_decl_field_typ = + data.scope_decl_context_item_typ; } :: acc | _ -> acc) @@ -528,22 +537,21 @@ let process_scope_decl (ctxt : context) (decl : Ast.scope_decl) : context = let out_struct_fields = let sco = ScopeMap.find scope_uid ctxt.scopes in let str = get_struct ctxt decl.scope_decl_name in - Desugared.Ast.IdentMap.fold + IdentMap.fold (fun id var svmap -> match var with | SubScope _ -> svmap | ScopeVar v -> ( try let field = - StructMap.find str - (Desugared.Ast.IdentMap.find id ctxt.field_idmap) + StructMap.find str (IdentMap.find id ctxt.field_idmap) in ScopeVarMap.add v field svmap with Not_found -> svmap)) sco.var_idmap ScopeVarMap.empty in let typedefs = - Desugared.Ast.IdentMap.update + IdentMap.update (Marked.unmark decl.scope_decl_name) (function | Some (TScope (scope, { out_struct_name; _ })) -> @@ -559,8 +567,8 @@ let typedef_info = function | TScope (s, _) -> ScopeName.get_info s (** Process the names of all declaration items *) -let process_name_item (ctxt : context) (item : Ast.code_item Marked.pos) : - context = +let process_name_item (ctxt : context) (item : Surface.Ast.code_item Marked.pos) + : context = let raise_already_defined_error (use : Uid.MarkedString.info) name pos msg = Errors.raise_multispanned_error [ @@ -578,13 +586,13 @@ let process_name_item (ctxt : context) (item : Ast.code_item Marked.pos) : Option.iter (fun use -> raise_already_defined_error (typedef_info use) name pos "scope") - (Desugared.Ast.IdentMap.find_opt name ctxt.typedefs); + (IdentMap.find_opt name ctxt.typedefs); let scope_uid = ScopeName.fresh (name, pos) in let out_struct_uid = StructName.fresh (name, pos) in { ctxt with typedefs = - Desugared.Ast.IdentMap.add name + IdentMap.add name (TScope ( scope_uid, { @@ -595,8 +603,8 @@ let process_name_item (ctxt : context) (item : Ast.code_item Marked.pos) : scopes = ScopeMap.add scope_uid { - var_idmap = Desugared.Ast.IdentMap.empty; - scope_defs_contexts = Desugared.Ast.ScopeDefMap.empty; + var_idmap = IdentMap.empty; + scope_defs_contexts = Ast.ScopeDefMap.empty; sub_scopes = ScopeSet.empty; } ctxt.scopes; @@ -606,12 +614,12 @@ let process_name_item (ctxt : context) (item : Ast.code_item Marked.pos) : Option.iter (fun use -> raise_already_defined_error (typedef_info use) name pos "struct") - (Desugared.Ast.IdentMap.find_opt name ctxt.typedefs); + (IdentMap.find_opt name ctxt.typedefs); let s_uid = StructName.fresh sdecl.struct_decl_name in { ctxt with typedefs = - Desugared.Ast.IdentMap.add + IdentMap.add (Marked.unmark sdecl.struct_decl_name) (TStruct s_uid) ctxt.typedefs; } @@ -620,20 +628,20 @@ let process_name_item (ctxt : context) (item : Ast.code_item Marked.pos) : Option.iter (fun use -> raise_already_defined_error (typedef_info use) name pos "enum") - (Desugared.Ast.IdentMap.find_opt name ctxt.typedefs); + (IdentMap.find_opt name ctxt.typedefs); let e_uid = EnumName.fresh edecl.enum_decl_name in { ctxt with typedefs = - Desugared.Ast.IdentMap.add + IdentMap.add (Marked.unmark edecl.enum_decl_name) (TEnum e_uid) ctxt.typedefs; } | ScopeUse _ -> ctxt (** Process a code item that is a declaration *) -let process_decl_item (ctxt : context) (item : Ast.code_item Marked.pos) : - context = +let process_decl_item (ctxt : context) (item : Surface.Ast.code_item Marked.pos) + : context = match Marked.unmark item with | ScopeDecl decl -> process_scope_decl ctxt decl | StructDecl sdecl -> process_struct_decl ctxt sdecl @@ -643,44 +651,46 @@ let process_decl_item (ctxt : context) (item : Ast.code_item Marked.pos) : (** Process a code block *) let process_code_block (ctxt : context) - (block : Ast.code_block) - (process_item : context -> Ast.code_item Marked.pos -> context) : context = + (block : Surface.Ast.code_block) + (process_item : context -> Surface.Ast.code_item Marked.pos -> context) : + context = List.fold_left (fun ctxt decl -> process_item ctxt decl) ctxt block (** Process a law structure, only considering the code blocks *) let rec process_law_structure (ctxt : context) - (s : Ast.law_structure) - (process_item : context -> Ast.code_item Marked.pos -> context) : context = + (s : Surface.Ast.law_structure) + (process_item : context -> Surface.Ast.code_item Marked.pos -> context) : + context = match s with - | Ast.LawHeading (_, children) -> + | Surface.Ast.LawHeading (_, children) -> List.fold_left (fun ctxt child -> process_law_structure ctxt child process_item) ctxt children - | Ast.CodeBlock (block, _, _) -> process_code_block ctxt block process_item - | Ast.LawInclude _ | Ast.LawText _ -> ctxt + | Surface.Ast.CodeBlock (block, _, _) -> + process_code_block ctxt block process_item + | Surface.Ast.LawInclude _ | Surface.Ast.LawText _ -> ctxt (** {1 Scope uses pass} *) let get_def_key - (name : Ast.qident) - (state : Ast.ident Marked.pos option) + (name : Surface.Ast.qident) + (state : Surface.Ast.ident Marked.pos option) (scope_uid : ScopeName.t) (ctxt : context) - (pos : Pos.t) : Desugared.Ast.ScopeDef.t = + (pos : Pos.t) : Ast.ScopeDef.t = let scope_ctxt = ScopeMap.find scope_uid ctxt.scopes in match name with | [x] -> let x_uid = get_var_uid scope_uid ctxt x in let var_sig = ScopeVarMap.find x_uid ctxt.var_typs in - Desugared.Ast.ScopeDef.Var + Ast.ScopeDef.Var ( x_uid, match state with | Some state -> ( try Some - (Desugared.Ast.IdentMap.find (Marked.unmark state) - var_sig.var_sig_states_idmap) + (IdentMap.find (Marked.unmark state) var_sig.var_sig_states_idmap) with Not_found -> Errors.raise_multispanned_error [ @@ -691,8 +701,7 @@ let get_def_key "This identifier is not a state declared for variable %a." ScopeVar.format_t x_uid) | None -> - if not (Desugared.Ast.IdentMap.is_empty var_sig.var_sig_states_idmap) - then + if not (IdentMap.is_empty var_sig.var_sig_states_idmap) then Errors.raise_multispanned_error [ None, Marked.get_mark x; @@ -705,9 +714,7 @@ let get_def_key else None ) | [y; x] -> let (subscope_uid, subscope_real_uid) : SubScopeName.t * ScopeName.t = - match - Desugared.Ast.IdentMap.find_opt (Marked.unmark y) scope_ctxt.var_idmap - with + match IdentMap.find_opt (Marked.unmark y) scope_ctxt.var_idmap with | Some (SubScope (v, u)) -> v, u | Some _ -> Errors.raise_spanned_error pos @@ -718,7 +725,7 @@ let get_def_key Print.lit_style (Marked.unmark y) in let x_uid = get_var_uid subscope_real_uid ctxt x in - Desugared.Ast.ScopeDef.SubScopeVar (subscope_uid, x_uid, pos) + Ast.ScopeDef.SubScopeVar (subscope_uid, x_uid, pos) | _ -> Errors.raise_spanned_error pos "This line is defining a quantity that is neither a scope variable nor a \ @@ -728,7 +735,7 @@ let get_def_key let process_definition (ctxt : context) (s_name : ScopeName.t) - (d : Ast.definition) : context = + (d : Surface.Ast.definition) : context = (* We update the definition context inside the big context *) { ctxt with @@ -748,7 +755,7 @@ let process_definition { s_ctxt with scope_defs_contexts = - Desugared.Ast.ScopeDefMap.update def_key + Ast.ScopeDefMap.update def_key (fun def_key_ctx -> let def_key_ctx : scope_def_context = Option.fold @@ -757,7 +764,7 @@ let process_definition (* Here, this is the first time we encounter a definition for this definition key *) default_exception_rulename = None; - label_idmap = Desugared.Ast.IdentMap.empty; + label_idmap = IdentMap.empty; } ~some:(fun x -> x) def_key_ctx @@ -765,16 +772,15 @@ let process_definition (* First, we update the def key context with information about the definition's label*) let def_key_ctx = - match d.Ast.definition_label with + match d.Surface.Ast.definition_label with | None -> def_key_ctx | Some label -> let new_label_idmap = - Desugared.Ast.IdentMap.update (Marked.unmark label) + IdentMap.update (Marked.unmark label) (fun existing_label -> match existing_label with | Some existing_label -> Some existing_label - | None -> - Some (Desugared.Ast.LabelName.fresh label)) + | None -> Some (LabelName.fresh label)) def_key_ctx.label_idmap in { def_key_ctx with label_idmap = new_label_idmap } @@ -782,7 +788,7 @@ let process_definition (* And second, we update the map of default rulenames for unlabeled exceptions *) let def_key_ctx = - match d.Ast.definition_exception_to with + match d.Surface.Ast.definition_exception_to with (* If this definition is an exception, it cannot be a default definition *) | UnlabeledException | ExceptionToLabel _ -> def_key_ctx @@ -806,7 +812,7 @@ let process_definition } (* No definition has been set yet for this key *) | None -> ( - match d.Ast.definition_label with + match d.Surface.Ast.definition_label with (* This default definition has a label. This is not allowed for unlabeled exceptions *) | Some _ -> @@ -838,31 +844,34 @@ let process_definition let process_scope_use_item (s_name : ScopeName.t) (ctxt : context) - (sitem : Ast.scope_use_item Marked.pos) : context = + (sitem : Surface.Ast.scope_use_item Marked.pos) : context = match Marked.unmark sitem with - | Rule r -> process_definition ctxt s_name (Ast.rule_to_def r) + | Rule r -> process_definition ctxt s_name (Surface.Ast.rule_to_def r) | Definition d -> process_definition ctxt s_name d | _ -> ctxt -let process_scope_use (ctxt : context) (suse : Ast.scope_use) : context = +let process_scope_use (ctxt : context) (suse : Surface.Ast.scope_use) : context + = let s_name = match - Desugared.Ast.IdentMap.find_opt - (Marked.unmark suse.Ast.scope_use_name) + IdentMap.find_opt + (Marked.unmark suse.Surface.Ast.scope_use_name) ctxt.typedefs with | Some (TScope (sn, _)) -> sn | _ -> Errors.raise_spanned_error - (Marked.get_mark suse.Ast.scope_use_name) + (Marked.get_mark suse.Surface.Ast.scope_use_name) "\"%a\": this scope has not been declared anywhere, is it a typo?" (Utils.Cli.format_with_style [ANSITerminal.yellow]) - (Marked.unmark suse.Ast.scope_use_name) + (Marked.unmark suse.Surface.Ast.scope_use_name) in - List.fold_left (process_scope_use_item s_name) ctxt suse.Ast.scope_use_items + List.fold_left + (process_scope_use_item s_name) + ctxt suse.Surface.Ast.scope_use_items -let process_use_item (ctxt : context) (item : Ast.code_item Marked.pos) : - context = +let process_use_item (ctxt : context) (item : Surface.Ast.code_item Marked.pos) + : context = match Marked.unmark item with | ScopeDecl _ | StructDecl _ | EnumDecl _ -> ctxt | ScopeUse suse -> process_scope_use ctxt suse @@ -870,17 +879,17 @@ let process_use_item (ctxt : context) (item : Ast.code_item Marked.pos) : (** {1 API} *) (** Derive the context from metadata, in one pass over the declarations *) -let form_context (prgm : Ast.program) : context = +let form_context (prgm : Surface.Ast.program) : context = let empty_ctxt = { - local_var_idmap = Desugared.Ast.IdentMap.empty; - typedefs = Desugared.Ast.IdentMap.empty; + local_var_idmap = IdentMap.empty; + typedefs = IdentMap.empty; scopes = ScopeMap.empty; var_typs = ScopeVarMap.empty; structs = StructMap.empty; - field_idmap = Desugared.Ast.IdentMap.empty; + field_idmap = IdentMap.empty; enums = EnumMap.empty; - constructor_idmap = Desugared.Ast.IdentMap.empty; + constructor_idmap = IdentMap.empty; } in let ctxt = diff --git a/compiler/surface/name_resolution.mli b/compiler/desugared/name_resolution.mli similarity index 82% rename from compiler/surface/name_resolution.mli rename to compiler/desugared/name_resolution.mli index 067c729a..d9f2e8c5 100644 --- a/compiler/surface/name_resolution.mli +++ b/compiler/desugared/name_resolution.mli @@ -25,13 +25,15 @@ open Shared_ast type ident = string +module IdentMap : Map.S with type key = String.t + type unique_rulename = | Ambiguous of Pos.t list - | Unique of Desugared.Ast.RuleName.t Marked.pos + | Unique of RuleName.t Marked.pos type scope_def_context = { default_exception_rulename : unique_rulename option; - label_idmap : Desugared.Ast.LabelName.t Desugared.Ast.IdentMap.t; + label_idmap : LabelName.t IdentMap.t; } type scope_var_or_subscope = @@ -39,9 +41,9 @@ type scope_var_or_subscope = | SubScope of SubScopeName.t * ScopeName.t type scope_context = { - var_idmap : scope_var_or_subscope Desugared.Ast.IdentMap.t; + var_idmap : scope_var_or_subscope IdentMap.t; (** All variables, including scope variables and subscopes *) - scope_defs_contexts : scope_def_context Desugared.Ast.ScopeDefMap.t; + scope_defs_contexts : scope_def_context Ast.ScopeDefMap.t; (** What is the default rule to refer to for unnamed exceptions, if any *) sub_scopes : ScopeSet.t; (** Other scopes referred to by this scope. Used for dependency analysis *) @@ -57,8 +59,8 @@ type enum_context = typ EnumConstructorMap.t type var_sig = { var_sig_typ : typ; var_sig_is_condition : bool; - var_sig_io : Ast.scope_decl_context_io; - var_sig_states_idmap : StateName.t Desugared.Ast.IdentMap.t; + var_sig_io : Surface.Ast.scope_decl_context_io; + var_sig_states_idmap : StateName.t IdentMap.t; var_sig_states_list : StateName.t list; } @@ -71,15 +73,15 @@ type typedef = (** Implicitly defined output struct *) type context = { - local_var_idmap : Desugared.Ast.expr Var.t Desugared.Ast.IdentMap.t; + local_var_idmap : Ast.expr Var.t IdentMap.t; (** Inside a definition, local variables can be introduced by functions arguments or pattern matching *) - typedefs : typedef Desugared.Ast.IdentMap.t; + typedefs : typedef IdentMap.t; (** Gathers the names of the scopes, structs and enums *) - field_idmap : StructFieldName.t StructMap.t Desugared.Ast.IdentMap.t; + field_idmap : StructFieldName.t StructMap.t IdentMap.t; (** The names of the struct fields. Names of fields can be shared between different structs *) - constructor_idmap : EnumConstructor.t EnumMap.t Desugared.Ast.IdentMap.t; + constructor_idmap : EnumConstructor.t EnumMap.t IdentMap.t; (** The names of the enum constructors. Constructor names can be shared between different enums *) scopes : scope_context ScopeMap.t; (** For each scope, its context *) @@ -104,7 +106,7 @@ val get_var_typ : context -> ScopeVar.t -> typ (** Gets the type associated to an uid *) val is_var_cond : context -> ScopeVar.t -> bool -val get_var_io : context -> ScopeVar.t -> Ast.scope_decl_context_io +val get_var_io : context -> ScopeVar.t -> Surface.Ast.scope_decl_context_io val get_var_uid : ScopeName.t -> context -> ident Marked.pos -> ScopeVar.t (** Get the variable uid inside the scope given in argument *) @@ -120,22 +122,22 @@ val is_subscope_uid : ScopeName.t -> context -> ident -> bool val belongs_to : context -> ScopeVar.t -> ScopeName.t -> bool (** Checks if the var_uid belongs to the scope scope_uid *) -val get_def_typ : context -> Desugared.Ast.ScopeDef.t -> typ +val get_def_typ : context -> Ast.ScopeDef.t -> typ (** Retrieves the type of a scope definition from the context *) -val is_def_cond : context -> Desugared.Ast.ScopeDef.t -> bool -val is_type_cond : Ast.typ -> bool +val is_def_cond : context -> Ast.ScopeDef.t -> bool +val is_type_cond : Surface.Ast.typ -> bool -val add_def_local_var : context -> ident -> context * Desugared.Ast.expr Var.t +val add_def_local_var : context -> ident -> context * Ast.expr Var.t (** Adds a binding to the context *) val get_def_key : - Ast.qident -> - Ast.ident Marked.pos option -> + Surface.Ast.qident -> + Surface.Ast.ident Marked.pos option -> ScopeName.t -> context -> Pos.t -> - Desugared.Ast.ScopeDef.t + Ast.ScopeDef.t (** Usage: [get_def_key var_name var_state scope_uid ctxt pos]*) val get_enum : context -> ident Marked.pos -> EnumName.t @@ -152,5 +154,5 @@ val get_scope : context -> ident Marked.pos -> ScopeName.t (** {1 API} *) -val form_context : Ast.program -> context +val form_context : Surface.Ast.program -> context (** Derive the context from metadata, in one pass over the declarations *) diff --git a/compiler/driver.ml b/compiler/driver.ml index 8e2b7df8..96eaa7e0 100644 --- a/compiler/driver.ml +++ b/compiler/driver.ml @@ -143,7 +143,7 @@ let driver source_file (options : Cli.options) : int = | ( `Interpret | `Typecheck | `OCaml | `Python | `Scalc | `Lcalc | `Dcalc | `Scopelang | `Proof | `Plugin _ ) as backend -> ( Cli.debug_print "Name resolution..."; - let ctxt = Surface.Name_resolution.form_context prgm in + let ctxt = Desugared.Name_resolution.form_context prgm in let scope_uid = match options.ex_scope, backend with | None, `Interpret -> @@ -151,27 +151,29 @@ let driver source_file (options : Cli.options) : int = | None, _ -> let _, scope = try - Desugared.Ast.IdentMap.filter_map + Desugared.Name_resolution.IdentMap.filter_map (fun _ -> function - | Surface.Name_resolution.TScope (uid, _) -> Some uid + | Desugared.Name_resolution.TScope (uid, _) -> Some uid | _ -> None) ctxt.typedefs - |> Desugared.Ast.IdentMap.choose + |> Desugared.Name_resolution.IdentMap.choose with Not_found -> Errors.raise_error "There isn't any scope inside the program." in scope | Some name, _ -> ( - match Desugared.Ast.IdentMap.find_opt name ctxt.typedefs with - | Some (Surface.Name_resolution.TScope (uid, _)) -> uid + match + Desugared.Name_resolution.IdentMap.find_opt name ctxt.typedefs + with + | Some (Desugared.Name_resolution.TScope (uid, _)) -> uid | _ -> Errors.raise_error "There is no scope \"%s\" inside the program." name) in Cli.debug_print "Desugaring..."; - let prgm = Surface.Desugaring.desugar_program ctxt prgm in + let prgm = Desugared.From_surface.translate_program ctxt prgm in Cli.debug_print "Collecting rules..."; - let prgm = Desugared.Desugared_to_scope.translate_program prgm in + let prgm = Scopelang.From_desugared.translate_program prgm in match backend with | `Scopelang -> let _output_file, with_output = get_output_format () in @@ -194,7 +196,7 @@ let driver source_file (options : Cli.options) : int = in let prgm = Scopelang.Ast.type_program prgm in Cli.debug_print "Translating to default calculus..."; - let prgm = Scopelang.Scope_to_dcalc.translate_program prgm in + let prgm = Dcalc.From_scopelang.translate_program prgm in let prgm = if options.optimize then begin Cli.debug_print "Optimizing default calculus..."; diff --git a/compiler/lcalc/from_dcalc.ml b/compiler/lcalc/from_dcalc.ml new file mode 100644 index 00000000..ec7601c3 --- /dev/null +++ b/compiler/lcalc/from_dcalc.ml @@ -0,0 +1,21 @@ +(* This file is part of the Catala compiler, a specification language for tax + and social benefits computation rules. Copyright (C) 2020 Inria, contributor: + Denis Merigoux + + Licensed under the Apache License, Version 2.0 (the "License"); you may not + use this file except in compliance with the License. You may obtain a copy of + the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + License for the specific language governing permissions and limitations under + the License. *) + +let translate_program_with_exceptions = + Compile_with_exceptions.translate_program + +let translate_program_without_exceptions = + Compile_without_exceptions.translate_program diff --git a/compiler/lcalc/from_dcalc.mli b/compiler/lcalc/from_dcalc.mli new file mode 100644 index 00000000..493115f3 --- /dev/null +++ b/compiler/lcalc/from_dcalc.mli @@ -0,0 +1,26 @@ +(* This file is part of the Catala compiler, a specification language for tax + and social benefits computation rules. Copyright (C) 2020 Inria, contributor: + Denis Merigoux + + Licensed under the Apache License, Version 2.0 (the "License"); you may not + use this file except in compliance with the License. You may obtain a copy of + the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + License for the specific language governing permissions and limitations under + the License. *) + +val translate_program_with_exceptions : 'm Dcalc.Ast.program -> 'm Ast.program +(** Translation from the default calculus to the lambda calculus. This + translation uses exceptions to handle empty default terms. *) + +val translate_program_without_exceptions : + 'm Dcalc.Ast.program -> 'm Ast.program +(** Translation from the default calculus to the lambda calculus. This + translation uses an option monad to handle empty defaults terms. This + transformation is one piece to permit to compile toward legacy languages + that does not contains exceptions. *) diff --git a/compiler/scopelang/ast.ml b/compiler/scopelang/ast.ml index 72723c70..d41b2308 100644 --- a/compiler/scopelang/ast.ml +++ b/compiler/scopelang/ast.ml @@ -39,17 +39,14 @@ let rec locations_used (e : 'm expr) : LocationSet.t = (fun e -> LocationSet.union (locations_used e)) e LocationSet.empty -type io_input = NoInput | OnlyInput | Reentrant -type io = { io_output : bool Marked.pos; io_input : io_input Marked.pos } - type 'm rule = - | Definition of location Marked.pos * typ * io * 'm expr + | Definition of location Marked.pos * typ * Desugared.Ast.io * 'm expr | Assertion of 'm expr | Call of ScopeName.t * SubScopeName.t * 'm mark type 'm scope_decl = { scope_decl_name : ScopeName.t; - scope_sig : (typ * io) ScopeVarMap.t; + scope_sig : (typ * Desugared.Ast.io) ScopeVarMap.t; scope_decl_rules : 'm rule list; scope_mark : 'm mark; } diff --git a/compiler/scopelang/ast.mli b/compiler/scopelang/ast.mli index 2586bedd..3e1c1dff 100644 --- a/compiler/scopelang/ast.mli +++ b/compiler/scopelang/ast.mli @@ -31,35 +31,14 @@ type 'm expr = (scopelang, 'm mark) gexpr val locations_used : 'm expr -> LocationSet.t -(** This type characterizes the three levels of visibility for a given scope - variable with regards to the scope's input and possible redefinitions inside - the scope.. *) -type io_input = - | NoInput - (** For an internal variable defined only in the scope, and does not - appear in the input. *) - | OnlyInput - (** For variables that should not be redefined in the scope, because they - appear in the input. *) - | Reentrant - (** For variables defined in the scope that can also be redefined by the - caller as they appear in the input. *) - -type io = { - io_output : bool Marked.pos; - (** [true] is present in the output of the scope. *) - io_input : io_input Marked.pos; -} -(** Characterization of the input/output status of a scope variable. *) - type 'm rule = - | Definition of location Marked.pos * typ * io * 'm expr + | Definition of location Marked.pos * typ * Desugared.Ast.io * 'm expr | Assertion of 'm expr | Call of ScopeName.t * SubScopeName.t * 'm mark type 'm scope_decl = { scope_decl_name : ScopeName.t; - scope_sig : (typ * io) ScopeVarMap.t; + scope_sig : (typ * Desugared.Ast.io) ScopeVarMap.t; scope_decl_rules : 'm rule list; scope_mark : 'm mark; } diff --git a/compiler/scopelang/dune b/compiler/scopelang/dune index ab63006c..43607aab 100644 --- a/compiler/scopelang/dune +++ b/compiler/scopelang/dune @@ -1,7 +1,7 @@ (library (name scopelang) (public_name catala.scopelang) - (libraries utils dcalc ocamlgraph) + (libraries utils ocamlgraph desugared) (flags (:standard -short-paths))) diff --git a/compiler/desugared/desugared_to_scope.ml b/compiler/scopelang/from_desugared.ml similarity index 77% rename from compiler/desugared/desugared_to_scope.ml rename to compiler/scopelang/from_desugared.ml index cb3e2400..87f6c8c0 100644 --- a/compiler/desugared/desugared_to_scope.ml +++ b/compiler/scopelang/from_desugared.ml @@ -27,20 +27,19 @@ type target_scope_vars = type ctx = { scope_var_mapping : target_scope_vars ScopeVarMap.t; - var_mapping : (Ast.expr, untyped Scopelang.Ast.expr Var.t) Var.Map.t; + var_mapping : (Desugared.Ast.expr, untyped Ast.expr Var.t) Var.Map.t; } let tag_with_log_entry - (e : untyped Scopelang.Ast.expr boxed) + (e : untyped Ast.expr boxed) (l : log_entry) - (markings : Utils.Uid.MarkedString.info list) : - untyped Scopelang.Ast.expr boxed = + (markings : Utils.Uid.MarkedString.info list) : untyped Ast.expr boxed = Expr.eapp (Expr.eop (Unop (Log (l, markings))) (Marked.get_mark e)) [e] (Marked.get_mark e) -let rec translate_expr (ctx : ctx) (e : Ast.expr) : - untyped Scopelang.Ast.expr boxed = +let rec translate_expr (ctx : ctx) (e : Desugared.Ast.expr) : + untyped Ast.expr boxed = let m = Marked.get_mark e in match Marked.unmark e with | ELocation (SubScopeVar (s_name, ss_name, s_var)) -> @@ -130,36 +129,36 @@ let rec translate_expr (ctx : ctx) (e : Ast.expr) : (** Intermediate representation for the exception tree of rules for a particular scope definition. *) type rule_tree = - | Leaf of Ast.rule list + | Leaf of Desugared.Ast.rule list (** Rules defining a base case piecewise. List is non-empty. *) - | Node of rule_tree list * Ast.rule list + | Node of rule_tree list * Desugared.Ast.rule list (** [Node (exceptions, base_case)] is a list of exceptions to a non-empty list of rules defining a base case piecewise. *) (** Transforms a flat list of rules into a tree, taking into account the priorities declared between rules *) -let def_map_to_tree (def_info : Ast.ScopeDef.t) (def : Ast.rule Ast.RuleMap.t) : - rule_tree list = - let exc_graph = Dependency.build_exceptions_graph def def_info in - Dependency.check_for_exception_cycle exc_graph; +let def_map_to_tree + (def_info : Desugared.Ast.ScopeDef.t) + (def : Desugared.Ast.rule RuleMap.t) : rule_tree list = + let exc_graph = Desugared.Dependency.build_exceptions_graph def def_info in + Desugared.Dependency.check_for_exception_cycle exc_graph; (* we start by the base cases: they are the vertices which have no successors *) let base_cases = - Dependency.ExceptionsDependencies.fold_vertex + Desugared.Dependency.ExceptionsDependencies.fold_vertex (fun v base_cases -> - if Dependency.ExceptionsDependencies.out_degree exc_graph v = 0 then - v :: base_cases + if + Desugared.Dependency.ExceptionsDependencies.out_degree exc_graph v = 0 + then v :: base_cases else base_cases) exc_graph [] in - let rec build_tree (base_cases : Ast.RuleSet.t) : rule_tree = + let rec build_tree (base_cases : RuleSet.t) : rule_tree = let exceptions = - Dependency.ExceptionsDependencies.pred exc_graph base_cases + Desugared.Dependency.ExceptionsDependencies.pred exc_graph base_cases in let base_case_as_rule_list = - List.map - (fun r -> Ast.RuleMap.find r def) - (Ast.RuleSet.elements base_cases) + List.map (fun r -> RuleMap.find r def) (RuleSet.elements base_cases) in match exceptions with | [] -> Leaf base_case_as_rule_list @@ -174,8 +173,8 @@ let rec rule_tree_to_expr ~(toplevel : bool) (ctx : ctx) (def_pos : Pos.t) - (is_func : Ast.expr Var.t option) - (tree : rule_tree) : untyped Scopelang.Ast.expr boxed = + (is_func : Desugared.Ast.expr Var.t option) + (tree : rule_tree) : untyped Ast.expr boxed = let emark = Untyped { pos = def_pos } in let exceptions, base_rules = match tree with Leaf r -> [], r | Node (exceptions, r) -> exceptions, r @@ -183,9 +182,10 @@ let rec rule_tree_to_expr (* because each rule has its own variable parameter and we want to convert the whole rule tree into a function, we need to perform some alpha-renaming of all the expressions *) - let substitute_parameter (e : Ast.expr boxed) (rule : Ast.rule) : - Ast.expr boxed = - match is_func, rule.Ast.rule_parameter with + let substitute_parameter + (e : Desugared.Ast.expr boxed) + (rule : Desugared.Ast.rule) : Desugared.Ast.expr boxed = + match is_func, rule.Desugared.Ast.rule_parameter with | Some new_param, Some (old_param, _) -> let binder = Bindlib.bind_var old_param (Marked.unmark e) in Marked.mark (Marked.get_mark e) @@ -217,16 +217,16 @@ let rec rule_tree_to_expr in let base_just_list = List.map - (fun rule -> substitute_parameter rule.Ast.rule_just rule) + (fun rule -> substitute_parameter rule.Desugared.Ast.rule_just rule) base_rules in let base_cons_list = List.map - (fun rule -> substitute_parameter rule.Ast.rule_cons rule) + (fun rule -> substitute_parameter rule.Desugared.Ast.rule_cons rule) base_rules in - let translate_and_unbox_list (list : Ast.expr boxed list) : - untyped Scopelang.Ast.expr boxed list = + let translate_and_unbox_list (list : Desugared.Ast.expr boxed list) : + untyped Ast.expr boxed list = List.map (fun e -> (* There are two levels of boxing here, the outermost is introduced by @@ -258,7 +258,7 @@ let rec rule_tree_to_expr (Expr.elit (LBool true) emark) default_containing_base_cases emark in - match is_func, (List.hd base_rules).Ast.rule_parameter with + match is_func, (List.hd base_rules).Desugared.Ast.rule_parameter with | None, None -> default | Some new_param, Some (_, typ) -> if toplevel then @@ -277,22 +277,22 @@ let rec rule_tree_to_expr an {!constructor: Dcalc.EDefault} *) let translate_def (ctx : ctx) - (def_info : Ast.ScopeDef.t) - (def : Ast.rule Ast.RuleMap.t) + (def_info : Desugared.Ast.ScopeDef.t) + (def : Desugared.Ast.rule RuleMap.t) (typ : typ) - (io : Scopelang.Ast.io) + (io : Desugared.Ast.io) ~(is_cond : bool) - ~(is_subscope_var : bool) : untyped Scopelang.Ast.expr boxed = + ~(is_subscope_var : bool) : untyped Ast.expr boxed = (* Here, we have to transform this list of rules into a default tree. *) let is_def_func = match Marked.unmark typ with TArrow (_, _) -> true | _ -> false in - let is_rule_func _ (r : Ast.rule) : bool = - Option.is_some r.Ast.rule_parameter + let is_rule_func _ (r : Desugared.Ast.rule) : bool = + Option.is_some r.Desugared.Ast.rule_parameter in - let all_rules_func = Ast.RuleMap.for_all is_rule_func def in + let all_rules_func = RuleMap.for_all is_rule_func def in let all_rules_not_func = - Ast.RuleMap.for_all (fun n r -> not (is_rule_func n r)) def + RuleMap.for_all (fun n r -> not (is_rule_func n r)) def in let is_def_func_param_typ : typ option = if is_def_func && all_rules_func then @@ -302,20 +302,21 @@ let translate_def Errors.raise_spanned_error (Marked.get_mark typ) "The definitions of %a are function but it doesn't have a function \ type" - Ast.ScopeDef.format_t def_info + Desugared.Ast.ScopeDef.format_t def_info else if (not is_def_func) && all_rules_not_func then None else let spans = List.map (fun (_, r) -> - Some "This definition is a function:", Expr.pos r.Ast.rule_cons) - (Ast.RuleMap.bindings (Ast.RuleMap.filter is_rule_func def)) + ( Some "This definition is a function:", + Expr.pos r.Desugared.Ast.rule_cons )) + (RuleMap.bindings (RuleMap.filter is_rule_func def)) @ List.map (fun (_, r) -> ( Some "This definition is not a function:", - Expr.pos r.Ast.rule_cons )) - (Ast.RuleMap.bindings - (Ast.RuleMap.filter (fun n r -> not (is_rule_func n r)) def)) + Expr.pos r.Desugared.Ast.rule_cons )) + (RuleMap.bindings + (RuleMap.filter (fun n r -> not (is_rule_func n r)) def)) in Errors.raise_multispanned_error spans "some definitions of the same variable are functions while others \ @@ -323,7 +324,7 @@ let translate_def in let top_list = def_map_to_tree def_info def in let is_input = - match Marked.unmark io.Scopelang.Ast.io_input with + match Marked.unmark io.Desugared.Ast.io_input with | OnlyInput -> true | _ -> false in @@ -333,13 +334,13 @@ let translate_def where the condition is declared. Except when the variable is an input, where we want the [false] to be added at each caller parent scope. *) Some - (Ast.always_false_rule - (Ast.ScopeDef.get_position def_info) + (Desugared.Ast.always_false_rule + (Desugared.Ast.ScopeDef.get_position def_info) is_def_func_param_typ) else None in if - Ast.RuleMap.cardinal def = 0 + RuleMap.cardinal def = 0 && is_subscope_var (* Here we have a special case for the empty definitions. Indeed, we could use the code for the regular case below that would create a convoluted @@ -364,16 +365,18 @@ let translate_def will not be provided by the calee scope, it has to be placed in the caller. *) then - Expr.elit LEmptyError (Untyped { pos = Ast.ScopeDef.get_position def_info }) + Expr.elit LEmptyError + (Untyped { pos = Desugared.Ast.ScopeDef.get_position def_info }) else rule_tree_to_expr ~toplevel:true ctx - (Ast.ScopeDef.get_position def_info) + (Desugared.Ast.ScopeDef.get_position def_info) (Option.map (fun _ -> Var.make "param") is_def_func_param_typ) (match top_list, top_value with | [], None -> (* In this case, there are no rules to define the expression and no default value so we put an empty rule. *) - Leaf [Ast.empty_rule (Marked.get_mark typ) is_def_func_param_typ] + Leaf + [Desugared.Ast.empty_rule (Marked.get_mark typ) is_def_func_param_typ] | [], Some top_value -> (* In this case, there are no rules to define the expression but a default value so we put it. *) @@ -386,32 +389,39 @@ let translate_def | _, None -> Node ( top_list, - [Ast.empty_rule (Marked.get_mark typ) is_def_func_param_typ] )) + [ + Desugared.Ast.empty_rule (Marked.get_mark typ) + is_def_func_param_typ; + ] )) (** Translates a scope *) -let translate_scope (ctx : ctx) (scope : Ast.scope) : - untyped Scopelang.Ast.scope_decl = - let scope_dependencies = Dependency.build_scope_dependencies scope in - Dependency.check_for_cycle scope scope_dependencies; +let translate_scope (ctx : ctx) (scope : Desugared.Ast.scope) : + untyped Ast.scope_decl = + let scope_dependencies = + Desugared.Dependency.build_scope_dependencies scope + in + Desugared.Dependency.check_for_cycle scope scope_dependencies; let scope_ordering = - Dependency.correct_computation_ordering scope_dependencies + Desugared.Dependency.correct_computation_ordering scope_dependencies in let scope_decl_rules = List.flatten (List.map (fun vertex -> match vertex with - | Dependency.Vertex.Var (var, state) -> ( + | Desugared.Dependency.Vertex.Var (var, state) -> ( let scope_def = - Ast.ScopeDefMap.find - (Ast.ScopeDef.Var (var, state)) + Desugared.Ast.ScopeDefMap.find + (Desugared.Ast.ScopeDef.Var (var, state)) scope.scope_defs in let var_def = scope_def.scope_def_rules in let var_typ = scope_def.scope_def_typ in let is_cond = scope_def.scope_def_is_condition in - match Marked.unmark scope_def.Ast.scope_def_io.io_input with - | OnlyInput when not (Ast.RuleMap.is_empty var_def) -> + match + Marked.unmark scope_def.Desugared.Ast.scope_def_io.io_input + with + | OnlyInput when not (RuleMap.is_empty var_def) -> (* If the variable is tagged as input, then it shall not be redefined. *) Errors.raise_multispanned_error @@ -420,8 +430,8 @@ let translate_scope (ctx : ctx) (scope : Ast.scope) : :: List.map (fun (rule, _) -> ( Some "Incriminated variable definition:", - Marked.get_mark (Ast.RuleName.get_info rule) )) - (Ast.RuleMap.bindings var_def)) + Marked.get_mark (RuleName.get_info rule) )) + (RuleMap.bindings var_def)) "It is impossible to give a definition to a scope variable \ tagged as input." | OnlyInput -> @@ -430,8 +440,8 @@ let translate_scope (ctx : ctx) (scope : Ast.scope) : | _ -> let expr_def = translate_def ctx - (Ast.ScopeDef.Var (var, state)) - var_def var_typ scope_def.Ast.scope_def_io ~is_cond + (Desugared.Ast.ScopeDef.Var (var, state)) + var_def var_typ scope_def.Desugared.Ast.scope_def_io ~is_cond ~is_subscope_var:false in let scope_var = @@ -441,56 +451,61 @@ let translate_scope (ctx : ctx) (scope : Ast.scope) : | _ -> failwith "should not happen" in [ - Scopelang.Ast.Definition + Ast.Definition ( ( ScopelangScopeVar ( scope_var, Marked.get_mark (ScopeVar.get_info scope_var) ), Marked.get_mark (ScopeVar.get_info scope_var) ), var_typ, - scope_def.Ast.scope_def_io, + scope_def.Desugared.Ast.scope_def_io, Expr.unbox expr_def ); ]) - | Dependency.Vertex.SubScope sub_scope_index -> + | Desugared.Dependency.Vertex.SubScope sub_scope_index -> (* Before calling the sub_scope, we need to include all the re-definitions of subscope parameters*) let sub_scope = SubScopeMap.find sub_scope_index scope.scope_sub_scopes in let sub_scope_vars_redefs_candidates = - Ast.ScopeDefMap.filter + Desugared.Ast.ScopeDefMap.filter (fun def_key scope_def -> match def_key with - | Ast.ScopeDef.Var _ -> false - | Ast.ScopeDef.SubScopeVar (sub_scope_index', _, _) -> + | Desugared.Ast.ScopeDef.Var _ -> false + | Desugared.Ast.ScopeDef.SubScopeVar (sub_scope_index', _, _) + -> sub_scope_index = sub_scope_index' (* We exclude subscope variables that have 0 re-definitions and are not visible in the input of the subscope *) && not ((match - Marked.unmark scope_def.Ast.scope_def_io.io_input + Marked.unmark + scope_def.Desugared.Ast.scope_def_io.io_input with - | Scopelang.Ast.NoInput -> true + | Desugared.Ast.NoInput -> true | _ -> false) - && Ast.RuleMap.is_empty scope_def.scope_def_rules)) + && RuleMap.is_empty scope_def.scope_def_rules)) scope.scope_defs in let sub_scope_vars_redefs = - Ast.ScopeDefMap.mapi + Desugared.Ast.ScopeDefMap.mapi (fun def_key scope_def -> - let def = scope_def.Ast.scope_def_rules in + let def = scope_def.Desugared.Ast.scope_def_rules in let def_typ = scope_def.scope_def_typ in let is_cond = scope_def.scope_def_is_condition in match def_key with - | Ast.ScopeDef.Var _ -> assert false (* should not happen *) - | Ast.ScopeDef.SubScopeVar (sscope, sub_scope_var, pos) -> + | Desugared.Ast.ScopeDef.Var _ -> + assert false (* should not happen *) + | Desugared.Ast.ScopeDef.SubScopeVar + (sscope, sub_scope_var, pos) -> (* This definition redefines a variable of the correct subscope. But we have to check that this redefinition is allowed with respect to the io parameters of that subscope variable. *) (match - Marked.unmark scope_def.Ast.scope_def_io.io_input + Marked.unmark + scope_def.Desugared.Ast.scope_def_io.io_input with - | Scopelang.Ast.NoInput -> + | Desugared.Ast.NoInput -> Errors.raise_multispanned_error (( Some "Incriminated subscope:", Marked.get_mark (SubScopeName.get_info sscope) ) @@ -501,11 +516,11 @@ let translate_scope (ctx : ctx) (scope : Ast.scope) : (fun (rule, _) -> ( Some "Incriminated subscope variable definition:", - Marked.get_mark (Ast.RuleName.get_info rule) )) - (Ast.RuleMap.bindings def)) + Marked.get_mark (RuleName.get_info rule) )) + (RuleMap.bindings def)) "It is impossible to give a definition to a subscope \ variable not tagged as input or context." - | OnlyInput when Ast.RuleMap.is_empty def && not is_cond -> + | OnlyInput when RuleMap.is_empty def && not is_cond -> (* If the subscope variable is tagged as input, then it shall be defined. *) Errors.raise_multispanned_error @@ -521,14 +536,16 @@ let translate_scope (ctx : ctx) (scope : Ast.scope) : this redefinition to a proper Scopelang term. *) let expr_def = translate_def ctx def_key def def_typ - scope_def.Ast.scope_def_io ~is_cond + scope_def.Desugared.Ast.scope_def_io ~is_cond ~is_subscope_var:true in let subscop_real_name = SubScopeMap.find sub_scope_index scope.scope_sub_scopes in - let var_pos = Ast.ScopeDef.get_position def_key in - Scopelang.Ast.Definition + let var_pos = + Desugared.Ast.ScopeDef.get_position def_key + in + Ast.Definition ( ( SubScopeVar ( subscop_real_name, (sub_scope_index, var_pos), @@ -544,16 +561,17 @@ let translate_scope (ctx : ctx) (scope : Ast.scope) : snd (List.hd states), var_pos ), var_pos ), def_typ, - scope_def.Ast.scope_def_io, + scope_def.Desugared.Ast.scope_def_io, Expr.unbox expr_def )) sub_scope_vars_redefs_candidates in let sub_scope_vars_redefs = - List.map snd (Ast.ScopeDefMap.bindings sub_scope_vars_redefs) + List.map snd + (Desugared.Ast.ScopeDefMap.bindings sub_scope_vars_redefs) in sub_scope_vars_redefs @ [ - Scopelang.Ast.Call + Ast.Call ( sub_scope, sub_scope_index, Untyped @@ -573,16 +591,18 @@ let translate_scope (ctx : ctx) (scope : Ast.scope) : @ List.map (fun e -> let scope_e = translate_expr ctx (Expr.unbox e) in - Scopelang.Ast.Assertion (Expr.unbox scope_e)) - scope.Ast.scope_assertions + Ast.Assertion (Expr.unbox scope_e)) + scope.Desugared.Ast.scope_assertions in let scope_sig = ScopeVarMap.fold - (fun var (states : Ast.var_or_states) acc -> + (fun var (states : Desugared.Ast.var_or_states) acc -> match states with | WholeVar -> let scope_def = - Ast.ScopeDefMap.find (Ast.ScopeDef.Var (var, None)) scope.scope_defs + Desugared.Ast.ScopeDefMap.find + (Desugared.Ast.ScopeDef.Var (var, None)) + scope.scope_defs in let typ = scope_def.scope_def_typ in ScopeVarMap.add @@ -593,13 +613,13 @@ let translate_scope (ctx : ctx) (scope : Ast.scope) : acc | States states -> (* What happens in the case of variables with multiple states is - interesting. We need to create as many Scopelang.Var entries in the - scope signature as there are states. *) + interesting. We need to create as many Var entries in the scope + signature as there are states. *) List.fold_left (fun acc (state : StateName.t) -> let scope_def = - Ast.ScopeDefMap.find - (Ast.ScopeDef.Var (var, Some state)) + Desugared.Ast.ScopeDefMap.find + (Desugared.Ast.ScopeDef.Var (var, Some state)) scope.scope_defs in ScopeVarMap.add @@ -613,27 +633,27 @@ let translate_scope (ctx : ctx) (scope : Ast.scope) : in let pos = Marked.get_mark (ScopeName.get_info scope.scope_uid) in { - Scopelang.Ast.scope_decl_name = scope.scope_uid; - Scopelang.Ast.scope_decl_rules; - Scopelang.Ast.scope_sig; - Scopelang.Ast.scope_mark = Untyped { pos }; + Ast.scope_decl_name = scope.scope_uid; + Ast.scope_decl_rules; + Ast.scope_sig; + Ast.scope_mark = Untyped { pos }; } (** {1 API} *) -let translate_program (pgrm : Ast.program) : untyped Scopelang.Ast.program = - (* First we give mappings to all the locations between Desugared and - Scopelang. This involves creating a new Scopelang scope variable for every - state of a Desugared variable. *) +let translate_program (pgrm : Desugared.Ast.program) : untyped Ast.program = + (* First we give mappings to all the locations between Desugared and This + involves creating a new Scopelang scope variable for every state of a + Desugared variable. *) let ctx = (* Todo: since we rename all scope vars at this point, it would be better to have different types for Desugared.ScopeVar.t and Scopelang.ScopeVar.t *) ScopeMap.fold (fun _scope scope_decl ctx -> ScopeVarMap.fold - (fun scope_var (states : Ast.var_or_states) ctx -> + (fun scope_var (states : Desugared.Ast.var_or_states) ctx -> match states with - | Ast.WholeVar -> + | Desugared.Ast.WholeVar -> { ctx with scope_var_mapping = @@ -661,8 +681,8 @@ let translate_program (pgrm : Ast.program) : untyped Scopelang.Ast.program = states)) ctx.scope_var_mapping; }) - scope_decl.Ast.scope_vars ctx) - pgrm.Ast.program_scopes + scope_decl.Desugared.Ast.scope_vars ctx) + pgrm.Desugared.Ast.program_scopes { scope_var_mapping = ScopeVarMap.empty; var_mapping = Var.Map.empty } in let ctx_scopes = @@ -680,10 +700,9 @@ let translate_program (pgrm : Ast.program) : untyped Scopelang.Ast.program = out_str.out_struct_fields ScopeVarMap.empty in { out_str with out_struct_fields }) - pgrm.Ast.program_ctx.ctx_scopes + pgrm.Desugared.Ast.program_ctx.ctx_scopes in { - Scopelang.Ast.program_scopes = - ScopeMap.map (translate_scope ctx) pgrm.program_scopes; + Ast.program_scopes = ScopeMap.map (translate_scope ctx) pgrm.program_scopes; program_ctx = { pgrm.program_ctx with ctx_scopes }; } diff --git a/compiler/desugared/desugared_to_scope.mli b/compiler/scopelang/from_desugared.mli similarity index 91% rename from compiler/desugared/desugared_to_scope.mli rename to compiler/scopelang/from_desugared.mli index b5314e7c..8f2dae8c 100644 --- a/compiler/desugared/desugared_to_scope.mli +++ b/compiler/scopelang/from_desugared.mli @@ -16,4 +16,4 @@ (** Translation from {!module: Desugared.Ast} to {!module: Scopelang.Ast} *) -val translate_program : Ast.program -> Shared_ast.untyped Scopelang.Ast.program +val translate_program : Desugared.Ast.program -> Shared_ast.untyped Ast.program diff --git a/compiler/scopelang/print.ml b/compiler/scopelang/print.ml index f64d908d..fcc4e6f6 100644 --- a/compiler/scopelang/print.ml +++ b/compiler/scopelang/print.ml @@ -56,11 +56,11 @@ let scope ?(debug = false) ctx fmt (name, decl) = Format.fprintf fmt "%a%a%a %a%a%a%a%a" Print.punctuation "(" ScopeVar.format_t scope_var Print.punctuation ":" (Print.typ ctx) typ Print.punctuation "|" Print.keyword - (match Marked.unmark vis.io_input with + (match Marked.unmark vis.Desugared.Ast.io_input with | NoInput -> "internal" | OnlyInput -> "input" | Reentrant -> "context") - (if Marked.unmark vis.io_output then fun fmt () -> + (if Marked.unmark vis.Desugared.Ast.io_output then fun fmt () -> Format.fprintf fmt "%a@,%a" Print.punctuation "|" Print.keyword "output" else fun fmt () -> Format.fprintf fmt "@<0>") diff --git a/compiler/shared_ast/definitions.ml b/compiler/shared_ast/definitions.ml index 0653ef06..8169c627 100644 --- a/compiler/shared_ast/definitions.ml +++ b/compiler/shared_ast/definitions.ml @@ -45,6 +45,20 @@ module EnumConstructor : Uid.Id with type info = Uid.MarkedString.info = module EnumMap : Map.S with type key = EnumName.t = Map.Make (EnumName) +(** Only used by surface *) + +module RuleName : Uid.Id with type info = Uid.MarkedString.info = + Uid.Make (Uid.MarkedString) () + +module RuleMap : Map.S with type key = RuleName.t = Map.Make (RuleName) +module RuleSet : Set.S with type elt = RuleName.t = Set.Make (RuleName) + +module LabelName : Uid.Id with type info = Uid.MarkedString.info = + Uid.Make (Uid.MarkedString) () + +module LabelMap : Map.S with type key = LabelName.t = Map.Make (LabelName) +module LabelSet : Set.S with type elt = LabelName.t = Set.Make (LabelName) + (** Only used by desugared/scopelang *) module ScopeVar : Uid.Id with type info = Uid.MarkedString.info = diff --git a/compiler/surface/ast.ml b/compiler/surface/ast.ml index 6ca03fd6..4f2af2e3 100644 --- a/compiler/surface/ast.ml +++ b/compiler/surface/ast.ml @@ -493,7 +493,7 @@ type rule = { rule_parameter : ident Marked.pos option; rule_condition : expression Marked.pos option; rule_name : qident Marked.pos; - rule_id : Desugared.Ast.RuleName.t; [@opaque] + rule_id : Shared_ast.RuleName.t; [@opaque] rule_consequence : (bool[@opaque]) Marked.pos; rule_state : ident Marked.pos option; } @@ -517,7 +517,7 @@ type definition = { definition_name : qident Marked.pos; definition_parameter : ident Marked.pos option; definition_condition : expression Marked.pos option; - definition_id : Desugared.Ast.RuleName.t; [@opaque] + definition_id : Shared_ast.RuleName.t; [@opaque] definition_expr : expression Marked.pos; definition_state : ident Marked.pos option; } diff --git a/compiler/surface/dune b/compiler/surface/dune index 2184fb8a..d8724a3b 100644 --- a/compiler/surface/dune +++ b/compiler/surface/dune @@ -6,11 +6,10 @@ menhirLib sedlex re - desugared - scopelang zarith zarith_stubs_js - dates_calc) + dates_calc + shared_ast) (preprocess (pps sedlex.ppx visitors.ppx))) diff --git a/compiler/surface/parser.mly b/compiler/surface/parser.mly index 2d2299de..2925c593 100644 --- a/compiler/surface/parser.mly +++ b/compiler/surface/parser.mly @@ -392,7 +392,7 @@ rule: rule_parameter = param_applied; rule_condition = cond; rule_name = name; - rule_id = Desugared.Ast.RuleName.fresh + rule_id = Shared_ast.RuleName.fresh (String.concat "." (List.map (fun i -> Marked.unmark i) (Marked.unmark name)), Pos.from_lpos $sloc); rule_consequence = cons; @@ -429,7 +429,7 @@ definition: definition_parameter = param; definition_condition = cond; definition_id = - Desugared.Ast.RuleName.fresh + Shared_ast.RuleName.fresh (String.concat "." (List.map (fun i -> Marked.unmark i) (Marked.unmark name)), Pos.from_lpos $sloc); definition_expr = e;