support static circuit functions. add :: syntax

This commit is contained in:
collin 2020-05-14 12:31:19 -07:00
parent ba84bf0d6c
commit 5fb0b58b68
10 changed files with 132 additions and 75 deletions

View File

@ -1,7 +1,7 @@
circuit PedersenHash { circuit PedersenHash {
parameters: group[1] parameters: group[1]
function new(b: u32) -> u32 { static function new(b: u32) -> u32 {
return b return b
} }
} }
@ -9,7 +9,7 @@ circuit PedersenHash {
function main() -> u32{ function main() -> u32{
let parameters = [0group; 1]; let parameters = [0group; 1];
let pedersen = PedersenHash { parameters: parameters }; let pedersen = PedersenHash { parameters: parameters };
let b = pedersen.new(3); let b = PedersenHash::new(3);
return b return b
} }

View File

@ -438,8 +438,16 @@ pub struct ArrayAccess<'ast> {
} }
#[derive(Clone, Debug, FromPest, PartialEq)] #[derive(Clone, Debug, FromPest, PartialEq)]
#[pest_ast(rule(Rule::access_member))] #[pest_ast(rule(Rule::access_object))]
pub struct MemberAccess<'ast> { pub struct ObjectAccess<'ast> {
pub identifier: Identifier<'ast>,
#[pest_ast(outer())]
pub span: Span<'ast>,
}
#[derive(Clone, Debug, FromPest, PartialEq)]
#[pest_ast(rule(Rule::access_static_object))]
pub struct StaticObjectAccess<'ast> {
pub identifier: Identifier<'ast>, pub identifier: Identifier<'ast>,
#[pest_ast(outer())] #[pest_ast(outer())]
pub span: Span<'ast>, pub span: Span<'ast>,
@ -450,7 +458,8 @@ pub struct MemberAccess<'ast> {
pub enum Access<'ast> { pub enum Access<'ast> {
Array(ArrayAccess<'ast>), Array(ArrayAccess<'ast>),
Call(CallAccess<'ast>), Call(CallAccess<'ast>),
Member(MemberAccess<'ast>), Object(ObjectAccess<'ast>),
StaticObject(StaticObjectAccess<'ast>),
} }
#[derive(Clone, Debug, FromPest, PartialEq)] #[derive(Clone, Debug, FromPest, PartialEq)]
@ -466,7 +475,7 @@ pub struct PostfixExpression<'ast> {
#[pest_ast(rule(Rule::assignee_access))] #[pest_ast(rule(Rule::assignee_access))]
pub enum AssigneeAccess<'ast> { pub enum AssigneeAccess<'ast> {
Array(ArrayAccess<'ast>), Array(ArrayAccess<'ast>),
Member(MemberAccess<'ast>), Member(ObjectAccess<'ast>),
} }
impl<'ast> fmt::Display for AssigneeAccess<'ast> { impl<'ast> fmt::Display for AssigneeAccess<'ast> {

View File

@ -6,6 +6,7 @@ use crate::{
ConstrainedProgram, ConstrainedValue, ConstrainedProgram, ConstrainedValue,
}, },
errors::ExpressionError, errors::ExpressionError,
new_scope,
types::{ types::{
CircuitMember, CircuitObject, Expression, Identifier, RangeOrExpression, SpreadOrExpression, CircuitMember, CircuitObject, Expression, Identifier, RangeOrExpression, SpreadOrExpression,
}, },
@ -20,21 +21,23 @@ impl<F: Field + PrimeField, G: Group, CS: ConstraintSystem<F>> ConstrainedProgra
/// Enforce a variable expression by getting the resolved value /// Enforce a variable expression by getting the resolved value
pub(crate) fn evaluate_identifier( pub(crate) fn evaluate_identifier(
&mut self, &mut self,
scope: String, file_scope: String,
unresolved_variable: Identifier<F, G>, function_scope: String,
unresolved_identifier: Identifier<F, G>,
) -> Result<ConstrainedValue<F, G>, ExpressionError> { ) -> Result<ConstrainedValue<F, G>, ExpressionError> {
// Evaluate the variable name in the current function scope // Evaluate the identifier name in the current function scope
let variable_name = new_scope_from_variable(scope, &unresolved_variable); let variable_name = new_scope(function_scope, unresolved_identifier.to_string());
let identifier_name = new_scope(file_scope, unresolved_identifier.to_string());
if self.contains_name(&variable_name) { if let Some(variable) = self.get(&variable_name) {
// Reassigning variable to another variable // Reassigning variable to another variable
Ok(self.get_mut(&variable_name).unwrap().clone()) Ok(variable.clone())
} else if self.contains_variable(&unresolved_variable) { } else if let Some(identifier) = self.get(&identifier_name) {
// Check global scope (function and circuit names) // Check global scope (function and circuit names)
Ok(self.get_mut_variable(&unresolved_variable).unwrap().clone()) Ok(identifier.clone())
} else { } else {
Err(ExpressionError::UndefinedVariable( Err(ExpressionError::UndefinedIdentifier(
unresolved_variable.to_string(), unresolved_identifier.to_string(),
)) ))
} }
} }
@ -449,6 +452,53 @@ impl<F: Field + PrimeField, G: Group, CS: ConstraintSystem<F>> ConstrainedProgra
} }
} }
fn enforce_circuit_static_access_expression(
&mut self,
cs: &mut CS,
file_scope: String,
function_scope: String,
circuit_identifier: Box<Expression<F, G>>,
circuit_member: Identifier<F, G>,
) -> Result<ConstrainedValue<F, G>, ExpressionError> {
// Get defined circuit
let circuit = match self.enforce_expression(
cs,
file_scope.clone(),
function_scope.clone(),
*circuit_identifier.clone(),
)? {
ConstrainedValue::CircuitDefinition(circuit_definition) => circuit_definition,
value => return Err(ExpressionError::InvalidCircuitAccess(value.to_string())),
};
// Find static circuit function
let matched_function = circuit.objects.into_iter().find(|member| match member {
CircuitObject::CircuitFunction(_static, _function) => *_static,
_ => false,
});
// Return errors if no static function exists
let function = match matched_function {
Some(CircuitObject::CircuitFunction(_static, function)) => {
if _static {
function
} else {
return Err(ExpressionError::InvalidStaticFunction(
function.function_name.to_string(),
));
}
}
_ => {
return Err(ExpressionError::UndefinedStaticFunction(
circuit.identifier.to_string(),
circuit_member.to_string(),
))
}
};
Ok(ConstrainedValue::Function(function))
}
fn enforce_function_call_expression( fn enforce_function_call_expression(
&mut self, &mut self,
cs: &mut CS, cs: &mut CS,
@ -492,7 +542,7 @@ impl<F: Field + PrimeField, G: Group, CS: ConstraintSystem<F>> ConstrainedProgra
match expression { match expression {
// Variables // Variables
Expression::Identifier(unresolved_variable) => { Expression::Identifier(unresolved_variable) => {
self.evaluate_identifier(function_scope, unresolved_variable) self.evaluate_identifier(file_scope, function_scope, unresolved_variable)
} }
// Values // Values
@ -690,7 +740,7 @@ impl<F: Field + PrimeField, G: Group, CS: ConstraintSystem<F>> ConstrainedProgra
circuit_name, circuit_name,
members, members,
), ),
Expression::CircuitMemberAccess(circuit_variable, circuit_member) => self Expression::CircuitObjectAccess(circuit_variable, circuit_member) => self
.enforce_circuit_access_expression( .enforce_circuit_access_expression(
cs, cs,
file_scope, file_scope,
@ -698,6 +748,14 @@ impl<F: Field + PrimeField, G: Group, CS: ConstraintSystem<F>> ConstrainedProgra
circuit_variable, circuit_variable,
circuit_member, circuit_member,
), ),
Expression::CircuitStaticObjectAccess(circuit_identifier, circuit_member) => self
.enforce_circuit_static_access_expression(
cs,
file_scope,
function_scope,
circuit_identifier,
circuit_member,
),
// Functions // Functions
Expression::FunctionCall(function, arguments) => self.enforce_function_call_expression( Expression::FunctionCall(function, arguments) => self.enforce_function_call_expression(

View File

@ -31,8 +31,8 @@ impl<F: Field + PrimeField, G: Group, CS: ConstraintSystem<F>> ConstrainedProgra
input: Expression<F, G>, input: Expression<F, G>,
) -> Result<ConstrainedValue<F, G>, FunctionError> { ) -> Result<ConstrainedValue<F, G>, FunctionError> {
match input { match input {
Expression::Identifier(variable) => { Expression::Identifier(identifier) => {
Ok(self.evaluate_identifier(caller_scope, variable)?) Ok(self.evaluate_identifier(caller_scope, function_name, identifier)?)
} }
expression => Ok(self.enforce_expression(cs, scope, function_name, expression)?), expression => Ok(self.enforce_expression(cs, scope, function_name, expression)?),
} }

View File

@ -66,14 +66,6 @@ impl<F: Field + PrimeField, G: Group, CS: ConstraintSystem<F>> ConstrainedProgra
self.store(variable.name, value); self.store(variable.name, value);
} }
pub(crate) fn contains_name(&self, name: &String) -> bool {
self.identifiers.contains_key(name)
}
pub(crate) fn contains_variable(&self, variable: &Identifier<F, G>) -> bool {
self.contains_name(&variable.name)
}
pub(crate) fn get(&self, name: &String) -> Option<&ConstrainedValue<F, G>> { pub(crate) fn get(&self, name: &String) -> Option<&ConstrainedValue<F, G>> {
self.identifiers.get(name) self.identifiers.get(name)
} }

View File

@ -2,9 +2,9 @@ use crate::errors::{BooleanError, FieldElementError, FunctionError, IntegerError
#[derive(Debug, Error)] #[derive(Debug, Error)]
pub enum ExpressionError { pub enum ExpressionError {
// Variables // Identifiers
#[error("Variable \"{}\" not found", _0)] #[error("Variable \"{}\" not found", _0)]
UndefinedVariable(String), UndefinedIdentifier(String),
// Types // Types
#[error("{}", _0)] #[error("{}", _0)]
@ -63,6 +63,15 @@ pub enum ExpressionError {
#[error("Expected circuit value {}", _0)] #[error("Expected circuit value {}", _0)]
ExpectedCircuitValue(String), ExpectedCircuitValue(String),
#[error("Circuit {} has no static function {}", _0, _1)]
UndefinedStaticFunction(String, String),
#[error(
"Static access only supported for static circuit functions, got function {}",
_0
)]
InvalidStaticFunction(String),
// Functions // Functions
#[error( #[error(
"Function {} must be declared before it is used in an inline expression", "Function {} must be declared before it is used in an inline expression",

View File

@ -111,12 +111,13 @@ range_or_expression = { range | expression }
access_array = { "[" ~ range_or_expression ~ "]" } access_array = { "[" ~ range_or_expression ~ "]" }
access_call = { "(" ~ expression_tuple ~ ")" } access_call = { "(" ~ expression_tuple ~ ")" }
access_member = { "." ~ identifier } access_object = { "." ~ identifier }
access = { access_array | access_call | access_member } access_static_object = { "::" ~ identifier }
access = { access_array | access_call | access_object | access_static_object}
expression_postfix = { identifier ~ access+ } expression_postfix = { identifier ~ access+ }
assignee_access = { access_array | access_member } assignee_access = { access_array | access_object }
assignee = { identifier ~ assignee_access* } assignee = { identifier ~ assignee_access* }
spread = { "..." ~ expression } spread = { "..." ~ expression }

View File

@ -154,7 +154,8 @@ pub enum Expression<F: Field + PrimeField, G: Group> {
// Circuits // Circuits
Circuit(Identifier<F, G>, Vec<CircuitMember<F, G>>), Circuit(Identifier<F, G>, Vec<CircuitMember<F, G>>),
CircuitMemberAccess(Box<Expression<F, G>>, Identifier<F, G>), // (circuit name, circuit object name) CircuitObjectAccess(Box<Expression<F, G>>, Identifier<F, G>), // (declared circuit name, circuit object name)
CircuitStaticObjectAccess(Box<Expression<F, G>>, Identifier<F, G>), // (defined circuit name, circuit staic object name)
// Functions // Functions
FunctionCall(Box<Expression<F, G>>, Vec<Expression<F, G>>), FunctionCall(Box<Expression<F, G>>, Vec<Expression<F, G>>),

View File

@ -154,8 +154,11 @@ impl<'ast, F: Field + PrimeField, G: Group> fmt::Display for Expression<F, G> {
} }
write!(f, "}}") write!(f, "}}")
} }
Expression::CircuitMemberAccess(ref circuit_variable, ref member) => { Expression::CircuitObjectAccess(ref circuit_name, ref member) => {
write!(f, "{}.{}", circuit_variable, member) write!(f, "{}.{}", circuit_name, member)
}
Expression::CircuitStaticObjectAccess(ref circuit_name, ref member) => {
write!(f, "{}::{}", circuit_name, member)
} }
// Function calls // Function calls

View File

@ -317,49 +317,33 @@ impl<'ast, F: Field + PrimeField, G: Group> From<ast::PostfixExpression<'ast>>
.accesses .accesses
.into_iter() .into_iter()
.fold(variable, |acc, access| match access { .fold(variable, |acc, access| match access {
// Handle function calls
ast::Access::Call(function) => {
types::Expression::FunctionCall(
Box::new(acc),
function
.expressions
.into_iter()
.map(|expression| types::Expression::from(expression))
.collect(),
)
//match acc {
// Normal function call
// types::Expression::Identifier(identifier) => types::Expression::FunctionCall(
// Box::new(acc),
// function
// .expressions
// .into_iter()
// .map(|expression| types::Expression::from(expression))
// .collect(),
// ),
// // Circuit function call
// types::Expression::CircuitMemberAccess(acc, circuit_function) => types::Expression::FunctionCall(
// circuit_function,
// function
// .expressions
// .into_iter()
// .map(|expression| types::Expression::from(expression))
// .collect(),
// ),
// expression => {
// unimplemented!("only function names are callable, found \"{}\"", expression)
// }
}
// Handle circuit member accesses
ast::Access::Member(circuit_member) => types::Expression::CircuitMemberAccess(
Box::new(acc),
types::Identifier::from(circuit_member.identifier),
),
// Handle array accesses // Handle array accesses
ast::Access::Array(array) => types::Expression::ArrayAccess( ast::Access::Array(array) => types::Expression::ArrayAccess(
Box::new(acc), Box::new(acc),
Box::new(types::RangeOrExpression::from(array.expression)), Box::new(types::RangeOrExpression::from(array.expression)),
), ),
// Handle function calls
ast::Access::Call(function) => types::Expression::FunctionCall(
Box::new(acc),
function
.expressions
.into_iter()
.map(|expression| types::Expression::from(expression))
.collect(),
),
// Handle circuit member accesses
ast::Access::Object(circuit_object) => types::Expression::CircuitObjectAccess(
Box::new(acc),
types::Identifier::from(circuit_object.identifier),
),
ast::Access::StaticObject(circuit_object) => {
types::Expression::CircuitStaticObjectAccess(
Box::new(acc),
types::Identifier::from(circuit_object.identifier),
)
}
}) })
} }
} }
@ -406,7 +390,7 @@ impl<'ast, F: Field + PrimeField, G: Group> From<ast::Assignee<'ast>> for types:
.into_iter() .into_iter()
.fold(variable, |acc, access| match access { .fold(variable, |acc, access| match access {
ast::AssigneeAccess::Member(struct_member) => { ast::AssigneeAccess::Member(struct_member) => {
types::Expression::CircuitMemberAccess( types::Expression::CircuitObjectAccess(
Box::new(acc), Box::new(acc),
types::Identifier::from(struct_member.identifier), types::Identifier::from(struct_member.identifier),
) )