syntax for nested arrays. enforce_eq for nested arrays

This commit is contained in:
collin 2020-05-12 13:42:10 -07:00
parent 7422c66d95
commit ab4a9c6058
9 changed files with 87 additions and 49 deletions

View File

@ -1,20 +1,10 @@
// Basic Pedersen hash function
// 1 window x 3 bits
struct PedersenCRHParameters {
bases: fe[3]
}
function main() {
let arr1: u32[2][3] = [[0; 3]; 2];
let arr2: u32[2][3] = [[0, 0, 0], [0, 0, 0]];
function main(bits: private bool[3]) -> fe {
let params = PedersenCRHParameters { bases: [1fe, 1fe, 1fe] };
assert_eq(arr1, arr2);
let res = 0fe;
for bit in 0..3 {
//if (bits[bit]) {
res += params.bases[bit];
//}
}
return res
}

View File

@ -94,11 +94,11 @@ fn main() {
// Set main function arguments in compiled program
// let argument = Some(InputValue::Field(Fr::one()));
let bool_true = InputValue::Boolean(true);
let array = InputValue::Array(vec![bool_true.clone(), bool_true.clone(), bool_true.clone()]);
let argument = Some(array);
program.parameters = vec![argument];
// let bool_true = InputValue::Boolean(true);
// let array = InputValue::Array(vec![bool_true.clone(), bool_true.clone(), bool_true.clone()]);
// let argument = Some(array);
//
// program.parameters = vec![argument];
// Generate proof
let proof = create_random_proof(program, &params, rng).unwrap();

View File

@ -193,7 +193,7 @@ pub enum BasicOrStructType<'ast> {
#[pest_ast(rule(Rule::type_array))]
pub struct ArrayType<'ast> {
pub _type: BasicType<'ast>,
pub count: Value<'ast>,
pub dimensions: Vec<Value<'ast>>,
#[pest_ast(outer())]
pub span: Span<'ast>,
}

View File

@ -99,23 +99,27 @@ impl<F: Field + PrimeField, CS: ConstraintSystem<F>> ConstrainedProgram<F, CS> {
fn allocate_array(
&mut self,
cs: &mut CS,
input_model: InputModel<F>,
array_name: Variable<F>,
array_private: bool,
array_type: Type<F>,
array_dimensions: Vec<usize>,
input_value: Option<InputValue<F>>,
expected_length: usize,
) -> Result<ConstrainedValue<F>, FunctionError> {
let expected_length = array_dimensions[0];
let mut array_value = vec![];
let array_input_type = input_model.inner_type()?;
match input_value {
Some(InputValue::Array(arr)) => {
// Check the dimension of the array
Self::check_inputs_length(expected_length, arr.len())?;
// Allocate each value in the current row
for (i, value) in arr.into_iter().enumerate() {
let array_input_model = InputModel {
private: input_model.private.clone(),
_type: array_input_type.clone(),
private: array_private,
_type: array_type.next_dimension(&array_dimensions),
variable: new_variable_from_variables(
&input_model.variable,
&array_name,
&Variable::new(i.to_string()),
),
};
@ -128,13 +132,13 @@ impl<F: Field + PrimeField, CS: ConstraintSystem<F>> ConstrainedProgram<F, CS> {
}
}
None => {
// Allocate all parameter values
// Allocate all row values as none
for i in 0..expected_length {
let array_input_model = InputModel {
private: input_model.private.clone(),
_type: array_input_type.clone(),
private: array_private,
_type: array_type.next_dimension(&array_dimensions),
variable: new_variable_from_variables(
&input_model.variable,
&array_name,
&Variable::new(i.to_string()),
),
};
@ -170,9 +174,14 @@ impl<F: Field + PrimeField, CS: ConstraintSystem<F>> ConstrainedProgram<F, CS> {
Ok(self.field_element_from_input(cs, input_model, input_value)?)
}
Type::Boolean => Ok(self.bool_from_input(cs, input_model, input_value)?),
Type::Array(ref _type, length) => {
self.allocate_array(cs, input_model, input_value, length)
}
Type::Array(_type, dimensions) => self.allocate_array(
cs,
input_model.variable,
input_model.private,
*_type,
dimensions,
input_value,
),
_ => unimplemented!("main function input not implemented for type"),
}
}

View File

@ -374,11 +374,21 @@ impl<F: Field + PrimeField, CS: ConstraintSystem<F>> ConstrainedProgram<F, CS> {
(ConstrainedValue::FieldElement(fe_1), ConstrainedValue::FieldElement(fe_2)) => {
self.enforce_field_eq(cs, fe_1, fe_2)
}
(ConstrainedValue::Array(arr_1), ConstrainedValue::Array(arr_2)) => {
for (left, right) in arr_1.into_iter().zip(arr_2.into_iter()) {
self.enforce_assert_eq_statement(cs, left, right)?;
}
}
(val_1, val_2) => {
return Err(StatementError::AssertEq(
val_1.to_string(),
val_2.to_string(),
))
unimplemented!(
"assert eq not supported for given types {} == {}",
val_1,
val_2
)
// return Err(StatementError::AssertEq(
// val_1.to_string(),
// val_2.to_string(),
// ))
}
})
}

View File

@ -35,17 +35,21 @@ impl<F: Field + PrimeField> ConstrainedValue<F> {
}
(ConstrainedValue::FieldElement(ref _f), Type::FieldElement) => {}
(ConstrainedValue::Boolean(ref _b), Type::Boolean) => {}
(ConstrainedValue::Array(ref arr), Type::Array(ref ty, ref len)) => {
(ConstrainedValue::Array(ref arr), Type::Array(ref _type, ref dimensions)) => {
// check array lengths are equal
if arr.len() != *len {
if arr.len() != dimensions[0] {
return Err(ValueError::ArrayLength(format!(
"Expected array {:?} to be length {}",
arr, len
"Expected array {:?} to be length {:?}",
arr, dimensions[0]
)));
}
// get next dimension of array if nested
let next_type = _type.next_dimension(dimensions);
// check each value in array matches
for value in arr {
value.expect_type(ty)?;
value.expect_type(&next_type)?;
}
}
(
@ -59,9 +63,9 @@ impl<F: Field + PrimeField> ConstrainedValue<F> {
)));
}
}
(ConstrainedValue::Return(ref values), ty) => {
(ConstrainedValue::Return(ref values), _type) => {
for value in values {
value.expect_type(ty)?;
value.expect_type(_type)?;
}
}
(value, _type) => {

View File

@ -163,10 +163,25 @@ pub enum Type<F: Field + PrimeField> {
IntegerType(IntegerType),
FieldElement,
Boolean,
Array(Box<Type<F>>, usize),
Array(Box<Type<F>>, Vec<usize>),
Struct(Variable<F>),
}
impl<F: Field + PrimeField> Type<F> {
pub fn next_dimension(&self, dimensions: &Vec<usize>) -> Self {
let _type = self.clone();
if dimensions.len() > 1 {
let mut next = vec![];
next.extend_from_slice(&dimensions[1..]);
return Type::Array(Box::new(_type), next);
}
_type
}
}
#[derive(Clone, PartialEq, Eq)]
pub enum ConditionalNestedOrEnd<F: Field + PrimeField> {
Nested(Box<ConditionalStatement<F>>),
@ -223,7 +238,7 @@ pub struct InputModel<F: Field + PrimeField> {
impl<F: Field + PrimeField> InputModel<F> {
pub fn inner_type(&self) -> Result<Type<F>, ValueError> {
match self._type {
match &self._type {
Type::Array(ref _type, _length) => Ok(*_type.clone()),
ref _type => Err(ValueError::ArrayModel(_type.to_string())),
}

View File

@ -261,7 +261,13 @@ impl<F: Field + PrimeField> fmt::Display for Type<F> {
Type::FieldElement => write!(f, "fe"),
Type::Boolean => write!(f, "bool"),
Type::Struct(ref variable) => write!(f, "{}", variable),
Type::Array(ref array, ref count) => write!(f, "{}[{}]", array, count),
Type::Array(ref array, ref dimensions) => {
write!(f, "{}", *array)?;
for row in dimensions {
write!(f, "[{}]", row)?;
}
write!(f, "")
}
}
}
}

View File

@ -610,9 +610,13 @@ impl<'ast, F: Field + PrimeField> From<ast::BasicType<'ast>> for types::Type<F>
impl<'ast, F: Field + PrimeField> From<ast::ArrayType<'ast>> for types::Type<F> {
fn from(array_type: ast::ArrayType<'ast>) -> Self {
let element_type = Box::new(types::Type::from(array_type._type));
let count = types::Expression::<F>::get_count(array_type.count);
let dimensions = array_type
.dimensions
.into_iter()
.map(|row| types::Expression::<F>::get_count(row))
.collect();
types::Type::Array(element_type, count)
types::Type::Array(element_type, dimensions)
}
}