From 74248f936ba8fc4be474941fc2a49adef5fe747f Mon Sep 17 00:00:00 2001 From: collin Date: Mon, 10 Aug 2020 22:02:03 -0700 Subject: [PATCH] impl tuples for inputs --- compiler/src/errors/function.rs | 6 ++ .../src/function/input/main_function_input.rs | 3 +- compiler/src/function/input/mod.rs | 3 + compiler/src/function/input/tuple.rs | 56 +++++++++++++++++++ input/src/errors/parser.rs | 9 +++ input/src/expressions/expression.rs | 41 +++++++------- input/src/leo-input.pest | 13 +++-- input/src/types/mod.rs | 3 + input/src/types/tuple_type.rs | 25 +++++++++ input/src/types/type_.rs | 4 +- typed/src/input/input_value.rs | 45 +++++++++++---- typed/src/types/type_.rs | 17 +++++- 12 files changed, 185 insertions(+), 40 deletions(-) create mode 100644 compiler/src/function/input/tuple.rs create mode 100644 input/src/types/tuple_type.rs diff --git a/compiler/src/errors/function.rs b/compiler/src/errors/function.rs index ca267b7c11..a4d26a87cc 100644 --- a/compiler/src/errors/function.rs +++ b/compiler/src/errors/function.rs @@ -78,6 +78,12 @@ impl FunctionError { Self::new_from_span(message, span) } + pub fn invalid_tuple(actual: String, span: Span) -> Self { + let message = format!("Expected function input tuple, found `{}`", actual); + + Self::new_from_span(message, span) + } + pub fn return_arguments_length(expected: usize, actual: usize, span: Span) -> Self { let message = format!("function expected {} returns, found {} returns", expected, actual); diff --git a/compiler/src/function/input/main_function_input.rs b/compiler/src/function/input/main_function_input.rs index 0b24f707b1..1d27ee33a6 100644 --- a/compiler/src/function/input/main_function_input.rs +++ b/compiler/src/function/input/main_function_input.rs @@ -42,7 +42,8 @@ impl> ConstrainedProgram { input_option, span, )?)), - Type::Array(_type, dimensions) => self.allocate_array(cs, name, *_type, dimensions, input_option, span), + Type::Array(type_, dimensions) => self.allocate_array(cs, name, *type_, dimensions, input_option, span), + Type::Tuple(types) => self.allocate_tuple(cs, name, types, input_option, span), _ => unimplemented!("main function input not implemented for type"), } } diff --git a/compiler/src/function/input/mod.rs b/compiler/src/function/input/mod.rs index 0db12df19f..34d417ce4f 100644 --- a/compiler/src/function/input/mod.rs +++ b/compiler/src/function/input/mod.rs @@ -14,3 +14,6 @@ pub use self::input_keyword::*; pub mod input_section; pub use self::input_section::*; + +pub mod tuple; +pub use self::tuple::*; diff --git a/compiler/src/function/input/tuple.rs b/compiler/src/function/input/tuple.rs new file mode 100644 index 0000000000..55a0eba022 --- /dev/null +++ b/compiler/src/function/input/tuple.rs @@ -0,0 +1,56 @@ +//! Allocates an array as a main function input parameter in a compiled Leo program. + +use crate::{ + errors::FunctionError, + program::{new_scope, ConstrainedProgram}, + value::ConstrainedValue, + GroupType, +}; + +use leo_typed::{InputValue, Span, Type}; + +use snarkos_models::{ + curves::{Field, PrimeField}, + gadgets::r1cs::ConstraintSystem, +}; + +impl> ConstrainedProgram { + pub fn allocate_tuple>( + &mut self, + cs: &mut CS, + name: String, + types: Vec, + input_value: Option, + span: Span, + ) -> Result, FunctionError> { + let mut tuple_values = vec![]; + + match input_value { + Some(InputValue::Tuple(values)) => { + // Allocate each value in the tuple + for (i, (value, type_)) in values.into_iter().zip(types.into_iter()).enumerate() { + let value_name = new_scope(name.clone(), i.to_string()); + + tuple_values.push(self.allocate_main_function_input( + cs, + type_, + value_name, + Some(value), + span.clone(), + )?) + } + } + None => { + // Allocate all tuple values as none + for (i, type_) in types.into_iter().enumerate() { + let value_name = new_scope(name.clone(), i.to_string()); + + tuple_values.push(self.allocate_main_function_input(cs, type_, value_name, None, span.clone())?); + } + } + _ => return Err(FunctionError::invalid_tuple(input_value.unwrap().to_string(), span)), + } + + Ok(ConstrainedValue::Tuple(tuple_values)) + } +} diff --git a/input/src/errors/parser.rs b/input/src/errors/parser.rs index 9714c27440..89330d28d8 100644 --- a/input/src/errors/parser.rs +++ b/input/src/errors/parser.rs @@ -114,6 +114,15 @@ impl InputParserError { Self::new_from_span(message, table.span) } + pub fn tuple_length(expected: usize, actual: usize, span: Span) -> Self { + let message = format!( + "expected a tuple with {} elements, found a tuple with {} elements", + expected, actual + ); + + Self::new_from_span(message, span) + } + pub fn section(header: Header) -> Self { let message = format!( "the section header `{}` must have a double bracket visibility in a state `.state` file", diff --git a/input/src/expressions/expression.rs b/input/src/expressions/expression.rs index e990424b68..31c8ab14f5 100644 --- a/input/src/expressions/expression.rs +++ b/input/src/expressions/expression.rs @@ -1,8 +1,4 @@ -use crate::{ - ast::Rule, - expressions::*, - values::{Address, Value}, -}; +use crate::{ast::Rule, expressions::*, values::Value}; use pest::Span; use pest_ast::FromPest; @@ -11,19 +7,19 @@ use std::fmt; #[derive(Clone, Debug, FromPest, PartialEq)] #[pest_ast(rule(Rule::expression))] pub enum Expression<'ast> { - ArrayInline(ArrayInlineExpression<'ast>), ArrayInitializer(ArrayInitializerExpression<'ast>), + ArrayInline(ArrayInlineExpression<'ast>), + Tuple(Vec>), Value(Value<'ast>), - ImplicitAddress(Address<'ast>), } impl<'ast> Expression<'ast> { pub fn span(&self) -> &Span { match self { - Expression::ArrayInline(expression) => &expression.span, Expression::ArrayInitializer(expression) => &expression.span, + Expression::ArrayInline(expression) => &expression.span, + Expression::Tuple(tuple) => tuple[0].span(), Expression::Value(value) => value.span(), - Expression::ImplicitAddress(address) => &address.span, } } } @@ -31,20 +27,25 @@ impl<'ast> Expression<'ast> { impl<'ast> fmt::Display for Expression<'ast> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { - Expression::ImplicitAddress(ref address) => write!(f, "{}", address), - Expression::Value(ref expression) => write!(f, "{}", expression), - Expression::ArrayInline(ref expression) => { - for (i, value) in expression.expressions.iter().enumerate() { - write!(f, "array [{}", value)?; - if i < expression.expressions.len() - 1 { - write!(f, ", ")?; - } - } - write!(f, "]") - } Expression::ArrayInitializer(ref expression) => { write!(f, "array [{} ; {}]", expression.expression, expression.count) } + Expression::ArrayInline(ref array) => { + let values = array + .expressions + .iter() + .map(|x| format!("{}", x)) + .collect::>() + .join(", "); + + write!(f, "array [{}]", values) + } + Expression::Tuple(ref tuple) => { + let values = tuple.iter().map(|x| format!("{}", x)).collect::>().join(", "); + + write!(f, "({})", values) + } + Expression::Value(ref expression) => write!(f, "{}", expression), } } } diff --git a/input/src/leo-input.pest b/input/src/leo-input.pest index bc487e98a5..fec003e652 100644 --- a/input/src/leo-input.pest +++ b/input/src/leo-input.pest @@ -33,7 +33,7 @@ LINE_END = { ";" ~ NEWLINE* } /// Types // Declared in types/type_.rs -type_ = { type_array | type_data } +type_ = { type_tuple | type_array | type_data } // Declared in types/integer_type.rs type_integer = { @@ -89,6 +89,8 @@ type_data = { type_field | type_group | type_boolean | type_address | type_integ // Declared in types/array_type.rs type_array = { type_data ~ ("[" ~ number_positive ~ "]")+ } +type_tuple = { "(" ~ type_ ~ ("," ~ (type_tuple | type_))+ ~ ")" } + /// Values // Declared in values/value.rs @@ -148,11 +150,12 @@ inline_array_inner = _{ (expression ~ ("," ~ NEWLINE* ~ expression)*)? } // Declared in expressions/expression.rs expression = { - expression_array_inline + value + | expression_tuple + | expression_array_inline | expression_array_initializer - | value - | address // address conflicts with identifier namespaces so we catch implicit address values as expressions here } +expression_tuple = { "(" ~ expression ~ ("," ~ expression)+ ~")" } /// Parameters @@ -185,7 +188,7 @@ header = { main | record | registers | state_leaf | state | identifier } /// Definitions // Declared in definition/definition.rs -definition = { parameter ~ "=" ~ NEWLINE* ~ expression ~ LINE_END } +definition = { parameter ~ "=" ~ expression ~ LINE_END } /// Table diff --git a/input/src/types/mod.rs b/input/src/types/mod.rs index 364b8f9150..6a247be42b 100644 --- a/input/src/types/mod.rs +++ b/input/src/types/mod.rs @@ -22,6 +22,9 @@ pub use integer_type::*; pub mod signed_integer_type; pub use signed_integer_type::*; +pub mod tuple_type; +pub use tuple_type::*; + pub mod type_; pub use type_::*; diff --git a/input/src/types/tuple_type.rs b/input/src/types/tuple_type.rs new file mode 100644 index 0000000000..6daad6cb1e --- /dev/null +++ b/input/src/types/tuple_type.rs @@ -0,0 +1,25 @@ +use crate::{ast::Rule, types::Type}; + +use pest::Span; +use pest_ast::FromPest; + +#[derive(Clone, Debug, FromPest, PartialEq, Eq)] +#[pest_ast(rule(Rule::type_tuple))] +pub struct TupleType<'ast> { + pub types_: Vec>, + #[pest_ast(outer())] + pub span: Span<'ast>, +} + +impl<'ast> std::fmt::Display for TupleType<'ast> { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + let tuple = self + .types_ + .iter() + .map(|x| format!("{}", x)) + .collect::>() + .join(", "); + + write!(f, "({})", tuple) + } +} diff --git a/input/src/types/type_.rs b/input/src/types/type_.rs index a7a823ecab..d7ea2779e0 100644 --- a/input/src/types/type_.rs +++ b/input/src/types/type_.rs @@ -3,11 +3,12 @@ use crate::{ast::Rule, types::*}; use pest_ast::FromPest; use std::fmt; -#[derive(Clone, Debug, FromPest, PartialEq)] +#[derive(Clone, Debug, FromPest, PartialEq, Eq)] #[pest_ast(rule(Rule::type_))] pub enum Type<'ast> { Basic(DataType), Array(ArrayType<'ast>), + Tuple(TupleType<'ast>), } impl<'ast> fmt::Display for Type<'ast> { @@ -15,6 +16,7 @@ impl<'ast> fmt::Display for Type<'ast> { match *self { Type::Basic(ref basic) => write!(f, "{}", basic), Type::Array(ref array) => write!(f, "{}", array), + Type::Tuple(ref tuple) => write!(f, "{}", tuple), } } } diff --git a/typed/src/input/input_value.rs b/typed/src/input/input_value.rs index 9b51bf953e..fde0861480 100644 --- a/typed/src/input/input_value.rs +++ b/typed/src/input/input_value.rs @@ -5,7 +5,7 @@ use leo_input::{ values::{BooleanValue, FieldValue, GroupValue, NumberValue, Value}, }; -use leo_input::values::Address; +use leo_input::{types::TupleType, values::Address}; use std::fmt; #[derive(Clone, PartialEq, Eq)] @@ -16,6 +16,7 @@ pub enum InputValue { Group(String), Integer(IntegerType, String), Array(Vec), + Tuple(Vec), } impl InputValue { @@ -66,9 +67,6 @@ impl InputValue { pub(crate) fn from_expression(type_: Type, expression: Expression) -> Result { match (type_, expression) { - (Type::Basic(DataType::Address(_)), Expression::ImplicitAddress(address)) => { - Ok(InputValue::from_address(address)) - } (Type::Basic(data_type), Expression::Value(value)) => InputValue::from_value(data_type, value), (Type::Array(array_type), Expression::ArrayInline(inline)) => { InputValue::from_array_inline(array_type, inline) @@ -76,6 +74,7 @@ impl InputValue { (Type::Array(array_type), Expression::ArrayInitializer(initializer)) => { InputValue::from_array_initializer(array_type, initializer) } + (Type::Tuple(tuple_type), Expression::Tuple(tuple)) => InputValue::from_tuple(tuple_type, tuple), (type_, expression) => Err(InputParserError::expression_type_mismatch(type_, expression)), } } @@ -139,6 +138,28 @@ impl InputValue { Ok(InputValue::Array(values)) } + + pub(crate) fn from_tuple(tuple_type: TupleType, tuple: Vec) -> Result { + let num_types = tuple_type.types_.len(); + let num_values = tuple.len(); + + if num_types != num_values { + return Err(InputParserError::tuple_length( + num_types, + num_values, + tuple_type.span.clone(), + )); + } + + let mut values = vec![]; + for (type_, value) in tuple_type.types_.into_iter().zip(tuple.into_iter()) { + let value = InputValue::from_expression(type_, value)?; + + values.push(value) + } + + Ok(InputValue::Tuple(values)) + } } impl fmt::Display for InputValue { @@ -150,14 +171,14 @@ impl fmt::Display for InputValue { InputValue::Field(ref field) => write!(f, "{}", field), InputValue::Integer(ref type_, ref number) => write!(f, "{}{:?}", number, type_), InputValue::Array(ref array) => { - write!(f, "[")?; - for (i, e) in array.iter().enumerate() { - write!(f, "{}", e)?; - if i < array.len() - 1 { - write!(f, ", ")?; - } - } - write!(f, "]") + let values = array.iter().map(|x| format!("{}", x)).collect::>().join(", "); + + write!(f, "array [{}]", values) + } + InputValue::Tuple(ref tuple) => { + let values = tuple.iter().map(|x| format!("{}", x)).collect::>().join(", "); + + write!(f, "({})", values) } } } diff --git a/typed/src/types/type_.rs b/typed/src/types/type_.rs index a894b2b560..703b7325ef 100644 --- a/typed/src/types/type_.rs +++ b/typed/src/types/type_.rs @@ -1,6 +1,6 @@ use crate::{Expression, Identifier, IntegerType}; use leo_ast::types::{ArrayType, CircuitType, DataType, Type as AstType}; -use leo_input::types::{ArrayType as InputArrayType, DataType as InputDataType, Type as InputAstType}; +use leo_input::types::{ArrayType as InputArrayType, DataType as InputDataType, TupleType, Type as InputAstType}; use serde::{Deserialize, Serialize}; use std::fmt; @@ -17,6 +17,7 @@ pub enum Type { // Data type wrappers Array(Box, Vec), + Tuple(Vec), Circuit(Identifier), SelfType, } @@ -108,11 +109,20 @@ impl<'ast> From> for Type { } } +impl<'ast> From> for Type { + fn from(tuple_type: TupleType<'ast>) -> Self { + let types = tuple_type.types_.into_iter().map(|type_| Type::from(type_)).collect(); + + Type::Tuple(types) + } +} + impl<'ast> From> for Type { fn from(type_: InputAstType<'ast>) -> Self { match type_ { InputAstType::Basic(type_) => Type::from(type_), InputAstType::Array(type_) => Type::from(type_), + InputAstType::Tuple(type_) => Type::from(type_), } } } @@ -162,6 +172,11 @@ impl fmt::Display for Type { } write!(f, "") } + Type::Tuple(ref tuple) => { + let types = tuple.iter().map(|x| format!("{}", x)).collect::>().join(", "); + + write!(f, "({})", types) + } } } }