mirror of
https://github.com/CatalaLang/catala.git
synced 2024-11-09 22:16:10 +03:00
Handle additional scopelang cases in helper functions
This commit is contained in:
parent
49e37c71b4
commit
a6702808ef
@ -65,10 +65,24 @@ let eifthenelse e1 e2 e3 mark =
|
||||
let eerroronempty e1 mark =
|
||||
Bindlib.box_apply (fun e1 -> ErrorOnEmpty e1, mark) e1
|
||||
|
||||
let eraise e1 pos = Bindlib.box (ERaise e1, pos)
|
||||
let eraise e1 mark = Bindlib.box (ERaise e1, mark)
|
||||
|
||||
let ecatch e1 exn e2 pos =
|
||||
Bindlib.box_apply2 (fun e1 e2 -> ECatch (e1, exn, e2), pos) e1 e2
|
||||
let ecatch e1 exn e2 mark =
|
||||
Bindlib.box_apply2 (fun e1 e2 -> ECatch (e1, exn, e2), mark) e1 e2
|
||||
|
||||
let elocation loc mark = Bindlib.box (ELocation loc, mark)
|
||||
|
||||
let estruct name fields mark =
|
||||
Bindlib.box_apply (fun es -> EStruct (name, es), mark) fields
|
||||
|
||||
let estructaccess e1 field struc mark =
|
||||
Bindlib.box_apply (fun e1 -> EStructAccess (e1, field, struc), mark) e1
|
||||
|
||||
let eenuminj e1 cons enum mark =
|
||||
Bindlib.box_apply (fun e1 -> EEnumInj (e1, cons, enum), mark) e1
|
||||
|
||||
let ematchs e1 enum cases mark =
|
||||
Bindlib.box_apply2 (fun e1 cases -> EMatchS (e1, enum, cases), mark) e1 cases
|
||||
|
||||
(* - Manipulation of marks - *)
|
||||
|
||||
@ -129,6 +143,7 @@ let fold_marks
|
||||
|
||||
(* - Traversal functions - *)
|
||||
|
||||
(* shallow map *)
|
||||
let map
|
||||
(type a)
|
||||
(ctx : 'ctx)
|
||||
@ -156,8 +171,28 @@ let map
|
||||
| EDefault (excepts, just, cons) ->
|
||||
edefault (List.map (f ctx) excepts) ((f ctx) just) ((f ctx) cons) m
|
||||
| ErrorOnEmpty e1 -> eerroronempty ((f ctx) e1) m
|
||||
| ECatch (e1, exn, e2) -> ecatch (f ctx e1) exn (f ctx e2) (Marked.get_mark e)
|
||||
| ERaise exn -> eraise exn (Marked.get_mark e)
|
||||
| ECatch (e1, exn, e2) -> ecatch (f ctx e1) exn (f ctx e2) m
|
||||
| ERaise exn -> eraise exn m
|
||||
| ELocation loc -> elocation loc m
|
||||
| EStruct (name, fields) ->
|
||||
let fields =
|
||||
StructFieldMap.fold
|
||||
(fun fld e -> Bindlib.box_apply2 (StructFieldMap.add fld) (f ctx e))
|
||||
fields
|
||||
(Bindlib.box StructFieldMap.empty)
|
||||
in
|
||||
estruct name fields m
|
||||
| EStructAccess (e1, field, struc) -> estructaccess (f ctx e1) field struc m
|
||||
| EEnumInj (e1, cons, enum) -> eenuminj (f ctx e1) cons enum m
|
||||
| EMatchS (e1, enum, cases) ->
|
||||
let cases =
|
||||
EnumConstructorMap.fold
|
||||
(fun cstr e ->
|
||||
Bindlib.box_apply2 (EnumConstructorMap.add cstr) (f ctx e))
|
||||
cases
|
||||
(Bindlib.box EnumConstructorMap.empty)
|
||||
in
|
||||
ematchs (f ctx e1) enum cases m
|
||||
|
||||
let rec map_top_down ~f e = map () ~f:(fun () -> map_top_down ~f) (f e)
|
||||
|
||||
@ -313,9 +348,21 @@ and equal : type a. (a, 't) gexpr marked -> (a, 't) gexpr marked -> bool =
|
||||
| ERaise ex1, ERaise ex2 -> equal_except ex1 ex2
|
||||
| ECatch (etry1, ex1, ewith1), ECatch (etry2, ex2, ewith2) ->
|
||||
equal etry1 etry2 && equal_except ex1 ex2 && equal ewith1 ewith2
|
||||
| ELocation _, ELocation _ -> true
|
||||
| EStruct (s1, fields1), EStruct (s2, fields2) ->
|
||||
StructName.equal s1 s2 && StructFieldMap.equal equal fields1 fields2
|
||||
| EStructAccess (e1, f1, s1), EStructAccess (e2, f2, s2) ->
|
||||
StructName.equal s1 s2 && StructFieldName.equal f1 f2 && equal e1 e2
|
||||
| EEnumInj (e1, c1, n1), EEnumInj (e2, c2, n2) ->
|
||||
EnumName.equal n1 n2 && EnumConstructor.equal c1 c2 && equal e1 e2
|
||||
| EMatchS (e1, n1, cases1), EMatchS (e2, n2, cases2) ->
|
||||
EnumName.equal n1 n2
|
||||
&& equal e1 e2
|
||||
&& EnumConstructorMap.equal equal cases1 cases2
|
||||
| ( ( EVar _ | ETuple _ | ETupleAccess _ | EInj _ | EMatch _ | EArray _
|
||||
| ELit _ | EAbs _ | EApp _ | EAssert _ | EOp _ | EDefault _
|
||||
| EIfThenElse _ | ErrorOnEmpty _ | ERaise _ | ECatch _ ),
|
||||
| EIfThenElse _ | ErrorOnEmpty _ | ERaise _ | ECatch _ | ELocation _
|
||||
| EStruct _ | EStructAccess _ | EEnumInj _ | EMatchS _ ),
|
||||
_ ) ->
|
||||
false
|
||||
|
||||
@ -348,6 +395,16 @@ let rec free_vars : type a. (a, 't) gexpr marked -> (a, 't) gexpr Var.Set.t =
|
||||
| EAbs (binder, _) ->
|
||||
let vs, body = Bindlib.unmbind binder in
|
||||
Array.fold_right Var.Set.remove vs (free_vars body)
|
||||
| ELocation _ -> Var.Set.empty
|
||||
| EStruct (_, fields) ->
|
||||
StructFieldMap.fold
|
||||
(fun _ e -> Var.Set.union (free_vars e))
|
||||
fields Var.Set.empty
|
||||
| EStructAccess (e1, _, _) -> free_vars e1
|
||||
| EEnumInj (e1, _, _) -> free_vars e1
|
||||
| EMatchS (e1, _, cases) ->
|
||||
free_vars e1
|
||||
|> EnumConstructorMap.fold (fun _ e -> Var.Set.union (free_vars e)) cases
|
||||
|
||||
let remove_logging_calls e =
|
||||
let rec f () e =
|
||||
@ -384,3 +441,10 @@ let rec size : type a. (a, 't) gexpr marked -> int =
|
||||
exceptions
|
||||
| ERaise _ -> 1
|
||||
| ECatch (etry, _, ewith) -> 1 + size etry + size ewith
|
||||
| ELocation _ -> 1
|
||||
| EStruct (_, fields) ->
|
||||
StructFieldMap.fold (fun _ e acc -> acc + 1 + size e) fields 0
|
||||
| EStructAccess (e1, _, _) -> 1 + size e1
|
||||
| EEnumInj (e1, _, _) -> 1 + size e1
|
||||
| EMatchS (e1, _, cases) ->
|
||||
EnumConstructorMap.fold (fun _ e acc -> acc + 1 + size e) cases (size e1)
|
||||
|
@ -59,6 +59,13 @@ let tlit (fmt : Format.formatter) (l : typ_lit) : unit =
|
||||
| TDuration -> "duration"
|
||||
| TDate -> "date")
|
||||
|
||||
let location (fmt : Format.formatter) (l : location) : unit =
|
||||
match l with
|
||||
| ScopeVar v -> Format.fprintf fmt "%a" ScopeVar.format_t (Marked.unmark v)
|
||||
| SubScopeVar (_, subindex, subvar) ->
|
||||
Format.fprintf fmt "%a.%a" SubScopeName.format_t (Marked.unmark subindex)
|
||||
ScopeVar.format_t (Marked.unmark subvar)
|
||||
|
||||
let enum_constructor (fmt : Format.formatter) (c : EnumConstructor.t) : unit =
|
||||
Format.fprintf fmt "%a"
|
||||
(Utils.Cli.format_with_style [ANSITerminal.magenta])
|
||||
@ -109,7 +116,6 @@ let rec typ (ctx : decl_ctx) (fmt : Format.formatter) (ty : typ) : unit =
|
||||
(Marked.unmark t1)
|
||||
| TAny -> base_type fmt "any"
|
||||
|
||||
(* (EmileRolley) NOTE: seems to be factorizable with Print.lit. *)
|
||||
let lit (type a) (fmt : Format.formatter) (l : a glit) : unit =
|
||||
match l with
|
||||
| LBool b -> lit_style fmt (string_of_bool b)
|
||||
@ -338,3 +344,29 @@ let rec expr :
|
||||
with_parens e1 keyword "with" except exn with_parens e2
|
||||
| ERaise exn ->
|
||||
Format.fprintf fmt "@[<hov 2>%a@ %a@]" keyword "raise" except exn
|
||||
| ELocation loc -> location fmt loc
|
||||
| EStruct (name, fields) ->
|
||||
Format.fprintf fmt " @[<hov 2>%a@ %a@ %a@ %a@]" StructName.format_t name
|
||||
punctuation "{"
|
||||
(Format.pp_print_list
|
||||
~pp_sep:(fun fmt () -> Format.fprintf fmt "%a@ " punctuation ";")
|
||||
(fun fmt (field_name, field_expr) ->
|
||||
Format.fprintf fmt "%a%a%a%a@ %a" punctuation "\""
|
||||
StructFieldName.format_t field_name punctuation "\"" punctuation
|
||||
"=" expr field_expr))
|
||||
(StructFieldMap.bindings fields)
|
||||
punctuation "}"
|
||||
| EStructAccess (e1, field, _) ->
|
||||
Format.fprintf fmt "%a%a%a%a%a" expr e1 punctuation "." punctuation "\""
|
||||
StructFieldName.format_t field punctuation "\""
|
||||
| EEnumInj (e1, cons, _) ->
|
||||
Format.fprintf fmt "%a@ %a" EnumConstructor.format_t cons expr e1
|
||||
| EMatchS (e1, _, cases) ->
|
||||
Format.fprintf fmt "@[<hov 0>%a@ @[<hov 2>%a@]@ %a@ %a@]" keyword "match"
|
||||
expr e1 keyword "with"
|
||||
(Format.pp_print_list
|
||||
~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n")
|
||||
(fun fmt (cons_name, case_expr) ->
|
||||
Format.fprintf fmt "@[<hov 2>%a %a@ %a@ %a@]" punctuation "|"
|
||||
enum_constructor cons_name punctuation "→" expr case_expr))
|
||||
(EnumConstructorMap.bindings cases)
|
||||
|
Loading…
Reference in New Issue
Block a user