support static circuit functions. add :: syntax

This commit is contained in:
collin 2020-05-14 12:31:19 -07:00
parent ba84bf0d6c
commit 5fb0b58b68
10 changed files with 132 additions and 75 deletions

View File

@ -1,7 +1,7 @@
circuit PedersenHash {
parameters: group[1]
function new(b: u32) -> u32 {
static function new(b: u32) -> u32 {
return b
}
}
@ -9,7 +9,7 @@ circuit PedersenHash {
function main() -> u32{
let parameters = [0group; 1];
let pedersen = PedersenHash { parameters: parameters };
let b = pedersen.new(3);
let b = PedersenHash::new(3);
return b
}

View File

@ -438,8 +438,16 @@ pub struct ArrayAccess<'ast> {
}
#[derive(Clone, Debug, FromPest, PartialEq)]
#[pest_ast(rule(Rule::access_member))]
pub struct MemberAccess<'ast> {
#[pest_ast(rule(Rule::access_object))]
pub struct ObjectAccess<'ast> {
pub identifier: Identifier<'ast>,
#[pest_ast(outer())]
pub span: Span<'ast>,
}
#[derive(Clone, Debug, FromPest, PartialEq)]
#[pest_ast(rule(Rule::access_static_object))]
pub struct StaticObjectAccess<'ast> {
pub identifier: Identifier<'ast>,
#[pest_ast(outer())]
pub span: Span<'ast>,
@ -450,7 +458,8 @@ pub struct MemberAccess<'ast> {
pub enum Access<'ast> {
Array(ArrayAccess<'ast>),
Call(CallAccess<'ast>),
Member(MemberAccess<'ast>),
Object(ObjectAccess<'ast>),
StaticObject(StaticObjectAccess<'ast>),
}
#[derive(Clone, Debug, FromPest, PartialEq)]
@ -466,7 +475,7 @@ pub struct PostfixExpression<'ast> {
#[pest_ast(rule(Rule::assignee_access))]
pub enum AssigneeAccess<'ast> {
Array(ArrayAccess<'ast>),
Member(MemberAccess<'ast>),
Member(ObjectAccess<'ast>),
}
impl<'ast> fmt::Display for AssigneeAccess<'ast> {

View File

@ -6,6 +6,7 @@ use crate::{
ConstrainedProgram, ConstrainedValue,
},
errors::ExpressionError,
new_scope,
types::{
CircuitMember, CircuitObject, Expression, Identifier, RangeOrExpression, SpreadOrExpression,
},
@ -20,21 +21,23 @@ impl<F: Field + PrimeField, G: Group, CS: ConstraintSystem<F>> ConstrainedProgra
/// Enforce a variable expression by getting the resolved value
pub(crate) fn evaluate_identifier(
&mut self,
scope: String,
unresolved_variable: Identifier<F, G>,
file_scope: String,
function_scope: String,
unresolved_identifier: Identifier<F, G>,
) -> Result<ConstrainedValue<F, G>, ExpressionError> {
// Evaluate the variable name in the current function scope
let variable_name = new_scope_from_variable(scope, &unresolved_variable);
// Evaluate the identifier name in the current function scope
let variable_name = new_scope(function_scope, unresolved_identifier.to_string());
let identifier_name = new_scope(file_scope, unresolved_identifier.to_string());
if self.contains_name(&variable_name) {
if let Some(variable) = self.get(&variable_name) {
// Reassigning variable to another variable
Ok(self.get_mut(&variable_name).unwrap().clone())
} else if self.contains_variable(&unresolved_variable) {
Ok(variable.clone())
} else if let Some(identifier) = self.get(&identifier_name) {
// Check global scope (function and circuit names)
Ok(self.get_mut_variable(&unresolved_variable).unwrap().clone())
Ok(identifier.clone())
} else {
Err(ExpressionError::UndefinedVariable(
unresolved_variable.to_string(),
Err(ExpressionError::UndefinedIdentifier(
unresolved_identifier.to_string(),
))
}
}
@ -449,6 +452,53 @@ impl<F: Field + PrimeField, G: Group, CS: ConstraintSystem<F>> ConstrainedProgra
}
}
fn enforce_circuit_static_access_expression(
&mut self,
cs: &mut CS,
file_scope: String,
function_scope: String,
circuit_identifier: Box<Expression<F, G>>,
circuit_member: Identifier<F, G>,
) -> Result<ConstrainedValue<F, G>, ExpressionError> {
// Get defined circuit
let circuit = match self.enforce_expression(
cs,
file_scope.clone(),
function_scope.clone(),
*circuit_identifier.clone(),
)? {
ConstrainedValue::CircuitDefinition(circuit_definition) => circuit_definition,
value => return Err(ExpressionError::InvalidCircuitAccess(value.to_string())),
};
// Find static circuit function
let matched_function = circuit.objects.into_iter().find(|member| match member {
CircuitObject::CircuitFunction(_static, _function) => *_static,
_ => false,
});
// Return errors if no static function exists
let function = match matched_function {
Some(CircuitObject::CircuitFunction(_static, function)) => {
if _static {
function
} else {
return Err(ExpressionError::InvalidStaticFunction(
function.function_name.to_string(),
));
}
}
_ => {
return Err(ExpressionError::UndefinedStaticFunction(
circuit.identifier.to_string(),
circuit_member.to_string(),
))
}
};
Ok(ConstrainedValue::Function(function))
}
fn enforce_function_call_expression(
&mut self,
cs: &mut CS,
@ -492,7 +542,7 @@ impl<F: Field + PrimeField, G: Group, CS: ConstraintSystem<F>> ConstrainedProgra
match expression {
// Variables
Expression::Identifier(unresolved_variable) => {
self.evaluate_identifier(function_scope, unresolved_variable)
self.evaluate_identifier(file_scope, function_scope, unresolved_variable)
}
// Values
@ -690,7 +740,7 @@ impl<F: Field + PrimeField, G: Group, CS: ConstraintSystem<F>> ConstrainedProgra
circuit_name,
members,
),
Expression::CircuitMemberAccess(circuit_variable, circuit_member) => self
Expression::CircuitObjectAccess(circuit_variable, circuit_member) => self
.enforce_circuit_access_expression(
cs,
file_scope,
@ -698,6 +748,14 @@ impl<F: Field + PrimeField, G: Group, CS: ConstraintSystem<F>> ConstrainedProgra
circuit_variable,
circuit_member,
),
Expression::CircuitStaticObjectAccess(circuit_identifier, circuit_member) => self
.enforce_circuit_static_access_expression(
cs,
file_scope,
function_scope,
circuit_identifier,
circuit_member,
),
// Functions
Expression::FunctionCall(function, arguments) => self.enforce_function_call_expression(

View File

@ -31,8 +31,8 @@ impl<F: Field + PrimeField, G: Group, CS: ConstraintSystem<F>> ConstrainedProgra
input: Expression<F, G>,
) -> Result<ConstrainedValue<F, G>, FunctionError> {
match input {
Expression::Identifier(variable) => {
Ok(self.evaluate_identifier(caller_scope, variable)?)
Expression::Identifier(identifier) => {
Ok(self.evaluate_identifier(caller_scope, function_name, identifier)?)
}
expression => Ok(self.enforce_expression(cs, scope, function_name, expression)?),
}

View File

@ -66,14 +66,6 @@ impl<F: Field + PrimeField, G: Group, CS: ConstraintSystem<F>> ConstrainedProgra
self.store(variable.name, value);
}
pub(crate) fn contains_name(&self, name: &String) -> bool {
self.identifiers.contains_key(name)
}
pub(crate) fn contains_variable(&self, variable: &Identifier<F, G>) -> bool {
self.contains_name(&variable.name)
}
pub(crate) fn get(&self, name: &String) -> Option<&ConstrainedValue<F, G>> {
self.identifiers.get(name)
}

View File

@ -2,9 +2,9 @@ use crate::errors::{BooleanError, FieldElementError, FunctionError, IntegerError
#[derive(Debug, Error)]
pub enum ExpressionError {
// Variables
// Identifiers
#[error("Variable \"{}\" not found", _0)]
UndefinedVariable(String),
UndefinedIdentifier(String),
// Types
#[error("{}", _0)]
@ -63,6 +63,15 @@ pub enum ExpressionError {
#[error("Expected circuit value {}", _0)]
ExpectedCircuitValue(String),
#[error("Circuit {} has no static function {}", _0, _1)]
UndefinedStaticFunction(String, String),
#[error(
"Static access only supported for static circuit functions, got function {}",
_0
)]
InvalidStaticFunction(String),
// Functions
#[error(
"Function {} must be declared before it is used in an inline expression",

View File

@ -111,12 +111,13 @@ range_or_expression = { range | expression }
access_array = { "[" ~ range_or_expression ~ "]" }
access_call = { "(" ~ expression_tuple ~ ")" }
access_member = { "." ~ identifier }
access = { access_array | access_call | access_member }
access_object = { "." ~ identifier }
access_static_object = { "::" ~ identifier }
access = { access_array | access_call | access_object | access_static_object}
expression_postfix = { identifier ~ access+ }
assignee_access = { access_array | access_member }
assignee_access = { access_array | access_object }
assignee = { identifier ~ assignee_access* }
spread = { "..." ~ expression }

View File

@ -154,7 +154,8 @@ pub enum Expression<F: Field + PrimeField, G: Group> {
// Circuits
Circuit(Identifier<F, G>, Vec<CircuitMember<F, G>>),
CircuitMemberAccess(Box<Expression<F, G>>, Identifier<F, G>), // (circuit name, circuit object name)
CircuitObjectAccess(Box<Expression<F, G>>, Identifier<F, G>), // (declared circuit name, circuit object name)
CircuitStaticObjectAccess(Box<Expression<F, G>>, Identifier<F, G>), // (defined circuit name, circuit staic object name)
// Functions
FunctionCall(Box<Expression<F, G>>, Vec<Expression<F, G>>),

View File

@ -154,8 +154,11 @@ impl<'ast, F: Field + PrimeField, G: Group> fmt::Display for Expression<F, G> {
}
write!(f, "}}")
}
Expression::CircuitMemberAccess(ref circuit_variable, ref member) => {
write!(f, "{}.{}", circuit_variable, member)
Expression::CircuitObjectAccess(ref circuit_name, ref member) => {
write!(f, "{}.{}", circuit_name, member)
}
Expression::CircuitStaticObjectAccess(ref circuit_name, ref member) => {
write!(f, "{}::{}", circuit_name, member)
}
// Function calls

View File

@ -317,49 +317,33 @@ impl<'ast, F: Field + PrimeField, G: Group> From<ast::PostfixExpression<'ast>>
.accesses
.into_iter()
.fold(variable, |acc, access| match access {
// Handle function calls
ast::Access::Call(function) => {
types::Expression::FunctionCall(
Box::new(acc),
function
.expressions
.into_iter()
.map(|expression| types::Expression::from(expression))
.collect(),
)
//match acc {
// Normal function call
// types::Expression::Identifier(identifier) => types::Expression::FunctionCall(
// Box::new(acc),
// function
// .expressions
// .into_iter()
// .map(|expression| types::Expression::from(expression))
// .collect(),
// ),
// // Circuit function call
// types::Expression::CircuitMemberAccess(acc, circuit_function) => types::Expression::FunctionCall(
// circuit_function,
// function
// .expressions
// .into_iter()
// .map(|expression| types::Expression::from(expression))
// .collect(),
// ),
// expression => {
// unimplemented!("only function names are callable, found \"{}\"", expression)
// }
}
// Handle circuit member accesses
ast::Access::Member(circuit_member) => types::Expression::CircuitMemberAccess(
Box::new(acc),
types::Identifier::from(circuit_member.identifier),
),
// Handle array accesses
ast::Access::Array(array) => types::Expression::ArrayAccess(
Box::new(acc),
Box::new(types::RangeOrExpression::from(array.expression)),
),
// Handle function calls
ast::Access::Call(function) => types::Expression::FunctionCall(
Box::new(acc),
function
.expressions
.into_iter()
.map(|expression| types::Expression::from(expression))
.collect(),
),
// Handle circuit member accesses
ast::Access::Object(circuit_object) => types::Expression::CircuitObjectAccess(
Box::new(acc),
types::Identifier::from(circuit_object.identifier),
),
ast::Access::StaticObject(circuit_object) => {
types::Expression::CircuitStaticObjectAccess(
Box::new(acc),
types::Identifier::from(circuit_object.identifier),
)
}
})
}
}
@ -406,7 +390,7 @@ impl<'ast, F: Field + PrimeField, G: Group> From<ast::Assignee<'ast>> for types:
.into_iter()
.fold(variable, |acc, access| match access {
ast::AssigneeAccess::Member(struct_member) => {
types::Expression::CircuitMemberAccess(
types::Expression::CircuitObjectAccess(
Box::new(acc),
types::Identifier::from(struct_member.identifier),
)