diff --git a/compiler/dcalc/from_scopelang.ml b/compiler/dcalc/from_scopelang.ml index b2ee26e0..54e1392c 100644 --- a/compiler/dcalc/from_scopelang.ml +++ b/compiler/dcalc/from_scopelang.ml @@ -642,14 +642,14 @@ let translate_rule ( (fun next -> Bindlib.box_apply2 (fun next merged_expr -> - ScopeLet - { - scope_let_next = next; - scope_let_typ = tau; - scope_let_expr = merged_expr; - scope_let_kind = ScopeVarDefinition; - scope_let_pos = Mark.get a; - }) + Cons + ( { + scope_let_typ = tau; + scope_let_expr = merged_expr; + scope_let_kind = ScopeVarDefinition; + scope_let_pos = Mark.get a; + }, + next )) (Bindlib.bind_var a_var next) (Expr.Box.lift merged_expr)), { @@ -691,14 +691,14 @@ let translate_rule ( (fun next -> Bindlib.box_apply2 (fun next thunked_or_nonempty_new_e -> - ScopeLet - { - scope_let_next = next; - scope_let_pos = Mark.get a_name; - scope_let_typ = input_var_typ (Mark.remove tau) a_io; - scope_let_expr = thunked_or_nonempty_new_e; - scope_let_kind = SubScopeVarDefinition; - }) + Cons + ( { + scope_let_pos = Mark.get a_name; + scope_let_typ = input_var_typ (Mark.remove tau) a_io; + scope_let_expr = thunked_or_nonempty_new_e; + scope_let_kind = SubScopeVarDefinition; + }, + next )) (Bindlib.bind_var a_var next) (Expr.Box.lift thunked_or_nonempty_new_e)), { @@ -836,14 +836,14 @@ let translate_rule let call_scope_let next = Bindlib.box_apply2 (fun next call_expr -> - ScopeLet - { - scope_let_next = next; - scope_let_pos = pos_sigma; - scope_let_kind = CallingSubScope; - scope_let_typ = result_tuple_typ; - scope_let_expr = call_expr; - }) + Cons + ( { + scope_let_pos = pos_sigma; + scope_let_kind = CallingSubScope; + scope_let_typ = result_tuple_typ; + scope_let_expr = call_expr; + }, + next )) (Bindlib.bind_var result_tuple_var next) (Expr.Box.lift call_expr) in @@ -856,17 +856,17 @@ let translate_rule in Bindlib.box_apply2 (fun next r -> - ScopeLet - { - scope_let_next = next; - scope_let_pos = pos_sigma; - scope_let_typ = var_ctx.scope_var_typ, pos_sigma; - scope_let_kind = DestructuringSubScopeResults; - scope_let_expr = - ( EStructAccess - { name = called_scope_return_struct; e = r; field }, - mark_tany m pos_sigma ); - }) + Cons + ( { + scope_let_pos = pos_sigma; + scope_let_typ = var_ctx.scope_var_typ, pos_sigma; + scope_let_kind = DestructuringSubScopeResults; + scope_let_expr = + ( EStructAccess + { name = called_scope_return_struct; e = r; field }, + mark_tany m pos_sigma ); + }, + next )) (Bindlib.bind_var v next) (Expr.Box.lift (Expr.make_var result_tuple_var (mark_tany m pos_sigma)))) @@ -892,17 +892,17 @@ let translate_rule ( (fun next -> Bindlib.box_apply2 (fun next new_e -> - ScopeLet - { - scope_let_next = next; - scope_let_pos; - scope_let_typ; - scope_let_expr = - Mark.add - (Expr.map_ty (fun _ -> scope_let_typ) (Mark.get e)) - (EAssert new_e); - scope_let_kind = Assertion; - }) + Cons + ( { + scope_let_pos; + scope_let_typ; + scope_let_expr = + Mark.add + (Expr.map_ty (fun _ -> scope_let_typ) (Mark.get e)) + (EAssert new_e); + scope_let_kind = Assertion; + }, + next )) (Bindlib.bind_var (Var.make "_") next) (Expr.Box.lift new_e)), ctx ) @@ -944,7 +944,7 @@ let translate_rules in ( scope_lets (Bindlib.box_apply - (fun return_exp -> Result return_exp) + (fun return_exp -> Last return_exp) (Expr.Box.lift return_exp)), new_ctx ) @@ -1042,18 +1042,18 @@ let translate_scope_decl in Bindlib.box_apply2 (fun next r -> - ScopeLet - { - scope_let_kind = DestructuringInputStruct; - scope_let_next = next; - scope_let_pos = pos_sigma; - scope_let_typ = - input_var_typ var_ctx.scope_var_typ var_ctx.scope_var_io; - scope_let_expr = - ( EStructAccess - { name = scope_input_struct_name; e = r; field }, - mark_tany scope_mark pos_sigma ); - }) + Cons + ( { + scope_let_kind = DestructuringInputStruct; + scope_let_pos = pos_sigma; + scope_let_typ = + input_var_typ var_ctx.scope_var_typ var_ctx.scope_var_io; + scope_let_expr = + ( EStructAccess + { name = scope_input_struct_name; e = r; field }, + mark_tany scope_mark pos_sigma ); + }, + next )) (Bindlib.bind_var v next) (Expr.Box.lift (Expr.make_var scope_input_var (mark_tany scope_mark pos_sigma)))) @@ -1182,7 +1182,7 @@ let translate_program (prgm : 'm Scopelang.Ast.program) : 'm Ast.program = ending with the top-level scope. The decl_ctx is filled in left-to-right order, then the chained scopes aggregated from the right. *) let rec translate_defs = function - | [] -> Bindlib.box Nil + | [] -> Bindlib.box (Last ()) | def :: next -> let dvar, def = match def with diff --git a/compiler/dcalc/invariants.ml b/compiler/dcalc/invariants.ml index 0ff015a0..38f6ad18 100644 --- a/compiler/dcalc/invariants.ml +++ b/compiler/dcalc/invariants.ml @@ -22,38 +22,24 @@ type invariant_status = Fail | Pass | Ignore type invariant_expr = decl_ctx -> typed expr -> invariant_status let check_invariant (inv : string * invariant_expr) (p : typed program) : bool = - (* TODO: add a Program.fold_left_map_exprs to get rid of the mutable - reference *) - let result = ref true in let name, inv = inv in - let total = ref 0 in - let ok = ref 0 in - let p' = - Program.map_exprs p ~varf:Fun.id ~f:(fun e -> + let result, total, ok = + Program.fold_exprs p ~init:(true, 0, 0) ~f:(fun acc e _ty -> (* let currente = e in *) - let rec f e = - let r = - match inv p.decl_ctx e with - | Ignore -> true - | Fail -> - Message.raise_spanned_error (Expr.pos e) - "@[Invariant @{%s@} failed.@,%a@]" name - (Print.expr ()) e - | Pass -> - incr ok; - incr total; - true - in - Expr.map_gather e ~acc:r ~join:( && ) ~f + let rec f e (result, total, ok) = + let result, total, ok = Expr.shallow_fold f e (result, total, ok) in + match inv p.decl_ctx e with + | Ignore -> result, total, ok + | Fail -> + Message.raise_spanned_error (Expr.pos e) + "@[Invariant @{%s@} failed.@,%a@]" name + (Print.expr ()) e + | Pass -> result, total + 1, ok + 1 in - - let res, e' = f e in - result := res && !result; - e') + f e acc) in - assert (Bindlib.free_vars p' = Bindlib.empty_ctxt); - Message.emit_debug "Invariant %s checked.@ result: [%d/%d]" name !ok !total; - !result + Message.emit_debug "Invariant %s checked.@ result: [%d/%d]" name ok total; + result (* Structural invariant: no default can have as type A -> B *) let invariant_default_no_arrow () : string * invariant_expr = diff --git a/compiler/driver.ml b/compiler/driver.ml index 520c2f4e..39254aef 100644 --- a/compiler/driver.ml +++ b/compiler/driver.ml @@ -586,14 +586,12 @@ module Commands = struct let scope_uid = get_scope_uid prg.decl_ctx scope in Print.scope ~debug:options.Cli.debug prg.decl_ctx fmt ( scope_uid, - Option.get - (Scope.fold_left ~init:None - ~f:(fun acc def _ -> - match def with - | ScopeDef (name, body) when ScopeName.equal name scope_uid -> - Some body - | _ -> acc) - prg.code_items) ); + BoundList.find + ~f:(function + | ScopeDef (name, body) when ScopeName.equal name scope_uid -> + Some body + | _ -> None) + prg.code_items ); Format.pp_print_newline fmt () | None -> let scope_uid = get_random_scope_uid prg.decl_ctx in diff --git a/compiler/lcalc/closure_conversion.ml b/compiler/lcalc/closure_conversion.ml index 2990b6f5..a460cbea 100644 --- a/compiler/lcalc/closure_conversion.ml +++ b/compiler/lcalc/closure_conversion.ml @@ -286,10 +286,9 @@ let rec transform_closures_expr : new_e1 call_expr (Expr.pos e) ) | _ -> . -(* Here I have to reimplement Scope.map_exprs_in_lets because I'm changing the - type *) +(* Can't reuse Scope.map because we inspect the bind variables *) let transform_closures_scope_let ctx scope_body_expr = - Scope.fold_right_lets + BoundList.fold_right ~f:(fun scope_let var_next acc -> let _free_vars, new_scope_let_expr = (transform_closures_expr @@ -298,13 +297,13 @@ let transform_closures_scope_let ctx scope_body_expr = in Bindlib.box_apply2 (fun scope_let_next scope_let_expr -> - ScopeLet - { - scope_let with - scope_let_next; - scope_let_expr; - scope_let_typ = Mark.copy scope_let.scope_let_typ TAny; - }) + Cons + ( { + scope_let with + scope_let_expr; + scope_let_typ = Mark.copy scope_let.scope_let_typ TAny; + }, + scope_let_next )) (Bindlib.bind_var var_next acc) (Expr.Box.lift new_scope_let_expr)) ~init:(fun res -> @@ -312,14 +311,12 @@ let transform_closures_scope_let ctx scope_body_expr = (* INVARIANT here: the result expr of a scope is simply a struct containing all output variables so nothing should be converted here, so no need to take into account free variables. *) - Bindlib.box_apply - (fun res -> Result res) - (Expr.Box.lift new_scope_let_expr)) + Bindlib.box_apply (fun e -> Last e) (Expr.Box.lift new_scope_let_expr)) scope_body_expr let transform_closures_program (p : 'm program) : 'm program Bindlib.box = - let _, new_code_items = - Scope.fold_map + let (), new_code_items = + BoundList.fold_map ~f:(fun toplevel_vars var code_item -> match code_item with | ScopeDef (name, body) -> @@ -346,6 +343,7 @@ let transform_closures_program (p : 'm program) : 'm program Bindlib.box = pos ) in ( Var.Map.add var ty toplevel_vars, + var, Bindlib.box_apply (fun scope_body_expr -> ScopeDef (name, { body with scope_body_expr })) @@ -361,6 +359,7 @@ let transform_closures_program (p : 'm program) : 'm program Bindlib.box = let _free_vars, new_expr = transform_closures_expr ctx expr in let new_binder = Expr.bind v new_expr in ( Var.Map.add var ty toplevel_vars, + var, Bindlib.box_apply (fun e -> Topdef (name, ty, e)) (Expr.Box.lift (Expr.eabs new_binder tys m)) ) @@ -373,11 +372,12 @@ let transform_closures_program (p : 'm program) : 'm program Bindlib.box = in let _free_vars, new_expr = transform_closures_expr ctx expr in ( Var.Map.add var ty toplevel_vars, + var, Bindlib.box_apply (fun e -> Topdef (name, (TAny, Mark.get ty), e)) (Expr.Box.lift new_expr) )) - ~varf:(fun v -> v) - Var.Map.empty p.code_items + ~last:(fun _ () -> (), Bindlib.box ()) + ~init:Var.Map.empty p.code_items in (* Now we need to further tweak [decl_ctx] because some of the user-defined types can have closures in them and these closured might have changed type. @@ -550,7 +550,7 @@ let rec hoist_closures_expr : (* Here I have to reimplement Scope.map_exprs_in_lets because I'm changing the type *) let hoist_closures_scope_let name_context scope_body_expr = - Scope.fold_right_lets + BoundList.fold_right ~f:(fun scope_let var_next (hoisted_closures, next_scope_lets) -> let new_hoisted_closures, new_scope_let_expr = (hoist_closures_expr (Bindlib.name_of var_next)) @@ -559,7 +559,7 @@ let hoist_closures_scope_let name_context scope_body_expr = ( new_hoisted_closures @ hoisted_closures, Bindlib.box_apply2 (fun scope_let_next scope_let_expr -> - ScopeLet { scope_let with scope_let_next; scope_let_expr }) + Cons ({ scope_let with scope_let_expr }, scope_let_next)) (Bindlib.bind_var var_next next_scope_lets) (Expr.Box.lift new_scope_let_expr) )) ~init:(fun res -> @@ -571,7 +571,7 @@ let hoist_closures_scope_let name_context scope_body_expr = no need to take into account free variables. *) ( hoisted_closures, Bindlib.box_apply - (fun res -> Result res) + (fun res -> Last res) (Expr.Box.lift new_scope_let_expr) )) scope_body_expr @@ -579,7 +579,7 @@ let rec hoist_closures_code_item_list (code_items : (lcalc, 'm) gexpr code_item_list) : (lcalc, 'm) gexpr code_item_list Bindlib.box = match code_items with - | Nil -> Bindlib.box Nil + | Last () -> Bindlib.box (Last ()) | Cons (code_item, next_code_items) -> let code_item_var, next_code_items = Bindlib.unbind next_code_items in let hoisted_closures, new_code_item = diff --git a/compiler/lcalc/compile_with_exceptions.ml b/compiler/lcalc/compile_with_exceptions.ml index 8e1b0607..b8f76363 100644 --- a/compiler/lcalc/compile_with_exceptions.ml +++ b/compiler/lcalc/compile_with_exceptions.ml @@ -92,59 +92,5 @@ and translate_expr (e : 'm D.expr) : 'm A.expr boxed = Expr.map ~f:translate_expr ~typ:translate_typ e | _ -> . -let translate_scope_body_expr (scope_body_expr : 'expr1 scope_body_expr) : - 'expr2 scope_body_expr Bindlib.box = - Scope.fold_right_lets - ~f:(fun scope_let var_next acc -> - Bindlib.box_apply2 - (fun scope_let_next scope_let_expr -> - ScopeLet - { - scope_let with - scope_let_next; - scope_let_expr; - scope_let_typ = translate_typ scope_let.scope_let_typ; - }) - (Bindlib.bind_var (Var.translate var_next) acc) - (Expr.Box.lift (translate_expr scope_let.scope_let_expr))) - ~init:(fun res -> - Bindlib.box_apply - (fun res -> Result res) - (Expr.Box.lift (translate_expr res))) - scope_body_expr - -let translate_code_items scopes = - let f = function - | ScopeDef (name, body) -> - let scope_input_var, scope_lets = Bindlib.unbind body.scope_body_expr in - let new_body_expr = translate_scope_body_expr scope_lets in - let new_body_expr = - Bindlib.bind_var (Var.translate scope_input_var) new_body_expr - in - Bindlib.box_apply - (fun scope_body_expr -> ScopeDef (name, { body with scope_body_expr })) - new_body_expr - | Topdef (name, typ, expr) -> - Bindlib.box_apply - (fun e -> Topdef (name, typ, e)) - (Expr.Box.lift (translate_expr expr)) - in - Scope.map ~f ~varf:Var.translate scopes - let translate_program (prg : 'm D.program) : 'm A.program = - let code_items = Bindlib.unbox (translate_code_items prg.code_items) in - let ctx_enums = - EnumName.Map.map - (EnumConstructor.Map.map translate_typ) - prg.decl_ctx.ctx_enums - in - let ctx_structs = - StructName.Map.map - (StructField.Map.map translate_typ) - prg.decl_ctx.ctx_structs - in - { - prg with - code_items; - decl_ctx = { prg.decl_ctx with ctx_enums; ctx_structs }; - } + Program.map_exprs prg ~typ:translate_typ ~varf:Var.translate ~f:translate_expr diff --git a/compiler/lcalc/compile_without_exceptions.ml b/compiler/lcalc/compile_without_exceptions.ml index a8d5be69..13a962fe 100644 --- a/compiler/lcalc/compile_without_exceptions.ml +++ b/compiler/lcalc/compile_without_exceptions.ml @@ -123,60 +123,5 @@ and translate_expr (e : 'm D.expr) : 'm A.expr boxed = Expr.map ~f:translate_expr ~typ:translate_typ e | _ -> . -let translate_scope_body_expr - (scope_body_expr : (dcalc, 'm) gexpr scope_body_expr) : - (lcalc, 'm) gexpr scope_body_expr Bindlib.box = - Scope.fold_right_lets - ~f:(fun scope_let var_next acc -> - Bindlib.box_apply2 - (fun scope_let_next scope_let_expr -> - ScopeLet - { - scope_let with - scope_let_next; - scope_let_expr; - scope_let_typ = translate_typ scope_let.scope_let_typ; - }) - (Bindlib.bind_var (Var.translate var_next) acc) - (Expr.Box.lift (translate_expr scope_let.scope_let_expr))) - ~init:(fun res -> - Bindlib.box_apply - (fun res -> Result res) - (Expr.Box.lift (translate_expr res))) - scope_body_expr - -let translate_code_items scopes = - let f = function - | ScopeDef (name, body) -> - let scope_input_var, scope_lets = Bindlib.unbind body.scope_body_expr in - let new_body_expr = translate_scope_body_expr scope_lets in - let new_body_expr = - Bindlib.bind_var (Var.translate scope_input_var) new_body_expr - in - Bindlib.box_apply - (fun scope_body_expr -> ScopeDef (name, { body with scope_body_expr })) - new_body_expr - | Topdef (name, typ, expr) -> - Bindlib.box_apply - (fun e -> Topdef (name, typ, e)) - (Expr.Box.lift (translate_expr expr)) - in - Scope.map ~f ~varf:Var.translate scopes - let translate_program (prg : 'm D.program) : 'm A.program = - let code_items = Bindlib.unbox (translate_code_items prg.code_items) in - let ctx_enums = - EnumName.Map.map - (EnumConstructor.Map.map translate_typ) - prg.decl_ctx.ctx_enums - in - let ctx_structs = - StructName.Map.map - (StructField.Map.map translate_typ) - prg.decl_ctx.ctx_structs - in - { - prg with - code_items; - decl_ctx = { prg.decl_ctx with ctx_enums; ctx_structs }; - } + Program.map_exprs prg ~typ:translate_typ ~varf:Var.translate ~f:translate_expr diff --git a/compiler/lcalc/monomorphize.ml b/compiler/lcalc/monomorphize.ml index 13c94cd3..bf2115e8 100644 --- a/compiler/lcalc/monomorphize.ml +++ b/compiler/lcalc/monomorphize.ml @@ -46,6 +46,9 @@ type monomorphized_instances = { arrays : array_instance Type.Map.t; } +let empty_instances = + { options = Type.Map.empty; tuples = Type.Map.empty; arrays = Type.Map.empty } + let collect_monomorphized_instances (prg : typed program) : monomorphized_instances = let option_instances_counter = ref 0 in @@ -157,23 +160,8 @@ let collect_monomorphized_instances (prg : typed program) : Expr.shallow_fold collect_expr e (collect_typ acc (Expr.ty e)) in let acc = - Scope.fold_left - ~init: - { - options = Type.Map.empty; - tuples = Type.Map.empty; - arrays = Type.Map.empty; - } - ~f:(fun acc item _ -> - match item with - | Topdef (_, typ, e) -> collect_typ (collect_expr e acc) typ - | ScopeDef (_, body) -> - let _, body = Bindlib.unbind body.scope_body_expr in - Scope.fold_left_lets ~init:acc - ~f:(fun acc { scope_let_typ; scope_let_expr; _ } _ -> - collect_typ (collect_expr scope_let_expr acc) scope_let_typ) - body) - prg.code_items + Scope.fold_exprs prg.code_items ~init:empty_instances ~f:(fun acc e typ -> + collect_typ (collect_expr e acc) typ) in EnumName.Map.fold (fun _ constructors acc -> @@ -301,104 +289,67 @@ let rec monomorphize_expr let program (prg : typed program) : typed program * Scopelang.Dependency.TVertex.t list = let monomorphized_instances = collect_monomorphized_instances prg in + let decl_ctx = prg.decl_ctx in (* First we remove the polymorphic option type *) - let prg = - { - prg with - decl_ctx = - { - prg.decl_ctx with - ctx_enums = - EnumName.Map.remove Expr.option_enum prg.decl_ctx.ctx_enums; - }; - } - in + let ctx_enums = EnumName.Map.remove Expr.option_enum decl_ctx.ctx_enums in + let ctx_structs = decl_ctx.ctx_structs in (* Then we replace all hardcoded types and expressions with the monomorphized instances *) - let prg = - { - prg with - decl_ctx = - { - prg.decl_ctx with - ctx_enums = - EnumName.Map.map - (EnumConstructor.Map.map - (monomorphize_typ monomorphized_instances)) - prg.decl_ctx.ctx_enums; - ctx_structs = - StructName.Map.map - (StructField.Map.map (monomorphize_typ monomorphized_instances)) - prg.decl_ctx.ctx_structs; - }; - } + let ctx_enums = + EnumName.Map.map + (EnumConstructor.Map.map (monomorphize_typ monomorphized_instances)) + ctx_enums + in + let ctx_structs = + StructName.Map.map + (StructField.Map.map (monomorphize_typ monomorphized_instances)) + ctx_structs in (* Then we augment the [decl_ctx] with the monomorphized instances *) - let prg = - { - prg with - decl_ctx = - { - prg.decl_ctx with - ctx_enums = - Type.Map.fold - (fun _ (option_instance : option_instance) (ctx_enums : enum_ctx) -> - EnumName.Map.add option_instance.name - (EnumConstructor.Map.add option_instance.none_cons - (TLit TUnit, Pos.no_pos) - (EnumConstructor.Map.singleton option_instance.some_cons - (monomorphize_typ monomorphized_instances - (option_instance.some_typ, Pos.no_pos)))) - ctx_enums) - monomorphized_instances.options prg.decl_ctx.ctx_enums; - ctx_structs = - Type.Map.fold - (fun _ (tuple_instance : tuple_instance) - (ctx_structs : struct_ctx) -> - StructName.Map.add tuple_instance.name - (List.fold_left - (fun acc (field, typ) -> - StructField.Map.add field - (monomorphize_typ monomorphized_instances - (typ, Pos.no_pos)) - acc) - StructField.Map.empty tuple_instance.fields) - ctx_structs) - monomorphized_instances.tuples - (Type.Map.fold - (fun _ (array_instance : array_instance) - (ctx_structs : struct_ctx) -> - StructName.Map.add array_instance.name - (StructField.Map.add array_instance.content_field - ( TArray - (monomorphize_typ monomorphized_instances - (array_instance.content_typ, Pos.no_pos)), - Pos.no_pos ) - (StructField.Map.singleton array_instance.len_field - (TLit TInt, Pos.no_pos))) - ctx_structs) - monomorphized_instances.arrays prg.decl_ctx.ctx_structs); - }; - } + let ctx_enums = + Type.Map.fold + (fun _ (option_instance : option_instance) (ctx_enums : enum_ctx) -> + EnumName.Map.add option_instance.name + (EnumConstructor.Map.add option_instance.none_cons + (TLit TUnit, Pos.no_pos) + (EnumConstructor.Map.singleton option_instance.some_cons + (monomorphize_typ monomorphized_instances + (option_instance.some_typ, Pos.no_pos)))) + ctx_enums) + monomorphized_instances.options ctx_enums in + let ctx_structs = + Type.Map.fold + (fun _ (tuple_instance : tuple_instance) (ctx_structs : struct_ctx) -> + StructName.Map.add tuple_instance.name + (List.fold_left + (fun acc (field, typ) -> + StructField.Map.add field + (monomorphize_typ monomorphized_instances (typ, Pos.no_pos)) + acc) + StructField.Map.empty tuple_instance.fields) + ctx_structs) + monomorphized_instances.tuples + (Type.Map.fold + (fun _ (array_instance : array_instance) (ctx_structs : struct_ctx) -> + StructName.Map.add array_instance.name + (StructField.Map.add array_instance.content_field + ( TArray + (monomorphize_typ monomorphized_instances + (array_instance.content_typ, Pos.no_pos)), + Pos.no_pos ) + (StructField.Map.singleton array_instance.len_field + (TLit TInt, Pos.no_pos))) + ctx_structs) + monomorphized_instances.arrays ctx_structs) + in + let decl_ctx = { decl_ctx with ctx_structs; ctx_enums } in let code_items = Bindlib.unbox - @@ Scope.map - ~f:(fun code_item -> - match code_item with - | Topdef (name, typ, e) -> Bindlib.box (Topdef (name, typ, e)) - | ScopeDef (name, body) -> - let s_var, scope_body = Bindlib.unbind body.scope_body_expr in - Bindlib.box_apply - (fun scope_body_expr -> - ScopeDef (name, { body with scope_body_expr })) - (Bindlib.bind_var s_var - (Scope.map_exprs_in_lets ~varf:Fun.id - ~transform_types:(monomorphize_typ monomorphized_instances) - ~f:(monomorphize_expr monomorphized_instances) - scope_body))) - ~varf:Fun.id prg.code_items + @@ Scope.map_exprs prg.code_items + ~typ:(monomorphize_typ monomorphized_instances) + ~varf:Fun.id + ~f:(monomorphize_expr monomorphized_instances) in - ( { prg with code_items }, - Scopelang.Dependency.check_type_cycles prg.decl_ctx.ctx_structs - prg.decl_ctx.ctx_enums ) + ( { prg with decl_ctx; code_items }, + Scopelang.Dependency.check_type_cycles ctx_structs ctx_enums ) diff --git a/compiler/lcalc/to_ocaml.ml b/compiler/lcalc/to_ocaml.ml index e549bb1f..c1e94115 100644 --- a/compiler/lcalc/to_ocaml.ml +++ b/compiler/lcalc/to_ocaml.ml @@ -448,8 +448,7 @@ let rec format_expr (ctx : decl_ctx) (fmt : Format.formatter) (e : 'm expr) : | ERaise exc -> Format.fprintf fmt "raise@ %a" format_exception (exc, Expr.pos e) | ECatch { body; exn; handler } -> - Format.fprintf fmt - "@,@[@[try@ %a@]@ with@]@ @[%a@ ->@ %a@]" + Format.fprintf fmt "@[@[try@ %a@]@ with@]@ @[%a@ ->@ %a@]" format_with_parens body format_exception (exn, Expr.pos e) format_with_parens handler @@ -569,48 +568,57 @@ let rename_vars e = (rename_vars ~exclude:ocaml_keywords ~reset_context_for_closed_terms:true ~skip_constant_binders:true ~constant_binder_name:(Some "_") e)) -let format_expr ctx fmt e = format_expr ctx fmt (rename_vars e) +let format_expr ctx fmt e = + Format.pp_open_vbox fmt 0; + format_expr ctx fmt (rename_vars e); + Format.pp_close_box fmt () -let rec format_scope_body_expr +let format_scope_body_expr (ctx : decl_ctx) (fmt : Format.formatter) (scope_lets : 'm Ast.expr scope_body_expr) : unit = - match scope_lets with - | Result e -> format_expr ctx fmt e - | ScopeLet scope_let -> - let scope_let_var, scope_let_next = - Bindlib.unbind scope_let.scope_let_next - in - Format.fprintf fmt "@[let %a: %a = %a in@]@\n%a" format_var - scope_let_var format_typ scope_let.scope_let_typ (format_expr ctx) - scope_let.scope_let_expr - (format_scope_body_expr ctx) - scope_let_next + Format.pp_open_vbox fmt 0; + let last_e = + BoundList.iter + ~f:(fun scope_let_var scope_let -> + Format.fprintf fmt "@[@[let %a: %a =@ %a@ @]in@]@," + format_var scope_let_var format_typ scope_let.scope_let_typ + (format_expr ctx) scope_let.scope_let_expr) + scope_lets + in + format_expr ctx fmt last_e; + Format.pp_close_box fmt () let format_code_items (ctx : decl_ctx) (fmt : Format.formatter) (code_items : 'm Ast.expr code_item_list) : ('m Ast.expr Var.t * 'm Ast.expr code_item) String.Map.t = - Scope.fold_left - ~f:(fun bnd item var -> - match item with - | Topdef (name, typ, e) -> - Format.fprintf fmt "@\n@\n@[let %a : %a =@\n%a@]" format_var var - format_typ typ (format_expr ctx) e; - String.Map.add (TopdefName.to_string name) (var, item) bnd - | ScopeDef (name, body) -> - let scope_input_var, scope_body_expr = - Bindlib.unbind body.scope_body_expr - in - Format.fprintf fmt "@\n@\n@[let %a (%a: %a.t) : %a.t =@\n%a@]" - format_var var format_var scope_input_var format_to_module_name - (`Sname body.scope_body_input_struct) format_to_module_name - (`Sname body.scope_body_output_struct) - (format_scope_body_expr ctx) - scope_body_expr; - String.Map.add (ScopeName.to_string name) (var, item) bnd) - ~init:String.Map.empty code_items + Format.pp_open_vbox fmt 0; + let var_bindings, () = + BoundList.fold_left + ~f:(fun bnd item var -> + match item with + | Topdef (name, typ, e) -> + Format.fprintf fmt "@,@[@[let %a : %a =@]@ %a@]@," + format_var var format_typ typ (format_expr ctx) e; + String.Map.add (TopdefName.to_string name) (var, item) bnd + | ScopeDef (name, body) -> + let scope_input_var, scope_body_expr = + Bindlib.unbind body.scope_body_expr + in + Format.fprintf fmt + "@,@[@[let %a (%a: %a.t) : %a.t =@]@ %a@]@," format_var + var format_var scope_input_var format_to_module_name + (`Sname body.scope_body_input_struct) format_to_module_name + (`Sname body.scope_body_output_struct) + (format_scope_body_expr ctx) + scope_body_expr; + String.Map.add (ScopeName.to_string name) (var, item) bnd) + ~init:String.Map.empty code_items + in + Format.pp_close_box fmt (); + var_bindings let format_scope_exec (ctx : decl_ctx) diff --git a/compiler/scalc/from_lcalc.ml b/compiler/scalc/from_lcalc.ml index 902989c4..522f847c 100644 --- a/compiler/scalc/from_lcalc.ml +++ b/compiler/scalc/from_lcalc.ml @@ -574,7 +574,7 @@ let rec translate_scope_body_expr (func_dict : ('m L.expr, A.FuncName.t) Var.Map.t) (scope_expr : 'm L.expr scope_body_expr) : A.block = match scope_expr with - | Result e -> + | Last e -> let block, new_e = translate_expr { @@ -587,8 +587,8 @@ let rec translate_scope_body_expr e in block @ [A.SReturn (Mark.remove new_e), Mark.get new_e] - | ScopeLet scope_let -> - let let_var, scope_let_next = Bindlib.unbind scope_let.scope_let_next in + | Cons (scope_let, next_bnd) -> + let let_var, scope_let_next = Bindlib.unbind next_bnd in let let_var_id = A.VarName.fresh (Bindlib.name_of let_var, scope_let.scope_let_pos) in @@ -637,8 +637,8 @@ let rec translate_scope_body_expr let translate_program ~(config : translation_config) (p : 'm L.program) : A.program = - let _, _, rev_items = - Scope.fold_left + let (_, _, rev_items), () = + BoundList.fold_left ~f:(fun (func_dict, var_dict, rev_items) code_item var -> match code_item with | ScopeDef (name, body) -> diff --git a/compiler/shared_ast/boundList.ml b/compiler/shared_ast/boundList.ml new file mode 100644 index 00000000..faeba5de --- /dev/null +++ b/compiler/shared_ast/boundList.ml @@ -0,0 +1,118 @@ +(* This file is part of the Catala compiler, a specification language for tax + and social benefits computation rules. Copyright (C) 2020-2022 Inria, + contributor: Louis Gesbert + + Licensed under the Apache License, Version 2.0 (the "License"); you may not + use this file except in compliance with the License. You may obtain a copy of + the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + License for the specific language governing permissions and limitations under + the License. *) + +open Definitions + +type ('e, 'elt, 'last) t = ('e, 'elt, 'last) bound_list = + | Last of 'last + | Cons of 'elt * ('e, ('e, 'elt, 'last) t) binder + +let rec last = function + | Last e -> e + | Cons (_, bnd) -> + let _, next = Bindlib.unbind bnd in + last next + +let rec iter ~f = function + | Last l -> l + | Cons (item, next_bind) -> + let var, next = Bindlib.unbind next_bind in + f var item; + iter ~f next + +let rec find ~f = function + | Last _ -> raise Not_found + | Cons (item, next_bind) -> ( + match f item with + | Some r -> r + | None -> + let _, next = Bindlib.unbind next_bind in + find ~f next) + +let rec fold_left ~f ~init = function + | Last l -> init, l + | Cons (item, next_bind) -> + let var, next = Bindlib.unbind next_bind in + fold_left ~f ~init:(f init item var) next + +let rec fold_right ~f ~init = function + | Last l -> init l + | Cons (item, next_bind) -> + let var_next, next = Bindlib.unbind next_bind in + let result_next = fold_right ~f ~init next in + f item var_next result_next + +let rec fold_lr ~top ~down ~bottom ~up = function + | Last l -> bottom l top + | Cons (item, next_bind) -> + let var, next = Bindlib.unbind next_bind in + let top = down var item top in + let bottom = fold_lr ~down ~up ~top ~bottom next in + up var item bottom + +let rec map ~f ~last = function + | Last l -> Bindlib.box_apply (fun l -> Last l) (last l) + | Cons (item, next_bind) -> + let var, next = Bindlib.unbind next_bind in + let var, item = f var item in + let next_bind = Bindlib.bind_var var (map ~f ~last next) in + Bindlib.box_apply2 + (fun item next_bind -> Cons (item, next_bind)) + item next_bind + +let rec fold_map ~f ~last ~init:ctx = function + | Last l -> + let ret, l = last ctx l in + ret, Bindlib.box_apply (fun l -> Last l) l + | Cons (item, next_bind) -> + let var, next = Bindlib.unbind next_bind in + let ctx, var, item = f ctx var item in + let ctx, next = fold_map ~f ~last ~init:ctx next in + let next_bind = Bindlib.bind_var var next in + ( ctx, + Bindlib.box_apply2 + (fun item next_bind -> Cons (item, next_bind)) + item next_bind ) + +let rec fold_left2 ~f ~init a b = + match a, b with + | Last l1, Last l2 -> init, (l1, l2) + | Cons (item1, next_bind1), Cons (item2, next_bind2) -> + let var, next1, next2 = Bindlib.unbind2 next_bind1 next_bind2 in + fold_left2 ~f ~init:(f init item1 item2 var) next1 next2 + | _ -> invalid_arg "fold_left2" + +let rec equal ~f ~last a b = + match a, b with + | Last l1, Last l2 -> last l1 l2 + | Cons (item1, next_bind1), Cons (item2, next_bind2) -> + f item1 item2 + && + let _, next1, next2 = Bindlib.unbind2 next_bind1 next_bind2 in + equal ~f ~last next1 next2 + | _ -> false + +let rec compare ~f ~last a b = + match a, b with + | Last l1, Last l2 -> last l1 l2 + | Cons (item1, next_bind1), Cons (item2, next_bind2) -> ( + match f item1 item2 with + | 0 -> + let _, next1, next2 = Bindlib.unbind2 next_bind1 next_bind2 in + compare ~f ~last next1 next2 + | n -> n) + | Last _, Cons _ -> -1 + | Cons _, Last _ -> 1 diff --git a/compiler/shared_ast/boundList.mli b/compiler/shared_ast/boundList.mli new file mode 100644 index 00000000..33089802 --- /dev/null +++ b/compiler/shared_ast/boundList.mli @@ -0,0 +1,92 @@ +(* This file is part of the Catala compiler, a specification language for tax + and social benefits computation rules. Copyright (C) 2020-2022 Inria, + contributor: Louis Gesbert + + Licensed under the Apache License, Version 2.0 (the "License"); you may not + use this file except in compliance with the License. You may obtain a copy of + the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + License for the specific language governing permissions and limitations under + the License. *) + +(** Bound lists are non-empty linked lists where each element is a binder onto + the next. They are useful for ordered program definitions, like nested + let-ins. + + [let a = e1 in e2] is thus represented as [Cons (e1, {a. Last e2})]. + + The following provides a few utility functions for their traversal and + manipulation. In particular, [map] functions take care of unbinding, then + properly rebinding the variables. *) + +open Definitions + +type ('e, 'elt, 'last) t = ('e, 'elt, 'last) bound_list = + | Last of 'last + | Cons of 'elt * ('e, ('e, 'elt, 'last) t) binder + +val last : (_, _, 'a) t -> 'a +val iter : f:('e Var.t -> 'elt -> unit) -> ('e, 'elt, 'last) t -> 'last +val find : f:('elt -> 'a option) -> (_, 'elt, _) t -> 'a + +val fold_left : + f:('acc -> 'elt -> 'e Var.t -> 'acc) -> + init:'acc -> + ('e, 'elt, 'last) t -> + 'acc * 'last + +val fold_left2 : + f:('acc -> 'elt1 -> 'elt2 -> 'e Var.t -> 'acc) -> + init:'acc -> + ('e, 'elt1, 'last1) t -> + ('e, 'elt2, 'last2) t -> + 'acc * ('last1 * 'last2) + +val fold_right : + f:('elt -> 'e Var.t -> 'acc -> 'acc) -> + init:('last -> 'acc) -> + ('e, 'elt, 'last) t -> + 'acc + +val fold_lr : + top:'dacc -> + down:('e Var.t -> 'elt -> 'dacc -> 'dacc) -> + bottom:('last -> 'dacc -> 'uacc) -> + up:('e Var.t -> 'elt -> 'uacc -> 'uacc) -> + ('e, 'elt, 'last) t -> + 'uacc +(** Bi-directional fold: [down] accumulates downwards, starting from [top]; upon + reaching [last], [bottom] is called; then [up] accumulates on the way back + up *) + +val map : + f:('e1 Var.t -> 'elt1 -> 'e2 Var.t * 'elt2 Bindlib.box) -> + last:('last1 -> 'last2 Bindlib.box) -> + ('e1, 'elt1, 'last1) t -> + ('e2, 'elt2, 'last2) t Bindlib.box + +val fold_map : + f:('ctx -> 'e1 Var.t -> 'elt1 -> 'ctx * 'e2 Var.t * 'elt2 Bindlib.box) -> + last:('ctx -> 'last1 -> 'ret * 'last2 Bindlib.box) -> + init:'ctx -> + ('e1, 'elt1, 'last1) t -> + 'ret * ('e2, 'elt2, 'last2) t Bindlib.box + +val equal : + f:('elt -> 'elt -> bool) -> + last:('last -> 'last -> bool) -> + (('e, 'elt, 'last) t as 'l) -> + 'l -> + bool + +val compare : + f:('elt -> 'elt -> int) -> + last:('last -> 'last -> int) -> + (('e, 'elt, 'last) t as 'l) -> + 'l -> + int diff --git a/compiler/shared_ast/definitions.ml b/compiler/shared_ast/definitions.ml index e007471b..5e2d8184 100644 --- a/compiler/shared_ast/definitions.ml +++ b/compiler/shared_ast/definitions.ml @@ -615,6 +615,12 @@ type ('e, 'b) mbinder = (('a, 'm) naked_gexpr, 'b) Bindlib.mbinder Note that this structure is at the moment only relevant for [dcalc] and [lcalc], as [scopelang] has its own scope structure, as the name implies. *) +(** A linked list, but with a binder for each element into the next: + [x := let a = e1 in e2] is thus [Cons (e1, {a. Cons (e2, {x. Nil})})] *) +type ('e, 'elt, 'last) bound_list = + | Last of 'last + | Cons of 'elt * ('e, ('e, 'elt, 'last) bound_list) binder + (** This kind annotation signals that the let-binding respects a structural invariant. These invariants concern the shape of the expression in the let-binding, and are documented below. *) @@ -632,21 +638,17 @@ type 'e scope_let = { scope_let_kind : scope_let_kind; scope_let_typ : typ; scope_let_expr : 'e; - scope_let_next : ('e, 'e scope_body_expr) binder; - (* todo ? Factorise the code_item _list type below and use it here *) scope_let_pos : Pos.t; } constraint 'e = ('a any, _) gexpr (** This type is parametrized by the expression type so it can be reused in later intermediate representations. *) +type 'e scope_body_expr = ('e, 'e scope_let, 'e) bound_list + constraint 'e = ('a any, _) gexpr (** A scope let-binding has all the information necessary to make a proper let-binding expression, plus an annotation for the kind of the let-binding that comes from the compilation of a {!module: Scopelang.Ast} statement. *) -and 'e scope_body_expr = - | Result of 'e - | ScopeLet of 'e scope_let - constraint 'e = ('a any, _) gexpr type 'e scope_body = { scope_body_input_struct : StructName.t; @@ -663,13 +665,7 @@ type 'e code_item = | ScopeDef of ScopeName.t * 'e scope_body | Topdef of TopdefName.t * typ * 'e -(** A chained list, but with a binder for each element into the next: - [x := let a - = e1 in e2] is thus [Cons (e1, {a. Cons (e2, {x. Nil})})] *) -type 'e code_item_list = - | Nil - | Cons of 'e code_item * ('e, 'e code_item_list) binder - +type 'e code_item_list = ('e, 'e code_item, unit) bound_list type struct_ctx = typ StructField.Map.t StructName.Map.t type enum_ctx = typ EnumConstructor.Map.t EnumName.Map.t diff --git a/compiler/shared_ast/expr.ml b/compiler/shared_ast/expr.ml index 41d9f525..6ae3efcf 100644 --- a/compiler/shared_ast/expr.ml +++ b/compiler/shared_ast/expr.ml @@ -778,6 +778,8 @@ let rec free_vars : ('a, 't) gexpr -> ('a, 't) gexpr Var.Set.t = function let vs, body = Bindlib.unmbind binder in Array.fold_right Var.Set.remove vs (free_vars body) | e -> shallow_fold (fun e -> Var.Set.union (free_vars e)) e Var.Set.empty +(* Could also be done with [rebox] followed by [Bindlib.free_vars], if that + returned more than a context *) (* This function is first defined in [Print], only for dependency reasons *) let skip_wrappers : type a. (a, 'm) gexpr -> (a, 'm) gexpr = Print.skip_wrappers diff --git a/compiler/shared_ast/interpreter.ml b/compiler/shared_ast/interpreter.ml index 4a3bbbda..1e596e40 100644 --- a/compiler/shared_ast/interpreter.ml +++ b/compiler/shared_ast/interpreter.ml @@ -944,7 +944,10 @@ let interpret_program_lcalc p s : (Uid.MarkedString.info * ('a, 'm) gexpr) list in let to_interpret = Expr.make_app (Expr.box e) - [Expr.estruct ~name:s_in ~fields:application_term mark_e] + [ + Expr.estruct ~name:s_in ~fields:application_term + (Expr.map_ty (fun (_, pos) -> TStruct s_in, pos) mark_e); + ] [TStruct s_in, Expr.pos e] (Expr.pos e) in @@ -996,7 +999,10 @@ let interpret_program_dcalc p s : (Uid.MarkedString.info * ('a, 'm) gexpr) list in let to_interpret = Expr.make_app (Expr.box e) - [Expr.estruct ~name:s_in ~fields:application_term mark_e] + [ + Expr.estruct ~name:s_in ~fields:application_term + (Expr.map_ty (fun (_, pos) -> TStruct s_in, pos) mark_e); + ] [TStruct s_in, Expr.pos e] (Expr.pos e) in diff --git a/compiler/shared_ast/optimizations.ml b/compiler/shared_ast/optimizations.ml index 36dd93cc..6a3c670c 100644 --- a/compiler/shared_ast/optimizations.ml +++ b/compiler/shared_ast/optimizations.ml @@ -380,8 +380,7 @@ let optimize_expr : optimize_expr { decl_ctx } e let optimize_program (p : 'm program) : 'm program = - Bindlib.unbox - (Program.map_exprs ~f:(optimize_expr p.decl_ctx) ~varf:(fun v -> v) p) + Program.map_exprs ~f:(optimize_expr p.decl_ctx) ~varf:(fun v -> v) p let test_iota_reduction_1 () = let x = Var.make "x" in diff --git a/compiler/shared_ast/print.ml b/compiler/shared_ast/print.ml index 75848429..dda7fc0e 100644 --- a/compiler/shared_ast/print.ml +++ b/compiler/shared_ast/print.ml @@ -562,17 +562,19 @@ module ExprGen (C : EXPR_PARAM) = struct Format.fprintf fmt "@[%a @[%a@]@ @]%a@ %a" punctuation "λ" (Format.pp_print_list ~pp_sep:Format.pp_print_space (fun fmt (x, tau) -> - match tau with - | TLit TUnit, _ -> punctuation fmt "("; punctuation fmt ")" - | _ -> - punctuation fmt "("; - Format.pp_open_hvbox fmt 2; - var fmt x; - punctuation fmt ":"; - Format.pp_print_space fmt (); - typ_gen None ~colors fmt tau; - Format.pp_close_box fmt (); - punctuation fmt ")")) + match tau with + | TLit TUnit, _ -> + punctuation fmt "("; + punctuation fmt ")" + | _ -> + punctuation fmt "("; + Format.pp_open_hvbox fmt 2; + var fmt x; + punctuation fmt ":"; + Format.pp_print_space fmt (); + typ_gen None ~colors fmt tau; + Format.pp_close_box fmt (); + punctuation fmt ")")) xs_tau punctuation "→" (rhs expr) body | EAppOp { op = (Map | Filter) as op; args = [arg1; arg2]; _ } -> Format.fprintf fmt "@[%a %a@ %a@]" operator op (lhs exprc) arg1 @@ -710,17 +712,16 @@ module ExprGen (C : EXPR_PARAM) = struct | EAbs { binder; tys; _ }, _ -> let xs, body, bnd_ctx = Bindlib.unmbind_in bnd_ctx binder in let expr = exprb bnd_ctx in - let pp_args fmt = match tys with - | [TLit TUnit, _] -> () + let pp_args fmt = + match tys with + | [(TLit TUnit, _)] -> () | _ -> Format.pp_print_seq ~pp_sep:Format.pp_print_space var fmt (Array.to_seq xs); Format.pp_print_space fmt () in - Format.fprintf fmt "@[%a %t@ %t%a@ %a@]" punctuation - "|" pp_cons_name - pp_args - punctuation "→" (rhs expr) body + Format.fprintf fmt "@[%a %t@ %t%a@ %a@]" punctuation "|" + pp_cons_name pp_args punctuation "→" (rhs expr) body | e -> Format.fprintf fmt "@[%a %t@ %a@ %a@]" punctuation "|" pp_cons_name punctuation "→" (rhs exprc) e)) @@ -782,30 +783,22 @@ let scope_let_kind ?debug:(_debug = true) _ctx fmt k = | DestructuringSubScopeResults -> keyword fmt "sub_get" | Assertion -> keyword fmt "assert" -let[@ocamlformat "disable"] rec +let[@ocamlformat "disable"] scope_body_expr ?(debug = false) ctx fmt b : unit = - match b with - | Result e -> Format.fprintf fmt "%a %a" keyword "return" (expr ~debug ()) e - | ScopeLet - { - scope_let_kind = kind; - scope_let_typ; - scope_let_expr; - scope_let_next; - _; - } -> - let x, next = Bindlib.unbind scope_let_next in + let print_scope_let x sl = Format.fprintf fmt - "@[@[%a %a %a %a@ %a@ %a@]@ %a@;<1 -2>%a@]@,%a" + "@[@[%a %a %a %a@ %a@ %a@]@ %a@;<1 -2>%a@]@," keyword "let" - (scope_let_kind ~debug ctx) kind + (scope_let_kind ~debug ctx) sl.scope_let_kind (if debug then var_debug else var) x punctuation ":" - (typ ctx) scope_let_typ + (typ ctx) sl.scope_let_typ punctuation "=" - (expr ~debug ()) scope_let_expr + (expr ~debug ()) sl.scope_let_expr keyword "in" - (scope_body_expr ~debug ctx) next + in + let last = BoundList.iter ~f:print_scope_let b in + Format.fprintf fmt "%a %a" keyword "return" (expr ~debug ()) last let scope_body ?(debug = false) ctx fmt (n, l) : unit = let { @@ -936,16 +929,12 @@ let code_item ?(debug = false) ?name decl_ctx fmt c = "let topval" TopdefName.format n op_style ":" (typ decl_ctx) ty op_style "=" (expr ~debug ()) e -let rec code_item_list ?(debug = false) decl_ctx fmt c = - match c with - | Nil -> () - | Cons (c, b) -> - let x, cl = Bindlib.unbind b in - Format.fprintf fmt "%a @.%a" - (code_item ~debug ~name:(Format.asprintf "%a" var_debug x) decl_ctx) - c - (code_item_list ~debug decl_ctx) - cl +let code_item_list ?(debug = false) decl_ctx fmt c = + BoundList.iter c ~f:(fun x item -> + code_item ~debug + ~name:(Format.asprintf "%a" var_debug x) + decl_ctx fmt item; + Format.pp_print_newline fmt ()) let program ?(debug = false) fmt p = decl_ctx ~debug p.decl_ctx fmt p.decl_ctx; diff --git a/compiler/shared_ast/program.ml b/compiler/shared_ast/program.ml index c75c1b4c..6ce67cd8 100644 --- a/compiler/shared_ast/program.ml +++ b/compiler/shared_ast/program.ml @@ -17,16 +17,37 @@ open Definitions -let map_exprs ~f ~varf { code_items; decl_ctx; lang; module_name } = - Bindlib.box_apply - (fun code_items -> { code_items; decl_ctx; lang; module_name }) - (Scope.map_exprs ~f ~varf code_items) +let map_decl_ctx ~f ctx = + { + ctx with + ctx_enums = EnumName.Map.map (EnumConstructor.Map.map f) ctx.ctx_enums; + ctx_structs = StructName.Map.map (StructField.Map.map f) ctx.ctx_structs; + ctx_topdefs = TopdefName.Map.map f ctx.ctx_topdefs; + } -let fold_left_exprs ~f ~init { code_items; _ } = - Scope.fold_left ~f:(fun acc e _ -> f acc e) ~init code_items +let map_exprs ?typ ~f ~varf { code_items; decl_ctx; lang; module_name } = + let boxed_prg = + Bindlib.box_apply + (fun code_items -> + let decl_ctx = + match typ with None -> decl_ctx | Some f -> map_decl_ctx ~f decl_ctx + in + { code_items; decl_ctx; lang; module_name }) + (Scope.map_exprs ?typ ~f ~varf code_items) + in + assert (Bindlib.is_closed boxed_prg); + Bindlib.unbox boxed_prg -let fold_right_exprs ~f ~init { code_items; _ } = - Scope.fold_right ~f:(fun e _ acc -> f e acc) ~init code_items +let fold_left ~f ~init { code_items; _ } = + fst @@ BoundList.fold_left ~f:(fun acc e _ -> f acc e) ~init code_items + +let fold_exprs ~f ~init prg = Scope.fold_exprs ~f ~init prg.code_items + +let fold_right ~f ~init { code_items; _ } = + BoundList.fold_right + ~f:(fun e _ acc -> f e acc) + ~init:(fun () -> init) + code_items let empty_ctx = { @@ -42,56 +63,25 @@ let empty_ctx = let get_scope_body { code_items; _ } scope = match - Scope.fold_left ~init:None + BoundList.fold_left ~init:None ~f:(fun acc item _ -> match item with | ScopeDef (name, body) when ScopeName.equal scope name -> Some body | _ -> acc) code_items with - | None -> raise Not_found - | Some body -> body + | None, _ -> raise Not_found + | Some body, _ -> body let untype : 'm. ('a, 'm) gexpr program -> ('a, untyped) gexpr program = - fun prg -> Bindlib.unbox (map_exprs ~f:Expr.untype ~varf:Var.translate prg) + fun prg -> map_exprs ~f:Expr.untype ~varf:Var.translate prg -let rec find_scope name vars = function - | Nil -> raise Not_found - | Cons (ScopeDef (n, body), _) when ScopeName.equal name n -> - List.rev vars, body - | Cons (_, next_bind) -> - let var, next = Bindlib.unbind next_bind in - find_scope name (var :: vars) next - -let rec all_scopes code_item_list = - match code_item_list with - | Nil -> [] - | Cons (ScopeDef (n, _), next_bind) -> - let _var, next = Bindlib.unbind next_bind in - n :: all_scopes next - | Cons (_, next_bind) -> - let _var, next = Bindlib.unbind next_bind in - all_scopes next +let find_scope name = + BoundList.find ~f:(function + | ScopeDef (n, body) when ScopeName.equal name n -> Some body + | _ -> None) let to_expr p main_scope = - let _, main_scope_body = find_scope main_scope [] p.code_items in - let res = - Scope.unfold p.decl_ctx p.code_items - (Scope.get_body_mark main_scope_body) - (ScopeName main_scope) - in + let res = Scope.unfold p.decl_ctx p.code_items main_scope in Expr.Box.assert_closed (Expr.Box.lift res); res - -let equal p p' = - (* TODO: include toplevel definitions in this program comparison. *) - let ss = all_scopes p.code_items in - let ss' = all_scopes p'.code_items in - - List.length ss = List.length ss' - && ListLabels.for_all2 ss ss' ~f:(fun s s' -> - ScopeName.equal s s' - && - let e1 = Expr.unbox @@ to_expr p s in - let e2 = Expr.unbox @@ to_expr p' s in - Expr.equal e1 e2) diff --git a/compiler/shared_ast/program.mli b/compiler/shared_ast/program.mli index b527ba8f..a702e2b4 100644 --- a/compiler/shared_ast/program.mli +++ b/compiler/shared_ast/program.mli @@ -23,16 +23,22 @@ val empty_ctx : decl_ctx (** {2 Transformations} *) +val map_decl_ctx : f:(typ -> typ) -> decl_ctx -> decl_ctx + val map_exprs : + ?typ:(typ -> typ) -> f:('expr1 -> 'expr2 boxed) -> varf:('expr1 Var.t -> 'expr2 Var.t) -> 'expr1 program -> - 'expr2 program Bindlib.box + 'expr2 program +(** If [typ] is specified, definitions in [decl_ctx] are also processed *) -val fold_left_exprs : +val fold_left : f:('a -> 'expr code_item -> 'a) -> init:'a -> 'expr program -> 'a -val fold_right_exprs : +val fold_exprs : f:('a -> 'expr -> typ -> 'a) -> init:'a -> 'expr program -> 'a + +val fold_right : f:('expr code_item -> 'a -> 'a) -> init:'a -> 'expr program -> 'a val get_scope_body : @@ -45,6 +51,4 @@ val to_expr : ((_ any, _) gexpr as 'e) program -> ScopeName.t -> 'e boxed corresponding to the main program and returning the main scope as a function. *) -val equal : - (('a any, _) gexpr as 'e) program -> (('a any, _) gexpr as 'e) program -> bool -(** Warning / todo: only compares program scopes at the moment *) +val find_scope : ScopeName.t -> 'e code_item_list -> 'e scope_body diff --git a/compiler/shared_ast/scope.ml b/compiler/shared_ast/scope.ml index 2054520e..59bcbcb8 100644 --- a/compiler/shared_ast/scope.ml +++ b/compiler/shared_ast/scope.ml @@ -18,151 +18,80 @@ open Catala_utils open Definitions -let rec fold_left_lets ~f ~init scope_body_expr = - match scope_body_expr with - | Result _ -> init - | ScopeLet scope_let -> - let var, next = Bindlib.unbind scope_let.scope_let_next in - fold_left_lets ~f ~init:(f init scope_let var) next - -let rec fold_right_lets ~f ~init scope_body_expr = - match scope_body_expr with - | Result result -> init result - | ScopeLet scope_let -> - let var, next = Bindlib.unbind scope_let.scope_let_next in - let next_result = fold_right_lets ~f ~init next in - f scope_let var next_result - let map_exprs_in_lets : - ?transform_types:(typ -> typ) -> + ?typ:(typ -> typ) -> f:('expr1 -> 'expr2 boxed) -> varf:('expr1 Var.t -> 'expr2 Var.t) -> 'expr1 scope_body_expr -> 'expr2 scope_body_expr Bindlib.box = - fun ?(transform_types = Fun.id) ~f ~varf scope_body_expr -> - fold_right_lets - ~f:(fun scope_let var_next acc -> - Bindlib.box_apply2 - (fun scope_let_next scope_let_expr -> - ScopeLet + fun ?(typ = Fun.id) ~f ~varf scope_body_expr -> + let f e = Expr.Box.lift (f e) in + BoundList.map ~last:f + ~f:(fun v scope_let -> + ( varf v, + Bindlib.box_apply + (fun scope_let_expr -> { scope_let with - scope_let_next; scope_let_expr; - scope_let_typ = transform_types scope_let.scope_let_typ; + scope_let_typ = typ scope_let.scope_let_typ; }) - (Bindlib.bind_var (varf var_next) acc) - (Expr.Box.lift (f scope_let.scope_let_expr))) - ~init:(fun res -> - Bindlib.box_apply (fun res -> Result res) (Expr.Box.lift (f res))) + (f scope_let.scope_let_expr) )) scope_body_expr -let rec fold_left ~f ~init = function - | Nil -> init - | Cons (item, next_bind) -> - let var, next = Bindlib.unbind next_bind in - fold_left ~f ~init:(f init item var) next - -let rec fold_right ~f ~init = function - | Nil -> init - | Cons (item, next_bind) -> - let var_next, next = Bindlib.unbind next_bind in - let result_next = fold_right ~f ~init next in - f item var_next result_next - -let rec map ~f ~varf = function - | Nil -> Bindlib.box Nil - | Cons (item, next_bind) -> - let item = f item in - let next_bind = - let var, next = Bindlib.unbind next_bind in - Bindlib.bind_var (varf var) (map ~f ~varf next) - in - Bindlib.box_apply2 - (fun item next_bind -> Cons (item, next_bind)) - item next_bind - -let rec map_ctx ~f ~varf ctx = function - | Nil -> Bindlib.box Nil - | Cons (item, next_bind) -> - let ctx, item = f ctx item in - let next_bind = - let var, next = Bindlib.unbind next_bind in - Bindlib.bind_var (varf var) (map_ctx ~f ~varf ctx next) - in - Bindlib.box_apply2 - (fun item next_bind -> Cons (item, next_bind)) - item next_bind - -let rec fold_map ~f ~varf ctx = function - | Nil -> ctx, Bindlib.box Nil - | Cons (item, next_bind) -> - let var, next = Bindlib.unbind next_bind in - let ctx, item = f ctx var item in - let ctx, next = fold_map ~f ~varf ctx next in - let next_bind = Bindlib.bind_var (varf var) next in - ( ctx, - Bindlib.box_apply2 - (fun item next_bind -> Cons (item, next_bind)) - item next_bind ) - -let map_exprs ~f ~varf scopes = - let f = function +let map_exprs ?(typ = Fun.id) ~f ~varf scopes = + let f v = function | ScopeDef (name, body) -> let scope_input_var, scope_lets = Bindlib.unbind body.scope_body_expr in - let new_body_expr = map_exprs_in_lets ~f ~varf scope_lets in + let new_body_expr = map_exprs_in_lets ~typ ~f ~varf scope_lets in let new_body_expr = Bindlib.bind_var (varf scope_input_var) new_body_expr in - Bindlib.box_apply - (fun scope_body_expr -> ScopeDef (name, { body with scope_body_expr })) - new_body_expr - | Topdef (name, typ, expr) -> - Bindlib.box_apply - (fun e -> Topdef (name, typ, e)) - (Expr.Box.lift (f expr)) + ( varf v, + Bindlib.box_apply + (fun scope_body_expr -> + ScopeDef (name, { body with scope_body_expr })) + new_body_expr ) + | Topdef (name, ty, expr) -> + ( varf v, + Bindlib.box_apply + (fun e -> Topdef (name, typ ty, e)) + (Expr.Box.lift (f expr)) ) in - map ~f ~varf scopes + BoundList.map ~f ~last:Bindlib.box scopes -(* TODO: compute the expected body expr arrow type manually instead of [TAny] - for double-checking types ? *) -let rec get_body_expr_mark = function - | ScopeLet sl -> - let _, e = Bindlib.unbind sl.scope_let_next in - get_body_expr_mark e - | Result e -> - let m = Mark.get e in - Expr.with_ty m (Mark.add (Expr.mark_pos m) TAny) +let fold_exprs ~f ~init scopes = + let f acc def _ = + match def with + | Topdef (_, typ, e) -> f acc e typ + | ScopeDef (_, scope) -> + let _, body = Bindlib.unbind scope.scope_body_expr in + let acc, last = + BoundList.fold_left body ~init:acc ~f:(fun acc sl _ -> + f acc sl.scope_let_expr sl.scope_let_typ) + in + f acc last (TStruct scope.scope_body_output_struct, Expr.pos last) + in + fst @@ BoundList.fold_left ~f ~init scopes + +let typ body = + let pos = Mark.get (StructName.get_info body.scope_body_input_struct) in + let input_typ = Mark.add pos (TStruct body.scope_body_input_struct) in + let result_typ = Mark.add pos (TStruct body.scope_body_output_struct) in + Mark.add pos (TArrow ([input_typ], result_typ)) let get_body_mark scope_body = - let _, e = Bindlib.unbind scope_body.scope_body_expr in - get_body_expr_mark e + let m0 = + match Bindlib.unbind scope_body.scope_body_expr with + | _, Last (_, m) | _, Cons ({ scope_let_expr = _, m; _ }, _) -> m + in + Expr.with_ty m0 (typ scope_body) -let rec unfold_body_expr (ctx : decl_ctx) (scope_let : 'e scope_body_expr) = - match scope_let with - | Result e -> Expr.rebox e - | ScopeLet - { - scope_let_kind = _; - scope_let_typ; - scope_let_expr; - scope_let_next; - scope_let_pos; - } -> - let var, next = Bindlib.unbind scope_let_next in - Expr.make_let_in var scope_let_typ - (Expr.rebox scope_let_expr) - (unfold_body_expr ctx next) - scope_let_pos - -let build_typ_from_sig - (_ctx : decl_ctx) - (scope_input_struct_name : StructName.t) - (scope_return_struct_name : StructName.t) - (pos : Pos.t) : typ = - let input_typ = Mark.add pos (TStruct scope_input_struct_name) in - let result_typ = Mark.add pos (TStruct scope_return_struct_name) in - Mark.add pos (TArrow ([input_typ], result_typ)) +let unfold_body_expr (_ctx : decl_ctx) (scope_let : 'e scope_body_expr) = + BoundList.fold_right scope_let ~init:Expr.rebox ~f:(fun sl var acc -> + Expr.make_let_in var sl.scope_let_typ + (Expr.rebox sl.scope_let_expr) + acc sl.scope_let_pos) let input_type ty io = match io, ty with @@ -171,59 +100,34 @@ let input_type ty io = | (Runtime.Reentrant, iopos), (ty, tpos) -> TDefault (ty, tpos), iopos | _, ty -> ty -type 'e scope_name_or_var = ScopeName of ScopeName.t | ScopeVar of 'e Var.t - -let to_expr (ctx : decl_ctx) (body : 'e scope_body) (mark_scope : 'm) : 'e boxed - = +let to_expr (ctx : decl_ctx) (body : 'e scope_body) : 'e boxed = let var, body_expr = Bindlib.unbind body.scope_body_expr in let body_expr = unfold_body_expr ctx body_expr in + let pos = Expr.pos body_expr in Expr.make_abs [| var |] body_expr - [TStruct body.scope_body_input_struct, Expr.mark_pos mark_scope] - (Expr.mark_pos mark_scope) + [TStruct body.scope_body_input_struct, pos] + pos -let rec unfold - (ctx : decl_ctx) - (s : 'e code_item_list) - (mark : 'm mark) - (main_scope : 'expr scope_name_or_var) : 'e boxed = - match s with - | Nil -> ( - match main_scope with - | ScopeVar v -> Expr.make_var v mark - | ScopeName _ -> failwith "should not happen") - | Cons (item, next_bind) -> - let var, next = Bindlib.unbind next_bind in - let typ, expr, pos, is_main = - match item with - | ScopeDef (name, body) -> - let pos = Mark.get (ScopeName.get_info name) in - let body_mark = get_body_mark body in - let is_main = - match main_scope with - | ScopeName n -> ScopeName.equal n name - | ScopeVar _ -> false - in - let typ = - build_typ_from_sig ctx body.scope_body_input_struct - body.scope_body_output_struct pos - in - let expr = to_expr ctx body body_mark in - typ, expr, pos, is_main - | Topdef (name, typ, expr) -> - let pos = Mark.get (TopdefName.get_info name) in - typ, Expr.rebox expr, pos, false - in - let main_scope = if is_main then ScopeVar var else main_scope in - let next = unfold ctx next mark main_scope in - Expr.make_let_in var typ expr next pos +let unfold (ctx : decl_ctx) (s : 'e code_item_list) (main_scope : ScopeName.t) : + 'e boxed = + BoundList.fold_lr s ~top:None + ~down:(fun v item main -> + match main, item with + | None, ScopeDef (name, body) when ScopeName.equal name main_scope -> + Some (Expr.make_var v (get_body_mark body)) + | r, _ -> r) + ~bottom:(fun () -> function Some v -> v | None -> raise Not_found) + ~up:(fun var item next -> + let e, typ = + match item with + | ScopeDef (_, body) -> to_expr ctx body, typ body + | Topdef (_, typ, expr) -> Expr.rebox expr, typ + in + Expr.make_let_in var typ e next (Expr.pos e)) -let rec free_vars_body_expr scope_lets = - match scope_lets with - | Result e -> Expr.free_vars e - | ScopeLet { scope_let_expr = e; scope_let_next = next; _ } -> - let v, body = Bindlib.unbind next in - Var.Set.union (Expr.free_vars e) - (Var.Set.remove v (free_vars_body_expr body)) +let free_vars_body_expr scope_lets = + BoundList.fold_right scope_lets ~init:Expr.free_vars ~f:(fun sl v acc -> + Var.Set.union (Var.Set.remove v acc) (Expr.free_vars sl.scope_let_expr)) let free_vars_item = function | ScopeDef (_, { scope_body_expr; _ }) -> @@ -231,9 +135,8 @@ let free_vars_item = function Var.Set.remove v (free_vars_body_expr body) | Topdef (_, _, expr) -> Expr.free_vars expr -let rec free_vars scopes = - match scopes with - | Nil -> Var.Set.empty - | Cons (item, next_bind) -> - let v, next = Bindlib.unbind next_bind in - Var.Set.union (Var.Set.remove v (free_vars next)) (free_vars_item item) +let free_vars scopes = + BoundList.fold_right scopes + ~init:(fun () -> Var.Set.empty) + ~f:(fun item v acc -> + Var.Set.union (Var.Set.remove v acc) (free_vars_item item)) diff --git a/compiler/shared_ast/scope.mli b/compiler/shared_ast/scope.mli index c7475299..696e04de 100644 --- a/compiler/shared_ast/scope.mli +++ b/compiler/shared_ast/scope.mli @@ -23,28 +23,8 @@ open Definitions (** {2 Traversal functions} *) -val fold_left_lets : - f:('a -> 'e scope_let -> 'e Var.t -> 'a) -> - init:'a -> - 'e scope_body_expr -> - 'a -(** Usage: - [fold_left_lets ~f:(fun acc scope_let scope_let_var -> ...) ~init scope_lets], - where [scope_let_var] is the variable bound to the scope let in the next - scope lets to be examined. *) - -val fold_right_lets : - f:('expr1 scope_let -> 'expr1 Var.t -> 'a -> 'a) -> - init:('expr1 -> 'a) -> - 'expr1 scope_body_expr -> - 'a -(** Usage: - [fold_right_lets ~f:(fun scope_let scope_let_var acc -> ...) ~init scope_lets], - where [scope_let_var] is the variable bound to the scope let in the next - scope lets to be examined (which are before in the program order). *) - val map_exprs_in_lets : - ?transform_types:(typ -> typ) -> + ?typ:(typ -> typ) -> f:('expr1 -> 'expr2 boxed) -> varf:('expr1 Var.t -> 'expr2 Var.t) -> 'expr1 scope_body_expr -> @@ -58,48 +38,8 @@ val map_exprs_in_lets : activated, then the resulting types in the scope let left-hand-sides will be reset to [TAny]. *) -val fold_left : - f:('a -> 'expr1 code_item -> 'expr1 Var.t -> 'a) -> - init:'a -> - 'expr1 code_item_list -> - 'a -(** Usage: [fold_left ~f:(fun acc code_def code_var -> ...) ~init code_def], - where [code_var] is the variable bound to the code item in the next code - items to be examined. *) - -val fold_right : - f:('expr1 code_item -> 'expr1 Var.t -> 'a -> 'a) -> - init:'a -> - 'expr1 code_item_list -> - 'a -(** Usage: - [fold_right_scope ~f:(fun scope_def scope_var acc -> ...) ~init scope_def], - where [scope_var] is the variable bound to the scope in the next scopes to - be examined (which are before in the program order). *) - -val map : - f:('e1 code_item -> 'e2 code_item Bindlib.box) -> - varf:('e1 Var.t -> 'e2 Var.t) -> - 'e1 code_item_list -> - 'e2 code_item_list Bindlib.box - -val map_ctx : - f:('ctx -> 'e1 code_item -> 'ctx * 'e2 code_item Bindlib.box) -> - varf:('e1 Var.t -> 'e2 Var.t) -> - 'ctx -> - 'e1 code_item_list -> - 'e2 code_item_list Bindlib.box -(** Similar to [map], but a context is passed left-to-right through the given - function *) - -val fold_map : - f:('ctx -> 'e1 Var.t -> 'e1 code_item -> 'ctx * 'e2 code_item Bindlib.box) -> - varf:('e1 Var.t -> 'e2 Var.t) -> - 'ctx -> - 'e1 code_item_list -> - 'ctx * 'e2 code_item_list Bindlib.box - val map_exprs : + ?typ:(typ -> typ) -> f:('expr1 -> 'expr2 boxed) -> varf:('expr1 Var.t -> 'expr2 Var.t) -> 'expr1 code_item_list -> @@ -107,28 +47,20 @@ val map_exprs : (** This is the main map visitor for all the expressions inside all the scopes of the program. *) -val get_body_mark : (_, 'm) gexpr scope_body -> 'm mark +val fold_exprs : + f:('acc -> 'expr -> typ -> 'acc) -> init:'acc -> 'expr code_item_list -> 'acc (** {2 Conversions} *) -val to_expr : - decl_ctx -> ('a any, 'm) gexpr scope_body -> 'm mark -> ('a, 'm) boxed_gexpr +val to_expr : decl_ctx -> ('a any, 'm) gexpr scope_body -> ('a, 'm) boxed_gexpr (** Usage: [to_expr ctx body scope_position] where [scope_position] corresponds to the line of the scope declaration for instance. *) -type 'e scope_name_or_var = ScopeName of ScopeName.t | ScopeVar of 'e Var.t - val unfold : - decl_ctx -> - ((_, 'm) gexpr as 'e) code_item_list -> - 'm mark -> - 'e scope_name_or_var -> - 'e boxed + decl_ctx -> ((_, 'm) gexpr as 'e) code_item_list -> ScopeName.t -> 'e boxed -val build_typ_from_sig : - decl_ctx -> StructName.t -> StructName.t -> Pos.t -> typ -(** [build_typ_from_sig ctx in_struct out_struct pos] builds the arrow type for - the specified scope *) +val typ : _ scope_body -> typ +(** builds the arrow type for the specified scope *) val input_type : typ -> Runtime.io_input Mark.pos -> typ (** Returns the correct input type for scope input variables: this is [typ] for diff --git a/compiler/shared_ast/shared_ast.ml b/compiler/shared_ast/shared_ast.ml index 692fd1eb..739066cc 100644 --- a/compiler/shared_ast/shared_ast.ml +++ b/compiler/shared_ast/shared_ast.ml @@ -20,6 +20,7 @@ module Qident = Qident module Type = Type module Operator = Operator module Expr = Expr +module BoundList = BoundList module Scope = Scope module Program = Program module Print = Print diff --git a/compiler/shared_ast/typing.ml b/compiler/shared_ast/typing.ml index 3bcd8d6c..b7582e6b 100644 --- a/compiler/shared_ast/typing.ml +++ b/compiler/shared_ast/typing.ml @@ -944,46 +944,38 @@ let check_expr ctx ?env ?typ e = let expr ctx ?(env = Env.empty ctx) ?typ e = Expr.map_marks ~f:(get_ty_mark ~flags:env.flags) (expr_raw ctx ~env ?typ e) -let rec scope_body_expr ctx env ty_out body_expr = - match body_expr with - | A.Result e -> - let e' = wrap_expr ctx (typecheck_expr_top_down ctx env ty_out) e in - let e' = Expr.map_marks ~f:(get_ty_mark ~flags:env.flags) e' in - Bindlib.box_apply (fun e -> A.Result e) (Expr.Box.lift e') - | A.ScopeLet - { - scope_let_kind; - scope_let_typ; - scope_let_expr = e0; - scope_let_next; - scope_let_pos; - } -> - let ty_e = ast_to_typ scope_let_typ in - let e = wrap_expr ctx (typecheck_expr_bottom_up ctx env) e0 in - wrap ctx (fun t -> unify ctx e0 (ty e) t) ty_e; - (* We could use [typecheck_expr_top_down] rather than this manual - unification, but we get better messages with this order of the [unify] - parameters, which keeps location of the type as defined instead of as - inferred. *) - let var, next = Bindlib.unbind scope_let_next in - let env = Env.add var ty_e env in - let next = scope_body_expr ctx env ty_out next in - let scope_let_next = Bindlib.bind_var (Var.translate var) next in - Bindlib.box_apply2 - (fun scope_let_expr scope_let_next -> - A.ScopeLet - { - scope_let_kind; - scope_let_typ = - (match Mark.remove scope_let_typ with - | TAny -> typ_to_ast ~flags:env.flags (ty e) - | _ -> scope_let_typ); - scope_let_expr; - scope_let_next; - scope_let_pos; - }) - (Expr.Box.lift (Expr.map_marks ~f:(get_ty_mark ~flags:env.flags) e)) - scope_let_next +let scope_body_expr ctx env ty_out body_expr = + let _env, ret = + BoundList.fold_map body_expr ~init:env + ~last:(fun env e -> + let e' = wrap_expr ctx (typecheck_expr_top_down ctx env ty_out) e in + let e' = Expr.map_marks ~f:(get_ty_mark ~flags:env.flags) e' in + env, Expr.Box.lift e') + ~f:(fun env var scope -> + let e0 = scope.A.scope_let_expr in + let ty_e = ast_to_typ scope.A.scope_let_typ in + let e = wrap_expr ctx (typecheck_expr_bottom_up ctx env) e0 in + wrap ctx (fun t -> unify ctx e0 (ty e) t) ty_e; + (* We could use [typecheck_expr_top_down] rather than this manual + unification, but we get better messages with this order of the + [unify] parameters, which keeps location of the type as defined + instead of as inferred. *) + ( Env.add var ty_e env, + Var.translate var, + Bindlib.box_apply + (fun scope_let_expr -> + { + scope with + A.scope_let_typ = + (match scope.A.scope_let_typ with + | TAny, _ -> typ_to_ast ~flags:env.flags (ty e) + | ty -> ty); + A.scope_let_expr; + }) + (Expr.Box.lift (Expr.map_marks ~f:(get_ty_mark ~flags:env.flags) e)) + )) + in + ret let scope_body ctx env body = let get_pos struct_name = Mark.get (A.StructName.get_info struct_name) in @@ -1003,33 +995,29 @@ let scope_body ctx env body = (get_pos body.A.scope_body_output_struct) (TArrow ([ty_in], ty_out))) ) -let rec scopes ctx env = function - | A.Nil -> Bindlib.box A.Nil, env - | A.Cons (item, next_bind) -> - let var, next = Bindlib.unbind next_bind in - let env, def = +let scopes ctx env = + BoundList.fold_map ~init:env + ~last:(fun ctx () -> ctx, Bindlib.box ()) + ~f:(fun env var item -> match item with | A.ScopeDef (name, body) -> let body_e, ty_scope = scope_body ctx env body in ( Env.add var ty_scope env, + Var.translate var, Bindlib.box_apply (fun body -> A.ScopeDef (name, body)) body_e ) | A.Topdef (name, typ, e) -> let e' = expr_raw ctx ~env ~typ e in let (A.Custom { custom = uf; _ }) = Mark.get e' in let e' = Expr.map_marks ~f:(get_ty_mark ~flags:env.flags) e' in ( Env.add var uf env, + Var.translate var, Bindlib.box_apply (fun e -> A.Topdef (name, Expr.ty e', e)) - (Expr.Box.lift e') ) - in - let next', env = scopes ctx env next in - let next_bind' = Bindlib.bind_var (Var.translate var) next' in - ( Bindlib.box_apply2 (fun item next -> A.Cons (item, next)) def next_bind', - env ) + (Expr.Box.lift e') )) let program ?fail_on_any ?assume_op_types prg = let env = Env.empty ?fail_on_any ?assume_op_types prg.A.decl_ctx in - let code_items, new_env = scopes prg.A.decl_ctx env prg.A.code_items in + let new_env, code_items = scopes prg.A.decl_ctx env prg.A.code_items in { A.lang = prg.lang; A.module_name = prg.A.module_name; diff --git a/compiler/tests.ml b/compiler/tests.ml index f37dee69..daabdbc6 100644 --- a/compiler/tests.ml +++ b/compiler/tests.ml @@ -5,7 +5,6 @@ let () = ( "Iota-reduction", [ test_case "#1" `Quick Shared_ast.Optimizations.test_iota_reduction_1; - test_case "#2" `Quick - Shared_ast.Optimizations.test_iota_reduction_2; + test_case "#2" `Quick Shared_ast.Optimizations.test_iota_reduction_2; ] ); ] diff --git a/compiler/verification/conditions.ml b/compiler/verification/conditions.ml index fabf5eb0..53bff0d0 100644 --- a/compiler/verification/conditions.ml +++ b/compiler/verification/conditions.ml @@ -286,11 +286,9 @@ let rec generate_verification_conditions_scope_body_expr (scope_body_expr : 'm expr scope_body_expr) : ctx * verification_condition list * typed expr list = match scope_body_expr with - | Result _ -> ctx, [], [] - | ScopeLet scope_let -> - let scope_let_var, scope_let_next = - Bindlib.unbind scope_let.scope_let_next - in + | Last _ -> ctx, [], [] + | Cons (scope_let, scope_let_next) -> + let scope_let_var, scope_let_next = Bindlib.unbind scope_let_next in let new_ctx, vc_list, assert_list = match scope_let.scope_let_kind with | Assertion -> ( @@ -378,49 +376,53 @@ let generate_verification_conditions_code_items (decl_ctx : decl_ctx) (code_items : 'm expr code_item_list) (s : ScopeName.t option) : verification_condition list = - Scope.fold_left - ~f:(fun vcs item _ -> - match item with - | Topdef _ -> [] - | ScopeDef (name, body) -> - let is_selected_scope = - match s with - | Some s when ScopeName.equal s name -> true - | None -> true - | _ -> false - in - let new_vcs = - if is_selected_scope then - let _scope_input_var, scope_body_expr = - Bindlib.unbind body.scope_body_expr - in - let ctx = - { - current_scope_name = name; - decl = decl_ctx; - input_vars = []; - scope_variables_typs = - Var.Map.empty - (* We don't need to add the typ of the scope input var here - because it will never appear in an expression for which we - generate a verification conditions (the big struct is - destructured with a series of let bindings just after. )*); - } - in - let _, vcs, asserts = - generate_verification_conditions_scope_body_expr ctx - scope_body_expr - in - let combined_assert = - conjunction_exprs asserts - (Typed - { pos = Pos.no_pos; ty = Mark.add Pos.no_pos (TLit TBool) }) - in - List.map (fun vc -> { vc with vc_asserts = combined_assert }) vcs - else [] - in - new_vcs @ vcs) - ~init:[] code_items + let conditions, () = + BoundList.fold_left + ~f:(fun vcs item _ -> + match item with + | Topdef _ -> [] + | ScopeDef (name, body) -> + let is_selected_scope = + match s with + | Some s when ScopeName.equal s name -> true + | None -> true + | _ -> false + in + let new_vcs = + if is_selected_scope then + let _scope_input_var, scope_body_expr = + Bindlib.unbind body.scope_body_expr + in + let ctx = + { + current_scope_name = name; + decl = decl_ctx; + input_vars = []; + scope_variables_typs = + Var.Map.empty + (* We don't need to add the typ of the scope input var here + because it will never appear in an expression for which + we generate a verification conditions (the big struct is + destructured with a series of let bindings just after. + )*); + } + in + let _, vcs, asserts = + generate_verification_conditions_scope_body_expr ctx + scope_body_expr + in + let combined_assert = + conjunction_exprs asserts + (Typed + { pos = Pos.no_pos; ty = Mark.add Pos.no_pos (TLit TBool) }) + in + List.map (fun vc -> { vc with vc_asserts = combined_assert }) vcs + else [] + in + new_vcs @ vcs) + ~init:[] code_items + in + conditions let generate_verification_conditions (p : 'm program) (s : ScopeName.t option) : verification_condition list = diff --git a/tests/test_func/good/closure_conversion.catala_en b/tests/test_func/good/closure_conversion.catala_en index 426fb621..b84ea9b8 100644 --- a/tests/test_func/good/closure_conversion.catala_en +++ b/tests/test_func/good/closure_conversion.catala_en @@ -27,13 +27,13 @@ type S = { z: integer; } let topval closure_f : (closure_env, integer) → integer = λ (env: closure_env) (y: integer) → - if (from_closure_env env).0 then y else - y + if (from_closure_env env).0 then y else - y let scope S (S_in: S_in {x_in: bool}): S {z: integer} = let get x : bool = S_in.x_in in let set f : ((closure_env, integer) → integer * closure_env) = (closure_f, to_closure_env (x)) in let set z : integer = f.0 f.1 -1 in - return { S z = z; } + return { S z = z; } ``` diff --git a/tests/test_func/good/closure_return.catala_en b/tests/test_func/good/closure_return.catala_en index 6ed6568b..c0e483bb 100644 --- a/tests/test_func/good/closure_return.catala_en +++ b/tests/test_func/good/closure_return.catala_en @@ -25,7 +25,7 @@ type S = { f: ((closure_env, integer) → integer * closure_env); } let topval closure_f : (closure_env, integer) → integer = λ (env: closure_env) (y: integer) → - if (from_closure_env env).0 then y else - y + if (from_closure_env env).0 then y else - y let scope S (S_in: S_in {x_in: bool}) : S {f: ((closure_env, integer) → integer * closure_env)} @@ -34,6 +34,6 @@ let scope S let set f : ((closure_env, integer) → integer * closure_env) = (closure_f, to_closure_env (x)) in - return { S f = f; } + return { S f = f; } ``` diff --git a/tests/test_func/good/scope_call_func_struct_closure.catala_en b/tests/test_func/good/scope_call_func_struct_closure.catala_en index 041f2100..ee12a1f4 100644 --- a/tests/test_func/good/scope_call_func_struct_closure.catala_en +++ b/tests/test_func/good/scope_call_func_struct_closure.catala_en @@ -70,7 +70,7 @@ type Foo = { z: integer; } let topval closure_y : (closure_env, integer) → integer = λ (env: closure_env) (z: integer) → - (from_closure_env env).0 + z + (from_closure_env env).0 + z let scope SubFoo1 (SubFoo1_in: SubFoo1_in {x_in: integer}) : SubFoo1 { @@ -82,11 +82,11 @@ let scope SubFoo1 let set y : ((closure_env, integer) → integer * closure_env) = (closure_y, to_closure_env (x)) in - return { SubFoo1 x = x; y = y; } + return { SubFoo1 x = x; y = y; } let topval closure_y : (closure_env, integer) → integer = λ (env: closure_env) (z: integer) → let env1 : (integer * integer) = from_closure_env env in - ((env1.1 + env1.0 + z)) + ((env1.1 + env1.0 + z)) let scope SubFoo2 (SubFoo2_in: SubFoo2_in {x1_in: integer; x2_in: integer}) : SubFoo2 { @@ -99,19 +99,19 @@ let scope SubFoo2 let set y : ((closure_env, integer) → integer * closure_env) = (closure_y, to_closure_env (x2, x1)) in - return { SubFoo2 x1 = x1; y = y; } + return { SubFoo2 x1 = x1; y = y; } let topval closure_r : (closure_env, integer) → integer = λ (env: closure_env) (param0: integer) → let code_and_env : ((closure_env, integer) → integer * closure_env) = (from_closure_env env).0.y in - code_and_env.0 code_and_env.1 param0 + code_and_env.0 code_and_env.1 param0 let topval closure_r : (closure_env, integer) → integer = λ (env: closure_env) (param0: integer) → let code_and_env : ((closure_env, integer) → integer * closure_env) = (from_closure_env env).0.y in - code_and_env.0 code_and_env.1 param0 + code_and_env.0 code_and_env.1 param0 let scope Foo (Foo_in: Foo_in {b_in: ((closure_env, unit) → eoption bool * closure_env)}) @@ -153,7 +153,7 @@ let scope Foo in code_and_env.0 code_and_env.1 1 in - return { Foo z = z; } + return { Foo z = z; } ``` diff --git a/tests/test_modules/good/output/mod_def.ml b/tests/test_modules/good/output/mod_def.ml index 91e4d782..52ca81c9 100644 --- a/tests/test_modules/good/output/mod_def.ml +++ b/tests/test_modules/good/output/mod_def.ml @@ -25,9 +25,8 @@ module S_in = struct end - let s (s_in: S_in.t) : S.t = - let sr_: money = + let sr_: money = try (handle_default {filename = "tests/test_modules/good/mod_def.catala_en"; @@ -47,7 +46,7 @@ let s (s_in: S_in.t) : S.t = {filename = "tests/test_modules/good/mod_def.catala_en"; start_line=16; start_column=10; end_line=16; end_column=12; law_headings=["Test modules + inclusions 1"]})) in - let e1_: Enum1.t = + let e1_: Enum1.t = try (handle_default {filename = "tests/test_modules/good/mod_def.catala_en"; @@ -70,6 +69,7 @@ let s (s_in: S_in.t) : S.t = let half_ : integer -> decimal = fun (x_: integer) -> o_div_int_int x_ (integer_of_string "2") + let () = Runtime_ocaml.Runtime.register_module "Mod_def" [ "S", Obj.repr s; diff --git a/tests/test_name_resolution/good/let_in2.catala_en b/tests/test_name_resolution/good/let_in2.catala_en index 4e135bdc..c9253372 100644 --- a/tests/test_name_resolution/good/let_in2.catala_en +++ b/tests/test_name_resolution/good/let_in2.catala_en @@ -52,17 +52,16 @@ module S_in = struct end - let s (s_in: S_in.t) : S.t = let a_: unit -> bool = s_in.S_in.a_in in - let a_: bool = + let a_: bool = try (handle_default {filename = "tests/test_name_resolution/good/let_in2.catala_en"; start_line=7; start_column=18; end_line=7; end_column=19; law_headings=["Article"]} ([|(fun (_: unit) -> a_ ())|]) (fun (_: unit) -> true) - (fun (_: unit) -> + (fun (_: unit) -> try (handle_default {filename = "tests/test_name_resolution/good/let_in2.catala_en"; @@ -91,6 +90,7 @@ let s (s_in: S_in.t) : S.t = start_line=7; start_column=18; end_line=7; end_column=19; law_headings=["Article"]})) in {S.a = a_} + let () = Runtime_ocaml.Runtime.register_module "Let_in2" [ "S", Obj.repr s ] diff --git a/tests/test_scope/good/191_fix_record_name_confusion.catala_en b/tests/test_scope/good/191_fix_record_name_confusion.catala_en index 1d7fed3b..f3e8d1d8 100644 --- a/tests/test_scope/good/191_fix_record_name_confusion.catala_en +++ b/tests/test_scope/good/191_fix_record_name_confusion.catala_en @@ -50,7 +50,6 @@ module ScopeB_in = struct end - let scope_a (scope_a_in: ScopeA_in.t) : ScopeA.t = let a_: bool = true in {ScopeA.a = a_} @@ -60,6 +59,7 @@ let scope_b (scope_b_in: ScopeB_in.t) : ScopeB.t = let scope_a_dot_a_: bool = result_.ScopeA.a in let a_: bool = scope_a_dot_a_ in {ScopeB.a = a_} + Generating entry points for scopes: ScopeA ScopeB let entry_scopes = [ diff --git a/tests/test_scope/good/scope_call3.catala_en b/tests/test_scope/good/scope_call3.catala_en index e7970bf6..64fabdcf 100644 --- a/tests/test_scope/good/scope_call3.catala_en +++ b/tests/test_scope/good/scope_call3.catala_en @@ -40,7 +40,7 @@ $ catala Interpret -t -s HousingComputation --debug [DEBUG] Translating to default calculus... [DEBUG] Typechecking again... [DEBUG] Starting interpretation... -[LOG] ≔ HousingComputation.f: λ (x_76: integer) → error_empty ⟨ ⟨#{☛ } true ⊢ ⟨(let result_77 : RentComputation = (#{→ RentComputation.direct} (λ (RentComputation_in_78: RentComputation_in) → let g_79 : integer → integer = #{≔ RentComputation.g} (λ (x1_80: integer) → error_empty ⟨ ⟨#{☛ } true ⊢ ⟨x1_80 +! 1⟩⟩ | false ⊢ ∅ ⟩) in let f_81 : integer → integer = #{≔ RentComputation.f} (λ (x1_82: integer) → error_empty ⟨ ⟨#{☛ } true ⊢ ⟨#{← RentComputation.g} #{≔ RentComputation.g.output} (#{→ RentComputation.g} g_79) #{≔ RentComputation.g.input0} (x1_82 +! 1)⟩⟩ | false ⊢ ∅ ⟩) in { RentComputation f = f_81; })) #{≔ RentComputation.direct.input} {RentComputation_in} in let result1_83 : RentComputation = { RentComputation f = λ (param0_84: integer) → #{← RentComputation.f} #{≔ RentComputation.f.output} (#{→ RentComputation.f} result_77.f) #{≔ RentComputation.f.input0} param0_84; } in #{← RentComputation.direct} #{≔ RentComputation.direct.output} if #{☛ RentComputation.direct.output} true then result1_83 else result1_83).f x_76⟩⟩ | false ⊢ ∅ ⟩ +[LOG] ≔ HousingComputation.f: λ (x_67: integer) → error_empty ⟨ ⟨#{☛ } true ⊢ ⟨(let result_68 : RentComputation = (#{→ RentComputation.direct} (λ (RentComputation_in_69: RentComputation_in) → let g_70 : integer → integer = #{≔ RentComputation.g} (λ (x1_71: integer) → error_empty ⟨ ⟨#{☛ } true ⊢ ⟨x1_71 +! 1⟩⟩ | false ⊢ ∅ ⟩) in let f_72 : integer → integer = #{≔ RentComputation.f} (λ (x1_73: integer) → error_empty ⟨ ⟨#{☛ } true ⊢ ⟨#{← RentComputation.g} #{≔ RentComputation.g.output} (#{→ RentComputation.g} g_70) #{≔ RentComputation.g.input0} (x1_73 +! 1)⟩⟩ | false ⊢ ∅ ⟩) in { RentComputation f = f_72; })) #{≔ RentComputation.direct.input} {RentComputation_in} in let result1_74 : RentComputation = { RentComputation f = λ (param0_75: integer) → #{← RentComputation.f} #{≔ RentComputation.f.output} (#{→ RentComputation.f} result_68.f) #{≔ RentComputation.f.input0} param0_75; } in #{← RentComputation.direct} #{≔ RentComputation.direct.output} if #{☛ RentComputation.direct.output} true then result1_74 else result1_74).f x_67⟩⟩ | false ⊢ ∅ ⟩ [LOG] ☛ Definition applied: ┌─⯈ tests/test_scope/good/scope_call3.catala_en:8.14-8.20: └─┐ @@ -55,14 +55,14 @@ $ catala Interpret -t -s HousingComputation --debug │ ‾ [LOG] → RentComputation.direct [LOG] ≔ RentComputation.direct.input: {RentComputation_in} -[LOG] ≔ RentComputation.g: λ (x_85: integer) → error_empty ⟨ ⟨#{☛ } true ⊢ ⟨x_85 +! 1⟩⟩ | false ⊢ ∅ ⟩ -[LOG] ≔ RentComputation.f: λ (x_86: integer) → error_empty ⟨ ⟨#{☛ } true ⊢ ⟨#{← RentComputation.g} #{≔ RentComputation.g.output} (#{→ RentComputation.g} (λ (x1_87: integer) → error_empty ⟨ ⟨#{☛ } true ⊢ ⟨x1_87 +! 1⟩⟩ | false ⊢ ∅ ⟩)) #{≔ RentComputation.g.input0} (x_86 +! 1)⟩⟩ | false ⊢ ∅ ⟩ +[LOG] ≔ RentComputation.g: λ (x_76: integer) → error_empty ⟨ ⟨#{☛ } true ⊢ ⟨x_76 +! 1⟩⟩ | false ⊢ ∅ ⟩ +[LOG] ≔ RentComputation.f: λ (x_77: integer) → error_empty ⟨ ⟨#{☛ } true ⊢ ⟨#{← RentComputation.g} #{≔ RentComputation.g.output} (#{→ RentComputation.g} (λ (x1_78: integer) → error_empty ⟨ ⟨#{☛ } true ⊢ ⟨x1_78 +! 1⟩⟩ | false ⊢ ∅ ⟩)) #{≔ RentComputation.g.input0} (x_77 +! 1)⟩⟩ | false ⊢ ∅ ⟩ [LOG] ☛ Definition applied: ┌─⯈ tests/test_scope/good/scope_call3.catala_en:7.29-7.54: └─┐ 7 │ definition f of x equals (output of RentComputation).f of x │ ‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾ -[LOG] ≔ RentComputation.direct.output: { RentComputation f = λ (param0_88: integer) → #{← RentComputation.f} #{≔ RentComputation.f.output} (#{→ RentComputation.f} { RentComputation f = λ (x_89: integer) → error_empty ⟨ ⟨#{☛ } true ⊢ ⟨#{← RentComputation.g} #{≔ RentComputation.g.output} (#{→ RentComputation.g} (λ (x1_90: integer) → error_empty ⟨ ⟨#{☛ } true ⊢ ⟨x1_90 +! 1⟩⟩ | false ⊢ ∅ ⟩)) #{≔ RentComputation.g.input0} (x_89 +! 1)⟩⟩ | false ⊢ ∅ ⟩; }.f) #{≔ RentComputation.f.input0} param0_88; } +[LOG] ≔ RentComputation.direct.output: { RentComputation f = λ (param0_79: integer) → #{← RentComputation.f} #{≔ RentComputation.f.output} (#{→ RentComputation.f} { RentComputation f = λ (x_80: integer) → error_empty ⟨ ⟨#{☛ } true ⊢ ⟨#{← RentComputation.g} #{≔ RentComputation.g.output} (#{→ RentComputation.g} (λ (x1_81: integer) → error_empty ⟨ ⟨#{☛ } true ⊢ ⟨x1_81 +! 1⟩⟩ | false ⊢ ∅ ⟩)) #{≔ RentComputation.g.input0} (x_80 +! 1)⟩⟩ | false ⊢ ∅ ⟩; }.f) #{≔ RentComputation.f.input0} param0_79; } [LOG] ← RentComputation.direct [LOG] → RentComputation.f [LOG] ≔ RentComputation.f.input0: 1