From c798635e2936160483a73cc6cc1b1e872af56371 Mon Sep 17 00:00:00 2001 From: collin Date: Mon, 23 Mar 2020 15:19:47 -0700 Subject: [PATCH] basic program compilation from ast --- Cargo.toml | 8 + simple.program | 3 +- src/ast.rs | 443 +++++++++++++++++++++++++++++++++++ src/lib.rs | 14 ++ src/main.rs | 431 ++++------------------------------ src/program/mod.rs | 11 + src/program/program.rs | 90 +++++++ src/program/types.rs | 57 +++++ src/program/types_display.rs | 58 +++++ src/program/types_from.rs | 178 ++++++++++++++ 10 files changed, 900 insertions(+), 393 deletions(-) create mode 100644 src/ast.rs create mode 100644 src/lib.rs create mode 100644 src/program/mod.rs create mode 100644 src/program/program.rs create mode 100644 src/program/types.rs create mode 100644 src/program/types_display.rs create mode 100644 src/program/types_from.rs diff --git a/Cargo.toml b/Cargo.toml index 3cc33393c3..38769dc8ae 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,6 +4,14 @@ version = "0.1.0" authors = ["howardwu "] edition = "2018" +[lib] +name = "language" +path = "src/lib.rs" + +[[bin]] +name = "snarkLang" +path = "src/main.rs" + [dependencies] from-pest = "0.3.1" lazy_static = "1.3.0" diff --git a/simple.program b/simple.program index 9d815113b8..ee7041432a 100644 --- a/simple.program +++ b/simple.program @@ -1,2 +1 @@ -x = 5 + 3 -y = x * (x * 2) +return a + b \ No newline at end of file diff --git a/src/ast.rs b/src/ast.rs new file mode 100644 index 0000000000..3ed1236d5b --- /dev/null +++ b/src/ast.rs @@ -0,0 +1,443 @@ +use from_pest::{ConversionError, FromPest, Void}; +use pest::{ + error::Error, + iterators::{Pair, Pairs}, + prec_climber::{Assoc, Operator, PrecClimber}, + Parser, Span, +}; +use pest_ast::FromPest; +use std::fmt; + +#[derive(Parser)] +#[grammar = "language.pest"] +pub struct LanguageParser; + +pub fn parse(input: &str) -> Result, Error> { + LanguageParser::parse(Rule::file, input) +} + +fn span_into_string(span: Span) -> String { + span.as_str().to_string() +} + +lazy_static! { + static ref PRECEDENCE_CLIMBER: PrecClimber = precedence_climber(); +} + +fn precedence_climber() -> PrecClimber { + PrecClimber::new(vec![ + Operator::new(Rule::operation_or, Assoc::Left), + Operator::new(Rule::operation_and, Assoc::Left), + Operator::new(Rule::operation_eq, Assoc::Left) + | Operator::new(Rule::operation_neq, Assoc::Left), + Operator::new(Rule::operation_geq, Assoc::Left) + | Operator::new(Rule::operation_gt, Assoc::Left) + | Operator::new(Rule::operation_leq, Assoc::Left) + | Operator::new(Rule::operation_lt, Assoc::Left), + Operator::new(Rule::operation_add, Assoc::Left) + | Operator::new(Rule::operation_sub, Assoc::Left), + Operator::new(Rule::operation_mul, Assoc::Left) + | Operator::new(Rule::operation_div, Assoc::Left), + Operator::new(Rule::operation_pow, Assoc::Left), + ]) +} + +fn parse_term(pair: Pair) -> Box { + Box::new(match pair.as_rule() { + Rule::expression_term => { + let clone = pair.clone(); + let next = clone.into_inner().next().unwrap(); + match next.as_rule() { + Rule::expression_primitive => { + let next = next.into_inner().next().unwrap(); + match next.as_rule() { + Rule::value => Expression::Value( + Value::from_pest(&mut pair.into_inner().next().unwrap().into_inner()).unwrap() + ), + Rule::variable => Expression::Variable( + Variable::from_pest(&mut pair.into_inner().next().unwrap().into_inner()).unwrap(), + ), + rule => unreachable!("`expression_primitive` should contain one of [`value`, `variable`], found {:#?}", rule) + } + } + Rule::expression_not => { + let span = next.as_span(); + let mut inner = next.into_inner(); + let operation = match inner.next().unwrap().as_rule() { + Rule::operation_pre_not => Not::from_pest(&mut pair.into_inner().next().unwrap().into_inner()).unwrap(), + rule => unreachable!("`expression_not` should yield `operation_pre_not`, found {:#?}", rule) + }; + let expression = parse_term(inner.next().unwrap()); + Expression::Not(NotExpression { operation, expression, span }) + }, + Rule::expression => Expression::from_pest(&mut pair.into_inner()).unwrap(), // Parenthesis case + + // Rule::expression_increment => { + // let span = next.as_span(); + // let mut inner = next.into_inner(); + // let expression = parse_expression_term(inner.next().unwrap()); + // let operation = match inner.next().unwrap().as_rule() { + // Rule::operation_post_increment => Increment::from_pest(&mut pair.into_inner().next().unwrap().into_inner()).unwrap(), + // rule => unreachable!("`expression_increment` should yield `operation_post_increment`, found {:#?}", rule) + // }; + // Expression::Increment(IncrementExpression { operation, expression, span }) + // }, + // Rule::expression_decrement => { + // let span = next.as_span(); + // let mut inner = next.into_inner(); + // let expression = parse_expression_term(inner.next().unwrap()); + // let operation = match inner.next().unwrap().as_rule() { + // Rule::operation_post_decrement => Decrement::from_pest(&mut pair.into_inner().next().unwrap().into_inner()).unwrap(), + // rule => unreachable!("`expression_decrement` should yield `operation_post_decrement`, found {:#?}", rule) + // }; + // Expression::Decrement(DecrementExpression { operation, expression, span }) + // }, + + rule => unreachable!("`term` should contain one of ['value', 'variable', 'expression', 'expression_not', 'expression_increment', 'expression_decrement'], found {:#?}", rule) + } + } + rule => unreachable!( + "`parse_expression_term` should be invoked on `Rule::expression_term`, found {:#?}", + rule + ), + }) +} + +fn binary_expression<'ast>( + lhs: Box>, + pair: Pair<'ast, Rule>, + rhs: Box>, +) -> Box> { + let (start, _) = lhs.span().clone().split(); + let (_, end) = rhs.span().clone().split(); + let span = start.span(&end); + + Box::new(match pair.as_rule() { + Rule::operation_or => Expression::binary(BinaryOperator::Or, lhs, rhs, span), + Rule::operation_and => Expression::binary(BinaryOperator::And, lhs, rhs, span), + Rule::operation_eq => Expression::binary(BinaryOperator::Eq, lhs, rhs, span), + Rule::operation_neq => Expression::binary(BinaryOperator::Neq, lhs, rhs, span), + Rule::operation_geq => Expression::binary(BinaryOperator::Geq, lhs, rhs, span), + Rule::operation_gt => Expression::binary(BinaryOperator::Gt, lhs, rhs, span), + Rule::operation_leq => Expression::binary(BinaryOperator::Leq, lhs, rhs, span), + Rule::operation_lt => Expression::binary(BinaryOperator::Lt, lhs, rhs, span), + Rule::operation_add => Expression::binary(BinaryOperator::Add, lhs, rhs, span), + Rule::operation_sub => Expression::binary(BinaryOperator::Sub, lhs, rhs, span), + Rule::operation_mul => Expression::binary(BinaryOperator::Mul, lhs, rhs, span), + Rule::operation_div => Expression::binary(BinaryOperator::Div, lhs, rhs, span), + Rule::operation_pow => Expression::binary(BinaryOperator::Pow, lhs, rhs, span), + _ => unreachable!(), + }) +} + +#[derive(Clone, Debug, FromPest, PartialEq)] +#[pest_ast(rule(Rule::file))] +pub struct File<'ast> { + pub statement: Vec>, + pub eoi: EOI, + #[pest_ast(outer())] + pub span: Span<'ast>, +} + +// Visibility + +#[derive(Clone, Debug, FromPest, PartialEq)] +#[pest_ast(rule(Rule::visibility_public))] +pub struct Public {} + +#[derive(Clone, Debug, FromPest, PartialEq)] +#[pest_ast(rule(Rule::visibility_private))] +pub struct Private {} + +#[derive(Clone, Debug, FromPest, PartialEq)] +#[pest_ast(rule(Rule::visibility))] +pub enum Visibility { + Public(Public), + Private(Private), +} + +// Unary Operations + +#[derive(Clone, Debug, FromPest, PartialEq)] +#[pest_ast(rule(Rule::operation_pre_not))] +pub struct Not<'ast> { + #[pest_ast(outer())] + pub span: Span<'ast>, +} + +#[derive(Clone, Debug, FromPest, PartialEq)] +#[pest_ast(rule(Rule::operation_post_increment))] +pub struct Increment<'ast> { + #[pest_ast(outer())] + pub span: Span<'ast>, +} + +#[derive(Clone, Debug, FromPest, PartialEq)] +#[pest_ast(rule(Rule::operation_post_decrement))] +pub struct Decrement<'ast> { + #[pest_ast(outer())] + pub span: Span<'ast>, +} + +// Binary Operations + +#[derive(Debug, PartialEq, Clone)] +pub enum BinaryOperator { + Or, + And, + Eq, + Neq, + Geq, + Gt, + Leq, + Lt, + Add, + Sub, + Mul, + Div, + Pow, +} + +// Values + +#[derive(Clone, Debug, FromPest, PartialEq)] +#[pest_ast(rule(Rule::value_boolean))] +pub struct Boolean<'ast> { + #[pest_ast(outer(with(span_into_string)))] + pub value: String, + #[pest_ast(outer())] + pub span: Span<'ast>, +} + +impl<'ast> fmt::Display for Boolean<'ast> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self.value) + } +} + +#[derive(Clone, Debug, FromPest, PartialEq)] +#[pest_ast(rule(Rule::value_field))] +pub struct Field<'ast> { + #[pest_ast(outer(with(span_into_string)))] + pub value: String, + #[pest_ast(outer())] + pub span: Span<'ast>, +} + +impl<'ast> fmt::Display for Field<'ast> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self.value) + } +} + +#[derive(Clone, Debug, FromPest, PartialEq)] +#[pest_ast(rule(Rule::value))] +pub enum Value<'ast> { + Boolean(Boolean<'ast>), + Field(Field<'ast>), +} + +impl<'ast> Value<'ast> { + pub fn span(&self) -> &Span<'ast> { + match self { + Value::Boolean(value) => &value.span, + Value::Field(value) => &value.span, + } + } +} + +impl<'ast> fmt::Display for Value<'ast> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + Value::Boolean(ref value) => write!(f, "{}", value), + Value::Field(ref value) => write!(f, "{}", value), + } + } +} + +// Variables + +#[derive(Clone, Debug, FromPest, PartialEq)] +#[pest_ast(rule(Rule::variable))] +pub struct Variable<'ast> { + #[pest_ast(outer(with(span_into_string)))] + pub value: String, + #[pest_ast(outer())] + pub span: Span<'ast>, +} + +impl<'ast> fmt::Display for Variable<'ast> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self.value) + } +} + +// Expressions + +#[derive(Clone, Debug, FromPest, PartialEq)] +#[pest_ast(rule(Rule::expression_not))] +pub struct NotExpression<'ast> { + pub operation: Not<'ast>, + pub expression: Box>, + #[pest_ast(outer())] + pub span: Span<'ast>, +} + +// #[derive(Clone, Debug, FromPest, PartialEq)] +// #[pest_ast(rule(Rule::expression_increment))] +// pub struct IncrementExpression<'ast> { +// pub expression: Box>, +// pub operation: Increment<'ast>, +// #[pest_ast(outer())] +// pub span: Span<'ast>, +// } +// +// #[derive(Clone, Debug, FromPest, PartialEq)] +// #[pest_ast(rule(Rule::expression_decrement))] +// pub struct DecrementExpression<'ast> { +// pub expression: Box>, +// pub operation: Decrement<'ast>, +// #[pest_ast(outer())] +// pub span: Span<'ast>, +// } + +#[derive(Clone, Debug, PartialEq)] +pub struct BinaryExpression<'ast> { + pub operation: BinaryOperator, + pub left: Box>, + pub right: Box>, + pub span: Span<'ast>, +} + +// #[derive(Clong, Debug, PartialEq)] +// pub struct IdentifierExpression<'ast> { +// pub value: String, +// pub span: Span<'ast>, +// } + +#[derive(Clone, Debug, PartialEq)] +pub enum Expression<'ast> { + Value(Value<'ast>), + Variable(Variable<'ast>), + Not(NotExpression<'ast>), + Binary(BinaryExpression<'ast>), + // Increment(IncrementExpression<'ast>), + // Decrement(DecrementExpression<'ast>), +} + +impl<'ast> Expression<'ast> { + pub fn binary( + operation: BinaryOperator, + left: Box>, + right: Box>, + span: Span<'ast>, + ) -> Self { + Expression::Binary(BinaryExpression { + operation, + left, + right, + span, + }) + } + + pub fn span(&self) -> &Span<'ast> { + match self { + Expression::Value(expression) => &expression.span(), + Expression::Variable(expression) => &expression.span, + Expression::Not(expression) => &expression.span, + Expression::Binary(expression) => &expression.span, + // Expression::Increment(expression) => &expression.span, + // Expression::Decrement(expression) => &expression.span, + } + } +} + +impl<'ast> fmt::Display for Expression<'ast> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + Expression::Value(ref expression) => write!(f, "{}", expression), + Expression::Variable(ref expression) => write!(f, "{}", expression), + Expression::Not(ref expression) => write!(f, "{}", expression.expression), + Expression::Binary(ref expression) => { + write!(f, "{} == {}", expression.left, expression.right) + } + } + } +} + +impl<'ast> FromPest<'ast> for Expression<'ast> { + type Rule = Rule; + type FatalError = Void; + + fn from_pest(pest: &mut Pairs<'ast, Rule>) -> Result> { + let mut clone = pest.clone(); + let pair = clone.next().ok_or(::from_pest::ConversionError::NoMatch)?; + match pair.as_rule() { + Rule::expression => { + // Transfer iterated state to pest. + *pest = clone; + Ok(*PRECEDENCE_CLIMBER.climb(pair.into_inner(), parse_term, binary_expression)) + } + _ => Err(ConversionError::NoMatch), + } + } +} + +// Statements + +#[derive(Debug, FromPest, PartialEq, Clone)] +#[pest_ast(rule(Rule::statement_assign))] +pub struct AssignStatement<'ast> { + pub variable: Variable<'ast>, + pub expression: Expression<'ast>, + #[pest_ast(outer())] + pub span: Span<'ast>, +} + +impl<'ast> fmt::Display for AssignStatement<'ast> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self.expression) + } +} + +#[derive(Debug, FromPest, PartialEq, Clone)] +#[pest_ast(rule(Rule::statement_return))] +pub struct ReturnStatement<'ast> { + pub expressions: Vec>, + #[pest_ast(outer())] + pub span: Span<'ast>, +} + +impl<'ast> fmt::Display for ReturnStatement<'ast> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + for (i, expression) in self.expressions.iter().enumerate() { + write!(f, "{}", expression)?; + if i < self.expressions.len() - 1 { + write!(f, ", ")?; + } + } + write!(f, "") + } +} + +#[derive(Clone, Debug, FromPest, PartialEq)] +#[pest_ast(rule(Rule::statement))] +pub enum Statement<'ast> { + Assign(AssignStatement<'ast>), + Return(ReturnStatement<'ast>), +} + +impl<'ast> fmt::Display for Statement<'ast> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + Statement::Assign(ref statement) => write!(f, "{}", statement), + Statement::Return(ref statement) => write!(f, "{}", statement), + } + } +} + +// Utilities + +#[derive(Debug, FromPest, PartialEq, Clone)] +#[pest_ast(rule(Rule::EOI))] +pub struct EOI; diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000000..a56330eed4 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,14 @@ +extern crate pest; +#[macro_use] +extern crate pest_derive; + +extern crate from_pest; +#[macro_use] +extern crate pest_ast; + +#[macro_use] +extern crate lazy_static; + +pub mod ast; + +pub mod program; diff --git a/src/main.rs b/src/main.rs index 65b8d93094..ed84f3945b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,406 +1,56 @@ -extern crate pest; -#[macro_use] -extern crate pest_derive; +use language::*; -extern crate from_pest; -#[macro_use] -extern crate pest_ast; - -#[macro_use] -extern crate lazy_static; - -use pest::Parser; +use from_pest::FromPest; use std::fs; -#[derive(Parser)] -#[grammar = "language.pest"] -pub struct LanguageParser; - -mod ast { - use from_pest::ConversionError; - use from_pest::FromPest; - use from_pest::Void; - use pest::iterators::{Pair, Pairs}; - use pest::prec_climber::{Assoc, Operator, PrecClimber}; - use pest::Span; - use pest_ast::FromPest; - use super::Rule; - - fn span_into_string(span: Span) -> String { - span.as_str().to_string() - } - - lazy_static! { - static ref PRECEDENCE_CLIMBER: PrecClimber = precedence_climber(); - } - - fn precedence_climber() -> PrecClimber { - PrecClimber::new(vec![ - Operator::new(Rule::operation_or, Assoc::Left), - Operator::new(Rule::operation_and, Assoc::Left), - Operator::new(Rule::operation_eq, Assoc::Left) - | Operator::new(Rule::operation_neq, Assoc::Left), - Operator::new(Rule::operation_geq, Assoc::Left) - | Operator::new(Rule::operation_gt, Assoc::Left) - | Operator::new(Rule::operation_leq, Assoc::Left) - | Operator::new(Rule::operation_lt, Assoc::Left), - Operator::new(Rule::operation_add, Assoc::Left) - | Operator::new(Rule::operation_sub, Assoc::Left), - Operator::new(Rule::operation_mul, Assoc::Left) - | Operator::new(Rule::operation_div, Assoc::Left), - Operator::new(Rule::operation_pow, Assoc::Left), - ]) - } - - fn parse_term(pair: Pair) -> Box { - Box::new(match pair.as_rule() { - Rule::expression_term => { - let clone = pair.clone(); - let next = clone.into_inner().next().unwrap(); - match next.as_rule() { - Rule::expression_primitive => { - let next = next.into_inner().next().unwrap(); - match next.as_rule() { - Rule::value => Expression::Value( - Value::from_pest(&mut pair.into_inner().next().unwrap().into_inner()).unwrap() - ), - Rule::variable => Expression::Variable( - Variable::from_pest(&mut pair.into_inner().next().unwrap().into_inner()).unwrap(), - ), - rule => unreachable!("`expression_primitive` should contain one of [`value`, `variable`], found {:#?}", rule) - } - } - Rule::expression_not => { - let span = next.as_span(); - let mut inner = next.into_inner(); - let operation = match inner.next().unwrap().as_rule() { - Rule::operation_pre_not => Not::from_pest(&mut pair.into_inner().next().unwrap().into_inner()).unwrap(), - rule => unreachable!("`expression_not` should yield `operation_pre_not`, found {:#?}", rule) - }; - let expression = parse_term(inner.next().unwrap()); - Expression::Not(NotExpression { operation, expression, span }) - }, - Rule::expression => Expression::from_pest(&mut pair.into_inner()).unwrap(), // Parenthesis case - - // Rule::expression_increment => { - // let span = next.as_span(); - // let mut inner = next.into_inner(); - // let expression = parse_expression_term(inner.next().unwrap()); - // let operation = match inner.next().unwrap().as_rule() { - // Rule::operation_post_increment => Increment::from_pest(&mut pair.into_inner().next().unwrap().into_inner()).unwrap(), - // rule => unreachable!("`expression_increment` should yield `operation_post_increment`, found {:#?}", rule) - // }; - // Expression::Increment(IncrementExpression { operation, expression, span }) - // }, - // Rule::expression_decrement => { - // let span = next.as_span(); - // let mut inner = next.into_inner(); - // let expression = parse_expression_term(inner.next().unwrap()); - // let operation = match inner.next().unwrap().as_rule() { - // Rule::operation_post_decrement => Decrement::from_pest(&mut pair.into_inner().next().unwrap().into_inner()).unwrap(), - // rule => unreachable!("`expression_decrement` should yield `operation_post_decrement`, found {:#?}", rule) - // }; - // Expression::Decrement(DecrementExpression { operation, expression, span }) - // }, - - rule => unreachable!("`term` should contain one of ['value', 'variable', 'expression', 'expression_not', 'expression_increment', 'expression_decrement'], found {:#?}", rule) - } - } - rule => unreachable!("`parse_expression_term` should be invoked on `Rule::expression_term`, found {:#?}", rule), - }) - } - - fn binary_expression<'ast>( - lhs: Box>, - pair: Pair<'ast, Rule>, - rhs: Box>, - ) -> Box> { - let (start, _) = lhs.span().clone().split(); - let (_, end) = rhs.span().clone().split(); - let span = start.span(&end); - - Box::new(match pair.as_rule() { - Rule::operation_or => Expression::binary(BinaryOperator::Or, lhs, rhs, span), - Rule::operation_and => Expression::binary(BinaryOperator::And, lhs, rhs, span), - Rule::operation_eq => Expression::binary(BinaryOperator::Eq, lhs, rhs, span), - Rule::operation_neq => Expression::binary(BinaryOperator::Neq, lhs, rhs, span), - Rule::operation_geq => Expression::binary(BinaryOperator::Geq, lhs, rhs, span), - Rule::operation_gt => Expression::binary(BinaryOperator::Gt, lhs, rhs, span), - Rule::operation_leq => Expression::binary(BinaryOperator::Leq, lhs, rhs, span), - Rule::operation_lt => Expression::binary(BinaryOperator::Lt, lhs, rhs, span), - Rule::operation_add => Expression::binary(BinaryOperator::Add, lhs, rhs, span), - Rule::operation_sub => Expression::binary(BinaryOperator::Sub, lhs, rhs, span), - Rule::operation_mul => Expression::binary(BinaryOperator::Mul, lhs, rhs, span), - Rule::operation_div => Expression::binary(BinaryOperator::Div, lhs, rhs, span), - Rule::operation_pow => Expression::binary(BinaryOperator::Pow, lhs, rhs, span), - _ => unreachable!(), - }) - } - - #[derive(Clone, Debug, FromPest, PartialEq)] - #[pest_ast(rule(Rule::file))] - pub struct File<'ast> { - pub statement: Vec>, - pub eoi: EOI, - #[pest_ast(outer())] - pub span: Span<'ast>, - } - - // Visibility - - #[derive(Clone, Debug, FromPest, PartialEq)] - #[pest_ast(rule(Rule::visibility_public))] - pub struct Public {} - - #[derive(Clone, Debug, FromPest, PartialEq)] - #[pest_ast(rule(Rule::visibility_private))] - pub struct Private {} - - #[derive(Clone, Debug, FromPest, PartialEq)] - #[pest_ast(rule(Rule::visibility))] - pub enum Visibility { - Public(Public), - Private(Private), - } - - // Unary Operations - - #[derive(Clone, Debug, FromPest, PartialEq)] - #[pest_ast(rule(Rule::operation_pre_not))] - pub struct Not<'ast> { - #[pest_ast(outer())] - pub span: Span<'ast>, - } - - #[derive(Clone, Debug, FromPest, PartialEq)] - #[pest_ast(rule(Rule::operation_post_increment))] - pub struct Increment<'ast> { - #[pest_ast(outer())] - pub span: Span<'ast>, - } - - #[derive(Clone, Debug, FromPest, PartialEq)] - #[pest_ast(rule(Rule::operation_post_decrement))] - pub struct Decrement<'ast> { - #[pest_ast(outer())] - pub span: Span<'ast>, - } - - // Binary Operations - - #[derive(Debug, PartialEq, Clone)] - pub enum BinaryOperator { - Or, - And, - Eq, - Neq, - Geq, - Gt, - Leq, - Lt, - Add, - Sub, - Mul, - Div, - Pow, - } - - // Values - - #[derive(Clone, Debug, FromPest, PartialEq)] - #[pest_ast(rule(Rule::value_boolean))] - pub struct Boolean<'ast> { - #[pest_ast(outer(with(span_into_string)))] - pub value: String, - #[pest_ast(outer())] - pub span: Span<'ast>, - } - - #[derive(Clone, Debug, FromPest, PartialEq)] - #[pest_ast(rule(Rule::value_field))] - pub struct Field<'ast> { - #[pest_ast(outer(with(span_into_string)))] - pub value: String, - #[pest_ast(outer())] - pub span: Span<'ast>, - } - - #[derive(Clone, Debug, FromPest, PartialEq)] - #[pest_ast(rule(Rule::value))] - pub enum Value<'ast> { - Boolean(Boolean<'ast>), - Field(Field<'ast>), - } - - impl<'ast> Value<'ast> { - pub fn span(&self) -> &Span<'ast> { - match self { - Value::Boolean(value) => &value.span, - Value::Field(value) => &value.span, - } - } - } - - // Variables - - #[derive(Clone, Debug, FromPest, PartialEq)] - #[pest_ast(rule(Rule::variable))] - pub struct Variable<'ast> { - #[pest_ast(outer(with(span_into_string)))] - pub value: String, - #[pest_ast(outer())] - pub span: Span<'ast>, - } - - // Expressions - - #[derive(Clone, Debug, FromPest, PartialEq)] - #[pest_ast(rule(Rule::expression_not))] - pub struct NotExpression<'ast> { - pub operation: Not<'ast>, - pub expression: Box>, - #[pest_ast(outer())] - pub span: Span<'ast>, - } - - // #[derive(Clone, Debug, FromPest, PartialEq)] - // #[pest_ast(rule(Rule::expression_increment))] - // pub struct IncrementExpression<'ast> { - // pub expression: Box>, - // pub operation: Increment<'ast>, - // #[pest_ast(outer())] - // pub span: Span<'ast>, - // } - // - // #[derive(Clone, Debug, FromPest, PartialEq)] - // #[pest_ast(rule(Rule::expression_decrement))] - // pub struct DecrementExpression<'ast> { - // pub expression: Box>, - // pub operation: Decrement<'ast>, - // #[pest_ast(outer())] - // pub span: Span<'ast>, - // } - - #[derive(Clone, Debug, PartialEq)] - pub struct BinaryExpression<'ast> { - pub operation: BinaryOperator, - pub left: Box>, - pub right: Box>, - pub span: Span<'ast>, - } - - #[derive(Clone, Debug, PartialEq)] - pub enum Expression<'ast> { - Value(Value<'ast>), - Variable(Variable<'ast>), - Not(NotExpression<'ast>), - Binary(BinaryExpression<'ast>), - - // Increment(IncrementExpression<'ast>), - // Decrement(DecrementExpression<'ast>), - } - - impl<'ast> Expression<'ast> { - pub fn binary( - operation: BinaryOperator, - left: Box>, - right: Box>, - span: Span<'ast>, - ) -> Self { - Expression::Binary(BinaryExpression { operation, left, right, span }) - } - - pub fn span(&self) -> &Span<'ast> { - match self { - Expression::Value(expression) => &expression.span(), - Expression::Variable(expression) => &expression.span, - Expression::Not(expression) => &expression.span, - Expression::Binary(expression) => &expression.span, - - // Expression::Increment(expression) => &expression.span, - // Expression::Decrement(expression) => &expression.span, - } - } - } - - impl<'ast> FromPest<'ast> for Expression<'ast> { - type Rule = Rule; - type FatalError = Void; - - fn from_pest(pest: &mut Pairs<'ast, Rule>) -> Result> { - let mut clone = pest.clone(); - let pair = clone.next().ok_or(::from_pest::ConversionError::NoMatch)?; - match pair.as_rule() { - Rule::expression => { - // Transfer iterated state to pest. - *pest = clone; - Ok(*PRECEDENCE_CLIMBER.climb(pair.into_inner(), parse_term, binary_expression)) - } - _ => Err(ConversionError::NoMatch), - } - } - } - - // Statements - - #[derive(Debug, FromPest, PartialEq, Clone)] - #[pest_ast(rule(Rule::statement_assign))] - pub struct AssignStatement<'ast> { - pub variable: Variable<'ast>, - pub expression: Expression<'ast>, - #[pest_ast(outer())] - pub span: Span<'ast>, - } - - #[derive(Debug, FromPest, PartialEq, Clone)] - #[pest_ast(rule(Rule::statement_return))] - pub struct ReturnStatement<'ast> { - pub expressions: Vec>, - #[pest_ast(outer())] - pub span: Span<'ast>, - } - - #[derive(Clone, Debug, FromPest, PartialEq)] - #[pest_ast(rule(Rule::statement))] - pub enum Statement<'ast> { - Assign(AssignStatement<'ast>), - Return(ReturnStatement<'ast>), - } - - // Utilities - - #[derive(Debug, FromPest, PartialEq, Clone)] - #[pest_ast(rule(Rule::EOI))] - pub struct EOI; -} - fn main() { - use crate::from_pest::FromPest; - use snarkos_gadgets::curves::edwards_bls12::FqGadget; - use snarkos_models::gadgets::{r1cs::{ConstraintSystem, TestConstraintSystem, Fr}, utilities::{alloc::AllocGadget, boolean::Boolean}}; + // use snarkos_gadgets::curves::edwards_bls12::FqGadget; + // use snarkos_models::gadgets::{ + // r1cs::{ConstraintSystem, TestConstraintSystem, Fr}, + // utilities::{ + // alloc::{AllocGadget}, + // boolean::Boolean, + // uint32::UInt32, + // } + // }; + // Read in file as string let unparsed_file = fs::read_to_string("simple.program").expect("cannot read file"); - let mut file = LanguageParser::parse(Rule::file, &unparsed_file).expect("unsuccessful parse"); + + // Parse the file using langauge.pest + let mut file = ast::parse(&unparsed_file).expect("unsuccessful parse"); + + // Build the abstract syntax tree let syntax_tree = ast::File::from_pest(&mut file).expect("infallible"); - for statement in syntax_tree.statement { - match statement { - ast::Statement::Assign(statement) => { - println!("{:#?}", statement); - }, - ast::Statement::Return(statement) => { + let program = program::Program::from(syntax_tree); - } - } - } + println!("{:?}", program); - let mut cs = TestConstraintSystem::::new(); + // // Use this code when proving + // let left_u32 = left_string.parse::().unwrap(); + // let right_u32 = right_string.parse::().unwrap(); + // + // println!("left u32 value: {:#?}", left_u32); + // println!("right u32 value: {:#?}", right_u32); + // + // let left_constraint = UInt32::alloc(cs.ns(|| "left variable"), Some(left_u32)).unwrap(); + // let right_constraint = UInt32::constant(right_u32); + // + // let bool = Boolean::alloc(cs.ns(|| format!("boolean")), || Ok(true)).unwrap(); + // + // left_constraint.conditional_enforce_equal(cs.ns(|| format!("enforce left == right")), &right_constraint, &bool).unwrap(); - Boolean::alloc(cs.ns(|| format!("boolean")), || Ok(true)); + // // Constraint testing + // let bool = Boolean::alloc(cs.ns(|| format!("boolean")), || Ok(true)).unwrap(); + // let a_bit = UInt32::alloc(cs.ns(|| "a_bit"), Some(4u32)).unwrap(); + // let b_bit = UInt32::constant(5u32); + // + // a_bit.conditional_enforce_equal(cs.ns(|| format!("enforce equal")), &b_bit, &bool).unwrap(); + // println!("satisfied: {:?}", cs.is_satisfied()); - println!("\n\n number of constraints for input: {}", cs.num_constraints()); - + // println!("\n\n number of constraints for input: {}", cs.num_constraints()); // for token in file.into_inner() { // match token.as_rule() { @@ -411,7 +61,6 @@ fn main() { // // println!("{:?}", token); // } - // let mut field_sum: f64 = 0.0; // let mut record_count: u64 = 0; // diff --git a/src/program/mod.rs b/src/program/mod.rs new file mode 100644 index 0000000000..976d3ad63b --- /dev/null +++ b/src/program/mod.rs @@ -0,0 +1,11 @@ +pub mod program; +pub use self::program::*; + +pub mod types; +pub use self::types::*; + +pub mod types_display; +pub use self::types_display::*; + +pub mod types_from; +pub use self::types_from::*; diff --git a/src/program/program.rs b/src/program/program.rs new file mode 100644 index 0000000000..085b9a07b9 --- /dev/null +++ b/src/program/program.rs @@ -0,0 +1,90 @@ +use crate::program::{Expression, ExpressionList, Statement, StatementNode, Variable}; + +use pest::Span; +use std::fmt; +use std::fmt::Formatter; + +// AST -> Program + +/// Position in input file +#[derive(Clone, Copy)] +pub struct Position { + pub line: usize, + pub column: usize, +} + +impl fmt::Display for Position { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}:{}", self.line, self.column) + } +} + +impl fmt::Debug for Position { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + write!(f, "{}:{}", self.line, self.column) + } +} + +/// Building blocks for a program +#[derive(Debug, Clone)] +pub struct Node { + start: Position, + end: Position, + value: T, +} + +impl Node { + pub fn new(start: Position, end: Position, value: T) -> Node { + Self { start, end, value } + } +} + +impl fmt::Display for Node { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self.value) + } +} + +impl std::cmp::PartialEq for Node { + fn eq(&self, other: &Node) -> bool { + self.value.eq(&other.value) + } +} + +pub trait NodeValue: fmt::Display + Sized + PartialEq { + fn span(self, span: Span) -> Node { + let start = span.start_pos().line_col(); + let end = span.end_pos().line_col(); + + let start = Position { + line: start.0, + column: start.1, + }; + + let end = Position { + line: end.0, + column: end.1, + }; + + Node::new(start, end, self) + } +} + +impl From for Node { + fn from(v: V) -> Self { + let mock_position = Position { line: 1, column: 1 }; + + Self::new(mock_position, mock_position, v) + } +} + +impl<'ast> NodeValue for Expression<'ast> {} +impl<'ast> NodeValue for ExpressionList<'ast> {} +impl<'ast> NodeValue for Statement<'ast> {} +impl<'ast> NodeValue for Variable<'ast> {} + +/// A collection of nodes created from an abstract syntax tree. +#[derive(Debug)] +pub struct Program<'ast> { + pub nodes: Vec>, +} diff --git a/src/program/types.rs b/src/program/types.rs new file mode 100644 index 0000000000..a24893d6ba --- /dev/null +++ b/src/program/types.rs @@ -0,0 +1,57 @@ +use crate::program::Node; + +// Program Nodes - Wrappers for different types in a program. +pub type ExpressionNode<'ast> = Node>; +pub type ExpressionListNode<'ast> = Node>; +pub type StatementNode<'ast> = Node>; +pub type VariableNode<'ast> = Node>; + +/// Identifier string +pub type Identifier<'ast> = &'ast str; + +/// Program variable +#[derive(Debug, Clone, PartialEq)] +pub struct Variable<'ast> { + pub id: Identifier<'ast>, +} + +/// Program expression that evaluates to a value +#[derive(Debug, Clone, PartialEq)] +pub enum Expression<'ast> { + // Expression identifier + Identifier(Identifier<'ast>), + // Values + Boolean(bool), + Field(Identifier<'ast>), + // Variable + Variable(VariableNode<'ast>), + // Not expression + Not(Box>), + // Binary expression + Or(Box>, Box>), + And(Box>, Box>), + Eq(Box>, Box>), + Neq(Box>, Box>), + Geq(Box>, Box>), + Gt(Box>, Box>), + Leq(Box>, Box>), + Lt(Box>, Box>), + Add(Box>, Box>), + Sub(Box>, Box>), + Mul(Box>, Box>), + Div(Box>, Box>), + Pow(Box>, Box>), +} + +#[derive(Debug, Clone, PartialEq)] +pub struct ExpressionList<'ast> { + pub expressions: Vec>, +} + +/// Program statement that defines some action (or expression) to be carried out +#[derive(Debug, PartialEq, Clone)] +pub enum Statement<'ast> { + Declaration(VariableNode<'ast>), + Definition(VariableNode<'ast>, ExpressionNode<'ast>), + Return(ExpressionListNode<'ast>), +} diff --git a/src/program/types_display.rs b/src/program/types_display.rs new file mode 100644 index 0000000000..a2c0e7bbc8 --- /dev/null +++ b/src/program/types_display.rs @@ -0,0 +1,58 @@ +use crate::program::{Expression, ExpressionList, Statement, Variable}; + +use std::fmt; + +impl<'ast> fmt::Display for Variable<'ast> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self.id) + } +} + +impl<'ast> fmt::Display for Expression<'ast> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + Expression::Identifier(ref s) => write!(f, "Identifier({:?})", s), + Expression::Boolean(ref b) => write!(f, "Boolean({:?})", b), + Expression::Field(ref s) => write!(f, "Field({:?})", s), + Expression::Variable(ref v) => write!(f, "{}", v), + Expression::Not(ref e) => write!(f, "{}", e), + Expression::Or(ref lhs, ref rhs) => write!(f, "{} || {}", lhs, rhs), + Expression::And(ref lhs, ref rhs) => write!(f, "{} && {}", lhs, rhs), + Expression::Eq(ref lhs, ref rhs) => write!(f, "{} == {}", lhs, rhs), + Expression::Neq(ref lhs, ref rhs) => write!(f, "{} != {}", lhs, rhs), + Expression::Geq(ref lhs, ref rhs) => write!(f, "{} >= {}", lhs, rhs), + Expression::Gt(ref lhs, ref rhs) => write!(f, "{} > {}", lhs, rhs), + Expression::Leq(ref lhs, ref rhs) => write!(f, "{} <= {}", lhs, rhs), + Expression::Lt(ref lhs, ref rhs) => write!(f, "{} < {}", lhs, rhs), + Expression::Add(ref lhs, ref rhs) => write!(f, "{} + {}", lhs, rhs), + Expression::Sub(ref lhs, ref rhs) => write!(f, "{} - {}", lhs, rhs), + Expression::Mul(ref lhs, ref rhs) => write!(f, "{} * {}", lhs, rhs), + Expression::Div(ref lhs, ref rhs) => write!(f, "{} / {}", lhs, rhs), + Expression::Pow(ref lhs, ref rhs) => write!(f, "{} ** {}", lhs, rhs), + } + } +} + +impl<'ast> fmt::Display for ExpressionList<'ast> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + for (i, expression) in self.expressions.iter().enumerate() { + write!(f, "{}", expression)?; + if i < self.expressions.len() - 1 { + write!(f, ", ")?; + } + } + write!(f, "") + } +} + +impl<'ast> fmt::Display for Statement<'ast> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + Statement::Return(ref expressions) => write!(f, "return {}", expressions), + Statement::Declaration(ref variable) => write!(f, "{}", variable), + Statement::Definition(ref variable, ref expression) => { + write!(f, "{} = {}", variable, expression) + } + } + } +} diff --git a/src/program/types_from.rs b/src/program/types_from.rs new file mode 100644 index 0000000000..7a556b64e9 --- /dev/null +++ b/src/program/types_from.rs @@ -0,0 +1,178 @@ +use crate::{ + ast, + program::{program, types, NodeValue}, +}; + +impl<'ast> From> for types::ExpressionNode<'ast> { + fn from(boolean: ast::Boolean<'ast>) -> Self { + types::Expression::Boolean( + boolean + .value + .parse::() + .expect("unable to parse boolean"), + ) + .span(boolean.span) + } +} + +impl<'ast> From> for types::ExpressionNode<'ast> { + fn from(field: ast::Field<'ast>) -> Self { + types::Expression::Field(field.span.as_str()).span(field.span) + } +} + +impl<'ast> From> for types::ExpressionNode<'ast> { + fn from(value: ast::Value<'ast>) -> Self { + match value { + ast::Value::Boolean(boolean) => types::ExpressionNode::from(boolean), + ast::Value::Field(field) => types::ExpressionNode::from(field), + } + } +} + +impl<'ast> From> for types::VariableNode<'ast> { + fn from(variable: ast::Variable<'ast>) -> Self { + types::Variable { + id: variable.span.as_str(), + } + .span(variable.span) + } +} + +impl<'ast> From> for types::ExpressionNode<'ast> { + fn from(variable: ast::Variable<'ast>) -> Self { + types::Expression::Variable(types::VariableNode::from(variable.clone())).span(variable.span) + } +} + +impl<'ast> From> for types::ExpressionNode<'ast> { + fn from(expression: ast::NotExpression<'ast>) -> Self { + types::Expression::Not(Box::new(types::ExpressionNode::from( + *expression.expression, + ))) + .span(expression.span) + } +} + +impl<'ast> From> for types::ExpressionNode<'ast> { + fn from(expression: ast::BinaryExpression<'ast>) -> Self { + match expression.operation { + ast::BinaryOperator::Or => types::Expression::Or( + Box::new(types::ExpressionNode::from(*expression.left)), + Box::new(types::ExpressionNode::from(*expression.right)), + ), + ast::BinaryOperator::And => types::Expression::And( + Box::new(types::ExpressionNode::from(*expression.left)), + Box::new(types::ExpressionNode::from(*expression.right)), + ), + ast::BinaryOperator::Eq => types::Expression::Eq( + Box::new(types::ExpressionNode::from(*expression.left)), + Box::new(types::ExpressionNode::from(*expression.right)), + ), + ast::BinaryOperator::Neq => types::Expression::Neq( + Box::new(types::ExpressionNode::from(*expression.left)), + Box::new(types::ExpressionNode::from(*expression.right)), + ), + ast::BinaryOperator::Geq => types::Expression::Geq( + Box::new(types::ExpressionNode::from(*expression.left)), + Box::new(types::ExpressionNode::from(*expression.right)), + ), + ast::BinaryOperator::Gt => types::Expression::Gt( + Box::new(types::ExpressionNode::from(*expression.left)), + Box::new(types::ExpressionNode::from(*expression.right)), + ), + ast::BinaryOperator::Leq => types::Expression::Leq( + Box::new(types::ExpressionNode::from(*expression.left)), + Box::new(types::ExpressionNode::from(*expression.right)), + ), + ast::BinaryOperator::Lt => types::Expression::Lt( + Box::new(types::ExpressionNode::from(*expression.left)), + Box::new(types::ExpressionNode::from(*expression.right)), + ), + ast::BinaryOperator::Add => types::Expression::Add( + Box::new(types::ExpressionNode::from(*expression.left)), + Box::new(types::ExpressionNode::from(*expression.right)), + ), + ast::BinaryOperator::Sub => types::Expression::Sub( + Box::new(types::ExpressionNode::from(*expression.left)), + Box::new(types::ExpressionNode::from(*expression.right)), + ), + ast::BinaryOperator::Mul => types::Expression::Mul( + Box::new(types::ExpressionNode::from(*expression.left)), + Box::new(types::ExpressionNode::from(*expression.right)), + ), + ast::BinaryOperator::Div => types::Expression::Div( + Box::new(types::ExpressionNode::from(*expression.left)), + Box::new(types::ExpressionNode::from(*expression.right)), + ), + ast::BinaryOperator::Pow => types::Expression::Pow( + Box::new(types::ExpressionNode::from(*expression.left)), + Box::new(types::ExpressionNode::from(*expression.right)), + ), + } + .span(expression.span) + } +} + +impl<'ast> From> for types::ExpressionNode<'ast> { + fn from(expression: ast::Expression<'ast>) -> Self { + match expression { + ast::Expression::Value(expression) => types::ExpressionNode::from(expression), + ast::Expression::Variable(expression) => types::ExpressionNode::from(expression), + ast::Expression::Not(expression) => types::ExpressionNode::from(expression), + ast::Expression::Binary(expression) => types::ExpressionNode::from(expression), + } + } +} + +impl<'ast> From> for types::StatementNode<'ast> { + fn from(statement: ast::AssignStatement<'ast>) -> Self { + types::Statement::Definition( + types::VariableNode::from(statement.variable), + types::ExpressionNode::from(statement.expression), + ) + .span(statement.span) + } +} + +impl<'ast> From> for types::StatementNode<'ast> { + fn from(statement: ast::ReturnStatement<'ast>) -> Self { + types::Statement::Return( + types::ExpressionList { + expressions: statement + .expressions + .into_iter() + .map(|expression| types::ExpressionNode::from(expression)) + .collect(), + } + .span(statement.span.clone()), + ) + .span(statement.span) + } +} + +impl<'ast> From> for types::StatementNode<'ast> { + fn from(statement: ast::Statement<'ast>) -> Self { + match statement { + ast::Statement::Assign(statement) => types::StatementNode::from(statement), + ast::Statement::Return(statement) => types::StatementNode::from(statement), + } + } +} + +impl<'ast> From> for program::Program<'ast> { + fn from(file: ast::File<'ast>) -> Self { + program::Program { + nodes: file + .statement + .iter() + .map(|statement| types::StatementNode::from(statement.clone())) + .collect(), + } + // for statement in file.statement { + // // println!("statement {:?}", statement); + // let node = program::StatementNode::from(statement); + // println!("node {:?}", node); + // } + } +}