Add constant folding pass to compiler, tests, and benches

This commit is contained in:
Pranav Gaddamadugu 2022-07-12 12:44:10 -07:00
parent e9ab5944af
commit 09e0b4670a
7 changed files with 83 additions and 15 deletions

View File

@ -40,6 +40,7 @@ pub mod groups;
pub use self::groups::*;
pub mod input;
pub use self::input::ProgramInput;
pub use self::input::*;
pub mod passes;

View File

@ -17,8 +17,8 @@
//! The compiler for Leo programs.
//!
//! The [`Compiler`] type compiles Leo programs into R1CS circuits.
use leo_ast::Program;
pub use leo_ast::{Ast, InputAst};
use leo_ast::{Program, ProgramInput};
use leo_errors::emitter::Handler;
use leo_errors::{CompilerError, Result};
pub use leo_passes::SymbolTable;
@ -47,7 +47,7 @@ pub struct Compiler<'a> {
pub network: String,
/// The AST for the program.
pub ast: Ast,
/// The input ast for the program if it exists.
/// The program input, if it exists.
pub input_ast: Option<InputAst>,
/// Compiler options on some optional output files.
output_options: OutputOptions,
@ -99,17 +99,13 @@ impl<'a> Compiler<'a> {
ast = ast.set_program_name(self.program_name.clone());
ast = ast.set_network(self.network.clone());
if self.output_options.initial_ast {
// Write the AST snapshot post parsing.
if self.output_options.spans_enabled {
ast.to_json_file(self.output_directory.clone(), "initial_ast.json")?;
} else {
ast.to_json_file_without_keys(self.output_directory.clone(), "initial_ast.json", &["span"])?;
}
}
// Store the AST.
self.ast = ast;
if self.output_options.initial_ast {
self.write_ast_to_json("initial_ast.json")?;
}
Ok(())
}
@ -144,7 +140,7 @@ impl<'a> Compiler<'a> {
}
}
self.input_ast = Some(input_ast);
self.input_ast = Some(input_ast.try_into()?);
}
Ok(())
}
@ -159,6 +155,23 @@ impl<'a> Compiler<'a> {
TypeChecker::do_pass((&self.ast, self.handler, symbol_table))
}
/// Runs the constant folding pass.
pub fn constant_folding_pass(&mut self, symbol_table: SymbolTable) -> Result<SymbolTable> {
let (ast, symbol_table) = ConstantFolder::do_pass((
std::mem::take(&mut self.ast),
self.handler,
symbol_table,
self.input_ast.as_ref().map(|i| &i.constants),
))?;
self.ast = ast;
if self.output_options.constant_folded_ast {
self.write_ast_to_json("constant_folded_ast.json")?;
}
Ok(symbol_table)
}
/// Runs the loop unrolling pass.
pub fn loop_unrolling_pass(&mut self, symbol_table: SymbolTable) -> Result<SymbolTable> {
let (ast, symbol_table) = LoopUnroller::do_pass((std::mem::take(&mut self.ast), self.handler, symbol_table))?;
@ -176,6 +189,9 @@ impl<'a> Compiler<'a> {
let st = self.symbol_table_pass()?;
let st = self.type_checker_pass(st)?;
// TODO: Make this pass optional.
let st = self.constant_folding_pass(st)?;
// TODO: Make this pass optional.
let st = self.loop_unrolling_pass(st)?;
Ok(st)

View File

@ -22,6 +22,8 @@ pub struct OutputOptions {
pub initial_ast: bool,
/// If enabled writes the input AST after parsing.
pub initial_input_ast: bool,
/// If enabled, writes the AST after constant folding.
pub constant_folded_ast: bool,
/// If enabled writes the AST after loop unrolling.
pub unrolled_ast: bool,
}

View File

@ -50,6 +50,7 @@ fn new_compiler(handler: &Handler, main_file_path: PathBuf) -> Compiler<'_> {
spans_enabled: false,
initial_input_ast: true,
initial_ast: true,
constant_folded_ast: true,
unrolled_ast: true,
}),
)
@ -105,6 +106,7 @@ struct OutputItem {
struct CompileOutput {
pub output: Vec<OutputItem>,
pub initial_ast: String,
pub constant_folded_ast: String,
pub unrolled_ast: String,
}
@ -138,6 +140,7 @@ fn collect_all_inputs(test: &Test) -> Result<Vec<PathBuf>, String> {
fn compile_and_process<'a>(parsed: &'a mut Compiler<'a>) -> Result<SymbolTable, LeoError> {
let st = parsed.symbol_table_pass()?;
let st = parsed.type_checker_pass(st)?;
let st = parsed.constant_folding_pass(st)?;
let st = parsed.loop_unrolling_pass(st)?;
Ok(st)
}
@ -218,6 +221,7 @@ fn run_test(test: Test, handler: &Handler, err_buf: &BufferEmitter) -> Result<Va
}
let initial_ast = hash_file("/tmp/output/initial_ast.json");
let constant_folded_ast = hash_file("/tmp/output/constant_folded_ast.json");
let unrolled_ast = hash_file("/tmp/output/unrolled_ast.json");
if fs::read_dir("/tmp/output").is_ok() {
@ -227,6 +231,7 @@ fn run_test(test: Test, handler: &Handler, err_buf: &BufferEmitter) -> Result<Va
let final_output = CompileOutput {
output: output_items,
initial_ast,
constant_folded_ast,
unrolled_ast,
};
Ok(serde_yaml::to_value(&final_output).expect("serialization failed"))

View File

@ -33,7 +33,7 @@ use crate::{Pass, SymbolTable};
impl<'a> Pass for ConstantFolder<'a> {
type Input = (Ast, &'a Handler, SymbolTable, Option<&'a Definitions>);
type Output = Result<Ast>;
type Output = Result<(Ast, SymbolTable)>;
fn do_pass((ast, handler, st, input_consts): Self::Input) -> Self::Output {
// Reconstructs the AST based off any flattening work that is done.
@ -41,6 +41,6 @@ impl<'a> Pass for ConstantFolder<'a> {
let program = reconstructor.reconstruct_program(ast.into_repr());
handler.last_err()?;
Ok(Ast::new(program))
Ok((Ast::new(program), reconstructor.symbol_table.take()))
}
}

View File

@ -50,6 +50,8 @@ pub struct BuildOptions {
pub enable_initial_input_ast_snapshot: bool,
#[structopt(long, help = "Writes AST snapshot of the initial parse.")]
pub enable_initial_ast_snapshot: bool,
#[structopt(long, help = "Writes AST snapshot of the constant folded AST.")]
pub enable_constant_folded_snapshot: bool,
#[structopt(long, help = "Writes AST snapshot of the unrolled AST.")]
pub enable_unrolled_ast_snapshot: bool,
// Note: This is currently made optional since code generation is just a prototype.
@ -66,11 +68,13 @@ impl From<BuildOptions> for OutputOptions {
spans_enabled: options.enable_spans,
initial_input_ast: options.enable_initial_input_ast_snapshot,
initial_ast: options.enable_initial_ast_snapshot,
constant_folded_ast: options.enable_constant_folded_snapshot,
unrolled_ast: options.enable_unrolled_ast_snapshot,
};
if options.enable_all_ast_snapshots {
out_options.initial_input_ast = true;
out_options.initial_ast = true;
out_options.constant_folded_ast = true;
out_options.unrolled_ast = true;
}
@ -111,7 +115,7 @@ impl Command for Build {
let build_directory = BuildDirectory::open(&package_path)?;
// Initialize error handler
let handler = leo_errors::emitter::Handler::default();
let handler = Handler::default();
// Fetch paths to all .leo files in the source directory.
let source_files = SourceDirectory::files(&package_path)?;

View File

@ -38,6 +38,8 @@ enum BenchMode {
Symbol,
/// Benchmarks type checking.
Type,
/// Benchmarks constant folding
Fold,
/// Benchmarks loop unrolling.
Unroll,
/// Benchmarks all the above stages.
@ -102,6 +104,7 @@ impl Sample {
BenchMode::Parse => self.bench_parse(c),
BenchMode::Symbol => self.bench_symbol_table(c),
BenchMode::Type => self.bench_type_checker(c),
BenchMode::Fold => self.bench_constant_folder(c),
BenchMode::Unroll => self.bench_loop_unroller(c),
BenchMode::Full => self.bench_full(c),
}
@ -180,6 +183,35 @@ impl Sample {
});
}
fn bench_constant_folder(&self, c: &mut Criterion) {
c.bench_function(&format!("loop unrolling pass{}", self.name), |b| {
// Iter custom is used so we can use custom timings around the compiler stages.
// This way we can only time the necessary stage.
b.iter_custom(|iters| {
let mut time = Duration::default();
for _ in 0..iters {
SESSION_GLOBALS.set(&SessionGlobals::default(), || {
let handler = BufEmitter::new_handler();
let mut compiler = new_compiler(&handler);
let (input, name) = self.data();
compiler
.parse_program_from_string(input, name)
.expect("Failed to parse program");
let symbol_table = compiler.symbol_table_pass().expect("failed to generate symbol table");
let symbol_table = compiler
.type_checker_pass(symbol_table)
.expect("failed to run type check pass");
let start = Instant::now();
let out = compiler.constant_folding_pass(symbol_table);
time += start.elapsed();
out.expect("failed to run constant folding pass")
});
}
time
})
});
}
fn bench_loop_unroller(&self, c: &mut Criterion) {
c.bench_function(&format!("loop unrolling pass{}", self.name), |b| {
// Iter custom is used so we can use custom timings around the compiler stages.
@ -198,6 +230,9 @@ impl Sample {
let symbol_table = compiler
.type_checker_pass(symbol_table)
.expect("failed to run type check pass");
let symbol_table = compiler
.constant_folding_pass(symbol_table)
.expect("failed to run constant folding pass");
let start = Instant::now();
let out = compiler.loop_unrolling_pass(symbol_table);
time += start.elapsed();
@ -228,6 +263,9 @@ impl Sample {
let symbol_table = compiler
.type_checker_pass(symbol_table)
.expect("failed to run type check pass");
let symbol_table = compiler
.constant_folding_pass(symbol_table)
.expect("failed to run constant folding pass");
compiler
.loop_unrolling_pass(symbol_table)
.expect("failed to run loop unrolling pass");
@ -251,6 +289,7 @@ macro_rules! bench {
bench!(bench_parse, BenchMode::Parse);
bench!(bench_symbol, BenchMode::Symbol);
bench!(bench_type, BenchMode::Type);
bench!(bench_fold, BenchMode::Fold);
bench!(bench_unroll, BenchMode::Unroll);
bench!(bench_full, BenchMode::Full);
@ -261,6 +300,7 @@ criterion_group!(
bench_parse,
bench_symbol,
bench_type,
bench_fold,
bench_unroll,
bench_full
);