diff --git a/src/catala/catala_surface/desugaring.ml b/src/catala/catala_surface/desugaring.ml index a6a28e75..34893447 100644 --- a/src/catala/catala_surface/desugaring.ml +++ b/src/catala/catala_surface/desugaring.ml @@ -1,6 +1,6 @@ (* This file is part of the Catala compiler, a specification language for tax and social benefits computation rules. Copyright (C) 2020 Inria, contributor: Nicolas Chataing - + Denis Merigoux Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at @@ -15,13 +15,13 @@ open Catala_ast open Lambda_ast -let subscope_ident (y : string) (x : string) : string = y ^ "::" ^ x - (** The optional argument subdef allows to choose between differents uids in case the expression is a redefinition of a subvariable *) -let rec expr_to_lambda ?(subdef : Uid.t option) (scope : Uid.t) (ctxt : Name_resolution.context) - ((expr, pos) : Catala_ast.expression Pos.marked) : Lambda_ast.term = - let rec_helper = expr_to_lambda ?subdef scope ctxt in +let rec expr_to_lambda (scope : Uid.Scope.t) (def_key : Uid.ScopeDef.t option) + (ctxt : Name_resolution.context) ((expr, pos) : Catala_ast.expression Pos.marked) : + Lambda_ast.term = + let scope_ctxt = Uid.ScopeMap.find scope ctxt.scopes in + let rec_helper = expr_to_lambda scope def_key ctxt in match expr with | IfThenElse (e_if, e_then, e_else) -> ((EIfThenElse (rec_helper e_if, rec_helper e_then, rec_helper e_else), pos), TDummy) @@ -37,118 +37,44 @@ let rec expr_to_lambda ?(subdef : Uid.t option) (scope : Uid.t) (ctxt : Name_res | Number ((Int i, _), _) -> EInt i | Number ((Dec (i, f), _), _) -> EDec (i, f) | Bool b -> EBool b - | _ -> assert false + | _ -> Name_resolution.raise_unsupported_feature "literal" pos in ((untyped_term, pos), TDummy) - | Ident x -> - let uid = Name_resolution.get_var_uid scope ctxt (x, pos) in - ((EVar uid, pos), TDummy) + | Ident x -> ( + (* first we check whether this is a local var, then we resort to scope-wide variables *) + match def_key with + | Some def_key -> ( + let def_ctxt = Uid.ScopeDefMap.find def_key scope_ctxt.definitions in + match Uid.IdentMap.find_opt x def_ctxt.var_idmap with + | None -> ( + match Uid.IdentMap.find_opt x scope_ctxt.var_idmap with + | Some uid -> ((EVar (NoPrefix, uid), pos), TDummy) + | None -> + Name_resolution.raise_unknown_identifier "for a local or scope-wide variable" + (x, pos) ) + | Some uid -> ((ELocalVar uid, pos), TDummy) ) + | None -> ( + match Uid.IdentMap.find_opt x scope_ctxt.var_idmap with + | Some uid -> ((EVar (NoPrefix, uid), pos), TDummy) + | None -> Name_resolution.raise_unknown_identifier "for a scope-wide variable" (x, pos) ) + ) | Dotted (e, x) -> ( (* For now we only accept dotted identifiers of the type y.x where y is a sub-scope *) match Pos.unmark e with - | Ident y -> ( - let _, sub_uid = Name_resolution.get_subscope_uid scope ctxt (Pos.same_pos_as y e) in - match subdef with - | None -> - (* No redefinition : take the uid from the current scope *) - let ident = subscope_ident y (Pos.unmark x) in - let uid = Name_resolution.get_var_uid scope ctxt (ident, Pos.get_position e) in - ((EVar uid, pos), TDummy) - | Some uid when uid <> sub_uid -> - (* Redefinition of a var from another scope : uid from the current scope *) - let ident = subscope_ident y (Pos.unmark x) in - let uid = Name_resolution.get_var_uid scope ctxt (ident, Pos.get_position e) in - ((EVar uid, pos), TDummy) - | Some sub_uid -> - (* Redefinition of a var from the same scope, uid from the subscope *) - let uid = Name_resolution.get_var_uid sub_uid ctxt x in - ((EVar uid, pos), TDummy) ) - | _ -> assert false ) + | Ident y -> + let subscope_uid : Uid.SubScope.t = + Name_resolution.get_subscope_uid scope ctxt (Pos.same_pos_as y e) + in + let subscope_real_uid : Uid.Scope.t = + Uid.SubScopeMap.find subscope_uid scope_ctxt.sub_scopes + in + let subscope_var_uid = Name_resolution.get_var_uid subscope_real_uid ctxt x in + ((EVar (SubScopePrefix subscope_uid, subscope_var_uid), pos), TDummy) + | _ -> + Name_resolution.raise_unsupported_feature + "left hand side of a dotted expression should be an identifier" pos ) | FunCall (f, arg) -> ((EApp (rec_helper f, [ rec_helper arg ]), pos), TDummy) - | _ -> assert false - -(** Checks that a term is well typed and annotate it *) -let rec typing (ctxt : Name_resolution.context) (((t, pos), _) : Lambda_ast.term) : Lambda_ast.term - = - match t with - | EVar uid -> - let typ = Name_resolution.get_uid_typ ctxt uid in - ((EVar uid, pos), typ) - | EFun (bindings, body) -> - (* Note that given the context formation process, the binder will already be present in the - context (since we are working with uids), however it's added there for the sake of clarity *) - let ctxt_data = - List.fold_left - (fun data (uid, arg_typ) -> - let uid_data : Name_resolution.uid_data = - { uid_typ = arg_typ; uid_sort = Name_resolution.IdBinder } - in - Uid.UidMap.add uid uid_data data) - ctxt.data bindings - in - - let body = typing { ctxt with data = ctxt_data } body in - let ret_typ = Lambda_ast.get_typ body in - let rec build_typ = function - | [] -> ret_typ - | (_, arg_t) :: args -> TArrow (arg_t, build_typ args) - in - let fun_typ = build_typ bindings in - ((EFun (bindings, body), pos), fun_typ) - | EApp (f, args) -> - let f = typing ctxt f in - let f_typ = Lambda_ast.get_typ f in - let args = List.map (typing ctxt) args in - let args_typ = List.map Lambda_ast.get_typ args in - let rec check_arrow_typ f_typ args_typ = - match (f_typ, args_typ) with - | typ, [] -> typ - | TArrow (arg_typ, ret_typ), fst_typ :: typs -> - if arg_typ = fst_typ then check_arrow_typ ret_typ typs else assert false - | _ -> assert false - in - let ret_typ = check_arrow_typ f_typ args_typ in - ((EApp (f, args), pos), ret_typ) - | EIfThenElse (t_if, t_then, t_else) -> - let t_if = typing ctxt t_if in - let typ_if = Lambda_ast.get_typ t_if in - let t_then = typing ctxt t_then in - let typ_then = Lambda_ast.get_typ t_then in - let t_else = typing ctxt t_else in - let typ_else = Lambda_ast.get_typ t_else in - if typ_if <> TBool then assert false - else if typ_then <> typ_else then assert false - else ((EIfThenElse (t_if, t_then, t_else), pos), typ_then) - | EInt _ | EDec _ -> ((t, pos), TInt) - | EBool _ -> ((t, pos), TBool) - | EOp op -> - let typ = - match op with - | Binop binop -> ( - match binop with - | And | Or -> TArrow (TBool, TArrow (TBool, TBool)) - | Add | Sub | Mult | Div -> TArrow (TInt, TArrow (TInt, TInt)) - | Lt | Lte | Gt | Gte | Eq | Neq -> TArrow (TInt, TArrow (TInt, TBool)) ) - | Unop Minus -> TArrow (TInt, TInt) - | Unop Not -> TArrow (TBool, TBool) - in - ((t, pos), typ) - | EDefault t -> - let defaults = - IntMap.map - (fun (just, cons) -> - let just = typing ctxt just in - if Lambda_ast.get_typ just <> TBool then - let cons = typing ctxt cons in - (just, cons) - else assert false) - t.defaults - in - let typ_cons = IntMap.choose defaults |> snd |> snd |> Lambda_ast.get_typ in - IntMap.iter - (fun _ (_, cons) -> if Lambda_ast.get_typ cons <> typ_cons then assert false else ()) - defaults; - ((EDefault { t with defaults }, pos), typ_cons) + | _ -> Name_resolution.raise_unsupported_feature "unsupported expression" pos (* Translation from the parsed ast to the scope language *) @@ -161,111 +87,83 @@ let merge_conditions (precond : Lambda_ast.term option) (cond : Lambda_ast.term | Some cond, None | None, Some cond -> cond | None, None -> ((EBool true, default_pos), TBool) -let process_default (ctxt : Name_resolution.context) (scope : Uid.t) (def : Lambda_ast.default_term) +let process_default (ctxt : Name_resolution.context) (scope : Uid.Scope.t) + (def_key : Uid.ScopeDef.t) (def : Lambda_ast.default_term) (param_uid : Uid.LocalVar.t option) (precond : Lambda_ast.term option) (just : Catala_ast.expression Pos.marked option) - (body : Catala_ast.expression Pos.marked) (subdef : Uid.t option) : Lambda_ast.default_term = + (body : Catala_ast.expression Pos.marked) : Lambda_ast.default_term = let just = match just with - | Some cond -> - let cond = expr_to_lambda ?subdef scope ctxt cond |> typing ctxt in - if Lambda_ast.get_typ cond = TBool then Some cond else assert false + | Some cond -> Some (expr_to_lambda scope (Some def_key) ctxt cond) | None -> None in - let condition = merge_conditions precond just (Pos.get_position body) |> typing ctxt in - let body = expr_to_lambda ?subdef scope ctxt body |> typing ctxt in + let condition = merge_conditions precond just (Pos.get_position body) in + let body = expr_to_lambda scope (Some def_key) ctxt body in + (* if there's a parameter, we have to wrap the justifiction and the body in a func *) + let condition, body = + match param_uid with + | None -> (condition, body) + | Some param_uid -> + ( ((EFun ([ (param_uid, TDummy) ], condition), Pos.get_position (fst condition)), TDummy), + ((EFun ([ (param_uid, TDummy) ], body), Pos.get_position (fst body)), TDummy) ) + in Lambda_ast.add_default condition body def (* Process a definition *) -let process_def (precond : Lambda_ast.term option) (scope_uid : Uid.t) +let process_def (precond : Lambda_ast.term option) (scope_uid : Uid.Scope.t) (ctxt : Name_resolution.context) (prgm : Scope_ast.program) (def : Catala_ast.definition) : Scope_ast.program = - let scope : Scope_ast.scope = UidMap.find scope_uid prgm in + let scope : Scope_ast.scope = Uid.ScopeMap.find scope_uid prgm in + let scope_ctxt = Uid.ScopeMap.find scope_uid ctxt.scopes in let default_pos = Pos.get_position def.definition_expr in - let is_func = match def.definition_parameter with Some _ -> true | None -> false in - let scope_updated = + let param_uid (def_uid : Uid.ScopeDef.t) : Uid.LocalVar.t option = + match def.definition_parameter with + | None -> None + | Some param -> + let def_ctxt = Uid.ScopeDefMap.find def_uid scope_ctxt.definitions in + Some (Uid.IdentMap.find (Pos.unmark param) def_ctxt.var_idmap) + in + let def_key = match Pos.unmark def.definition_name with | [ x ] -> let x_uid = Name_resolution.get_var_uid scope_uid ctxt x in - let ctxt, arg_uid = - Name_resolution.add_binding ctxt scope_uid x_uid def.definition_parameter - in - let x_def = - match UidMap.find_opt x_uid scope.scope_defs with - | Some def -> def - | None -> - let typ = Name_resolution.get_uid_typ ctxt x_uid in - if is_func then Scope_ast.empty_func_def (Option.get arg_uid) default_pos typ - else Scope_ast.empty_var_def default_pos typ - in - let x_def = - Lambda_ast.map_untype - (fun t -> - match t with - | EDefault default -> - EDefault - (process_default ctxt scope_uid default precond def.definition_condition - def.definition_expr None) - | EFun ([ bind ], term) -> ( - match arg_uid with - | Some bind' -> - if fst bind <> bind' then assert false - else - let term = - Lambda_ast.map_untype - (fun t -> - match t with - | EDefault default -> - EDefault - (process_default ctxt scope_uid default precond - def.definition_condition def.definition_expr None) - | _ -> assert false) - term - in - EFun ([ bind ], term) - | None -> assert false ) - | _ -> assert false) - x_def - in - { scope with scope_defs = UidMap.add x_uid x_def scope.scope_defs } + Uid.ScopeDef.Var x_uid | [ y; x ] -> - let y_uid, subscope_uid = Name_resolution.get_subscope_uid scope_uid ctxt y in - let x_uid = Name_resolution.get_var_uid subscope_uid ctxt x in - let ctxt, arg_uid = - Name_resolution.add_binding ctxt subscope_uid x_uid def.definition_parameter + let subscope_uid : Uid.SubScope.t = Name_resolution.get_subscope_uid scope_uid ctxt y in + let subscope_real_uid : Uid.Scope.t = + Uid.SubScopeMap.find subscope_uid scope_ctxt.sub_scopes in - let y_redefs = - match UidMap.find_opt y_uid scope.scope_sub_defs with - | Some redefs -> redefs - | None -> UidMap.empty - in - let x_redef = - match UidMap.find_opt x_uid y_redefs with - | Some redef -> redef - | None -> - let typ = Name_resolution.get_uid_typ ctxt x_uid in - if is_func then Scope_ast.empty_func_def (Option.get arg_uid) default_pos typ - else Scope_ast.empty_var_def default_pos typ - in - let x_redef = - Lambda_ast.map_untype - (fun t -> - match t with - | EDefault default -> - EDefault - (process_default ctxt scope_uid default precond def.definition_condition - def.definition_expr (Some subscope_uid)) - | _ -> assert false) - x_redef - in - let y_redefs = UidMap.add x_uid x_redef y_redefs in - { scope with scope_sub_defs = UidMap.add y_uid y_redefs scope.scope_sub_defs } + let x_uid = Name_resolution.get_var_uid subscope_real_uid ctxt x in + Uid.ScopeDef.SubScopeVar (subscope_uid, x_uid) | _ -> Errors.raise_spanned_error "Structs are not handled yet" default_pos in - UidMap.add scope_uid scope_updated prgm + let scope_updated = + let x_def = + match Uid.ScopeDefMap.find_opt def_key scope.scope_defs with + | Some def -> def + | None -> + let typ = Name_resolution.get_def_typ ctxt def_key in + Scope_ast.empty_def default_pos typ + in + let x_def = + Lambda_ast.map_untype + (fun t -> + match t with + | EDefault default -> + EDefault + (process_default ctxt scope_uid def_key default (param_uid def_key) precond + def.definition_condition def.definition_expr) + | _ -> assert false + (* should not happen *)) + x_def + in + { scope with scope_defs = Uid.ScopeDefMap.add def_key x_def scope.scope_defs } + in + Uid.ScopeMap.add scope_uid scope_updated prgm (** Process a rule from the surface language *) -let process_rule (precond : Lambda_ast.term option) (scope : Uid.t) (ctxt : Name_resolution.context) - (prgm : Scope_ast.program) (rule : Catala_ast.rule) : Scope_ast.program = +let process_rule (precond : Lambda_ast.term option) (scope : Uid.Scope.t) + (ctxt : Name_resolution.context) (prgm : Scope_ast.program) (rule : Catala_ast.rule) : + Scope_ast.program = let consequence_expr = Catala_ast.Literal (Catala_ast.Bool (Pos.unmark rule.rule_consequence)) in let def = { @@ -277,7 +175,7 @@ let process_rule (precond : Lambda_ast.term option) (scope : Uid.t) (ctxt : Name in process_def precond scope ctxt prgm def -let process_scope_use_item (cond : Lambda_ast.term option) (scope : Uid.t) +let process_scope_use_item (cond : Lambda_ast.term option) (scope : Uid.Scope.t) (ctxt : Name_resolution.context) (prgm : Scope_ast.program) (item : Catala_ast.scope_use_item Pos.marked) : Scope_ast.program = match Pos.unmark item with @@ -288,19 +186,18 @@ let process_scope_use_item (cond : Lambda_ast.term option) (scope : Uid.t) let process_scope_use (ctxt : Name_resolution.context) (prgm : Scope_ast.program) (use : Catala_ast.scope_use) : Scope_ast.program = let name = fst use.scope_use_name in - let scope_uid = Name_resolution.IdentMap.find name ctxt.scope_id_to_uid in + let scope_uid = Uid.IdentMap.find name ctxt.scope_idmap in (* Make sure the scope exists *) let prgm = - match UidMap.find_opt scope_uid prgm with + match Uid.ScopeMap.find_opt scope_uid prgm with | Some _ -> prgm - | None -> UidMap.add scope_uid (Scope_ast.empty_scope scope_uid) prgm + | None -> Uid.ScopeMap.add scope_uid (Scope_ast.empty_scope scope_uid) prgm in let cond = match use.scope_use_condition with | Some expr -> - let untyped_term = expr_to_lambda scope_uid ctxt expr in - let term = typing ctxt untyped_term in - if Lambda_ast.get_typ term = TBool then Some term else assert false + let untyped_term = expr_to_lambda scope_uid None ctxt expr in + Some untyped_term | None -> None in List.fold_left (process_scope_use_item cond scope_uid ctxt) prgm use.scope_use_items @@ -308,7 +205,7 @@ let process_scope_use (ctxt : Name_resolution.context) (prgm : Scope_ast.program (** Scopes processing *) let translate_program_to_scope (ctxt : Name_resolution.context) (prgm : Catala_ast.program) : Scope_ast.program = - let empty_prgm = UidMap.empty in + let empty_prgm = Uid.ScopeMap.empty in let processer (prgm : Scope_ast.program) (item : Catala_ast.program_item) : Scope_ast.program = match item with | CodeBlock (block, _) | MetadataBlock (block, _) -> diff --git a/src/catala/catala_surface/name_resolution.ml b/src/catala/catala_surface/name_resolution.ml new file mode 100644 index 00000000..11b7ca97 --- /dev/null +++ b/src/catala/catala_surface/name_resolution.ml @@ -0,0 +1,288 @@ +(* This file is part of the Catala compiler, a specification language for tax and social benefits + computation rules. Copyright (C) 2020 Inria, contributor: Nicolas Chataing + Denis Merigoux + + Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + in compliance with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software distributed under the License + is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + or implied. See the License for the specific language governing permissions and limitations under + the License. *) + +(** Builds a context that allows for mapping each name to a precise uid, taking lexical scopes into + account *) + +type ident = string + +type typ = Lambda_ast.typ + +type def_context = { var_idmap : Uid.LocalVar.t Uid.IdentMap.t } +(** Inside a definition, local variables can be introduced by functions arguments or pattern + matching *) + +type scope_context = { + var_idmap : Uid.Var.t Uid.IdentMap.t; + sub_scopes_idmap : Uid.SubScope.t Uid.IdentMap.t; + sub_scopes : Uid.Scope.t Uid.SubScopeMap.t; + definitions : def_context Uid.ScopeDefMap.t; + (** Contains the local variables in all the definitions *) +} +(** Inside a scope, we distinguish between the variables and the subscopes. *) + +type context = { + scope_idmap : Uid.Scope.t Uid.IdentMap.t; + scopes : scope_context Uid.ScopeMap.t; + var_typs : typ Uid.VarMap.t; +} + +let raise_unsupported_feature (msg : string) (pos : Pos.t) = + Errors.raise_spanned_error (Printf.sprintf "unsupported feature: %s" msg) pos + +let raise_unknown_identifier (msg : string) (ident : ident Pos.marked) = + Errors.raise_spanned_error + (Printf.sprintf "%s: unknown identifier %s" (Pos.unmark ident) msg) + (Pos.get_position ident) + +(** Get the type associated to an uid *) +let get_var_typ (ctxt : context) (uid : Uid.Var.t) : typ = Uid.VarMap.find uid ctxt.var_typs + +(** Process a subscope declaration *) +let process_subscope_decl (scope : Uid.Scope.t) (ctxt : context) + (decl : Catala_ast.scope_decl_context_scope) : context = + let name, name_pos = decl.scope_decl_context_scope_name in + let subscope, s_pos = decl.scope_decl_context_scope_sub_scope in + let scope_ctxt = Uid.ScopeMap.find scope ctxt.scopes in + match Uid.IdentMap.find_opt subscope scope_ctxt.sub_scopes_idmap with + | Some use -> + Errors.raise_multispanned_error "subscope name already used" + [ + (Some "first use", Pos.get_position (Uid.SubScope.get_info use)); + (Some "second use", s_pos); + ] + | None -> + let sub_scope_uid = Uid.SubScope.fresh (name, name_pos) in + let original_subscope_uid = + match Uid.IdentMap.find_opt subscope ctxt.scope_idmap with + | None -> raise_unknown_identifier "for a scope" (subscope, s_pos) + | Some id -> id + in + let scope_ctxt = + { + scope_ctxt with + sub_scopes_idmap = Uid.IdentMap.add name sub_scope_uid scope_ctxt.sub_scopes_idmap; + sub_scopes = Uid.SubScopeMap.add sub_scope_uid original_subscope_uid scope_ctxt.sub_scopes; + } + in + { ctxt with scopes = Uid.ScopeMap.add scope scope_ctxt ctxt.scopes } + +let process_base_typ ((typ, typ_pos) : Catala_ast.base_typ Pos.marked) : Lambda_ast.typ = + match typ with + | Catala_ast.Condition -> Lambda_ast.TBool + | Catala_ast.Data (Catala_ast.Collection _) -> raise_unsupported_feature "collection type" typ_pos + | Catala_ast.Data (Catala_ast.Optional _) -> raise_unsupported_feature "option type" typ_pos + | Catala_ast.Data (Catala_ast.Primitive prim) -> ( + match prim with + | Catala_ast.Integer | Catala_ast.Decimal | Catala_ast.Money | Catala_ast.Date -> + Lambda_ast.TInt + | Catala_ast.Boolean -> Lambda_ast.TBool + | Catala_ast.Text -> raise_unsupported_feature "text type" typ_pos + | Catala_ast.Named _ -> raise_unsupported_feature "struct or enum types" typ_pos ) + +let process_type ((typ, typ_pos) : Catala_ast.typ Pos.marked) : Lambda_ast.typ = + match typ with + | Catala_ast.Base base_typ -> process_base_typ (base_typ, typ_pos) + | Catala_ast.Func { arg_typ; return_typ } -> + Lambda_ast.TArrow (process_base_typ arg_typ, process_base_typ return_typ) + +(** Process data declaration *) +let process_data_decl (scope : Uid.Scope.t) (ctxt : context) + (decl : Catala_ast.scope_decl_context_data) : context = + (* First check the type of the context data *) + let data_typ = process_type decl.scope_decl_context_item_typ in + let name, pos = decl.scope_decl_context_item_name in + let scope_ctxt = Uid.ScopeMap.find scope ctxt.scopes in + match Uid.IdentMap.find_opt name scope_ctxt.var_idmap with + | Some use -> + Errors.raise_multispanned_error "var name already used" + [ (Some "first use", Pos.get_position (Uid.Var.get_info use)); (Some "second use", pos) ] + | None -> + let uid = Uid.Var.fresh (name, pos) in + let scope_ctxt = + { scope_ctxt with var_idmap = Uid.IdentMap.add name uid scope_ctxt.var_idmap } + in + { + ctxt with + scopes = Uid.ScopeMap.add scope scope_ctxt ctxt.scopes; + var_typs = Uid.VarMap.add uid data_typ ctxt.var_typs; + } + +(** Process an item declaration *) +let process_item_decl (scope : Uid.Scope.t) (ctxt : context) + (decl : Catala_ast.scope_decl_context_item) : context = + match decl with + | Catala_ast.ContextData data_decl -> process_data_decl scope ctxt data_decl + | Catala_ast.ContextScope sub_decl -> process_subscope_decl scope ctxt sub_decl + +(** Adds a binding to the context *) +let add_def_local_var (ctxt : context) (scope_uid : Uid.Scope.t) (def_uid : Uid.ScopeDef.t) + (name : ident Pos.marked) : context = + let scope_ctxt = Uid.ScopeMap.find scope_uid ctxt.scopes in + let def_ctx = Uid.ScopeDefMap.find def_uid scope_ctxt.definitions in + let local_var_uid = Uid.LocalVar.fresh name in + let def_ctx = + { var_idmap = Uid.IdentMap.add (Pos.unmark name) local_var_uid def_ctx.var_idmap } + in + let scope_ctxt = + { scope_ctxt with definitions = Uid.ScopeDefMap.add def_uid def_ctx scope_ctxt.definitions } + in + { ctxt with scopes = Uid.ScopeMap.add scope_uid scope_ctxt ctxt.scopes } + +(** Process a scope declaration *) +let process_scope_decl (ctxt : context) (decl : Catala_ast.scope_decl) : context = + let name, pos = decl.scope_decl_name in + (* Checks if the name is already used *) + match Uid.IdentMap.find_opt name ctxt.scope_idmap with + | Some use -> + Errors.raise_multispanned_error "scope name already used" + [ (Some "first use", Pos.get_position (Uid.Scope.get_info use)); (Some "second use", pos) ] + | None -> + let scope_uid = Uid.Scope.fresh (name, pos) in + let ctxt = + { + ctxt with + scope_idmap = Uid.IdentMap.add name scope_uid ctxt.scope_idmap; + scopes = + Uid.ScopeMap.add scope_uid + { + var_idmap = Uid.IdentMap.empty; + sub_scopes_idmap = Uid.IdentMap.empty; + definitions = Uid.ScopeDefMap.empty; + sub_scopes = Uid.SubScopeMap.empty; + } + ctxt.scopes; + } + in + List.fold_left + (fun ctxt item -> process_item_decl scope_uid ctxt (Pos.unmark item)) + ctxt decl.scope_decl_context + +let qident_to_scope_def (ctxt : context) (scope_uid : Uid.Scope.t) + (id : Catala_ast.qident Pos.marked) : Uid.ScopeDef.t = + let scope_ctxt = Uid.ScopeMap.find scope_uid ctxt.scopes in + match Pos.unmark id with + | [ x ] -> ( + match Uid.IdentMap.find_opt (Pos.unmark x) scope_ctxt.var_idmap with + | None -> raise_unknown_identifier "for a var of the scope" x + | Some id -> Uid.ScopeDef.Var id ) + | [ s; x ] -> ( + let sub_scope_uid = + match Uid.IdentMap.find_opt (Pos.unmark s) scope_ctxt.sub_scopes_idmap with + | None -> raise_unknown_identifier "for a subscope of this scope" s + | Some id -> id + in + let real_sub_scope_uid = Uid.SubScopeMap.find sub_scope_uid scope_ctxt.sub_scopes in + let sub_scope_ctx = Uid.ScopeMap.find real_sub_scope_uid ctxt.scopes in + match Uid.IdentMap.find_opt (Pos.unmark x) sub_scope_ctx.var_idmap with + | None -> raise_unknown_identifier "for a var of this subscope" x + | Some id -> Uid.ScopeDef.SubScopeVar (sub_scope_uid, id) ) + | _ -> raise_unsupported_feature "wrong qident" (Pos.get_position id) + +let process_scope_use (ctxt : context) (use : Catala_ast.scope_use) : context = + let scope_uid = + match Uid.IdentMap.find_opt (Pos.unmark use.scope_use_name) ctxt.scope_idmap with + | None -> raise_unknown_identifier "for a scope" use.scope_use_name + | Some id -> id + in + List.fold_left + (fun ctxt use_item -> + match Pos.unmark use_item with + | Catala_ast.Definition def -> + let scope_ctxt = Uid.ScopeMap.find scope_uid ctxt.scopes in + let def_uid = qident_to_scope_def ctxt scope_uid def.definition_name in + let def_ctxt = + { + var_idmap = + ( match def.definition_parameter with + | None -> Uid.IdentMap.empty + | Some param -> Uid.IdentMap.singleton (Pos.unmark param) (Uid.LocalVar.fresh param) + ); + } + in + let scope_ctxt = + { + scope_ctxt with + definitions = Uid.ScopeDefMap.add def_uid def_ctxt scope_ctxt.definitions; + } + in + { ctxt with scopes = Uid.ScopeMap.add scope_uid scope_ctxt ctxt.scopes } + | _ -> raise_unsupported_feature "unsupported item" (Pos.get_position use_item)) + ctxt use.scope_use_items + +(** Process a code item : for now it only handles scope decls *) +let process_use_item (ctxt : context) (item : Catala_ast.code_item Pos.marked) : context = + match Pos.unmark item with + | ScopeDecl _ -> ctxt + | ScopeUse use -> process_scope_use ctxt use + | _ -> raise_unsupported_feature "item not supported" (Pos.get_position item) + +(** Process a code item : for now it only handles scope decls *) +let process_decl_item (ctxt : context) (item : Catala_ast.code_item Pos.marked) : context = + match Pos.unmark item with ScopeDecl decl -> process_scope_decl ctxt decl | _ -> ctxt + +(** Process a code block *) +let process_code_block (ctxt : context) (block : Catala_ast.code_block) + (process_item : context -> Catala_ast.code_item Pos.marked -> context) : context = + List.fold_left (fun ctxt decl -> process_item ctxt decl) ctxt block + +(** Process a program item *) +let process_program_item (ctxt : context) (item : Catala_ast.program_item) + (process_item : context -> Catala_ast.code_item Pos.marked -> context) : context = + match item with + | CodeBlock (block, _) | MetadataBlock (block, _) -> process_code_block ctxt block process_item + | _ -> ctxt + +(** Derive the context from metadata, in two passes *) +let form_context (prgm : Catala_ast.program) : context = + let empty_ctxt = + { scope_idmap = Uid.IdentMap.empty; scopes = Uid.ScopeMap.empty; var_typs = Uid.VarMap.empty } + in + let ctxt = + List.fold_left + (fun ctxt item -> process_program_item ctxt item process_decl_item) + empty_ctxt prgm.program_items + in + List.fold_left + (fun ctxt item -> process_program_item ctxt item process_use_item) + ctxt prgm.program_items + +(** Get the variable uid inside the scope given in argument *) +let get_var_uid (scope_uid : Uid.Scope.t) (ctxt : context) ((x, pos) : ident Pos.marked) : Uid.Var.t + = + let scope = Uid.ScopeMap.find scope_uid ctxt.scopes in + match Uid.IdentMap.find_opt x scope.var_idmap with + | None -> raise_unknown_identifier "for a var of this scope" (x, pos) + | Some uid -> uid + +(** Get the subscope uid inside the scope given in argument *) +let get_subscope_uid (scope_uid : Uid.Scope.t) (ctxt : context) ((y, pos) : ident Pos.marked) : + Uid.SubScope.t = + let scope = Uid.ScopeMap.find scope_uid ctxt.scopes in + match Uid.IdentMap.find_opt y scope.sub_scopes_idmap with + | None -> raise_unknown_identifier "for a subscope of this scope" (y, pos) + | Some sub_uid -> sub_uid + +(** Checks if the var_uid belongs to the scope scope_uid *) +let belongs_to (ctxt : context) (uid : Uid.Var.t) (scope_uid : Uid.Scope.t) : bool = + let scope = Uid.ScopeMap.find scope_uid ctxt.scopes in + Uid.IdentMap.exists (fun _ var_uid -> Uid.Var.compare uid var_uid = 0) scope.var_idmap + +let get_def_typ (ctxt : context) (def : Uid.ScopeDef.t) : typ = + match def with + | Uid.ScopeDef.SubScopeVar (_, x) + (* we don't need to look at the subscope prefix because [x] is already the uid referring back to + the original subscope *) + | Uid.ScopeDef.Var x -> + Uid.VarMap.find x ctxt.var_typs diff --git a/src/catala/cli.ml b/src/catala/cli.ml index ce74b8f6..b3591efc 100644 --- a/src/catala/cli.ml +++ b/src/catala/cli.ml @@ -164,8 +164,8 @@ let debug_print (s : string) = let error_print (s : string) = Printf.eprintf "%s\n" (add_prefix_to_each_line s (fun _ -> error_marker ())); - flush stdout; - flush stdout + flush stderr; + flush stderr let warning_print (s : string) = Printf.printf "%s\n" (add_prefix_to_each_line s (fun _ -> warning_marker ())); diff --git a/src/catala/driver.ml b/src/catala/driver.ml index 0e9937f3..9bd6a504 100644 --- a/src/catala/driver.ml +++ b/src/catala/driver.ml @@ -100,7 +100,7 @@ let driver (source_file : string) (debug : bool) (unstyled : bool) (wrap_weaved_ match ex_scope with | None -> Errors.raise_error "No scope was provided for execution." | Some name -> ( - match Name_resolution.IdentMap.find_opt name ctxt.scope_id_to_uid with + match Uid.IdentMap.find_opt name ctxt.scope_idmap with | None -> Errors.raise_error (Printf.sprintf "There is no scope %s inside the program." name) @@ -108,21 +108,23 @@ let driver (source_file : string) (debug : bool) (unstyled : bool) (wrap_weaved_ in let prgm = Desugaring.translate_program_to_scope ctxt program in let scope = - match Uid.UidMap.find_opt scope_uid prgm with + match Uid.ScopeMap.find_opt scope_uid prgm with | Some scope -> scope | None -> + let scope_info = Uid.Scope.get_info scope_uid in Errors.raise_spanned_error (Printf.sprintf "Scope %s does not define anything, and therefore cannot be executed" - (Uid.get_ident scope_uid)) - (Uid.get_pos scope_uid) + (Pos.unmark scope_info)) + (Pos.get_position scope_info) in let exec_ctxt = Scope_interpreter.execute_scope ctxt prgm scope in - Uid.UidMap.iter - (fun uid value -> + Lambda_interpreter.ExecContext.iter + (fun context_key value -> Cli.result_print - (Printf.sprintf "%s -> %s" (Uid.get_ident uid) - (Format_lambda.print_term ((value, Uid.get_pos uid), TDummy)))) + (Printf.sprintf "%s -> %s" + (Lambda_interpreter.ExecContextKey.format_t context_key) + (Format_lambda.print_term ((value, Pos.no_pos), TDummy)))) exec_ctxt; 0 with Errors.StructuredError (msg, pos) -> diff --git a/src/catala/lambda_calculus/format_lambda.ml b/src/catala/lambda_calculus/format_lambda.ml index 7f06ccca..8616a08e 100644 --- a/src/catala/lambda_calculus/format_lambda.ml +++ b/src/catala/lambda_calculus/format_lambda.ml @@ -12,43 +12,17 @@ or implied. See the License for the specific language governing permissions and limitations under the License. *) -module UidMap = Uid.UidMap module IdentMap = Map.Make (String) -(** Print the context in a readable manner *) -let print_context (ctxt : Name_resolution.context) : string = - let rec typ_to_string = function - | Lambda_ast.TBool -> "bool" - | Lambda_ast.TInt -> "num" - | Lambda_ast.TDummy -> "(.)" - | Lambda_ast.TArrow (t1, t2) -> - Printf.sprintf "%s -> (%s)" (typ_to_string t1) (typ_to_string t2) - in - let print_var ((var_id, var_uid) : Uid.ident * Uid.t) : string = - let data = UidMap.find var_uid ctxt.data in - let info = - match data.uid_sort with - | IdScope -> "\tscope" - | IdScopeVar None -> Printf.sprintf "\ttyp : %s\tvar" (typ_to_string data.uid_typ) - | IdScopeVar (Some _) -> Printf.sprintf "\ttyp : %s\tfun" (typ_to_string data.uid_typ) - | IdSubScope uid -> Printf.sprintf "\tsubscope : %d" uid - | IdSubScopeVar (var_uid, sub_scope_uid) -> - Printf.sprintf "\ttype : %s\tsubvar(%d, scope %d)" (typ_to_string data.uid_typ) var_uid - sub_scope_uid - | IdBinder -> Printf.sprintf "\ttyp : %s\tbinder" (typ_to_string data.uid_typ) - in - Printf.sprintf "%s (uid : %n)%s\n" var_id var_uid info - in - let print_scope ((scope_ident, scope_uid) : Uid.ident * Uid.t) : string = - Printf.sprintf "Scope %s (uid : %n):\n" scope_ident scope_uid - ^ ( (UidMap.find scope_uid ctxt.scopes).var_id_to_uid |> IdentMap.bindings |> List.map print_var - |> String.concat "" ) - ^ Printf.sprintf "\n" - in - ctxt.scope_id_to_uid |> IdentMap.bindings |> List.map print_scope |> String.concat "" - (* Printing functions for Lambda_ast.term *) +let rec format_typ (ty : Lambda_ast.typ) : string = + match ty with + | TBool -> "bool" + | TInt -> "int" + | TArrow (t1, t2) -> Format.sprintf "(%s) -> (%s)" (format_typ t1) (format_typ t2) + | TDummy -> "??" + (** Operator printer *) let print_op (op : Lambda_ast.op) : string = match op with @@ -69,14 +43,25 @@ let print_op (op : Lambda_ast.op) : string = | Unop Not -> "not" | Unop Minus -> "-" +let rec repeat_string n s = if n = 0 then "" else s ^ repeat_string (n - 1) s + +let print_prefix (prefix : Lambda_ast.var_prefix) : string = + match prefix with + | NoPrefix -> "" + | SubScopePrefix s -> Uid.SubScope.format_t s ^ "." + | CallerPrefix (i, s) -> ( + repeat_string i "CALLER." + ^ match s with None -> "" | Some s -> Uid.SubScope.format_t s ^ "." ) + (** Print Lambda_ast.term *) let rec print_term (((t, _), _) : Lambda_ast.term) : string = match t with - | EVar uid -> Printf.sprintf "%s(%d)" (Uid.get_ident uid) uid + | EVar (s, uid) -> Printf.sprintf "%s%s" (print_prefix s) (Uid.Var.format_t uid) + | ELocalVar uid -> Uid.LocalVar.format_t uid | EFun (binders, body) -> let sbody = print_term body in Printf.sprintf "fun %s -> %s" - (binders |> List.map (fun (x, _) -> Printf.sprintf "%d" x) |> String.concat " ") + (binders |> List.map (fun (uid, _) -> Uid.LocalVar.format_t uid) |> String.concat " ") sbody | EApp (f, args) -> Printf.sprintf "(%s) [%s]" (print_term f) (args |> List.map print_term |> String.concat ";") @@ -89,7 +74,11 @@ let rec print_term (((t, _), _) : Lambda_ast.term) : string = | EDefault t -> print_default_term t and print_default_term (term : Lambda_ast.default_term) : string = - term.defaults |> Lambda_ast.IntMap.bindings - |> List.map (fun (_, (cond, body)) -> - Printf.sprintf "\t%s => %s" (print_term cond) (print_term body)) - |> String.concat "\n" + ( term.defaults + |> List.mapi (fun i (cond, body) -> + Printf.sprintf "[%d]\t%s => %s" i (print_term cond) (print_term body)) + |> String.concat "\n" ) + ^ "\n" + ^ ( term.ordering + |> List.map (fun (hi, lo) -> Printf.sprintf "%d > %d" hi lo) + |> String.concat ", " ) diff --git a/src/catala/lambda_calculus/lambda_ast.ml b/src/catala/lambda_calculus/lambda_ast.ml index 6aaa000f..fe8263e4 100644 --- a/src/catala/lambda_calculus/lambda_ast.ml +++ b/src/catala/lambda_calculus/lambda_ast.ml @@ -12,10 +12,6 @@ or implied. See the License for the specific language governing permissions and limitations under the License. *) -module UidMap = Uid.UidMap -module UidSet = Uid.UidSet -module IntMap = Map.Make (Int) - (* TDummy means the term is not typed *) type typ = TBool | TInt | TArrow of typ * typ | TDummy @@ -27,14 +23,21 @@ type unop = Catala_ast.unop type op = Binop of binop | Unop of unop -type binding = Uid.t * typ +type binding = Uid.LocalVar.t * typ -(*type enum_case = uid*) +type var_prefix = + (* See [Scope_interpreter] for details about the meaning of this case. The `int` means the number + of times you have to go to the parent caller to get the variable *) + | CallerPrefix of int * Uid.SubScope.t option + | NoPrefix + | SubScopePrefix of Uid.SubScope.t type term = untyped_term Pos.marked * typ and untyped_term = - | EVar of Uid.t + | EVar of var_prefix * Uid.Var.t + (** This case is only for terms embedded in the scope language *) + | ELocalVar of Uid.LocalVar.t | EFun of binding list * term | EApp of term * term list | EIfThenElse of term * term * term @@ -46,11 +49,7 @@ and untyped_term = (* (x,y) in ordering means that default x has precedence over default y : if both are true then x would be choser over y *) -and default_term = { - defaults : (term * term) IntMap.t; - ordering : (int * int) list; - nb_defaults : int; -} +and default_term = { defaults : (term * term) list; ordering : (int * int) list } let untype (((term, _), _) : term) : untyped_term = term @@ -65,22 +64,16 @@ let map_untype2 (f : untyped_term -> untyped_term -> untyped_term) (((t1, pos), (((t2, _), _) : term) : term = ((f t1 t2, pos), typ) -let empty_default_term : default_term = { defaults = IntMap.empty; ordering = []; nb_defaults = 0 } +let empty_default_term : default_term = { defaults = []; ordering = [] } let add_default (just : term) (cons : term) (term : default_term) = - { - term with - defaults = IntMap.add term.nb_defaults (just, cons) term.defaults; - nb_defaults = term.nb_defaults + 1; - } + { term with defaults = term.defaults @ [ (just, cons) ] } (** Merge two defalts terms, taking into account that one has higher precedence than the other *) let merge_default_terms (lo_term : default_term) (hi_term : default_term) : default_term = - let n = lo_term.nb_defaults in - let n' = hi_term.nb_defaults in - let defaults = - IntMap.fold (fun k default -> IntMap.add (n + k) default) hi_term.defaults lo_term.defaults - in + let n = List.length lo_term.defaults in + let n' = List.length hi_term.defaults in + let defaults = lo_term.defaults @ hi_term.defaults in let rec add_hi_prec = function | [] -> lo_term.ordering | (k, k') :: xs -> (n + k, n + k') :: add_hi_prec xs @@ -97,25 +90,31 @@ let merge_default_terms (lo_term : default_term) (hi_term : default_term) : defa let rec gen_list i j acc = if i = j then acc else gen_list (i + 1) j (i :: acc) in let gen_list i j = gen_list i j [] in let prec' = gen_prec (gen_list 0 n) (gen_list n (n + n')) in - { defaults; ordering = prec @ prec'; nb_defaults = n + n' } + { defaults; ordering = prec @ prec' } -(** Returns the free variables of a term *) -let rec term_fv (term : term) : UidSet.t = +(** Returns the free variables (scope language variables) of a term. Used to build the dependency + graph *) +let rec term_fv (term : term) : Uid.ScopeDefSet.t = match untype term with - | EVar uid -> UidSet.singleton uid - | EFun (bindings, body) -> - let body_fv = term_fv body in - let bindings = bindings |> List.map (fun (x, _) -> x) |> UidSet.of_list in - UidSet.diff body_fv bindings - | EApp (f, args) -> List.fold_left (fun fv arg -> UidSet.union fv (term_fv arg)) (term_fv f) args + | EVar (NoPrefix, uid) -> Uid.ScopeDefSet.singleton (Uid.ScopeDef.Var uid) + | EVar (SubScopePrefix sub_uid, uid) -> + Uid.ScopeDefSet.singleton (Uid.ScopeDef.SubScopeVar (sub_uid, uid)) + | EVar (CallerPrefix _, _) -> + Uid.ScopeDefSet.empty + (* here we return an empty dependency because when calling a subscope, the variables of the + caller graph needed for it are already computed *) + | ELocalVar _ -> Uid.ScopeDefSet.empty + | EFun (_, body) -> term_fv body + | EApp (f, args) -> + List.fold_left (fun fv arg -> Uid.ScopeDefSet.union fv (term_fv arg)) (term_fv f) args | EIfThenElse (t_if, t_then, t_else) -> - UidSet.union (term_fv t_if) (UidSet.union (term_fv t_then) (term_fv t_else)) + Uid.ScopeDefSet.union (term_fv t_if) (Uid.ScopeDefSet.union (term_fv t_then) (term_fv t_else)) | EDefault default -> default_term_fv default - | _ -> UidSet.empty + | EBool _ | EInt _ | EDec _ | EOp _ -> Uid.ScopeDefSet.empty -and default_term_fv (term : default_term) : UidSet.t = - IntMap.fold - (fun _ (cond, body) -> - let fv = UidSet.union (term_fv cond) (term_fv body) in - UidSet.union fv) - term.defaults UidSet.empty +and default_term_fv (term : default_term) : Uid.ScopeDefSet.t = + List.fold_left + (fun acc (cond, body) -> + let fv = Uid.ScopeDefSet.union (term_fv cond) (term_fv body) in + Uid.ScopeDefSet.union fv acc) + Uid.ScopeDefSet.empty term.defaults diff --git a/src/catala/lambda_calculus/lambda_interpreter.ml b/src/catala/lambda_calculus/lambda_interpreter.ml index 5515133b..abdd7e06 100644 --- a/src/catala/lambda_calculus/lambda_interpreter.ml +++ b/src/catala/lambda_calculus/lambda_interpreter.ml @@ -18,24 +18,55 @@ type uid = int type scope_uid = int -module UidMap = Uid.UidMap -module UidSet = Uid.UidSet +module ExecContextKey = struct + type t = LocalVar of Uid.LocalVar.t | ScopeVar of var_prefix * Uid.Var.t -type exec_context = Lambda_ast.untyped_term UidMap.t + let compare x y = + match (x, y) with + | LocalVar x, LocalVar y -> Uid.LocalVar.compare x y + | ScopeVar (x1, x2), ScopeVar (y1, y2) -> ( + match (x1, y1) with + | NoPrefix, NoPrefix | CallerPrefix _, CallerPrefix _ -> Uid.Var.compare x2 y2 + | SubScopePrefix x1, SubScopePrefix y1 -> + let sub_comp = Uid.SubScope.compare x1 y1 in + if sub_comp = 0 then Uid.Var.compare x2 y2 else sub_comp + | _ -> compare x y ) + | _ -> compare x y -let empty_exec_ctxt = UidMap.empty + let format_t (x : t) : string = + match x with + | LocalVar x -> Uid.LocalVar.format_t x + | ScopeVar (prefix, var) -> Format_lambda.print_prefix prefix ^ Uid.Var.format_t var +end + +module ExecContext = Map.Make (ExecContextKey) + +type exec_context = Lambda_ast.untyped_term ExecContext.t + +let format_exec_context (ctx : exec_context) = + String.concat "\n" + (List.map + (fun (key, value) -> + Printf.sprintf "%s -> %s" (ExecContextKey.format_t key) + (Format_lambda.print_term ((value, Pos.no_pos), TDummy))) + (ExecContext.bindings ctx)) + +let empty_exec_ctxt = ExecContext.empty + +let raise_default_conflict (def : Uid.ScopeDef.t) (true_pos : Pos.t list) (false_pos : Pos.t list) = + let var_str = Uid.ScopeDef.format_t def in + let var_pos = + match def with + | Uid.ScopeDef.SubScopeVar (_, v) | Uid.ScopeDef.Var v -> Pos.get_position (Uid.Var.get_info v) + in -let raise_default_conflict (id : string Pos.marked) (true_pos : Pos.t list) (false_pos : Pos.t list) - = if List.length true_pos = 0 then let justifications : (string option * Pos.t) list = List.map (fun pos -> (Some "This justification is false:", pos)) false_pos in Errors.raise_multispanned_error - (Printf.sprintf "Default logic error for variable %s: no justification is true." - (Pos.unmark id)) - ( ( Some (Printf.sprintf "The error concerns this variable %s" (Pos.unmark id)), - Pos.get_position id ) + (Printf.sprintf "Default logic error for variable %s: no justification is true." var_str) + ( (Some (Printf.sprintf "The error concerns this variable %s" var_str), var_pos) :: justifications ) else let justifications : (string option * Pos.t) list = @@ -43,22 +74,30 @@ let raise_default_conflict (id : string Pos.marked) (true_pos : Pos.t list) (fal in Errors.raise_multispanned_error "Default logic conflict, multiple justifications are true but are not related by a precedence" - ( ( Some (Printf.sprintf "The conflict concerns this variable %s" (Pos.unmark id)), - Pos.get_position id ) + ( (Some (Printf.sprintf "The conflict concerns this variable %s" var_str), var_pos) :: justifications ) -let rec eval_term (top_uid : Uid.t) (exec_ctxt : exec_context) (term : Lambda_ast.term) : +let rec eval_term (top_uid : Uid.ScopeDef.t) (exec_ctxt : exec_context) (term : Lambda_ast.term) : Lambda_ast.term = let (term, pos), typ = term in let evaled_term = match term with | EFun _ | EInt _ | EDec _ | EBool _ | EOp _ -> term (* already a value *) - | EVar uid -> ( - match UidMap.find_opt uid exec_ctxt with + | ELocalVar uid -> ( + let ctxt_key = ExecContextKey.LocalVar uid in + match ExecContext.find_opt ctxt_key exec_ctxt with | Some t -> t | None -> Errors.raise_spanned_error - (Printf.sprintf "Variable %s is not defined" (Uid.get_ident uid)) + (Printf.sprintf "Local Variable %s is not defined" (Uid.LocalVar.format_t uid)) + pos ) + | EVar (prefix, uid) -> ( + let ctxt_key = ExecContextKey.ScopeVar (prefix, uid) in + match ExecContext.find_opt ctxt_key exec_ctxt with + | Some t -> t + | None -> + Errors.raise_spanned_error + (Printf.sprintf "Variable %s is not defined" (Uid.Var.format_t uid)) pos ) | EApp (f, args) -> ( (* First evaluate and match the function body *) @@ -68,7 +107,9 @@ let rec eval_term (top_uid : Uid.t) (exec_ctxt : exec_context) (term : Lambda_as let exec_ctxt = List.fold_left2 (fun ctxt arg (uid, _) -> - UidMap.add uid (arg |> eval_term top_uid exec_ctxt |> Lambda_ast.untype) ctxt) + ExecContext.add (ExecContextKey.LocalVar uid) + (arg |> eval_term top_uid exec_ctxt |> Lambda_ast.untype) + ctxt) exec_ctxt args bindings in eval_term top_uid exec_ctxt body |> Lambda_ast.untype @@ -82,11 +123,13 @@ let rec eval_term (top_uid : Uid.t) (exec_ctxt : exec_context) (term : Lambda_as | And | Or -> let b1, b2 = match args with [ EBool b1; EBool b2 ] -> (b1, b2) | _ -> assert false + (* should not happen *) in EBool (if binop = And then b1 && b2 else b1 || b2) | _ -> ( let i1, i2 = match args with [ EInt i1; EInt i2 ] -> (i1, i2) | _ -> assert false + (* should not happen *) in match binop with | Add | Sub | Mult | Div -> @@ -97,6 +140,7 @@ let rec eval_term (top_uid : Uid.t) (exec_ctxt : exec_context) (term : Lambda_as | Mult -> ( * ) | Div -> ( / ) | _ -> assert false + (* should not happen *) in EInt (op_arith i1 i2) | _ -> @@ -109,40 +153,50 @@ let rec eval_term (top_uid : Uid.t) (exec_ctxt : exec_context) (term : Lambda_as | Eq -> ( = ) | Neq -> ( <> ) | _ -> assert false + (* should not happen *) in EBool (op_comp i1 i2) ) ) - | Unop Minus -> ( match args with [ EInt i ] -> EInt (-i) | _ -> assert false ) - | Unop Not -> ( match args with [ EBool b ] -> EBool (not b) | _ -> assert false ) ) + | Unop Minus -> ( + match args with + | [ EInt i ] -> EInt (-i) + | _ -> assert false (* should not happen *) ) + | Unop Not -> ( + match args with + | [ EBool b ] -> EBool (not b) + | _ -> assert false (* should not happen *) ) ) | _ -> assert false ) | EIfThenElse (t_if, t_then, t_else) -> ( match eval_term top_uid exec_ctxt t_if |> Lambda_ast.untype with | EBool b -> if b then eval_term top_uid exec_ctxt t_then else eval_term top_uid exec_ctxt t_else - | _ -> assert false ) + | _ -> assert false (* should not happen *) ) |> Lambda_ast.untype | EDefault t -> ( match eval_default_term top_uid exec_ctxt t with | Ok value -> value |> Lambda_ast.untype - | Error (true_pos, false_pos) -> - raise_default_conflict (Uid.get_ident top_uid, Uid.get_pos top_uid) true_pos false_pos ) + | Error (true_pos, false_pos) -> raise_default_conflict top_uid true_pos false_pos ) in ((evaled_term, pos), typ) (* Evaluates a default term : see the formalization for an insight about this operation *) -and eval_default_term (top_uid : Uid.t) (exec_ctxt : exec_context) (term : Lambda_ast.default_term) - : (Lambda_ast.term, Pos.t list * Pos.t list) result = +and eval_default_term (top_uid : Uid.ScopeDef.t) (exec_ctxt : exec_context) + (term : Lambda_ast.default_term) : (Lambda_ast.term, Pos.t list * Pos.t list) result = (* First filter out the term which justification are false *) - let candidates : 'a IntMap.t = - IntMap.filter - (fun _ (cond, _) -> + let defaults_numbered : (int * (term * term)) list = + List.mapi (fun (x : int) (y : term * term) -> (x, y)) term.defaults + in + let candidates : 'a list = + List.filter + (fun (_, (cond, _)) -> match eval_term top_uid exec_ctxt cond |> Lambda_ast.untype with | EBool b -> b - | _ -> assert false) - term.defaults + | _ -> assert false + (* should not happen *)) + defaults_numbered in (* Now filter out the terms that have a predecessor which justification is true *) let module ISet = Set.Make (Int) in - let key_candidates = IntMap.fold (fun x _ -> ISet.add x) candidates ISet.empty in + let key_candidates = List.fold_left (fun acc (x, _) -> ISet.add x acc) ISet.empty candidates in let chosen_one = List.fold_left (fun set (lo, hi) -> if ISet.mem lo set && ISet.mem hi set then ISet.remove hi set else set) @@ -150,14 +204,15 @@ and eval_default_term (top_uid : Uid.t) (exec_ctxt : exec_context) (term : Lambd in match ISet.elements chosen_one with | [ x ] -> - let _, cons = IntMap.find x term.defaults in + let _, (_, cons) = List.find (fun (i, _) -> i = x) defaults_numbered in Ok (eval_term top_uid exec_ctxt cons) | xs -> let true_pos = - xs |> List.map (fun x -> IntMap.find x term.defaults |> fst |> Lambda_ast.get_pos) + xs + |> List.map (fun x -> + List.find (fun (i, _) -> i = x) defaults_numbered |> snd |> fst |> Lambda_ast.get_pos) in let false_pos : Pos.t list = - let bindings : (scope_uid * (term * term)) list = term.defaults |> IntMap.bindings in - List.map (fun (_, (cond, _)) -> Lambda_ast.get_pos cond) bindings + List.map (fun (_, (cond, _)) -> Lambda_ast.get_pos cond) defaults_numbered in Error (true_pos, false_pos) diff --git a/src/catala/lambda_calculus/lambda_typechecker.ml b/src/catala/lambda_calculus/lambda_typechecker.ml new file mode 100644 index 00000000..4f7885f0 --- /dev/null +++ b/src/catala/lambda_calculus/lambda_typechecker.ml @@ -0,0 +1,138 @@ +(* This file is part of the Catala compiler, a specification language for tax and social benefits + computation rules. Copyright (C) 2020 Inria, contributor: Nicolas Chataing + + + Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + in compliance with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software distributed under the License + is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + or implied. See the License for the specific language governing permissions and limitations under + the License. *) + +(** Checks that a term is well typed and annotate it *) +let rec type_term (ctxt : Name_resolution.context) (local_ctx : Lambda_ast.typ Uid.LocalVarMap.t) + (((t, pos), _) : Lambda_ast.term) : Lambda_ast.term = + match t with + | EVar (s, uid) -> + (* so here we can ignore the subscope uid because the uid of the variable already corresponds + to the uid of the variable in its original scope *) + let typ = Name_resolution.get_var_typ ctxt uid in + ((EVar (s, uid), pos), typ) + | ELocalVar uid -> + let typ = Uid.LocalVarMap.find uid local_ctx in + ((ELocalVar uid, pos), typ) + | EFun (bindings, body) -> + let local_ctx = + List.fold_left + (fun local_ctx (binding, ty) -> Uid.LocalVarMap.add binding ty local_ctx) + local_ctx bindings + in + let body = type_term ctxt local_ctx body in + let ret_typ = Lambda_ast.get_typ body in + let rec build_typ = function + | [] -> ret_typ + | (_, arg_t) :: args -> TArrow (arg_t, build_typ args) + in + let fun_typ = build_typ bindings in + ((EFun (bindings, body), pos), fun_typ) + | EApp (f, args) -> + let f = type_term ctxt local_ctx f in + let f_typ = Lambda_ast.get_typ f in + let args = List.map (type_term ctxt local_ctx) args in + let args_typ = + List.map (fun arg -> (Lambda_ast.get_typ arg, Pos.get_position (fst arg))) args + in + let rec check_arrow_typ f_typ args_typ = + match (f_typ, args_typ) with + | typ, [] -> typ + | Lambda_ast.TArrow (arg_typ, ret_typ), fst_typ :: typs -> + let fst_typ_s = Pos.unmark fst_typ in + if arg_typ = fst_typ_s then check_arrow_typ ret_typ typs + else + Errors.raise_multispanned_error "error when comparing types of function arguments" + [ + ( Some (Printf.sprintf "expected type %s" (Format_lambda.format_typ f_typ)), + Pos.get_position (fst f) ); + ( Some (Printf.sprintf "got type %s" (Format_lambda.format_typ fst_typ_s)), + Pos.get_position fst_typ ); + ] + | _ -> + Errors.raise_multispanned_error "wrong number of arguments for function call" + [ + ( Some (Printf.sprintf "expected type %s" (Format_lambda.format_typ f_typ)), + Pos.get_position (fst f) ); + ( Some + (Printf.sprintf "got type %s" + (String.concat " -> " + (List.map (fun (ty, _) -> Format_lambda.format_typ ty) args_typ))), + Pos.get_position (List.hd args_typ) ); + ] + in + let ret_typ = check_arrow_typ f_typ args_typ in + ((EApp (f, args), pos), ret_typ) + | EIfThenElse (t_if, t_then, t_else) -> + let t_if = type_term ctxt local_ctx t_if in + let typ_if = Lambda_ast.get_typ t_if in + let t_then = type_term ctxt local_ctx t_then in + let typ_then = Lambda_ast.get_typ t_then in + let t_else = type_term ctxt local_ctx t_else in + let typ_else = Lambda_ast.get_typ t_else in + if typ_if <> TBool then + Errors.raise_spanned_error + (Format.sprintf "expecting type bool, got type %s" (Format_lambda.format_typ typ_if)) + (Pos.get_position (fst t_if)) + else if typ_then <> typ_else then + Errors.raise_multispanned_error + "expecting same types for the true and false branches of the conditional" + [ + ( Some (Format.sprintf "the true branch has type %s" (Format_lambda.format_typ typ_then)), + Pos.get_position (fst t_then) ); + ( Some + (Format.sprintf "the false branch has type %s" (Format_lambda.format_typ typ_else)), + Pos.get_position (fst t_else) ); + ] + else ((EIfThenElse (t_if, t_then, t_else), pos), typ_then) + | EInt _ | EDec _ -> ((t, pos), TInt) + | EBool _ -> ((t, pos), TBool) + | EOp op -> + let typ = + match op with + | Binop binop -> ( + match binop with + | And | Or -> Lambda_ast.TArrow (TBool, TArrow (TBool, TBool)) + | Add | Sub | Mult | Div -> TArrow (TInt, TArrow (TInt, TInt)) + | Lt | Lte | Gt | Gte | Eq | Neq -> TArrow (TInt, TArrow (TInt, TBool)) ) + | Unop Minus -> TArrow (TInt, TInt) + | Unop Not -> TArrow (TBool, TBool) + in + ((t, pos), typ) + | EDefault t -> + let defaults = + List.map + (fun (just, cons) -> + let just_t = type_term ctxt local_ctx just in + if Lambda_ast.get_typ just_t <> TBool then + let cons = type_term ctxt local_ctx cons in + (just_t, cons) + else + Errors.raise_spanned_error + (Format.sprintf "expected type of default condition to be bool, got %s" + (Format_lambda.format_typ (Lambda_ast.get_typ just))) + (Pos.get_position (fst just))) + t.defaults + in + let typ_cons = List.hd defaults |> snd |> snd in + List.iter + (fun (_, cons) -> + if Lambda_ast.get_typ cons <> typ_cons then + Errors.raise_spanned_error + (Format.sprintf "expected default condition to be of type %s, got type %s" + (Format_lambda.format_typ (Lambda_ast.get_typ cons)) + (Format_lambda.format_typ typ_cons)) + (Pos.get_position (fst cons)) + else ()) + defaults; + ((EDefault { t with defaults }, pos), typ_cons) diff --git a/src/catala/scope_language/format_scope.ml b/src/catala/scope_language/format_scope.ml index f96410d7..89d3491e 100644 --- a/src/catala/scope_language/format_scope.ml +++ b/src/catala/scope_language/format_scope.ml @@ -12,27 +12,21 @@ or implied. See the License for the specific language governing permissions and limitations under the License. *) -module UidMap = Uid.UidMap - (** Print a scope program *) let print_scope (scope : Scope_ast.scope) : string = - let print_defs (defs : Scope_ast.definition UidMap.t) : string = - defs |> UidMap.bindings + let print_defs (defs : Scope_ast.definition Uid.ScopeDefMap.t) : string = + defs |> Uid.ScopeDefMap.bindings |> List.map (fun (uid, term) -> - Printf.sprintf "%s:\n%s" (Uid.get_ident uid) (Format_lambda.print_term term)) + Printf.sprintf "%s:\n%s" (Uid.ScopeDef.format_t uid) (Format_lambda.print_term term)) |> String.concat "" in "___Variables Definition___\n" ^ print_defs scope.scope_defs ^ "___Subscope (Re)definition___\n" - ^ ( scope.scope_sub_defs |> UidMap.bindings - |> List.map (fun (scope_uid, defs) -> - Printf.sprintf "__%s__:\n%s" (Uid.get_ident scope_uid) (print_defs defs)) - |> String.concat "" ) ^ "\n" (** Print the whole program *) let print_program (prgm : Scope_ast.program) : string = - prgm |> UidMap.bindings + prgm |> Uid.ScopeMap.bindings |> List.map (fun (uid, scope) -> - Printf.sprintf "Scope %s:\n%s" (Uid.get_ident uid) (print_scope scope)) + Printf.sprintf "Scope %s:\n%s" (Uid.Scope.format_t uid) (print_scope scope)) |> String.concat "\n" |> Printf.sprintf "Scope program\n%s" diff --git a/src/catala/scope_language/name_resolution.ml b/src/catala/scope_language/name_resolution.ml deleted file mode 100644 index 75aae48f..00000000 --- a/src/catala/scope_language/name_resolution.ml +++ /dev/null @@ -1,277 +0,0 @@ -(* This file is part of the Catala compiler, a specification language for tax and social benefits - computation rules. Copyright (C) 2020 Inria, contributor: Nicolas Chataing - - - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except - in compliance with the License. You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software distributed under the License - is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express - or implied. See the License for the specific language governing permissions and limitations under - the License. *) - -type uid = Uid.t - -type scope_uid = Uid.t - -type var_uid = Uid.t - -type sub_scope_uid = Uid.t - -module UidMap = Uid.UidMap -module UidSet = Uid.UidSet - -type ident = string - -module IdentMap = Map.Make (String) - -type typ = Lambda_ast.typ - -type sort = - | IdScope - | IdScopeVar of uid option - | IdSubScope of uid - | IdSubScopeVar of var_uid * sub_scope_uid - | IdBinder - -type uid_data = { uid_typ : typ; uid_sort : sort } - -type scope_context = { var_id_to_uid : uid IdentMap.t; uid_set : UidSet.t } - -type context = { - scope_id_to_uid : uid IdentMap.t; - scopes : scope_context UidMap.t; - data : uid_data UidMap.t; -} - -let subscope_ident (y : string) (x : string) : string = y ^ "::" ^ x - -let raise_unsupported_feature (msg : string) (pos : Pos.t) = - Errors.raise_spanned_error (Printf.sprintf "Unsupported feature: %s" msg) pos - -let raise_undefined_identifier (msg : string) (pos : Pos.t) = - Errors.raise_spanned_error (Printf.sprintf "Undefined identifier: %s" msg) pos - -let raise_unknown_identifier (msg : string) (pos : Pos.t) = - Errors.raise_spanned_error (Printf.sprintf "Unknown identifier: %s" msg) pos - -(** Get the type associated to an uid *) -let get_uid_typ (ctxt : context) (uid : uid) : typ = (UidMap.find uid ctxt.data).uid_typ - -(** Get the sort associated to an uid *) -let get_uid_sort (ctxt : context) (uid : uid) : sort = (UidMap.find uid ctxt.data).uid_sort - -(** Process a subscope declaration *) -let process_subscope_decl (scope : uid) (ctxt : context) - (decl : Catala_ast.scope_decl_context_scope) : context = - let name, decl_pos = decl.scope_decl_context_scope_name in - let subscope, s_pos = decl.scope_decl_context_scope_sub_scope in - (* First check that the designated subscope is a scope *) - let sub_uid = - match IdentMap.find_opt subscope ctxt.scope_id_to_uid with - | None -> raise_undefined_identifier subscope s_pos - | Some uid -> ( - match get_uid_sort ctxt uid with - | IdScope -> uid - | _ -> raise_undefined_identifier "..." s_pos ) - in - let scope_ctxt = UidMap.find scope ctxt.scopes in - let subscope_ctxt = UidMap.find sub_uid ctxt.scopes in - match IdentMap.find_opt name scope_ctxt.var_id_to_uid with - | Some _ -> assert false (* Variable is already used in this scope *) - | None -> - let sub_scope_uid = Uid.fresh name decl_pos in - let scope_ctxt = - { - var_id_to_uid = IdentMap.add name sub_scope_uid scope_ctxt.var_id_to_uid; - uid_set = UidSet.add sub_scope_uid scope_ctxt.uid_set; - } - in - let ctxt = - { - ctxt with - scopes = UidMap.add scope scope_ctxt ctxt.scopes; - data = - UidMap.add sub_scope_uid - { uid_typ = Lambda_ast.TDummy; uid_sort = IdSubScope sub_uid } - ctxt.data; - } - in - (* Now duplicate all variables from the subscope *) - IdentMap.fold - (fun sub_var sub_uid ctxt -> - let fresh_varname = subscope_ident name sub_var in - (* We use the same pos as the subscope declaration *) - let fresh_uid = Uid.fresh fresh_varname decl_pos in - let scope_ctxt = UidMap.find scope ctxt.scopes in - let scope_ctxt = - { - var_id_to_uid = IdentMap.add fresh_varname fresh_uid scope_ctxt.var_id_to_uid; - uid_set = UidSet.add fresh_uid scope_ctxt.uid_set; - } - in - let sub_data = UidMap.find sub_uid ctxt.data in - (* Add a reference to the subvar *) - let data = { sub_data with uid_sort = IdSubScopeVar (sub_uid, sub_scope_uid) } in - { - ctxt with - scopes = UidMap.add scope scope_ctxt ctxt.scopes; - data = UidMap.add fresh_uid data ctxt.data; - }) - subscope_ctxt.var_id_to_uid ctxt - -let process_base_typ ((typ, typ_pos) : Catala_ast.base_typ Pos.marked) : Lambda_ast.typ = - match typ with - | Catala_ast.Condition -> Lambda_ast.TBool - | Catala_ast.Data (Catala_ast.Collection _) -> raise_unsupported_feature "collection type" typ_pos - | Catala_ast.Data (Catala_ast.Optional _) -> raise_unsupported_feature "option type" typ_pos - | Catala_ast.Data (Catala_ast.Primitive prim) -> ( - match prim with - | Catala_ast.Integer | Catala_ast.Decimal | Catala_ast.Money | Catala_ast.Date -> - Lambda_ast.TInt - | Catala_ast.Boolean -> Lambda_ast.TBool - | Catala_ast.Text -> raise_unsupported_feature "text type" typ_pos - | Catala_ast.Named _ -> raise_unsupported_feature "struct or enum types" typ_pos ) - -let process_type ((typ, typ_pos) : Catala_ast.typ Pos.marked) : Lambda_ast.typ = - match typ with - | Catala_ast.Base base_typ -> process_base_typ (base_typ, typ_pos) - | Catala_ast.Func { arg_typ; return_typ } -> - Lambda_ast.TArrow (process_base_typ arg_typ, process_base_typ return_typ) - -(** Process data declaration *) -let process_data_decl (scope : uid) (ctxt : context) (decl : Catala_ast.scope_decl_context_data) : - context = - (* First check the type of the context data *) - let lambda_typ = process_type decl.scope_decl_context_item_typ in - let name, pos = decl.scope_decl_context_item_name in - let scope_ctxt = UidMap.find scope ctxt.scopes in - match IdentMap.find_opt name scope_ctxt.var_id_to_uid with - | Some _ -> (* Variable is already used in this scope *) assert false - | None -> ( - let uid = Uid.fresh name pos in - let scope_ctxt = - { - var_id_to_uid = IdentMap.add name uid scope_ctxt.var_id_to_uid; - uid_set = UidSet.add uid scope_ctxt.uid_set; - } - in - - match lambda_typ with - | TArrow (arg_typ, _) -> - (* We now can get a fresh uid for the data *) - let arg_uid = Uid.fresh (Printf.sprintf "ARG_OF(%s)" name) pos in - let arg_data = { uid_typ = arg_typ; uid_sort = IdBinder } in - let var_data = { uid_typ = lambda_typ; uid_sort = IdScopeVar (Some arg_uid) } in - let data = ctxt.data |> UidMap.add uid var_data |> UidMap.add arg_uid arg_data in - { ctxt with scopes = UidMap.add scope scope_ctxt ctxt.scopes; data } - | _ -> - { - ctxt with - scopes = UidMap.add scope scope_ctxt ctxt.scopes; - data = UidMap.add uid { uid_typ = lambda_typ; uid_sort = IdScopeVar None } ctxt.data; - } ) - -(** Process an item declaration *) -let process_item_decl (scope : uid) (ctxt : context) (decl : Catala_ast.scope_decl_context_item) : - context = - match decl with - | Catala_ast.ContextData data_decl -> process_data_decl scope ctxt data_decl - | Catala_ast.ContextScope sub_decl -> process_subscope_decl scope ctxt sub_decl - -(** Process a scope declaration *) -let process_scope_decl (ctxt : context) (decl : Catala_ast.scope_decl) : context = - let name, pos = decl.scope_decl_name in - (* Checks if the name is already used *) - match IdentMap.find_opt name ctxt.scope_id_to_uid with - | Some _ -> assert false - | None -> - let scope_uid = Uid.fresh name pos in - let ctxt = - { - scope_id_to_uid = IdentMap.add name scope_uid ctxt.scope_id_to_uid; - data = UidMap.add scope_uid { uid_typ = Lambda_ast.TDummy; uid_sort = IdScope } ctxt.data; - scopes = - UidMap.add scope_uid - { var_id_to_uid = IdentMap.empty; uid_set = UidSet.empty } - ctxt.scopes; - } - in - List.fold_left - (fun ctxt item -> process_item_decl scope_uid ctxt (Pos.unmark item)) - ctxt decl.scope_decl_context - -(** Process a code item : for now it only handles scope decls *) -let process_code_item (ctxt : context) (item : Catala_ast.code_item) : context = - match item with ScopeDecl decl -> process_scope_decl ctxt decl | _ -> ctxt - -(** Process a code block *) -let process_code_block (ctxt : context) (block : Catala_ast.code_block) : context = - List.fold_left (fun ctxt decl -> Pos.unmark decl |> process_code_item ctxt) ctxt block - -(** Process a program item *) -let process_program_item (ctxt : context) (item : Catala_ast.program_item) : context = - match item with - | CodeBlock (block, _) | MetadataBlock (block, _) -> process_code_block ctxt block - | _ -> ctxt - -(** Derive the context from metadata *) -let form_context (prgm : Catala_ast.program) : context = - let empty_ctxt = - { scope_id_to_uid = IdentMap.empty; scopes = UidMap.empty; data = UidMap.empty } - in - List.fold_left process_program_item empty_ctxt prgm.program_items - -(** Get the variable uid inside the scope given in argument *) -let get_var_uid (scope_uid : uid) (ctxt : context) ((x, pos) : ident Pos.marked) : uid = - let scope = UidMap.find scope_uid ctxt.scopes in - match IdentMap.find_opt x scope.var_id_to_uid with - | None -> raise_undefined_identifier x pos - | Some uid -> ( - (* Checks that the uid has sort IdScopeVar or IdScopeBinder *) - match get_uid_sort ctxt uid with - | IdScopeVar _ | IdBinder | IdSubScopeVar _ -> uid - | _ -> - let err_msg = Printf.sprintf "Identifier \"%s\" should be a variable, but it isn't" x in - Errors.raise_spanned_error err_msg pos ) - -(** Get the subscope uid inside the scope given in argument *) -let get_subscope_uid (scope_uid : uid) (ctxt : context) ((y, pos) : ident Pos.marked) : uid * uid = - let scope = UidMap.find scope_uid ctxt.scopes in - match IdentMap.find_opt y scope.var_id_to_uid with - | None -> raise_unknown_identifier y pos - | Some sub_uid -> ( - match get_uid_sort ctxt sub_uid with - | IdSubScope scope_ref -> (sub_uid, scope_ref) - | _ -> - let err_msg = Printf.sprintf "Identifier \"%s\" should be a subscope, but it isn't" y in - Errors.raise_spanned_error err_msg pos ) - -(** Checks if the var_uid belongs to the scope scope_uid *) -let belongs_to (ctxt : context) (uid : var_uid) (scope_uid : scope_uid) : bool = - let scope = UidMap.find scope_uid ctxt.scopes in - UidSet.mem uid scope.uid_set - -(** Adds a binding to the context *) -let add_binding (ctxt : context) (scope_uid : Uid.t) (fun_uid : Uid.t) - (bind_name : ident Pos.marked option) : context * Uid.t option = - match bind_name with - | None -> (ctxt, None) - | Some name -> - let name = Pos.unmark name in - let scope_ctxt = UidMap.find scope_uid ctxt.scopes in - let arg_uid = - match get_uid_sort ctxt fun_uid with - | IdScopeVar (Some arg_uid) -> arg_uid - | _ -> - Errors.raise_spanned_error - (Printf.sprintf "Var %s is supposed to be a function but it isn't" - (Uid.get_ident fun_uid)) - (Uid.get_pos fun_uid) - in - let scope_ctxt = - { scope_ctxt with var_id_to_uid = IdentMap.add name arg_uid scope_ctxt.var_id_to_uid } - in - ({ ctxt with scopes = UidMap.add scope_uid scope_ctxt ctxt.scopes }, Some arg_uid) diff --git a/src/catala/scope_language/scope_ast.ml b/src/catala/scope_language/scope_ast.ml index dd1a0137..7b596101 100644 --- a/src/catala/scope_language/scope_ast.ml +++ b/src/catala/scope_language/scope_ast.ml @@ -12,21 +12,12 @@ or implied. See the License for the specific language governing permissions and limitations under the License. *) -module UidMap = Uid.UidMap - (* Scopes *) -type binder = string Pos.marked +type binder = Uid.LocalVar.t type definition = Lambda_ast.term -let empty_func_def (bind : Uid.t) (pos : Pos.t) (typ : Lambda_ast.typ) : definition = - match typ with - | TArrow (t_arg, t_ret) -> - let body_term : Lambda_ast.term = ((EDefault Lambda_ast.empty_default_term, pos), t_ret) in - ((EFun ([ (bind, t_arg) ], body_term), pos), typ) - | _ -> assert false - -let empty_var_def (pos : Pos.t) (typ : Lambda_ast.typ) : definition = +let empty_def (pos : Pos.t) (typ : Lambda_ast.typ) : definition = ((EDefault Lambda_ast.empty_default_term, pos), typ) type assertion = Lambda_ast.term @@ -40,20 +31,18 @@ type meta_assertion = | VariesWith of Lambda_ast.term * variation_typ Pos.marked option type scope = { - scope_uid : Uid.t; - scope_defs : definition UidMap.t; - scope_sub_defs : definition UidMap.t UidMap.t; + scope_uid : Uid.Scope.t; + scope_defs : definition Uid.ScopeDefMap.t; scope_assertions : assertion list; - scope_meta_assertions : meta_assertion list UidMap.t; + scope_meta_assertions : meta_assertion list; } -let empty_scope (uid : Uid.t) : scope = +let empty_scope (uid : Uid.Scope.t) : scope = { scope_uid = uid; - scope_defs = UidMap.empty; - scope_sub_defs = UidMap.empty; + scope_defs = Uid.ScopeDefMap.empty; scope_assertions = []; - scope_meta_assertions = UidMap.empty; + scope_meta_assertions = []; } -type program = scope UidMap.t +type program = scope Uid.ScopeMap.t diff --git a/src/catala/scope_language/scope_interpreter.ml b/src/catala/scope_language/scope_interpreter.ml index 5dc06c17..81a7b4fc 100644 --- a/src/catala/scope_language/scope_interpreter.ml +++ b/src/catala/scope_language/scope_interpreter.ml @@ -12,155 +12,290 @@ or implied. See the License for the specific language governing permissions and limitations under the License. *) -module G = Graph.Pack.Digraph -module UidMap = Uid.UidMap -module UidSet = Uid.UidSet +(** The vertices of the scope dependency graph are either : -(** Returns the scheduling of the scope variables, if y is a subscope and x a variable of y, then we - have two different variable y.x(internal) and y.x(result) and the ordering y.x(internal) -> y -> - y.x(result) *) -let build_scope_schedule (ctxt : Name_resolution.context) (scope : Scope_ast.scope) : G.t = - let g = G.create ~size:100 () in + - the variables of the scope ; + - the subscopes of the scope. + + Indeed, during interpretation, subscopes are executed atomically. + + In the graph, x -> y if x is used in the definition of y. *) + +module Vertex = struct + type t = Var of Uid.Var.t | SubScope of Uid.SubScope.t + + let hash x = match x with Var x -> Uid.Var.hash x | SubScope x -> Uid.SubScope.hash x + + let compare = compare + + let equal x y = + match (x, y) with + | Var x, Var y -> Uid.Var.compare x y = 0 + | SubScope x, SubScope y -> Uid.SubScope.compare x y = 0 + | _ -> false + + let format_t (x : t) : string = + match x with Var v -> Uid.Var.format_t v | SubScope v -> Uid.SubScope.format_t v +end + +(** On the edges, the label is the expression responsible for the use of the variable *) +module Edge = struct + type t = Pos.t + + let compare = compare + + let default = Pos.no_pos +end + +module ScopeDependencies = Graph.Persistent.Digraph.ConcreteBidirectionalLabeled (Vertex) (Edge) +module TopologicalTraversal = Graph.Topological.Make (ScopeDependencies) + +module SCC = Graph.Components.Make (ScopeDependencies) +(** Tarjan's stongly connected components algorithm, provided by OCamlGraph *) + +(** Outputs an error in case of cycles. *) +let check_for_cycle (g : ScopeDependencies.t) : unit = + (* if there is a cycle, there will be an strongly connected component of cardinality > 1 *) + let sccs = SCC.scc_list g in + if List.length sccs < ScopeDependencies.nb_vertex g then + let scc = List.find (fun scc -> List.length scc > 1) sccs in + Errors.raise_multispanned_error "cyclic dependency dected between variables!" + (List.flatten + (List.map + (fun v -> + let var_str, var_info = + match v with + | Vertex.Var v -> (Uid.Var.format_t v, Uid.Var.get_info v) + | Vertex.SubScope v -> (Uid.SubScope.format_t v, Uid.SubScope.get_info v) + in + let succs = ScopeDependencies.succ_e g v in + let _, edge_pos, succ = List.find (fun (_, _, succ) -> List.mem succ scc) succs in + let succ_str = + match succ with + | Vertex.Var v -> Uid.Var.format_t v + | Vertex.SubScope v -> Uid.SubScope.format_t v + in + [ + (Some ("cycle variable " ^ var_str ^ ", declared:"), Pos.get_position var_info); + ( Some ("used here in the definition of another cycle variable " ^ succ_str ^ ":"), + edge_pos ); + ]) + scc)) + +let build_scope_dependencies (scope : Scope_ast.scope) (ctxt : Name_resolution.context) : + ScopeDependencies.t = + let g = ScopeDependencies.empty in let scope_uid = scope.scope_uid in (* Add all the vertices to the graph *) - let vertices = - UidSet.fold - (fun uid verts -> - match Name_resolution.get_uid_sort ctxt uid with - | IdScopeVar _ | IdSubScope _ -> UidMap.add uid (G.V.create uid) verts - | _ -> verts) - (UidMap.find scope_uid ctxt.scopes).uid_set UidMap.empty + let scope_ctxt = Uid.ScopeMap.find scope_uid ctxt.scopes in + let g = + Uid.IdentMap.fold + (fun _ (v : Uid.Var.t) g -> ScopeDependencies.add_vertex g (Vertex.Var v)) + scope_ctxt.var_idmap g + in + let g = + Uid.IdentMap.fold + (fun _ (v : Uid.SubScope.t) g -> ScopeDependencies.add_vertex g (Vertex.SubScope v)) + scope_ctxt.sub_scopes_idmap g + in + let g = + Uid.ScopeDefMap.fold + (fun def_key def g -> + let fv = Lambda_ast.term_fv def in + Uid.ScopeDefSet.fold + (fun fv_def g -> + match (def_key, fv_def) with + | Uid.ScopeDef.Var defined, Uid.ScopeDef.Var used -> + (* simple case *) + ScopeDependencies.add_edge g (Vertex.Var used) (Vertex.Var defined) + | Uid.ScopeDef.SubScopeVar (defined, _), Uid.ScopeDef.Var used -> + (* here we are defining the input of a subscope using a var of the scope *) + ScopeDependencies.add_edge g (Vertex.Var used) (Vertex.SubScope defined) + | Uid.ScopeDef.SubScopeVar (defined, _), Uid.ScopeDef.SubScopeVar (used, _) -> + (* here we are defining the input of a scope with the output of another subscope *) + ScopeDependencies.add_edge g (Vertex.SubScope used) (Vertex.SubScope defined) + | Uid.ScopeDef.Var defined, Uid.ScopeDef.SubScopeVar (used, _) -> + (* finally we define a scope var with the output of a subscope *) + ScopeDependencies.add_edge g (Vertex.SubScope used) (Vertex.Var defined)) + fv g) + scope.scope_defs g in - UidMap.iter (fun _ v -> G.add_vertex g v) vertices; - (* Process definitions dependencies. There are two types of dependencies : var -> var; sub_scope - -> var *) - UidMap.iter - (fun var_uid def -> - let fv = Lambda_ast.term_fv def in - UidSet.iter - (fun uid -> - if Name_resolution.belongs_to ctxt uid scope_uid then - let data = UidMap.find uid ctxt.data in - let from_uid = - match data.uid_sort with - | IdScopeVar _ -> uid - | IdSubScopeVar (_, sub_scope_uid) -> sub_scope_uid - | _ -> assert false - in - G.add_edge g (UidMap.find from_uid vertices) (UidMap.find var_uid vertices) - else ()) - fv) - scope.scope_defs; - (* Process sub-definitions dependencies. Only one kind of dependencies : var -> sub_scopes*) - UidMap.iter - (fun sub_scope_uid defs -> - UidMap.iter - (fun _ def -> - let fv = Lambda_ast.term_fv def in - UidSet.iter - (fun var_uid -> - (* Process only uid from the current scope (not the subscope) *) - if Name_resolution.belongs_to ctxt var_uid scope_uid then - G.add_edge g (UidMap.find var_uid vertices) (UidMap.find sub_scope_uid vertices) - else ()) - fv) - defs) - scope.scope_sub_defs; g -let merge_var_redefs (subscope : Scope_ast.scope) (redefs : Scope_ast.definition UidMap.t) : - Scope_ast.scope = +let rec rewrite_subscope_redef_before_call (ctxt : Name_resolution.context) + (parent_scope : Uid.Scope.t) (subscope : Scope_ast.scope) + ((redef, redef_ty) : Scope_ast.definition) : Scope_ast.definition = + let parent_scope_ctx = Uid.ScopeMap.find parent_scope ctxt.scopes in + let rec_call = rewrite_subscope_redef_before_call ctxt parent_scope subscope in + match Pos.unmark redef with + | Lambda_ast.EVar (prefix, var) -> ( + match prefix with + | Lambda_ast.NoPrefix -> + (* this is a variable of the parent scope, we add the prefix *) + ( Pos.same_pos_as (Lambda_ast.EVar (Lambda_ast.CallerPrefix (1, None), var)) redef, + redef_ty ) + | Lambda_ast.SubScopePrefix parent_sub -> + let parent_sub_real = Uid.SubScopeMap.find parent_sub parent_scope_ctx.sub_scopes in + (* two cases here *) + if parent_sub_real = subscope.scope_uid then + (* we remove the prefix since we're calling this precise subscope *) + (Pos.same_pos_as (Lambda_ast.EVar (Lambda_ast.NoPrefix, var)) redef, redef_ty) + else + (* we add the caller prefix*) + ( Pos.same_pos_as + (Lambda_ast.EVar (Lambda_ast.CallerPrefix (1, Some parent_sub), var)) + redef, + redef_ty ) + | Lambda_ast.CallerPrefix (i, grand_parent_sub) -> + (* In this tricky case, we are trying to call a subscope while being executed as a + subscope of a "grand-parent" scope. See [tests/scopes/grand_parent_scope.catala] for an + exemple. What we do in this case is that we propagate the prefix while adding 1 to the + generation counter *) + ( Pos.same_pos_as + (Lambda_ast.EVar (Lambda_ast.CallerPrefix (i + 1, grand_parent_sub), var)) + redef, + redef_ty ) ) + | Lambda_ast.EInt _ | Lambda_ast.EBool _ | Lambda_ast.EDec _ | Lambda_ast.EOp _ + | Lambda_ast.ELocalVar _ -> + (redef, redef_ty) + | Lambda_ast.EFun (bindings, body) -> + (Pos.same_pos_as (Lambda_ast.EFun (bindings, rec_call body)) redef, redef_ty) + | Lambda_ast.EApp (f, args) -> + (Pos.same_pos_as (Lambda_ast.EApp (rec_call f, List.map rec_call args)) redef, redef_ty) + | Lambda_ast.EIfThenElse (if_t, then_t, else_t) -> + ( Pos.same_pos_as + (Lambda_ast.EIfThenElse (rec_call if_t, rec_call then_t, rec_call else_t)) + redef, + redef_ty ) + | Lambda_ast.EDefault default -> + ( Pos.same_pos_as + (Lambda_ast.EDefault + { + default with + defaults = List.map (fun (x, y) -> (rec_call x, rec_call y)) default.defaults; + }) + redef, + redef_ty ) + +(** In this function, the keys of the [redefs] maps are variables of the [subscope] *) +let merge_var_redefs_before_subscope_call (ctxt : Name_resolution.context) + (parent_scope : Uid.Scope.t) (subscope : Scope_ast.scope) + (redefs : Scope_ast.definition Uid.VarMap.t) : Scope_ast.scope = let merge_defaults : Lambda_ast.term -> Lambda_ast.term -> Lambda_ast.term = Lambda_ast.map_untype2 (fun old_t new_t -> match (old_t, new_t) with | EDefault old_def, EDefault new_def -> EDefault (Lambda_ast.merge_default_terms old_def new_def) - | EFun ([ bind ], old_t), EFun (_, new_t) -> - let body = - Lambda_ast.map_untype2 - (fun old_t new_t -> - match (old_t, new_t) with - | EDefault old_def, EDefault new_def -> - EDefault (Lambda_ast.merge_default_terms old_def new_def) - | _ -> assert false) - old_t new_t - in - EFun ([ bind ], body) - | _ -> assert false) + | _ -> assert false + (* should not happen *)) in - + (* when merging redefinitions inside a subscope for execution, we need to annotate the variables + of the parent scope with the caller prefix *) { subscope with scope_defs = - UidMap.fold - (fun uid new_def sub_defs -> - match UidMap.find_opt uid sub_defs with - | None -> UidMap.add uid new_def sub_defs + Uid.VarMap.fold + (fun new_def_var new_def sub_defs -> + let new_def = rewrite_subscope_redef_before_call ctxt parent_scope subscope new_def in + match Uid.ScopeDefMap.find_opt (Uid.ScopeDef.Var new_def_var) sub_defs with + | None -> Uid.ScopeDefMap.add (Uid.ScopeDef.Var new_def_var) new_def sub_defs | Some old_def -> let def = merge_defaults old_def new_def in - UidMap.add uid def sub_defs) + Uid.ScopeDefMap.add (Uid.ScopeDef.Var new_def_var) def sub_defs) redefs subscope.scope_defs; } +let rewrite_context_before_executing_subscope (subscope : Uid.SubScope.t) + (exec_context : Lambda_interpreter.exec_context) : Lambda_interpreter.exec_context = + Lambda_interpreter.ExecContext.fold + (fun key value acc -> + match key with + | Lambda_interpreter.ExecContextKey.LocalVar _ -> + (* we can forget local vars when entering a subscope *) + acc + | Lambda_interpreter.ExecContextKey.ScopeVar (prefix, var) -> + let new_prefix = + (* note: this has to match with the behavior of [rewrite_subscope_redef_before_call] *) + match prefix with + | Lambda_ast.NoPrefix -> Lambda_ast.CallerPrefix (1, None) + | Lambda_ast.CallerPrefix (i, sub) -> Lambda_ast.CallerPrefix (i + 1, sub) + | Lambda_ast.SubScopePrefix sub -> + if sub = subscope then Lambda_ast.NoPrefix else Lambda_ast.CallerPrefix (1, Some sub) + in + Lambda_interpreter.ExecContext.add + (Lambda_interpreter.ExecContextKey.ScopeVar (new_prefix, var)) + value acc) + exec_context Lambda_interpreter.ExecContext.empty + +let rewrite_context_after_executing_subscope (subscope : Uid.SubScope.t) + (exec_context : Lambda_interpreter.exec_context) : Lambda_interpreter.exec_context = + Lambda_interpreter.ExecContext.fold + (fun key value acc -> + match key with + | Lambda_interpreter.ExecContextKey.LocalVar _ -> + (* we can forget local vars when entering a subscope *) + acc + | Lambda_interpreter.ExecContextKey.ScopeVar (prefix, var) -> ( + let new_prefix = + match prefix with + | Lambda_ast.NoPrefix -> Some (Lambda_ast.SubScopePrefix subscope) + | Lambda_ast.CallerPrefix (i, sub) -> ( + if i > 1 then Some (Lambda_ast.CallerPrefix (i - 1, sub)) + else + match sub with + | None -> Some Lambda_ast.NoPrefix + | Some sub -> Some (Lambda_ast.SubScopePrefix sub) ) + | Lambda_ast.SubScopePrefix _ -> None + (* we drop the subscope's subscopes since they can't be accessed *) + in + match new_prefix with + | None -> acc + | Some new_prefix -> + Lambda_interpreter.ExecContext.add + (Lambda_interpreter.ExecContextKey.ScopeVar (new_prefix, var)) + value acc )) + exec_context Lambda_interpreter.ExecContext.empty + let rec execute_scope ?(exec_context = Lambda_interpreter.empty_exec_ctxt) (ctxt : Name_resolution.context) (prgm : Scope_ast.program) (scope_prgm : Scope_ast.scope) : Lambda_interpreter.exec_context = - let schedule = build_scope_schedule ctxt scope_prgm in - - (* Printf.printf "Scheduling : "; *) - (* G.Topological.iter (fun v_uid -> Printf.printf "%s; " (G.V.label v_uid |> Uid.get_ident)) - schedule; *) - (* Printf.printf "\n"; *) - G.Topological.fold - (fun v_uid exec_context -> - let uid = G.V.label v_uid in - match Name_resolution.get_uid_sort ctxt uid with - | IdScopeVar _ -> ( - match UidMap.find_opt uid scope_prgm.scope_defs with + let scope_ctxt = Uid.ScopeMap.find scope_prgm.scope_uid ctxt.scopes in + let deps = build_scope_dependencies scope_prgm ctxt in + check_for_cycle deps; + TopologicalTraversal.fold + (fun v exec_context -> + match v with + | Vertex.Var var -> ( + match Uid.ScopeDefMap.find_opt (Uid.ScopeDef.Var var) scope_prgm.scope_defs with | Some def -> - UidMap.add uid - (Lambda_interpreter.eval_term uid exec_context def |> Lambda_ast.untype) + (* we evaluate a variable of the scope, no tricky business here *) + Lambda_interpreter.ExecContext.add + (Lambda_interpreter.ExecContextKey.ScopeVar (Lambda_ast.NoPrefix, var)) + ( Lambda_interpreter.eval_term (Uid.ScopeDef.Var var) exec_context def + |> Lambda_ast.untype ) exec_context - | None -> - Errors.raise_multispanned_error - (Printf.sprintf "Variable %s is undefined in scope %s" (Uid.get_ident uid) - (Uid.get_ident scope_prgm.scope_uid)) - [ (None, Uid.get_pos scope_prgm.scope_uid); (None, Uid.get_pos uid) ] ) - | IdSubScope sub_scope_ref -> - (* Merge the new definitions *) - let sub_scope_prgm = - match UidMap.find_opt sub_scope_ref prgm with - | Some sub_scope -> sub_scope - | None -> - Errors.raise_multispanned_error - (Printf.sprintf - "The subscope %s of %s has no definition inside it, and therefore cannot be \ - executed" - (Uid.get_ident scope_prgm.scope_uid) - (Uid.get_ident sub_scope_ref)) - [ (None, Uid.get_pos scope_prgm.scope_uid); (None, Uid.get_pos sub_scope_ref) ] + | None -> assert false (* should not happen *) ) + | Vertex.SubScope subscope_uid -> + (* this is the tricky case where we have to the the bookkeeping of rewriting the context + and additional defaults that we pass for the subscope for execution. See formalization + for more details *) + let subscope_real_uid = Uid.SubScopeMap.find subscope_uid scope_ctxt.sub_scopes in + let subscope = Uid.ScopeMap.find subscope_real_uid prgm in + let redefs_to_include_to_subscope = + Uid.ScopeDefMap.fold + (fun def_key def acc -> + match def_key with + | Uid.ScopeDef.Var _ -> acc + | Uid.ScopeDef.SubScopeVar (def_sub_uid, var) -> + if def_sub_uid = subscope_uid then Uid.VarMap.add var def acc else acc) + scope_prgm.scope_defs Uid.VarMap.empty in - let redefs = - match UidMap.find_opt uid scope_prgm.scope_sub_defs with - | Some defs -> defs - | None -> UidMap.empty + let subscope = + merge_var_redefs_before_subscope_call ctxt scope_prgm.scope_uid subscope + redefs_to_include_to_subscope in - let new_sub_scope_prgm = merge_var_redefs sub_scope_prgm redefs in - (* Scope_ast.print_scope new_sub_scope_prgm; *) - let out_context = execute_scope ctxt ~exec_context prgm new_sub_scope_prgm in - (* Now let's merge back the value from the output context *) - UidSet.fold - (fun var_uid exec_context -> - match Name_resolution.get_uid_sort ctxt var_uid with - | IdSubScopeVar (ref_uid, scope_ref) -> - if uid = scope_ref then - match Name_resolution.get_uid_sort ctxt ref_uid with - | IdScopeVar _ | IdSubScopeVar _ -> - let value = UidMap.find ref_uid out_context in - UidMap.add var_uid value exec_context - | _ -> exec_context - else exec_context - | _ -> exec_context) - (UidMap.find scope_prgm.scope_uid ctxt.scopes).uid_set exec_context - | _ -> assert false) - schedule exec_context + let exec_context = rewrite_context_before_executing_subscope subscope_uid exec_context in + let exec_context = execute_scope ~exec_context ctxt prgm subscope in + let exec_context = rewrite_context_after_executing_subscope subscope_uid exec_context in + exec_context) + deps exec_context diff --git a/src/catala/scope_language/uid.ml b/src/catala/scope_language/uid.ml index 731c247f..270b411c 100644 --- a/src/catala/scope_language/uid.ml +++ b/src/catala/scope_language/uid.ml @@ -12,30 +12,87 @@ or implied. See the License for the specific language governing permissions and limitations under the License. *) -type t = int +module IdentMap = Map.Make (String) -module UidSet = Set.Make (Int) -module UidMap = Map.Make (Int) +module type Id = sig + type t -type ident = string + type info -let ident_tbl = ref UidMap.empty + val fresh : info -> t -let pos_tbl = ref UidMap.empty + val get_info : t -> info -let counter = ref 0 + val compare : t -> t -> int -let fresh (id : ident) (pos : Pos.t) : t = - incr counter; - ident_tbl := UidMap.add !counter id !ident_tbl; - pos_tbl := UidMap.add !counter pos !pos_tbl; - !counter + val format_t : t -> string -let get_ident (uid : t) : ident = UidMap.find uid !ident_tbl + val hash : t -> int +end -let get_pos (uid : t) : Pos.t = UidMap.find uid !pos_tbl +module Make (X : sig + type info -let map_add_list (key : t) (item : 'a) (map : 'a list UidMap.t) = - match UidMap.find_opt key map with - | Some l -> UidMap.add key (item :: l) map - | None -> UidMap.add key [ item ] map + val format_info : info -> string +end) : Id with type info = X.info = struct + type t = { id : int; info : X.info } + + type info = X.info + + let counter = ref 0 + + let fresh (info : X.info) : t = + incr counter; + { id = !counter; info } + + let get_info (uid : t) : X.info = uid.info + + let compare (x : t) (y : t) : int = compare x.id y.id + + let format_t (x : t) : string = Printf.sprintf "%s" (X.format_info x.info) + + let hash (x : t) : int = x.id +end + +module MarkedString = struct + type info = string Pos.marked + + let format_info (s, _) = s +end + +module Scope = Make (MarkedString) +module ScopeSet = Set.Make (Scope) +module ScopeMap = Map.Make (Scope) +module Var = Make (MarkedString) +module VarSet = Set.Make (Var) +module VarMap = Map.Make (Var) +module LocalVar = Make (MarkedString) +module LocalVarSet = Set.Make (LocalVar) +module LocalVarMap = Map.Make (LocalVar) +module SubScope = Make (MarkedString) +module SubScopeSet = Set.Make (SubScope) +module SubScopeMap = Map.Make (SubScope) + +(** Inside a scope, a definition can refer either to a scope def, or a subscope def *) +module ScopeDef = struct + type t = + | Var of Var.t + | SubScopeVar of SubScope.t * Var.t + (** In this case, the [Uid.Var.t] lives inside the context of the subscope's original + declaration *) + + let compare x y = + match (x, y) with + | Var x, Var y | Var x, SubScopeVar (_, y) | SubScopeVar (_, x), Var y -> Var.compare x y + | SubScopeVar (_, x), SubScopeVar (_, y) -> SubScope.compare x y + + let format_t x = + match x with + | Var v -> Var.format_t v + | SubScopeVar (s, v) -> Printf.sprintf "%s.%s" (SubScope.format_t s) (Var.format_t v) + + let hash x = match x with Var v -> Var.hash v | SubScopeVar (_, v) -> Var.hash v +end + +module ScopeDefMap = Map.Make (ScopeDef) +module ScopeDefSet = Set.Make (ScopeDef) diff --git a/src/catala/scope_language/uid.mli b/src/catala/scope_language/uid.mli new file mode 100644 index 00000000..2207ebcb --- /dev/null +++ b/src/catala/scope_language/uid.mli @@ -0,0 +1,79 @@ +(* This file is part of the Catala compiler, a specification language for tax and social benefits + computation rules. Copyright (C) 2020 Inria, contributor: Denis Merigoux + + + Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + in compliance with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software distributed under the License + is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + or implied. See the License for the specific language governing permissions and limitations under + the License. *) + +module IdentMap : Map.S with type key = String.t + +module MarkedString : sig + type info = string Pos.marked + + val format_info : 'a * 'b -> 'a +end + +module type Id = sig + type t + + type info + + val fresh : info -> t + + val get_info : t -> info + + val compare : t -> t -> int + + val format_t : t -> string + + val hash : t -> int +end + +module Scope : Id with type info = MarkedString.info + +module ScopeSet : Set.S with type elt = Scope.t + +module ScopeMap : Map.S with type key = Scope.t + +module Var : Id with type info = MarkedString.info + +module VarSet : Set.S with type elt = Var.t + +module VarMap : Map.S with type key = Var.t + +module LocalVar : Id with type info = MarkedString.info + +module LocalVarSet : Set.S with type elt = LocalVar.t + +module LocalVarMap : Map.S with type key = LocalVar.t + +module SubScope : Id with type info = MarkedString.info + +module SubScopeSet : Set.S with type elt = SubScope.t + +module SubScopeMap : Map.S with type key = SubScope.t + +module ScopeDef : sig + type t = + | Var of Var.t + | SubScopeVar of SubScope.t * Var.t + (** In this case, the [Uid.Var.t] lives inside the context of the subscope's original + declaration *) + + val compare : t -> t -> int + + val format_t : t -> string + + val hash : t -> int +end + +module ScopeDefMap : Map.S with type key = ScopeDef.t + +module ScopeDefSet : Set.S with type elt = ScopeDef.t diff --git a/tests/test_scope/grand_parent_caller.catala b/tests/test_scope/grand_parent_caller.catala new file mode 100644 index 00000000..c9b2cdf9 --- /dev/null +++ b/tests/test_scope/grand_parent_caller.catala @@ -0,0 +1,24 @@ +/* +new scope A: + param x type int + +new scope B: + param a scope A + param y type int + +new scope C: + param b scope B + param z type int + + +scope A: + def x := 0 + +scope B: + def a.x := y + def y := 1 + +scope C: + def b.y := z + def z = 2 +*/ \ No newline at end of file diff --git a/tests/test_scope/sub_scope.catala.B.out b/tests/test_scope/sub_scope.catala.B.out index 7b3901de..bc8d148b 100644 --- a/tests/test_scope/sub_scope.catala.B.out +++ b/tests/test_scope/sub_scope.catala.B.out @@ -1,8 +1,8 @@ -[RESULT] a -> 42 +[RESULT] a -> 42 [RESULT] b -> true -[RESULT] scopeA::a -> 1 -[RESULT] scopeA::a_base -> 1 -[RESULT] scopeA::b -> true -[RESULT] scopeAbis::a -> -1 -[RESULT] scopeAbis::a_base -> 1 -[RESULT] scopeAbis::b -> false +[RESULT] scopeA.a -> 1 +[RESULT] scopeA.b -> true +[RESULT] scopeA.a_base -> 1 +[RESULT] scopeAbis.a -> -1 +[RESULT] scopeAbis.b -> false +[RESULT] scopeAbis.a_base -> 1 \ No newline at end of file diff --git a/tests/test_scope/sub_sub_scope.catala.C.out b/tests/test_scope/sub_sub_scope.catala.C.out index 5a490ced..07582c1b 100644 --- a/tests/test_scope/sub_sub_scope.catala.C.out +++ b/tests/test_scope/sub_sub_scope.catala.C.out @@ -1,6 +1,4 @@ -[RESULT] a::u -> true -[RESULT] a::x -> 2 -[RESULT] b::a::u -> true -[RESULT] b::a::x -> 1 -[RESULT] b::y -> 3 -[RESULT] z -> 2 +[RESULT] z -> 2 +[RESULT] a.x -> 2 +[RESULT] a.u -> true +[RESULT] b.y -> 3 \ No newline at end of file