From 1bbe71d72627c909b948c4798332148e6ba07f54 Mon Sep 17 00:00:00 2001 From: collin Date: Mon, 20 Apr 2020 13:06:47 -0700 Subject: [PATCH] add and constrain explict type for u32 --- simple.program | 7 +- simple_import.program | 7 +- src/aleo_program/constraints.rs | 802 ++++++++++++++++-------------- src/aleo_program/types.rs | 257 ++++++---- src/aleo_program/types_display.rs | 107 ++-- src/aleo_program/types_from.rs | 654 ++++++++++++------------ src/ast.rs | 512 ++++++++++--------- src/language.pest | 13 +- src/main.rs | 2 +- 9 files changed, 1247 insertions(+), 1114 deletions(-) diff --git a/simple.program b/simple.program index 01982b85de..fe2089156d 100644 --- a/simple.program +++ b/simple.program @@ -1,4 +1,5 @@ -from "./simple_import" import foo +from "./simple_import" import Point -def main() -> (field): - return foo() \ No newline at end of file +def main() -> (Point): + Point p = Point { x: 1u32, y: 2u32} + return p \ No newline at end of file diff --git a/simple_import.program b/simple_import.program index 57f9d307c9..667fa6871c 100644 --- a/simple_import.program +++ b/simple_import.program @@ -1,3 +1,4 @@ -def foo() -> (field): - // return myGlobal <- not allowed - return 42 \ No newline at end of file +struct Point { + u32 x + u32 y +} \ No newline at end of file diff --git a/src/aleo_program/constraints.rs b/src/aleo_program/constraints.rs index 5b5e1fbb17..d6a79a66ad 100644 --- a/src/aleo_program/constraints.rs +++ b/src/aleo_program/constraints.rs @@ -1,7 +1,7 @@ use crate::aleo_program::{ - Assignee, BooleanExpression, BooleanSpreadOrExpression, Expression, FieldExpression, - FieldRangeOrExpression, FieldSpreadOrExpression, Function, Import, Program, Statement, Struct, - StructMember, Type, Variable, + Assignee, BooleanExpression, BooleanSpreadOrExpression, Expression, Function, Import, Integer, + IntegerExpression, IntegerRangeOrExpression, IntegerSpreadOrExpression, Program, Statement, + Struct, StructMember, Type, Variable, }; use crate::ast; @@ -15,20 +15,21 @@ use snarkos_models::gadgets::{ use std::collections::HashMap; use std::fmt; use std::fs; +use std::marker::PhantomData; #[derive(Clone)] -pub enum ResolvedValue { +pub enum ResolvedValue { Boolean(Boolean), BooleanArray(Vec), - FieldElement(UInt32), - FieldElementArray(Vec), - StructDefinition(Struct), - StructExpression(Variable, Vec), - Function(Function), - Return(Vec), // add Null for function returns + U32(UInt32), + U32Array(Vec), + StructDefinition(Struct), + StructExpression(Variable, Vec>), + Function(Function), + Return(Vec>), // add Null for function returns } -impl fmt::Display for ResolvedValue { +impl fmt::Display for ResolvedValue { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { ResolvedValue::Boolean(ref value) => write!(f, "{}", value.get_value().unwrap()), @@ -42,8 +43,8 @@ impl fmt::Display for ResolvedValue { } write!(f, "]") } - ResolvedValue::FieldElement(ref value) => write!(f, "{}", value.value.unwrap()), - ResolvedValue::FieldElementArray(ref array) => { + ResolvedValue::U32(ref value) => write!(f, "{}", value.value.unwrap()), + ResolvedValue::U32Array(ref array) => { write!(f, "[")?; for (i, e) in array.iter().enumerate() { write!(f, "{}", e.value.unwrap())?; @@ -83,59 +84,361 @@ impl fmt::Display for ResolvedValue { } } -pub struct ResolvedProgram { - pub resolved_names: HashMap, +pub struct ResolvedProgram> { + pub resolved_names: HashMap>, + pub _cs: PhantomData, } pub fn new_scope(outer: String, inner: String) -> String { format!("{}_{}", outer, inner) } -pub fn new_scope_from_variable(outer: String, inner: &Variable) -> String { - new_scope(outer, inner.0.clone()) +pub fn new_scope_from_variable( + outer: String, + inner: &Variable, +) -> String { + new_scope(outer, inner.name.clone()) } -impl ResolvedProgram { +impl> ResolvedProgram { fn new() -> Self { Self { resolved_names: HashMap::new(), + _cs: PhantomData::, } } - fn store(&mut self, name: String, value: ResolvedValue) { + fn store(&mut self, name: String, value: ResolvedValue) { self.resolved_names.insert(name, value); } - fn store_variable(&mut self, variable: Variable, value: ResolvedValue) { - self.store(variable.0, value); + fn store_variable(&mut self, variable: Variable, value: ResolvedValue) { + self.store(variable.name, value); } fn contains_name(&self, name: &String) -> bool { self.resolved_names.contains_key(name) } - fn contains_variable(&self, variable: &Variable) -> bool { - self.contains_name(&variable.0) + fn contains_variable(&self, variable: &Variable) -> bool { + self.contains_name(&variable.name) } - fn get(&self, name: &String) -> Option<&ResolvedValue> { + fn get(&self, name: &String) -> Option<&ResolvedValue> { self.resolved_names.get(name) } - fn get_mut(&mut self, name: &String) -> Option<&mut ResolvedValue> { + fn get_mut(&mut self, name: &String) -> Option<&mut ResolvedValue> { self.resolved_names.get_mut(name) } - fn get_mut_variable(&mut self, variable: &Variable) -> Option<&mut ResolvedValue> { - self.get_mut(&variable.0) + fn get_mut_variable(&mut self, variable: &Variable) -> Option<&mut ResolvedValue> { + self.get_mut(&variable.name) } - fn bool_from_variable>( + /// Constrain integers + + fn integer_from_variable( &mut self, cs: &mut CS, scope: String, - variable: Variable, + variable: Variable, + ) -> ResolvedValue { + // Evaluate variable name in current function scope + let variable_name = new_scope_from_variable(scope, &variable); + + if self.contains_name(&variable_name) { + // TODO: return synthesis error: "assignment missing" here + self.get(&variable_name).unwrap().clone() + } else { + // TODO: remove this after resolving arguments + let argument = std::env::args() + .nth(1) + .unwrap_or("1".into()) + .parse::() + .unwrap(); + + println!(" argument passed to command line a = {:?}\n", argument); + + // let a = 1; + ResolvedValue::U32(UInt32::alloc(cs.ns(|| variable.name), Some(argument)).unwrap()) + } + } + + fn get_integer_constant(integer: Integer) -> ResolvedValue { + match integer { + Integer::U32(u32_value) => ResolvedValue::U32(UInt32::constant(u32_value)), + } + } + + fn get_integer_value( + &mut self, + cs: &mut CS, + scope: String, + expression: IntegerExpression, + ) -> ResolvedValue { + match expression { + IntegerExpression::Variable(variable) => { + self.integer_from_variable(cs, scope, variable) + } + IntegerExpression::Number(number) => Self::get_integer_constant(number), + field => self.enforce_integer_expression(cs, scope, field), + } + } + + fn enforce_u32_equality(cs: &mut CS, left: UInt32, right: UInt32) -> Boolean { + left.conditional_enforce_equal( + cs.ns(|| format!("enforce field equal")), + &right, + &Boolean::Constant(true), + ) + .unwrap(); + + Boolean::Constant(true) + } + + fn enforce_integer_equality( + &mut self, + cs: &mut CS, + scope: String, + left: IntegerExpression, + right: IntegerExpression, ) -> Boolean { + let left = self.get_integer_value(cs, scope.clone(), left); + let right = self.get_integer_value(cs, scope.clone(), right); + + match (left, right) { + (ResolvedValue::U32(left_u32), ResolvedValue::U32(right_u32)) => { + Self::enforce_u32_equality(cs, left_u32, right_u32) + } + (left_int, right_int) => { + unimplemented!("equality not impl between {} == {}", left_int, right_int) + } + } + } + + fn enforce_u32_add(cs: &mut CS, left: UInt32, right: UInt32) -> ResolvedValue { + ResolvedValue::U32( + UInt32::addmany( + cs.ns(|| format!("enforce {} + {}", left.value.unwrap(), right.value.unwrap())), + &[left, right], + ) + .unwrap(), + ) + } + + fn enforce_add( + &mut self, + cs: &mut CS, + scope: String, + left: IntegerExpression, + right: IntegerExpression, + ) -> ResolvedValue { + let left = self.get_integer_value(cs, scope.clone(), left); + let right = self.get_integer_value(cs, scope.clone(), right); + + match (left, right) { + (ResolvedValue::U32(left_u32), ResolvedValue::U32(right_u32)) => { + Self::enforce_u32_add(cs, left_u32, right_u32) + } + (left_int, right_int) => { + unimplemented!("add not impl between {} + {}", left_int, right_int) + } + } + } + + fn enforce_u32_sub(cs: &mut CS, left: UInt32, right: UInt32) -> ResolvedValue { + ResolvedValue::U32( + left.sub( + cs.ns(|| format!("enforce {} - {}", left.value.unwrap(), right.value.unwrap())), + &right, + ) + .unwrap(), + ) + } + + fn enforce_sub( + &mut self, + cs: &mut CS, + scope: String, + left: IntegerExpression, + right: IntegerExpression, + ) -> ResolvedValue { + let left = self.get_integer_value(cs, scope.clone(), left); + let right = self.get_integer_value(cs, scope.clone(), right); + + match (left, right) { + (ResolvedValue::U32(left_u32), ResolvedValue::U32(right_u32)) => { + Self::enforce_u32_sub(cs, left_u32, right_u32) + } + (left_int, right_int) => { + unimplemented!("add not impl between {} + {}", left_int, right_int) + } + } + } + + fn enforce_u32_mul(cs: &mut CS, left: UInt32, right: UInt32) -> ResolvedValue { + ResolvedValue::U32( + left.mul( + cs.ns(|| format!("enforce {} * {}", left.value.unwrap(), right.value.unwrap())), + &right, + ) + .unwrap(), + ) + } + + fn enforce_mul( + &mut self, + cs: &mut CS, + scope: String, + left: IntegerExpression, + right: IntegerExpression, + ) -> ResolvedValue { + let left = self.get_integer_value(cs, scope.clone(), left); + let right = self.get_integer_value(cs, scope.clone(), right); + + match (left, right) { + (ResolvedValue::U32(left_u32), ResolvedValue::U32(right_u32)) => { + Self::enforce_u32_mul(cs, left_u32, right_u32) + } + (left_int, right_int) => { + unimplemented!("add not impl between {} + {}", left_int, right_int) + } + } + } + + fn enforce_u32_div(cs: &mut CS, left: UInt32, right: UInt32) -> ResolvedValue { + ResolvedValue::U32( + left.div( + cs.ns(|| format!("enforce {} / {}", left.value.unwrap(), right.value.unwrap())), + &right, + ) + .unwrap(), + ) + } + + fn enforce_div( + &mut self, + cs: &mut CS, + scope: String, + left: IntegerExpression, + right: IntegerExpression, + ) -> ResolvedValue { + let left = self.get_integer_value(cs, scope.clone(), left); + let right = self.get_integer_value(cs, scope.clone(), right); + + match (left, right) { + (ResolvedValue::U32(left_u32), ResolvedValue::U32(right_u32)) => { + Self::enforce_u32_div(cs, left_u32, right_u32) + } + (left_int, right_int) => { + unimplemented!("add not impl between {} + {}", left_int, right_int) + } + } + } + + fn enforce_u32_pow(cs: &mut CS, left: UInt32, right: UInt32) -> ResolvedValue { + ResolvedValue::U32( + left.pow( + cs.ns(|| { + format!( + "enforce {} ** {}", + left.value.unwrap(), + right.value.unwrap() + ) + }), + &right, + ) + .unwrap(), + ) + } + + fn enforce_pow( + &mut self, + cs: &mut CS, + scope: String, + left: IntegerExpression, + right: IntegerExpression, + ) -> ResolvedValue { + let left = self.get_integer_value(cs, scope.clone(), left); + let right = self.get_integer_value(cs, scope.clone(), right); + + match (left, right) { + (ResolvedValue::U32(left_u32), ResolvedValue::U32(right_u32)) => { + Self::enforce_u32_pow(cs, left_u32, right_u32) + } + (left_int, right_int) => { + unimplemented!("add not impl between {} + {}", left_int, right_int) + } + } + } + + fn enforce_integer_expression( + &mut self, + cs: &mut CS, + scope: String, + expression: IntegerExpression, + ) -> ResolvedValue { + match expression { + IntegerExpression::Variable(variable) => { + self.integer_from_variable(cs, scope, variable) + } + IntegerExpression::Number(number) => Self::get_integer_constant(number), + IntegerExpression::Add(left, right) => self.enforce_add(cs, scope, *left, *right), + IntegerExpression::Sub(left, right) => self.enforce_sub(cs, scope, *left, *right), + IntegerExpression::Mul(left, right) => self.enforce_mul(cs, scope, *left, *right), + IntegerExpression::Div(left, right) => self.enforce_div(cs, scope, *left, *right), + IntegerExpression::Pow(left, right) => self.enforce_pow(cs, scope, *left, *right), + IntegerExpression::IfElse(first, second, third) => { + let resolved_first = + match self.enforce_boolean_expression(cs, scope.clone(), *first) { + ResolvedValue::Boolean(resolved) => resolved, + _ => unimplemented!("if else conditional must resolve to boolean"), + }; + + if resolved_first.eq(&Boolean::Constant(true)) { + self.enforce_integer_expression(cs, scope, *second) + } else { + self.enforce_integer_expression(cs, scope, *third) + } + } + IntegerExpression::Array(array) => { + let mut result = vec![]; + array.into_iter().for_each(|element| match *element { + IntegerSpreadOrExpression::Spread(spread) => match spread { + IntegerExpression::Variable(variable) => { + let array_name = new_scope_from_variable(scope.clone(), &variable); + match self.get(&array_name) { + Some(value) => match value { + ResolvedValue::U32Array(array) => result.extend(array.clone()), + value => unimplemented!( + "spreads only implemented for arrays, got {}", + value + ), + }, + None => unimplemented!( + "cannot copy elements from array that does not exist {}", + variable.name + ), + } + } + value => { + unimplemented!("spreads only implemented for arrays, got {}", value) + } + }, + IntegerSpreadOrExpression::Expression(expression) => { + match self.enforce_integer_expression(cs, scope.clone(), expression) { + ResolvedValue::U32(value) => result.push(value), + _ => unimplemented!("cannot resolve field"), + } + } + }); + ResolvedValue::U32Array(result) + } + } + } + + fn bool_from_variable(&mut self, cs: &mut CS, scope: String, variable: Variable) -> Boolean { // Evaluate variable name in current function scope let variable_name = new_scope_from_variable(scope, &variable); @@ -153,44 +456,15 @@ impl ResolvedProgram { .unwrap(); println!(" argument passed to command line a = {:?}\n", argument); // let a = true; - Boolean::alloc(cs.ns(|| variable.0), || Ok(argument)).unwrap() + Boolean::alloc(cs.ns(|| variable.name), || Ok(argument)).unwrap() } } - fn u32_from_variable>( + fn get_bool_value( &mut self, cs: &mut CS, scope: String, - variable: Variable, - ) -> UInt32 { - // Evaluate variable name in current function scope - let variable_name = new_scope_from_variable(scope, &variable); - - if self.contains_name(&variable_name) { - // TODO: return synthesis error: "assignment missing" here - match self.get(&variable_name).unwrap() { - ResolvedValue::FieldElement(field) => field.clone(), - value => panic!("expected a field, got {}", value), - } - } else { - let argument = std::env::args() - .nth(1) - .unwrap_or("1".into()) - .parse::() - .unwrap(); - - println!(" argument passed to command line a = {:?}\n", argument); - - // let a = 1; - UInt32::alloc(cs.ns(|| variable.0), Some(argument)).unwrap() - } - } - - fn get_bool_value>( - &mut self, - cs: &mut CS, - scope: String, - expression: BooleanExpression, + expression: BooleanExpression, ) -> Boolean { match expression { BooleanExpression::Variable(variable) => self.bool_from_variable(cs, scope, variable), @@ -202,39 +476,23 @@ impl ResolvedProgram { } } - fn get_u32_value>( + fn enforce_not( &mut self, cs: &mut CS, scope: String, - expression: FieldExpression, - ) -> UInt32 { - match expression { - FieldExpression::Variable(variable) => self.u32_from_variable(cs, scope, variable), - FieldExpression::Number(number) => UInt32::constant(number), - field => match self.enforce_field_expression(cs, scope, field) { - ResolvedValue::FieldElement(value) => value, - _ => unimplemented!("field expression did not resolve to field"), - }, - } - } - - fn enforce_not>( - &mut self, - cs: &mut CS, - scope: String, - expression: BooleanExpression, + expression: BooleanExpression, ) -> Boolean { let expression = self.get_bool_value(cs, scope, expression); expression.not() } - fn enforce_or>( + fn enforce_or( &mut self, cs: &mut CS, scope: String, - left: BooleanExpression, - right: BooleanExpression, + left: BooleanExpression, + right: BooleanExpression, ) -> Boolean { let left = self.get_bool_value(cs, scope.clone(), left); let right = self.get_bool_value(cs, scope.clone(), right); @@ -242,12 +500,12 @@ impl ResolvedProgram { Boolean::or(cs, &left, &right).unwrap() } - fn enforce_and>( + fn enforce_and( &mut self, cs: &mut CS, scope: String, - left: BooleanExpression, - right: BooleanExpression, + left: BooleanExpression, + right: BooleanExpression, ) -> Boolean { let left = self.get_bool_value(cs, scope.clone(), left); let right = self.get_bool_value(cs, scope.clone(), right); @@ -255,12 +513,12 @@ impl ResolvedProgram { Boolean::and(cs, &left, &right).unwrap() } - fn enforce_bool_equality>( + fn enforce_bool_equality( &mut self, cs: &mut CS, scope: String, - left: BooleanExpression, - right: BooleanExpression, + left: BooleanExpression, + right: BooleanExpression, ) -> Boolean { let left = self.get_bool_value(cs, scope.clone(), left); let right = self.get_bool_value(cs, scope.clone(), right); @@ -271,32 +529,12 @@ impl ResolvedProgram { Boolean::Constant(true) } - fn enforce_field_equality>( + fn enforce_boolean_expression( &mut self, cs: &mut CS, scope: String, - left: FieldExpression, - right: FieldExpression, - ) -> Boolean { - let left = self.get_u32_value(cs, scope.clone(), left); - let right = self.get_u32_value(cs, scope.clone(), right); - - left.conditional_enforce_equal( - cs.ns(|| format!("enforce field equal")), - &right, - &Boolean::Constant(true), - ) - .unwrap(); - - Boolean::Constant(true) - } - - fn enforce_boolean_expression>( - &mut self, - cs: &mut CS, - scope: String, - expression: BooleanExpression, - ) -> ResolvedValue { + expression: BooleanExpression, + ) -> ResolvedValue { match expression { BooleanExpression::Variable(variable) => { ResolvedValue::Boolean(self.bool_from_variable(cs, scope, variable)) @@ -315,7 +553,7 @@ impl ResolvedProgram { ResolvedValue::Boolean(self.enforce_bool_equality(cs, scope, *left, *right)) } BooleanExpression::FieldEq(left, right) => { - ResolvedValue::Boolean(self.enforce_field_equality(cs, scope, *left, *right)) + ResolvedValue::Boolean(self.enforce_integer_equality(cs, scope, *left, *right)) } BooleanExpression::IfElse(first, second, third) => { let resolved_first = @@ -332,7 +570,7 @@ impl ResolvedProgram { BooleanExpression::Array(array) => { let mut result = vec![]; array.into_iter().for_each(|element| match *element { - BooleanSpreadOrExpression::Spread(spread) => match spread.0 { + BooleanSpreadOrExpression::Spread(spread) => match spread { BooleanExpression::Variable(variable) => { let array_name = new_scope_from_variable(scope.clone(), &variable); match self.get(&array_name) { @@ -347,7 +585,7 @@ impl ResolvedProgram { }, None => unimplemented!( "cannot copy elements from array that does not exist {}", - variable.0 + variable.name ), } } @@ -355,7 +593,7 @@ impl ResolvedProgram { unimplemented!("spreads only implemented for arrays, got {}", value) } }, - BooleanSpreadOrExpression::BooleanExpression(expression) => { + BooleanSpreadOrExpression::Expression(expression) => { match self.enforce_boolean_expression(cs, scope.clone(), expression) { ResolvedValue::Boolean(value) => result.push(value), value => { @@ -370,186 +608,13 @@ impl ResolvedProgram { } } - fn enforce_add>( + fn enforce_struct_expression( &mut self, cs: &mut CS, scope: String, - left: FieldExpression, - right: FieldExpression, - ) -> UInt32 { - let left = self.get_u32_value(cs, scope.clone(), left); - let right = self.get_u32_value(cs, scope.clone(), right); - - UInt32::addmany( - cs.ns(|| format!("enforce {} + {}", left.value.unwrap(), right.value.unwrap())), - &[left, right], - ) - .unwrap() - } - - fn enforce_sub>( - &mut self, - cs: &mut CS, - scope: String, - left: FieldExpression, - right: FieldExpression, - ) -> UInt32 { - let left = self.get_u32_value(cs, scope.clone(), left); - let right = self.get_u32_value(cs, scope.clone(), right); - - left.sub( - cs.ns(|| format!("enforce {} - {}", left.value.unwrap(), right.value.unwrap())), - &right, - ) - .unwrap() - } - - fn enforce_mul>( - &mut self, - cs: &mut CS, - scope: String, - left: FieldExpression, - right: FieldExpression, - ) -> UInt32 { - let left = self.get_u32_value(cs, scope.clone(), left); - let right = self.get_u32_value(cs, scope.clone(), right); - - let res = left - .mul( - cs.ns(|| format!("enforce {} * {}", left.value.unwrap(), right.value.unwrap())), - &right, - ) - .unwrap(); - - res - } - - fn enforce_div>( - &mut self, - cs: &mut CS, - scope: String, - left: FieldExpression, - right: FieldExpression, - ) -> UInt32 { - let left = self.get_u32_value(cs, scope.clone(), left); - let right = self.get_u32_value(cs, scope.clone(), right); - - left.div( - cs.ns(|| format!("enforce {} / {}", left.value.unwrap(), right.value.unwrap())), - &right, - ) - .unwrap() - } - - fn enforce_pow>( - &mut self, - cs: &mut CS, - scope: String, - left: FieldExpression, - right: FieldExpression, - ) -> UInt32 { - let left = self.get_u32_value(cs, scope.clone(), left); - let right = self.get_u32_value(cs, scope.clone(), right); - - left.pow( - cs.ns(|| { - format!( - "enforce {} ** {}", - left.value.unwrap(), - right.value.unwrap() - ) - }), - &right, - ) - .unwrap() - } - - fn enforce_field_expression>( - &mut self, - cs: &mut CS, - scope: String, - expression: FieldExpression, - ) -> ResolvedValue { - match expression { - FieldExpression::Variable(variable) => { - ResolvedValue::FieldElement(self.u32_from_variable(cs, scope, variable)) - } - FieldExpression::Number(number) => { - ResolvedValue::FieldElement(UInt32::constant(number)) - } - FieldExpression::Add(left, right) => { - ResolvedValue::FieldElement(self.enforce_add(cs, scope, *left, *right)) - } - FieldExpression::Sub(left, right) => { - ResolvedValue::FieldElement(self.enforce_sub(cs, scope, *left, *right)) - } - FieldExpression::Mul(left, right) => { - ResolvedValue::FieldElement(self.enforce_mul(cs, scope, *left, *right)) - } - FieldExpression::Div(left, right) => { - ResolvedValue::FieldElement(self.enforce_div(cs, scope, *left, *right)) - } - FieldExpression::Pow(left, right) => { - ResolvedValue::FieldElement(self.enforce_pow(cs, scope, *left, *right)) - } - FieldExpression::IfElse(first, second, third) => { - let resolved_first = - match self.enforce_boolean_expression(cs, scope.clone(), *first) { - ResolvedValue::Boolean(resolved) => resolved, - _ => unimplemented!("if else conditional must resolve to boolean"), - }; - - if resolved_first.eq(&Boolean::Constant(true)) { - self.enforce_field_expression(cs, scope, *second) - } else { - self.enforce_field_expression(cs, scope, *third) - } - } - FieldExpression::Array(array) => { - let mut result = vec![]; - array.into_iter().for_each(|element| match *element { - FieldSpreadOrExpression::Spread(spread) => match spread.0 { - FieldExpression::Variable(variable) => { - let array_name = new_scope_from_variable(scope.clone(), &variable); - match self.get(&array_name) { - Some(value) => match value { - ResolvedValue::FieldElementArray(array) => { - result.extend(array.clone()) - } - value => unimplemented!( - "spreads only implemented for arrays, got {}", - value - ), - }, - None => unimplemented!( - "cannot copy elements from array that does not exist {}", - variable.0 - ), - } - } - value => { - unimplemented!("spreads only implemented for arrays, got {}", value) - } - }, - FieldSpreadOrExpression::FieldExpression(expression) => { - match self.enforce_field_expression(cs, scope.clone(), expression) { - ResolvedValue::FieldElement(value) => result.push(value), - _ => unimplemented!("cannot resolve field"), - } - } - }); - ResolvedValue::FieldElementArray(result) - } - } - } - - fn enforce_struct_expression>( - &mut self, - cs: &mut CS, - scope: String, - variable: Variable, - members: Vec, - ) -> ResolvedValue { + variable: Variable, + members: Vec>, + ) -> ResolvedValue { if let Some(resolved_value) = self.get_mut_variable(&variable) { match resolved_value { ResolvedValue::StructDefinition(struct_definition) => { @@ -577,29 +642,24 @@ impl ResolvedProgram { } } - fn enforce_index>( - &mut self, - cs: &mut CS, - scope: String, - index: FieldExpression, - ) -> usize { - match self.enforce_field_expression(cs, scope.clone(), index) { - ResolvedValue::FieldElement(number) => number.value.unwrap() as usize, + fn enforce_index(&mut self, cs: &mut CS, scope: String, index: IntegerExpression) -> usize { + match self.enforce_integer_expression(cs, scope.clone(), index) { + ResolvedValue::U32(number) => number.value.unwrap() as usize, value => unimplemented!("From index must resolve to a uint32, got {}", value), } } - fn enforce_array_access_expression>( + fn enforce_array_access_expression( &mut self, cs: &mut CS, scope: String, - array: Box, - index: FieldRangeOrExpression, - ) -> ResolvedValue { + array: Box>, + index: IntegerRangeOrExpression, + ) -> ResolvedValue { match self.enforce_expression(cs, scope.clone(), *array) { - ResolvedValue::FieldElementArray(field_array) => { + ResolvedValue::U32Array(field_array) => { match index { - FieldRangeOrExpression::Range(from, to) => { + IntegerRangeOrExpression::Range(from, to) => { let from_resolved = match from { Some(from_index) => self.enforce_index(cs, scope.clone(), from_index), None => 0usize, // Array slice starts at index 0 @@ -608,19 +668,17 @@ impl ResolvedProgram { Some(to_index) => self.enforce_index(cs, scope.clone(), to_index), None => field_array.len(), // Array slice ends at array length }; - ResolvedValue::FieldElementArray( - field_array[from_resolved..to_resolved].to_owned(), - ) + ResolvedValue::U32Array(field_array[from_resolved..to_resolved].to_owned()) } - FieldRangeOrExpression::FieldExpression(index) => { + IntegerRangeOrExpression::Expression(index) => { let index_resolved = self.enforce_index(cs, scope.clone(), index); - ResolvedValue::FieldElement(field_array[index_resolved].to_owned()) + ResolvedValue::U32(field_array[index_resolved].to_owned()) } } } ResolvedValue::BooleanArray(bool_array) => { match index { - FieldRangeOrExpression::Range(from, to) => { + IntegerRangeOrExpression::Range(from, to) => { let from_resolved = match from { Some(from_index) => self.enforce_index(cs, scope.clone(), from_index), None => 0usize, // Array slice starts at index 0 @@ -633,7 +691,7 @@ impl ResolvedProgram { bool_array[from_resolved..to_resolved].to_owned(), ) } - FieldRangeOrExpression::FieldExpression(index) => { + IntegerRangeOrExpression::Expression(index) => { let index_resolved = self.enforce_index(cs, scope.clone(), index); ResolvedValue::Boolean(bool_array[index_resolved].to_owned()) } @@ -643,13 +701,13 @@ impl ResolvedProgram { } } - fn enforce_struct_access_expression>( + fn enforce_struct_access_expression( &mut self, cs: &mut CS, scope: String, - struct_variable: Box, - struct_member: Variable, - ) -> ResolvedValue { + struct_variable: Box>, + struct_member: Variable, + ) -> ResolvedValue { match self.enforce_expression(cs, scope.clone(), *struct_variable) { ResolvedValue::StructExpression(_name, members) => { let matched_member = members @@ -657,38 +715,38 @@ impl ResolvedProgram { .find(|member| member.variable == struct_member); match matched_member { Some(member) => self.enforce_expression(cs, scope.clone(), member.expression), - None => unimplemented!("Cannot access struct member {}", struct_member.0), + None => unimplemented!("Cannot access struct member {}", struct_member.name), } } value => unimplemented!("Cannot access element of untyped struct {}", value), } } - fn enforce_function_access_expression>( + fn enforce_function_access_expression( &mut self, cs: &mut CS, scope: String, - function: Box, - arguments: Vec, - ) -> ResolvedValue { + function: Box>, + arguments: Vec>, + ) -> ResolvedValue { match self.enforce_expression(cs, scope, *function) { ResolvedValue::Function(function) => self.enforce_function(cs, function, arguments), value => unimplemented!("Cannot call unknown function {}", value), } } - fn enforce_expression>( + fn enforce_expression( &mut self, cs: &mut CS, scope: String, - expression: Expression, - ) -> ResolvedValue { + expression: Expression, + ) -> ResolvedValue { match expression { Expression::Boolean(boolean_expression) => { self.enforce_boolean_expression(cs, scope, boolean_expression) } - Expression::FieldElement(field_expression) => { - self.enforce_field_expression(cs, scope, field_expression) + Expression::Integer(field_expression) => { + self.enforce_integer_expression(cs, scope, field_expression) } Expression::Variable(unresolved_variable) => { let variable_name = new_scope_from_variable(scope, &unresolved_variable); @@ -714,11 +772,7 @@ impl ResolvedProgram { unresolved_variable, )) } else { - ResolvedValue::FieldElement(self.u32_from_variable( - cs, - variable_name, - unresolved_variable, - )) + self.integer_from_variable(cs, variable_name, unresolved_variable) } } } @@ -734,10 +788,11 @@ impl ResolvedProgram { Expression::FunctionCall(function, arguments) => { self.enforce_function_access_expression(cs, scope, function, arguments) } + expression => unimplemented!("expression not impl {}", expression), } } - fn resolve_assignee(&mut self, scope: String, assignee: Assignee) -> String { + fn resolve_assignee(&mut self, scope: String, assignee: Assignee) -> String { match assignee { Assignee::Variable(name) => new_scope_from_variable(scope, &name), Assignee::Array(array, _index) => self.resolve_assignee(scope, *array), @@ -747,12 +802,12 @@ impl ResolvedProgram { } } - fn enforce_definition_statement>( + fn enforce_definition_statement( &mut self, cs: &mut CS, scope: String, - assignee: Assignee, - expression: Expression, + assignee: Assignee, + expression: Expression, ) { // Create or modify the lhs variable in the current function scope match assignee { @@ -774,16 +829,13 @@ impl ResolvedProgram { // Resolve index so we know if we are assigning to a single value or a range of values match index_expression { - FieldRangeOrExpression::FieldExpression(index) => { + IntegerRangeOrExpression::Expression(index) => { let index = self.enforce_index(cs, scope.clone(), index); // Modify the single value of the array in place match self.get_mut(&expected_array_name) { Some(value) => match (value, result) { - ( - ResolvedValue::FieldElementArray(old), - ResolvedValue::FieldElement(new), - ) => { + (ResolvedValue::U32Array(old), ResolvedValue::U32(new)) => { old[index] = new.to_owned(); } (ResolvedValue::BooleanArray(old), ResolvedValue::Boolean(new)) => { @@ -799,7 +851,7 @@ impl ResolvedProgram { ), } } - FieldRangeOrExpression::Range(from, to) => { + IntegerRangeOrExpression::Range(from, to) => { let from_index = match from { Some(expression) => self.enforce_index(cs, scope.clone(), expression), None => 0usize, @@ -814,10 +866,7 @@ impl ResolvedProgram { // Modify the range of values of the array in place match self.get_mut(&expected_array_name) { Some(value) => match (value, result) { - ( - ResolvedValue::FieldElementArray(old), - ResolvedValue::FieldElementArray(new), - ) => { + (ResolvedValue::U32Array(old), ResolvedValue::U32Array(new)) => { let to_index = to_index_option.unwrap_or(old.len()); old.splice(from_index..to_index, new.iter().cloned()); } @@ -873,26 +922,21 @@ impl ResolvedProgram { }; } - fn enforce_return_statement>( + fn enforce_return_statement( &mut self, cs: &mut CS, scope: String, - statements: Vec, - ) -> ResolvedValue { + statements: Vec>, + ) -> ResolvedValue { ResolvedValue::Return( statements .into_iter() .map(|expression| self.enforce_expression(cs, scope.clone(), expression)) - .collect::>(), + .collect::>>(), ) } - fn enforce_statement>( - &mut self, - cs: &mut CS, - scope: String, - statement: Statement, - ) { + fn enforce_statement(&mut self, cs: &mut CS, scope: String, statement: Statement) { match statement { Statement::Definition(variable, expression) => { self.enforce_definition_statement(cs, scope, variable, expression); @@ -907,14 +951,14 @@ impl ResolvedProgram { }; } - fn enforce_for_statement>( + fn enforce_for_statement( &mut self, cs: &mut CS, scope: String, - index: Variable, - start: FieldExpression, - stop: FieldExpression, - statements: Vec, + index: Variable, + start: IntegerExpression, + stop: IntegerExpression, + statements: Vec>, ) { let start_index = self.enforce_index(cs, scope.clone(), start); let stop_index = self.enforce_index(cs, scope.clone(), stop); @@ -923,10 +967,7 @@ impl ResolvedProgram { // Store index in current function scope. // For loop scope is not implemented. let index_name = new_scope_from_variable(scope.clone(), &index); - self.store( - index_name, - ResolvedValue::FieldElement(UInt32::constant(i as u32)), - ); + self.store(index_name, ResolvedValue::U32(UInt32::constant(i as u32))); // Evaluate statements statements @@ -936,12 +977,12 @@ impl ResolvedProgram { } } - fn enforce_function>( + fn enforce_function( &mut self, cs: &mut CS, - function: Function, - arguments: Vec, - ) -> ResolvedValue { + function: Function, + arguments: Vec>, + ) -> ResolvedValue { // Make sure we are given the correct number of arguments if function.parameters.len() != arguments.len() { unimplemented!( @@ -962,15 +1003,15 @@ impl ResolvedProgram { // Check that argument is correct type match parameter.ty.clone() { - Type::FieldElement => { + Type::U32 => { match self.enforce_expression(cs, function.get_name(), argument) { - ResolvedValue::FieldElement(field) => { + ResolvedValue::U32(field) => { // Store argument as variable with {function_name}_{parameter name} let variable_name = new_scope_from_variable( function.get_name(), ¶meter.variable, ); - self.store(variable_name, ResolvedValue::FieldElement(field)); + self.store(variable_name, ResolvedValue::U32(field)); } argument => unimplemented!("expected field argument got {}", argument), } @@ -1030,11 +1071,7 @@ impl ResolvedProgram { return_values } - fn enforce_import>( - &mut self, - cs: &mut CS, - import: Import, - ) { + fn enforce_import(&mut self, cs: &mut CS, import: Import) { // println!("import: {}", import); // Resolve program file path @@ -1056,11 +1093,7 @@ impl ResolvedProgram { // self.store(name, value) } - pub fn resolve_definitions>( - &mut self, - cs: &mut CS, - program: Program, - ) { + pub fn resolve_definitions(&mut self, cs: &mut CS, program: Program) { program .imports .into_iter() @@ -1079,10 +1112,7 @@ impl ResolvedProgram { }); } - pub fn generate_constraints>( - cs: &mut CS, - program: Program, - ) { + pub fn generate_constraints(cs: &mut CS, program: Program) { let mut resolved_program = ResolvedProgram::new(); resolved_program.resolve_definitions(cs, program); diff --git a/src/aleo_program/types.rs b/src/aleo_program/types.rs index ebf0c09f88..bbd70d0ec1 100644 --- a/src/aleo_program/types.rs +++ b/src/aleo_program/types.rs @@ -1,146 +1,192 @@ -//! A zokrates_program consists of nodes that keep track of position and wrap zokrates_program types. +//! A typed program in aleo consists of import, struct, and function definitions. +//! Each defined type consists of typed statements and expressions. //! //! @file types.rs //! @author Collin Chin //! @date 2020 use crate::aleo_program::Import; + +use snarkos_models::curves::{Field, PrimeField}; use std::collections::HashMap; +use std::marker::PhantomData; /// A variable in a constraint system. #[derive(Clone, PartialEq, Eq, Hash)] -pub struct Variable(pub String); - -/// Spread operator -#[derive(Debug, Clone)] -pub struct FieldSpread(pub FieldExpression); - -/// Spread or field expression enum -#[derive(Debug, Clone)] -pub enum FieldSpreadOrExpression { - Spread(FieldSpread), - FieldExpression(FieldExpression), +pub struct Variable { + pub name: String, + pub(crate) _field: PhantomData, } -/// Range or field expression enum +/// An integer type enum wrapping the integer value #[derive(Debug, Clone)] -pub enum FieldRangeOrExpression { - Range(Option, Option), - FieldExpression(FieldExpression), +pub enum Integer { + // U8(u8), + U32(u32), + // U64(u64), } +/// Spread operator or u32 expression enum +#[derive(Debug, Clone)] +pub enum IntegerSpreadOrExpression { + Spread(IntegerExpression), + Expression(IntegerExpression), +} + +/// Range or integer expression enum +#[derive(Debug, Clone)] +pub enum IntegerRangeOrExpression { + Range(Option>, Option>), + Expression(IntegerExpression), +} + +/// Expression that evaluates to a u32 value +#[derive(Debug, Clone)] +pub enum IntegerExpression { + Variable(Variable), + Number(Integer), + // Operators + Add(Box>, Box>), + Sub(Box>, Box>), + Mul(Box>, Box>), + Div(Box>, Box>), + Pow(Box>, Box>), + // Conditionals + IfElse( + Box>, + Box>, + Box>, + ), + // Arrays + Array(Vec>>), +} + +// /// Spread or field expression enum +// #[derive(Debug, Clone)] +// pub enum FieldSpreadOrExpression { +// Spread(FieldExpression), +// Expression(FieldExpression), +// } + /// Expression that evaluates to a field value #[derive(Debug, Clone)] -pub enum FieldExpression { - Variable(Variable), - Number(u32), - // Operators - Add(Box, Box), - Sub(Box, Box), - Mul(Box, Box), - Div(Box, Box), - Pow(Box, Box), - // Conditionals - IfElse( - Box, - Box, - Box, - ), - // Arrays - Array(Vec>), +pub enum FieldExpression { + Variable(Variable), + Number(F), + // // Operators + // Add(Box>, Box>), + // Sub(Box>, Box>), + // Mul(Box>, Box>), + // Div(Box>, Box>), + // Pow(Box>, Box>), + // // Conditionals + // IfElse( + // Box, + // Box>, + // Box>, + // ), + // // Arrays + // Array(Vec>>), } -/// Spread operator -#[derive(Debug, Clone)] -pub struct BooleanSpread(pub BooleanExpression); - /// Spread or field expression enum #[derive(Debug, Clone)] -pub enum BooleanSpreadOrExpression { - Spread(BooleanSpread), - BooleanExpression(BooleanExpression), +pub enum BooleanSpreadOrExpression { + Spread(BooleanExpression), + Expression(BooleanExpression), } /// Expression that evaluates to a boolean value #[derive(Debug, Clone)] -pub enum BooleanExpression { - Variable(Variable), +pub enum BooleanExpression { + Variable(Variable), Value(bool), // Boolean operators - Not(Box), - Or(Box, Box), - And(Box, Box), - BoolEq(Box, Box), + Not(Box>), + Or(Box>, Box>), + And(Box>, Box>), + BoolEq(Box>, Box>), // Field operators - FieldEq(Box, Box), - Geq(Box, Box), - Gt(Box, Box), - Leq(Box, Box), - Lt(Box, Box), + FieldEq(Box>, Box>), + Geq(Box>, Box>), + Gt(Box>, Box>), + Leq(Box>, Box>), + Lt(Box>, Box>), // Conditionals IfElse( - Box, - Box, - Box, + Box>, + Box>, + Box>, ), // Arrays - Array(Vec>), + Array(Vec>>), } /// Expression that evaluates to a value #[derive(Debug, Clone)] -pub enum Expression { - Boolean(BooleanExpression), - FieldElement(FieldExpression), - Variable(Variable), - Struct(Variable, Vec), - ArrayAccess(Box, FieldRangeOrExpression), - StructMemberAccess(Box, Variable), // (struct name, struct member name) - FunctionCall(Box, Vec), +pub enum Expression { + Integer(IntegerExpression), + FieldElement(FieldExpression), + Boolean(BooleanExpression), + Variable(Variable), + Struct(Variable, Vec>), + ArrayAccess(Box>, IntegerRangeOrExpression), + StructMemberAccess(Box>, Variable), // (struct name, struct member name) + FunctionCall(Box>, Vec>), } +/// Definition assignee: v, arr[0..2], Point p.x #[derive(Debug, Clone)] -pub enum Assignee { - Variable(Variable), - Array(Box, FieldRangeOrExpression), - StructMember(Box, Variable), +pub enum Assignee { + Variable(Variable), + Array(Box>, IntegerRangeOrExpression), + StructMember(Box>, Variable), } /// Program statement that defines some action (or expression) to be carried out. #[derive(Clone)] -pub enum Statement { +pub enum Statement { // Declaration(Variable), - Definition(Assignee, Expression), - For(Variable, FieldExpression, FieldExpression, Vec), - Return(Vec), + Definition(Assignee, Expression), + For( + Variable, + IntegerExpression, + IntegerExpression, + Vec>, + ), + Return(Vec>), } +/// Explicit type used for defining struct members and function parameters #[derive(Clone, Debug)] -pub enum Type { - Boolean, +pub enum Type { + U32, FieldElement, - Array(Box, usize), - Struct(Variable), + Boolean, + Array(Box>, usize), + Struct(Variable), } #[derive(Clone, Debug)] -pub struct StructMember { - pub variable: Variable, - pub expression: Expression, +pub struct StructMember { + pub variable: Variable, + pub expression: Expression, } #[derive(Clone)] -pub struct StructField { - pub variable: Variable, - pub ty: Type, +pub struct StructField { + pub variable: Variable, + pub ty: Type, } #[derive(Clone)] -pub struct Struct { - pub variable: Variable, - pub fields: Vec, +pub struct Struct { + pub variable: Variable, + pub fields: Vec>, } +/// Function parameters + #[derive(Clone, Debug)] pub enum Visibility { Public, @@ -148,10 +194,10 @@ pub enum Visibility { } #[derive(Clone)] -pub struct Parameter { +pub struct Parameter { pub visibility: Option, - pub ty: Type, - pub variable: Variable, + pub ty: Type, + pub variable: Variable, } /// The given name for a defined function in the program. @@ -159,14 +205,14 @@ pub struct Parameter { pub struct FunctionName(pub String); #[derive(Clone)] -pub struct Function { +pub struct Function { pub function_name: FunctionName, - pub parameters: Vec, - pub returns: Vec, - pub statements: Vec, + pub parameters: Vec>, + pub returns: Vec>, + pub statements: Vec>, } -impl Function { +impl Function { pub fn get_name(&self) -> String { self.function_name.0.clone() } @@ -174,28 +220,19 @@ impl Function { /// A simple program with statement expressions, program arguments and program returns. #[derive(Debug, Clone)] -pub struct Program<'ast> { - pub name: Variable, +pub struct Program<'ast, F: Field + PrimeField> { + pub name: Variable, pub imports: Vec>, - pub structs: HashMap, - pub functions: HashMap, + pub structs: HashMap, Struct>, + pub functions: HashMap>, } -impl<'ast> Program<'ast> { +impl<'ast, F: Field + PrimeField> Program<'ast, F> { pub fn name(mut self, name: String) -> Self { - self.name = Variable(name); + self.name = Variable { + name, + _field: PhantomData::, + }; self } } - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_variable() { - let variable = Variable("1".into()); - - println!("{:#?}", variable); - } -} diff --git a/src/aleo_program/types_display.rs b/src/aleo_program/types_display.rs index 398d9fcf12..1f3c972e45 100644 --- a/src/aleo_program/types_display.rs +++ b/src/aleo_program/types_display.rs @@ -5,53 +5,56 @@ //! @date 2020 use crate::aleo_program::{ - Assignee, BooleanExpression, BooleanSpread, BooleanSpreadOrExpression, Expression, - FieldExpression, FieldRangeOrExpression, FieldSpread, FieldSpreadOrExpression, Function, - FunctionName, Parameter, Statement, Struct, StructField, Type, Variable, + Assignee, BooleanExpression, BooleanSpreadOrExpression, Expression, Function, FunctionName, + Integer, IntegerExpression, IntegerRangeOrExpression, IntegerSpreadOrExpression, Parameter, + Statement, Struct, StructField, Type, Variable, }; +use snarkos_models::curves::{Field, PrimeField}; use std::fmt; -impl fmt::Display for Variable { +impl fmt::Display for Variable { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{}", self.0) + write!(f, "{}", self.name) } } -impl fmt::Debug for Variable { +impl fmt::Debug for Variable { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{}", self.0) + write!(f, "{}", self.name) } } -impl fmt::Display for FieldSpread { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "...{}", self.0) - } -} - -impl fmt::Display for FieldSpreadOrExpression { +impl fmt::Display for Integer { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { - FieldSpreadOrExpression::Spread(ref spread) => write!(f, "{}", spread), - FieldSpreadOrExpression::FieldExpression(ref expression) => write!(f, "{}", expression), + Integer::U32(ref num) => write!(f, "{}", num), } } } -impl<'ast> fmt::Display for FieldExpression { +impl fmt::Display for IntegerSpreadOrExpression { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { - FieldExpression::Variable(ref variable) => write!(f, "{}", variable), - FieldExpression::Number(ref number) => write!(f, "{}", number), - FieldExpression::Add(ref lhs, ref rhs) => write!(f, "{} + {}", lhs, rhs), - FieldExpression::Sub(ref lhs, ref rhs) => write!(f, "{} - {}", lhs, rhs), - FieldExpression::Mul(ref lhs, ref rhs) => write!(f, "{} * {}", lhs, rhs), - FieldExpression::Div(ref lhs, ref rhs) => write!(f, "{} / {}", lhs, rhs), - FieldExpression::Pow(ref lhs, ref rhs) => write!(f, "{} ** {}", lhs, rhs), - FieldExpression::IfElse(ref a, ref b, ref c) => { + IntegerSpreadOrExpression::Spread(ref spread) => write!(f, "...{}", spread), + IntegerSpreadOrExpression::Expression(ref expression) => write!(f, "{}", expression), + } + } +} + +impl<'ast, F: Field + PrimeField> fmt::Display for IntegerExpression { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + IntegerExpression::Variable(ref variable) => write!(f, "{}", variable), + IntegerExpression::Number(ref number) => write!(f, "{}", number), + IntegerExpression::Add(ref lhs, ref rhs) => write!(f, "{} + {}", lhs, rhs), + IntegerExpression::Sub(ref lhs, ref rhs) => write!(f, "{} - {}", lhs, rhs), + IntegerExpression::Mul(ref lhs, ref rhs) => write!(f, "{} * {}", lhs, rhs), + IntegerExpression::Div(ref lhs, ref rhs) => write!(f, "{} / {}", lhs, rhs), + IntegerExpression::Pow(ref lhs, ref rhs) => write!(f, "{} ** {}", lhs, rhs), + IntegerExpression::IfElse(ref a, ref b, ref c) => { write!(f, "if {} then {} else {} fi", a, b, c) } - FieldExpression::Array(ref array) => { + IntegerExpression::Array(ref array) => { write!(f, "[")?; for (i, e) in array.iter().enumerate() { write!(f, "{}", e)?; @@ -65,24 +68,16 @@ impl<'ast> fmt::Display for FieldExpression { } } -impl fmt::Display for BooleanSpread { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "...{}", self.0) - } -} - -impl fmt::Display for BooleanSpreadOrExpression { +impl fmt::Display for BooleanSpreadOrExpression { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { - BooleanSpreadOrExpression::Spread(ref spread) => write!(f, "{}", spread), - BooleanSpreadOrExpression::BooleanExpression(ref expression) => { - write!(f, "{}", expression) - } + BooleanSpreadOrExpression::Spread(ref spread) => write!(f, "...{}", spread), + BooleanSpreadOrExpression::Expression(ref expression) => write!(f, "{}", expression), } } } -impl<'ast> fmt::Display for BooleanExpression { +impl<'ast, F: Field + PrimeField> fmt::Display for BooleanExpression { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { BooleanExpression::Variable(ref variable) => write!(f, "{}", variable), @@ -114,10 +109,10 @@ impl<'ast> fmt::Display for BooleanExpression { } } -impl<'ast> fmt::Display for FieldRangeOrExpression { +impl<'ast, F: Field + PrimeField> fmt::Display for IntegerRangeOrExpression { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { - FieldRangeOrExpression::Range(ref from, ref to) => write!( + IntegerRangeOrExpression::Range(ref from, ref to) => write!( f, "{}..{}", from.as_ref() @@ -125,16 +120,19 @@ impl<'ast> fmt::Display for FieldRangeOrExpression { .unwrap_or("".to_string()), to.as_ref().map(|e| e.to_string()).unwrap_or("".to_string()) ), - FieldRangeOrExpression::FieldExpression(ref e) => write!(f, "{}", e), + IntegerRangeOrExpression::Expression(ref e) => write!(f, "{}", e), } } } -impl<'ast> fmt::Display for Expression { +impl<'ast, F: Field + PrimeField> fmt::Display for Expression { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { + Expression::Integer(ref integer_expression) => write!(f, "{}", integer_expression), + Expression::FieldElement(ref _field_expression) => { + unimplemented!("field elem not impl ") + } Expression::Boolean(ref boolean_expression) => write!(f, "{}", boolean_expression), - Expression::FieldElement(ref field_expression) => write!(f, "{}", field_expression), Expression::Variable(ref variable) => write!(f, "{}", variable), Expression::Struct(ref var, ref members) => { write!(f, "{} {{", var)?; @@ -164,7 +162,7 @@ impl<'ast> fmt::Display for Expression { } } -impl fmt::Display for Assignee { +impl fmt::Display for Assignee { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { Assignee::Variable(ref variable) => write!(f, "{}", variable), @@ -176,7 +174,7 @@ impl fmt::Display for Assignee { } } -impl fmt::Display for Statement { +impl fmt::Display for Statement { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { Statement::Definition(ref variable, ref statement) => { @@ -199,7 +197,7 @@ impl fmt::Display for Statement { } } -impl fmt::Debug for Statement { +impl fmt::Debug for Statement { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { Statement::Definition(ref variable, ref statement) => { @@ -222,24 +220,25 @@ impl fmt::Debug for Statement { } } -impl fmt::Display for Type { +impl fmt::Display for Type { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { + Type::FieldElement => unimplemented!("field type unimpl"), + Type::U32 => write!(f, "field"), Type::Boolean => write!(f, "bool"), - Type::FieldElement => write!(f, "field"), Type::Struct(ref variable) => write!(f, "{}", variable), Type::Array(ref array, ref count) => write!(f, "[{}; {}]", array, count), } } } -impl fmt::Display for StructField { +impl fmt::Display for StructField { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{} : {}", self.ty, self.variable) } } -impl fmt::Debug for Struct { +impl fmt::Debug for Struct { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "struct {} {{ \n", self.variable)?; for field in self.fields.iter() { @@ -249,7 +248,7 @@ impl fmt::Debug for Struct { } } -impl fmt::Display for Parameter { +impl fmt::Display for Parameter { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { // let visibility = if self.private { "private " } else { "" }; write!( @@ -262,7 +261,7 @@ impl fmt::Display for Parameter { } } -impl fmt::Debug for Parameter { +impl fmt::Debug for Parameter { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "Parameter(variable: {:?})", self.ty) } @@ -274,7 +273,7 @@ impl fmt::Debug for FunctionName { } } -impl fmt::Display for Function { +impl fmt::Display for Function { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!( f, @@ -293,7 +292,7 @@ impl fmt::Display for Function { } } -impl fmt::Debug for Function { +impl fmt::Debug for Function { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!( f, diff --git a/src/aleo_program/types_from.rs b/src/aleo_program/types_from.rs index 39120fac42..10b81387d0 100644 --- a/src/aleo_program/types_from.rs +++ b/src/aleo_program/types_from.rs @@ -1,78 +1,170 @@ -//! Logic to convert from an abstract syntax tree (ast) representation to a typed zokrates_program. +//! Logic to convert from an abstract syntax tree (ast) representation to a typed aleo program. //! -//! @file zokrates_program.rs +//! @file types_from.rs //! @author Collin Chin //! @date 2020 use crate::aleo_program::{types, Import, PathString}; use crate::ast; +use snarkos_models::curves::{Field, PrimeField}; use std::collections::HashMap; +use std::marker::PhantomData; use std::path::Path; -impl<'ast> From> for types::FieldExpression { - fn from(field: ast::Field<'ast>) -> Self { - let number = field.value.parse::().expect("unable to unwrap field"); - types::FieldExpression::Number(number) +/// pest ast -> types::Variable + +impl<'ast, F: Field + PrimeField> From> for types::Variable { + fn from(variable: ast::Variable<'ast>) -> Self { + types::Variable { + name: variable.value, + _field: PhantomData::, + } } } -impl<'ast> From> for types::BooleanExpression { - fn from(boolean: ast::Boolean<'ast>) -> Self { - let boolean = boolean - .value - .parse::() - .expect("unable to unwrap boolean"); - types::BooleanExpression::Value(boolean) +/// pest ast - types::Integer + +impl<'ast, F: Field + PrimeField> From> for types::IntegerExpression { + fn from(variable: ast::Variable<'ast>) -> Self { + types::IntegerExpression::Variable(types::Variable::from(variable)) } } -impl<'ast> From> for types::Expression { - fn from(value: ast::Value<'ast>) -> Self { - match value { - ast::Value::Boolean(value) => { - types::Expression::Boolean(types::BooleanExpression::from(value)) +impl<'ast, F: Field + PrimeField> From> for types::IntegerExpression { + fn from(field: ast::U32<'ast>) -> Self { + types::IntegerExpression::Number(types::Integer::U32( + field + .number + .value + .parse::() + .expect("unable to unwrap u32"), + )) + } +} + +impl<'ast, F: Field + PrimeField> From> for types::IntegerExpression { + fn from(expression: ast::Expression<'ast>) -> Self { + match types::Expression::from(expression) { + types::Expression::Integer(integer_expression) => integer_expression, + types::Expression::Variable(variable) => types::IntegerExpression::Variable(variable), + _ => unimplemented!("expected field in field expression"), + } + } +} + +impl<'ast, F: Field + PrimeField> From> + for types::IntegerSpreadOrExpression +{ + fn from(expression: ast::Expression<'ast>) -> Self { + match types::Expression::from(expression) { + types::Expression::Integer(expression) => { + types::IntegerSpreadOrExpression::Expression(expression) } - ast::Value::Field(value) => { - types::Expression::FieldElement(types::FieldExpression::from(value)) + _ => unimplemented!("cannot create field expression from boolean type"), + } + } +} + +impl<'ast, F: Field + PrimeField> From> + for types::IntegerSpreadOrExpression +{ + fn from(s_or_e: ast::SpreadOrExpression<'ast>) -> Self { + match s_or_e { + ast::SpreadOrExpression::Spread(spread) => { + types::IntegerSpreadOrExpression::from(spread.expression) + } + ast::SpreadOrExpression::Expression(expression) => { + types::IntegerSpreadOrExpression::from(expression) } } } } -impl<'ast> From> for types::Variable { +impl<'ast, F: Field + PrimeField> From> + for types::IntegerRangeOrExpression +{ + fn from(range_or_expression: ast::RangeOrExpression<'ast>) -> Self { + match range_or_expression { + ast::RangeOrExpression::Range(range) => { + let from = range + .from + .map(|from| match types::Expression::from(from.0) { + types::Expression::Integer(field) => field, + expression => { + unimplemented!("Range bounds should be numbers, found {}", expression) + } + }); + let to = range.to.map(|to| match types::Expression::from(to.0) { + types::Expression::Integer(field) => field, + expression => { + unimplemented!("Range bounds should be numbers, found {}", expression) + } + }); + + types::IntegerRangeOrExpression::Range(from, to) + } + ast::RangeOrExpression::Expression(expression) => { + match types::Expression::from(expression) { + types::Expression::Integer(field_expression) => { + types::IntegerRangeOrExpression::Expression(field_expression) + } + // types::Expression::ArrayAccess(expression, field), // recursive array access + expression => unimplemented!("expression must be field, found {}", expression), + } + } + } + } +} + +/// pest ast -> types::Boolean + +impl<'ast, F: Field + PrimeField> From> + for types::BooleanSpreadOrExpression +{ + fn from(expression: ast::Expression<'ast>) -> Self { + match types::Expression::from(expression) { + types::Expression::Boolean(expression) => { + types::BooleanSpreadOrExpression::Expression(expression) + } + _ => unimplemented!("cannot create boolean expression from field type"), + } + } +} + +impl<'ast, F: Field + PrimeField> From> + for types::BooleanSpreadOrExpression +{ + fn from(s_or_e: ast::SpreadOrExpression<'ast>) -> Self { + match s_or_e { + ast::SpreadOrExpression::Spread(spread) => { + types::BooleanSpreadOrExpression::from(spread.expression) + } + ast::SpreadOrExpression::Expression(expression) => { + types::BooleanSpreadOrExpression::from(expression) + } + } + } +} + +impl<'ast, F: Field + PrimeField> From> for types::BooleanExpression { fn from(variable: ast::Variable<'ast>) -> Self { - types::Variable(variable.value) + types::BooleanExpression::Variable(types::Variable::from(variable)) } } -impl<'ast> From> for types::FieldExpression { - fn from(variable: ast::Variable<'ast>) -> Self { - types::FieldExpression::Variable(types::Variable(variable.value)) +impl<'ast, F: Field + PrimeField> From> for types::BooleanExpression { + fn from(boolean: ast::Boolean<'ast>) -> Self { + types::BooleanExpression::Value( + boolean + .value + .parse::() + .expect("unable to unwrap boolean"), + ) } } -impl<'ast> From> for types::BooleanExpression { - fn from(variable: ast::Variable<'ast>) -> Self { - types::BooleanExpression::Variable(types::Variable(variable.value)) - } -} - -impl<'ast> From> for types::Expression { - fn from(variable: ast::Variable<'ast>) -> Self { - types::Expression::Variable(types::Variable(variable.value)) - } -} - -impl<'ast> From> for types::Expression { - fn from(expression: ast::NotExpression<'ast>) -> Self { - types::Expression::Boolean(types::BooleanExpression::Not(Box::new( - types::BooleanExpression::from(*expression.expression), - ))) - } -} - -impl<'ast> From> for types::BooleanExpression { +impl<'ast, F: Field + PrimeField> From> for types::BooleanExpression { fn from(expression: ast::Expression<'ast>) -> Self { match types::Expression::from(expression) { types::Expression::Boolean(boolean_expression) => boolean_expression, @@ -82,17 +174,37 @@ impl<'ast> From> for types::BooleanExpression { } } -impl<'ast> From> for types::FieldExpression { - fn from(expression: ast::Expression<'ast>) -> Self { - match types::Expression::from(expression) { - types::Expression::FieldElement(field_expression) => field_expression, - types::Expression::Variable(variable) => types::FieldExpression::Variable(variable), - _ => unimplemented!("expected field in field expression"), +/// pest ast -> types::Expression + +impl<'ast, F: Field + PrimeField> From> for types::Expression { + fn from(value: ast::Value<'ast>) -> Self { + match value { + ast::Value::U32(value) => { + types::Expression::Integer(types::IntegerExpression::from(value)) + } + ast::Value::Field(_field) => unimplemented!("cannot declare field values yet"), + ast::Value::Boolean(value) => { + types::Expression::Boolean(types::BooleanExpression::from(value)) + } } } } -impl<'ast> types::BooleanExpression { +impl<'ast, F: Field + PrimeField> From> for types::Expression { + fn from(variable: ast::Variable<'ast>) -> Self { + types::Expression::Variable(types::Variable::from(variable)) + } +} + +impl<'ast, F: Field + PrimeField> From> for types::Expression { + fn from(expression: ast::NotExpression<'ast>) -> Self { + types::Expression::Boolean(types::BooleanExpression::Not(Box::new( + types::BooleanExpression::from(*expression.expression), + ))) + } +} + +impl<'ast, F: Field + PrimeField> types::BooleanExpression { /// Find out which types we are comparing and output the corresponding expression. fn from_eq(expression: ast::BinaryExpression<'ast>) -> Self { let left = types::Expression::from(*expression.left); @@ -117,18 +229,18 @@ impl<'ast> types::BooleanExpression { ) } //TODO: check case for two variables? // Field equality - (types::Expression::FieldElement(lhs), types::Expression::FieldElement(rhs)) => { + (types::Expression::Integer(lhs), types::Expression::Integer(rhs)) => { types::BooleanExpression::FieldEq(Box::new(lhs), Box::new(rhs)) } - (types::Expression::FieldElement(lhs), types::Expression::Variable(rhs)) => { + (types::Expression::Integer(lhs), types::Expression::Variable(rhs)) => { types::BooleanExpression::FieldEq( Box::new(lhs), - Box::new(types::FieldExpression::Variable(rhs)), + Box::new(types::IntegerExpression::Variable(rhs)), ) } - (types::Expression::Variable(lhs), types::Expression::FieldElement(rhs)) => { + (types::Expression::Variable(lhs), types::Expression::Integer(rhs)) => { types::BooleanExpression::FieldEq( - Box::new(types::FieldExpression::Variable(lhs)), + Box::new(types::IntegerExpression::Variable(lhs)), Box::new(rhs), ) } @@ -142,7 +254,7 @@ impl<'ast> types::BooleanExpression { } } -impl<'ast> From> for types::Expression { +impl<'ast, F: Field + PrimeField> From> for types::Expression { fn from(expression: ast::BinaryExpression<'ast>) -> Self { match expression.operation { // Boolean operations @@ -161,57 +273,47 @@ impl<'ast> From> for types::Expression { types::Expression::Boolean(types::BooleanExpression::from_neq(expression)) } ast::BinaryOperator::Geq => types::Expression::Boolean(types::BooleanExpression::Geq( - Box::new(types::FieldExpression::from(*expression.left)), - Box::new(types::FieldExpression::from(*expression.right)), + Box::new(types::IntegerExpression::from(*expression.left)), + Box::new(types::IntegerExpression::from(*expression.right)), )), ast::BinaryOperator::Gt => types::Expression::Boolean(types::BooleanExpression::Gt( - Box::new(types::FieldExpression::from(*expression.left)), - Box::new(types::FieldExpression::from(*expression.right)), + Box::new(types::IntegerExpression::from(*expression.left)), + Box::new(types::IntegerExpression::from(*expression.right)), )), ast::BinaryOperator::Leq => types::Expression::Boolean(types::BooleanExpression::Leq( - Box::new(types::FieldExpression::from(*expression.left)), - Box::new(types::FieldExpression::from(*expression.right)), + Box::new(types::IntegerExpression::from(*expression.left)), + Box::new(types::IntegerExpression::from(*expression.right)), )), ast::BinaryOperator::Lt => types::Expression::Boolean(types::BooleanExpression::Lt( - Box::new(types::FieldExpression::from(*expression.left)), - Box::new(types::FieldExpression::from(*expression.right)), + Box::new(types::IntegerExpression::from(*expression.left)), + Box::new(types::IntegerExpression::from(*expression.right)), )), // Field operations - ast::BinaryOperator::Add => { - types::Expression::FieldElement(types::FieldExpression::Add( - Box::new(types::FieldExpression::from(*expression.left)), - Box::new(types::FieldExpression::from(*expression.right)), - )) - } - ast::BinaryOperator::Sub => { - types::Expression::FieldElement(types::FieldExpression::Sub( - Box::new(types::FieldExpression::from(*expression.left)), - Box::new(types::FieldExpression::from(*expression.right)), - )) - } - ast::BinaryOperator::Mul => { - types::Expression::FieldElement(types::FieldExpression::Mul( - Box::new(types::FieldExpression::from(*expression.left)), - Box::new(types::FieldExpression::from(*expression.right)), - )) - } - ast::BinaryOperator::Div => { - types::Expression::FieldElement(types::FieldExpression::Div( - Box::new(types::FieldExpression::from(*expression.left)), - Box::new(types::FieldExpression::from(*expression.right)), - )) - } - ast::BinaryOperator::Pow => { - types::Expression::FieldElement(types::FieldExpression::Pow( - Box::new(types::FieldExpression::from(*expression.left)), - Box::new(types::FieldExpression::from(*expression.right)), - )) - } + ast::BinaryOperator::Add => types::Expression::Integer(types::IntegerExpression::Add( + Box::new(types::IntegerExpression::from(*expression.left)), + Box::new(types::IntegerExpression::from(*expression.right)), + )), + ast::BinaryOperator::Sub => types::Expression::Integer(types::IntegerExpression::Sub( + Box::new(types::IntegerExpression::from(*expression.left)), + Box::new(types::IntegerExpression::from(*expression.right)), + )), + ast::BinaryOperator::Mul => types::Expression::Integer(types::IntegerExpression::Mul( + Box::new(types::IntegerExpression::from(*expression.left)), + Box::new(types::IntegerExpression::from(*expression.right)), + )), + ast::BinaryOperator::Div => types::Expression::Integer(types::IntegerExpression::Div( + Box::new(types::IntegerExpression::from(*expression.left)), + Box::new(types::IntegerExpression::from(*expression.right)), + )), + ast::BinaryOperator::Pow => types::Expression::Integer(types::IntegerExpression::Pow( + Box::new(types::IntegerExpression::from(*expression.left)), + Box::new(types::IntegerExpression::from(*expression.right)), + )), } } } -impl<'ast> From> for types::Expression { +impl<'ast, F: Field + PrimeField> From> for types::Expression { fn from(expression: ast::TernaryExpression<'ast>) -> Self { // Evaluate expressions to find out result type let first = types::BooleanExpression::from(*expression.first); @@ -242,24 +344,24 @@ impl<'ast> From> for types::Expression { )) } // Field Result - (types::Expression::FieldElement(second), types::Expression::FieldElement(third)) => { - types::Expression::FieldElement(types::FieldExpression::IfElse( + (types::Expression::Integer(second), types::Expression::Integer(third)) => { + types::Expression::Integer(types::IntegerExpression::IfElse( Box::new(first), Box::new(second), Box::new(third), )) } - (types::Expression::FieldElement(second), types::Expression::Variable(third)) => { - types::Expression::FieldElement(types::FieldExpression::IfElse( + (types::Expression::Integer(second), types::Expression::Variable(third)) => { + types::Expression::Integer(types::IntegerExpression::IfElse( Box::new(first), Box::new(second), - Box::new(types::FieldExpression::Variable(third)), + Box::new(types::IntegerExpression::Variable(third)), )) } - (types::Expression::Variable(second), types::Expression::FieldElement(third)) => { - types::Expression::FieldElement(types::FieldExpression::IfElse( + (types::Expression::Variable(second), types::Expression::Integer(third)) => { + types::Expression::Integer(types::IntegerExpression::IfElse( Box::new(first), - Box::new(types::FieldExpression::Variable(second)), + Box::new(types::IntegerExpression::Variable(second)), Box::new(third), )) } @@ -274,41 +376,7 @@ impl<'ast> From> for types::Expression { } } -impl<'ast> From> for types::FieldRangeOrExpression { - fn from(range_or_expression: ast::RangeOrExpression<'ast>) -> Self { - match range_or_expression { - ast::RangeOrExpression::Range(range) => { - let from = range - .from - .map(|from| match types::Expression::from(from.0) { - types::Expression::FieldElement(field) => field, - expression => { - unimplemented!("Range bounds should be numbers, found {}", expression) - } - }); - let to = range.to.map(|to| match types::Expression::from(to.0) { - types::Expression::FieldElement(field) => field, - expression => { - unimplemented!("Range bounds should be numbers, found {}", expression) - } - }); - - types::FieldRangeOrExpression::Range(from, to) - } - ast::RangeOrExpression::Expression(expression) => { - match types::Expression::from(expression) { - types::Expression::FieldElement(field_expression) => { - types::FieldRangeOrExpression::FieldExpression(field_expression) - } - // types::Expression::ArrayAccess(expression, field), // recursive array access - expression => unimplemented!("expression must be field, found {}", expression), - } - } - } - } -} - -impl<'ast> From> for types::Expression { +impl<'ast, F: Field + PrimeField> From> for types::Expression { fn from(expression: ast::PostfixExpression<'ast>) -> Self { let variable = types::Expression::Variable(types::Variable::from(expression.variable)); @@ -339,13 +407,13 @@ impl<'ast> From> for types::Expression { ), ast::Access::Array(array) => types::Expression::ArrayAccess( Box::new(acc), - types::FieldRangeOrExpression::from(array.expression), + types::IntegerRangeOrExpression::from(array.expression), ), }) } } -impl<'ast> From> for types::Expression { +impl<'ast, F: Field + PrimeField> From> for types::Expression { fn from(expression: ast::Expression<'ast>) -> Self { match expression { ast::Expression::Value(value) => types::Expression::from(value), @@ -359,147 +427,53 @@ impl<'ast> From> for types::Expression { ast::Expression::ArrayInitializer(_expression) => { unimplemented!("unknown type for array initializer expression") } + ast::Expression::StructInline(_expression) => { + unimplemented!("unknown type for inline struct expression") + } ast::Expression::Postfix(expression) => types::Expression::from(expression), _ => unimplemented!(), } } } -impl<'ast> From> for types::Assignee { - fn from(variable: ast::Variable<'ast>) -> Self { - types::Assignee::Variable(types::Variable::from(variable)) - } -} - -impl<'ast> From> for types::Assignee { - fn from(assignee: ast::Assignee<'ast>) -> Self { - let variable = types::Assignee::from(assignee.variable); - - // we start with the id, and we fold the array of accesses by wrapping the current value - assignee - .accesses - .into_iter() - .fold(variable, |acc, access| match access { - ast::AssigneeAccess::Array(array) => types::Assignee::Array( - Box::new(acc), - types::FieldRangeOrExpression::from(array.expression), - ), - ast::AssigneeAccess::Member(struct_member) => types::Assignee::StructMember( - Box::new(acc), - types::Variable::from(struct_member.variable), - ), - }) - } -} - -impl<'ast> From> for types::Statement { - fn from(statement: ast::AssignStatement<'ast>) -> Self { - types::Statement::Definition( - types::Assignee::from(statement.assignee), - types::Expression::from(statement.expression), - ) - } -} - -impl<'ast> From> for types::BooleanSpread { - fn from(spread: ast::Spread<'ast>) -> Self { - let boolean_expression = types::Expression::from(spread.expression); - match boolean_expression { - types::Expression::Boolean(expression) => types::BooleanSpread(expression), - types::Expression::Variable(variable) => { - types::BooleanSpread(types::BooleanExpression::Variable(variable)) - } - _ => unimplemented!("cannot create boolean spread from field type"), - } - } -} - -impl<'ast> From> for types::BooleanSpreadOrExpression { - fn from(expression: ast::Expression<'ast>) -> Self { - match types::Expression::from(expression) { - types::Expression::Boolean(expression) => { - types::BooleanSpreadOrExpression::BooleanExpression(expression) - } - _ => unimplemented!("cannot create boolean expression from field type"), - } - } -} - -impl<'ast> From> for types::BooleanSpreadOrExpression { - fn from(s_or_e: ast::SpreadOrExpression<'ast>) -> Self { - match s_or_e { - ast::SpreadOrExpression::Spread(spread) => { - types::BooleanSpreadOrExpression::Spread(types::BooleanSpread::from(spread)) - } - ast::SpreadOrExpression::Expression(expression) => { - match types::Expression::from(expression) { - types::Expression::Boolean(expression) => { - types::BooleanSpreadOrExpression::BooleanExpression(expression) - } - _ => unimplemented!("cannot create boolean expression from field type"), - } - } - } - } -} - -impl<'ast> From> for types::FieldSpread { - fn from(spread: ast::Spread<'ast>) -> Self { - match types::Expression::from(spread.expression) { - types::Expression::FieldElement(expression) => types::FieldSpread(expression), - types::Expression::Variable(variable) => { - types::FieldSpread(types::FieldExpression::Variable(variable)) - } - expression => unimplemented!( - "cannot create field spread from boolean type {}", - expression - ), - } - } -} - -impl<'ast> From> for types::FieldSpreadOrExpression { - fn from(expression: ast::Expression<'ast>) -> Self { - match types::Expression::from(expression) { - types::Expression::FieldElement(expression) => { - types::FieldSpreadOrExpression::FieldExpression(expression) - } - _ => unimplemented!("cannot create field expression from boolean type"), - } - } -} - -impl<'ast> From> for types::FieldSpreadOrExpression { - fn from(s_or_e: ast::SpreadOrExpression<'ast>) -> Self { - match s_or_e { - ast::SpreadOrExpression::Spread(spread) => { - types::FieldSpreadOrExpression::Spread(types::FieldSpread::from(spread)) - } - ast::SpreadOrExpression::Expression(expression) => { - types::FieldSpreadOrExpression::from(expression) - } - } - } -} - -impl<'ast> From> for types::StructMember { - fn from(member: ast::InlineStructMember<'ast>) -> Self { - types::StructMember { - variable: types::Variable::from(member.variable), - expression: types::Expression::from(member.expression), - } - } -} - -impl<'ast> types::Expression { +/// pest ast -> typed types::Expression +/// For defined types (ex: u32[4]) we manually construct the expression instead of implementing the From trait. +/// This saves us from having to resolve things at a later point in time. +impl<'ast, F: Field + PrimeField> types::Expression { fn from_basic(_ty: ast::BasicType<'ast>, _expression: ast::Expression<'ast>) -> Self { unimplemented!("from basic not impl"); } fn from_array(ty: ast::ArrayType<'ast>, expression: ast::Expression<'ast>) -> Self { match ty.ty { + ast::BasicType::U32(_ty) => { + let elements: Vec>> = match expression { + ast::Expression::ArrayInline(array) => array + .expressions + .into_iter() + .map(|s_or_e| Box::new(types::IntegerSpreadOrExpression::from(s_or_e))) + .collect(), + ast::Expression::ArrayInitializer(array) => { + let count = match array.count { + ast::Value::U32(f) => f + .number + .value + .parse::() + .expect("Unable to read array size"), + _ => unimplemented!("Array size should be an integer"), + }; + let expression = + Box::new(types::IntegerSpreadOrExpression::from(*array.expression)); + + vec![expression; count] + } + _ => unimplemented!("expected array after array type"), + }; + types::Expression::Integer(types::IntegerExpression::Array(elements)) + } + ast::BasicType::Field(_ty) => unimplemented!("from array field basic types unimpl"), ast::BasicType::Boolean(_ty) => { - let elements: Vec> = match expression { + let elements: Vec>> = match expression { ast::Expression::ArrayInline(array) => array .expressions .into_iter() @@ -507,9 +481,11 @@ impl<'ast> types::Expression { .collect(), ast::Expression::ArrayInitializer(array) => { let count = match array.count { - ast::Value::Field(f) => { - f.value.parse::().expect("Unable to read array size") - } + ast::Value::U32(f) => f + .number + .value + .parse::() + .expect("Unable to read array size"), _ => unimplemented!("Array size should be an integer"), }; let expression = @@ -521,29 +497,6 @@ impl<'ast> types::Expression { }; types::Expression::Boolean(types::BooleanExpression::Array(elements)) } - ast::BasicType::Field(_ty) => { - let elements: Vec> = match expression { - ast::Expression::ArrayInline(array) => array - .expressions - .into_iter() - .map(|s_or_e| Box::new(types::FieldSpreadOrExpression::from(s_or_e))) - .collect(), - ast::Expression::ArrayInitializer(array) => { - let count = match array.count { - ast::Value::Field(f) => { - f.value.parse::().expect("Unable to read array size") - } - _ => unimplemented!("Array size should be an integer"), - }; - let expression = - Box::new(types::FieldSpreadOrExpression::from(*array.expression)); - - vec![expression; count] - } - _ => unimplemented!("expected array after array type"), - }; - types::Expression::FieldElement(types::FieldExpression::Array(elements)) - } } } @@ -559,7 +512,7 @@ impl<'ast> types::Expression { .members .into_iter() .map(|member| types::StructMember::from(member)) - .collect::>(); + .collect::>>(); types::Expression::Struct(variable, members) } @@ -576,7 +529,47 @@ impl<'ast> types::Expression { } } -impl<'ast> From> for types::Statement { +/// pest ast -> types::Assignee + +impl<'ast, F: Field + PrimeField> From> for types::Assignee { + fn from(variable: ast::Variable<'ast>) -> Self { + types::Assignee::Variable(types::Variable::from(variable)) + } +} + +impl<'ast, F: Field + PrimeField> From> for types::Assignee { + fn from(assignee: ast::Assignee<'ast>) -> Self { + let variable = types::Assignee::from(assignee.variable); + + // we start with the id, and we fold the array of accesses by wrapping the current value + assignee + .accesses + .into_iter() + .fold(variable, |acc, access| match access { + ast::AssigneeAccess::Array(array) => types::Assignee::Array( + Box::new(acc), + types::IntegerRangeOrExpression::from(array.expression), + ), + ast::AssigneeAccess::Member(struct_member) => types::Assignee::StructMember( + Box::new(acc), + types::Variable::from(struct_member.variable), + ), + }) + } +} + +/// pest ast -> types::Statement + +impl<'ast, F: Field + PrimeField> From> for types::Statement { + fn from(statement: ast::AssignStatement<'ast>) -> Self { + types::Statement::Definition( + types::Assignee::from(statement.assignee), + types::Expression::from(statement.expression), + ) + } +} + +impl<'ast, F: Field + PrimeField> From> for types::Statement { fn from(statement: ast::DefinitionStatement<'ast>) -> Self { types::Statement::Definition( types::Assignee::from(statement.variable), @@ -585,7 +578,7 @@ impl<'ast> From> for types::Statement { } } -impl<'ast> From> for types::Statement { +impl<'ast, F: Field + PrimeField> From> for types::Statement { fn from(statement: ast::ReturnStatement<'ast>) -> Self { types::Statement::Return( statement @@ -597,12 +590,12 @@ impl<'ast> From> for types::Statement { } } -impl<'ast> From> for types::Statement { +impl<'ast, F: Field + PrimeField> From> for types::Statement { fn from(statement: ast::ForStatement<'ast>) -> Self { types::Statement::For( types::Variable::from(statement.index), - types::FieldExpression::from(statement.start), - types::FieldExpression::from(statement.stop), + types::IntegerExpression::from(statement.start), + types::IntegerExpression::from(statement.stop), statement .statements .into_iter() @@ -612,7 +605,7 @@ impl<'ast> From> for types::Statement { } } -impl<'ast> From> for types::Statement { +impl<'ast, F: Field + PrimeField> From> for types::Statement { fn from(statement: ast::Statement<'ast>) -> Self { match statement { ast::Statement::Assign(statement) => types::Statement::from(statement), @@ -623,33 +616,40 @@ impl<'ast> From> for types::Statement { } } -impl<'ast> From> for types::Type { +/// pest ast -> Explicit types::Type for defining struct members and function params + +impl<'ast, F: Field + PrimeField> From> for types::Type { fn from(basic_type: ast::BasicType<'ast>) -> Self { match basic_type { - ast::BasicType::Field(_ty) => types::Type::FieldElement, + ast::BasicType::U32(_ty) => types::Type::U32, + ast::BasicType::Field(_ty) => types::Type::U32, ast::BasicType::Boolean(_ty) => types::Type::Boolean, } } } -impl<'ast> From> for types::Type { +impl<'ast, F: Field + PrimeField> From> for types::Type { fn from(array_type: ast::ArrayType<'ast>) -> Self { let element_type = Box::new(types::Type::from(array_type.ty)); let count = match array_type.count { - ast::Value::Field(f) => f.value.parse::().expect("Unable to read array size"), + ast::Value::Field(f) => f + .number + .value + .parse::() + .expect("Unable to read array size"), _ => unimplemented!("Array size should be an integer"), }; types::Type::Array(element_type, count) } } -impl<'ast> From> for types::Type { +impl<'ast, F: Field + PrimeField> From> for types::Type { fn from(struct_type: ast::StructType<'ast>) -> Self { types::Type::Struct(types::Variable::from(struct_type.variable)) } } -impl<'ast> From> for types::Type { +impl<'ast, F: Field + PrimeField> From> for types::Type { fn from(ty: ast::Type<'ast>) -> Self { match ty { ast::Type::Basic(ty) => types::Type::from(ty), @@ -659,7 +659,18 @@ impl<'ast> From> for types::Type { } } -impl<'ast> From> for types::StructField { +/// pest ast -> types::Struct + +impl<'ast, F: Field + PrimeField> From> for types::StructMember { + fn from(member: ast::InlineStructMember<'ast>) -> Self { + types::StructMember { + variable: types::Variable::from(member.variable), + expression: types::Expression::from(member.expression), + } + } +} + +impl<'ast, F: Field + PrimeField> From> for types::StructField { fn from(struct_field: ast::StructField<'ast>) -> Self { types::StructField { variable: types::Variable::from(struct_field.variable), @@ -668,7 +679,7 @@ impl<'ast> From> for types::StructField { } } -impl<'ast> From> for types::Struct { +impl<'ast, F: Field + PrimeField> From> for types::Struct { fn from(struct_definition: ast::Struct<'ast>) -> Self { let variable = types::Variable::from(struct_definition.variable); let fields = struct_definition @@ -681,6 +692,8 @@ impl<'ast> From> for types::Struct { } } +/// pest ast -> function types::Parameters + impl From for types::Visibility { fn from(visibility: ast::Visibility) -> Self { match visibility { @@ -690,7 +703,7 @@ impl From for types::Visibility { } } -impl<'ast> From> for types::Parameter { +impl<'ast, F: Field + PrimeField> From> for types::Parameter { fn from(parameter: ast::Parameter<'ast>) -> Self { let ty = types::Type::from(parameter.ty); let variable = types::Variable::from(parameter.variable); @@ -712,13 +725,15 @@ impl<'ast> From> for types::Parameter { } } +/// pest ast -> types::Function + impl<'ast> From> for types::FunctionName { fn from(name: ast::FunctionName<'ast>) -> Self { types::FunctionName(name.value) } } -impl<'ast> From> for types::Function { +impl<'ast, F: Field + PrimeField> From> for types::Function { fn from(function_definition: ast::Function<'ast>) -> Self { let function_name = types::FunctionName::from(function_definition.function_name); let parameters = function_definition @@ -746,6 +761,8 @@ impl<'ast> From> for types::Function { } } +/// pest ast -> Import + impl<'ast> From> for PathString<'ast> { fn from(import: ast::Variable<'ast>) -> Self { import.span.as_str() @@ -766,7 +783,9 @@ impl<'ast> From> for Import<'ast> { } } -impl<'ast> From> for types::Program<'ast> { +/// pest ast -> types::Program + +impl<'ast, F: Field + PrimeField> From> for types::Program<'ast, F> { fn from(file: ast::File<'ast>) -> Self { // Compiled ast -> aleo program representation let imports = file @@ -792,7 +811,10 @@ impl<'ast> From> for types::Program<'ast> { }); types::Program { - name: types::Variable("".into()), + name: types::Variable { + name: "".into(), + _field: PhantomData::, + }, imports, structs, functions, diff --git a/src/ast.rs b/src/ast.rs index 4c951d8b3c..513b03c331 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -30,203 +30,6 @@ lazy_static! { static ref PRECEDENCE_CLIMBER: PrecClimber = precedence_climber(); } -fn precedence_climber() -> PrecClimber { - PrecClimber::new(vec![ - Operator::new(Rule::operation_or, Assoc::Left), - Operator::new(Rule::operation_and, Assoc::Left), - Operator::new(Rule::operation_eq, Assoc::Left) - | Operator::new(Rule::operation_neq, Assoc::Left), - Operator::new(Rule::operation_geq, Assoc::Left) - | Operator::new(Rule::operation_gt, Assoc::Left) - | Operator::new(Rule::operation_leq, Assoc::Left) - | Operator::new(Rule::operation_lt, Assoc::Left), - Operator::new(Rule::operation_add, Assoc::Left) - | Operator::new(Rule::operation_sub, Assoc::Left), - Operator::new(Rule::operation_mul, Assoc::Left) - | Operator::new(Rule::operation_div, Assoc::Left), - Operator::new(Rule::operation_pow, Assoc::Left), - ]) -} - -fn parse_term(pair: Pair) -> Box { - Box::new(match pair.as_rule() { - Rule::expression_term => { - let clone = pair.clone(); - let next = clone.into_inner().next().unwrap(); - match next.as_rule() { - Rule::expression => Expression::from_pest(&mut pair.into_inner()).unwrap(), // Parenthesis case - Rule::expression_inline_struct => { - Expression::StructInline( - StructInlineExpression::from_pest(&mut pair.into_inner()).unwrap(), - ) - }, - Rule::expression_array_inline => { - Expression::ArrayInline( - ArrayInlineExpression::from_pest(&mut pair.into_inner()).unwrap() - ) - }, - Rule::expression_array_initializer => { - Expression::ArrayInitializer( - ArrayInitializerExpression::from_pest(&mut pair.into_inner()).unwrap() - ) - }, - Rule::expression_conditional => { - Expression::Ternary( - TernaryExpression::from_pest(&mut pair.into_inner()).unwrap(), - ) - }, - Rule::expression_not => { - let span = next.as_span(); - let mut inner = next.into_inner(); - let operation = match inner.next().unwrap().as_rule() { - Rule::operation_pre_not => Not::from_pest(&mut pair.into_inner().next().unwrap().into_inner()).unwrap(), - rule => unreachable!("`expression_not` should yield `operation_pre_not`, found {:#?}", rule) - }; - let expression = parse_term(inner.next().unwrap()); - Expression::Not(NotExpression { operation, expression, span }) - }, - Rule::expression_increment => { - println!("expression increment"); - let span = next.as_span(); - let mut inner = next.into_inner(); - let expression = parse_term(inner.next().unwrap()); - let operation = match inner.next().unwrap().as_rule() { - Rule::operation_post_increment => Increment::from_pest(&mut pair.into_inner().next().unwrap().into_inner()).unwrap(), - rule => unreachable!("`expression_increment` should yield `operation_post_increment`, found {:#?}", rule) - }; - Expression::Increment(IncrementExpression { operation, expression, span }) - }, - Rule::expression_decrement => { - println!("expression decrement"); - let span = next.as_span(); - let mut inner = next.into_inner(); - let expression = parse_term(inner.next().unwrap()); - let operation = match inner.next().unwrap().as_rule() { - Rule::operation_post_decrement => Decrement::from_pest(&mut pair.into_inner().next().unwrap().into_inner()).unwrap(), - rule => unreachable!("`expression_decrement` should yield `operation_post_decrement`, found {:#?}", rule) - }; - Expression::Decrement(DecrementExpression { operation, expression, span }) - }, - Rule::expression_postfix => { - Expression::Postfix( - PostfixExpression::from_pest(&mut pair.into_inner()).unwrap(), - ) - } - Rule::expression_primitive => { - let next = next.into_inner().next().unwrap(); - match next.as_rule() { - Rule::value => Expression::Value( - Value::from_pest(&mut pair.into_inner().next().unwrap().into_inner()).unwrap() - ), - Rule::variable => Expression::Variable( - Variable::from_pest(&mut pair.into_inner().next().unwrap().into_inner()).unwrap(), - ), - rule => unreachable!("`expression_primitive` should contain one of [`value`, `variable`], found {:#?}", rule) - } - }, - - rule => unreachable!("`term` should contain one of ['value', 'variable', 'expression', 'expression_not', 'expression_increment', 'expression_decrement'], found {:#?}", rule) - } - } - rule => unreachable!( - "`parse_expression_term` should be invoked on `Rule::expression_term`, found {:#?}", - rule - ), - }) -} - -fn binary_expression<'ast>( - lhs: Box>, - pair: Pair<'ast, Rule>, - rhs: Box>, -) -> Box> { - let (start, _) = lhs.span().clone().split(); - let (_, end) = rhs.span().clone().split(); - let span = start.span(&end); - - Box::new(match pair.as_rule() { - Rule::operation_or => Expression::binary(BinaryOperator::Or, lhs, rhs, span), - Rule::operation_and => Expression::binary(BinaryOperator::And, lhs, rhs, span), - Rule::operation_eq => Expression::binary(BinaryOperator::Eq, lhs, rhs, span), - Rule::operation_neq => Expression::binary(BinaryOperator::Neq, lhs, rhs, span), - Rule::operation_geq => Expression::binary(BinaryOperator::Geq, lhs, rhs, span), - Rule::operation_gt => Expression::binary(BinaryOperator::Gt, lhs, rhs, span), - Rule::operation_leq => Expression::binary(BinaryOperator::Leq, lhs, rhs, span), - Rule::operation_lt => Expression::binary(BinaryOperator::Lt, lhs, rhs, span), - Rule::operation_add => Expression::binary(BinaryOperator::Add, lhs, rhs, span), - Rule::operation_sub => Expression::binary(BinaryOperator::Sub, lhs, rhs, span), - Rule::operation_mul => Expression::binary(BinaryOperator::Mul, lhs, rhs, span), - Rule::operation_div => Expression::binary(BinaryOperator::Div, lhs, rhs, span), - Rule::operation_pow => Expression::binary(BinaryOperator::Pow, lhs, rhs, span), - _ => unreachable!(), - }) -} - -// Types - -#[derive(Clone, Debug, FromPest, PartialEq)] -#[pest_ast(rule(Rule::ty_bool))] -pub struct BooleanType<'ast> { - #[pest_ast(outer())] - pub span: Span<'ast>, -} - -#[derive(Clone, Debug, FromPest, PartialEq)] -#[pest_ast(rule(Rule::ty_field))] -pub struct FieldType<'ast> { - #[pest_ast(outer())] - pub span: Span<'ast>, -} - -#[derive(Clone, Debug, FromPest, PartialEq)] -#[pest_ast(rule(Rule::ty_struct))] -pub struct StructType<'ast> { - pub variable: Variable<'ast>, - #[pest_ast(outer())] - pub span: Span<'ast>, -} - -#[derive(Clone, Debug, FromPest, PartialEq)] -#[pest_ast(rule(Rule::ty_basic))] -pub enum BasicType<'ast> { - Field(FieldType<'ast>), - Boolean(BooleanType<'ast>), -} - -#[derive(Clone, Debug, FromPest, PartialEq)] -#[pest_ast(rule(Rule::ty_basic_or_struct))] -pub enum BasicOrStructType<'ast> { - Struct(StructType<'ast>), - Basic(BasicType<'ast>), -} - -#[derive(Clone, Debug, FromPest, PartialEq)] -#[pest_ast(rule(Rule::ty_array))] -pub struct ArrayType<'ast> { - pub ty: BasicType<'ast>, - pub count: Value<'ast>, - #[pest_ast(outer())] - pub span: Span<'ast>, -} - -#[derive(Clone, Debug, FromPest, PartialEq)] -#[pest_ast(rule(Rule::ty))] -pub enum Type<'ast> { - Basic(BasicType<'ast>), - Array(ArrayType<'ast>), - Struct(StructType<'ast>), -} - -impl<'ast> fmt::Display for Type<'ast> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match *self { - Type::Basic(ref _ty) => write!(f, "basic"), - Type::Array(ref _ty) => write!(f, "array"), - Type::Struct(ref _ty) => write!(f, "struct"), - } - } -} - // Visibility #[derive(Clone, Debug, FromPest, PartialEq)] @@ -286,7 +89,124 @@ pub enum BinaryOperator { Pow, } +// Types + +#[derive(Clone, Debug, FromPest, PartialEq)] +#[pest_ast(rule(Rule::ty_u32))] +pub struct U32Type<'ast> { + #[pest_ast(outer())] + pub span: Span<'ast>, +} + +#[derive(Clone, Debug, FromPest, PartialEq)] +#[pest_ast(rule(Rule::ty_bool))] +pub struct BooleanType<'ast> { + #[pest_ast(outer())] + pub span: Span<'ast>, +} + +#[derive(Clone, Debug, FromPest, PartialEq)] +#[pest_ast(rule(Rule::ty_field))] +pub struct FieldType<'ast> { + #[pest_ast(outer())] + pub span: Span<'ast>, +} + +#[derive(Clone, Debug, FromPest, PartialEq)] +#[pest_ast(rule(Rule::ty_struct))] +pub struct StructType<'ast> { + pub variable: Variable<'ast>, + #[pest_ast(outer())] + pub span: Span<'ast>, +} + +#[derive(Clone, Debug, FromPest, PartialEq)] +#[pest_ast(rule(Rule::ty_basic))] +pub enum BasicType<'ast> { + U32(U32Type<'ast>), + Field(FieldType<'ast>), + Boolean(BooleanType<'ast>), +} + +#[derive(Clone, Debug, FromPest, PartialEq)] +#[pest_ast(rule(Rule::ty_basic_or_struct))] +pub enum BasicOrStructType<'ast> { + Struct(StructType<'ast>), + Basic(BasicType<'ast>), +} + +#[derive(Clone, Debug, FromPest, PartialEq)] +#[pest_ast(rule(Rule::ty_array))] +pub struct ArrayType<'ast> { + pub ty: BasicType<'ast>, + pub count: Value<'ast>, + #[pest_ast(outer())] + pub span: Span<'ast>, +} + +#[derive(Clone, Debug, FromPest, PartialEq)] +#[pest_ast(rule(Rule::ty))] +pub enum Type<'ast> { + Basic(BasicType<'ast>), + Array(ArrayType<'ast>), + Struct(StructType<'ast>), +} + +impl<'ast> fmt::Display for Type<'ast> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + Type::Basic(ref _ty) => write!(f, "basic"), + Type::Array(ref _ty) => write!(f, "array"), + Type::Struct(ref _ty) => write!(f, "struct"), + } + } +} + // Values +#[derive(Clone, Debug, FromPest, PartialEq)] +#[pest_ast(rule(Rule::value_number))] +pub struct Number<'ast> { + #[pest_ast(outer(with(span_into_string)))] + pub value: String, + #[pest_ast(outer())] + pub span: Span<'ast>, +} + +impl<'ast> fmt::Display for Number<'ast> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self.value) + } +} + +#[derive(Clone, Debug, FromPest, PartialEq)] +#[pest_ast(rule(Rule::value_u32))] +pub struct U32<'ast> { + pub number: Number<'ast>, + pub ty: U32Type<'ast>, + #[pest_ast(outer())] + pub span: Span<'ast>, +} + +impl<'ast> fmt::Display for U32<'ast> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self.number) + } +} + +#[derive(Clone, Debug, FromPest, PartialEq)] +#[pest_ast(rule(Rule::value_field))] +pub struct Field<'ast> { + pub number: Number<'ast>, + pub ty: U32Type<'ast>, + #[pest_ast(outer())] + pub span: Span<'ast>, +} + +impl<'ast> fmt::Display for Field<'ast> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self.number) + } +} #[derive(Clone, Debug, FromPest, PartialEq)] #[pest_ast(rule(Rule::value_boolean))] @@ -303,33 +223,20 @@ impl<'ast> fmt::Display for Boolean<'ast> { } } -#[derive(Clone, Debug, FromPest, PartialEq)] -#[pest_ast(rule(Rule::value_field))] -pub struct Field<'ast> { - #[pest_ast(outer(with(span_into_string)))] - pub value: String, - #[pest_ast(outer())] - pub span: Span<'ast>, -} - -impl<'ast> fmt::Display for Field<'ast> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{}", self.value) - } -} - #[derive(Clone, Debug, FromPest, PartialEq)] #[pest_ast(rule(Rule::value))] pub enum Value<'ast> { - Boolean(Boolean<'ast>), + U32(U32<'ast>), Field(Field<'ast>), + Boolean(Boolean<'ast>), } impl<'ast> Value<'ast> { pub fn span(&self) -> &Span<'ast> { match self { - Value::Boolean(value) => &value.span, + Value::U32(value) => &value.span, Value::Field(value) => &value.span, + Value::Boolean(value) => &value.span, } } } @@ -337,8 +244,9 @@ impl<'ast> Value<'ast> { impl<'ast> fmt::Display for Value<'ast> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { - Value::Boolean(ref value) => write!(f, "{}", value), + Value::U32(ref value) => write!(f, "{}", value), Value::Field(ref value) => write!(f, "{}", value), + Value::Boolean(ref value) => write!(f, "{}", value), } } } @@ -720,6 +628,140 @@ impl<'ast> fmt::Display for Expression<'ast> { } } +fn precedence_climber() -> PrecClimber { + PrecClimber::new(vec![ + Operator::new(Rule::operation_or, Assoc::Left), + Operator::new(Rule::operation_and, Assoc::Left), + Operator::new(Rule::operation_eq, Assoc::Left) + | Operator::new(Rule::operation_neq, Assoc::Left), + Operator::new(Rule::operation_geq, Assoc::Left) + | Operator::new(Rule::operation_gt, Assoc::Left) + | Operator::new(Rule::operation_leq, Assoc::Left) + | Operator::new(Rule::operation_lt, Assoc::Left), + Operator::new(Rule::operation_add, Assoc::Left) + | Operator::new(Rule::operation_sub, Assoc::Left), + Operator::new(Rule::operation_mul, Assoc::Left) + | Operator::new(Rule::operation_div, Assoc::Left), + Operator::new(Rule::operation_pow, Assoc::Left), + ]) +} + +fn parse_term(pair: Pair) -> Box { + Box::new(match pair.as_rule() { + Rule::expression_term => { + let clone = pair.clone(); + let next = clone.into_inner().next().unwrap(); + match next.as_rule() { + Rule::expression => Expression::from_pest(&mut pair.into_inner()).unwrap(), // Parenthesis case + Rule::expression_inline_struct => { + Expression::StructInline( + StructInlineExpression::from_pest(&mut pair.into_inner()).unwrap(), + ) + }, + Rule::expression_array_inline => { + Expression::ArrayInline( + ArrayInlineExpression::from_pest(&mut pair.into_inner()).unwrap() + ) + }, + Rule::expression_array_initializer => { + Expression::ArrayInitializer( + ArrayInitializerExpression::from_pest(&mut pair.into_inner()).unwrap() + ) + }, + Rule::expression_conditional => { + Expression::Ternary( + TernaryExpression::from_pest(&mut pair.into_inner()).unwrap(), + ) + }, + Rule::expression_not => { + let span = next.as_span(); + let mut inner = next.into_inner(); + let operation = match inner.next().unwrap().as_rule() { + Rule::operation_pre_not => Not::from_pest(&mut pair.into_inner().next().unwrap().into_inner()).unwrap(), + rule => unreachable!("`expression_not` should yield `operation_pre_not`, found {:#?}", rule) + }; + let expression = parse_term(inner.next().unwrap()); + Expression::Not(NotExpression { operation, expression, span }) + }, + Rule::expression_increment => { + println!("expression increment"); + let span = next.as_span(); + let mut inner = next.into_inner(); + let expression = parse_term(inner.next().unwrap()); + let operation = match inner.next().unwrap().as_rule() { + Rule::operation_post_increment => Increment::from_pest(&mut pair.into_inner().next().unwrap().into_inner()).unwrap(), + rule => unreachable!("`expression_increment` should yield `operation_post_increment`, found {:#?}", rule) + }; + Expression::Increment(IncrementExpression { operation, expression, span }) + }, + Rule::expression_decrement => { + println!("expression decrement"); + let span = next.as_span(); + let mut inner = next.into_inner(); + let expression = parse_term(inner.next().unwrap()); + let operation = match inner.next().unwrap().as_rule() { + Rule::operation_post_decrement => Decrement::from_pest(&mut pair.into_inner().next().unwrap().into_inner()).unwrap(), + rule => unreachable!("`expression_decrement` should yield `operation_post_decrement`, found {:#?}", rule) + }; + Expression::Decrement(DecrementExpression { operation, expression, span }) + }, + Rule::expression_postfix => { + Expression::Postfix( + PostfixExpression::from_pest(&mut pair.into_inner()).unwrap(), + ) + } + Rule::expression_primitive => { + let next = next.into_inner().next().unwrap(); + match next.as_rule() { + Rule::value => { + Expression::Value( + Value::from_pest(&mut pair.into_inner().next().unwrap().into_inner()).unwrap() + ) + }, + Rule::variable => Expression::Variable( + Variable::from_pest(&mut pair.into_inner().next().unwrap().into_inner()).unwrap(), + ), + rule => unreachable!("`expression_primitive` should contain one of [`value`, `variable`], found {:#?}", rule) + } + }, + + rule => unreachable!("`term` should contain one of ['value', 'variable', 'expression', 'expression_not', 'expression_increment', 'expression_decrement'], found {:#?}", rule) + } + } + rule => unreachable!( + "`parse_expression_term` should be invoked on `Rule::expression_term`, found {:#?}", + rule + ), + }) +} + +fn binary_expression<'ast>( + lhs: Box>, + pair: Pair<'ast, Rule>, + rhs: Box>, +) -> Box> { + let (start, _) = lhs.span().clone().split(); + let (_, end) = rhs.span().clone().split(); + let span = start.span(&end); + + Box::new(match pair.as_rule() { + Rule::operation_or => Expression::binary(BinaryOperator::Or, lhs, rhs, span), + Rule::operation_and => Expression::binary(BinaryOperator::And, lhs, rhs, span), + Rule::operation_eq => Expression::binary(BinaryOperator::Eq, lhs, rhs, span), + Rule::operation_neq => Expression::binary(BinaryOperator::Neq, lhs, rhs, span), + Rule::operation_geq => Expression::binary(BinaryOperator::Geq, lhs, rhs, span), + Rule::operation_gt => Expression::binary(BinaryOperator::Gt, lhs, rhs, span), + Rule::operation_leq => Expression::binary(BinaryOperator::Leq, lhs, rhs, span), + Rule::operation_lt => Expression::binary(BinaryOperator::Lt, lhs, rhs, span), + Rule::operation_add => Expression::binary(BinaryOperator::Add, lhs, rhs, span), + Rule::operation_sub => Expression::binary(BinaryOperator::Sub, lhs, rhs, span), + Rule::operation_mul => Expression::binary(BinaryOperator::Mul, lhs, rhs, span), + Rule::operation_div => Expression::binary(BinaryOperator::Div, lhs, rhs, span), + Rule::operation_pow => Expression::binary(BinaryOperator::Pow, lhs, rhs, span), + _ => unreachable!(), + }) +} + impl<'ast> FromPest<'ast> for Expression<'ast> { type Rule = Rule; type FatalError = Void; @@ -759,6 +801,14 @@ pub struct DefinitionStatement<'ast> { pub span: Span<'ast>, } +#[derive(Clone, Debug, FromPest, PartialEq)] +#[pest_ast(rule(Rule::statement_return))] +pub struct ReturnStatement<'ast> { + pub expressions: Vec>, + #[pest_ast(outer())] + pub span: Span<'ast>, +} + #[derive(Clone, Debug, FromPest, PartialEq)] #[pest_ast(rule(Rule::statement_for))] pub struct ForStatement<'ast> { @@ -770,21 +820,13 @@ pub struct ForStatement<'ast> { pub span: Span<'ast>, } -#[derive(Clone, Debug, FromPest, PartialEq)] -#[pest_ast(rule(Rule::statement_return))] -pub struct ReturnStatement<'ast> { - pub expressions: Vec>, - #[pest_ast(outer())] - pub span: Span<'ast>, -} - #[derive(Clone, Debug, FromPest, PartialEq)] #[pest_ast(rule(Rule::statement))] pub enum Statement<'ast> { Assign(AssignStatement<'ast>), Definition(DefinitionStatement<'ast>), - Iteration(ForStatement<'ast>), Return(ReturnStatement<'ast>), + Iteration(ForStatement<'ast>), } impl<'ast> fmt::Display for AssignStatement<'ast> { @@ -799,16 +841,6 @@ impl<'ast> fmt::Display for DefinitionStatement<'ast> { } } -impl<'ast> fmt::Display for ForStatement<'ast> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!( - f, - "for {} in {}..{} do {:#?} endfor", - self.index, self.start, self.stop, self.statements - ) - } -} - impl<'ast> fmt::Display for ReturnStatement<'ast> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { for (i, expression) in self.expressions.iter().enumerate() { @@ -821,13 +853,23 @@ impl<'ast> fmt::Display for ReturnStatement<'ast> { } } +impl<'ast> fmt::Display for ForStatement<'ast> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "for {} in {}..{} do {:#?} endfor", + self.index, self.start, self.stop, self.statements + ) + } +} + impl<'ast> fmt::Display for Statement<'ast> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { Statement::Assign(ref statement) => write!(f, "{}", statement), Statement::Definition(ref statement) => write!(f, "{}", statement), - Statement::Iteration(ref statement) => write!(f, "{}", statement), Statement::Return(ref statement) => write!(f, "{}", statement), + Statement::Iteration(ref statement) => write!(f, "{}", statement), } } } diff --git a/src/language.pest b/src/language.pest index 15e5d67e65..c0bdfb6a62 100644 --- a/src/language.pest +++ b/src/language.pest @@ -48,10 +48,10 @@ operation_binary = _ { // operation_div_assign = { "/=" } /// Types - +ty_u32 = {"u32"} ty_field = {"field"} ty_bool = {"bool"} -ty_basic = { ty_field | ty_bool } +ty_basic = { ty_u32 | ty_field | ty_bool } ty_struct = { variable } ty_basic_or_struct = {ty_basic | ty_struct } ty_array = {ty_basic ~ ("[" ~ value ~ "]")+ } @@ -59,10 +59,11 @@ ty = {ty_array | ty_basic | ty_struct} type_list = _{(ty ~ ("," ~ ty)*)?} /// Values - +value_number = @{ "-"? ~ ("0" | ASCII_NONZERO_DIGIT ~ ASCII_DIGIT*)} +value_u32 = { value_number ~ ty_u32} +value_field = { value_number ~ ty_field } value_boolean = { "true" | "false" } -value_field = @{ "-"? ~ ("0" | ASCII_NONZERO_DIGIT ~ ASCII_DIGIT*) } -value = { value_boolean | value_field } +value = { value_u32 | value_field | value_boolean } /// Variables @@ -88,8 +89,8 @@ access_member = { "." ~ variable } access = { access_array | access_call | access_member } expression_postfix = { variable ~ access+ } -assignee = { variable ~ assignee_access* } assignee_access = { access_array | access_member } +assignee = { variable ~ assignee_access* } spread = { "..." ~ expression } spread_or_expression = { spread | expression } diff --git a/src/main.rs b/src/main.rs index 531879e162..e1f20ddbf0 100644 --- a/src/main.rs +++ b/src/main.rs @@ -47,7 +47,7 @@ impl ConstraintSynthesizer for Benchmark { let syntax_tree = ast::File::from_pest(&mut file).expect("infallible"); // println!("{:#?}", syntax_tree); - let program = aleo_program::Program::from(syntax_tree); + let program = aleo_program::Program::<'_, F>::from(syntax_tree); println!(" compiled: {:#?}", program); let program = program.name("simple".into());