diff --git a/benchmark/simple.leo b/benchmark/simple.leo index cc1caf9289..049ddfb201 100644 --- a/benchmark/simple.leo +++ b/benchmark/simple.leo @@ -1,7 +1,7 @@ circuit PedersenHash { parameters: group[1] - function new(b: u32) -> u32 { + static function new(b: u32) -> u32 { return b } } @@ -9,7 +9,7 @@ circuit PedersenHash { function main() -> u32{ let parameters = [0group; 1]; let pedersen = PedersenHash { parameters: parameters }; - let b = pedersen.new(3); + let b = PedersenHash::new(3); return b } \ No newline at end of file diff --git a/compiler/src/ast.rs b/compiler/src/ast.rs index 48e221a047..081fa3406f 100644 --- a/compiler/src/ast.rs +++ b/compiler/src/ast.rs @@ -438,8 +438,16 @@ pub struct ArrayAccess<'ast> { } #[derive(Clone, Debug, FromPest, PartialEq)] -#[pest_ast(rule(Rule::access_member))] -pub struct MemberAccess<'ast> { +#[pest_ast(rule(Rule::access_object))] +pub struct ObjectAccess<'ast> { + pub identifier: Identifier<'ast>, + #[pest_ast(outer())] + pub span: Span<'ast>, +} + +#[derive(Clone, Debug, FromPest, PartialEq)] +#[pest_ast(rule(Rule::access_static_object))] +pub struct StaticObjectAccess<'ast> { pub identifier: Identifier<'ast>, #[pest_ast(outer())] pub span: Span<'ast>, @@ -450,7 +458,8 @@ pub struct MemberAccess<'ast> { pub enum Access<'ast> { Array(ArrayAccess<'ast>), Call(CallAccess<'ast>), - Member(MemberAccess<'ast>), + Object(ObjectAccess<'ast>), + StaticObject(StaticObjectAccess<'ast>), } #[derive(Clone, Debug, FromPest, PartialEq)] @@ -466,7 +475,7 @@ pub struct PostfixExpression<'ast> { #[pest_ast(rule(Rule::assignee_access))] pub enum AssigneeAccess<'ast> { Array(ArrayAccess<'ast>), - Member(MemberAccess<'ast>), + Member(ObjectAccess<'ast>), } impl<'ast> fmt::Display for AssigneeAccess<'ast> { diff --git a/compiler/src/constraints/expression.rs b/compiler/src/constraints/expression.rs index ea7b35e8d9..136994b1cf 100644 --- a/compiler/src/constraints/expression.rs +++ b/compiler/src/constraints/expression.rs @@ -6,6 +6,7 @@ use crate::{ ConstrainedProgram, ConstrainedValue, }, errors::ExpressionError, + new_scope, types::{ CircuitMember, CircuitObject, Expression, Identifier, RangeOrExpression, SpreadOrExpression, }, @@ -20,21 +21,23 @@ impl> ConstrainedProgra /// Enforce a variable expression by getting the resolved value pub(crate) fn evaluate_identifier( &mut self, - scope: String, - unresolved_variable: Identifier, + file_scope: String, + function_scope: String, + unresolved_identifier: Identifier, ) -> Result, ExpressionError> { - // Evaluate the variable name in the current function scope - let variable_name = new_scope_from_variable(scope, &unresolved_variable); + // Evaluate the identifier name in the current function scope + let variable_name = new_scope(function_scope, unresolved_identifier.to_string()); + let identifier_name = new_scope(file_scope, unresolved_identifier.to_string()); - if self.contains_name(&variable_name) { + if let Some(variable) = self.get(&variable_name) { // Reassigning variable to another variable - Ok(self.get_mut(&variable_name).unwrap().clone()) - } else if self.contains_variable(&unresolved_variable) { + Ok(variable.clone()) + } else if let Some(identifier) = self.get(&identifier_name) { // Check global scope (function and circuit names) - Ok(self.get_mut_variable(&unresolved_variable).unwrap().clone()) + Ok(identifier.clone()) } else { - Err(ExpressionError::UndefinedVariable( - unresolved_variable.to_string(), + Err(ExpressionError::UndefinedIdentifier( + unresolved_identifier.to_string(), )) } } @@ -449,6 +452,53 @@ impl> ConstrainedProgra } } + fn enforce_circuit_static_access_expression( + &mut self, + cs: &mut CS, + file_scope: String, + function_scope: String, + circuit_identifier: Box>, + circuit_member: Identifier, + ) -> Result, ExpressionError> { + // Get defined circuit + let circuit = match self.enforce_expression( + cs, + file_scope.clone(), + function_scope.clone(), + *circuit_identifier.clone(), + )? { + ConstrainedValue::CircuitDefinition(circuit_definition) => circuit_definition, + value => return Err(ExpressionError::InvalidCircuitAccess(value.to_string())), + }; + + // Find static circuit function + let matched_function = circuit.objects.into_iter().find(|member| match member { + CircuitObject::CircuitFunction(_static, _function) => *_static, + _ => false, + }); + + // Return errors if no static function exists + let function = match matched_function { + Some(CircuitObject::CircuitFunction(_static, function)) => { + if _static { + function + } else { + return Err(ExpressionError::InvalidStaticFunction( + function.function_name.to_string(), + )); + } + } + _ => { + return Err(ExpressionError::UndefinedStaticFunction( + circuit.identifier.to_string(), + circuit_member.to_string(), + )) + } + }; + + Ok(ConstrainedValue::Function(function)) + } + fn enforce_function_call_expression( &mut self, cs: &mut CS, @@ -492,7 +542,7 @@ impl> ConstrainedProgra match expression { // Variables Expression::Identifier(unresolved_variable) => { - self.evaluate_identifier(function_scope, unresolved_variable) + self.evaluate_identifier(file_scope, function_scope, unresolved_variable) } // Values @@ -690,7 +740,7 @@ impl> ConstrainedProgra circuit_name, members, ), - Expression::CircuitMemberAccess(circuit_variable, circuit_member) => self + Expression::CircuitObjectAccess(circuit_variable, circuit_member) => self .enforce_circuit_access_expression( cs, file_scope, @@ -698,6 +748,14 @@ impl> ConstrainedProgra circuit_variable, circuit_member, ), + Expression::CircuitStaticObjectAccess(circuit_identifier, circuit_member) => self + .enforce_circuit_static_access_expression( + cs, + file_scope, + function_scope, + circuit_identifier, + circuit_member, + ), // Functions Expression::FunctionCall(function, arguments) => self.enforce_function_call_expression( diff --git a/compiler/src/constraints/function.rs b/compiler/src/constraints/function.rs index 573c6bd068..1f6485b95c 100644 --- a/compiler/src/constraints/function.rs +++ b/compiler/src/constraints/function.rs @@ -31,8 +31,8 @@ impl> ConstrainedProgra input: Expression, ) -> Result, FunctionError> { match input { - Expression::Identifier(variable) => { - Ok(self.evaluate_identifier(caller_scope, variable)?) + Expression::Identifier(identifier) => { + Ok(self.evaluate_identifier(caller_scope, function_name, identifier)?) } expression => Ok(self.enforce_expression(cs, scope, function_name, expression)?), } diff --git a/compiler/src/constraints/program.rs b/compiler/src/constraints/program.rs index b93f940b13..15f691135c 100644 --- a/compiler/src/constraints/program.rs +++ b/compiler/src/constraints/program.rs @@ -66,14 +66,6 @@ impl> ConstrainedProgra self.store(variable.name, value); } - pub(crate) fn contains_name(&self, name: &String) -> bool { - self.identifiers.contains_key(name) - } - - pub(crate) fn contains_variable(&self, variable: &Identifier) -> bool { - self.contains_name(&variable.name) - } - pub(crate) fn get(&self, name: &String) -> Option<&ConstrainedValue> { self.identifiers.get(name) } diff --git a/compiler/src/errors/constraints/expression.rs b/compiler/src/errors/constraints/expression.rs index 8bf16835fe..d132c851f3 100644 --- a/compiler/src/errors/constraints/expression.rs +++ b/compiler/src/errors/constraints/expression.rs @@ -2,9 +2,9 @@ use crate::errors::{BooleanError, FieldElementError, FunctionError, IntegerError #[derive(Debug, Error)] pub enum ExpressionError { - // Variables + // Identifiers #[error("Variable \"{}\" not found", _0)] - UndefinedVariable(String), + UndefinedIdentifier(String), // Types #[error("{}", _0)] @@ -63,6 +63,15 @@ pub enum ExpressionError { #[error("Expected circuit value {}", _0)] ExpectedCircuitValue(String), + #[error("Circuit {} has no static function {}", _0, _1)] + UndefinedStaticFunction(String, String), + + #[error( + "Static access only supported for static circuit functions, got function {}", + _0 + )] + InvalidStaticFunction(String), + // Functions #[error( "Function {} must be declared before it is used in an inline expression", diff --git a/compiler/src/leo.pest b/compiler/src/leo.pest index 8f84c5d193..512b006d9d 100644 --- a/compiler/src/leo.pest +++ b/compiler/src/leo.pest @@ -111,12 +111,13 @@ range_or_expression = { range | expression } access_array = { "[" ~ range_or_expression ~ "]" } access_call = { "(" ~ expression_tuple ~ ")" } -access_member = { "." ~ identifier } -access = { access_array | access_call | access_member } +access_object = { "." ~ identifier } +access_static_object = { "::" ~ identifier } +access = { access_array | access_call | access_object | access_static_object} expression_postfix = { identifier ~ access+ } -assignee_access = { access_array | access_member } +assignee_access = { access_array | access_object } assignee = { identifier ~ assignee_access* } spread = { "..." ~ expression } diff --git a/compiler/src/types.rs b/compiler/src/types.rs index c9ab64defc..213d31755e 100644 --- a/compiler/src/types.rs +++ b/compiler/src/types.rs @@ -154,7 +154,8 @@ pub enum Expression { // Circuits Circuit(Identifier, Vec>), - CircuitMemberAccess(Box>, Identifier), // (circuit name, circuit object name) + CircuitObjectAccess(Box>, Identifier), // (declared circuit name, circuit object name) + CircuitStaticObjectAccess(Box>, Identifier), // (defined circuit name, circuit staic object name) // Functions FunctionCall(Box>, Vec>), diff --git a/compiler/src/types_display.rs b/compiler/src/types_display.rs index c3b717f251..66dc69d894 100644 --- a/compiler/src/types_display.rs +++ b/compiler/src/types_display.rs @@ -154,8 +154,11 @@ impl<'ast, F: Field + PrimeField, G: Group> fmt::Display for Expression { } write!(f, "}}") } - Expression::CircuitMemberAccess(ref circuit_variable, ref member) => { - write!(f, "{}.{}", circuit_variable, member) + Expression::CircuitObjectAccess(ref circuit_name, ref member) => { + write!(f, "{}.{}", circuit_name, member) + } + Expression::CircuitStaticObjectAccess(ref circuit_name, ref member) => { + write!(f, "{}::{}", circuit_name, member) } // Function calls diff --git a/compiler/src/types_from.rs b/compiler/src/types_from.rs index dd5cee18a4..f94ce6461e 100644 --- a/compiler/src/types_from.rs +++ b/compiler/src/types_from.rs @@ -317,49 +317,33 @@ impl<'ast, F: Field + PrimeField, G: Group> From> .accesses .into_iter() .fold(variable, |acc, access| match access { - // Handle function calls - ast::Access::Call(function) => { - types::Expression::FunctionCall( - Box::new(acc), - function - .expressions - .into_iter() - .map(|expression| types::Expression::from(expression)) - .collect(), - ) - //match acc { - // Normal function call - // types::Expression::Identifier(identifier) => types::Expression::FunctionCall( - // Box::new(acc), - // function - // .expressions - // .into_iter() - // .map(|expression| types::Expression::from(expression)) - // .collect(), - // ), - // // Circuit function call - // types::Expression::CircuitMemberAccess(acc, circuit_function) => types::Expression::FunctionCall( - // circuit_function, - // function - // .expressions - // .into_iter() - // .map(|expression| types::Expression::from(expression)) - // .collect(), - // ), - // expression => { - // unimplemented!("only function names are callable, found \"{}\"", expression) - // } - } - // Handle circuit member accesses - ast::Access::Member(circuit_member) => types::Expression::CircuitMemberAccess( - Box::new(acc), - types::Identifier::from(circuit_member.identifier), - ), // Handle array accesses ast::Access::Array(array) => types::Expression::ArrayAccess( Box::new(acc), Box::new(types::RangeOrExpression::from(array.expression)), ), + + // Handle function calls + ast::Access::Call(function) => types::Expression::FunctionCall( + Box::new(acc), + function + .expressions + .into_iter() + .map(|expression| types::Expression::from(expression)) + .collect(), + ), + + // Handle circuit member accesses + ast::Access::Object(circuit_object) => types::Expression::CircuitObjectAccess( + Box::new(acc), + types::Identifier::from(circuit_object.identifier), + ), + ast::Access::StaticObject(circuit_object) => { + types::Expression::CircuitStaticObjectAccess( + Box::new(acc), + types::Identifier::from(circuit_object.identifier), + ) + } }) } } @@ -406,7 +390,7 @@ impl<'ast, F: Field + PrimeField, G: Group> From> for types: .into_iter() .fold(variable, |acc, access| match access { ast::AssigneeAccess::Member(struct_member) => { - types::Expression::CircuitMemberAccess( + types::Expression::CircuitObjectAccess( Box::new(acc), types::Identifier::from(struct_member.identifier), )