Move struct_or_enum type to generic TypeIdent.t

It used to be hidden away in Scopelang.Dependencies, but is useful throughout.
This commit is contained in:
Louis Gesbert 2024-08-29 16:26:51 +02:00
parent 2d9b2edc9d
commit f2ac1e39cc
14 changed files with 119 additions and 110 deletions

View File

@ -175,7 +175,7 @@ module Passes = struct
optimize:bool ->
check_invariants:bool ->
typed:ty mark ->
ty Dcalc.Ast.program * Scopelang.Dependency.TVertex.t list =
ty Dcalc.Ast.program * TypeIdent.t list =
fun options ~includes ~optimize ~check_invariants ~typed ->
let prg = scopelang options ~includes in
debug_pass_name "dcalc";
@ -233,7 +233,7 @@ module Passes = struct
~expand_ops
~renaming :
typed Lcalc.Ast.program
* Scopelang.Dependency.TVertex.t list
* TypeIdent.t list
* Renaming.context option =
let prg, type_ordering =
dcalc options ~includes ~optimize ~check_invariants ~typed
@ -290,7 +290,7 @@ module Passes = struct
| Some renaming ->
let prg, ren_ctx = Renaming.apply renaming prg in
let type_ordering =
let open Scopelang.Dependency.TVertex in
let open TypeIdent in
List.map
(function
| Struct s -> Struct (Renaming.struct_name ren_ctx s)
@ -311,7 +311,7 @@ module Passes = struct
~monomorphize_types
~expand_ops
~renaming :
Scalc.Ast.program * Scopelang.Dependency.TVertex.t list * Renaming.context
Scalc.Ast.program * TypeIdent.t list * Renaming.context
=
let prg, type_ordering, renaming_context =
lcalc options ~includes ~optimize ~check_invariants ~typed:Expr.typed

View File

@ -43,7 +43,7 @@ module Passes : sig
optimize:bool ->
check_invariants:bool ->
typed:'m Shared_ast.mark ->
'm Dcalc.Ast.program * Scopelang.Dependency.TVertex.t list
'm Dcalc.Ast.program * Shared_ast.TypeIdent.t list
val lcalc :
Global.options ->
@ -57,7 +57,7 @@ module Passes : sig
expand_ops:bool ->
renaming:Shared_ast.Renaming.t option ->
Shared_ast.typed Lcalc.Ast.program
* Scopelang.Dependency.TVertex.t list
* Shared_ast.TypeIdent.t list
* Shared_ast.Renaming.context option
val scalc :
@ -73,7 +73,7 @@ module Passes : sig
expand_ops:bool ->
renaming:Shared_ast.Renaming.t option ->
Scalc.Ast.program
* Scopelang.Dependency.TVertex.t list
* Shared_ast.TypeIdent.t list
* Shared_ast.Renaming.context
end

View File

@ -283,7 +283,7 @@ let rec monomorphize_expr
| e -> e
let program (prg : typed program) :
typed program * Scopelang.Dependency.TVertex.t list =
typed program * TypeIdent.t list =
let monomorphized_instances = collect_monomorphized_instances prg in
let decl_ctx = prg.decl_ctx in
(* First we remove the polymorphic option type *)

View File

@ -18,7 +18,7 @@ open Shared_ast
open Ast
val program :
typed program -> typed program * Scopelang.Dependency.TVertex.t list
typed program -> typed program * TypeIdent.t list
(** This function performs type monomorphization in a Catala program with two
main actions: {ul
{- transforms tuples into named structs.}

View File

@ -469,7 +469,7 @@ let format_enum_embedding
(EnumConstructor.Map.bindings enum_cases)
let format_ctx
(type_ordering : Scopelang.Dependency.TVertex.t list)
(type_ordering : TypeIdent.t list)
(fmt : Format.formatter)
(ctx : decl_ctx) : unit =
let format_struct_decl fmt (struct_name, struct_fields) =
@ -508,13 +508,13 @@ let format_ctx
List.exists
(fun struct_or_enum ->
match struct_or_enum with
| Scopelang.Dependency.TVertex.Enum _ -> false
| Scopelang.Dependency.TVertex.Struct s' -> s = s')
| TypeIdent.Enum _ -> false
| TypeIdent.Struct s' -> s = s')
type_ordering
in
let scope_structs =
List.map
(fun (s, _) -> Scopelang.Dependency.TVertex.Struct s)
(fun (s, _) -> TypeIdent.Struct s)
(StructName.Map.bindings
(StructName.Map.filter
(fun s _ -> not (is_in_type_ordering s))
@ -523,11 +523,11 @@ let format_ctx
List.iter
(fun struct_or_enum ->
match struct_or_enum with
| Scopelang.Dependency.TVertex.Struct s ->
| TypeIdent.Struct s ->
let def = StructName.Map.find s ctx.ctx_structs in
if StructName.path s = [] then
Format.fprintf fmt "%a@\n" format_struct_decl (s, def)
| Scopelang.Dependency.TVertex.Enum e ->
| TypeIdent.Enum e ->
let def = EnumName.Map.find e ctx.ctx_enums in
if EnumName.path e = [] then
Format.fprintf fmt "%a@\n" format_enum_decl (e, def))
@ -737,7 +737,7 @@ let format_program
?(exec_args = true)
~(hashf : Hash.t -> Hash.full)
(p : 'm Ast.program)
(type_ordering : Scopelang.Dependency.TVertex.t list) : unit =
(type_ordering : TypeIdent.t list) : unit =
Format.pp_open_vbox fmt 0;
Format.pp_print_string fmt header;
check_and_reexport_used_modules fmt ~hashf

View File

@ -44,7 +44,7 @@ val format_program :
?exec_args:bool ->
hashf:(Hash.t -> Hash.full) ->
'm Ast.program ->
Scopelang.Dependency.TVertex.t list ->
TypeIdent.t list ->
unit
(** Usage [format_program fmt p type_dependencies_ordering]. Either one of these
may be set:

View File

@ -168,7 +168,7 @@ module To_jsoo = struct
else Format.fprintf fmt "%s_" lowercase_name
let format_ctx
(type_ordering : Scopelang.Dependency.TVertex.t list)
(type_ordering : TypeIdent.t list)
(fmt : Format.formatter)
(ctx : decl_ctx) : unit =
let format_prop_or_meth fmt (struct_field_type : typ) =
@ -353,13 +353,13 @@ module To_jsoo = struct
List.exists
(fun struct_or_enum ->
match struct_or_enum with
| Scopelang.Dependency.TVertex.Enum _ -> false
| Scopelang.Dependency.TVertex.Struct s' -> s = s')
| TypeIdent.Enum _ -> false
| TypeIdent.Struct s' -> s = s')
type_ordering
in
let scope_structs =
List.map
(fun (s, _) -> Scopelang.Dependency.TVertex.Struct s)
(fun (s, _) -> TypeIdent.Struct s)
(StructName.Map.bindings
(StructName.Map.filter
(fun s _ -> not (is_in_type_ordering s))
@ -368,10 +368,10 @@ module To_jsoo = struct
List.iter
(fun struct_or_enum ->
match struct_or_enum with
| Scopelang.Dependency.TVertex.Struct s ->
| TypeIdent.Struct s ->
Format.fprintf fmt "%a@\n" format_struct_decl
(s, StructName.Map.find s ctx.ctx_structs)
| Scopelang.Dependency.TVertex.Enum e ->
| TypeIdent.Enum e ->
Format.fprintf fmt "%a@\n" format_enum_decl
(e, EnumName.Map.find e ctx.ctx_enums))
(type_ordering @ scope_structs)
@ -428,7 +428,7 @@ module To_jsoo = struct
(fmt : Format.formatter)
(module_name : string option)
(prgm : 'm Lcalc.Ast.program)
(type_ordering : Scopelang.Dependency.TVertex.t list) =
(type_ordering : TypeIdent.t list) =
let fmt_lib_name fmt _ =
Format.fprintf fmt "%sLib"
(Option.fold ~none:""

View File

@ -125,7 +125,7 @@ let rec format_typ
| TClosureEnv -> Format.fprintf fmt "%sCLOSURE_ENV%t" sconst element_name
let format_ctx
(type_ordering : Scopelang.Dependency.TVertex.t list)
(type_ordering : TypeIdent.t list)
(fmt : Format.formatter)
(ctx : decl_ctx) : unit =
let format_struct_decl fmt (struct_name, struct_fields) =
@ -179,13 +179,13 @@ let format_ctx
List.exists
(fun struct_or_enum ->
match struct_or_enum with
| Scopelang.Dependency.TVertex.Enum _ -> false
| Scopelang.Dependency.TVertex.Struct s' -> s = s')
| TypeIdent.Enum _ -> false
| TypeIdent.Struct s' -> s = s')
type_ordering
in
let scope_structs =
List.map
(fun (s, _) -> Scopelang.Dependency.TVertex.Struct s)
(fun (s, _) -> TypeIdent.Struct s)
(StructName.Map.bindings
(StructName.Map.filter
(fun s _ -> not (is_in_type_ordering s))
@ -194,10 +194,10 @@ let format_ctx
Format.pp_print_list
(fun fmt struct_or_enum ->
match struct_or_enum with
| Scopelang.Dependency.TVertex.Struct s ->
| TypeIdent.Struct s ->
Format.fprintf fmt "%a" format_struct_decl
(s, StructName.Map.find s ctx.ctx_structs)
| Scopelang.Dependency.TVertex.Enum e ->
| TypeIdent.Enum e ->
Format.fprintf fmt "%a" format_enum_decl
(e, EnumName.Map.find e ctx.ctx_enums))
fmt
@ -733,7 +733,7 @@ let format_main (fmt : Format.formatter) (p : Ast.program) =
let format_program
(fmt : Format.formatter)
(p : Ast.program)
(type_ordering : Scopelang.Dependency.TVertex.t list) : unit =
(type_ordering : TypeIdent.t list) : unit =
Format.pp_open_vbox fmt 0;
Format.fprintf fmt
"/* This file has been generated by the Catala compiler, do not edit! */@,\

View File

@ -21,5 +21,5 @@ open Shared_ast
val renaming : Renaming.t
val format_program :
Format.formatter -> Ast.program -> Scopelang.Dependency.TVertex.t list -> unit
Format.formatter -> Ast.program -> TypeIdent.t list -> unit
(** Usage [format_program fmt p type_dependencies_ordering] *)

View File

@ -414,7 +414,7 @@ and format_block ctx (fmt : Format.formatter) (b : block) : unit =
Format.pp_close_box fmt ()
let format_ctx
(type_ordering : Scopelang.Dependency.TVertex.t list)
(type_ordering : TypeIdent.t list)
(fmt : Format.formatter)
ctx : unit =
let format_struct_decl fmt (struct_name, struct_fields) =
@ -509,13 +509,13 @@ let format_ctx
List.exists
(fun struct_or_enum ->
match struct_or_enum with
| Scopelang.Dependency.TVertex.Enum _ -> false
| Scopelang.Dependency.TVertex.Struct s' -> s = s')
| TypeIdent.Enum _ -> false
| TypeIdent.Struct s' -> s = s')
type_ordering
in
let scope_structs =
List.map
(fun (s, _) -> Scopelang.Dependency.TVertex.Struct s)
(fun (s, _) -> TypeIdent.Struct s)
(StructName.Map.bindings
(StructName.Map.filter
(fun s _ -> not (is_in_type_ordering s))
@ -524,11 +524,11 @@ let format_ctx
List.iter
(fun struct_or_enum ->
match struct_or_enum with
| Scopelang.Dependency.TVertex.Struct s ->
| TypeIdent.Struct s ->
if StructName.path s = [] then
Format.fprintf fmt "%a@,@," format_struct_decl
(s, StructName.Map.find s ctx.decl_ctx.ctx_structs)
| Scopelang.Dependency.TVertex.Enum e ->
| TypeIdent.Enum e ->
if EnumName.path e = [] then
Format.fprintf fmt "%a@,@," format_enum_decl
(e, EnumName.Map.find e ctx.decl_ctx.ctx_enums))
@ -553,7 +553,7 @@ let format_code_item ctx fmt = function
let format_program
(fmt : Format.formatter)
(p : Ast.program)
(type_ordering : Scopelang.Dependency.TVertex.t list) : unit =
(type_ordering : TypeIdent.t list) : unit =
Format.pp_open_vbox fmt 0;
let header =
[

View File

@ -21,5 +21,5 @@ open Shared_ast
val renaming : Renaming.t
val format_program :
Format.formatter -> Ast.program -> Scopelang.Dependency.TVertex.t list -> unit
Format.formatter -> Ast.program -> TypeIdent.t list -> unit
(** Usage [format_program fmt p type_dependencies_ordering] *)

View File

@ -200,40 +200,6 @@ let check_for_cycle_in_defs (g : SDependencies.t) : unit =
let get_defs_ordering (g : SDependencies.t) : SVertex.t list =
List.rev (STopologicalTraversal.fold (fun sd acc -> sd :: acc) g [])
module TVertex = struct
type t = Struct of StructName.t | Enum of EnumName.t
let hash x =
match x with
| Struct x -> StructName.id x
| Enum x -> Hashtbl.hash (`Enum (EnumName.id x))
let compare x y =
match x, y with
| Struct x, Struct y -> StructName.compare x y
| Enum x, Enum y -> EnumName.compare x y
| Struct _, Enum _ -> 1
| Enum _, Struct _ -> -1
let equal x y =
match x, y with
| Struct x, Struct y -> StructName.compare x y = 0
| Enum x, Enum y -> EnumName.compare x y = 0
| _ -> false
let format (fmt : Format.formatter) (x : t) : unit =
match x with
| Struct x -> StructName.format fmt x
| Enum x -> EnumName.format fmt x
let get_info (x : t) =
match x with
| Struct x -> StructName.get_info x
| Enum x -> EnumName.get_info x
end
module TVertexSet = Set.Make (TVertex)
(** On the edges, the label is the expression responsible for the use of the
function *)
module TEdge = struct
@ -244,29 +210,29 @@ module TEdge = struct
end
module TDependencies =
Graph.Persistent.Digraph.ConcreteBidirectionalLabeled (TVertex) (TEdge)
Graph.Persistent.Digraph.ConcreteBidirectionalLabeled (TypeIdent) (TEdge)
module TTopologicalTraversal = Graph.Topological.Make (TDependencies)
module TSCC = Graph.Components.Make (TDependencies)
(** Tarjan's stongly connected components algorithm, provided by OCamlGraph *)
let rec get_structs_or_enums_in_type (t : typ) : TVertexSet.t =
let rec get_structs_or_enums_in_type (t : typ) : TypeIdent.Set.t =
match Mark.remove t with
| TStruct s -> TVertexSet.singleton (TVertex.Struct s)
| TEnum e -> TVertexSet.singleton (TVertex.Enum e)
| TStruct s -> TypeIdent.Set.singleton (Struct s)
| TEnum e -> TypeIdent.Set.singleton (Enum e)
| TArrow (t1, t2) ->
TVertexSet.union
TypeIdent.Set.union
(t1
|> List.map get_structs_or_enums_in_type
|> List.fold_left TVertexSet.union TVertexSet.empty)
|> List.fold_left TypeIdent.Set.union TypeIdent.Set.empty)
(get_structs_or_enums_in_type t2)
| TClosureEnv | TLit _ | TAny -> TVertexSet.empty
| TClosureEnv | TLit _ | TAny -> TypeIdent.Set.empty
| TOption t1 | TArray t1 | TDefault t1 -> get_structs_or_enums_in_type t1
| TTuple ts ->
List.fold_left
(fun acc t -> TVertexSet.union acc (get_structs_or_enums_in_type t))
TVertexSet.empty ts
(fun acc t -> TypeIdent.Set.union acc (get_structs_or_enums_in_type t))
TypeIdent.Set.empty ts
let build_type_graph (structs : struct_ctx) (enums : enum_ctx) : TDependencies.t
=
@ -276,16 +242,16 @@ let build_type_graph (structs : struct_ctx) (enums : enum_ctx) : TDependencies.t
(fun s fields g ->
StructField.Map.fold
(fun _ typ g ->
let def = TVertex.Struct s in
let def = TypeIdent.Struct s in
let g = TDependencies.add_vertex g def in
let used = get_structs_or_enums_in_type typ in
TVertexSet.fold
TypeIdent.Set.fold
(fun used g ->
if TVertex.equal used def then
if TypeIdent.equal used def then
Message.error ~pos:(Mark.get typ)
"The type@ %a@ is@ defined@ using@ itself,@ which@ is@ \
not@ supported@ (Catala does not allow recursive types)"
TVertex.format used
TypeIdent.format used
else
let edge = TDependencies.E.create used (Mark.get typ) def in
TDependencies.add_edge_e g edge)
@ -298,16 +264,16 @@ let build_type_graph (structs : struct_ctx) (enums : enum_ctx) : TDependencies.t
(fun e cases g ->
EnumConstructor.Map.fold
(fun _ typ g ->
let def = TVertex.Enum e in
let def = TypeIdent.Enum e in
let g = TDependencies.add_vertex g def in
let used = get_structs_or_enums_in_type typ in
TVertexSet.fold
TypeIdent.Set.fold
(fun used g ->
if TVertex.equal used def then
if TypeIdent.equal used def then
Message.error ~pos:(Mark.get typ)
"The type@ %a@ is@ defined@ using@ itself,@ which@ is@ \
not@ supported@ (Catala does not allow recursive types)"
TVertex.format used
TypeIdent.format used
else
let edge = TDependencies.E.create used (Mark.get typ) def in
TDependencies.add_edge_e g edge)
@ -317,7 +283,7 @@ let build_type_graph (structs : struct_ctx) (enums : enum_ctx) : TDependencies.t
in
g
let check_type_cycles (structs : struct_ctx) (enums : enum_ctx) : TVertex.t list
let check_type_cycles (structs : struct_ctx) (enums : enum_ctx) : TypeIdent.t list
=
let g = build_type_graph structs enums in
(* if there is a cycle, there will be an strongly connected component of
@ -330,13 +296,13 @@ let check_type_cycles (structs : struct_ctx) (enums : enum_ctx) : TVertex.t list
(List.map
(fun v ->
let var_str, var_info =
Format.asprintf "%a" TVertex.format v, TVertex.get_info v
Format.asprintf "%a" TypeIdent.format v, TypeIdent.get_info v
in
let succs = TDependencies.succ_e g v in
let _, edge_pos, succ =
List.find (fun (_, _, succ) -> List.mem succ scc) succs
in
let succ_str = Format.asprintf "%a" TVertex.format succ in
let succ_str = Format.asprintf "%a" TypeIdent.format succ in
[
"Cycle type " ^ var_str ^ ", declared:", Mark.get var_info;
( "Used here in the definition of another cycle type "

View File

@ -33,24 +33,11 @@ val build_program_dep_graph : 'm Ast.program -> SDependencies.t
val check_for_cycle_in_defs : SDependencies.t -> unit
val get_defs_ordering : SDependencies.t -> vertex list
(** {1 Type dependencies} *)
module TVertex : sig
type t = Struct of StructName.t | Enum of EnumName.t
val format : Format.formatter -> t -> unit
val get_info : t -> Uid.MarkedString.info
include Graph.Sig.COMPARABLE with type t := t
end
module TVertexSet : Set.S with type elt = TVertex.t
(** On the edges, the label is the expression responsible for the use of the
function *)
module TDependencies :
Graph.Sig.P with type V.t = TVertex.t and type E.label = Pos.t
Graph.Sig.P with type V.t = TypeIdent.t and type E.label = Pos.t
val get_structs_or_enums_in_type : typ -> TVertexSet.t
val get_structs_or_enums_in_type : typ -> TypeIdent.Set.t
val build_type_graph : struct_ctx -> enum_ctx -> TDependencies.t
val check_type_cycles : struct_ctx -> enum_ctx -> TVertex.t list
val check_type_cycles : struct_ctx -> enum_ctx -> TypeIdent.t list

View File

@ -228,6 +228,62 @@ and naked_typ =
| TAny
| TClosureEnv (** Hides an existential type needed for closure conversion *)
module TypeIdent: sig
type t =
| Struct of StructName.t
| Enum of EnumName.t
include Map.OrderedType with type t := t
val get_info : t -> Uid.MarkedString.info
val equal : t -> t -> bool
val hash : t -> int
module Set : Set.S with type elt = t
module Map : Map.S with type key = t
end
= struct
module Ordering = struct
type t =
| Struct of StructName.t
| Enum of EnumName.t
let compare x y =
match x, y with
| Struct x, Struct y -> StructName.compare x y
| Enum x, Enum y -> EnumName.compare x y
| Struct _, Enum _ -> 1
| Enum _, Struct _ -> -1
let equal x y =
match x, y with
| Struct x, Struct y -> StructName.compare x y = 0
| Enum x, Enum y -> EnumName.compare x y = 0
| _ -> false
let format (fmt : Format.formatter) (x : t) : unit =
match x with
| Struct x -> StructName.format fmt x
| Enum x -> EnumName.format fmt x
end
include Ordering
let hash x =
match x with
| Struct x -> StructName.id x
| Enum x -> Hashtbl.hash (`Enum (EnumName.id x))
let get_info (x : t) =
match x with
| Struct x -> StructName.get_info x
| Enum x -> EnumName.get_info x
module Set = Set.Make (Ordering)
module Map = Map.Make (Ordering)
end
(** {2 Constants and operators} *)
type date = Runtime.date