diff --git a/benchmark/simple.leo b/benchmark/simple.leo index f67173df8b..5df81712cb 100644 --- a/benchmark/simple.leo +++ b/benchmark/simple.leo @@ -1,19 +1,29 @@ -circuit Circ { - x: u32 +circuit PedersenHash { + + parameters: group[1] + + static function new(parameters: group[1]) -> Self { + return Self { parameters: parameters } + } + + function hash(bits: bool[1]) -> group { + let mut digest: group = 0group; + + for i in 0..1 { + let base: group = if bits[i] ? parameters[i] : 0group; + digest += base; + } + + return digest + } } -function main() -> u32 { - let mut a = 1; - a = 0; +function main() -> group { + let parameters = [0group; 1]; + let pedersen = PedersenHash::new(parameters); - let b = 1; - //b = 0; // <- illegal + let input: bool[1] = [true]; + let output = pedersen.hash(input); - let mut arr = [1, 2]; - arr[0] = 0; - - let mut c = Circ { x: 1 }; - c.x = 0; - - return c.x + return output } \ No newline at end of file diff --git a/compiler/src/ast.rs b/compiler/src/ast.rs index 65dc7a23a9..c496e894b8 100644 --- a/compiler/src/ast.rs +++ b/compiler/src/ast.rs @@ -171,24 +171,15 @@ pub enum IntegerType { #[derive(Clone, Debug, FromPest, PartialEq)] #[pest_ast(rule(Rule::type_field))] -pub struct FieldType<'ast> { - #[pest_ast(outer())] - pub span: Span<'ast>, -} +pub struct FieldType {} #[derive(Clone, Debug, FromPest, PartialEq)] #[pest_ast(rule(Rule::type_group))] -pub struct GroupType<'ast> { - #[pest_ast(outer())] - pub span: Span<'ast>, -} +pub struct GroupType {} #[derive(Clone, Debug, FromPest, PartialEq)] #[pest_ast(rule(Rule::type_bool))] -pub struct BooleanType<'ast> { - #[pest_ast(outer())] - pub span: Span<'ast>, -} +pub struct BooleanType {} #[derive(Clone, Debug, FromPest, PartialEq)] #[pest_ast(rule(Rule::type_circuit))] @@ -198,19 +189,23 @@ pub struct CircuitType<'ast> { pub span: Span<'ast>, } +#[derive(Clone, Debug, FromPest, PartialEq)] +#[pest_ast(rule(Rule::type_self))] +pub struct SelfType {} + #[derive(Clone, Debug, FromPest, PartialEq)] #[pest_ast(rule(Rule::type_basic))] -pub enum BasicType<'ast> { +pub enum BasicType { Integer(IntegerType), - Field(FieldType<'ast>), - Group(GroupType<'ast>), - Boolean(BooleanType<'ast>), + Field(FieldType), + Group(GroupType), + Boolean(BooleanType), } #[derive(Clone, Debug, FromPest, PartialEq)] #[pest_ast(rule(Rule::type_array))] pub struct ArrayType<'ast> { - pub _type: BasicType<'ast>, + pub _type: BasicType, pub dimensions: Vec>, #[pest_ast(outer())] pub span: Span<'ast>, @@ -219,9 +214,10 @@ pub struct ArrayType<'ast> { #[derive(Clone, Debug, FromPest, PartialEq)] #[pest_ast(rule(Rule::_type))] pub enum Type<'ast> { - Basic(BasicType<'ast>), + Basic(BasicType), Array(ArrayType<'ast>), Circuit(CircuitType<'ast>), + SelfType(SelfType), } impl<'ast> fmt::Display for Type<'ast> { @@ -230,6 +226,7 @@ impl<'ast> fmt::Display for Type<'ast> { Type::Basic(ref _type) => write!(f, "basic"), Type::Array(ref _type) => write!(f, "array"), Type::Circuit(ref _type) => write!(f, "struct"), + Type::SelfType(ref _type) => write!(f, "Self"), } } } @@ -269,7 +266,7 @@ impl<'ast> fmt::Display for Integer<'ast> { #[pest_ast(rule(Rule::value_field))] pub struct Field<'ast> { pub number: Number<'ast>, - pub _type: FieldType<'ast>, + pub _type: FieldType, #[pest_ast(outer())] pub span: Span<'ast>, } @@ -284,7 +281,7 @@ impl<'ast> fmt::Display for Field<'ast> { #[pest_ast(rule(Rule::value_group))] pub struct Group<'ast> { pub number: Number<'ast>, - pub _type: GroupType<'ast>, + pub _type: GroupType, #[pest_ast(outer())] pub span: Span<'ast>, } @@ -445,12 +442,21 @@ pub struct MemberAccess<'ast> { pub span: Span<'ast>, } +#[derive(Clone, Debug, FromPest, PartialEq)] +#[pest_ast(rule(Rule::access_static_member))] +pub struct StaticMemberAccess<'ast> { + pub identifier: Identifier<'ast>, + #[pest_ast(outer())] + pub span: Span<'ast>, +} + #[derive(Clone, Debug, FromPest, PartialEq)] #[pest_ast(rule(Rule::access))] pub enum Access<'ast> { Array(ArrayAccess<'ast>), Call(CallAccess<'ast>), - Member(MemberAccess<'ast>), + Object(MemberAccess<'ast>), + StaticObject(StaticMemberAccess<'ast>), } #[derive(Clone, Debug, FromPest, PartialEq)] @@ -552,8 +558,8 @@ pub struct ArrayInitializerExpression<'ast> { // Circuits #[derive(Clone, Debug, FromPest, PartialEq)] -#[pest_ast(rule(Rule::circuit_object))] -pub struct CircuitObject<'ast> { +#[pest_ast(rule(Rule::circuit_field_definition))] +pub struct CircuitFieldDefinition<'ast> { pub identifier: Identifier<'ast>, pub _type: Type<'ast>, #[pest_ast(outer())] @@ -561,17 +567,37 @@ pub struct CircuitObject<'ast> { } #[derive(Clone, Debug, FromPest, PartialEq)] -#[pest_ast(rule(Rule::circuit_definition))] -pub struct Circuit<'ast> { - pub identifier: Identifier<'ast>, - pub fields: Vec>, +#[pest_ast(rule(Rule::_static))] +pub struct Static {} + +#[derive(Clone, Debug, FromPest, PartialEq)] +#[pest_ast(rule(Rule::circuit_function))] +pub struct CircuitFunction<'ast> { + pub _static: Option, + pub function: Function<'ast>, #[pest_ast(outer())] pub span: Span<'ast>, } #[derive(Clone, Debug, FromPest, PartialEq)] -#[pest_ast(rule(Rule::inline_circuit_member))] -pub struct InlineCircuitMember<'ast> { +#[pest_ast(rule(Rule::circuit_member))] +pub enum CircuitMember<'ast> { + CircuitFieldDefinition(CircuitFieldDefinition<'ast>), + CircuitFunction(CircuitFunction<'ast>), +} + +#[derive(Clone, Debug, FromPest, PartialEq)] +#[pest_ast(rule(Rule::circuit_definition))] +pub struct Circuit<'ast> { + pub identifier: Identifier<'ast>, + pub members: Vec>, + #[pest_ast(outer())] + pub span: Span<'ast>, +} + +#[derive(Clone, Debug, FromPest, PartialEq)] +#[pest_ast(rule(Rule::circuit_field))] +pub struct CircuitField<'ast> { pub identifier: Identifier<'ast>, pub expression: Expression<'ast>, #[pest_ast(outer())] @@ -579,10 +605,10 @@ pub struct InlineCircuitMember<'ast> { } #[derive(Clone, Debug, FromPest, PartialEq)] -#[pest_ast(rule(Rule::expression_inline_circuit))] +#[pest_ast(rule(Rule::expression_circuit_inline))] pub struct CircuitInlineExpression<'ast> { pub identifier: Identifier<'ast>, - pub members: Vec>, + pub members: Vec>, #[pest_ast(outer())] pub span: Span<'ast>, } @@ -758,7 +784,7 @@ fn parse_term(pair: Pair) -> Box { let next = clone.into_inner().next().unwrap(); match next.as_rule() { Rule::expression => Expression::from_pest(&mut pair.into_inner()).unwrap(), // Parenthesis case - Rule::expression_inline_circuit => { + Rule::expression_circuit_inline => { Expression::CircuitInline( CircuitInlineExpression::from_pest(&mut pair.into_inner()).unwrap(), ) diff --git a/compiler/src/constraints/expression.rs b/compiler/src/constraints/expression.rs index 1f8f6b7d0e..9740ac5f0c 100644 --- a/compiler/src/constraints/expression.rs +++ b/compiler/src/constraints/expression.rs @@ -2,11 +2,14 @@ use crate::{ constraints::{ - new_scope_from_variable, new_variable_from_variable, ConstrainedCircuitObject, - ConstrainedProgram, ConstrainedValue, + new_scope_from_variable, ConstrainedCircuitMember, ConstrainedProgram, ConstrainedValue, }, errors::ExpressionError, - types::{CircuitMember, Expression, Identifier, RangeOrExpression, SpreadOrExpression}, + new_scope, + types::{ + CircuitFieldDefinition, CircuitMember, Expression, Identifier, RangeOrExpression, + SpreadOrExpression, + }, }; use snarkos_models::{ @@ -18,21 +21,23 @@ impl> ConstrainedProgra /// Enforce a variable expression by getting the resolved value pub(crate) fn evaluate_identifier( &mut self, - scope: String, - unresolved_variable: Identifier, + file_scope: String, + function_scope: String, + unresolved_identifier: Identifier, ) -> Result, ExpressionError> { - // Evaluate the variable name in the current function scope - let variable_name = new_scope_from_variable(scope, &unresolved_variable); + // Evaluate the identifier name in the current function scope + let variable_name = new_scope(function_scope, unresolved_identifier.to_string()); + let identifier_name = new_scope(file_scope, unresolved_identifier.to_string()); - if self.contains_name(&variable_name) { + if let Some(variable) = self.get(&variable_name) { // Reassigning variable to another variable - Ok(self.get_mut(&variable_name).unwrap().clone()) - } else if self.contains_variable(&unresolved_variable) { + Ok(variable.clone()) + } else if let Some(identifier) = self.get(&identifier_name) { // Check global scope (function and circuit names) - Ok(self.get_mut_variable(&unresolved_variable).unwrap().clone()) + Ok(identifier.clone()) } else { - Err(ExpressionError::UndefinedVariable( - unresolved_variable.to_string(), + Err(ExpressionError::UndefinedIdentifier( + unresolved_identifier.to_string(), )) } } @@ -54,11 +59,18 @@ impl> ConstrainedProgra (ConstrainedValue::GroupElement(ge_1), ConstrainedValue::GroupElement(ge_2)) => { Self::evaluate_group_add(ge_1, ge_2) } + (ConstrainedValue::Mutable(val_1), val_2) => { + self.enforce_add_expression(cs, *val_1, val_2)? + } + (val_1, ConstrainedValue::Mutable(val_2)) => { + self.enforce_add_expression(cs, val_1, *val_2)? + } (val_1, val_2) => { + println!("not both groups"); return Err(ExpressionError::IncompatibleTypes(format!( "{} + {}", val_1, val_2, - ))) + ))); } }) } @@ -79,6 +91,12 @@ impl> ConstrainedProgra (ConstrainedValue::GroupElement(ge_1), ConstrainedValue::GroupElement(ge_2)) => { Self::evaluate_group_sub(ge_1, ge_2) } + (ConstrainedValue::Mutable(val_1), val_2) => { + self.enforce_sub_expression(cs, *val_1, val_2)? + } + (val_1, ConstrainedValue::Mutable(val_2)) => { + self.enforce_sub_expression(cs, val_1, *val_2)? + } (val_1, val_2) => { return Err(ExpressionError::IncompatibleTypes(format!( "{} - {}", @@ -101,6 +119,12 @@ impl> ConstrainedProgra (ConstrainedValue::FieldElement(fe_1), ConstrainedValue::FieldElement(fe_2)) => { self.enforce_field_mul(cs, fe_1, fe_2)? } + (ConstrainedValue::Mutable(val_1), val_2) => { + self.enforce_mul_expression(cs, *val_1, val_2)? + } + (val_1, ConstrainedValue::Mutable(val_2)) => { + self.enforce_mul_expression(cs, val_1, *val_2)? + } (val_1, val_2) => { return Err(ExpressionError::IncompatibleTypes(format!( "{} * {}", @@ -123,6 +147,12 @@ impl> ConstrainedProgra (ConstrainedValue::FieldElement(fe_1), ConstrainedValue::FieldElement(fe_2)) => { self.enforce_field_div(cs, fe_1, fe_2)? } + (ConstrainedValue::Mutable(val_1), val_2) => { + self.enforce_div_expression(cs, *val_1, val_2)? + } + (val_1, ConstrainedValue::Mutable(val_2)) => { + self.enforce_div_expression(cs, val_1, *val_2)? + } (val_1, val_2) => { return Err(ExpressionError::IncompatibleTypes(format!( "{} / {}", @@ -144,6 +174,12 @@ impl> ConstrainedProgra (ConstrainedValue::FieldElement(fe_1), ConstrainedValue::Integer(num_2)) => { self.enforce_field_pow(cs, fe_1, num_2)? } + (ConstrainedValue::Mutable(val_1), val_2) => { + self.enforce_pow_expression(cs, *val_1, val_2)? + } + (val_1, ConstrainedValue::Mutable(val_2)) => { + self.enforce_pow_expression(cs, val_1, *val_2)? + } (_, ConstrainedValue::FieldElement(num_2)) => { return Err(ExpressionError::InvalidExponent(num_2.to_string())) } @@ -175,6 +211,12 @@ impl> ConstrainedProgra (ConstrainedValue::GroupElement(ge_1), ConstrainedValue::GroupElement(ge_2)) => { Self::evaluate_group_eq(ge_1, ge_2) } + (ConstrainedValue::Mutable(val_1), val_2) => { + self.evaluate_eq_expression(*val_1, val_2)? + } + (val_1, ConstrainedValue::Mutable(val_2)) => { + self.evaluate_eq_expression(val_1, *val_2)? + } (val_1, val_2) => { return Err(ExpressionError::IncompatibleTypes(format!( "{} == {}", @@ -193,6 +235,12 @@ impl> ConstrainedProgra // (ResolvedValue::FieldElement(fe_1), ResolvedValue::FieldElement(fe_2)) => { // Self::field_geq(fe_1, fe_2) // } + (ConstrainedValue::Mutable(val_1), val_2) => { + self.evaluate_geq_expression(*val_1, val_2) + } + (val_1, ConstrainedValue::Mutable(val_2)) => { + self.evaluate_geq_expression(val_1, *val_2) + } (val_1, val_2) => Err(ExpressionError::IncompatibleTypes(format!( "{} >= {}, values must be fields", val_1, val_2 @@ -209,6 +257,8 @@ impl> ConstrainedProgra // (ResolvedValue::FieldElement(fe_1), ResolvedValue::FieldElement(fe_2)) => { // Self::field_gt(fe_1, fe_2) // } + (ConstrainedValue::Mutable(val_1), val_2) => self.evaluate_gt_expression(*val_1, val_2), + (val_1, ConstrainedValue::Mutable(val_2)) => self.evaluate_gt_expression(val_1, *val_2), (val_1, val_2) => Err(ExpressionError::IncompatibleTypes(format!( "{} > {}, values must be fields", val_1, val_2 @@ -225,6 +275,12 @@ impl> ConstrainedProgra // (ResolvedValue::FieldElement(fe_1), ResolvedValue::FieldElement(fe_2)) => { // Self::field_leq(fe_1, fe_2) // } + (ConstrainedValue::Mutable(val_1), val_2) => { + self.evaluate_leq_expression(*val_1, val_2) + } + (val_1, ConstrainedValue::Mutable(val_2)) => { + self.evaluate_leq_expression(val_1, *val_2) + } (val_1, val_2) => Err(ExpressionError::IncompatibleTypes(format!( "{} <= {}, values must be fields", val_1, val_2 @@ -241,6 +297,8 @@ impl> ConstrainedProgra // (ResolvedValue::FieldElement(fe_1), ResolvedValue::FieldElement(fe_2)) => { // Self::field_lt(fe_1, fe_2) // } + (ConstrainedValue::Mutable(val_1), val_2) => self.evaluate_lt_expression(*val_1, val_2), + (val_1, ConstrainedValue::Mutable(val_2)) => self.evaluate_lt_expression(val_1, *val_2), (val_1, val_2) => Err(ExpressionError::IncompatibleTypes(format!( "{} < {}, values must be fields", val_1, val_2, @@ -348,47 +406,74 @@ impl> ConstrainedProgra cs: &mut CS, file_scope: String, function_scope: String, - variable: Identifier, - members: Vec>, + identifier: Identifier, + members: Vec>, ) -> Result, ExpressionError> { - let circuit_name = new_variable_from_variable(file_scope.clone(), &variable); + let mut program_identifier = new_scope(file_scope.clone(), identifier.to_string()); + + if identifier.is_self() { + program_identifier = file_scope.clone(); + } if let Some(ConstrainedValue::CircuitDefinition(circuit_definition)) = - self.get_mut_variable(&circuit_name) + self.get_mut(&program_identifier) { + let circuit_identifier = circuit_definition.identifier.clone(); let mut resolved_members = vec![]; - for (field, member) in circuit_definition - .fields - .clone() - .into_iter() - .zip(members.clone().into_iter()) - { - if field.identifier != member.identifier { - return Err(ExpressionError::InvalidCircuitObject( - field.identifier.name, - member.identifier.name, - )); - } - // Resolve and enforce circuit fields - let member_value = self.enforce_expression( - cs, - file_scope.clone(), - function_scope.clone(), - member.expression, - )?; + for member in circuit_definition.members.clone().into_iter() { + match member { + CircuitMember::CircuitField(identifier, _type) => { + let matched_field = members + .clone() + .into_iter() + .find(|field| field.identifier.eq(&identifier)); + match matched_field { + Some(field) => { + // Resolve and enforce circuit object + let field_value = self.enforce_expression( + cs, + file_scope.clone(), + function_scope.clone(), + field.expression, + )?; - // Check member types - member_value.expect_type(&field._type)?; + // Check field type + field_value.expect_type(&_type)?; - resolved_members.push(ConstrainedCircuitObject(member.identifier, member_value)) + resolved_members + .push(ConstrainedCircuitMember(identifier, field_value)) + } + None => { + return Err(ExpressionError::ExpectedCircuitValue( + identifier.to_string(), + )) + } + } + } + CircuitMember::CircuitFunction(_static, function) => { + let identifier = function.function_name.clone(); + let mut constrained_function_value = + ConstrainedValue::Function(Some(circuit_identifier.clone()), function); + + if _static { + constrained_function_value = + ConstrainedValue::Static(Box::new(constrained_function_value)); + } + + resolved_members.push(ConstrainedCircuitMember( + identifier, + constrained_function_value, + )); + } + }; } Ok(ConstrainedValue::CircuitExpression( - variable, + circuit_identifier.clone(), resolved_members, )) } else { - Err(ExpressionError::UndefinedCircuit(variable.to_string())) + Err(ExpressionError::UndefinedCircuit(identifier.to_string())) } } @@ -397,50 +482,139 @@ impl> ConstrainedProgra cs: &mut CS, file_scope: String, function_scope: String, - circuit_variable: Box>, + circuit_identifier: Box>, circuit_member: Identifier, ) -> Result, ExpressionError> { - let members = match self.enforce_expression( + let (circuit_name, members) = match self.enforce_expression( cs, file_scope.clone(), function_scope.clone(), - *circuit_variable.clone(), + *circuit_identifier.clone(), )? { - ConstrainedValue::CircuitExpression(_name, members) => members, + ConstrainedValue::CircuitExpression(name, members) => (name, members), ConstrainedValue::Mutable(value) => match *value { - ConstrainedValue::CircuitExpression(_name, members) => members, + ConstrainedValue::CircuitExpression(name, members) => (name, members), value => return Err(ExpressionError::InvalidCircuitAccess(value.to_string())), }, value => return Err(ExpressionError::InvalidCircuitAccess(value.to_string())), }; let matched_member = members + .clone() .into_iter() .find(|member| member.0 == circuit_member); match matched_member { - Some(member) => Ok(member.1), + Some(member) => { + match &member.1 { + ConstrainedValue::Function(ref _circuit_identifier, ref _function) => { + // Pass static circuit fields into function call by value + for stored_member in members { + match &stored_member.1 { + ConstrainedValue::Function(_, _) => {} + ConstrainedValue::Static(_) => {} + _ => { + let circuit_scope = + new_scope(file_scope.clone(), circuit_name.to_string()); + let function_scope = + new_scope(circuit_scope, member.0.to_string()); + let field = + new_scope(function_scope, stored_member.0.to_string()); + + self.store(field, stored_member.1.clone()); + } + } + } + } + _ => {} + } + Ok(member.1) + } None => Err(ExpressionError::UndefinedCircuitObject( circuit_member.to_string(), )), } } + fn enforce_circuit_static_access_expression( + &mut self, + cs: &mut CS, + file_scope: String, + function_scope: String, + circuit_identifier: Box>, + circuit_member: Identifier, + ) -> Result, ExpressionError> { + // Get defined circuit + let circuit = match self.enforce_expression( + cs, + file_scope.clone(), + function_scope.clone(), + *circuit_identifier.clone(), + )? { + ConstrainedValue::CircuitDefinition(circuit_definition) => circuit_definition, + value => return Err(ExpressionError::InvalidCircuitAccess(value.to_string())), + }; + + // Find static circuit function + let matched_function = circuit.members.into_iter().find(|member| match member { + CircuitMember::CircuitFunction(_static, _function) => *_static, + _ => false, + }); + + // Return errors if no static function exists + let function = match matched_function { + Some(CircuitMember::CircuitFunction(_static, function)) => { + if _static { + function + } else { + return Err(ExpressionError::InvalidStaticFunction( + function.function_name.to_string(), + )); + } + } + _ => { + return Err(ExpressionError::UndefinedStaticFunction( + circuit.identifier.to_string(), + circuit_member.to_string(), + )) + } + }; + + Ok(ConstrainedValue::Function( + Some(circuit.identifier), + function, + )) + } + fn enforce_function_call_expression( &mut self, cs: &mut CS, file_scope: String, function_scope: String, - function: Identifier, + function: Box>, arguments: Vec>, ) -> Result, ExpressionError> { - let function_name = new_variable_from_variable(file_scope.clone(), &function); - let function_call = match self.get(&function_name.to_string()) { - Some(ConstrainedValue::Function(function)) => function.clone(), - _ => return Err(ExpressionError::UndefinedFunction(function.to_string())), + let function_value = self.enforce_expression( + cs, + file_scope.clone(), + function_scope.clone(), + *function.clone(), + )?; + + let (outer_scope, function_call) = match function_value { + ConstrainedValue::Function(circuit_identifier, function) => { + let mut outer_scope = file_scope.clone(); + // If this is a circuit function, evaluate inside the circuit scope + if circuit_identifier.is_some() { + outer_scope = new_scope(file_scope, circuit_identifier.unwrap().to_string()); + } + + (outer_scope, function.clone()) + } + value => return Err(ExpressionError::UndefinedFunction(value.to_string())), }; - match self.enforce_function(cs, file_scope, function_scope, function_call, arguments) { + match self.enforce_function(cs, outer_scope, function_scope, function_call, arguments) { Ok(ConstrainedValue::Return(return_values)) => { if return_values.len() == 1 { Ok(return_values[0].clone()) @@ -463,7 +637,7 @@ impl> ConstrainedProgra match expression { // Variables Expression::Identifier(unresolved_variable) => { - self.evaluate_identifier(function_scope, unresolved_variable) + self.evaluate_identifier(file_scope, function_scope, unresolved_variable) } // Values @@ -669,6 +843,14 @@ impl> ConstrainedProgra circuit_variable, circuit_member, ), + Expression::CircuitStaticFunctionAccess(circuit_identifier, circuit_member) => self + .enforce_circuit_static_access_expression( + cs, + file_scope, + function_scope, + circuit_identifier, + circuit_member, + ), // Functions Expression::FunctionCall(function, arguments) => self.enforce_function_call_expression( diff --git a/compiler/src/constraints/function.rs b/compiler/src/constraints/function.rs index 573c6bd068..c1e0517070 100644 --- a/compiler/src/constraints/function.rs +++ b/compiler/src/constraints/function.rs @@ -2,7 +2,7 @@ //! a resolved Leo program. use crate::{ - constraints::{new_scope, new_variable_from_variables, ConstrainedProgram, ConstrainedValue}, + constraints::{new_scope, ConstrainedProgram, ConstrainedValue}, errors::{FunctionError, ImportError}, types::{Expression, Function, Identifier, InputValue, Program, Type}, }; @@ -31,8 +31,8 @@ impl> ConstrainedProgra input: Expression, ) -> Result, FunctionError> { match input { - Expression::Identifier(variable) => { - Ok(self.evaluate_identifier(caller_scope, variable)?) + Expression::Identifier(identifier) => { + Ok(self.evaluate_identifier(caller_scope, function_name, identifier)?) } expression => Ok(self.enforce_expression(cs, scope, function_name, expression)?), } @@ -231,12 +231,12 @@ impl> ConstrainedProgra program .circuits .into_iter() - .for_each(|(variable, circuit_def)| { + .for_each(|(identifier, circuit)| { let resolved_circuit_name = - new_variable_from_variables(&program_name.clone(), &variable); - self.store_variable( + new_scope(program_name.to_string(), identifier.to_string()); + self.store( resolved_circuit_name, - ConstrainedValue::CircuitDefinition(circuit_def), + ConstrainedValue::CircuitDefinition(circuit), ); }); @@ -246,8 +246,11 @@ impl> ConstrainedProgra .into_iter() .for_each(|(function_name, function)| { let resolved_function_name = - new_scope(program_name.name.clone(), function_name.name); - self.store(resolved_function_name, ConstrainedValue::Function(function)); + new_scope(program_name.to_string(), function_name.to_string()); + self.store( + resolved_function_name, + ConstrainedValue::Function(None, function), + ); }); Ok(()) diff --git a/compiler/src/constraints/import.rs b/compiler/src/constraints/import.rs index 74beb3f8f9..580e4e8297 100644 --- a/compiler/src/constraints/import.rs +++ b/compiler/src/constraints/import.rs @@ -1,7 +1,8 @@ use crate::{ ast, - constraints::{new_variable_from_variables, ConstrainedProgram, ConstrainedValue}, + constraints::{ConstrainedProgram, ConstrainedValue}, errors::constraints::ImportError, + new_scope, types::Program, Import, }; @@ -46,7 +47,7 @@ impl> ConstrainedProgra let program_name = program.name.clone(); // match each import symbol to a symbol in the imported file - import.symbols.into_iter().for_each(|symbol| { + for symbol in import.symbols.into_iter() { // see if the imported symbol is a circuit let matched_circuit = program .circuits @@ -54,18 +55,9 @@ impl> ConstrainedProgra .into_iter() .find(|(circuit_name, _circuit_def)| symbol.symbol == *circuit_name); - match matched_circuit { + let value = match matched_circuit { Some((_circuit_name, circuit_def)) => { - // take the alias if it is present - let resolved_name = symbol.alias.unwrap_or(symbol.symbol); - let resolved_circuit_name = - new_variable_from_variables(&program_name.clone(), &resolved_name); - - // store imported circuit under resolved name - self.store_variable( - resolved_circuit_name, - ConstrainedValue::CircuitDefinition(circuit_def), - ); + ConstrainedValue::CircuitDefinition(circuit_def) } None => { // see if the imported symbol is a function @@ -75,18 +67,7 @@ impl> ConstrainedProgra match matched_function { Some((_function_name, function)) => { - // take the alias if it is present - let resolved_name = symbol.alias.unwrap_or(symbol.symbol); - let resolved_function_name = new_variable_from_variables( - &program_name.clone(), - &resolved_name, - ); - - // store imported function under resolved name - self.store_variable( - resolved_function_name, - ConstrainedValue::Function(function), - ) + ConstrainedValue::Function(None, function) } None => unimplemented!( "cannot find imported symbol {} in imported file {}", @@ -95,8 +76,16 @@ impl> ConstrainedProgra ), } } - } - }); + }; + + // take the alias if it is present + let resolved_name = symbol.alias.unwrap_or(symbol.symbol); + let resolved_circuit_name = + new_scope(program_name.to_string(), resolved_name.to_string()); + + // store imported circuit under resolved name + self.store(resolved_circuit_name, value); + } // evaluate all import statements in imported file program diff --git a/compiler/src/constraints/mod.rs b/compiler/src/constraints/mod.rs index 0521d779ad..65a1ab2fa5 100644 --- a/compiler/src/constraints/mod.rs +++ b/compiler/src/constraints/mod.rs @@ -56,7 +56,7 @@ pub fn generate_constraints { + ConstrainedValue::Function(_circuit_identifier, function) => { let result = resolved_program.enforce_main_function(cs, program_name, function, parameters)?; log::debug!("{}", result); diff --git a/compiler/src/constraints/program.rs b/compiler/src/constraints/program.rs index b93f940b13..e2a61249ac 100644 --- a/compiler/src/constraints/program.rs +++ b/compiler/src/constraints/program.rs @@ -58,22 +58,6 @@ impl> ConstrainedProgra self.identifiers.insert(name, value); } - pub(crate) fn store_variable( - &mut self, - variable: Identifier, - value: ConstrainedValue, - ) { - self.store(variable.name, value); - } - - pub(crate) fn contains_name(&self, name: &String) -> bool { - self.identifiers.contains_key(name) - } - - pub(crate) fn contains_variable(&self, variable: &Identifier) -> bool { - self.contains_name(&variable.name) - } - pub(crate) fn get(&self, name: &String) -> Option<&ConstrainedValue> { self.identifiers.get(name) } @@ -81,11 +65,4 @@ impl> ConstrainedProgra pub(crate) fn get_mut(&mut self, name: &String) -> Option<&mut ConstrainedValue> { self.identifiers.get_mut(name) } - - pub(crate) fn get_mut_variable( - &mut self, - variable: &Identifier, - ) -> Option<&mut ConstrainedValue> { - self.get_mut(&variable.name) - } } diff --git a/compiler/src/constraints/statement.rs b/compiler/src/constraints/statement.rs index 504a976770..c491350ab7 100644 --- a/compiler/src/constraints/statement.rs +++ b/compiler/src/constraints/statement.rs @@ -21,12 +21,26 @@ impl> ConstrainedProgra match assignee { Assignee::Identifier(name) => new_scope_from_variable(scope, &name), Assignee::Array(array, _index) => self.resolve_assignee(scope, *array), - Assignee::CircuitMember(circuit_variable, _member) => { - self.resolve_assignee(scope, *circuit_variable) + Assignee::CircuitField(circuit_name, _member) => { + self.resolve_assignee(scope, *circuit_name) } } } + fn get_mutable_assignee( + &mut self, + name: String, + ) -> Result<&mut ConstrainedValue, StatementError> { + // Check that assignee exists and is mutable + Ok(match self.get_mut(&name) { + Some(value) => match value { + ConstrainedValue::Mutable(mutable_value) => mutable_value, + _ => return Err(StatementError::ImmutableAssign(name)), + }, + None => return Err(StatementError::UndefinedVariable(name)), + }) + } + fn mutate_array( &mut self, cs: &mut CS, @@ -43,7 +57,7 @@ impl> ConstrainedProgra self.enforce_index(cs, file_scope.clone(), function_scope.clone(), index)?; // Modify the single value of the array in place - match self.get_mutable_variable(name)? { + match self.get_mutable_assignee(name)? { ConstrainedValue::Array(old) => { old[index] = new_value; } @@ -61,7 +75,7 @@ impl> ConstrainedProgra }; // Modify the range of values of the array in place - match (self.get_mutable_variable(name)?, new_value) { + match (self.get_mutable_assignee(name)?, new_value) { (ConstrainedValue::Array(old), ConstrainedValue::Array(ref new)) => { let to_index = to_index_option.unwrap_or(old.len()); old.splice(from_index..to_index, new.iter().cloned()); @@ -74,19 +88,29 @@ impl> ConstrainedProgra Ok(()) } - fn mutute_circuit_object( + fn mutute_circuit_field( &mut self, circuit_name: String, object_name: Identifier, new_value: ConstrainedValue, ) -> Result<(), StatementError> { - match self.get_mutable_variable(circuit_name)? { - ConstrainedValue::CircuitExpression(_variable, objects) => { - // Modify the circuit member in place - let matched_object = objects.into_iter().find(|object| object.0 == object_name); + match self.get_mutable_assignee(circuit_name)? { + ConstrainedValue::CircuitExpression(_variable, members) => { + // Modify the circuit field in place + let matched_field = members.into_iter().find(|object| object.0 == object_name); - match matched_object { - Some(mut object) => object.1 = new_value.to_owned(), + match matched_field { + Some(object) => match &object.1 { + ConstrainedValue::Function(_circuit_identifier, function) => { + return Err(StatementError::ImmutableCircuitFunction( + function.function_name.to_string(), + )) + } + ConstrainedValue::Static(_value) => { + return Err(StatementError::ImmutableCircuitFunction("static".into())) + } + _ => object.1 = new_value.to_owned(), + }, None => { return Err(StatementError::UndefinedCircuitObject( object_name.to_string(), @@ -100,20 +124,6 @@ impl> ConstrainedProgra Ok(()) } - fn get_mutable_variable( - &mut self, - name: String, - ) -> Result<&mut ConstrainedValue, StatementError> { - // Check that assignee exists and is mutable - Ok(match self.get_mut(&name) { - Some(value) => match value { - ConstrainedValue::Mutable(mutable_value) => mutable_value, - _ => return Err(StatementError::ImmutableAssign(name)), - }, - None => return Err(StatementError::UndefinedVariable(name)), - }) - } - fn enforce_assign_statement( &mut self, cs: &mut CS, @@ -132,7 +142,7 @@ impl> ConstrainedProgra // Mutate the old value into the new value match assignee { Assignee::Identifier(_identifier) => { - let old_value = self.get_mutable_variable(variable_name.clone())?; + let old_value = self.get_mutable_assignee(variable_name.clone())?; *old_value = new_value; @@ -146,8 +156,8 @@ impl> ConstrainedProgra range_or_expression, new_value, ), - Assignee::CircuitMember(_assignee, object_name) => { - self.mutute_circuit_object(variable_name, object_name, new_value) + Assignee::CircuitField(_assignee, object_name) => { + self.mutute_circuit_field(variable_name, object_name, new_value) } } } diff --git a/compiler/src/constraints/value.rs b/compiler/src/constraints/value.rs index 4e87f55a7b..3a2f69360a 100644 --- a/compiler/src/constraints/value.rs +++ b/compiler/src/constraints/value.rs @@ -13,7 +13,7 @@ use snarkos_models::{ use std::fmt; #[derive(Clone, PartialEq, Eq)] -pub struct ConstrainedCircuitObject( +pub struct ConstrainedCircuitMember( pub Identifier, pub ConstrainedValue, ); @@ -24,12 +24,17 @@ pub enum ConstrainedValue { FieldElement(FieldElement), GroupElement(G), Boolean(Boolean), + Array(Vec>), + CircuitDefinition(Circuit), - CircuitExpression(Identifier, Vec>), - Function(Function), + CircuitExpression(Identifier, Vec>), + + Function(Option>, Function), // (optional circuit identifier, function definition) Return(Vec>), + Mutable(Box>), + Static(Box>), } impl ConstrainedValue { @@ -63,10 +68,21 @@ impl ConstrainedValue { Type::Circuit(ref expected_name), ) => { if expected_name != actual_name { - return Err(ValueError::StructName(format!( - "Expected struct name {} got {}", - expected_name, actual_name - ))); + return Err(ValueError::CircuitName( + expected_name.to_string(), + actual_name.to_string(), + )); + } + } + ( + ConstrainedValue::CircuitExpression(ref actual_name, ref _members), + Type::SelfType, + ) => { + if Identifier::new("Self".into()) == *actual_name { + return Err(ValueError::CircuitName( + "Self".into(), + actual_name.to_string(), + )); } } (ConstrainedValue::Return(ref values), _type) => { @@ -77,6 +93,9 @@ impl ConstrainedValue { (ConstrainedValue::Mutable(ref value), _type) => { value.expect_type(&_type)?; } + (ConstrainedValue::Static(ref value), _type) => { + value.expect_type(&_type)?; + } (value, _type) => { return Err(ValueError::TypeError(format!( "expected type {}, got {}", @@ -106,8 +125,8 @@ impl fmt::Display for ConstrainedValue { } write!(f, "]") } - ConstrainedValue::CircuitExpression(ref variable, ref members) => { - write!(f, "{} {{", variable)?; + ConstrainedValue::CircuitExpression(ref identifier, ref members) => { + write!(f, "{} {{", identifier)?; for (i, member) in members.iter().enumerate() { write!(f, "{}: {}", member.0, member.1)?; if i < members.len() - 1 { @@ -127,10 +146,13 @@ impl fmt::Display for ConstrainedValue { write!(f, "]") } ConstrainedValue::CircuitDefinition(ref _definition) => { - unimplemented!("cannot return struct definition in program") + unimplemented!("cannot return circuit definition in program") + } + ConstrainedValue::Function(ref _circuit_option, ref function) => { + write!(f, "{}();", function.function_name) } - ConstrainedValue::Function(ref function) => write!(f, "{}();", function.function_name), ConstrainedValue::Mutable(ref value) => write!(f, "mut {}", value), + ConstrainedValue::Static(ref value) => write!(f, "static {}", value), } } } diff --git a/compiler/src/errors/constraints/expression.rs b/compiler/src/errors/constraints/expression.rs index 514caeee9e..615febaa85 100644 --- a/compiler/src/errors/constraints/expression.rs +++ b/compiler/src/errors/constraints/expression.rs @@ -2,9 +2,9 @@ use crate::errors::{BooleanError, FieldElementError, FunctionError, IntegerError #[derive(Debug, Error)] pub enum ExpressionError { - // Variables - #[error("Variable \"{}\" not found", _0)] - UndefinedVariable(String), + // Identifiers + #[error("Identifier \"{}\" not found", _0)] + UndefinedIdentifier(String), // Types #[error("{}", _0)] @@ -51,12 +51,21 @@ pub enum ExpressionError { #[error("Circuit object {} does not exist", _0)] UndefinedCircuitObject(String), - #[error("Expected circuit object {}, got {}", _0, _1)] - InvalidCircuitObject(String, String), - #[error("Cannot access circuit {}", _0)] InvalidCircuitAccess(String), + #[error("Expected circuit value {}", _0)] + ExpectedCircuitValue(String), + + #[error("Circuit {} has no static function {}", _0, _1)] + UndefinedStaticFunction(String, String), + + #[error( + "Static access only supported for static circuit functions, got function {}", + _0 + )] + InvalidStaticFunction(String), + // Functions #[error( "Function {} must be declared before it is used in an inline expression", diff --git a/compiler/src/errors/constraints/statement.rs b/compiler/src/errors/constraints/statement.rs index 1ee808c666..8e864262f2 100644 --- a/compiler/src/errors/constraints/statement.rs +++ b/compiler/src/errors/constraints/statement.rs @@ -31,6 +31,9 @@ pub enum StatementError { UndefinedArray(String), // Circuits + #[error("Cannot mutate circuit function, {}", _0)] + ImmutableCircuitFunction(String), + #[error("Attempted to assign to unknown circuit {}", _0)] UndefinedCircuit(String), diff --git a/compiler/src/errors/constraints/value.rs b/compiler/src/errors/constraints/value.rs index fc06f90e31..4428083f35 100644 --- a/compiler/src/errors/constraints/value.rs +++ b/compiler/src/errors/constraints/value.rs @@ -12,9 +12,9 @@ pub enum ValueError { #[error("{}", _0)] IntegerError(IntegerError), - /// Unexpected struct name - #[error("{}", _0)] - StructName(String), + /// Unexpected circuit name + #[error("Expected circuit name {} got {}", _0, _1)] + CircuitName(String, String), /// Unexpected type #[error("{}", _0)] diff --git a/compiler/src/leo.pest b/compiler/src/leo.pest index ebb4e3670e..5996f063c4 100644 --- a/compiler/src/leo.pest +++ b/compiler/src/leo.pest @@ -79,10 +79,11 @@ type_integer = { type_field = {"field"} type_group = {"group"} type_bool = {"bool"} +type_self = {"Self"} type_basic = { type_field | type_group | type_bool | type_integer } type_circuit = { identifier } type_array = {type_basic ~ ("[" ~ value ~ "]")+ } -_type = {type_array | type_basic | type_circuit} +_type = {type_self | type_array | type_basic | type_circuit} type_list = _{(_type ~ ("," ~ _type)*)?} /// Values @@ -112,7 +113,8 @@ range_or_expression = { range | expression } access_array = { "[" ~ range_or_expression ~ "]" } access_call = { "(" ~ expression_tuple ~ ")" } access_member = { "." ~ identifier } -access = { access_array | access_call | access_member } +access_static_member = { "::" ~ identifier } +access = { access_array | access_call | access_member | access_static_member} expression_postfix = { identifier ~ access+ } @@ -130,13 +132,18 @@ expression_array_initializer = { "[" ~ spread_or_expression ~ ";" ~ value ~ "]" /// Circuits -circuit_object = { identifier ~ ":" ~ _type } -circuit_object_list = _{(circuit_object ~ (NEWLINE+ ~ circuit_object)*)? } -circuit_definition = { "circuit" ~ identifier ~ "{" ~ NEWLINE* ~ circuit_object_list ~ NEWLINE* ~ "}" ~ NEWLINE* } +circuit_field_definition = { identifier ~ ":" ~ _type ~ NEWLINE* } -inline_circuit_member = { identifier ~ ":" ~ expression } -inline_circuit_member_list = _{(inline_circuit_member ~ ("," ~ NEWLINE* ~ inline_circuit_member)*)? ~ ","? } -expression_inline_circuit = { identifier ~ "{" ~ NEWLINE* ~ inline_circuit_member_list ~ NEWLINE* ~ "}" } +_static = {"static"} +circuit_function = {_static? ~ function_definition } + +circuit_member = { circuit_function | circuit_field_definition } + +circuit_definition = { "circuit" ~ identifier ~ "{" ~ NEWLINE* ~ circuit_member* ~ NEWLINE* ~ "}" ~ NEWLINE* } + +circuit_field = { identifier ~ ":" ~ expression } +circuit_field_list = _{(circuit_field ~ ("," ~ NEWLINE* ~ circuit_field)*)? ~ ","? } +expression_circuit_inline = { identifier ~ "{" ~ NEWLINE* ~ circuit_field_list ~ NEWLINE* ~ "}" } /// Conditionals @@ -146,7 +153,7 @@ expression_conditional = { "if" ~ expression ~ "?" ~ expression ~ ":" ~ expressi expression_term = { ("(" ~ expression ~ ")") - | expression_inline_circuit + | expression_circuit_inline | expression_conditional | expression_postfix | expression_primitive diff --git a/compiler/src/types.rs b/compiler/src/types.rs index 8efef3111c..8fbea85e70 100644 --- a/compiler/src/types.rs +++ b/compiler/src/types.rs @@ -1,4 +1,4 @@ -//! A typed Leo program consists of import, struct, and function definitions. +//! A typed Leo program consists of import, circuit, and function definitions. //! Each defined type consists of typed statements and expressions. use crate::{errors::IntegerError, Import}; @@ -30,6 +30,10 @@ impl Identifier { _engine: PhantomData::, } } + + pub fn is_self(&self) -> bool { + self.name == "Self" + } } /// A variable that is assigned to a value in the constrained program @@ -153,11 +157,12 @@ pub enum Expression { ArrayAccess(Box>, Box>), // (array name, range) // Circuits - Circuit(Identifier, Vec>), - CircuitMemberAccess(Box>, Identifier), // (circuit name, circuit object name) + Circuit(Identifier, Vec>), + CircuitMemberAccess(Box>, Identifier), // (declared circuit name, circuit member name) + CircuitStaticFunctionAccess(Box>, Identifier), // (defined circuit name, circuit static member name) // Functions - FunctionCall(Identifier, Vec>), + FunctionCall(Box>, Vec>), } /// Definition assignee: v, arr[0..2], Point p.x @@ -165,7 +170,7 @@ pub enum Expression { pub enum Assignee { Identifier(Identifier), Array(Box>, RangeOrExpression), - CircuitMember(Box>, Identifier), // (circuit name, circuit object name) + CircuitField(Box>, Identifier), // (circuit name, circuit field name) } /// Explicit integer type @@ -187,6 +192,7 @@ pub enum Type { Boolean, Array(Box>, Vec), Circuit(Identifier), + SelfType, } impl Type { @@ -230,22 +236,24 @@ pub enum Statement { Expression(Expression), } +/// Circuits + #[derive(Clone, Debug, PartialEq, Eq)] -pub struct CircuitMember { +pub struct CircuitFieldDefinition { pub identifier: Identifier, pub expression: Expression, } #[derive(Clone, PartialEq, Eq)] -pub struct CircuitObject { - pub identifier: Identifier, - pub _type: Type, +pub enum CircuitMember { + CircuitField(Identifier, Type), + CircuitFunction(bool, Function), } #[derive(Clone, PartialEq, Eq)] pub struct Circuit { pub identifier: Identifier, - pub fields: Vec>, + pub members: Vec>, } /// Function parameters diff --git a/compiler/src/types_display.rs b/compiler/src/types_display.rs index ce061987ab..c90a1497b6 100644 --- a/compiler/src/types_display.rs +++ b/compiler/src/types_display.rs @@ -1,7 +1,7 @@ //! Format display functions for Leo types. use crate::{ - Assignee, Circuit, CircuitObject, ConditionalNestedOrEnd, ConditionalStatement, Expression, + Assignee, Circuit, CircuitMember, ConditionalNestedOrEnd, ConditionalStatement, Expression, FieldElement, Function, Identifier, InputModel, InputValue, Integer, IntegerType, RangeOrExpression, SpreadOrExpression, Statement, Type, Variable, }; @@ -154,8 +154,11 @@ impl<'ast, F: Field + PrimeField, G: Group> fmt::Display for Expression { } write!(f, "}}") } - Expression::CircuitMemberAccess(ref circuit_variable, ref member) => { - write!(f, "{}.{}", circuit_variable, member) + Expression::CircuitMemberAccess(ref circuit_name, ref member) => { + write!(f, "{}.{}", circuit_name, member) + } + Expression::CircuitStaticFunctionAccess(ref circuit_name, ref member) => { + write!(f, "{}::{}", circuit_name, member) } // Function calls @@ -178,7 +181,7 @@ impl fmt::Display for Assignee { match *self { Assignee::Identifier(ref variable) => write!(f, "{}", variable), Assignee::Array(ref array, ref index) => write!(f, "{}[{}]", array, index), - Assignee::CircuitMember(ref circuit_variable, ref member) => { + Assignee::CircuitField(ref circuit_variable, ref member) => { write!(f, "{}.{}", circuit_variable, member) } } @@ -278,6 +281,7 @@ impl fmt::Display for Type { Type::GroupElement => write!(f, "group"), Type::Boolean => write!(f, "bool"), Type::Circuit(ref variable) => write!(f, "{}", variable), + Type::SelfType => write!(f, "Self"), Type::Array(ref array, ref dimensions) => { write!(f, "{}", *array)?; for row in dimensions { @@ -289,23 +293,33 @@ impl fmt::Display for Type { } } -impl fmt::Display for CircuitObject { +impl fmt::Display for CircuitMember { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{}: {}", self.identifier, self._type) + match self { + CircuitMember::CircuitField(ref identifier, ref _type) => { + write!(f, "{}: {}", identifier, _type) + } + CircuitMember::CircuitFunction(ref _static, ref function) => { + if *_static { + write!(f, "static ")?; + } + write!(f, "{}", function) + } + } } } impl Circuit { fn format(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "circuit {} {{ \n", self.identifier)?; - for field in self.fields.iter() { + for field in self.members.iter() { write!(f, " {}\n", field)?; } write!(f, "}}") } } -// impl fmt::Display for Struct {// uncomment when we no longer print out Program +// impl fmt::Display for Circuit {// uncomment when we no longer print out Program // fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { // self.format(f) // } @@ -385,11 +399,11 @@ impl Function { } } -// impl fmt::Display for Function {// uncomment when we no longer print out Program -// fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { -// self.format(f) -// } -// } +impl fmt::Display for Function { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + self.format(f) + } +} impl fmt::Debug for Function { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { diff --git a/compiler/src/types_from.rs b/compiler/src/types_from.rs index 4495e8e7ba..bd8513109e 100644 --- a/compiler/src/types_from.rs +++ b/compiler/src/types_from.rs @@ -276,11 +276,11 @@ impl<'ast, F: Field + PrimeField, G: Group> From From> - for types::CircuitMember +impl<'ast, F: Field + PrimeField, G: Group> From> + for types::CircuitFieldDefinition { - fn from(member: ast::InlineCircuitMember<'ast>) -> Self { - types::CircuitMember { + fn from(member: ast::CircuitField<'ast>) -> Self { + types::CircuitFieldDefinition { identifier: types::Identifier::from(member.identifier), expression: types::Expression::from(member.expression), } @@ -295,8 +295,8 @@ impl<'ast, F: Field + PrimeField, G: Group> From>>(); + .map(|member| types::CircuitFieldDefinition::from(member)) + .collect::>>(); types::Expression::Circuit(variable, members) } @@ -317,27 +317,33 @@ impl<'ast, F: Field + PrimeField, G: Group> From> .accesses .into_iter() .fold(variable, |acc, access| match access { - ast::Access::Call(function) => match acc { - types::Expression::Identifier(variable) => types::Expression::FunctionCall( - variable, - function - .expressions - .into_iter() - .map(|expression| types::Expression::from(expression)) - .collect(), - ), - expression => { - unimplemented!("only function names are callable, found \"{}\"", expression) - } - }, - ast::Access::Member(struct_member) => types::Expression::CircuitMemberAccess( - Box::new(acc), - types::Identifier::from(struct_member.identifier), - ), + // Handle array accesses ast::Access::Array(array) => types::Expression::ArrayAccess( Box::new(acc), Box::new(types::RangeOrExpression::from(array.expression)), ), + + // Handle function calls + ast::Access::Call(function) => types::Expression::FunctionCall( + Box::new(acc), + function + .expressions + .into_iter() + .map(|expression| types::Expression::from(expression)) + .collect(), + ), + + // Handle circuit member accesses + ast::Access::Object(circuit_object) => types::Expression::CircuitMemberAccess( + Box::new(acc), + types::Identifier::from(circuit_object.identifier), + ), + ast::Access::StaticObject(circuit_object) => { + types::Expression::CircuitStaticFunctionAccess( + Box::new(acc), + types::Identifier::from(circuit_object.identifier), + ) + } }) } } @@ -383,10 +389,10 @@ impl<'ast, F: Field + PrimeField, G: Group> From> for types: .accesses .into_iter() .fold(variable, |acc, access| match access { - ast::AssigneeAccess::Member(struct_member) => { + ast::AssigneeAccess::Member(circuit_member) => { types::Expression::CircuitMemberAccess( Box::new(acc), - types::Identifier::from(struct_member.identifier), + types::Identifier::from(circuit_member.identifier), ) } ast::AssigneeAccess::Array(array) => types::Expression::ArrayAccess( @@ -418,9 +424,9 @@ impl<'ast, F: Field + PrimeField, G: Group> From> for types: Box::new(acc), types::RangeOrExpression::from(array.expression), ), - ast::AssigneeAccess::Member(struct_member) => types::Assignee::CircuitMember( + ast::AssigneeAccess::Member(circuit_field) => types::Assignee::CircuitField( Box::new(acc), - types::Identifier::from(struct_member.identifier), + types::Identifier::from(circuit_field.identifier), ), }) } @@ -524,7 +530,7 @@ impl<'ast, F: Field + PrimeField, G: Group> From From> for types } } -/// pest ast -> Explicit types::Type for defining struct members and function params +/// pest ast -> Explicit types::Type for defining circuit members and function params impl From for types::IntegerType { fn from(integer_type: ast::IntegerType) -> Self { @@ -650,14 +656,14 @@ impl From for types::IntegerType { } } -impl<'ast, F: Field + PrimeField, G: Group> From> for types::Type { - fn from(basic_type: ast::BasicType<'ast>) -> Self { +impl From for types::Type { + fn from(basic_type: ast::BasicType) -> Self { match basic_type { ast::BasicType::Integer(_type) => { types::Type::IntegerType(types::IntegerType::from(_type)) } ast::BasicType::Field(_type) => types::Type::FieldElement, - ast::BasicType::Group(_type) => unimplemented!(), + ast::BasicType::Group(_type) => types::Type::GroupElement, ast::BasicType::Boolean(_type) => types::Type::Boolean, } } @@ -677,8 +683,8 @@ impl<'ast, F: Field + PrimeField, G: Group> From> for types } impl<'ast, F: Field + PrimeField, G: Group> From> for types::Type { - fn from(struct_type: ast::CircuitType<'ast>) -> Self { - types::Type::Circuit(types::Identifier::from(struct_type.identifier)) + fn from(circuit_type: ast::CircuitType<'ast>) -> Self { + types::Type::Circuit(types::Identifier::from(circuit_type.identifier)) } } @@ -688,35 +694,62 @@ impl<'ast, F: Field + PrimeField, G: Group> From> for types::Typ ast::Type::Basic(_type) => types::Type::from(_type), ast::Type::Array(_type) => types::Type::from(_type), ast::Type::Circuit(_type) => types::Type::from(_type), + ast::Type::SelfType(_type) => types::Type::SelfType, } } } -/// pest ast -> types::Struct +/// pest ast -> types::Circuit -impl<'ast, F: Field + PrimeField, G: Group> From> - for types::CircuitObject +impl<'ast, F: Field + PrimeField, G: Group> From> + for types::CircuitMember { - fn from(struct_field: ast::CircuitObject<'ast>) -> Self { - types::CircuitObject { - identifier: types::Identifier::from(struct_field.identifier), - _type: types::Type::from(struct_field._type), + fn from(circuit_value: ast::CircuitFieldDefinition<'ast>) -> Self { + types::CircuitMember::CircuitField( + types::Identifier::from(circuit_value.identifier), + types::Type::from(circuit_value._type), + ) + } +} + +impl<'ast, F: Field + PrimeField, G: Group> From> + for types::CircuitMember +{ + fn from(circuit_function: ast::CircuitFunction<'ast>) -> Self { + types::CircuitMember::CircuitFunction( + circuit_function._static.is_some(), + types::Function::from(circuit_function.function), + ) + } +} + +impl<'ast, F: Field + PrimeField, G: Group> From> + for types::CircuitMember +{ + fn from(object: ast::CircuitMember<'ast>) -> Self { + match object { + ast::CircuitMember::CircuitFieldDefinition(circuit_value) => { + types::CircuitMember::from(circuit_value) + } + ast::CircuitMember::CircuitFunction(circuit_function) => { + types::CircuitMember::from(circuit_function) + } } } } impl<'ast, F: Field + PrimeField, G: Group> From> for types::Circuit { - fn from(struct_definition: ast::Circuit<'ast>) -> Self { - let variable = types::Identifier::from(struct_definition.identifier); - let fields = struct_definition - .fields + fn from(circuit: ast::Circuit<'ast>) -> Self { + let variable = types::Identifier::from(circuit.identifier); + let members = circuit + .members .into_iter() - .map(|struct_field| types::CircuitObject::from(struct_field)) + .map(|member| types::CircuitMember::from(member)) .collect(); types::Circuit { identifier: variable, - fields, + members, } } } @@ -804,14 +837,14 @@ impl<'ast, F: Field + PrimeField, G: Group> types::Program { .map(|import| Import::from(import)) .collect::>>(); - let mut structs = HashMap::new(); + let mut circuits = HashMap::new(); let mut functions = HashMap::new(); let mut num_parameters = 0usize; - file.circuits.into_iter().for_each(|struct_def| { - structs.insert( - types::Identifier::from(struct_def.identifier.clone()), - types::Circuit::from(struct_def), + file.circuits.into_iter().for_each(|circuit| { + circuits.insert( + types::Identifier::from(circuit.identifier.clone()), + types::Circuit::from(circuit), ); }); file.functions.into_iter().for_each(|function_def| { @@ -829,7 +862,7 @@ impl<'ast, F: Field + PrimeField, G: Group> types::Program { name: types::Identifier::new(name), num_parameters, imports, - circuits: structs, + circuits, functions, } }