mirror of
https://github.com/CatalaLang/catala.git
synced 2024-11-09 22:16:10 +03:00
130 lines
5.1 KiB
OCaml
130 lines
5.1 KiB
OCaml
(* This file is part of the Catala compiler, a specification language for tax and social benefits
|
|
computation rules. Copyright (C) 2020 Inria, contributor: Denis Merigoux
|
|
<denis.merigoux@inria.fr>
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
|
in compliance with the License. You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software distributed under the License
|
|
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
|
or implied. See the License for the specific language governing permissions and limitations under
|
|
the License. *)
|
|
open Utils
|
|
open Ast
|
|
|
|
let ( let+ ) x f = Bindlib.box_apply f x
|
|
|
|
let ( and+ ) x y = Bindlib.box_pair x y
|
|
|
|
let visitor_map
|
|
(t : 'a -> expr Pos.marked -> expr Pos.marked Bindlib.box)
|
|
(ctx : 'a)
|
|
(e : expr Pos.marked) : expr Pos.marked Bindlib.box =
|
|
(* calls [t ctx] on every direct childs of [e], then rebuild an abstract syntax tree modified.
|
|
Used in other transformations. *)
|
|
let default_mark e' = Pos.same_pos_as e' e in
|
|
match Pos.unmark e with
|
|
| EVar (v, pos) ->
|
|
let+ v = Bindlib.box_var v in
|
|
(v, pos)
|
|
| ETuple (args, n) ->
|
|
let+ args = args |> List.map (t ctx) |> Bindlib.box_list in
|
|
default_mark @@ ETuple (args, n)
|
|
| ETupleAccess (e1, i, n, ts) ->
|
|
let+ e1 = t ctx e1 in
|
|
default_mark @@ ETupleAccess (e1, i, n, ts)
|
|
| EInj (e1, i, n, ts) ->
|
|
let+ e1 = t ctx e1 in
|
|
default_mark @@ EInj (e1, i, n, ts)
|
|
| EMatch (arg, cases, n) ->
|
|
let+ arg = t ctx arg and+ cases = cases |> List.map (t ctx) |> Bindlib.box_list in
|
|
default_mark @@ EMatch (arg, cases, n)
|
|
| EArray args ->
|
|
let+ args = args |> List.map (t ctx) |> Bindlib.box_list in
|
|
default_mark @@ EArray args
|
|
| EAbs ((binder, pos_binder), ts) ->
|
|
let vars, body = Bindlib.unmbind binder in
|
|
let body = t ctx body in
|
|
let+ binder = Bindlib.bind_mvar vars body in
|
|
default_mark @@ EAbs ((binder, pos_binder), ts)
|
|
| EApp (e1, args) ->
|
|
let+ e1 = t ctx e1 and+ args = args |> List.map (t ctx) |> Bindlib.box_list in
|
|
default_mark @@ EApp (e1, args)
|
|
| EAssert e1 ->
|
|
let+ e1 = t ctx e1 in
|
|
default_mark @@ EAssert e1
|
|
| EIfThenElse (e1, e2, e3) ->
|
|
let+ e1 = t ctx e1 and+ e2 = t ctx e2 and+ e3 = t ctx e3 in
|
|
default_mark @@ EIfThenElse (e1, e2, e3)
|
|
| ECatch (e1, exn, e2) ->
|
|
let+ e1 = t ctx e1 and+ e2 = t ctx e2 in
|
|
default_mark @@ ECatch (e1, exn, e2)
|
|
| ERaise _ | ELit _ | EOp _ -> Bindlib.box e
|
|
|
|
let rec iota_expr (_ : unit) (e : expr Pos.marked) : expr Pos.marked Bindlib.box =
|
|
let default_mark e' = Pos.mark (Pos.get_position e) e' in
|
|
match Pos.unmark e with
|
|
| EMatch ((EInj (e1, i, n', _ts), _), cases, n) when Dcalc.Ast.EnumName.compare n n' = 0 ->
|
|
let+ e1 = visitor_map iota_expr () e1 and+ case = visitor_map iota_expr () (List.nth cases i) in
|
|
default_mark @@ EApp (case, [ e1 ])
|
|
|
|
| EMatch (e', cases, n) when begin
|
|
cases
|
|
|> List.mapi (fun i (case, _pos) ->
|
|
match case with
|
|
| EInj (_ei, i', n', _ts') ->
|
|
i = i' && (* n = n' *) (Dcalc.Ast.EnumName.compare n n' = 0)
|
|
| _ -> false
|
|
)
|
|
|> List.for_all Fun.id
|
|
end ->
|
|
visitor_map iota_expr () e'
|
|
|
|
| _ -> visitor_map iota_expr () e
|
|
|
|
|
|
let rec beta_expr (_: unit) (e: expr Pos.marked): expr Pos.marked Bindlib.box =
|
|
let default_mark e' = Pos.same_pos_as e' e in
|
|
match Pos.unmark e with
|
|
| EApp (e1, args) ->
|
|
let+ e1 = visitor_map beta_expr () e1
|
|
and+ args = List.map (visitor_map beta_expr ()) args |> Bindlib.box_list in
|
|
begin match Pos.unmark e1 with
|
|
| EAbs ((binder, _pos_binder), _ts) ->
|
|
let _ : (_, _) Bindlib.mbinder = binder in
|
|
Bindlib.msubst binder (List.map fst args |> Array.of_list)
|
|
| _ ->
|
|
default_mark @@ EApp (e1, args)
|
|
end
|
|
| _ -> visitor_map beta_expr () e
|
|
|
|
let iota_optimizations (p : program) : program =
|
|
{ p with scopes = List.map (fun (var, e) -> (var, Bindlib.unbox (iota_expr () e))) p.scopes }
|
|
|
|
let _beta_optimizations (p: program): program =
|
|
{ p with scopes = List.map (fun (var, e) -> (var, Bindlib.unbox (beta_expr () e))) p.scopes }
|
|
|
|
let rec peephole_expr (_ : unit) (e : expr Pos.marked) : expr Pos.marked Bindlib.box =
|
|
let default_mark e' = Pos.mark (Pos.get_position e) e' in
|
|
|
|
match Pos.unmark e with
|
|
| EIfThenElse (e1, e2, e3) -> (
|
|
let+ e1 = visitor_map peephole_expr () e1
|
|
and+ e2 = visitor_map peephole_expr () e2
|
|
and+ e3 = visitor_map peephole_expr () e3 in
|
|
match Pos.unmark e1 with
|
|
| ELit (LBool true) | EApp ((EOp (Unop (Log _)), _), [ (ELit (LBool true), _) ]) -> e2
|
|
| ELit (LBool false) | EApp ((EOp (Unop (Log _)), _), [ (ELit (LBool false), _) ]) -> e3
|
|
| _ -> default_mark @@ EIfThenElse (e1, e2, e3))
|
|
| _ -> visitor_map peephole_expr () e
|
|
|
|
let peephole_optimizations (p : program) : program =
|
|
{ p with scopes = List.map (fun (var, e) -> (var, Bindlib.unbox (peephole_expr () e))) p.scopes }
|
|
|
|
let optimize_program (p : program) : program =
|
|
p
|
|
|> iota_optimizations
|
|
|> peephole_optimizations
|