diff --git a/compiler/desugared/from_surface.ml b/compiler/desugared/from_surface.ml index 9e1e0cbe..90141e3a 100644 --- a/compiler/desugared/from_surface.ml +++ b/compiler/desugared/from_surface.ml @@ -226,17 +226,20 @@ let rec translate_expr let rec_helper ?(local_vars = local_vars) e = translate_expr scope inside_definition_of ctxt local_vars e in - let rec detuplify_list = function + let rec detuplify_list names = function (* Where a list is expected (e.g. after [among]), as syntactic sugar, if a tuple is found instead we transpose it into a list of tuples *) | S.Tuple ls, pos -> let m = Untyped { pos } in - let ls = List.map detuplify_list ls in - let rec zip = function + let ls = List.map (detuplify_list []) ls in + let rec zip names = function | [] -> assert false | [l] -> l | l1 :: r -> - let rhs = zip r in + let name1, names = + match names with name1 :: names -> name1, names | [] -> "x", [] + in + let rhs = zip names r in let rtys, explode = match List.length r with | 1 -> (TAny, pos), fun e -> [e] @@ -248,7 +251,11 @@ let rec translate_expr in let tys = [TAny, pos; rtys] in let f_join = - let x1 = Var.make "x1" and x2 = Var.make "x2" in + let x1 = Var.make name1 in + let x2 = + Var.make + (match names with [] -> "zip" | _ -> String.concat "_" names) + in Expr.make_abs [| x1; x2 |] (Expr.make_tuple (Expr.evar x1 m :: explode (Expr.evar x2 m)) m) tys pos @@ -257,8 +264,10 @@ let rec translate_expr ~tys:((TAny, pos) :: List.map (fun ty -> TArray ty, pos) tys) m in - zip ls - | e -> rec_helper e + zip names ls + | e -> + (* If the input is not a tuple, we assume it's already a list *) + rec_helper e in let pos = Mark.get expr in let emark = Untyped { pos } in @@ -663,8 +672,10 @@ let rec translate_expr | ArrayLit es -> Expr.earray (List.map rec_helper es) emark | Tuple es -> Expr.etuple (List.map rec_helper es) emark | CollectionOp (((S.Filter { f } | S.Map { f }) as op), collection) -> - let collection = detuplify_list collection in let param_names, predicate = f in + let collection = + detuplify_list (List.map Mark.remove param_names) collection + in let params = List.map (fun n -> Var.make (Mark.remove n)) param_names in let local_vars = List.fold_left2 @@ -709,7 +720,9 @@ let rec translate_expr collection ) -> let default = rec_helper default in let pos_dft = Expr.pos default in - let collection = detuplify_list collection in + let collection = + detuplify_list (List.map Mark.remove param_names) collection + in let params = List.map (fun n -> Var.make (Mark.remove n)) param_names in let local_vars = List.fold_left2 @@ -767,7 +780,9 @@ let rec translate_expr Expr.etupleaccess ~e:weighted_result ~index:0 ~size:2 emark | CollectionOp (((Exists { predicate } | Forall { predicate }) as op), collection) -> - let collection = detuplify_list collection in + let collection = + detuplify_list (List.map Mark.remove (fst predicate)) collection + in let init, op = match op with | Exists _ -> false, S.Or @@ -850,7 +865,7 @@ let rec translate_expr | MemCollection (member, collection) -> let param_var = Var.make "collection_member" in let param = Expr.make_var param_var emark in - let collection = detuplify_list collection in + let collection = detuplify_list ["collection_member"] collection in let init = Expr.elit (LBool false) emark in let acc_var = Var.make "acc" in let acc = Expr.make_var acc_var emark in diff --git a/tests/test_tuples/good/tuplists.catala_en b/tests/test_tuples/good/tuplists.catala_en index bcb89da1..4c47a6fc 100644 --- a/tests/test_tuples/good/tuplists.catala_en +++ b/tests/test_tuples/good/tuplists.catala_en @@ -156,11 +156,11 @@ let grok : (decimal, money, money) → (money * decimal) = in let tlist : list of (decimal * money * money) = map2 - (λ (x1: decimal) (x2: (money * money)) → - let a_b_c : (decimal * money * money) = (x1, x2.0, x2.1) in + (λ (a: decimal) (b_c: (money * money)) → + let a_b_c : (decimal * money * money) = (a, b_c.0, b_c.1) in (a_b_c.0, a_b_c.1, a_b_c.2)) lis1 - map2 (λ (x1: money) (x2: money) → (x1, x2)) lis2 lis3 + map2 (λ (b: money) (c: money) → (b, c)) lis2 lis3 in let S : S_in → S = λ (S_in: S_in) → @@ -169,19 +169,19 @@ let S : S_in → S = in let r2 : list of (money * decimal) = map2 - (λ (x1: decimal) (x2: (money * money)) → - let x3 : (decimal * money * money) = (x1, x2.0, x2.1) in - grok x3.0 x3.1 x3.2) + (λ (x: decimal) (zip: (money * money)) → + let x1 : (decimal * money * money) = (x, zip.0, zip.1) in + grok x1.0 x1.1 x1.2) lis1 - map2 (λ (x1: money) (x2: money) → (x1, x2)) lis2 lis3 + map2 (λ (x: money) (zip: money) → (x, zip)) lis2 lis3 in let r3 : list of (money * decimal) = map2 - (λ (x1: decimal) (x2: (money * money)) → - let x_y_z : (decimal * money * money) = (x1, x2.0, x2.1) in + (λ (x: decimal) (y_z: (money * money)) → + let x_y_z : (decimal * money * money) = (x, y_z.0, y_z.1) in grok x_y_z.0 x_y_z.1 x_y_z.2) lis1 - map2 (λ (x1: money) (x2: money) → (x1, x2)) lis2 lis3 + map2 (λ (y: money) (z: money) → (y, z)) lis2 lis3 in let r4 : list of (money * decimal) = map (λ (x_y_z: (decimal * money * money)) → @@ -190,22 +190,22 @@ let S : S_in → S = in let r5 : list of (money * decimal) = map2 - (λ (x1: decimal) (x2: (money * money)) → - let x_y_z : (decimal * money * money) = (x1, x2.0, x2.1) in + (λ (x: decimal) (y_z: (money * money)) → + let x_y_z : (decimal * money * money) = (x, y_z.0, y_z.1) in (x_y_z.1 * x_y_z.0, x_y_z.1 / x_y_z.2)) lis1 - map2 (λ (x1: money) (x2: money) → (x1, x2)) lis2 lis3 + map2 (λ (y: money) (z: money) → (y, z)) lis2 lis3 in let r6 : list of (money * decimal) = map2 - (λ (x1: (decimal * money)) (x2: money) → - let xy_z : ((decimal * money) * money) = (x1, x2) in - let xy : (decimal * money) = xy_z.0 in - let z : money = xy_z.1 in - (xy.1 * xy.0, xy.1 / z)) + (λ (xy: (decimal * money)) (z: money) → + let xy_z : ((decimal * money) * money) = (xy, z) in + let xy1 : (decimal * money) = xy_z.0 in + let z1 : money = xy_z.1 in + (xy1.1 * xy1.0, xy1.1 / z1)) map2 - (λ (x1: decimal) (x2: money) → - let x_y : (decimal * money) = (x1, x2) in + (λ (x: decimal) (y: money) → + let x_y : (decimal * money) = (x, y) in (x_y.0, x_y.1)) lis1 lis2