From 025d9ab45df10adb160ae6b96638638cb1e571dc Mon Sep 17 00:00:00 2001 From: collin Date: Wed, 15 Apr 2020 13:47:53 -0700 Subject: [PATCH] constraints array access --- simple.program | 13 +----- src/aleo_program/constraints.rs | 71 +++++++++++++++++++++++++++- src/aleo_program/types.rs | 11 ++++- src/aleo_program/types_display.rs | 21 ++++++++- src/aleo_program/types_from.rs | 78 +++++++++++++++++++++++++++---- src/ast.rs | 16 +++---- src/language.pest | 2 +- 7 files changed, 179 insertions(+), 33 deletions(-) diff --git a/simple.program b/simple.program index b30af8c1d9..c099914c19 100644 --- a/simple.program +++ b/simple.program @@ -1,14 +1,5 @@ -struct Point { - field x - field y -} +bool[2] a = [true, false] -Point p = Point {x: 1, y: 0} - -//bool[2] a = [true, false] -//bool[2] b = [true; 2] - -//field[4] c = [1, 2, 3, 4] -//field[3] d = [1; 3] +p = a[0..2] return p \ No newline at end of file diff --git a/src/aleo_program/constraints.rs b/src/aleo_program/constraints.rs index 4826a47573..3979060e75 100644 --- a/src/aleo_program/constraints.rs +++ b/src/aleo_program/constraints.rs @@ -1,6 +1,7 @@ use crate::aleo_program::{ BooleanExpression, BooleanSpreadOrExpression, Expression, FieldExpression, - FieldSpreadOrExpression, Function, Program, Statement, Struct, StructMember, Type, Variable, + FieldRangeOrExpression, FieldSpreadOrExpression, Function, Program, Statement, Struct, + StructMember, Type, Variable, }; use snarkos_models::curves::{Field, PrimeField}; @@ -463,6 +464,70 @@ impl ResolvedProgram { } } + fn enforce_index>( + &mut self, + cs: &mut CS, + index: FieldExpression, + ) -> usize { + match self.enforce_field_expression(cs, index) { + ResolvedValue::FieldElement(number) => number.value.unwrap() as usize, + value => unimplemented!("From index must resolve to a uint32, got {}", value), + } + } + + fn enforce_array_access_expression>( + &mut self, + cs: &mut CS, + array: Box, + index: FieldRangeOrExpression, + ) -> ResolvedValue { + match self.enforce_expression(cs, *array) { + ResolvedValue::FieldElementArray(field_array) => { + match index { + FieldRangeOrExpression::Range(from, to) => { + let from_resolved = match from { + Some(from_index) => self.enforce_index(cs, from_index), + None => 0usize, // Array slice starts at index 0 + }; + let to_resolved = match to { + Some(to_index) => self.enforce_index(cs, to_index), + None => field_array.len(), // Array slice ends at array length + }; + ResolvedValue::FieldElementArray( + field_array[from_resolved..to_resolved].to_owned(), + ) + } + FieldRangeOrExpression::FieldExpression(index) => { + let index_resolved = self.enforce_index(cs, index); + ResolvedValue::FieldElement(field_array[index_resolved].to_owned()) + } + } + } + ResolvedValue::BooleanArray(bool_array) => { + match index { + FieldRangeOrExpression::Range(from, to) => { + let from_resolved = match from { + Some(from_index) => self.enforce_index(cs, from_index), + None => 0usize, // Array slice starts at index 0 + }; + let to_resolved = match to { + Some(to_index) => self.enforce_index(cs, to_index), + None => bool_array.len(), // Array slice ends at array length + }; + ResolvedValue::BooleanArray( + bool_array[from_resolved..to_resolved].to_owned(), + ) + } + FieldRangeOrExpression::FieldExpression(index) => { + let index_resolved = self.enforce_index(cs, index); + ResolvedValue::Boolean(bool_array[index_resolved].to_owned()) + } + } + } + value => unimplemented!("Cannot access element of untyped array"), + } + } + fn enforce_expression>( &mut self, cs: &mut CS, @@ -499,6 +564,9 @@ impl ResolvedProgram { Expression::Struct(struct_name, members) => { self.enforce_struct_expression(cs, struct_name, members) } + Expression::ArrayAccess(array, index) => { + self.enforce_array_access_expression(cs, array, index) + } // _ => unimplemented!("expression not enforced yet") } } @@ -596,6 +664,7 @@ impl ResolvedProgram { Expression::Struct(_v, _m) => { unimplemented!("return struct not impl"); } + _ => unimplemented!("expression can't be returned yet"), }); } }; diff --git a/src/aleo_program/types.rs b/src/aleo_program/types.rs index 81fc40b58d..b319fccb5e 100644 --- a/src/aleo_program/types.rs +++ b/src/aleo_program/types.rs @@ -21,6 +21,13 @@ pub enum FieldSpreadOrExpression { FieldExpression(FieldExpression), } +/// Range or field expression enum +#[derive(Debug, Clone)] +pub enum FieldRangeOrExpression { + Range(Option, Option), + FieldExpression(FieldExpression), +} + /// Expression that evaluates to a field value #[derive(Debug, Clone)] pub enum FieldExpression { @@ -85,7 +92,9 @@ pub enum Expression { Boolean(BooleanExpression), FieldElement(FieldExpression), Variable(Variable), + ArrayAccess(Box, FieldRangeOrExpression), Struct(Variable, Vec), + // StructMemberAccess(Variable, Variable)// (struct name, struct member name) } /// Program statement that defines some action (or expression) to be carried out. @@ -98,8 +107,8 @@ pub enum Statement { #[derive(Clone, Debug)] pub enum Type { - FieldElement, Boolean, + FieldElement, Array(Box, usize), Struct(Variable), } diff --git a/src/aleo_program/types_display.rs b/src/aleo_program/types_display.rs index 36eb76aa30..41b039aeb6 100644 --- a/src/aleo_program/types_display.rs +++ b/src/aleo_program/types_display.rs @@ -6,7 +6,8 @@ use crate::aleo_program::{ BooleanExpression, BooleanSpread, BooleanSpreadOrExpression, Expression, FieldExpression, - FieldSpread, FieldSpreadOrExpression, Statement, Struct, StructField, Type, Variable, + FieldRangeOrExpression, FieldSpread, FieldSpreadOrExpression, Statement, Struct, StructField, + Type, Variable, }; use std::fmt; @@ -108,6 +109,22 @@ impl<'ast> fmt::Display for BooleanExpression { } } +impl<'ast> fmt::Display for FieldRangeOrExpression { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + FieldRangeOrExpression::Range(ref from, ref to) => write!( + f, + "{}..{}", + from.as_ref() + .map(|e| e.to_string()) + .unwrap_or("".to_string()), + to.as_ref().map(|e| e.to_string()).unwrap_or("".to_string()) + ), + FieldRangeOrExpression::FieldExpression(ref e) => write!(f, "{}", e), + } + } +} + impl<'ast> fmt::Display for Expression { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { @@ -124,6 +141,8 @@ impl<'ast> fmt::Display for Expression { } write!(f, "}}") } + Expression::ArrayAccess(ref array, ref index) => write!(f, "{}[{}]", array, index), + _ => unimplemented!("can't display expression yet"), } } } diff --git a/src/aleo_program/types_from.rs b/src/aleo_program/types_from.rs index 6c7a52dd8e..50c0a5df95 100644 --- a/src/aleo_program/types_from.rs +++ b/src/aleo_program/types_from.rs @@ -76,11 +76,7 @@ impl<'ast> From> for types::BooleanExpression { match types::Expression::from(expression) { types::Expression::Boolean(boolean_expression) => boolean_expression, types::Expression::Variable(variable) => types::BooleanExpression::Variable(variable), - types::Expression::FieldElement(field_expression) => unimplemented!( - "cannot compare field expression {} in boolean expression", - field_expression - ), - types::Expression::Struct(_v, _m) => unimplemented!("no inline struct yet"), + _ => unimplemented!("expected boolean in boolean expression"), } } } @@ -90,11 +86,7 @@ impl<'ast> From> for types::FieldExpression { match types::Expression::from(expression) { types::Expression::FieldElement(field_expression) => field_expression, types::Expression::Variable(variable) => types::FieldExpression::Variable(variable), - types::Expression::Boolean(boolean_expression) => unimplemented!( - "cannot compare boolean expression {} in field expression", - boolean_expression - ), - types::Expression::Struct(_v, _m) => unimplemented!("no inline struct yet"), + _ => unimplemented!("expected field in field expression"), } } } @@ -281,6 +273,71 @@ impl<'ast> From> for types::Expression { } } +impl<'ast> From> for types::FieldRangeOrExpression { + fn from(range_or_expression: ast::RangeOrExpression<'ast>) -> Self { + match range_or_expression { + ast::RangeOrExpression::Range(range) => { + let from = range + .from + .map(|from| match types::Expression::from(from.0) { + types::Expression::FieldElement(field) => field, + expression => { + unimplemented!("Range bounds should be numbers, found {}", expression) + } + }); + let to = range.to.map(|to| match types::Expression::from(to.0) { + types::Expression::FieldElement(field) => field, + expression => { + unimplemented!("Range bounds should be numbers, found {}", expression) + } + }); + + types::FieldRangeOrExpression::Range(from, to) + } + ast::RangeOrExpression::Expression(expression) => { + match types::Expression::from(expression) { + types::Expression::FieldElement(field_expression) => { + types::FieldRangeOrExpression::FieldExpression(field_expression) + } + // types::Expression::ArrayAccess(expression, field), // recursive array access + expression => unimplemented!("expression must be field, found {}", expression), + } + } + } + } +} + +impl<'ast> From> for types::Expression { + fn from(expression: ast::PostfixExpression<'ast>) -> Self { + let variable = types::Expression::Variable(types::Variable::from(expression.variable)); + + // ast::PostFixExpression contains an array of "accesses": `a(34)[42]` is represented as `[a, [Call(34), Select(42)]]`, but Access call expressions + // are recursive, so it is `Select(Call(a, 34), 42)`. We apply this transformation here + + // we start with the id, and we fold the array of accesses by wrapping the current value + expression + .accesses + .into_iter() + .fold(variable, |acc, access| match access { + ast::Access::Call(a) => match acc { + types::Expression::Variable(_) => { + unimplemented!("function calls not implemented") + } + expression => { + unimplemented!("only function names are callable, found \"{}\"", expression) + } + }, + ast::Access::Member(struct_member) => { + unimplemented!("struct calls not implemented") + } + ast::Access::Select(array) => types::Expression::ArrayAccess( + Box::new(acc), + types::FieldRangeOrExpression::from(array.expression), + ), + }) + } +} + impl<'ast> From> for types::Expression { fn from(expression: ast::Expression<'ast>) -> Self { match expression { @@ -295,6 +352,7 @@ impl<'ast> From> for types::Expression { ast::Expression::ArrayInitializer(_expression) => { unimplemented!("unknown type for array initializer expression") } + ast::Expression::Postfix(expression) => types::Expression::from(expression), _ => unimplemented!(), } } diff --git a/src/ast.rs b/src/ast.rs index e2b358c1fc..55dd0b354c 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -392,13 +392,13 @@ pub enum RangeOrExpression<'ast> { Expression(Expression<'ast>), } -// #[derive(Clone, Debug, FromPest, PartialEq)] -// #[pest_ast(rule(Rule::call_access))] -// pub struct CallAccess<'ast> { -// pub expressions: Vec>, -// #[pest_ast(outer())] -// pub span: Span<'ast>, -// } +#[derive(Clone, Debug, FromPest, PartialEq)] +#[pest_ast(rule(Rule::access_call))] +pub struct CallAccess<'ast> { + pub expressions: Vec>, + #[pest_ast(outer())] + pub span: Span<'ast>, +} #[derive(Clone, Debug, FromPest, PartialEq)] #[pest_ast(rule(Rule::access_array))] @@ -419,7 +419,7 @@ pub struct MemberAccess<'ast> { #[derive(Clone, Debug, FromPest, PartialEq)] #[pest_ast(rule(Rule::access))] pub enum Access<'ast> { - // Call(CallAccess<'ast>), + Call(CallAccess<'ast>), Select(ArrayAccess<'ast>), Member(MemberAccess<'ast>), } diff --git a/src/language.pest b/src/language.pest index a35589be72..6730f3d4fa 100644 --- a/src/language.pest +++ b/src/language.pest @@ -86,7 +86,7 @@ access_array = { "[" ~ range_or_expression ~ "]" } access_call = { "(" ~ expression_tuple ~ ")" } access_member = { "." ~ variable } access = { access_array | access_call | access_member } -expression_postfix = { variable ~ access+ } // add ++ and -- operators +expression_postfix = { variable ~ access+ } spread = { "..." ~ expression } spread_or_expression = { spread | expression }