mirror of
https://github.com/CatalaLang/catala.git
synced 2024-11-08 07:51:43 +03:00
Field and constructor names appear in default calculus
This commit is contained in:
parent
70aa8ae2c1
commit
fb592fa735
@ -13,6 +13,7 @@
|
||||
the License. *)
|
||||
|
||||
module Pos = Utils.Pos
|
||||
module Uid = Utils.Uid
|
||||
|
||||
type typ =
|
||||
| TBool
|
||||
@ -32,10 +33,14 @@ type operator = Binop of binop | Unop of unop
|
||||
|
||||
type expr =
|
||||
| EVar of expr Bindlib.var Pos.marked
|
||||
| ETuple of expr Pos.marked list
|
||||
| ETupleAccess of expr Pos.marked * int
|
||||
| EInj of expr Pos.marked * int * typ Pos.marked list
|
||||
| EMatch of expr Pos.marked * expr Pos.marked list
|
||||
| ETuple of (expr Pos.marked * Uid.MarkedString.info option) list
|
||||
(** The [MarkedString.info] is the former struct field name*)
|
||||
| ETupleAccess of expr Pos.marked * int * Uid.MarkedString.info option
|
||||
(** The [MarkedString.info] is the former struct field name*)
|
||||
| EInj of expr Pos.marked * int * Uid.MarkedString.info * typ Pos.marked list
|
||||
(** The [MarkedString.info] is the former enum case name *)
|
||||
| EMatch of expr Pos.marked * (expr Pos.marked * Uid.MarkedString.info) list
|
||||
(** The [MarkedString.info] is the former enum case name *)
|
||||
| ELit of lit
|
||||
| EAbs of Pos.t * (expr, expr Pos.marked) Bindlib.mbinder * typ Pos.marked list
|
||||
| EApp of expr Pos.marked * expr Pos.marked list
|
||||
|
@ -83,13 +83,13 @@ let rec evaluate_expr (e : A.expr Pos.marked) : A.expr Pos.marked =
|
||||
term was well-typed"
|
||||
(Pos.get_position e) )
|
||||
| EAbs _ | ELit _ | EOp _ -> e (* thse are values *)
|
||||
| ETuple es -> Pos.same_pos_as (A.ETuple (List.map evaluate_expr es)) e
|
||||
| ETupleAccess (e1, n) -> (
|
||||
| ETuple es -> Pos.same_pos_as (A.ETuple (List.map (fun (e', i) -> (evaluate_expr e', i)) es)) e
|
||||
| ETupleAccess (e1, n, _) -> (
|
||||
let e1 = evaluate_expr e1 in
|
||||
match Pos.unmark e1 with
|
||||
| ETuple es -> (
|
||||
match List.nth_opt es n with
|
||||
| Some e' -> e'
|
||||
| Some (e', _) -> e'
|
||||
| None ->
|
||||
Errors.raise_spanned_error
|
||||
(Format.asprintf
|
||||
@ -104,14 +104,14 @@ let rec evaluate_expr (e : A.expr Pos.marked) : A.expr Pos.marked =
|
||||
if the term was well-typed)"
|
||||
n)
|
||||
(Pos.get_position e1) )
|
||||
| EInj (e1, n, ts) ->
|
||||
| EInj (e1, n, i, ts) ->
|
||||
let e1' = evaluate_expr e1 in
|
||||
Pos.same_pos_as (A.EInj (e1', n, ts)) e
|
||||
Pos.same_pos_as (A.EInj (e1', n, i, ts)) e
|
||||
| EMatch (e1, es) -> (
|
||||
let e1 = evaluate_expr e1 in
|
||||
match Pos.unmark e1 with
|
||||
| A.EInj (e1, n, _) ->
|
||||
let es_n =
|
||||
| A.EInj (e1, n, _, _) ->
|
||||
let es_n, _ =
|
||||
match List.nth_opt es n with
|
||||
| Some es_n -> es_n
|
||||
| None ->
|
||||
@ -190,7 +190,7 @@ let interpret_program (e : Ast.expr Pos.marked) : (Ast.Var.t * Ast.expr Pos.mark
|
||||
match Pos.unmark (evaluate_expr to_interpret) with
|
||||
| Ast.ETuple args ->
|
||||
let vars, _ = Bindlib.unmbind binder in
|
||||
List.map2 (fun arg var -> (var, arg)) args (Array.to_list vars)
|
||||
List.map2 (fun (arg, _) var -> (var, arg)) args (Array.to_list vars)
|
||||
| _ ->
|
||||
Errors.raise_spanned_error "The interpretation of a program should always yield a tuple"
|
||||
(Pos.get_position e) )
|
||||
|
@ -79,16 +79,26 @@ let rec format_expr (fmt : Format.formatter) (e : expr Pos.marked) : unit =
|
||||
| EVar v -> Format.fprintf fmt "%a" format_var (Pos.unmark v)
|
||||
| ETuple es ->
|
||||
Format.fprintf fmt "(%a)"
|
||||
(Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt ",") format_expr)
|
||||
(Format.pp_print_list
|
||||
~pp_sep:(fun fmt () -> Format.fprintf fmt ",")
|
||||
(fun fmt (e, struct_field) ->
|
||||
match struct_field with
|
||||
| Some struct_field ->
|
||||
Format.fprintf fmt "@[<hov 2>\"%a\":@ %a@]" Uid.MarkedString.format_info
|
||||
struct_field format_expr e
|
||||
| None -> Format.fprintf fmt "@[%a@]" format_expr e))
|
||||
es
|
||||
| ETupleAccess (e1, n) -> Format.fprintf fmt "%a.%d" format_expr e1 n
|
||||
| EInj (e, n, ts) ->
|
||||
Format.fprintf fmt "inj[%a].%d %a"
|
||||
(Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt " *@ ") format_typ)
|
||||
ts n format_expr e
|
||||
| ETupleAccess (e1, n, i) -> (
|
||||
match i with
|
||||
| None -> Format.fprintf fmt "%a.%d" format_expr e1 n
|
||||
| Some i -> Format.fprintf fmt "%a.\"%a\"" format_expr e1 Uid.MarkedString.format_info i )
|
||||
| EInj (e, _n, i, _ts) -> Format.fprintf fmt "%a %a" Uid.MarkedString.format_info i format_expr e
|
||||
| EMatch (e, es) ->
|
||||
Format.fprintf fmt "@[<hov 2>match %a with %a@]" format_expr e
|
||||
(Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt " |@ ") format_expr)
|
||||
(Format.pp_print_list
|
||||
~pp_sep:(fun fmt () -> Format.fprintf fmt " |@ ")
|
||||
(fun fmt (e, c) ->
|
||||
Format.fprintf fmt "%a %a" Uid.MarkedString.format_info c format_expr e))
|
||||
es
|
||||
| ELit l -> Format.fprintf fmt "%a" format_lit (Pos.same_pos_as l e)
|
||||
| EApp ((EAbs (_, binder, taus), _), args) ->
|
||||
|
@ -127,9 +127,9 @@ let rec typecheck_expr_bottom_up (env : env) (e : A.expr Pos.marked) : typ Pos.m
|
||||
| ELit LUnit -> UnionFind.make (Pos.same_pos_as TUnit e)
|
||||
| ELit LEmptyError -> UnionFind.make (Pos.same_pos_as TAny e)
|
||||
| ETuple es ->
|
||||
let ts = List.map (typecheck_expr_bottom_up env) es in
|
||||
let ts = List.map (fun (e, _) -> typecheck_expr_bottom_up env e) es in
|
||||
UnionFind.make (Pos.same_pos_as (TTuple ts) e)
|
||||
| ETupleAccess (e1, n) -> (
|
||||
| ETupleAccess (e1, n, _) -> (
|
||||
let t1 = typecheck_expr_bottom_up env e1 in
|
||||
match Pos.unmark (UnionFind.get (UnionFind.find t1)) with
|
||||
| TTuple ts -> (
|
||||
@ -145,7 +145,7 @@ let rec typecheck_expr_bottom_up (env : env) (e : A.expr Pos.marked) : typ Pos.m
|
||||
Errors.raise_spanned_error
|
||||
(Format.asprintf "Expected a tuple, got a %a" format_typ t1)
|
||||
(Pos.get_position e1) )
|
||||
| EInj (e1, n, ts) ->
|
||||
| EInj (e1, n, _, ts) ->
|
||||
let ts = List.map (fun t -> UnionFind.make (Pos.map_under_mark ast_to_typ t)) ts in
|
||||
let ts_n =
|
||||
match List.nth_opt ts n with
|
||||
@ -160,12 +160,12 @@ let rec typecheck_expr_bottom_up (env : env) (e : A.expr Pos.marked) : typ Pos.m
|
||||
typecheck_expr_top_down env e1 ts_n;
|
||||
UnionFind.make (Pos.same_pos_as (TEnum ts) e)
|
||||
| EMatch (e1, es) ->
|
||||
let enum_cases = List.map (fun e' -> UnionFind.make (Pos.same_pos_as TAny e')) es in
|
||||
let enum_cases = List.map (fun (e', _) -> UnionFind.make (Pos.same_pos_as TAny e')) es in
|
||||
let t_e1 = UnionFind.make (Pos.same_pos_as (TEnum enum_cases) e1) in
|
||||
typecheck_expr_top_down env e1 t_e1;
|
||||
let t_ret = UnionFind.make (Pos.same_pos_as TAny e) in
|
||||
List.iteri
|
||||
(fun i es' ->
|
||||
(fun i (es', _) ->
|
||||
let enum_t = List.nth enum_cases i in
|
||||
let t_es' = UnionFind.make (Pos.same_pos_as (TArrow (enum_t, t_ret)) es') in
|
||||
typecheck_expr_top_down env es' t_es')
|
||||
@ -229,12 +229,12 @@ and typecheck_expr_top_down (env : env) (e : A.expr Pos.marked)
|
||||
| ETuple es -> (
|
||||
let tau' = UnionFind.get (UnionFind.find tau) in
|
||||
match Pos.unmark tau' with
|
||||
| TTuple ts -> List.iter2 (typecheck_expr_top_down env) es ts
|
||||
| TTuple ts -> List.iter2 (fun (e, _) t -> typecheck_expr_top_down env e t) es ts
|
||||
| _ ->
|
||||
Errors.raise_spanned_error
|
||||
(Format.asprintf "exprected %a, got a tuple" format_typ tau)
|
||||
(Pos.get_position e) )
|
||||
| ETupleAccess (e1, n) -> (
|
||||
| ETupleAccess (e1, n, _) -> (
|
||||
let t1 = typecheck_expr_bottom_up env e1 in
|
||||
match Pos.unmark (UnionFind.get (UnionFind.find t1)) with
|
||||
| TTuple t1s -> (
|
||||
@ -250,7 +250,7 @@ and typecheck_expr_top_down (env : env) (e : A.expr Pos.marked)
|
||||
Errors.raise_spanned_error
|
||||
(Format.asprintf "exprected a tuple , got %a" format_typ tau)
|
||||
(Pos.get_position e) )
|
||||
| EInj (e1, n, ts) ->
|
||||
| EInj (e1, n, _, ts) ->
|
||||
let ts = List.map (fun t -> UnionFind.make (Pos.map_under_mark ast_to_typ t)) ts in
|
||||
let ts_n =
|
||||
match List.nth_opt ts n with
|
||||
@ -265,12 +265,12 @@ and typecheck_expr_top_down (env : env) (e : A.expr Pos.marked)
|
||||
typecheck_expr_top_down env e1 ts_n;
|
||||
unify (UnionFind.make (Pos.same_pos_as (TEnum ts) e)) tau
|
||||
| EMatch (e1, es) ->
|
||||
let enum_cases = List.map (fun e' -> UnionFind.make (Pos.same_pos_as TAny e')) es in
|
||||
let enum_cases = List.map (fun (e', _) -> UnionFind.make (Pos.same_pos_as TAny e')) es in
|
||||
let t_e1 = UnionFind.make (Pos.same_pos_as (TEnum enum_cases) e1) in
|
||||
typecheck_expr_top_down env e1 t_e1;
|
||||
let t_ret = UnionFind.make (Pos.same_pos_as TAny e) in
|
||||
List.iteri
|
||||
(fun i es' ->
|
||||
(fun i (es', _) ->
|
||||
let enum_t = List.nth enum_cases i in
|
||||
let t_es' = UnionFind.make (Pos.same_pos_as (TArrow (enum_t, t_ret)) es') in
|
||||
typecheck_expr_top_down env es' t_es')
|
||||
|
@ -95,6 +95,11 @@ let rec translate_expr (ctx : ctx) (e : Ast.expr Pos.marked) : Dcalc.Ast.expr Po
|
||||
(Pos.get_position e)
|
||||
in
|
||||
let field_d = translate_expr ctx field_e in
|
||||
let field_d =
|
||||
Bindlib.box_apply
|
||||
(fun field_d -> (field_d, Some (Ast.StructFieldName.get_info field_name)))
|
||||
field_d
|
||||
in
|
||||
(field_d :: d_fields, Ast.StructFieldMap.remove field_name e_fields))
|
||||
struct_sig ([], e_fields)
|
||||
in
|
||||
@ -121,7 +126,10 @@ let rec translate_expr (ctx : ctx) (e : Ast.expr Pos.marked) : Dcalc.Ast.expr Po
|
||||
(Pos.get_position e)
|
||||
in
|
||||
let e1 = translate_expr ctx e1 in
|
||||
Bindlib.box_apply (fun e1 -> Dcalc.Ast.ETupleAccess (e1, field_index)) e1
|
||||
Bindlib.box_apply
|
||||
(fun e1 ->
|
||||
Dcalc.Ast.ETupleAccess (e1, field_index, Some (Ast.StructFieldName.get_info field_name)))
|
||||
e1
|
||||
| EEnumInj (e1, constructor, enum_name) ->
|
||||
let enum_sig = Ast.EnumMap.find enum_name ctx.enums in
|
||||
let _, constructor_index =
|
||||
@ -136,7 +144,10 @@ let rec translate_expr (ctx : ctx) (e : Ast.expr Pos.marked) : Dcalc.Ast.expr Po
|
||||
Bindlib.box_apply
|
||||
(fun e1 ->
|
||||
Dcalc.Ast.EInj
|
||||
(e1, constructor_index, List.map (fun (_, t) -> translate_typ ctx t) enum_sig))
|
||||
( e1,
|
||||
constructor_index,
|
||||
Ast.EnumConstructor.get_info constructor,
|
||||
List.map (fun (_, t) -> translate_typ ctx t) enum_sig ))
|
||||
e1
|
||||
| EMatch (e1, enum_name, cases) ->
|
||||
let enum_sig = Ast.EnumMap.find enum_name ctx.enums in
|
||||
@ -144,15 +155,19 @@ let rec translate_expr (ctx : ctx) (e : Ast.expr Pos.marked) : Dcalc.Ast.expr Po
|
||||
List.fold_right
|
||||
(fun (constructor, _) (d_cases, e_cases) ->
|
||||
let case_e =
|
||||
Option.value
|
||||
~default:
|
||||
(Errors.raise_spanned_error
|
||||
(Format.asprintf "The constructor %a does not belong to the enum %a"
|
||||
Ast.EnumConstructor.format_t constructor Ast.EnumName.format_t enum_name)
|
||||
(Pos.get_position e))
|
||||
(Ast.EnumConstructorMap.find_opt constructor e_cases)
|
||||
try Ast.EnumConstructorMap.find constructor e_cases
|
||||
with Not_found ->
|
||||
Errors.raise_spanned_error
|
||||
(Format.asprintf "The constructor %a does not belong to the enum %a"
|
||||
Ast.EnumConstructor.format_t constructor Ast.EnumName.format_t enum_name)
|
||||
(Pos.get_position e)
|
||||
in
|
||||
let case_d = translate_expr ctx case_e in
|
||||
let case_d =
|
||||
Bindlib.box_apply
|
||||
(fun case_d -> (case_d, Ast.EnumConstructor.get_info constructor))
|
||||
case_d
|
||||
in
|
||||
(case_d :: d_cases, Ast.EnumConstructorMap.remove constructor e_cases))
|
||||
enum_sig ([], cases)
|
||||
in
|
||||
@ -343,7 +358,7 @@ let rec translate_rule (ctx : ctx) (rule : Ast.rule) (rest : Ast.rule list) (pos
|
||||
(fun (_, tau, dvar) (acc, i) ->
|
||||
let result_access =
|
||||
Bindlib.box_apply
|
||||
(fun r -> (Dcalc.Ast.ETupleAccess (r, i), pos_sigma))
|
||||
(fun r -> (Dcalc.Ast.ETupleAccess (r, i, None), pos_sigma))
|
||||
(Dcalc.Ast.make_var (result_tuple_var, pos_sigma))
|
||||
in
|
||||
(Dcalc.Ast.make_let_in dvar (tau, pos_sigma) result_access acc, i - 1))
|
||||
@ -363,7 +378,7 @@ and translate_rules (ctx : ctx) (rules : Ast.rule list) (pos_sigma : Pos.t) :
|
||||
let scope_variables = Ast.ScopeVarMap.bindings ctx.scope_vars in
|
||||
let return_exp =
|
||||
Bindlib.box_apply
|
||||
(fun args -> (Dcalc.Ast.ETuple args, pos_sigma))
|
||||
(fun args -> (Dcalc.Ast.ETuple (List.map (fun arg -> (arg, None)) args), pos_sigma))
|
||||
(Bindlib.box_list
|
||||
(List.map
|
||||
(fun (_, (dcalc_var, _)) -> Dcalc.Ast.make_var (dcalc_var, pos_sigma))
|
||||
|
@ -15,7 +15,7 @@
|
||||
module type Info = sig
|
||||
type info
|
||||
|
||||
val format_info : info -> string
|
||||
val format_info : Format.formatter -> info -> unit
|
||||
end
|
||||
|
||||
module type Id = sig
|
||||
@ -34,12 +34,7 @@ module type Id = sig
|
||||
val hash : t -> int
|
||||
end
|
||||
|
||||
module Make (X : sig
|
||||
type info
|
||||
|
||||
val format_info : info -> string
|
||||
end)
|
||||
() : Id with type info = X.info = struct
|
||||
module Make (X : Info) () : Id with type info = X.info = struct
|
||||
type t = { id : int; info : X.info }
|
||||
|
||||
type info = X.info
|
||||
@ -55,7 +50,7 @@ end)
|
||||
let compare (x : t) (y : t) : int = compare x.id y.id
|
||||
|
||||
let format_t (fmt : Format.formatter) (x : t) : unit =
|
||||
Format.fprintf fmt "%s" (X.format_info x.info)
|
||||
Format.fprintf fmt "%a" X.format_info x.info
|
||||
|
||||
let hash (x : t) : int = x.id
|
||||
end
|
||||
@ -63,5 +58,5 @@ end
|
||||
module MarkedString = struct
|
||||
type info = string Pos.marked
|
||||
|
||||
let format_info (s, _) = s
|
||||
let format_info fmt (s, _) = Format.fprintf fmt "%s" s
|
||||
end
|
||||
|
@ -15,7 +15,7 @@
|
||||
module type Info = sig
|
||||
type info
|
||||
|
||||
val format_info : info -> string
|
||||
val format_info : Format.formatter -> info -> unit
|
||||
end
|
||||
|
||||
module MarkedString : Info with type info = string Pos.marked
|
||||
|
Loading…
Reference in New Issue
Block a user