mirror of
https://github.com/CatalaLang/catala.git
synced 2024-11-09 22:16:10 +03:00
Added ints and operators to default calculus
This commit is contained in:
parent
18eb757232
commit
e1dfb52095
@ -17,10 +17,17 @@ module Pos = Utils.Pos
|
||||
type typ =
|
||||
| TBool
|
||||
| TUnit
|
||||
| TInt
|
||||
| TTuple of typ Pos.marked list
|
||||
| TArrow of typ Pos.marked * typ Pos.marked
|
||||
|
||||
type lit = LTrue | LFalse | LEmptyError
|
||||
type lit = LBool of bool | LEmptyError | LInt of Int64.t
|
||||
|
||||
type binop = And | Or | Add | Sub | Mult | Div | Lt | Lte | Gt | Gte | Eq | Neq
|
||||
|
||||
type unop = Not | Minus
|
||||
|
||||
type operator = Binop of binop | Unop of unop
|
||||
|
||||
type expr =
|
||||
| EVar of expr Pos.marked Bindlib.var
|
||||
@ -29,6 +36,7 @@ type expr =
|
||||
| ELit of lit
|
||||
| EAbs of Pos.t * (expr Pos.marked, expr Pos.marked) Bindlib.mbinder * typ list
|
||||
| EApp of expr Pos.marked * expr Pos.marked list
|
||||
| EOp of operator
|
||||
| EDefault of expr Pos.marked * expr Pos.marked * expr Pos.marked list
|
||||
| EIfThenElse of expr Pos.marked * expr Pos.marked * expr Pos.marked
|
||||
|
||||
|
@ -19,6 +19,42 @@ module A = Ast
|
||||
let is_empty_error (e : A.expr Pos.marked) : bool =
|
||||
match Pos.unmark e with ELit LEmptyError -> true | _ -> false
|
||||
|
||||
let evaluate_operator (op : A.operator Pos.marked) (args : A.expr Pos.marked list) :
|
||||
A.expr Pos.marked =
|
||||
Pos.same_pos_as
|
||||
( match (Pos.unmark op, List.map Pos.unmark args) with
|
||||
| A.Binop A.And, [ ELit (LBool b1); ELit (LBool b2) ] -> A.ELit (LBool (b1 && b2))
|
||||
| A.Binop A.Or, [ ELit (LBool b1); ELit (LBool b2) ] -> A.ELit (LBool (b1 || b2))
|
||||
| A.Binop A.Add, [ ELit (LInt i1); ELit (LInt i2) ] -> A.ELit (LInt (Int64.add i1 i2))
|
||||
| A.Binop A.Sub, [ ELit (LInt i1); ELit (LInt i2) ] -> A.ELit (LInt (Int64.sub i1 i2))
|
||||
| A.Binop A.Mult, [ ELit (LInt i1); ELit (LInt i2) ] -> A.ELit (LInt (Int64.mul i1 i2))
|
||||
| A.Binop A.Div, [ ELit (LInt i1); ELit (LInt i2) ] ->
|
||||
if i2 <> Int64.zero then A.ELit (LInt (Int64.div i1 i2))
|
||||
else
|
||||
Errors.raise_multispanned_error "division by zero at runtime"
|
||||
[
|
||||
(Some "The division operator:", Pos.get_position op);
|
||||
(Some "The null denominator:", Pos.get_position (List.nth args 2));
|
||||
]
|
||||
| A.Binop A.Lt, [ ELit (LInt i1); ELit (LInt i2) ] -> A.ELit (LBool (i1 < i2))
|
||||
| A.Binop A.Lte, [ ELit (LInt i1); ELit (LInt i2) ] -> A.ELit (LBool (i1 <= i2))
|
||||
| A.Binop A.Gt, [ ELit (LInt i1); ELit (LInt i2) ] -> A.ELit (LBool (i1 > i2))
|
||||
| A.Binop A.Gte, [ ELit (LInt i1); ELit (LInt i2) ] -> A.ELit (LBool (i1 >= i2))
|
||||
| A.Binop A.Eq, [ ELit (LInt i1); ELit (LInt i2) ] -> A.ELit (LBool (i1 = i2))
|
||||
| A.Binop A.Eq, [ ELit (LBool b1); ELit (LBool b2) ] -> A.ELit (LBool (b1 = b2))
|
||||
| A.Binop A.Eq, [ _; _ ] -> A.ELit (LBool false) (* comparing functions return false *)
|
||||
| A.Binop A.Neq, [ ELit (LInt i1); ELit (LInt i2) ] -> A.ELit (LBool (i1 <> i2))
|
||||
| A.Binop A.Neq, [ ELit (LBool b1); ELit (LBool b2) ] -> A.ELit (LBool (b1 <> b2))
|
||||
| A.Binop A.Neq, [ _; _ ] -> A.ELit (LBool true)
|
||||
| A.Unop A.Not, [ ELit (LBool b) ] -> A.ELit (LBool (not b))
|
||||
| A.Unop A.Minus, [ ELit (LInt i) ] -> A.ELit (LInt (Int64.sub Int64.zero i))
|
||||
| _ ->
|
||||
Errors.raise_multispanned_error
|
||||
"operator applied to the wrong arguments (should not happen if the term was well-typed)"
|
||||
[ (Some "Operator:", Pos.get_position op) ]
|
||||
@@ List.mapi (fun i arg -> Some ("Argument n°" ^ string_of_int i, Pos.get_position arg)) )
|
||||
op
|
||||
|
||||
let rec evaluate_expr (e : A.expr Pos.marked) : A.expr Pos.marked =
|
||||
match Pos.unmark e with
|
||||
| EVar _ ->
|
||||
@ -38,13 +74,14 @@ let rec evaluate_expr (e : A.expr Pos.marked) : A.expr Pos.marked =
|
||||
(Format.asprintf "wrong function call, expected %d arguments, got %d"
|
||||
(Bindlib.mbinder_arity binder) (List.length args))
|
||||
(Pos.get_position e)
|
||||
| EOp op -> evaluate_operator (Pos.same_pos_as op e1) args
|
||||
| ELit LEmptyError -> Pos.same_pos_as (A.ELit LEmptyError) e
|
||||
| _ ->
|
||||
Errors.raise_spanned_error
|
||||
"function has not been reduced to a lambda at evaluation (should not happen if the \
|
||||
term was well-typed"
|
||||
(Pos.get_position e) )
|
||||
| EAbs _ | ELit _ -> e (* thse are values *)
|
||||
| EAbs _ | ELit _ | EOp _ -> e (* thse are values *)
|
||||
| ETuple es -> Pos.same_pos_as (A.ETuple (List.map evaluate_expr es)) e
|
||||
| ETupleAccess (e1, n) -> (
|
||||
let e1 = evaluate_expr e1 in
|
||||
@ -70,8 +107,8 @@ let rec evaluate_expr (e : A.expr Pos.marked) : A.expr Pos.marked =
|
||||
let just = evaluate_expr just in
|
||||
match Pos.unmark just with
|
||||
| ELit LEmptyError -> Pos.same_pos_as (A.ELit LEmptyError) e
|
||||
| ELit LTrue -> evaluate_expr cons
|
||||
| ELit LFalse -> (
|
||||
| ELit (LBool true) -> evaluate_expr cons
|
||||
| ELit (LBool false) -> (
|
||||
let subs = List.map evaluate_expr subs in
|
||||
let empty_count = List.length (List.filter is_empty_error subs) in
|
||||
match List.length subs - empty_count with
|
||||
@ -96,8 +133,8 @@ let rec evaluate_expr (e : A.expr Pos.marked) : A.expr Pos.marked =
|
||||
(Pos.get_position e) )
|
||||
| EIfThenElse (cond, et, ef) -> (
|
||||
match Pos.unmark (evaluate_expr cond) with
|
||||
| ELit LTrue -> evaluate_expr et
|
||||
| ELit LFalse -> evaluate_expr ef
|
||||
| ELit (LBool true) -> evaluate_expr et
|
||||
| ELit (LBool false) -> evaluate_expr ef
|
||||
| _ ->
|
||||
Errors.raise_spanned_error
|
||||
"expected a boolean literal for the result of this condition (should not happen if the \
|
||||
|
@ -21,6 +21,7 @@ module A = Ast
|
||||
|
||||
type typ =
|
||||
| TUnit
|
||||
| TInt
|
||||
| TBool
|
||||
| TArrow of typ Pos.marked UnionFind.elem * typ Pos.marked UnionFind.elem
|
||||
| TTuple of typ Pos.marked UnionFind.elem list
|
||||
@ -31,6 +32,7 @@ let rec format_typ (fmt : Format.formatter) (ty : typ Pos.marked UnionFind.elem)
|
||||
match Pos.unmark ty_repr with
|
||||
| TUnit -> Format.fprintf fmt "unit"
|
||||
| TBool -> Format.fprintf fmt "bool"
|
||||
| TInt -> Format.fprintf fmt "int"
|
||||
| TAny -> Format.fprintf fmt "any"
|
||||
| TTuple ts ->
|
||||
Format.fprintf fmt "(%a)"
|
||||
@ -42,7 +44,7 @@ let rec unify (t1 : typ Pos.marked UnionFind.elem) (t2 : typ Pos.marked UnionFin
|
||||
let t1_repr = UnionFind.get (UnionFind.find t1) in
|
||||
let t2_repr = UnionFind.get (UnionFind.find t2) in
|
||||
match (t1_repr, t2_repr) with
|
||||
| (TUnit, _), (TUnit, _) | (TBool, _), (TBool, _) -> ()
|
||||
| (TUnit, _), (TUnit, _) | (TBool, _), (TBool, _) | (TInt, _), (TInt, _) -> ()
|
||||
| (TArrow (t11, t12), _), (TArrow (t21, t22), _) ->
|
||||
unify t11 t21;
|
||||
unify t12 t22
|
||||
@ -60,10 +62,25 @@ let rec unify (t1 : typ Pos.marked UnionFind.elem) (t2 : typ Pos.marked UnionFin
|
||||
(Some (Format.asprintf "Type %a coming from expression:" format_typ t2), t2_pos);
|
||||
]
|
||||
|
||||
let op_type (op : A.operator Pos.marked) : typ Pos.marked UnionFind.elem =
|
||||
let pos = Pos.get_position op in
|
||||
let bt = UnionFind.make (TBool, pos) in
|
||||
let it = UnionFind.make (TInt, pos) in
|
||||
let any = UnionFind.make (TAny, pos) in
|
||||
let arr x y = UnionFind.make (TArrow (x, y), pos) in
|
||||
match Pos.unmark op with
|
||||
| A.Binop (A.And | A.Or) -> arr bt (arr bt bt)
|
||||
| A.Binop (A.Add | A.Sub | A.Mult | A.Div) -> arr it (arr it it)
|
||||
| A.Binop (A.Lt | A.Lte | A.Gt | A.Gte) -> arr it (arr it bt)
|
||||
| A.Binop (A.Eq | A.Neq) -> arr any (arr any bt)
|
||||
| A.Unop A.Minus -> arr it it
|
||||
| A.Unop A.Not -> arr bt bt
|
||||
|
||||
let rec ast_to_typ (ty : A.typ) : typ =
|
||||
match ty with
|
||||
| A.TUnit -> TUnit
|
||||
| A.TBool -> TBool
|
||||
| A.TInt -> TInt
|
||||
| A.TArrow (t1, t2) ->
|
||||
TArrow
|
||||
( UnionFind.make (Pos.map_under_mark ast_to_typ t1),
|
||||
@ -76,6 +93,7 @@ let rec typ_to_ast (ty : typ Pos.marked UnionFind.elem) : A.typ Pos.marked =
|
||||
match ty with
|
||||
| TUnit -> A.TUnit
|
||||
| TBool -> A.TBool
|
||||
| TInt -> A.TInt
|
||||
| TTuple ts -> A.TTuple (List.map typ_to_ast ts)
|
||||
| TArrow (t1, t2) -> A.TArrow (typ_to_ast t1, typ_to_ast t2)
|
||||
| TAny -> A.TUnit)
|
||||
@ -92,7 +110,8 @@ let rec typecheck_expr_bottom_up (env : env) (e : A.expr Pos.marked) : typ Pos.m
|
||||
| None ->
|
||||
Errors.raise_spanned_error "Variable not found in the current context"
|
||||
(Pos.get_position e) )
|
||||
| ELit (LTrue | LFalse) -> UnionFind.make (Pos.same_pos_as TBool e)
|
||||
| ELit (LBool _) -> UnionFind.make (Pos.same_pos_as TBool e)
|
||||
| ELit (LInt _) -> UnionFind.make (Pos.same_pos_as TInt e)
|
||||
| ELit LEmptyError -> UnionFind.make (Pos.same_pos_as TAny e)
|
||||
| ETuple es ->
|
||||
let ts = List.map (typecheck_expr_bottom_up env) es in
|
||||
@ -138,6 +157,7 @@ let rec typecheck_expr_bottom_up (env : env) (e : A.expr Pos.marked) : typ Pos.m
|
||||
in
|
||||
typecheck_expr_top_down env e1 t_app;
|
||||
t_app
|
||||
| EOp op -> op_type (Pos.same_pos_as op e)
|
||||
| EDefault (just, cons, subs) ->
|
||||
typecheck_expr_top_down env just (UnionFind.make (Pos.same_pos_as TBool just));
|
||||
let tcons = typecheck_expr_bottom_up env cons in
|
||||
@ -158,7 +178,8 @@ and typecheck_expr_top_down (env : env) (e : A.expr Pos.marked)
|
||||
| None ->
|
||||
Errors.raise_spanned_error "Variable not found in the current context"
|
||||
(Pos.get_position e) )
|
||||
| ELit (LTrue | LFalse) -> unify tau (UnionFind.make (Pos.same_pos_as TBool e))
|
||||
| ELit (LBool _) -> unify tau (UnionFind.make (Pos.same_pos_as TBool e))
|
||||
| ELit (LInt _) -> unify tau (UnionFind.make (Pos.same_pos_as TInt e))
|
||||
| ELit LEmptyError -> unify tau (UnionFind.make (Pos.same_pos_as TAny e))
|
||||
| ETuple es -> (
|
||||
let tau' = UnionFind.get (UnionFind.find tau) in
|
||||
@ -216,6 +237,9 @@ and typecheck_expr_top_down (env : env) (e : A.expr Pos.marked)
|
||||
t_args tau
|
||||
in
|
||||
unify te1 t_func
|
||||
| EOp op ->
|
||||
let op_typ = op_type (Pos.same_pos_as op e) in
|
||||
unify op_typ tau
|
||||
| EDefault (just, cons, subs) ->
|
||||
typecheck_expr_top_down env just (UnionFind.make (Pos.same_pos_as TBool just));
|
||||
typecheck_expr_top_down env cons tau;
|
||||
|
@ -57,7 +57,7 @@ type definition = Scopelang.Ast.expr Pos.marked
|
||||
|
||||
let empty_def (pos : Pos.t) : definition =
|
||||
( Scopelang.Ast.EDefault
|
||||
( (Scopelang.Ast.ELit Dcalc.Ast.LFalse, pos),
|
||||
( (Scopelang.Ast.ELit (Dcalc.Ast.LBool false), pos),
|
||||
(Scopelang.Ast.ELit Dcalc.Ast.LEmptyError, pos),
|
||||
[] ),
|
||||
pos )
|
||||
|
Loading…
Reference in New Issue
Block a user