mirror of
https://github.com/ProvableHQ/leo.git
synced 2024-12-24 02:31:44 +03:00
constant folding
This commit is contained in:
parent
d4ed69830a
commit
9156068801
8
Cargo.lock
generated
8
Cargo.lock
generated
@ -1226,6 +1226,7 @@ dependencies = [
|
||||
"indexmap",
|
||||
"leo-asg",
|
||||
"leo-ast",
|
||||
"leo-constant-folding",
|
||||
"leo-gadgets",
|
||||
"leo-imports",
|
||||
"leo-input",
|
||||
@ -1250,6 +1251,13 @@ dependencies = [
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "leo-constant-folding"
|
||||
version = "1.2.3"
|
||||
dependencies = [
|
||||
"leo-asg",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "leo-gadgets"
|
||||
version = "1.2.3"
|
||||
|
@ -37,6 +37,7 @@ members = [
|
||||
"parser",
|
||||
"state",
|
||||
"synthesizer",
|
||||
"asg-passes/constant-folding",
|
||||
]
|
||||
|
||||
[dependencies.leo-ast]
|
||||
|
25
asg-passes/constant-folding/Cargo.toml
Normal file
25
asg-passes/constant-folding/Cargo.toml
Normal file
@ -0,0 +1,25 @@
|
||||
[package]
|
||||
name = "leo-constant-folding"
|
||||
version = "1.2.3"
|
||||
authors = [ "The Aleo Team <hello@aleo.org>" ]
|
||||
description = "The Leo programming language"
|
||||
homepage = "https://aleo.org"
|
||||
repository = "https://github.com/AleoHQ/leo"
|
||||
keywords = [
|
||||
"aleo",
|
||||
"cryptography",
|
||||
"leo",
|
||||
"programming-language",
|
||||
"zero-knowledge"
|
||||
]
|
||||
categories = [ "cryptography::cryptocurrencies", "web-programming" ]
|
||||
include = [ "Cargo.toml", "leo", "README.md", "LICENSE.md" ]
|
||||
license = "GPL-3.0"
|
||||
edition = "2018"
|
||||
|
||||
[lib]
|
||||
path = "src/lib.rs"
|
||||
|
||||
[dependencies.leo-asg]
|
||||
path = "../../asg"
|
||||
version = "1.2.0"
|
56
asg-passes/constant-folding/src/lib.rs
Normal file
56
asg-passes/constant-folding/src/lib.rs
Normal file
@ -0,0 +1,56 @@
|
||||
// Copyright (C) 2019-2021 Aleo Systems Inc.
|
||||
// This file is part of the Leo library.
|
||||
|
||||
// The Leo library is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
|
||||
// The Leo library is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
use std::{cell::Cell};
|
||||
|
||||
use leo_asg::*;
|
||||
|
||||
pub struct ConstantFolding<'a, 'b> {
|
||||
program: &'b Program<'a>
|
||||
}
|
||||
|
||||
impl<'a, 'b> ExpressionVisitor<'a> for ConstantFolding<'a, 'b> {
|
||||
fn visit_expression(&mut self, input: &Cell<&Expression<'a>>) -> VisitResult {
|
||||
let expr = input.get();
|
||||
if let Some(const_value) = expr.const_value() {
|
||||
let folded_expr = Expression::Constant(Constant {
|
||||
parent: Cell::new(expr.get_parent()),
|
||||
span: expr.span().cloned(),
|
||||
value: const_value,
|
||||
});
|
||||
let folded_expr = self.program.scope.alloc_expression(folded_expr);
|
||||
input.set(folded_expr);
|
||||
VisitResult::SkipChildren
|
||||
} else {
|
||||
VisitResult::VisitChildren
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, 'b> StatementVisitor<'a> for ConstantFolding<'a, 'b> {}
|
||||
|
||||
impl<'a, 'b> ProgramVisitor<'a> for ConstantFolding<'a, 'b> {}
|
||||
|
||||
impl<'a, 'b> AsgPass<'a> for ConstantFolding<'a, 'b> {
|
||||
fn do_pass(asg: &Program<'a>) -> Result<(), FormattedError> {
|
||||
let pass = ConstantFolding {
|
||||
program: asg,
|
||||
};
|
||||
let mut director = VisitorDirector::new(pass);
|
||||
director.visit_program(asg).ok();
|
||||
Ok(())
|
||||
}
|
||||
}
|
@ -80,10 +80,14 @@ impl<'a> ExpressionNode<'a> for VariableRef<'a> {
|
||||
|
||||
value.get().const_value()
|
||||
} else {
|
||||
for defined_variable in variables.iter() {
|
||||
for (i, defined_variable) in variables.iter().enumerate() {
|
||||
let defined_variable = defined_variable.borrow();
|
||||
if defined_variable.id == variable.id {
|
||||
return value.get().const_value();
|
||||
match value.get().const_value() {
|
||||
Some(ConstValue::Tuple(values)) => return values.get(i).cloned(),
|
||||
None => return None,
|
||||
_ => (),
|
||||
}
|
||||
}
|
||||
}
|
||||
panic!("no corresponding tuple variable found during const destructuring (corrupt asg?)");
|
||||
|
@ -102,8 +102,8 @@ impl<'a> Asg<'a> {
|
||||
}
|
||||
|
||||
/// Returns the internal program ASG representation.
|
||||
pub fn as_repr(&self) -> Program<'a> {
|
||||
self.asg.clone()
|
||||
pub fn as_repr(&self) -> &Program<'a> {
|
||||
&self.asg
|
||||
}
|
||||
|
||||
// /// Serializes the ast into a JSON string.
|
||||
|
@ -17,6 +17,6 @@
|
||||
use crate::Program;
|
||||
pub use leo_ast::FormattedError;
|
||||
|
||||
pub trait AsgPass {
|
||||
fn do_pass(asg: &Program) -> Result<(), FormattedError>;
|
||||
pub trait AsgPass<'a> {
|
||||
fn do_pass(asg: &Program<'a>) -> Result<(), FormattedError>;
|
||||
}
|
||||
|
@ -388,9 +388,8 @@ impl<'a, R: StatementVisitor<'a>> VisitorDirector<'a, R> {
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
impl<'a, R: ProgramVisitor<'a>> VisitorDirector<'a, R> {
|
||||
fn visit_function(&mut self, input: &'a Function<'a>) -> ConcreteVisitResult {
|
||||
pub fn visit_function(&mut self, input: &'a Function<'a>) -> ConcreteVisitResult {
|
||||
match self.visitor.visit_function(input) {
|
||||
VisitResult::VisitChildren => {
|
||||
self.visit_opt_statement(&input.body)?;
|
||||
@ -400,7 +399,7 @@ impl<'a, R: ProgramVisitor<'a>> VisitorDirector<'a, R> {
|
||||
}
|
||||
}
|
||||
|
||||
fn visit_circuit_member(&mut self, input: &CircuitMember<'a>) -> ConcreteVisitResult {
|
||||
pub fn visit_circuit_member(&mut self, input: &CircuitMember<'a>) -> ConcreteVisitResult {
|
||||
match self.visitor.visit_circuit_member(input) {
|
||||
VisitResult::VisitChildren => {
|
||||
if let CircuitMember::Function(f) = input {
|
||||
@ -412,7 +411,7 @@ impl<'a, R: ProgramVisitor<'a>> VisitorDirector<'a, R> {
|
||||
}
|
||||
}
|
||||
|
||||
fn visit_circuit(&mut self, input: &'a Circuit<'a>) -> ConcreteVisitResult {
|
||||
pub fn visit_circuit(&mut self, input: &'a Circuit<'a>) -> ConcreteVisitResult {
|
||||
match self.visitor.visit_circuit(input) {
|
||||
VisitResult::VisitChildren => {
|
||||
for (_, member) in input.members.borrow().iter() {
|
||||
@ -424,7 +423,7 @@ impl<'a, R: ProgramVisitor<'a>> VisitorDirector<'a, R> {
|
||||
}
|
||||
}
|
||||
|
||||
fn visit_program(&mut self, input: &Program<'a>) -> ConcreteVisitResult {
|
||||
pub fn visit_program(&mut self, input: &Program<'a>) -> ConcreteVisitResult {
|
||||
match self.visitor.visit_program(input) {
|
||||
VisitResult::VisitChildren => {
|
||||
for (_, import) in input.imported_modules.iter() {
|
||||
|
@ -49,6 +49,10 @@ version = "1.2.3"
|
||||
path = "../parser"
|
||||
version = "1.2.3"
|
||||
|
||||
[dependencies.leo-constant-folding]
|
||||
path = "../asg-passes/constant-folding"
|
||||
version = "1.2.3"
|
||||
|
||||
[dependencies.snarkvm-curves]
|
||||
version = "0.2.0"
|
||||
default-features = false
|
||||
|
@ -16,15 +16,9 @@
|
||||
|
||||
//! Compiles a Leo program from a file path.
|
||||
|
||||
use crate::{
|
||||
constraints::{generate_constraints, generate_test_constraints},
|
||||
errors::CompilerError,
|
||||
GroupType,
|
||||
OutputBytes,
|
||||
OutputFile,
|
||||
};
|
||||
use crate::{CompilerOptions, GroupType, OutputBytes, OutputFile, constraints::{generate_constraints, generate_test_constraints}, errors::CompilerError};
|
||||
use indexmap::IndexMap;
|
||||
use leo_asg::Asg;
|
||||
use leo_asg::{Asg, AsgPass, FormattedError};
|
||||
pub use leo_asg::{new_context, AsgContext as Context, AsgContext};
|
||||
use leo_ast::{Input, LeoError, MainInput, Program};
|
||||
use leo_input::LeoInputParser;
|
||||
@ -68,6 +62,7 @@ pub struct Compiler<'a, F: PrimeField, G: GroupType<F>> {
|
||||
context: AsgContext<'a>,
|
||||
asg: Option<Asg<'a>>,
|
||||
file_contents: RefCell<IndexMap<String, Rc<Vec<String>>>>,
|
||||
options: CompilerOptions,
|
||||
_engine: PhantomData<F>,
|
||||
_group: PhantomData<G>,
|
||||
}
|
||||
@ -90,6 +85,7 @@ impl<'a, F: PrimeField, G: GroupType<F>> Compiler<'a, F, G> {
|
||||
program_input: Input::new(),
|
||||
asg: None,
|
||||
context,
|
||||
options: CompilerOptions::default(),
|
||||
file_contents: RefCell::new(IndexMap::new()),
|
||||
_engine: PhantomData,
|
||||
_group: PhantomData,
|
||||
@ -116,6 +112,10 @@ impl<'a, F: PrimeField, G: GroupType<F>> Compiler<'a, F, G> {
|
||||
Ok(compiler)
|
||||
}
|
||||
|
||||
pub fn set_options(&mut self, options: CompilerOptions) {
|
||||
self.options = options;
|
||||
}
|
||||
|
||||
///
|
||||
/// Returns a new `Compiler` from the given main file path.
|
||||
///
|
||||
@ -251,10 +251,21 @@ impl<'a, F: PrimeField, G: GroupType<F>> Compiler<'a, F, G> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn do_asg_passes(&self) -> Result<(), FormattedError> {
|
||||
assert!(self.asg.is_some());
|
||||
if self.options.constant_folding_enabled {
|
||||
leo_constant_folding::ConstantFolding::do_pass(self.asg.as_ref().unwrap().as_repr())?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
///
|
||||
/// Synthesizes the circuit with program input to verify correctness.
|
||||
///
|
||||
pub fn compile_constraints<CS: ConstraintSystem<F>>(&self, cs: &mut CS) -> Result<OutputBytes, CompilerError> {
|
||||
self.do_asg_passes().map_err(CompilerError::AsgPassError)?;
|
||||
|
||||
generate_constraints::<F, G, CS>(cs, &self.asg.as_ref().unwrap(), &self.program_input).map_err(|mut error| {
|
||||
if let Some(path) = error.get_path().map(|x| x.to_string()) {
|
||||
let content = match self.resolve_content(&path) {
|
||||
@ -271,6 +282,8 @@ impl<'a, F: PrimeField, G: GroupType<F>> Compiler<'a, F, G> {
|
||||
/// Synthesizes the circuit for test functions with program input.
|
||||
///
|
||||
pub fn compile_test_constraints(self, input_pairs: InputPairs) -> Result<(u32, u32), CompilerError> {
|
||||
self.do_asg_passes().map_err(CompilerError::AsgPassError)?;
|
||||
|
||||
generate_test_constraints::<F, G>(
|
||||
&self.asg.as_ref().unwrap(),
|
||||
input_pairs,
|
||||
|
@ -15,7 +15,7 @@
|
||||
// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
use crate::errors::{FunctionError, ImportError, OutputBytesError, OutputFileError};
|
||||
use leo_asg::AsgConvertError;
|
||||
use leo_asg::{AsgConvertError, FormattedError};
|
||||
use leo_ast::LeoError;
|
||||
use leo_imports::ImportParserError;
|
||||
use leo_input::InputParserError;
|
||||
@ -30,6 +30,9 @@ pub enum CompilerError {
|
||||
#[error("{}", _0)]
|
||||
SyntaxError(#[from] SyntaxError),
|
||||
|
||||
#[error("{}", _0)]
|
||||
AsgPassError(FormattedError),
|
||||
|
||||
#[error("{}", _0)]
|
||||
ImportError(#[from] ImportError),
|
||||
|
||||
|
@ -27,13 +27,40 @@ use crate::{
|
||||
FieldType,
|
||||
GroupType,
|
||||
};
|
||||
use leo_asg::{expression::*, ConstValue, Expression, Node};
|
||||
use leo_asg::{ConstValue, Expression, Node, Span, expression::*};
|
||||
|
||||
use snarkvm_fields::PrimeField;
|
||||
use snarkvm_gadgets::traits::utilities::boolean::Boolean;
|
||||
use snarkvm_r1cs::ConstraintSystem;
|
||||
|
||||
impl<'a, F: PrimeField, G: GroupType<F>> ConstrainedProgram<'a, F, G> {
|
||||
pub(crate) fn enforce_const_value<CS: ConstraintSystem<F>>(
|
||||
&mut self,
|
||||
cs: &mut CS,
|
||||
value: &ConstValue,
|
||||
span: &Span,
|
||||
) -> Result<ConstrainedValue<'a, F, G>, ExpressionError> {
|
||||
Ok(match value {
|
||||
ConstValue::Address(value) => ConstrainedValue::Address(Address::constant(value.clone(), span)?),
|
||||
ConstValue::Boolean(value) => ConstrainedValue::Boolean(Boolean::Constant(*value)),
|
||||
ConstValue::Field(value) => ConstrainedValue::Field(FieldType::constant(value.to_string(), span)?),
|
||||
ConstValue::Group(value) => ConstrainedValue::Group(G::constant(value, span)?),
|
||||
ConstValue::Int(value) => ConstrainedValue::Integer(Integer::new(value)),
|
||||
ConstValue::Tuple(values) =>
|
||||
ConstrainedValue::Tuple(
|
||||
values.iter()
|
||||
.map(|x| self.enforce_const_value(cs, x, span))
|
||||
.collect::<Result<Vec<_>, _>>()?
|
||||
),
|
||||
ConstValue::Array(values) =>
|
||||
ConstrainedValue::Array(
|
||||
values.iter()
|
||||
.map(|x| self.enforce_const_value(cs, x, span))
|
||||
.collect::<Result<Vec<_>, _>>()?
|
||||
),
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn enforce_expression<CS: ConstraintSystem<F>>(
|
||||
&mut self,
|
||||
cs: &mut CS,
|
||||
@ -49,14 +76,7 @@ impl<'a, F: PrimeField, G: GroupType<F>> ConstrainedProgram<'a, F, G> {
|
||||
|
||||
// Values
|
||||
Expression::Constant(Constant { value, .. }) => {
|
||||
Ok(match value {
|
||||
ConstValue::Address(value) => ConstrainedValue::Address(Address::constant(value.clone(), span)?),
|
||||
ConstValue::Boolean(value) => ConstrainedValue::Boolean(Boolean::Constant(*value)),
|
||||
ConstValue::Field(value) => ConstrainedValue::Field(FieldType::constant(value.to_string(), span)?),
|
||||
ConstValue::Group(value) => ConstrainedValue::Group(G::constant(value, span)?),
|
||||
ConstValue::Int(value) => ConstrainedValue::Integer(Integer::new(value)),
|
||||
ConstValue::Tuple(_) | ConstValue::Array(_) => unimplemented!(), // shouldnt be in the asg here
|
||||
})
|
||||
self.enforce_const_value(cs, value, span)
|
||||
}
|
||||
|
||||
// Binary operations
|
||||
|
@ -58,3 +58,6 @@ pub use self::value::*;
|
||||
|
||||
pub mod stage;
|
||||
pub use self::stage::*;
|
||||
|
||||
pub mod option;
|
||||
pub use self::option::*;
|
||||
|
28
compiler/src/option.rs
Normal file
28
compiler/src/option.rs
Normal file
@ -0,0 +1,28 @@
|
||||
// Copyright (C) 2019-2021 Aleo Systems Inc.
|
||||
// This file is part of the Leo library.
|
||||
|
||||
// The Leo library is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
|
||||
// The Leo library is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct CompilerOptions {
|
||||
pub constant_folding_enabled: bool,
|
||||
}
|
||||
|
||||
impl Default for CompilerOptions {
|
||||
fn default() -> Self {
|
||||
CompilerOptions {
|
||||
constant_folding_enabled: true,
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user