2023-07-12 12:48:46 +03:00
|
|
|
(* This file is part of the Catala compiler, a specification language for tax
|
|
|
|
and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
|
|
|
|
Louis Gesbert <louis.gesbert@inria.fr>
|
|
|
|
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License"); you may not
|
|
|
|
use this file except in compliance with the License. You may obtain a copy of
|
|
|
|
the License at
|
|
|
|
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
|
|
distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
|
|
WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
|
|
License for the specific language governing permissions and limitations under
|
|
|
|
the License. *)
|
|
|
|
|
|
|
|
(** Small wrapper on top of the stdlib [Map] module to add some useful functions *)
|
|
|
|
|
|
|
|
(* NOTE: only defines typed module, a .mli would be completely redundant *)
|
|
|
|
|
|
|
|
include Stdlib.Map
|
|
|
|
|
|
|
|
module type OrderedType = sig
|
|
|
|
include Stdlib.Map.OrderedType
|
|
|
|
|
|
|
|
val format : Format.formatter -> t -> unit
|
|
|
|
end
|
|
|
|
|
|
|
|
module type S = sig
|
|
|
|
include Stdlib.Map.S
|
|
|
|
|
2023-08-15 17:57:52 +03:00
|
|
|
exception Not_found of key
|
|
|
|
(* Slightly more informative [Not_found] exception *)
|
|
|
|
|
|
|
|
val find : key -> 'a t -> 'a
|
2023-07-12 12:48:46 +03:00
|
|
|
val keys : 'a t -> key list
|
|
|
|
val values : 'a t -> 'a list
|
|
|
|
val of_list : (key * 'a) list -> 'a t
|
2023-11-20 18:01:06 +03:00
|
|
|
val disjoint_union : 'a t -> 'a t -> 'a t
|
2023-07-12 12:48:46 +03:00
|
|
|
|
|
|
|
val format_keys :
|
|
|
|
?pp_sep:(Format.formatter -> unit -> unit) ->
|
|
|
|
Format.formatter ->
|
|
|
|
'a t ->
|
|
|
|
unit
|
|
|
|
|
|
|
|
val format_values :
|
|
|
|
?pp_sep:(Format.formatter -> unit -> unit) ->
|
|
|
|
(Format.formatter -> 'a -> unit) ->
|
|
|
|
Format.formatter ->
|
|
|
|
'a t ->
|
|
|
|
unit
|
|
|
|
|
|
|
|
val format_bindings :
|
|
|
|
?pp_sep:(Format.formatter -> unit -> unit) ->
|
|
|
|
(Format.formatter -> (Format.formatter -> unit) -> 'a -> unit) ->
|
|
|
|
Format.formatter ->
|
|
|
|
'a t ->
|
|
|
|
unit
|
|
|
|
(** Formats the bindings of [t] in order. The user-supplied format function is
|
|
|
|
provided with a formatter for keys (can be used with ["%t"]) and the
|
|
|
|
corresponding value. *)
|
|
|
|
|
|
|
|
val format :
|
|
|
|
?pp_sep:(Format.formatter -> unit -> unit) ->
|
|
|
|
?pp_bind:(Format.formatter -> unit -> unit) ->
|
|
|
|
(Format.formatter -> 'a -> unit) ->
|
|
|
|
Format.formatter ->
|
|
|
|
'a t ->
|
|
|
|
unit
|
|
|
|
(** Formats all bindings of the map in order using the given separator
|
|
|
|
(default ["; "]) and binding indicator (default [" = "]). *)
|
|
|
|
end
|
|
|
|
|
|
|
|
module Make (Ord : OrderedType) : S with type key = Ord.t = struct
|
|
|
|
include Stdlib.Map.Make (Ord)
|
|
|
|
|
2023-08-15 17:57:52 +03:00
|
|
|
exception Not_found of key
|
|
|
|
|
|
|
|
let () =
|
|
|
|
Printexc.register_printer
|
|
|
|
@@ function
|
|
|
|
| Not_found k ->
|
|
|
|
Some (Format.asprintf "key '%a' not found in map" Ord.format k)
|
|
|
|
| _ -> None
|
|
|
|
|
|
|
|
let find k t = try find k t with Stdlib.Not_found -> raise (Not_found k)
|
2023-07-12 12:48:46 +03:00
|
|
|
let keys t = fold (fun k _ acc -> k :: acc) t [] |> List.rev
|
|
|
|
let values t = fold (fun _ v acc -> v :: acc) t [] |> List.rev
|
|
|
|
let of_list l = List.fold_left (fun m (k, v) -> add k v m) empty l
|
2023-12-01 01:53:38 +03:00
|
|
|
|
2023-11-20 18:01:06 +03:00
|
|
|
let disjoint_union t1 t2 =
|
2023-12-01 01:53:38 +03:00
|
|
|
union
|
|
|
|
(fun k _ _ ->
|
|
|
|
Format.kasprintf failwith "Maps are not disjoint: conflict on key %a"
|
2023-11-20 18:01:06 +03:00
|
|
|
Ord.format k)
|
|
|
|
t1 t2
|
2023-07-12 12:48:46 +03:00
|
|
|
|
|
|
|
let format_keys ?pp_sep ppf t =
|
|
|
|
Format.pp_print_list ?pp_sep Ord.format ppf (keys t)
|
|
|
|
|
|
|
|
let format_values ?pp_sep pp_value ppf t =
|
|
|
|
Format.pp_print_list ?pp_sep pp_value ppf (values t)
|
|
|
|
|
|
|
|
let format_bindings ?pp_sep pp_bnd ppf t =
|
|
|
|
Format.pp_print_list ?pp_sep
|
|
|
|
(fun ppf (k, v) -> pp_bnd ppf (fun ppf -> Ord.format ppf k) v)
|
|
|
|
ppf (bindings t)
|
|
|
|
|
|
|
|
let format
|
|
|
|
?(pp_sep = fun ppf () -> Format.fprintf ppf ";@ ")
|
|
|
|
?(pp_bind = fun ppf () -> Format.fprintf ppf " =@ ")
|
|
|
|
pp_value
|
|
|
|
ppf
|
|
|
|
t =
|
|
|
|
Format.pp_print_list ~pp_sep
|
|
|
|
(fun ppf (key, value) ->
|
|
|
|
Format.pp_open_hvbox ppf 2;
|
|
|
|
Ord.format ppf key;
|
|
|
|
pp_bind ppf ();
|
|
|
|
pp_value ppf value;
|
|
|
|
Format.pp_close_box ppf ())
|
|
|
|
ppf (bindings t)
|
|
|
|
end
|