add circuit variable access and type check

This commit is contained in:
collin 2022-06-22 15:14:00 -10:00
parent a79196c945
commit 0c89c1b5fb
11 changed files with 117 additions and 55 deletions

View File

@ -14,27 +14,27 @@
// You should have received a copy of the GNU General Public License
// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.
use crate::{Expression, Identifier, Node};
use crate::{Identifier, Node, Type};
use leo_span::Span;
use serde::{Deserialize, Serialize};
use std::fmt;
/// An access expression to an associated member variable., e.g. `u8::MAX`.
/// An access expression to an circuit constant., e.g. `u8::MAX`.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct AssociatedVariableAccess {
pub struct AssociatedConstant {
/// The inner circuit type.
pub inner: Box<Expression>,
/// The static circuit member variable that is being accessed.
pub ty: Type,
/// The circuit constant that is being accessed.
pub name: Identifier,
/// The span for the entire expression `Foo::bar()`.
pub span: Span,
}
impl fmt::Display for AssociatedVariableAccess {
impl fmt::Display for AssociatedConstant {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}::{}", self.inner, self.name)
write!(f, "{}::{}", self.ty, self.name)
}
}
crate::simple_node_impl!(AssociatedVariableAccess);
crate::simple_node_impl!(AssociatedConstant);

View File

@ -14,7 +14,7 @@
// You should have received a copy of the GNU General Public License
// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.
use crate::{Expression, Identifier, Node};
use crate::{Expression, Identifier, Node, Type};
use leo_span::Span;
use serde::{Deserialize, Serialize};
@ -22,9 +22,9 @@ use std::fmt;
/// An access expression to an associated function in a circuit, e.g.`Pedersen64::hash()`.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct AssociatedFunctionCall {
pub struct AssociatedFunction {
/// The inner circuit type.
pub inner: Box<Expression>,
pub ty: Type,
/// The static circuit member function that is being accessed.
pub name: Identifier,
/// The arguments passed to the function `name`.
@ -33,10 +33,10 @@ pub struct AssociatedFunctionCall {
pub span: Span,
}
impl fmt::Display for AssociatedFunctionCall {
impl fmt::Display for AssociatedFunction {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}::{}", self.inner, self.name)
write!(f, "{}::{}", self.ty, self.name)
}
}
crate::simple_node_impl!(AssociatedFunctionCall);
crate::simple_node_impl!(AssociatedFunction);

View File

@ -20,5 +20,5 @@ pub use member_access::*;
mod associated_function_access;
pub use associated_function_access::*;
mod associated_variable_access;
pub use associated_variable_access::*;
mod associated_constant_access;
pub use associated_constant_access::*;

View File

@ -24,7 +24,7 @@ use std::fmt;
/// A member of a circuit definition.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum CircuitMember {
// CAUTION: circuit functions are unstable for Leo testnet3.
// CAUTION: circuit constants are unstable for Leo testnet3.
// /// A static constant in a circuit.
// /// For example: `const foobar: u8 = 42;`.
// CircuitConst(

View File

