diff --git a/src/abstract_tree/assignment.rs b/src/abstract_tree/assignment.rs index 442b1ec..5b0b8b9 100644 --- a/src/abstract_tree/assignment.rs +++ b/src/abstract_tree/assignment.rs @@ -42,6 +42,8 @@ impl AbstractTree for Assignment { } fn validate(&self, context: &Context) -> Result<(), ValidationError> { + self.statement.validate(context)?; + let statement_type = self.statement.expected_type(context)?; if let Some(expected) = &self.r#type { diff --git a/src/abstract_tree/block.rs b/src/abstract_tree/block.rs index 54f6810..e9a3ee9 100644 --- a/src/abstract_tree/block.rs +++ b/src/abstract_tree/block.rs @@ -1,7 +1,6 @@ use crate::{ context::Context, error::{RuntimeError, ValidationError}, - Value, }; use super::{AbstractTree, Action, Statement, Type}; @@ -19,9 +18,11 @@ impl Block { impl AbstractTree for Block { fn expected_type(&self, _context: &Context) -> Result { - let final_statement = self.statements.last().unwrap(); - - final_statement.expected_type(_context) + if let Some(statement) = self.statements.last() { + statement.expected_type(_context) + } else { + Ok(Type::None) + } } fn validate(&self, _context: &Context) -> Result<(), ValidationError> { @@ -33,23 +34,26 @@ impl AbstractTree for Block { } fn run(self, _context: &Context) -> Result { - let mut previous = Value::none(); + let mut previous = Action::None; for statement in self.statements { let action = statement.run(_context)?; previous = match action { - Action::Return(value) => value, + Action::Return(value) => Action::Return(value), r#break => return Ok(r#break), }; } - Ok(Action::Return(previous)) + Ok(previous) } } #[cfg(test)] mod tests { - use crate::abstract_tree::{Expression, ValueNode}; + use crate::{ + abstract_tree::{Expression, ValueNode}, + Value, + }; use super::*; diff --git a/src/abstract_tree/value_node.rs b/src/abstract_tree/value_node.rs index c0be308..0dd2043 100644 --- a/src/abstract_tree/value_node.rs +++ b/src/abstract_tree/value_node.rs @@ -6,7 +6,7 @@ use crate::{ Value, }; -use super::{AbstractTree, Action, Expression, Identifier, Statement, Type}; +use super::{AbstractTree, Action, Block, Expression, Identifier, Type}; #[derive(Clone, Debug, PartialEq)] pub enum ValueNode { @@ -21,7 +21,7 @@ pub enum ValueNode { Function { parameters: Vec<(Identifier, Type)>, return_type: Type, - body: Box, + body: Block, }, } @@ -72,6 +72,23 @@ impl AbstractTree for ValueNode { } } + if let ValueNode::Function { + parameters, + return_type, + body, + } = self + { + let function_context = Context::new(); + + for (identifier, r#type) in parameters { + function_context.set_type(identifier.clone(), r#type.clone())?; + } + + let actual_return_type = body.expected_type(&function_context)?; + + return_type.check(&actual_return_type)?; + } + Ok(()) } @@ -115,7 +132,7 @@ impl AbstractTree for ValueNode { parameters, return_type, body, - } => Value::function(parameters, return_type, *body), + } => Value::function(parameters, return_type, body), }; Ok(Action::Return(value)) diff --git a/src/lexer.rs b/src/lexer.rs index 59f76e0..e1fe795 100644 --- a/src/lexer.rs +++ b/src/lexer.rs @@ -210,6 +210,7 @@ pub fn lexer<'src>() -> impl Parser< .map(Token::Control); let keyword = choice(( + just("any").padded(), just("bool").padded(), just("break").padded(), just("else").padded(), @@ -218,6 +219,7 @@ pub fn lexer<'src>() -> impl Parser< just("if").padded(), just("list").padded(), just("map").padded(), + just("none").padded(), just("range").padded(), just("str").padded(), just("loop").padded(), diff --git a/src/parser.rs b/src/parser.rs index 47468e3..7f9c7e1 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -47,10 +47,21 @@ pub fn parser<'src>() -> DustParser<'src> { } }; + let basic_value = select! { + Token::Boolean(boolean) => ValueNode::Boolean(boolean), + Token::Integer(integer) => ValueNode::Integer(integer), + Token::Float(float) => ValueNode::Float(float), + Token::String(string) => ValueNode::String(string.to_string()), + } + .map(|value| Expression::Value(value)) + .boxed(); + let basic_type = choice(( + just(Token::Keyword("any")).to(Type::Any), just(Token::Keyword("bool")).to(Type::Boolean), just(Token::Keyword("float")).to(Type::Float), just(Token::Keyword("int")).to(Type::Integer), + just(Token::Keyword("none")).to(Type::None), just(Token::Keyword("range")).to(Type::Range), just(Token::Keyword("str")).to(Type::String), just(Token::Keyword("list")).to(Type::List), @@ -80,150 +91,194 @@ pub fn parser<'src>() -> DustParser<'src> { .map(|identifier| Type::Custom(identifier)), ))); - let expression = recursive(|expression| { - let basic_value = select! { - Token::Boolean(boolean) => ValueNode::Boolean(boolean), - Token::Integer(integer) => ValueNode::Integer(integer), - Token::Float(float) => ValueNode::Float(float), - Token::String(string) => ValueNode::String(string.to_string()), - } - .map(|value| Expression::Value(value)) - .boxed(); - - let identifier_expression = identifier + let statement = recursive(|statement| { + let block = statement .clone() - .map(|identifier| Expression::Identifier(identifier)) - .boxed(); - - let range = { - let raw_integer = select! { - Token::Integer(integer) => integer - }; - - raw_integer - .clone() - .then_ignore(just(Token::Control(Control::DoubleDot))) - .then(raw_integer) - .map(|(start, end)| Expression::Value(ValueNode::Range(start..end))) - }; - - let list = expression - .clone() - .separated_by(just(Token::Control(Control::Comma))) - .allow_trailing() - .collect() - .delimited_by( - just(Token::Control(Control::SquareOpen)), - just(Token::Control(Control::SquareClose)), - ) - .map(|list| Expression::Value(ValueNode::List(list))) - .boxed(); - - let map_assignment = identifier - .clone() - .then(type_specification.clone().or_not()) - .then_ignore(just(Token::Operator(Operator::Assign))) - .then(expression.clone()) - .map(|((identifier, r#type), expression)| (identifier, r#type, expression)); - - let map = map_assignment - .separated_by(just(Token::Control(Control::Comma)).or_not()) - .allow_trailing() + .repeated() .collect() .delimited_by( just(Token::Control(Control::CurlyOpen)), just(Token::Control(Control::CurlyClose)), ) - .map(|map_assigment_list| Expression::Value(ValueNode::Map(map_assigment_list))); + .map(|statements| Block::new(statements)); - let r#enum = identifier - .clone() - .then_ignore(just(Token::Control(Control::DoubleColon))) - .then(identifier.clone()) - .map(|(name, variant)| Expression::Value(ValueNode::Enum(name, variant))) - .boxed(); + let expression = recursive(|expression| { + let identifier_expression = identifier + .clone() + .map(|identifier| Expression::Identifier(identifier)) + .boxed(); - let atom = choice(( - identifier_expression.clone(), - basic_value.clone(), - list.clone(), - r#enum.clone(), - expression.clone().delimited_by( - just(Token::Control(Control::ParenOpen)), - just(Token::Control(Control::ParenClose)), - ), - )); + let range = { + let raw_integer = select! { + Token::Integer(integer) => integer + }; - use Operator::*; + raw_integer + .clone() + .then_ignore(just(Token::Control(Control::DoubleDot))) + .then(raw_integer) + .map(|(start, end)| Expression::Value(ValueNode::Range(start..end))) + }; - let logic_math_and_index = atom - .pratt(( - prefix(2, just(Token::Operator(Not)), |expression| { - Expression::Logic(Box::new(Logic::Not(expression))) - }), - infix(left(1), just(Token::Operator(Equal)), |left, right| { - Expression::Logic(Box::new(Logic::Equal(left, right))) - }), - infix(left(1), just(Token::Operator(NotEqual)), |left, right| { - Expression::Logic(Box::new(Logic::NotEqual(left, right))) - }), - infix(left(1), just(Token::Operator(Greater)), |left, right| { - Expression::Logic(Box::new(Logic::Greater(left, right))) - }), - infix(left(1), just(Token::Operator(Less)), |left, right| { - Expression::Logic(Box::new(Logic::Less(left, right))) - }), - infix( - left(1), - just(Token::Operator(GreaterOrEqual)), - |left, right| Expression::Logic(Box::new(Logic::GreaterOrEqual(left, right))), - ), - infix( - left(1), - just(Token::Operator(LessOrEqual)), - |left, right| Expression::Logic(Box::new(Logic::LessOrEqual(left, right))), - ), - infix(left(1), just(Token::Operator(And)), |left, right| { - Expression::Logic(Box::new(Logic::And(left, right))) - }), - infix(left(1), just(Token::Operator(Or)), |left, right| { - Expression::Logic(Box::new(Logic::Or(left, right))) - }), - infix(left(1), just(Token::Operator(Add)), |left, right| { - Expression::Math(Box::new(Math::Add(left, right))) - }), - infix(left(1), just(Token::Operator(Subtract)), |left, right| { - Expression::Math(Box::new(Math::Subtract(left, right))) - }), - infix(left(2), just(Token::Operator(Multiply)), |left, right| { - Expression::Math(Box::new(Math::Multiply(left, right))) - }), - infix(left(2), just(Token::Operator(Divide)), |left, right| { - Expression::Math(Box::new(Math::Divide(left, right))) - }), - infix(left(1), just(Token::Operator(Modulo)), |left, right| { - Expression::Math(Box::new(Math::Modulo(left, right))) - }), - infix( - left(3), - just(Token::Control(Control::Dot)), - |left, right| Expression::Index(Box::new(Index::new(left, right))), + let list = expression + .clone() + .separated_by(just(Token::Control(Control::Comma))) + .allow_trailing() + .collect() + .delimited_by( + just(Token::Control(Control::SquareOpen)), + just(Token::Control(Control::SquareClose)), + ) + .map(|list| Expression::Value(ValueNode::List(list))) + .boxed(); + + let map_assignment = identifier + .clone() + .then(type_specification.clone().or_not()) + .then_ignore(just(Token::Operator(Operator::Assign))) + .then(expression.clone()) + .map(|((identifier, r#type), expression)| (identifier, r#type, expression)); + + let map = map_assignment + .separated_by(just(Token::Control(Control::Comma)).or_not()) + .allow_trailing() + .collect() + .delimited_by( + just(Token::Control(Control::CurlyOpen)), + just(Token::Control(Control::CurlyClose)), + ) + .map(|map_assigment_list| Expression::Value(ValueNode::Map(map_assigment_list))); + + let r#enum = identifier + .clone() + .then_ignore(just(Token::Control(Control::DoubleColon))) + .then(identifier.clone()) + .map(|(name, variant)| Expression::Value(ValueNode::Enum(name, variant))) + .boxed(); + + let function = identifier + .clone() + .then(type_specification.clone()) + .separated_by(just(Token::Control(Control::Comma))) + .collect::>() + .delimited_by( + just(Token::Control(Control::ParenOpen)), + just(Token::Control(Control::ParenClose)), + ) + .then(type_specification.clone()) + .then(block.clone()) + .map(|((parameters, return_type), body)| { + Expression::Value(ValueNode::Function { + parameters, + return_type, + body, + }) + }) + .boxed(); + + let function_expression = choice((identifier_expression.clone(), function.clone())); + + let function_call = function_expression + .then( + expression + .clone() + .separated_by(just(Token::Control(Control::Comma))) + .collect() + .delimited_by( + just(Token::Control(Control::ParenOpen)), + just(Token::Control(Control::ParenClose)), + ), + ) + .map(|(function, arguments)| { + Expression::FunctionCall(FunctionCall::new(function, arguments)) + }) + .boxed(); + + let atom = choice(( + function_call, + identifier_expression.clone(), + basic_value.clone(), + list.clone(), + r#enum.clone(), + expression.clone().delimited_by( + just(Token::Control(Control::ParenOpen)), + just(Token::Control(Control::ParenClose)), ), + )); + + use Operator::*; + + let logic_math_and_index = atom + .pratt(( + prefix(2, just(Token::Operator(Not)), |expression| { + Expression::Logic(Box::new(Logic::Not(expression))) + }), + infix(left(1), just(Token::Operator(Equal)), |left, right| { + Expression::Logic(Box::new(Logic::Equal(left, right))) + }), + infix(left(1), just(Token::Operator(NotEqual)), |left, right| { + Expression::Logic(Box::new(Logic::NotEqual(left, right))) + }), + infix(left(1), just(Token::Operator(Greater)), |left, right| { + Expression::Logic(Box::new(Logic::Greater(left, right))) + }), + infix(left(1), just(Token::Operator(Less)), |left, right| { + Expression::Logic(Box::new(Logic::Less(left, right))) + }), + infix( + left(1), + just(Token::Operator(GreaterOrEqual)), + |left, right| { + Expression::Logic(Box::new(Logic::GreaterOrEqual(left, right))) + }, + ), + infix( + left(1), + just(Token::Operator(LessOrEqual)), + |left, right| Expression::Logic(Box::new(Logic::LessOrEqual(left, right))), + ), + infix(left(1), just(Token::Operator(And)), |left, right| { + Expression::Logic(Box::new(Logic::And(left, right))) + }), + infix(left(1), just(Token::Operator(Or)), |left, right| { + Expression::Logic(Box::new(Logic::Or(left, right))) + }), + infix(left(1), just(Token::Operator(Add)), |left, right| { + Expression::Math(Box::new(Math::Add(left, right))) + }), + infix(left(1), just(Token::Operator(Subtract)), |left, right| { + Expression::Math(Box::new(Math::Subtract(left, right))) + }), + infix(left(2), just(Token::Operator(Multiply)), |left, right| { + Expression::Math(Box::new(Math::Multiply(left, right))) + }), + infix(left(2), just(Token::Operator(Divide)), |left, right| { + Expression::Math(Box::new(Math::Divide(left, right))) + }), + infix(left(1), just(Token::Operator(Modulo)), |left, right| { + Expression::Math(Box::new(Math::Modulo(left, right))) + }), + infix( + left(3), + just(Token::Control(Control::Dot)), + |left, right| Expression::Index(Box::new(Index::new(left, right))), + ), + )) + .boxed(); + + choice(( + function, + range, + r#enum, + logic_math_and_index, + identifier_expression, + list, + map, + basic_value, )) - .boxed(); + .boxed() + }); - choice(( - range, - r#enum, - logic_math_and_index, - identifier_expression, - list, - map, - basic_value, - )) - }); - - let statement = recursive(|statement| { let expression_statement = expression .clone() .map(|expression| Statement::Expression(expression)) @@ -247,16 +302,7 @@ pub fn parser<'src>() -> DustParser<'src> { }) .boxed(); - let block = statement - .clone() - .repeated() - .collect() - .delimited_by( - just(Token::Control(Control::CurlyOpen)), - just(Token::Control(Control::CurlyClose)), - ) - .map(|statements| Statement::Block(Block::new(statements))) - .boxed(); + let block_statement = block.clone().map(|block| Statement::Block(block)); let r#loop = statement .clone() @@ -283,54 +329,16 @@ pub fn parser<'src>() -> DustParser<'src> { }) .boxed(); - let function = identifier - .clone() - .then(type_specification.clone()) - .separated_by(just(Token::Control(Control::Comma))) - .collect::>() - .delimited_by( - just(Token::Control(Control::ParenOpen)), - just(Token::Control(Control::ParenClose)), - ) - .then(type_specification) - .then(statement.clone()) - .map(|((parameters, return_type), body)| { - Statement::Expression(Expression::Value(ValueNode::Function { - parameters, - return_type, - body: Box::new(body), - })) - }); - - let function_call = expression - .clone() - .then( - expression - .clone() - .separated_by(just(Token::Control(Control::Comma))) - .collect() - .delimited_by( - just(Token::Control(Control::ParenOpen)), - just(Token::Control(Control::ParenClose)), - ), - ) - .map(|(function, arguments)| { - Statement::Expression(Expression::FunctionCall(FunctionCall::new( - function, arguments, - ))) - }); - choice(( - function_call, assignment, expression_statement, r#break, - block, + block_statement, r#loop, if_else, - function, )) .then_ignore(just(Token::Control(Control::Semicolon)).or_not()) + .boxed() }); statement @@ -368,13 +376,13 @@ mod tests { #[test] fn function() { assert_eq!( - parse(&lex("(x: int): int x ").unwrap()).unwrap()[0].0, + parse(&lex("(x: int): int { x }").unwrap()).unwrap()[0].0, Statement::Expression(Expression::Value(ValueNode::Function { parameters: vec![(Identifier::new("x"), Type::Integer)], return_type: Type::Integer, - body: Box::new(Statement::Expression(Expression::Identifier( + body: Block::new(vec![Statement::Expression(Expression::Identifier( Identifier::new("x") - ))) + ))]) })) ) } diff --git a/src/value.rs b/src/value.rs index ec11588..ac4c0a5 100644 --- a/src/value.rs +++ b/src/value.rs @@ -13,7 +13,7 @@ use stanza::{ }; use crate::{ - abstract_tree::{AbstractTree, Action, Identifier, Statement, Type}, + abstract_tree::{AbstractTree, Action, Block, Identifier, Type}, context::Context, error::{RuntimeError, ValidationError}, }; @@ -73,11 +73,7 @@ impl Value { Value(Arc::new(ValueInner::Enum(name, variant))) } - pub fn function( - parameters: Vec<(Identifier, Type)>, - return_type: Type, - body: Statement, - ) -> Self { + pub fn function(parameters: Vec<(Identifier, Type)>, return_type: Type, body: Block) -> Self { Value(Arc::new(ValueInner::Function(Function::Parsed( ParsedFunction { parameters, @@ -369,7 +365,7 @@ impl Function { pub struct ParsedFunction { parameters: Vec<(Identifier, Type)>, return_type: Type, - body: Statement, + body: Block, } #[derive(Clone, Debug, Eq, PartialEq, PartialOrd, Ord)] diff --git a/tests/functions.rs b/tests/functions.rs index 2c17fbe..bd905f6 100644 --- a/tests/functions.rs +++ b/tests/functions.rs @@ -38,7 +38,7 @@ fn callback() { foobar = (cb : () -> str) : str { cb() } - foobar(() : str { 'Hiya' }) + foobar(() : str 'Hiya') ", ), Ok(Some(Value::string("Hiya".to_string()))) @@ -57,12 +57,12 @@ fn function_context_does_not_capture_values() { " x = 1 - foo = () : any { x } + foo = () : any { x } " ), Err(vec![Error::Validation { error: ValidationError::VariableNotFound(Identifier::new("x")), - span: (0..0).into() + span: (32..66).into() }]) ); diff --git a/tests/variables.rs b/tests/variables.rs index 32a33ce..5712c49 100644 --- a/tests/variables.rs +++ b/tests/variables.rs @@ -1,5 +1,5 @@ use dust_lang::{ - abstract_tree::{Expression, Identifier, Statement, Type}, + abstract_tree::{Block, Expression, Identifier, Statement, Type}, error::{Error, TypeCheckError, ValidationError}, *, }; @@ -41,7 +41,9 @@ fn function_variable() { Ok(Some(Value::function( vec![(Identifier::new("x"), Type::Integer)], Type::Integer, - Statement::Expression(Expression::Identifier(Identifier::new("x"))) + Block::new(vec![Statement::Expression(Expression::Identifier( + Identifier::new("x") + ))]) ))) ); }