From 915606880166608edcc904bddcf2a79bf8919d55 Mon Sep 17 00:00:00 2001 From: Protryon Date: Sun, 7 Mar 2021 05:45:00 -0800 Subject: [PATCH] constant folding --- Cargo.lock | 8 ++++ Cargo.toml | 1 + asg-passes/constant-folding/Cargo.toml | 25 ++++++++++++ asg-passes/constant-folding/src/lib.rs | 56 ++++++++++++++++++++++++++ asg/src/expression/variable_ref.rs | 8 +++- asg/src/lib.rs | 4 +- asg/src/pass.rs | 4 +- asg/src/reducer/visitor_director.rs | 9 ++--- compiler/Cargo.toml | 4 ++ compiler/src/compiler.rs | 29 +++++++++---- compiler/src/errors/compiler.rs | 5 ++- compiler/src/expression/expression.rs | 38 ++++++++++++----- compiler/src/lib.rs | 3 ++ compiler/src/option.rs | 28 +++++++++++++ 14 files changed, 193 insertions(+), 29 deletions(-) create mode 100644 asg-passes/constant-folding/Cargo.toml create mode 100644 asg-passes/constant-folding/src/lib.rs create mode 100644 compiler/src/option.rs diff --git a/Cargo.lock b/Cargo.lock index 1eaf8481b4..68369879e0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1226,6 +1226,7 @@ dependencies = [ "indexmap", "leo-asg", "leo-ast", + "leo-constant-folding", "leo-gadgets", "leo-imports", "leo-input", @@ -1250,6 +1251,13 @@ dependencies = [ "tracing", ] +[[package]] +name = "leo-constant-folding" +version = "1.2.3" +dependencies = [ + "leo-asg", +] + [[package]] name = "leo-gadgets" version = "1.2.3" diff --git a/Cargo.toml b/Cargo.toml index 9efc7f6ecf..d86af3b228 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -37,6 +37,7 @@ members = [ "parser", "state", "synthesizer", + "asg-passes/constant-folding", ] [dependencies.leo-ast] diff --git a/asg-passes/constant-folding/Cargo.toml b/asg-passes/constant-folding/Cargo.toml new file mode 100644 index 0000000000..7948613283 --- /dev/null +++ b/asg-passes/constant-folding/Cargo.toml @@ -0,0 +1,25 @@ +[package] +name = "leo-constant-folding" +version = "1.2.3" +authors = [ "The Aleo Team " ] +description = "The Leo programming language" +homepage = "https://aleo.org" +repository = "https://github.com/AleoHQ/leo" +keywords = [ + "aleo", + "cryptography", + "leo", + "programming-language", + "zero-knowledge" +] +categories = [ "cryptography::cryptocurrencies", "web-programming" ] +include = [ "Cargo.toml", "leo", "README.md", "LICENSE.md" ] +license = "GPL-3.0" +edition = "2018" + +[lib] +path = "src/lib.rs" + +[dependencies.leo-asg] +path = "../../asg" +version = "1.2.0" diff --git a/asg-passes/constant-folding/src/lib.rs b/asg-passes/constant-folding/src/lib.rs new file mode 100644 index 0000000000..141dab4f2c --- /dev/null +++ b/asg-passes/constant-folding/src/lib.rs @@ -0,0 +1,56 @@ +// Copyright (C) 2019-2021 Aleo Systems Inc. +// This file is part of the Leo library. + +// The Leo library is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. + +// The Leo library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. + +// You should have received a copy of the GNU General Public License +// along with the Leo library. If not, see . + +use std::{cell::Cell}; + +use leo_asg::*; + +pub struct ConstantFolding<'a, 'b> { + program: &'b Program<'a> +} + +impl<'a, 'b> ExpressionVisitor<'a> for ConstantFolding<'a, 'b> { + fn visit_expression(&mut self, input: &Cell<&Expression<'a>>) -> VisitResult { + let expr = input.get(); + if let Some(const_value) = expr.const_value() { + let folded_expr = Expression::Constant(Constant { + parent: Cell::new(expr.get_parent()), + span: expr.span().cloned(), + value: const_value, + }); + let folded_expr = self.program.scope.alloc_expression(folded_expr); + input.set(folded_expr); + VisitResult::SkipChildren + } else { + VisitResult::VisitChildren + } + } +} + +impl<'a, 'b> StatementVisitor<'a> for ConstantFolding<'a, 'b> {} + +impl<'a, 'b> ProgramVisitor<'a> for ConstantFolding<'a, 'b> {} + +impl<'a, 'b> AsgPass<'a> for ConstantFolding<'a, 'b> { + fn do_pass(asg: &Program<'a>) -> Result<(), FormattedError> { + let pass = ConstantFolding { + program: asg, + }; + let mut director = VisitorDirector::new(pass); + director.visit_program(asg).ok(); + Ok(()) + } +} diff --git a/asg/src/expression/variable_ref.rs b/asg/src/expression/variable_ref.rs index d702caa295..48022d843c 100644 --- a/asg/src/expression/variable_ref.rs +++ b/asg/src/expression/variable_ref.rs @@ -80,10 +80,14 @@ impl<'a> ExpressionNode<'a> for VariableRef<'a> { value.get().const_value() } else { - for defined_variable in variables.iter() { + for (i, defined_variable) in variables.iter().enumerate() { let defined_variable = defined_variable.borrow(); if defined_variable.id == variable.id { - return value.get().const_value(); + match value.get().const_value() { + Some(ConstValue::Tuple(values)) => return values.get(i).cloned(), + None => return None, + _ => (), + } } } panic!("no corresponding tuple variable found during const destructuring (corrupt asg?)"); diff --git a/asg/src/lib.rs b/asg/src/lib.rs index 098abf569a..be0de56ce9 100644 --- a/asg/src/lib.rs +++ b/asg/src/lib.rs @@ -102,8 +102,8 @@ impl<'a> Asg<'a> { } /// Returns the internal program ASG representation. - pub fn as_repr(&self) -> Program<'a> { - self.asg.clone() + pub fn as_repr(&self) -> &Program<'a> { + &self.asg } // /// Serializes the ast into a JSON string. diff --git a/asg/src/pass.rs b/asg/src/pass.rs index 9baa4f2a5a..7603b8f1ff 100644 --- a/asg/src/pass.rs +++ b/asg/src/pass.rs @@ -17,6 +17,6 @@ use crate::Program; pub use leo_ast::FormattedError; -pub trait AsgPass { - fn do_pass(asg: &Program) -> Result<(), FormattedError>; +pub trait AsgPass<'a> { + fn do_pass(asg: &Program<'a>) -> Result<(), FormattedError>; } diff --git a/asg/src/reducer/visitor_director.rs b/asg/src/reducer/visitor_director.rs index 83ac86072c..4371671494 100644 --- a/asg/src/reducer/visitor_director.rs +++ b/asg/src/reducer/visitor_director.rs @@ -388,9 +388,8 @@ impl<'a, R: StatementVisitor<'a>> VisitorDirector<'a, R> { } } -#[allow(dead_code)] impl<'a, R: ProgramVisitor<'a>> VisitorDirector<'a, R> { - fn visit_function(&mut self, input: &'a Function<'a>) -> ConcreteVisitResult { + pub fn visit_function(&mut self, input: &'a Function<'a>) -> ConcreteVisitResult { match self.visitor.visit_function(input) { VisitResult::VisitChildren => { self.visit_opt_statement(&input.body)?; @@ -400,7 +399,7 @@ impl<'a, R: ProgramVisitor<'a>> VisitorDirector<'a, R> { } } - fn visit_circuit_member(&mut self, input: &CircuitMember<'a>) -> ConcreteVisitResult { + pub fn visit_circuit_member(&mut self, input: &CircuitMember<'a>) -> ConcreteVisitResult { match self.visitor.visit_circuit_member(input) { VisitResult::VisitChildren => { if let CircuitMember::Function(f) = input { @@ -412,7 +411,7 @@ impl<'a, R: ProgramVisitor<'a>> VisitorDirector<'a, R> { } } - fn visit_circuit(&mut self, input: &'a Circuit<'a>) -> ConcreteVisitResult { + pub fn visit_circuit(&mut self, input: &'a Circuit<'a>) -> ConcreteVisitResult { match self.visitor.visit_circuit(input) { VisitResult::VisitChildren => { for (_, member) in input.members.borrow().iter() { @@ -424,7 +423,7 @@ impl<'a, R: ProgramVisitor<'a>> VisitorDirector<'a, R> { } } - fn visit_program(&mut self, input: &Program<'a>) -> ConcreteVisitResult { + pub fn visit_program(&mut self, input: &Program<'a>) -> ConcreteVisitResult { match self.visitor.visit_program(input) { VisitResult::VisitChildren => { for (_, import) in input.imported_modules.iter() { diff --git a/compiler/Cargo.toml b/compiler/Cargo.toml index 663f987286..b0cf389dc5 100644 --- a/compiler/Cargo.toml +++ b/compiler/Cargo.toml @@ -49,6 +49,10 @@ version = "1.2.3" path = "../parser" version = "1.2.3" +[dependencies.leo-constant-folding] +path = "../asg-passes/constant-folding" +version = "1.2.3" + [dependencies.snarkvm-curves] version = "0.2.0" default-features = false diff --git a/compiler/src/compiler.rs b/compiler/src/compiler.rs index c5f2a3b880..0df4669878 100644 --- a/compiler/src/compiler.rs +++ b/compiler/src/compiler.rs @@ -16,15 +16,9 @@ //! Compiles a Leo program from a file path. -use crate::{ - constraints::{generate_constraints, generate_test_constraints}, - errors::CompilerError, - GroupType, - OutputBytes, - OutputFile, -}; +use crate::{CompilerOptions, GroupType, OutputBytes, OutputFile, constraints::{generate_constraints, generate_test_constraints}, errors::CompilerError}; use indexmap::IndexMap; -use leo_asg::Asg; +use leo_asg::{Asg, AsgPass, FormattedError}; pub use leo_asg::{new_context, AsgContext as Context, AsgContext}; use leo_ast::{Input, LeoError, MainInput, Program}; use leo_input::LeoInputParser; @@ -68,6 +62,7 @@ pub struct Compiler<'a, F: PrimeField, G: GroupType> { context: AsgContext<'a>, asg: Option>, file_contents: RefCell>>>, + options: CompilerOptions, _engine: PhantomData, _group: PhantomData, } @@ -90,6 +85,7 @@ impl<'a, F: PrimeField, G: GroupType> Compiler<'a, F, G> { program_input: Input::new(), asg: None, context, + options: CompilerOptions::default(), file_contents: RefCell::new(IndexMap::new()), _engine: PhantomData, _group: PhantomData, @@ -116,6 +112,10 @@ impl<'a, F: PrimeField, G: GroupType> Compiler<'a, F, G> { Ok(compiler) } + pub fn set_options(&mut self, options: CompilerOptions) { + self.options = options; + } + /// /// Returns a new `Compiler` from the given main file path. /// @@ -251,10 +251,21 @@ impl<'a, F: PrimeField, G: GroupType> Compiler<'a, F, G> { Ok(()) } + fn do_asg_passes(&self) -> Result<(), FormattedError> { + assert!(self.asg.is_some()); + if self.options.constant_folding_enabled { + leo_constant_folding::ConstantFolding::do_pass(self.asg.as_ref().unwrap().as_repr())?; + } + + Ok(()) + } + /// /// Synthesizes the circuit with program input to verify correctness. /// pub fn compile_constraints>(&self, cs: &mut CS) -> Result { + self.do_asg_passes().map_err(CompilerError::AsgPassError)?; + generate_constraints::(cs, &self.asg.as_ref().unwrap(), &self.program_input).map_err(|mut error| { if let Some(path) = error.get_path().map(|x| x.to_string()) { let content = match self.resolve_content(&path) { @@ -271,6 +282,8 @@ impl<'a, F: PrimeField, G: GroupType> Compiler<'a, F, G> { /// Synthesizes the circuit for test functions with program input. /// pub fn compile_test_constraints(self, input_pairs: InputPairs) -> Result<(u32, u32), CompilerError> { + self.do_asg_passes().map_err(CompilerError::AsgPassError)?; + generate_test_constraints::( &self.asg.as_ref().unwrap(), input_pairs, diff --git a/compiler/src/errors/compiler.rs b/compiler/src/errors/compiler.rs index 23a99a96f2..be76b88c7f 100644 --- a/compiler/src/errors/compiler.rs +++ b/compiler/src/errors/compiler.rs @@ -15,7 +15,7 @@ // along with the Leo library. If not, see . use crate::errors::{FunctionError, ImportError, OutputBytesError, OutputFileError}; -use leo_asg::AsgConvertError; +use leo_asg::{AsgConvertError, FormattedError}; use leo_ast::LeoError; use leo_imports::ImportParserError; use leo_input::InputParserError; @@ -30,6 +30,9 @@ pub enum CompilerError { #[error("{}", _0)] SyntaxError(#[from] SyntaxError), + #[error("{}", _0)] + AsgPassError(FormattedError), + #[error("{}", _0)] ImportError(#[from] ImportError), diff --git a/compiler/src/expression/expression.rs b/compiler/src/expression/expression.rs index 45f545daa8..5aeec980df 100644 --- a/compiler/src/expression/expression.rs +++ b/compiler/src/expression/expression.rs @@ -27,13 +27,40 @@ use crate::{ FieldType, GroupType, }; -use leo_asg::{expression::*, ConstValue, Expression, Node}; +use leo_asg::{ConstValue, Expression, Node, Span, expression::*}; use snarkvm_fields::PrimeField; use snarkvm_gadgets::traits::utilities::boolean::Boolean; use snarkvm_r1cs::ConstraintSystem; impl<'a, F: PrimeField, G: GroupType> ConstrainedProgram<'a, F, G> { + pub(crate) fn enforce_const_value>( + &mut self, + cs: &mut CS, + value: &ConstValue, + span: &Span, + ) -> Result, ExpressionError> { + Ok(match value { + ConstValue::Address(value) => ConstrainedValue::Address(Address::constant(value.clone(), span)?), + ConstValue::Boolean(value) => ConstrainedValue::Boolean(Boolean::Constant(*value)), + ConstValue::Field(value) => ConstrainedValue::Field(FieldType::constant(value.to_string(), span)?), + ConstValue::Group(value) => ConstrainedValue::Group(G::constant(value, span)?), + ConstValue::Int(value) => ConstrainedValue::Integer(Integer::new(value)), + ConstValue::Tuple(values) => + ConstrainedValue::Tuple( + values.iter() + .map(|x| self.enforce_const_value(cs, x, span)) + .collect::, _>>()? + ), + ConstValue::Array(values) => + ConstrainedValue::Array( + values.iter() + .map(|x| self.enforce_const_value(cs, x, span)) + .collect::, _>>()? + ), + }) + } + pub(crate) fn enforce_expression>( &mut self, cs: &mut CS, @@ -49,14 +76,7 @@ impl<'a, F: PrimeField, G: GroupType> ConstrainedProgram<'a, F, G> { // Values Expression::Constant(Constant { value, .. }) => { - Ok(match value { - ConstValue::Address(value) => ConstrainedValue::Address(Address::constant(value.clone(), span)?), - ConstValue::Boolean(value) => ConstrainedValue::Boolean(Boolean::Constant(*value)), - ConstValue::Field(value) => ConstrainedValue::Field(FieldType::constant(value.to_string(), span)?), - ConstValue::Group(value) => ConstrainedValue::Group(G::constant(value, span)?), - ConstValue::Int(value) => ConstrainedValue::Integer(Integer::new(value)), - ConstValue::Tuple(_) | ConstValue::Array(_) => unimplemented!(), // shouldnt be in the asg here - }) + self.enforce_const_value(cs, value, span) } // Binary operations diff --git a/compiler/src/lib.rs b/compiler/src/lib.rs index 626733ba3d..d97d1412f5 100644 --- a/compiler/src/lib.rs +++ b/compiler/src/lib.rs @@ -58,3 +58,6 @@ pub use self::value::*; pub mod stage; pub use self::stage::*; + +pub mod option; +pub use self::option::*; diff --git a/compiler/src/option.rs b/compiler/src/option.rs new file mode 100644 index 0000000000..989f574c00 --- /dev/null +++ b/compiler/src/option.rs @@ -0,0 +1,28 @@ +// Copyright (C) 2019-2021 Aleo Systems Inc. +// This file is part of the Leo library. + +// The Leo library is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. + +// The Leo library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. + +// You should have received a copy of the GNU General Public License +// along with the Leo library. If not, see . + +#[derive(Clone)] +pub struct CompilerOptions { + pub constant_folding_enabled: bool, +} + +impl Default for CompilerOptions { + fn default() -> Self { + CompilerOptions { + constant_folding_enabled: true, + } + } +} \ No newline at end of file