diff --git a/compiler/desugared/ast.ml b/compiler/desugared/ast.ml index 98ea3436..8716f3c6 100644 --- a/compiler/desugared/ast.ml +++ b/compiler/desugared/ast.ml @@ -103,7 +103,7 @@ type rule = { rule_id : RuleName.t; rule_just : expr boxed; rule_cons : expr boxed; - rule_parameter : (expr Var.t * typ) option; + rule_parameter : (expr Var.t * typ) list option; rule_exception : exception_situation; rule_label : label_situation; } @@ -124,45 +124,46 @@ module Rule = struct let c2 = Expr.unbox r2.rule_cons in Expr.compare c1 c2 | n -> n) - | Some (v1, t1), Some (v2, t2) -> ( - match Type.compare t1 t2 with - | 0 -> ( - let open Bindlib in - let b1 = unbox (bind_var v1 (Expr.Box.lift r1.rule_just)) in - let b2 = unbox (bind_var v2 (Expr.Box.lift r2.rule_just)) in - let _, j1, j2 = unbind2 b1 b2 in - match Expr.compare j1 j2 with - | 0 -> - let b1 = unbox (bind_var v1 (Expr.Box.lift r1.rule_cons)) in - let b2 = unbox (bind_var v2 (Expr.Box.lift r2.rule_cons)) in - let _, c1, c2 = unbind2 b1 b2 in - Expr.compare c1 c2 - | n -> n) - | n -> n) + | Some l1, Some l2 -> + ListLabels.compare l1 l2 ~cmp:(fun (v1, t1) (v2, t2) -> + match Type.compare t1 t2 with + | 0 -> ( + let open Bindlib in + let b1 = unbox (bind_var v1 (Expr.Box.lift r1.rule_just)) in + let b2 = unbox (bind_var v2 (Expr.Box.lift r2.rule_just)) in + let _, j1, j2 = unbind2 b1 b2 in + match Expr.compare j1 j2 with + | 0 -> + let b1 = unbox (bind_var v1 (Expr.Box.lift r1.rule_cons)) in + let b2 = unbox (bind_var v2 (Expr.Box.lift r2.rule_cons)) in + let _, c1, c2 = unbind2 b1 b2 in + Expr.compare c1 c2 + | n -> n) + | n -> n) | None, Some _ -> -1 | Some _, None -> 1 end -let empty_rule (pos : Pos.t) (have_parameter : typ option) : rule = +let empty_rule (pos : Pos.t) (have_parameter : typ list option) : rule = { rule_just = Expr.box (ELit (LBool false), Untyped { pos }); rule_cons = Expr.box (ELit LEmptyError, Untyped { pos }); rule_parameter = (match have_parameter with - | Some typ -> Some (Var.make "dummy", typ) + | Some typs -> Some (List.map (fun typ -> Var.make "dummy", typ) typs) | None -> None); rule_exception = BaseCase; rule_id = RuleName.fresh ("empty", pos); rule_label = Unlabeled; } -let always_false_rule (pos : Pos.t) (have_parameter : typ option) : rule = +let always_false_rule (pos : Pos.t) (have_parameter : typ list option) : rule = { rule_just = Expr.box (ELit (LBool true), Untyped { pos }); rule_cons = Expr.box (ELit (LBool false), Untyped { pos }); rule_parameter = (match have_parameter with - | Some typ -> Some (Var.make "dummy", typ) + | Some typs -> Some (List.map (fun typ -> Var.make "dummy", typ) typs) | None -> None); rule_exception = BaseCase; rule_id = RuleName.fresh ("always_false", pos); diff --git a/compiler/desugared/ast.mli b/compiler/desugared/ast.mli index 988fba0b..48b33f6c 100644 --- a/compiler/desugared/ast.mli +++ b/compiler/desugared/ast.mli @@ -60,15 +60,15 @@ type rule = { rule_id : RuleName.t; rule_just : expr boxed; rule_cons : expr boxed; - rule_parameter : (expr Var.t * typ) option; + rule_parameter : (expr Var.t * typ) list option; rule_exception : exception_situation; rule_label : label_situation; } module Rule : Set.OrderedType with type t = rule -val empty_rule : Pos.t -> typ option -> rule -val always_false_rule : Pos.t -> typ option -> rule +val empty_rule : Pos.t -> typ list option -> rule +val always_false_rule : Pos.t -> typ list option -> rule type assertion = expr boxed type variation_typ = Increasing | Decreasing diff --git a/compiler/desugared/disambiguate.ml b/compiler/desugared/disambiguate.ml index 8ff615b0..44ea094c 100644 --- a/compiler/desugared/disambiguate.ml +++ b/compiler/desugared/disambiguate.ml @@ -29,7 +29,9 @@ let rule ctx env rule = let env = match rule.rule_parameter with | None -> env - | Some (v, ty) -> Typing.Env.add_var v ty env + | Some l -> + let vs, tys = List.split l in + ListLabels.fold_right2 vs tys ~init:env ~f:Typing.Env.add_var in (* Note: we could use the known rule type here to direct typing. We choose not to because it shouldn't be needed for disambiguation, and we prefer to diff --git a/compiler/desugared/from_surface.ml b/compiler/desugared/from_surface.ml index ca71424b..ce7ca1ef 100644 --- a/compiler/desugared/from_surface.ml +++ b/compiler/desugared/from_surface.ml @@ -938,7 +938,8 @@ let process_default Name_resolution.get_def_typ ctxt (Marked.unmark def_key) in match Marked.unmark def_key_typ, param_uid with - | TArrow (t_in, _), Some param_uid -> Some (Marked.unmark param_uid, t_in) + | TArrow (t_ins, _), Some param_uid -> + Some (List.map (fun t_in -> Marked.unmark param_uid, t_in) t_ins) | TArrow _, None -> Errors.raise_spanned_error (Expr.pos cons) "This definition has a function type but the parameter is missing" @@ -1203,11 +1204,7 @@ let process_topdef body arg_types (Marked.get_mark def.S.topdef_name) in - let typ = - List.fold_right - (fun argty retty -> TArrow (argty, retty), ty_pos) - arg_types body_type - in + let typ = TArrow (arg_types, body_type), ty_pos in { prgm with Ast.program_topdefs = diff --git a/compiler/desugared/name_resolution.ml b/compiler/desugared/name_resolution.ml index 7ea80982..53029fb4 100644 --- a/compiler/desugared/name_resolution.ml +++ b/compiler/desugared/name_resolution.ml @@ -319,7 +319,9 @@ let process_type (ctxt : context) ((naked_typ, typ_pos) : Surface.Ast.typ) : typ match naked_typ with | Surface.Ast.Base base_typ -> process_base_typ ctxt (base_typ, typ_pos) | Surface.Ast.Func { arg_typ; return_typ } -> - ( TArrow (process_base_typ ctxt arg_typ, process_base_typ ctxt return_typ), + (* TODO Louis: /!\ There is only one argument in the surface syntax for + function now. *) + ( TArrow ([process_base_typ ctxt arg_typ], process_base_typ ctxt return_typ), typ_pos ) (** Process data declaration *) diff --git a/compiler/plugins/api_web.ml b/compiler/plugins/api_web.ml index caf91ba4..8731a4b4 100644 --- a/compiler/plugins/api_web.ml +++ b/compiler/plugins/api_web.ml @@ -78,8 +78,11 @@ module To_jsoo = struct Format.fprintf fmt "@[%a@ Js.js_array Js.t@]" format_typ_with_parens t1 | TAny -> Format.fprintf fmt "Js.Unsafe.any Js.t" | TArrow (t1, t2) -> - Format.fprintf fmt "(@[%a, @ %a@]) Js.meth_callback" - format_typ_with_parens t1 format_typ_with_parens t2 + Format.fprintf fmt "(@[unit, @ %a -> %a@]) Js.meth_callback" + (Format.pp_print_list + ~pp_sep:(fun fmt () -> Format.pp_print_string fmt " -> ") + format_typ_with_parens) + t1 format_typ_with_parens t2 let rec format_typ_to_jsoo fmt typ = match Marked.unmark typ with @@ -153,13 +156,21 @@ module To_jsoo = struct (fun fmt (struct_field, struct_field_type) -> match Marked.unmark struct_field_type with | TArrow (t1, t2) -> + let args_names = + ListLabels.mapi t1 ~f:(fun i _ -> + "function_input" ^ string_of_int i) + in Format.fprintf fmt "@[method %a =@ Js.wrap_meth_callback@ @[(@,\ - fun input ->@ %a (%a.%a (%a input)))@]@]" - format_struct_field_name_camel_case struct_field + fun _ %a ->@ %a (%a.%a %a))@]@]" + (Format.pp_print_list (fun fmt -> Format.pp_print_string fmt)) + args_names format_struct_field_name_camel_case struct_field format_typ_to_jsoo t2 fmt_struct_name () format_struct_field_name (None, struct_field) - format_typ_of_jsoo t1 + (Format.pp_print_list (fun fmt (i, ti) -> + Format.fprintf fmt "@[(%a@ %a)@]" + Format.pp_print_string i format_typ_to_jsoo ti)) + (List.combine args_names t1) | _ -> Format.fprintf fmt "@[val %a =@ %a %a.%a@]" format_struct_field_name_camel_case struct_field diff --git a/compiler/scopelang/dependency.ml b/compiler/scopelang/dependency.ml index d4187f93..10545c76 100644 --- a/compiler/scopelang/dependency.ml +++ b/compiler/scopelang/dependency.ml @@ -239,7 +239,9 @@ let rec get_structs_or_enums_in_type (t : typ) : TVertexSet.t = | TEnum e -> TVertexSet.singleton (TVertex.Enum e) | TArrow (t1, t2) -> TVertexSet.union - (get_structs_or_enums_in_type t1) + (t1 + |> List.map get_structs_or_enums_in_type + |> List.fold_left TVertexSet.union TVertexSet.empty) (get_structs_or_enums_in_type t2) | TLit _ | TAny -> TVertexSet.empty | TOption t1 | TArray t1 -> get_structs_or_enums_in_type t1 diff --git a/compiler/scopelang/from_desugared.ml b/compiler/scopelang/from_desugared.ml index fcec0c44..dffb6b66 100644 --- a/compiler/scopelang/from_desugared.ml +++ b/compiler/scopelang/from_desugared.ml @@ -212,26 +212,29 @@ let rec rule_tree_to_expr ~(is_reentrant_var : bool) (ctx : ctx) (def_pos : Pos.t) - (is_func : Desugared.Ast.expr Var.t option) + (is_func : Desugared.Ast.expr Var.t list option) (tree : rule_tree) : untyped Ast.expr boxed = let emark = Untyped { pos = def_pos } in let exceptions, base_rules = match tree with Leaf r -> [], r | Node (exceptions, r) -> exceptions, r in - (* because each rule has its own variable parameter and we want to convert the - whole rule tree into a function, we need to perform some alpha-renaming of - all the expressions *) + (* because each rule has its own variables parameters and we want to convert + the whole rule tree into a function, we need to perform some alpha-renaming + of all the expressions *) let substitute_parameter (e : Desugared.Ast.expr boxed) (rule : Desugared.Ast.rule) : Desugared.Ast.expr boxed = match is_func, rule.Desugared.Ast.rule_parameter with - | Some new_param, Some (old_param, _) -> - let binder = Bindlib.bind_var old_param (Marked.unmark e) in + | Some new_params, Some old_params_with_types -> + let old_params, _ = List.split old_params_with_types in + let old_params = Array.of_list old_params in + let new_params = Array.of_list new_params in + let binder = Bindlib.bind_mvar old_params (Marked.unmark e) in Marked.mark (Marked.get_mark e) @@ Bindlib.box_apply2 - (fun binder new_param -> Bindlib.subst binder new_param) + (fun binder new_param -> Bindlib.msubst binder new_param) binder - (Bindlib.box_var new_param) + (new_params |> Array.map Bindlib.box_var |> Bindlib.box_array) | None, None -> e | _ -> assert false (* should not happen *) @@ -239,20 +242,22 @@ let rec rule_tree_to_expr let ctx = match is_func with | None -> ctx - | Some new_param -> ( - match Var.Map.find_opt new_param ctx.var_mapping with - | None -> - let new_param_scope = Var.make (Bindlib.name_of new_param) in - { - ctx with - var_mapping = Var.Map.add new_param new_param_scope ctx.var_mapping; - } - | Some _ -> - (* We only create a mapping if none exists because [rule_tree_to_expr] - is called recursively on the exceptions of the tree and we don't want - to create a new Scopelang variable for the parameter at each tree - level. *) - ctx) + | Some new_params -> + ListLabels.fold_left new_params ~init:ctx ~f:(fun ctx new_param -> + match Var.Map.find_opt new_param ctx.var_mapping with + | None -> + let new_param_scope = Var.make (Bindlib.name_of new_param) in + { + ctx with + var_mapping = + Var.Map.add new_param new_param_scope ctx.var_mapping; + } + | Some _ -> + (* We only create a mapping if none exists because + [rule_tree_to_expr] is called recursively on the exceptions of + the tree and we don't want to create a new Scopelang variable for + the parameter at each tree level. *) + ctx) in let base_just_list = List.map @@ -301,7 +306,8 @@ let rec rule_tree_to_expr in match is_func, (List.hd base_rules).Desugared.Ast.rule_parameter with | None, None -> default - | Some new_param, Some (_, typ) -> + | Some new_params, Some ls -> + let _, tys = List.split ls in if toplevel then (* When we're creating a function from multiple defaults, we must check that the result returned by the function is not empty, unless we're @@ -311,9 +317,12 @@ let rec rule_tree_to_expr let default = if is_reentrant_var then default else Expr.eerroronempty default emark in + Expr.make_abs - [| Var.Map.find new_param ctx.var_mapping |] - default [typ] def_pos + (new_params + |> List.map (fun x -> Var.Map.find x ctx.var_mapping) + |> Array.of_list) + default tys def_pos else default | _ -> (* should not happen *) assert false @@ -340,7 +349,7 @@ let translate_def let all_rules_not_func = RuleName.Map.for_all (fun n r -> not (is_rule_func n r)) def in - let is_def_func_param_typ : typ option = + let is_def_func_param_typ : typ list option = if is_def_func && all_rules_func then match Marked.unmark typ with | TArrow (t_param, _) -> Some t_param @@ -379,7 +388,7 @@ let translate_def | Reentrant -> true | _ -> false in - let top_value = + let top_value : Desugared.Ast.rule option = if is_cond && ((not is_subscope_var) || (is_subscope_var && is_input)) then (* We add the bottom [false] value for conditions, only for the scope where the condition is declared. Except when the variable is an input, @@ -419,13 +428,19 @@ let translate_def let m = Untyped { pos = Desugared.Ast.ScopeDef.get_position def_info } in let empty_error = Expr.elit LEmptyError m in match is_def_func_param_typ with - | Some ty -> - Expr.make_abs [| Var.make "_" |] empty_error [ty] (Expr.mark_pos m) + | Some tys -> + Expr.make_abs + (Array.init (List.length tys) (fun _ -> Var.make "_")) + empty_error tys (Expr.mark_pos m) | _ -> empty_error else rule_tree_to_expr ~toplevel:true ~is_reentrant_var:is_reentrant ctx (Desugared.Ast.ScopeDef.get_position def_info) - (Option.map (fun _ -> Var.make "param") is_def_func_param_typ) + (Option.map + (fun l -> + ListLabels.mapi l ~f:(fun i _ -> + Var.make ("param" ^ string_of_int i))) + is_def_func_param_typ) (match top_list, top_value with | [], None -> (* In this case, there are no rules to define the expression and no