enforce function return type

This commit is contained in:
collin 2020-04-22 13:26:52 -07:00
parent fa37bc4a40
commit 1e5c2a7ef9
5 changed files with 60 additions and 14 deletions

View File

@ -1,5 +1,5 @@
from "./simple_import" import Point
def foo() -> (fe):
return 3fe
def main() -> (Point):
Point p = Point { x: 1, y: 2}
return p
def main() -> (fe):
return foo()

View File

@ -113,8 +113,12 @@ impl<F: Field + PrimeField, CS: ConstraintSystem<F>> ResolvedProgram<F, CS> {
);
}
Statement::Return(expressions) => {
return_values =
self.enforce_return_statement(cs, function.get_name(), expressions)
return_values = self.enforce_return_statement(
cs,
function.get_name(),
expressions,
function.returns.to_owned(),
)
}
});

View File

@ -4,7 +4,7 @@
//! @author Collin Chin <collin@aleo.org>
//! @date 2020
use crate::program::types::{Function, Struct, StructMember, Variable};
use crate::program::types::{Function, Struct, StructMember, Type, Variable};
use snarkos_models::curves::{Field, PrimeField};
use snarkos_models::gadgets::{utilities::boolean::Boolean, utilities::uint32::UInt32};
@ -24,6 +24,33 @@ pub enum ResolvedValue<F: Field + PrimeField> {
Return(Vec<ResolvedValue<F>>), // add Null for function returns
}
impl<F: Field + PrimeField> ResolvedValue<F> {
pub(crate) fn match_type(&self, ty: &Type<F>) -> bool {
match (self, ty) {
(ResolvedValue::U32(ref _a), Type::U32) => true,
(ResolvedValue::U32Array(ref arr), Type::Array(ref arr_type, ref len)) => {
(arr.len() == *len) & (**arr_type == Type::U32)
}
(ResolvedValue::FieldElement(ref _a), Type::FieldElement) => true,
(ResolvedValue::FieldElementArray(ref arr), Type::Array(ref arr_type, ref len)) => {
(arr.len() == *len) & (**arr_type == Type::FieldElement)
}
(ResolvedValue::Boolean(ref _a), Type::Boolean) => true,
(ResolvedValue::BooleanArray(ref arr), Type::Array(ref arr_type, ref len)) => {
(arr.len() == *len) & (**arr_type == Type::Boolean)
}
(ResolvedValue::Return(ref values), ty) => {
let mut res = true;
for value in values {
res &= value.match_type(ty)
}
res
}
(_, _) => false,
}
}
}
impl<F: Field + PrimeField> fmt::Display for ResolvedValue<F> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {

View File

@ -6,7 +6,7 @@
use crate::program::constraints::{new_scope_from_variable, ResolvedProgram, ResolvedValue};
use crate::program::{
Assignee, Expression, IntegerExpression, IntegerRangeOrExpression, Statement, Variable,
Assignee, Expression, IntegerExpression, IntegerRangeOrExpression, Statement, Type, Variable,
};
use snarkos_models::curves::{Field, PrimeField};
@ -148,16 +148,31 @@ impl<F: Field + PrimeField, CS: ConstraintSystem<F>> ResolvedProgram<F, CS> {
cs: &mut CS,
scope: String,
statements: Vec<Expression<F>>,
return_types: Vec<Type<F>>,
) -> ResolvedValue<F> {
ResolvedValue::Return(
statements
.into_iter()
.map(|expression| self.enforce_expression(cs, scope.clone(), expression))
.zip(return_types.into_iter())
.map(|(expression, ty)| {
let result = self.enforce_expression(cs, scope.clone(), expression);
if !result.match_type(&ty) {
unimplemented!("expected return type {}, got {}", ty, result)
} else {
result
}
})
.collect::<Vec<ResolvedValue<F>>>(),
)
}
fn enforce_statement(&mut self, cs: &mut CS, scope: String, statement: Statement<F>) {
fn enforce_statement(
&mut self,
cs: &mut CS,
scope: String,
statement: Statement<F>,
return_types: Vec<Type<F>>,
) {
match statement {
Statement::Definition(variable, expression) => {
self.enforce_definition_statement(cs, scope, variable, expression);
@ -167,7 +182,7 @@ impl<F: Field + PrimeField, CS: ConstraintSystem<F>> ResolvedProgram<F, CS> {
}
Statement::Return(statements) => {
// TODO: add support for early termination
let _res = self.enforce_return_statement(cs, scope, statements);
let _res = self.enforce_return_statement(cs, scope, statements, return_types);
}
};
}
@ -190,11 +205,11 @@ impl<F: Field + PrimeField, CS: ConstraintSystem<F>> ResolvedProgram<F, CS> {
let index_name = new_scope_from_variable(scope.clone(), &index);
self.store(index_name, ResolvedValue::U32(UInt32::constant(i as u32)));
// Evaluate statements
// Evaluate statements (for loop statements should not have a return type)
statements
.clone()
.into_iter()
.for_each(|statement| self.enforce_statement(cs, scope.clone(), statement));
.for_each(|statement| self.enforce_statement(cs, scope.clone(), statement, vec![]));
}
}
}

View File

@ -160,7 +160,7 @@ pub enum Statement<F: Field + PrimeField> {
}
/// Explicit type used for defining struct members and function parameters
#[derive(Clone, Debug)]
#[derive(Clone, Debug, PartialEq)]
pub enum Type<F: Field + PrimeField> {
U32,
FieldElement,