Merge pull request #1113 from AleoHQ/feature/circuit-constant-value

Feature/circuit constant value
This commit is contained in:
Alessandro Coglio 2021-07-08 12:58:41 -07:00 committed by GitHub
commit cf290ed097
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 66 additions and 33 deletions

View File

@ -30,8 +30,6 @@ jobs:
- name: cargo fmt --check
uses: actions-rs/cargo@v1
env:
CARGO_NET_GIT_FETCH_WITH_CLI: true
with:
command: fmt
args: --all -- --check
@ -120,7 +118,6 @@ jobs:
command: test
args: --all --features ci_skip
env:
CARGO_NET_GIT_FETCH_WITH_CLI: true
CARGO_INCREMENTAL: "0"
- name: Install dependencies for code coverage

View File

@ -23,7 +23,7 @@ pub struct ConstantFolding<'a, 'b> {
}
impl<'a, 'b> ExpressionVisitor<'a> for ConstantFolding<'a, 'b> {
fn visit_expression(&mut self, input: &Cell<&Expression<'a>>) -> VisitResult {
fn visit_expression(&mut self, input: &Cell<&'a Expression<'a>>) -> VisitResult {
let expr = input.get();
if let Some(const_value) = expr.const_value() {
let folded_expr = Expression::Constant(Constant {

View File

@ -14,8 +14,9 @@
// You should have received a copy of the GNU General Public License
// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.
use crate::{AsgConvertError, IntegerType, Span, Type};
use crate::{AsgConvertError, Circuit, Identifier, IntegerType, Span, Type};
use indexmap::IndexMap;
use num_bigint::BigInt;
use std::{convert::TryInto, fmt};
use tendril::StrTendril;
@ -118,8 +119,8 @@ impl From<leo_ast::CharValue> for CharValue {
}
}
#[derive(Clone, Debug, PartialEq)]
pub enum ConstValue {
#[derive(Clone, PartialEq)]
pub enum ConstValue<'a> {
Int(ConstInt),
Group(GroupValue),
Field(BigInt),
@ -128,8 +129,9 @@ pub enum ConstValue {
Char(CharValue),
// compounds
Tuple(Vec<ConstValue>),
Array(Vec<ConstValue>),
Tuple(Vec<ConstValue<'a>>),
Array(Vec<ConstValue<'a>>),
Circuit(&'a Circuit<'a>, IndexMap<String, (Identifier, ConstValue<'a>)>),
}
macro_rules! const_int_op {
@ -311,8 +313,8 @@ impl ConstInt {
}
}
impl ConstValue {
pub fn get_type<'a>(&self) -> Option<Type<'a>> {
impl<'a> ConstValue<'a> {
pub fn get_type(&'a self) -> Option<Type<'a>> {
Some(match self {
ConstValue::Int(i) => i.get_type(),
ConstValue::Group(_) => Type::Group,
@ -324,6 +326,7 @@ impl ConstValue {
Type::Tuple(sub_consts.iter().map(|x| x.get_type()).collect::<Option<Vec<Type>>>()?)
}
ConstValue::Array(values) => Type::Array(Box::new(values.get(0)?.get_type()?), values.len()),
ConstValue::Circuit(circuit, _) => Type::Circuit(circuit),
})
}

View File

@ -58,7 +58,7 @@ impl<'a> ExpressionNode<'a> for ArrayAccessExpression<'a> {
self.array.get().is_mut_ref()
}
fn const_value(&self) -> Option<ConstValue> {
fn const_value(&self) -> Option<ConstValue<'a>> {
let mut array = match self.array.get().const_value()? {
ConstValue::Array(values) => values,
_ => return None,

View File

@ -53,7 +53,7 @@ impl<'a> ExpressionNode<'a> for ArrayInitExpression<'a> {
false
}
fn const_value(&self) -> Option<ConstValue> {
fn const_value(&self) -> Option<ConstValue<'a>> {
let element = self.element.get().const_value()?;
Some(ConstValue::Array(vec![element; self.len]))
}

View File

@ -78,7 +78,7 @@ impl<'a> ExpressionNode<'a> for ArrayInlineExpression<'a> {
false
}
fn const_value(&self) -> Option<ConstValue> {
fn const_value(&self) -> Option<ConstValue<'a>> {
let mut const_values = vec![];
for (expr, spread) in self.elements.iter() {
if *spread {

View File

@ -70,7 +70,7 @@ impl<'a> ExpressionNode<'a> for ArrayRangeAccessExpression<'a> {
self.array.get().is_mut_ref()
}
fn const_value(&self) -> Option<ConstValue> {
fn const_value(&self) -> Option<ConstValue<'a>> {
let mut array = match self.array.get().const_value()? {
ConstValue::Array(values) => values,
_ => return None,
@ -176,6 +176,7 @@ impl<'a> FromAst<'a, leo_ast::ArrayRangeAccessExpression> for ArrayRangeAccessEx
} else {
None
};
if let Some(expected_len) = expected_len {
if let Some(length) = length {
if length != expected_len {

View File

@ -83,8 +83,14 @@ impl<'a> ExpressionNode<'a> for CircuitAccessExpression<'a> {
}
}
fn const_value(&self) -> Option<ConstValue> {
None
fn const_value(&self) -> Option<ConstValue<'a>> {
match self.target.get()?.const_value()? {
ConstValue::Circuit(_, members) => {
let (_, const_value) = members.get(&self.member.name.to_string())?.clone();
Some(const_value)
}
_ => None,
}
}
fn is_consty(&self) -> bool {

View File

@ -70,8 +70,17 @@ impl<'a> ExpressionNode<'a> for CircuitInitExpression<'a> {
true
}
fn const_value(&self) -> Option<ConstValue> {
None
fn const_value(&self) -> Option<ConstValue<'a>> {
let mut members = IndexMap::new();
for (identifier, member) in self.values.iter() {
// insert by name because accessmembers identifiers are different.
members.insert(
identifier.name.to_string(),
(identifier.clone(), member.get().const_value()?),
);
}
// Store circuit as well for get_type.
Some(ConstValue::Circuit(self.circuit.get(), members))
}
fn is_consty(&self) -> bool {

View File

@ -36,7 +36,7 @@ use std::cell::Cell;
pub struct Constant<'a> {
pub parent: Cell<Option<&'a Expression<'a>>>,
pub span: Option<Span>,
pub value: ConstValue, // should not be compound constants
pub value: ConstValue<'a>, // should not be compound constants
}
impl<'a> Node for Constant<'a> {
@ -56,7 +56,7 @@ impl<'a> ExpressionNode<'a> for Constant<'a> {
fn enforce_parents(&self, _expr: &'a Expression<'a>) {}
fn get_type(&self) -> Option<Type<'a>> {
fn get_type(&'a self) -> Option<Type<'a>> {
self.value.get_type()
}
@ -64,7 +64,7 @@ impl<'a> ExpressionNode<'a> for Constant<'a> {
false
}
fn const_value(&self) -> Option<ConstValue> {
fn const_value(&self) -> Option<ConstValue<'a>> {
Some(self.value.clone())
}
@ -267,6 +267,7 @@ impl<'a> Into<leo_ast::ValueExpression> for &Constant<'a> {
),
ConstValue::Tuple(_) => unimplemented!(),
ConstValue::Array(_) => unimplemented!(),
ConstValue::Circuit(_, _) => unimplemented!(),
}
}
}

View File

@ -124,9 +124,9 @@ pub trait ExpressionNode<'a>: Node {
fn get_parent(&self) -> Option<&'a Expression<'a>>;
fn enforce_parents(&self, expr: &'a Expression<'a>);
fn get_type(&self) -> Option<Type<'a>>;
fn get_type(&'a self) -> Option<Type<'a>>;
fn is_mut_ref(&self) -> bool;
fn const_value(&self) -> Option<ConstValue>; // todo: memoize
fn const_value(&'a self) -> Option<ConstValue>; // todo: memoize
fn is_consty(&self) -> bool;
}
@ -194,7 +194,7 @@ impl<'a> ExpressionNode<'a> for Expression<'a> {
}
}
fn get_type(&self) -> Option<Type<'a>> {
fn get_type(&'a self) -> Option<Type<'a>> {
use Expression::*;
match self {
VariableRef(x) => x.get_type(),
@ -236,7 +236,7 @@ impl<'a> ExpressionNode<'a> for Expression<'a> {
}
}
fn const_value(&self) -> Option<ConstValue> {
fn const_value(&'a self) -> Option<ConstValue<'a>> {
use Expression::*;
match self {
VariableRef(x) => x.const_value(),

View File

@ -56,7 +56,7 @@ impl<'a> ExpressionNode<'a> for TernaryExpression<'a> {
self.if_true.get().is_mut_ref() && self.if_false.get().is_mut_ref()
}
fn const_value(&self) -> Option<ConstValue> {
fn const_value(&self) -> Option<ConstValue<'a>> {
if let Some(ConstValue::Boolean(switch)) = self.condition.get().const_value() {
if switch {
self.if_true.get().const_value()

View File

@ -56,7 +56,7 @@ impl<'a> ExpressionNode<'a> for TupleAccessExpression<'a> {
self.tuple_ref.get().is_mut_ref()
}
fn const_value(&self) -> Option<ConstValue> {
fn const_value(&self) -> Option<ConstValue<'a>> {
let tuple_const = self.tuple_ref.get().const_value()?;
match tuple_const {
ConstValue::Tuple(sub_consts) => sub_consts.get(self.index).cloned(),

View File

@ -58,7 +58,7 @@ impl<'a> ExpressionNode<'a> for TupleInitExpression<'a> {
false
}
fn const_value(&self) -> Option<ConstValue> {
fn const_value(&self) -> Option<ConstValue<'a>> {
let mut consts = vec![];
for element in self.elements.iter() {
if let Some(const_value) = element.get().const_value() {

View File

@ -66,7 +66,7 @@ impl<'a> ExpressionNode<'a> for VariableRef<'a> {
}
// todo: we can use use hacky ssa here to catch more cases, or just enforce ssa before asg generation finished
fn const_value(&self) -> Option<ConstValue> {
fn const_value(&self) -> Option<ConstValue<'a>> {
let variable = self.variable.borrow();
if variable.mutable || variable.assignments.len() != 1 {
return None;

View File

@ -23,7 +23,7 @@ use crate::{
program::ConstrainedProgram,
relational::*,
resolve_core_circuit,
value::{Address, Char, CharType, ConstrainedValue, Integer},
value::{Address, Char, CharType, ConstrainedCircuitMember, ConstrainedValue, Integer},
FieldType,
GroupType,
};
@ -37,7 +37,7 @@ impl<'a, F: PrimeField, G: GroupType<F>> ConstrainedProgram<'a, F, G> {
pub(crate) fn enforce_const_value<CS: ConstraintSystem<F>>(
&mut self,
cs: &mut CS,
value: &ConstValue,
value: &'a ConstValue<'a>,
span: &Span,
) -> Result<ConstrainedValue<'a, F, G>, ExpressionError> {
Ok(match value {
@ -75,6 +75,17 @@ impl<'a, F: PrimeField, G: GroupType<F>> ConstrainedProgram<'a, F, G> {
.map(|x| self.enforce_const_value(cs, x, span))
.collect::<Result<Vec<_>, _>>()?,
),
ConstValue::Circuit(circuit, members) => {
let mut constrained_members = Vec::new();
for (_, (identifier, member)) in members.iter() {
constrained_members.push(ConstrainedCircuitMember(
identifier.clone(),
self.enforce_const_value(cs, member, span)?,
));
}
ConstrainedValue::CircuitExpression(circuit, constrained_members)
}
})
}

View File

@ -5,6 +5,10 @@ input_file:
- input/complex_access.in
*/
circuit Circ {
f: u32
}
function main (a: [u8; 8], b: u32, c: [[u8; 3]; 3], d: [(u8, u32); 1], e: [u8; (3, 4)] ) -> bool {
a[0..3][b] = 93;
a[2..6][1] = 87;
@ -16,6 +20,7 @@ function main (a: [u8; 8], b: u32, c: [[u8; 3]; 3], d: [(u8, u32); 1], e: [u8; (
c[0..2][0] = [1u8; 3];
c[1..][1][1..2][0] = 126;
c[1..][0] = [42, 43, 44];
c[Circ {f: 0}.f..1][0][0] += 2;
d[..1][0].1 = 1;
@ -24,7 +29,7 @@ function main (a: [u8; 8], b: u32, c: [[u8; 3]; 3], d: [(u8, u32); 1], e: [u8; (
return
a == [200u8, 93, 42, 174, 5, 6, 43, 8]
&& c == [[1u8, 1, 1], [42, 43, 44], [7, 126, 9]]
&& c == [[3u8, 1, 1], [42, 43, 44], [7, 126, 9]]
&& d == [(0u8, 1u32)]
&& e == [[33u8, 22, 22, 22], [0, 0, 0, 0], [0, 0, 0, 0]];
}