diff --git a/dust-lang/src/analyzer.rs b/dust-lang/src/analyzer.rs index 6af3ca0..fda1b53 100644 --- a/dust-lang/src/analyzer.rs +++ b/dust-lang/src/analyzer.rs @@ -13,9 +13,10 @@ use crate::{ ast::{ AbstractSyntaxTree, BlockExpression, CallExpression, ElseExpression, FieldAccessExpression, IfExpression, LetStatement, ListExpression, ListIndexExpression, LoopExpression, Node, - OperatorExpression, RangeExpression, Statement, StructExpression, TupleAccessExpression, + OperatorExpression, RangeExpression, Span, Statement, StructExpression, + TupleAccessExpression, }, - parse, Context, DustError, Expression, Identifier, Span, Type, + parse, Context, DustError, Expression, Identifier, Type, }; /// Analyzes the abstract syntax tree for errors. diff --git a/dust-lang/src/ast/expression.rs b/dust-lang/src/ast/expression.rs index 835e2ee..f5b352a 100644 --- a/dust-lang/src/ast/expression.rs +++ b/dust-lang/src/ast/expression.rs @@ -1,13 +1,14 @@ use std::{ cmp::Ordering, + collections::HashMap, fmt::{self, Display, Formatter}, }; use serde::{Deserialize, Serialize}; -use crate::{Context, FunctionType, Identifier, Span, StructType, Type, Value}; +use crate::{Context, FieldsStructType, FunctionType, Identifier, StructType, TupleType, Type}; -use super::{Node, Statement}; +use super::{Node, Span, Statement}; #[derive(Debug, Clone, Eq, PartialEq, PartialOrd, Ord, Serialize, Deserialize)] pub enum Expression { @@ -221,6 +222,13 @@ impl Expression { Self::Literal(Node::new(Box::new(literal), position)) } + pub fn has_block(&self) -> bool { + matches!( + self, + Expression::Block(_) | Expression::If(_) | Expression::Loop(_) + ) + } + pub fn as_identifier(&self) -> Option<&Identifier> { if let Expression::Identifier(identifier) = self { Some(&identifier.inner) @@ -253,7 +261,9 @@ impl Expression { let container_type = container.return_type(context)?; - if let Type::Struct(StructType::Fields { fields, .. }) = container_type { + if let Type::Struct(StructType::Fields(FieldsStructType { fields, .. })) = + container_type + { fields .into_iter() .find(|(name, _)| name == &field.inner) @@ -325,26 +335,30 @@ impl Expression { }, Expression::Range(_) => Some(Type::Range), Expression::Struct(struct_expression) => match struct_expression.inner.as_ref() { - StructExpression::Fields { name, fields } => { - let mut field_types = Vec::with_capacity(fields.len()); + StructExpression::Fields { fields, .. } => { + let mut types = HashMap::with_capacity(fields.len()); - for (field_name, expression) in fields { + for (field, expression) in fields { let r#type = expression.return_type(context)?; - field_types.push((field_name.inner.clone(), r#type)); + types.insert(field.inner.clone(), r#type); } - Some(Type::Struct(StructType::Fields { - name: name.inner.clone(), - fields: field_types, - })) + Some(Type::Struct(StructType::Fields(FieldsStructType { + fields: types, + }))) } - StructExpression::Unit { name } => Some(Type::Struct(StructType::Unit { - name: name.inner.clone(), - })), + StructExpression::Unit { .. } => Some(Type::Struct(StructType::Unit)), }, Expression::TupleAccess(tuple_access_expression) => { - todo!() + let TupleAccessExpression { tuple, index } = tuple_access_expression.inner.as_ref(); + let tuple_value = tuple.return_type(context)?; + + if let Type::Tuple(TupleType { fields }) = tuple_value { + fields.get(index.inner).cloned() + } else { + None + } } } } diff --git a/dust-lang/src/ast/mod.rs b/dust-lang/src/ast/mod.rs index 05e8e28..a706cf1 100644 --- a/dust-lang/src/ast/mod.rs +++ b/dust-lang/src/ast/mod.rs @@ -12,7 +12,7 @@ use std::{ use serde::{Deserialize, Serialize}; -use crate::Span; +pub type Span = (usize, usize); /// In-memory representation of a Dust program. #[derive(Debug, Clone, Eq, PartialEq, PartialOrd, Ord, Serialize, Deserialize)] diff --git a/dust-lang/src/ast/statement.rs b/dust-lang/src/ast/statement.rs index bb6bffa..a4e2fa7 100644 --- a/dust-lang/src/ast/statement.rs +++ b/dust-lang/src/ast/statement.rs @@ -2,9 +2,9 @@ use std::fmt::{self, Display, Formatter}; use serde::{Deserialize, Serialize}; -use crate::{Context, Identifier, Span, Type}; +use crate::{Context, Identifier, Type}; -use super::{Expression, Node}; +use super::{Expression, Node, Span}; #[derive(Debug, Clone, Eq, PartialEq, PartialOrd, Ord, Serialize, Deserialize)] pub enum Statement { diff --git a/dust-lang/src/context.rs b/dust-lang/src/context.rs index 137c395..d64ab5a 100644 --- a/dust-lang/src/context.rs +++ b/dust-lang/src/context.rs @@ -4,7 +4,7 @@ use std::{ sync::{Arc, PoisonError as StdPoisonError, RwLock, RwLockWriteGuard}, }; -use crate::{Identifier, Span, Type, Value}; +use crate::{ast::Span, Identifier, Type, Value}; pub type Variables = HashMap; diff --git a/dust-lang/src/lexer.rs b/dust-lang/src/lexer.rs index c80bb05..9614e96 100644 --- a/dust-lang/src/lexer.rs +++ b/dust-lang/src/lexer.rs @@ -8,7 +8,7 @@ use std::{ fmt::{self, Display, Formatter}, }; -use crate::{Span, Token}; +use crate::{ast::Span, Token}; /// Lexes the input and return a vector of tokens and their positions. /// diff --git a/dust-lang/src/lib.rs b/dust-lang/src/lib.rs index 1493807..c393cb9 100644 --- a/dust-lang/src/lib.rs +++ b/dust-lang/src/lib.rs @@ -29,16 +29,14 @@ pub mod value; pub mod vm; pub use analyzer::{analyze, Analyzer, AnalyzerError}; -pub use ast::{AbstractSyntaxTree, Expression, Node, Statement}; +pub use ast::{AbstractSyntaxTree, Expression, Statement}; pub use built_in_function::{BuiltInFunction, BuiltInFunctionError}; pub use context::{Context, VariableData}; pub use dust_error::DustError; pub use identifier::Identifier; pub use lexer::{lex, LexError, Lexer}; pub use parser::{parse, ParseError, Parser}; -pub use r#type::{EnumType, FunctionType, StructType, Type}; +pub use r#type::*; pub use token::{Token, TokenKind, TokenOwned}; -pub use value::{Struct, Value, ValueError}; +pub use value::*; pub use vm::{run, run_with_context, Vm, VmError}; - -pub type Span = (usize, usize); diff --git a/dust-lang/src/parser.rs b/dust-lang/src/parser.rs index 8f546e7..1b39bf8 100644 --- a/dust-lang/src/parser.rs +++ b/dust-lang/src/parser.rs @@ -11,9 +11,7 @@ use std::{ str::ParseBoolError, }; -use crate::{ - ast::*, DustError, Identifier, LexError, Lexer, Span, Token, TokenKind, TokenOwned, Type, -}; +use crate::{ast::*, DustError, Identifier, LexError, Lexer, Token, TokenKind, TokenOwned, Type}; /// Parses the input into an abstract syntax tree. /// @@ -146,6 +144,8 @@ impl<'src> Parser<'src> { let start_position = self.current_position; if let Token::Let = self.current_token { + log::trace!("Parsing let statement"); + self.next_token()?; let is_mutable = if let Token::Mut = self.current_token { @@ -170,21 +170,33 @@ impl<'src> Parser<'src> { let value = self.parse_expression(0)?; - if let Token::Semicolon = self.current_token { + let end = if let Token::Semicolon = self.current_token { + let end = self.current_position.1; + self.next_token()?; - } + + end + } else { + return Err(ParseError::ExpectedToken { + expected: TokenKind::Semicolon, + actual: self.current_token.to_owned(), + position: self.current_position, + }); + }; let r#let = if is_mutable { LetStatement::LetMut { identifier, value } } else { LetStatement::Let { identifier, value } }; - let position = (start_position.0, self.current_position.1); + let position = (start_position.0, end); return Ok(Statement::Let(Node::new(r#let, position))); } if let Token::Struct = self.current_token { + log::trace!("Parsing struct definition"); + self.next_token()?; let (name, name_end) = if let Token::Identifier(_) = self.current_token { @@ -276,18 +288,27 @@ impl<'src> Parser<'src> { } let expression = self.parse_expression(0)?; - - let statement = if let Token::Semicolon = self.current_token { - let position = (start_position.0, self.current_position.1); - + let end = self.current_position.1; + let is_nullified = if let Token::Semicolon = self.current_token { self.next_token()?; - Statement::ExpressionNullified(Node::new(expression, position)) + true } else { - Statement::Expression(expression) + matches!( + expression, + Expression::Block(_) | Expression::Loop(_) | Expression::If(_) + ) }; - Ok(statement) + if is_nullified { + let position = (start_position.0, end); + + Ok(Statement::ExpressionNullified(Node::new( + expression, position, + ))) + } else { + Ok(Statement::Expression(expression)) + } } fn next_token(&mut self) -> Result<(), ParseError> { @@ -348,7 +369,8 @@ impl<'src> Parser<'src> { Ok(Expression::negation(operand, position)) } - _ => Err(ParseError::UnexpectedToken { + _ => Err(ParseError::ExpectedTokenMultiple { + expected: vec![TokenKind::Bang, TokenKind::Minus], actual: self.current_token.to_owned(), position: self.current_position, }), @@ -363,7 +385,7 @@ impl<'src> Parser<'src> { match self.current_token { Token::Async => { let block = self.parse_block()?; - let position = (start_position.0, self.current_position.1); + let position = (start_position.0, block.position.1); Ok(Expression::block(block.inner, position)) } @@ -460,12 +482,10 @@ impl<'src> Parser<'src> { )) } Token::If => { - let start = self.current_position.0; - self.next_token()?; let r#if = self.parse_if()?; - let position = (start, self.current_position.1); + let position = (start_position.0, self.current_position.1); Ok(Expression::r#if(r#if, position)) } @@ -577,7 +597,20 @@ impl<'src> Parser<'src> { Ok(Expression::while_loop(condition, block, position)) } - _ => Err(ParseError::UnexpectedToken { + _ => Err(ParseError::ExpectedTokenMultiple { + expected: vec![ + TokenKind::Async, + TokenKind::Boolean, + TokenKind::Float, + TokenKind::Identifier, + TokenKind::Integer, + TokenKind::If, + TokenKind::LeftCurlyBrace, + TokenKind::LeftParenthesis, + TokenKind::LeftSquareBrace, + TokenKind::String, + TokenKind::While, + ], actual: self.current_token.to_owned(), position: self.current_position, }), @@ -798,7 +831,12 @@ impl<'src> Parser<'src> { Expression::list_index(ListIndexExpression { list: left, index }, position) } _ => { - return Err(ParseError::UnexpectedToken { + return Err(ParseError::ExpectedTokenMultiple { + expected: vec![ + TokenKind::Dot, + TokenKind::LeftParenthesis, + TokenKind::LeftSquareBrace, + ], actual: self.current_token.to_owned(), position: self.current_position, }); @@ -1064,9 +1102,58 @@ mod tests { use super::*; + #[test] + fn let_mut_while_loop() { + env_logger::builder().is_test(true).try_init().ok(); + + let source = "let mut x = 0; while x < 10 { x += 1 };"; + + assert_eq!( + parse(source), + Ok(AbstractSyntaxTree { + statements: [ + Statement::Let(Node::new( + LetStatement::LetMut { + identifier: Node::new(Identifier::new("x"), (8, 9)), + value: Expression::literal(LiteralExpression::Integer(0), (12, 13)), + }, + (0, 14), + )), + Statement::ExpressionNullified(Node::new( + Expression::while_loop( + Expression::comparison( + Expression::identifier(Identifier::new("x"), (21, 22)), + Node::new(ComparisonOperator::LessThan, (23, 24)), + Expression::literal(LiteralExpression::Integer(10), (25, 27)), + (21, 27), + ), + Node::new( + BlockExpression::Sync(vec![Statement::Expression( + Expression::compound_assignment( + Expression::identifier(Identifier::new("x"), (30, 31)), + Node::new(MathOperator::Add, (32, 34)), + Expression::literal( + LiteralExpression::Integer(1), + (35, 36) + ), + (30, 36), + ), + )]), + (28, 38), + ), + (15, 39), + ), + (15, 39) + )) + ] + .into() + }) + ); + } + #[test] fn let_statement() { - let source = "let x = 42"; + let source = "let x = 42;"; assert_eq!( parse(source), @@ -1076,7 +1163,7 @@ mod tests { identifier: Node::new(Identifier::new("x"), (4, 5)), value: Expression::literal(LiteralExpression::Integer(42), (8, 10)), }, - (0, 10), + (0, 11), ))] .into() }) @@ -1085,7 +1172,7 @@ mod tests { #[test] fn let_mut_statement() { - let source = "let mut x = false"; + let source = "let mut x = false;"; assert_eq!( parse(source), @@ -1095,7 +1182,7 @@ mod tests { identifier: Node::new(Identifier::new("x"), (8, 9)), value: Expression::literal(LiteralExpression::Boolean(false), (12, 17)), }, - (0, 17), + (0, 18), ))] .into() }) @@ -1109,31 +1196,43 @@ mod tests { assert_eq!( parse(source), Ok(AbstractSyntaxTree { - statements: [Statement::Expression(Expression::block( - BlockExpression::Async(vec![ - Statement::ExpressionNullified(Node::new( - Expression::operator( + statements: [Statement::ExpressionNullified(Node::new( + Expression::block( + BlockExpression::Async(vec![ + Statement::ExpressionNullified(Node::new( + Expression::operator( + OperatorExpression::Assignment { + assignee: Expression::identifier( + Identifier::new("x"), + (8, 9) + ), + value: Expression::literal( + LiteralExpression::Integer(42), + (12, 14) + ), + }, + (8, 14) + ), + (8, 15) + )), + Statement::Expression(Expression::operator( OperatorExpression::Assignment { - assignee: Expression::identifier(Identifier::new("x"), (8, 9)), + assignee: Expression::identifier( + Identifier::new("y"), + (16, 17) + ), value: Expression::literal( - LiteralExpression::Integer(42), - (12, 14) + LiteralExpression::Float(4.0), + (20, 23) ), }, - (8, 14) - ), - (8, 15) - )), - Statement::Expression(Expression::operator( - OperatorExpression::Assignment { - assignee: Expression::identifier(Identifier::new("y"), (16, 17)), - value: Expression::literal(LiteralExpression::Float(4.0), (20, 23)), - }, - (16, 23) - )) - ]), + (16, 23) + )) + ]), + (0, 25) + ), (0, 25) - ),)] + ))] .into() }) ); @@ -1454,16 +1553,19 @@ mod tests { assert_eq!( parse(source), Ok(AbstractSyntaxTree::with_statements([ - Statement::Expression(Expression::r#if( - IfExpression::If { - condition: Expression::identifier(Identifier::new("x"), (3, 4)), - if_block: Node::new( - BlockExpression::Sync(vec![Statement::Expression( - Expression::identifier(Identifier::new("y"), (7, 8)) - )]), - (5, 10) - ) - }, + Statement::ExpressionNullified(Node::new( + Expression::r#if( + IfExpression::If { + condition: Expression::identifier(Identifier::new("x"), (3, 4)), + if_block: Node::new( + BlockExpression::Sync(vec![Statement::Expression( + Expression::identifier(Identifier::new("y"), (7, 8)) + )]), + (5, 10) + ) + }, + (0, 10) + ), (0, 10) )) ])) @@ -1477,22 +1579,25 @@ mod tests { assert_eq!( parse(source), Ok(AbstractSyntaxTree::with_statements([ - Statement::Expression(Expression::r#if( - IfExpression::IfElse { - condition: Expression::identifier(Identifier::new("x"), (3, 4)), - if_block: Node::new( - BlockExpression::Sync(vec![Statement::Expression( - Expression::identifier(Identifier::new("y"), (7, 8)) - )]), - (5, 10) - ), - r#else: ElseExpression::Block(Node::new( - BlockExpression::Sync(vec![Statement::Expression( - Expression::identifier(Identifier::new("z"), (18, 19)) - )]), - (16, 21) - )) - }, + Statement::ExpressionNullified(Node::new( + Expression::r#if( + IfExpression::IfElse { + condition: Expression::identifier(Identifier::new("x"), (3, 4)), + if_block: Node::new( + BlockExpression::Sync(vec![Statement::Expression( + Expression::identifier(Identifier::new("y"), (7, 8)) + )]), + (5, 10) + ), + r#else: ElseExpression::Block(Node::new( + BlockExpression::Sync(vec![Statement::Expression( + Expression::identifier(Identifier::new("z"), (18, 19)) + )]), + (16, 21) + )) + }, + (0, 21) + ), (0, 21) )) ])) @@ -1506,34 +1611,40 @@ mod tests { assert_eq!( parse(source), Ok(AbstractSyntaxTree::with_statements([ - Statement::Expression(Expression::r#if( - IfExpression::IfElse { - condition: Expression::identifier(Identifier::new("x"), (3, 4)), - if_block: Node::new( - BlockExpression::Sync(vec![Statement::Expression( - Expression::identifier(Identifier::new("y"), (7, 8)) - )]), - (5, 10) - ), - r#else: ElseExpression::If(Node::new( - Box::new(IfExpression::IfElse { - condition: Expression::identifier(Identifier::new("z"), (19, 20)), - if_block: Node::new( - BlockExpression::Sync(vec![Statement::Expression( - Expression::identifier(Identifier::new("a"), (23, 24)) - )]), - (21, 26) - ), - r#else: ElseExpression::Block(Node::new( - BlockExpression::Sync(vec![Statement::Expression( - Expression::identifier(Identifier::new("b"), (34, 35)) - )]), - (32, 37) - )), - }), - (16, 37) - )), - }, + Statement::ExpressionNullified(Node::new( + Expression::r#if( + IfExpression::IfElse { + condition: Expression::identifier(Identifier::new("x"), (3, 4)), + if_block: Node::new( + BlockExpression::Sync(vec![Statement::Expression( + Expression::identifier(Identifier::new("y"), (7, 8)) + )]), + (5, 10) + ), + r#else: ElseExpression::If(Node::new( + Box::new(IfExpression::IfElse { + condition: Expression::identifier( + Identifier::new("z"), + (19, 20) + ), + if_block: Node::new( + BlockExpression::Sync(vec![Statement::Expression( + Expression::identifier(Identifier::new("a"), (23, 24)) + )]), + (21, 26) + ), + r#else: ElseExpression::Block(Node::new( + BlockExpression::Sync(vec![Statement::Expression( + Expression::identifier(Identifier::new("b"), (34, 35)) + )]), + (32, 37) + )), + }), + (16, 37) + )), + }, + (0, 37) + ), (0, 37) )) ])) @@ -1547,28 +1658,39 @@ mod tests { assert_eq!( parse(source), Ok(AbstractSyntaxTree::with_statements([ - Statement::Expression(Expression::while_loop( - Expression::operator( - OperatorExpression::Comparison { - left: Expression::identifier(Identifier::new("x"), (6, 7)), - operator: Node::new(ComparisonOperator::LessThan, (8, 9)), - right: Expression::literal(LiteralExpression::Integer(10), (10, 12)), - }, - (6, 12) - ), - Node::new( - BlockExpression::Sync(vec![Statement::Expression(Expression::operator( - OperatorExpression::CompoundAssignment { - assignee: Expression::identifier(Identifier::new("x"), (15, 16)), - operator: Node::new(MathOperator::Add, (17, 19)), - modifier: Expression::literal( - LiteralExpression::Integer(1), - (20, 21) + Statement::ExpressionNullified(Node::new( + Expression::while_loop( + Expression::operator( + OperatorExpression::Comparison { + left: Expression::identifier(Identifier::new("x"), (6, 7)), + operator: Node::new(ComparisonOperator::LessThan, (8, 9)), + right: Expression::literal( + LiteralExpression::Integer(10), + (10, 12) ), }, - (15, 21) - ))]), - (13, 23) + (6, 12) + ), + Node::new( + BlockExpression::Sync(vec![Statement::Expression( + Expression::operator( + OperatorExpression::CompoundAssignment { + assignee: Expression::identifier( + Identifier::new("x"), + (15, 16) + ), + operator: Node::new(MathOperator::Add, (17, 19)), + modifier: Expression::literal( + LiteralExpression::Integer(1), + (20, 21) + ), + }, + (15, 21) + ) + )]), + (13, 23) + ), + (0, 23) ), (0, 23) )) @@ -1621,15 +1743,18 @@ mod tests { assert_eq!( parse(source), Ok(AbstractSyntaxTree::with_statements([ - Statement::Expression(Expression::block( - BlockExpression::Sync(vec![Statement::Expression(Expression::operator( - OperatorExpression::Math { - left: Expression::literal(LiteralExpression::Integer(40), (2, 4)), - operator: Node::new(MathOperator::Add, (5, 6)), - right: Expression::literal(LiteralExpression::Integer(2), (7, 8)), - }, - (2, 8) - ))]), + Statement::ExpressionNullified(Node::new( + Expression::block( + BlockExpression::Sync(vec![Statement::Expression(Expression::operator( + OperatorExpression::Math { + left: Expression::literal(LiteralExpression::Integer(40), (2, 4)), + operator: Node::new(MathOperator::Add, (5, 6)), + right: Expression::literal(LiteralExpression::Integer(2), (7, 8)), + }, + (2, 8) + ))]), + (0, 10) + ), (0, 10) )) ])) @@ -1643,51 +1768,57 @@ mod tests { assert_eq!( parse(source), Ok(AbstractSyntaxTree::with_statements([ - Statement::Expression(Expression::block( - BlockExpression::Sync(vec![ - Statement::ExpressionNullified(Node::new( - Expression::operator( - OperatorExpression::Assignment { - assignee: Expression::identifier( - Identifier::new("foo"), - (2, 5) - ), - value: Expression::literal( - LiteralExpression::Integer(42), - (8, 10) - ), - }, - (2, 10) - ), - (2, 11) - ),), - Statement::ExpressionNullified(Node::new( - Expression::operator( - OperatorExpression::Assignment { - assignee: Expression::identifier( - Identifier::new("bar"), - (12, 15) - ), - value: Expression::literal( - LiteralExpression::Integer(42), - (18, 20) - ), - }, - (12, 20) - ), - (12, 21) - ),), - Statement::Expression(Expression::operator( - OperatorExpression::Assignment { - assignee: Expression::identifier(Identifier::new("baz"), (22, 25)), - value: Expression::literal( - LiteralExpression::String("42".to_string()), - (28, 32) + Statement::ExpressionNullified(Node::new( + Expression::block( + BlockExpression::Sync(vec![ + Statement::ExpressionNullified(Node::new( + Expression::operator( + OperatorExpression::Assignment { + assignee: Expression::identifier( + Identifier::new("foo"), + (2, 5) + ), + value: Expression::literal( + LiteralExpression::Integer(42), + (8, 10) + ), + }, + (2, 10) ), - }, - (22, 32) - )), - ]), + (2, 11) + ),), + Statement::ExpressionNullified(Node::new( + Expression::operator( + OperatorExpression::Assignment { + assignee: Expression::identifier( + Identifier::new("bar"), + (12, 15) + ), + value: Expression::literal( + LiteralExpression::Integer(42), + (18, 20) + ), + }, + (12, 20) + ), + (12, 21) + ),), + Statement::Expression(Expression::operator( + OperatorExpression::Assignment { + assignee: Expression::identifier( + Identifier::new("baz"), + (22, 25) + ), + value: Expression::literal( + LiteralExpression::String("42".to_string()), + (28, 32) + ), + }, + (22, 32) + )), + ]), + (0, 34) + ), (0, 34) )) ])) diff --git a/dust-lang/src/type.rs b/dust-lang/src/type.rs index 5f4f9f8..6e2bf5b 100644 --- a/dust-lang/src/type.rs +++ b/dust-lang/src/type.rs @@ -10,6 +10,8 @@ //! library's "length" function does not care about the type of item in the list, only the list //! itself. So the input is defined as `[any]`, i.e. `Type::ListOf(Box::new(Type::Any))`. use std::{ + cmp::Ordering, + collections::HashMap, fmt::{self, Display, Formatter}, sync::Arc, }; @@ -47,7 +49,7 @@ pub enum Type { Range, String, Struct(StructType), - Tuple(Vec), + Tuple(TupleType), } impl Type { @@ -264,7 +266,7 @@ impl Display for Type { Type::Range => write!(f, "range"), Type::String => write!(f, "str"), Type::Struct(struct_type) => write!(f, "{struct_type}"), - Type::Tuple(fields) => { + Type::Tuple(TupleType { fields }) => { write!(f, "(")?; for (index, r#type) in fields.iter().enumerate() { @@ -331,24 +333,16 @@ impl Display for FunctionType { #[derive(Clone, Debug, Eq, PartialEq, PartialOrd, Ord, Serialize, Deserialize)] pub enum StructType { - Unit { - name: Identifier, - }, - Tuple { - name: Identifier, - fields: Vec, - }, - Fields { - name: Identifier, - fields: Vec<(Identifier, Type)>, - }, + Unit, + Tuple(TupleType), + Fields(FieldsStructType), } impl Display for StructType { fn fmt(&self, f: &mut Formatter) -> fmt::Result { match self { - StructType::Unit { .. } => write!(f, "()"), - StructType::Tuple { fields, .. } => { + StructType::Unit => write!(f, "()"), + StructType::Tuple(TupleType { fields, .. }) => { write!(f, "(")?; for (index, r#type) in fields.iter().enumerate() { @@ -361,12 +355,8 @@ impl Display for StructType { write!(f, ")") } - StructType::Fields { - name: identifier, - fields, - .. - } => { - write!(f, "{identifier} {{ ")?; + StructType::Fields(FieldsStructType { fields, .. }) => { + write!(f, "{{ ")?; for (index, (identifier, r#type)) in fields.iter().enumerate() { write!(f, "{identifier}: {type}")?; @@ -382,6 +372,28 @@ impl Display for StructType { } } +#[derive(Clone, Debug, Eq, PartialEq, PartialOrd, Ord, Serialize, Deserialize)] +pub struct TupleType { + pub fields: Vec, +} + +#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)] +pub struct FieldsStructType { + pub fields: HashMap, +} + +impl PartialOrd for FieldsStructType { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for FieldsStructType { + fn cmp(&self, other: &Self) -> Ordering { + self.fields.iter().cmp(other.fields.iter()) + } +} + #[derive(Clone, Debug, Eq, PartialEq, PartialOrd, Ord, Serialize, Deserialize)] pub struct EnumType { name: Identifier, diff --git a/dust-lang/src/value.rs b/dust-lang/src/value.rs index 27d244e..4afc5d0 100644 --- a/dust-lang/src/value.rs +++ b/dust-lang/src/value.rs @@ -1,6 +1,7 @@ //! Dust value representation use std::{ cmp::Ordering, + collections::HashMap, error::Error, fmt::{self, Display, Formatter}, ops::{Range, RangeInclusive}, @@ -14,7 +15,8 @@ use serde::{ }; use crate::{ - AbstractSyntaxTree, Context, EnumType, FunctionType, Identifier, StructType, Type, Vm, VmError, + AbstractSyntaxTree, Context, EnumType, FieldsStructType, FunctionType, Identifier, StructType, + TupleType, Type, Vm, VmError, }; /// Dust value representation @@ -197,14 +199,25 @@ impl Value { Value::RangeInclusive(_) => Type::Range, Value::String(_) => Type::String, Value::Struct(r#struct) => match r#struct { - Struct::Unit { r#type } => Type::Struct(r#type.clone()), - Struct::Tuple { r#type, .. } => Type::Struct(r#type.clone()), - Struct::Fields { r#type, .. } => Type::Struct(r#type.clone()), + Struct::Unit { .. } => Type::Struct(StructType::Unit), + Struct::Tuple { fields, .. } => { + let types = fields.iter().map(|field| field.r#type()).collect(); + + Type::Struct(StructType::Tuple(TupleType { fields: types })) + } + Struct::Fields { fields, .. } => { + let types = fields + .iter() + .map(|(identifier, value)| (identifier.clone(), value.r#type())) + .collect(); + + Type::Struct(StructType::Fields(FieldsStructType { fields: types })) + } }, Value::Tuple(values) => { - let item_types = values.iter().map(Value::r#type).collect(); + let fields = values.iter().map(|value| value.r#type()).collect(); - Type::Tuple(item_types) + Type::Tuple(TupleType { fields }) } } } @@ -212,23 +225,7 @@ impl Value { pub fn get_field(&self, field: &Identifier) -> Option { match self { Value::Mutable(inner) => inner.read().unwrap().get_field(field), - Value::Struct(Struct::Fields { - fields, - r#type: - StructType::Fields { - fields: field_types, - .. - }, - }) => field_types - .iter() - .zip(fields.iter()) - .find_map(|((identifier, _), value)| { - if identifier == field { - Some(value.clone()) - } else { - None - } - }), + Value::Struct(Struct::Fields { fields, .. }) => fields.get(field).cloned(), _ => None, } } @@ -237,6 +234,7 @@ impl Value { match self { Value::List(values) => values.get(index).cloned(), Value::Mutable(inner) => inner.read().unwrap().get_index(index), + Value::Struct(Struct::Tuple { fields, .. }) => fields.get(index).cloned(), _ => None, } } @@ -1055,21 +1053,74 @@ impl<'de> Deserialize<'de> for Function { } } -#[derive(Clone, Debug, Eq, PartialEq, PartialOrd, Ord, Serialize, Deserialize)] +#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)] pub enum Struct { Unit { - r#type: StructType, + name: Identifier, }, Tuple { - r#type: StructType, + name: Identifier, fields: Vec, }, Fields { - r#type: StructType, - fields: Vec, + name: Identifier, + fields: HashMap, }, } +impl PartialOrd for Struct { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for Struct { + fn cmp(&self, other: &Self) -> Ordering { + match (self, other) { + (Struct::Unit { name: left }, Struct::Unit { name: right }) => left.cmp(right), + (Struct::Unit { .. }, _) => Ordering::Greater, + ( + Struct::Tuple { + name: left_name, + fields: left_fields, + }, + Struct::Tuple { + name: right_name, + fields: right_fields, + }, + ) => { + let type_cmp = left_name.cmp(right_name); + + if type_cmp != Ordering::Equal { + return type_cmp; + } + + left_fields.cmp(right_fields) + } + (Struct::Tuple { .. }, _) => Ordering::Greater, + ( + Struct::Fields { + name: left_name, + fields: left_fields, + }, + Struct::Fields { + name: right_name, + fields: right_fields, + }, + ) => { + let type_cmp = left_name.cmp(right_name); + + if type_cmp != Ordering::Equal { + return type_cmp; + } + + left_fields.into_iter().cmp(right_fields.into_iter()) + } + (Struct::Fields { .. }, _) => Ordering::Greater, + } + } +} + impl Display for Struct { fn fmt(&self, f: &mut Formatter) -> fmt::Result { match self { @@ -1087,19 +1138,10 @@ impl Display for Struct { write!(f, ")") } - Struct::Fields { - fields, - r#type: - StructType::Fields { - fields: field_types, - .. - }, - } => { + Struct::Fields { fields, .. } => { write!(f, "{{ ")?; - for (index, ((identifier, _), value)) in - field_types.iter().zip(fields.iter()).enumerate() - { + for (index, (identifier, value)) in fields.iter().enumerate() { if index > 0 { write!(f, ", ")?; } @@ -1115,7 +1157,7 @@ impl Display for Struct { } #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -enum Rangeable { +pub enum Rangeable { Byte(u8), Character(char), Float(f64), diff --git a/dust-lang/src/vm.rs b/dust-lang/src/vm.rs index 8c479f6..42923fb 100644 --- a/dust-lang/src/vm.rs +++ b/dust-lang/src/vm.rs @@ -16,10 +16,10 @@ use crate::{ AbstractSyntaxTree, BlockExpression, CallExpression, ComparisonOperator, ElseExpression, FieldAccessExpression, IfExpression, LetStatement, ListExpression, ListIndexExpression, LiteralExpression, LogicOperator, LoopExpression, MathOperator, Node, OperatorExpression, - RangeExpression, Statement, StructDefinition, + RangeExpression, Span, Statement, StructDefinition, }, parse, Analyzer, BuiltInFunctionError, Context, DustError, Expression, Identifier, ParseError, - Span, StructType, Type, Value, ValueError, + StructType, TupleType, Type, Value, ValueError, }; /// Run the source code and return the result. @@ -121,10 +121,12 @@ impl Vm { } Statement::StructDefinition(struct_definition) => { let (name, struct_type) = match struct_definition.inner { - StructDefinition::Unit { name } => { - (name.inner.clone(), StructType::Unit { name: name.inner }) + StructDefinition::Unit { name } => (name.inner.clone(), StructType::Unit), + StructDefinition::Tuple { name, items } => { + let fields = items.into_iter().map(|item| item.inner).collect(); + + (name.inner.clone(), StructType::Tuple(TupleType { fields })) } - StructDefinition::Tuple { name, items } => todo!(), StructDefinition::Fields { name, fields } => todo!(), }; @@ -988,7 +990,9 @@ impl Display for VmError { #[cfg(test)] mod tests { - use crate::{Struct, StructType, Type}; + use std::collections::HashMap; + + use crate::Struct; use super::*; @@ -1006,14 +1010,11 @@ mod tests { assert_eq!( run(input), Ok(Some(Value::Struct(Struct::Fields { - r#type: StructType::Fields { - name: Identifier::new("Foo"), - fields: vec![ - (Identifier::new("bar"), Type::Integer), - (Identifier::new("baz"), Type::Float) - ] - }, - fields: vec![Value::Integer(42), Value::Float(4.0)] + name: Identifier::new("Foo"), + fields: HashMap::from([ + (Identifier::new("bar"), Value::Integer(42)), + (Identifier::new("baz"), Value::Float(4.0)) + ]) }))) ); } @@ -1029,10 +1030,7 @@ mod tests { assert_eq!( run(input), Ok(Some(Value::Struct(Struct::Tuple { - r#type: StructType::Tuple { - name: Identifier::new("Foo"), - fields: vec![Type::Integer] - }, + name: Identifier::new("Foo"), fields: vec![Value::Integer(42)] }))) ) @@ -1045,9 +1043,7 @@ mod tests { assert_eq!( run(input), Ok(Some(Value::Struct(Struct::Tuple { - r#type: StructType::Unit { - name: Identifier::new("Foo") - }, + name: Identifier::new("Foo"), fields: vec![Value::Integer(42)] }))) ); @@ -1064,9 +1060,7 @@ mod tests { assert_eq!( run(input), Ok(Some(Value::Struct(Struct::Unit { - r#type: StructType::Unit { - name: Identifier::new("Foo") - } + name: Identifier::new("Foo") }))) ) } @@ -1078,9 +1072,7 @@ mod tests { assert_eq!( run(input), Ok(Some(Value::Struct(Struct::Unit { - r#type: StructType::Unit { - name: Identifier::new("Foo") - } + name: Identifier::new("Foo") }))) ); } @@ -1163,7 +1155,7 @@ mod tests { #[test] fn if_else() { - let input = "if false { 1 } else { 2 }"; + let input = "let x = if false { 1 } else { 2 }; x"; assert_eq!(run(input), Ok(Some(Value::Integer(2)))); } @@ -1184,7 +1176,7 @@ mod tests { #[test] fn while_loop() { - let input = "let mut x = 0; while x < 5 { x += 1; } x"; + let input = "let mut x = 0; while x < 5 { x += 1 } x"; assert_eq!(run(input), Ok(Some(Value::Integer(5)))); }