diff --git a/Cargo.lock b/Cargo.lock index 902274c662..4bcf7231ad 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1289,6 +1289,7 @@ dependencies = [ "serde", "serde_json", "thiserror", + "typed-arena", "uuid", ] @@ -1414,7 +1415,7 @@ dependencies = [ "notify", "num-bigint", "rand", - "rand_core 0.6.1", + "rand_core 0.6.2", "reqwest", "rusty-hook", "self_update", @@ -2175,9 +2176,9 @@ dependencies = [ [[package]] name = "rand_core" -version = "0.6.1" +version = "0.6.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c026d7df8b298d90ccbbc5190bd04d85e159eaf5576caeacf8741da93ccbd2e5" +checksum = "34cf66eb183df1c5876e2dcf6b13d57340741e8dc255b48e40a26de954d06ae7" [[package]] name = "rand_hc" @@ -3149,6 +3150,12 @@ version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "59547bce71d9c38b83d9c0e92b6066c4253371f15005def0c30d9657f50c7642" +[[package]] +name = "typed-arena" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0685c84d5d54d1c26f7d3eb96cd41550adb97baed141a761cf335d3d33bcd0ae" + [[package]] name = "typenum" version = "1.12.0" @@ -3442,9 +3449,9 @@ dependencies = [ [[package]] name = "zip" -version = "0.5.9" +version = "0.5.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc2896475a242c41366941faa27264df2cb935185a92e059a004d0048feb2ac5" +checksum = "5a8977234acab718eb2820494b2f96cbb16004c19dddf88b7445b27381450997" dependencies = [ "byteorder", "bzip2", diff --git a/Cargo.toml b/Cargo.toml index 334bde9ef3..757570cba3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -35,7 +35,7 @@ members = [ "input", "linter", "package", - "state" + "state", ] [dependencies.leo-ast] @@ -123,7 +123,7 @@ version = "0.3" version = "0.7" [dependencies.rand_core] -version = "0.6.1" +version = "0.6.2" [dependencies.reqwest] version = "0.11.0" diff --git a/asg/Cargo.toml b/asg/Cargo.toml index 6af3f5d876..b6d7fbb219 100644 --- a/asg/Cargo.toml +++ b/asg/Cargo.toml @@ -44,5 +44,8 @@ features = [ "v4", "serde" ] [dependencies.num-bigint] version = "0.3" +[dependencies.typed-arena] +version = "2.0" + [dev-dependencies.criterion] version = "0.3" diff --git a/asg/src/checks/return_path.rs b/asg/src/checks/return_path.rs index 15be40c2ed..8a21f0514e 100644 --- a/asg/src/checks/return_path.rs +++ b/asg/src/checks/return_path.rs @@ -25,8 +25,6 @@ use crate::{ Span, }; -use std::sync::Arc; - pub struct ReturnPathReducer { pub errors: Vec<(Span, String)>, } @@ -48,14 +46,14 @@ impl Default for ReturnPathReducer { } #[allow(unused_variables)] -impl MonoidalReducerExpression for ReturnPathReducer { - fn reduce_expression(&mut self, input: &Arc, value: BoolAnd) -> BoolAnd { +impl<'a> MonoidalReducerExpression<'a, BoolAnd> for ReturnPathReducer { + fn reduce_expression(&mut self, input: &'a Expression<'a>, value: BoolAnd) -> BoolAnd { BoolAnd(false) } } #[allow(unused_variables)] -impl MonoidalReducerStatement for ReturnPathReducer { +impl<'a> MonoidalReducerStatement<'a, BoolAnd> for ReturnPathReducer { fn reduce_assign_access(&mut self, input: &AssignAccess, left: Option, right: Option) -> BoolAnd { BoolAnd(false) } @@ -69,7 +67,7 @@ impl MonoidalReducerStatement for ReturnPathReducer { BoolAnd(false) } else if let Some(index) = statements[..statements.len() - 1].iter().map(|x| x.0).position(|x| x) { self.record_error( - input.statements[index].span(), + input.statements[index].get().span(), "dead code due to unconditional early return".to_string(), ); BoolAnd(true) diff --git a/asg/src/const_value.rs b/asg/src/const_value.rs index f022afbef6..bab4aa5b3b 100644 --- a/asg/src/const_value.rs +++ b/asg/src/const_value.rs @@ -226,7 +226,7 @@ impl ConstInt { } } - pub fn get_type(&self) -> Type { + pub fn get_type<'a>(&self) -> Type<'a> { Type::Integer(self.get_int_type()) } @@ -247,7 +247,7 @@ impl ConstInt { } impl ConstValue { - pub fn get_type(&self) -> Option { + pub fn get_type<'a>(&self) -> Option> { Some(match self { ConstValue::Int(i) => i.get_type(), ConstValue::Group(_) => Type::Group, diff --git a/asg/src/expression/array_access.rs b/asg/src/expression/array_access.rs index 01e9b22855..3bbbd7909d 100644 --- a/asg/src/expression/array_access.rs +++ b/asg/src/expression/array_access.rs @@ -17,56 +17,53 @@ use crate::{AsgConvertError, ConstValue, Expression, ExpressionNode, FromAst, Node, PartialType, Scope, Span, Type}; use leo_ast::IntegerType; -use std::{ - cell::RefCell, - sync::{Arc, Weak}, -}; +use std::cell::Cell; -#[derive(Debug)] -pub struct ArrayAccessExpression { - pub parent: RefCell>>, +#[derive(Clone)] +pub struct ArrayAccessExpression<'a> { + pub parent: Cell>>, pub span: Option, - pub array: Arc, - pub index: Arc, + pub array: Cell<&'a Expression<'a>>, + pub index: Cell<&'a Expression<'a>>, } -impl Node for ArrayAccessExpression { +impl<'a> Node for ArrayAccessExpression<'a> { fn span(&self) -> Option<&Span> { self.span.as_ref() } } -impl ExpressionNode for ArrayAccessExpression { - fn set_parent(&self, parent: Weak) { +impl<'a> ExpressionNode<'a> for ArrayAccessExpression<'a> { + fn set_parent(&self, parent: &'a Expression<'a>) { self.parent.replace(Some(parent)); } - fn get_parent(&self) -> Option> { - self.parent.borrow().as_ref().map(Weak::upgrade).flatten() + fn get_parent(&self) -> Option<&'a Expression<'a>> { + self.parent.get() } - fn enforce_parents(&self, expr: &Arc) { - self.array.set_parent(Arc::downgrade(expr)); - self.index.set_parent(Arc::downgrade(expr)); + fn enforce_parents(&self, expr: &'a Expression<'a>) { + self.array.get().set_parent(expr); + self.index.get().set_parent(expr); } - fn get_type(&self) -> Option { - match self.array.get_type() { + fn get_type(&self) -> Option> { + match self.array.get().get_type() { Some(Type::Array(element, _)) => Some(*element), _ => None, } } fn is_mut_ref(&self) -> bool { - self.array.is_mut_ref() + self.array.get().is_mut_ref() } fn const_value(&self) -> Option { - let mut array = match self.array.const_value()? { + let mut array = match self.array.get().const_value()? { ConstValue::Array(values) => values, _ => return None, }; - let const_index = match self.index.const_value()? { + let const_index = match self.index.get().const_value()? { ConstValue::Int(x) => x.to_usize()?, _ => return None, }; @@ -77,17 +74,17 @@ impl ExpressionNode for ArrayAccessExpression { } fn is_consty(&self) -> bool { - self.array.is_consty() + self.array.get().is_consty() } } -impl FromAst for ArrayAccessExpression { +impl<'a> FromAst<'a, leo_ast::ArrayAccessExpression> for ArrayAccessExpression<'a> { fn from_ast( - scope: &Scope, + scope: &'a Scope<'a>, value: &leo_ast::ArrayAccessExpression, - expected_type: Option, - ) -> Result { - let array = Arc::::from_ast( + expected_type: Option>, + ) -> Result, AsgConvertError> { + let array = <&Expression<'a>>::from_ast( scope, &*value.array, Some(PartialType::Array(expected_type.map(Box::new), None)), @@ -103,7 +100,7 @@ impl FromAst for ArrayAccessExpression { } } - let index = Arc::::from_ast( + let index = <&Expression<'a>>::from_ast( scope, &*value.index, Some(PartialType::Integer(None, Some(IntegerType::U32))), @@ -116,19 +113,19 @@ impl FromAst for ArrayAccessExpression { } Ok(ArrayAccessExpression { - parent: RefCell::new(None), + parent: Cell::new(None), span: Some(value.span.clone()), - array, - index, + array: Cell::new(array), + index: Cell::new(index), }) } } -impl Into for &ArrayAccessExpression { +impl<'a> Into for &ArrayAccessExpression<'a> { fn into(self) -> leo_ast::ArrayAccessExpression { leo_ast::ArrayAccessExpression { - array: Box::new(self.array.as_ref().into()), - index: Box::new(self.index.as_ref().into()), + array: Box::new(self.array.get().into()), + index: Box::new(self.index.get().into()), span: self.span.clone().unwrap_or_default(), } } diff --git a/asg/src/expression/array_init.rs b/asg/src/expression/array_init.rs index 73033884ff..063485b4c4 100644 --- a/asg/src/expression/array_init.rs +++ b/asg/src/expression/array_init.rs @@ -16,40 +16,37 @@ use crate::{AsgConvertError, ConstValue, Expression, ExpressionNode, FromAst, Node, PartialType, Scope, Span, Type}; -use std::{ - cell::RefCell, - sync::{Arc, Weak}, -}; +use std::cell::Cell; -#[derive(Debug)] -pub struct ArrayInitExpression { - pub parent: RefCell>>, +#[derive(Clone)] +pub struct ArrayInitExpression<'a> { + pub parent: Cell>>, pub span: Option, - pub element: Arc, + pub element: Cell<&'a Expression<'a>>, pub len: usize, } -impl Node for ArrayInitExpression { +impl<'a> Node for ArrayInitExpression<'a> { fn span(&self) -> Option<&Span> { self.span.as_ref() } } -impl ExpressionNode for ArrayInitExpression { - fn set_parent(&self, parent: Weak) { +impl<'a> ExpressionNode<'a> for ArrayInitExpression<'a> { + fn set_parent(&self, parent: &'a Expression<'a>) { self.parent.replace(Some(parent)); } - fn get_parent(&self) -> Option> { - self.parent.borrow().as_ref().map(Weak::upgrade).flatten() + fn get_parent(&self) -> Option<&'a Expression<'a>> { + self.parent.get() } - fn enforce_parents(&self, expr: &Arc) { - self.element.set_parent(Arc::downgrade(expr)); + fn enforce_parents(&self, expr: &'a Expression<'a>) { + self.element.get().set_parent(expr); } - fn get_type(&self) -> Option { - Some(Type::Array(Box::new(self.element.get_type()?), self.len)) + fn get_type(&self) -> Option> { + Some(Type::Array(Box::new(self.element.get().get_type()?), self.len)) } fn is_mut_ref(&self) -> bool { @@ -62,16 +59,16 @@ impl ExpressionNode for ArrayInitExpression { } fn is_consty(&self) -> bool { - self.element.is_consty() + self.element.get().is_consty() } } -impl FromAst for ArrayInitExpression { +impl<'a> FromAst<'a, leo_ast::ArrayInitExpression> for ArrayInitExpression<'a> { fn from_ast( - scope: &Scope, + scope: &'a Scope<'a>, value: &leo_ast::ArrayInitExpression, - expected_type: Option, - ) -> Result { + expected_type: Option>, + ) -> Result, AsgConvertError> { let (mut expected_item, expected_len) = match expected_type { Some(PartialType::Array(item, dims)) => (item.map(|x| *x), dims), None => (None, None), @@ -130,17 +127,19 @@ impl FromAst for ArrayInitExpression { } } } - let mut element = Some(Arc::::from_ast(scope, &*value.element, expected_item)?); + let mut element = Some(<&'a Expression<'a>>::from_ast(scope, &*value.element, expected_item)?); let mut output = None; for dimension in dimensions.iter().rev().copied() { output = Some(ArrayInitExpression { - parent: RefCell::new(None), + parent: Cell::new(None), span: Some(value.span.clone()), - element: output - .map(Expression::ArrayInit) - .map(Arc::new) - .unwrap_or_else(|| element.take().unwrap()), + element: Cell::new( + output + .map(Expression::ArrayInit) + .map(|expr| &*scope.alloc_expression(expr)) + .unwrap_or_else(|| element.take().unwrap()), + ), len: dimension, }); } @@ -148,10 +147,10 @@ impl FromAst for ArrayInitExpression { } } -impl Into for &ArrayInitExpression { +impl<'a> Into for &ArrayInitExpression<'a> { fn into(self) -> leo_ast::ArrayInitExpression { leo_ast::ArrayInitExpression { - element: Box::new(self.element.as_ref().into()), + element: Box::new(self.element.get().into()), dimensions: leo_ast::ArrayDimensions(vec![leo_ast::PositiveNumber { value: self.len.to_string(), }]), diff --git a/asg/src/expression/array_inline.rs b/asg/src/expression/array_inline.rs index ad4de160fd..5ee2167547 100644 --- a/asg/src/expression/array_inline.rs +++ b/asg/src/expression/array_inline.rs @@ -17,25 +17,22 @@ use crate::{AsgConvertError, ConstValue, Expression, ExpressionNode, FromAst, Node, PartialType, Scope, Span, Type}; use leo_ast::SpreadOrExpression; -use std::{ - cell::RefCell, - sync::{Arc, Weak}, -}; +use std::cell::Cell; -#[derive(Debug)] -pub struct ArrayInlineExpression { - pub parent: RefCell>>, +#[derive(Clone)] +pub struct ArrayInlineExpression<'a> { + pub parent: Cell>>, pub span: Option, - pub elements: Vec<(Arc, bool)>, // bool = if spread + pub elements: Vec<(Cell<&'a Expression<'a>>, bool)>, // bool = if spread } -impl ArrayInlineExpression { +impl<'a> ArrayInlineExpression<'a> { pub fn expanded_length(&self) -> usize { self.elements .iter() .map(|(expr, is_spread)| { if *is_spread { - match expr.get_type() { + match expr.get().get_type() { Some(Type::Array(_item, len)) => len, _ => 0, } @@ -47,30 +44,30 @@ impl ArrayInlineExpression { } } -impl Node for ArrayInlineExpression { +impl<'a> Node for ArrayInlineExpression<'a> { fn span(&self) -> Option<&Span> { self.span.as_ref() } } -impl ExpressionNode for ArrayInlineExpression { - fn set_parent(&self, parent: Weak) { +impl<'a> ExpressionNode<'a> for ArrayInlineExpression<'a> { + fn set_parent(&self, parent: &'a Expression<'a>) { self.parent.replace(Some(parent)); } - fn get_parent(&self) -> Option> { - self.parent.borrow().as_ref().map(Weak::upgrade).flatten() + fn get_parent(&self) -> Option<&'a Expression<'a>> { + self.parent.get() } - fn enforce_parents(&self, expr: &Arc) { + fn enforce_parents(&self, expr: &'a Expression<'a>) { self.elements.iter().for_each(|(element, _)| { - element.set_parent(Arc::downgrade(expr)); + element.get().set_parent(expr); }) } - fn get_type(&self) -> Option { + fn get_type(&self) -> Option> { Some(Type::Array( - Box::new(self.elements.first()?.0.get_type()?), + Box::new(self.elements.first()?.0.get().get_type()?), self.expanded_length(), )) } @@ -83,28 +80,28 @@ impl ExpressionNode for ArrayInlineExpression { let mut const_values = vec![]; for (expr, spread) in self.elements.iter() { if *spread { - match expr.const_value()? { + match expr.get().const_value()? { ConstValue::Array(items) => const_values.extend(items), _ => return None, } } else { - const_values.push(expr.const_value()?); + const_values.push(expr.get().const_value()?); } } Some(ConstValue::Array(const_values)) } fn is_consty(&self) -> bool { - self.elements.iter().all(|x| x.0.is_consty()) + self.elements.iter().all(|x| x.0.get().is_consty()) } } -impl FromAst for ArrayInlineExpression { +impl<'a> FromAst<'a, leo_ast::ArrayInlineExpression> for ArrayInlineExpression<'a> { fn from_ast( - scope: &Scope, + scope: &'a Scope<'a>, value: &leo_ast::ArrayInlineExpression, - expected_type: Option, - ) -> Result { + expected_type: Option>, + ) -> Result, AsgConvertError> { let (mut expected_item, expected_len) = match expected_type { Some(PartialType::Array(item, dims)) => (item.map(|x| *x), dims), None => (None, None), @@ -119,22 +116,22 @@ impl FromAst for ArrayInlineExpression { let mut len = 0; let output = ArrayInlineExpression { - parent: RefCell::new(None), + parent: Cell::new(None), span: Some(value.span.clone()), elements: value .elements .iter() .map(|e| match e { SpreadOrExpression::Expression(e) => { - let expr = Arc::::from_ast(scope, e, expected_item.clone())?; + let expr = <&Expression<'a>>::from_ast(scope, e, expected_item.clone())?; if expected_item.is_none() { expected_item = expr.get_type().map(Type::partial); } len += 1; - Ok((expr, false)) + Ok((Cell::new(expr), false)) } SpreadOrExpression::Spread(e) => { - let expr = Arc::::from_ast( + let expr = <&Expression<'a>>::from_ast( scope, e, Some(PartialType::Array(expected_item.clone().map(Box::new), None)), @@ -160,7 +157,7 @@ impl FromAst for ArrayInlineExpression { )); } } - Ok((expr, true)) + Ok((Cell::new(expr), true)) } }) .collect::, AsgConvertError>>()?, @@ -178,14 +175,14 @@ impl FromAst for ArrayInlineExpression { } } -impl Into for &ArrayInlineExpression { +impl<'a> Into for &ArrayInlineExpression<'a> { fn into(self) -> leo_ast::ArrayInlineExpression { leo_ast::ArrayInlineExpression { elements: self .elements .iter() .map(|(element, spread)| { - let element = element.as_ref().into(); + let element = element.get().into(); if *spread { SpreadOrExpression::Spread(element) } else { diff --git a/asg/src/expression/array_range_access.rs b/asg/src/expression/array_range_access.rs index 4ca809b720..f070116aa0 100644 --- a/asg/src/expression/array_range_access.rs +++ b/asg/src/expression/array_range_access.rs @@ -17,57 +17,54 @@ use crate::{AsgConvertError, ConstValue, Expression, ExpressionNode, FromAst, Node, PartialType, Scope, Span, Type}; use leo_ast::IntegerType; -use std::{ - cell::RefCell, - sync::{Arc, Weak}, -}; +use std::cell::Cell; -#[derive(Debug)] -pub struct ArrayRangeAccessExpression { - pub parent: RefCell>>, +#[derive(Clone)] +pub struct ArrayRangeAccessExpression<'a> { + pub parent: Cell>>, pub span: Option, - pub array: Arc, - pub left: Option>, - pub right: Option>, + pub array: Cell<&'a Expression<'a>>, + pub left: Cell>>, + pub right: Cell>>, } -impl Node for ArrayRangeAccessExpression { +impl<'a> Node for ArrayRangeAccessExpression<'a> { fn span(&self) -> Option<&Span> { self.span.as_ref() } } -impl ExpressionNode for ArrayRangeAccessExpression { - fn set_parent(&self, parent: Weak) { +impl<'a> ExpressionNode<'a> for ArrayRangeAccessExpression<'a> { + fn set_parent(&self, parent: &'a Expression<'a>) { self.parent.replace(Some(parent)); } - fn get_parent(&self) -> Option> { - self.parent.borrow().as_ref().map(Weak::upgrade).flatten() + fn get_parent(&self) -> Option<&'a Expression<'a>> { + self.parent.get() } - fn enforce_parents(&self, expr: &Arc) { - self.array.set_parent(Arc::downgrade(expr)); - self.array.enforce_parents(&self.array); - if let Some(left) = self.left.as_ref() { - left.set_parent(Arc::downgrade(expr)); + fn enforce_parents(&self, expr: &'a Expression<'a>) { + self.array.get().set_parent(expr); + self.array.get().enforce_parents(self.array.get()); + if let Some(left) = self.left.get() { + left.set_parent(expr); } - if let Some(right) = self.right.as_ref() { - right.set_parent(Arc::downgrade(expr)); + if let Some(right) = self.right.get() { + right.set_parent(expr); } } - fn get_type(&self) -> Option { - let (element, array_len) = match self.array.get_type() { + fn get_type(&self) -> Option> { + let (element, array_len) = match self.array.get().get_type() { Some(Type::Array(element, len)) => (element, len), _ => return None, }; - let const_left = match self.left.as_ref().map(|x| x.const_value()) { + let const_left = match self.left.get().map(|x| x.const_value()) { Some(Some(ConstValue::Int(x))) => x.to_usize()?, None => 0, _ => return None, }; - let const_right = match self.right.as_ref().map(|x| x.const_value()) { + let const_right = match self.right.get().map(|x| x.const_value()) { Some(Some(ConstValue::Int(x))) => x.to_usize()?, None => array_len, _ => return None, @@ -80,20 +77,20 @@ impl ExpressionNode for ArrayRangeAccessExpression { } fn is_mut_ref(&self) -> bool { - self.array.is_mut_ref() + self.array.get().is_mut_ref() } fn const_value(&self) -> Option { - let mut array = match self.array.const_value()? { + let mut array = match self.array.get().const_value()? { ConstValue::Array(values) => values, _ => return None, }; - let const_left = match self.left.as_ref().map(|x| x.const_value()) { + let const_left = match self.left.get().map(|x| x.const_value()) { Some(Some(ConstValue::Int(x))) => x.to_usize()?, None => 0, _ => return None, }; - let const_right = match self.right.as_ref().map(|x| x.const_value()) { + let const_right = match self.right.get().map(|x| x.const_value()) { Some(Some(ConstValue::Int(x))) => x.to_usize()?, None => array.len(), _ => return None, @@ -106,16 +103,16 @@ impl ExpressionNode for ArrayRangeAccessExpression { } fn is_consty(&self) -> bool { - self.array.is_consty() + self.array.get().is_consty() } } -impl FromAst for ArrayRangeAccessExpression { +impl<'a> FromAst<'a, leo_ast::ArrayRangeAccessExpression> for ArrayRangeAccessExpression<'a> { fn from_ast( - scope: &Scope, + scope: &'a Scope<'a>, value: &leo_ast::ArrayRangeAccessExpression, - expected_type: Option, - ) -> Result { + expected_type: Option>, + ) -> Result, AsgConvertError> { let expected_array = match expected_type { Some(PartialType::Array(element, _len)) => Some(PartialType::Array(element, None)), None => None, @@ -127,7 +124,7 @@ impl FromAst for ArrayRangeAccessExpression )); } }; - let array = Arc::::from_ast(scope, &*value.array, expected_array)?; + let array = <&Expression<'a>>::from_ast(scope, &*value.array, expected_array)?; let array_type = array.get_type(); match array_type { Some(Type::Array(_, _)) => (), @@ -143,14 +140,14 @@ impl FromAst for ArrayRangeAccessExpression .left .as_deref() .map(|left| { - Arc::::from_ast(scope, left, Some(PartialType::Integer(None, Some(IntegerType::U32)))) + <&Expression<'a>>::from_ast(scope, left, Some(PartialType::Integer(None, Some(IntegerType::U32)))) }) .transpose()?; let right = value .right .as_deref() .map(|right| { - Arc::::from_ast(scope, right, Some(PartialType::Integer(None, Some(IntegerType::U32)))) + <&Expression<'a>>::from_ast(scope, right, Some(PartialType::Integer(None, Some(IntegerType::U32)))) }) .transpose()?; @@ -169,21 +166,21 @@ impl FromAst for ArrayRangeAccessExpression } } Ok(ArrayRangeAccessExpression { - parent: RefCell::new(None), + parent: Cell::new(None), span: Some(value.span.clone()), - array, - left, - right, + array: Cell::new(array), + left: Cell::new(left), + right: Cell::new(right), }) } } -impl Into for &ArrayRangeAccessExpression { +impl<'a> Into for &ArrayRangeAccessExpression<'a> { fn into(self) -> leo_ast::ArrayRangeAccessExpression { leo_ast::ArrayRangeAccessExpression { - array: Box::new(self.array.as_ref().into()), - left: self.left.as_ref().map(|left| Box::new(left.as_ref().into())), - right: self.right.as_ref().map(|right| Box::new(right.as_ref().into())), + array: Box::new(self.array.get().into()), + left: self.left.get().map(|left| Box::new(left.into())), + right: self.right.get().map(|right| Box::new(right.into())), span: self.span.clone().unwrap_or_default(), } } diff --git a/asg/src/expression/binary.rs b/asg/src/expression/binary.rs index 3c4bfc207c..9322b1d587 100644 --- a/asg/src/expression/binary.rs +++ b/asg/src/expression/binary.rs @@ -17,44 +17,41 @@ use crate::{AsgConvertError, ConstValue, Expression, ExpressionNode, FromAst, Node, PartialType, Scope, Span, Type}; pub use leo_ast::{BinaryOperation, BinaryOperationClass}; -use std::{ - cell::RefCell, - sync::{Arc, Weak}, -}; +use std::cell::Cell; -#[derive(Debug)] -pub struct BinaryExpression { - pub parent: RefCell>>, +#[derive(Clone)] +pub struct BinaryExpression<'a> { + pub parent: Cell>>, pub span: Option, pub operation: BinaryOperation, - pub left: Arc, - pub right: Arc, + pub left: Cell<&'a Expression<'a>>, + pub right: Cell<&'a Expression<'a>>, } -impl Node for BinaryExpression { +impl<'a> Node for BinaryExpression<'a> { fn span(&self) -> Option<&Span> { self.span.as_ref() } } -impl ExpressionNode for BinaryExpression { - fn set_parent(&self, parent: Weak) { +impl<'a> ExpressionNode<'a> for BinaryExpression<'a> { + fn set_parent(&self, parent: &'a Expression<'a>) { self.parent.replace(Some(parent)); } - fn get_parent(&self) -> Option> { - self.parent.borrow().as_ref().map(Weak::upgrade).flatten() + fn get_parent(&self) -> Option<&'a Expression<'a>> { + self.parent.get() } - fn enforce_parents(&self, expr: &Arc) { - self.left.set_parent(Arc::downgrade(expr)); - self.right.set_parent(Arc::downgrade(expr)); + fn enforce_parents(&self, expr: &'a Expression<'a>) { + self.left.get().set_parent(expr); + self.right.get().set_parent(expr); } - fn get_type(&self) -> Option { + fn get_type(&self) -> Option> { match self.operation.class() { BinaryOperationClass::Boolean => Some(Type::Boolean), - BinaryOperationClass::Numeric => self.left.get_type(), + BinaryOperationClass::Numeric => self.left.get().get_type(), } } @@ -64,8 +61,8 @@ impl ExpressionNode for BinaryExpression { fn const_value(&self) -> Option { use BinaryOperation::*; - let left = self.left.const_value()?; - let right = self.right.const_value()?; + let left = self.left.get().const_value()?; + let right = self.right.get().const_value()?; match (left, right) { (ConstValue::Int(left), ConstValue::Int(right)) => Some(match self.operation { @@ -110,16 +107,16 @@ impl ExpressionNode for BinaryExpression { } fn is_consty(&self) -> bool { - self.left.is_consty() && self.right.is_consty() + self.left.get().is_consty() && self.right.get().is_consty() } } -impl FromAst for BinaryExpression { +impl<'a> FromAst<'a, leo_ast::BinaryExpression> for BinaryExpression<'a> { fn from_ast( - scope: &Scope, + scope: &'a Scope<'a>, value: &leo_ast::BinaryExpression, - expected_type: Option, - ) -> Result { + expected_type: Option>, + ) -> Result, AsgConvertError> { let class = value.op.class(); let expected_type = match class { BinaryOperationClass::Boolean => match expected_type { @@ -148,16 +145,16 @@ impl FromAst for BinaryExpression { }; // left - let (left, right) = match Arc::::from_ast(scope, &*value.left, expected_type.clone()) { + let (left, right) = match <&Expression<'a>>::from_ast(scope, &*value.left, expected_type.clone()) { Ok(left) => { if let Some(left_type) = left.get_type() { - let right = Arc::::from_ast(scope, &*value.right, Some(left_type.partial()))?; + let right = <&Expression<'a>>::from_ast(scope, &*value.right, Some(left_type.partial()))?; (left, right) } else { - let right = Arc::::from_ast(scope, &*value.right, expected_type)?; + let right = <&Expression<'a>>::from_ast(scope, &*value.right, expected_type)?; if let Some(right_type) = right.get_type() { ( - Arc::::from_ast(scope, &*value.left, Some(right_type.partial()))?, + <&Expression<'a>>::from_ast(scope, &*value.left, Some(right_type.partial()))?, right, ) } else { @@ -166,10 +163,10 @@ impl FromAst for BinaryExpression { } } Err(e) => { - let right = Arc::::from_ast(scope, &*value.right, expected_type)?; + let right = <&Expression<'a>>::from_ast(scope, &*value.right, expected_type)?; if let Some(right_type) = right.get_type() { ( - Arc::::from_ast(scope, &*value.left, Some(right_type.partial()))?, + <&Expression<'a>>::from_ast(scope, &*value.left, Some(right_type.partial()))?, right, ) } else { @@ -244,21 +241,21 @@ impl FromAst for BinaryExpression { (_, _) => (), } Ok(BinaryExpression { - parent: RefCell::new(None), + parent: Cell::new(None), span: Some(value.span.clone()), operation: value.op.clone(), - left, - right, + left: Cell::new(left), + right: Cell::new(right), }) } } -impl Into for &BinaryExpression { +impl<'a> Into for &BinaryExpression<'a> { fn into(self) -> leo_ast::BinaryExpression { leo_ast::BinaryExpression { op: self.operation.clone(), - left: Box::new(self.left.as_ref().into()), - right: Box::new(self.right.as_ref().into()), + left: Box::new(self.left.get().into()), + right: Box::new(self.right.get().into()), span: self.span.clone().unwrap_or_default(), } } diff --git a/asg/src/expression/call.rs b/asg/src/expression/call.rs index 306301576c..52572a4907 100644 --- a/asg/src/expression/call.rs +++ b/asg/src/expression/call.rs @@ -31,46 +31,43 @@ use crate::{ }; pub use leo_ast::{BinaryOperation, Node as AstNode}; -use std::{ - cell::RefCell, - sync::{Arc, Weak}, -}; +use std::cell::Cell; -#[derive(Debug)] -pub struct CallExpression { - pub parent: RefCell>>, +#[derive(Clone)] +pub struct CallExpression<'a> { + pub parent: Cell>>, pub span: Option, - pub function: Arc, - pub target: Option>, - pub arguments: Vec>, + pub function: Cell<&'a Function<'a>>, + pub target: Cell>>, + pub arguments: Vec>>, } -impl Node for CallExpression { +impl<'a> Node for CallExpression<'a> { fn span(&self) -> Option<&Span> { self.span.as_ref() } } -impl ExpressionNode for CallExpression { - fn set_parent(&self, parent: Weak) { +impl<'a> ExpressionNode<'a> for CallExpression<'a> { + fn set_parent(&self, parent: &'a Expression<'a>) { self.parent.replace(Some(parent)); } - fn get_parent(&self) -> Option> { - self.parent.borrow().as_ref().map(Weak::upgrade).flatten() + fn get_parent(&self) -> Option<&'a Expression<'a>> { + self.parent.get() } - fn enforce_parents(&self, expr: &Arc) { - if let Some(target) = self.target.as_ref() { - target.set_parent(Arc::downgrade(expr)); + fn enforce_parents(&self, expr: &'a Expression<'a>) { + if let Some(target) = self.target.get() { + target.set_parent(expr); } self.arguments.iter().for_each(|element| { - element.set_parent(Arc::downgrade(expr)); + element.get().set_parent(expr); }) } - fn get_type(&self) -> Option { - Some(self.function.output.clone().into()) + fn get_type(&self) -> Option> { + Some(self.function.get().output.clone()) } fn is_mut_ref(&self) -> bool { @@ -83,21 +80,20 @@ impl ExpressionNode for CallExpression { } fn is_consty(&self) -> bool { - self.target.as_ref().map(|x| x.is_consty()).unwrap_or(true) && self.arguments.iter().all(|x| x.is_consty()) + self.target.get().map(|x| x.is_consty()).unwrap_or(true) && self.arguments.iter().all(|x| x.get().is_consty()) } } -impl FromAst for CallExpression { +impl<'a> FromAst<'a, leo_ast::CallExpression> for CallExpression<'a> { fn from_ast( - scope: &Scope, + scope: &'a Scope<'a>, value: &leo_ast::CallExpression, - expected_type: Option, - ) -> Result { + expected_type: Option>, + ) -> Result, AsgConvertError> { let (target, function) = match &*value.function { leo_ast::Expression::Identifier(name) => ( None, scope - .borrow() .resolve_function(&name.name) .ok_or_else(|| AsgConvertError::unresolved_function(&name.name, &name.span))?, ), @@ -106,7 +102,7 @@ impl FromAst for CallExpression { name, span, }) => { - let target = Arc::::from_ast(scope, &**ast_circuit, None)?; + let target = <&Expression<'a>>::from_ast(scope, &**ast_circuit, None)?; let circuit = match target.get_type() { Some(Type::Circuit(circuit)) => circuit, type_ => { @@ -137,7 +133,7 @@ impl FromAst for CallExpression { &span, )); } - (Some(target), body.clone()) + (Some(target), *body) } CircuitMember::Variable(_) => { return Err(AsgConvertError::circuit_variable_call(&circuit_name, &name.name, &span)); @@ -151,7 +147,6 @@ impl FromAst for CallExpression { }) => { let circuit = if let leo_ast::Expression::Identifier(circuit_name) = &**ast_circuit { scope - .borrow() .resolve_circuit(&circuit_name.name) .ok_or_else(|| AsgConvertError::unresolved_circuit(&circuit_name.name, &circuit_name.span))? } else { @@ -172,7 +167,7 @@ impl FromAst for CallExpression { &span, )); } - (None, body.clone()) + (None, *body) } CircuitMember::Variable(_) => { return Err(AsgConvertError::circuit_variable_call(&circuit_name, &name.name, &span)); @@ -186,7 +181,7 @@ impl FromAst for CallExpression { } }; if let Some(expected) = expected_type { - let output: Type = function.output.clone().into(); + let output: Type = function.output.clone(); if !expected.matches(&output) { return Err(AsgConvertError::unexpected_type( &expected.to_string(), @@ -207,46 +202,45 @@ impl FromAst for CallExpression { .arguments .iter() .zip(function.arguments.iter()) - .map(|(expr, argument)| { - let argument = argument.borrow(); - let converted = - Arc::::from_ast(scope, expr, Some(argument.type_.clone().strong().partial()))?; + .map(|(expr, (_, argument))| { + let argument = argument.get().borrow(); + let converted = <&Expression<'a>>::from_ast(scope, expr, Some(argument.type_.clone().partial()))?; if argument.const_ && !converted.is_consty() { return Err(AsgConvertError::unexpected_nonconst(&expr.span())); } - Ok(converted) + Ok(Cell::new(converted)) }) .collect::, AsgConvertError>>()?; Ok(CallExpression { - parent: RefCell::new(None), + parent: Cell::new(None), span: Some(value.span.clone()), arguments, - function, - target, + function: Cell::new(function), + target: Cell::new(target), }) } } -impl Into for &CallExpression { +impl<'a> Into for &CallExpression<'a> { fn into(self) -> leo_ast::CallExpression { - let target_function = if let Some(target) = &self.target { - target.as_ref().into() + let target_function = if let Some(target) = self.target.get() { + target.into() } else { - let circuit = self.function.circuit.borrow().as_ref().map(|x| x.upgrade()).flatten(); + let circuit = self.function.get().circuit.get(); if let Some(circuit) = circuit { leo_ast::Expression::CircuitStaticFunctionAccess(leo_ast::CircuitStaticFunctionAccessExpression { circuit: Box::new(leo_ast::Expression::Identifier(circuit.name.borrow().clone())), - name: self.function.name.borrow().clone(), + name: self.function.get().name.borrow().clone(), span: self.span.clone().unwrap_or_default(), }) } else { - leo_ast::Expression::Identifier(self.function.name.borrow().clone()) + leo_ast::Expression::Identifier(self.function.get().name.borrow().clone()) } }; leo_ast::CallExpression { function: Box::new(target_function), - arguments: self.arguments.iter().map(|arg| arg.as_ref().into()).collect(), + arguments: self.arguments.iter().map(|arg| arg.get().into()).collect(), span: self.span.clone().unwrap_or_default(), } } diff --git a/asg/src/expression/circuit_access.rs b/asg/src/expression/circuit_access.rs index f7e0a07139..739579cabd 100644 --- a/asg/src/expression/circuit_access.rs +++ b/asg/src/expression/circuit_access.rs @@ -18,7 +18,6 @@ use crate::{ AsgConvertError, Circuit, CircuitMember, - CircuitMemberBody, ConstValue, Expression, ExpressionNode, @@ -31,56 +30,53 @@ use crate::{ Type, }; -use std::{ - cell::RefCell, - sync::{Arc, Weak}, -}; +use std::cell::Cell; -#[derive(Debug)] -pub struct CircuitAccessExpression { - pub parent: RefCell>>, +#[derive(Clone)] +pub struct CircuitAccessExpression<'a> { + pub parent: Cell>>, pub span: Option, - pub circuit: Arc, - pub target: Option>, + pub circuit: Cell<&'a Circuit<'a>>, + pub target: Cell>>, pub member: Identifier, } -impl Node for CircuitAccessExpression { +impl<'a> Node for CircuitAccessExpression<'a> { fn span(&self) -> Option<&Span> { self.span.as_ref() } } -impl ExpressionNode for CircuitAccessExpression { - fn set_parent(&self, parent: Weak) { +impl<'a> ExpressionNode<'a> for CircuitAccessExpression<'a> { + fn set_parent(&self, parent: &'a Expression<'a>) { self.parent.replace(Some(parent)); } - fn get_parent(&self) -> Option> { - self.parent.borrow().as_ref().map(Weak::upgrade).flatten() + fn get_parent(&self) -> Option<&'a Expression<'a>> { + self.parent.get() } - fn enforce_parents(&self, expr: &Arc) { - if let Some(target) = self.target.as_ref() { - target.set_parent(Arc::downgrade(expr)); + fn enforce_parents(&self, expr: &'a Expression<'a>) { + if let Some(target) = self.target.get() { + target.set_parent(expr); } } - fn get_type(&self) -> Option { - if self.target.is_none() { + fn get_type(&self) -> Option> { + if self.target.get().is_none() { None // function target only for static } else { - let members = self.circuit.members.borrow(); + let members = self.circuit.get().members.borrow(); let member = members.get(&self.member.name)?; match member { - CircuitMember::Variable(type_) => Some(type_.clone().into()), + CircuitMember::Variable(type_) => Some(type_.clone()), CircuitMember::Function(_) => None, } } } fn is_mut_ref(&self) -> bool { - if let Some(target) = self.target.as_ref() { + if let Some(target) = self.target.get() { target.is_mut_ref() } else { false @@ -92,17 +88,17 @@ impl ExpressionNode for CircuitAccessExpression { } fn is_consty(&self) -> bool { - self.target.as_ref().map(|x| x.is_consty()).unwrap_or(true) + self.target.get().map(|x| x.is_consty()).unwrap_or(true) } } -impl FromAst for CircuitAccessExpression { +impl<'a> FromAst<'a, leo_ast::CircuitMemberAccessExpression> for CircuitAccessExpression<'a> { fn from_ast( - scope: &Scope, + scope: &'a Scope<'a>, value: &leo_ast::CircuitMemberAccessExpression, - expected_type: Option, - ) -> Result { - let target = Arc::::from_ast(scope, &*value.circuit, None)?; + expected_type: Option>, + ) -> Result, AsgConvertError> { + let target = <&'a Expression<'a>>::from_ast(scope, &*value.circuit, None)?; let circuit = match target.get_type() { Some(Type::Circuit(circuit)) => circuit, x => { @@ -119,7 +115,7 @@ impl FromAst for CircuitAccessExpression if let Some(member) = circuit.members.borrow().get(&value.name.name) { if let Some(expected_type) = &expected_type { if let CircuitMember::Variable(type_) = &member { - let type_: Type = type_.clone().into(); + let type_: Type = type_.clone(); if !expected_type.matches(&type_) { return Err(AsgConvertError::unexpected_type( &expected_type.to_string(), @@ -140,15 +136,10 @@ impl FromAst for CircuitAccessExpression } else if circuit.is_input_pseudo_circuit() { // add new member to implicit input if let Some(expected_type) = expected_type.map(PartialType::full).flatten() { - circuit.members.borrow_mut().insert( - value.name.name.clone(), - CircuitMember::Variable(expected_type.clone().into()), - ); - let body = circuit.body.borrow().upgrade().expect("stale input circuit body"); - - body.members + circuit + .members .borrow_mut() - .insert(value.name.name.clone(), CircuitMemberBody::Variable(expected_type)); + .insert(value.name.name.clone(), CircuitMember::Variable(expected_type.clone())); } else { return Err(AsgConvertError::input_ref_needs_type( &circuit.name.borrow().name, @@ -165,24 +156,23 @@ impl FromAst for CircuitAccessExpression } Ok(CircuitAccessExpression { - parent: RefCell::new(None), + parent: Cell::new(None), span: Some(value.span.clone()), - target: Some(target), - circuit, + target: Cell::new(Some(target)), + circuit: Cell::new(circuit), member: value.name.clone(), }) } } -impl FromAst for CircuitAccessExpression { +impl<'a> FromAst<'a, leo_ast::CircuitStaticFunctionAccessExpression> for CircuitAccessExpression<'a> { fn from_ast( - scope: &Scope, + scope: &Scope<'a>, value: &leo_ast::CircuitStaticFunctionAccessExpression, expected_type: Option, - ) -> Result { + ) -> Result, AsgConvertError> { let circuit = match &*value.circuit { leo_ast::Expression::Identifier(name) => scope - .borrow() .resolve_circuit(&name.name) .ok_or_else(|| AsgConvertError::unresolved_circuit(&name.name, &name.span))?, _ => { @@ -213,26 +203,28 @@ impl FromAst for CircuitAccessEx } Ok(CircuitAccessExpression { - parent: RefCell::new(None), + parent: Cell::new(None), span: Some(value.span.clone()), - target: None, - circuit, + target: Cell::new(None), + circuit: Cell::new(circuit), member: value.name.clone(), }) } } -impl Into for &CircuitAccessExpression { +impl<'a> Into for &CircuitAccessExpression<'a> { fn into(self) -> leo_ast::Expression { - if let Some(target) = self.target.as_ref() { + if let Some(target) = self.target.get() { leo_ast::Expression::CircuitMemberAccess(leo_ast::CircuitMemberAccessExpression { - circuit: Box::new(target.as_ref().into()), + circuit: Box::new(target.into()), name: self.member.clone(), span: self.span.clone().unwrap_or_default(), }) } else { leo_ast::Expression::CircuitStaticFunctionAccess(leo_ast::CircuitStaticFunctionAccessExpression { - circuit: Box::new(leo_ast::Expression::Identifier(self.circuit.name.borrow().clone())), + circuit: Box::new(leo_ast::Expression::Identifier( + self.circuit.get().name.borrow().clone(), + )), name: self.member.clone(), span: self.span.clone().unwrap_or_default(), }) diff --git a/asg/src/expression/circuit_init.rs b/asg/src/expression/circuit_init.rs index c05dbe6ae8..b7ec7a764b 100644 --- a/asg/src/expression/circuit_init.rs +++ b/asg/src/expression/circuit_init.rs @@ -31,42 +31,39 @@ use crate::{ }; use indexmap::{IndexMap, IndexSet}; -use std::{ - cell::RefCell, - sync::{Arc, Weak}, -}; +use std::cell::Cell; -#[derive(Debug)] -pub struct CircuitInitExpression { - pub parent: RefCell>>, +#[derive(Clone)] +pub struct CircuitInitExpression<'a> { + pub parent: Cell>>, pub span: Option, - pub circuit: Arc, - pub values: Vec<(Identifier, Arc)>, + pub circuit: Cell<&'a Circuit<'a>>, + pub values: Vec<(Identifier, Cell<&'a Expression<'a>>)>, } -impl Node for CircuitInitExpression { +impl<'a> Node for CircuitInitExpression<'a> { fn span(&self) -> Option<&Span> { self.span.as_ref() } } -impl ExpressionNode for CircuitInitExpression { - fn set_parent(&self, parent: Weak) { +impl<'a> ExpressionNode<'a> for CircuitInitExpression<'a> { + fn set_parent(&self, parent: &'a Expression<'a>) { self.parent.replace(Some(parent)); } - fn get_parent(&self) -> Option> { - self.parent.borrow().as_ref().map(Weak::upgrade).flatten() + fn get_parent(&self) -> Option<&'a Expression<'a>> { + self.parent.get() } - fn enforce_parents(&self, expr: &Arc) { + fn enforce_parents(&self, expr: &'a Expression<'a>) { self.values.iter().for_each(|(_, element)| { - element.set_parent(Arc::downgrade(expr)); + element.get().set_parent(expr); }) } - fn get_type(&self) -> Option { - Some(Type::Circuit(self.circuit.clone())) + fn get_type(&self) -> Option> { + Some(Type::Circuit(self.circuit.get())) } fn is_mut_ref(&self) -> bool { @@ -78,18 +75,17 @@ impl ExpressionNode for CircuitInitExpression { } fn is_consty(&self) -> bool { - self.values.iter().all(|(_, value)| value.is_consty()) + self.values.iter().all(|(_, value)| value.get().is_consty()) } } -impl FromAst for CircuitInitExpression { +impl<'a> FromAst<'a, leo_ast::CircuitInitExpression> for CircuitInitExpression<'a> { fn from_ast( - scope: &Scope, + scope: &'a Scope<'a>, value: &leo_ast::CircuitInitExpression, - expected_type: Option, - ) -> Result { + expected_type: Option>, + ) -> Result, AsgConvertError> { let circuit = scope - .borrow() .resolve_circuit(&value.name.name) .ok_or_else(|| AsgConvertError::unresolved_circuit(&value.name.name, &value.name.span))?; match expected_type { @@ -109,7 +105,7 @@ impl FromAst for CircuitInitExpression { .map(|x| (&x.identifier.name, (&x.identifier, &x.expression))) .collect(); - let mut values: Vec<(Identifier, Arc)> = vec![]; + let mut values: Vec<(Identifier, Cell<&'a Expression<'a>>)> = vec![]; let mut defined_variables = IndexSet::::new(); { @@ -124,13 +120,13 @@ impl FromAst for CircuitInitExpression { } defined_variables.insert(name.clone()); let type_: Type = if let CircuitMember::Variable(type_) = &member { - type_.clone().into() + type_.clone() } else { continue; }; if let Some((identifier, receiver)) = members.get(&name) { - let received = Arc::::from_ast(scope, *receiver, Some(type_.partial()))?; - values.push(((*identifier).clone(), received)); + let received = <&Expression<'a>>::from_ast(scope, *receiver, Some(type_.partial()))?; + values.push(((*identifier).clone(), Cell::new(received))); } else { return Err(AsgConvertError::missing_circuit_member( &circuit.name.borrow().name, @@ -152,24 +148,24 @@ impl FromAst for CircuitInitExpression { } Ok(CircuitInitExpression { - parent: RefCell::new(None), + parent: Cell::new(None), span: Some(value.span.clone()), - circuit, + circuit: Cell::new(circuit), values, }) } } -impl Into for &CircuitInitExpression { +impl<'a> Into for &CircuitInitExpression<'a> { fn into(self) -> leo_ast::CircuitInitExpression { leo_ast::CircuitInitExpression { - name: self.circuit.name.borrow().clone(), + name: self.circuit.get().name.borrow().clone(), members: self .values .iter() .map(|(name, value)| leo_ast::CircuitImpliedVariableDefinition { identifier: name.clone(), - expression: value.as_ref().into(), + expression: value.get().into(), }) .collect(), span: self.span.clone().unwrap_or_default(), diff --git a/asg/src/expression/constant.rs b/asg/src/expression/constant.rs index e2cfe82512..ae2a1fdd13 100644 --- a/asg/src/expression/constant.rs +++ b/asg/src/expression/constant.rs @@ -29,36 +29,33 @@ use crate::{ Type, }; -use std::{ - cell::RefCell, - sync::{Arc, Weak}, -}; +use std::cell::Cell; -#[derive(Debug)] -pub struct Constant { - pub parent: RefCell>>, +#[derive(Clone)] +pub struct Constant<'a> { + pub parent: Cell>>, pub span: Option, pub value: ConstValue, // should not be compound constants } -impl Node for Constant { +impl<'a> Node for Constant<'a> { fn span(&self) -> Option<&Span> { self.span.as_ref() } } -impl ExpressionNode for Constant { - fn set_parent(&self, parent: Weak) { +impl<'a> ExpressionNode<'a> for Constant<'a> { + fn set_parent(&self, parent: &'a Expression<'a>) { self.parent.replace(Some(parent)); } - fn get_parent(&self) -> Option> { - self.parent.borrow().as_ref().map(Weak::upgrade).flatten() + fn get_parent(&self) -> Option<&'a Expression<'a>> { + self.parent.get() } - fn enforce_parents(&self, _expr: &Arc) {} + fn enforce_parents(&self, _expr: &'a Expression<'a>) {} - fn get_type(&self) -> Option { + fn get_type(&self) -> Option> { self.value.get_type() } @@ -75,12 +72,12 @@ impl ExpressionNode for Constant { } } -impl FromAst for Constant { +impl<'a> FromAst<'a, leo_ast::ValueExpression> for Constant<'a> { fn from_ast( - _scope: &Scope, + _scope: &'a Scope<'a>, value: &leo_ast::ValueExpression, - expected_type: Option, - ) -> Result { + expected_type: Option>, + ) -> Result, AsgConvertError> { use leo_ast::ValueExpression::*; Ok(match value { Address(value, span) => { @@ -95,7 +92,7 @@ impl FromAst for Constant { } } Constant { - parent: RefCell::new(None), + parent: Cell::new(None), span: Some(span.clone()), value: ConstValue::Address(value.clone()), } @@ -112,7 +109,7 @@ impl FromAst for Constant { } } Constant { - parent: RefCell::new(None), + parent: Cell::new(None), span: Some(span.clone()), value: ConstValue::Boolean( value @@ -133,7 +130,7 @@ impl FromAst for Constant { } } Constant { - parent: RefCell::new(None), + parent: Cell::new(None), span: Some(span.clone()), value: ConstValue::Field(value.parse().map_err(|_| AsgConvertError::invalid_int(&value, span))?), } @@ -150,7 +147,7 @@ impl FromAst for Constant { } } Constant { - parent: RefCell::new(None), + parent: Cell::new(None), span: Some(value.span().clone()), value: ConstValue::Group(match &**value { leo_ast::GroupValue::Single(value, _) => GroupValue::Single(value.clone()), @@ -164,23 +161,23 @@ impl FromAst for Constant { None => return Err(AsgConvertError::unresolved_type("unknown", span)), Some(PartialType::Integer(Some(sub_type), _)) | Some(PartialType::Integer(None, Some(sub_type))) => { Constant { - parent: RefCell::new(None), + parent: Cell::new(None), span: Some(span.clone()), value: ConstValue::Int(ConstInt::parse(&sub_type, value, span)?), } } Some(PartialType::Type(Type::Field)) => Constant { - parent: RefCell::new(None), + parent: Cell::new(None), span: Some(span.clone()), value: ConstValue::Field(value.parse().map_err(|_| AsgConvertError::invalid_int(&value, span))?), }, Some(PartialType::Type(Type::Group)) => Constant { - parent: RefCell::new(None), + parent: Cell::new(None), span: Some(span.clone()), value: ConstValue::Group(GroupValue::Single(value.to_string())), }, Some(PartialType::Type(Type::Address)) => Constant { - parent: RefCell::new(None), + parent: Cell::new(None), span: Some(span.clone()), value: ConstValue::Address(value.to_string()), }, @@ -200,7 +197,7 @@ impl FromAst for Constant { } } Constant { - parent: RefCell::new(None), + parent: Cell::new(None), span: Some(span.clone()), value: ConstValue::Int(ConstInt::parse(int_type, value, span)?), } @@ -209,7 +206,7 @@ impl FromAst for Constant { } } -impl Into for &Constant { +impl<'a> Into for &Constant<'a> { fn into(self) -> leo_ast::ValueExpression { match &self.value { ConstValue::Address(value) => { diff --git a/asg/src/expression/mod.rs b/asg/src/expression/mod.rs index a60747022e..720bac0781 100644 --- a/asg/src/expression/mod.rs +++ b/asg/src/expression/mod.rs @@ -64,31 +64,29 @@ pub use variable_ref::*; use crate::{AsgConvertError, ConstValue, FromAst, Node, PartialType, Scope, Span, Type}; -use std::sync::{Arc, Weak}; +#[derive(Clone)] +pub enum Expression<'a> { + VariableRef(VariableRef<'a>), + Constant(Constant<'a>), + Binary(BinaryExpression<'a>), + Unary(UnaryExpression<'a>), + Ternary(TernaryExpression<'a>), -#[derive(Debug)] -pub enum Expression { - VariableRef(VariableRef), - Constant(Constant), - Binary(BinaryExpression), - Unary(UnaryExpression), - Ternary(TernaryExpression), + ArrayInline(ArrayInlineExpression<'a>), + ArrayInit(ArrayInitExpression<'a>), + ArrayAccess(ArrayAccessExpression<'a>), + ArrayRangeAccess(ArrayRangeAccessExpression<'a>), - ArrayInline(ArrayInlineExpression), - ArrayInit(ArrayInitExpression), - ArrayAccess(ArrayAccessExpression), - ArrayRangeAccess(ArrayRangeAccessExpression), + TupleInit(TupleInitExpression<'a>), + TupleAccess(TupleAccessExpression<'a>), - TupleInit(TupleInitExpression), - TupleAccess(TupleAccessExpression), + CircuitInit(CircuitInitExpression<'a>), + CircuitAccess(CircuitAccessExpression<'a>), - CircuitInit(CircuitInitExpression), - CircuitAccess(CircuitAccessExpression), - - Call(CallExpression), + Call(CallExpression<'a>), } -impl Node for Expression { +impl<'a> Node for Expression<'a> { fn span(&self) -> Option<&Span> { use Expression::*; match self { @@ -110,19 +108,19 @@ impl Node for Expression { } } -pub trait ExpressionNode: Node { - fn set_parent(&self, parent: Weak); - fn get_parent(&self) -> Option>; - fn enforce_parents(&self, expr: &Arc); +pub trait ExpressionNode<'a>: Node { + fn set_parent(&self, parent: &'a Expression<'a>); + fn get_parent(&self) -> Option<&'a Expression<'a>>; + fn enforce_parents(&self, expr: &'a Expression<'a>); - fn get_type(&self) -> Option; + fn get_type(&self) -> Option>; fn is_mut_ref(&self) -> bool; fn const_value(&self) -> Option; // todo: memoize fn is_consty(&self) -> bool; } -impl ExpressionNode for Expression { - fn set_parent(&self, parent: Weak) { +impl<'a> ExpressionNode<'a> for Expression<'a> { + fn set_parent(&self, parent: &'a Expression<'a>) { use Expression::*; match self { VariableRef(x) => x.set_parent(parent), @@ -142,7 +140,7 @@ impl ExpressionNode for Expression { } } - fn get_parent(&self) -> Option> { + fn get_parent(&self) -> Option<&'a Expression<'a>> { use Expression::*; match self { VariableRef(x) => x.get_parent(), @@ -162,7 +160,7 @@ impl ExpressionNode for Expression { } } - fn enforce_parents(&self, expr: &Arc) { + fn enforce_parents(&self, expr: &'a Expression<'a>) { use Expression::*; match self { VariableRef(x) => x.enforce_parents(expr), @@ -182,7 +180,7 @@ impl ExpressionNode for Expression { } } - fn get_type(&self) -> Option { + fn get_type(&self) -> Option> { use Expression::*; match self { VariableRef(x) => x.get_type(), @@ -263,65 +261,70 @@ impl ExpressionNode for Expression { } } -impl FromAst for Arc { +impl<'a> FromAst<'a, leo_ast::Expression> for &'a Expression<'a> { fn from_ast( - scope: &Scope, + scope: &'a Scope<'a>, value: &leo_ast::Expression, - expected_type: Option, + expected_type: Option>, ) -> Result { use leo_ast::Expression::*; let expression = match value { Identifier(identifier) => Self::from_ast(scope, identifier, expected_type)?, - Value(value) => Arc::new(Constant::from_ast(scope, value, expected_type).map(Expression::Constant)?), - Binary(binary) => { - Arc::new(BinaryExpression::from_ast(scope, binary, expected_type).map(Expression::Binary)?) + Value(value) => { + scope.alloc_expression(Constant::from_ast(scope, value, expected_type).map(Expression::Constant)?) } - Unary(unary) => Arc::new(UnaryExpression::from_ast(scope, unary, expected_type).map(Expression::Unary)?), - Ternary(conditional) => { - Arc::new(TernaryExpression::from_ast(scope, conditional, expected_type).map(Expression::Ternary)?) + Binary(binary) => scope + .alloc_expression(BinaryExpression::from_ast(scope, binary, expected_type).map(Expression::Binary)?), + Unary(unary) => { + scope.alloc_expression(UnaryExpression::from_ast(scope, unary, expected_type).map(Expression::Unary)?) } + Ternary(conditional) => scope.alloc_expression( + TernaryExpression::from_ast(scope, conditional, expected_type).map(Expression::Ternary)?, + ), - ArrayInline(array_inline) => Arc::new( + ArrayInline(array_inline) => scope.alloc_expression( ArrayInlineExpression::from_ast(scope, array_inline, expected_type).map(Expression::ArrayInline)?, ), - ArrayInit(array_init) => { - Arc::new(ArrayInitExpression::from_ast(scope, array_init, expected_type).map(Expression::ArrayInit)?) - } - ArrayAccess(array_access) => Arc::new( + ArrayInit(array_init) => scope.alloc_expression( + ArrayInitExpression::from_ast(scope, array_init, expected_type).map(Expression::ArrayInit)?, + ), + ArrayAccess(array_access) => scope.alloc_expression( ArrayAccessExpression::from_ast(scope, array_access, expected_type).map(Expression::ArrayAccess)?, ), - ArrayRangeAccess(array_range_access) => Arc::new( + ArrayRangeAccess(array_range_access) => scope.alloc_expression( ArrayRangeAccessExpression::from_ast(scope, array_range_access, expected_type) .map(Expression::ArrayRangeAccess)?, ), - TupleInit(tuple_init) => { - Arc::new(TupleInitExpression::from_ast(scope, tuple_init, expected_type).map(Expression::TupleInit)?) - } - TupleAccess(tuple_access) => Arc::new( + TupleInit(tuple_init) => scope.alloc_expression( + TupleInitExpression::from_ast(scope, tuple_init, expected_type).map(Expression::TupleInit)?, + ), + TupleAccess(tuple_access) => scope.alloc_expression( TupleAccessExpression::from_ast(scope, tuple_access, expected_type).map(Expression::TupleAccess)?, ), - CircuitInit(circuit_init) => Arc::new( + CircuitInit(circuit_init) => scope.alloc_expression( CircuitInitExpression::from_ast(scope, circuit_init, expected_type).map(Expression::CircuitInit)?, ), - CircuitMemberAccess(circuit_member) => Arc::new( + CircuitMemberAccess(circuit_member) => scope.alloc_expression( CircuitAccessExpression::from_ast(scope, circuit_member, expected_type) .map(Expression::CircuitAccess)?, ), - CircuitStaticFunctionAccess(circuit_member) => Arc::new( + CircuitStaticFunctionAccess(circuit_member) => scope.alloc_expression( CircuitAccessExpression::from_ast(scope, circuit_member, expected_type) .map(Expression::CircuitAccess)?, ), - Call(call) => Arc::new(CallExpression::from_ast(scope, call, expected_type).map(Expression::Call)?), + Call(call) => { + scope.alloc_expression(CallExpression::from_ast(scope, call, expected_type).map(Expression::Call)?) + } }; expression.enforce_parents(&expression); Ok(expression) } } -impl Into for &Expression { +impl<'a> Into for &Expression<'a> { fn into(self) -> leo_ast::Expression { use Expression::*; match self { diff --git a/asg/src/expression/ternary.rs b/asg/src/expression/ternary.rs index 7f000eb581..93d1a7df0f 100644 --- a/asg/src/expression/ternary.rs +++ b/asg/src/expression/ternary.rs @@ -16,55 +16,52 @@ use crate::{AsgConvertError, ConstValue, Expression, ExpressionNode, FromAst, Node, PartialType, Scope, Span, Type}; -use std::{ - cell::RefCell, - sync::{Arc, Weak}, -}; +use std::cell::Cell; -#[derive(Debug)] -pub struct TernaryExpression { - pub parent: RefCell>>, +#[derive(Clone)] +pub struct TernaryExpression<'a> { + pub parent: Cell>>, pub span: Option, - pub condition: Arc, - pub if_true: Arc, - pub if_false: Arc, + pub condition: Cell<&'a Expression<'a>>, + pub if_true: Cell<&'a Expression<'a>>, + pub if_false: Cell<&'a Expression<'a>>, } -impl Node for TernaryExpression { +impl<'a> Node for TernaryExpression<'a> { fn span(&self) -> Option<&Span> { self.span.as_ref() } } -impl ExpressionNode for TernaryExpression { - fn set_parent(&self, parent: Weak) { +impl<'a> ExpressionNode<'a> for TernaryExpression<'a> { + fn set_parent(&self, parent: &'a Expression<'a>) { self.parent.replace(Some(parent)); } - fn get_parent(&self) -> Option> { - self.parent.borrow().as_ref().map(Weak::upgrade).flatten() + fn get_parent(&self) -> Option<&'a Expression<'a>> { + self.parent.get() } - fn enforce_parents(&self, expr: &Arc) { - self.condition.set_parent(Arc::downgrade(expr)); - self.if_true.set_parent(Arc::downgrade(expr)); - self.if_false.set_parent(Arc::downgrade(expr)); + fn enforce_parents(&self, expr: &'a Expression<'a>) { + self.condition.get().set_parent(expr); + self.if_true.get().set_parent(expr); + self.if_false.get().set_parent(expr); } - fn get_type(&self) -> Option { - self.if_true.get_type() + fn get_type(&self) -> Option> { + self.if_true.get().get_type() } fn is_mut_ref(&self) -> bool { - self.if_true.is_mut_ref() && self.if_false.is_mut_ref() + self.if_true.get().is_mut_ref() && self.if_false.get().is_mut_ref() } fn const_value(&self) -> Option { - if let Some(ConstValue::Boolean(switch)) = self.condition.const_value() { + if let Some(ConstValue::Boolean(switch)) = self.condition.get().const_value() { if switch { - self.if_true.const_value() + self.if_true.get().const_value() } else { - self.if_false.const_value() + self.if_false.get().const_value() } } else { None @@ -72,32 +69,40 @@ impl ExpressionNode for TernaryExpression { } fn is_consty(&self) -> bool { - self.condition.is_consty() && self.if_true.is_consty() && self.if_false.is_consty() + self.condition.get().is_consty() && self.if_true.get().is_consty() && self.if_false.get().is_consty() } } -impl FromAst for TernaryExpression { +impl<'a> FromAst<'a, leo_ast::TernaryExpression> for TernaryExpression<'a> { fn from_ast( - scope: &Scope, + scope: &'a Scope<'a>, value: &leo_ast::TernaryExpression, - expected_type: Option, - ) -> Result { + expected_type: Option>, + ) -> Result, AsgConvertError> { Ok(TernaryExpression { - parent: RefCell::new(None), + parent: Cell::new(None), span: Some(value.span.clone()), - condition: Arc::::from_ast(scope, &*value.condition, Some(Type::Boolean.partial()))?, - if_true: Arc::::from_ast(scope, &*value.if_true, expected_type.clone())?, - if_false: Arc::::from_ast(scope, &*value.if_false, expected_type)?, + condition: Cell::new(<&Expression<'a>>::from_ast( + scope, + &*value.condition, + Some(Type::Boolean.partial()), + )?), + if_true: Cell::new(<&Expression<'a>>::from_ast( + scope, + &*value.if_true, + expected_type.clone(), + )?), + if_false: Cell::new(<&Expression<'a>>::from_ast(scope, &*value.if_false, expected_type)?), }) } } -impl Into for &TernaryExpression { +impl<'a> Into for &TernaryExpression<'a> { fn into(self) -> leo_ast::TernaryExpression { leo_ast::TernaryExpression { - condition: Box::new(self.condition.as_ref().into()), - if_true: Box::new(self.if_true.as_ref().into()), - if_false: Box::new(self.if_false.as_ref().into()), + condition: Box::new(self.condition.get().into()), + if_true: Box::new(self.if_true.get().into()), + if_false: Box::new(self.if_false.get().into()), span: self.span.clone().unwrap_or_default(), } } diff --git a/asg/src/expression/tuple_access.rs b/asg/src/expression/tuple_access.rs index 3a491eec18..4886ff4aa6 100644 --- a/asg/src/expression/tuple_access.rs +++ b/asg/src/expression/tuple_access.rs @@ -16,51 +16,48 @@ use crate::{AsgConvertError, ConstValue, Expression, ExpressionNode, FromAst, Node, PartialType, Scope, Span, Type}; -use std::{ - cell::RefCell, - sync::{Arc, Weak}, -}; +use std::cell::Cell; -#[derive(Debug)] -pub struct TupleAccessExpression { - pub parent: RefCell>>, +#[derive(Clone)] +pub struct TupleAccessExpression<'a> { + pub parent: Cell>>, pub span: Option, - pub tuple_ref: Arc, + pub tuple_ref: Cell<&'a Expression<'a>>, pub index: usize, } -impl Node for TupleAccessExpression { +impl<'a> Node for TupleAccessExpression<'a> { fn span(&self) -> Option<&Span> { self.span.as_ref() } } -impl ExpressionNode for TupleAccessExpression { - fn set_parent(&self, parent: Weak) { +impl<'a> ExpressionNode<'a> for TupleAccessExpression<'a> { + fn set_parent(&self, parent: &'a Expression<'a>) { self.parent.replace(Some(parent)); } - fn get_parent(&self) -> Option> { - self.parent.borrow().as_ref().map(Weak::upgrade).flatten() + fn get_parent(&self) -> Option<&'a Expression<'a>> { + self.parent.get() } - fn enforce_parents(&self, expr: &Arc) { - self.tuple_ref.set_parent(Arc::downgrade(expr)); + fn enforce_parents(&self, expr: &'a Expression<'a>) { + self.tuple_ref.get().set_parent(expr); } - fn get_type(&self) -> Option { - match self.tuple_ref.get_type()? { + fn get_type(&self) -> Option> { + match self.tuple_ref.get().get_type()? { Type::Tuple(subtypes) => subtypes.get(self.index).cloned(), _ => None, } } fn is_mut_ref(&self) -> bool { - self.tuple_ref.is_mut_ref() + self.tuple_ref.get().is_mut_ref() } fn const_value(&self) -> Option { - let tuple_const = self.tuple_ref.const_value()?; + let tuple_const = self.tuple_ref.get().const_value()?; match tuple_const { ConstValue::Tuple(sub_consts) => sub_consts.get(self.index).cloned(), _ => None, @@ -68,16 +65,16 @@ impl ExpressionNode for TupleAccessExpression { } fn is_consty(&self) -> bool { - self.tuple_ref.is_consty() + self.tuple_ref.get().is_consty() } } -impl FromAst for TupleAccessExpression { +impl<'a> FromAst<'a, leo_ast::TupleAccessExpression> for TupleAccessExpression<'a> { fn from_ast( - scope: &Scope, + scope: &'a Scope<'a>, value: &leo_ast::TupleAccessExpression, - expected_type: Option, - ) -> Result { + expected_type: Option>, + ) -> Result, AsgConvertError> { let index = value .index .value @@ -87,7 +84,7 @@ impl FromAst for TupleAccessExpression { let mut expected_tuple = vec![None; index + 1]; expected_tuple[index] = expected_type; - let tuple = Arc::::from_ast(scope, &*value.tuple, Some(PartialType::Tuple(expected_tuple)))?; + let tuple = <&Expression<'a>>::from_ast(scope, &*value.tuple, Some(PartialType::Tuple(expected_tuple)))?; let tuple_type = tuple.get_type(); if let Some(Type::Tuple(_items)) = tuple_type { } else { @@ -99,18 +96,18 @@ impl FromAst for TupleAccessExpression { } Ok(TupleAccessExpression { - parent: RefCell::new(None), + parent: Cell::new(None), span: Some(value.span.clone()), - tuple_ref: tuple, + tuple_ref: Cell::new(tuple), index, }) } } -impl Into for &TupleAccessExpression { +impl<'a> Into for &TupleAccessExpression<'a> { fn into(self) -> leo_ast::TupleAccessExpression { leo_ast::TupleAccessExpression { - tuple: Box::new(self.tuple_ref.as_ref().into()), + tuple: Box::new(self.tuple_ref.get().into()), index: leo_ast::PositiveNumber { value: self.index.to_string(), }, diff --git a/asg/src/expression/tuple_init.rs b/asg/src/expression/tuple_init.rs index 2772238e2c..195b4fe4af 100644 --- a/asg/src/expression/tuple_init.rs +++ b/asg/src/expression/tuple_init.rs @@ -16,43 +16,40 @@ use crate::{AsgConvertError, ConstValue, Expression, ExpressionNode, FromAst, Node, PartialType, Scope, Span, Type}; -use std::{ - cell::RefCell, - sync::{Arc, Weak}, -}; +use std::cell::Cell; -#[derive(Debug)] -pub struct TupleInitExpression { - pub parent: RefCell>>, +#[derive(Clone)] +pub struct TupleInitExpression<'a> { + pub parent: Cell>>, pub span: Option, - pub elements: Vec>, + pub elements: Vec>>, } -impl Node for TupleInitExpression { +impl<'a> Node for TupleInitExpression<'a> { fn span(&self) -> Option<&Span> { self.span.as_ref() } } -impl ExpressionNode for TupleInitExpression { - fn set_parent(&self, parent: Weak) { +impl<'a> ExpressionNode<'a> for TupleInitExpression<'a> { + fn set_parent(&self, parent: &'a Expression<'a>) { self.parent.replace(Some(parent)); } - fn get_parent(&self) -> Option> { - self.parent.borrow().as_ref().map(Weak::upgrade).flatten() + fn get_parent(&self) -> Option<&'a Expression<'a>> { + self.parent.get() } - fn enforce_parents(&self, expr: &Arc) { + fn enforce_parents(&self, expr: &'a Expression<'a>) { self.elements.iter().for_each(|element| { - element.set_parent(Arc::downgrade(expr)); + element.get().set_parent(expr); }) } - fn get_type(&self) -> Option { + fn get_type(&self) -> Option> { let mut output = vec![]; for element in self.elements.iter() { - output.push(element.get_type()?); + output.push(element.get().get_type()?); } Some(Type::Tuple(output)) } @@ -64,7 +61,7 @@ impl ExpressionNode for TupleInitExpression { fn const_value(&self) -> Option { let mut consts = vec![]; for element in self.elements.iter() { - if let Some(const_value) = element.const_value() { + if let Some(const_value) = element.get().const_value() { consts.push(const_value); } else { return None; @@ -74,16 +71,16 @@ impl ExpressionNode for TupleInitExpression { } fn is_consty(&self) -> bool { - self.elements.iter().all(|x| x.is_consty()) + self.elements.iter().all(|x| x.get().is_consty()) } } -impl FromAst for TupleInitExpression { +impl<'a> FromAst<'a, leo_ast::TupleInitExpression> for TupleInitExpression<'a> { fn from_ast( - scope: &Scope, + scope: &'a Scope<'a>, value: &leo_ast::TupleInitExpression, - expected_type: Option, - ) -> Result { + expected_type: Option>, + ) -> Result, AsgConvertError> { let tuple_types = match expected_type { Some(PartialType::Tuple(sub_types)) => Some(sub_types), None => None, @@ -111,26 +108,27 @@ impl FromAst for TupleInitExpression { .iter() .enumerate() .map(|(i, e)| { - Arc::::from_ast( + <&Expression<'a>>::from_ast( scope, e, tuple_types.as_ref().map(|x| x.get(i)).flatten().cloned().flatten(), ) + .map(Cell::new) }) .collect::, AsgConvertError>>()?; Ok(TupleInitExpression { - parent: RefCell::new(None), + parent: Cell::new(None), span: Some(value.span.clone()), elements, }) } } -impl Into for &TupleInitExpression { +impl<'a> Into for &TupleInitExpression<'a> { fn into(self) -> leo_ast::TupleInitExpression { leo_ast::TupleInitExpression { - elements: self.elements.iter().map(|e| e.as_ref().into()).collect(), + elements: self.elements.iter().map(|e| e.get().into()).collect(), span: self.span.clone().unwrap_or_default(), } } diff --git a/asg/src/expression/unary.rs b/asg/src/expression/unary.rs index 2d0dafc49e..0e5078e71c 100644 --- a/asg/src/expression/unary.rs +++ b/asg/src/expression/unary.rs @@ -17,40 +17,37 @@ use crate::{AsgConvertError, ConstValue, Expression, ExpressionNode, FromAst, Node, PartialType, Scope, Span, Type}; pub use leo_ast::UnaryOperation; -use std::{ - cell::RefCell, - sync::{Arc, Weak}, -}; +use std::cell::Cell; -#[derive(Debug)] -pub struct UnaryExpression { - pub parent: RefCell>>, +#[derive(Clone)] +pub struct UnaryExpression<'a> { + pub parent: Cell>>, pub span: Option, pub operation: UnaryOperation, - pub inner: Arc, + pub inner: Cell<&'a Expression<'a>>, } -impl Node for UnaryExpression { +impl<'a> Node for UnaryExpression<'a> { fn span(&self) -> Option<&Span> { self.span.as_ref() } } -impl ExpressionNode for UnaryExpression { - fn set_parent(&self, parent: Weak) { +impl<'a> ExpressionNode<'a> for UnaryExpression<'a> { + fn set_parent(&self, parent: &'a Expression<'a>) { self.parent.replace(Some(parent)); } - fn get_parent(&self) -> Option> { - self.parent.borrow().as_ref().map(Weak::upgrade).flatten() + fn get_parent(&self) -> Option<&'a Expression<'a>> { + self.parent.get() } - fn enforce_parents(&self, expr: &Arc) { - self.inner.set_parent(Arc::downgrade(expr)); + fn enforce_parents(&self, expr: &'a Expression<'a>) { + self.inner.get().set_parent(expr); } - fn get_type(&self) -> Option { - self.inner.get_type() + fn get_type(&self) -> Option> { + self.inner.get().get_type() } fn is_mut_ref(&self) -> bool { @@ -58,7 +55,7 @@ impl ExpressionNode for UnaryExpression { } fn const_value(&self) -> Option { - if let Some(inner) = self.inner.const_value() { + if let Some(inner) = self.inner.get().const_value() { match self.operation { UnaryOperation::Not => match inner { ConstValue::Boolean(value) => Some(ConstValue::Boolean(!value)), @@ -79,16 +76,16 @@ impl ExpressionNode for UnaryExpression { } fn is_consty(&self) -> bool { - self.inner.is_consty() + self.inner.get().is_consty() } } -impl FromAst for UnaryExpression { +impl<'a> FromAst<'a, leo_ast::UnaryExpression> for UnaryExpression<'a> { fn from_ast( - scope: &Scope, + scope: &'a Scope<'a>, value: &leo_ast::UnaryExpression, - expected_type: Option, - ) -> Result { + expected_type: Option>, + ) -> Result, AsgConvertError> { let expected_type = match value.op { UnaryOperation::Not => match expected_type.map(|x| x.full()).flatten() { Some(Type::Boolean) | None => Some(Type::Boolean), @@ -115,19 +112,23 @@ impl FromAst for UnaryExpression { }, }; Ok(UnaryExpression { - parent: RefCell::new(None), + parent: Cell::new(None), span: Some(value.span.clone()), operation: value.op.clone(), - inner: Arc::::from_ast(scope, &*value.inner, expected_type.map(Into::into))?, + inner: Cell::new(<&Expression<'a>>::from_ast( + scope, + &*value.inner, + expected_type.map(Into::into), + )?), }) } } -impl Into for &UnaryExpression { +impl<'a> Into for &UnaryExpression<'a> { fn into(self) -> leo_ast::UnaryExpression { leo_ast::UnaryExpression { op: self.operation.clone(), - inner: Box::new(self.inner.as_ref().into()), + inner: Box::new(self.inner.get().into()), span: self.span.clone().unwrap_or_default(), } } diff --git a/asg/src/expression/variable_ref.rs b/asg/src/expression/variable_ref.rs index 9bfdf6da3f..d702caa295 100644 --- a/asg/src/expression/variable_ref.rs +++ b/asg/src/expression/variable_ref.rs @@ -31,37 +31,34 @@ use crate::{ Variable, }; -use std::{ - cell::RefCell, - sync::{Arc, Weak}, -}; +use std::cell::Cell; -#[derive(Debug)] -pub struct VariableRef { - pub parent: RefCell>>, +#[derive(Clone)] +pub struct VariableRef<'a> { + pub parent: Cell>>, pub span: Option, - pub variable: Variable, + pub variable: &'a Variable<'a>, } -impl Node for VariableRef { +impl<'a> Node for VariableRef<'a> { fn span(&self) -> Option<&Span> { self.span.as_ref() } } -impl ExpressionNode for VariableRef { - fn set_parent(&self, parent: Weak) { +impl<'a> ExpressionNode<'a> for VariableRef<'a> { + fn set_parent(&self, parent: &'a Expression<'a>) { self.parent.replace(Some(parent)); } - fn get_parent(&self) -> Option> { - self.parent.borrow().as_ref().map(Weak::upgrade).flatten() + fn get_parent(&self) -> Option<&'a Expression<'a>> { + self.parent.get() } - fn enforce_parents(&self, _expr: &Arc) {} + fn enforce_parents(&self, _expr: &'a Expression<'a>) {} - fn get_type(&self) -> Option { - Some(self.variable.borrow().type_.clone().strong()) + fn get_type(&self) -> Option> { + Some(self.variable.borrow().type_.clone()) } fn is_mut_ref(&self) -> bool { @@ -74,24 +71,19 @@ impl ExpressionNode for VariableRef { if variable.mutable || variable.assignments.len() != 1 { return None; } - let assignment = variable - .assignments - .get(0) - .unwrap() - .upgrade() - .expect("stale assignment for variable"); + let assignment = variable.assignments.get(0).unwrap(); match &*assignment { Statement::Definition(DefinitionStatement { variables, value, .. }) => { if variables.len() == 1 { let defined_variable = variables.get(0).unwrap().borrow(); assert_eq!(variable.id, defined_variable.id); - value.const_value() + value.get().const_value() } else { for defined_variable in variables.iter() { let defined_variable = defined_variable.borrow(); if defined_variable.id == variable.id { - return value.const_value(); + return value.get().const_value(); } } panic!("no corresponding tuple variable found during const destructuring (corrupt asg?)"); @@ -109,12 +101,7 @@ impl ExpressionNode for VariableRef { if variable.mutable || variable.assignments.len() != 1 { return false; } - let assignment = variable - .assignments - .get(0) - .unwrap() - .upgrade() - .expect("stale assignment for variable"); + let assignment = variable.assignments.get(0).unwrap(); match &*assignment { Statement::Definition(DefinitionStatement { variables, value, .. }) => { @@ -122,12 +109,12 @@ impl ExpressionNode for VariableRef { let defined_variable = variables.get(0).unwrap().borrow(); assert_eq!(variable.id, defined_variable.id); - value.is_consty() + value.get().is_consty() } else { for defined_variable in variables.iter() { let defined_variable = defined_variable.borrow(); if defined_variable.id == variable.id { - return value.is_consty(); + return value.get().is_consty(); } } panic!("no corresponding tuple variable found during const destructuring (corrupt asg?)"); @@ -139,21 +126,21 @@ impl ExpressionNode for VariableRef { } } -impl FromAst for Arc { +impl<'a> FromAst<'a, leo_ast::Identifier> for &'a Expression<'a> { fn from_ast( - scope: &Scope, + scope: &'a Scope<'a>, value: &leo_ast::Identifier, - expected_type: Option, - ) -> Result, AsgConvertError> { + expected_type: Option>, + ) -> Result<&'a Expression<'a>, AsgConvertError> { let variable = if value.name == "input" { - if let Some(function) = scope.borrow().resolve_current_function() { + if let Some(function) = scope.resolve_current_function() { if !function.has_input { return Err(AsgConvertError::unresolved_reference(&value.name, &value.span)); } } else { return Err(AsgConvertError::unresolved_reference(&value.name, &value.span)); } - if let Some(input) = scope.borrow().resolve_input() { + if let Some(input) = scope.resolve_input() { input.container } else { return Err(AsgConvertError::InternalError( @@ -161,12 +148,12 @@ impl FromAst for Arc { )); } } else { - match scope.borrow().resolve_variable(&value.name) { + match scope.resolve_variable(&value.name) { Some(v) => v, None => { if value.name.starts_with("aleo1") { - return Ok(Arc::new(Expression::Constant(Constant { - parent: RefCell::new(None), + return Ok(scope.alloc_expression(Expression::Constant(Constant { + parent: Cell::new(None), span: Some(value.span.clone()), value: ConstValue::Address(value.name.clone()), }))); @@ -177,11 +164,11 @@ impl FromAst for Arc { }; let variable_ref = VariableRef { - parent: RefCell::new(None), + parent: Cell::new(None), span: Some(value.span.clone()), - variable: variable.clone(), + variable, }; - let expression = Arc::new(Expression::VariableRef(variable_ref)); + let expression = scope.alloc_expression(Expression::VariableRef(variable_ref)); if let Some(expected_type) = expected_type { let type_ = expression @@ -197,13 +184,13 @@ impl FromAst for Arc { } let mut variable_ref = variable.borrow_mut(); - variable_ref.references.push(Arc::downgrade(&expression)); + variable_ref.references.push(expression); Ok(expression) } } -impl Into for &VariableRef { +impl<'a> Into for &VariableRef<'a> { fn into(self) -> leo_ast::Identifier { self.variable.borrow().name.clone() } diff --git a/asg/src/import.rs b/asg/src/import.rs index e28d18ee28..4dc3c5ae43 100644 --- a/asg/src/import.rs +++ b/asg/src/import.rs @@ -16,56 +16,74 @@ //! Helper methods for resolving imported packages. -use crate::{AsgConvertError, Program, Span}; +use std::marker::PhantomData; + +use crate::{AsgContext, AsgConvertError, Program, Span}; use indexmap::IndexMap; -pub trait ImportResolver { - fn resolve_package(&mut self, package_segments: &[&str], span: &Span) -> Result, AsgConvertError>; +pub trait ImportResolver<'a> { + fn resolve_package( + &mut self, + context: AsgContext<'a>, + package_segments: &[&str], + span: &Span, + ) -> Result>, AsgConvertError>; } pub struct NullImportResolver; -impl ImportResolver for NullImportResolver { +impl<'a> ImportResolver<'a> for NullImportResolver { fn resolve_package( &mut self, + _context: AsgContext<'a>, _package_segments: &[&str], _span: &Span, - ) -> Result, AsgConvertError> { + ) -> Result>, AsgConvertError> { Ok(None) } } -pub struct CoreImportResolver<'a, T: ImportResolver + 'static>(pub &'a mut T); +pub struct CoreImportResolver<'a, 'b, T: ImportResolver<'b>> { + inner: &'a mut T, + lifetime: PhantomData<&'b ()>, +} -impl<'a, T: ImportResolver + 'static> ImportResolver for CoreImportResolver<'a, T> { - fn resolve_package(&mut self, package_segments: &[&str], span: &Span) -> Result, AsgConvertError> { - if !package_segments.is_empty() && package_segments.get(0).unwrap() == &"core" { - Ok(crate::resolve_core_module(&*package_segments[1..].join("."))?) - } else { - self.0.resolve_package(package_segments, span) +impl<'a, 'b, T: ImportResolver<'b>> CoreImportResolver<'a, 'b, T> { + pub fn new(inner: &'a mut T) -> Self { + CoreImportResolver { + inner, + lifetime: PhantomData, } } } -pub struct StandardImportResolver; - -impl ImportResolver for StandardImportResolver { +impl<'a, 'b, T: ImportResolver<'b>> ImportResolver<'b> for CoreImportResolver<'a, 'b, T> { fn resolve_package( &mut self, - _package_segments: &[&str], - _span: &Span, - ) -> Result, AsgConvertError> { - Ok(None) + context: AsgContext<'b>, + package_segments: &[&str], + span: &Span, + ) -> Result>, AsgConvertError> { + if !package_segments.is_empty() && package_segments.get(0).unwrap() == &"core" { + Ok(crate::resolve_core_module(context, &*package_segments[1..].join("."))?) + } else { + self.inner.resolve_package(context, package_segments, span) + } } } -pub struct MockedImportResolver { - pub packages: IndexMap, +pub struct MockedImportResolver<'a> { + pub packages: IndexMap>, } -impl ImportResolver for MockedImportResolver { - fn resolve_package(&mut self, package_segments: &[&str], _span: &Span) -> Result, AsgConvertError> { +impl<'a> ImportResolver<'a> for MockedImportResolver<'a> { + fn resolve_package( + &mut self, + _context: AsgContext<'a>, + package_segments: &[&str], + _span: &Span, + ) -> Result>, AsgConvertError> { Ok(self.packages.get(&package_segments.join(".")).cloned()) } } diff --git a/asg/src/input.rs b/asg/src/input.rs index b814927f2a..11789a7425 100644 --- a/asg/src/input.rs +++ b/asg/src/input.rs @@ -14,23 +14,20 @@ // You should have received a copy of the GNU General Public License // along with the Leo library. If not, see . -use crate::{Circuit, CircuitBody, CircuitMember, CircuitMemberBody, Identifier, Scope, Type, Variable, WeakType}; +use crate::{Circuit, CircuitMember, Identifier, Scope, Type, Variable}; use indexmap::IndexMap; -use std::{ - cell::RefCell, - sync::{Arc, Weak}, -}; +use std::cell::RefCell; -/// Stores program input values as asg nodes. -#[derive(Debug, Clone)] -pub struct Input { - pub registers: Arc, - pub state: Arc, - pub state_leaf: Arc, - pub record: Arc, - pub container_circuit: Arc, - pub container: Variable, +/// Stores program input values as ASG nodes. +#[derive(Clone, Copy)] +pub struct Input<'a> { + pub registers: &'a Circuit<'a>, + pub state: &'a Circuit<'a>, + pub state_leaf: &'a Circuit<'a>, + pub record: &'a Circuit<'a>, + pub container_circuit: &'a Circuit<'a>, + pub container: &'a Variable<'a>, } pub const CONTAINER_PSEUDO_CIRCUIT: &str = "$InputContainer"; @@ -39,95 +36,56 @@ pub const RECORD_PSEUDO_CIRCUIT: &str = "$InputRecord"; pub const STATE_PSEUDO_CIRCUIT: &str = "$InputState"; pub const STATE_LEAF_PSEUDO_CIRCUIT: &str = "$InputStateLeaf"; -impl Input { - fn make_header(name: &str) -> Arc { - Arc::new(Circuit { +impl<'a> Input<'a> { + fn make_header(scope: &'a Scope<'a>, name: &str) -> &'a Circuit<'a> { + scope.alloc_circuit(Circuit { id: uuid::Uuid::new_v4(), name: RefCell::new(Identifier::new(name.to_string())), - body: RefCell::new(Weak::new()), members: RefCell::new(IndexMap::new()), core_mapping: RefCell::new(None), + scope, + span: Default::default(), }) } - fn make_body(scope: &Scope, circuit: &Arc) -> Arc { - let body = Arc::new(CircuitBody { - scope: scope.clone(), - span: None, - circuit: circuit.clone(), - members: RefCell::new(IndexMap::new()), - }); - circuit.body.replace(Arc::downgrade(&body)); - body - } - - pub fn new(scope: &Scope) -> Self { - let registers = Self::make_header(REGISTERS_PSEUDO_CIRCUIT); - let record = Self::make_header(RECORD_PSEUDO_CIRCUIT); - let state = Self::make_header(STATE_PSEUDO_CIRCUIT); - let state_leaf = Self::make_header(STATE_LEAF_PSEUDO_CIRCUIT); + pub fn new(scope: &'a Scope<'a>) -> Self { + let input_scope = scope.make_subscope(); + let registers = Self::make_header(input_scope, REGISTERS_PSEUDO_CIRCUIT); + let record = Self::make_header(input_scope, RECORD_PSEUDO_CIRCUIT); + let state = Self::make_header(input_scope, STATE_PSEUDO_CIRCUIT); + let state_leaf = Self::make_header(input_scope, STATE_LEAF_PSEUDO_CIRCUIT); let mut container_members = IndexMap::new(); container_members.insert( "registers".to_string(), - CircuitMember::Variable(WeakType::Circuit(Arc::downgrade(®isters))), - ); - container_members.insert( - "record".to_string(), - CircuitMember::Variable(WeakType::Circuit(Arc::downgrade(&record))), - ); - container_members.insert( - "state".to_string(), - CircuitMember::Variable(WeakType::Circuit(Arc::downgrade(&state))), + CircuitMember::Variable(Type::Circuit(registers)), ); + container_members.insert("record".to_string(), CircuitMember::Variable(Type::Circuit(record))); + container_members.insert("state".to_string(), CircuitMember::Variable(Type::Circuit(state))); container_members.insert( "state_leaf".to_string(), - CircuitMember::Variable(WeakType::Circuit(Arc::downgrade(&state_leaf))), + CircuitMember::Variable(Type::Circuit(state_leaf)), ); - let container_circuit = Arc::new(Circuit { + let container_circuit = input_scope.alloc_circuit(Circuit { id: uuid::Uuid::new_v4(), name: RefCell::new(Identifier::new(CONTAINER_PSEUDO_CIRCUIT.to_string())), - body: RefCell::new(Weak::new()), members: RefCell::new(container_members), core_mapping: RefCell::new(None), + scope: input_scope, + span: Default::default(), }); - let registers_body = Self::make_body(scope, ®isters); - let record_body = Self::make_body(scope, &record); - let state_body = Self::make_body(scope, &state); - let state_leaf_body = Self::make_body(scope, &state_leaf); - - let mut container_body_members = IndexMap::new(); - container_body_members.insert( - "registers".to_string(), - CircuitMemberBody::Variable(Type::Circuit(registers)), - ); - container_body_members.insert("record".to_string(), CircuitMemberBody::Variable(Type::Circuit(record))); - container_body_members.insert("state".to_string(), CircuitMemberBody::Variable(Type::Circuit(state))); - container_body_members.insert( - "state_leaf".to_string(), - CircuitMemberBody::Variable(Type::Circuit(state_leaf)), - ); - - let container_circuit_body = Arc::new(CircuitBody { - scope: scope.clone(), - span: None, - circuit: container_circuit.clone(), - members: RefCell::new(container_body_members), - }); - container_circuit.body.replace(Arc::downgrade(&container_circuit_body)); - Input { - registers: registers_body, - record: record_body, - state: state_body, - state_leaf: state_leaf_body, - container_circuit: container_circuit_body, - container: Arc::new(RefCell::new(crate::InnerVariable { + registers, + record, + state, + state_leaf, + container_circuit, + container: input_scope.alloc_variable(RefCell::new(crate::InnerVariable { id: uuid::Uuid::new_v4(), name: Identifier::new("input".to_string()), - type_: Type::Circuit(container_circuit).weak(), + type_: Type::Circuit(container_circuit), mutable: false, const_: false, declaration: crate::VariableDeclaration::Input, @@ -138,7 +96,7 @@ impl Input { } } -impl Circuit { +impl<'a> Circuit<'a> { pub fn is_input_pseudo_circuit(&self) -> bool { matches!( &*self.name.borrow().name, diff --git a/asg/src/lib.rs b/asg/src/lib.rs index f93d3b4791..1fe1d8f0da 100644 --- a/asg/src/lib.rs +++ b/asg/src/lib.rs @@ -65,11 +65,17 @@ pub mod type_; pub use type_::*; pub mod variable; +use typed_arena::Arena; pub use variable::*; +pub mod pass; +pub use pass::*; + pub use leo_ast::{Ast, Identifier, Span}; -use std::{cell::RefCell, path::Path, sync::Arc}; +pub type AsgContext<'a> = &'a Arena>; + +use std::path::Path; /// The abstract semantic graph (ASG) for a Leo program. /// @@ -77,21 +83,27 @@ use std::{cell::RefCell, path::Path, sync::Arc}; /// These data types form a graph that begins from a [`Program`] type node. /// /// A new [`Asg`] can be created from an [`Ast`] generated in the `ast` module. -#[derive(Debug, Clone)] -pub struct Asg { - asg: Arc>, +#[derive(Clone)] +pub struct Asg<'a> { + context: AsgContext<'a>, + asg: Program<'a>, } -impl Asg { +impl<'a> Asg<'a> { /// Creates a new ASG from a given AST and import resolver. - pub fn new(ast: &Ast, resolver: &mut T) -> Result { + pub fn new>( + context: AsgContext<'a>, + ast: &Ast, + resolver: &mut T, + ) -> Result { Ok(Self { - asg: InternalProgram::new(&ast.as_repr(), resolver)?, + context, + asg: InternalProgram::new(context, &ast.as_repr(), resolver)?, }) } /// Returns the internal program ASG representation. - pub fn as_repr(&self) -> Arc> { + pub fn as_repr(&self) -> Program<'a> { self.asg.clone() } @@ -108,10 +120,18 @@ impl Asg { } // TODO (howardwu): Remove this. -pub fn load_asg(content: &str, resolver: &mut T) -> Result { +pub fn load_asg<'a, T: ImportResolver<'a>>( + context: AsgContext<'a>, + content: &str, + resolver: &mut T, +) -> Result, AsgConvertError> { // Parses the Leo file and constructs a grammar ast. let ast = leo_grammar::Grammar::new(&Path::new("input.leo"), content) .map_err(|e| AsgConvertError::InternalError(format!("ast: {:?}", e)))?; - InternalProgram::new(leo_ast::Ast::new("load_ast", &ast)?.as_repr(), resolver) + InternalProgram::new(context, leo_ast::Ast::new("load_ast", &ast)?.as_repr(), resolver) +} + +pub fn new_context<'a>() -> Arena> { + Arena::new() } diff --git a/asg/src/node.rs b/asg/src/node.rs index eece9c6fa8..c27d16547a 100644 --- a/asg/src/node.rs +++ b/asg/src/node.rs @@ -14,15 +14,28 @@ // You should have received a copy of the GNU General Public License // along with the Leo library. If not, see . -use crate::{AsgConvertError, PartialType, Scope, Span}; +use crate::{AsgConvertError, Circuit, Expression, Function, PartialType, Scope, Span, Statement, Variable}; /// A node in the abstract semantic graph. pub trait Node { fn span(&self) -> Option<&Span>; } -pub(super) trait FromAst: Sized + 'static { +pub(super) trait FromAst<'a, T: leo_ast::Node + 'static>: Sized { // expected_type contract: if present, output expression must be of type expected_type. // type of an element may NEVER be None unless it is functionally a non-expression. (static call targets, function ref call targets are not expressions) - fn from_ast(scope: &Scope, value: &T, expected_type: Option) -> Result; + fn from_ast( + scope: &'a Scope<'a>, + value: &T, + expected_type: Option>, + ) -> Result; +} + +pub enum ArenaNode<'a> { + Expression(Expression<'a>), + Scope(Scope<'a>), + Statement(Statement<'a>), + Variable(Variable<'a>), + Circuit(Circuit<'a>), + Function(Function<'a>), } diff --git a/asg/src/pass.rs b/asg/src/pass.rs new file mode 100644 index 0000000000..facdf10e5f --- /dev/null +++ b/asg/src/pass.rs @@ -0,0 +1,22 @@ +// Copyright (C) 2019-2021 Aleo Systems Inc. +// This file is part of the Leo library. + +// The Leo library is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. + +// The Leo library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. + +// You should have received a copy of the GNU General Public License +// along with the Leo library. If not, see . + +use crate::Program; +pub use leo_ast::Error as FormattedError; + +pub trait AsgPass { + fn do_pass(asg: &Program) -> Result<(), FormattedError>; +} diff --git a/asg/src/prelude.rs b/asg/src/prelude.rs index a567b2ff72..e977fd137f 100644 --- a/asg/src/prelude.rs +++ b/asg/src/prelude.rs @@ -16,15 +16,16 @@ // TODO (protryon): We should merge this with core -use crate::{AsgConvertError, Program}; +use crate::{AsgContext, AsgConvertError, Program}; // TODO (protryon): Make asg deep copy so we can cache resolved core modules // TODO (protryon): Figure out how to do headers without bogus returns -pub fn resolve_core_module(module: &str) -> Result, AsgConvertError> { +pub fn resolve_core_module<'a>(context: AsgContext<'a>, module: &str) -> Result>, AsgConvertError> { match module { "unstable.blake2s" => { let asg = crate::load_asg( + context, r#" circuit Blake2s { function hash(seed: [u8; 32], message: [u8; 32]) -> [u8; 32] { @@ -34,7 +35,7 @@ pub fn resolve_core_module(module: &str) -> Result, AsgConvertEr "#, &mut crate::NullImportResolver, )?; - asg.borrow().set_core_mapping("blake2s"); + asg.set_core_mapping("blake2s"); Ok(Some(asg)) } _ => Ok(None), diff --git a/asg/src/program/circuit.rs b/asg/src/program/circuit.rs index 120ab48322..ace54ca2c2 100644 --- a/asg/src/program/circuit.rs +++ b/asg/src/program/circuit.rs @@ -14,37 +14,29 @@ // You should have received a copy of the GNU General Public License // along with the Leo library. If not, see . -use crate::{AsgConvertError, Function, FunctionBody, Identifier, InnerScope, Node, Scope, Span, Type, WeakType}; +use crate::{AsgConvertError, Function, Identifier, Node, Scope, Span, Type}; use indexmap::IndexMap; -use std::{ - cell::RefCell, - sync::{Arc, Weak}, -}; +use std::cell::RefCell; use uuid::Uuid; -#[derive(Debug)] -pub enum CircuitMemberBody { - Variable(Type), - Function(Arc), +#[derive(Clone)] +pub enum CircuitMember<'a> { + Variable(Type<'a>), + Function(&'a Function<'a>), } -#[derive(Debug)] -pub enum CircuitMember { - Variable(WeakType), - Function(Arc), -} - -#[derive(Debug)] -pub struct Circuit { +#[derive(Clone)] +pub struct Circuit<'a> { pub id: Uuid, pub name: RefCell, pub core_mapping: RefCell>, - pub body: RefCell>, - pub members: RefCell>, + pub scope: &'a Scope<'a>, + pub span: Option, + pub members: RefCell>>, } -impl PartialEq for Circuit { +impl<'a> PartialEq for Circuit<'a> { fn eq(&self, other: &Circuit) -> bool { if self.name != other.name { return false; @@ -52,81 +44,30 @@ impl PartialEq for Circuit { self.id == other.id } } -impl Eq for Circuit {} -#[derive(Debug)] -pub struct CircuitBody { - pub scope: Scope, - pub span: Option, - pub circuit: Arc, - pub members: RefCell>, -} +impl<'a> Eq for Circuit<'a> {} -impl PartialEq for CircuitBody { - fn eq(&self, other: &CircuitBody) -> bool { - self.circuit == other.circuit - } -} -impl Eq for CircuitBody {} - -impl Node for CircuitMemberBody { +impl<'a> Node for Circuit<'a> { fn span(&self) -> Option<&Span> { - None + self.span.as_ref() } } -impl Circuit { - pub(super) fn init(value: &leo_ast::Circuit) -> Arc { - Arc::new(Circuit { +impl<'a> Circuit<'a> { + pub(super) fn init(scope: &'a Scope<'a>, value: &leo_ast::Circuit) -> Result<&'a Circuit<'a>, AsgConvertError> { + let new_scope = scope.make_subscope(); + + let circuit = scope.alloc_circuit(Circuit { id: Uuid::new_v4(), name: RefCell::new(value.circuit_name.clone()), - body: RefCell::new(Weak::new()), members: RefCell::new(IndexMap::new()), core_mapping: RefCell::new(None), - }) - } - - pub(super) fn from_ast(self: Arc, scope: &Scope, value: &leo_ast::Circuit) -> Result<(), AsgConvertError> { - let new_scope = InnerScope::make_subscope(scope); // temporary scope for function headers - new_scope.borrow_mut().circuit_self = Some(self.clone()); - - let mut members = self.members.borrow_mut(); - for member in value.members.iter() { - match member { - leo_ast::CircuitMember::CircuitVariable(name, type_) => { - members.insert( - name.name.clone(), - CircuitMember::Variable(new_scope.borrow().resolve_ast_type(type_)?.into()), - ); - } - leo_ast::CircuitMember::CircuitFunction(function) => { - let asg_function = Arc::new(Function::from_ast(&new_scope, function)?); - - members.insert(function.identifier.name.clone(), CircuitMember::Function(asg_function)); - } - } - } - - for (_, member) in members.iter() { - if let CircuitMember::Function(func) = member { - func.circuit.borrow_mut().replace(Arc::downgrade(&self)); - } - } - - Ok(()) - } -} - -impl CircuitBody { - pub(super) fn from_ast( - scope: &Scope, - value: &leo_ast::Circuit, - circuit: Arc, - ) -> Result { - let mut members = IndexMap::new(); - let new_scope = InnerScope::make_subscope(scope); - new_scope.borrow_mut().circuit_self = Some(circuit.clone()); + span: Some(value.circuit_name.span.clone()), + scope: new_scope, + }); + new_scope.circuit_self.replace(Some(circuit)); + let mut members = circuit.members.borrow_mut(); for member in value.members.iter() { match member { leo_ast::CircuitMember::CircuitVariable(name, type_) => { @@ -139,7 +80,7 @@ impl CircuitBody { } members.insert( name.name.clone(), - CircuitMemberBody::Variable(new_scope.borrow().resolve_ast_type(type_)?), + CircuitMember::Variable(new_scope.resolve_ast_type(type_)?), ); } leo_ast::CircuitMember::CircuitFunction(function) => { @@ -150,51 +91,51 @@ impl CircuitBody { &function.identifier.span, )); } - let asg_function = { - let circuit_members = circuit.members.borrow(); - match circuit_members.get(&function.identifier.name).unwrap() { - CircuitMember::Function(f) => f.clone(), - _ => unimplemented!(), - } - }; - let function_body = Arc::new(FunctionBody::from_ast(&new_scope, function, asg_function.clone())?); - asg_function.body.replace(Arc::downgrade(&function_body)); - - members.insert( - function.identifier.name.clone(), - CircuitMemberBody::Function(function_body), - ); + let asg_function = Function::init(new_scope, function)?; + asg_function.circuit.replace(Some(circuit)); + members.insert(function.identifier.name.clone(), CircuitMember::Function(asg_function)); } } } - Ok(CircuitBody { - span: Some(value.circuit_name.span.clone()), - circuit, - members: RefCell::new(members), - scope: scope.clone(), - }) + Ok(circuit) + } + + pub(super) fn fill_from_ast(self: &'a Circuit<'a>, value: &leo_ast::Circuit) -> Result<(), AsgConvertError> { + for member in value.members.iter() { + match member { + leo_ast::CircuitMember::CircuitVariable(..) => {} + leo_ast::CircuitMember::CircuitFunction(function) => { + let asg_function = match *self + .members + .borrow() + .get(&function.identifier.name) + .expect("missing header for defined circuit function") + { + CircuitMember::Function(f) => f, + _ => unimplemented!(), + }; + Function::fill_from_ast(asg_function, function)?; + } + } + } + Ok(()) } } -impl Into for &Circuit { +impl<'a> Into for &Circuit<'a> { fn into(self) -> leo_ast::Circuit { - let members = match self.body.borrow().upgrade() { - Some(body) => body - .members - .borrow() - .iter() - .map(|(name, member)| match &member { - CircuitMemberBody::Variable(type_) => { - leo_ast::CircuitMember::CircuitVariable(Identifier::new(name.clone()), type_.into()) - } - CircuitMemberBody::Function(func) => { - leo_ast::CircuitMember::CircuitFunction(func.function.as_ref().into()) - } - }) - .collect(), - None => vec![], - }; + let members = self + .members + .borrow() + .iter() + .map(|(name, member)| match &member { + CircuitMember::Variable(type_) => { + leo_ast::CircuitMember::CircuitVariable(Identifier::new(name.clone()), type_.into()) + } + CircuitMember::Function(func) => leo_ast::CircuitMember::CircuitFunction((*func).into()), + }) + .collect(); leo_ast::Circuit { circuit_name: self.name.borrow().clone(), members, diff --git a/asg/src/program/function.rs b/asg/src/program/function.rs index daff1741f6..0032b24070 100644 --- a/asg/src/program/function.rs +++ b/asg/src/program/function.rs @@ -20,7 +20,6 @@ use crate::{ Circuit, FromAst, Identifier, - InnerScope, MonoidalDirector, ReturnPathReducer, Scope, @@ -28,72 +27,58 @@ use crate::{ Statement, Type, Variable, - WeakType, }; +use indexmap::IndexMap; use leo_ast::FunctionInput; -use std::{ - cell::RefCell, - sync::{Arc, Weak}, -}; +use std::cell::{Cell, RefCell}; use uuid::Uuid; -#[derive(Debug, PartialEq)] +#[derive(Clone, Copy, PartialEq)] pub enum FunctionQualifier { SelfRef, MutSelfRef, Static, } -#[derive(Debug)] -pub struct Function { +#[derive(Clone)] +pub struct Function<'a> { pub id: Uuid, pub name: RefCell, - pub output: WeakType, + pub output: Type<'a>, pub has_input: bool, - pub arguments: Vec, - pub circuit: RefCell>>, - pub body: RefCell>, + pub arguments: IndexMap>>, + pub circuit: Cell>>, + pub span: Option, + pub body: Cell>>, + pub scope: &'a Scope<'a>, pub qualifier: FunctionQualifier, } -impl PartialEq for Function { - fn eq(&self, other: &Function) -> bool { +impl<'a> PartialEq for Function<'a> { + fn eq(&self, other: &Function<'a>) -> bool { if self.name.borrow().name != other.name.borrow().name { return false; } self.id == other.id } } -impl Eq for Function {} -#[derive(Debug)] -pub struct FunctionBody { - pub span: Option, - pub function: Arc, - pub body: Arc, - pub scope: Scope, -} +impl<'a> Eq for Function<'a> {} -impl PartialEq for FunctionBody { - fn eq(&self, other: &FunctionBody) -> bool { - self.function == other.function - } -} -impl Eq for FunctionBody {} - -impl Function { - pub(crate) fn from_ast(scope: &Scope, value: &leo_ast::Function) -> Result { - let output: Type = value +impl<'a> Function<'a> { + pub(crate) fn init(scope: &'a Scope<'a>, value: &leo_ast::Function) -> Result<&'a Function<'a>, AsgConvertError> { + let output: Type<'a> = value .output .as_ref() - .map(|t| scope.borrow().resolve_ast_type(t)) + .map(|t| scope.resolve_ast_type(t)) .transpose()? .unwrap_or_else(|| Type::Tuple(vec![])); let mut qualifier = FunctionQualifier::Static; let mut has_input = false; + let new_scope = scope.make_subscope(); - let mut arguments = vec![]; + let mut arguments = IndexMap::new(); { for input in value.input.iter() { match input { @@ -107,77 +92,74 @@ impl Function { qualifier = FunctionQualifier::MutSelfRef; } FunctionInput::Variable(leo_ast::FunctionInputVariable { - identifier, - mutable, - const_, type_, - span: _span, + identifier, + const_, + mutable, + .. }) => { - let variable = Arc::new(RefCell::new(crate::InnerVariable { + let variable = scope.alloc_variable(RefCell::new(crate::InnerVariable { id: Uuid::new_v4(), name: identifier.clone(), - type_: scope.borrow().resolve_ast_type(&type_)?.weak(), + type_: scope.resolve_ast_type(&type_)?, mutable: *mutable, const_: *const_, declaration: crate::VariableDeclaration::Parameter, references: vec![], assignments: vec![], })); - arguments.push(variable.clone()); + arguments.insert(identifier.name.clone(), Cell::new(&*variable)); } } } } - if qualifier != FunctionQualifier::Static && scope.borrow().circuit_self.is_none() { + if qualifier != FunctionQualifier::Static && scope.circuit_self.get().is_none() { return Err(AsgConvertError::invalid_self_in_global(&value.span)); } - Ok(Function { + let function = scope.alloc_function(Function { id: Uuid::new_v4(), name: RefCell::new(value.identifier.clone()), - output: output.into(), + output, has_input, arguments, - circuit: RefCell::new(None), - body: RefCell::new(Weak::new()), + circuit: Cell::new(None), + body: Cell::new(None), qualifier, - }) - } -} + scope: new_scope, + span: Some(value.span.clone()), + }); + function.scope.function.replace(Some(function)); -impl FunctionBody { - pub(super) fn from_ast( - scope: &Scope, - value: &leo_ast::Function, - function: Arc, - ) -> Result { - let new_scope = InnerScope::make_subscope(scope); - { - let mut scope_borrow = new_scope.borrow_mut(); - if function.qualifier != FunctionQualifier::Static { - let circuit = function.circuit.borrow(); - let self_variable = Arc::new(RefCell::new(crate::InnerVariable { - id: Uuid::new_v4(), - name: Identifier::new("self".to_string()), - type_: WeakType::Circuit(circuit.as_ref().unwrap().clone()), - mutable: function.qualifier == FunctionQualifier::MutSelfRef, - const_: false, - declaration: crate::VariableDeclaration::Parameter, - references: vec![], - assignments: vec![], - })); - scope_borrow.variables.insert("self".to_string(), self_variable); - } - scope_borrow.function = Some(function.clone()); - for argument in function.arguments.iter() { - let name = argument.borrow().name.name.clone(); - scope_borrow.variables.insert(name, argument.clone()); - } + Ok(function) + } + + pub(super) fn fill_from_ast(self: &'a Function<'a>, value: &leo_ast::Function) -> Result<(), AsgConvertError> { + if self.qualifier != FunctionQualifier::Static { + let circuit = self.circuit.get(); + let self_variable = self.scope.alloc_variable(RefCell::new(crate::InnerVariable { + id: Uuid::new_v4(), + name: Identifier::new("self".to_string()), + type_: Type::Circuit(circuit.as_ref().unwrap()), + mutable: self.qualifier == FunctionQualifier::MutSelfRef, + const_: false, + declaration: crate::VariableDeclaration::Parameter, + references: vec![], + assignments: vec![], + })); + self.scope + .variables + .borrow_mut() + .insert("self".to_string(), self_variable); } - let main_block = BlockStatement::from_ast(&new_scope, &value.block, None)?; + for (name, argument) in self.arguments.iter() { + self.scope.variables.borrow_mut().insert(name.clone(), argument.get()); + } + + let main_block = BlockStatement::from_ast(self.scope, &value.block, None)?; let mut director = MonoidalDirector::new(ReturnPathReducer::new()); - if !director.reduce_block(&main_block).0 && !function.output.is_unit() { + if !director.reduce_block(&main_block).0 && !self.output.is_unit() { return Err(AsgConvertError::function_missing_return( - &function.name.borrow().name, + &self.name.borrow().name, &value.span, )); } @@ -185,47 +167,39 @@ impl FunctionBody { #[allow(clippy::never_loop)] // TODO @Protryon: How should we return multiple errors? for (span, error) in director.reducer().errors { return Err(AsgConvertError::function_return_validation( - &function.name.borrow().name, + &self.name.borrow().name, &error, &span, )); } - Ok(FunctionBody { - span: Some(value.span.clone()), - function, - body: Arc::new(Statement::Block(main_block)), - scope: new_scope, - }) + self.body + .replace(Some(self.scope.alloc_statement(Statement::Block(main_block)))); + + Ok(()) } } -impl Into for &Function { +impl<'a> Into for &Function<'a> { fn into(self) -> leo_ast::Function { - let (input, body, span) = match self.body.borrow().upgrade() { - Some(body) => ( - body.function - .arguments - .iter() - .map(|variable| { - let variable = variable.borrow(); - leo_ast::FunctionInput::Variable(leo_ast::FunctionInputVariable { - identifier: variable.name.clone(), - mutable: variable.mutable, - const_: variable.const_, - type_: (&variable.type_.clone().strong()).into(), - span: Span::default(), - }) - }) - .collect(), - match body.body.as_ref() { - Statement::Block(block) => block.into(), - _ => unimplemented!(), - }, - body.span.clone().unwrap_or_default(), - ), + let input = self + .arguments + .iter() + .map(|(_, variable)| { + let variable = variable.get().borrow(); + leo_ast::FunctionInput::Variable(leo_ast::FunctionInputVariable { + identifier: variable.name.clone(), + mutable: variable.mutable, + const_: variable.const_, + type_: (&variable.type_).into(), + span: Span::default(), + }) + }) + .collect(); + let (body, span) = match self.body.get() { + Some(Statement::Block(block)) => (block.into(), block.span.clone().unwrap_or_default()), + Some(_) => unimplemented!(), None => ( - vec![], leo_ast::Block { statements: vec![], span: Default::default(), @@ -233,7 +207,7 @@ impl Into for &Function { Default::default(), ), }; - let output: Type = self.output.clone().into(); + let output: Type = self.output.clone(); leo_ast::Function { identifier: self.name.borrow().clone(), input, diff --git a/asg/src/program/mod.rs b/asg/src/program/mod.rs index 4de5e81f54..f7ae6005d3 100644 --- a/asg/src/program/mod.rs +++ b/asg/src/program/mod.rs @@ -24,16 +24,18 @@ pub use circuit::*; mod function; pub use function::*; -use crate::{AsgConvertError, ImportResolver, InnerScope, Input, Scope}; +use crate::{ArenaNode, AsgContext, AsgConvertError, ImportResolver, Input, Scope}; use leo_ast::{Identifier, PackageAccess, PackageOrPackages, Span}; use indexmap::IndexMap; -use std::{cell::RefCell, sync::Arc}; +use std::cell::{Cell, RefCell}; use uuid::Uuid; /// Stores the Leo program abstract semantic graph (ASG). -#[derive(Debug, Clone)] -pub struct InternalProgram { +#[derive(Clone)] +pub struct InternalProgram<'a> { + pub context: AsgContext<'a>, + /// The unique id of the program. pub id: Uuid, @@ -42,24 +44,25 @@ pub struct InternalProgram { /// The packages imported by this program. /// these should generally not be accessed directly, but through scoped imports - pub imported_modules: IndexMap, + pub imported_modules: IndexMap>, /// Maps test name => test code block. - pub test_functions: IndexMap, Option)>, // identifier = test input file + pub test_functions: IndexMap, Option)>, // identifier = test input file /// Maps function name => function code block. - pub functions: IndexMap>, + pub functions: IndexMap>, /// Maps circuit name => circuit code block. - pub circuits: IndexMap>, + pub circuits: IndexMap>, /// Bindings for names and additional program context. - pub scope: Scope, + pub scope: &'a Scope<'a>, } -pub type Program = Arc>; +pub type Program<'a> = InternalProgram<'a>; /// Enumerates what names are imported from a package. +#[derive(Clone)] enum ImportSymbol { /// Import the symbol by name. Direct(String), @@ -124,7 +127,7 @@ fn resolve_import_package_access( } } -impl InternalProgram { +impl<'a> InternalProgram<'a> { /// Returns a new Leo program ASG from the given Leo program AST and its imports. /// /// Stages: @@ -133,10 +136,11 @@ impl InternalProgram { /// 3. finalize declared functions /// 4. resolve all asg nodes /// - pub fn new( + pub fn new>( + arena: AsgContext<'a>, program: &leo_ast::Program, import_resolver: &mut T, - ) -> Result { + ) -> Result, AsgConvertError> { // Recursively extract imported symbols. let mut imported_symbols: Vec<(Vec, ImportSymbol, Span)> = vec![]; for import in program.imports.iter() { @@ -149,24 +153,27 @@ impl InternalProgram { deduplicated_imports.insert(package.clone(), span.clone()); } - let mut wrapped_resolver = crate::CoreImportResolver(import_resolver); + let mut wrapped_resolver = crate::CoreImportResolver::new(import_resolver); // Load imported programs. let mut resolved_packages: IndexMap, Program> = IndexMap::new(); for (package, span) in deduplicated_imports.iter() { let pretty_package = package.join("."); - let resolved_package = - match wrapped_resolver.resolve_package(&package.iter().map(|x| &**x).collect::>()[..], span)? { - Some(x) => x, - None => return Err(AsgConvertError::unresolved_import(&*pretty_package, &Span::default())), - }; + let resolved_package = match wrapped_resolver.resolve_package( + arena, + &package.iter().map(|x| &**x).collect::>()[..], + span, + )? { + Some(x) => x, + None => return Err(AsgConvertError::unresolved_import(&*pretty_package, &Span::default())), + }; resolved_packages.insert(package.clone(), resolved_package); } - let mut imported_functions: IndexMap> = IndexMap::new(); - let mut imported_circuits: IndexMap> = IndexMap::new(); + let mut imported_functions: IndexMap> = IndexMap::new(); + let mut imported_circuits: IndexMap> = IndexMap::new(); // Prepare locally relevant scope of imports. for (package, symbol, span) in imported_symbols.into_iter() { @@ -175,7 +182,6 @@ impl InternalProgram { let resolved_package = resolved_packages .get(&package) .expect("could not find preloaded package"); - let resolved_package = resolved_package.borrow(); match symbol { ImportSymbol::All => { imported_functions.extend(resolved_package.functions.clone().into_iter()); @@ -183,9 +189,9 @@ impl InternalProgram { } ImportSymbol::Direct(name) => { if let Some(function) = resolved_package.functions.get(&name) { - imported_functions.insert(name.clone(), function.clone()); - } else if let Some(function) = resolved_package.circuits.get(&name) { - imported_circuits.insert(name.clone(), function.clone()); + imported_functions.insert(name.clone(), *function); + } else if let Some(circuit) = resolved_package.circuits.get(&name) { + imported_circuits.insert(name.clone(), *circuit); } else { return Err(AsgConvertError::unresolved_import( &*format!("{}.{}", pretty_package, name), @@ -195,9 +201,9 @@ impl InternalProgram { } ImportSymbol::Alias(name, alias) => { if let Some(function) = resolved_package.functions.get(&name) { - imported_functions.insert(alias.clone(), function.clone()); - } else if let Some(function) = resolved_package.circuits.get(&name) { - imported_circuits.insert(alias.clone(), function.clone()); + imported_functions.insert(alias.clone(), *function); + } else if let Some(circuit) = resolved_package.circuits.get(&name) { + imported_circuits.insert(alias.clone(), *circuit); } else { return Err(AsgConvertError::unresolved_import( &*format!("{}.{}", pretty_package, name), @@ -208,71 +214,54 @@ impl InternalProgram { } } - let import_scope = Arc::new(RefCell::new(InnerScope { + let import_scope = match arena.alloc(ArenaNode::Scope(Scope { + arena, id: uuid::Uuid::new_v4(), - parent_scope: None, - circuit_self: None, - variables: IndexMap::new(), - functions: imported_functions - .iter() - .map(|(name, func)| (name.clone(), func.function.clone())) - .collect(), - circuits: imported_circuits - .iter() - .map(|(name, circuit)| (name.clone(), circuit.circuit.clone())) - .collect(), - function: None, - input: None, - })); + parent_scope: Cell::new(None), + circuit_self: Cell::new(None), + variables: RefCell::new(IndexMap::new()), + functions: RefCell::new(imported_functions), + circuits: RefCell::new(imported_circuits), + function: Cell::new(None), + input: Cell::new(None), + })) { + ArenaNode::Scope(c) => c, + _ => unimplemented!(), + }; + + let scope = import_scope.alloc_scope(Scope { + arena, + input: Cell::new(Some(Input::new(import_scope))), // we use import_scope to avoid recursive scope ref here + id: uuid::Uuid::new_v4(), + parent_scope: Cell::new(Some(import_scope)), + circuit_self: Cell::new(None), + variables: RefCell::new(IndexMap::new()), + functions: RefCell::new(IndexMap::new()), + circuits: RefCell::new(IndexMap::new()), + function: Cell::new(None), + }); // Prepare header-like scope entries. - let mut proto_circuits = IndexMap::new(); for (name, circuit) in program.circuits.iter() { assert_eq!(name.name, circuit.circuit_name.name); - let asg_circuit = Circuit::init(circuit); + let asg_circuit = Circuit::init(scope, circuit)?; - proto_circuits.insert(name.name.clone(), asg_circuit); - } - - let scope = Arc::new(RefCell::new(InnerScope { - input: Some(Input::new(&import_scope)), // we use import_scope to avoid recursive scope ref here - id: uuid::Uuid::new_v4(), - parent_scope: Some(import_scope), - circuit_self: None, - variables: IndexMap::new(), - functions: IndexMap::new(), - circuits: proto_circuits - .iter() - .map(|(name, circuit)| (name.clone(), circuit.clone())) - .collect(), - function: None, - })); - - for (name, circuit) in program.circuits.iter() { - assert_eq!(name.name, circuit.circuit_name.name); - let asg_circuit = proto_circuits.get(&name.name).unwrap(); - - asg_circuit.clone().from_ast(&scope, &circuit)?; + scope.circuits.borrow_mut().insert(name.name.clone(), asg_circuit); } let mut proto_test_functions = IndexMap::new(); for (name, test_function) in program.tests.iter() { assert_eq!(name.name, test_function.function.identifier.name); - let function = Arc::new(Function::from_ast(&scope, &test_function.function)?); + let function = Function::init(scope, &test_function.function)?; proto_test_functions.insert(name.name.clone(), function); } - let mut proto_functions = IndexMap::new(); for (name, function) in program.functions.iter() { assert_eq!(name.name, function.identifier.name); - let asg_function = Arc::new(Function::from_ast(&scope, function)?); + let function = Function::init(scope, function)?; - scope - .borrow_mut() - .functions - .insert(name.name.clone(), asg_function.clone()); - proto_functions.insert(name.name.clone(), asg_function); + scope.functions.borrow_mut().insert(name.name.clone(), function); } // Load concrete definitions. @@ -281,38 +270,33 @@ impl InternalProgram { assert_eq!(name.name, test_function.function.identifier.name); let function = proto_test_functions.get(&name.name).unwrap(); - let body = Arc::new(FunctionBody::from_ast( - &scope, - &test_function.function, - function.clone(), - )?); - function.body.replace(Arc::downgrade(&body)); + function.fill_from_ast(&test_function.function)?; - test_functions.insert(name.name.clone(), (body, test_function.input_file.clone())); + test_functions.insert(name.name.clone(), (*function, test_function.input_file.clone())); } let mut functions = IndexMap::new(); for (name, function) in program.functions.iter() { assert_eq!(name.name, function.identifier.name); - let asg_function = proto_functions.get(&name.name).unwrap(); + let asg_function = *scope.functions.borrow().get(&name.name).unwrap(); - let body = Arc::new(FunctionBody::from_ast(&scope, function, asg_function.clone())?); - asg_function.body.replace(Arc::downgrade(&body)); + asg_function.fill_from_ast(function)?; - functions.insert(name.name.clone(), body); + functions.insert(name.name.clone(), asg_function); } let mut circuits = IndexMap::new(); for (name, circuit) in program.circuits.iter() { assert_eq!(name.name, circuit.circuit_name.name); - let asg_circuit = proto_circuits.get(&name.name).unwrap(); - let body = Arc::new(CircuitBody::from_ast(&scope, circuit, asg_circuit.clone())?); - asg_circuit.body.replace(Arc::downgrade(&body)); + let asg_circuit = *scope.circuits.borrow().get(&name.name).unwrap(); - circuits.insert(name.name.clone(), body); + asg_circuit.fill_from_ast(circuit)?; + + circuits.insert(name.name.clone(), asg_circuit); } - Ok(Arc::new(RefCell::new(InternalProgram { + Ok(InternalProgram { + context: arena, id: Uuid::new_v4(), name: program.name.clone(), test_functions, @@ -323,12 +307,12 @@ impl InternalProgram { .map(|(package, program)| (package.join("."), program)) .collect(), scope, - }))) + }) } pub(crate) fn set_core_mapping(&self, mapping: &str) { for (_, circuit) in self.circuits.iter() { - circuit.circuit.core_mapping.replace(Some(mapping.to_string())); + circuit.core_mapping.replace(Some(mapping.to_string())); } } } @@ -347,15 +331,15 @@ impl Iterator for InternalIdentifierGenerator { } } /// Returns an AST from the given ASG program. -pub fn reform_ast(program: &Program) -> leo_ast::Program { +pub fn reform_ast<'a>(program: &Program<'a>) -> leo_ast::Program { let mut all_programs: IndexMap = IndexMap::new(); - let mut program_stack = program.borrow().imported_modules.clone(); + let mut program_stack = program.imported_modules.clone(); while let Some((module, program)) = program_stack.pop() { if all_programs.contains_key(&module) { continue; } all_programs.insert(module, program.clone()); - program_stack.extend(program.borrow().imported_modules.clone()); + program_stack.extend(program.imported_modules.clone()); } all_programs.insert("".to_string(), program.clone()); let core_programs: Vec<_> = all_programs @@ -365,16 +349,15 @@ pub fn reform_ast(program: &Program) -> leo_ast::Program { .collect(); all_programs.retain(|module, _| !module.starts_with("core.")); - let mut all_circuits: IndexMap> = IndexMap::new(); - let mut all_functions: IndexMap> = IndexMap::new(); - let mut all_test_functions: IndexMap, Option)> = IndexMap::new(); + let mut all_circuits: IndexMap> = IndexMap::new(); + let mut all_functions: IndexMap> = IndexMap::new(); + let mut all_test_functions: IndexMap, Option)> = IndexMap::new(); let mut identifiers = InternalIdentifierGenerator { next: 0 }; for (_, program) in all_programs.into_iter() { - let program = program.borrow(); for (name, circuit) in program.circuits.iter() { let identifier = format!("{}{}", identifiers.next().unwrap(), name); - circuit.circuit.name.borrow_mut().name = identifier.clone(); - all_circuits.insert(identifier, circuit.clone()); + circuit.name.borrow_mut().name = identifier.clone(); + all_circuits.insert(identifier, *circuit); } for (name, function) in program.functions.iter() { let identifier = if name == "main" { @@ -382,12 +365,12 @@ pub fn reform_ast(program: &Program) -> leo_ast::Program { } else { format!("{}{}", identifiers.next().unwrap(), name) }; - function.function.name.borrow_mut().name = identifier.clone(); - all_functions.insert(identifier, function.clone()); + function.name.borrow_mut().name = identifier.clone(); + all_functions.insert(identifier, *function); } for (name, function) in program.test_functions.iter() { let identifier = format!("{}{}", identifiers.next().unwrap(), name); - function.0.function.name.borrow_mut().name = identifier.clone(); + function.0.name.borrow_mut().name = identifier.clone(); all_test_functions.insert(identifier, function.clone()); } } @@ -409,29 +392,24 @@ pub fn reform_ast(program: &Program) -> leo_ast::Program { tests: all_test_functions .into_iter() .map(|(_, (function, ident))| { - (function.function.name.borrow().clone(), leo_ast::TestFunction { - function: function.function.as_ref().into(), + (function.name.borrow().clone(), leo_ast::TestFunction { + function: function.into(), input_file: ident, }) }) .collect(), functions: all_functions .into_iter() - .map(|(_, function)| { - ( - function.function.name.borrow().clone(), - function.function.as_ref().into(), - ) - }) + .map(|(_, function)| (function.name.borrow().clone(), function.into())) .collect(), circuits: all_circuits .into_iter() - .map(|(_, circuit)| (circuit.circuit.name.borrow().clone(), circuit.circuit.as_ref().into())) + .map(|(_, circuit)| (circuit.name.borrow().clone(), circuit.into())) .collect(), } } -impl Into for &InternalProgram { +impl<'a> Into for &InternalProgram<'a> { fn into(self) -> leo_ast::Program { leo_ast::Program { name: self.name.clone(), @@ -440,24 +418,19 @@ impl Into for &InternalProgram { circuits: self .circuits .iter() - .map(|(_, circuit)| (circuit.circuit.name.borrow().clone(), circuit.circuit.as_ref().into())) + .map(|(_, circuit)| (circuit.name.borrow().clone(), (*circuit).into())) .collect(), functions: self .functions .iter() - .map(|(_, function)| { - ( - function.function.name.borrow().clone(), - function.function.as_ref().into(), - ) - }) + .map(|(_, function)| (function.name.borrow().clone(), (*function).into())) .collect(), tests: self .test_functions .iter() .map(|(_, function)| { - (function.0.function.name.borrow().clone(), leo_ast::TestFunction { - function: function.0.function.as_ref().into(), + (function.0.name.borrow().clone(), leo_ast::TestFunction { + function: function.0.into(), input_file: function.1.clone(), }) }) diff --git a/asg/src/reducer/mod.rs b/asg/src/reducer/mod.rs index 7af5e14c7b..92a123ed29 100644 --- a/asg/src/reducer/mod.rs +++ b/asg/src/reducer/mod.rs @@ -25,3 +25,9 @@ pub use monoidal_director::*; mod monoidal_reducer; pub use monoidal_reducer::*; + +mod visitor; +pub use visitor::*; + +mod visitor_director; +pub use visitor_director::*; diff --git a/asg/src/reducer/monoidal_director.rs b/asg/src/reducer/monoidal_director.rs index 8c37436e67..40a55d4a30 100644 --- a/asg/src/reducer/monoidal_director.rs +++ b/asg/src/reducer/monoidal_director.rs @@ -17,14 +17,14 @@ use super::*; use crate::{expression::*, program::*, statement::*}; -use std::{marker::PhantomData, sync::Arc}; +use std::marker::PhantomData; -pub struct MonoidalDirector> { +pub struct MonoidalDirector<'a, T: Monoid, R: MonoidalReducerExpression<'a, T>> { reducer: R, - _monoid: PhantomData, + _monoid: PhantomData<&'a T>, } -impl> MonoidalDirector { +impl<'a, T: Monoid, R: MonoidalReducerExpression<'a, T>> MonoidalDirector<'a, T, R> { pub fn new(reducer: R) -> Self { Self { reducer, @@ -36,8 +36,8 @@ impl> MonoidalDirector { self.reducer } - pub fn reduce_expression(&mut self, input: &Arc) -> T { - match &**input { + pub fn reduce_expression(&mut self, input: &'a Expression<'a>) -> T { + let value = match input { Expression::ArrayAccess(e) => self.reduce_array_access(e), Expression::ArrayInit(e) => self.reduce_array_init(e), Expression::ArrayInline(e) => self.reduce_array_inline(e), @@ -52,101 +52,115 @@ impl> MonoidalDirector { Expression::TupleInit(e) => self.reduce_tuple_init(e), Expression::Unary(e) => self.reduce_unary(e), Expression::VariableRef(e) => self.reduce_variable_ref(e), - } + }; + + self.reducer.reduce_expression(input, value) } - pub fn reduce_array_access(&mut self, input: &ArrayAccessExpression) -> T { - let array = self.reduce_expression(&input.array); - let index = self.reduce_expression(&input.index); + pub fn reduce_array_access(&mut self, input: &ArrayAccessExpression<'a>) -> T { + let array = self.reduce_expression(input.array.get()); + let index = self.reduce_expression(input.index.get()); self.reducer.reduce_array_access(input, array, index) } - pub fn reduce_array_init(&mut self, input: &ArrayInitExpression) -> T { - let element = self.reduce_expression(&input.element); + pub fn reduce_array_init(&mut self, input: &ArrayInitExpression<'a>) -> T { + let element = self.reduce_expression(input.element.get()); self.reducer.reduce_array_init(input, element) } - pub fn reduce_array_inline(&mut self, input: &ArrayInlineExpression) -> T { - let elements = input.elements.iter().map(|(x, _)| self.reduce_expression(x)).collect(); + pub fn reduce_array_inline(&mut self, input: &ArrayInlineExpression<'a>) -> T { + let elements = input + .elements + .iter() + .map(|(x, _)| self.reduce_expression(x.get())) + .collect(); self.reducer.reduce_array_inline(input, elements) } - pub fn reduce_array_range_access(&mut self, input: &ArrayRangeAccessExpression) -> T { - let array = self.reduce_expression(&input.array); - let left = input.left.as_ref().map(|e| self.reduce_expression(e)); - let right = input.right.as_ref().map(|e| self.reduce_expression(e)); + pub fn reduce_array_range_access(&mut self, input: &ArrayRangeAccessExpression<'a>) -> T { + let array = self.reduce_expression(input.array.get()); + let left = input.left.get().map(|e| self.reduce_expression(e)); + let right = input.right.get().map(|e| self.reduce_expression(e)); self.reducer.reduce_array_range_access(input, array, left, right) } - pub fn reduce_binary(&mut self, input: &BinaryExpression) -> T { - let left = self.reduce_expression(&input.left); - let right = self.reduce_expression(&input.right); + pub fn reduce_binary(&mut self, input: &BinaryExpression<'a>) -> T { + let left = self.reduce_expression(input.left.get()); + let right = self.reduce_expression(input.right.get()); self.reducer.reduce_binary(input, left, right) } - pub fn reduce_call(&mut self, input: &CallExpression) -> T { - let target = input.target.as_ref().map(|e| self.reduce_expression(e)); - let arguments = input.arguments.iter().map(|e| self.reduce_expression(e)).collect(); + pub fn reduce_call(&mut self, input: &CallExpression<'a>) -> T { + let target = input.target.get().map(|e| self.reduce_expression(e)); + let arguments = input + .arguments + .iter() + .map(|e| self.reduce_expression(e.get())) + .collect(); self.reducer.reduce_call(input, target, arguments) } - pub fn reduce_circuit_access(&mut self, input: &CircuitAccessExpression) -> T { - let target = input.target.as_ref().map(|e| self.reduce_expression(e)); + pub fn reduce_circuit_access(&mut self, input: &CircuitAccessExpression<'a>) -> T { + let target = input.target.get().map(|e| self.reduce_expression(e)); self.reducer.reduce_circuit_access(input, target) } - pub fn reduce_circuit_init(&mut self, input: &CircuitInitExpression) -> T { - let values = input.values.iter().map(|(_, e)| self.reduce_expression(e)).collect(); + pub fn reduce_circuit_init(&mut self, input: &CircuitInitExpression<'a>) -> T { + let values = input + .values + .iter() + .map(|(_, e)| self.reduce_expression(e.get())) + .collect(); self.reducer.reduce_circuit_init(input, values) } - pub fn reduce_ternary_expression(&mut self, input: &TernaryExpression) -> T { - let condition = self.reduce_expression(&input.condition); - let if_true = self.reduce_expression(&input.if_true); - let if_false = self.reduce_expression(&input.if_false); + pub fn reduce_ternary_expression(&mut self, input: &TernaryExpression<'a>) -> T { + let condition = self.reduce_expression(input.condition.get()); + let if_true = self.reduce_expression(input.if_true.get()); + let if_false = self.reduce_expression(input.if_false.get()); self.reducer .reduce_ternary_expression(input, condition, if_true, if_false) } - pub fn reduce_constant(&mut self, input: &Constant) -> T { + pub fn reduce_constant(&mut self, input: &Constant<'a>) -> T { self.reducer.reduce_constant(input) } - pub fn reduce_tuple_access(&mut self, input: &TupleAccessExpression) -> T { - let tuple_ref = self.reduce_expression(&input.tuple_ref); + pub fn reduce_tuple_access(&mut self, input: &TupleAccessExpression<'a>) -> T { + let tuple_ref = self.reduce_expression(input.tuple_ref.get()); self.reducer.reduce_tuple_access(input, tuple_ref) } - pub fn reduce_tuple_init(&mut self, input: &TupleInitExpression) -> T { - let values = input.elements.iter().map(|e| self.reduce_expression(e)).collect(); + pub fn reduce_tuple_init(&mut self, input: &TupleInitExpression<'a>) -> T { + let values = input.elements.iter().map(|e| self.reduce_expression(e.get())).collect(); self.reducer.reduce_tuple_init(input, values) } - pub fn reduce_unary(&mut self, input: &UnaryExpression) -> T { - let inner = self.reduce_expression(&input.inner); + pub fn reduce_unary(&mut self, input: &UnaryExpression<'a>) -> T { + let inner = self.reduce_expression(input.inner.get()); self.reducer.reduce_unary(input, inner) } - pub fn reduce_variable_ref(&mut self, input: &VariableRef) -> T { + pub fn reduce_variable_ref(&mut self, input: &VariableRef<'a>) -> T { self.reducer.reduce_variable_ref(input) } } -impl> MonoidalDirector { - pub fn reduce_statement(&mut self, input: &Arc) -> T { - match &**input { +impl<'a, T: Monoid, R: MonoidalReducerStatement<'a, T>> MonoidalDirector<'a, T, R> { + pub fn reduce_statement(&mut self, input: &'a Statement<'a>) -> T { + let value = match input { Statement::Assign(s) => self.reduce_assign(s), Statement::Block(s) => self.reduce_block(s), Statement::Conditional(s) => self.reduce_conditional_statement(s), @@ -155,57 +169,67 @@ impl> MonoidalDirector { Statement::Expression(s) => self.reduce_expression_statement(s), Statement::Iteration(s) => self.reduce_iteration(s), Statement::Return(s) => self.reduce_return(s), - } + }; + + self.reducer.reduce_statement(input, value) } - pub fn reduce_assign_access(&mut self, input: &AssignAccess) -> T { + pub fn reduce_assign_access(&mut self, input: &AssignAccess<'a>) -> T { let (left, right) = match input { AssignAccess::ArrayRange(left, right) => ( - left.as_ref().map(|e| self.reduce_expression(e)), - right.as_ref().map(|e| self.reduce_expression(e)), + left.get().map(|e| self.reduce_expression(e)), + right.get().map(|e| self.reduce_expression(e)), ), - AssignAccess::ArrayIndex(index) => (Some(self.reduce_expression(index)), None), + AssignAccess::ArrayIndex(index) => (Some(self.reduce_expression(index.get())), None), _ => (None, None), }; self.reducer.reduce_assign_access(input, left, right) } - pub fn reduce_assign(&mut self, input: &AssignStatement) -> T { + pub fn reduce_assign(&mut self, input: &AssignStatement<'a>) -> T { let accesses = input .target_accesses .iter() .map(|x| self.reduce_assign_access(x)) .collect(); - let value = self.reduce_expression(&input.value); + let value = self.reduce_expression(input.value.get()); self.reducer.reduce_assign(input, accesses, value) } - pub fn reduce_block(&mut self, input: &BlockStatement) -> T { - let statements = input.statements.iter().map(|x| self.reduce_statement(x)).collect(); + pub fn reduce_block(&mut self, input: &BlockStatement<'a>) -> T { + let statements = input + .statements + .iter() + .map(|x| self.reduce_statement(x.get())) + .collect(); self.reducer.reduce_block(input, statements) } - pub fn reduce_conditional_statement(&mut self, input: &ConditionalStatement) -> T { - let condition = self.reduce_expression(&input.condition); - let if_true = self.reduce_statement(&input.result); - let if_false = input.next.as_ref().map(|s| self.reduce_statement(s)); + pub fn reduce_conditional_statement(&mut self, input: &ConditionalStatement<'a>) -> T { + let condition = self.reduce_expression(input.condition.get()); + let if_true = self.reduce_statement(input.result.get()); + let if_false = input.next.get().map(|s| self.reduce_statement(s)); self.reducer .reduce_conditional_statement(input, condition, if_true, if_false) } - pub fn reduce_formatted_string(&mut self, input: &FormattedString) -> T { - let parameters = input.parameters.iter().map(|e| self.reduce_expression(e)).collect(); + pub fn reduce_formatted_string(&mut self, input: &FormattedString<'a>) -> T { + let parameters = input + .parameters + .iter() + .map(|e| self.reduce_expression(e.get())) + .collect(); self.reducer.reduce_formatted_string(input, parameters) } - pub fn reduce_console(&mut self, input: &ConsoleStatement) -> T { + pub fn reduce_console(&mut self, input: &ConsoleStatement<'a>) -> T { let argument = match &input.function { - ConsoleFunction::Assert(e) => self.reduce_expression(e), + ConsoleFunction::Assert(e) => self.reduce_expression(e.get()), ConsoleFunction::Debug(f) | ConsoleFunction::Error(f) | ConsoleFunction::Log(f) => { self.reduce_formatted_string(f) } @@ -214,51 +238,51 @@ impl> MonoidalDirector { self.reducer.reduce_console(input, argument) } - pub fn reduce_definition(&mut self, input: &DefinitionStatement) -> T { - let value = self.reduce_expression(&input.value); + pub fn reduce_definition(&mut self, input: &DefinitionStatement<'a>) -> T { + let value = self.reduce_expression(input.value.get()); self.reducer.reduce_definition(input, value) } - pub fn reduce_expression_statement(&mut self, input: &ExpressionStatement) -> T { - let value = self.reduce_expression(&input.expression); + pub fn reduce_expression_statement(&mut self, input: &ExpressionStatement<'a>) -> T { + let value = self.reduce_expression(input.expression.get()); self.reducer.reduce_expression_statement(input, value) } - pub fn reduce_iteration(&mut self, input: &IterationStatement) -> T { - let start = self.reduce_expression(&input.start); - let stop = self.reduce_expression(&input.stop); - let body = self.reduce_statement(&input.body); + pub fn reduce_iteration(&mut self, input: &IterationStatement<'a>) -> T { + let start = self.reduce_expression(input.start.get()); + let stop = self.reduce_expression(input.stop.get()); + let body = self.reduce_statement(input.body.get()); self.reducer.reduce_iteration(input, start, stop, body) } - pub fn reduce_return(&mut self, input: &ReturnStatement) -> T { - let value = self.reduce_expression(&input.expression); + pub fn reduce_return(&mut self, input: &ReturnStatement<'a>) -> T { + let value = self.reduce_expression(input.expression.get()); self.reducer.reduce_return(input, value) } } #[allow(dead_code)] -impl> MonoidalDirector { - fn reduce_function(&mut self, input: &Arc) -> T { - let body = self.reduce_statement(&input.body); +impl<'a, T: Monoid, R: MonoidalReducerProgram<'a, T>> MonoidalDirector<'a, T, R> { + fn reduce_function(&mut self, input: &'a Function<'a>) -> T { + let body = input.body.get().map(|s| self.reduce_statement(s)).unwrap_or_default(); self.reducer.reduce_function(input, body) } - fn reduce_circuit_member(&mut self, input: &CircuitMemberBody) -> T { + fn reduce_circuit_member(&mut self, input: &CircuitMember<'a>) -> T { let function = match input { - CircuitMemberBody::Function(f) => Some(self.reduce_function(f)), + CircuitMember::Function(f) => Some(self.reduce_function(f)), _ => None, }; self.reducer.reduce_circuit_member(input, function) } - fn reduce_circuit(&mut self, input: &Arc) -> T { + fn reduce_circuit(&mut self, input: &'a Circuit<'a>) -> T { let members = input .members .borrow() @@ -269,8 +293,7 @@ impl> MonoidalDirector { self.reducer.reduce_circuit(input, members) } - fn reduce_program(&mut self, input: &Program) -> T { - let input = input.borrow(); + fn reduce_program(&mut self, input: &Program<'a>) -> T { let imported_modules = input .imported_modules .iter() diff --git a/asg/src/reducer/monoidal_reducer.rs b/asg/src/reducer/monoidal_reducer.rs index cd37dd601d..ba7a177f5b 100644 --- a/asg/src/reducer/monoidal_reducer.rs +++ b/asg/src/reducer/monoidal_reducer.rs @@ -16,29 +16,27 @@ use crate::{expression::*, program::*, statement::*, Monoid}; -use std::sync::Arc; - #[allow(unused_variables)] -pub trait MonoidalReducerExpression { - fn reduce_expression(&mut self, input: &Arc, value: T) -> T { +pub trait MonoidalReducerExpression<'a, T: Monoid> { + fn reduce_expression(&mut self, input: &'a Expression<'a>, value: T) -> T { value } - fn reduce_array_access(&mut self, input: &ArrayAccessExpression, array: T, index: T) -> T { + fn reduce_array_access(&mut self, input: &ArrayAccessExpression<'a>, array: T, index: T) -> T { array.append(index) } - fn reduce_array_init(&mut self, input: &ArrayInitExpression, element: T) -> T { + fn reduce_array_init(&mut self, input: &ArrayInitExpression<'a>, element: T) -> T { element } - fn reduce_array_inline(&mut self, input: &ArrayInlineExpression, elements: Vec) -> T { + fn reduce_array_inline(&mut self, input: &ArrayInlineExpression<'a>, elements: Vec) -> T { T::default().append_all(elements.into_iter()) } fn reduce_array_range_access( &mut self, - input: &ArrayRangeAccessExpression, + input: &ArrayRangeAccessExpression<'a>, array: T, left: Option, right: Option, @@ -46,69 +44,69 @@ pub trait MonoidalReducerExpression { array.append_option(left).append_option(right) } - fn reduce_binary(&mut self, input: &BinaryExpression, left: T, right: T) -> T { + fn reduce_binary(&mut self, input: &BinaryExpression<'a>, left: T, right: T) -> T { left.append(right) } - fn reduce_call(&mut self, input: &CallExpression, target: Option, arguments: Vec) -> T { + fn reduce_call(&mut self, input: &CallExpression<'a>, target: Option, arguments: Vec) -> T { target.unwrap_or_default().append_all(arguments.into_iter()) } - fn reduce_circuit_access(&mut self, input: &CircuitAccessExpression, target: Option) -> T { + fn reduce_circuit_access(&mut self, input: &CircuitAccessExpression<'a>, target: Option) -> T { target.unwrap_or_default() } - fn reduce_circuit_init(&mut self, input: &CircuitInitExpression, values: Vec) -> T { + fn reduce_circuit_init(&mut self, input: &CircuitInitExpression<'a>, values: Vec) -> T { T::default().append_all(values.into_iter()) } - fn reduce_ternary_expression(&mut self, input: &TernaryExpression, condition: T, if_true: T, if_false: T) -> T { + fn reduce_ternary_expression(&mut self, input: &TernaryExpression<'a>, condition: T, if_true: T, if_false: T) -> T { condition.append(if_true).append(if_false) } - fn reduce_constant(&mut self, input: &Constant) -> T { + fn reduce_constant(&mut self, input: &Constant<'a>) -> T { T::default() } - fn reduce_tuple_access(&mut self, input: &TupleAccessExpression, tuple_ref: T) -> T { + fn reduce_tuple_access(&mut self, input: &TupleAccessExpression<'a>, tuple_ref: T) -> T { tuple_ref } - fn reduce_tuple_init(&mut self, input: &TupleInitExpression, values: Vec) -> T { + fn reduce_tuple_init(&mut self, input: &TupleInitExpression<'a>, values: Vec) -> T { T::default().append_all(values.into_iter()) } - fn reduce_unary(&mut self, input: &UnaryExpression, inner: T) -> T { + fn reduce_unary(&mut self, input: &UnaryExpression<'a>, inner: T) -> T { inner } - fn reduce_variable_ref(&mut self, input: &VariableRef) -> T { + fn reduce_variable_ref(&mut self, input: &VariableRef<'a>) -> T { T::default() } } #[allow(unused_variables)] -pub trait MonoidalReducerStatement: MonoidalReducerExpression { - fn reduce_statement(&mut self, input: &Arc, value: T) -> T { +pub trait MonoidalReducerStatement<'a, T: Monoid>: MonoidalReducerExpression<'a, T> { + fn reduce_statement(&mut self, input: &'a Statement<'a>, value: T) -> T { value } // left = Some(ArrayIndex.0) always if AssignAccess::ArrayIndex. if member/tuple, always None - fn reduce_assign_access(&mut self, input: &AssignAccess, left: Option, right: Option) -> T { + fn reduce_assign_access(&mut self, input: &AssignAccess<'a>, left: Option, right: Option) -> T { left.unwrap_or_default().append_option(right) } - fn reduce_assign(&mut self, input: &AssignStatement, accesses: Vec, value: T) -> T { + fn reduce_assign(&mut self, input: &AssignStatement<'a>, accesses: Vec, value: T) -> T { T::default().append_all(accesses.into_iter()).append(value) } - fn reduce_block(&mut self, input: &BlockStatement, statements: Vec) -> T { + fn reduce_block(&mut self, input: &BlockStatement<'a>, statements: Vec) -> T { T::default().append_all(statements.into_iter()) } fn reduce_conditional_statement( &mut self, - input: &ConditionalStatement, + input: &ConditionalStatement<'a>, condition: T, if_true: T, if_false: Option, @@ -116,42 +114,42 @@ pub trait MonoidalReducerStatement: MonoidalReducerExpression { condition.append(if_true).append_option(if_false) } - fn reduce_formatted_string(&mut self, input: &FormattedString, parameters: Vec) -> T { + fn reduce_formatted_string(&mut self, input: &FormattedString<'a>, parameters: Vec) -> T { T::default().append_all(parameters.into_iter()) } - fn reduce_console(&mut self, input: &ConsoleStatement, argument: T) -> T { + fn reduce_console(&mut self, input: &ConsoleStatement<'a>, argument: T) -> T { argument } - fn reduce_definition(&mut self, input: &DefinitionStatement, value: T) -> T { + fn reduce_definition(&mut self, input: &DefinitionStatement<'a>, value: T) -> T { value } - fn reduce_expression_statement(&mut self, input: &ExpressionStatement, expression: T) -> T { + fn reduce_expression_statement(&mut self, input: &ExpressionStatement<'a>, expression: T) -> T { expression } - fn reduce_iteration(&mut self, input: &IterationStatement, start: T, stop: T, body: T) -> T { + fn reduce_iteration(&mut self, input: &IterationStatement<'a>, start: T, stop: T, body: T) -> T { start.append(stop).append(body) } - fn reduce_return(&mut self, input: &ReturnStatement, value: T) -> T { + fn reduce_return(&mut self, input: &ReturnStatement<'a>, value: T) -> T { value } } #[allow(unused_variables)] -pub trait MonoidalReducerProgram: MonoidalReducerStatement { - fn reduce_function(&mut self, input: &Arc, body: T) -> T { +pub trait MonoidalReducerProgram<'a, T: Monoid>: MonoidalReducerStatement<'a, T> { + fn reduce_function(&mut self, input: &'a Function<'a>, body: T) -> T { body } - fn reduce_circuit_member(&mut self, input: &CircuitMemberBody, function: Option) -> T { + fn reduce_circuit_member(&mut self, input: &CircuitMember<'a>, function: Option) -> T { function.unwrap_or_default() } - fn reduce_circuit(&mut self, input: &Arc, members: Vec) -> T { + fn reduce_circuit(&mut self, input: &'a Circuit<'a>, members: Vec) -> T { T::default().append_all(members.into_iter()) } diff --git a/asg/src/reducer/visitor.rs b/asg/src/reducer/visitor.rs new file mode 100644 index 0000000000..415f3ed265 --- /dev/null +++ b/asg/src/reducer/visitor.rs @@ -0,0 +1,161 @@ +// Copyright (C) 2019-2021 Aleo Systems Inc. +// This file is part of the Leo library. + +// The Leo library is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. + +// The Leo library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. + +// You should have received a copy of the GNU General Public License +// along with the Leo library. If not, see . + +use std::cell::Cell; + +use crate::{expression::*, program::*, statement::*}; + +pub enum VisitResult { + VisitChildren, + SkipChildren, + Exit, +} + +impl Default for VisitResult { + fn default() -> Self { + VisitResult::VisitChildren + } +} + +#[allow(unused_variables)] +pub trait ExpressionVisitor<'a> { + fn visit_expression(&mut self, input: &Cell<&'a Expression<'a>>) -> VisitResult { + Default::default() + } + + fn visit_array_access(&mut self, input: &ArrayAccessExpression<'a>) -> VisitResult { + Default::default() + } + + fn visit_array_init(&mut self, input: &ArrayInitExpression<'a>) -> VisitResult { + Default::default() + } + + fn visit_array_inline(&mut self, input: &ArrayInlineExpression<'a>) -> VisitResult { + Default::default() + } + + fn visit_array_range_access(&mut self, input: &ArrayRangeAccessExpression<'a>) -> VisitResult { + Default::default() + } + + fn visit_binary(&mut self, input: &BinaryExpression<'a>) -> VisitResult { + Default::default() + } + + fn visit_call(&mut self, input: &CallExpression<'a>) -> VisitResult { + Default::default() + } + + fn visit_circuit_access(&mut self, input: &CircuitAccessExpression<'a>) -> VisitResult { + Default::default() + } + + fn visit_circuit_init(&mut self, input: &CircuitInitExpression<'a>) -> VisitResult { + Default::default() + } + + fn visit_ternary_expression(&mut self, input: &TernaryExpression<'a>) -> VisitResult { + Default::default() + } + + fn visit_constant(&mut self, input: &Constant<'a>) -> VisitResult { + Default::default() + } + + fn visit_tuple_access(&mut self, input: &TupleAccessExpression<'a>) -> VisitResult { + Default::default() + } + + fn visit_tuple_init(&mut self, input: &TupleInitExpression<'a>) -> VisitResult { + Default::default() + } + + fn visit_unary(&mut self, input: &UnaryExpression<'a>) -> VisitResult { + Default::default() + } + + fn visit_variable_ref(&mut self, input: &VariableRef<'a>) -> VisitResult { + Default::default() + } +} + +#[allow(unused_variables)] +pub trait StatementVisitor<'a>: ExpressionVisitor<'a> { + fn visit_statement(&mut self, input: &Cell<&'a Statement<'a>>) -> VisitResult { + Default::default() + } + + // left = Some(ArrayIndex.0) always if AssignAccess::ArrayIndex. if member/tuple, always None + fn visit_assign_access(&mut self, input: &AssignAccess<'a>) -> VisitResult { + Default::default() + } + + fn visit_assign(&mut self, input: &AssignStatement<'a>) -> VisitResult { + Default::default() + } + + fn visit_block(&mut self, input: &BlockStatement<'a>) -> VisitResult { + Default::default() + } + + fn visit_conditional_statement(&mut self, input: &ConditionalStatement<'a>) -> VisitResult { + Default::default() + } + + fn visit_formatted_string(&mut self, input: &FormattedString<'a>) -> VisitResult { + Default::default() + } + + fn visit_console(&mut self, input: &ConsoleStatement<'a>) -> VisitResult { + Default::default() + } + + fn visit_definition(&mut self, input: &DefinitionStatement<'a>) -> VisitResult { + Default::default() + } + + fn visit_expression_statement(&mut self, input: &ExpressionStatement<'a>) -> VisitResult { + Default::default() + } + + fn visit_iteration(&mut self, input: &IterationStatement<'a>) -> VisitResult { + Default::default() + } + + fn visit_return(&mut self, input: &ReturnStatement<'a>) -> VisitResult { + Default::default() + } +} + +#[allow(unused_variables)] +pub trait ProgramVisitor<'a>: StatementVisitor<'a> { + fn visit_function(&mut self, input: &'a Function<'a>) -> VisitResult { + Default::default() + } + + fn visit_circuit_member(&mut self, input: &CircuitMember<'a>) -> VisitResult { + Default::default() + } + + fn visit_circuit(&mut self, input: &'a Circuit<'a>) -> VisitResult { + Default::default() + } + + fn visit_program(&mut self, input: &Program<'a>) -> VisitResult { + Default::default() + } +} diff --git a/asg/src/reducer/visitor_director.rs b/asg/src/reducer/visitor_director.rs new file mode 100644 index 0000000000..d71926d2d0 --- /dev/null +++ b/asg/src/reducer/visitor_director.rs @@ -0,0 +1,442 @@ +// Copyright (C) 2019-2021 Aleo Systems Inc. +// This file is part of the Leo library. + +// The Leo library is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. + +// The Leo library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. + +// You should have received a copy of the GNU General Public License +// along with the Leo library. If not, see . + +use super::*; +use crate::{expression::*, program::*, statement::*}; + +use std::{cell::Cell, marker::PhantomData}; + +pub struct VisitorDirector<'a, R: ExpressionVisitor<'a>> { + visitor: R, + lifetime: PhantomData<&'a ()>, +} + +pub type ConcreteVisitResult = Result<(), ()>; + +impl Into for VisitResult { + fn into(self) -> ConcreteVisitResult { + match self { + VisitResult::VisitChildren => Ok(()), + VisitResult::SkipChildren => Ok(()), + VisitResult::Exit => Err(()), + } + } +} + +impl<'a, R: ExpressionVisitor<'a>> VisitorDirector<'a, R> { + pub fn new(visitor: R) -> Self { + Self { + visitor, + lifetime: PhantomData, + } + } + + pub fn visitor(self) -> R { + self.visitor + } + + pub fn visit_expression(&mut self, input: &Cell<&'a Expression<'a>>) -> ConcreteVisitResult { + match self.visitor.visit_expression(input) { + VisitResult::VisitChildren => match input.get() { + Expression::ArrayAccess(e) => self.visit_array_access(e), + Expression::ArrayInit(e) => self.visit_array_init(e), + Expression::ArrayInline(e) => self.visit_array_inline(e), + Expression::ArrayRangeAccess(e) => self.visit_array_range_access(e), + Expression::Binary(e) => self.visit_binary(e), + Expression::Call(e) => self.visit_call(e), + Expression::CircuitAccess(e) => self.visit_circuit_access(e), + Expression::CircuitInit(e) => self.visit_circuit_init(e), + Expression::Ternary(e) => self.visit_ternary_expression(e), + Expression::Constant(e) => self.visit_constant(e), + Expression::TupleAccess(e) => self.visit_tuple_access(e), + Expression::TupleInit(e) => self.visit_tuple_init(e), + Expression::Unary(e) => self.visit_unary(e), + Expression::VariableRef(e) => self.visit_variable_ref(e), + }, + x => x.into(), + } + } + + fn visit_opt_expression(&mut self, input: &Cell>>) -> ConcreteVisitResult { + let interior = match input.get() { + Some(expr) => Some(Cell::new(expr)), + None => None, + }; + if let Some(interior) = interior.as_ref() { + let result = self.visit_expression(interior); + input.replace(Some(interior.get())); + result + } else { + Ok(()) + } + } + + pub fn visit_array_access(&mut self, input: &ArrayAccessExpression<'a>) -> ConcreteVisitResult { + match self.visitor.visit_array_access(input) { + VisitResult::VisitChildren => { + self.visit_expression(&input.array)?; + self.visit_expression(&input.index)?; + Ok(()) + } + x => x.into(), + } + } + + pub fn visit_array_init(&mut self, input: &ArrayInitExpression<'a>) -> ConcreteVisitResult { + match self.visitor.visit_array_init(input) { + VisitResult::VisitChildren => { + self.visit_expression(&input.element)?; + Ok(()) + } + x => x.into(), + } + } + + pub fn visit_array_inline(&mut self, input: &ArrayInlineExpression<'a>) -> ConcreteVisitResult { + match self.visitor.visit_array_inline(input) { + VisitResult::VisitChildren => { + for (element, _) in input.elements.iter() { + self.visit_expression(element)?; + } + Ok(()) + } + x => x.into(), + } + } + + pub fn visit_array_range_access(&mut self, input: &ArrayRangeAccessExpression<'a>) -> ConcreteVisitResult { + match self.visitor.visit_array_range_access(input) { + VisitResult::VisitChildren => { + self.visit_expression(&input.array)?; + self.visit_opt_expression(&input.left)?; + self.visit_opt_expression(&input.right)?; + Ok(()) + } + x => x.into(), + } + } + + pub fn visit_binary(&mut self, input: &BinaryExpression<'a>) -> ConcreteVisitResult { + match self.visitor.visit_binary(input) { + VisitResult::VisitChildren => { + self.visit_expression(&input.left)?; + self.visit_expression(&input.right)?; + Ok(()) + } + x => x.into(), + } + } + + pub fn visit_call(&mut self, input: &CallExpression<'a>) -> ConcreteVisitResult { + match self.visitor.visit_call(input) { + VisitResult::VisitChildren => { + self.visit_opt_expression(&input.target)?; + for argument in input.arguments.iter() { + self.visit_expression(argument)?; + } + Ok(()) + } + x => x.into(), + } + } + + pub fn visit_circuit_access(&mut self, input: &CircuitAccessExpression<'a>) -> ConcreteVisitResult { + match self.visitor.visit_circuit_access(input) { + VisitResult::VisitChildren => { + self.visit_opt_expression(&input.target)?; + Ok(()) + } + x => x.into(), + } + } + + pub fn visit_circuit_init(&mut self, input: &CircuitInitExpression<'a>) -> ConcreteVisitResult { + match self.visitor.visit_circuit_init(input) { + VisitResult::VisitChildren => { + for (_, argument) in input.values.iter() { + self.visit_expression(argument)?; + } + Ok(()) + } + x => x.into(), + } + } + + pub fn visit_ternary_expression(&mut self, input: &TernaryExpression<'a>) -> ConcreteVisitResult { + match self.visitor.visit_ternary_expression(input) { + VisitResult::VisitChildren => { + self.visit_expression(&input.condition)?; + self.visit_expression(&input.if_true)?; + self.visit_expression(&input.if_false)?; + Ok(()) + } + x => x.into(), + } + } + + pub fn visit_constant(&mut self, input: &Constant<'a>) -> ConcreteVisitResult { + self.visitor.visit_constant(input).into() + } + + pub fn visit_tuple_access(&mut self, input: &TupleAccessExpression<'a>) -> ConcreteVisitResult { + match self.visitor.visit_tuple_access(input) { + VisitResult::VisitChildren => { + self.visit_expression(&input.tuple_ref)?; + Ok(()) + } + x => x.into(), + } + } + + pub fn visit_tuple_init(&mut self, input: &TupleInitExpression<'a>) -> ConcreteVisitResult { + match self.visitor.visit_tuple_init(input) { + VisitResult::VisitChildren => { + for argument in input.elements.iter() { + self.visit_expression(argument)?; + } + Ok(()) + } + x => x.into(), + } + } + + pub fn visit_unary(&mut self, input: &UnaryExpression<'a>) -> ConcreteVisitResult { + match self.visitor.visit_unary(input) { + VisitResult::VisitChildren => { + self.visit_expression(&input.inner)?; + Ok(()) + } + x => x.into(), + } + } + + pub fn visit_variable_ref(&mut self, input: &VariableRef<'a>) -> ConcreteVisitResult { + self.visitor.visit_variable_ref(input).into() + } +} + +impl<'a, R: StatementVisitor<'a>> VisitorDirector<'a, R> { + pub fn visit_statement(&mut self, input: &Cell<&'a Statement<'a>>) -> ConcreteVisitResult { + match self.visitor.visit_statement(input) { + VisitResult::VisitChildren => match input.get() { + Statement::Assign(s) => self.visit_assign(s), + Statement::Block(s) => self.visit_block(s), + Statement::Conditional(s) => self.visit_conditional_statement(s), + Statement::Console(s) => self.visit_console(s), + Statement::Definition(s) => self.visit_definition(s), + Statement::Expression(s) => self.visit_expression_statement(s), + Statement::Iteration(s) => self.visit_iteration(s), + Statement::Return(s) => self.visit_return(s), + }, + x => x.into(), + } + } + + fn visit_opt_statement(&mut self, input: &Cell>>) -> ConcreteVisitResult { + let interior = match input.get() { + Some(expr) => Some(Cell::new(expr)), + None => None, + }; + if let Some(interior) = interior.as_ref() { + let result = self.visit_statement(interior); + input.replace(Some(interior.get())); + result + } else { + Ok(()) + } + } + + pub fn visit_assign_access(&mut self, input: &AssignAccess<'a>) -> ConcreteVisitResult { + match self.visitor.visit_assign_access(input) { + VisitResult::VisitChildren => { + match input { + AssignAccess::ArrayRange(left, right) => { + self.visit_opt_expression(left)?; + self.visit_opt_expression(right)?; + } + AssignAccess::ArrayIndex(index) => self.visit_expression(index)?, + _ => (), + } + Ok(()) + } + x => x.into(), + } + } + + pub fn visit_assign(&mut self, input: &AssignStatement<'a>) -> ConcreteVisitResult { + match self.visitor.visit_assign(input) { + VisitResult::VisitChildren => { + for access in input.target_accesses.iter() { + self.visit_assign_access(access)?; + } + self.visit_expression(&input.value)?; + Ok(()) + } + x => x.into(), + } + } + + pub fn visit_block(&mut self, input: &BlockStatement<'a>) -> ConcreteVisitResult { + match self.visitor.visit_block(input) { + VisitResult::VisitChildren => { + for statement in input.statements.iter() { + self.visit_statement(statement)?; + } + Ok(()) + } + x => x.into(), + } + } + + pub fn visit_conditional_statement(&mut self, input: &ConditionalStatement<'a>) -> ConcreteVisitResult { + match self.visitor.visit_conditional_statement(input) { + VisitResult::VisitChildren => { + self.visit_expression(&input.condition)?; + self.visit_statement(&input.result)?; + self.visit_opt_statement(&input.next)?; + Ok(()) + } + x => x.into(), + } + } + + pub fn visit_formatted_string(&mut self, input: &FormattedString<'a>) -> ConcreteVisitResult { + match self.visitor.visit_formatted_string(input) { + VisitResult::VisitChildren => { + for parameter in input.parameters.iter() { + self.visit_expression(parameter)?; + } + Ok(()) + } + x => x.into(), + } + } + + pub fn visit_console(&mut self, input: &ConsoleStatement<'a>) -> ConcreteVisitResult { + match self.visitor.visit_console(input) { + VisitResult::VisitChildren => { + match &input.function { + ConsoleFunction::Assert(e) => self.visit_expression(e)?, + ConsoleFunction::Debug(f) | ConsoleFunction::Error(f) | ConsoleFunction::Log(f) => { + self.visit_formatted_string(f)? + } + } + Ok(()) + } + x => x.into(), + } + } + + pub fn visit_definition(&mut self, input: &DefinitionStatement<'a>) -> ConcreteVisitResult { + match self.visitor.visit_definition(input) { + VisitResult::VisitChildren => { + self.visit_expression(&input.value)?; + Ok(()) + } + x => x.into(), + } + } + + pub fn visit_expression_statement(&mut self, input: &ExpressionStatement<'a>) -> ConcreteVisitResult { + match self.visitor.visit_expression_statement(input) { + VisitResult::VisitChildren => { + self.visit_expression(&input.expression)?; + Ok(()) + } + x => x.into(), + } + } + + pub fn visit_iteration(&mut self, input: &IterationStatement<'a>) -> ConcreteVisitResult { + match self.visitor.visit_iteration(input) { + VisitResult::VisitChildren => { + self.visit_expression(&input.start)?; + self.visit_expression(&input.stop)?; + self.visit_statement(&input.body)?; + Ok(()) + } + x => x.into(), + } + } + + pub fn visit_return(&mut self, input: &ReturnStatement<'a>) -> ConcreteVisitResult { + match self.visitor.visit_return(input) { + VisitResult::VisitChildren => { + self.visit_expression(&input.expression)?; + Ok(()) + } + x => x.into(), + } + } +} + +#[allow(dead_code)] +impl<'a, R: ProgramVisitor<'a>> VisitorDirector<'a, R> { + fn visit_function(&mut self, input: &'a Function<'a>) -> ConcreteVisitResult { + match self.visitor.visit_function(input) { + VisitResult::VisitChildren => { + self.visit_opt_statement(&input.body)?; + Ok(()) + } + x => x.into(), + } + } + + fn visit_circuit_member(&mut self, input: &CircuitMember<'a>) -> ConcreteVisitResult { + match self.visitor.visit_circuit_member(input) { + VisitResult::VisitChildren => { + if let CircuitMember::Function(f) = input { + self.visit_function(f)?; + } + Ok(()) + } + x => x.into(), + } + } + + fn visit_circuit(&mut self, input: &'a Circuit<'a>) -> ConcreteVisitResult { + match self.visitor.visit_circuit(input) { + VisitResult::VisitChildren => { + for (_, member) in input.members.borrow().iter() { + self.visit_circuit_member(member)?; + } + Ok(()) + } + x => x.into(), + } + } + + fn visit_program(&mut self, input: &Program<'a>) -> ConcreteVisitResult { + match self.visitor.visit_program(input) { + VisitResult::VisitChildren => { + for (_, import) in input.imported_modules.iter() { + self.visit_program(import)?; + } + for (_, (function, _)) in input.test_functions.iter() { + self.visit_function(function)?; + } + for (_, function) in input.functions.iter() { + self.visit_function(function)?; + } + for (_, circuit) in input.circuits.iter() { + self.visit_circuit(circuit)?; + } + Ok(()) + } + x => x.into(), + } + } +} diff --git a/asg/src/scope.rs b/asg/src/scope.rs index 8c8d29e0ca..d08837172e 100644 --- a/asg/src/scope.rs +++ b/asg/src/scope.rs @@ -14,58 +14,98 @@ // You should have received a copy of the GNU General Public License // along with the Leo library. If not, see . -use crate::{AsgConvertError, Circuit, Function, Input, Type, Variable}; +use crate::{ArenaNode, AsgConvertError, Circuit, Expression, Function, Input, Statement, Type, Variable}; use indexmap::IndexMap; -use std::{cell::RefCell, sync::Arc}; +use std::cell::{Cell, RefCell}; +use typed_arena::Arena; use uuid::Uuid; /// An abstract data type that track the current bindings for variables, functions, and circuits. -#[derive(Debug)] -pub struct InnerScope { +#[derive(Clone)] +pub struct Scope<'a> { + pub arena: &'a Arena>, + /// The unique id of the scope. pub id: Uuid, /// The parent scope that this scope inherits. - pub parent_scope: Option, + pub parent_scope: Cell>>, /// The function definition that this scope occurs in. - pub function: Option>, + pub function: Cell>>, /// The circuit definition that this scope occurs in. - pub circuit_self: Option>, + pub circuit_self: Cell>>, /// Maps variable name => variable. - pub variables: IndexMap, + pub variables: RefCell>>, /// Maps function name => function. - pub functions: IndexMap>, + pub functions: RefCell>>, /// Maps circuit name => circuit. - pub circuits: IndexMap>, + pub circuits: RefCell>>, /// The main input to the program. - pub input: Option, + pub input: Cell>>, } -pub type Scope = Arc>; +#[allow(clippy::mut_from_ref)] +impl<'a> Scope<'a> { + pub fn alloc_expression(&'a self, expr: Expression<'a>) -> &'a mut Expression<'a> { + match self.arena.alloc(ArenaNode::Expression(expr)) { + ArenaNode::Expression(e) => e, + _ => unimplemented!(), + } + } + + pub fn alloc_statement(&'a self, statement: Statement<'a>) -> &'a mut Statement<'a> { + match self.arena.alloc(ArenaNode::Statement(statement)) { + ArenaNode::Statement(e) => e, + _ => unimplemented!(), + } + } + + pub fn alloc_variable(&'a self, variable: Variable<'a>) -> &'a mut Variable<'a> { + match self.arena.alloc(ArenaNode::Variable(variable)) { + ArenaNode::Variable(e) => e, + _ => unimplemented!(), + } + } + + pub fn alloc_scope(&'a self, scope: Scope<'a>) -> &'a mut Scope<'a> { + match self.arena.alloc(ArenaNode::Scope(scope)) { + ArenaNode::Scope(e) => e, + _ => unimplemented!(), + } + } + + pub fn alloc_circuit(&'a self, circuit: Circuit<'a>) -> &'a mut Circuit<'a> { + match self.arena.alloc(ArenaNode::Circuit(circuit)) { + ArenaNode::Circuit(e) => e, + _ => unimplemented!(), + } + } + + pub fn alloc_function(&'a self, function: Function<'a>) -> &'a mut Function<'a> { + match self.arena.alloc(ArenaNode::Function(function)) { + ArenaNode::Function(e) => e, + _ => unimplemented!(), + } + } -impl InnerScope { /// /// Returns a reference to the variable corresponding to the name. /// /// If the current scope did not have this name present, then the parent scope is checked. /// If there is no parent scope, then `None` is returned. /// - pub fn resolve_variable(&self, name: &str) -> Option { - if let Some(resolved) = self.variables.get(name) { - Some(resolved.clone()) - } else if let Some(resolved) = self.parent_scope.as_ref() { - if let Some(resolved) = resolved.borrow().resolve_variable(name) { - Some(resolved) - } else { - None - } + pub fn resolve_variable(&self, name: &str) -> Option<&'a Variable<'a>> { + if let Some(resolved) = self.variables.borrow().get(name) { + Some(*resolved) + } else if let Some(scope) = self.parent_scope.get() { + scope.resolve_variable(name) } else { None } @@ -77,15 +117,11 @@ impl InnerScope { /// If the current scope did not have a function present, then the parent scope is checked. /// If there is no parent scope, then `None` is returned. /// - pub fn resolve_current_function(&self) -> Option> { - if let Some(resolved) = self.function.as_ref() { - Some(resolved.clone()) - } else if let Some(resolved) = self.parent_scope.as_ref() { - if let Some(resolved) = resolved.borrow().resolve_current_function() { - Some(resolved) - } else { - None - } + pub fn resolve_current_function(&self) -> Option<&'a Function> { + if let Some(resolved) = self.function.get() { + Some(resolved) + } else if let Some(scope) = self.parent_scope.get() { + scope.resolve_current_function() } else { None } @@ -97,15 +133,11 @@ impl InnerScope { /// If the current scope did not have an input present, then the parent scope is checked. /// If there is no parent scope, then `None` is returned. /// - pub fn resolve_input(&self) -> Option { - if let Some(input) = self.input.as_ref() { - Some(input.clone()) - } else if let Some(resolved) = self.parent_scope.as_ref() { - if let Some(resolved) = resolved.borrow().resolve_input() { - Some(resolved) - } else { - None - } + pub fn resolve_input(&self) -> Option> { + if let Some(input) = self.input.get() { + Some(input) + } else if let Some(resolved) = self.parent_scope.get() { + resolved.resolve_input() } else { None } @@ -117,15 +149,11 @@ impl InnerScope { /// If the current scope did not have this name present, then the parent scope is checked. /// If there is no parent scope, then `None` is returned. /// - pub fn resolve_function(&self, name: &str) -> Option> { - if let Some(resolved) = self.functions.get(name) { - Some(resolved.clone()) - } else if let Some(resolved) = self.parent_scope.as_ref() { - if let Some(resolved) = resolved.borrow().resolve_function(name) { - Some(resolved) - } else { - None - } + pub fn resolve_function(&self, name: &str) -> Option<&'a Function<'a>> { + if let Some(resolved) = self.functions.borrow().get(name) { + Some(*resolved) + } else if let Some(resolved) = self.parent_scope.get() { + resolved.resolve_function(name) } else { None } @@ -137,17 +165,13 @@ impl InnerScope { /// If the current scope did not have this name present, then the parent scope is checked. /// If there is no parent scope, then `None` is returned. /// - pub fn resolve_circuit(&self, name: &str) -> Option> { - if let Some(resolved) = self.circuits.get(name) { - Some(resolved.clone()) - } else if name == "Self" && self.circuit_self.is_some() { - self.circuit_self.clone() - } else if let Some(resolved) = self.parent_scope.as_ref() { - if let Some(resolved) = resolved.borrow().resolve_circuit(name) { - Some(resolved) - } else { - None - } + pub fn resolve_circuit(&self, name: &str) -> Option<&'a Circuit<'a>> { + if let Some(resolved) = self.circuits.borrow().get(name) { + Some(*resolved) + } else if name == "Self" && self.circuit_self.get().is_some() { + self.circuit_self.get() + } else if let Some(resolved) = self.parent_scope.get() { + resolved.resolve_circuit(name) } else { None } @@ -159,15 +183,11 @@ impl InnerScope { /// If the current scope did not have a circuit self present, then the parent scope is checked. /// If there is no parent scope, then `None` is returned. /// - pub fn resolve_circuit_self(&self) -> Option> { - if let Some(resolved) = self.circuit_self.as_ref() { - Some(resolved.clone()) - } else if let Some(resolved) = self.parent_scope.as_ref() { - if let Some(resolved) = resolved.borrow().resolve_circuit_self() { - Some(resolved) - } else { - None - } + pub fn resolve_circuit_self(&self) -> Option<&'a Circuit<'a>> { + if let Some(resolved) = self.circuit_self.get() { + Some(resolved) + } else if let Some(resolved) = self.parent_scope.get() { + resolved.resolve_circuit_self() } else { None } @@ -176,23 +196,24 @@ impl InnerScope { /// /// Returns a new scope given a parent scope. /// - pub fn make_subscope(scope: &Scope) -> Scope { - Arc::new(RefCell::new(InnerScope { + pub fn make_subscope(self: &'a Scope<'a>) -> &'a Scope<'a> { + self.alloc_scope(Scope::<'a> { + arena: self.arena, id: Uuid::new_v4(), - parent_scope: Some(scope.clone()), - circuit_self: None, - variables: IndexMap::new(), - functions: IndexMap::new(), - circuits: IndexMap::new(), - function: None, - input: None, - })) + parent_scope: Cell::new(Some(self)), + circuit_self: Cell::new(None), + variables: RefCell::new(IndexMap::new()), + functions: RefCell::new(IndexMap::new()), + circuits: RefCell::new(IndexMap::new()), + function: Cell::new(None), + input: Cell::new(None), + }) } /// /// Returns the type returned by the current scope. /// - pub fn resolve_ast_type(&self, type_: &leo_ast::Type) -> Result { + pub fn resolve_ast_type(&self, type_: &leo_ast::Type) -> Result, AsgConvertError> { use leo_ast::Type::*; Ok(match type_ { Address => Type::Address, diff --git a/asg/src/statement/assign.rs b/asg/src/statement/assign.rs index 10d7be7dc2..1eda9119ba 100644 --- a/asg/src/statement/assign.rs +++ b/asg/src/statement/assign.rs @@ -35,49 +35,49 @@ use crate::{ pub use leo_ast::AssignOperation; use leo_ast::AssigneeAccess as AstAssigneeAccess; -use std::sync::{Arc, Weak}; +use std::cell::Cell; -#[derive(Debug)] -pub enum AssignAccess { - ArrayRange(Option>, Option>), - ArrayIndex(Arc), +#[derive(Clone)] +pub enum AssignAccess<'a> { + ArrayRange(Cell>>, Cell>>), + ArrayIndex(Cell<&'a Expression<'a>>), Tuple(usize), Member(Identifier), } -#[derive(Debug)] -pub struct AssignStatement { - pub parent: Option>, +#[derive(Clone)] +pub struct AssignStatement<'a> { + pub parent: Cell>>, pub span: Option, pub operation: AssignOperation, - pub target_variable: Variable, - pub target_accesses: Vec, - pub value: Arc, + pub target_variable: Cell<&'a Variable<'a>>, + pub target_accesses: Vec>, + pub value: Cell<&'a Expression<'a>>, } -impl Node for AssignStatement { +impl<'a> Node for AssignStatement<'a> { fn span(&self) -> Option<&Span> { self.span.as_ref() } } -impl FromAst for Arc { +impl<'a> FromAst<'a, leo_ast::AssignStatement> for &'a Statement<'a> { fn from_ast( - scope: &Scope, + scope: &'a Scope<'a>, statement: &leo_ast::AssignStatement, - _expected_type: Option, - ) -> Result, AsgConvertError> { + _expected_type: Option>, + ) -> Result { let (name, span) = (&statement.assignee.identifier.name, &statement.assignee.identifier.span); let variable = if name == "input" { - if let Some(function) = scope.borrow().resolve_current_function() { + if let Some(function) = scope.resolve_current_function() { if !function.has_input { return Err(AsgConvertError::unresolved_reference(name, span)); } } else { return Err(AsgConvertError::unresolved_reference(name, span)); } - if let Some(input) = scope.borrow().resolve_input() { + if let Some(input) = scope.resolve_input() { input.container } else { return Err(AsgConvertError::InternalError( @@ -86,7 +86,6 @@ impl FromAst for Arc { } } else { scope - .borrow() .resolve_variable(&name) .ok_or_else(|| AsgConvertError::unresolved_reference(name, span))? }; @@ -94,7 +93,7 @@ impl FromAst for Arc { if !variable.borrow().mutable { return Err(AsgConvertError::immutable_assignment(&name, &statement.span)); } - let mut target_type: Option = Some(variable.borrow().type_.clone().strong().into()); + let mut target_type: Option = Some(variable.borrow().type_.clone().into()); let mut target_accesses = vec![]; for access in statement.assignee.accesses.iter() { @@ -104,16 +103,16 @@ impl FromAst for Arc { let left = left .as_ref() .map( - |left: &leo_ast::Expression| -> Result, AsgConvertError> { - Arc::::from_ast(scope, left, index_type.clone()) + |left: &leo_ast::Expression| -> Result<&'a Expression<'a>, AsgConvertError> { + <&Expression<'a>>::from_ast(scope, left, index_type.clone()) }, ) .transpose()?; let right = right .as_ref() .map( - |right: &leo_ast::Expression| -> Result, AsgConvertError> { - Arc::::from_ast(scope, right, index_type) + |right: &leo_ast::Expression| -> Result<&'a Expression<'a>, AsgConvertError> { + <&Expression<'a>>::from_ast(scope, right, index_type) }, ) .transpose()?; @@ -156,18 +155,18 @@ impl FromAst for Arc { _ => return Err(AsgConvertError::index_into_non_array(&name, &statement.span)), } - AssignAccess::ArrayRange(left, right) + AssignAccess::ArrayRange(Cell::new(left), Cell::new(right)) } AstAssigneeAccess::ArrayIndex(index) => { target_type = match target_type.clone() { Some(PartialType::Array(item, _)) => item.map(|x| *x), _ => return Err(AsgConvertError::index_into_non_array(&name, &statement.span)), }; - AssignAccess::ArrayIndex(Arc::::from_ast( + AssignAccess::ArrayIndex(Cell::new(<&Expression<'a>>::from_ast( scope, index, Some(PartialType::Integer(None, Some(IntegerType::U32))), - )?) + )?)) } AstAssigneeAccess::Tuple(index, _) => { let index = index @@ -203,7 +202,7 @@ impl FromAst for Arc { return Err(AsgConvertError::illegal_function_assign(&name.name, &statement.span)); } }; - Some(x.strong().partial()) + Some(x.partial()) } _ => { return Err(AsgConvertError::index_into_non_tuple( @@ -216,41 +215,40 @@ impl FromAst for Arc { } }); } - let value = Arc::::from_ast(scope, &statement.value, target_type)?; + let value = <&Expression<'a>>::from_ast(scope, &statement.value, target_type)?; - let statement = Arc::new(Statement::Assign(AssignStatement { - parent: None, + let statement = scope.alloc_statement(Statement::Assign(AssignStatement { + parent: Cell::new(None), span: Some(statement.span.clone()), operation: statement.operation.clone(), - target_variable: variable.clone(), + target_variable: Cell::new(variable), target_accesses, - value, + value: Cell::new(value), })); { let mut variable = variable.borrow_mut(); - variable.assignments.push(Arc::downgrade(&statement)); + variable.assignments.push(statement); } Ok(statement) } } -impl Into for &AssignStatement { +impl<'a> Into for &AssignStatement<'a> { fn into(self) -> leo_ast::AssignStatement { leo_ast::AssignStatement { operation: self.operation.clone(), assignee: leo_ast::Assignee { - identifier: self.target_variable.borrow().name.clone(), + identifier: self.target_variable.get().borrow().name.clone(), accesses: self .target_accesses .iter() .map(|access| match access { - AssignAccess::ArrayRange(left, right) => AstAssigneeAccess::ArrayRange( - left.as_ref().map(|e| e.as_ref().into()), - right.as_ref().map(|e| e.as_ref().into()), - ), - AssignAccess::ArrayIndex(index) => AstAssigneeAccess::ArrayIndex(index.as_ref().into()), + AssignAccess::ArrayRange(left, right) => { + AstAssigneeAccess::ArrayRange(left.get().map(|e| e.into()), right.get().map(|e| e.into())) + } + AssignAccess::ArrayIndex(index) => AstAssigneeAccess::ArrayIndex(index.get().into()), AssignAccess::Tuple(index) => AstAssigneeAccess::Tuple( leo_ast::PositiveNumber { value: index.to_string(), @@ -262,7 +260,7 @@ impl Into for &AssignStatement { .collect(), span: self.span.clone().unwrap_or_default(), }, - value: self.value.as_ref().into(), + value: self.value.get().into(), span: self.span.clone().unwrap_or_default(), } } diff --git a/asg/src/statement/block.rs b/asg/src/statement/block.rs index ce417e24eb..9a64fcdfb9 100644 --- a/asg/src/statement/block.rs +++ b/asg/src/statement/block.rs @@ -14,38 +14,38 @@ // You should have received a copy of the GNU General Public License // along with the Leo library. If not, see . -use crate::{AsgConvertError, FromAst, InnerScope, Node, PartialType, Scope, Span, Statement}; +use crate::{AsgConvertError, FromAst, Node, PartialType, Scope, Span, Statement}; -use std::sync::{Arc, Weak}; +use std::cell::Cell; -#[derive(Debug)] -pub struct BlockStatement { - pub parent: Option>, +#[derive(Clone)] +pub struct BlockStatement<'a> { + pub parent: Cell>>, pub span: Option, - pub statements: Vec>, - pub scope: Scope, + pub statements: Vec>>, + pub scope: &'a Scope<'a>, } -impl Node for BlockStatement { +impl<'a> Node for BlockStatement<'a> { fn span(&self) -> Option<&Span> { self.span.as_ref() } } -impl FromAst for BlockStatement { +impl<'a> FromAst<'a, leo_ast::Block> for BlockStatement<'a> { fn from_ast( - scope: &Scope, + scope: &'a Scope<'a>, statement: &leo_ast::Block, - _expected_type: Option, + _expected_type: Option>, ) -> Result { - let new_scope = InnerScope::make_subscope(scope); + let new_scope = scope.make_subscope(); let mut output = vec![]; for item in statement.statements.iter() { - output.push(Arc::::from_ast(&new_scope, item, None)?); + output.push(Cell::new(<&'a Statement<'a>>::from_ast(&new_scope, item, None)?)); } Ok(BlockStatement { - parent: None, + parent: Cell::new(None), span: Some(statement.span.clone()), statements: output, scope: new_scope, @@ -53,14 +53,10 @@ impl FromAst for BlockStatement { } } -impl Into for &BlockStatement { +impl<'a> Into for &BlockStatement<'a> { fn into(self) -> leo_ast::Block { leo_ast::Block { - statements: self - .statements - .iter() - .map(|statement| statement.as_ref().into()) - .collect(), + statements: self.statements.iter().map(|statement| statement.get().into()).collect(), span: self.span.clone().unwrap_or_default(), } } diff --git a/asg/src/statement/conditional.rs b/asg/src/statement/conditional.rs index 449d995ea7..6f0b5f09b7 100644 --- a/asg/src/statement/conditional.rs +++ b/asg/src/statement/conditional.rs @@ -16,31 +16,31 @@ use crate::{AsgConvertError, BlockStatement, Expression, FromAst, Node, PartialType, Scope, Span, Statement, Type}; -use std::sync::{Arc, Weak}; +use std::cell::Cell; -#[derive(Debug)] -pub struct ConditionalStatement { - pub parent: Option>, +#[derive(Clone)] +pub struct ConditionalStatement<'a> { + pub parent: Cell>>, pub span: Option, - pub condition: Arc, - pub result: Arc, - pub next: Option>, + pub condition: Cell<&'a Expression<'a>>, + pub result: Cell<&'a Statement<'a>>, + pub next: Cell>>, } -impl Node for ConditionalStatement { +impl<'a> Node for ConditionalStatement<'a> { fn span(&self) -> Option<&Span> { self.span.as_ref() } } -impl FromAst for ConditionalStatement { +impl<'a> FromAst<'a, leo_ast::ConditionalStatement> for ConditionalStatement<'a> { fn from_ast( - scope: &Scope, + scope: &'a Scope<'a>, statement: &leo_ast::ConditionalStatement, - _expected_type: Option, + _expected_type: Option>, ) -> Result { - let condition = Arc::::from_ast(scope, &statement.condition, Some(Type::Boolean.into()))?; - let result = Arc::new(Statement::Block(BlockStatement::from_ast( + let condition = <&Expression<'a>>::from_ast(scope, &statement.condition, Some(Type::Boolean.into()))?; + let result = scope.alloc_statement(Statement::Block(BlockStatement::from_ast( scope, &statement.block, None, @@ -48,28 +48,30 @@ impl FromAst for ConditionalStatement { let next = statement .next .as_deref() - .map(|next| -> Result, AsgConvertError> { Arc::::from_ast(scope, next, None) }) + .map(|next| -> Result<&'a Statement<'a>, AsgConvertError> { + <&'a Statement<'a>>::from_ast(scope, next, None) + }) .transpose()?; Ok(ConditionalStatement { - parent: None, + parent: Cell::new(None), span: Some(statement.span.clone()), - condition, - result, - next, + condition: Cell::new(condition), + result: Cell::new(result), + next: Cell::new(next), }) } } -impl Into for &ConditionalStatement { +impl<'a> Into for &ConditionalStatement<'a> { fn into(self) -> leo_ast::ConditionalStatement { leo_ast::ConditionalStatement { - condition: self.condition.as_ref().into(), - block: match self.result.as_ref() { + condition: self.condition.get().into(), + block: match self.result.get() { Statement::Block(block) => block.into(), _ => unimplemented!(), }, - next: self.next.as_deref().map(|e| Box::new(e.into())), + next: self.next.get().as_deref().map(|e| Box::new(e.into())), span: self.span.clone().unwrap_or_default(), } } diff --git a/asg/src/statement/console.rs b/asg/src/statement/console.rs index 3c0b8af10b..fdcf2858e5 100644 --- a/asg/src/statement/console.rs +++ b/asg/src/statement/console.rs @@ -17,43 +17,43 @@ use crate::{AsgConvertError, Expression, FromAst, Node, PartialType, Scope, Span, Statement, Type}; use leo_ast::ConsoleFunction as AstConsoleFunction; -use std::sync::{Arc, Weak}; +use std::cell::Cell; // TODO (protryon): Refactor to not require/depend on span -#[derive(Debug)] -pub struct FormattedString { +#[derive(Clone)] +pub struct FormattedString<'a> { pub string: String, pub containers: Vec, - pub parameters: Vec>, + pub parameters: Vec>>, pub span: Span, } -#[derive(Debug)] -pub enum ConsoleFunction { - Assert(Arc), - Debug(FormattedString), - Error(FormattedString), - Log(FormattedString), +#[derive(Clone)] +pub enum ConsoleFunction<'a> { + Assert(Cell<&'a Expression<'a>>), + Debug(FormattedString<'a>), + Error(FormattedString<'a>), + Log(FormattedString<'a>), } -#[derive(Debug)] -pub struct ConsoleStatement { - pub parent: Option>, +#[derive(Clone)] +pub struct ConsoleStatement<'a> { + pub parent: Cell>>, pub span: Option, - pub function: ConsoleFunction, + pub function: ConsoleFunction<'a>, } -impl Node for ConsoleStatement { +impl<'a> Node for ConsoleStatement<'a> { fn span(&self) -> Option<&Span> { self.span.as_ref() } } -impl FromAst for FormattedString { +impl<'a> FromAst<'a, leo_ast::FormattedString> for FormattedString<'a> { fn from_ast( - scope: &Scope, + scope: &'a Scope<'a>, value: &leo_ast::FormattedString, - _expected_type: Option, + _expected_type: Option>, ) -> Result { if value.parameters.len() != value.containers.len() { // + 1 for formatting string as to not confuse user @@ -65,7 +65,7 @@ impl FromAst for FormattedString { } let mut parameters = vec![]; for parameter in value.parameters.iter() { - parameters.push(Arc::::from_ast(scope, parameter, None)?); + parameters.push(Cell::new(<&Expression<'a>>::from_ast(scope, parameter, None)?)); } Ok(FormattedString { string: value.string.clone(), @@ -76,7 +76,7 @@ impl FromAst for FormattedString { } } -impl Into for &FormattedString { +impl<'a> Into for &FormattedString<'a> { fn into(self) -> leo_ast::FormattedString { leo_ast::FormattedString { string: self.string.clone(), @@ -85,27 +85,25 @@ impl Into for &FormattedString { .iter() .map(|span| leo_ast::FormattedContainer { span: span.clone() }) .collect(), - parameters: self.parameters.iter().map(|e| e.as_ref().into()).collect(), + parameters: self.parameters.iter().map(|e| e.get().into()).collect(), span: self.span.clone(), } } } -impl FromAst for ConsoleStatement { +impl<'a> FromAst<'a, leo_ast::ConsoleStatement> for ConsoleStatement<'a> { fn from_ast( - scope: &Scope, + scope: &'a Scope<'a>, statement: &leo_ast::ConsoleStatement, - _expected_type: Option, + _expected_type: Option>, ) -> Result { Ok(ConsoleStatement { - parent: None, + parent: Cell::new(None), span: Some(statement.span.clone()), function: match &statement.function { - AstConsoleFunction::Assert(expression) => ConsoleFunction::Assert(Arc::::from_ast( - scope, - expression, - Some(Type::Boolean.into()), - )?), + AstConsoleFunction::Assert(expression) => ConsoleFunction::Assert(Cell::new( + <&Expression<'a>>::from_ast(scope, expression, Some(Type::Boolean.into()))?, + )), AstConsoleFunction::Debug(formatted_string) => { ConsoleFunction::Debug(FormattedString::from_ast(scope, formatted_string, None)?) } @@ -120,12 +118,12 @@ impl FromAst for ConsoleStatement { } } -impl Into for &ConsoleStatement { +impl<'a> Into for &ConsoleStatement<'a> { fn into(self) -> leo_ast::ConsoleStatement { use ConsoleFunction::*; leo_ast::ConsoleStatement { function: match &self.function { - Assert(e) => AstConsoleFunction::Assert(e.as_ref().into()), + Assert(e) => AstConsoleFunction::Assert(e.get().into()), Debug(formatted_string) => AstConsoleFunction::Debug(formatted_string.into()), Error(formatted_string) => AstConsoleFunction::Error(formatted_string.into()), Log(formatted_string) => AstConsoleFunction::Log(formatted_string.into()), diff --git a/asg/src/statement/definition.rs b/asg/src/statement/definition.rs index 438b26b29c..846e34cc3a 100644 --- a/asg/src/statement/definition.rs +++ b/asg/src/statement/definition.rs @@ -29,38 +29,35 @@ use crate::{ Variable, }; -use std::{ - cell::RefCell, - sync::{Arc, Weak}, -}; +use std::cell::{Cell, RefCell}; -#[derive(Debug)] -pub struct DefinitionStatement { - pub parent: Option>, +#[derive(Clone)] +pub struct DefinitionStatement<'a> { + pub parent: Cell>>, pub span: Option, - pub variables: Vec, - pub value: Arc, + pub variables: Vec<&'a Variable<'a>>, + pub value: Cell<&'a Expression<'a>>, } -impl Node for DefinitionStatement { +impl<'a> Node for DefinitionStatement<'a> { fn span(&self) -> Option<&Span> { self.span.as_ref() } } -impl FromAst for Arc { +impl<'a> FromAst<'a, leo_ast::DefinitionStatement> for &'a Statement<'a> { fn from_ast( - scope: &Scope, + scope: &'a Scope<'a>, statement: &leo_ast::DefinitionStatement, - _expected_type: Option, - ) -> Result, AsgConvertError> { + _expected_type: Option>, + ) -> Result { let type_ = statement .type_ .as_ref() - .map(|x| scope.borrow().resolve_ast_type(&x)) + .map(|x| scope.resolve_ast_type(&x)) .transpose()?; - let value = Arc::::from_ast(scope, &statement.value, type_.clone().map(Into::into))?; + let value = <&Expression<'a>>::from_ast(scope, &statement.value, type_.clone().map(Into::into))?; let type_ = type_.or_else(|| value.get_type()); @@ -95,12 +92,11 @@ impl FromAst for Arc { if statement.declaration_type == leo_ast::Declare::Const && variable.mutable { return Err(AsgConvertError::illegal_ast_structure("cannot have const mut")); } - variables.push(Arc::new(RefCell::new(InnerVariable { + variables.push(&*scope.alloc_variable(RefCell::new(InnerVariable { id: uuid::Uuid::new_v4(), name: variable.identifier.clone(), - type_: type_ - .ok_or_else(|| AsgConvertError::unresolved_type(&variable.identifier.name, &statement.span))? - .weak(), + type_: + type_.ok_or_else(|| AsgConvertError::unresolved_type(&variable.identifier.name, &statement.span))?, mutable: variable.mutable, const_: false, declaration: crate::VariableDeclaration::Definition, @@ -109,31 +105,29 @@ impl FromAst for Arc { }))); } - { - let mut scope_borrow = scope.borrow_mut(); - for variable in variables.iter() { - scope_borrow - .variables - .insert(variable.borrow().name.name.clone(), variable.clone()); - } + for variable in variables.iter() { + scope + .variables + .borrow_mut() + .insert(variable.borrow().name.name.clone(), *variable); } - let statement = Arc::new(Statement::Definition(DefinitionStatement { - parent: None, + let statement = scope.alloc_statement(Statement::Definition(DefinitionStatement { + parent: Cell::new(None), span: Some(statement.span.clone()), variables: variables.clone(), - value, + value: Cell::new(value), })); - variables.iter().for_each(|variable| { - variable.borrow_mut().assignments.push(Arc::downgrade(&statement)); - }); + for variable in variables { + variable.borrow_mut().assignments.push(statement); + } Ok(statement) } } -impl Into for &DefinitionStatement { +impl<'a> Into for &DefinitionStatement<'a> { fn into(self) -> leo_ast::DefinitionStatement { assert!(!self.variables.is_empty()); @@ -147,7 +141,7 @@ impl Into for &DefinitionStatement { span: variable.name.span.clone(), }); if type_.is_none() { - type_ = Some((&variable.type_.clone().strong()).into()); + type_ = Some((&variable.type_.clone()).into()); } } @@ -155,7 +149,7 @@ impl Into for &DefinitionStatement { declaration_type: leo_ast::Declare::Let, variable_names, type_, - value: self.value.as_ref().into(), + value: self.value.get().into(), span: self.span.clone().unwrap_or_default(), } } diff --git a/asg/src/statement/expression.rs b/asg/src/statement/expression.rs index 03f5c34689..f00c55c488 100644 --- a/asg/src/statement/expression.rs +++ b/asg/src/statement/expression.rs @@ -16,41 +16,41 @@ use crate::{AsgConvertError, Expression, FromAst, Node, PartialType, Scope, Span, Statement}; -use std::sync::{Arc, Weak}; +use std::cell::Cell; -#[derive(Debug)] -pub struct ExpressionStatement { - pub parent: Option>, +#[derive(Clone)] +pub struct ExpressionStatement<'a> { + pub parent: Cell>>, pub span: Option, - pub expression: Arc, + pub expression: Cell<&'a Expression<'a>>, } -impl Node for ExpressionStatement { +impl<'a> Node for ExpressionStatement<'a> { fn span(&self) -> Option<&Span> { self.span.as_ref() } } -impl FromAst for ExpressionStatement { +impl<'a> FromAst<'a, leo_ast::ExpressionStatement> for ExpressionStatement<'a> { fn from_ast( - scope: &Scope, + scope: &'a Scope<'a>, statement: &leo_ast::ExpressionStatement, - _expected_type: Option, + _expected_type: Option>, ) -> Result { - let expression = Arc::::from_ast(scope, &statement.expression, None)?; + let expression = <&Expression<'a>>::from_ast(scope, &statement.expression, None)?; Ok(ExpressionStatement { - parent: None, + parent: Cell::new(None), span: Some(statement.span.clone()), - expression, + expression: Cell::new(expression), }) } } -impl Into for &ExpressionStatement { +impl<'a> Into for &ExpressionStatement<'a> { fn into(self) -> leo_ast::ExpressionStatement { leo_ast::ExpressionStatement { - expression: self.expression.as_ref().into(), + expression: self.expression.get().into(), span: self.span.clone().unwrap_or_default(), } } diff --git a/asg/src/statement/iteration.rs b/asg/src/statement/iteration.rs index bd19f0ab68..76a095712f 100644 --- a/asg/src/statement/iteration.rs +++ b/asg/src/statement/iteration.rs @@ -30,43 +30,39 @@ use crate::{ Variable, }; -use std::{ - cell::RefCell, - sync::{Arc, Weak}, -}; +use std::cell::{Cell, RefCell}; -#[derive(Debug)] -pub struct IterationStatement { - pub parent: Option>, +#[derive(Clone)] +pub struct IterationStatement<'a> { + pub parent: Cell>>, pub span: Option, - pub variable: Variable, - pub start: Arc, - pub stop: Arc, - pub body: Arc, + pub variable: &'a Variable<'a>, + pub start: Cell<&'a Expression<'a>>, + pub stop: Cell<&'a Expression<'a>>, + pub body: Cell<&'a Statement<'a>>, } -impl Node for IterationStatement { +impl<'a> Node for IterationStatement<'a> { fn span(&self) -> Option<&Span> { self.span.as_ref() } } -impl FromAst for Arc { +impl<'a> FromAst<'a, leo_ast::IterationStatement> for &'a Statement<'a> { fn from_ast( - scope: &Scope, + scope: &'a Scope<'a>, statement: &leo_ast::IterationStatement, - _expected_type: Option, - ) -> Result, AsgConvertError> { + _expected_type: Option>, + ) -> Result { let expected_index_type = Some(PartialType::Integer(None, Some(IntegerType::U32))); - let start = Arc::::from_ast(scope, &statement.start, expected_index_type.clone())?; - let stop = Arc::::from_ast(scope, &statement.stop, expected_index_type)?; - let variable = Arc::new(RefCell::new(InnerVariable { + let start = <&Expression<'a>>::from_ast(scope, &statement.start, expected_index_type.clone())?; + let stop = <&Expression<'a>>::from_ast(scope, &statement.stop, expected_index_type)?; + let variable = scope.alloc_variable(RefCell::new(InnerVariable { id: uuid::Uuid::new_v4(), name: statement.variable.clone(), type_: start .get_type() - .ok_or_else(|| AsgConvertError::unresolved_type(&statement.variable.name, &statement.span))? - .weak(), + .ok_or_else(|| AsgConvertError::unresolved_type(&statement.variable.name, &statement.span))?, mutable: false, const_: true, declaration: crate::VariableDeclaration::IterationDefinition, @@ -74,34 +70,34 @@ impl FromAst for Arc { assignments: vec![], })); scope - .borrow_mut() .variables - .insert(statement.variable.name.clone(), variable.clone()); + .borrow_mut() + .insert(statement.variable.name.clone(), variable); - let statement = Arc::new(Statement::Iteration(IterationStatement { - parent: None, + let statement = scope.alloc_statement(Statement::Iteration(IterationStatement { + parent: Cell::new(None), span: Some(statement.span.clone()), - variable: variable.clone(), - stop, - start, - body: Arc::new(Statement::Block(crate::BlockStatement::from_ast( + variable, + stop: Cell::new(stop), + start: Cell::new(start), + body: Cell::new(scope.alloc_statement(Statement::Block(crate::BlockStatement::from_ast( scope, &statement.block, None, - )?)), + )?))), })); - variable.borrow_mut().assignments.push(Arc::downgrade(&statement)); + variable.borrow_mut().assignments.push(statement); Ok(statement) } } -impl Into for &IterationStatement { +impl<'a> Into for &IterationStatement<'a> { fn into(self) -> leo_ast::IterationStatement { leo_ast::IterationStatement { variable: self.variable.borrow().name.clone(), - start: self.start.as_ref().into(), - stop: self.stop.as_ref().into(), - block: match self.body.as_ref() { + start: self.start.get().into(), + stop: self.stop.get().into(), + block: match self.body.get() { Statement::Block(block) => block.into(), _ => unimplemented!(), }, diff --git a/asg/src/statement/mod.rs b/asg/src/statement/mod.rs index d6d8432b32..bc924f2ae1 100644 --- a/asg/src/statement/mod.rs +++ b/asg/src/statement/mod.rs @@ -44,21 +44,19 @@ pub use return_::*; use crate::{AsgConvertError, FromAst, Node, PartialType, Scope, Span}; -use std::sync::Arc; - -#[derive(Debug)] -pub enum Statement { - Return(ReturnStatement), - Definition(DefinitionStatement), - Assign(AssignStatement), - Conditional(ConditionalStatement), - Iteration(IterationStatement), - Console(ConsoleStatement), - Expression(ExpressionStatement), - Block(BlockStatement), +#[derive(Clone)] +pub enum Statement<'a> { + Return(ReturnStatement<'a>), + Definition(DefinitionStatement<'a>), + Assign(AssignStatement<'a>), + Conditional(ConditionalStatement<'a>), + Iteration(IterationStatement<'a>), + Console(ConsoleStatement<'a>), + Expression(ExpressionStatement<'a>), + Block(BlockStatement<'a>), } -impl Node for Statement { +impl<'a> Node for Statement<'a> { fn span(&self) -> Option<&Span> { use Statement::*; match self { @@ -74,31 +72,37 @@ impl Node for Statement { } } -impl FromAst for Arc { +impl<'a> FromAst<'a, leo_ast::Statement> for &'a Statement<'a> { fn from_ast( - scope: &Scope, + scope: &'a Scope<'a>, value: &leo_ast::Statement, - _expected_type: Option, - ) -> Result, AsgConvertError> { + _expected_type: Option>, + ) -> Result<&'a Statement<'a>, AsgConvertError> { use leo_ast::Statement::*; Ok(match value { - Return(statement) => Arc::new(Statement::Return(ReturnStatement::from_ast(scope, statement, None)?)), - Definition(statement) => Arc::::from_ast(scope, statement, None)?, - Assign(statement) => Arc::::from_ast(scope, statement, None)?, - Conditional(statement) => Arc::new(Statement::Conditional(ConditionalStatement::from_ast( + Return(statement) => { + scope.alloc_statement(Statement::Return(ReturnStatement::from_ast(scope, statement, None)?)) + } + Definition(statement) => Self::from_ast(scope, statement, None)?, + Assign(statement) => Self::from_ast(scope, statement, None)?, + Conditional(statement) => scope.alloc_statement(Statement::Conditional(ConditionalStatement::from_ast( scope, statement, None, )?)), - Iteration(statement) => Arc::::from_ast(scope, statement, None)?, - Console(statement) => Arc::new(Statement::Console(ConsoleStatement::from_ast(scope, statement, None)?)), - Expression(statement) => Arc::new(Statement::Expression(ExpressionStatement::from_ast( + Iteration(statement) => Self::from_ast(scope, statement, None)?, + Console(statement) => { + scope.alloc_statement(Statement::Console(ConsoleStatement::from_ast(scope, statement, None)?)) + } + Expression(statement) => scope.alloc_statement(Statement::Expression(ExpressionStatement::from_ast( scope, statement, None, )?)), - Block(statement) => Arc::new(Statement::Block(BlockStatement::from_ast(scope, statement, None)?)), + Block(statement) => { + scope.alloc_statement(Statement::Block(BlockStatement::from_ast(scope, statement, None)?)) + } }) } } -impl Into for &Statement { +impl<'a> Into for &Statement<'a> { fn into(self) -> leo_ast::Statement { use Statement::*; match self { diff --git a/asg/src/statement/return_.rs b/asg/src/statement/return_.rs index 2a33e98cb7..fb6b640390 100644 --- a/asg/src/statement/return_.rs +++ b/asg/src/statement/return_.rs @@ -16,44 +16,46 @@ use crate::{AsgConvertError, Expression, FromAst, Node, PartialType, Scope, Span, Statement, Type}; -use std::sync::{Arc, Weak}; - -#[derive(Debug)] -pub struct ReturnStatement { - pub parent: Option>, +use std::cell::Cell; +#[derive(Clone)] +pub struct ReturnStatement<'a> { + pub parent: Cell>>, pub span: Option, - pub expression: Arc, + pub expression: Cell<&'a Expression<'a>>, } -impl Node for ReturnStatement { +impl<'a> Node for ReturnStatement<'a> { fn span(&self) -> Option<&Span> { self.span.as_ref() } } -impl FromAst for ReturnStatement { +impl<'a> FromAst<'a, leo_ast::ReturnStatement> for ReturnStatement<'a> { fn from_ast( - scope: &Scope, + scope: &'a Scope<'a>, statement: &leo_ast::ReturnStatement, - _expected_type: Option, + _expected_type: Option>, ) -> Result { let return_type: Option = scope - .borrow() .resolve_current_function() .map(|x| x.output.clone()) .map(Into::into); Ok(ReturnStatement { - parent: None, + parent: Cell::new(None), span: Some(statement.span.clone()), - expression: Arc::::from_ast(scope, &statement.expression, return_type.map(Into::into))?, + expression: Cell::new(<&Expression<'a>>::from_ast( + scope, + &statement.expression, + return_type.map(Into::into), + )?), }) } } -impl Into for &ReturnStatement { +impl<'a> Into for &ReturnStatement<'a> { fn into(self) -> leo_ast::ReturnStatement { leo_ast::ReturnStatement { - expression: self.expression.as_ref().into(), + expression: self.expression.get().into(), span: self.span.clone().unwrap_or_default(), } } diff --git a/asg/src/type_.rs b/asg/src/type_.rs index 078d9a6cf9..83c5f3d759 100644 --- a/asg/src/type_.rs +++ b/asg/src/type_.rs @@ -17,14 +17,11 @@ use crate::Circuit; pub use leo_ast::IntegerType; -use std::{ - fmt, - sync::{Arc, Weak}, -}; +use std::fmt; -/// A type in an ASG. -#[derive(Debug, Clone, PartialEq)] -pub enum Type { +/// A type in an asg. +#[derive(Clone, PartialEq)] +pub enum Type<'a> { // Data types Address, Boolean, @@ -33,55 +30,21 @@ pub enum Type { Integer(IntegerType), // Data type wrappers - Array(Box, usize), - Tuple(Vec), - Circuit(Arc), + Array(Box>, usize), + Tuple(Vec>), + Circuit(&'a Circuit<'a>), } -#[derive(Debug, Clone)] -pub enum WeakType { - Type(Type), // circuit not allowed - Circuit(Weak), -} - -#[derive(Debug, Clone, PartialEq)] -pub enum PartialType { - Type(Type), // non-array or tuple +#[derive(Clone, PartialEq)] +pub enum PartialType<'a> { + Type(Type<'a>), // non-array or tuple Integer(Option, Option), // specific, context-specific - Array(Option>, Option), - Tuple(Vec>), + Array(Option>>, Option), + Tuple(Vec>>), } -impl Into for WeakType { - fn into(self) -> Type { - match self { - WeakType::Type(t) => t, - WeakType::Circuit(circuit) => Type::Circuit(circuit.upgrade().unwrap()), - } - } -} - -impl WeakType { - pub fn strong(self) -> Type { - self.into() - } - - pub fn is_unit(&self) -> bool { - matches!(self, WeakType::Type(Type::Tuple(t)) if t.is_empty()) - } -} - -impl Into for Type { - fn into(self) -> WeakType { - match self { - Type::Circuit(circuit) => WeakType::Circuit(Arc::downgrade(&circuit)), - t => WeakType::Type(t), - } - } -} - -impl Into> for PartialType { - fn into(self) -> Option { +impl<'a> Into>> for PartialType<'a> { + fn into(self) -> Option> { match self { PartialType::Type(t) => Some(t), PartialType::Integer(sub_type, contextual_type) => Some(Type::Integer(sub_type.or(contextual_type)?)), @@ -96,12 +59,12 @@ impl Into> for PartialType { } } -impl PartialType { - pub fn full(self) -> Option { +impl<'a> PartialType<'a> { + pub fn full(self) -> Option> { self.into() } - pub fn matches(&self, other: &Type) -> bool { + pub fn matches(&self, other: &Type<'a>) -> bool { match (self, other) { (PartialType::Type(t), other) => t.is_assignable_from(other), (PartialType::Integer(self_sub_type, _), Type::Integer(sub_type)) => { @@ -137,8 +100,8 @@ impl PartialType { } } -impl Into for Type { - fn into(self) -> PartialType { +impl<'a> Into> for Type<'a> { + fn into(self) -> PartialType<'a> { match self { Type::Integer(sub_type) => PartialType::Integer(Some(sub_type), None), Type::Array(element, len) => PartialType::Array(Some(Box::new((*element).into())), Some(len)), @@ -148,16 +111,12 @@ impl Into for Type { } } -impl Type { - pub fn is_assignable_from(&self, from: &Type) -> bool { +impl<'a> Type<'a> { + pub fn is_assignable_from(&self, from: &Type<'a>) -> bool { self == from } - pub fn partial(self) -> PartialType { - self.into() - } - - pub fn weak(self) -> WeakType { + pub fn partial(self) -> PartialType<'a> { self.into() } @@ -166,7 +125,7 @@ impl Type { } } -impl fmt::Display for Type { +impl<'a> fmt::Display for Type<'a> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Type::Address => write!(f, "address"), @@ -190,7 +149,7 @@ impl fmt::Display for Type { } } -impl fmt::Display for PartialType { +impl<'a> fmt::Display for PartialType<'a> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { PartialType::Type(t) => t.fmt(f), @@ -230,7 +189,7 @@ impl fmt::Display for PartialType { } } -impl Into for &Type { +impl<'a> Into for &Type<'a> { fn into(self) -> leo_ast::Type { use Type::*; match self { diff --git a/asg/src/variable.rs b/asg/src/variable.rs index d519da53de..09bf9ca749 100644 --- a/asg/src/variable.rs +++ b/asg/src/variable.rs @@ -14,17 +14,15 @@ // You should have received a copy of the GNU General Public License // along with the Leo library. If not, see . -use crate::{Expression, Statement, WeakType}; +use std::cell::RefCell; + +use crate::{Expression, Statement, Type}; use leo_ast::Identifier; -use std::{ - cell::RefCell, - sync::{Arc, Weak}, -}; use uuid::Uuid; /// Specifies how a program variable was declared. -#[derive(Debug, PartialEq)] +#[derive(Clone, Copy, PartialEq)] pub enum VariableDeclaration { Definition, IterationDefinition, @@ -33,17 +31,16 @@ pub enum VariableDeclaration { } /// Stores information on a program variable. -#[derive(Debug)] -pub struct InnerVariable { +#[derive(Clone)] +pub struct InnerVariable<'a> { pub id: Uuid, pub name: Identifier, - pub type_: WeakType, + pub type_: Type<'a>, pub mutable: bool, pub const_: bool, // only function arguments, const var definitions NOT included pub declaration: VariableDeclaration, - pub references: Vec>, // all Expression::VariableRef or panic - pub assignments: Vec>, // all Statement::Assign or panic -- must be 1 if not mutable, or 0 if declaration == input | parameter + pub references: Vec<&'a Expression<'a>>, // all Expression::VariableRef or panic + pub assignments: Vec<&'a Statement<'a>>, // all Statement::Assign or panic -- must be 1 if not mutable, or 0 if declaration == input | parameter } -pub type Variable = Arc>; -pub type WeakVariable = Weak>; +pub type Variable<'a> = RefCell>; diff --git a/asg/tests/fail/address/mod.rs b/asg/tests/fail/address/mod.rs index 0ee77207c6..bd909d0f70 100644 --- a/asg/tests/fail/address/mod.rs +++ b/asg/tests/fail/address/mod.rs @@ -14,10 +14,12 @@ // You should have received a copy of the GNU General Public License // along with the Leo library. If not, see . +use leo_asg::new_context; + use crate::load_asg; #[test] fn test_implicit_invalid() { let program_string = include_str!("implicit_invalid.leo"); - load_asg(program_string).err().unwrap(); + load_asg(&new_context(), program_string).err().unwrap(); } diff --git a/asg/tests/fail/array/mod.rs b/asg/tests/fail/array/mod.rs index 1a9eaf4ae5..ecd8c89802 100644 --- a/asg/tests/fail/array/mod.rs +++ b/asg/tests/fail/array/mod.rs @@ -14,6 +14,8 @@ // You should have received a copy of the GNU General Public License // along with the Leo library. If not, see . +use leo_asg::new_context; + use crate::load_asg; // Expressions @@ -21,49 +23,49 @@ use crate::load_asg; #[test] fn test_initializer_fail() { let program_string = include_str!("initializer_fail.leo"); - load_asg(program_string).err().unwrap(); + load_asg(&new_context(), program_string).err().unwrap(); } #[test] fn test_input_nested_3x2_fail() { let program_string = include_str!("input_nested_3x2_fail.leo"); - load_asg(program_string).err().unwrap(); + load_asg(&new_context(), program_string).err().unwrap(); } #[test] fn test_input_tuple_3x2_fail() { let program_string = include_str!("input_tuple_3x2_fail.leo"); - load_asg(program_string).err().unwrap(); + load_asg(&new_context(), program_string).err().unwrap(); } #[test] fn test_multi_fail_initializer() { let program_string = include_str!("multi_fail_initializer.leo"); - load_asg(program_string).err().unwrap(); + load_asg(&new_context(), program_string).err().unwrap(); } #[test] fn test_multi_inline_fail() { let program_string = include_str!("multi_fail_inline.leo"); - load_asg(program_string).err().unwrap(); + load_asg(&new_context(), program_string).err().unwrap(); } #[test] fn test_multi_initializer_fail() { let program_string = include_str!("multi_initializer_fail.leo"); - load_asg(program_string).err().unwrap(); + load_asg(&new_context(), program_string).err().unwrap(); } #[test] fn test_nested_3x2_value_fail() { let program_string = include_str!("nested_3x2_value_fail.leo"); - load_asg(program_string).err().unwrap(); + load_asg(&new_context(), program_string).err().unwrap(); } #[test] fn test_tuple_3x2_value_fail() { let program_string = include_str!("tuple_3x2_value_fail.leo"); - load_asg(program_string).err().unwrap(); + load_asg(&new_context(), program_string).err().unwrap(); } // Array type tests @@ -71,65 +73,65 @@ fn test_tuple_3x2_value_fail() { #[test] fn test_type_fail() { let program_string = include_str!("type_fail.leo"); - load_asg(program_string).err().unwrap(); + load_asg(&new_context(), program_string).err().unwrap(); } #[test] fn test_type_nested_value_nested_3x2_fail() { let program_string = include_str!("type_nested_value_nested_3x2_fail.leo"); - load_asg(program_string).err().unwrap(); + load_asg(&new_context(), program_string).err().unwrap(); } #[test] fn test_type_nested_value_nested_4x3x2_fail() { let program_string = include_str!("type_nested_value_nested_4x3x2_fail.leo"); - load_asg(program_string).err().unwrap(); + load_asg(&new_context(), program_string).err().unwrap(); } #[test] fn test_type_nested_value_tuple_3x2_fail() { let program_string = include_str!("type_nested_value_tuple_3x2_fail.leo"); - load_asg(program_string).err().unwrap(); + load_asg(&new_context(), program_string).err().unwrap(); } #[test] fn test_type_nested_value_tuple_4x3x2_fail() { let program_string = include_str!("type_nested_value_tuple_4x3x2_fail.leo"); - load_asg(program_string).err().unwrap(); + load_asg(&new_context(), program_string).err().unwrap(); } #[test] fn test_type_tuple_value_nested_3x2_fail() { let program_string = include_str!("type_tuple_value_nested_3x2_fail.leo"); - load_asg(program_string).err().unwrap(); + load_asg(&new_context(), program_string).err().unwrap(); } #[test] fn test_type_tuple_value_nested_3x2_swap_fail() { let program_string = include_str!("type_tuple_value_nested_3x2_swap_fail.leo"); - load_asg(program_string).err().unwrap(); + load_asg(&new_context(), program_string).err().unwrap(); } #[test] fn test_type_tuple_value_nested_4x3x2_fail() { let program_string = include_str!("type_tuple_value_nested_4x3x2_fail.leo"); - load_asg(program_string).err().unwrap(); + load_asg(&new_context(), program_string).err().unwrap(); } #[test] fn test_type_tuple_value_tuple_3x2_fail() { let program_string = include_str!("type_tuple_value_tuple_3x2_fail.leo"); - load_asg(program_string).err().unwrap(); + load_asg(&new_context(), program_string).err().unwrap(); } #[test] fn test_type_tuple_value_tuple_3x2_swap_fail() { let program_string = include_str!("type_tuple_value_tuple_3x2_swap_fail.leo"); - load_asg(program_string).err().unwrap(); + load_asg(&new_context(), program_string).err().unwrap(); } #[test] fn test_type_tuple_value_tuple_4x3x2_fail() { let program_string = include_str!("type_tuple_value_tuple_4x3x2_fail.leo"); - load_asg(program_string).err().unwrap(); + load_asg(&new_context(), program_string).err().unwrap(); } diff --git a/asg/tests/fail/boolean/mod.rs b/asg/tests/fail/boolean/mod.rs index 4712d5b89c..8918bbdbde 100644 --- a/asg/tests/fail/boolean/mod.rs +++ b/asg/tests/fail/boolean/mod.rs @@ -14,22 +14,24 @@ // You should have received a copy of the GNU General Public License // along with the Leo library. If not, see . +use leo_asg::new_context; + use crate::load_asg; #[test] fn test_not_u32() { let program_string = include_str!("not_u32.leo"); - load_asg(program_string).err().unwrap(); + load_asg(&new_context(), program_string).err().unwrap(); } #[test] fn test_true_or_u32() { let program_string = include_str!("true_or_u32.leo"); - load_asg(program_string).err().unwrap(); + load_asg(&new_context(), program_string).err().unwrap(); } #[test] fn test_true_and_u32() { let program_string = include_str!("true_and_u32.leo"); - load_asg(program_string).err().unwrap(); + load_asg(&new_context(), program_string).err().unwrap(); } diff --git a/asg/tests/fail/circuits/mod.rs b/asg/tests/fail/circuits/mod.rs index f88d9ec42d..57338c7d4d 100644 --- a/asg/tests/fail/circuits/mod.rs +++ b/asg/tests/fail/circuits/mod.rs @@ -14,6 +14,8 @@ // You should have received a copy of the GNU General Public License // along with the Leo library. If not, see . +use leo_asg::new_context; + use crate::load_asg; // Expressions @@ -21,13 +23,13 @@ use crate::load_asg; #[test] fn test_inline_fail() { let program_string = include_str!("inline_fail.leo"); - load_asg(program_string).err().unwrap(); + load_asg(&new_context(), program_string).err().unwrap(); } #[test] fn test_inline_undefined() { let program_string = include_str!("inline_undefined.leo"); - load_asg(program_string).err().unwrap(); + load_asg(&new_context(), program_string).err().unwrap(); } // Members @@ -35,19 +37,19 @@ fn test_inline_undefined() { #[test] fn test_member_variable_fail() { let program_string = include_str!("member_variable_fail.leo"); - load_asg(program_string).err().unwrap(); + load_asg(&new_context(), program_string).err().unwrap(); } #[test] fn test_member_function_fail() { let program_string = include_str!("member_function_fail.leo"); - load_asg(program_string).err().unwrap(); + load_asg(&new_context(), program_string).err().unwrap(); } #[test] fn test_member_function_invalid() { let program_string = include_str!("member_function_invalid.leo"); - load_asg(program_string).err().unwrap(); + load_asg(&new_context(), program_string).err().unwrap(); } #[test] @@ -64,19 +66,19 @@ fn test_mut_member_function_fail() { console.assert(a.echo(1u32) == 1u32); }"#; - load_asg(program_string).err().unwrap(); + load_asg(&new_context(), program_string).err().unwrap(); } #[test] fn test_member_static_function_invalid() { let program_string = include_str!("member_static_function_invalid.leo"); - load_asg(program_string).err().unwrap(); + load_asg(&new_context(), program_string).err().unwrap(); } #[test] fn test_member_static_function_undefined() { let program_string = include_str!("member_static_function_undefined.leo"); - load_asg(program_string).err().unwrap(); + load_asg(&new_context(), program_string).err().unwrap(); } // Mutability @@ -84,37 +86,37 @@ fn test_member_static_function_undefined() { #[test] fn test_mutate_function_fail() { let program_string = include_str!("mut_function_fail.leo"); - load_asg(program_string).err().unwrap(); + load_asg(&new_context(), program_string).err().unwrap(); } #[test] fn test_mutate_self_variable_fail() { let program_string = include_str!("mut_self_variable_fail.leo"); - load_asg(program_string).err().unwrap(); + load_asg(&new_context(), program_string).err().unwrap(); } #[test] fn test_mutate_self_function_fail() { let program_string = include_str!("mut_self_function_fail.leo"); - load_asg(program_string).err().unwrap(); + load_asg(&new_context(), program_string).err().unwrap(); } #[test] fn test_mutate_self_static_function_fail() { let program_string = include_str!("mut_self_static_function_fail.leo"); - load_asg(program_string).err().unwrap(); + load_asg(&new_context(), program_string).err().unwrap(); } #[test] fn test_mutate_static_function_fail() { let program_string = include_str!("mut_static_function_fail.leo"); - load_asg(program_string).err().unwrap(); + load_asg(&new_context(), program_string).err().unwrap(); } #[test] fn test_mutate_variable_fail() { let program_string = include_str!("mut_variable_fail.leo"); - load_asg(program_string).err().unwrap(); + load_asg(&new_context(), program_string).err().unwrap(); } // Self @@ -122,17 +124,17 @@ fn test_mutate_variable_fail() { #[test] fn test_self_fail() { let program_string = include_str!("self_fail.leo"); - load_asg(program_string).err().unwrap(); + load_asg(&new_context(), program_string).err().unwrap(); } #[test] fn test_self_member_invalid() { let program_string = include_str!("self_member_invalid.leo"); - load_asg(program_string).err().unwrap(); + load_asg(&new_context(), program_string).err().unwrap(); } #[test] fn test_self_member_undefined() { let program_string = include_str!("self_member_undefined.leo"); - load_asg(program_string).err().unwrap(); + load_asg(&new_context(), program_string).err().unwrap(); } diff --git a/asg/tests/fail/console/mod.rs b/asg/tests/fail/console/mod.rs index 9f967460f3..8e9b4cfe7f 100644 --- a/asg/tests/fail/console/mod.rs +++ b/asg/tests/fail/console/mod.rs @@ -14,28 +14,30 @@ // You should have received a copy of the GNU General Public License // along with the Leo library. If not, see . +use leo_asg::new_context; + use crate::load_asg; #[test] fn test_log_fail() { let program_string = include_str!("log_fail.leo"); - load_asg(program_string).err().unwrap(); + load_asg(&new_context(), program_string).err().unwrap(); } #[test] fn test_log_parameter_fail_unknown() { let program_string = include_str!("log_parameter_fail_unknown.leo"); - load_asg(program_string).err().unwrap(); + load_asg(&new_context(), program_string).err().unwrap(); } #[test] fn test_log_parameter_fail_empty() { let program_string = include_str!("log_parameter_fail_empty.leo"); - load_asg(program_string).err().unwrap(); + load_asg(&new_context(), program_string).err().unwrap(); } #[test] fn test_log_parameter_fail_none() { let program_string = include_str!("log_parameter_fail_empty.leo"); - load_asg(program_string).err().unwrap(); + load_asg(&new_context(), program_string).err().unwrap(); } diff --git a/asg/tests/fail/core/mod.rs b/asg/tests/fail/core/mod.rs index 00ee4ad03a..c640c5b8b1 100644 --- a/asg/tests/fail/core/mod.rs +++ b/asg/tests/fail/core/mod.rs @@ -14,28 +14,30 @@ // You should have received a copy of the GNU General Public License // along with the Leo library. If not, see . +use leo_asg::new_context; + use crate::load_asg; #[test] fn test_core_circuit_invalid() { let program_string = include_str!("core_package_invalid.leo"); - load_asg(program_string).err().unwrap(); + load_asg(&new_context(), program_string).err().unwrap(); } #[test] fn test_core_circuit_star_fail() { let program_string = include_str!("core_circuit_star_fail.leo"); - load_asg(program_string).err().unwrap(); + load_asg(&new_context(), program_string).err().unwrap(); } #[test] fn test_core_package_invalid() { let program_string = include_str!("core_package_invalid.leo"); - load_asg(program_string).err().unwrap(); + load_asg(&new_context(), program_string).err().unwrap(); } #[test] fn test_core_unstable_package_invalid() { let program_string = include_str!("core_unstable_package_invalid.leo"); - load_asg(program_string).err().unwrap(); + load_asg(&new_context(), program_string).err().unwrap(); } diff --git a/asg/tests/fail/function/mod.rs b/asg/tests/fail/function/mod.rs index 07d80c4a1f..c88dbf6758 100644 --- a/asg/tests/fail/function/mod.rs +++ b/asg/tests/fail/function/mod.rs @@ -14,42 +14,44 @@ // You should have received a copy of the GNU General Public License // along with the Leo library. If not, see . +use leo_asg::new_context; + use crate::load_asg; #[test] fn test_multiple_returns_fail() { let program_string = include_str!("multiple_returns_fail.leo"); - load_asg(program_string).err().unwrap(); + load_asg(&new_context(), program_string).err().unwrap(); } #[test] fn test_multiple_returns_input_ambiguous() { let program_string = include_str!("multiple_returns_input_ambiguous.leo"); - load_asg(program_string).err().unwrap(); + load_asg(&new_context(), program_string).err().unwrap(); } #[test] fn test_multiple_returns_fail_conditional() { let program_string = include_str!("multiple_returns_fail_conditional.leo"); - load_asg(program_string).err().unwrap(); + load_asg(&new_context(), program_string).err().unwrap(); } #[test] fn test_scope_fail() { let program_string = include_str!("scope_fail.leo"); - load_asg(program_string).err().unwrap(); + load_asg(&new_context(), program_string).err().unwrap(); } #[test] fn test_undefined() { let program_string = include_str!("undefined.leo"); - load_asg(program_string).err().unwrap(); + load_asg(&new_context(), program_string).err().unwrap(); } #[test] fn test_array_input() { let program_string = include_str!("array_input.leo"); - load_asg(program_string).err().unwrap(); + load_asg(&new_context(), program_string).err().unwrap(); } // Test return multidimensional arrays @@ -57,11 +59,11 @@ fn test_array_input() { #[test] fn test_return_array_nested_fail() { let program_string = include_str!("return_array_nested_fail.leo"); - load_asg(program_string).err().unwrap(); + load_asg(&new_context(), program_string).err().unwrap(); } #[test] fn test_return_array_tuple_fail() { let program_string = include_str!("return_array_tuple_fail.leo"); - load_asg(program_string).err().unwrap(); + load_asg(&new_context(), program_string).err().unwrap(); } diff --git a/asg/tests/fail/group/mod.rs b/asg/tests/fail/group/mod.rs index 01a3059bde..d5d6f82e62 100644 --- a/asg/tests/fail/group/mod.rs +++ b/asg/tests/fail/group/mod.rs @@ -19,17 +19,17 @@ use crate::load_asg; #[test] fn test_both_sign_high() { let program_string = include_str!("both_sign_high.leo"); - load_asg(program_string).err().unwrap(); + load_asg(&new_context(), program_string).err().unwrap(); } #[test] fn test_both_sign_low() { let program_string = include_str!("both_sign_low.leo"); - load_asg(program_string).err().unwrap(); + load_asg(&new_context(), program_string).err().unwrap(); } #[test] fn test_both_sign_inferred() { let program_string = include_str!("both_sign_inferred.leo"); - load_asg(program_string).err().unwrap(); + load_asg(&new_context(), program_string).err().unwrap(); } diff --git a/asg/tests/fail/integers/int_macro.rs b/asg/tests/fail/integers/int_macro.rs index cd06cba948..0dad4559c5 100644 --- a/asg/tests/fail/integers/int_macro.rs +++ b/asg/tests/fail/integers/int_macro.rs @@ -16,25 +16,27 @@ macro_rules! test_int { ($name: ident) => { + use leo_asg::new_context; + pub struct $name {} // we are not doing constant folding here, so asg doesnt catch this // impl $name { // fn test_negate_min_fail() { // let program_string = include_str!("negate_min.leo"); - // crate::load_asg(program_string).err().unwrap(); + // crate::load_asg(&new_context(), program_string).err().unwrap(); // } // } impl super::IntegerTester for $name { fn test_min_fail() { let program_string = include_str!("min_fail.leo"); - crate::load_asg(program_string).err().unwrap(); + crate::load_asg(&new_context(), program_string).err().unwrap(); } fn test_max_fail() { let program_string = include_str!("max_fail.leo"); - crate::load_asg(program_string).err().unwrap(); + crate::load_asg(&new_context(), program_string).err().unwrap(); } } }; diff --git a/asg/tests/fail/integers/uint_macro.rs b/asg/tests/fail/integers/uint_macro.rs index 6b718dbeda..3f42d8bfbe 100644 --- a/asg/tests/fail/integers/uint_macro.rs +++ b/asg/tests/fail/integers/uint_macro.rs @@ -16,17 +16,19 @@ macro_rules! test_uint { ($name: ident) => { + use leo_asg::new_context; + pub struct $name {} impl super::IntegerTester for $name { fn test_min_fail() { let program_string = include_str!("min_fail.leo"); - crate::load_asg(program_string).err().unwrap(); + crate::load_asg(&new_context(), program_string).err().unwrap(); } fn test_max_fail() { let program_string = include_str!("max_fail.leo"); - crate::load_asg(program_string).err().unwrap(); + crate::load_asg(&new_context(), program_string).err().unwrap(); } } }; diff --git a/asg/tests/fail/mutability/mod.rs b/asg/tests/fail/mutability/mod.rs index 84b481b9f3..6864183a76 100644 --- a/asg/tests/fail/mutability/mod.rs +++ b/asg/tests/fail/mutability/mod.rs @@ -14,52 +14,54 @@ // You should have received a copy of the GNU General Public License // along with the Leo library. If not, see . +use leo_asg::new_context; + use crate::load_asg; #[test] fn test_let() { let program_string = include_str!("let.leo"); - load_asg(program_string).err().unwrap(); + load_asg(&new_context(), program_string).err().unwrap(); } #[test] fn test_const_fail() { let program_string = include_str!("const.leo"); - load_asg(program_string).err().unwrap(); + load_asg(&new_context(), program_string).err().unwrap(); } #[test] fn test_const_mut_fail() { let program_string = include_str!("const_mut.leo"); - load_asg(program_string).err().unwrap(); + load_asg(&new_context(), program_string).err().unwrap(); } #[test] fn test_array() { let program_string = include_str!("array.leo"); - load_asg(program_string).err().unwrap(); + load_asg(&new_context(), program_string).err().unwrap(); } #[test] fn test_circuit() { let program_string = include_str!("circuit.leo"); - load_asg(program_string).err().unwrap(); + load_asg(&new_context(), program_string).err().unwrap(); } #[test] fn test_circuit_function_mut() { let program_string = include_str!("circuit_function_mut.leo"); - load_asg(program_string).err().unwrap(); + load_asg(&new_context(), program_string).err().unwrap(); } #[test] fn test_circuit_static_function_mut() { let program_string = include_str!("circuit_static_function_mut.leo"); - load_asg(program_string).err().unwrap(); + load_asg(&new_context(), program_string).err().unwrap(); } #[test] fn test_function_input() { let program_string = include_str!("function_input.leo"); - load_asg(program_string).err().unwrap(); + load_asg(&new_context(), program_string).err().unwrap(); } diff --git a/asg/tests/fail/statements/mod.rs b/asg/tests/fail/statements/mod.rs index 42ba3243ad..f12ed3e983 100644 --- a/asg/tests/fail/statements/mod.rs +++ b/asg/tests/fail/statements/mod.rs @@ -14,10 +14,12 @@ // You should have received a copy of the GNU General Public License // along with the Leo library. If not, see . +use leo_asg::new_context; + use crate::load_asg; #[test] fn test_num_returns_fail() { let program_string = include_str!("num_returns_fail.leo"); - load_asg(program_string).err().unwrap(); + load_asg(&new_context(), program_string).err().unwrap(); } diff --git a/asg/tests/mod.rs b/asg/tests/mod.rs index 78c1f436e6..1eb54e4854 100644 --- a/asg/tests/mod.rs +++ b/asg/tests/mod.rs @@ -26,20 +26,21 @@ mod pass; const TESTING_FILEPATH: &str = "input.leo"; const TESTING_PROGRAM_NAME: &str = "test_program"; -fn load_asg(program_string: &str) -> Result { - load_asg_imports(program_string, &mut NullImportResolver) +fn load_asg<'a>(context: AsgContext<'a>, program_string: &str) -> Result, AsgConvertError> { + load_asg_imports(context, program_string, &mut NullImportResolver) } -fn load_asg_imports( +fn load_asg_imports<'a, T: ImportResolver<'a>>( + context: AsgContext<'a>, program_string: &str, imports: &mut T, -) -> Result { +) -> Result, AsgConvertError> { let grammar = Grammar::new(Path::new(&TESTING_FILEPATH), program_string)?; let ast = Ast::new(TESTING_PROGRAM_NAME, &grammar)?; - InternalProgram::new(&ast.as_repr(), imports) + InternalProgram::new(context, &ast.as_repr(), imports) } -fn mocked_resolver() -> MockedImportResolver { +fn mocked_resolver<'a>(_ctx: AsgContext<'a>) -> MockedImportResolver<'a> { let packages = indexmap::IndexMap::new(); MockedImportResolver { packages } } diff --git a/asg/tests/pass/address/mod.rs b/asg/tests/pass/address/mod.rs index 7cd18d57a2..40d4da0ea3 100644 --- a/asg/tests/pass/address/mod.rs +++ b/asg/tests/pass/address/mod.rs @@ -14,35 +14,37 @@ // You should have received a copy of the GNU General Public License // along with the Leo library. If not, see . +use leo_asg::new_context; + use crate::load_asg; #[test] fn test_valid() { let program_string = include_str!("valid.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_implicit_valid() { let program_string = include_str!("implicit_valid.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_console_assert_pass() { let program_string = include_str!("console_assert_pass.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_ternary() { let program_string = include_str!("ternary.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_equal() { let program_string = include_str!("equal.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } diff --git a/asg/tests/pass/array/mod.rs b/asg/tests/pass/array/mod.rs index bca1da1a97..8b16056f3b 100644 --- a/asg/tests/pass/array/mod.rs +++ b/asg/tests/pass/array/mod.rs @@ -14,6 +14,8 @@ // You should have received a copy of the GNU General Public License // along with the Leo library. If not, see . +use leo_asg::new_context; + use crate::load_asg; // Registers @@ -21,7 +23,7 @@ use crate::load_asg; #[test] fn test_registers() { let program_string = include_str!("registers.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } // Expressions @@ -29,173 +31,173 @@ fn test_registers() { #[test] fn test_inline() { let program_string = include_str!("inline.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_initializer() { let program_string = include_str!("initializer.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_initializer_input() { let program_string = include_str!("initializer_input.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_input_nested_3x2() { let program_string = include_str!("input_nested_3x2.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_input_tuple_3x2() { let program_string = include_str!("input_tuple_3x2.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_multi_initializer() { let program_string = include_str!("multi_initializer.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_nested_3x2_value() { let program_string = include_str!("nested_3x2_value.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_tuple_3x2_value() { let program_string = include_str!("tuple_3x2_value.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_spread() { let program_string = include_str!("spread.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_slice() { let program_string = include_str!("slice.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_index_u8() { let program_string = include_str!("index_u8.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_slice_i8() { let program_string = include_str!("slice_i8.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_slice_lower() { let program_string = include_str!("slice_lower.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_type_nested_value_nested_3x2() { let program_string = include_str!("type_nested_value_nested_3x2.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_type_nested_value_nested_4x3x2() { let program_string = include_str!("type_nested_value_nested_4x3x2.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_type_nested_value_tuple_3x2() { let program_string = include_str!("type_nested_value_tuple_3x2.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_type_nested_value_tuple_4x3x2() { let program_string = include_str!("type_nested_value_tuple_4x3x2.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_type_tuple_value_nested_3x2() { let program_string = include_str!("type_tuple_value_nested_3x2.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_type_tuple_value_nested_4x3x2() { let program_string = include_str!("type_tuple_value_nested_4x3x2.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_type_tuple_value_tuple_3x2() { let program_string = include_str!("type_tuple_value_tuple_3x2.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_type_tuple_value_tuple_4x3x2() { let program_string = include_str!("type_tuple_value_tuple_4x3x2.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_input_type_nested_value_nested_3x2() { let program_string = include_str!("type_input_3x2.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_input_type_nested_value_nested_4x3x2() { let program_string = include_str!("type_input_4x3x2.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_input_type_nested_value_tuple_3x2() { let program_string = include_str!("type_input_3x2.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_input_type_nested_value_tuple_4x3x2() { let program_string = include_str!("type_input_4x3x2.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_input_type_tuple_value_nested_3x2() { let program_string = include_str!("type_input_3x2.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_input_type_tuple_value_nested_4x3x2() { let program_string = include_str!("type_input_4x3x2.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_input_type_tuple_value_tuple_3x2() { let program_string = include_str!("type_input_3x2.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_input_type_tuple_value_tuple_4x3x2() { let program_string = include_str!("type_input_4x3x2.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } diff --git a/asg/tests/pass/boolean/mod.rs b/asg/tests/pass/boolean/mod.rs index 64f9172507..4e5ef256ac 100644 --- a/asg/tests/pass/boolean/mod.rs +++ b/asg/tests/pass/boolean/mod.rs @@ -14,18 +14,20 @@ // You should have received a copy of the GNU General Public License // along with the Leo library. If not, see . +use leo_asg::new_context; + use crate::load_asg; #[test] fn test_input_pass() { let program_string = include_str!("assert_eq_input.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_registers() { let program_string = include_str!("output_register.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } // Boolean not ! @@ -33,19 +35,19 @@ fn test_registers() { #[test] fn test_not_true() { let program_string = include_str!("not_true.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_not_false() { let program_string = include_str!("not_false.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_not_mutable() { let program_string = include_str!("not_mutable.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } // Boolean or || @@ -53,19 +55,19 @@ fn test_not_mutable() { #[test] fn test_true_or_true() { let program_string = include_str!("true_or_true.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_true_or_false() { let program_string = include_str!("true_or_false.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_false_or_false() { let program_string = include_str!("false_or_false.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } // Boolean and && @@ -73,19 +75,19 @@ fn test_false_or_false() { #[test] fn test_true_and_true() { let program_string = include_str!("true_and_true.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_true_and_false() { let program_string = include_str!("true_and_false.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_false_and_false() { let program_string = include_str!("false_and_false.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } // All @@ -93,5 +95,5 @@ fn test_false_and_false() { #[test] fn test_all() { let program_string = include_str!("all.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } diff --git a/asg/tests/pass/circuits/mod.rs b/asg/tests/pass/circuits/mod.rs index 811eedd2d4..3c9f701892 100644 --- a/asg/tests/pass/circuits/mod.rs +++ b/asg/tests/pass/circuits/mod.rs @@ -14,6 +14,8 @@ // You should have received a copy of the GNU General Public License // along with the Leo library. If not, see . +use leo_asg::new_context; + use crate::load_asg; // Expressions @@ -21,7 +23,7 @@ use crate::load_asg; #[test] fn test_inline() { let program_string = include_str!("inline.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } // Members @@ -29,19 +31,19 @@ fn test_inline() { #[test] fn test_member_variable() { let program_string = include_str!("member_variable.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_member_variable_and_function() { let program_string = include_str!("member_variable_and_function.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_member_function() { let program_string = include_str!("member_function.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] @@ -58,25 +60,25 @@ fn test_mut_member_function() { console.assert(a.echo(1u32) == 1u32); }"#; - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_member_function_nested() { let program_string = include_str!("member_function_nested.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_member_static_function() { let program_string = include_str!("member_static_function.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_member_static_function_nested() { let program_string = include_str!("member_static_function_nested.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } // Mutability @@ -84,19 +86,19 @@ fn test_member_static_function_nested() { #[test] fn test_mutate_self_variable() { let program_string = include_str!("mut_self_variable.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_mutate_self_variable_conditional() { let program_string = include_str!("mut_self_variable_conditional.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_mutate_variable() { let program_string = include_str!("mut_variable.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } // Self @@ -104,7 +106,7 @@ fn test_mutate_variable() { #[test] fn test_self_member_pass() { let program_string = include_str!("self_member.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } // All @@ -112,13 +114,13 @@ fn test_self_member_pass() { #[test] fn test_pedersen_mock() { let program_string = include_str!("pedersen_mock.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_define_circuit_inside_circuit_function() { let program_string = include_str!("define_circuit_inside_circuit_function.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] @@ -131,5 +133,5 @@ fn test_circuit_explicit_define() { let x: One = One {x: 5}; } "#; - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } diff --git a/asg/tests/pass/console/mod.rs b/asg/tests/pass/console/mod.rs index 4f1a243012..bdf3173007 100644 --- a/asg/tests/pass/console/mod.rs +++ b/asg/tests/pass/console/mod.rs @@ -14,30 +14,32 @@ // You should have received a copy of the GNU General Public License // along with the Leo library. If not, see . +use leo_asg::new_context; + use crate::load_asg; #[test] fn test_log() { let program_string = include_str!("log.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_log_parameter() { let program_string = include_str!("log_parameter.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_log_parameter_many() { let program_string = include_str!("log_parameter_many.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_log_input() { let program_string = include_str!("log_input.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } // Debug @@ -45,7 +47,7 @@ fn test_log_input() { #[test] fn test_debug() { let program_string = include_str!("debug.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } // Error @@ -53,7 +55,7 @@ fn test_debug() { #[test] fn test_error() { let program_string = include_str!("error.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } // Assertion @@ -61,11 +63,11 @@ fn test_error() { #[test] fn test_assert() { let program_string = include_str!("assert.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_conditional_assert() { let program_string = include_str!("conditional_assert.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } diff --git a/asg/tests/pass/core/mod.rs b/asg/tests/pass/core/mod.rs index 69d528d54c..3094c59cf7 100644 --- a/asg/tests/pass/core/mod.rs +++ b/asg/tests/pass/core/mod.rs @@ -14,22 +14,24 @@ // You should have received a copy of the GNU General Public License // along with the Leo library. If not, see . +use leo_asg::new_context; + use crate::load_asg; #[test] fn test_unstable_blake2s() { let program_string = include_str!("unstable_blake2s.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_blake2s_input() { let program_string = include_str!("blake2s_input.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_blake2s_random() { let program_string = include_str!("blake2s_random.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } diff --git a/asg/tests/pass/definition/mod.rs b/asg/tests/pass/definition/mod.rs index 921dc3e16d..764387f506 100644 --- a/asg/tests/pass/definition/mod.rs +++ b/asg/tests/pass/definition/mod.rs @@ -14,17 +14,19 @@ // You should have received a copy of the GNU General Public License // along with the Leo library. If not, see . +use leo_asg::new_context; + use crate::load_asg; #[test] fn test_out_of_order() { let program_string = include_str!("out_of_order.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } // #[test] // #[ignore] // fn test_out_of_order_with_import() { // let program_string = include_str!("out_of_order_with_import.leo"); -// load_asg(program_string).unwrap(); +// load_asg(&new_context(), program_string).unwrap(); // } diff --git a/asg/tests/pass/field/mod.rs b/asg/tests/pass/field/mod.rs index 8c895b0c92..eb8553f178 100644 --- a/asg/tests/pass/field/mod.rs +++ b/asg/tests/pass/field/mod.rs @@ -14,18 +14,20 @@ // You should have received a copy of the GNU General Public License // along with the Leo library. If not, see . +use leo_asg::new_context; + use crate::load_asg; #[test] fn test_negate() { let program_string = include_str!("negate.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_add() { let program_string = include_str!("add.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] @@ -35,41 +37,41 @@ fn test_add_explicit() { let c: field = 0field + 1field; } "#; - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_sub() { let program_string = include_str!("sub.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_div() { let program_string = include_str!("div.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_mul() { let program_string = include_str!("mul.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_eq() { let program_string = include_str!("eq.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_console_assert_pass() { let program_string = include_str!("console_assert.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_ternary() { let program_string = include_str!("ternary.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } diff --git a/asg/tests/pass/form_ast.rs b/asg/tests/pass/form_ast.rs index c45f9e77ae..5248b0ecf9 100644 --- a/asg/tests/pass/form_ast.rs +++ b/asg/tests/pass/form_ast.rs @@ -14,6 +14,8 @@ // You should have received a copy of the GNU General Public License // along with the Leo library. If not, see . +use leo_asg::new_context; + use crate::load_asg; use leo_ast::Ast; use leo_grammar::Grammar; @@ -23,7 +25,8 @@ use std::path::Path; #[test] fn test_basic() { let program_string = include_str!("./circuits/pedersen_mock.leo"); - let asg = load_asg(program_string).unwrap(); + let ctx = new_context(); + let asg = load_asg(&ctx, program_string).unwrap(); let reformed_ast = leo_asg::reform_ast(&asg); println!("{}", reformed_ast); // panic!(); @@ -48,7 +51,8 @@ fn test_function_rename() { console.assert(total == 20); } "#; - let asg = load_asg(program_string).unwrap(); + let ctx = new_context(); + let asg = load_asg(&ctx, program_string).unwrap(); let reformed_ast = leo_asg::reform_ast(&asg); println!("{}", reformed_ast); // panic!(); @@ -56,7 +60,8 @@ fn test_function_rename() { #[test] fn test_imports() { - let mut imports = crate::mocked_resolver(); + let ctx = new_context(); + let mut imports = crate::mocked_resolver(&ctx); let test_import = r#" circuit Point { x: u32 @@ -69,7 +74,7 @@ fn test_imports() { "#; imports .packages - .insert("test-import".to_string(), load_asg(test_import).unwrap()); + .insert("test-import".to_string(), load_asg(&ctx, test_import).unwrap()); let program_string = r#" import test-import.foo; @@ -90,7 +95,7 @@ fn test_imports() { serde_json::to_string(Ast::new("test", &test_grammar).unwrap().as_repr()).unwrap() ); - let asg = crate::load_asg_imports(program_string, &mut imports).unwrap(); + let asg = crate::load_asg_imports(&ctx, program_string, &mut imports).unwrap(); let reformed_ast = leo_asg::reform_ast(&asg); println!("{}", serde_json::to_string(&reformed_ast).unwrap()); // panic!(); diff --git a/asg/tests/pass/function/mod.rs b/asg/tests/pass/function/mod.rs index abdbd93b7c..964f5dab53 100644 --- a/asg/tests/pass/function/mod.rs +++ b/asg/tests/pass/function/mod.rs @@ -14,18 +14,20 @@ // You should have received a copy of the GNU General Public License // along with the Leo library. If not, see . +use leo_asg::new_context; + use crate::load_asg; #[test] fn test_empty() { let program_string = include_str!("empty.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_iteration() { let program_string = include_str!("iteration.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] @@ -45,7 +47,7 @@ fn test_const_args() { console.assert(a == 20u32); } "#; - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] @@ -66,7 +68,7 @@ fn test_const_args_used() { console.assert(a == 6u8); } "#; - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] @@ -85,61 +87,61 @@ fn test_const_args_fail() { console.assert(a == 1u8); } "#; - load_asg(program_string).err().unwrap(); + load_asg(&new_context(), program_string).err().unwrap(); } #[test] fn test_iteration_repeated() { let program_string = include_str!("iteration_repeated.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_newlines() { let program_string = include_str!("newlines.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_multiple_returns() { let program_string = include_str!("multiple_returns.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_multiple_returns_main() { let program_string = include_str!("multiple_returns_main.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_repeated_function_call() { let program_string = include_str!("repeated.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_return() { let program_string = include_str!("return.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_undefined() { let program_string = include_str!("undefined.leo"); - load_asg(program_string).err().unwrap(); + load_asg(&new_context(), program_string).err().unwrap(); } #[test] fn test_value_unchanged() { let program_string = include_str!("value_unchanged.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_array_input() { let program_string = include_str!("array_input.leo"); - load_asg(program_string).err().unwrap(); + load_asg(&new_context(), program_string).err().unwrap(); } // Test return multidimensional arrays @@ -147,13 +149,13 @@ fn test_array_input() { #[test] fn test_return_array_nested_pass() { let program_string = include_str!("return_array_nested_pass.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_return_array_tuple_pass() { let program_string = include_str!("return_array_tuple_pass.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } // Test return tuples @@ -161,11 +163,11 @@ fn test_return_array_tuple_pass() { #[test] fn test_return_tuple() { let program_string = include_str!("return_tuple.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_return_tuple_conditional() { let program_string = include_str!("return_tuple_conditional.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } diff --git a/asg/tests/pass/group/mod.rs b/asg/tests/pass/group/mod.rs index a8a213e37a..a80dc664d8 100644 --- a/asg/tests/pass/group/mod.rs +++ b/asg/tests/pass/group/mod.rs @@ -14,12 +14,14 @@ // You should have received a copy of the GNU General Public License // along with the Leo library. If not, see . +use leo_asg::new_context; + use crate::load_asg; #[test] fn test_one() { let program_string = include_str!("one.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] @@ -29,79 +31,79 @@ fn test_implicit() { let element: group = 0; } "#; - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_zero() { let program_string = include_str!("zero.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_point() { let program_string = include_str!("point.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_x_sign_high() { let program_string = include_str!("x_sign_high.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_x_sign_low() { let program_string = include_str!("x_sign_low.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_x_sign_inferred() { let program_string = include_str!("x_sign_inferred.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_y_sign_high() { let program_string = include_str!("y_sign_high.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_y_sign_low() { let program_string = include_str!("y_sign_low.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_y_sign_inferred() { let program_string = include_str!("y_sign_inferred.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_point_input() { let program_string = include_str!("point_input.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_input() { let program_string = include_str!("input.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_negate() { let program_string = include_str!("negate.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_add() { let program_string = include_str!("add.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] @@ -111,29 +113,29 @@ fn test_add_explicit() { let c: group = 0group + 1group; } "#; - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_sub() { let program_string = include_str!("sub.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_console_assert_pass() { let program_string = include_str!("assert_eq.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_eq() { let program_string = include_str!("eq.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_ternary() { let program_string = include_str!("ternary.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } diff --git a/asg/tests/pass/import/mod.rs b/asg/tests/pass/import/mod.rs index c7de7adadf..b7ecea72c2 100644 --- a/asg/tests/pass/import/mod.rs +++ b/asg/tests/pass/import/mod.rs @@ -14,134 +14,144 @@ // You should have received a copy of the GNU General Public License // along with the Leo library. If not, see . +use leo_asg::new_context; + use crate::{load_asg, load_asg_imports, mocked_resolver}; #[test] fn test_basic() { - let mut imports = mocked_resolver(); + let ctx = new_context(); + let mut imports = mocked_resolver(&ctx); imports.packages.insert( "test-import".to_string(), - load_asg(include_str!("src/test-import.leo")).unwrap(), + load_asg(&ctx, include_str!("src/test-import.leo")).unwrap(), ); let program_string = include_str!("basic.leo"); - load_asg_imports(program_string, &mut imports).unwrap(); + load_asg_imports(&ctx, program_string, &mut imports).unwrap(); } #[test] fn test_multiple() { - let mut imports = mocked_resolver(); + let ctx = new_context(); + let mut imports = mocked_resolver(&ctx); imports.packages.insert( "test-import".to_string(), - load_asg(include_str!("src/test-import.leo")).unwrap(), + load_asg(&ctx, include_str!("src/test-import.leo")).unwrap(), ); let program_string = include_str!("multiple.leo"); - load_asg_imports(program_string, &mut imports).unwrap(); + load_asg_imports(&ctx, program_string, &mut imports).unwrap(); } #[test] fn test_star() { - let mut imports = mocked_resolver(); + let ctx = new_context(); + let mut imports = mocked_resolver(&ctx); imports.packages.insert( "test-import".to_string(), - load_asg(include_str!("src/test-import.leo")).unwrap(), + load_asg(&ctx, include_str!("src/test-import.leo")).unwrap(), ); let program_string = include_str!("star.leo"); - load_asg_imports(program_string, &mut imports).unwrap(); + load_asg_imports(&ctx, program_string, &mut imports).unwrap(); } #[test] fn test_alias() { - let mut imports = mocked_resolver(); + let ctx = new_context(); + let mut imports = mocked_resolver(&ctx); imports.packages.insert( "test-import".to_string(), - load_asg(include_str!("src/test-import.leo")).unwrap(), + load_asg(&ctx, include_str!("src/test-import.leo")).unwrap(), ); let program_string = include_str!("alias.leo"); - load_asg_imports(program_string, &mut imports).unwrap(); + load_asg_imports(&ctx, program_string, &mut imports).unwrap(); } // naming tests #[test] fn test_name() { - let mut imports = mocked_resolver(); + let ctx = new_context(); + let mut imports = mocked_resolver(&ctx); imports.packages.insert( "hello-world".to_string(), - load_asg(include_str!("src/hello-world.leo")).unwrap(), + load_asg(&ctx, include_str!("src/hello-world.leo")).unwrap(), + ); + imports.packages.insert( + "a0-f".to_string(), + load_asg(&ctx, include_str!("src/a0-f.leo")).unwrap(), ); imports .packages - .insert("a0-f".to_string(), load_asg(include_str!("src/a0-f.leo")).unwrap()); - imports - .packages - .insert("a-9".to_string(), load_asg(include_str!("src/a-9.leo")).unwrap()); + .insert("a-9".to_string(), load_asg(&ctx, include_str!("src/a-9.leo")).unwrap()); let program_string = include_str!("names.leo"); - load_asg_imports(program_string, &mut imports).unwrap(); + load_asg_imports(&ctx, program_string, &mut imports).unwrap(); } // more complex tests #[test] fn test_many_import() { - let mut imports = mocked_resolver(); + let ctx = new_context(); + let mut imports = mocked_resolver(&ctx); imports.packages.insert( "test-import".to_string(), - load_asg(include_str!("src/test-import.leo")).unwrap(), + load_asg(&ctx, include_str!("src/test-import.leo")).unwrap(), ); imports.packages.insert( "bar".to_string(), - load_asg(include_str!("imports/bar/src/lib.leo")).unwrap(), + load_asg(&ctx, include_str!("imports/bar/src/lib.leo")).unwrap(), ); imports.packages.insert( "bar.baz".to_string(), - load_asg(include_str!("imports/bar/src/baz.leo")).unwrap(), + load_asg(&ctx, include_str!("imports/bar/src/baz.leo")).unwrap(), ); imports.packages.insert( "bar.baz".to_string(), - load_asg(include_str!("imports/bar/src/baz.leo")).unwrap(), + load_asg(&ctx, include_str!("imports/bar/src/baz.leo")).unwrap(), ); imports.packages.insert( "bar.bat.bat".to_string(), - load_asg(include_str!("imports/bar/src/bat/bat.leo")).unwrap(), + load_asg(&ctx, include_str!("imports/bar/src/bat/bat.leo")).unwrap(), ); imports.packages.insert( "car".to_string(), - load_asg(include_str!("imports/car/src/lib.leo")).unwrap(), + load_asg(&ctx, include_str!("imports/car/src/lib.leo")).unwrap(), ); let program_string = include_str!("many_import.leo"); - load_asg_imports(program_string, &mut imports).unwrap(); + load_asg_imports(&ctx, program_string, &mut imports).unwrap(); } #[test] fn test_many_import_star() { - let mut imports = mocked_resolver(); + let ctx = new_context(); + let mut imports = mocked_resolver(&ctx); imports.packages.insert( "test-import".to_string(), - load_asg(include_str!("src/test-import.leo")).unwrap(), + load_asg(&ctx, include_str!("src/test-import.leo")).unwrap(), ); imports.packages.insert( "bar".to_string(), - load_asg(include_str!("imports/bar/src/lib.leo")).unwrap(), + load_asg(&ctx, include_str!("imports/bar/src/lib.leo")).unwrap(), ); imports.packages.insert( "bar.baz".to_string(), - load_asg(include_str!("imports/bar/src/baz.leo")).unwrap(), + load_asg(&ctx, include_str!("imports/bar/src/baz.leo")).unwrap(), ); imports.packages.insert( "bar.baz".to_string(), - load_asg(include_str!("imports/bar/src/baz.leo")).unwrap(), + load_asg(&ctx, include_str!("imports/bar/src/baz.leo")).unwrap(), ); imports.packages.insert( "bar.bat.bat".to_string(), - load_asg(include_str!("imports/bar/src/bat/bat.leo")).unwrap(), + load_asg(&ctx, include_str!("imports/bar/src/bat/bat.leo")).unwrap(), ); imports.packages.insert( "car".to_string(), - load_asg(include_str!("imports/car/src/lib.leo")).unwrap(), + load_asg(&ctx, include_str!("imports/car/src/lib.leo")).unwrap(), ); let program_string = include_str!("many_import_star.leo"); - load_asg_imports(program_string, &mut imports).unwrap(); + load_asg_imports(&ctx, program_string, &mut imports).unwrap(); } diff --git a/asg/tests/pass/input_files/program_input/mod.rs b/asg/tests/pass/input_files/program_input/mod.rs index a223bd970e..0a094937c3 100644 --- a/asg/tests/pass/input_files/program_input/mod.rs +++ b/asg/tests/pass/input_files/program_input/mod.rs @@ -14,16 +14,18 @@ // You should have received a copy of the GNU General Public License // along with the Leo library. If not, see . +use leo_asg::new_context; + use crate::load_asg; #[test] fn test_input_pass() { let program_string = include_str!("main.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_input_multiple() { let program_string = include_str!("main_multiple.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } diff --git a/asg/tests/pass/input_files/program_input_and_program_state/mod.rs b/asg/tests/pass/input_files/program_input_and_program_state/mod.rs index 0df6849508..c38b40c5a7 100644 --- a/asg/tests/pass/input_files/program_input_and_program_state/mod.rs +++ b/asg/tests/pass/input_files/program_input_and_program_state/mod.rs @@ -14,10 +14,12 @@ // You should have received a copy of the GNU General Public License // along with the Leo library. If not, see . +use leo_asg::new_context; + use crate::load_asg; #[test] fn test_access() { let program_string = include_str!("access.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } diff --git a/asg/tests/pass/input_files/program_state/mod.rs b/asg/tests/pass/input_files/program_state/mod.rs index ce09adf28d..e429df9c58 100644 --- a/asg/tests/pass/input_files/program_state/mod.rs +++ b/asg/tests/pass/input_files/program_state/mod.rs @@ -14,16 +14,18 @@ // You should have received a copy of the GNU General Public License // along with the Leo library. If not, see . +use leo_asg::new_context; + use crate::load_asg; #[test] fn test_access_state() { let program_string = include_str!("access_state.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_access_all() { let program_string = include_str!("access_all.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } diff --git a/asg/tests/pass/integers/int_macro.rs b/asg/tests/pass/integers/int_macro.rs index b9c7e7bc28..a7370b8f82 100644 --- a/asg/tests/pass/integers/int_macro.rs +++ b/asg/tests/pass/integers/int_macro.rs @@ -16,94 +16,96 @@ macro_rules! test_int { ($name: ident) => { + use leo_asg::new_context; + pub struct $name {} impl $name { fn test_negate() { let program_string = include_str!("negate.leo"); - crate::load_asg(program_string).unwrap(); + crate::load_asg(&new_context(), program_string).unwrap(); } fn test_negate_zero() { let program_string = include_str!("negate_zero.leo"); - crate::load_asg(program_string).unwrap(); + crate::load_asg(&new_context(), program_string).unwrap(); } } impl super::IntegerTester for $name { fn test_min() { let program_string = include_str!("min.leo"); - crate::load_asg(program_string).unwrap(); + crate::load_asg(&new_context(), program_string).unwrap(); } fn test_max() { let program_string = include_str!("max.leo"); - crate::load_asg(program_string).unwrap(); + crate::load_asg(&new_context(), program_string).unwrap(); } fn test_add() { let program_string = include_str!("add.leo"); - crate::load_asg(program_string).unwrap(); + crate::load_asg(&new_context(), program_string).unwrap(); } fn test_sub() { let program_string = include_str!("sub.leo"); - crate::load_asg(program_string).unwrap(); + crate::load_asg(&new_context(), program_string).unwrap(); } fn test_mul() { let program_string = include_str!("mul.leo"); - crate::load_asg(program_string).unwrap(); + crate::load_asg(&new_context(), program_string).unwrap(); } fn test_div() { let program_string = include_str!("div.leo"); - crate::load_asg(program_string).unwrap(); + crate::load_asg(&new_context(), program_string).unwrap(); } fn test_pow() { let program_string = include_str!("pow.leo"); - crate::load_asg(program_string).unwrap(); + crate::load_asg(&new_context(), program_string).unwrap(); } fn test_eq() { let program_string = include_str!("eq.leo"); - crate::load_asg(program_string).unwrap(); + crate::load_asg(&new_context(), program_string).unwrap(); } fn test_ne() { let program_string = include_str!("ne.leo"); - crate::load_asg(program_string).unwrap(); + crate::load_asg(&new_context(), program_string).unwrap(); } fn test_ge() { let program_string = include_str!("ge.leo"); - crate::load_asg(program_string).unwrap(); + crate::load_asg(&new_context(), program_string).unwrap(); } fn test_gt() { let program_string = include_str!("gt.leo"); - crate::load_asg(program_string).unwrap(); + crate::load_asg(&new_context(), program_string).unwrap(); } fn test_le() { let program_string = include_str!("le.leo"); - crate::load_asg(program_string).unwrap(); + crate::load_asg(&new_context(), program_string).unwrap(); } fn test_lt() { let program_string = include_str!("lt.leo"); - crate::load_asg(program_string).unwrap(); + crate::load_asg(&new_context(), program_string).unwrap(); } fn test_console_assert() { let program_string = include_str!("console_assert.leo"); - crate::load_asg(program_string).unwrap(); + crate::load_asg(&new_context(), program_string).unwrap(); } fn test_ternary() { let program_string = include_str!("ternary.leo"); - crate::load_asg(program_string).unwrap(); + crate::load_asg(&new_context(), program_string).unwrap(); } } }; diff --git a/asg/tests/pass/integers/uint_macro.rs b/asg/tests/pass/integers/uint_macro.rs index 7d555f10e3..4765636b7c 100644 --- a/asg/tests/pass/integers/uint_macro.rs +++ b/asg/tests/pass/integers/uint_macro.rs @@ -16,82 +16,84 @@ macro_rules! test_uint { ($name: ident) => { + use leo_asg::new_context; + pub struct $name {} impl super::IntegerTester for $name { fn test_min() { let program_string = include_str!("min.leo"); - crate::load_asg(program_string).unwrap(); + crate::load_asg(&new_context(), program_string).unwrap(); } fn test_max() { let program_string = include_str!("max.leo"); - crate::load_asg(program_string).unwrap(); + crate::load_asg(&new_context(), program_string).unwrap(); } fn test_add() { let program_string = include_str!("add.leo"); - crate::load_asg(program_string).unwrap(); + crate::load_asg(&new_context(), program_string).unwrap(); } fn test_sub() { let program_string = include_str!("sub.leo"); - crate::load_asg(program_string).unwrap(); + crate::load_asg(&new_context(), program_string).unwrap(); } fn test_mul() { let program_string = include_str!("mul.leo"); - crate::load_asg(program_string).unwrap(); + crate::load_asg(&new_context(), program_string).unwrap(); } fn test_div() { let program_string = include_str!("div.leo"); - crate::load_asg(program_string).unwrap(); + crate::load_asg(&new_context(), program_string).unwrap(); } fn test_pow() { let program_string = include_str!("pow.leo"); - crate::load_asg(program_string).unwrap(); + crate::load_asg(&new_context(), program_string).unwrap(); } fn test_eq() { let program_string = include_str!("eq.leo"); - crate::load_asg(program_string).unwrap(); + crate::load_asg(&new_context(), program_string).unwrap(); } fn test_ne() { let program_string = include_str!("ne.leo"); - crate::load_asg(program_string).unwrap(); + crate::load_asg(&new_context(), program_string).unwrap(); } fn test_ge() { let program_string = include_str!("ge.leo"); - crate::load_asg(program_string).unwrap(); + crate::load_asg(&new_context(), program_string).unwrap(); } fn test_gt() { let program_string = include_str!("gt.leo"); - crate::load_asg(program_string).unwrap(); + crate::load_asg(&new_context(), program_string).unwrap(); } fn test_le() { let program_string = include_str!("le.leo"); - crate::load_asg(program_string).unwrap(); + crate::load_asg(&new_context(), program_string).unwrap(); } fn test_lt() { let program_string = include_str!("lt.leo"); - crate::load_asg(program_string).unwrap(); + crate::load_asg(&new_context(), program_string).unwrap(); } fn test_console_assert() { let program_string = include_str!("console_assert.leo"); - crate::load_asg(program_string).unwrap(); + crate::load_asg(&new_context(), program_string).unwrap(); } fn test_ternary() { let program_string = include_str!("ternary.leo"); - crate::load_asg(program_string).unwrap(); + crate::load_asg(&new_context(), program_string).unwrap(); } } }; diff --git a/asg/tests/pass/mutability/mod.rs b/asg/tests/pass/mutability/mod.rs index 5da14da0ae..c1e056514f 100644 --- a/asg/tests/pass/mutability/mod.rs +++ b/asg/tests/pass/mutability/mod.rs @@ -14,58 +14,60 @@ // You should have received a copy of the GNU General Public License // along with the Leo library. If not, see . +use leo_asg::new_context; + use crate::load_asg; #[test] fn test_let_mut() { let program_string = include_str!("let_mut.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_let_mut_nested() { let program_string = include_str!("let_mut_nested.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_array_mut() { let program_string = include_str!("array_mut.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_array_tuple_mut() { let program_string = include_str!("array_tuple_mut.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_array_splice_mut() { let program_string = include_str!("array_splice_mut.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_circuit_mut() { let program_string = include_str!("circuit_mut.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_circuit_variable_mut() { let program_string = include_str!("circuit_variable_mut.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_function_input_mut() { let program_string = include_str!("function_input_mut.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_swap() { let program_string = include_str!("swap.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } diff --git a/asg/tests/pass/statements/conditional/mod.rs b/asg/tests/pass/statements/conditional/mod.rs index 73a987bb94..c5598eea10 100644 --- a/asg/tests/pass/statements/conditional/mod.rs +++ b/asg/tests/pass/statements/conditional/mod.rs @@ -14,40 +14,42 @@ // You should have received a copy of the GNU General Public License // along with the Leo library. If not, see . +use leo_asg::new_context; + use crate::load_asg; #[test] fn test_assert() { let program_string = include_str!("assert.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_mutate() { let program_string = include_str!("mutate.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_for_loop() { let program_string = include_str!("for_loop.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_chain() { let program_string = include_str!("chain.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_nested() { let program_string = include_str!("nested.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_multiple_returns() { let program_string = include_str!("multiple_returns.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } diff --git a/asg/tests/pass/statements/mod.rs b/asg/tests/pass/statements/mod.rs index f6123b86db..71a530350f 100644 --- a/asg/tests/pass/statements/mod.rs +++ b/asg/tests/pass/statements/mod.rs @@ -14,6 +14,8 @@ // You should have received a copy of the GNU General Public License // along with the Leo library. If not, see . +use leo_asg::new_context; + use crate::load_asg; pub mod conditional; @@ -23,7 +25,7 @@ pub mod conditional; #[test] fn test_ternary_basic() { let program_string = include_str!("ternary_basic.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } // Iteration for i {start}..{stop} { statements } @@ -31,11 +33,11 @@ fn test_ternary_basic() { #[test] fn test_iteration_basic() { let program_string = include_str!("iteration_basic.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_block() { let program_string = include_str!("block.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } diff --git a/asg/tests/pass/tuples/mod.rs b/asg/tests/pass/tuples/mod.rs index 19f5dc5267..4abf803706 100644 --- a/asg/tests/pass/tuples/mod.rs +++ b/asg/tests/pass/tuples/mod.rs @@ -14,70 +14,72 @@ // You should have received a copy of the GNU General Public License // along with the Leo library. If not, see . +use leo_asg::new_context; + use crate::load_asg; #[test] fn test_tuple_basic() { let program_string = include_str!("basic.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_tuple_access() { let program_string = include_str!("access.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_tuple_typed() { let program_string = include_str!("typed.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_multiple() { let program_string = include_str!("multiple.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_multiple_typed() { let program_string = include_str!("multiple_typed.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_function() { let program_string = include_str!("function.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_function_typed() { let program_string = include_str!("function_typed.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_function_multiple() { let program_string = include_str!("function_multiple.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_nested() { let program_string = include_str!("nested.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_nested_access() { let program_string = include_str!("nested_access.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } #[test] fn test_nested_typed() { let program_string = include_str!("nested_typed.leo"); - load_asg(program_string).unwrap(); + load_asg(&new_context(), program_string).unwrap(); } diff --git a/compiler/src/compiler.rs b/compiler/src/compiler.rs index 60f7f87099..0a92f251de 100644 --- a/compiler/src/compiler.rs +++ b/compiler/src/compiler.rs @@ -44,24 +44,36 @@ use std::{ path::{Path, PathBuf}, }; +pub use leo_asg::{new_context, AsgContext as Context, AsgContext}; + +thread_local! { + static THREAD_GLOBAL_CONTEXT: AsgContext<'static> = Box::leak(Box::new(new_context())); +} + +/// Conventience function to return a leaked thread-local global context. Should only be used for transient programs (like cli). +pub fn thread_leaked_context() -> AsgContext<'static> { + THREAD_GLOBAL_CONTEXT.with(|f| *f) +} + /// Stores information to compile a Leo program. #[derive(Clone)] -pub struct Compiler> { +pub struct Compiler<'a, F: PrimeField, G: GroupType> { program_name: String, main_file_path: PathBuf, output_directory: PathBuf, program: Program, program_input: Input, - asg: Option, + ctx: AsgContext<'a>, + asg: Option>, _engine: PhantomData, _group: PhantomData, } -impl> Compiler { +impl<'a, F: PrimeField, G: GroupType> Compiler<'a, F, G> { /// /// Returns a new Leo program compiler. /// - pub fn new(package_name: String, main_file_path: PathBuf, output_directory: PathBuf) -> Self { + pub fn new(package_name: String, main_file_path: PathBuf, output_directory: PathBuf, ctx: AsgContext<'a>) -> Self { Self { program_name: package_name.clone(), main_file_path, @@ -69,6 +81,7 @@ impl> Compiler { program: Program::new(package_name), program_input: Input::new(), asg: None, + ctx, _engine: PhantomData, _group: PhantomData, } @@ -85,8 +98,9 @@ impl> Compiler { package_name: String, main_file_path: PathBuf, output_directory: PathBuf, + ctx: AsgContext<'a>, ) -> Result { - let mut compiler = Self::new(package_name, main_file_path, output_directory); + let mut compiler = Self::new(package_name, main_file_path, output_directory, ctx); compiler.parse_program()?; @@ -101,6 +115,7 @@ impl> Compiler { /// Parses and stores all imported programs. /// Performs type inference checking on the program, imported programs, and program input. /// + #[allow(clippy::too_many_arguments)] pub fn parse_program_with_input( package_name: String, main_file_path: PathBuf, @@ -109,8 +124,9 @@ impl> Compiler { input_path: &Path, state_string: &str, state_path: &Path, + ctx: AsgContext<'a>, ) -> Result { - let mut compiler = Self::new(package_name, main_file_path, output_directory); + let mut compiler = Self::new(package_name, main_file_path, output_directory, ctx); compiler.parse_input(input_string, input_path, state_string, state_path)?; @@ -189,7 +205,7 @@ impl> Compiler { tracing::debug!("Program parsing complete\n{:#?}", self.program); // Create a new symbol table from the program, imported_programs, and program_input. - let asg = Asg::new(&core_ast, &mut leo_imports::ImportParser::default())?; + let asg = Asg::new(self.ctx, &core_ast, &mut leo_imports::ImportParser::default())?; tracing::debug!("ASG generation complete"); @@ -261,7 +277,7 @@ impl> Compiler { } } -impl> ConstraintSynthesizer for Compiler { +impl<'a, F: PrimeField, G: GroupType> ConstraintSynthesizer for Compiler<'a, F, G> { /// /// Synthesizes the circuit with program input. /// diff --git a/compiler/src/console/assert.rs b/compiler/src/console/assert.rs index f7a60082fc..9d47cd5afa 100644 --- a/compiler/src/console/assert.rs +++ b/compiler/src/console/assert.rs @@ -24,19 +24,18 @@ use crate::{ GroupType, }; use leo_asg::{Expression, Span}; -use std::sync::Arc; use snarkvm_models::{ curves::PrimeField, gadgets::{r1cs::ConstraintSystem, utilities::boolean::Boolean}, }; -impl> ConstrainedProgram { +impl<'a, F: PrimeField, G: GroupType> ConstrainedProgram<'a, F, G> { pub fn evaluate_console_assert>( &mut self, cs: &mut CS, indicator: &Boolean, - expression: &Arc, + expression: &'a Expression<'a>, span: &Span, ) -> Result<(), ConsoleError> { // Evaluate assert expression diff --git a/compiler/src/console/console.rs b/compiler/src/console/console.rs index 010b0bf5cf..2603cd8db7 100644 --- a/compiler/src/console/console.rs +++ b/compiler/src/console/console.rs @@ -24,16 +24,21 @@ use snarkvm_models::{ gadgets::{r1cs::ConstraintSystem, utilities::boolean::Boolean}, }; -impl> ConstrainedProgram { +impl<'a, F: PrimeField, G: GroupType> ConstrainedProgram<'a, F, G> { pub fn evaluate_console_function_call>( &mut self, cs: &mut CS, indicator: &Boolean, - console: &ConsoleStatement, + console: &ConsoleStatement<'a>, ) -> Result<(), ConsoleError> { match &console.function { ConsoleFunction::Assert(expression) => { - self.evaluate_console_assert(cs, indicator, expression, &console.span.clone().unwrap_or_default())?; + self.evaluate_console_assert( + cs, + indicator, + expression.get(), + &console.span.clone().unwrap_or_default(), + )?; } ConsoleFunction::Debug(string) => { let string = self.format(cs, string)?; diff --git a/compiler/src/console/format.rs b/compiler/src/console/format.rs index 2f42776233..048303a023 100644 --- a/compiler/src/console/format.rs +++ b/compiler/src/console/format.rs @@ -21,11 +21,11 @@ use leo_asg::FormattedString; use snarkvm_models::{curves::PrimeField, gadgets::r1cs::ConstraintSystem}; -impl> ConstrainedProgram { +impl<'a, F: PrimeField, G: GroupType> ConstrainedProgram<'a, F, G> { pub fn format>( &mut self, cs: &mut CS, - formatted: &FormattedString, + formatted: &FormattedString<'a>, ) -> Result { // Check that containers and parameters match if formatted.containers.len() != formatted.parameters.len() { @@ -47,7 +47,7 @@ impl> ConstrainedProgram { let mut result = string.to_string(); for parameter in formatted.parameters.iter() { - let parameter_value = self.enforce_expression(cs, parameter)?; + let parameter_value = self.enforce_expression(cs, parameter.get())?; result = result.replacen("{}", ¶meter_value.to_string(), 1); } diff --git a/compiler/src/constraints/constraints.rs b/compiler/src/constraints/constraints.rs index b97c0a2da6..412d8e30cc 100644 --- a/compiler/src/constraints/constraints.rs +++ b/compiler/src/constraints/constraints.rs @@ -28,16 +28,16 @@ use snarkvm_models::{ }; use std::path::Path; -pub fn generate_constraints, CS: ConstraintSystem>( +pub fn generate_constraints<'a, F: PrimeField, G: GroupType, CS: ConstraintSystem>( cs: &mut CS, - asg: &Asg, + asg: &Asg<'a>, input: &Input, ) -> Result { let program = asg.as_repr(); let mut resolved_program = ConstrainedProgram::::new(program.clone()); let main = { - let program = program.borrow(); + let program = program; program.functions.get("main").cloned() }; @@ -50,20 +50,19 @@ pub fn generate_constraints, CS: ConstraintSystem } } -pub fn generate_test_constraints>( - asg: &Asg, +pub fn generate_test_constraints<'a, F: PrimeField, G: GroupType>( + asg: &Asg<'a>, input: InputPairs, main_file_path: &Path, output_directory: &Path, ) -> Result<(u32, u32), CompilerError> { let program = asg.as_repr(); let mut resolved_program = ConstrainedProgram::::new(program.clone()); - let program_name = program.borrow().name.clone(); + let program_name = program.name.clone(); // Get default input let default = input.pairs.get(&program_name); - let program = program.borrow(); let tests = &program.test_functions; tracing::info!("Running {} tests", tests.len()); diff --git a/compiler/src/definition/definition.rs b/compiler/src/definition/definition.rs index 181054c599..75e225b135 100644 --- a/compiler/src/definition/definition.rs +++ b/compiler/src/definition/definition.rs @@ -21,8 +21,8 @@ use leo_asg::Variable; use snarkvm_models::curves::PrimeField; -impl> ConstrainedProgram { - pub fn store_definition(&mut self, variable: &Variable, value: ConstrainedValue) { +impl<'a, F: PrimeField, G: GroupType> ConstrainedProgram<'a, F, G> { + pub fn store_definition(&mut self, variable: &Variable, value: ConstrainedValue<'a, F, G>) { let variable = variable.borrow(); self.store(variable.id, value); diff --git a/compiler/src/expression/arithmetic/add.rs b/compiler/src/expression/arithmetic/add.rs index 33b3171496..e2edaa93c1 100644 --- a/compiler/src/expression/arithmetic/add.rs +++ b/compiler/src/expression/arithmetic/add.rs @@ -21,12 +21,12 @@ use leo_ast::Span; use snarkvm_models::{curves::PrimeField, gadgets::r1cs::ConstraintSystem}; -pub fn enforce_add, CS: ConstraintSystem>( +pub fn enforce_add<'a, F: PrimeField, G: GroupType, CS: ConstraintSystem>( cs: &mut CS, - left: ConstrainedValue, - right: ConstrainedValue, + left: ConstrainedValue<'a, F, G>, + right: ConstrainedValue<'a, F, G>, span: &Span, -) -> Result, ExpressionError> { +) -> Result, ExpressionError> { match (left, right) { (ConstrainedValue::Integer(num_1), ConstrainedValue::Integer(num_2)) => { Ok(ConstrainedValue::Integer(num_1.add(cs, num_2, span)?)) diff --git a/compiler/src/expression/arithmetic/div.rs b/compiler/src/expression/arithmetic/div.rs index 29cf9dcaef..1c8ea36f5d 100644 --- a/compiler/src/expression/arithmetic/div.rs +++ b/compiler/src/expression/arithmetic/div.rs @@ -21,12 +21,12 @@ use leo_ast::Span; use snarkvm_models::{curves::PrimeField, gadgets::r1cs::ConstraintSystem}; -pub fn enforce_div, CS: ConstraintSystem>( +pub fn enforce_div<'a, F: PrimeField, G: GroupType, CS: ConstraintSystem>( cs: &mut CS, - left: ConstrainedValue, - right: ConstrainedValue, + left: ConstrainedValue<'a, F, G>, + right: ConstrainedValue<'a, F, G>, span: &Span, -) -> Result, ExpressionError> { +) -> Result, ExpressionError> { match (left, right) { (ConstrainedValue::Integer(num_1), ConstrainedValue::Integer(num_2)) => { Ok(ConstrainedValue::Integer(num_1.div(cs, num_2, span)?)) diff --git a/compiler/src/expression/arithmetic/mul.rs b/compiler/src/expression/arithmetic/mul.rs index c0acdadd40..0500ae7705 100644 --- a/compiler/src/expression/arithmetic/mul.rs +++ b/compiler/src/expression/arithmetic/mul.rs @@ -21,12 +21,12 @@ use leo_ast::Span; use snarkvm_models::{curves::PrimeField, gadgets::r1cs::ConstraintSystem}; -pub fn enforce_mul, CS: ConstraintSystem>( +pub fn enforce_mul<'a, F: PrimeField, G: GroupType, CS: ConstraintSystem>( cs: &mut CS, - left: ConstrainedValue, - right: ConstrainedValue, + left: ConstrainedValue<'a, F, G>, + right: ConstrainedValue<'a, F, G>, span: &Span, -) -> Result, ExpressionError> { +) -> Result, ExpressionError> { match (left, right) { (ConstrainedValue::Integer(num_1), ConstrainedValue::Integer(num_2)) => { Ok(ConstrainedValue::Integer(num_1.mul(cs, num_2, span)?)) diff --git a/compiler/src/expression/arithmetic/negate.rs b/compiler/src/expression/arithmetic/negate.rs index ee920ba70a..72891e9547 100644 --- a/compiler/src/expression/arithmetic/negate.rs +++ b/compiler/src/expression/arithmetic/negate.rs @@ -21,11 +21,11 @@ use leo_ast::Span; use snarkvm_models::{curves::PrimeField, gadgets::r1cs::ConstraintSystem}; -pub fn enforce_negate, CS: ConstraintSystem>( +pub fn enforce_negate<'a, F: PrimeField, G: GroupType, CS: ConstraintSystem>( cs: &mut CS, - value: ConstrainedValue, + value: ConstrainedValue<'a, F, G>, span: &Span, -) -> Result, ExpressionError> { +) -> Result, ExpressionError> { match value { ConstrainedValue::Integer(integer) => Ok(ConstrainedValue::Integer(integer.negate(cs, span)?)), ConstrainedValue::Field(field) => Ok(ConstrainedValue::Field(field.negate(cs, span)?)), diff --git a/compiler/src/expression/arithmetic/pow.rs b/compiler/src/expression/arithmetic/pow.rs index bcac31609b..3df43fac76 100644 --- a/compiler/src/expression/arithmetic/pow.rs +++ b/compiler/src/expression/arithmetic/pow.rs @@ -21,12 +21,12 @@ use leo_ast::Span; use snarkvm_models::{curves::PrimeField, gadgets::r1cs::ConstraintSystem}; -pub fn enforce_pow, CS: ConstraintSystem>( +pub fn enforce_pow<'a, F: PrimeField, G: GroupType, CS: ConstraintSystem>( cs: &mut CS, - left: ConstrainedValue, - right: ConstrainedValue, + left: ConstrainedValue<'a, F, G>, + right: ConstrainedValue<'a, F, G>, span: &Span, -) -> Result, ExpressionError> { +) -> Result, ExpressionError> { match (left, right) { (ConstrainedValue::Integer(num_1), ConstrainedValue::Integer(num_2)) => { Ok(ConstrainedValue::Integer(num_1.pow(cs, num_2, span)?)) diff --git a/compiler/src/expression/arithmetic/sub.rs b/compiler/src/expression/arithmetic/sub.rs index 14658e0afe..1abac1547a 100644 --- a/compiler/src/expression/arithmetic/sub.rs +++ b/compiler/src/expression/arithmetic/sub.rs @@ -21,12 +21,12 @@ use leo_ast::Span; use snarkvm_models::{curves::PrimeField, gadgets::r1cs::ConstraintSystem}; -pub fn enforce_sub, CS: ConstraintSystem>( +pub fn enforce_sub<'a, F: PrimeField, G: GroupType, CS: ConstraintSystem>( cs: &mut CS, - left: ConstrainedValue, - right: ConstrainedValue, + left: ConstrainedValue<'a, F, G>, + right: ConstrainedValue<'a, F, G>, span: &Span, -) -> Result, ExpressionError> { +) -> Result, ExpressionError> { match (left, right) { (ConstrainedValue::Integer(num_1), ConstrainedValue::Integer(num_2)) => { Ok(ConstrainedValue::Integer(num_1.sub(cs, num_2, span)?)) diff --git a/compiler/src/expression/array/access.rs b/compiler/src/expression/array/access.rs index 87f15fba3e..8a9469ca7a 100644 --- a/compiler/src/expression/array/access.rs +++ b/compiler/src/expression/array/access.rs @@ -18,19 +18,18 @@ use crate::{errors::ExpressionError, program::ConstrainedProgram, value::ConstrainedValue, GroupType}; use leo_asg::{Expression, Span}; -use std::sync::Arc; use snarkvm_models::{curves::PrimeField, gadgets::r1cs::ConstraintSystem}; -impl> ConstrainedProgram { +impl<'a, F: PrimeField, G: GroupType> ConstrainedProgram<'a, F, G> { #[allow(clippy::too_many_arguments)] pub fn enforce_array_access>( &mut self, cs: &mut CS, - array: &Arc, - index: &Arc, + array: &'a Expression<'a>, + index: &'a Expression<'a>, span: &Span, - ) -> Result, ExpressionError> { + ) -> Result, ExpressionError> { let array = match self.enforce_expression(cs, array)? { ConstrainedValue::Array(array) => array, value => return Err(ExpressionError::undefined_array(value.to_string(), span.to_owned())), @@ -44,21 +43,21 @@ impl> ConstrainedProgram { pub fn enforce_array_range_access>( &mut self, cs: &mut CS, - array: &Arc, - left: Option<&Arc>, - right: Option<&Arc>, + array: &'a Expression<'a>, + left: Option<&'a Expression<'a>>, + right: Option<&'a Expression<'a>>, span: &Span, - ) -> Result, ExpressionError> { + ) -> Result, ExpressionError> { let array = match self.enforce_expression(cs, array)? { ConstrainedValue::Array(array) => array, value => return Err(ExpressionError::undefined_array(value.to_string(), span.to_owned())), }; - let from_resolved = match left.as_deref() { + let from_resolved = match left { Some(from_index) => self.enforce_index(cs, from_index, span)?, None => 0usize, // Array slice starts at index 0 }; - let to_resolved = match right.as_deref() { + let to_resolved = match right { Some(to_index) => self.enforce_index(cs, to_index, span)?, None => array.len(), // Array slice ends at array length }; diff --git a/compiler/src/expression/array/array.rs b/compiler/src/expression/array/array.rs index 5ec2b52385..04c60441b3 100644 --- a/compiler/src/expression/array/array.rs +++ b/compiler/src/expression/array/array.rs @@ -16,25 +16,26 @@ //! Enforces an array expression in a compiled Leo program. +use std::cell::Cell; + use crate::{errors::ExpressionError, program::ConstrainedProgram, value::ConstrainedValue, GroupType}; use leo_asg::{Expression, Span}; -use std::sync::Arc; use snarkvm_models::{curves::PrimeField, gadgets::r1cs::ConstraintSystem}; -impl> ConstrainedProgram { +impl<'a, F: PrimeField, G: GroupType> ConstrainedProgram<'a, F, G> { /// Enforce array expressions pub fn enforce_array>( &mut self, cs: &mut CS, - array: &[(Arc, bool)], + array: &[(Cell<&'a Expression<'a>>, bool)], span: Span, - ) -> Result, ExpressionError> { + ) -> Result, ExpressionError> { let expected_dimension = None; let mut result = vec![]; for (element, is_spread) in array.iter() { - let element_value = self.enforce_expression(cs, element)?; + let element_value = self.enforce_expression(cs, element.get())?; if *is_spread { match element_value { ConstrainedValue::Array(array) => result.extend(array), @@ -63,9 +64,9 @@ impl> ConstrainedProgram { pub fn enforce_array_initializer>( &mut self, cs: &mut CS, - element_expression: &Arc, + element_expression: &'a Expression<'a>, actual_size: usize, - ) -> Result, ExpressionError> { + ) -> Result, ExpressionError> { let mut value = self.enforce_expression(cs, element_expression)?; // Allocate the array. diff --git a/compiler/src/expression/array/index.rs b/compiler/src/expression/array/index.rs index 67be2bc902..4b6bb27c6e 100644 --- a/compiler/src/expression/array/index.rs +++ b/compiler/src/expression/array/index.rs @@ -18,15 +18,14 @@ use crate::{errors::ExpressionError, program::ConstrainedProgram, value::ConstrainedValue, GroupType}; use leo_asg::{Expression, Span}; -use std::sync::Arc; use snarkvm_models::{curves::PrimeField, gadgets::r1cs::ConstraintSystem}; -impl> ConstrainedProgram { +impl<'a, F: PrimeField, G: GroupType> ConstrainedProgram<'a, F, G> { pub(crate) fn enforce_index>( &mut self, cs: &mut CS, - index: &Arc, + index: &'a Expression<'a>, span: &Span, ) -> Result { match self.enforce_expression(cs, index)? { diff --git a/compiler/src/expression/binary/binary.rs b/compiler/src/expression/binary/binary.rs index 9aeba58eb1..0509ce90b8 100644 --- a/compiler/src/expression/binary/binary.rs +++ b/compiler/src/expression/binary/binary.rs @@ -18,20 +18,19 @@ use crate::{errors::ExpressionError, program::ConstrainedProgram, value::ConstrainedValue, GroupType}; use leo_asg::Expression; -use std::sync::Arc; use snarkvm_models::{curves::PrimeField, gadgets::r1cs::ConstraintSystem}; -type ConstrainedValuePair = (ConstrainedValue, ConstrainedValue); +type ConstrainedValuePair<'a, T, U> = (ConstrainedValue<'a, T, U>, ConstrainedValue<'a, T, U>); -impl> ConstrainedProgram { +impl<'a, F: PrimeField, G: GroupType> ConstrainedProgram<'a, F, G> { #[allow(clippy::too_many_arguments)] pub fn enforce_binary_expression>( &mut self, cs: &mut CS, - left: &Arc, - right: &Arc, - ) -> Result, ExpressionError> { + left: &'a Expression<'a>, + right: &'a Expression<'a>, + ) -> Result, ExpressionError> { let resolved_left = self.enforce_expression(cs, left)?; let resolved_right = self.enforce_expression(cs, right)?; diff --git a/compiler/src/expression/circuit/access.rs b/compiler/src/expression/circuit/access.rs index 672d3656c7..09d3327ca1 100644 --- a/compiler/src/expression/circuit/access.rs +++ b/compiler/src/expression/circuit/access.rs @@ -21,24 +21,24 @@ use leo_asg::{CircuitAccessExpression, Node}; use snarkvm_models::{curves::PrimeField, gadgets::r1cs::ConstraintSystem}; -impl> ConstrainedProgram { +impl<'a, F: PrimeField, G: GroupType> ConstrainedProgram<'a, F, G> { #[allow(clippy::too_many_arguments)] pub fn enforce_circuit_access>( &mut self, cs: &mut CS, - expr: &CircuitAccessExpression, - ) -> Result, ExpressionError> { - if let Some(target) = &expr.target { + expr: &CircuitAccessExpression<'a>, + ) -> Result, ExpressionError> { + if let Some(target) = expr.target.get() { //todo: we can prob pass values by ref here to avoid copying the entire circuit on access let target_value = self.enforce_expression(cs, target)?; match target_value { ConstrainedValue::CircuitExpression(def, members) => { - assert!(def.circuit == expr.circuit); + assert!(def == expr.circuit.get()); if let Some(member) = members.into_iter().find(|x| x.0.name == expr.member.name) { Ok(member.1) } else { Err(ExpressionError::undefined_member_access( - expr.circuit.name.borrow().to_string(), + expr.circuit.get().name.borrow().to_string(), expr.member.to_string(), expr.member.span.clone(), )) diff --git a/compiler/src/expression/circuit/circuit.rs b/compiler/src/expression/circuit/circuit.rs index 4b8942c462..fba53cab5b 100644 --- a/compiler/src/expression/circuit/circuit.rs +++ b/compiler/src/expression/circuit/circuit.rs @@ -22,23 +22,18 @@ use crate::{ value::{ConstrainedCircuitMember, ConstrainedValue}, GroupType, }; -use leo_asg::{CircuitInitExpression, CircuitMemberBody, Span}; +use leo_asg::{CircuitInitExpression, CircuitMember, Span}; use snarkvm_models::{curves::PrimeField, gadgets::r1cs::ConstraintSystem}; -impl> ConstrainedProgram { +impl<'a, F: PrimeField, G: GroupType> ConstrainedProgram<'a, F, G> { pub fn enforce_circuit>( &mut self, cs: &mut CS, - expr: &CircuitInitExpression, + expr: &CircuitInitExpression<'a>, span: &Span, - ) -> Result, ExpressionError> { - let circuit = expr - .circuit - .body - .borrow() - .upgrade() - .expect("circuit init stale circuit ref"); + ) -> Result, ExpressionError> { + let circuit = expr.circuit.get(); let members = circuit.members.borrow(); let mut resolved_members = Vec::with_capacity(members.len()); @@ -49,15 +44,15 @@ impl> ConstrainedProgram { .get(&name.name) .expect("illegal name in asg circuit init expression"); match target { - CircuitMemberBody::Variable(_type_) => { - let variable_value = self.enforce_expression(cs, inner)?; + CircuitMember::Variable(_type_) => { + let variable_value = self.enforce_expression(cs, inner.get())?; resolved_members.push(ConstrainedCircuitMember(name.clone(), variable_value)); } _ => return Err(ExpressionError::expected_circuit_member(name.to_string(), span.clone())), } } - let value = ConstrainedValue::CircuitExpression(circuit.clone(), resolved_members); + let value = ConstrainedValue::CircuitExpression(circuit, resolved_members); Ok(value) } } diff --git a/compiler/src/expression/conditional/conditional.rs b/compiler/src/expression/conditional/conditional.rs index 3deb2cc47a..aa515d5346 100644 --- a/compiler/src/expression/conditional/conditional.rs +++ b/compiler/src/expression/conditional/conditional.rs @@ -18,24 +18,23 @@ use crate::{errors::ExpressionError, program::ConstrainedProgram, value::ConstrainedValue, GroupType}; use leo_asg::{Expression, Span}; -use std::sync::Arc; use snarkvm_models::{ curves::PrimeField, gadgets::{r1cs::ConstraintSystem, utilities::select::CondSelectGadget}, }; -impl> ConstrainedProgram { +impl<'a, F: PrimeField, G: GroupType> ConstrainedProgram<'a, F, G> { /// Enforce ternary conditional expression #[allow(clippy::too_many_arguments)] pub fn enforce_conditional_expression>( &mut self, cs: &mut CS, - conditional: &Arc, - first: &Arc, - second: &Arc, + conditional: &'a Expression<'a>, + first: &'a Expression<'a>, + second: &'a Expression<'a>, span: &Span, - ) -> Result, ExpressionError> { + ) -> Result, ExpressionError> { let conditional_value = match self.enforce_expression(cs, conditional)? { ConstrainedValue::Boolean(resolved) => resolved, value => return Err(ExpressionError::conditional_boolean(value.to_string(), span.to_owned())), diff --git a/compiler/src/expression/expression.rs b/compiler/src/expression/expression.rs index 261e97e053..e399d090ae 100644 --- a/compiler/src/expression/expression.rs +++ b/compiler/src/expression/expression.rs @@ -28,21 +28,20 @@ use crate::{ GroupType, }; use leo_asg::{expression::*, ConstValue, Expression, Node}; -use std::sync::Arc; use snarkvm_models::{ curves::PrimeField, gadgets::{r1cs::ConstraintSystem, utilities::boolean::Boolean}, }; -impl> ConstrainedProgram { +impl<'a, F: PrimeField, G: GroupType> ConstrainedProgram<'a, F, G> { pub(crate) fn enforce_expression>( &mut self, cs: &mut CS, - expression: &Arc, - ) -> Result, ExpressionError> { + expression: &'a Expression<'a>, + ) -> Result, ExpressionError> { let span = expression.span().cloned().unwrap_or_default(); - match &**expression { + match expression { // Variables Expression::VariableRef(variable_ref) => self.evaluate_ref(variable_ref), @@ -62,7 +61,7 @@ impl> ConstrainedProgram { Expression::Binary(BinaryExpression { left, right, operation, .. }) => { - let (resolved_left, resolved_right) = self.enforce_binary_expression(cs, left, right)?; + let (resolved_left, resolved_right) = self.enforce_binary_expression(cs, left.get(), right.get())?; match operation { BinaryOperation::Add => enforce_add(cs, resolved_left, resolved_right, &span), @@ -89,10 +88,10 @@ impl> ConstrainedProgram { // Unary operations Expression::Unary(UnaryExpression { inner, operation, .. }) => match operation { UnaryOperation::Negate => { - let resolved_inner = self.enforce_expression(cs, inner)?; + let resolved_inner = self.enforce_expression(cs, inner.get())?; enforce_negate(cs, resolved_inner, &span) } - UnaryOperation::Not => Ok(evaluate_not(self.enforce_expression(cs, inner)?, &span)?), + UnaryOperation::Not => Ok(evaluate_not(self.enforce_expression(cs, inner.get())?, &span)?), }, Expression::Ternary(TernaryExpression { @@ -100,24 +99,26 @@ impl> ConstrainedProgram { if_true, if_false, .. - }) => self.enforce_conditional_expression(cs, condition, if_true, if_false, &span), + }) => self.enforce_conditional_expression(cs, condition.get(), if_true.get(), if_false.get(), &span), // Arrays - Expression::ArrayInline(ArrayInlineExpression { elements, .. }) => self.enforce_array(cs, elements, span), + Expression::ArrayInline(ArrayInlineExpression { elements, .. }) => { + self.enforce_array(cs, &elements[..], span) + } Expression::ArrayInit(ArrayInitExpression { element, len, .. }) => { - self.enforce_array_initializer(cs, element, *len) + self.enforce_array_initializer(cs, element.get(), *len) } Expression::ArrayAccess(ArrayAccessExpression { array, index, .. }) => { - self.enforce_array_access(cs, array, index, &span) + self.enforce_array_access(cs, array.get(), index.get(), &span) } Expression::ArrayRangeAccess(ArrayRangeAccessExpression { array, left, right, .. }) => { - self.enforce_array_range_access(cs, array, left.as_ref(), right.as_ref(), &span) + self.enforce_array_range_access(cs, array.get(), left.get(), right.get(), &span) } // Tuples - Expression::TupleInit(TupleInitExpression { elements, .. }) => self.enforce_tuple(cs, elements), + Expression::TupleInit(TupleInitExpression { elements, .. }) => self.enforce_tuple(cs, &elements[..]), Expression::TupleAccess(TupleAccessExpression { tuple_ref, index, .. }) => { - self.enforce_tuple_access(cs, tuple_ref, *index, &span) + self.enforce_tuple_access(cs, tuple_ref.get(), *index, &span) } // Circuits @@ -131,26 +132,21 @@ impl> ConstrainedProgram { arguments, .. }) => { - if let Some(circuit) = function - .circuit - .borrow() - .as_ref() - .map(|x| x.upgrade().expect("stale circuit for member function")) - { + if let Some(circuit) = function.get().circuit.get() { let core_mapping = circuit.core_mapping.borrow(); if let Some(core_mapping) = core_mapping.as_deref() { let core_circuit = resolve_core_circuit::(core_mapping); return self.enforce_core_circuit_call_expression( cs, &core_circuit, - &function, - target.as_ref(), - arguments, + function.get(), + target.get(), + &arguments[..], &span, ); } } - self.enforce_function_call_expression(cs, &function, target.as_ref(), arguments, &span) + self.enforce_function_call_expression(cs, function.get(), target.get(), &arguments[..], &span) } } } diff --git a/compiler/src/expression/function/core_circuit.rs b/compiler/src/expression/function/core_circuit.rs index 7de16f7fa6..0658d6c8e6 100644 --- a/compiler/src/expression/function/core_circuit.rs +++ b/compiler/src/expression/function/core_circuit.rs @@ -13,31 +13,27 @@ // You should have received a copy of the GNU General Public License // along with the Leo library. If not, see . + +use std::cell::Cell; + use crate::{program::ConstrainedProgram, value::ConstrainedValue, CoreCircuit, GroupType}; use crate::errors::ExpressionError; use leo_asg::{Expression, Function, Span}; use snarkvm_models::{curves::PrimeField, gadgets::r1cs::ConstraintSystem}; -use std::sync::Arc; -impl> ConstrainedProgram { +impl<'a, F: PrimeField, G: GroupType> ConstrainedProgram<'a, F, G> { /// Call a default core circuit function with arguments #[allow(clippy::too_many_arguments)] - pub fn enforce_core_circuit_call_expression, C: CoreCircuit>( + pub fn enforce_core_circuit_call_expression, C: CoreCircuit<'a, F, G>>( &mut self, cs: &mut CS, core_circuit: &C, - function: &Arc, - target: Option<&Arc>, - arguments: &[Arc], + function: &'a Function<'a>, + target: Option<&'a Expression<'a>>, + arguments: &[Cell<&'a Expression<'a>>], span: &Span, - ) -> Result, ExpressionError> { - let function = function - .body - .borrow() - .upgrade() - .expect("stale function in call expression"); - + ) -> Result, ExpressionError> { let target_value = if let Some(target) = target { Some(self.enforce_expression(cs, target)?) } else { @@ -47,7 +43,7 @@ impl> ConstrainedProgram { // Get the value of each core function argument let arguments = arguments .iter() - .map(|argument| self.enforce_expression(cs, argument)) + .map(|argument| self.enforce_expression(cs, argument.get())) .collect::, _>>()?; // Call the core function diff --git a/compiler/src/expression/function/function.rs b/compiler/src/expression/function/function.rs index 3ffe6ad0ed..ba1888c4c6 100644 --- a/compiler/src/expression/function/function.rs +++ b/compiler/src/expression/function/function.rs @@ -16,22 +16,23 @@ //! Enforce a function call expression in a compiled Leo program. +use std::cell::Cell; + use crate::{errors::ExpressionError, program::ConstrainedProgram, value::ConstrainedValue, GroupType}; use leo_asg::{Expression, Function, Span}; -use std::sync::Arc; use snarkvm_models::{curves::PrimeField, gadgets::r1cs::ConstraintSystem}; -impl> ConstrainedProgram { +impl<'a, F: PrimeField, G: GroupType> ConstrainedProgram<'a, F, G> { #[allow(clippy::too_many_arguments)] pub fn enforce_function_call_expression>( &mut self, cs: &mut CS, - function: &Arc, - target: Option<&Arc>, - arguments: &[Arc], + function: &'a Function<'a>, + target: Option<&'a Expression<'a>>, + arguments: &[Cell<&'a Expression<'a>>], span: &Span, - ) -> Result, ExpressionError> { + ) -> Result, ExpressionError> { let name_unique = || { format!( "function call {} {}:{}", @@ -40,11 +41,6 @@ impl> ConstrainedProgram { span.start, ) }; - let function = function - .body - .borrow() - .upgrade() - .expect("stale function in call expression"); let return_value = self .enforce_function(&mut cs.ns(name_unique), &function, target, arguments) diff --git a/compiler/src/expression/logical/and.rs b/compiler/src/expression/logical/and.rs index a7c1613a27..0d8ef3773d 100644 --- a/compiler/src/expression/logical/and.rs +++ b/compiler/src/expression/logical/and.rs @@ -24,12 +24,12 @@ use snarkvm_models::{ gadgets::{r1cs::ConstraintSystem, utilities::boolean::Boolean}, }; -pub fn enforce_and, CS: ConstraintSystem>( +pub fn enforce_and<'a, F: PrimeField, G: GroupType, CS: ConstraintSystem>( cs: &mut CS, - left: ConstrainedValue, - right: ConstrainedValue, + left: ConstrainedValue<'a, F, G>, + right: ConstrainedValue<'a, F, G>, span: &Span, -) -> Result, BooleanError> { +) -> Result, BooleanError> { let name = format!("{} && {}", left, right); if let (ConstrainedValue::Boolean(left_bool), ConstrainedValue::Boolean(right_bool)) = (left, right) { diff --git a/compiler/src/expression/logical/not.rs b/compiler/src/expression/logical/not.rs index 59d0842695..84d16c5be8 100644 --- a/compiler/src/expression/logical/not.rs +++ b/compiler/src/expression/logical/not.rs @@ -21,10 +21,10 @@ use leo_asg::Span; use snarkvm_models::curves::PrimeField; -pub fn evaluate_not>( - value: ConstrainedValue, +pub fn evaluate_not<'a, F: PrimeField, G: GroupType>( + value: ConstrainedValue<'a, F, G>, span: &Span, -) -> Result, BooleanError> { +) -> Result, BooleanError> { match value { ConstrainedValue::Boolean(boolean) => Ok(ConstrainedValue::Boolean(boolean.not())), value => Err(BooleanError::cannot_evaluate(format!("!{}", value), span.clone())), diff --git a/compiler/src/expression/logical/or.rs b/compiler/src/expression/logical/or.rs index d1df2fab88..ec1739d085 100644 --- a/compiler/src/expression/logical/or.rs +++ b/compiler/src/expression/logical/or.rs @@ -24,12 +24,12 @@ use snarkvm_models::{ gadgets::{r1cs::ConstraintSystem, utilities::boolean::Boolean}, }; -pub fn enforce_or, CS: ConstraintSystem>( +pub fn enforce_or<'a, F: PrimeField, G: GroupType, CS: ConstraintSystem>( cs: &mut CS, - left: ConstrainedValue, - right: ConstrainedValue, + left: ConstrainedValue<'a, F, G>, + right: ConstrainedValue<'a, F, G>, span: &Span, -) -> Result, BooleanError> { +) -> Result, BooleanError> { let name = format!("{} || {}", left, right); if let (ConstrainedValue::Boolean(left_bool), ConstrainedValue::Boolean(right_bool)) = (left, right) { diff --git a/compiler/src/expression/relational/eq.rs b/compiler/src/expression/relational/eq.rs index 2df94b159b..99dfe39e42 100644 --- a/compiler/src/expression/relational/eq.rs +++ b/compiler/src/expression/relational/eq.rs @@ -27,12 +27,12 @@ use snarkvm_models::{ }, }; -pub fn evaluate_eq, CS: ConstraintSystem>( +pub fn evaluate_eq<'a, F: PrimeField, G: GroupType, CS: ConstraintSystem>( cs: &mut CS, - left: ConstrainedValue, - right: ConstrainedValue, + left: ConstrainedValue<'a, F, G>, + right: ConstrainedValue<'a, F, G>, span: &Span, -) -> Result, ExpressionError> { +) -> Result, ExpressionError> { let namespace_string = format!("evaluate {} == {} {}:{}", left, right, span.line, span.start); let constraint_result = match (left, right) { (ConstrainedValue::Address(address_1), ConstrainedValue::Address(address_2)) => { diff --git a/compiler/src/expression/relational/ge.rs b/compiler/src/expression/relational/ge.rs index 9664f9bbd6..0788633554 100644 --- a/compiler/src/expression/relational/ge.rs +++ b/compiler/src/expression/relational/ge.rs @@ -22,12 +22,12 @@ use leo_gadgets::bits::ComparatorGadget; use snarkvm_models::{curves::PrimeField, gadgets::r1cs::ConstraintSystem}; -pub fn evaluate_ge, CS: ConstraintSystem>( +pub fn evaluate_ge<'a, F: PrimeField, G: GroupType, CS: ConstraintSystem>( cs: &mut CS, - left: ConstrainedValue, - right: ConstrainedValue, + left: ConstrainedValue<'a, F, G>, + right: ConstrainedValue<'a, F, G>, span: &Span, -) -> Result, ExpressionError> { +) -> Result, ExpressionError> { let unique_namespace = cs.ns(|| format!("evaluate {} >= {} {}:{}", left, right, span.line, span.start)); let constraint_result = match (left, right) { (ConstrainedValue::Integer(num_1), ConstrainedValue::Integer(num_2)) => { diff --git a/compiler/src/expression/relational/gt.rs b/compiler/src/expression/relational/gt.rs index f596e4cf64..8590e17b6c 100644 --- a/compiler/src/expression/relational/gt.rs +++ b/compiler/src/expression/relational/gt.rs @@ -22,12 +22,12 @@ use leo_gadgets::bits::ComparatorGadget; use snarkvm_models::{curves::PrimeField, gadgets::r1cs::ConstraintSystem}; -pub fn evaluate_gt, CS: ConstraintSystem>( +pub fn evaluate_gt<'a, F: PrimeField, G: GroupType, CS: ConstraintSystem>( cs: &mut CS, - left: ConstrainedValue, - right: ConstrainedValue, + left: ConstrainedValue<'a, F, G>, + right: ConstrainedValue<'a, F, G>, span: &Span, -) -> Result, ExpressionError> { +) -> Result, ExpressionError> { let unique_namespace = cs.ns(|| format!("evaluate {} > {} {}:{}", left, right, span.line, span.start)); let constraint_result = match (left, right) { (ConstrainedValue::Integer(num_1), ConstrainedValue::Integer(num_2)) => { diff --git a/compiler/src/expression/relational/le.rs b/compiler/src/expression/relational/le.rs index e824a4fb55..38ba4ba6d1 100644 --- a/compiler/src/expression/relational/le.rs +++ b/compiler/src/expression/relational/le.rs @@ -22,12 +22,12 @@ use leo_gadgets::bits::ComparatorGadget; use snarkvm_models::{curves::PrimeField, gadgets::r1cs::ConstraintSystem}; -pub fn evaluate_le, CS: ConstraintSystem>( +pub fn evaluate_le<'a, F: PrimeField, G: GroupType, CS: ConstraintSystem>( cs: &mut CS, - left: ConstrainedValue, - right: ConstrainedValue, + left: ConstrainedValue<'a, F, G>, + right: ConstrainedValue<'a, F, G>, span: &Span, -) -> Result, ExpressionError> { +) -> Result, ExpressionError> { let unique_namespace = cs.ns(|| format!("evaluate {} <= {} {}:{}", left, right, span.line, span.start)); let constraint_result = match (left, right) { (ConstrainedValue::Integer(num_1), ConstrainedValue::Integer(num_2)) => { diff --git a/compiler/src/expression/relational/lt.rs b/compiler/src/expression/relational/lt.rs index a4d2b4f746..5282f60335 100644 --- a/compiler/src/expression/relational/lt.rs +++ b/compiler/src/expression/relational/lt.rs @@ -22,12 +22,12 @@ use leo_gadgets::bits::comparator::EvaluateLtGadget; use snarkvm_models::{curves::PrimeField, gadgets::r1cs::ConstraintSystem}; -pub fn evaluate_lt, CS: ConstraintSystem>( +pub fn evaluate_lt<'a, F: PrimeField, G: GroupType, CS: ConstraintSystem>( cs: &mut CS, - left: ConstrainedValue, - right: ConstrainedValue, + left: ConstrainedValue<'a, F, G>, + right: ConstrainedValue<'a, F, G>, span: &Span, -) -> Result, ExpressionError> { +) -> Result, ExpressionError> { let unique_namespace = cs.ns(|| format!("evaluate {} < {} {}:{}", left, right, span.line, span.start)); let constraint_result = match (left, right) { (ConstrainedValue::Integer(num_1), ConstrainedValue::Integer(num_2)) => { diff --git a/compiler/src/expression/tuple/access.rs b/compiler/src/expression/tuple/access.rs index ccca481aad..1f219a4d9a 100644 --- a/compiler/src/expression/tuple/access.rs +++ b/compiler/src/expression/tuple/access.rs @@ -18,19 +18,18 @@ use crate::{errors::ExpressionError, program::ConstrainedProgram, value::ConstrainedValue, GroupType}; use leo_asg::{Expression, Span}; -use std::sync::Arc; use snarkvm_models::{curves::PrimeField, gadgets::r1cs::ConstraintSystem}; -impl> ConstrainedProgram { +impl<'a, F: PrimeField, G: GroupType> ConstrainedProgram<'a, F, G> { #[allow(clippy::too_many_arguments)] pub fn enforce_tuple_access>( &mut self, cs: &mut CS, - tuple: &Arc, + tuple: &'a Expression<'a>, index: usize, span: &Span, - ) -> Result, ExpressionError> { + ) -> Result, ExpressionError> { // Get the tuple values. let tuple = match self.enforce_expression(cs, tuple)? { ConstrainedValue::Tuple(tuple) => tuple, diff --git a/compiler/src/expression/tuple/tuple.rs b/compiler/src/expression/tuple/tuple.rs index 6884f50e88..4759d0fef1 100644 --- a/compiler/src/expression/tuple/tuple.rs +++ b/compiler/src/expression/tuple/tuple.rs @@ -16,22 +16,23 @@ //! Enforces an tuple expression in a compiled Leo program. +use std::cell::Cell; + use crate::{errors::ExpressionError, program::ConstrainedProgram, value::ConstrainedValue, GroupType}; use leo_asg::Expression; -use std::sync::Arc; use snarkvm_models::{curves::PrimeField, gadgets::r1cs::ConstraintSystem}; -impl> ConstrainedProgram { +impl<'a, F: PrimeField, G: GroupType> ConstrainedProgram<'a, F, G> { /// Enforce tuple expressions pub fn enforce_tuple>( &mut self, cs: &mut CS, - tuple: &[Arc], - ) -> Result, ExpressionError> { + tuple: &[Cell<&'a Expression<'a>>], + ) -> Result, ExpressionError> { let mut result = Vec::with_capacity(tuple.len()); for expression in tuple.iter() { - result.push(self.enforce_expression(cs, expression)?); + result.push(self.enforce_expression(cs, expression.get())?); } Ok(ConstrainedValue::Tuple(result)) diff --git a/compiler/src/expression/variable_ref/variable_ref.rs b/compiler/src/expression/variable_ref/variable_ref.rs index 54de37daf5..07951bbbdc 100644 --- a/compiler/src/expression/variable_ref/variable_ref.rs +++ b/compiler/src/expression/variable_ref/variable_ref.rs @@ -21,9 +21,9 @@ use leo_asg::VariableRef; use snarkvm_models::curves::PrimeField; -impl> ConstrainedProgram { +impl<'a, F: PrimeField, G: GroupType> ConstrainedProgram<'a, F, G> { /// Enforce a variable expression by getting the resolved value - pub fn evaluate_ref(&mut self, variable_ref: &VariableRef) -> Result, ExpressionError> { + pub fn evaluate_ref(&mut self, variable_ref: &VariableRef) -> Result, ExpressionError> { // Evaluate the identifier name in the current function scope let variable = variable_ref.variable.borrow(); let result_value = if let Some(value) = self.get(&variable.id) { diff --git a/compiler/src/function/function.rs b/compiler/src/function/function.rs index 434f2b64e0..acde765b47 100644 --- a/compiler/src/function/function.rs +++ b/compiler/src/function/function.rs @@ -18,22 +18,22 @@ use crate::{errors::FunctionError, program::ConstrainedProgram, value::ConstrainedValue, GroupType}; -use leo_asg::{Expression, FunctionBody, FunctionQualifier}; -use std::sync::Arc; +use leo_asg::{Expression, Function, FunctionQualifier}; +use std::cell::Cell; use snarkvm_models::{ curves::PrimeField, gadgets::{r1cs::ConstraintSystem, utilities::boolean::Boolean}, }; -impl> ConstrainedProgram { +impl<'a, F: PrimeField, G: GroupType> ConstrainedProgram<'a, F, G> { pub(crate) fn enforce_function>( &mut self, cs: &mut CS, - function: &Arc, - target: Option<&Arc>, - arguments: &[Arc], - ) -> Result, FunctionError> { + function: &'a Function<'a>, + target: Option<&'a Expression<'a>>, + arguments: &[Cell<&'a Expression<'a>>], + ) -> Result, FunctionError> { let target_value = if let Some(target) = target { Some(self.enforce_expression(cs, target)?) } else { @@ -43,7 +43,6 @@ impl> ConstrainedProgram { let self_var = if let Some(target) = &target_value { let self_var = function .scope - .borrow() .resolve_variable("self") .expect("attempted to call static function from non-static context"); self.store(self_var.borrow().id, target.clone()); @@ -52,7 +51,7 @@ impl> ConstrainedProgram { None }; - if function.function.arguments.len() != arguments.len() { + if function.arguments.len() != arguments.len() { return Err(FunctionError::input_not_found( "arguments length invalid".to_string(), function.span.clone().unwrap_or_default(), @@ -60,9 +59,9 @@ impl> ConstrainedProgram { } // Store input values as new variables in resolved program - for (variable, input_expression) in function.function.arguments.iter().zip(arguments.iter()) { - let input_value = self.enforce_expression(cs, input_expression)?; - let variable = variable.borrow(); + for ((_, variable), input_expression) in function.arguments.iter().zip(arguments.iter()) { + let input_value = self.enforce_expression(cs, input_expression.get())?; + let variable = variable.get().borrow(); self.store(variable.id, input_value); } @@ -71,13 +70,17 @@ impl> ConstrainedProgram { let mut results = vec![]; let indicator = Boolean::constant(true); - let output = function.function.output.clone().strong(); + let output = function.output.clone(); - let mut result = self.enforce_statement(cs, &indicator, &function.body)?; + let mut result = self.enforce_statement( + cs, + &indicator, + function.body.get().expect("attempted to call function header"), + )?; results.append(&mut result); - if function.function.qualifier == FunctionQualifier::MutSelfRef { + if function.qualifier == FunctionQualifier::MutSelfRef { if let (Some(self_var), Some(target)) = (self_var, target) { let new_self = self .get(&self_var.borrow().id) diff --git a/compiler/src/function/input/array.rs b/compiler/src/function/input/array.rs index 7aabc3e93c..ebf3bca676 100644 --- a/compiler/src/function/input/array.rs +++ b/compiler/src/function/input/array.rs @@ -23,7 +23,7 @@ use leo_ast::{InputValue, Span}; use snarkvm_models::{curves::PrimeField, gadgets::r1cs::ConstraintSystem}; -impl> ConstrainedProgram { +impl<'a, F: PrimeField, G: GroupType> ConstrainedProgram<'a, F, G> { pub fn allocate_array>( &mut self, cs: &mut CS, @@ -32,7 +32,7 @@ impl> ConstrainedProgram { array_len: usize, input_value: Option, span: &Span, - ) -> Result, FunctionError> { + ) -> Result, FunctionError> { // Build the array value using the expected types. let mut array_value = vec![]; diff --git a/compiler/src/function/input/input_keyword.rs b/compiler/src/function/input/input_keyword.rs index a26fca1bdd..e7aaed9fd6 100644 --- a/compiler/src/function/input/input_keyword.rs +++ b/compiler/src/function/input/input_keyword.rs @@ -15,9 +15,8 @@ // along with the Leo library. If not, see . use crate::{errors::FunctionError, ConstrainedCircuitMember, ConstrainedProgram, ConstrainedValue, GroupType}; -use leo_asg::{CircuitBody, CircuitMemberBody, Type}; +use leo_asg::{Circuit, CircuitMember, Type}; use leo_ast::{Identifier, Input, Span}; -use std::sync::Arc; use snarkvm_models::{curves::PrimeField, gadgets::r1cs::ConstraintSystem}; @@ -26,14 +25,14 @@ pub const REGISTERS_VARIABLE_NAME: &str = "registers"; pub const STATE_VARIABLE_NAME: &str = "state"; pub const STATE_LEAF_VARIABLE_NAME: &str = "state_leaf"; -impl> ConstrainedProgram { +impl<'a, F: PrimeField, G: GroupType> ConstrainedProgram<'a, F, G> { pub fn allocate_input_keyword>( &mut self, cs: &mut CS, span: Span, - expected_type: &Arc, + expected_type: &'a Circuit<'a>, input: &Input, - ) -> Result, FunctionError> { + ) -> Result, FunctionError> { // Create an identifier for each input variable let registers_name = Identifier { @@ -73,11 +72,7 @@ impl> ConstrainedProgram { for (name, values) in sections { let sub_circuit = match expected_type.members.borrow().get(&name.name) { - Some(CircuitMemberBody::Variable(Type::Circuit(circuit))) => circuit - .body - .borrow() - .upgrade() - .expect("stale circuit body for input subtype"), + Some(CircuitMember::Variable(Type::Circuit(circuit))) => *circuit, _ => panic!("illegal input type definition from asg"), }; @@ -91,6 +86,6 @@ impl> ConstrainedProgram { // Return input variable keyword as circuit expression - Ok(ConstrainedValue::CircuitExpression(expected_type.clone(), members)) + Ok(ConstrainedValue::CircuitExpression(expected_type, members)) } } diff --git a/compiler/src/function/input/input_section.rs b/compiler/src/function/input/input_section.rs index f322d7db7d..65331a3172 100644 --- a/compiler/src/function/input/input_section.rs +++ b/compiler/src/function/input/input_section.rs @@ -15,32 +15,31 @@ // along with the Leo library. If not, see . use crate::{errors::FunctionError, ConstrainedCircuitMember, ConstrainedProgram, ConstrainedValue, GroupType}; -use leo_asg::{AsgConvertError, CircuitBody, CircuitMemberBody}; +use leo_asg::{AsgConvertError, Circuit, CircuitMember}; use leo_ast::{Identifier, InputValue, Parameter}; -use std::sync::Arc; use snarkvm_models::{curves::PrimeField, gadgets::r1cs::ConstraintSystem}; use indexmap::IndexMap; -impl> ConstrainedProgram { +impl<'a, F: PrimeField, G: GroupType> ConstrainedProgram<'a, F, G> { pub fn allocate_input_section>( &mut self, cs: &mut CS, identifier: Identifier, - expected_type: Arc, + expected_type: &'a Circuit<'a>, section: IndexMap>, - ) -> Result, FunctionError> { + ) -> Result, FunctionError> { let mut members = Vec::with_capacity(section.len()); // Allocate each section definition as a circuit member value for (parameter, option) in section.into_iter() { let section_members = expected_type.members.borrow(); let expected_type = match section_members.get(¶meter.variable.name) { - Some(CircuitMemberBody::Variable(inner)) => inner, + Some(CircuitMember::Variable(inner)) => inner, _ => continue, // present, but unused }; - let declared_type = self.asg.borrow().scope.borrow().resolve_ast_type(¶meter.type_)?; + let declared_type = self.asg.scope.resolve_ast_type(¶meter.type_)?; if !expected_type.is_assignable_from(&declared_type) { return Err(AsgConvertError::unexpected_type( &expected_type.to_string(), diff --git a/compiler/src/function/input/main_function_input.rs b/compiler/src/function/input/main_function_input.rs index c93c4eda29..b4e5d540a3 100644 --- a/compiler/src/function/input/main_function_input.rs +++ b/compiler/src/function/input/main_function_input.rs @@ -34,7 +34,7 @@ use leo_asg::Type; use leo_ast::{InputValue, Span}; use snarkvm_models::{curves::PrimeField, gadgets::r1cs::ConstraintSystem}; -impl> ConstrainedProgram { +impl<'a, F: PrimeField, G: GroupType> ConstrainedProgram<'a, F, G> { pub fn allocate_main_function_input>( &mut self, cs: &mut CS, @@ -42,7 +42,7 @@ impl> ConstrainedProgram { name: &str, input_option: Option, span: &Span, - ) -> Result, FunctionError> { + ) -> Result, FunctionError> { match type_ { Type::Address => Ok(Address::from_input(cs, name, input_option, span)?), Type::Boolean => Ok(bool_from_input(cs, name, input_option, span)?), diff --git a/compiler/src/function/input/tuple.rs b/compiler/src/function/input/tuple.rs index ba9974f787..0530a0790f 100644 --- a/compiler/src/function/input/tuple.rs +++ b/compiler/src/function/input/tuple.rs @@ -23,7 +23,7 @@ use leo_ast::{InputValue, Span}; use snarkvm_models::{curves::PrimeField, gadgets::r1cs::ConstraintSystem}; -impl> ConstrainedProgram { +impl<'a, F: PrimeField, G: GroupType> ConstrainedProgram<'a, F, G> { pub fn allocate_tuple>( &mut self, cs: &mut CS, @@ -31,7 +31,7 @@ impl> ConstrainedProgram { types: &[Type], input_value: Option, span: &Span, - ) -> Result, FunctionError> { + ) -> Result, FunctionError> { let mut tuple_values = vec![]; match input_value { diff --git a/compiler/src/function/main_function.rs b/compiler/src/function/main_function.rs index bf29abf27f..45f6b28733 100644 --- a/compiler/src/function/main_function.rs +++ b/compiler/src/function/main_function.rs @@ -18,33 +18,32 @@ use crate::{errors::FunctionError, program::ConstrainedProgram, GroupType, OutputBytes}; -use leo_asg::{Expression, FunctionBody, FunctionQualifier}; +use leo_asg::{Expression, Function, FunctionQualifier}; use leo_ast::Input; -use std::sync::Arc; +use std::cell::Cell; use snarkvm_models::{curves::PrimeField, gadgets::r1cs::ConstraintSystem}; -impl> ConstrainedProgram { +impl<'a, F: PrimeField, G: GroupType> ConstrainedProgram<'a, F, G> { pub fn enforce_main_function>( &mut self, cs: &mut CS, - function: &Arc, + function: &'a Function<'a>, input: &Input, ) -> Result { let registers = input.get_registers(); // Iterate over main function input variables and allocate new values - if function.function.has_input { + if function.has_input { // let input_var = function.scope. let asg_input = function .scope - .borrow() .resolve_input() .expect("no input variable in scope when function is qualified"); let value = self.allocate_input_keyword( cs, - function.function.name.borrow().span.clone(), + function.name.borrow().span.clone(), &asg_input.container_circuit, input, )?; @@ -52,7 +51,7 @@ impl> ConstrainedProgram { self.store(asg_input.container.borrow().id, value); } - match function.function.qualifier { + match function.qualifier { FunctionQualifier::SelfRef | FunctionQualifier::MutSelfRef => { unimplemented!("cannot access self variable in main function") } @@ -61,16 +60,16 @@ impl> ConstrainedProgram { let mut arguments = vec![]; - for input_variable in function.function.arguments.iter() { + for (_, input_variable) in function.arguments.iter() { { - let input_variable = input_variable.borrow(); + let input_variable = input_variable.get().borrow(); let name = input_variable.name.name.clone(); let input_option = input.get(&name).ok_or_else(|| { FunctionError::input_not_found(name.clone(), function.span.clone().unwrap_or_default()) })?; let input_value = self.allocate_main_function_input( cs, - &input_variable.type_.clone().strong(), + &input_variable.type_.clone(), &name, input_option, &function.span.clone().unwrap_or_default(), @@ -79,11 +78,13 @@ impl> ConstrainedProgram { // Store a new variable for every allocated main function input self.store(input_variable.id, input_value); } - arguments.push(Arc::new(Expression::VariableRef(leo_asg::VariableRef { - parent: std::cell::RefCell::new(None), - span: Some(input_variable.borrow().name.span.clone()), - variable: input_variable.clone(), - }))); + arguments.push(Cell::new(&*function.scope.alloc_expression(Expression::VariableRef( + leo_asg::VariableRef { + parent: Cell::new(None), + span: Some(input_variable.get().borrow().name.span.clone()), + variable: input_variable.get(), + }, + )))); } let span = function.span.clone().unwrap_or_default(); diff --git a/compiler/src/function/mut_target.rs b/compiler/src/function/mut_target.rs index b3d8f03226..7e774724da 100644 --- a/compiler/src/function/mut_target.rs +++ b/compiler/src/function/mut_target.rs @@ -33,27 +33,26 @@ use leo_asg::{ TupleAccessExpression, Variable, }; -use std::sync::Arc; use snarkvm_models::{curves::PrimeField, gadgets::r1cs::ConstraintSystem}; -impl> ConstrainedProgram { +impl<'a, F: PrimeField, G: GroupType> ConstrainedProgram<'a, F, G> { fn prepare_mut_access>( &mut self, cs: &mut CS, - expr: &Arc, + expr: &'a Expression<'a>, span: &Span, output: &mut Vec, - ) -> Result, StatementError> { - match &**expr { + ) -> Result>, StatementError> { + match expr { Expression::ArrayRangeAccess(ArrayRangeAccessExpression { array, left, right, .. }) => { - let inner = self.prepare_mut_access(cs, array, span, output)?; + let inner = self.prepare_mut_access(cs, array.get(), span, output)?; let start_index = left - .as_ref() + .get() .map(|start| self.enforce_index(cs, start, &span)) .transpose()?; let stop_index = right - .as_ref() + .get() .map(|stop| self.enforce_index(cs, stop, &span)) .transpose()?; @@ -61,27 +60,27 @@ impl> ConstrainedProgram { Ok(inner) } Expression::ArrayAccess(ArrayAccessExpression { array, index, .. }) => { - let inner = self.prepare_mut_access(cs, array, span, output)?; - let index = self.enforce_index(cs, index, &span)?; + let inner = self.prepare_mut_access(cs, array.get(), span, output)?; + let index = self.enforce_index(cs, index.get(), &span)?; output.push(ResolvedAssigneeAccess::ArrayIndex(index)); Ok(inner) } Expression::TupleAccess(TupleAccessExpression { tuple_ref, index, .. }) => { - let inner = self.prepare_mut_access(cs, tuple_ref, span, output)?; + let inner = self.prepare_mut_access(cs, tuple_ref.get(), span, output)?; output.push(ResolvedAssigneeAccess::Tuple(*index, span.clone())); Ok(inner) } - Expression::CircuitAccess(CircuitAccessExpression { - target: Some(target), - member, - .. - }) => { - let inner = self.prepare_mut_access(cs, target, span, output)?; + Expression::CircuitAccess(CircuitAccessExpression { target, member, .. }) => { + if let Some(target) = target.get() { + let inner = self.prepare_mut_access(cs, target, span, output)?; - output.push(ResolvedAssigneeAccess::Member(member.clone())); - Ok(inner) + output.push(ResolvedAssigneeAccess::Member(member.clone())); + Ok(inner) + } else { + Ok(None) + } } Expression::VariableRef(variable_ref) => Ok(Some(variable_ref.variable.clone())), _ => Ok(None), // not a valid reference to mutable variable, we copy @@ -93,8 +92,8 @@ impl> ConstrainedProgram { pub fn resolve_mut_ref>( &mut self, cs: &mut CS, - assignee: &Arc, - ) -> Result>>, StatementError> { + assignee: &'a Expression<'a>, + ) -> Result>>, StatementError> { let span = assignee.span().cloned().unwrap_or_default(); let mut accesses = vec![]; diff --git a/compiler/src/function/result/result.rs b/compiler/src/function/result/result.rs index cc3d4caed2..57d0ad0bc5 100644 --- a/compiler/src/function/result/result.rs +++ b/compiler/src/function/result/result.rs @@ -34,17 +34,17 @@ use snarkvm_models::{ }, }; -impl> ConstrainedProgram { +impl<'a, F: PrimeField, G: GroupType> ConstrainedProgram<'a, F, G> { /// /// Returns a conditionally selected result from the given possible function returns and /// given function return type. /// pub fn conditionally_select_result>( cs: &mut CS, - expected_return: &Type, - results: Vec<(Boolean, ConstrainedValue)>, + expected_return: &Type<'a>, + results: Vec<(Boolean, ConstrainedValue<'a, F, G>)>, span: &Span, - ) -> Result, StatementError> { + ) -> Result, StatementError> { // Initialize empty return value. let mut return_value = None; diff --git a/compiler/src/output/output_bytes.rs b/compiler/src/output/output_bytes.rs index 5e82e70563..df59b28f25 100644 --- a/compiler/src/output/output_bytes.rs +++ b/compiler/src/output/output_bytes.rs @@ -31,10 +31,10 @@ impl OutputBytes { &self.0 } - pub fn new_from_constrained_value>( - program: &Program, + pub fn new_from_constrained_value<'a, F: PrimeField, G: GroupType>( + program: &Program<'a>, registers: &Registers, - value: ConstrainedValue, + value: ConstrainedValue<'a, F, G>, span: Span, ) -> Result { let return_values = match value { @@ -67,7 +67,7 @@ impl OutputBytes { let name = parameter.variable.name; // Check register type == return value type. - let register_type = program.borrow().scope.borrow().resolve_ast_type(¶meter.type_)?; + let register_type = program.scope.resolve_ast_type(¶meter.type_)?; let return_value_type = value.to_type(&span)?; if !register_type.is_assignable_from(&return_value_type) { diff --git a/compiler/src/prelude/blake2s.rs b/compiler/src/prelude/blake2s.rs index 0e3c3854dc..96e68b38c6 100644 --- a/compiler/src/prelude/blake2s.rs +++ b/compiler/src/prelude/blake2s.rs @@ -14,11 +14,9 @@ // You should have received a copy of the GNU General Public License // along with the Leo library. If not, see . -use std::sync::Arc; - use super::CoreCircuit; use crate::{errors::ExpressionError, ConstrainedValue, GroupType, Integer}; -use leo_asg::{FunctionBody, Span}; +use leo_asg::{Function, Span}; use snarkvm_gadgets::algorithms::prf::Blake2sGadget; use snarkvm_models::{ curves::PrimeField, @@ -48,17 +46,17 @@ fn unwrap_argument>(arg: ConstrainedValue) } } -impl> CoreCircuit for Blake2s { +impl<'a, F: PrimeField, G: GroupType> CoreCircuit<'a, F, G> for Blake2s { fn call_function>( &self, cs: &mut CS, - function: Arc, + function: &'a Function<'a>, span: &Span, - target: Option>, - mut arguments: Vec>, - ) -> Result, ExpressionError> { + target: Option>, + mut arguments: Vec>, + ) -> Result, ExpressionError> { assert_eq!(arguments.len(), 2); // asg enforced - assert!(function.function.name.borrow().name == "hash"); // asg enforced + assert!(function.name.borrow().name == "hash"); // asg enforced assert!(target.is_none()); // asg enforced let input = unwrap_argument(arguments.remove(1)); let seed = unwrap_argument(arguments.remove(0)); diff --git a/compiler/src/prelude/mod.rs b/compiler/src/prelude/mod.rs index a4ed90c154..0319f67957 100644 --- a/compiler/src/prelude/mod.rs +++ b/compiler/src/prelude/mod.rs @@ -15,26 +15,24 @@ // along with the Leo library. If not, see . pub mod blake2s; -use std::sync::Arc; - pub use blake2s::*; use crate::{errors::ExpressionError, ConstrainedValue, GroupType}; -use leo_asg::{FunctionBody, Span}; +use leo_asg::{Function, Span}; use snarkvm_models::{curves::PrimeField, gadgets::r1cs::ConstraintSystem}; -pub trait CoreCircuit>: Send + Sync { +pub trait CoreCircuit<'a, F: PrimeField, G: GroupType>: Send + Sync { fn call_function>( &self, cs: &mut CS, - function: Arc, + function: &'a Function<'a>, span: &Span, - target: Option>, - arguments: Vec>, - ) -> Result, ExpressionError>; + target: Option>, + arguments: Vec>, + ) -> Result, ExpressionError>; } -pub fn resolve_core_circuit>(name: &str) -> impl CoreCircuit { +pub fn resolve_core_circuit<'a, F: PrimeField, G: GroupType>(name: &str) -> impl CoreCircuit<'a, F, G> { match name { "blake2s" => Blake2s, _ => unimplemented!("invalid core circuit: {}", name), diff --git a/compiler/src/program/program.rs b/compiler/src/program/program.rs index df243555af..cf14eac3d1 100644 --- a/compiler/src/program/program.rs +++ b/compiler/src/program/program.rs @@ -24,28 +24,28 @@ use snarkvm_models::curves::PrimeField; use indexmap::IndexMap; use uuid::Uuid; -pub struct ConstrainedProgram> { - pub asg: Program, - identifiers: IndexMap>, +pub struct ConstrainedProgram<'a, F: PrimeField, G: GroupType> { + pub asg: Program<'a>, + identifiers: IndexMap>, } -impl> ConstrainedProgram { - pub fn new(asg: Program) -> Self { +impl<'a, F: PrimeField, G: GroupType> ConstrainedProgram<'a, F, G> { + pub fn new(asg: Program<'a>) -> Self { Self { asg, identifiers: IndexMap::new(), } } - pub(crate) fn store(&mut self, name: Uuid, value: ConstrainedValue) { + pub(crate) fn store(&mut self, name: Uuid, value: ConstrainedValue<'a, F, G>) { self.identifiers.insert(name, value); } - pub(crate) fn get(&self, name: &Uuid) -> Option<&ConstrainedValue> { + pub(crate) fn get(&self, name: &Uuid) -> Option<&ConstrainedValue<'a, F, G>> { self.identifiers.get(name) } - pub(crate) fn get_mut(&mut self, name: &Uuid) -> Option<&mut ConstrainedValue> { + pub(crate) fn get_mut(&mut self, name: &Uuid) -> Option<&mut ConstrainedValue<'a, F, G>> { self.identifiers.get_mut(name) } } diff --git a/compiler/src/statement/assign/assign.rs b/compiler/src/statement/assign/assign.rs index 7694dc86ae..625dd7d2b6 100644 --- a/compiler/src/statement/assign/assign.rs +++ b/compiler/src/statement/assign/assign.rs @@ -27,16 +27,16 @@ use snarkvm_models::{ }, }; -impl> ConstrainedProgram { +impl<'a, F: PrimeField, G: GroupType> ConstrainedProgram<'a, F, G> { #[allow(clippy::too_many_arguments)] pub fn enforce_assign_statement>( &mut self, cs: &mut CS, indicator: &Boolean, - statement: &AssignStatement, + statement: &AssignStatement<'a>, ) -> Result<(), StatementError> { // Get the name of the variable we are assigning to - let new_value = self.enforce_expression(cs, &statement.value)?; + let new_value = self.enforce_expression(cs, statement.value.get())?; let mut resolved_assignee = self.resolve_assign(cs, statement)?; if resolved_assignee.len() == 1 { @@ -86,8 +86,8 @@ impl> ConstrainedProgram { condition: &Boolean, scope: String, operation: &AssignOperation, - target: &mut ConstrainedValue, - new_value: ConstrainedValue, + target: &mut ConstrainedValue<'a, F, G>, + new_value: ConstrainedValue<'a, F, G>, span: &Span, ) -> Result<(), StatementError> { let new_value = match operation { diff --git a/compiler/src/statement/assign/assignee.rs b/compiler/src/statement/assign/assignee.rs index e88597929c..dc505daac4 100644 --- a/compiler/src/statement/assign/assignee.rs +++ b/compiler/src/statement/assign/assignee.rs @@ -28,12 +28,12 @@ pub(crate) enum ResolvedAssigneeAccess { Member(Identifier), } -impl> ConstrainedProgram { +impl<'a, F: PrimeField, G: GroupType> ConstrainedProgram<'a, F, G> { pub fn resolve_assign>( &mut self, cs: &mut CS, - assignee: &AssignStatement, - ) -> Result>, StatementError> { + assignee: &AssignStatement<'a>, + ) -> Result>, StatementError> { let span = assignee.span.clone().unwrap_or_default(); let resolved_accesses = assignee @@ -42,17 +42,14 @@ impl> ConstrainedProgram { .map(|access| match access { AssignAccess::ArrayRange(start, stop) => { let start_index = start - .as_ref() + .get() .map(|start| self.enforce_index(cs, start, &span)) .transpose()?; - let stop_index = stop - .as_ref() - .map(|stop| self.enforce_index(cs, stop, &span)) - .transpose()?; + let stop_index = stop.get().map(|stop| self.enforce_index(cs, stop, &span)).transpose()?; Ok(ResolvedAssigneeAccess::ArrayRange(start_index, stop_index)) } AssignAccess::ArrayIndex(index) => { - let index = self.enforce_index(cs, index, &span)?; + let index = self.enforce_index(cs, index.get(), &span)?; Ok(ResolvedAssigneeAccess::ArrayIndex(index)) } @@ -61,7 +58,7 @@ impl> ConstrainedProgram { }) .collect::, crate::errors::ExpressionError>>()?; - let variable = assignee.target_variable.borrow(); + let variable = assignee.target_variable.get().borrow(); let mut result = vec![match self.get_mut(&variable.id) { Some(value) => value, @@ -96,11 +93,11 @@ impl> ConstrainedProgram { } // todo: this can prob have most of its error checking removed - pub(crate) fn resolve_assignee_access<'a>( + pub(crate) fn resolve_assignee_access<'b>( access: ResolvedAssigneeAccess, span: &Span, - mut value: Vec<&'a mut ConstrainedValue>, - ) -> Result>, StatementError> { + mut value: Vec<&'b mut ConstrainedValue<'a, F, G>>, + ) -> Result>, StatementError> { match access { ResolvedAssigneeAccess::ArrayIndex(index) => { if value.len() != 1 { diff --git a/compiler/src/statement/block/block.rs b/compiler/src/statement/block/block.rs index 2c44bbaf25..b7171a69b6 100644 --- a/compiler/src/statement/block/block.rs +++ b/compiler/src/statement/block/block.rs @@ -24,7 +24,7 @@ use snarkvm_models::{ gadgets::{r1cs::ConstraintSystem, utilities::boolean::Boolean}, }; -impl> ConstrainedProgram { +impl<'a, F: PrimeField, G: GroupType> ConstrainedProgram<'a, F, G> { /// Evaluates a branch of one or more statements and returns a result in /// the given scope. #[allow(clippy::too_many_arguments)] @@ -32,12 +32,12 @@ impl> ConstrainedProgram { &mut self, cs: &mut CS, indicator: &Boolean, - block: &BlockStatement, - ) -> StatementResult>> { + block: &BlockStatement<'a>, + ) -> StatementResult>> { let mut results = Vec::with_capacity(block.statements.len()); // Evaluate statements. Only allow a single return argument to be returned. for statement in block.statements.iter() { - let value = self.enforce_statement(cs, indicator, statement)?; + let value = self.enforce_statement(cs, indicator, statement.get())?; results.extend(value); } diff --git a/compiler/src/statement/conditional/conditional.rs b/compiler/src/statement/conditional/conditional.rs index 57291cd72c..b7698b6109 100644 --- a/compiler/src/statement/conditional/conditional.rs +++ b/compiler/src/statement/conditional/conditional.rs @@ -38,7 +38,7 @@ fn indicator_to_string(indicator: &Boolean) -> String { .unwrap_or_else(|| "[input]".to_string()) } -impl> ConstrainedProgram { +impl<'a, F: PrimeField, G: GroupType> ConstrainedProgram<'a, F, G> { /// Enforces a conditional statement with one or more branches. /// Due to R1CS constraints, we must evaluate every branch to properly construct the circuit. /// At program execution, we will pass an `indicator` bit down to all child statements within each branch. @@ -48,14 +48,14 @@ impl> ConstrainedProgram { &mut self, cs: &mut CS, indicator: &Boolean, - statement: &ConditionalStatement, - ) -> StatementResult>> { + statement: &ConditionalStatement<'a>, + ) -> StatementResult>> { let span = statement.span.clone().unwrap_or_default(); // Inherit an indicator from a previous statement. let outer_indicator = indicator; // Evaluate the conditional boolean as the inner indicator - let inner_indicator = match self.enforce_expression(cs, &statement.condition)? { + let inner_indicator = match self.enforce_expression(cs, statement.condition.get())? { ConstrainedValue::Boolean(resolved) => resolved, value => { return Err(StatementError::conditional_boolean(value.to_string(), span)); @@ -79,7 +79,7 @@ impl> ConstrainedProgram { let mut results = vec![]; // Evaluate branch 1 - let mut branch_1_result = self.enforce_statement(cs, &branch_1_indicator, &statement.result)?; + let mut branch_1_result = self.enforce_statement(cs, &branch_1_indicator, statement.result.get())?; results.append(&mut branch_1_result); @@ -98,7 +98,7 @@ impl> ConstrainedProgram { .map_err(|_| StatementError::indicator_calculation(branch_2_name, span.clone()))?; // Evaluate branch 2 - let mut branch_2_result = match &statement.next { + let mut branch_2_result = match statement.next.get() { Some(next) => self.enforce_statement(cs, &branch_2_indicator, next)?, None => vec![], }; diff --git a/compiler/src/statement/definition/definition.rs b/compiler/src/statement/definition/definition.rs index fce05a7920..96e53b9272 100644 --- a/compiler/src/statement/definition/definition.rs +++ b/compiler/src/statement/definition/definition.rs @@ -21,11 +21,11 @@ use leo_asg::{DefinitionStatement, Span, Variable}; use snarkvm_models::{curves::PrimeField, gadgets::r1cs::ConstraintSystem}; -impl> ConstrainedProgram { +impl<'a, F: PrimeField, G: GroupType> ConstrainedProgram<'a, F, G> { fn enforce_multiple_definition( &mut self, - variable_names: &[Variable], - values: Vec>, + variable_names: &[&'a Variable<'a>], + values: Vec>, span: &Span, ) -> Result<(), StatementError> { if values.len() != variable_names.len() { @@ -47,10 +47,10 @@ impl> ConstrainedProgram { pub fn enforce_definition_statement>( &mut self, cs: &mut CS, - statement: &DefinitionStatement, + statement: &DefinitionStatement<'a>, ) -> Result<(), StatementError> { let num_variables = statement.variables.len(); - let expression = self.enforce_expression(cs, &statement.value)?; + let expression = self.enforce_expression(cs, statement.value.get())?; let span = statement.span.clone().unwrap_or_default(); if num_variables == 1 { @@ -65,7 +65,7 @@ impl> ConstrainedProgram { value => return Err(StatementError::multiple_definition(value.to_string(), span)), }; - self.enforce_multiple_definition(&statement.variables, values, &span) + self.enforce_multiple_definition(&statement.variables[..], values, &span) } } } diff --git a/compiler/src/statement/iteration/iteration.rs b/compiler/src/statement/iteration/iteration.rs index 830546e40b..2a9f4c33c9 100644 --- a/compiler/src/statement/iteration/iteration.rs +++ b/compiler/src/statement/iteration/iteration.rs @@ -34,20 +34,20 @@ use snarkvm_models::{ }, }; -impl> ConstrainedProgram { +impl<'a, F: PrimeField, G: GroupType> ConstrainedProgram<'a, F, G> { #[allow(clippy::too_many_arguments)] pub fn enforce_iteration_statement>( &mut self, cs: &mut CS, indicator: &Boolean, - statement: &IterationStatement, - ) -> StatementResult>> { + statement: &IterationStatement<'a>, + ) -> StatementResult>> { let mut results = vec![]; let span = statement.span.clone().unwrap_or_default(); - let from = self.enforce_index(cs, &statement.start, &span)?; - let to = self.enforce_index(cs, &statement.stop, &span)?; + let from = self.enforce_index(cs, statement.start.get(), &span)?; + let to = self.enforce_index(cs, statement.stop.get(), &span)?; for i in from..to { // Store index in current function scope. @@ -64,7 +64,7 @@ impl> ConstrainedProgram { let result = self.enforce_statement( &mut cs.ns(|| format!("for loop iteration {} {}:{}", i, &span.line, &span.start)), indicator, - &statement.body, + statement.body.get(), )?; results.extend(result); diff --git a/compiler/src/statement/return_/return_.rs b/compiler/src/statement/return_/return_.rs index 039e87505b..ed2ec85385 100644 --- a/compiler/src/statement/return_/return_.rs +++ b/compiler/src/statement/return_/return_.rs @@ -21,13 +21,13 @@ use leo_asg::ReturnStatement; use snarkvm_models::{curves::PrimeField, gadgets::r1cs::ConstraintSystem}; -impl> ConstrainedProgram { +impl<'a, F: PrimeField, G: GroupType> ConstrainedProgram<'a, F, G> { pub fn enforce_return_statement>( &mut self, cs: &mut CS, - statement: &ReturnStatement, - ) -> Result, StatementError> { - let result = self.enforce_expression(cs, &statement.expression)?; + statement: &ReturnStatement<'a>, + ) -> Result, StatementError> { + let result = self.enforce_expression(cs, statement.expression.get())?; Ok(result) } } diff --git a/compiler/src/statement/statement.rs b/compiler/src/statement/statement.rs index 115f656b77..786b4bf088 100644 --- a/compiler/src/statement/statement.rs +++ b/compiler/src/statement/statement.rs @@ -18,7 +18,6 @@ use crate::{errors::StatementError, program::ConstrainedProgram, value::ConstrainedValue, GroupType}; use leo_asg::Statement; -use std::sync::Arc; use snarkvm_models::{ curves::PrimeField, @@ -26,9 +25,9 @@ use snarkvm_models::{ }; pub type StatementResult = Result; -pub type IndicatorAndConstrainedValue = (Boolean, ConstrainedValue); +pub type IndicatorAndConstrainedValue<'a, T, U> = (Boolean, ConstrainedValue<'a, T, U>); -impl> ConstrainedProgram { +impl<'a, F: PrimeField, G: GroupType> ConstrainedProgram<'a, F, G> { /// /// Enforce a program statement. /// Returns a Vector of (indicator, value) tuples. @@ -41,11 +40,11 @@ impl> ConstrainedProgram { &mut self, cs: &mut CS, indicator: &Boolean, - statement: &Arc, - ) -> StatementResult>> { + statement: &'a Statement<'a>, + ) -> StatementResult>> { let mut results = vec![]; - match &**statement { + match statement { Statement::Return(statement) => { let return_value = (*indicator, self.enforce_return_statement(cs, statement)?); @@ -71,7 +70,7 @@ impl> ConstrainedProgram { self.evaluate_console_function_call(cs, indicator, statement)?; } Statement::Expression(statement) => { - let value = self.enforce_expression(cs, &statement.expression)?; + let value = self.enforce_expression(cs, statement.expression.get())?; // handle empty return value cases match &value { ConstrainedValue::Tuple(values) => { diff --git a/compiler/src/value/address/address.rs b/compiler/src/value/address/address.rs index b8bec3f53c..28ccb53dfb 100644 --- a/compiler/src/value/address/address.rs +++ b/compiler/src/value/address/address.rs @@ -63,12 +63,12 @@ impl Address { self.bytes.iter().all(|byte| byte.is_constant()) } - pub(crate) fn from_input, CS: ConstraintSystem>( + pub(crate) fn from_input<'a, F: PrimeField, G: GroupType, CS: ConstraintSystem>( cs: &mut CS, name: &str, input_value: Option, span: &Span, - ) -> Result, AddressError> { + ) -> Result, AddressError> { // Check that the input value is the correct type let address_value = match input_value { Some(input) => { diff --git a/compiler/src/value/boolean/input.rs b/compiler/src/value/boolean/input.rs index 32119dfda3..8caa9a8e3e 100644 --- a/compiler/src/value/boolean/input.rs +++ b/compiler/src/value/boolean/input.rs @@ -41,12 +41,12 @@ pub(crate) fn allocate_bool>( .map_err(|_| BooleanError::missing_boolean(format!("{}: bool", name), span.to_owned())) } -pub(crate) fn bool_from_input, CS: ConstraintSystem>( +pub(crate) fn bool_from_input<'a, F: PrimeField, G: GroupType, CS: ConstraintSystem>( cs: &mut CS, name: &str, input_value: Option, span: &Span, -) -> Result, BooleanError> { +) -> Result, BooleanError> { // Check that the input value is the correct type let option = match input_value { Some(input) => { diff --git a/compiler/src/value/field/field_type.rs b/compiler/src/value/field/field_type.rs index 6b91c4bb16..1da0ea45a4 100644 --- a/compiler/src/value/field/field_type.rs +++ b/compiler/src/value/field/field_type.rs @@ -54,7 +54,22 @@ impl FieldType { } pub fn constant(string: String, span: &Span) -> Result { - let value = F::from_str(&string).map_err(|_| FieldError::invalid_field(string, span.to_owned()))?; + let first_char = string.chars().next().unwrap(); + let new_string: &str; + let value; + + // Check if first symbol is a negative. + // If so strip it, parse rest of string and then negate it. + if first_char == '-' { + new_string = string + .chars() + .next() + .map(|c| &string[c.len_utf8()..]) + .ok_or_else(|| FieldError::invalid_field(string.clone(), span.to_owned()))?; + value = -F::from_str(&new_string).map_err(|_| FieldError::invalid_field(string, span.to_owned()))?; + } else { + value = F::from_str(&string).map_err(|_| FieldError::invalid_field(string, span.to_owned()))?; + } Ok(FieldType::Constant(value)) } diff --git a/compiler/src/value/field/input.rs b/compiler/src/value/field/input.rs index 88a6f5b9fe..aa14e65a16 100644 --- a/compiler/src/value/field/input.rs +++ b/compiler/src/value/field/input.rs @@ -38,12 +38,12 @@ pub(crate) fn allocate_field>( .map_err(|_| FieldError::missing_field(format!("{}: field", name), span.to_owned())) } -pub(crate) fn field_from_input, CS: ConstraintSystem>( +pub(crate) fn field_from_input<'a, F: PrimeField, G: GroupType, CS: ConstraintSystem>( cs: &mut CS, name: &str, input_value: Option, span: &Span, -) -> Result, FieldError> { +) -> Result, FieldError> { // Check that the parameter value is the correct type let option = match input_value { Some(input) => { diff --git a/compiler/src/value/group/input.rs b/compiler/src/value/group/input.rs index 274d944e15..32e6c90edc 100644 --- a/compiler/src/value/group/input.rs +++ b/compiler/src/value/group/input.rs @@ -36,12 +36,12 @@ pub(crate) fn allocate_group, CS: ConstraintSyste .map_err(|_| GroupError::missing_group(format!("{}: group", name), span.to_owned())) } -pub(crate) fn group_from_input, CS: ConstraintSystem>( +pub(crate) fn group_from_input<'a, F: PrimeField, G: GroupType, CS: ConstraintSystem>( cs: &mut CS, name: &str, input_value: Option, span: &Span, -) -> Result, GroupError> { +) -> Result, GroupError> { // Check that the parameter value is the correct type let option = match input_value { Some(input) => { diff --git a/compiler/src/value/group/targets/edwards_bls12.rs b/compiler/src/value/group/targets/edwards_bls12.rs index 8ff4e4d81e..309ff3aa50 100644 --- a/compiler/src/value/group/targets/edwards_bls12.rs +++ b/compiler/src/value/group/targets/edwards_bls12.rs @@ -133,6 +133,23 @@ impl GroupType for EdwardsGroupType { } } +fn number_string_typing(number: &str, span: &Span) -> Result<(String, bool), GroupError> { + let first_char = number.chars().next().unwrap(); + + // Check if first symbol is a negative. + // If so strip it, parse rest of string and then negate it. + if first_char == '-' { + let uint = number + .chars() + .next() + .map(|c| &number[c.len_utf8()..]) + .ok_or_else(|| GroupError::invalid_group(number.to_string(), span.to_owned()))?; + Ok((uint.to_string(), true)) + } else { + Ok((number.to_string(), false)) + } +} + impl EdwardsGroupType { pub fn edwards_affine_from_value(value: &GroupValue, span: &Span) -> Result { match value { @@ -142,12 +159,19 @@ impl EdwardsGroupType { } pub fn edwards_affine_from_single(number: &str, span: &Span) -> Result { - if number.eq("0") { + let number_info = number_string_typing(number, &span.clone())?; + + if number_info.0.eq("0") { Ok(EdwardsAffine::zero()) } else { let one = edwards_affine_one(); - let number_value = - Fp256::from_str(&number).map_err(|_| GroupError::n_group(number.to_string(), span.clone()))?; + let number_value = match number_info { + (number, neg) if neg => { + -Fp256::from_str(&number).map_err(|_| GroupError::n_group(number, span.clone()))? + } + (number, _) => Fp256::from_str(&number).map_err(|_| GroupError::n_group(number, span.clone()))?, + }; + let result: EdwardsAffine = one.mul(&number_value); Ok(result) @@ -164,32 +188,42 @@ impl EdwardsGroupType { match (x, y) { // (x, y) - (GroupCoordinate::Number(x_string), GroupCoordinate::Number(y_string)) => { - Self::edwards_affine_from_pair(x_string, y_string, span, span, span) - } + (GroupCoordinate::Number(x_string), GroupCoordinate::Number(y_string)) => Self::edwards_affine_from_pair( + number_string_typing(&x_string, &span.clone())?, + number_string_typing(&y_string, &span.clone())?, + span, + span, + span, + ), // (x, +) (GroupCoordinate::Number(x_string), GroupCoordinate::SignHigh) => { - Self::edwards_affine_from_x_str(x_string, span, Some(true), span) + Self::edwards_affine_from_x_str(number_string_typing(&x_string, &span.clone())?, span, Some(true), span) } // (x, -) - (GroupCoordinate::Number(x_string), GroupCoordinate::SignLow) => { - Self::edwards_affine_from_x_str(x_string, span, Some(false), span) - } + (GroupCoordinate::Number(x_string), GroupCoordinate::SignLow) => Self::edwards_affine_from_x_str( + number_string_typing(&x_string, &span.clone())?, + span, + Some(false), + span, + ), // (x, _) (GroupCoordinate::Number(x_string), GroupCoordinate::Inferred) => { - Self::edwards_affine_from_x_str(x_string, span, None, span) + Self::edwards_affine_from_x_str(number_string_typing(&x_string, &span.clone())?, span, None, span) } // (+, y) (GroupCoordinate::SignHigh, GroupCoordinate::Number(y_string)) => { - Self::edwards_affine_from_y_str(y_string, span, Some(true), span) + Self::edwards_affine_from_y_str(number_string_typing(&y_string, &span.clone())?, span, Some(true), span) } // (-, y) - (GroupCoordinate::SignLow, GroupCoordinate::Number(y_string)) => { - Self::edwards_affine_from_y_str(y_string, span, Some(false), span) - } + (GroupCoordinate::SignLow, GroupCoordinate::Number(y_string)) => Self::edwards_affine_from_y_str( + number_string_typing(&y_string, &span.clone())?, + span, + Some(false), + span, + ), // (_, y) (GroupCoordinate::Inferred, GroupCoordinate::Number(y_string)) => { - Self::edwards_affine_from_y_str(y_string, span, None, span) + Self::edwards_affine_from_y_str(number_string_typing(&y_string, &span.clone())?, span, None, span) } // Invalid (x, y) => Err(GroupError::invalid_group(format!("({}, {})", x, y), span.clone())), @@ -197,12 +231,16 @@ impl EdwardsGroupType { } pub fn edwards_affine_from_x_str( - x_string: String, + x_info: (String, bool), x_span: &Span, greatest: Option, element_span: &Span, ) -> Result { - let x = Fq::from_str(&x_string).map_err(|_| GroupError::x_invalid(x_string, x_span.clone()))?; + let x = match x_info { + (x_str, neg) if neg => -Fq::from_str(&x_str).map_err(|_| GroupError::x_invalid(x_str, x_span.clone()))?, + (x_str, _) => Fq::from_str(&x_str).map_err(|_| GroupError::x_invalid(x_str, x_span.clone()))?, + }; + match greatest { // Sign provided Some(greatest) => { @@ -227,12 +265,15 @@ impl EdwardsGroupType { } pub fn edwards_affine_from_y_str( - y_string: String, + y_info: (String, bool), y_span: &Span, greatest: Option, element_span: &Span, ) -> Result { - let y = Fq::from_str(&y_string).map_err(|_| GroupError::y_invalid(y_string, y_span.clone()))?; + let y = match y_info { + (y_str, neg) if neg => -Fq::from_str(&y_str).map_err(|_| GroupError::y_invalid(y_str, y_span.clone()))?, + (y_str, _) => Fq::from_str(&y_str).map_err(|_| GroupError::y_invalid(y_str, y_span.clone()))?, + }; match greatest { // Sign provided @@ -258,14 +299,25 @@ impl EdwardsGroupType { } pub fn edwards_affine_from_pair( - x_string: String, - y_string: String, + x_info: (String, bool), + y_info: (String, bool), x_span: &Span, y_span: &Span, element_span: &Span, ) -> Result { - let x = Fq::from_str(&x_string).map_err(|_| GroupError::x_invalid(x_string, x_span.clone()))?; - let y = Fq::from_str(&y_string).map_err(|_| GroupError::y_invalid(y_string, y_span.clone()))?; + let x = match x_info { + (x_str, neg) if neg => { + -Fq::from_str(&x_str).map_err(|_| GroupError::x_invalid(x_str.to_string(), x_span.clone()))? + } + (x_str, _) => Fq::from_str(&x_str).map_err(|_| GroupError::x_invalid(x_str.to_string(), x_span.clone()))?, + }; + + let y = match y_info { + (y_str, neg) if neg => { + -Fq::from_str(&y_str).map_err(|_| GroupError::y_invalid(y_str.to_string(), y_span.clone()))? + } + (y_str, _) => Fq::from_str(&y_str).map_err(|_| GroupError::y_invalid(y_str.to_string(), y_span.clone()))?, + }; let element = EdwardsAffine::new(x, y); diff --git a/compiler/src/value/value.rs b/compiler/src/value/value.rs index 6fec55729d..873440ad79 100644 --- a/compiler/src/value/value.rs +++ b/compiler/src/value/value.rs @@ -17,7 +17,7 @@ //! The in memory stored value for a defined name in a compiled Leo program. use crate::{errors::ValueError, Address, FieldType, GroupType, Integer}; -use leo_asg::{CircuitBody, Identifier, Span, Type}; +use leo_asg::{Circuit, Identifier, Span, Type}; use snarkvm_errors::gadgets::SynthesisError; use snarkvm_models::{ @@ -27,13 +27,13 @@ use snarkvm_models::{ utilities::{boolean::Boolean, eq::ConditionalEqGadget, select::CondSelectGadget}, }, }; -use std::{fmt, sync::Arc}; +use std::fmt; #[derive(Clone, PartialEq, Eq)] -pub struct ConstrainedCircuitMember>(pub Identifier, pub ConstrainedValue); +pub struct ConstrainedCircuitMember<'a, F: PrimeField, G: GroupType>(pub Identifier, pub ConstrainedValue<'a, F, G>); #[derive(Clone, PartialEq, Eq)] -pub enum ConstrainedValue> { +pub enum ConstrainedValue<'a, F: PrimeField, G: GroupType> { // Data types Address(Address), Boolean(Boolean), @@ -42,17 +42,17 @@ pub enum ConstrainedValue> { Integer(Integer), // Arrays - Array(Vec>), + Array(Vec>), // Tuples - Tuple(Vec>), + Tuple(Vec>), // Circuits - CircuitExpression(Arc, Vec>), + CircuitExpression(&'a Circuit<'a>, Vec>), } -impl> ConstrainedValue { - pub(crate) fn to_type(&self, span: &Span) -> Result { +impl<'a, F: PrimeField, G: GroupType> ConstrainedValue<'a, F, G> { + pub(crate) fn to_type(&self, span: &Span) -> Result, ValueError> { Ok(match self { // Data types ConstrainedValue::Address(_address) => Type::Address, @@ -77,12 +77,12 @@ impl> ConstrainedValue { Type::Tuple(types) } - ConstrainedValue::CircuitExpression(id, _members) => Type::Circuit(id.circuit.clone()), + ConstrainedValue::CircuitExpression(id, _members) => Type::Circuit(*id), }) } } -impl> fmt::Display for ConstrainedValue { +impl<'a, F: PrimeField, G: GroupType> fmt::Display for ConstrainedValue<'a, F, G> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { // Data types @@ -116,7 +116,7 @@ impl> fmt::Display for ConstrainedValue { write!(f, "({})", values) } ConstrainedValue::CircuitExpression(ref circuit, ref members) => { - write!(f, "{} {{", circuit.circuit.name.borrow())?; + write!(f, "{} {{", circuit.name.borrow())?; for (i, member) in members.iter().enumerate() { write!(f, "{}: {}", member.0, member.1)?; if i < members.len() - 1 { @@ -129,13 +129,13 @@ impl> fmt::Display for ConstrainedValue { } } -impl> fmt::Debug for ConstrainedValue { +impl<'a, F: PrimeField, G: GroupType> fmt::Debug for ConstrainedValue<'a, F, G> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{}", self) } } -impl> ConditionalEqGadget for ConstrainedValue { +impl<'a, F: PrimeField, G: GroupType> ConditionalEqGadget for ConstrainedValue<'a, F, G> { fn conditional_enforce_equal>( &self, mut cs: CS, @@ -179,7 +179,7 @@ impl> ConditionalEqGadget for ConstrainedValue } } -impl> CondSelectGadget for ConstrainedValue { +impl<'a, F: PrimeField, G: GroupType> CondSelectGadget for ConstrainedValue<'a, F, G> { fn conditionally_select>( mut cs: CS, cond: &Boolean, @@ -245,7 +245,7 @@ impl> CondSelectGadget for ConstrainedValue return Err(SynthesisError::Unsatisfiable), }) @@ -256,7 +256,7 @@ impl> CondSelectGadget for ConstrainedValue> CondSelectGadget for ConstrainedCircuitMember { +impl<'a, F: PrimeField, G: GroupType> CondSelectGadget for ConstrainedCircuitMember<'a, F, G> { fn conditionally_select>( cs: CS, cond: &Boolean, diff --git a/compiler/tests/compiler/mod.rs b/compiler/tests/compiler/mod.rs index ec07bfa672..d07b11b69a 100644 --- a/compiler/tests/compiler/mod.rs +++ b/compiler/tests/compiler/mod.rs @@ -27,7 +27,8 @@ static MAIN_FILE_NAME: &str = "tests/compiler/main.leo"; fn test_parse_program_from_string() { // Parse program from string with compiler. let program_string = include_str!("main.leo"); - let mut compiler_no_path = EdwardsTestCompiler::new("".to_string(), PathBuf::new(), PathBuf::new()); + let context = crate::make_test_context(); + let mut compiler_no_path = EdwardsTestCompiler::new("".to_string(), PathBuf::new(), PathBuf::new(), context); compiler_no_path.parse_program_from_string(program_string).unwrap(); @@ -36,7 +37,7 @@ fn test_parse_program_from_string() { local.push(MAIN_FILE_NAME); let compiler_with_path = - EdwardsTestCompiler::parse_program_without_input("".to_string(), local, PathBuf::new()).unwrap(); + EdwardsTestCompiler::parse_program_without_input("".to_string(), local, PathBuf::new(), context).unwrap(); // Compare output bytes. let expected_output = get_output(compiler_no_path); diff --git a/compiler/tests/field/field.leo b/compiler/tests/field/field.leo new file mode 100644 index 0000000000..3f109881ad --- /dev/null +++ b/compiler/tests/field/field.leo @@ -0,0 +1,4 @@ +function main() { + let negOneField: field = -1field; + let oneField = 1field; +} \ No newline at end of file diff --git a/compiler/tests/field/mod.rs b/compiler/tests/field/mod.rs index f8f97fe5c6..2a768450d6 100644 --- a/compiler/tests/field/mod.rs +++ b/compiler/tests/field/mod.rs @@ -66,6 +66,14 @@ fn test_negate() { } } +#[test] +fn test_field() { + let program_string = include_str!("field.leo"); + let mut program = parse_program(program_string).unwrap(); + + assert_satisfied(program) +} + #[test] fn test_add() { use std::ops::Add; diff --git a/compiler/tests/group/mod.rs b/compiler/tests/group/mod.rs index 65680f0b14..d7fca2a4cd 100644 --- a/compiler/tests/group/mod.rs +++ b/compiler/tests/group/mod.rs @@ -388,3 +388,12 @@ fn test_ternary() { assert_satisfied(program); } + +#[test] +fn test_positive_and_negative() { + let program_string = include_str!("positive_and_negative.leo"); + + let program = parse_program(program_string).unwrap(); + + assert_satisfied(program); +} diff --git a/compiler/tests/group/positive_and_negative.leo b/compiler/tests/group/positive_and_negative.leo new file mode 100644 index 0000000000..fdf9b892be --- /dev/null +++ b/compiler/tests/group/positive_and_negative.leo @@ -0,0 +1,10 @@ +function main() { + let pos_element = 1group; + let neg_element = -1group; + + let pair_x_pos = (1, _)group; + let pair_x_neg = (-1, _)group; + + let pair_y_pos = (_, 1)group; + let pair_y_neg = (_, -1)group; +} \ No newline at end of file diff --git a/compiler/tests/input_files/program_input/input/main_field.in b/compiler/tests/input_files/program_input/input/main_field.in new file mode 100644 index 0000000000..c3512fa1c4 --- /dev/null +++ b/compiler/tests/input_files/program_input/input/main_field.in @@ -0,0 +1,2 @@ +[main] +a: field = 1; \ No newline at end of file diff --git a/compiler/tests/input_files/program_input/main_field.leo b/compiler/tests/input_files/program_input/main_field.leo new file mode 100644 index 0000000000..2fbcbd2b90 --- /dev/null +++ b/compiler/tests/input_files/program_input/main_field.leo @@ -0,0 +1,4 @@ +function main(a: field) { + // Change to assert when == is implemented for field. + console.log("a: {}", a); +} \ No newline at end of file diff --git a/compiler/tests/input_files/program_input/mod.rs b/compiler/tests/input_files/program_input/mod.rs index 58efad93c1..fc0ca80527 100644 --- a/compiler/tests/input_files/program_input/mod.rs +++ b/compiler/tests/input_files/program_input/mod.rs @@ -93,3 +93,13 @@ fn test_input_array_dimensions_mismatch() { expect_fail(program); } + +#[test] +fn test_field_input() { + let program_string = include_str!("main_field.leo"); + let input_string = include_str!("input/main_field.in"); + + let program = parse_program_with_input(program_string, input_string).unwrap(); + + assert_satisfied(program); +} diff --git a/compiler/tests/mod.rs b/compiler/tests/mod.rs index d1eee03f8b..23d59c045c 100644 --- a/compiler/tests/mod.rs +++ b/compiler/tests/mod.rs @@ -36,6 +36,7 @@ pub mod statements; pub mod syntax; pub mod tuples; +use leo_asg::{new_context, AsgContext}; use leo_ast::{InputValue, MainInput}; use leo_compiler::{ compiler::Compiler, @@ -54,18 +55,23 @@ use std::path::PathBuf; pub const TEST_OUTPUT_DIRECTORY: &str = "/output/"; const EMPTY_FILE: &str = ""; -pub type EdwardsTestCompiler = Compiler; -pub type EdwardsConstrainedValue = ConstrainedValue; +pub type EdwardsTestCompiler = Compiler<'static, Fq, EdwardsGroupType>; +pub type EdwardsConstrainedValue = ConstrainedValue<'static, Fq, EdwardsGroupType>; + +//convenience function for tests, leaks memory +pub(crate) fn make_test_context() -> AsgContext<'static> { + Box::leak(Box::new(new_context())) +} fn new_compiler() -> EdwardsTestCompiler { let program_name = "test".to_string(); let path = PathBuf::from("/test/src/main.leo"); let output_dir = PathBuf::from(TEST_OUTPUT_DIRECTORY); - EdwardsTestCompiler::new(program_name, path, output_dir) + EdwardsTestCompiler::new(program_name, path, output_dir, make_test_context()) } -pub(crate) fn parse_program(program_string: &str) -> Result { +pub(crate) fn parse_program<'a>(program_string: &str) -> Result { let mut compiler = new_compiler(); compiler.parse_program_from_string(program_string)?; @@ -73,7 +79,7 @@ pub(crate) fn parse_program(program_string: &str) -> Result Result { +pub(crate) fn parse_input<'a>(input_string: &str) -> Result { let mut compiler = new_compiler(); let path = PathBuf::new(); @@ -82,7 +88,7 @@ pub(crate) fn parse_input(input_string: &str) -> Result Result { +pub(crate) fn parse_state<'a>(state_string: &str) -> Result { let mut compiler = new_compiler(); let path = PathBuf::new(); @@ -91,7 +97,7 @@ pub(crate) fn parse_state(state_string: &str) -> Result( input_string: &str, state_string: &str, ) -> Result { @@ -103,7 +109,7 @@ pub(crate) fn parse_input_and_state( Ok(compiler) } -pub fn parse_program_with_input( +pub fn parse_program_with_input<'a>( program_string: &str, input_string: &str, ) -> Result { @@ -116,7 +122,7 @@ pub fn parse_program_with_input( Ok(compiler) } -pub fn parse_program_with_state( +pub fn parse_program_with_state<'a>( program_string: &str, state_string: &str, ) -> Result { diff --git a/imports/src/parser/import_parser.rs b/imports/src/parser/import_parser.rs index 2e4de8180a..b8376265bb 100644 --- a/imports/src/parser/import_parser.rs +++ b/imports/src/parser/import_parser.rs @@ -15,7 +15,7 @@ // along with the Leo library. If not, see . use crate::errors::ImportParserError; -use leo_asg::{AsgConvertError, ImportResolver, Program, Span}; +use leo_asg::{AsgContext, AsgConvertError, ImportResolver, Program, Span}; use indexmap::{IndexMap, IndexSet}; use std::env::current_dir; @@ -25,14 +25,19 @@ use std::env::current_dir; /// A program can import one or more packages. A package can be found locally in the source /// directory, foreign in the imports directory, or part of the core package list. #[derive(Clone, Default)] -pub struct ImportParser { +pub struct ImportParser<'a> { partial_imports: IndexSet, - imports: IndexMap, + imports: IndexMap>, } //todo: handle relative imports relative to file... -impl ImportResolver for ImportParser { - fn resolve_package(&mut self, package_segments: &[&str], span: &Span) -> Result, AsgConvertError> { +impl<'a> ImportResolver<'a> for ImportParser<'a> { + fn resolve_package( + &mut self, + ctx: AsgContext<'a>, + package_segments: &[&str], + span: &Span, + ) -> Result>, AsgConvertError> { let full_path = package_segments.join("."); if self.partial_imports.contains(&full_path) { return Err(ImportParserError::recursive_imports(&full_path, span).into()); @@ -46,7 +51,7 @@ impl ImportResolver for ImportParser { self.partial_imports.insert(full_path.clone()); let program = imports - .parse_package(path, package_segments, span) + .parse_package(ctx, path, package_segments, span) .map_err(|x| -> AsgConvertError { x.into() })?; self.partial_imports.remove(&full_path); self.imports.insert(full_path, program.clone()); diff --git a/imports/src/parser/parse_package.rs b/imports/src/parser/parse_package.rs index d6aa4706b5..38e7025910 100644 --- a/imports/src/parser/parse_package.rs +++ b/imports/src/parser/parse_package.rs @@ -15,7 +15,7 @@ // along with the Leo library. If not, see . use crate::{errors::ImportParserError, ImportParser}; -use leo_asg::{Identifier, Program, Span}; +use leo_asg::{AsgContext, Identifier, Program, Span}; use std::{fs, fs::DirEntry, path::PathBuf}; @@ -23,18 +23,19 @@ static SOURCE_FILE_EXTENSION: &str = ".leo"; static SOURCE_DIRECTORY_NAME: &str = "src/"; static IMPORTS_DIRECTORY_NAME: &str = "imports/"; -impl ImportParser { +impl<'a> ImportParser<'a> { fn parse_package_access( &mut self, + ctx: AsgContext<'a>, package: &DirEntry, remaining_segments: &[&str], span: &Span, - ) -> Result { + ) -> Result, ImportParserError> { if !remaining_segments.is_empty() { - return self.parse_package(package.path(), remaining_segments, span); + return self.parse_package(ctx, package.path(), remaining_segments, span); } let program = Self::parse_import_file(package, span)?; - let asg = leo_asg::InternalProgram::new(&program, self)?; + let asg = leo_asg::InternalProgram::new(ctx, &program, self)?; Ok(asg) } @@ -46,10 +47,11 @@ impl ImportParser { /// pub(crate) fn parse_package( &mut self, + ctx: AsgContext<'a>, mut path: PathBuf, segments: &[&str], span: &Span, - ) -> Result { + ) -> Result, ImportParserError> { let error_path = path.clone(); let package_name = segments[0]; @@ -111,8 +113,8 @@ impl ImportParser { package_name, span, ))), - (Some(source_entry), None) => self.parse_package_access(&source_entry, &segments[1..], span), - (None, Some(import_entry)) => self.parse_package_access(&import_entry, &segments[1..], span), + (Some(source_entry), None) => self.parse_package_access(ctx, &source_entry, &segments[1..], span), + (None, Some(import_entry)) => self.parse_package_access(ctx, &import_entry, &segments[1..], span), (None, None) => Err(ImportParserError::unknown_package(Identifier::new_with_span( package_name, span, @@ -121,7 +123,7 @@ impl ImportParser { } else { // Enforce local package access with no found imports directory match matched_source_entry { - Some(source_entry) => self.parse_package_access(&source_entry, &segments[1..], span), + Some(source_entry) => self.parse_package_access(ctx, &source_entry, &segments[1..], span), None => Err(ImportParserError::unknown_package(Identifier::new_with_span( package_name, span, diff --git a/imports/src/parser/parse_symbol.rs b/imports/src/parser/parse_symbol.rs index 0b0a5b2c7d..9b02656d75 100644 --- a/imports/src/parser/parse_symbol.rs +++ b/imports/src/parser/parse_symbol.rs @@ -22,7 +22,7 @@ use std::fs::DirEntry; static LIBRARY_FILE: &str = "src/lib.leo"; -impl ImportParser { +impl<'a> ImportParser<'a> { /// /// Returns a Leo syntax tree from a given package. /// diff --git a/leo/commands/build.rs b/leo/commands/build.rs index ac22fdba52..be11d31d4b 100644 --- a/leo/commands/build.rs +++ b/leo/commands/build.rs @@ -19,7 +19,10 @@ use crate::{ context::Context, synthesizer::{CircuitSynthesizer, SerializedCircuit}, }; -use leo_compiler::{compiler::Compiler, group::targets::edwards_bls12::EdwardsGroupType}; +use leo_compiler::{ + compiler::{thread_leaked_context, Compiler}, + group::targets::edwards_bls12::EdwardsGroupType, +}; use leo_package::{ inputs::*, outputs::{ChecksumFile, CircuitFile, OutputsDirectory, OUTPUTS_DIRECTORY_NAME}, @@ -45,7 +48,7 @@ impl Build { impl Command for Build { type Input = (); - type Output = Option<(Compiler, bool)>; + type Output = Option<(Compiler<'static, Fq, EdwardsGroupType>, bool)>; fn log_span(&self) -> Span { tracing::span!(tracing::Level::INFO, "Build") @@ -86,6 +89,7 @@ impl Command for Build { package_name.clone(), lib_file_path, output_directory.clone(), + thread_leaked_context(), )?; tracing::info!("Complete"); }; @@ -118,6 +122,7 @@ impl Command for Build { &input_path, &state_string, &state_path, + thread_leaked_context(), )?; // Compute the current program checksum diff --git a/leo/commands/package/publish.rs b/leo/commands/package/publish.rs index 7b1fceb3a6..97bce33719 100644 --- a/leo/commands/package/publish.rs +++ b/leo/commands/package/publish.rs @@ -119,14 +119,17 @@ impl Command for Publish { .send(); // Get a response result - let result = match response { - Ok(json_result) => match json_result.json::() { - Ok(json) => json, - Err(error) => { - tracing::warn!("{:?}", error); - return Err(anyhow!("Package not published")); + let result: ResponseJson = match response { + Ok(json_result) => { + let text = json_result.text()?; + + match serde_json::from_str(&text) { + Ok(json) => json, + Err(_) => { + return Err(anyhow!("Package not published: {}", text)); + } } - }, + } Err(error) => { tracing::warn!("{:?}", error); return Err(anyhow!("Connection unavailable")); diff --git a/leo/commands/setup.rs b/leo/commands/setup.rs index 550934f1e2..fafb8467a7 100644 --- a/leo/commands/setup.rs +++ b/leo/commands/setup.rs @@ -47,7 +47,7 @@ impl Setup { impl Command for Setup { type Input = ::Output; type Output = ( - Compiler, + Compiler<'static, Fr, EdwardsGroupType>, Parameters, PreparedVerifyingKey, ); diff --git a/leo/commands/test.rs b/leo/commands/test.rs index f9bf89a003..be683a41c6 100644 --- a/leo/commands/test.rs +++ b/leo/commands/test.rs @@ -15,7 +15,10 @@ // along with the Leo library. If not, see . use crate::{commands::Command, context::Context}; -use leo_compiler::{compiler::Compiler, group::targets::edwards_bls12::EdwardsGroupType}; +use leo_compiler::{ + compiler::{thread_leaked_context, Compiler}, + group::targets::edwards_bls12::EdwardsGroupType, +}; use leo_package::{ inputs::*, outputs::{OutputsDirectory, OUTPUTS_DIRECTORY_NAME}, @@ -109,6 +112,7 @@ impl Command for Test { package_name.clone(), file_path, output_directory.clone(), + thread_leaked_context(), )?; let temporary_program = program; diff --git a/package/src/inputs/pairs.rs b/package/src/inputs/pairs.rs index 9ffcca0df4..0d21e7c7f4 100644 --- a/package/src/inputs/pairs.rs +++ b/package/src/inputs/pairs.rs @@ -27,6 +27,7 @@ pub struct InputPairs { pub pairs: HashMap, } +#[derive(Debug)] pub struct InputPair { pub input_file: String, pub state_file: String, @@ -47,9 +48,12 @@ impl TryFrom<&Path> for InputPairs { let mut pairs = HashMap::::new(); for file in files { - let file_extension = file - .extension() - .ok_or_else(|| InputsDirectoryError::GettingFileExtension(file.as_os_str().to_owned()))?; + // if file name starts with . (dot) None is returned - we're + // skipping these files intentionally but not exiting + let file_extension = match file.extension() { + Some(extension) => extension, + None => continue, + }; let file_name = file .file_stem() @@ -84,10 +88,8 @@ impl TryFrom<&Path> for InputPairs { pairs.insert(file_name.to_owned(), pair); } } else { - return Err(InputsDirectoryError::InvalidFileExtension( - file_name.to_owned(), - file_extension.to_owned(), - )); + // kept for verbosity, can be removed + continue; } } diff --git a/package/src/root/zip.rs b/package/src/root/zip.rs index 76f82c1aee..ee09917dd0 100644 --- a/package/src/root/zip.rs +++ b/package/src/root/zip.rs @@ -19,7 +19,7 @@ use crate::{ errors::ZipFileError, imports::IMPORTS_DIRECTORY_NAME, - inputs::{INPUTS_DIRECTORY_NAME, INPUT_FILE_EXTENSION}, + inputs::{INPUTS_DIRECTORY_NAME, INPUT_FILE_EXTENSION, STATE_FILE_EXTENSION}, outputs::{ CHECKSUM_FILE_EXTENSION, CIRCUIT_FILE_EXTENSION, @@ -151,9 +151,8 @@ impl ZipFile { /// Check if the file path should be included in the package zip file. fn is_included(path: &Path) -> bool { - // excluded directories: `input`, `output`, `imports` - if path.ends_with(INPUTS_DIRECTORY_NAME.trim_end_matches('/')) - | path.ends_with(OUTPUTS_DIRECTORY_NAME.trim_end_matches('/')) + // excluded directories: `output`, `imports` + if path.ends_with(OUTPUTS_DIRECTORY_NAME.trim_end_matches('/')) | path.ends_with(IMPORTS_DIRECTORY_NAME.trim_end_matches('/')) { return false; @@ -161,8 +160,7 @@ fn is_included(path: &Path) -> bool { // excluded extensions: `.in`, `.bytes`, `lpk`, `lvk`, `.proof`, `.sum`, `.zip`, `.bytes` if let Some(true) = path.extension().map(|ext| { - ext.eq(INPUT_FILE_EXTENSION.trim_start_matches('.')) - | ext.eq(ZIP_FILE_EXTENSION.trim_start_matches('.')) + ext.eq(ZIP_FILE_EXTENSION.trim_start_matches('.')) | ext.eq(PROVING_KEY_FILE_EXTENSION.trim_start_matches('.')) | ext.eq(VERIFICATION_KEY_FILE_EXTENSION.trim_start_matches('.')) | ext.eq(PROOF_FILE_EXTENSION.trim_start_matches('.')) @@ -173,6 +171,18 @@ fn is_included(path: &Path) -> bool { return false; } + // Allow `inputs` folder + if path.ends_with(INPUTS_DIRECTORY_NAME.trim_end_matches('/')) { + return true; + } + + // Allow `.state` and `.in` files + if let Some(true) = path.extension().map(|ext| { + ext.eq(INPUT_FILE_EXTENSION.trim_start_matches('.')) | ext.eq(STATE_FILE_EXTENSION.trim_start_matches('.')) + }) { + return true; + } + // Allow the README.md and Leo.toml files in the root directory if (path.ends_with(README_FILENAME) | path.ends_with(MANIFEST_FILENAME)) & (path.parent() == Some(Path::new(""))) { return true;