From b329afbbdb167c1b75832dd4ea526bafcee3262b Mon Sep 17 00:00:00 2001 From: Louis Gesbert Date: Mon, 21 Nov 2022 10:12:45 +0100 Subject: [PATCH] Rename all Map/Set calls accordingly This is just a bunch of `sed` calls: ```shell sed -i 's/ScopeSet/ScopeName.Set/g' compiler/**/*.ml* sed -i 's/ScopeMap/ScopeName.Map/g' compiler/**/*.ml* sed -i 's/StructMap/StructName.Map/g' compiler/**/*.ml* sed -i 's/StructSet/StructName.Set/g' compiler/**/*.ml* sed -i 's/EnumMap/EnumName.Map/g' compiler/**/*.ml* sed -i 's/EnumSet/EnumName.Set/g' compiler/**/*.ml* sed -i 's/StructFieldName/StructField/g' compiler/**/*.ml* sed -i 's/StructFieldMap/StructField.Map/g' compiler/**/*.ml* sed -i 's/StructFieldSet/StructField.Set/g' compiler/**/*.ml* sed -i 's/EnumConstructorMap/EnumConstructor.Map/g' compiler/**/*.ml* sed -i 's/EnumConstructorSet/EnumConstructor.Set/g' compiler/**/*.ml* sed -i 's/RuleMap/RuleName.Map/g' compiler/**/*.ml* sed -i 's/RuleSet/RuleName.Set/g' compiler/**/*.ml* sed -i 's/LabelMap/LabelName.Map/g' compiler/**/*.ml* sed -i 's/LabelSet/LabelName.Set/g' compiler/**/*.ml* sed -i 's/ScopeVarMap/ScopeVar.Map/g' compiler/**/*.ml* sed -i 's/ScopeVarSet/ScopeVar.Set/g' compiler/**/*.ml* sed -i 's/SubScopeNameMap/SubScopeName.Map/g' compiler/**/*.ml* sed -i 's/SubScopeNameSet/SubScopeName.Set/g' compiler/**/*.ml* ``` ... and reformat --- compiler/dcalc/from_scopelang.ml | 131 +++++++++--------- compiler/dcalc/interpreter.ml | 20 +-- compiler/dcalc/optimizations.ml | 2 +- compiler/desugared/ast.ml | 15 ++- compiler/desugared/ast.mli | 13 +- compiler/desugared/dependency.ml | 68 +++++----- compiler/desugared/dependency.mli | 4 +- compiler/desugared/from_surface.ml | 134 ++++++++++--------- compiler/desugared/name_resolution.ml | 102 +++++++------- compiler/desugared/name_resolution.mli | 19 +-- compiler/driver.ml | 3 +- compiler/lcalc/ast.ml | 14 +- compiler/lcalc/ast.mli | 2 +- compiler/lcalc/closure_conversion.ml | 6 +- compiler/lcalc/compile_with_exceptions.ml | 4 +- compiler/lcalc/compile_without_exceptions.ml | 19 +-- compiler/lcalc/optimizations.ml | 6 +- compiler/lcalc/to_ocaml.ml | 43 +++--- compiler/lcalc/to_ocaml.mli | 6 +- compiler/plugins/api_web.ml | 24 ++-- compiler/plugins/json_schema.ml | 15 ++- compiler/scalc/ast.ml | 2 +- compiler/scalc/compile_from_lambda.ml | 4 +- compiler/scalc/print.ml | 10 +- compiler/scalc/to_python.ml | 29 ++-- compiler/scopelang/ast.ml | 12 +- compiler/scopelang/ast.mli | 4 +- compiler/scopelang/dependency.ml | 24 ++-- compiler/scopelang/from_desugared.ml | 97 +++++++------- compiler/scopelang/print.ml | 20 +-- compiler/shared_ast/definitions.ml | 24 ++-- compiler/shared_ast/expr.ml | 66 ++++----- compiler/shared_ast/expr.mli | 8 +- compiler/shared_ast/print.ml | 20 +-- compiler/shared_ast/typing.ml | 69 +++++----- compiler/shared_ast/typing.mli | 2 +- compiler/utils/uid.ml | 6 +- compiler/utils/uid.mli | 6 +- compiler/verification/z3backend.real.ml | 54 ++++---- 39 files changed, 560 insertions(+), 547 deletions(-) diff --git a/compiler/dcalc/from_scopelang.ml b/compiler/dcalc/from_scopelang.ml index 7ddffc5c..ab904e7a 100644 --- a/compiler/dcalc/from_scopelang.ml +++ b/compiler/dcalc/from_scopelang.ml @@ -31,25 +31,26 @@ type 'm scope_sig_ctx = { scope_sig_input_struct : StructName.t; (** Scope input *) scope_sig_output_struct : StructName.t; (** Scope output *) scope_sig_in_fields : - (StructFieldName.t * Desugared.Ast.io_input Marked.pos) ScopeVarMap.t; + (StructField.t * Desugared.Ast.io_input Marked.pos) ScopeVar.Map.t; (** Mapping between the input scope variables and the input struct fields. *) - scope_sig_out_fields : StructFieldName.t ScopeVarMap.t; + scope_sig_out_fields : StructField.t ScopeVar.Map.t; (** Mapping between the output scope variables and the output struct fields. TODO: could likely be removed now that we have it in the program ctx *) } -type 'm scope_sigs_ctx = 'm scope_sig_ctx ScopeMap.t +type 'm scope_sigs_ctx = 'm scope_sig_ctx ScopeName.Map.t type 'm ctx = { structs : struct_ctx; enums : enum_ctx; scope_name : ScopeName.t; scopes_parameters : 'm scope_sigs_ctx; - scope_vars : ('m Ast.expr Var.t * naked_typ * Desugared.Ast.io) ScopeVarMap.t; + scope_vars : + ('m Ast.expr Var.t * naked_typ * Desugared.Ast.io) ScopeVar.Map.t; subscope_vars : - ('m Ast.expr Var.t * naked_typ * Desugared.Ast.io) ScopeVarMap.t - SubScopeMap.t; + ('m Ast.expr Var.t * naked_typ * Desugared.Ast.io) ScopeVar.Map.t + SubScopeName.Map.t; local_vars : ('m Scopelang.Ast.expr, 'm Ast.expr Var.t) Var.Map.t; } @@ -63,8 +64,8 @@ let empty_ctx enums = enum_ctx; scope_name; scopes_parameters = scopes_ctx; - scope_vars = ScopeVarMap.empty; - subscope_vars = SubScopeMap.empty; + scope_vars = ScopeVar.Map.empty; + subscope_vars = SubScopeName.Map.empty; local_vars = Var.Map.empty; } @@ -171,7 +172,7 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm Scopelang.Ast.expr) : | LDuration _ ) as l) -> Expr.elit l m | EStruct { name; fields } -> - let fields = StructFieldMap.map (translate_expr ctx) fields in + let fields = StructField.Map.map (translate_expr ctx) fields in Expr.estruct name fields m | EStructAccess { e; field; name } -> Expr.estructaccess (translate_expr ctx e) field name m @@ -179,13 +180,13 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm Scopelang.Ast.expr) : let e' = translate_expr ctx e in Expr.einj e' cons name m | EMatch { e = e1; name; cases = e_cases } -> - let enum_sig = EnumMap.find name ctx.enums in + let enum_sig = EnumName.Map.find name ctx.enums in let d_cases, remaining_e_cases = (* FIXME: these checks should probably be moved to a better place *) - EnumConstructorMap.fold + EnumConstructor.Map.fold (fun constructor _ (d_cases, e_cases) -> let case_e = - try EnumConstructorMap.find constructor e_cases + try EnumConstructor.Map.find constructor e_cases with Not_found -> Errors.raise_spanned_error (Expr.pos e) "The constructor %a of enum %a is missing from this pattern \ @@ -193,26 +194,26 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm Scopelang.Ast.expr) : EnumConstructor.format_t constructor EnumName.format_t name in let case_d = translate_expr ctx case_e in - ( EnumConstructorMap.add constructor case_d d_cases, - EnumConstructorMap.remove constructor e_cases )) + ( EnumConstructor.Map.add constructor case_d d_cases, + EnumConstructor.Map.remove constructor e_cases )) enum_sig - (EnumConstructorMap.empty, e_cases) + (EnumConstructor.Map.empty, e_cases) in - if not (EnumConstructorMap.is_empty remaining_e_cases) then + if not (EnumConstructor.Map.is_empty remaining_e_cases) then Errors.raise_spanned_error (Expr.pos e) "Pattern matching is incomplete for enum %a: missing cases %a" EnumName.format_t name (Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt ", ") (fun fmt (case_name, _) -> EnumConstructor.format_t fmt case_name)) - (EnumConstructorMap.bindings remaining_e_cases); + (EnumConstructor.Map.bindings remaining_e_cases); let e1 = translate_expr ctx e1 in Expr.ematch e1 name d_cases m | EScopeCall { scope; args } -> let pos = Expr.mark_pos m in - let sc_sig = ScopeMap.find scope ctx.scopes_parameters in + let sc_sig = ScopeName.Map.find scope ctx.scopes_parameters in let in_var_map = - ScopeVarMap.merge + ScopeVar.Map.merge (fun var_name str_field expr -> let expr = match str_field, expr with @@ -229,7 +230,7 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm Scopelang.Ast.expr) : [ None, pos; ( Some "Declaration of the missing input variable", - Marked.get_mark (StructFieldName.get_info fld) ); + Marked.get_mark (StructField.get_info fld) ); ] "Definition of input variable '%a' missing in this scope call" ScopeVar.format_t var_name @@ -245,9 +246,9 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm Scopelang.Ast.expr) : sc_sig.scope_sig_in_fields args in let field_map = - ScopeVarMap.fold - (fun _ (fld, e) acc -> StructFieldMap.add fld e acc) - in_var_map StructFieldMap.empty + ScopeVar.Map.fold + (fun _ (fld, e) acc -> StructField.Map.add fld e acc) + in_var_map StructField.Map.empty in let arg_struct = Expr.estruct sc_sig.scope_sig_input_struct field_map (mark_tany m pos) @@ -278,7 +279,7 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm Scopelang.Ast.expr) : -- for more information see https://github.com/CatalaLang/catala/pull/280#discussion_r898851693. *) let retrieve_in_and_out_typ_or_any var vars = - let _, typ, _ = ScopeVarMap.find (Marked.unmark var) vars in + let _, typ, _ = ScopeVar.Map.find (Marked.unmark var) vars in match typ with | TArrow (marked_input_typ, marked_output_typ) -> Marked.unmark marked_input_typ, Marked.unmark marked_output_typ @@ -289,7 +290,7 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm Scopelang.Ast.expr) : retrieve_in_and_out_typ_or_any var ctx.scope_vars | ELocation (SubScopeVar (_, sname, var)) -> ctx.subscope_vars - |> SubScopeMap.find (Marked.unmark sname) + |> SubScopeName.Map.find (Marked.unmark sname) |> retrieve_in_and_out_typ_or_any var | _ -> TAny, TAny in @@ -336,13 +337,13 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm Scopelang.Ast.expr) : (List.map (translate_expr ctx) excepts) (translate_expr ctx just) (translate_expr ctx cons) m | ELocation (ScopelangScopeVar a) -> - let v, _, _ = ScopeVarMap.find (Marked.unmark a) ctx.scope_vars in + let v, _, _ = ScopeVar.Map.find (Marked.unmark a) ctx.scope_vars in Expr.evar v m | ELocation (SubScopeVar (_, s, a)) -> ( try let v, _, _ = - ScopeVarMap.find (Marked.unmark a) - (SubScopeMap.find (Marked.unmark s) ctx.subscope_vars) + ScopeVar.Map.find (Marked.unmark a) + (SubScopeName.Map.find (Marked.unmark s) ctx.subscope_vars) in Expr.evar v m with Not_found -> @@ -416,7 +417,7 @@ let translate_rule { ctx with scope_vars = - ScopeVarMap.add (Marked.unmark a) + ScopeVar.Map.add (Marked.unmark a) (a_var, Marked.unmark tau, a_io) ctx.scope_vars; } ) @@ -461,22 +462,22 @@ let translate_rule { ctx with subscope_vars = - SubScopeMap.update (Marked.unmark subs_index) + SubScopeName.Map.update (Marked.unmark subs_index) (fun map -> match map with | Some map -> Some - (ScopeVarMap.add (Marked.unmark subs_var) + (ScopeVar.Map.add (Marked.unmark subs_var) (a_var, Marked.unmark tau, a_io) map) | None -> Some - (ScopeVarMap.singleton (Marked.unmark subs_var) + (ScopeVar.Map.singleton (Marked.unmark subs_var) (a_var, Marked.unmark tau, a_io))) ctx.subscope_vars; } ) | Call (subname, subindex, m) -> - let subscope_sig = ScopeMap.find subname ctx.scopes_parameters in + let subscope_sig = ScopeName.Map.find subname ctx.scopes_parameters in let all_subscope_vars = subscope_sig.scope_sig_local_vars in let all_subscope_input_vars = List.filter @@ -496,11 +497,11 @@ let translate_rule let called_scope_input_struct = subscope_sig.scope_sig_input_struct in let called_scope_return_struct = subscope_sig.scope_sig_output_struct in let subscope_vars_defined = - try SubScopeMap.find subindex ctx.subscope_vars - with Not_found -> ScopeVarMap.empty + try SubScopeName.Map.find subindex ctx.subscope_vars + with Not_found -> ScopeVar.Map.empty in let subscope_var_not_yet_defined subvar = - not (ScopeVarMap.mem subvar subscope_vars_defined) + not (ScopeVar.Map.mem subvar subscope_vars_defined) in let pos_call = Marked.get_mark (SubScopeName.get_info subindex) in let subscope_args = @@ -515,17 +516,17 @@ let translate_rule Expr.empty_thunked_term m else let a_var, _, _ = - ScopeVarMap.find subvar.scope_var_name subscope_vars_defined + ScopeVar.Map.find subvar.scope_var_name subscope_vars_defined in Expr.make_var a_var (mark_tany m pos_call) in let field = Marked.unmark - (ScopeVarMap.find subvar.scope_var_name + (ScopeVar.Map.find subvar.scope_var_name subscope_sig.scope_sig_in_fields) in - StructFieldMap.add field e acc) - StructFieldMap.empty all_subscope_input_vars + StructField.Map.add field e acc) + StructField.Map.empty all_subscope_input_vars in let subscope_struct_arg = Expr.estruct called_scope_input_struct subscope_args @@ -583,7 +584,7 @@ let translate_rule List.fold_right (fun (var_ctx, v) next -> let field = - ScopeVarMap.find var_ctx.scope_var_name + ScopeVar.Map.find var_ctx.scope_var_name subscope_sig.scope_sig_out_fields in Bindlib.box_apply2 @@ -608,13 +609,13 @@ let translate_rule { ctx with subscope_vars = - SubScopeMap.add subindex + SubScopeName.Map.add subindex (List.fold_left (fun acc (var_ctx, dvar) -> - ScopeVarMap.add var_ctx.scope_var_name + ScopeVar.Map.add var_ctx.scope_var_name (dvar, var_ctx.scope_var_typ, var_ctx.scope_var_io) acc) - ScopeVarMap.empty all_subscope_output_vars_dcalc) + ScopeVar.Map.empty all_subscope_output_vars_dcalc) ctx.subscope_vars; } ) | Assertion e -> @@ -660,15 +661,15 @@ let translate_rules in let return_exp = Expr.estruct scope_sig.scope_sig_output_struct - (ScopeVarMap.fold + (ScopeVar.Map.fold (fun var (dcalc_var, _, io) acc -> if Marked.unmark io.Desugared.Ast.io_output then - let field = ScopeVarMap.find var scope_sig.scope_sig_out_fields in - StructFieldMap.add field + let field = ScopeVar.Map.find var scope_sig.scope_sig_out_fields in + StructField.Map.add field (Expr.make_var dcalc_var (mark_tany mark pos_sigma)) acc else acc) - new_ctx.scope_vars StructFieldMap.empty) + new_ctx.scope_vars StructField.Map.empty) (mark_tany mark pos_sigma) in ( scope_lets @@ -685,7 +686,7 @@ let translate_scope_decl (sigma : 'm Scopelang.Ast.scope_decl) : 'm Ast.expr scope_body Bindlib.box * struct_ctx = let sigma_info = ScopeName.get_info sigma.scope_decl_name in - let scope_sig = ScopeMap.find sigma.scope_decl_name sctx in + let scope_sig = ScopeName.Map.find sigma.scope_decl_name sctx in let scope_variables = scope_sig.scope_sig_local_vars in let ctx = (* the context must be initialized for fresh variables for all only-input @@ -699,7 +700,7 @@ let translate_scope_decl { ctx with scope_vars = - ScopeVarMap.add scope_var.scope_var_name + ScopeVar.Map.add scope_var.scope_var_name ( scope_var_dcalc, scope_var.scope_var_typ, scope_var.scope_var_io ) @@ -721,7 +722,7 @@ let translate_scope_decl List.map (fun var_ctx -> let dcalc_x, _, _ = - ScopeVarMap.find var_ctx.scope_var_name ctx.scope_vars + ScopeVar.Map.find var_ctx.scope_var_name ctx.scope_vars in var_ctx, dcalc_x) scope_variables @@ -748,7 +749,7 @@ let translate_scope_decl (fun (var_ctx, v) next -> let field = Marked.unmark - (ScopeVarMap.find var_ctx.scope_var_name + (ScopeVar.Map.find var_ctx.scope_var_name scope_sig.scope_sig_in_fields) in Bindlib.box_apply2 @@ -774,11 +775,13 @@ let translate_scope_decl List.fold_left (fun acc (var_ctx, _) -> let var = var_ctx.scope_var_name in - let field, _ = ScopeVarMap.find var scope_sig.scope_sig_in_fields in - StructFieldMap.add field (input_var_typ var_ctx) acc) - StructFieldMap.empty scope_input_variables + let field, _ = ScopeVar.Map.find var scope_sig.scope_sig_in_fields in + StructField.Map.add field (input_var_typ var_ctx) acc) + StructField.Map.empty scope_input_variables + in + let new_struct_ctx = + StructName.Map.singleton scope_input_struct_name field_map in - let new_struct_ctx = StructMap.singleton scope_input_struct_name field_map in ( Bindlib.box_apply (fun scope_body_expr -> { @@ -798,14 +801,14 @@ let translate_program (prgm : 'm Scopelang.Ast.program) : 'm Ast.program = in let decl_ctx = prgm.program_ctx in let sctx : 'm scope_sigs_ctx = - ScopeMap.mapi + ScopeName.Map.mapi (fun scope_name scope -> let scope_dvar = Var.make (Marked.unmark (ScopeName.get_info scope.Scopelang.Ast.scope_decl_name)) in - let scope_return = ScopeMap.find scope_name decl_ctx.ctx_scopes in + let scope_return = ScopeName.Map.find scope_name decl_ctx.ctx_scopes in let scope_input_var = Var.make (Marked.unmark (ScopeName.get_info scope_name) ^ "_in") in @@ -816,7 +819,7 @@ let translate_program (prgm : 'm Scopelang.Ast.program) : 'm Ast.program = (ScopeName.get_info scope_name)) in let scope_sig_in_fields = - ScopeVarMap.filter_map + ScopeVar.Map.filter_map (fun dvar (_, vis) -> match Marked.unmark vis.Desugared.Ast.io_input with | NoInput -> None @@ -824,7 +827,7 @@ let translate_program (prgm : 'm Scopelang.Ast.program) : 'm Ast.program = let info = ScopeVar.get_info dvar in let s = Marked.unmark info ^ "_in" in Some - ( StructFieldName.fresh (s, Marked.get_mark info), + ( StructField.fresh (s, Marked.get_mark info), vis.Desugared.Ast.io_input )) scope.scope_sig in @@ -837,7 +840,7 @@ let translate_program (prgm : 'm Scopelang.Ast.program) : 'm Ast.program = scope_var_typ = Marked.unmark tau; scope_var_io = vis; }) - (ScopeVarMap.bindings scope.scope_sig); + (ScopeVar.Map.bindings scope.scope_sig); scope_sig_scope_var = scope_dvar; scope_sig_input_var = scope_input_var; scope_sig_input_struct = scope_input_struct_name; @@ -852,17 +855,17 @@ let translate_program (prgm : 'm Scopelang.Ast.program) : 'm Ast.program = order, then the chained scopes aggregated from the right. *) let rec translate_scopes decl_ctx = function | scope_name :: next_scopes -> - let scope = ScopeMap.find scope_name prgm.program_scopes in + let scope = ScopeName.Map.find scope_name prgm.program_scopes in let scope_body, scope_in_struct = translate_scope_decl decl_ctx.ctx_structs decl_ctx.ctx_enums sctx scope_name scope in - let dvar = (ScopeMap.find scope_name sctx).scope_sig_scope_var in + let dvar = (ScopeName.Map.find scope_name sctx).scope_sig_scope_var in let decl_ctx = { decl_ctx with ctx_structs = - StructMap.union + StructName.Map.union (fun _ _ -> assert false (* should not happen *)) decl_ctx.ctx_structs scope_in_struct; } diff --git a/compiler/dcalc/interpreter.ml b/compiler/dcalc/interpreter.ml index 46aacf12..d3e2f712 100644 --- a/compiler/dcalc/interpreter.ml +++ b/compiler/dcalc/interpreter.ml @@ -180,7 +180,7 @@ let rec evaluate_operator ELit (LBool (StructName.equal s1 s2 - && StructFieldMap.equal + && StructField.Map.equal (fun e1 e2 -> match evaluate_operator ctx op pos [e1; e2] with | ELit (LBool b) -> b @@ -342,8 +342,8 @@ and evaluate_expr (ctx : decl_ctx) (e : 'm Ast.expr) : 'm Ast.expr = happen if the term was well-typed") | EAbs _ | ELit _ | EOp _ -> e (* these are values *) | EStruct { fields = es; name } -> - let new_es = StructFieldMap.map (evaluate_expr ctx) es in - if StructFieldMap.exists (fun _ e -> is_empty_error e) new_es then + let new_es = StructField.Map.map (evaluate_expr ctx) es in + if StructField.Map.exists (fun _ e -> is_empty_error e) new_es then Marked.same_mark_as (ELit LEmptyError) e else Marked.same_mark_as (EStruct { fields = new_es; name }) e | EStructAccess { e = e1; name = s; field } -> ( @@ -355,13 +355,13 @@ and evaluate_expr (ctx : decl_ctx) (e : 'm Ast.expr) : 'm Ast.expr = [None, Expr.pos e; None, Expr.pos e1] "Error during struct access: not the same structs (should not happen \ if the term was well-typed)"; - match StructFieldMap.find_opt field es with + match StructField.Map.find_opt field es with | Some e' -> e' | None -> Errors.raise_spanned_error (Expr.pos e1) "Invalid field access %a in struct %a (should not happen if the term \ was well-typed)" - StructFieldName.format_t field StructName.format_t s) + StructField.format_t field StructName.format_t s) | ELit LEmptyError -> Marked.same_mark_as (ELit LEmptyError) e | _ -> Errors.raise_spanned_error (Expr.pos e1) @@ -383,7 +383,7 @@ and evaluate_expr (ctx : decl_ctx) (e : 'm Ast.expr) : 'm Ast.expr = "Error during match: two different enums found (should not happen if \ the term was well-typed)"; let es_n = - match EnumConstructorMap.find_opt cons es with + match EnumConstructor.Map.find_opt cons es with | Some es_n -> es_n | None -> Errors.raise_spanned_error (Expr.pos e) @@ -499,9 +499,9 @@ let interpret_program : the types of the scope arguments. For [context] arguments, we can provide an empty thunked term. But for [input] arguments of another type, we cannot provide anything so we have to fail. *) - let taus = StructMap.find s_in ctx.ctx_structs in + let taus = StructName.Map.find s_in ctx.ctx_structs in let application_term = - StructFieldMap.map + StructField.Map.map (fun ty -> match Marked.unmark ty with | TArrow ((TLit TUnit, _), ty_in) -> @@ -523,8 +523,8 @@ let interpret_program : match Marked.unmark (evaluate_expr ctx (Expr.unbox to_interpret)) with | EStruct { fields; _ } -> List.map - (fun (fld, e) -> StructFieldName.get_info fld, e) - (StructFieldMap.bindings fields) + (fun (fld, e) -> StructField.get_info fld, e) + (StructField.Map.bindings fields) | _ -> Errors.raise_spanned_error (Expr.pos e) "The interpretation of a program should always yield a struct \ diff --git a/compiler/dcalc/optimizations.ml b/compiler/dcalc/optimizations.ml index 817b52c5..bfaadffb 100644 --- a/compiler/dcalc/optimizations.ml +++ b/compiler/dcalc/optimizations.ml @@ -80,7 +80,7 @@ let rec partial_evaluation (ctx : partial_evaluation_ctx) (e : 'm expr) : | EMatch { e = EInj { e; name = name1; cons }, _; cases; name } when EnumName.equal name name1 -> (* iota reduction *) - EApp { f = EnumConstructorMap.find cons cases; args = [e] } + EApp { f = EnumConstructor.Map.find cons cases; args = [e] } | EApp { f = EAbs { binder; _ }, _; args } -> (* beta reduction *) Marked.unmark (Bindlib.msubst binder (List.map fst args |> Array.of_list)) diff --git a/compiler/desugared/ast.ml b/compiler/desugared/ast.ml index 1a69fa9f..b40f9d46 100644 --- a/compiler/desugared/ast.ml +++ b/compiler/desugared/ast.ml @@ -178,7 +178,7 @@ type meta_assertion = | VariesWith of unit * variation_typ Marked.pos option type scope_def = { - scope_def_rules : rule RuleMap.t; + scope_def_rules : rule RuleName.Map.t; scope_def_typ : typ; scope_def_is_condition : bool; scope_def_io : io; @@ -187,15 +187,18 @@ type scope_def = { type var_or_states = WholeVar | States of StateName.t list type scope = { - scope_vars : var_or_states ScopeVarMap.t; - scope_sub_scopes : ScopeName.t SubScopeMap.t; + scope_vars : var_or_states ScopeVar.Map.t; + scope_sub_scopes : ScopeName.t SubScopeName.Map.t; scope_uid : ScopeName.t; scope_defs : scope_def ScopeDefMap.t; scope_assertions : assertion list; scope_meta_assertions : meta_assertion list; } -type program = { program_scopes : scope ScopeMap.t; program_ctx : decl_ctx } +type program = { + program_scopes : scope ScopeName.Map.t; + program_ctx : decl_ctx; +} let rec locations_used e : LocationSet.t = match e with @@ -208,7 +211,7 @@ let rec locations_used e : LocationSet.t = (fun e -> LocationSet.union (locations_used e)) e LocationSet.empty -let free_variables (def : rule RuleMap.t) : Pos.t ScopeDefMap.t = +let free_variables (def : rule RuleName.Map.t) : Pos.t ScopeDefMap.t = let add_locs (acc : Pos.t ScopeDefMap.t) (locs : LocationSet.t) : Pos.t ScopeDefMap.t = LocationSet.fold @@ -224,7 +227,7 @@ let free_variables (def : rule RuleMap.t) : Pos.t ScopeDefMap.t = loc_pos acc) locs acc in - RuleMap.fold + RuleName.Map.fold (fun _ rule acc -> let locs = LocationSet.union diff --git a/compiler/desugared/ast.mli b/compiler/desugared/ast.mli index f14aef54..ab4d7d2f 100644 --- a/compiler/desugared/ast.mli +++ b/compiler/desugared/ast.mli @@ -100,7 +100,7 @@ type io = { (** Characterization of the input/output status of a scope variable. *) type scope_def = { - scope_def_rules : rule RuleMap.t; + scope_def_rules : rule RuleName.Map.t; scope_def_typ : typ; scope_def_is_condition : bool; scope_def_io : io; @@ -109,17 +109,20 @@ type scope_def = { type var_or_states = WholeVar | States of StateName.t list type scope = { - scope_vars : var_or_states ScopeVarMap.t; - scope_sub_scopes : ScopeName.t SubScopeMap.t; + scope_vars : var_or_states ScopeVar.Map.t; + scope_sub_scopes : ScopeName.t SubScopeName.Map.t; scope_uid : ScopeName.t; scope_defs : scope_def ScopeDefMap.t; scope_assertions : assertion list; scope_meta_assertions : meta_assertion list; } -type program = { program_scopes : scope ScopeMap.t; program_ctx : decl_ctx } +type program = { + program_scopes : scope ScopeName.Map.t; + program_ctx : decl_ctx; +} (** {1 Helpers} *) val locations_used : expr -> LocationSet.t -val free_variables : rule RuleMap.t -> Pos.t ScopeDefMap.t +val free_variables : rule RuleName.Map.t -> Pos.t ScopeDefMap.t diff --git a/compiler/desugared/dependency.ml b/compiler/desugared/dependency.ml index 097bb349..4da66bc7 100644 --- a/compiler/desugared/dependency.ml +++ b/compiler/desugared/dependency.ml @@ -143,7 +143,7 @@ let build_scope_dependencies (scope : Ast.scope) : ScopeDependencies.t = let g = ScopeDependencies.empty in (* Add all the vertices to the graph *) let g = - ScopeVarMap.fold + ScopeVar.Map.fold (fun (v : ScopeVar.t) var_or_state g -> match var_or_state with | Ast.WholeVar -> ScopeDependencies.add_vertex g (Vertex.Var (v, None)) @@ -155,7 +155,7 @@ let build_scope_dependencies (scope : Ast.scope) : ScopeDependencies.t = scope.scope_vars g in let g = - SubScopeMap.fold + SubScopeName.Map.fold (fun (v : SubScopeName.t) _ g -> ScopeDependencies.add_vertex g (Vertex.SubScope v)) scope.scope_sub_scopes g @@ -229,10 +229,10 @@ let build_scope_dependencies (scope : Ast.scope) : ScopeDependencies.t = (** {2 Graph declaration} *) module ExceptionVertex = struct - include RuleSet + include RuleName.Set let hash (x : t) : int = - RuleSet.fold (fun r acc -> Int.logxor (RuleName.hash r) acc) x 0 + RuleName.Set.fold (fun r acc -> Int.logxor (RuleName.hash r) acc) x 0 let equal x y = compare x y = 0 end @@ -263,7 +263,7 @@ type exception_edge = { } let build_exceptions_graph - (def : Ast.rule RuleMap.t) + (def : Ast.rule RuleName.Map.t) (def_info : Ast.ScopeDef.t) : ExceptionsDependencies.t = (* First we partition the definitions into groups bearing the same label. To handle the rules that were not labeled by the user, we create implicit @@ -275,53 +275,55 @@ let build_exceptions_graph (* When declaring [exception definition x ...], it means there is a unique rule [R] to which this can be an exception to. So we give a unique label to all the rules that are implicitly exceptions to rule [R]. *) - let exception_to_rule_implicit_labels : LabelName.t RuleMap.t = - RuleMap.fold + let exception_to_rule_implicit_labels : LabelName.t RuleName.Map.t = + RuleName.Map.fold (fun _ rule_from exception_to_rule_implicit_labels -> match rule_from.Ast.rule_exception with | Ast.ExceptionToRule (rule_to, _) -> ( - match RuleMap.find_opt rule_to exception_to_rule_implicit_labels with + match + RuleName.Map.find_opt rule_to exception_to_rule_implicit_labels + with | Some _ -> (* we already created the label *) exception_to_rule_implicit_labels | None -> - RuleMap.add rule_to + RuleName.Map.add rule_to (LabelName.fresh ( "exception_to_" ^ Marked.unmark (RuleName.get_info rule_to), Pos.no_pos )) exception_to_rule_implicit_labels) | _ -> exception_to_rule_implicit_labels) - def RuleMap.empty + def RuleName.Map.empty in (* When declaring [exception foo_l definition x ...], the rule is exception to all the rules sharing label [foo_l]. So we give a unique label to all the rules that are implicitly exceptions to rule [foo_l]. *) - let exception_to_label_implicit_labels : LabelName.t LabelMap.t = - RuleMap.fold + let exception_to_label_implicit_labels : LabelName.t LabelName.Map.t = + RuleName.Map.fold (fun _ rule_from - (exception_to_label_implicit_labels : LabelName.t LabelMap.t) -> + (exception_to_label_implicit_labels : LabelName.t LabelName.Map.t) -> match rule_from.Ast.rule_exception with | Ast.ExceptionToLabel (label_to, _) -> ( match - LabelMap.find_opt label_to exception_to_label_implicit_labels + LabelName.Map.find_opt label_to exception_to_label_implicit_labels with | Some _ -> (* we already created the label *) exception_to_label_implicit_labels | None -> - LabelMap.add label_to + LabelName.Map.add label_to (LabelName.fresh ( "exception_to_" ^ Marked.unmark (LabelName.get_info label_to), Pos.no_pos )) exception_to_label_implicit_labels) | _ -> exception_to_label_implicit_labels) - def LabelMap.empty + def LabelName.Map.empty in (* Now we have all the labels necessary to partition our rules into sets, each one corresponding to a label relating to the structure of the exception DAG. *) let label_to_rule_sets = - RuleMap.fold + RuleName.Map.fold (fun rule_name rule rule_sets -> let label_of_rule = match rule.Ast.rule_label with @@ -330,23 +332,23 @@ let build_exceptions_graph match rule.Ast.rule_exception with | BaseCase -> base_case_implicit_label | ExceptionToRule (r, _) -> - RuleMap.find r exception_to_rule_implicit_labels + RuleName.Map.find r exception_to_rule_implicit_labels | ExceptionToLabel (l', _) -> - LabelMap.find l' exception_to_label_implicit_labels) + LabelName.Map.find l' exception_to_label_implicit_labels) in - LabelMap.update label_of_rule + LabelName.Map.update label_of_rule (fun rule_set -> match rule_set with - | None -> Some (RuleSet.singleton rule_name) - | Some rule_set -> Some (RuleSet.add rule_name rule_set)) + | None -> Some (RuleName.Set.singleton rule_name) + | Some rule_set -> Some (RuleName.Set.add rule_name rule_set)) rule_sets) - def LabelMap.empty + def LabelName.Map.empty in let find_label_of_rule (r : RuleName.t) : LabelName.t = fst - (LabelMap.choose - (LabelMap.filter - (fun _ rule_set -> RuleSet.mem r rule_set) + (LabelName.Map.choose + (LabelName.Map.filter + (fun _ rule_set -> RuleName.Set.mem r rule_set) label_to_rule_sets)) in (* Next, we collect the exception edges between those groups of rules referred @@ -354,7 +356,7 @@ let build_exceptions_graph edges as they are declared at each rule but should be the same for all the rules of the same group. *) let exception_edges : exception_edge list = - RuleMap.fold + RuleName.Map.fold (fun rule_name rule exception_edges -> let label_from = find_label_of_rule rule_name in let label_to_and_pos = @@ -414,7 +416,7 @@ let build_exceptions_graph in (* We've got the vertices and the edges, let's build the graph! *) let g = - LabelMap.fold + LabelName.Map.fold (fun _label rule_set g -> ExceptionsDependencies.add_vertex g rule_set) label_to_rule_sets ExceptionsDependencies.empty in @@ -423,9 +425,11 @@ let build_exceptions_graph List.fold_left (fun g edge -> let rule_group_from = - LabelMap.find edge.label_from label_to_rule_sets + LabelName.Map.find edge.label_from label_to_rule_sets + in + let rule_group_to = + LabelName.Map.find edge.label_to label_to_rule_sets in - let rule_group_to = LabelMap.find edge.label_to label_to_rule_sets in let edge = ExceptionsDependencies.E.create rule_group_from edge.edge_positions rule_group_to @@ -445,8 +449,8 @@ let check_for_exception_cycle (g : ExceptionsDependencies.t) : unit = let spans = List.flatten (List.map - (fun (vs : RuleSet.t) -> - let v = RuleSet.choose vs in + (fun (vs : RuleName.Set.t) -> + let v = RuleName.Set.choose vs in let var_str, var_info = Format.asprintf "%a" RuleName.format_t v, RuleName.get_info v in diff --git a/compiler/desugared/dependency.mli b/compiler/desugared/dependency.mli index 5487b786..6858aca0 100644 --- a/compiler/desugared/dependency.mli +++ b/compiler/desugared/dependency.mli @@ -72,9 +72,9 @@ val build_scope_dependencies : Ast.scope -> ScopeDependencies.t module EdgeExceptions : Graph.Sig.ORDERED_TYPE_DFT with type t = Pos.t list module ExceptionsDependencies : - Graph.Sig.P with type V.t = RuleSet.t and type E.label = EdgeExceptions.t + Graph.Sig.P with type V.t = RuleName.Set.t and type E.label = EdgeExceptions.t val build_exceptions_graph : - Ast.rule RuleMap.t -> Ast.ScopeDef.t -> ExceptionsDependencies.t + Ast.rule RuleName.Map.t -> Ast.ScopeDef.t -> ExceptionsDependencies.t val check_for_exception_cycle : ExceptionsDependencies.t -> unit diff --git a/compiler/desugared/from_surface.ml b/compiler/desugared/from_surface.ml index 0a98c54d..ff49420b 100644 --- a/compiler/desugared/from_surface.ml +++ b/compiler/desugared/from_surface.ml @@ -79,7 +79,7 @@ let disambiguate_constructor in match enum with | None -> - if EnumMap.cardinal possible_c_uids > 1 then + if EnumName.Map.cardinal possible_c_uids > 1 then Errors.raise_spanned_error (Marked.get_mark constructor) "This constructor name is ambiguous, it can belong to %a. Disambiguate \ @@ -88,14 +88,14 @@ let disambiguate_constructor ~pp_sep:(fun fmt () -> Format.fprintf fmt " or ") (fun fmt (s_name, _) -> Format.fprintf fmt "%a" EnumName.format_t s_name)) - (EnumMap.bindings possible_c_uids); - EnumMap.choose possible_c_uids + (EnumName.Map.bindings possible_c_uids); + EnumName.Map.choose possible_c_uids | Some enum -> ( try (* The path is fully qualified *) let e_uid = Name_resolution.get_enum ctxt enum in try - let c_uid = EnumMap.find e_uid possible_c_uids in + let c_uid = EnumName.Map.find e_uid possible_c_uids in e_uid, c_uid with Not_found -> Errors.raise_spanned_error pos "Enum %s does not contain case %s" @@ -114,7 +114,7 @@ let rec translate_expr (inside_definition_of : Ast.ScopeDef.t Marked.pos option) (ctxt : Name_resolution.context) (expr : Surface.Ast.expression Marked.pos) : Ast.expr boxed = - let scope_ctxt = ScopeMap.find scope ctxt.scopes in + let scope_ctxt = ScopeName.Map.find scope ctxt.scopes in let rec_helper = translate_expr scope inside_definition_of ctxt in let pos = Marked.get_mark expr in let emark = Untyped { pos } in @@ -130,7 +130,7 @@ let rec translate_expr disambiguate_constructor ctxt constructors pos_pattern in let cases = - EnumConstructorMap.mapi + EnumConstructor.Map.mapi (fun c_uid' tau -> if EnumConstructor.compare c_uid c_uid' <> 0 then let nop_var = Var.make "_" in @@ -143,7 +143,7 @@ let rec translate_expr in let e2 = translate_expr scope inside_definition_of ctxt e2 in Expr.make_abs [| binding_var |] e2 [tau] pos) - (EnumMap.find enum_uid ctxt.enums) + (EnumName.Map.find enum_uid ctxt.enums) in Expr.ematch (translate_expr scope inside_definition_of ctxt e1_sub) @@ -212,7 +212,7 @@ let rec translate_expr desambiguate. In general, only the last state can be referenced. Except if defining a state of the same variable, then it references the previous state in the chain. *) - let x_sig = ScopeVarMap.find uid ctxt.var_typs in + let x_sig = ScopeVar.Map.find uid ctxt.var_typs in let x_state = match x_sig.var_sig_states_list with | [] -> None @@ -281,7 +281,7 @@ let rec translate_expr match c with | None -> (* No constructor name was specified *) - if StructMap.cardinal x_possible_structs > 1 then + if StructName.Map.cardinal x_possible_structs > 1 then Errors.raise_spanned_error (Marked.get_mark x) "This struct field name is ambiguous, it can belong to %a. \ Disambiguate it by prefixing it with the struct name." @@ -289,15 +289,15 @@ let rec translate_expr ~pp_sep:(fun fmt () -> Format.fprintf fmt " or ") (fun fmt (s_name, _) -> Format.fprintf fmt "%a" StructName.format_t s_name)) - (StructMap.bindings x_possible_structs) + (StructName.Map.bindings x_possible_structs) else - let s_uid, f_uid = StructMap.choose x_possible_structs in + let s_uid, f_uid = StructName.Map.choose x_possible_structs in Expr.estructaccess e f_uid s_uid emark | Some c_name -> ( try let c_uid = Name_resolution.get_struct ctxt c_name in try - let f_uid = StructMap.find c_uid x_possible_structs in + let f_uid = StructName.Map.find c_uid x_possible_structs in Expr.estructaccess e f_uid c_uid emark with Not_found -> Errors.raise_spanned_error pos "Struct %s does not contain field %s" @@ -308,7 +308,7 @@ let rec translate_expr | FunCall (f, arg) -> Expr.eapp (rec_helper f) [rec_helper arg] emark | ScopeCall (sc_name, fields) -> let called_scope = Name_resolution.get_scope ctxt sc_name in - let scope_def = ScopeMap.find called_scope ctxt.scopes in + let scope_def = ScopeName.Map.find called_scope ctxt.scopes in let in_struct = List.fold_left (fun acc (fld_id, e) -> @@ -330,7 +330,7 @@ let rec translate_expr "Scope %a has no input variable %a" ScopeName.format_t called_scope Print.lit_style (Marked.unmark fld_id) in - ScopeVarMap.update var + ScopeVar.Map.update var (function | None -> Some (rec_helper e) | Some _ -> @@ -338,7 +338,7 @@ let rec translate_expr "Duplicate definition of scope input variable '%a'" ScopeVar.format_t var) acc) - ScopeVarMap.empty fields + ScopeVar.Map.empty fields in Expr.escopecall called_scope in_struct emark | LetIn (x, e1, e2) -> @@ -366,7 +366,7 @@ let rec translate_expr (fun s_fields (f_name, f_e) -> let f_uid = try - StructMap.find s_uid + StructName.Map.find s_uid (Name_resolution.IdentMap.find (Marked.unmark f_name) ctxt.field_idmap) with Not_found -> @@ -374,24 +374,23 @@ let rec translate_expr "This identifier should refer to a field of struct %s" (Marked.unmark s_name) in - (match StructFieldMap.find_opt f_uid s_fields with + (match StructField.Map.find_opt f_uid s_fields with | None -> () | Some e_field -> Errors.raise_multispanned_error [None, Marked.get_mark f_e; None, Expr.pos e_field] - "The field %a has been defined twice:" StructFieldName.format_t - f_uid); + "The field %a has been defined twice:" StructField.format_t f_uid); let f_e = translate_expr scope inside_definition_of ctxt f_e in - StructFieldMap.add f_uid f_e s_fields) - StructFieldMap.empty fields + StructField.Map.add f_uid f_e s_fields) + StructField.Map.empty fields in - let expected_s_fields = StructMap.find s_uid ctxt.structs in - StructFieldMap.iter + let expected_s_fields = StructName.Map.find s_uid ctxt.structs in + StructField.Map.iter (fun expected_f _ -> - if not (StructFieldMap.mem expected_f s_fields) then + if not (StructField.Map.mem expected_f s_fields) then Errors.raise_spanned_error pos "Missing field for structure %a: \"%a\"" StructName.format_t s_uid - StructFieldName.format_t expected_f) + StructField.format_t expected_f) expected_s_fields; Expr.estruct s_uid s_fields emark @@ -409,7 +408,7 @@ let rec translate_expr | None -> if (* No constructor name was specified *) - EnumMap.cardinal possible_c_uids > 1 + EnumName.Map.cardinal possible_c_uids > 1 then Errors.raise_spanned_error pos_constructor "This constructor name is ambiguous, it can belong to %a. \ @@ -418,9 +417,9 @@ let rec translate_expr ~pp_sep:(fun fmt () -> Format.fprintf fmt " or ") (fun fmt (s_name, _) -> Format.fprintf fmt "%a" EnumName.format_t s_name)) - (EnumMap.bindings possible_c_uids) + (EnumName.Map.bindings possible_c_uids) else - let e_uid, c_uid = EnumMap.choose possible_c_uids in + let e_uid, c_uid = EnumName.Map.choose possible_c_uids in let payload = Option.map (translate_expr scope inside_definition_of ctxt) payload in @@ -434,7 +433,7 @@ let rec translate_expr (* The path has been fully qualified *) let e_uid = Name_resolution.get_enum ctxt enum in try - let c_uid = EnumMap.find e_uid possible_c_uids in + let c_uid = EnumName.Map.find e_uid possible_c_uids in let payload = Option.map (translate_expr scope inside_definition_of ctxt) payload in @@ -468,13 +467,13 @@ let rec translate_expr (Marked.get_mark pattern) in let cases = - EnumConstructorMap.mapi + EnumConstructor.Map.mapi (fun c_uid' tau -> let nop_var = Var.make "_" in Expr.make_abs [| nop_var |] (Expr.elit (LBool (EnumConstructor.compare c_uid c_uid' = 0)) emark) [tau] pos) - (EnumMap.find enum_uid ctxt.enums) + (EnumName.Map.find enum_uid ctxt.enums) in Expr.ematch (translate_expr scope inside_definition_of ctxt e1) @@ -743,7 +742,7 @@ and disambiguate_match_and_build_expression (inside_definition_of : Ast.ScopeDef.t Marked.pos option) (ctxt : Name_resolution.context) (cases : Surface.Ast.match_case Marked.pos list) : - Ast.expr boxed EnumConstructorMap.t * EnumName.t = + Ast.expr boxed EnumConstructor.Map.t * EnumName.t = let create_var = function | None -> ctxt, Var.make "_" | Some param -> @@ -758,8 +757,8 @@ and disambiguate_match_and_build_expression e_binder = Expr.eabs e_binder [ - EnumConstructorMap.find c_uid - (EnumMap.find e_uid ctxt.Name_resolution.enums); + EnumConstructor.Map.find c_uid + (EnumName.Map.find e_uid ctxt.Name_resolution.enums); ] (Marked.get_mark case_body) in @@ -785,7 +784,7 @@ and disambiguate_match_and_build_expression case were matching constructors of enumeration %a" EnumName.format_t e_uid EnumName.format_t e_uid' in - (match EnumConstructorMap.find_opt c_uid cases_d with + (match EnumConstructor.Map.find_opt c_uid cases_d with | None -> () | Some e_case -> Errors.raise_multispanned_error @@ -799,7 +798,9 @@ and disambiguate_match_and_build_expression in let e_binder = Expr.bind [| param_var |] case_body in let case_expr = bind_case_body c_uid e_uid ctxt case_body e_binder in - EnumConstructorMap.add c_uid case_expr cases_d, Some e_uid, curr_index + 1 + ( EnumConstructor.Map.add c_uid case_expr cases_d, + Some e_uid, + curr_index + 1 ) | Surface.Ast.WildCard match_case_expr -> ( let nb_cases = List.length cases in let raise_wildcard_not_last_case_err () = @@ -821,13 +822,13 @@ and disambiguate_match_and_build_expression | Some e_uid -> if curr_index < nb_cases - 1 then raise_wildcard_not_last_case_err (); let missing_constructors = - EnumMap.find e_uid ctxt.Name_resolution.enums - |> EnumConstructorMap.filter_map (fun c_uid _ -> - match EnumConstructorMap.find_opt c_uid cases_d with + EnumName.Map.find e_uid ctxt.Name_resolution.enums + |> EnumConstructor.Map.filter_map (fun c_uid _ -> + match EnumConstructor.Map.find_opt c_uid cases_d with | Some _ -> None | None -> Some c_uid) in - if EnumConstructorMap.is_empty missing_constructors then + if EnumConstructor.Map.is_empty missing_constructors then Errors.format_spanned_warning case_pos "Unreachable match case, all constructors of the enumeration %a \ are already specified" @@ -851,19 +852,19 @@ and disambiguate_match_and_build_expression let e_binder = Expr.bind [| payload_var |] case_body in (* For each missing cases, binds the wildcard payload. *) - EnumConstructorMap.fold + EnumConstructor.Map.fold (fun c_uid _ (cases_d, e_uid_opt, curr_index) -> let case_expr = bind_case_body c_uid e_uid ctxt case_body e_binder in - ( EnumConstructorMap.add c_uid case_expr cases_d, + ( EnumConstructor.Map.add c_uid case_expr cases_d, e_uid_opt, curr_index + 1 )) missing_constructors (cases_d, Some e_uid, curr_index)) in let naked_expr, e_name, _ = - List.fold_left bind_match_cases (EnumConstructorMap.empty, None, 0) cases + List.fold_left bind_match_cases (EnumConstructor.Map.empty, None, 0) cases in naked_expr, Option.get e_name [@@ocamlformat "wrap-comments=false"] @@ -934,8 +935,8 @@ let process_def (ctxt : Name_resolution.context) (prgm : Ast.program) (def : Surface.Ast.definition) : Ast.program = - let scope : Ast.scope = ScopeMap.find scope_uid prgm.program_scopes in - let scope_ctxt = ScopeMap.find scope_uid ctxt.scopes in + let scope : Ast.scope = ScopeName.Map.find scope_uid prgm.program_scopes in + let scope_ctxt = ScopeName.Map.find scope_uid ctxt.scopes in let def_key = Name_resolution.get_def_key (Marked.unmark def.definition_name) @@ -994,7 +995,7 @@ let process_def { scope_def with scope_def_rules = - RuleMap.add rule_name + RuleName.Map.add rule_name (process_default new_ctxt scope_uid (def_key, Marked.get_mark def.definition_name) rule_name param_uid precond exception_situation label_situation @@ -1009,7 +1010,8 @@ let process_def in { prgm with - program_scopes = ScopeMap.add scope_uid scope_updated prgm.program_scopes; + program_scopes = + ScopeName.Map.add scope_uid scope_updated prgm.program_scopes; } (** Translates a {!type: Surface.Ast.rule} from the surface language *) @@ -1029,7 +1031,7 @@ let process_assert (ctxt : Name_resolution.context) (prgm : Ast.program) (ass : Surface.Ast.assertion) : Ast.program = - let scope : Ast.scope = ScopeMap.find scope_uid prgm.program_scopes in + let scope : Ast.scope = ScopeName.Map.find scope_uid prgm.program_scopes in let ass = translate_expr scope_uid None ctxt (match ass.Surface.Ast.assertion_condition with @@ -1055,7 +1057,7 @@ let process_assert in { prgm with - program_scopes = ScopeMap.add scope_uid new_scope prgm.program_scopes; + program_scopes = ScopeName.Map.add scope_uid new_scope prgm.program_scopes; } (** Translates a surface definition, rule or assertion *) @@ -1080,7 +1082,7 @@ let check_unlabeled_exception (scope : ScopeName.t) (ctxt : Name_resolution.context) (item : Surface.Ast.scope_use_item Marked.pos) : unit = - let scope_ctxt = ScopeMap.find scope ctxt.scopes in + let scope_ctxt = ScopeName.Map.find scope ctxt.scopes in match Marked.unmark item with | Surface.Ast.Rule _ | Surface.Ast.Definition _ -> ( let def_key, exception_to = @@ -1129,7 +1131,7 @@ let process_scope_use let scope_uid = Name_resolution.get_scope ctxt use.scope_use_name in (* Make sure the scope exists *) let prgm = - match ScopeMap.find_opt scope_uid prgm.program_scopes with + match ScopeName.Map.find_opt scope_uid prgm.program_scopes with | Some _ -> prgm | None -> assert false (* should not happen *) @@ -1163,13 +1165,13 @@ let init_scope_defs let add_def _ v scope_def_map = match v with | Name_resolution.ScopeVar v -> ( - let v_sig = ScopeVarMap.find v ctxt.Name_resolution.var_typs in + let v_sig = ScopeVar.Map.find v ctxt.Name_resolution.var_typs in match v_sig.var_sig_states_list with | [] -> let def_key = Ast.ScopeDef.Var (v, None) in Ast.ScopeDefMap.add def_key { - Ast.scope_def_rules = RuleMap.empty; + Ast.scope_def_rules = RuleName.Map.empty; Ast.scope_def_typ = v_sig.var_sig_typ; Ast.scope_def_is_condition = v_sig.var_sig_is_condition; Ast.scope_def_io = attribute_to_io v_sig.var_sig_io; @@ -1182,7 +1184,7 @@ let init_scope_defs let def_key = Ast.ScopeDef.Var (v, Some state) in let def = { - Ast.scope_def_rules = RuleMap.empty; + Ast.scope_def_rules = RuleName.Map.empty; Ast.scope_def_typ = v_sig.var_sig_typ; Ast.scope_def_is_condition = v_sig.var_sig_is_condition; Ast.scope_def_io = @@ -1209,7 +1211,7 @@ let init_scope_defs scope_def) | Name_resolution.SubScope (v0, subscope_uid) -> let sub_scope_def = - ScopeMap.find subscope_uid ctxt.Name_resolution.scopes + ScopeName.Map.find subscope_uid ctxt.Name_resolution.scopes in Name_resolution.IdentMap.fold (fun _ v scope_def_map -> @@ -1218,14 +1220,14 @@ let init_scope_defs | Name_resolution.ScopeVar v -> (* TODO: shouldn't we ignore internal variables too at this point ? *) - let v_sig = ScopeVarMap.find v ctxt.Name_resolution.var_typs in + let v_sig = ScopeVar.Map.find v ctxt.Name_resolution.var_typs in let def_key = Ast.ScopeDef.SubScopeVar (v0, v, Marked.get_mark (ScopeVar.get_info v)) in Ast.ScopeDefMap.add def_key { - Ast.scope_def_rules = RuleMap.empty; + Ast.scope_def_rules = RuleName.Map.empty; Ast.scope_def_typ = v_sig.var_sig_typ; Ast.scope_def_is_condition = v_sig.var_sig_is_condition; Ast.scope_def_io = attribute_to_io v_sig.var_sig_io; @@ -1241,7 +1243,7 @@ let translate_program (prgm : Surface.Ast.program) : Ast.program = let empty_prgm = let program_scopes = - ScopeMap.mapi + ScopeName.Map.mapi (fun s_uid s_context -> let scope_vars = Name_resolution.IdentMap.fold @@ -1249,11 +1251,11 @@ let translate_program match v with | Name_resolution.SubScope _ -> acc | Name_resolution.ScopeVar v -> ( - let v_sig = ScopeVarMap.find v ctxt.var_typs in + let v_sig = ScopeVar.Map.find v ctxt.var_typs in match v_sig.var_sig_states_list with - | [] -> ScopeVarMap.add v Ast.WholeVar acc - | states -> ScopeVarMap.add v (Ast.States states) acc)) - s_context.Name_resolution.var_idmap ScopeVarMap.empty + | [] -> ScopeVar.Map.add v Ast.WholeVar acc + | states -> ScopeVar.Map.add v (Ast.States states) acc)) + s_context.Name_resolution.var_idmap ScopeVar.Map.empty in let scope_sub_scopes = Name_resolution.IdentMap.fold @@ -1261,8 +1263,8 @@ let translate_program match v with | Name_resolution.ScopeVar _ -> acc | Name_resolution.SubScope (sub_var, sub_scope) -> - SubScopeMap.add sub_var sub_scope acc) - s_context.Name_resolution.var_idmap SubScopeMap.empty + SubScopeName.Map.add sub_var sub_scope acc) + s_context.Name_resolution.var_idmap SubScopeName.Map.empty in { Ast.scope_vars; @@ -1284,9 +1286,9 @@ let translate_program (fun _ def acc -> match def with | Name_resolution.TScope (scope, scope_out_struct) -> - ScopeMap.add scope scope_out_struct acc + ScopeName.Map.add scope scope_out_struct acc | _ -> acc) - ctxt.Name_resolution.typedefs ScopeMap.empty; + ctxt.Name_resolution.typedefs ScopeName.Map.empty; }; Ast.program_scopes; } diff --git a/compiler/desugared/name_resolution.ml b/compiler/desugared/name_resolution.ml index e7e77c37..64be33bc 100644 --- a/compiler/desugared/name_resolution.ml +++ b/compiler/desugared/name_resolution.ml @@ -45,15 +45,15 @@ type scope_context = { (** All variables, including scope variables and subscopes *) scope_defs_contexts : scope_def_context Ast.ScopeDefMap.t; (** What is the default rule to refer to for unnamed exceptions, if any *) - sub_scopes : ScopeSet.t; + sub_scopes : ScopeName.Set.t; (** Other scopes referred to by this scope. Used for dependency analysis *) } (** Inside a scope, we distinguish between the variables and the subscopes. *) -type struct_context = typ StructFieldMap.t +type struct_context = typ StructField.Map.t (** Types of the fields of a struct *) -type enum_context = typ EnumConstructorMap.t +type enum_context = typ EnumConstructor.Map.t (** Types of the payloads of the cases of an enum *) type var_sig = { @@ -78,16 +78,17 @@ type context = { arguments or pattern matching *) typedefs : typedef IdentMap.t; (** Gathers the names of the scopes, structs and enums *) - field_idmap : StructFieldName.t StructMap.t IdentMap.t; + field_idmap : StructField.t StructName.Map.t IdentMap.t; (** The names of the struct fields. Names of fields can be shared between different structs *) - constructor_idmap : EnumConstructor.t EnumMap.t IdentMap.t; + constructor_idmap : EnumConstructor.t EnumName.Map.t IdentMap.t; (** The names of the enum constructors. Constructor names can be shared between different enums *) - scopes : scope_context ScopeMap.t; (** For each scope, its context *) - structs : struct_context StructMap.t; (** For each struct, its context *) - enums : enum_context EnumMap.t; (** For each enum, its context *) - var_typs : var_sig ScopeVarMap.t; + scopes : scope_context ScopeName.Map.t; (** For each scope, its context *) + structs : struct_context StructName.Map.t; + (** For each struct, its context *) + enums : enum_context EnumName.Map.t; (** For each enum, its context *) + var_typs : var_sig ScopeVar.Map.t; (** The signatures of each scope variable declared *) } (** Main context used throughout {!module: Surface.Desugaring} *) @@ -109,21 +110,21 @@ let raise_unknown_identifier (msg : string) (ident : ident Marked.pos) = (** Gets the type associated to an uid *) let get_var_typ (ctxt : context) (uid : ScopeVar.t) : typ = - (ScopeVarMap.find uid ctxt.var_typs).var_sig_typ + (ScopeVar.Map.find uid ctxt.var_typs).var_sig_typ let is_var_cond (ctxt : context) (uid : ScopeVar.t) : bool = - (ScopeVarMap.find uid ctxt.var_typs).var_sig_is_condition + (ScopeVar.Map.find uid ctxt.var_typs).var_sig_is_condition let get_var_io (ctxt : context) (uid : ScopeVar.t) : Surface.Ast.scope_decl_context_io = - (ScopeVarMap.find uid ctxt.var_typs).var_sig_io + (ScopeVar.Map.find uid ctxt.var_typs).var_sig_io (** Get the variable uid inside the scope given in argument *) let get_var_uid (scope_uid : ScopeName.t) (ctxt : context) ((x, pos) : ident Marked.pos) : ScopeVar.t = - let scope = ScopeMap.find scope_uid ctxt.scopes in + let scope = ScopeName.Map.find scope_uid ctxt.scopes in match IdentMap.find_opt x scope.var_idmap with | Some (ScopeVar uid) -> uid | _ -> @@ -136,7 +137,7 @@ let get_subscope_uid (scope_uid : ScopeName.t) (ctxt : context) ((y, pos) : ident Marked.pos) : SubScopeName.t = - let scope = ScopeMap.find scope_uid ctxt.scopes in + let scope = ScopeName.Map.find scope_uid ctxt.scopes in match IdentMap.find_opt y scope.var_idmap with | Some (SubScope (sub_uid, _sub_id)) -> sub_uid | _ -> raise_unknown_identifier "for a subscope of this scope" (y, pos) @@ -145,7 +146,7 @@ let get_subscope_uid subscopes of [scope_uid]. *) let is_subscope_uid (scope_uid : ScopeName.t) (ctxt : context) (y : ident) : bool = - let scope = ScopeMap.find scope_uid ctxt.scopes in + let scope = ScopeName.Map.find scope_uid ctxt.scopes in match IdentMap.find_opt y scope.var_idmap with | Some (SubScope _) -> true | _ -> false @@ -153,7 +154,7 @@ let is_subscope_uid (scope_uid : ScopeName.t) (ctxt : context) (y : ident) : (** Checks if the var_uid belongs to the scope scope_uid *) let belongs_to (ctxt : context) (uid : ScopeVar.t) (scope_uid : ScopeName.t) : bool = - let scope = ScopeMap.find scope_uid ctxt.scopes in + let scope = ScopeName.Map.find scope_uid ctxt.scopes in IdentMap.exists (fun _ -> function | ScopeVar var_uid -> ScopeVar.equal uid var_uid @@ -242,7 +243,7 @@ let process_subscope_decl (decl : Surface.Ast.scope_decl_context_scope) : context = let name, name_pos = decl.scope_decl_context_scope_name in let subscope, s_pos = decl.scope_decl_context_scope_sub_scope in - let scope_ctxt = ScopeMap.find scope ctxt.scopes in + let scope_ctxt = ScopeName.Map.find scope ctxt.scopes in match IdentMap.find_opt subscope scope_ctxt.var_idmap with | Some use -> let info = @@ -267,10 +268,11 @@ let process_subscope_decl IdentMap.add name (SubScope (sub_scope_uid, original_subscope_uid)) scope_ctxt.var_idmap; - sub_scopes = ScopeSet.add original_subscope_uid scope_ctxt.sub_scopes; + sub_scopes = + ScopeName.Set.add original_subscope_uid scope_ctxt.sub_scopes; } in - { ctxt with scopes = ScopeMap.add scope scope_ctxt ctxt.scopes } + { ctxt with scopes = ScopeName.Map.add scope scope_ctxt ctxt.scopes } let is_type_cond ((typ, _) : Surface.Ast.typ) = match typ with @@ -329,7 +331,7 @@ let process_data_decl let data_typ = process_type ctxt decl.scope_decl_context_item_typ in let is_cond = is_type_cond decl.scope_decl_context_item_typ in let name, pos = decl.scope_decl_context_item_name in - let scope_ctxt = ScopeMap.find scope ctxt.scopes in + let scope_ctxt = ScopeName.Map.find scope ctxt.scopes in match IdentMap.find_opt name scope_ctxt.var_idmap with | Some use -> let info = @@ -360,9 +362,9 @@ let process_data_decl in { ctxt with - scopes = ScopeMap.add scope scope_ctxt ctxt.scopes; + scopes = ScopeName.Map.add scope scope_ctxt ctxt.scopes; var_typs = - ScopeVarMap.add uid + ScopeVar.Map.add uid { var_sig_typ = data_typ; var_sig_is_condition = is_cond; @@ -397,9 +399,7 @@ let process_struct_decl (ctxt : context) (sdecl : Surface.Ast.struct_decl) : (Marked.unmark sdecl.struct_decl_name); List.fold_left (fun ctxt (fdecl, _) -> - let f_uid = - StructFieldName.fresh fdecl.Surface.Ast.struct_decl_field_name - in + let f_uid = StructField.fresh fdecl.Surface.Ast.struct_decl_field_name in let ctxt = { ctxt with @@ -408,24 +408,24 @@ let process_struct_decl (ctxt : context) (sdecl : Surface.Ast.struct_decl) : (Marked.unmark fdecl.Surface.Ast.struct_decl_field_name) (fun uids -> match uids with - | None -> Some (StructMap.singleton s_uid f_uid) - | Some uids -> Some (StructMap.add s_uid f_uid uids)) + | None -> Some (StructName.Map.singleton s_uid f_uid) + | Some uids -> Some (StructName.Map.add s_uid f_uid uids)) ctxt.field_idmap; } in { ctxt with structs = - StructMap.update s_uid + StructName.Map.update s_uid (fun fields -> match fields with | None -> Some - (StructFieldMap.singleton f_uid + (StructField.Map.singleton f_uid (process_type ctxt fdecl.Surface.Ast.struct_decl_field_typ)) | Some fields -> Some - (StructFieldMap.add f_uid + (StructField.Map.add f_uid (process_type ctxt fdecl.Surface.Ast.struct_decl_field_typ) fields)) ctxt.structs; @@ -453,15 +453,15 @@ let process_enum_decl (ctxt : context) (edecl : Surface.Ast.enum_decl) : context (Marked.unmark cdecl.Surface.Ast.enum_decl_case_name) (fun uids -> match uids with - | None -> Some (EnumMap.singleton e_uid c_uid) - | Some uids -> Some (EnumMap.add e_uid c_uid uids)) + | None -> Some (EnumName.Map.singleton e_uid c_uid) + | Some uids -> Some (EnumName.Map.add e_uid c_uid uids)) ctxt.constructor_idmap; } in { ctxt with enums = - EnumMap.update e_uid + EnumName.Map.update e_uid (fun cases -> let typ = match cdecl.Surface.Ast.enum_decl_case_typ with @@ -469,8 +469,8 @@ let process_enum_decl (ctxt : context) (edecl : Surface.Ast.enum_decl) : context | Some typ -> process_type ctxt typ in match cases with - | None -> Some (EnumConstructorMap.singleton c_uid typ) - | Some fields -> Some (EnumConstructorMap.add c_uid typ fields)) + | None -> Some (EnumConstructor.Map.singleton c_uid typ) + | Some fields -> Some (EnumConstructor.Map.add c_uid typ fields)) ctxt.enums; }) ctxt edecl.enum_decl_cases @@ -522,9 +522,9 @@ let process_scope_decl (ctxt : context) (decl : Surface.Ast.scope_decl) : { ctxt with structs = - StructMap.add + StructName.Map.add (get_struct ctxt decl.scope_decl_name) - StructFieldMap.empty ctxt.structs; + StructField.Map.empty ctxt.structs; } else let ctxt = @@ -535,7 +535,7 @@ let process_scope_decl (ctxt : context) (decl : Surface.Ast.scope_decl) : } in let out_struct_fields = - let sco = ScopeMap.find scope_uid ctxt.scopes in + let sco = ScopeName.Map.find scope_uid ctxt.scopes in let str = get_struct ctxt decl.scope_decl_name in IdentMap.fold (fun id var svmap -> @@ -544,11 +544,11 @@ let process_scope_decl (ctxt : context) (decl : Surface.Ast.scope_decl) : | ScopeVar v -> ( try let field = - StructMap.find str (IdentMap.find id ctxt.field_idmap) + StructName.Map.find str (IdentMap.find id ctxt.field_idmap) in - ScopeVarMap.add v field svmap + ScopeVar.Map.add v field svmap with Not_found -> svmap)) - sco.var_idmap ScopeVarMap.empty + sco.var_idmap ScopeVar.Map.empty in let typedefs = IdentMap.update @@ -597,15 +597,15 @@ let process_name_item (ctxt : context) (item : Surface.Ast.code_item Marked.pos) ( scope_uid, { out_struct_name = out_struct_uid; - out_struct_fields = ScopeVarMap.empty; + out_struct_fields = ScopeVar.Map.empty; } )) ctxt.typedefs; scopes = - ScopeMap.add scope_uid + ScopeName.Map.add scope_uid { var_idmap = IdentMap.empty; scope_defs_contexts = Ast.ScopeDefMap.empty; - sub_scopes = ScopeSet.empty; + sub_scopes = ScopeName.Set.empty; } ctxt.scopes; } @@ -679,11 +679,11 @@ let get_def_key (scope_uid : ScopeName.t) (ctxt : context) (pos : Pos.t) : Ast.ScopeDef.t = - let scope_ctxt = ScopeMap.find scope_uid ctxt.scopes in + let scope_ctxt = ScopeName.Map.find scope_uid ctxt.scopes in match name with | [x] -> let x_uid = get_var_uid scope_uid ctxt x in - let var_sig = ScopeVarMap.find x_uid ctxt.var_typs in + let var_sig = ScopeVar.Map.find x_uid ctxt.var_typs in Ast.ScopeDef.Var ( x_uid, match state with @@ -807,7 +807,7 @@ let process_definition { ctxt with scopes = - ScopeMap.update s_name + ScopeName.Map.update s_name (fun (s_ctxt : scope_context option) -> let def_key = get_def_key @@ -875,11 +875,11 @@ let form_context (prgm : Surface.Ast.program) : context = { local_var_idmap = IdentMap.empty; typedefs = IdentMap.empty; - scopes = ScopeMap.empty; - var_typs = ScopeVarMap.empty; - structs = StructMap.empty; + scopes = ScopeName.Map.empty; + var_typs = ScopeVar.Map.empty; + structs = StructName.Map.empty; field_idmap = IdentMap.empty; - enums = EnumMap.empty; + enums = EnumName.Map.empty; constructor_idmap = IdentMap.empty; } in diff --git a/compiler/desugared/name_resolution.mli b/compiler/desugared/name_resolution.mli index d9f2e8c5..f7c419a5 100644 --- a/compiler/desugared/name_resolution.mli +++ b/compiler/desugared/name_resolution.mli @@ -45,15 +45,15 @@ type scope_context = { (** All variables, including scope variables and subscopes *) scope_defs_contexts : scope_def_context Ast.ScopeDefMap.t; (** What is the default rule to refer to for unnamed exceptions, if any *) - sub_scopes : ScopeSet.t; + sub_scopes : ScopeName.Set.t; (** Other scopes referred to by this scope. Used for dependency analysis *) } (** Inside a scope, we distinguish between the variables and the subscopes. *) -type struct_context = typ StructFieldMap.t +type struct_context = typ StructField.Map.t (** Types of the fields of a struct *) -type enum_context = typ EnumConstructorMap.t +type enum_context = typ EnumConstructor.Map.t (** Types of the payloads of the cases of an enum *) type var_sig = { @@ -78,16 +78,17 @@ type context = { arguments or pattern matching *) typedefs : typedef IdentMap.t; (** Gathers the names of the scopes, structs and enums *) - field_idmap : StructFieldName.t StructMap.t IdentMap.t; + field_idmap : StructField.t StructName.Map.t IdentMap.t; (** The names of the struct fields. Names of fields can be shared between different structs *) - constructor_idmap : EnumConstructor.t EnumMap.t IdentMap.t; + constructor_idmap : EnumConstructor.t EnumName.Map.t IdentMap.t; (** The names of the enum constructors. Constructor names can be shared between different enums *) - scopes : scope_context ScopeMap.t; (** For each scope, its context *) - structs : struct_context StructMap.t; (** For each struct, its context *) - enums : enum_context EnumMap.t; (** For each enum, its context *) - var_typs : var_sig ScopeVarMap.t; + scopes : scope_context ScopeName.Map.t; (** For each scope, its context *) + structs : struct_context StructName.Map.t; + (** For each struct, its context *) + enums : enum_context EnumName.Map.t; (** For each enum, its context *) + var_typs : var_sig ScopeVar.Map.t; (** The signatures of each scope variable declared *) } (** Main context used throughout {!module: Surface.Desugaring} *) diff --git a/compiler/driver.ml b/compiler/driver.ml index 96eaa7e0..cfa96a6f 100644 --- a/compiler/driver.ml +++ b/compiler/driver.ml @@ -182,7 +182,8 @@ let driver source_file (options : Cli.options) : int = if Option.is_some options.ex_scope then Format.fprintf fmt "%a\n" (Scopelang.Print.scope prgm.program_ctx ~debug:options.debug) - (scope_uid, Shared_ast.ScopeMap.find scope_uid prgm.program_scopes) + ( scope_uid, + Shared_ast.ScopeName.Map.find scope_uid prgm.program_scopes ) else Format.fprintf fmt "%a\n" (Scopelang.Print.program ~debug:options.debug) diff --git a/compiler/lcalc/ast.ml b/compiler/lcalc/ast.ml index 6de0a21b..f4ae7597 100644 --- a/compiler/lcalc/ast.ml +++ b/compiler/lcalc/ast.ml @@ -28,10 +28,10 @@ let option_enum : EnumName.t = EnumName.fresh ("eoption", Pos.no_pos) let none_constr : EnumConstructor.t = EnumConstructor.fresh ("ENone", Pos.no_pos) let some_constr : EnumConstructor.t = EnumConstructor.fresh ("ESome", Pos.no_pos) -let option_enum_config : typ EnumConstructorMap.t = - EnumConstructorMap.empty - |> EnumConstructorMap.add none_constr (TLit TUnit, Pos.no_pos) - |> EnumConstructorMap.add some_constr (TAny, Pos.no_pos) +let option_enum_config : typ EnumConstructor.Map.t = + EnumConstructor.Map.empty + |> EnumConstructor.Map.add none_constr (TLit TUnit, Pos.no_pos) + |> EnumConstructor.Map.add some_constr (TAny, Pos.no_pos) (* FIXME: proper typing in all the constructors below *) @@ -49,9 +49,9 @@ let make_some e = let make_matchopt_with_abs_arms arg e_none e_some = let m = Marked.get_mark arg in let cases = - EnumConstructorMap.empty - |> EnumConstructorMap.add none_constr e_none - |> EnumConstructorMap.add some_constr e_some + EnumConstructor.Map.empty + |> EnumConstructor.Map.add none_constr e_none + |> EnumConstructor.Map.add some_constr e_some in Expr.ematch arg option_enum cases m diff --git a/compiler/lcalc/ast.mli b/compiler/lcalc/ast.mli index 58cced2e..3fb0583d 100644 --- a/compiler/lcalc/ast.mli +++ b/compiler/lcalc/ast.mli @@ -32,7 +32,7 @@ type 'm program = 'm expr Shared_ast.program val option_enum : EnumName.t val none_constr : EnumConstructor.t val some_constr : EnumConstructor.t -val option_enum_config : typ EnumConstructorMap.t +val option_enum_config : typ EnumConstructor.Map.t val make_none : 'm mark -> 'm expr boxed val make_some : 'm expr boxed -> 'm expr boxed diff --git a/compiler/lcalc/closure_conversion.ml b/compiler/lcalc/closure_conversion.ml index 9ca8b0b7..554bb2aa 100644 --- a/compiler/lcalc/closure_conversion.ml +++ b/compiler/lcalc/closure_conversion.ml @@ -44,7 +44,7 @@ let closure_conversion_expr (type m) (ctx : m ctx) (e : m expr) : m expr boxed = (* We do not close the clotures inside the arms of the match expression, since they get a special treatment at compilation to Scalc. *) let free_vars, new_cases = - EnumConstructorMap.fold + EnumConstructor.Map.fold (fun cons e1 (free_vars, new_cases) -> match Marked.unmark e1 with | EAbs { binder; tys } -> @@ -52,12 +52,12 @@ let closure_conversion_expr (type m) (ctx : m ctx) (e : m expr) : m expr boxed = let new_free_vars, new_body = aux body in let new_binder = Expr.bind vars new_body in ( Var.Set.union free_vars new_free_vars, - EnumConstructorMap.add cons + EnumConstructor.Map.add cons (Expr.eabs new_binder tys (Marked.get_mark e1)) new_cases ) | _ -> failwith "should not happen") cases - (free_vars, EnumConstructorMap.empty) + (free_vars, EnumConstructor.Map.empty) in free_vars, Expr.ematch new_e name new_cases m | EApp { f = EAbs { binder; tys }, e1_pos; args } -> diff --git a/compiler/lcalc/compile_with_exceptions.ml b/compiler/lcalc/compile_with_exceptions.ml index ccaff511..aefa6668 100644 --- a/compiler/lcalc/compile_with_exceptions.ml +++ b/compiler/lcalc/compile_with_exceptions.ml @@ -58,13 +58,13 @@ and translate_expr (ctx : 'm ctx) (e : 'm D.expr) : 'm A.expr boxed = match Marked.unmark e with | EVar v -> Expr.make_var (Var.Map.find v ctx) m | EStruct { name; fields } -> - Expr.estruct name (StructFieldMap.map (translate_expr ctx) fields) m + Expr.estruct name (StructField.Map.map (translate_expr ctx) fields) m | EStructAccess { name; e; field } -> Expr.estructaccess (translate_expr ctx e) field name m | EInj { name; e; cons } -> Expr.einj (translate_expr ctx e) cons name m | EMatch { name; e; cases } -> Expr.ematch (translate_expr ctx e) name - (EnumConstructorMap.map (translate_expr ctx) cases) + (EnumConstructor.Map.map (translate_expr ctx) cases) m | EArray es -> Expr.earray (List.map (translate_expr ctx) es) m | ELit diff --git a/compiler/lcalc/compile_without_exceptions.ml b/compiler/lcalc/compile_without_exceptions.ml index b497fec3..daf67af0 100644 --- a/compiler/lcalc/compile_without_exceptions.ml +++ b/compiler/lcalc/compile_without_exceptions.ml @@ -255,11 +255,12 @@ let rec translate_and_hoist (ctx : 'm ctx) (e : 'm D.expr) : e', hoists | EStruct { name; fields } -> let fields', h_fields = - StructFieldMap.fold + StructField.Map.fold (fun field e (fields, hoists) -> let e, h = translate_and_hoist ctx e in - StructFieldMap.add field e fields, h :: hoists) - fields (StructFieldMap.empty, []) + StructField.Map.add field e fields, h :: hoists) + fields + (StructField.Map.empty, []) in let hoists = disjoint_union_maps (Expr.pos e) h_fields in Expr.estruct name fields' mark, hoists @@ -274,12 +275,12 @@ let rec translate_and_hoist (ctx : 'm ctx) (e : 'm D.expr) : | EMatch { name; e = e1; cases } -> let e1', h1 = translate_and_hoist ctx e1 in let cases', h_cases = - EnumConstructorMap.fold + EnumConstructor.Map.fold (fun cons e (cases, hoists) -> let e', h = translate_and_hoist ctx e in - EnumConstructorMap.add cons e' cases, h :: hoists) + EnumConstructor.Map.add cons e' cases, h :: hoists) cases - (EnumConstructorMap.empty, []) + (EnumConstructor.Map.empty, []) in let hoists = disjoint_union_maps (Expr.pos e) (h1 :: h_cases) in let e' = Expr.ematch e1' name cases' mark in @@ -537,7 +538,7 @@ let translate_program (prgm : 'm D.program) : 'm A.program = prgm.decl_ctx with ctx_enums = prgm.decl_ctx.ctx_enums - |> EnumMap.add A.option_enum A.option_enum_config; + |> EnumName.Map.add A.option_enum A.option_enum_config; } in let decl_ctx = @@ -545,9 +546,9 @@ let translate_program (prgm : 'm D.program) : 'm A.program = decl_ctx with ctx_structs = prgm.decl_ctx.ctx_structs - |> StructMap.mapi (fun n str -> + |> StructName.Map.mapi (fun n str -> if List.mem n inputs_structs then - StructFieldMap.map translate_typ str + StructField.Map.map translate_typ str (* Cli.debug_print @@ Format.asprintf "Input type: %a" (Print.typ decl_ctx) tau; Cli.debug_print @@ Format.asprintf "Output type: %a" (Print.typ decl_ctx) (translate_typ diff --git a/compiler/lcalc/optimizations.ml b/compiler/lcalc/optimizations.ml index ddc0d200..556e89cd 100644 --- a/compiler/lcalc/optimizations.ml +++ b/compiler/lcalc/optimizations.ml @@ -27,16 +27,16 @@ let rec iota_expr (e : 'm expr) : 'm expr boxed = | EMatch { e = EInj { e = e'; cons; name = n' }, _; cases; name = n } when EnumName.equal n n' -> let e1 = visitor_map iota_expr e' in - let case = visitor_map iota_expr (EnumConstructorMap.find cons cases) in + let case = visitor_map iota_expr (EnumConstructor.Map.find cons cases) in Expr.eapp case [e1] m | EMatch { e = e'; cases; name = n } when cases - |> EnumConstructorMap.mapi (fun i case -> + |> EnumConstructor.Map.mapi (fun i case -> match Marked.unmark case with | EInj { cons = i'; name = n'; _ } -> EnumConstructor.equal i i' && EnumName.equal n n' | _ -> false) - |> EnumConstructorMap.for_all (fun _ b -> b) -> + |> EnumConstructor.Map.for_all (fun _ b -> b) -> visitor_map iota_expr e' | _ -> visitor_map iota_expr e diff --git a/compiler/lcalc/to_ocaml.ml b/compiler/lcalc/to_ocaml.ml index 7a3c8520..3388c271 100644 --- a/compiler/lcalc/to_ocaml.ml +++ b/compiler/lcalc/to_ocaml.ml @@ -20,16 +20,16 @@ open Ast open String_common module D = Dcalc.Ast -let find_struct (s : StructName.t) (ctx : decl_ctx) : typ StructFieldMap.t = - try StructMap.find s ctx.ctx_structs +let find_struct (s : StructName.t) (ctx : decl_ctx) : typ StructField.Map.t = + try StructName.Map.find s ctx.ctx_structs with Not_found -> let s_name, pos = StructName.get_info s in Errors.raise_spanned_error pos "Internal Error: Structure %s was not found in the current environment." s_name -let find_enum (en : EnumName.t) (ctx : decl_ctx) : typ EnumConstructorMap.t = - try EnumMap.find en ctx.ctx_enums +let find_enum (en : EnumName.t) (ctx : decl_ctx) : typ EnumConstructor.Map.t = + try EnumName.Map.find en ctx.ctx_enums with Not_found -> let en_name, pos = EnumName.get_info en in Errors.raise_spanned_error pos @@ -162,13 +162,12 @@ let format_to_module_name let format_struct_field_name (fmt : Format.formatter) - ((sname_opt, v) : StructName.t option * StructFieldName.t) : unit = + ((sname_opt, v) : StructName.t option * StructField.t) : unit = (match sname_opt with | Some sname -> Format.fprintf fmt "%a.%s" format_to_module_name (`Sname sname) | None -> Format.fprintf fmt "%s") - (avoid_keywords - (to_ascii (Format.asprintf "%a" StructFieldName.format_t v))) + (avoid_keywords (to_ascii (Format.asprintf "%a" StructField.format_t v))) let format_enum_name (fmt : Format.formatter) (v : EnumName.t) : unit = Format.fprintf fmt "%s" @@ -284,7 +283,7 @@ let rec format_expr (ctx : decl_ctx) (fmt : Format.formatter) (e : 'm expr) : (fun fmt e -> Format.fprintf fmt "%a" format_with_parens e)) es | EStruct { name = s; fields = es } -> - if StructFieldMap.is_empty es then Format.fprintf fmt "()" + if StructField.Map.is_empty es then Format.fprintf fmt "()" else Format.fprintf fmt "{@[%a@]}" (Format.pp_print_list @@ -292,7 +291,7 @@ let rec format_expr (ctx : decl_ctx) (fmt : Format.formatter) (e : 'm expr) : (fun fmt (struct_field, e) -> Format.fprintf fmt "@[%a =@ %a@]" format_struct_field_name (Some s, struct_field) format_with_parens e)) - (StructFieldMap.bindings es) + (StructField.Map.bindings es) | EArray es -> Format.fprintf fmt "@[[|%a|]@]" (Format.pp_print_list @@ -331,7 +330,7 @@ let rec format_expr (ctx : decl_ctx) (fmt : Format.formatter) (e : 'm expr) : | _ -> assert false (* should not happen *)) e)) - (EnumConstructorMap.bindings cases) + (EnumConstructor.Map.bindings cases) | ELit l -> Format.fprintf fmt "%a" format_lit (Marked.mark (Expr.pos e) l) | EApp { f = EAbs { binder; tys }, _; args } -> let xs, body = Bindlib.unmbind binder in @@ -444,8 +443,8 @@ let rec format_expr (ctx : decl_ctx) (fmt : Format.formatter) (e : 'm expr) : let format_struct_embedding (fmt : Format.formatter) - ((struct_name, struct_fields) : StructName.t * typ StructFieldMap.t) = - if StructFieldMap.is_empty struct_fields then + ((struct_name, struct_fields) : StructName.t * typ StructField.Map.t) = + if StructField.Map.is_empty struct_fields then Format.fprintf fmt "let embed_%a (_: %a.t) : runtime_value = Unit@\n@\n" format_struct_name struct_name format_to_module_name (`Sname struct_name) else @@ -458,16 +457,16 @@ let format_struct_embedding (Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt ";@\n") (fun _fmt (struct_field, struct_field_type) -> - Format.fprintf fmt "(\"%a\",@ %a@ x.%a)" StructFieldName.format_t + Format.fprintf fmt "(\"%a\",@ %a@ x.%a)" StructField.format_t struct_field typ_embedding_name struct_field_type format_struct_field_name (Some struct_name, struct_field))) - (StructFieldMap.bindings struct_fields) + (StructField.Map.bindings struct_fields) let format_enum_embedding (fmt : Format.formatter) - ((enum_name, enum_cases) : EnumName.t * typ EnumConstructorMap.t) = - if EnumConstructorMap.is_empty enum_cases then + ((enum_name, enum_cases) : EnumName.t * typ EnumConstructor.Map.t) = + if EnumConstructor.Map.is_empty enum_cases then Format.fprintf fmt "let embed_%a (_: %a.t) : runtime_value = Unit@\n@\n" format_to_module_name (`Ename enum_name) format_enum_name enum_name else @@ -483,14 +482,14 @@ let format_enum_embedding Format.fprintf fmt "@[| %a x ->@ (\"%a\", %a x)@]" format_enum_cons_name enum_cons EnumConstructor.format_t enum_cons typ_embedding_name enum_cons_type)) - (EnumConstructorMap.bindings enum_cases) + (EnumConstructor.Map.bindings enum_cases) let format_ctx (type_ordering : Scopelang.Dependency.TVertex.t list) (fmt : Format.formatter) (ctx : decl_ctx) : unit = let format_struct_decl fmt (struct_name, struct_fields) = - if StructFieldMap.is_empty struct_fields then + if StructField.Map.is_empty struct_fields then Format.fprintf fmt "@[module %a = struct@\n@[type t = unit@]@]@\nend@\n" format_to_module_name (`Sname struct_name) @@ -505,7 +504,7 @@ let format_ctx (fun _fmt (struct_field, struct_field_type) -> Format.fprintf fmt "@[%a:@ %a@]" format_struct_field_name (None, struct_field) format_typ struct_field_type)) - (StructFieldMap.bindings struct_fields); + (StructField.Map.bindings struct_fields); if !Cli.trace_flag then format_struct_embedding fmt (struct_name, struct_fields) in @@ -518,7 +517,7 @@ let format_ctx (fun _fmt (enum_cons, enum_cons_type) -> Format.fprintf fmt "@[| %a@ of@ %a@]" format_enum_cons_name enum_cons format_typ enum_cons_type)) - (EnumConstructorMap.bindings enum_cons); + (EnumConstructor.Map.bindings enum_cons); if !Cli.trace_flag then format_enum_embedding fmt (enum_name, enum_cons) in let is_in_type_ordering s = @@ -532,8 +531,8 @@ let format_ctx let scope_structs = List.map (fun (s, _) -> Scopelang.Dependency.TVertex.Struct s) - (StructMap.bindings - (StructMap.filter + (StructName.Map.bindings + (StructName.Map.filter (fun s _ -> not (is_in_type_ordering s)) ctx.ctx_structs)) in diff --git a/compiler/lcalc/to_ocaml.mli b/compiler/lcalc/to_ocaml.mli index c060b299..4a5bc73f 100644 --- a/compiler/lcalc/to_ocaml.mli +++ b/compiler/lcalc/to_ocaml.mli @@ -19,8 +19,8 @@ open Shared_ast (** Formats a lambda calculus program into a valid OCaml program *) val avoid_keywords : string -> string -val find_struct : StructName.t -> decl_ctx -> typ StructFieldMap.t -val find_enum : EnumName.t -> decl_ctx -> typ EnumConstructorMap.t +val find_struct : StructName.t -> decl_ctx -> typ StructField.Map.t +val find_enum : EnumName.t -> decl_ctx -> typ EnumConstructor.Map.t val typ_needs_parens : typ -> bool (* val needs_parens : 'm expr -> bool *) @@ -29,7 +29,7 @@ val format_enum_cons_name : Format.formatter -> EnumConstructor.t -> unit val format_struct_name : Format.formatter -> StructName.t -> unit val format_struct_field_name : - Format.formatter -> StructName.t option * StructFieldName.t -> unit + Format.formatter -> StructName.t option * StructField.t -> unit val format_to_module_name : Format.formatter -> [< `Ename of EnumName.t | `Sname of StructName.t ] -> unit diff --git a/compiler/plugins/api_web.ml b/compiler/plugins/api_web.ml index d2ec5a57..33cbd2e4 100644 --- a/compiler/plugins/api_web.ml +++ b/compiler/plugins/api_web.ml @@ -40,9 +40,9 @@ module To_jsoo = struct let format_struct_field_name_camel_case (fmt : Format.formatter) - (v : StructFieldName.t) : unit = + (v : StructField.t) : unit = let s = - Format.asprintf "%a" StructFieldName.format_t v + Format.asprintf "%a" StructField.format_t v |> to_ascii |> to_snake_case |> avoid_keywords @@ -166,7 +166,7 @@ module To_jsoo = struct format_struct_field_name_camel_case struct_field format_typ_to_jsoo struct_field_type fmt_struct_name () format_struct_field_name (None, struct_field))) - (StructFieldMap.bindings struct_fields) + (StructField.Map.bindings struct_fields) in let fmt_of_jsoo fmt _ = Format.fprintf fmt "%a" @@ -186,7 +186,7 @@ module To_jsoo = struct format_struct_field_name (None, struct_field) format_typ_of_jsoo struct_field_type fmt_struct_name () format_struct_field_name_camel_case struct_field)) - (StructFieldMap.bindings struct_fields) + (StructField.Map.bindings struct_fields) in let fmt_conv_funs fmt _ = Format.fprintf fmt @@ -203,7 +203,7 @@ module To_jsoo = struct () fmt_struct_name () fmt_module_struct_name () fmt_of_jsoo () in - if StructFieldMap.is_empty struct_fields then + if StructField.Map.is_empty struct_fields then Format.fprintf fmt "class type %a =@ object end@\n\ let %a_to_jsoo (_ : %a.t) : %a Js.t = object%%js end@\n\ @@ -220,10 +220,10 @@ module To_jsoo = struct Format.fprintf fmt "@[method %a:@ %a %a@]" format_struct_field_name_camel_case struct_field format_typ struct_field_type format_prop_or_meth struct_field_type)) - (StructFieldMap.bindings struct_fields) + (StructField.Map.bindings struct_fields) fmt_conv_funs () in - let format_enum_decl fmt (enum_name, (enum_cons : typ EnumConstructorMap.t)) + let format_enum_decl fmt (enum_name, (enum_cons : typ EnumConstructor.Map.t)) = let fmt_enum_name fmt _ = format_enum_name fmt enum_name in let fmt_module_enum_name fmt _ = @@ -247,7 +247,7 @@ module To_jsoo = struct end@]" format_enum_cons_name cname format_enum_cons_name cname format_typ_to_jsoo typ)) - (EnumConstructorMap.bindings enum_cons) + (EnumConstructor.Map.bindings enum_cons) in let fmt_of_jsoo fmt _ = Format.fprintf fmt @@ -273,7 +273,7 @@ module To_jsoo = struct format_enum_cons_name cname fmt_module_enum_name () format_enum_cons_name cname format_typ_of_jsoo typ fmt_enum_name ())) - (EnumConstructorMap.bindings enum_cons) + (EnumConstructor.Map.bindings enum_cons) fmt_module_enum_name () in @@ -302,7 +302,7 @@ module To_jsoo = struct ~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n") (fun fmt (enum_cons, _) -> Format.fprintf fmt "- \"%a\"" format_enum_cons_name enum_cons)) - (EnumConstructorMap.bindings enum_cons) + (EnumConstructor.Map.bindings enum_cons) fmt_conv_funs () in let is_in_type_ordering s = @@ -316,8 +316,8 @@ module To_jsoo = struct let scope_structs = List.map (fun (s, _) -> Scopelang.Dependency.TVertex.Struct s) - (StructMap.bindings - (StructMap.filter + (StructName.Map.bindings + (StructName.Map.filter (fun s _ -> not (is_in_type_ordering s)) ctx.ctx_structs)) in diff --git a/compiler/plugins/json_schema.ml b/compiler/plugins/json_schema.ml index bec655e1..89a97319 100644 --- a/compiler/plugins/json_schema.ml +++ b/compiler/plugins/json_schema.ml @@ -38,9 +38,9 @@ module To_json = struct let format_struct_field_name_camel_case (fmt : Format.formatter) - (v : StructFieldName.t) : unit = + (v : StructField.t) : unit = let s = - Format.asprintf "%a" StructFieldName.format_t v + Format.asprintf "%a" StructField.format_t v |> to_ascii |> to_snake_case |> avoid_keywords @@ -97,7 +97,7 @@ module To_json = struct (fun fmt (field_name, field_type) -> Format.fprintf fmt "@[\"%a\": {@\n%a@]@\n}" format_struct_field_name_camel_case field_name fmt_type field_type)) - (StructFieldMap.bindings (find_struct sname ctx)) + (StructField.Map.bindings (find_struct sname ctx)) let fmt_definitions (ctx : decl_ctx) @@ -119,12 +119,13 @@ module To_json = struct | TEnum e -> List.fold_left collect (t :: acc) (List.map snd - (EnumConstructorMap.bindings (EnumMap.find e ctx.ctx_enums))) + (EnumConstructor.Map.bindings + (EnumName.Map.find e ctx.ctx_enums))) | TArray t -> collect acc t | _ -> acc in find_struct input_struct ctx - |> StructFieldMap.bindings + |> StructField.Map.bindings |> List.fold_left (fun acc (_, field_typ) -> collect acc field_typ) [] |> List.sort_uniq (fun t t' -> String.compare (get_name t) (get_name t')) in @@ -148,7 +149,7 @@ module To_json = struct Format.fprintf fmt "@[{@\n\"type\": \"string\",@\n\"enum\": [\"%a\"]@]@\n}" format_enum_cons_name enum_cons)) - (EnumConstructorMap.bindings enum_def) + (EnumConstructor.Map.bindings enum_def) (Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt ",@\n") (fun fmt (enum_cons, payload_type) -> @@ -170,7 +171,7 @@ module To_json = struct }@]@\n\ }" format_enum_cons_name enum_cons fmt_type payload_type)) - (EnumConstructorMap.bindings enum_def) + (EnumConstructor.Map.bindings enum_def) in Format.fprintf fmt "@\n%a" diff --git a/compiler/scalc/ast.ml b/compiler/scalc/ast.ml index cb8b79bd..44e00999 100644 --- a/compiler/scalc/ast.ml +++ b/compiler/scalc/ast.ml @@ -31,7 +31,7 @@ and naked_expr = | EVar of LocalName.t | EFunc of TopLevelName.t | EStruct of expr list * StructName.t - | EStructFieldAccess of expr * StructFieldName.t * StructName.t + | EStructFieldAccess of expr * StructField.t * StructName.t | EInj of expr * EnumConstructor.t * EnumName.t | EArray of expr list | ELit of L.lit diff --git a/compiler/scalc/compile_from_lambda.ml b/compiler/scalc/compile_from_lambda.ml index 25730f53..dc60a137 100644 --- a/compiler/scalc/compile_from_lambda.ml +++ b/compiler/scalc/compile_from_lambda.ml @@ -48,7 +48,7 @@ let rec translate_expr (ctxt : 'm ctxt) (expr : 'm L.expr) : A.block * A.expr = [], (local_var, Expr.pos expr) | EStruct { fields; name } -> let args_stmts, new_args = - StructFieldMap.fold + StructField.Map.fold (fun _ arg (args_stmts, new_args) -> let arg_stmts, new_arg = translate_expr ctxt arg in arg_stmts @ args_stmts, new_arg :: new_args) @@ -207,7 +207,7 @@ and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.expr) : A.block = | EMatch { e = e1; cases; name } -> let e1_stmts, new_e1 = translate_expr ctxt e1 in let new_cases = - EnumConstructorMap.fold + EnumConstructor.Map.fold (fun _ arg new_args -> match Marked.unmark arg with | EAbs { binder; _ } -> diff --git a/compiler/scalc/print.ml b/compiler/scalc/print.ml index 6a652572..17588ae4 100644 --- a/compiler/scalc/print.ml +++ b/compiler/scalc/print.ml @@ -46,10 +46,10 @@ let rec format_expr ~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ") (fun fmt (e, (struct_field, _)) -> Format.fprintf fmt "%a%a%a%a %a" Print.punctuation "\"" - StructFieldName.format_t struct_field Print.punctuation "\"" + StructField.format_t struct_field Print.punctuation "\"" Print.punctuation ":" format_expr e)) (List.combine es - (StructFieldMap.bindings (StructMap.find s decl_ctx.ctx_structs))) + (StructField.Map.bindings (StructName.Map.find s decl_ctx.ctx_structs))) Print.punctuation "}" | EArray es -> Format.fprintf fmt "@[%a%a%a@]" Print.punctuation "[" @@ -59,8 +59,7 @@ let rec format_expr es Print.punctuation "]" | EStructFieldAccess (e1, field, _) -> Format.fprintf fmt "%a%a%a%a%a" format_expr e1 Print.punctuation "." - Print.punctuation "\"" StructFieldName.format_t field Print.punctuation - "\"" + Print.punctuation "\"" StructField.format_t field Print.punctuation "\"" | EInj (e, cons, _) -> Format.fprintf fmt "@[%a@ %a@]" Print.enum_constructor cons format_expr e @@ -153,7 +152,8 @@ let rec format_statement (format_block decl_ctx ~debug) arm_block)) (List.combine - (EnumConstructorMap.bindings (EnumMap.find enum decl_ctx.ctx_enums)) + (EnumConstructor.Map.bindings + (EnumName.Map.find enum decl_ctx.ctx_enums)) arms) and format_block diff --git a/compiler/scalc/to_python.ml b/compiler/scalc/to_python.ml index ebb0e876..a4f9a69c 100644 --- a/compiler/scalc/to_python.ml +++ b/compiler/scalc/to_python.ml @@ -127,11 +127,10 @@ let format_struct_name (fmt : Format.formatter) (v : StructName.t) : unit = (avoid_keywords (to_camel_case (to_ascii (Format.asprintf "%a" StructName.format_t v)))) -let format_struct_field_name (fmt : Format.formatter) (v : StructFieldName.t) : - unit = +let format_struct_field_name (fmt : Format.formatter) (v : StructField.t) : unit + = Format.fprintf fmt "%s" - (avoid_keywords - (to_ascii (Format.asprintf "%a" StructFieldName.format_t v))) + (avoid_keywords (to_ascii (Format.asprintf "%a" StructField.format_t v))) let format_enum_name (fmt : Format.formatter) (v : EnumName.t) : unit = Format.fprintf fmt "%s" @@ -272,7 +271,7 @@ let rec format_expression (ctx : decl_ctx) (fmt : Format.formatter) (e : expr) : Format.fprintf fmt "%a = %a" format_struct_field_name struct_field (format_expression ctx) e)) (List.combine es - (StructFieldMap.bindings (StructMap.find s ctx.ctx_structs))) + (StructField.Map.bindings (StructName.Map.find s ctx.ctx_structs))) | EStructFieldAccess (e1, field, _) -> Format.fprintf fmt "%a.%a" (format_expression ctx) e1 format_struct_field_name field @@ -401,7 +400,7 @@ let rec format_statement List.map2 (fun (x, y) (cons, _) -> x, y, cons) cases - (EnumConstructorMap.bindings (EnumMap.find e_name ctx.ctx_enums)) + (EnumConstructor.Map.bindings (EnumName.Map.find e_name ctx.ctx_enums)) in let tmp_var = LocalName.fresh ("match_arg", Pos.no_pos) in Format.fprintf fmt "%a = %a@\n@[if %a@]" format_var tmp_var @@ -443,7 +442,7 @@ let format_ctx (fmt : Format.formatter) (ctx : decl_ctx) : unit = let format_struct_decl fmt (struct_name, struct_fields) = - let fields = StructFieldMap.bindings struct_fields in + let fields = StructField.Map.bindings struct_fields in Format.fprintf fmt "class %a:@\n\ \ def __init__(self, %a) -> None:@\n\ @@ -467,7 +466,7 @@ let format_ctx Format.fprintf fmt "%a: %a" format_struct_field_name struct_field format_typ struct_field_type)) fields - (if StructFieldMap.is_empty struct_fields then fun fmt _ -> + (if StructField.Map.is_empty struct_fields then fun fmt _ -> Format.fprintf fmt " pass" else Format.pp_print_list @@ -476,7 +475,7 @@ let format_ctx Format.fprintf fmt " self.%a = %a" format_struct_field_name struct_field format_struct_field_name struct_field)) fields format_struct_name struct_name - (if not (StructFieldMap.is_empty struct_fields) then + (if not (StructField.Map.is_empty struct_fields) then Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt " and@ ") (fun fmt (struct_field, _) -> @@ -496,7 +495,7 @@ let format_ctx fields in let format_enum_decl fmt (enum_name, enum_cons) = - if EnumConstructorMap.is_empty enum_cons then + if EnumConstructor.Map.is_empty enum_cons then failwith "no constructors in the enum" else Format.fprintf fmt @@ -529,7 +528,7 @@ let format_ctx Format.fprintf fmt "%a = %d" format_enum_cons_name enum_cons i)) (List.mapi (fun i (x, y) -> i, x, y) - (EnumConstructorMap.bindings enum_cons)) + (EnumConstructor.Map.bindings enum_cons)) format_enum_name enum_name format_enum_name enum_name format_enum_name enum_name in @@ -545,8 +544,8 @@ let format_ctx let scope_structs = List.map (fun (s, _) -> Scopelang.Dependency.TVertex.Struct s) - (StructMap.bindings - (StructMap.filter + (StructName.Map.bindings + (StructName.Map.filter (fun s _ -> not (is_in_type_ordering s)) ctx.ctx_structs)) in @@ -555,10 +554,10 @@ let format_ctx match struct_or_enum with | Scopelang.Dependency.TVertex.Struct s -> Format.fprintf fmt "%a@\n@\n" format_struct_decl - (s, StructMap.find s ctx.ctx_structs) + (s, StructName.Map.find s ctx.ctx_structs) | Scopelang.Dependency.TVertex.Enum e -> Format.fprintf fmt "%a@\n@\n" format_enum_decl - (e, EnumMap.find e ctx.ctx_enums)) + (e, EnumName.Map.find e ctx.ctx_enums)) (type_ordering @ scope_structs) let format_program diff --git a/compiler/scopelang/ast.ml b/compiler/scopelang/ast.ml index d41b2308..277457ef 100644 --- a/compiler/scopelang/ast.ml +++ b/compiler/scopelang/ast.ml @@ -46,13 +46,13 @@ type 'm rule = type 'm scope_decl = { scope_decl_name : ScopeName.t; - scope_sig : (typ * Desugared.Ast.io) ScopeVarMap.t; + scope_sig : (typ * Desugared.Ast.io) ScopeVar.Map.t; scope_decl_rules : 'm rule list; scope_mark : 'm mark; } type 'm program = { - program_scopes : 'm scope_decl ScopeMap.t; + program_scopes : 'm scope_decl ScopeName.Map.t; program_ctx : decl_ctx; } @@ -70,17 +70,17 @@ let type_rule decl_ctx env = function let type_program (prg : 'm program) : typed program = let typing_env = - ScopeMap.fold + ScopeName.Map.fold (fun scope_name scope_decl -> - let vars = ScopeVarMap.map fst scope_decl.scope_sig in + let vars = ScopeVar.Map.map fst scope_decl.scope_sig in Typing.Env.add_scope scope_name ~vars) prg.program_scopes Typing.Env.empty in let program_scopes = - ScopeMap.map + ScopeName.Map.map (fun scope_decl -> let typing_env = - ScopeVarMap.fold + ScopeVar.Map.fold (fun svar (typ, _) env -> Typing.Env.add_scope_var svar typ env) scope_decl.scope_sig typing_env in diff --git a/compiler/scopelang/ast.mli b/compiler/scopelang/ast.mli index 3e1c1dff..1208cc21 100644 --- a/compiler/scopelang/ast.mli +++ b/compiler/scopelang/ast.mli @@ -38,13 +38,13 @@ type 'm rule = type 'm scope_decl = { scope_decl_name : ScopeName.t; - scope_sig : (typ * Desugared.Ast.io) ScopeVarMap.t; + scope_sig : (typ * Desugared.Ast.io) ScopeVar.Map.t; scope_decl_rules : 'm rule list; scope_mark : 'm mark; } type 'm program = { - program_scopes : 'm scope_decl ScopeMap.t; + program_scopes : 'm scope_decl ScopeName.Map.t; program_ctx : decl_ctx; } diff --git a/compiler/scopelang/dependency.ml b/compiler/scopelang/dependency.ml index f710e1ea..9b7510ea 100644 --- a/compiler/scopelang/dependency.ml +++ b/compiler/scopelang/dependency.ml @@ -41,12 +41,12 @@ module SSCC = Graph.Components.Make (SDependencies) let rec expr_used_scopes e = let recurse_subterms e = Expr.shallow_fold - (fun e -> ScopeMap.union (fun _ x _ -> Some x) (expr_used_scopes e)) - e ScopeMap.empty + (fun e -> ScopeName.Map.union (fun _ x _ -> Some x) (expr_used_scopes e)) + e ScopeName.Map.empty in match e with | (EScopeCall { scope; _ }, m) as e -> - ScopeMap.add scope (Expr.mark_pos m) (recurse_subterms e) + ScopeName.Map.add scope (Expr.mark_pos m) (recurse_subterms e) | EAbs { binder; _ }, _ -> let _, body = Bindlib.unmbind binder in expr_used_scopes body @@ -58,28 +58,28 @@ let rule_used_scopes = function walking through all exprs again *) expr_used_scopes e | Ast.Call (subscope, subindex, _) -> - ScopeMap.singleton subscope + ScopeName.Map.singleton subscope (Marked.get_mark (SubScopeName.get_info subindex)) let build_program_dep_graph (prgm : 'm Ast.program) : SDependencies.t = let g = SDependencies.empty in let g = - ScopeMap.fold + ScopeName.Map.fold (fun v _ g -> SDependencies.add_vertex g v) prgm.program_scopes g in - ScopeMap.fold + ScopeName.Map.fold (fun scope_name scope g -> List.fold_left (fun g rule -> let used_scopes = rule_used_scopes rule in - if ScopeMap.mem scope_name used_scopes then + if ScopeName.Map.mem scope_name used_scopes then Errors.raise_spanned_error (Marked.get_mark (ScopeName.get_info scope.Ast.scope_decl_name)) "The scope %a is calling into itself as a subscope, which is \ forbidden since Catala does not provide recursion" ScopeName.format_t scope.Ast.scope_decl_name; - ScopeMap.fold + ScopeName.Map.fold (fun used_scope pos g -> let edge = SDependencies.E.create used_scope pos scope_name in SDependencies.add_edge_e g edge) @@ -190,9 +190,9 @@ let build_type_graph (structs : struct_ctx) (enums : enum_ctx) : TDependencies.t = let g = TDependencies.empty in let g = - StructMap.fold + StructName.Map.fold (fun s fields g -> - StructFieldMap.fold + StructField.Map.fold (fun _ typ g -> let def = TVertex.Struct s in let g = TDependencies.add_vertex g def in @@ -214,9 +214,9 @@ let build_type_graph (structs : struct_ctx) (enums : enum_ctx) : TDependencies.t structs g in let g = - EnumMap.fold + EnumName.Map.fold (fun e cases g -> - EnumConstructorMap.fold + EnumConstructor.Map.fold (fun _ typ g -> let def = TVertex.Enum e in let g = TDependencies.add_vertex g def in diff --git a/compiler/scopelang/from_desugared.ml b/compiler/scopelang/from_desugared.ml index 87f6c8c0..9577db60 100644 --- a/compiler/scopelang/from_desugared.ml +++ b/compiler/scopelang/from_desugared.ml @@ -26,7 +26,7 @@ type target_scope_vars = | States of (StateName.t * ScopeVar.t) list type ctx = { - scope_var_mapping : target_scope_vars ScopeVarMap.t; + scope_var_mapping : target_scope_vars ScopeVar.Map.t; var_mapping : (Desugared.Ast.expr, untyped Ast.expr Var.t) Var.Map.t; } @@ -46,7 +46,7 @@ let rec translate_expr (ctx : ctx) (e : Desugared.Ast.expr) : (* When referring to a subscope variable in an expression, we are referring to the output, hence we take the last state. *) let new_s_var = - match ScopeVarMap.find (Marked.unmark s_var) ctx.scope_var_mapping with + match ScopeVar.Map.find (Marked.unmark s_var) ctx.scope_var_mapping with | WholeVar new_s_var -> Marked.same_mark_as new_s_var s_var | States states -> Marked.same_mark_as (snd (List.hd (List.rev states))) s_var @@ -56,7 +56,7 @@ let rec translate_expr (ctx : ctx) (e : Desugared.Ast.expr) : Expr.elocation (ScopelangScopeVar (match - ScopeVarMap.find (Marked.unmark s_var) ctx.scope_var_mapping + ScopeVar.Map.find (Marked.unmark s_var) ctx.scope_var_mapping with | WholeVar new_s_var -> Marked.same_mark_as new_s_var s_var | States _ -> failwith "should not happen")) @@ -65,27 +65,27 @@ let rec translate_expr (ctx : ctx) (e : Desugared.Ast.expr) : Expr.elocation (ScopelangScopeVar (match - ScopeVarMap.find (Marked.unmark s_var) ctx.scope_var_mapping + ScopeVar.Map.find (Marked.unmark s_var) ctx.scope_var_mapping with | WholeVar _ -> failwith "should not happen" | States states -> Marked.same_mark_as (List.assoc state states) s_var)) m | EVar v -> Expr.evar (Var.Map.find v ctx.var_mapping) m | EStruct { name; fields } -> - Expr.estruct name (StructFieldMap.map (translate_expr ctx) fields) m + Expr.estruct name (StructField.Map.map (translate_expr ctx) fields) m | EStructAccess { e; field; name } -> Expr.estructaccess (translate_expr ctx e) field name m | EInj { e; cons; name } -> Expr.einj (translate_expr ctx e) cons name m | EMatch { e; name; cases } -> Expr.ematch (translate_expr ctx e) name - (EnumConstructorMap.map (translate_expr ctx) cases) + (EnumConstructor.Map.map (translate_expr ctx) cases) m | EScopeCall { scope; args } -> Expr.escopecall scope - (ScopeVarMap.fold + (ScopeVar.Map.fold (fun v e args' -> let v' = - match ScopeVarMap.find v ctx.scope_var_mapping with + match ScopeVar.Map.find v ctx.scope_var_mapping with | WholeVar v' -> v' | States ((_, v') :: _) -> (* When there are multiple states, the input is always the first @@ -93,8 +93,8 @@ let rec translate_expr (ctx : ctx) (e : Desugared.Ast.expr) : v' | States [] -> assert false in - ScopeVarMap.add v' (translate_expr ctx e) args') - args ScopeVarMap.empty) + ScopeVar.Map.add v' (translate_expr ctx e) args') + args ScopeVar.Map.empty) m | ELit (( LBool _ | LEmptyError | LInt _ | LRat _ | LMoney _ | LUnit | LDate _ @@ -139,7 +139,7 @@ type rule_tree = priorities declared between rules *) let def_map_to_tree (def_info : Desugared.Ast.ScopeDef.t) - (def : Desugared.Ast.rule RuleMap.t) : rule_tree list = + (def : Desugared.Ast.rule RuleName.Map.t) : rule_tree list = let exc_graph = Desugared.Dependency.build_exceptions_graph def def_info in Desugared.Dependency.check_for_exception_cycle exc_graph; (* we start by the base cases: they are the vertices which have no @@ -153,12 +153,14 @@ let def_map_to_tree else base_cases) exc_graph [] in - let rec build_tree (base_cases : RuleSet.t) : rule_tree = + let rec build_tree (base_cases : RuleName.Set.t) : rule_tree = let exceptions = Desugared.Dependency.ExceptionsDependencies.pred exc_graph base_cases in let base_case_as_rule_list = - List.map (fun r -> RuleMap.find r def) (RuleSet.elements base_cases) + List.map + (fun r -> RuleName.Map.find r def) + (RuleName.Set.elements base_cases) in match exceptions with | [] -> Leaf base_case_as_rule_list @@ -278,7 +280,7 @@ let rec rule_tree_to_expr let translate_def (ctx : ctx) (def_info : Desugared.Ast.ScopeDef.t) - (def : Desugared.Ast.rule RuleMap.t) + (def : Desugared.Ast.rule RuleName.Map.t) (typ : typ) (io : Desugared.Ast.io) ~(is_cond : bool) @@ -290,9 +292,9 @@ let translate_def let is_rule_func _ (r : Desugared.Ast.rule) : bool = Option.is_some r.Desugared.Ast.rule_parameter in - let all_rules_func = RuleMap.for_all is_rule_func def in + let all_rules_func = RuleName.Map.for_all is_rule_func def in let all_rules_not_func = - RuleMap.for_all (fun n r -> not (is_rule_func n r)) def + RuleName.Map.for_all (fun n r -> not (is_rule_func n r)) def in let is_def_func_param_typ : typ option = if is_def_func && all_rules_func then @@ -310,13 +312,13 @@ let translate_def (fun (_, r) -> ( Some "This definition is a function:", Expr.pos r.Desugared.Ast.rule_cons )) - (RuleMap.bindings (RuleMap.filter is_rule_func def)) + (RuleName.Map.bindings (RuleName.Map.filter is_rule_func def)) @ List.map (fun (_, r) -> ( Some "This definition is not a function:", Expr.pos r.Desugared.Ast.rule_cons )) - (RuleMap.bindings - (RuleMap.filter (fun n r -> not (is_rule_func n r)) def)) + (RuleName.Map.bindings + (RuleName.Map.filter (fun n r -> not (is_rule_func n r)) def)) in Errors.raise_multispanned_error spans "some definitions of the same variable are functions while others \ @@ -340,7 +342,7 @@ let translate_def else None in if - RuleMap.cardinal def = 0 + RuleName.Map.cardinal def = 0 && is_subscope_var (* Here we have a special case for the empty definitions. Indeed, we could use the code for the regular case below that would create a convoluted @@ -421,7 +423,7 @@ let translate_scope (ctx : ctx) (scope : Desugared.Ast.scope) : match Marked.unmark scope_def.Desugared.Ast.scope_def_io.io_input with - | OnlyInput when not (RuleMap.is_empty var_def) -> + | OnlyInput when not (RuleName.Map.is_empty var_def) -> (* If the variable is tagged as input, then it shall not be redefined. *) Errors.raise_multispanned_error @@ -431,7 +433,7 @@ let translate_scope (ctx : ctx) (scope : Desugared.Ast.scope) : (fun (rule, _) -> ( Some "Incriminated variable definition:", Marked.get_mark (RuleName.get_info rule) )) - (RuleMap.bindings var_def)) + (RuleName.Map.bindings var_def)) "It is impossible to give a definition to a scope variable \ tagged as input." | OnlyInput -> @@ -445,7 +447,7 @@ let translate_scope (ctx : ctx) (scope : Desugared.Ast.scope) : ~is_subscope_var:false in let scope_var = - match ScopeVarMap.find var ctx.scope_var_mapping, state with + match ScopeVar.Map.find var ctx.scope_var_mapping, state with | WholeVar v, None -> v | States states, Some state -> List.assoc state states | _ -> failwith "should not happen" @@ -464,7 +466,7 @@ let translate_scope (ctx : ctx) (scope : Desugared.Ast.scope) : (* Before calling the sub_scope, we need to include all the re-definitions of subscope parameters*) let sub_scope = - SubScopeMap.find sub_scope_index scope.scope_sub_scopes + SubScopeName.Map.find sub_scope_index scope.scope_sub_scopes in let sub_scope_vars_redefs_candidates = Desugared.Ast.ScopeDefMap.filter @@ -483,7 +485,7 @@ let translate_scope (ctx : ctx) (scope : Desugared.Ast.scope) : with | Desugared.Ast.NoInput -> true | _ -> false) - && RuleMap.is_empty scope_def.scope_def_rules)) + && RuleName.Map.is_empty scope_def.scope_def_rules)) scope.scope_defs in let sub_scope_vars_redefs = @@ -517,10 +519,11 @@ let translate_scope (ctx : ctx) (scope : Desugared.Ast.scope) : ( Some "Incriminated subscope variable definition:", Marked.get_mark (RuleName.get_info rule) )) - (RuleMap.bindings def)) + (RuleName.Map.bindings def)) "It is impossible to give a definition to a subscope \ variable not tagged as input or context." - | OnlyInput when RuleMap.is_empty def && not is_cond -> + | OnlyInput when RuleName.Map.is_empty def && not is_cond + -> (* If the subscope variable is tagged as input, then it shall be defined. *) Errors.raise_multispanned_error @@ -540,7 +543,8 @@ let translate_scope (ctx : ctx) (scope : Desugared.Ast.scope) : ~is_subscope_var:true in let subscop_real_name = - SubScopeMap.find sub_scope_index scope.scope_sub_scopes + SubScopeName.Map.find sub_scope_index + scope.scope_sub_scopes in let var_pos = Desugared.Ast.ScopeDef.get_position def_key @@ -550,7 +554,7 @@ let translate_scope (ctx : ctx) (scope : Desugared.Ast.scope) : ( subscop_real_name, (sub_scope_index, var_pos), match - ScopeVarMap.find sub_scope_var + ScopeVar.Map.find sub_scope_var ctx.scope_var_mapping with | WholeVar v -> v, var_pos @@ -595,7 +599,7 @@ let translate_scope (ctx : ctx) (scope : Desugared.Ast.scope) : scope.Desugared.Ast.scope_assertions in let scope_sig = - ScopeVarMap.fold + ScopeVar.Map.fold (fun var (states : Desugared.Ast.var_or_states) acc -> match states with | WholeVar -> @@ -605,8 +609,8 @@ let translate_scope (ctx : ctx) (scope : Desugared.Ast.scope) : scope.scope_defs in let typ = scope_def.scope_def_typ in - ScopeVarMap.add - (match ScopeVarMap.find var ctx.scope_var_mapping with + ScopeVar.Map.add + (match ScopeVar.Map.find var ctx.scope_var_mapping with | WholeVar v -> v | States _ -> failwith "should not happen") (typ, scope_def.scope_def_io) @@ -622,14 +626,14 @@ let translate_scope (ctx : ctx) (scope : Desugared.Ast.scope) : (Desugared.Ast.ScopeDef.Var (var, Some state)) scope.scope_defs in - ScopeVarMap.add - (match ScopeVarMap.find var ctx.scope_var_mapping with + ScopeVar.Map.add + (match ScopeVar.Map.find var ctx.scope_var_mapping with | WholeVar _ -> failwith "should not happen" | States states' -> List.assoc state states') (scope_def.scope_def_typ, scope_def.scope_def_io) acc) acc states) - scope.scope_vars ScopeVarMap.empty + scope.scope_vars ScopeVar.Map.empty in let pos = Marked.get_mark (ScopeName.get_info scope.scope_uid) in { @@ -648,16 +652,16 @@ let translate_program (pgrm : Desugared.Ast.program) : untyped Ast.program = let ctx = (* Todo: since we rename all scope vars at this point, it would be better to have different types for Desugared.ScopeVar.t and Scopelang.ScopeVar.t *) - ScopeMap.fold + ScopeName.Map.fold (fun _scope scope_decl ctx -> - ScopeVarMap.fold + ScopeVar.Map.fold (fun scope_var (states : Desugared.Ast.var_or_states) ctx -> match states with | Desugared.Ast.WholeVar -> { ctx with scope_var_mapping = - ScopeVarMap.add scope_var + ScopeVar.Map.add scope_var (WholeVar (ScopeVar.fresh (ScopeVar.get_info scope_var))) ctx.scope_var_mapping; } @@ -665,7 +669,7 @@ let translate_program (pgrm : Desugared.Ast.program) : untyped Ast.program = { ctx with scope_var_mapping = - ScopeVarMap.add scope_var + ScopeVar.Map.add scope_var (States (List.map (fun state -> @@ -683,26 +687,27 @@ let translate_program (pgrm : Desugared.Ast.program) : untyped Ast.program = }) scope_decl.Desugared.Ast.scope_vars ctx) pgrm.Desugared.Ast.program_scopes - { scope_var_mapping = ScopeVarMap.empty; var_mapping = Var.Map.empty } + { scope_var_mapping = ScopeVar.Map.empty; var_mapping = Var.Map.empty } in let ctx_scopes = - ScopeMap.map + ScopeName.Map.map (fun out_str -> let out_struct_fields = - ScopeVarMap.fold + ScopeVar.Map.fold (fun var fld out_map -> let var' = - match ScopeVarMap.find var ctx.scope_var_mapping with + match ScopeVar.Map.find var ctx.scope_var_mapping with | WholeVar v -> v | States l -> snd (List.hd (List.rev l)) in - ScopeVarMap.add var' fld out_map) - out_str.out_struct_fields ScopeVarMap.empty + ScopeVar.Map.add var' fld out_map) + out_str.out_struct_fields ScopeVar.Map.empty in { out_str with out_struct_fields }) pgrm.Desugared.Ast.program_ctx.ctx_scopes in { - Ast.program_scopes = ScopeMap.map (translate_scope ctx) pgrm.program_scopes; + Ast.program_scopes = + ScopeName.Map.map (translate_scope ctx) pgrm.program_scopes; program_ctx = { pgrm.program_ctx with ctx_scopes }; } diff --git a/compiler/scopelang/print.ml b/compiler/scopelang/print.ml index fcc4e6f6..6bd26bc9 100644 --- a/compiler/scopelang/print.ml +++ b/compiler/scopelang/print.ml @@ -22,22 +22,22 @@ let struc ctx (fmt : Format.formatter) (name : StructName.t) - (fields : typ StructFieldMap.t) : unit = + (fields : typ StructField.Map.t) : unit = Format.fprintf fmt "%a %a %a %a@\n@[ %a@]@\n%a" Print.keyword "struct" StructName.format_t name Print.punctuation "=" Print.punctuation "{" (Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n") (fun fmt (field_name, typ) -> - Format.fprintf fmt "%a%a %a" StructFieldName.format_t field_name + Format.fprintf fmt "%a%a %a" StructField.format_t field_name Print.punctuation ":" (Print.typ ctx) typ)) - (StructFieldMap.bindings fields) + (StructField.Map.bindings fields) Print.punctuation "}" let enum ctx (fmt : Format.formatter) (name : EnumName.t) - (cases : typ EnumConstructorMap.t) : unit = + (cases : typ EnumConstructor.Map.t) : unit = Format.fprintf fmt "%a %a %a @\n@[ %a@]" Print.keyword "enum" EnumName.format_t name Print.punctuation "=" (Format.pp_print_list @@ -46,7 +46,7 @@ let enum Format.fprintf fmt "%a %a%a %a" Print.punctuation "|" EnumConstructor.format_t field_name Print.punctuation ":" (Print.typ ctx) typ)) - (EnumConstructorMap.bindings cases) + (EnumConstructor.Map.bindings cases) let scope ?(debug = false) ctx fmt (name, decl) = Format.fprintf fmt "@[%a@ %a@ %a@ %a@ %a@]@\n@[ %a@]" @@ -65,7 +65,7 @@ let scope ?(debug = false) ctx fmt (name, decl) = "output" else fun fmt () -> Format.fprintf fmt "@<0>") () Print.punctuation ")")) - (ScopeVarMap.bindings decl.scope_sig) + (ScopeVar.Map.bindings decl.scope_sig) Print.punctuation "=" (Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt "%a@ " Print.punctuation ";") @@ -81,7 +81,7 @@ let scope ?(debug = false) ctx fmt (name, decl) = | ScopelangScopeVar v -> ( match Marked.unmark - (snd (ScopeVarMap.find (Marked.unmark v) decl.scope_sig)) + (snd (ScopeVar.Map.find (Marked.unmark v) decl.scope_sig)) .io_input with | Reentrant -> @@ -106,16 +106,16 @@ let program ?(debug : bool = false) (fmt : Format.formatter) (p : 'm program) : Format.pp_print_cut fmt () in Format.pp_open_vbox fmt 0; - StructMap.iter + StructName.Map.iter (fun n s -> struc ctx fmt n s; pp_sep fmt ()) ctx.ctx_structs; - EnumMap.iter + EnumName.Map.iter (fun n e -> enum ctx fmt n e; pp_sep fmt ()) ctx.ctx_enums; Format.pp_print_list ~pp_sep (scope ~debug ctx) fmt - (ScopeMap.bindings p.program_scopes); + (ScopeName.Map.bindings p.program_scopes); Format.pp_close_box fmt () diff --git a/compiler/shared_ast/definitions.ml b/compiler/shared_ast/definitions.ml index c1b88e3e..0beb6e01 100644 --- a/compiler/shared_ast/definitions.ml +++ b/compiler/shared_ast/definitions.ml @@ -22,29 +22,21 @@ open Utils module Runtime = Runtime_ocaml.Runtime - module ScopeName = Uid.Gen () - module StructName = Uid.Gen () - module StructField = Uid.Gen () - module EnumName = Uid.Gen () - module EnumConstructor = Uid.Gen () (** Only used by surface *) module RuleName = Uid.Gen () - module LabelName = Uid.Gen () (** Only used by desugared/scopelang *) module ScopeVar = Uid.Gen () - module SubScopeName = Uid.Gen () - module StateName = Uid.Gen () (** {1 Abstract syntax tree} *) @@ -194,13 +186,13 @@ and ('a, 't) naked_gexpr = -> ('a any, 't) naked_gexpr | EStruct : { name : StructName.t; - fields : ('a, 't) gexpr StructFieldMap.t; + fields : ('a, 't) gexpr StructField.Map.t; } -> ('a any, 't) naked_gexpr | EStructAccess : { name : StructName.t; e : ('a, 't) gexpr; - field : StructFieldName.t; + field : StructField.t; } -> ('a any, 't) naked_gexpr | EInj : { @@ -212,7 +204,7 @@ and ('a, 't) naked_gexpr = | EMatch : { name : EnumName.t; e : ('a, 't) gexpr; - cases : ('a, 't) gexpr EnumConstructorMap.t; + cases : ('a, 't) gexpr EnumConstructor.Map.t; } -> ('a any, 't) naked_gexpr (* Early stages *) @@ -221,7 +213,7 @@ and ('a, 't) naked_gexpr = -> (([< desugared | scopelang ] as 'a), 't) naked_gexpr | EScopeCall : { scope : ScopeName.t; - args : ('a, 't) gexpr ScopeVarMap.t; + args : ('a, 't) gexpr ScopeVar.Map.t; } -> (([< desugared | scopelang ] as 'a), 't) naked_gexpr (* Lambda-like *) @@ -348,18 +340,18 @@ and 'e scopes = | ScopeDef of 'e scope_def constraint 'e = (_ any, _ mark) gexpr -type struct_ctx = typ StructFieldMap.t StructMap.t -type enum_ctx = typ EnumConstructorMap.t EnumMap.t +type struct_ctx = typ StructField.Map.t StructName.Map.t +type enum_ctx = typ EnumConstructor.Map.t EnumName.Map.t type scope_out_struct = { out_struct_name : StructName.t; - out_struct_fields : StructFieldName.t ScopeVarMap.t; + out_struct_fields : StructField.t ScopeVar.Map.t; } type decl_ctx = { ctx_enums : enum_ctx; ctx_structs : struct_ctx; - ctx_scopes : scope_out_struct ScopeMap.t; + ctx_scopes : scope_out_struct ScopeName.Map.t; } type 'e program = { decl_ctx : decl_ctx; scopes : 'e scopes } diff --git a/compiler/shared_ast/expr.ml b/compiler/shared_ast/expr.ml index f3c62d40..dec67ebc 100644 --- a/compiler/shared_ast/expr.ml +++ b/compiler/shared_ast/expr.ml @@ -57,15 +57,15 @@ module Box = struct fun em -> B.box_apply (fun e -> Marked.mark (Marked.get_mark em) e) (Marked.unmark em) - module LiftStruct = Bindlib.Lift (StructFieldMap) + module LiftStruct = Bindlib.Lift (StructField.Map) let lift_struct = LiftStruct.lift_box - module LiftEnum = Bindlib.Lift (EnumConstructorMap) + module LiftEnum = Bindlib.Lift (EnumConstructor.Map) let lift_enum = LiftEnum.lift_box - module LiftScopeVars = Bindlib.Lift (ScopeVarMap) + module LiftScopeVars = Bindlib.Lift (ScopeVar.Map) let lift_scope_vars = LiftScopeVars.lift_box end @@ -108,11 +108,11 @@ let ecatch body exn handler = let elocation loc = Box.app0 @@ ELocation loc -let estruct name (fields : ('a, 't) boxed_gexpr StructFieldMap.t) mark = +let estruct name (fields : ('a, 't) boxed_gexpr StructField.Map.t) mark = Marked.mark mark @@ Bindlib.box_apply (fun fields -> EStruct { name; fields }) - (Box.lift_struct (StructFieldMap.map Box.lift fields)) + (Box.lift_struct (StructField.Map.map Box.lift fields)) let estructaccess e field name = Box.app1 e @@ fun e -> EStructAccess { name; e; field } @@ -124,13 +124,13 @@ let ematch e name cases mark = @@ Bindlib.box_apply2 (fun e cases -> EMatch { name; e; cases }) (Box.lift e) - (Box.lift_enum (EnumConstructorMap.map Box.lift cases)) + (Box.lift_enum (EnumConstructor.Map.map Box.lift cases)) let escopecall scope args mark = Marked.mark mark @@ Bindlib.box_apply (fun args -> EScopeCall { scope; args }) - (Box.lift_scope_vars (ScopeVarMap.map Box.lift args)) + (Box.lift_scope_vars (ScopeVar.Map.map Box.lift args)) (* - Manipulation of marks - *) @@ -230,14 +230,14 @@ let map | ERaise exn -> eraise exn m | ELocation loc -> elocation loc m | EStruct { name; fields } -> - let fields = StructFieldMap.map f fields in + let fields = StructField.Map.map f fields in estruct name fields m | EStructAccess { e; field; name } -> estructaccess (f e) field name m | EMatch { e; name; cases } -> - let cases = EnumConstructorMap.map f cases in + let cases = EnumConstructor.Map.map f cases in ematch (f e) name cases m | EScopeCall { scope; args } -> - let fields = ScopeVarMap.map f args in + let fields = ScopeVar.Map.map f args in escopecall scope fields m let rec map_top_down ~f e = map ~f:(map_top_down ~f) (f e) @@ -266,11 +266,11 @@ let shallow_fold | EDefault { excepts; just; cons } -> acc |> lfold excepts |> f just |> f cons | EErrorOnEmpty e -> acc |> f e | ECatch { body; handler; _ } -> acc |> f body |> f handler - | EStruct { fields; _ } -> acc |> StructFieldMap.fold (fun _ -> f) fields + | EStruct { fields; _ } -> acc |> StructField.Map.fold (fun _ -> f) fields | EStructAccess { e; _ } -> acc |> f e | EMatch { e; cases; _ } -> - acc |> f e |> EnumConstructorMap.fold (fun _ -> f) cases - | EScopeCall { args; _ } -> acc |> ScopeVarMap.fold (fun _ -> f) args + acc |> f e |> EnumConstructor.Map.fold (fun _ -> f) cases + | EScopeCall { args; _ } -> acc |> ScopeVar.Map.fold (fun _ -> f) args (* Like [map], but also allows to gather a result bottom-up. *) let map_gather @@ -339,12 +339,12 @@ let map_gather | ELocation loc -> acc, elocation loc m | EStruct { name; fields } -> let acc, fields = - StructFieldMap.fold + StructField.Map.fold (fun cons e (acc, fields) -> let acc1, e = f e in - join acc acc1, StructFieldMap.add cons e fields) + join acc acc1, StructField.Map.add cons e fields) fields - (acc, StructFieldMap.empty) + (acc, StructField.Map.empty) in acc, estruct name fields m | EStructAccess { e; field; name } -> @@ -353,21 +353,21 @@ let map_gather | EMatch { e; name; cases } -> let acc, e = f e in let acc, cases = - EnumConstructorMap.fold + EnumConstructor.Map.fold (fun cons e (acc, cases) -> let acc1, e = f e in - join acc acc1, EnumConstructorMap.add cons e cases) + join acc acc1, EnumConstructor.Map.add cons e cases) cases - (acc, EnumConstructorMap.empty) + (acc, EnumConstructor.Map.empty) in acc, ematch e name cases m | EScopeCall { scope; args } -> let acc, args = - ScopeVarMap.fold + ScopeVar.Map.fold (fun var e (acc, args) -> let acc1, e = f e in - join acc acc1, ScopeVarMap.add var e args) - args (acc, ScopeVarMap.empty) + join acc acc1, ScopeVar.Map.add var e args) + args (acc, ScopeVar.Map.empty) in acc, escopecall scope args m @@ -688,10 +688,10 @@ and equal : type a. (a, 't) gexpr -> (a, 't) gexpr -> bool = equal_location (Marked.mark Pos.no_pos l1) (Marked.mark Pos.no_pos l2) | ( EStruct { name = s1; fields = fields1 }, EStruct { name = s2; fields = fields2 } ) -> - StructName.equal s1 s2 && StructFieldMap.equal equal fields1 fields2 + StructName.equal s1 s2 && StructField.Map.equal equal fields1 fields2 | ( EStructAccess { e = e1; field = f1; name = s1 }, EStructAccess { e = e2; field = f2; name = s2 } ) -> - StructName.equal s1 s2 && StructFieldName.equal f1 f2 && equal e1 e2 + StructName.equal s1 s2 && StructField.equal f1 f2 && equal e1 e2 | EInj { e = e1; cons = c1; name = n1 }, EInj { e = e2; cons = c2; name = n2 } -> EnumName.equal n1 n2 && EnumConstructor.equal c1 c2 && equal e1 e2 @@ -699,10 +699,10 @@ and equal : type a. (a, 't) gexpr -> (a, 't) gexpr -> bool = EMatch { e = e2; name = n2; cases = cases2 } ) -> EnumName.equal n1 n2 && equal e1 e2 - && EnumConstructorMap.equal equal cases1 cases2 + && EnumConstructor.Map.equal equal cases1 cases2 | ( EScopeCall { scope = s1; args = fields1 }, EScopeCall { scope = s2; args = fields2 } ) -> - ScopeName.equal s1 s2 && ScopeVarMap.equal equal fields1 fields2 + ScopeName.equal s1 s2 && ScopeVar.Map.equal equal fields1 fields2 | ( ( EVar _ | ETuple _ | ETupleAccess _ | EArray _ | ELit _ | EAbs _ | EApp _ | EAssert _ | EOp _ | EDefault _ | EIfThenElse _ | EErrorOnEmpty _ | ERaise _ | ECatch _ | ELocation _ | EStruct _ | EStructAccess _ | EInj _ @@ -740,19 +740,19 @@ let rec compare : type a. (a, _) gexpr -> (a, _) gexpr -> int = compare_location (Marked.mark Pos.no_pos l1) (Marked.mark Pos.no_pos l2) | EStruct {name=name1; fields= field_map1}, EStruct {name=name2; fields= field_map2} -> StructName.compare name1 name2 @@< fun () -> - StructFieldMap.compare compare field_map1 field_map2 + StructField.Map.compare compare field_map1 field_map2 | EStructAccess {e=e1; field= field_name1; name= struct_name1}, EStructAccess {e=e2; field= field_name2; name= struct_name2} -> compare e1 e2 @@< fun () -> - StructFieldName.compare field_name1 field_name2 @@< fun () -> + StructField.compare field_name1 field_name2 @@< fun () -> StructName.compare struct_name1 struct_name2 | EMatch {e=e1; name= name1;cases= emap1}, EMatch {e=e2; name= name2;cases= emap2} -> EnumName.compare name1 name2 @@< fun () -> compare e1 e2 @@< fun () -> - EnumConstructorMap.compare compare emap1 emap2 + EnumConstructor.Map.compare compare emap1 emap2 | EScopeCall {scope=name1; args= field_map1}, EScopeCall {scope=name2; args= field_map2} -> ScopeName.compare name1 name2 @@< fun () -> - ScopeVarMap.compare compare field_map1 field_map2 + ScopeVar.Map.compare compare field_map1 field_map2 | ETuple es1, ETuple es2 -> List.compare compare es1 es2 | ETupleAccess {e=e1; index= n1; size=s1}, ETupleAccess {e=e2; index= n2; size=s2} -> @@ -841,12 +841,12 @@ let rec size : type a. (a, 't) gexpr -> int = | ECatch { body; handler; _ } -> 1 + size body + size handler | ELocation _ -> 1 | EStruct { fields; _ } -> - StructFieldMap.fold (fun _ e acc -> acc + 1 + size e) fields 0 + StructField.Map.fold (fun _ e acc -> acc + 1 + size e) fields 0 | EStructAccess { e; _ } -> 1 + size e | EMatch { e; cases; _ } -> - EnumConstructorMap.fold (fun _ e acc -> acc + 1 + size e) cases (size e) + EnumConstructor.Map.fold (fun _ e acc -> acc + 1 + size e) cases (size e) | EScopeCall { args; _ } -> - ScopeVarMap.fold (fun _ e acc -> acc + 1 + size e) args 1 + ScopeVar.Map.fold (fun _ e acc -> acc + 1 + size e) args 1 (* - Expression building helpers - *) diff --git a/compiler/shared_ast/expr.mli b/compiler/shared_ast/expr.mli index 5321edda..b51bc230 100644 --- a/compiler/shared_ast/expr.mli +++ b/compiler/shared_ast/expr.mli @@ -101,13 +101,13 @@ val elocation : val estruct : StructName.t -> - ('a any, 't) boxed_gexpr StructFieldMap.t -> + ('a any, 't) boxed_gexpr StructField.Map.t -> 't -> ('a, 't) boxed_gexpr val estructaccess : ('a any, 't) boxed_gexpr -> - StructFieldName.t -> + StructField.t -> StructName.t -> 't -> ('a, 't) boxed_gexpr @@ -122,13 +122,13 @@ val einj : val ematch : ('a any, 't) boxed_gexpr -> EnumName.t -> - ('a, 't) boxed_gexpr EnumConstructorMap.t -> + ('a, 't) boxed_gexpr EnumConstructor.Map.t -> 't -> ('a, 't) boxed_gexpr val escopecall : ScopeName.t -> - (([< desugared | scopelang ] as 'a), 't) boxed_gexpr ScopeVarMap.t -> + (([< desugared | scopelang ] as 'a), 't) boxed_gexpr ScopeVar.Map.t -> 't -> ('a, 't) boxed_gexpr diff --git a/compiler/shared_ast/print.ml b/compiler/shared_ast/print.ml index eb15729e..b5bd7a09 100644 --- a/compiler/shared_ast/print.ml +++ b/compiler/shared_ast/print.ml @@ -94,9 +94,9 @@ let rec typ (ctx : decl_ctx option) (fmt : Format.formatter) (ty : typ) : unit = ~pp_sep:(fun fmt () -> Format.fprintf fmt "%a@ " punctuation ";") (fun fmt (field, mty) -> Format.fprintf fmt "%a%a%a%a@ %a" punctuation "\"" - StructFieldName.format_t field punctuation "\"" punctuation ":" - typ mty)) - (StructFieldMap.bindings (StructMap.find s ctx.ctx_structs)) + StructField.format_t field punctuation "\"" punctuation ":" typ + mty)) + (StructField.Map.bindings (StructName.Map.find s ctx.ctx_structs)) punctuation "}") | TEnum e -> ( match ctx with @@ -109,7 +109,7 @@ let rec typ (ctx : decl_ctx option) (fmt : Format.formatter) (ty : typ) : unit = (fun fmt (case, mty) -> Format.fprintf fmt "%a%a@ %a" enum_constructor case punctuation ":" typ mty)) - (EnumConstructorMap.bindings (EnumMap.find e ctx.ctx_enums)) + (EnumConstructor.Map.bindings (EnumName.Map.find e ctx.ctx_enums)) punctuation "]") | TOption t -> Format.fprintf fmt "@[%a@ %a@]" base_type "option" typ t | TArrow (t1, t2) -> @@ -330,13 +330,13 @@ let rec expr_aux : ~pp_sep:(fun fmt () -> Format.fprintf fmt "%a@ " punctuation ";") (fun fmt (field_name, field_expr) -> Format.fprintf fmt "%a%a%a%a@ %a" punctuation "\"" - StructFieldName.format_t field_name punctuation "\"" punctuation - "=" expr field_expr)) - (StructFieldMap.bindings fields) + StructField.format_t field_name punctuation "\"" punctuation "=" + expr field_expr)) + (StructField.Map.bindings fields) punctuation "}" | EStructAccess { e; field; _ } -> Format.fprintf fmt "%a%a%a%a%a" expr e punctuation "." punctuation "\"" - StructFieldName.format_t field punctuation "\"" + StructField.format_t field punctuation "\"" | EInj { e; cons; _ } -> Format.fprintf fmt "%a@ %a" EnumConstructor.format_t cons expr e | EMatch { e; cases; _ } -> @@ -347,7 +347,7 @@ let rec expr_aux : (fun fmt (cons_name, case_expr) -> Format.fprintf fmt "@[%a %a@ %a@ %a@]" punctuation "|" enum_constructor cons_name punctuation "→" expr case_expr)) - (EnumConstructorMap.bindings cases) + (EnumConstructor.Map.bindings cases) | EScopeCall { scope; args } -> Format.pp_open_hovbox fmt 2; ScopeName.format_t fmt scope; @@ -362,7 +362,7 @@ let rec expr_aux : Format.fprintf fmt "%a%a%a%a@ %a" punctuation "\"" ScopeVar.format_t field_name punctuation "\"" punctuation "=" expr field_expr) fmt - (ScopeVarMap.bindings args); + (ScopeVar.Map.bindings args); Format.pp_close_box fmt (); punctuation fmt "}"; Format.pp_close_box fmt () diff --git a/compiler/shared_ast/typing.ml b/compiler/shared_ast/typing.ml index 3ede18c4..c1202054 100644 --- a/compiler/shared_ast/typing.ml +++ b/compiler/shared_ast/typing.ml @@ -284,32 +284,32 @@ let op_type (op : A.operator Marked.pos) : unionfind_typ = module Env = struct type 'e t = { vars : ('e, unionfind_typ) Var.Map.t; - scope_vars : A.typ A.ScopeVarMap.t; - scopes : A.typ A.ScopeVarMap.t A.ScopeMap.t; + scope_vars : A.typ A.ScopeVar.Map.t; + scopes : A.typ A.ScopeVar.Map.t A.ScopeName.Map.t; } let empty = { vars = Var.Map.empty; - scope_vars = A.ScopeVarMap.empty; - scopes = A.ScopeMap.empty; + scope_vars = A.ScopeVar.Map.empty; + scopes = A.ScopeName.Map.empty; } let get t v = Var.Map.find_opt v t.vars - let get_scope_var t sv = A.ScopeVarMap.find_opt sv t.scope_vars + let get_scope_var t sv = A.ScopeVar.Map.find_opt sv t.scope_vars let get_subscope_out_var t scope var = - Option.bind (A.ScopeMap.find_opt scope t.scopes) (fun vmap -> - A.ScopeVarMap.find_opt var vmap) + Option.bind (A.ScopeName.Map.find_opt scope t.scopes) (fun vmap -> + A.ScopeVar.Map.find_opt var vmap) let add v tau t = { t with vars = Var.Map.add v tau t.vars } let add_var v typ t = add v (ast_to_typ typ) t let add_scope_var v typ t = - { t with scope_vars = A.ScopeVarMap.add v typ t.scope_vars } + { t with scope_vars = A.ScopeVar.Map.add v typ t.scope_vars } let add_scope scope_name ~vars t = - { t with scopes = A.ScopeMap.add scope_name vars t.scopes } + { t with scopes = A.ScopeName.Map.add scope_name vars t.scopes } end let add_pos e ty = Marked.mark (Expr.pos e) ty @@ -373,33 +373,32 @@ and typecheck_expr_top_down : Expr.elocation loc (uf_mark (ast_to_typ ty)) | A.EStruct { name; fields } -> let mark = ty_mark (TStruct name) in - let str = A.StructMap.find name ctx.A.ctx_structs in + let str = A.StructName.Map.find name ctx.A.ctx_structs in let _check_fields : unit = let missing_fields, extra_fields = - A.StructFieldMap.fold + A.StructField.Map.fold (fun fld x (remaining, extra) -> - if A.StructFieldMap.mem fld remaining then - A.StructFieldMap.remove fld remaining, extra - else remaining, A.StructFieldMap.add fld x extra) + if A.StructField.Map.mem fld remaining then + A.StructField.Map.remove fld remaining, extra + else remaining, A.StructField.Map.add fld x extra) fields - (str, A.StructFieldMap.empty) + (str, A.StructField.Map.empty) in let errs = List.map (fun (f, ty) -> - ( Some - (Format.asprintf "Missing field %a" A.StructFieldName.format_t f), + ( Some (Format.asprintf "Missing field %a" A.StructField.format_t f), Marked.get_mark ty )) - (A.StructFieldMap.bindings missing_fields) + (A.StructField.Map.bindings missing_fields) @ List.map (fun (f, ef) -> - let dup = A.StructFieldMap.mem f str in + let dup = A.StructField.Map.mem f str in ( Some (Format.asprintf "%s field %a" (if dup then "Duplicate" else "Unknown") - A.StructFieldName.format_t f), + A.StructField.format_t f), Expr.pos ef )) - (A.StructFieldMap.bindings extra_fields) + (A.StructField.Map.bindings extra_fields) in if errs <> [] then Errors.raise_multispanned_error errs @@ -407,9 +406,9 @@ and typecheck_expr_top_down : name in let fields' = - A.StructFieldMap.mapi + A.StructField.Map.mapi (fun f_name f_e -> - let f_ty = A.StructFieldMap.find f_name str in + let f_ty = A.StructField.Map.find f_name str in typecheck_expr_top_down ctx env (ast_to_typ f_ty) f_e) fields in @@ -417,12 +416,12 @@ and typecheck_expr_top_down : | A.EStructAccess { e = e_struct; name; field } -> let fld_ty = let str = - try A.StructMap.find name ctx.A.ctx_structs + try A.StructName.Map.find name ctx.A.ctx_structs with Not_found -> Errors.raise_spanned_error pos_e "No structure %a found" A.StructName.format_t name in - try A.StructFieldMap.find field str + try A.StructField.Map.find field str with Not_found -> Errors.raise_multispanned_error [ @@ -431,7 +430,7 @@ and typecheck_expr_top_down : Marked.get_mark (A.StructName.get_info name) ); ] "Structure %a doesn't define a field %a" A.StructName.format_t name - A.StructFieldName.format_t field + A.StructField.format_t field in let mark = uf_mark (ast_to_typ fld_ty) in let e_struct' = @@ -443,20 +442,20 @@ and typecheck_expr_top_down : let e_enum' = typecheck_expr_top_down ctx env (ast_to_typ - (A.EnumConstructorMap.find cons - (A.EnumMap.find name ctx.A.ctx_enums))) + (A.EnumConstructor.Map.find cons + (A.EnumName.Map.find name ctx.A.ctx_enums))) e_enum in Expr.einj e_enum' cons name mark | A.EMatch { e = e1; name; cases } -> - let cases_ty = A.EnumMap.find name ctx.A.ctx_enums in + let cases_ty = A.EnumName.Map.find name ctx.A.ctx_enums in let t_ret = unionfind ~pos:e1 (TAny (Any.fresh ())) in let mark = uf_mark t_ret in let e1' = typecheck_expr_top_down ctx env (unionfind (TEnum name)) e1 in let cases' = - A.EnumConstructorMap.mapi + A.EnumConstructor.Map.mapi (fun c_name e -> - let c_ty = A.EnumConstructorMap.find c_name cases_ty in + let c_ty = A.EnumConstructor.Map.find c_name cases_ty in let e_ty = unionfind ~pos:e (TArrow (ast_to_typ c_ty, t_ret)) in typecheck_expr_top_down ctx env e_ty e) cases @@ -464,15 +463,15 @@ and typecheck_expr_top_down : Expr.ematch e1' name cases' mark | A.EScopeCall { scope; args } -> let scope_out_struct = - (A.ScopeMap.find scope ctx.ctx_scopes).out_struct_name + (A.ScopeName.Map.find scope ctx.ctx_scopes).out_struct_name in let mark = uf_mark (unionfind (TStruct scope_out_struct)) in - let vars = A.ScopeMap.find scope env.scopes in + let vars = A.ScopeName.Map.find scope env.scopes in let args' = - A.ScopeVarMap.mapi + A.ScopeVar.Map.mapi (fun name -> typecheck_expr_top_down ctx env - (ast_to_typ (A.ScopeVarMap.find name vars))) + (ast_to_typ (A.ScopeVar.Map.find name vars))) args in Expr.escopecall scope args' mark diff --git a/compiler/shared_ast/typing.mli b/compiler/shared_ast/typing.mli index 1329362e..1c9485fd 100644 --- a/compiler/shared_ast/typing.mli +++ b/compiler/shared_ast/typing.mli @@ -25,7 +25,7 @@ module Env : sig val empty : 'e t val add_var : 'e Var.t -> typ -> 'e t -> 'e t val add_scope_var : ScopeVar.t -> typ -> 'e t -> 'e t - val add_scope : ScopeName.t -> vars:typ ScopeVarMap.t -> 'e t -> 'e t + val add_scope : ScopeName.t -> vars:typ ScopeVar.Map.t -> 'e t -> 'e t end val expr : diff --git a/compiler/utils/uid.ml b/compiler/utils/uid.ml index 3b120c79..6704930b 100644 --- a/compiler/utils/uid.ml +++ b/compiler/utils/uid.ml @@ -34,16 +34,18 @@ module type Id = sig val format_t : Format.formatter -> t -> unit val hash : t -> int - module Set: Set.S with type elt = t - module Map: Map.S with type key = t + module Set : Set.S with type elt = t + module Map : Map.S with type key = t end module Make (X : Info) () : Id with type info = X.info = struct module Ordering = struct type t = { id : int; info : X.info } + let compare (x : t) (y : t) : int = compare x.id y.id let equal x y = Int.equal x.id y.id end + include Ordering type info = X.info diff --git a/compiler/utils/uid.mli b/compiler/utils/uid.mli index f58d5dcf..6f698943 100644 --- a/compiler/utils/uid.mli +++ b/compiler/utils/uid.mli @@ -49,8 +49,8 @@ module type Id = sig val format_t : Format.formatter -> t -> unit val hash : t -> int - module Set: Set.S with type elt = t - module Map: Map.S with type key = t + module Set : Set.S with type elt = t + module Map : Map.S with type key = t end (** This is the generative functor that ensures that two modules resulting from @@ -58,5 +58,5 @@ end OCaml typechecker. Prevents mixing up different sorts of identifiers. *) module Make (X : Info) () : Id with type info = X.info -(** Shortcut for creating a kind of uids over marked strings *) module Gen () : Id with type info = MarkedString.info +(** Shortcut for creating a kind of uids over marked strings *) diff --git a/compiler/verification/z3backend.real.ml b/compiler/verification/z3backend.real.ml index 7e532490..bb2f9a04 100644 --- a/compiler/verification/z3backend.real.ml +++ b/compiler/verification/z3backend.real.ml @@ -35,13 +35,13 @@ type context = { (* A map from strings, corresponding to Z3 symbol names, to the Catala variable they represent. Used when to pretty-print Z3 models when a counterexample is generated *) - ctx_z3datatypes : Sort.sort EnumMap.t; + ctx_z3datatypes : Sort.sort EnumName.Map.t; (* A map from Catala enumeration names to the corresponding Z3 sort, from which we can retrieve constructors and accessors *) ctx_z3matchsubsts : (typed expr, Expr.expr) Var.Map.t; (* A map from Catala temporary variables, generated when translating a match, to the corresponding enum accessor call as a Z3 expression *) - ctx_z3structs : Sort.sort StructMap.t; + ctx_z3structs : Sort.sort StructName.Map.t; (* A map from Catala struct names to the corresponding Z3 sort, from which we can retrieve the constructor and the accessors *) ctx_z3unit : Sort.sort * Expr.expr; @@ -80,7 +80,7 @@ let add_z3var (name : string) (v : typed expr Var.t) (ty : typ) (ctx : context) corresponding Z3 datatype [sort] to the context **) let add_z3enum (enum : EnumName.t) (sort : Sort.sort) (ctx : context) : context = - { ctx with ctx_z3datatypes = EnumMap.add enum sort ctx.ctx_z3datatypes } + { ctx with ctx_z3datatypes = EnumName.Map.add enum sort ctx.ctx_z3datatypes } (** [add_z3matchsubst] adds the mapping between temporary variable [v] and the Z3 expression [e] representing an accessor application to the context **) @@ -92,7 +92,7 @@ let add_z3matchsubst (v : typed expr Var.t) (e : Expr.expr) (ctx : context) : corresponding Z3 datatype [sort] to the context **) let add_z3struct (s : StructName.t) (sort : Sort.sort) (ctx : context) : context = - { ctx with ctx_z3structs = StructMap.add s sort ctx.ctx_z3structs } + { ctx with ctx_z3structs = StructName.Map.add s sort ctx.ctx_z3structs } let add_z3constraint (e : Expr.expr) (ctx : context) : context = { ctx with ctx_z3constraints = e :: ctx.ctx_z3constraints } @@ -161,16 +161,16 @@ let rec print_z3model_expr (ctx : context) (ty : typ) (e : Expr.expr) : string = match Marked.unmark ty with | TLit ty -> print_lit ty | TStruct name -> - let s = StructMap.find name ctx.ctx_decl.ctx_structs in - let get_fieldname (fn : StructFieldName.t) : string = - Marked.unmark (StructFieldName.get_info fn) + let s = StructName.Map.find name ctx.ctx_decl.ctx_structs in + let get_fieldname (fn : StructField.t) : string = + Marked.unmark (StructField.get_info fn) in let fields = List.map2 (fun (fn, ty) e -> Format.asprintf "-- %s : %s" (get_fieldname fn) (print_z3model_expr ctx ty e)) - (StructFieldMap.bindings s) + (StructField.Map.bindings s) (Expr.get_args e) in @@ -187,13 +187,13 @@ let rec print_z3model_expr (ctx : context) (ty : typ) (e : Expr.expr) : string = let fd = Expr.get_func_decl e in let fd_name = Symbol.to_string (FuncDecl.get_name fd) in - let enum_ctrs = EnumMap.find name ctx.ctx_decl.ctx_enums in + let enum_ctrs = EnumName.Map.find name ctx.ctx_decl.ctx_enums in let case = List.find (fun (ctr, _) -> (* FIXME: don't match on strings *) String.equal fd_name (Marked.unmark (EnumConstructor.get_info ctr))) - (EnumConstructorMap.bindings enum_ctrs) + (EnumConstructor.Map.bindings enum_ctrs) in Format.asprintf "%s (%s)" fd_name (print_z3model_expr ctx (snd case) e') @@ -310,12 +310,12 @@ and find_or_create_enum (ctx : context) (enum : EnumName.t) : [Sort.get_id arg_z3_ty] ) in - match EnumMap.find_opt enum ctx.ctx_z3datatypes with + match EnumName.Map.find_opt enum ctx.ctx_z3datatypes with | Some e -> ctx, e | None -> - let ctrs = EnumMap.find enum ctx.ctx_decl.ctx_enums in + let ctrs = EnumName.Map.find enum ctx.ctx_decl.ctx_enums in let ctx, z3_ctrs = - EnumConstructorMap.fold + EnumConstructor.Map.fold (fun ctr ty (ctx, ctrs) -> let ctx, ctr = create_constructor ctr ty ctx in ctx, ctr :: ctrs) @@ -334,20 +334,20 @@ and find_or_create_enum (ctx : context) (enum : EnumName.t) : context *) and find_or_create_struct (ctx : context) (s : StructName.t) : context * Sort.sort = - match StructMap.find_opt s ctx.ctx_z3structs with + match StructName.Map.find_opt s ctx.ctx_z3structs with | Some s -> ctx, s | None -> let s_name = Marked.unmark (StructName.get_info s) in - let fields = StructMap.find s ctx.ctx_decl.ctx_structs in + let fields = StructName.Map.find s ctx.ctx_decl.ctx_structs in let z3_fieldnames = List.map (fun f -> - Marked.unmark (StructFieldName.get_info (fst f)) + Marked.unmark (StructField.get_info (fst f)) |> Symbol.mk_string ctx.ctx_z3) - (StructFieldMap.bindings fields) + (StructField.Map.bindings fields) in let ctx, z3_fieldtypes_rev = - StructFieldMap.fold + StructField.Map.fold (fun _ ty (ctx, ftypes) -> let ctx, ftype = translate_typ ctx (Marked.unmark ty) in ctx, ftype :: ftypes) @@ -709,14 +709,12 @@ and translate_expr (ctx : context) (vc : typed expr) : context * Expr.expr = let idx_mappings = List.combine (List.map fst - (StructFieldMap.bindings - (StructMap.find name ctx.ctx_decl.ctx_structs))) + (StructField.Map.bindings + (StructName.Map.find name ctx.ctx_decl.ctx_structs))) accessors in let _, accessor = - List.find - (fun (field1, _) -> StructFieldName.equal field field1) - idx_mappings + List.find (fun (field1, _) -> StructField.equal field field1) idx_mappings in let ctx, s = translate_expr ctx e in ctx, Expr.mk_app ctx.ctx_z3 accessor [s] @@ -730,8 +728,8 @@ and translate_expr (ctx : context) (vc : typed expr) : context * Expr.expr = let idx_mappings = List.combine (List.map fst - (EnumConstructorMap.bindings - (EnumMap.find name ctx.ctx_decl.ctx_enums))) + (EnumConstructor.Map.bindings + (EnumName.Map.find name ctx.ctx_decl.ctx_enums))) ctrs in let _, ctr = @@ -760,7 +758,7 @@ and translate_expr (ctx : context) (vc : typed expr) : context * Expr.expr = (translate_match_arm z3_arg) ctx (List.combine - (List.map snd (EnumConstructorMap.bindings cases)) + (List.map snd (EnumConstructor.Map.bindings cases)) (Datatype.get_accessors z3_enum)) in let z3_arms = @@ -873,9 +871,9 @@ module Backend = struct ctx_decl = decl_ctx; ctx_funcdecl = Var.Map.empty; ctx_z3vars = StringMap.empty; - ctx_z3datatypes = EnumMap.empty; + ctx_z3datatypes = EnumName.Map.empty; ctx_z3matchsubsts = Var.Map.empty; - ctx_z3structs = StructMap.empty; + ctx_z3structs = StructName.Map.empty; ctx_z3unit = z3unit; ctx_z3constraints = []; }