@ -31,17 +31,17 @@ pub enum AccessExpression {
Member(MemberAccess),
// /// Access to a tuple field using its position, e.g., `tuple.1`.
// Tuple(TupleAccess),
/// Access to an associated variable of a circuit.
AssociatedVariable(AssociatedVariableAccess),
/// Access to an associated function of a circuit.
AssociatedFunction(AssociatedFunctionCall),
/// Access to an associated variable of a circuit e.g `u8::MAX`.
AssociatedConstant(AssociatedConstant),
/// Access to an associated function of a circuit e.g `Pedersen64::hash()`.
AssociatedFunction(AssociatedFunction),
}
impl Node for AccessExpression {
fn span(&self) -> Span {
match self {
AccessExpression::Member(n) => n.span(),
AccessExpression::AssociatedVariable(n) => n.span(),
AccessExpression::AssociatedConstant(n) => n.span(),
AccessExpression::AssociatedFunction(n) => n.span(),
}
}
@ -49,7 +49,7 @@ impl Node for AccessExpression {
fn set_span(&mut self, span: Span) {
match self {
AccessExpression::Member(n) => n.set_span(span),
AccessExpression::AssociatedVariable(n) => n.set_span(span),
AccessExpression::AssociatedConstant(n) => n.set_span(span),
AccessExpression::AssociatedFunction(n) => n.set_span(span),
}
}
@ -64,7 +64,7 @@ impl fmt::Display for AccessExpression {
// ArrayRange(access) => access.fmt(f),
Member(access) => access.fmt(f),
// Tuple(access) => access.fmt(f),
AssociatedVariable(access) => access.fmt(f),
AssociatedConstant(access) => access.fmt(f),
AssociatedFunction(access) => access.fmt(f),
}
}

View File

@ -67,13 +67,12 @@ pub trait ExpressionVisitorDirector<'a>: VisitorDirector<'a> {
) -> Option<Self::Output> {
if let VisitResult::VisitChildren = self.visitor_ref().visit_access(input) {
match input {
AccessExpression::Member(member) => self.visit_expression(&member.inner, additional),
AccessExpression::AssociatedVariable(member) => self.visit_expression(&member.inner, additional),
AccessExpression::Member(member) => return self.visit_expression(&member.inner, additional),
AccessExpression::AssociatedConstant(_member) => {},
AccessExpression::AssociatedFunction(member) => {
member.args.iter().for_each(|member| {
self.visit_expression(member, additional);
member.args.iter().for_each(|expr| {
self.visit_expression(expr, additional);
});
self.visit_expression(&member.inner, additional)
}
};
}

View File

@ -258,10 +258,7 @@ impl ParserContext<'_> {
/// Returns an [`Expression`] AST node if the next tokens represent a
/// method call expression.
fn parse_method_call_expression(&mut self, receiver: Expression) -> Result<Expression> {
// Parse the method name.
let method = self.expect_ident()?;
fn parse_method_call_expression(&mut self, receiver: Expression, method: Identifier) -> Result<Expression> {
// Parse the argument list.
let (mut args, _, span) = self.parse_expr_tuple()?;
let span = receiver.span() + span;
@ -283,7 +280,6 @@ impl ParserContext<'_> {
}))
} else {
// Either an invalid unary/binary operator, or more arguments given.
// todo: add circuit member access
self.emit_err(ParserError::expr_arbitrary_method_call(span));
Ok(Expression::Err(ErrExpression { span }))
}
@ -291,7 +287,14 @@ impl ParserContext<'_> {
/// Returns an [`Expression`] AST node if the next tokens represent a
/// static access expression.
fn parse_static_access_expression(&mut self, circuit_name: Expression) -> Result<Expression> {
fn parse_associated_access_expression(&mut self, circuit_name: Expression) -> Result<Expression> {
// Parse circuit name expression into circuit type.
let circuit_type = if let Expression::Identifier(ident) = circuit_name {
Type::Identifier(ident)
} else {
return Err(ParserError::invalid_associated_access(&circuit_name, circuit_name.span()).into());
};
// Parse the circuit member name (can be variable or function name).
let member_name = self.expect_ident()?;
@ -300,18 +303,18 @@ impl ParserContext<'_> {
// Parse the arguments
let (args, _, end) = self.parse_expr_tuple()?;
// Return the static function access expression.
AccessExpression::AssociatedFunction(AssociatedFunctionCall {
// Return the circuit function.
AccessExpression::AssociatedFunction(AssociatedFunction {
span: circuit_name.span() + end,
inner: Box::new(circuit_name),
ty: circuit_type,
name: member_name,
args,
})
} else {
// Return the static variable access expression.
AccessExpression::AssociatedVariable(AssociatedVariableAccess {
// Return the circuit constant.
AccessExpression::AssociatedConstant(AssociatedConstant {
span: circuit_name.span() + member_name.span(),
inner: Box::new(circuit_name),
ty: circuit_type,
name: member_name,
})
}))
@ -333,10 +336,23 @@ impl ParserContext<'_> {
let mut expr = self.parse_primary_expression()?;
loop {
if self.eat(&Token::Dot) {
// Eat a method call on a type
expr = self.parse_method_call_expression(expr)?
// Parse the method name.
let name = self.expect_ident()?;
if self.check(&Token::LeftParen) {
// Eat a method call on a type
expr = self.parse_method_call_expression(expr, name)?
} else {
// Eat a circuit member access.
expr = Expression::Access(AccessExpression::Member(MemberAccess {
span: expr.span(),
inner: Box::new(expr),
name,
}))
}
} else if self.eat(&Token::DoubleColon) {
expr = self.parse_static_access_expression(expr)?;
// Eat a core circuit constant or core circuit function call.
expr = self.parse_associated_access_expression(expr)?;
} else if self.check(&Token::LeftParen) {
// Parse a function call that's by itself.
let (arguments, _, span) = self.parse_paren_comma_list(|p| p.parse_expression().map(Some))?;

View File

@ -209,10 +209,29 @@ impl<'a> ExpressionVisitorDirector<'a> for Director<'a> {
// CAUTION: This implementation only allows access to core circuits.
if let VisitResult::VisitChildren = self.visitor.visit_access(input) {
match input {
AccessExpression::Member(access) => {
// Lookup circuit type.
if let Some(circuit) = self.visitor.symbol_table.lookup_circuit(&access.name.name) {
// Lookup circuit variable.
if let Some(member) = circuit.members.iter().find(|member| member.name() == access.name.name) {
match member {
CircuitMember::CircuitVariable(_ident, type_) => { return Some(type_.clone()) }
}
} else {
self.visitor.handler.emit_err(
TypeCheckerError::invalid_circuit_variable(&access.name, &access.inner, access.span())
.into(),
);
}
} else {
self.visitor
.handler
.emit_err(TypeCheckerError::invalid_circuit(&access.inner, access.span()).into());
}
}
AccessExpression::AssociatedFunction(access) => {
// Visit core circuit function
let circuit = self.visit_expression(&access.inner, &None);
if let Some(core_instruction) = self.visitor.assert_core_circuit_call(&circuit, &access.name) {
// Check core circuit name and function.
if let Some(core_instruction) = self.visitor.assert_core_circuit_call(&access.ty, &access.name) {
// Check num input arguments.
if core_instruction.num_args() != access.args.len() {
self.visitor.handler.emit_err(
@ -251,12 +270,17 @@ impl<'a> ExpressionVisitorDirector<'a> for Director<'a> {
expected,
access.span(),
));
} else {
self.visitor
.handler
.emit_err(TypeCheckerError::invalid_access_expression(access, access.span()).into());
}
}
expr => self
.visitor
.handler
.emit_err(TypeCheckerError::invalid_access_expression(expr, expr.span()).into()),
// todo: Add support for associated constants (u8::MAX).
}
}
None

View File

@ -113,12 +113,8 @@ impl<'a> TypeChecker<'a> {
/// Emits an error if the `circuit` is not a core library circuit.
/// Emits an error if the `function` is not supported by the circuit.
pub(crate) fn assert_core_circuit_call(
&self,
circuit: &Option<Type>,
function: &Identifier,
) -> Option<CoreInstruction> {
if let Some(Type::Identifier(ident)) = circuit {
pub(crate) fn assert_core_circuit_call(&self, circuit: &Type, function: &Identifier) -> Option<CoreInstruction> {
if let Type::Identifier(ident) = circuit {
// Lookup core circuit
match CoreInstruction::from_symbols(ident.name, function.name) {
None => {

View File

@ -379,7 +379,7 @@ create_messages!(
@formatted
circuit_functions_unstable {
args: (),
msg: "Circuit functions are currently an unstable feature and are disabled in Leo for testnet3",
msg: "Circuit functions are currently an unstable feature and are disabled in Leo for testnet3.",
help: None,
}
@ -387,7 +387,14 @@ create_messages!(
@formatted
circuit_constants_unstable {
args: (),
msg: "Circuit constants are currently an unstable feature and are disabled in Leo for testnet3",
msg: "Circuit constants are currently an unstable feature and are disabled in Leo for testnet3.",
help: None,
}
@formatted
invalid_associated_access {
args: (name: impl Display),
msg: format!("Invalid associated access call to circuit {name}."),
help: Some("Double colon `::` syntax is only supported for core circuits in Leo for testnet3.".to_string()),
}
);

View File

@ -202,4 +202,24 @@ create_messages!(
),
help: None,
}
/// Attempted to access an invalid circuit.
@formatted
invalid_circuit {
args: (circuit: impl Display),
msg: format!(
"Circuit {circuit} is not found in the current scope."
),
help: None,
}
/// Attempted to access an invalid circuit variable.
@formatted
invalid_circuit_variable {
args: (variable: impl Display, circuit: impl Display),
msg: format!(
"Circuit variable {variable} is not a member of circuit {circuit}."
),
help: None,
}
);