impl circuit Self type

This commit is contained in:
collin 2020-05-14 17:07:09 -07:00
parent b984c46a51
commit 021379458d
12 changed files with 72 additions and 49 deletions

View File

@ -7,7 +7,7 @@ circuit PedersenHash {
} }
function main() -> u32{ function main() -> PedersenHash {
let parameters = [0group; 1]; let parameters = [0group; 1];
let pedersen = PedersenHash::new(parameters); let pedersen = PedersenHash::new(parameters);

View File

@ -171,24 +171,15 @@ pub enum IntegerType {
#[derive(Clone, Debug, FromPest, PartialEq)] #[derive(Clone, Debug, FromPest, PartialEq)]
#[pest_ast(rule(Rule::type_field))] #[pest_ast(rule(Rule::type_field))]
pub struct FieldType<'ast> { pub struct FieldType {}
#[pest_ast(outer())]
pub span: Span<'ast>,
}
#[derive(Clone, Debug, FromPest, PartialEq)] #[derive(Clone, Debug, FromPest, PartialEq)]
#[pest_ast(rule(Rule::type_group))] #[pest_ast(rule(Rule::type_group))]
pub struct GroupType<'ast> { pub struct GroupType {}
#[pest_ast(outer())]
pub span: Span<'ast>,
}
#[derive(Clone, Debug, FromPest, PartialEq)] #[derive(Clone, Debug, FromPest, PartialEq)]
#[pest_ast(rule(Rule::type_bool))] #[pest_ast(rule(Rule::type_bool))]
pub struct BooleanType<'ast> { pub struct BooleanType {}
#[pest_ast(outer())]
pub span: Span<'ast>,
}
#[derive(Clone, Debug, FromPest, PartialEq)] #[derive(Clone, Debug, FromPest, PartialEq)]
#[pest_ast(rule(Rule::type_circuit))] #[pest_ast(rule(Rule::type_circuit))]
@ -204,17 +195,17 @@ pub struct SelfType {}
#[derive(Clone, Debug, FromPest, PartialEq)] #[derive(Clone, Debug, FromPest, PartialEq)]
#[pest_ast(rule(Rule::type_basic))] #[pest_ast(rule(Rule::type_basic))]
pub enum BasicType<'ast> { pub enum BasicType {
Integer(IntegerType), Integer(IntegerType),
Field(FieldType<'ast>), Field(FieldType),
Group(GroupType<'ast>), Group(GroupType),
Boolean(BooleanType<'ast>), Boolean(BooleanType),
} }
#[derive(Clone, Debug, FromPest, PartialEq)] #[derive(Clone, Debug, FromPest, PartialEq)]
#[pest_ast(rule(Rule::type_array))] #[pest_ast(rule(Rule::type_array))]
pub struct ArrayType<'ast> { pub struct ArrayType<'ast> {
pub _type: BasicType<'ast>, pub _type: BasicType,
pub dimensions: Vec<Value<'ast>>, pub dimensions: Vec<Value<'ast>>,
#[pest_ast(outer())] #[pest_ast(outer())]
pub span: Span<'ast>, pub span: Span<'ast>,
@ -223,7 +214,7 @@ pub struct ArrayType<'ast> {
#[derive(Clone, Debug, FromPest, PartialEq)] #[derive(Clone, Debug, FromPest, PartialEq)]
#[pest_ast(rule(Rule::_type))] #[pest_ast(rule(Rule::_type))]
pub enum Type<'ast> { pub enum Type<'ast> {
Basic(BasicType<'ast>), Basic(BasicType),
Array(ArrayType<'ast>), Array(ArrayType<'ast>),
Circuit(CircuitType<'ast>), Circuit(CircuitType<'ast>),
SelfType(SelfType), SelfType(SelfType),
@ -275,7 +266,7 @@ impl<'ast> fmt::Display for Integer<'ast> {
#[pest_ast(rule(Rule::value_field))] #[pest_ast(rule(Rule::value_field))]
pub struct Field<'ast> { pub struct Field<'ast> {
pub number: Number<'ast>, pub number: Number<'ast>,
pub _type: FieldType<'ast>, pub _type: FieldType,
#[pest_ast(outer())] #[pest_ast(outer())]
pub span: Span<'ast>, pub span: Span<'ast>,
} }
@ -290,7 +281,7 @@ impl<'ast> fmt::Display for Field<'ast> {
#[pest_ast(rule(Rule::value_group))] #[pest_ast(rule(Rule::value_group))]
pub struct Group<'ast> { pub struct Group<'ast> {
pub number: Number<'ast>, pub number: Number<'ast>,
pub _type: GroupType<'ast>, pub _type: GroupType,
#[pest_ast(outer())] #[pest_ast(outer())]
pub span: Span<'ast>, pub span: Span<'ast>,
} }

View File

@ -2,8 +2,7 @@
use crate::{ use crate::{
constraints::{ constraints::{
new_scope_from_variable, new_variable_from_variable, ConstrainedCircuitMember, new_scope_from_variable, ConstrainedCircuitMember, ConstrainedProgram, ConstrainedValue,
ConstrainedProgram, ConstrainedValue,
}, },
errors::ExpressionError, errors::ExpressionError,
new_scope, new_scope,
@ -354,14 +353,19 @@ impl<F: Field + PrimeField, G: Group, CS: ConstraintSystem<F>> ConstrainedProgra
cs: &mut CS, cs: &mut CS,
file_scope: String, file_scope: String,
function_scope: String, function_scope: String,
variable: Identifier<F, G>, identifier: Identifier<F, G>,
members: Vec<CircuitFieldDefinition<F, G>>, members: Vec<CircuitFieldDefinition<F, G>>,
) -> Result<ConstrainedValue<F, G>, ExpressionError> { ) -> Result<ConstrainedValue<F, G>, ExpressionError> {
let circuit_name = new_variable_from_variable(file_scope.clone(), &variable); let mut program_identifier = new_scope(file_scope.clone(), identifier.to_string());
if identifier.is_self() {
program_identifier = file_scope.clone();
}
if let Some(ConstrainedValue::CircuitDefinition(circuit_definition)) = if let Some(ConstrainedValue::CircuitDefinition(circuit_definition)) =
self.get_mut_variable(&circuit_name) self.get_mut(&program_identifier)
{ {
let circuit_identifier = circuit_definition.identifier.clone();
let mut resolved_members = vec![]; let mut resolved_members = vec![];
for member in circuit_definition.members.clone().into_iter() { for member in circuit_definition.members.clone().into_iter() {
match member { match member {
@ -395,7 +399,8 @@ impl<F: Field + PrimeField, G: Group, CS: ConstraintSystem<F>> ConstrainedProgra
} }
CircuitMember::CircuitFunction(_static, function) => { CircuitMember::CircuitFunction(_static, function) => {
let identifier = function.function_name.clone(); let identifier = function.function_name.clone();
let mut constrained_function_value = ConstrainedValue::Function(function); let mut constrained_function_value =
ConstrainedValue::Function(Some(circuit_identifier.clone()), function);
if _static { if _static {
constrained_function_value = constrained_function_value =
@ -411,11 +416,11 @@ impl<F: Field + PrimeField, G: Group, CS: ConstraintSystem<F>> ConstrainedProgra
} }
Ok(ConstrainedValue::CircuitExpression( Ok(ConstrainedValue::CircuitExpression(
variable, circuit_identifier.clone(),
resolved_members, resolved_members,
)) ))
} else { } else {
Err(ExpressionError::UndefinedCircuit(variable.to_string())) Err(ExpressionError::UndefinedCircuit(identifier.to_string()))
} }
} }
@ -497,7 +502,10 @@ impl<F: Field + PrimeField, G: Group, CS: ConstraintSystem<F>> ConstrainedProgra
} }
}; };
Ok(ConstrainedValue::Function(function)) Ok(ConstrainedValue::Function(
Some(circuit.identifier),
function,
))
} }
fn enforce_function_call_expression( fn enforce_function_call_expression(
@ -515,12 +523,20 @@ impl<F: Field + PrimeField, G: Group, CS: ConstraintSystem<F>> ConstrainedProgra
*function.clone(), *function.clone(),
)?; )?;
let function_call = match function_value { let (outer_scope, function_call) = match function_value {
ConstrainedValue::Function(function) => function.clone(), ConstrainedValue::Function(circuit_identifier, function) => {
let mut outer_scope = file_scope.clone();
// If this is a circuit function, evaluate inside the circuit scope
if circuit_identifier.is_some() {
outer_scope = new_scope(file_scope, circuit_identifier.unwrap().to_string());
}
(outer_scope, function.clone())
}
value => return Err(ExpressionError::UndefinedFunction(value.to_string())), value => return Err(ExpressionError::UndefinedFunction(value.to_string())),
}; };
match self.enforce_function(cs, file_scope, function_scope, function_call, arguments) { match self.enforce_function(cs, outer_scope, function_scope, function_call, arguments) {
Ok(ConstrainedValue::Return(return_values)) => { Ok(ConstrainedValue::Return(return_values)) => {
if return_values.len() == 1 { if return_values.len() == 1 {
Ok(return_values[0].clone()) Ok(return_values[0].clone())

View File

@ -247,7 +247,10 @@ impl<F: Field + PrimeField, G: Group, CS: ConstraintSystem<F>> ConstrainedProgra
.for_each(|(function_name, function)| { .for_each(|(function_name, function)| {
let resolved_function_name = let resolved_function_name =
new_scope(program_name.to_string(), function_name.to_string()); new_scope(program_name.to_string(), function_name.to_string());
self.store(resolved_function_name, ConstrainedValue::Function(function)); self.store(
resolved_function_name,
ConstrainedValue::Function(None, function),
);
}); });
Ok(()) Ok(())

View File

@ -67,7 +67,7 @@ impl<F: Field + PrimeField, G: Group, CS: ConstraintSystem<F>> ConstrainedProgra
match matched_function { match matched_function {
Some((_function_name, function)) => { Some((_function_name, function)) => {
ConstrainedValue::Function(function) ConstrainedValue::Function(None, function)
} }
None => unimplemented!( None => unimplemented!(
"cannot find imported symbol {} in imported file {}", "cannot find imported symbol {} in imported file {}",

View File

@ -56,7 +56,7 @@ pub fn generate_constraints<F: Field + PrimeField, G: Group, CS: ConstraintSyste
.ok_or_else(|| CompilerError::NoMain)?; .ok_or_else(|| CompilerError::NoMain)?;
match main.clone() { match main.clone() {
ConstrainedValue::Function(function) => { ConstrainedValue::Function(_circuit_identifier, function) => {
let result = let result =
resolved_program.enforce_main_function(cs, program_name, function, parameters)?; resolved_program.enforce_main_function(cs, program_name, function, parameters)?;
log::debug!("{}", result); log::debug!("{}", result);

View File

@ -65,11 +65,4 @@ impl<F: Field + PrimeField, G: Group, CS: ConstraintSystem<F>> ConstrainedProgra
pub(crate) fn get_mut(&mut self, name: &String) -> Option<&mut ConstrainedValue<F, G>> { pub(crate) fn get_mut(&mut self, name: &String) -> Option<&mut ConstrainedValue<F, G>> {
self.identifiers.get_mut(name) self.identifiers.get_mut(name)
} }
pub(crate) fn get_mut_variable(
&mut self,
variable: &Identifier<F, G>,
) -> Option<&mut ConstrainedValue<F, G>> {
self.get_mut(&variable.name)
}
} }

View File

@ -101,7 +101,7 @@ impl<F: Field + PrimeField, G: Group, CS: ConstraintSystem<F>> ConstrainedProgra
match matched_field { match matched_field {
Some(object) => match &object.1 { Some(object) => match &object.1 {
ConstrainedValue::Function(function) => { ConstrainedValue::Function(_circuit_identifier, function) => {
return Err(StatementError::ImmutableCircuitFunction( return Err(StatementError::ImmutableCircuitFunction(
function.function_name.to_string(), function.function_name.to_string(),
)) ))
@ -193,6 +193,7 @@ impl<F: Field + PrimeField, G: Group, CS: ConstraintSystem<F>> ConstrainedProgra
variable: Variable<F, G>, variable: Variable<F, G>,
expression: Expression<F, G>, expression: Expression<F, G>,
) -> Result<(), StatementError> { ) -> Result<(), StatementError> {
// println!("evaluating {}", expression);
let value = let value =
self.enforce_expression(cs, file_scope.clone(), function_scope.clone(), expression)?; self.enforce_expression(cs, file_scope.clone(), function_scope.clone(), expression)?;

View File

@ -30,7 +30,7 @@ pub enum ConstrainedValue<F: Field + PrimeField, G: Group> {
CircuitDefinition(Circuit<F, G>), CircuitDefinition(Circuit<F, G>),
CircuitExpression(Identifier<F, G>, Vec<ConstrainedCircuitMember<F, G>>), CircuitExpression(Identifier<F, G>, Vec<ConstrainedCircuitMember<F, G>>),
Function(Function<F, G>), Function(Option<Identifier<F, G>>, Function<F, G>), // (optional circuit identifier, function definition)
Return(Vec<ConstrainedValue<F, G>>), Return(Vec<ConstrainedValue<F, G>>),
Mutable(Box<ConstrainedValue<F, G>>), Mutable(Box<ConstrainedValue<F, G>>),
@ -74,6 +74,17 @@ impl<F: Field + PrimeField, G: Group> ConstrainedValue<F, G> {
)); ));
} }
} }
(
ConstrainedValue::CircuitExpression(ref actual_name, ref _members),
Type::SelfType,
) => {
if Identifier::new("Self".into()) == *actual_name {
return Err(ValueError::CircuitName(
"Self".into(),
actual_name.to_string(),
));
}
}
(ConstrainedValue::Return(ref values), _type) => { (ConstrainedValue::Return(ref values), _type) => {
for value in values { for value in values {
value.expect_type(_type)?; value.expect_type(_type)?;
@ -137,7 +148,9 @@ impl<F: Field + PrimeField, G: Group> fmt::Display for ConstrainedValue<F, G> {
ConstrainedValue::CircuitDefinition(ref _definition) => { ConstrainedValue::CircuitDefinition(ref _definition) => {
unimplemented!("cannot return circuit definition in program") unimplemented!("cannot return circuit definition in program")
} }
ConstrainedValue::Function(ref function) => write!(f, "{}();", function.function_name), ConstrainedValue::Function(ref _circuit_option, ref function) => {
write!(f, "{}();", function.function_name)
}
ConstrainedValue::Mutable(ref value) => write!(f, "mut {}", value), ConstrainedValue::Mutable(ref value) => write!(f, "mut {}", value),
ConstrainedValue::Static(ref value) => write!(f, "static {}", value), ConstrainedValue::Static(ref value) => write!(f, "static {}", value),
} }

View File

@ -30,6 +30,10 @@ impl<F: Field + PrimeField, G: Group> Identifier<F, G> {
_engine: PhantomData::<F>, _engine: PhantomData::<F>,
} }
} }
pub fn is_self(&self) -> bool {
self.name == "Self"
}
} }
/// A variable that is assigned to a value in the constrained program /// A variable that is assigned to a value in the constrained program
@ -188,6 +192,7 @@ pub enum Type<F: Field + PrimeField, G: Group> {
Boolean, Boolean,
Array(Box<Type<F, G>>, Vec<usize>), Array(Box<Type<F, G>>, Vec<usize>),
Circuit(Identifier<F, G>), Circuit(Identifier<F, G>),
SelfType,
} }
impl<F: Field + PrimeField, G: Group> Type<F, G> { impl<F: Field + PrimeField, G: Group> Type<F, G> {

View File

@ -281,6 +281,7 @@ impl<F: Field + PrimeField, G: Group> fmt::Display for Type<F, G> {
Type::GroupElement => write!(f, "group"), Type::GroupElement => write!(f, "group"),
Type::Boolean => write!(f, "bool"), Type::Boolean => write!(f, "bool"),
Type::Circuit(ref variable) => write!(f, "{}", variable), Type::Circuit(ref variable) => write!(f, "{}", variable),
Type::SelfType => write!(f, "Self"),
Type::Array(ref array, ref dimensions) => { Type::Array(ref array, ref dimensions) => {
write!(f, "{}", *array)?; write!(f, "{}", *array)?;
for row in dimensions { for row in dimensions {

View File

@ -656,8 +656,8 @@ impl From<ast::IntegerType> for types::IntegerType {
} }
} }
impl<'ast, F: Field + PrimeField, G: Group> From<ast::BasicType<'ast>> for types::Type<F, G> { impl<F: Field + PrimeField, G: Group> From<ast::BasicType> for types::Type<F, G> {
fn from(basic_type: ast::BasicType<'ast>) -> Self { fn from(basic_type: ast::BasicType) -> Self {
match basic_type { match basic_type {
ast::BasicType::Integer(_type) => { ast::BasicType::Integer(_type) => {
types::Type::IntegerType(types::IntegerType::from(_type)) types::Type::IntegerType(types::IntegerType::from(_type))
@ -694,7 +694,7 @@ impl<'ast, F: Field + PrimeField, G: Group> From<ast::Type<'ast>> for types::Typ
ast::Type::Basic(_type) => types::Type::from(_type), ast::Type::Basic(_type) => types::Type::from(_type),
ast::Type::Array(_type) => types::Type::from(_type), ast::Type::Array(_type) => types::Type::from(_type),
ast::Type::Circuit(_type) => types::Type::from(_type), ast::Type::Circuit(_type) => types::Type::from(_type),
ast::Type::SelfType(_type) => unimplemented!("no Self yet") ast::Type::SelfType(_type) => types::Type::SelfType,
} }
} }
} }