Compile HandleDefaultOpt specially

This commit is contained in:
Denis Merigoux 2023-12-11 14:34:31 +01:00
parent f072694e50
commit 5b7470fd0d
No known key found for this signature in database
GPG Key ID: EE99DCFA365C3EE3
4 changed files with 37 additions and 10 deletions

View File

@ -288,9 +288,11 @@ module Passes = struct
~keep_special_ops :
Scalc.Ast.program * Scopelang.Dependency.TVertex.t list =
let prg, type_ordering =
lcalc options ~includes ~optimize ~check_invariants ~typed:Expr.typed
lcalc options ~includes ~optimize ~check_invariants ~typed:Expr.untyped
~avoid_exceptions ~closure_conversion
in
Message.emit_debug "Retyping lambda calculus...";
let prg = Typing.program ~leave_unresolved:true prg in
debug_pass_name "scalc";
Scalc.From_lcalc.translate_program ~keep_special_ops prg, type_ordering
end

View File

@ -29,6 +29,13 @@ type 'm ctxt = {
keep_special_ops : bool;
}
let unthunk e =
match Mark.remove e with
| EAbs { binder; tys = [(TLit TUnit, _)] } ->
let _, e = Bindlib.unmbind binder in
e
| _ -> failwith "should not happen"
(* Expressions can spill out side effect, hence this function also returns a
list of statements to be prepended before the expression is evaluated *)
let rec translate_expr (ctxt : 'm ctxt) (expr : 'm L.expr) : A.block * A.expr =
@ -132,7 +139,7 @@ let rec translate_expr (ctxt : 'm ctxt) (expr : 'm L.expr) : A.block * A.expr =
}
in
let tmp_stmts = translate_statements ctxt expr in
( ( A.SLocalDecl ((tmp_var, Expr.pos expr), (TAny, Expr.pos expr)),
( ( A.SLocalDecl ((tmp_var, Expr.pos expr), Expr.maybe_ty (Mark.get expr)),
Expr.pos expr )
:: tmp_stmts,
(A.EVar tmp_var, Expr.pos expr) )
@ -154,6 +161,8 @@ and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.expr) : A.block =
| EArray exceptions -> exceptions
| _ -> failwith "should not happen"
in
let just = unthunk just in
let cons = unthunk cons in
List.iter
(fun ex ->
Message.emit_debug "exception: %a" (Print.expr ~debug:true ()) ex)
@ -167,7 +176,14 @@ and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.expr) : A.block =
arg_stmts @ exceptions_stmts, new_arg :: new_exceptions)
([], []) exceptions
in
assert false
let just_stmts, new_just = translate_expr ctxt just in
let new_cons = translate_statements ctxt cons in
exceptions_stmts
@ just_stmts
@ [
( A.SSpecialOp (OHandleDefaultOpt (new_exceptions, new_just, new_cons)),
Expr.pos block_expr );
]
| EApp { f = EAbs { binder; tys }, binder_mark; args } ->
(* This defines multiple local variables at the time *)
let binder_pos = Expr.mark_pos binder_mark in

View File

@ -21,4 +21,4 @@ open Shared_ast
useful if the target language after Scalc does not support nested functions
like C. *)
val translate_program :
keep_special_ops:bool -> untyped Lcalc.Ast.program -> Ast.program
keep_special_ops:bool -> typed Lcalc.Ast.program -> Ast.program

View File

@ -152,20 +152,29 @@ let rec format_statement
(naked_expr, Mark.get stmt)
| SSwitch (e_switch, enum, arms) ->
let cons = EnumName.Map.find enum decl_ctx.ctx_enums in
Format.fprintf fmt "@[<v 0>%a @[<hov 2>%a@]%a@]%a" Print.keyword "switch"
Format.fprintf fmt "@[<v 0>%a @[<hov 2>%a@]%a@,@]%a" Print.keyword "switch"
(format_expr decl_ctx ~debug)
e_switch Print.punctuation ":"
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n")
(fun fmt ((case, _), (arm_block, payload_name)) ->
Format.fprintf fmt "%a %a%a@ %a @[<v 2>%a@ %a@]" Print.punctuation
"|" EnumConstructor.format case Print.punctuation ":"
format_var_name payload_name Print.punctuation ""
Format.fprintf fmt "@[<v 2>%a %a %a %a@ %a@]" Print.punctuation "|"
EnumConstructor.format case format_var_name payload_name
Print.punctuation ""
(format_block decl_ctx ~debug)
arm_block))
(List.combine (EnumConstructor.Map.bindings cons) arms)
| SSpecialOp (OHandleDefaultOpt (_exceptions, _just, _cons)) ->
Format.fprintf fmt "handle_default_opt ..."
| SSpecialOp (OHandleDefaultOpt (exceptions, just, cons)) ->
Format.fprintf fmt "@[<hov 2>%a %a%a%a@]@\n@[<hov 2>%a@ %a %a%a@\n%a@]"
Print.keyword "handle exceptions" Print.punctuation "["
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ";@ ")
(fun fmt e -> Format.fprintf fmt "%a" (format_expr decl_ctx ~debug) e))
exceptions Print.punctuation "]" Print.keyword "or if"
(format_expr decl_ctx ~debug)
just Print.keyword "then" Print.punctuation ":"
(format_block decl_ctx ~debug)
cons
and format_block
(decl_ctx : decl_ctx)