mirror of
https://github.com/ProvableHQ/leo.git
synced 2024-12-23 18:21:38 +03:00
support static circuit functions. add :: syntax
This commit is contained in:
parent
ba84bf0d6c
commit
5fb0b58b68
@ -1,7 +1,7 @@
|
||||
circuit PedersenHash {
|
||||
parameters: group[1]
|
||||
|
||||
function new(b: u32) -> u32 {
|
||||
static function new(b: u32) -> u32 {
|
||||
return b
|
||||
}
|
||||
}
|
||||
@ -9,7 +9,7 @@ circuit PedersenHash {
|
||||
function main() -> u32{
|
||||
let parameters = [0group; 1];
|
||||
let pedersen = PedersenHash { parameters: parameters };
|
||||
let b = pedersen.new(3);
|
||||
let b = PedersenHash::new(3);
|
||||
|
||||
return b
|
||||
}
|
@ -438,8 +438,16 @@ pub struct ArrayAccess<'ast> {
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, FromPest, PartialEq)]
|
||||
#[pest_ast(rule(Rule::access_member))]
|
||||
pub struct MemberAccess<'ast> {
|
||||
#[pest_ast(rule(Rule::access_object))]
|
||||
pub struct ObjectAccess<'ast> {
|
||||
pub identifier: Identifier<'ast>,
|
||||
#[pest_ast(outer())]
|
||||
pub span: Span<'ast>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, FromPest, PartialEq)]
|
||||
#[pest_ast(rule(Rule::access_static_object))]
|
||||
pub struct StaticObjectAccess<'ast> {
|
||||
pub identifier: Identifier<'ast>,
|
||||
#[pest_ast(outer())]
|
||||
pub span: Span<'ast>,
|
||||
@ -450,7 +458,8 @@ pub struct MemberAccess<'ast> {
|
||||
pub enum Access<'ast> {
|
||||
Array(ArrayAccess<'ast>),
|
||||
Call(CallAccess<'ast>),
|
||||
Member(MemberAccess<'ast>),
|
||||
Object(ObjectAccess<'ast>),
|
||||
StaticObject(StaticObjectAccess<'ast>),
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, FromPest, PartialEq)]
|
||||
@ -466,7 +475,7 @@ pub struct PostfixExpression<'ast> {
|
||||
#[pest_ast(rule(Rule::assignee_access))]
|
||||
pub enum AssigneeAccess<'ast> {
|
||||
Array(ArrayAccess<'ast>),
|
||||
Member(MemberAccess<'ast>),
|
||||
Member(ObjectAccess<'ast>),
|
||||
}
|
||||
|
||||
impl<'ast> fmt::Display for AssigneeAccess<'ast> {
|
||||
|
@ -6,6 +6,7 @@ use crate::{
|
||||
ConstrainedProgram, ConstrainedValue,
|
||||
},
|
||||
errors::ExpressionError,
|
||||
new_scope,
|
||||
types::{
|
||||
CircuitMember, CircuitObject, Expression, Identifier, RangeOrExpression, SpreadOrExpression,
|
||||
},
|
||||
@ -20,21 +21,23 @@ impl<F: Field + PrimeField, G: Group, CS: ConstraintSystem<F>> ConstrainedProgra
|
||||
/// Enforce a variable expression by getting the resolved value
|
||||
pub(crate) fn evaluate_identifier(
|
||||
&mut self,
|
||||
scope: String,
|
||||
unresolved_variable: Identifier<F, G>,
|
||||
file_scope: String,
|
||||
function_scope: String,
|
||||
unresolved_identifier: Identifier<F, G>,
|
||||
) -> Result<ConstrainedValue<F, G>, ExpressionError> {
|
||||
// Evaluate the variable name in the current function scope
|
||||
let variable_name = new_scope_from_variable(scope, &unresolved_variable);
|
||||
// Evaluate the identifier name in the current function scope
|
||||
let variable_name = new_scope(function_scope, unresolved_identifier.to_string());
|
||||
let identifier_name = new_scope(file_scope, unresolved_identifier.to_string());
|
||||
|
||||
if self.contains_name(&variable_name) {
|
||||
if let Some(variable) = self.get(&variable_name) {
|
||||
// Reassigning variable to another variable
|
||||
Ok(self.get_mut(&variable_name).unwrap().clone())
|
||||
} else if self.contains_variable(&unresolved_variable) {
|
||||
Ok(variable.clone())
|
||||
} else if let Some(identifier) = self.get(&identifier_name) {
|
||||
// Check global scope (function and circuit names)
|
||||
Ok(self.get_mut_variable(&unresolved_variable).unwrap().clone())
|
||||
Ok(identifier.clone())
|
||||
} else {
|
||||
Err(ExpressionError::UndefinedVariable(
|
||||
unresolved_variable.to_string(),
|
||||
Err(ExpressionError::UndefinedIdentifier(
|
||||
unresolved_identifier.to_string(),
|
||||
))
|
||||
}
|
||||
}
|
||||
@ -449,6 +452,53 @@ impl<F: Field + PrimeField, G: Group, CS: ConstraintSystem<F>> ConstrainedProgra
|
||||
}
|
||||
}
|
||||
|
||||
fn enforce_circuit_static_access_expression(
|
||||
&mut self,
|
||||
cs: &mut CS,
|
||||
file_scope: String,
|
||||
function_scope: String,
|
||||
circuit_identifier: Box<Expression<F, G>>,
|
||||
circuit_member: Identifier<F, G>,
|
||||
) -> Result<ConstrainedValue<F, G>, ExpressionError> {
|
||||
// Get defined circuit
|
||||
let circuit = match self.enforce_expression(
|
||||
cs,
|
||||
file_scope.clone(),
|
||||
function_scope.clone(),
|
||||
*circuit_identifier.clone(),
|
||||
)? {
|
||||
ConstrainedValue::CircuitDefinition(circuit_definition) => circuit_definition,
|
||||
value => return Err(ExpressionError::InvalidCircuitAccess(value.to_string())),
|
||||
};
|
||||
|
||||
// Find static circuit function
|
||||
let matched_function = circuit.objects.into_iter().find(|member| match member {
|
||||
CircuitObject::CircuitFunction(_static, _function) => *_static,
|
||||
_ => false,
|
||||
});
|
||||
|
||||
// Return errors if no static function exists
|
||||
let function = match matched_function {
|
||||
Some(CircuitObject::CircuitFunction(_static, function)) => {
|
||||
if _static {
|
||||
function
|
||||
} else {
|
||||
return Err(ExpressionError::InvalidStaticFunction(
|
||||
function.function_name.to_string(),
|
||||
));
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
return Err(ExpressionError::UndefinedStaticFunction(
|
||||
circuit.identifier.to_string(),
|
||||
circuit_member.to_string(),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
Ok(ConstrainedValue::Function(function))
|
||||
}
|
||||
|
||||
fn enforce_function_call_expression(
|
||||
&mut self,
|
||||
cs: &mut CS,
|
||||
@ -492,7 +542,7 @@ impl<F: Field + PrimeField, G: Group, CS: ConstraintSystem<F>> ConstrainedProgra
|
||||
match expression {
|
||||
// Variables
|
||||
Expression::Identifier(unresolved_variable) => {
|
||||
self.evaluate_identifier(function_scope, unresolved_variable)
|
||||
self.evaluate_identifier(file_scope, function_scope, unresolved_variable)
|
||||
}
|
||||
|
||||
// Values
|
||||
@ -690,7 +740,7 @@ impl<F: Field + PrimeField, G: Group, CS: ConstraintSystem<F>> ConstrainedProgra
|
||||
circuit_name,
|
||||
members,
|
||||
),
|
||||
Expression::CircuitMemberAccess(circuit_variable, circuit_member) => self
|
||||
Expression::CircuitObjectAccess(circuit_variable, circuit_member) => self
|
||||
.enforce_circuit_access_expression(
|
||||
cs,
|
||||
file_scope,
|
||||
@ -698,6 +748,14 @@ impl<F: Field + PrimeField, G: Group, CS: ConstraintSystem<F>> ConstrainedProgra
|
||||
circuit_variable,
|
||||
circuit_member,
|
||||
),
|
||||
Expression::CircuitStaticObjectAccess(circuit_identifier, circuit_member) => self
|
||||
.enforce_circuit_static_access_expression(
|
||||
cs,
|
||||
file_scope,
|
||||
function_scope,
|
||||
circuit_identifier,
|
||||
circuit_member,
|
||||
),
|
||||
|
||||
// Functions
|
||||
Expression::FunctionCall(function, arguments) => self.enforce_function_call_expression(
|
||||
|
@ -31,8 +31,8 @@ impl<F: Field + PrimeField, G: Group, CS: ConstraintSystem<F>> ConstrainedProgra
|
||||
input: Expression<F, G>,
|
||||
) -> Result<ConstrainedValue<F, G>, FunctionError> {
|
||||
match input {
|
||||
Expression::Identifier(variable) => {
|
||||
Ok(self.evaluate_identifier(caller_scope, variable)?)
|
||||
Expression::Identifier(identifier) => {
|
||||
Ok(self.evaluate_identifier(caller_scope, function_name, identifier)?)
|
||||
}
|
||||
expression => Ok(self.enforce_expression(cs, scope, function_name, expression)?),
|
||||
}
|
||||
|
@ -66,14 +66,6 @@ impl<F: Field + PrimeField, G: Group, CS: ConstraintSystem<F>> ConstrainedProgra
|
||||
self.store(variable.name, value);
|
||||
}
|
||||
|
||||
pub(crate) fn contains_name(&self, name: &String) -> bool {
|
||||
self.identifiers.contains_key(name)
|
||||
}
|
||||
|
||||
pub(crate) fn contains_variable(&self, variable: &Identifier<F, G>) -> bool {
|
||||
self.contains_name(&variable.name)
|
||||
}
|
||||
|
||||
pub(crate) fn get(&self, name: &String) -> Option<&ConstrainedValue<F, G>> {
|
||||
self.identifiers.get(name)
|
||||
}
|
||||
|
@ -2,9 +2,9 @@ use crate::errors::{BooleanError, FieldElementError, FunctionError, IntegerError
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum ExpressionError {
|
||||
// Variables
|
||||
// Identifiers
|
||||
#[error("Variable \"{}\" not found", _0)]
|
||||
UndefinedVariable(String),
|
||||
UndefinedIdentifier(String),
|
||||
|
||||
// Types
|
||||
#[error("{}", _0)]
|
||||
@ -63,6 +63,15 @@ pub enum ExpressionError {
|
||||
#[error("Expected circuit value {}", _0)]
|
||||
ExpectedCircuitValue(String),
|
||||
|
||||
#[error("Circuit {} has no static function {}", _0, _1)]
|
||||
UndefinedStaticFunction(String, String),
|
||||
|
||||
#[error(
|
||||
"Static access only supported for static circuit functions, got function {}",
|
||||
_0
|
||||
)]
|
||||
InvalidStaticFunction(String),
|
||||
|
||||
// Functions
|
||||
#[error(
|
||||
"Function {} must be declared before it is used in an inline expression",
|
||||
|
@ -111,12 +111,13 @@ range_or_expression = { range | expression }
|
||||
|
||||
access_array = { "[" ~ range_or_expression ~ "]" }
|
||||
access_call = { "(" ~ expression_tuple ~ ")" }
|
||||
access_member = { "." ~ identifier }
|
||||
access = { access_array | access_call | access_member }
|
||||
access_object = { "." ~ identifier }
|
||||
access_static_object = { "::" ~ identifier }
|
||||
access = { access_array | access_call | access_object | access_static_object}
|
||||
|
||||
expression_postfix = { identifier ~ access+ }
|
||||
|
||||
assignee_access = { access_array | access_member }
|
||||
assignee_access = { access_array | access_object }
|
||||
assignee = { identifier ~ assignee_access* }
|
||||
|
||||
spread = { "..." ~ expression }
|
||||
|
@ -154,7 +154,8 @@ pub enum Expression<F: Field + PrimeField, G: Group> {
|
||||
|
||||
// Circuits
|
||||
Circuit(Identifier<F, G>, Vec<CircuitMember<F, G>>),
|
||||
CircuitMemberAccess(Box<Expression<F, G>>, Identifier<F, G>), // (circuit name, circuit object name)
|
||||
CircuitObjectAccess(Box<Expression<F, G>>, Identifier<F, G>), // (declared circuit name, circuit object name)
|
||||
CircuitStaticObjectAccess(Box<Expression<F, G>>, Identifier<F, G>), // (defined circuit name, circuit staic object name)
|
||||
|
||||
// Functions
|
||||
FunctionCall(Box<Expression<F, G>>, Vec<Expression<F, G>>),
|
||||
|
@ -154,8 +154,11 @@ impl<'ast, F: Field + PrimeField, G: Group> fmt::Display for Expression<F, G> {
|
||||
}
|
||||
write!(f, "}}")
|
||||
}
|
||||
Expression::CircuitMemberAccess(ref circuit_variable, ref member) => {
|
||||
write!(f, "{}.{}", circuit_variable, member)
|
||||
Expression::CircuitObjectAccess(ref circuit_name, ref member) => {
|
||||
write!(f, "{}.{}", circuit_name, member)
|
||||
}
|
||||
Expression::CircuitStaticObjectAccess(ref circuit_name, ref member) => {
|
||||
write!(f, "{}::{}", circuit_name, member)
|
||||
}
|
||||
|
||||
// Function calls
|
||||
|
@ -317,49 +317,33 @@ impl<'ast, F: Field + PrimeField, G: Group> From<ast::PostfixExpression<'ast>>
|
||||
.accesses
|
||||
.into_iter()
|
||||
.fold(variable, |acc, access| match access {
|
||||
// Handle function calls
|
||||
ast::Access::Call(function) => {
|
||||
types::Expression::FunctionCall(
|
||||
Box::new(acc),
|
||||
function
|
||||
.expressions
|
||||
.into_iter()
|
||||
.map(|expression| types::Expression::from(expression))
|
||||
.collect(),
|
||||
)
|
||||
//match acc {
|
||||
// Normal function call
|
||||
// types::Expression::Identifier(identifier) => types::Expression::FunctionCall(
|
||||
// Box::new(acc),
|
||||
// function
|
||||
// .expressions
|
||||
// .into_iter()
|
||||
// .map(|expression| types::Expression::from(expression))
|
||||
// .collect(),
|
||||
// ),
|
||||
// // Circuit function call
|
||||
// types::Expression::CircuitMemberAccess(acc, circuit_function) => types::Expression::FunctionCall(
|
||||
// circuit_function,
|
||||
// function
|
||||
// .expressions
|
||||
// .into_iter()
|
||||
// .map(|expression| types::Expression::from(expression))
|
||||
// .collect(),
|
||||
// ),
|
||||
// expression => {
|
||||
// unimplemented!("only function names are callable, found \"{}\"", expression)
|
||||
// }
|
||||
}
|
||||
// Handle circuit member accesses
|
||||
ast::Access::Member(circuit_member) => types::Expression::CircuitMemberAccess(
|
||||
Box::new(acc),
|
||||
types::Identifier::from(circuit_member.identifier),
|
||||
),
|
||||
// Handle array accesses
|
||||
ast::Access::Array(array) => types::Expression::ArrayAccess(
|
||||
Box::new(acc),
|
||||
Box::new(types::RangeOrExpression::from(array.expression)),
|
||||
),
|
||||
|
||||
// Handle function calls
|
||||
ast::Access::Call(function) => types::Expression::FunctionCall(
|
||||
Box::new(acc),
|
||||
function
|
||||
.expressions
|
||||
.into_iter()
|
||||
.map(|expression| types::Expression::from(expression))
|
||||
.collect(),
|
||||
),
|
||||
|
||||
// Handle circuit member accesses
|
||||
ast::Access::Object(circuit_object) => types::Expression::CircuitObjectAccess(
|
||||
Box::new(acc),
|
||||
types::Identifier::from(circuit_object.identifier),
|
||||
),
|
||||
ast::Access::StaticObject(circuit_object) => {
|
||||
types::Expression::CircuitStaticObjectAccess(
|
||||
Box::new(acc),
|
||||
types::Identifier::from(circuit_object.identifier),
|
||||
)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -406,7 +390,7 @@ impl<'ast, F: Field + PrimeField, G: Group> From<ast::Assignee<'ast>> for types:
|
||||
.into_iter()
|
||||
.fold(variable, |acc, access| match access {
|
||||
ast::AssigneeAccess::Member(struct_member) => {
|
||||
types::Expression::CircuitMemberAccess(
|
||||
types::Expression::CircuitObjectAccess(
|
||||
Box::new(acc),
|
||||
types::Identifier::from(struct_member.identifier),
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user