mirror of
https://github.com/CatalaLang/catala.git
synced 2024-11-08 17:10:22 +03:00
267 lines
11 KiB
OCaml
267 lines
11 KiB
OCaml
(* This file is part of the Catala compiler, a specification language for tax
|
|
and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
|
|
Denis Merigoux <denis.merigoux@inria.fr>
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License"); you may not
|
|
use this file except in compliance with the License. You may obtain a copy of
|
|
the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
License for the specific language governing permissions and limitations under
|
|
the License. *)
|
|
|
|
open Utils
|
|
open Shared_ast
|
|
open Ast
|
|
|
|
let needs_parens (e : expr Marked.pos) : bool =
|
|
match Marked.unmark e with EAbs _ -> true | _ -> false
|
|
|
|
let format_var (fmt : Format.formatter) (v : Var.t) : unit =
|
|
Format.fprintf fmt "%s_%d" (Bindlib.name_of v) (Bindlib.uid_of v)
|
|
|
|
let format_location (fmt : Format.formatter) (l : location) : unit =
|
|
match l with
|
|
| ScopeVar v -> Format.fprintf fmt "%a" ScopeVar.format_t (Marked.unmark v)
|
|
| SubScopeVar (_, subindex, subvar) ->
|
|
Format.fprintf fmt "%a.%a" SubScopeName.format_t (Marked.unmark subindex)
|
|
ScopeVar.format_t (Marked.unmark subvar)
|
|
|
|
let typ_needs_parens (e : typ Marked.pos) : bool =
|
|
match Marked.unmark e with TArrow _ -> true | _ -> false
|
|
|
|
let rec format_typ (fmt : Format.formatter) (typ : typ Marked.pos) : unit =
|
|
let format_typ_with_parens (fmt : Format.formatter) (t : typ Marked.pos) =
|
|
if typ_needs_parens t then
|
|
Format.fprintf fmt "%a%a%a" Print.punctuation "(" format_typ t
|
|
Print.punctuation ")"
|
|
else Format.fprintf fmt "%a" format_typ t
|
|
in
|
|
match Marked.unmark typ with
|
|
| TLit l -> Print.tlit fmt l
|
|
| TStruct s -> Format.fprintf fmt "%a" StructName.format_t s
|
|
| TEnum e -> Format.fprintf fmt "%a" EnumName.format_t e
|
|
| TArrow (t1, t2) ->
|
|
Format.fprintf fmt "@[<hov 2>%a %a@ %a@]" format_typ_with_parens t1
|
|
Print.operator "→" format_typ t2
|
|
| TArray t1 ->
|
|
Format.fprintf fmt "@[%a@ %a@]" format_typ
|
|
(Marked.same_mark_as t1 typ)
|
|
Print.base_type "array"
|
|
| TAny -> Format.fprintf fmt "any"
|
|
|
|
let rec format_expr
|
|
?(debug : bool = false)
|
|
(fmt : Format.formatter)
|
|
(e : expr Marked.pos) : unit =
|
|
let format_expr = format_expr ~debug in
|
|
let format_with_parens (fmt : Format.formatter) (e : expr Marked.pos) =
|
|
if needs_parens e then Format.fprintf fmt "(%a)" format_expr e
|
|
else Format.fprintf fmt "%a" format_expr e
|
|
in
|
|
match Marked.unmark e with
|
|
| ELocation l -> Format.fprintf fmt "%a" format_location l
|
|
| EVar v -> Format.fprintf fmt "%a" format_var v
|
|
| ELit l -> Format.fprintf fmt "%a" Print.lit l
|
|
| EStruct (name, fields) ->
|
|
Format.fprintf fmt " @[<hov 2>%a@ %a@ %a@ %a@]" StructName.format_t name
|
|
Print.punctuation "{"
|
|
(Format.pp_print_list
|
|
~pp_sep:(fun fmt () -> Format.fprintf fmt "%a@ " Print.punctuation ";")
|
|
(fun fmt (field_name, field_expr) ->
|
|
Format.fprintf fmt "%a%a%a%a@ %a" Print.punctuation "\""
|
|
StructFieldName.format_t field_name Print.punctuation "\""
|
|
Print.punctuation "=" format_expr field_expr))
|
|
(StructFieldMap.bindings fields)
|
|
Print.punctuation "}"
|
|
| EStructAccess (e1, field, _) ->
|
|
Format.fprintf fmt "%a%a%a%a%a" format_expr e1 Print.punctuation "."
|
|
Print.punctuation "\"" StructFieldName.format_t field Print.punctuation
|
|
"\""
|
|
| EEnumInj (e1, cons, _) ->
|
|
Format.fprintf fmt "%a@ %a" EnumConstructor.format_t cons format_expr e1
|
|
| EMatch (e1, _, cases) ->
|
|
Format.fprintf fmt "@[<hov 0>%a@ @[<hov 2>%a@]@ %a@ %a@]" Print.keyword
|
|
"match" format_expr e1 Print.keyword "with"
|
|
(Format.pp_print_list
|
|
~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n")
|
|
(fun fmt (cons_name, case_expr) ->
|
|
Format.fprintf fmt "@[<hov 2>%a %a@ %a@ %a@]" Print.punctuation "|"
|
|
Print.enum_constructor cons_name Print.punctuation "→" format_expr
|
|
case_expr))
|
|
(EnumConstructorMap.bindings cases)
|
|
| EApp ((EAbs (binder, taus), _), args) ->
|
|
let xs, body = Bindlib.unmbind binder in
|
|
let xs_tau = List.map2 (fun x tau -> x, tau) (Array.to_list xs) taus in
|
|
let xs_tau_arg = List.map2 (fun (x, tau) arg -> x, tau, arg) xs_tau args in
|
|
Format.fprintf fmt "@[%a%a@]"
|
|
(Format.pp_print_list
|
|
~pp_sep:(fun fmt () -> Format.fprintf fmt " ")
|
|
(fun fmt (x, tau, arg) ->
|
|
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@ %a@ %a@ %a@ %a@\n@]"
|
|
Print.keyword "let" format_var x Print.punctuation ":" format_typ
|
|
tau Print.punctuation "=" format_expr arg Print.keyword "in"))
|
|
xs_tau_arg format_expr body
|
|
| EAbs (binder, taus) ->
|
|
let xs, body = Bindlib.unmbind binder in
|
|
let xs_tau = List.map2 (fun x tau -> x, tau) (Array.to_list xs) taus in
|
|
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@ %a@]" Print.punctuation "λ"
|
|
(Format.pp_print_list
|
|
~pp_sep:(fun fmt () -> Format.fprintf fmt " ")
|
|
(fun fmt (x, tau) ->
|
|
Format.fprintf fmt "@[%a%a%a@ %a%a@]" Print.punctuation "("
|
|
format_var x Print.punctuation ":" format_typ tau Print.punctuation
|
|
")"))
|
|
xs_tau Print.punctuation "→" format_expr body
|
|
| EApp ((EOp (Binop op), _), [arg1; arg2]) ->
|
|
Format.fprintf fmt "@[%a@ %a@ %a@]" format_with_parens arg1 Print.binop op
|
|
format_with_parens arg2
|
|
| EApp ((EOp (Unop (Log _)), _), [arg1]) when not debug ->
|
|
format_expr fmt arg1
|
|
| EApp ((EOp (Unop op), _), [arg1]) ->
|
|
Format.fprintf fmt "@[%a@ %a@]" Print.unop op format_with_parens arg1
|
|
| EApp (f, args) ->
|
|
Format.fprintf fmt "@[%a@ %a@]" format_expr f
|
|
(Format.pp_print_list
|
|
~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ")
|
|
format_with_parens)
|
|
args
|
|
| EIfThenElse (e1, e2, e3) ->
|
|
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@ %a@ %a@ %a@]" Print.keyword "if"
|
|
format_expr e1 Print.keyword "then" format_expr e2 Print.keyword "else"
|
|
format_expr e3
|
|
| EOp (Ternop op) -> Format.fprintf fmt "%a" Print.ternop op
|
|
| EOp (Binop op) -> Format.fprintf fmt "%a" Print.binop op
|
|
| EOp (Unop op) -> Format.fprintf fmt "%a" Print.unop op
|
|
| EDefault (excepts, just, cons) ->
|
|
if List.length excepts = 0 then
|
|
Format.fprintf fmt "@[%a%a %a@ %a%a@]" Print.punctuation "⟨" format_expr
|
|
just Print.punctuation "⊢" format_expr cons Print.punctuation "⟩"
|
|
else
|
|
Format.fprintf fmt "@[<hov 2>%a%a@ %a@ %a %a@ %a%a@]" Print.punctuation
|
|
"⟨"
|
|
(Format.pp_print_list
|
|
~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ")
|
|
format_expr)
|
|
excepts Print.punctuation "|" format_expr just Print.punctuation "⊢"
|
|
format_expr cons Print.punctuation "⟩"
|
|
| ErrorOnEmpty e' ->
|
|
Format.fprintf fmt "error_empty@ %a" format_with_parens e'
|
|
| EArray es ->
|
|
Format.fprintf fmt "%a%a%a" Print.punctuation "["
|
|
(Format.pp_print_list
|
|
~pp_sep:(fun fmt () -> Print.punctuation fmt ";")
|
|
(fun fmt e -> Format.fprintf fmt "@[%a@]" format_expr e))
|
|
es Print.punctuation "]"
|
|
|
|
let format_struct
|
|
(fmt : Format.formatter)
|
|
((name, fields) : StructName.t * (StructFieldName.t * typ Marked.pos) list)
|
|
: unit =
|
|
Format.fprintf fmt "%a %a %a %a@\n@[<hov 2> %a@]@\n%a" Print.keyword "type"
|
|
StructName.format_t name Print.punctuation "=" Print.punctuation "{"
|
|
(Format.pp_print_list
|
|
~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n")
|
|
(fun fmt (field_name, typ) ->
|
|
Format.fprintf fmt "%a%a %a" StructFieldName.format_t field_name
|
|
Print.punctuation ":" format_typ typ))
|
|
fields Print.punctuation "}"
|
|
|
|
let format_enum
|
|
(fmt : Format.formatter)
|
|
((name, cases) : EnumName.t * (EnumConstructor.t * typ Marked.pos) list) :
|
|
unit =
|
|
Format.fprintf fmt "%a %a %a @\n@[<hov 2> %a@]" Print.keyword "type"
|
|
EnumName.format_t name Print.punctuation "="
|
|
(Format.pp_print_list
|
|
~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n")
|
|
(fun fmt (field_name, typ) ->
|
|
Format.fprintf fmt "%a %a%a %a" Print.punctuation "|"
|
|
EnumConstructor.format_t field_name Print.punctuation ":" format_typ
|
|
typ))
|
|
cases
|
|
|
|
let format_scope
|
|
?(debug : bool = false)
|
|
(fmt : Format.formatter)
|
|
((name, decl) : ScopeName.t * scope_decl) : unit =
|
|
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@ %a@ %a@]@\n@[<v 2> %a@]"
|
|
Print.keyword "let" Print.keyword "scope" ScopeName.format_t name
|
|
(Format.pp_print_list
|
|
~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ")
|
|
(fun fmt (scope_var, (typ, vis)) ->
|
|
Format.fprintf fmt "%a%a%a %a%a%a%a%a" Print.punctuation "("
|
|
ScopeVar.format_t scope_var Print.punctuation ":" format_typ typ
|
|
Print.punctuation "|" Print.keyword
|
|
(match Marked.unmark vis.io_input with
|
|
| NoInput -> "internal"
|
|
| OnlyInput -> "input"
|
|
| Reentrant -> "context")
|
|
(if Marked.unmark vis.io_output then fun fmt () ->
|
|
Format.fprintf fmt "%a@,%a" Print.punctuation "|" Print.keyword
|
|
"output"
|
|
else fun fmt () -> Format.fprintf fmt "@<0>")
|
|
() Print.punctuation ")"))
|
|
(ScopeVarMap.bindings decl.scope_sig)
|
|
Print.punctuation "="
|
|
(Format.pp_print_list
|
|
~pp_sep:(fun fmt () -> Format.fprintf fmt "%a@ " Print.punctuation ";")
|
|
(fun fmt rule ->
|
|
match rule with
|
|
| Definition (loc, typ, _, e) ->
|
|
Format.fprintf fmt "@[<hov 2>%a %a %a %a %a@ %a@]" Print.keyword
|
|
"let" format_location (Marked.unmark loc) Print.punctuation ":"
|
|
format_typ typ Print.punctuation "="
|
|
(fun fmt e ->
|
|
match Marked.unmark loc with
|
|
| SubScopeVar _ -> format_expr fmt e
|
|
| ScopeVar v -> (
|
|
match
|
|
Marked.unmark
|
|
(snd (ScopeVarMap.find (Marked.unmark v) decl.scope_sig))
|
|
.io_input
|
|
with
|
|
| Reentrant ->
|
|
Format.fprintf fmt "%a@ %a" Print.operator
|
|
"reentrant or by default" (format_expr ~debug) e
|
|
| _ -> Format.fprintf fmt "%a" (format_expr ~debug) e))
|
|
e
|
|
| Assertion e ->
|
|
Format.fprintf fmt "%a %a" Print.keyword "assert"
|
|
(format_expr ~debug) e
|
|
| Call (scope_name, subscope_name) ->
|
|
Format.fprintf fmt "%a %a%a%a%a" Print.keyword "call"
|
|
ScopeName.format_t scope_name Print.punctuation "["
|
|
SubScopeName.format_t subscope_name Print.punctuation "]"))
|
|
decl.scope_decl_rules
|
|
|
|
let format_program
|
|
?(debug : bool = false)
|
|
(fmt : Format.formatter)
|
|
(p : program) : unit =
|
|
Format.fprintf fmt "%a%a%a%a%a"
|
|
(Format.pp_print_list
|
|
~pp_sep:(fun fmt () -> Format.fprintf fmt "\n\n")
|
|
format_struct)
|
|
(StructMap.bindings p.program_structs)
|
|
(fun fmt () ->
|
|
if StructMap.is_empty p.program_structs then Format.fprintf fmt ""
|
|
else Format.fprintf fmt "\n\n")
|
|
()
|
|
(Format.pp_print_list
|
|
~pp_sep:(fun fmt () -> Format.fprintf fmt "\n\n")
|
|
format_enum)
|
|
(EnumMap.bindings p.program_enums)
|
|
(fun fmt () ->
|
|
if EnumMap.is_empty p.program_enums then Format.fprintf fmt ""
|
|
else Format.fprintf fmt "\n\n")
|
|
()
|
|
(Format.pp_print_list
|
|
~pp_sep:(fun fmt () -> Format.fprintf fmt "\n\n")
|
|
(format_scope ~debug))
|
|
(ScopeMap.bindings p.program_scopes)
|