fix return type namespaces

This commit is contained in:
collin 2020-06-23 16:58:22 -07:00
parent c2862c7a0c
commit 2291cc51fd
6 changed files with 95 additions and 81 deletions

View File

@ -220,15 +220,7 @@ impl<F: Field + PrimeField, G: GroupType<F>> ConstrainedProgram<F, G> {
right: ConstrainedValue<F, G>, right: ConstrainedValue<F, G>,
span: Span, span: Span,
) -> Result<ConstrainedValue<F, G>, ExpressionError> { ) -> Result<ConstrainedValue<F, G>, ExpressionError> {
let mut unique_namespace = cs.ns(|| { let mut unique_namespace = cs.ns(|| format!("evaluate {} == {} {}:{}", left, right, span.line, span.start));
format!(
"evaluate {} == {} {}:{}",
left.to_string(),
right.to_string(),
span.line,
span.start
)
});
let constraint_result = match (left, right) { let constraint_result = match (left, right) {
(ConstrainedValue::Boolean(bool_1), ConstrainedValue::Boolean(bool_2)) => { (ConstrainedValue::Boolean(bool_1), ConstrainedValue::Boolean(bool_2)) => {
bool_1.evaluate_equal(unique_namespace, &bool_2) bool_1.evaluate_equal(unique_namespace, &bool_2)
@ -425,10 +417,7 @@ impl<F: Field + PrimeField, G: GroupType<F>> ConstrainedProgram<F, G> {
let unique_namespace = cs.ns(|| { let unique_namespace = cs.ns(|| {
format!( format!(
"select {} or {} {}:{}", "select {} or {} {}:{}",
resolved_second.to_string(), resolved_second, resolved_third, span.line, span.start
resolved_third.to_string(),
span.line,
span.start
) )
}); });

View File

@ -21,6 +21,7 @@ use leo_types::{
Variable, Variable,
}; };
use crate::errors::ValueError;
use snarkos_models::{ use snarkos_models::{
curves::{Field, PrimeField}, curves::{Field, PrimeField},
gadgets::{ gadgets::{
@ -317,6 +318,27 @@ impl<F: Field + PrimeField, G: GroupType<F>> ConstrainedProgram<F, G> {
Ok(()) Ok(())
} }
fn check_return_types(expected: &Vec<Type>, actual: &Vec<Type>, span: Span) -> Result<(), StatementError> {
expected
.iter()
.zip(actual.iter())
.map(|(type_1, type_2)| {
if type_1.ne(type_2) {
// catch return Self type
if type_1.is_self() && type_2.is_circuit() {
Ok(())
} else {
Err(StatementError::arguments_type(type_1, type_2, span.clone()))
}
} else {
Ok(())
}
})
.collect::<Result<Vec<()>, StatementError>>()?;
Ok(())
}
fn enforce_return_statement<CS: ConstraintSystem<F>>( fn enforce_return_statement<CS: ConstraintSystem<F>>(
&mut self, &mut self,
cs: &mut CS, cs: &mut CS,
@ -336,7 +358,7 @@ impl<F: Field + PrimeField, G: GroupType<F>> ConstrainedProgram<F, G> {
} }
let mut returns = vec![]; let mut returns = vec![];
for (expression, ty) in expressions.into_iter().zip(return_types.into_iter()) { for (expression, ty) in expressions.into_iter().zip(return_types.clone().into_iter()) {
let expected_types = vec![ty.clone()]; let expected_types = vec![ty.clone()];
let result = self.enforce_expression_value( let result = self.enforce_expression_value(
cs, cs,
@ -350,6 +372,13 @@ impl<F: Field + PrimeField, G: GroupType<F>> ConstrainedProgram<F, G> {
returns.push(result); returns.push(result);
} }
let actual_types = returns
.iter()
.map(|value| value.to_type(span.clone()))
.collect::<Result<Vec<Type>, ValueError>>()?;
Self::check_return_types(&return_types, &actual_types, span)?;
Ok(ConstrainedValue::Return(returns)) Ok(ConstrainedValue::Return(returns))
} }

View File

@ -72,6 +72,20 @@ impl<F: Field + PrimeField, G: GroupType<F>> ConstrainedValue<F, G> {
ConstrainedValue::Field(_field) => Type::Field, ConstrainedValue::Field(_field) => Type::Field,
ConstrainedValue::Group(_group) => Type::Group, ConstrainedValue::Group(_group) => Type::Group,
ConstrainedValue::Boolean(_bool) => Type::Boolean, ConstrainedValue::Boolean(_bool) => Type::Boolean,
ConstrainedValue::Array(types) => {
let array_type = types[0].to_type(span.clone())?;
let count = types.len();
// nested array type
if let Type::Array(inner_type, inner_dimensions) = &array_type {
let mut dimensions = inner_dimensions.clone();
dimensions.push(count);
return Ok(Type::Array(inner_type.clone(), dimensions));
}
Type::Array(Box::new(array_type), vec![count])
}
ConstrainedValue::CircuitExpression(id, _members) => Type::Circuit(id.clone()),
value => return Err(ValueError::implicit(value.to_string(), span)), value => return Err(ValueError::implicit(value.to_string(), span)),
}) })
} }
@ -256,33 +270,21 @@ impl<F: Field + PrimeField, G: GroupType<F>> ConditionalEqGadget<F> for Constrai
condition: &Boolean, condition: &Boolean,
) -> Result<(), SynthesisError> { ) -> Result<(), SynthesisError> {
match (self, other) { match (self, other) {
(ConstrainedValue::Boolean(bool_1), ConstrainedValue::Boolean(bool_2)) => bool_1.conditional_enforce_equal( (ConstrainedValue::Boolean(bool_1), ConstrainedValue::Boolean(bool_2)) => {
cs.ns(|| format!("{} == {}", self.to_string(), other.to_string())), bool_1.conditional_enforce_equal(cs, bool_2, &condition)
bool_2, }
&condition, (ConstrainedValue::Integer(num_1), ConstrainedValue::Integer(num_2)) => {
), num_1.conditional_enforce_equal(cs, num_2, &condition)
(ConstrainedValue::Integer(num_1), ConstrainedValue::Integer(num_2)) => num_1.conditional_enforce_equal( }
cs.ns(|| format!("{} == {}", self.to_string(), other.to_string())), (ConstrainedValue::Field(field_1), ConstrainedValue::Field(field_2)) => {
num_2, field_1.conditional_enforce_equal(cs, field_2, &condition)
&condition, }
), (ConstrainedValue::Group(group_1), ConstrainedValue::Group(group_2)) => {
(ConstrainedValue::Field(field_1), ConstrainedValue::Field(field_2)) => field_1.conditional_enforce_equal( group_1.conditional_enforce_equal(cs, group_2, &condition)
cs.ns(|| format!("{} == {}", self.to_string(), other.to_string())), }
field_2,
&condition,
),
(ConstrainedValue::Group(group_1), ConstrainedValue::Group(group_2)) => group_1.conditional_enforce_equal(
cs.ns(|| format!("{} == {}", self.to_string(), other.to_string())),
group_2,
&condition,
),
(ConstrainedValue::Array(arr_1), ConstrainedValue::Array(arr_2)) => { (ConstrainedValue::Array(arr_1), ConstrainedValue::Array(arr_2)) => {
for (i, (left, right)) in arr_1.into_iter().zip(arr_2.into_iter()).enumerate() { for (i, (left, right)) in arr_1.into_iter().zip(arr_2.into_iter()).enumerate() {
left.conditional_enforce_equal( left.conditional_enforce_equal(cs.ns(|| format!("array[{}]", i)), right, &condition)?;
cs.ns(|| format!("array[{}] equal {} == {}", i, left.to_string(), right.to_string())),
right,
&condition,
)?;
} }
Ok(()) Ok(())
} }
@ -304,49 +306,22 @@ impl<F: Field + PrimeField, G: GroupType<F>> CondSelectGadget<F> for Constrained
) -> Result<Self, SynthesisError> { ) -> Result<Self, SynthesisError> {
Ok(match (first, second) { Ok(match (first, second) {
(ConstrainedValue::Boolean(bool_1), ConstrainedValue::Boolean(bool_2)) => { (ConstrainedValue::Boolean(bool_1), ConstrainedValue::Boolean(bool_2)) => {
ConstrainedValue::Boolean(Boolean::conditionally_select( ConstrainedValue::Boolean(Boolean::conditionally_select(cs, cond, bool_1, bool_2)?)
cs.ns(|| format!("if cond ? {} else {}", first.to_string(), second.to_string())),
cond,
bool_1,
bool_2,
)?)
} }
(ConstrainedValue::Integer(num_1), ConstrainedValue::Integer(num_2)) => { (ConstrainedValue::Integer(num_1), ConstrainedValue::Integer(num_2)) => {
ConstrainedValue::Integer(Integer::conditionally_select( ConstrainedValue::Integer(Integer::conditionally_select(cs, cond, num_1, num_2)?)
cs.ns(|| format!("if cond ? {} else {}", first.to_string(), second.to_string())),
cond,
num_1,
num_2,
)?)
} }
(ConstrainedValue::Field(field_1), ConstrainedValue::Field(field_2)) => { (ConstrainedValue::Field(field_1), ConstrainedValue::Field(field_2)) => {
ConstrainedValue::Field(FieldType::conditionally_select( ConstrainedValue::Field(FieldType::conditionally_select(cs, cond, field_1, field_2)?)
cs.ns(|| format!("if cond ? {} else {}", first.to_string(), second.to_string())),
cond,
field_1,
field_2,
)?)
} }
(ConstrainedValue::Group(group_1), ConstrainedValue::Group(group_2)) => { (ConstrainedValue::Group(group_1), ConstrainedValue::Group(group_2)) => {
ConstrainedValue::Group(G::conditionally_select( ConstrainedValue::Group(G::conditionally_select(cs, cond, group_1, group_2)?)
cs.ns(|| format!("if cond ? {} else {}", first.to_string(), second.to_string())),
cond,
group_1,
group_2,
)?)
} }
(ConstrainedValue::Array(arr_1), ConstrainedValue::Array(arr_2)) => { (ConstrainedValue::Array(arr_1), ConstrainedValue::Array(arr_2)) => {
let mut array = vec![]; let mut array = vec![];
for (i, (first, second)) in arr_1.into_iter().zip(arr_2.into_iter()).enumerate() { for (i, (first, second)) in arr_1.into_iter().zip(arr_2.into_iter()).enumerate() {
array.push(Self::conditionally_select( array.push(Self::conditionally_select(
cs.ns(|| { cs.ns(|| format!("array[{}]", i,)),
format!(
"array[{}] = if cond ? {} else {}",
i,
first.to_string(),
second.to_string()
)
}),
cond, cond,
first, first,
second, second,

View File

@ -1,5 +1,5 @@
use crate::errors::{BooleanError, ExpressionError, ValueError}; use crate::errors::{BooleanError, ExpressionError, ValueError};
use leo_types::{Error as FormattedError, IntegerError, Span}; use leo_types::{Error as FormattedError, IntegerError, Span, Type};
use std::path::PathBuf; use std::path::PathBuf;
#[derive(Debug, Error)] #[derive(Debug, Error)]
@ -35,6 +35,12 @@ impl StatementError {
StatementError::Error(FormattedError::new_from_span(message, span)) StatementError::Error(FormattedError::new_from_span(message, span))
} }
pub fn arguments_type(expected: &Type, actual: &Type, span: Span) -> Self {
let message = format!("expected return argument type `{}`, found type `{}`", expected, actual);
Self::new_from_span(message, span)
}
pub fn array_assign_index(span: Span) -> Self { pub fn array_assign_index(span: Span) -> Self {
let message = format!("Cannot assign single index to array of values"); let message = format!("Cannot assign single index to array of values");

View File

@ -25,7 +25,6 @@ fn test_undefined() {
CompilerError::FunctionError(FunctionError::StatementError(StatementError::ExpressionError( CompilerError::FunctionError(FunctionError::StatementError(StatementError::ExpressionError(
ExpressionError::Error(error), ExpressionError::Error(error),
))) => { ))) => {
println!("{}", error);
assert_eq!( assert_eq!(
format!("{}", error), format!("{}", error),
vec![ vec![

View File

@ -15,6 +15,22 @@ pub enum Type {
SelfType, SelfType,
} }
impl Type {
pub fn is_self(&self) -> bool {
if let Type::SelfType = self {
return true;
}
false
}
pub fn is_circuit(&self) -> bool {
if let Type::Circuit(_) = self {
return true;
}
false
}
}
/// pest ast -> Explicit Type for defining circuit members and function params /// pest ast -> Explicit Type for defining circuit members and function params
impl From<DataType> for Type { impl From<DataType> for Type {
@ -60,29 +76,29 @@ impl<'ast> From<AstType<'ast>> for Type {
impl Type { impl Type {
pub fn outer_dimension(&self, dimensions: &Vec<usize>) -> Self { pub fn outer_dimension(&self, dimensions: &Vec<usize>) -> Self {
let _type = self.clone(); let type_ = self.clone();
if dimensions.len() > 1 { if dimensions.len() > 1 {
let mut next = vec![]; let mut next = vec![];
next.extend_from_slice(&dimensions[1..]); next.extend_from_slice(&dimensions[1..]);
return Type::Array(Box::new(_type), next); return Type::Array(Box::new(type_), next);
} }
_type type_
} }
pub fn inner_dimension(&self, dimensions: &Vec<usize>) -> Self { pub fn inner_dimension(&self, dimensions: &Vec<usize>) -> Self {
let _type = self.clone(); let type_ = self.clone();
if dimensions.len() > 1 { if dimensions.len() > 1 {
let mut next = vec![]; let mut next = vec![];
next.extend_from_slice(&dimensions[..dimensions.len() - 1]); next.extend_from_slice(&dimensions[..dimensions.len() - 1]);
return Type::Array(Box::new(_type), next); return Type::Array(Box::new(type_), next);
} }
_type type_
} }
} }
@ -93,8 +109,8 @@ impl fmt::Display for Type {
Type::Field => write!(f, "field"), Type::Field => write!(f, "field"),
Type::Group => write!(f, "group"), Type::Group => write!(f, "group"),
Type::Boolean => write!(f, "bool"), Type::Boolean => write!(f, "bool"),
Type::Circuit(ref variable) => write!(f, "{}", variable), Type::Circuit(ref variable) => write!(f, "circuit {}", variable),
Type::SelfType => write!(f, "Self"), Type::SelfType => write!(f, "SelfType"),
Type::Array(ref array, ref dimensions) => { Type::Array(ref array, ref dimensions) => {
write!(f, "{}", *array)?; write!(f, "{}", *array)?;
for row in dimensions { for row in dimensions {