diff --git a/src/grammar.lalrpop b/src/grammar.lalrpop index 2c5574cb..5f9ef619 100644 --- a/src/grammar.lalrpop +++ b/src/grammar.lalrpop @@ -269,8 +269,8 @@ UOp: UnaryOp = { "tag" => UnaryOp::Tag(s), "wrap" => UnaryOp::Wrap(), "embed" => UnaryOp::Embed(<>), - "map" => UnaryOp::ListMap(<>), - "recordMap" => UnaryOp::RecordMap(<>), + "map" => UnaryOp::ListMap(), + "recordMap" => UnaryOp::RecordMap(), "seq" => UnaryOp::Seq(), "deepSeq" => UnaryOp::DeepSeq(), "head" => UnaryOp::ListHead(), diff --git a/src/operation.rs b/src/operation.rs index 18305520..cf10eda4 100644 --- a/src/operation.rs +++ b/src/operation.rs @@ -434,7 +434,11 @@ fn process_unary_operation( )) } } - UnaryOp::ListMap(f) => { + UnaryOp::ListMap() => { + let (f, _) = stack + .pop_arg() + .ok_or_else(|| EvalError::NotEnoughArgs(2, String::from("map"), pos_op))?; + if let Term::List(ts) = *t { let mut shared_env = Environment::new(); let f_as_var = f.body.closurize(&mut env, f.env); @@ -464,7 +468,11 @@ fn process_unary_operation( )) } } - UnaryOp::RecordMap(f) => { + UnaryOp::RecordMap() => { + let (f, _) = stack + .pop_arg() + .ok_or_else(|| EvalError::NotEnoughArgs(2, String::from("recordMap"), pos_op))?; + if let Term::Record(rec) = *t { let mut shared_env = Environment::new(); let f_as_var = f.body.closurize(&mut env, f.env); diff --git a/src/program.rs b/src/program.rs index 836c0c09..decdad42 100644 --- a/src/program.rs +++ b/src/program.rs @@ -811,15 +811,15 @@ Assume(#alwaysTrue, false) assert_eq!( eval_string( - "(%recordMap% (fun y => fun x => x + 1) { foo = 1; bar = \"it's lazy\"; }).foo" + "(%recordMap% { foo = 1; bar = \"it's lazy\"; } (fun y => fun x => x + 1)).foo" ), Ok(Term::Num(2.)), ); assert_eq!( eval_string( "let r = %recordMap% - (fun y x => if %isNum% x then x + 1 else 0) { foo = 1; bar = \"it's lazy\"; } + (fun y x => if %isNum% x then x + 1 else 0) in (r.foo) + (r.bar)" ), @@ -878,7 +878,7 @@ Assume(#alwaysTrue, false) fn lists() { assert_eq!(eval_string("%elemAt% [1,2,3] 1"), Ok(Term::Num(2.0))); assert_eq!( - eval_string("%elemAt% (%map% (fun x => x + 1) [1,2,3]) 1"), + eval_string("%elemAt% (%map% [1,2,3] (fun x => x + 1)) 1"), Ok(Term::Num(3.0)) ); @@ -929,7 +929,7 @@ Assume(#alwaysTrue, false) if y then true else false else false in - let all = fun pred => fun l => foldr and true (%map% pred l) in + let all = fun pred => fun l => foldr and true (%map% l pred) in let isZ = fun x => x == 0 in all isZ [0, 0, 0, 1]" ), diff --git a/src/term.rs b/src/term.rs index 7d7dc3da..6303d1a8 100644 --- a/src/term.rs +++ b/src/term.rs @@ -498,13 +498,13 @@ pub enum UnaryOp { StaticAccess(Ident), /// Map a function on each element of a list. - ListMap(CapturedTerm), + ListMap(), /// Map a function on a record. /// /// The mapped function must take two arguments, the name of the field as a string, and the /// content of the field. `RecordMap` then replaces the content of each field by the result of the /// function: i.e., `recordMap f {a=2;}` evaluates to `{a=(f "a" 2);}`. - RecordMap(CapturedTerm), + RecordMap(), /// Inverse the polarity of a label. ChangePolarity(), @@ -574,8 +574,8 @@ impl UnaryOp { match self { Switch(has_default) => Switch(has_default), - ListMap(t) => ListMap(f(t)), - RecordMap(t) => RecordMap(f(t)), + ListMap() => ListMap(), + RecordMap() => RecordMap(), Ite() => Ite(), diff --git a/src/typecheck.rs b/src/typecheck.rs index 4e0a7627..c4b99374 100644 --- a/src/typecheck.rs +++ b/src/typecheck.rs @@ -1335,8 +1335,8 @@ fn instantiate_foralls(state: &mut State, mut ty: TypeWrapper, inst: ForallInst) /// Type of unary operations. pub fn get_uop_type( state: &mut State, - envs: Envs, - strict: bool, + _envs: Envs, + _strict: bool, op: &UnaryOp, ) -> Result { Ok(match op { @@ -1398,20 +1398,16 @@ pub fn get_uop_type( mk_tyw_arrow!(mk_tyw_record!((id.clone(), res.clone()); row), res) } - // List -> List - // Unify f with a -> b. - UnaryOp::ListMap(f) => { + // forall a b. List -> (a -> b) -> List + UnaryOp::ListMap() => { let a = TypeWrapper::Ptr(new_var(state.table)); let b = TypeWrapper::Ptr(new_var(state.table)); let f_type = mk_tyw_arrow!(a.clone(), b.clone()); - type_check_(state, envs.clone(), strict, f, f_type)?; - - mk_tyw_arrow!(AbsType::List(), AbsType::List()) + mk_tyw_arrow!(AbsType::List(), f_type, AbsType::List()) } - // { _ : a} -> { _ : b } - // Unify f with Str -> a -> b. - UnaryOp::RecordMap(f) => { + // forall a b. { _ : a} -> (Str -> a -> b) -> { _ : b } + UnaryOp::RecordMap() => { // Assuming f has type Str -> a -> b, // this has type DynRecord(a) -> DynRecord(b) @@ -1419,9 +1415,11 @@ pub fn get_uop_type( let b = TypeWrapper::Ptr(new_var(state.table)); let f_type = mk_tyw_arrow!(AbsType::Str(), a.clone(), b.clone()); - type_check_(state, envs.clone(), strict, f, f_type)?; - - mk_tyw_arrow!(mk_typewrapper::dyn_record(a), mk_typewrapper::dyn_record(b)) + mk_tyw_arrow!( + mk_typewrapper::dyn_record(a), + f_type, + mk_typewrapper::dyn_record(b) + ) } // forall a b. a -> b -> b UnaryOp::Seq() | UnaryOp::DeepSeq() => { @@ -2079,7 +2077,7 @@ mod tests { parse_and_typecheck("fun l => %tail% l : List -> List").unwrap(); parse_and_typecheck("fun l => %head% l : List -> Dyn").unwrap(); parse_and_typecheck( - "fun f l => %map% f l : forall a. (forall b. (a -> b) -> List -> List)", + "fun f l => %map% l f : forall a. (forall b. (a -> b) -> List -> List)", ) .unwrap(); parse_and_typecheck("(fun l1 => fun l2 => l1 @ l2) : List -> List -> List").unwrap(); @@ -2087,7 +2085,7 @@ mod tests { parse_and_typecheck("(fun l => %head% l) : forall a. (List -> a)").unwrap_err(); parse_and_typecheck( - "(fun f l => %elemAt% (%map% f l) 0) : forall a. (forall b. (a -> b) -> List -> b)", + "(fun f l => %elemAt% (%map% l f) 0) : forall a. (forall b. (a -> b) -> List -> b)", ) .unwrap_err(); } diff --git a/stdlib/contracts.ncl b/stdlib/contracts.ncl index 3606bbab..bde17055 100644 --- a/stdlib/contracts.ncl +++ b/stdlib/contracts.ncl @@ -38,7 +38,7 @@ dyn_record = fun contr l t => if %isRecord% t then - %recordMap% (fun _field => contr l) t + %recordMap% t (fun _field => contr l) else %blame% (%tag% "not a record" l); diff --git a/stdlib/lists.ncl b/stdlib/lists.ncl index 091c41ad..f2e6ca5e 100644 --- a/stdlib/lists.ncl +++ b/stdlib/lists.ncl @@ -6,7 +6,7 @@ length : List -> Num = fun l => %length% l; - map : (Dyn -> Dyn) -> List -> List = fun f l => %map% f l; + map : (Dyn -> Dyn) -> List -> List = fun f l => %map% l f; elemAt : List -> Num -> Dyn = fun l n => %elemAt% l n; diff --git a/stdlib/records.ncl b/stdlib/records.ncl index 2d35c37b..1826a0fc 100644 --- a/stdlib/records.ncl +++ b/stdlib/records.ncl @@ -1,6 +1,6 @@ { records = { - map : forall a b. (Str -> a -> b) -> {_: a} -> {_: b} = fun f r => %recordMap% f r; + map : forall a b. (Str -> a -> b) -> {_: a} -> {_: b} = fun f r => %recordMap% r f; // TODO: change Dyn to { | Dyn} once the PR introducing open contracts lands fieldsOf : Dyn -> List = fun r => %fieldsOf% r;