Fix function validation and parsing

This commit is contained in:
Jeff 2024-03-09 12:58:29 -05:00
parent e272d99bae
commit 2dd1628bca
8 changed files with 238 additions and 207 deletions

View File

@ -42,6 +42,8 @@ impl AbstractTree for Assignment {
} }
fn validate(&self, context: &Context) -> Result<(), ValidationError> { fn validate(&self, context: &Context) -> Result<(), ValidationError> {
self.statement.validate(context)?;
let statement_type = self.statement.expected_type(context)?; let statement_type = self.statement.expected_type(context)?;
if let Some(expected) = &self.r#type { if let Some(expected) = &self.r#type {

View File

@ -1,7 +1,6 @@
use crate::{ use crate::{
context::Context, context::Context,
error::{RuntimeError, ValidationError}, error::{RuntimeError, ValidationError},
Value,
}; };
use super::{AbstractTree, Action, Statement, Type}; use super::{AbstractTree, Action, Statement, Type};
@ -19,9 +18,11 @@ impl Block {
impl AbstractTree for Block { impl AbstractTree for Block {
fn expected_type(&self, _context: &Context) -> Result<Type, ValidationError> { fn expected_type(&self, _context: &Context) -> Result<Type, ValidationError> {
let final_statement = self.statements.last().unwrap(); if let Some(statement) = self.statements.last() {
statement.expected_type(_context)
final_statement.expected_type(_context) } else {
Ok(Type::None)
}
} }
fn validate(&self, _context: &Context) -> Result<(), ValidationError> { fn validate(&self, _context: &Context) -> Result<(), ValidationError> {
@ -33,23 +34,26 @@ impl AbstractTree for Block {
} }
fn run(self, _context: &Context) -> Result<Action, RuntimeError> { fn run(self, _context: &Context) -> Result<Action, RuntimeError> {
let mut previous = Value::none(); let mut previous = Action::None;
for statement in self.statements { for statement in self.statements {
let action = statement.run(_context)?; let action = statement.run(_context)?;
previous = match action { previous = match action {
Action::Return(value) => value, Action::Return(value) => Action::Return(value),
r#break => return Ok(r#break), r#break => return Ok(r#break),
}; };
} }
Ok(Action::Return(previous)) Ok(previous)
} }
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::abstract_tree::{Expression, ValueNode}; use crate::{
abstract_tree::{Expression, ValueNode},
Value,
};
use super::*; use super::*;

View File

@ -6,7 +6,7 @@ use crate::{
Value, Value,
}; };
use super::{AbstractTree, Action, Expression, Identifier, Statement, Type}; use super::{AbstractTree, Action, Block, Expression, Identifier, Type};
#[derive(Clone, Debug, PartialEq)] #[derive(Clone, Debug, PartialEq)]
pub enum ValueNode { pub enum ValueNode {
@ -21,7 +21,7 @@ pub enum ValueNode {
Function { Function {
parameters: Vec<(Identifier, Type)>, parameters: Vec<(Identifier, Type)>,
return_type: Type, return_type: Type,
body: Box<Statement>, body: Block,
}, },
} }
@ -72,6 +72,23 @@ impl AbstractTree for ValueNode {
} }
} }
if let ValueNode::Function {
parameters,
return_type,
body,
} = self
{
let function_context = Context::new();
for (identifier, r#type) in parameters {
function_context.set_type(identifier.clone(), r#type.clone())?;
}
let actual_return_type = body.expected_type(&function_context)?;
return_type.check(&actual_return_type)?;
}
Ok(()) Ok(())
} }
@ -115,7 +132,7 @@ impl AbstractTree for ValueNode {
parameters, parameters,
return_type, return_type,
body, body,
} => Value::function(parameters, return_type, *body), } => Value::function(parameters, return_type, body),
}; };
Ok(Action::Return(value)) Ok(Action::Return(value))

View File

@ -210,6 +210,7 @@ pub fn lexer<'src>() -> impl Parser<
.map(Token::Control); .map(Token::Control);
let keyword = choice(( let keyword = choice((
just("any").padded(),
just("bool").padded(), just("bool").padded(),
just("break").padded(), just("break").padded(),
just("else").padded(), just("else").padded(),
@ -218,6 +219,7 @@ pub fn lexer<'src>() -> impl Parser<
just("if").padded(), just("if").padded(),
just("list").padded(), just("list").padded(),
just("map").padded(), just("map").padded(),
just("none").padded(),
just("range").padded(), just("range").padded(),
just("str").padded(), just("str").padded(),
just("loop").padded(), just("loop").padded(),

View File

@ -47,10 +47,21 @@ pub fn parser<'src>() -> DustParser<'src> {
} }
}; };
let basic_value = select! {
Token::Boolean(boolean) => ValueNode::Boolean(boolean),
Token::Integer(integer) => ValueNode::Integer(integer),
Token::Float(float) => ValueNode::Float(float),
Token::String(string) => ValueNode::String(string.to_string()),
}
.map(|value| Expression::Value(value))
.boxed();
let basic_type = choice(( let basic_type = choice((
just(Token::Keyword("any")).to(Type::Any),
just(Token::Keyword("bool")).to(Type::Boolean), just(Token::Keyword("bool")).to(Type::Boolean),
just(Token::Keyword("float")).to(Type::Float), just(Token::Keyword("float")).to(Type::Float),
just(Token::Keyword("int")).to(Type::Integer), just(Token::Keyword("int")).to(Type::Integer),
just(Token::Keyword("none")).to(Type::None),
just(Token::Keyword("range")).to(Type::Range), just(Token::Keyword("range")).to(Type::Range),
just(Token::Keyword("str")).to(Type::String), just(Token::Keyword("str")).to(Type::String),
just(Token::Keyword("list")).to(Type::List), just(Token::Keyword("list")).to(Type::List),
@ -80,16 +91,18 @@ pub fn parser<'src>() -> DustParser<'src> {
.map(|identifier| Type::Custom(identifier)), .map(|identifier| Type::Custom(identifier)),
))); )));
let expression = recursive(|expression| { let statement = recursive(|statement| {
let basic_value = select! { let block = statement
Token::Boolean(boolean) => ValueNode::Boolean(boolean), .clone()
Token::Integer(integer) => ValueNode::Integer(integer), .repeated()
Token::Float(float) => ValueNode::Float(float), .collect()
Token::String(string) => ValueNode::String(string.to_string()), .delimited_by(
} just(Token::Control(Control::CurlyOpen)),
.map(|value| Expression::Value(value)) just(Token::Control(Control::CurlyClose)),
.boxed(); )
.map(|statements| Block::new(statements));
let expression = recursive(|expression| {
let identifier_expression = identifier let identifier_expression = identifier
.clone() .clone()
.map(|identifier| Expression::Identifier(identifier)) .map(|identifier| Expression::Identifier(identifier))
@ -143,7 +156,46 @@ pub fn parser<'src>() -> DustParser<'src> {
.map(|(name, variant)| Expression::Value(ValueNode::Enum(name, variant))) .map(|(name, variant)| Expression::Value(ValueNode::Enum(name, variant)))
.boxed(); .boxed();
let function = identifier
.clone()
.then(type_specification.clone())
.separated_by(just(Token::Control(Control::Comma)))
.collect::<Vec<(Identifier, Type)>>()
.delimited_by(
just(Token::Control(Control::ParenOpen)),
just(Token::Control(Control::ParenClose)),
)
.then(type_specification.clone())
.then(block.clone())
.map(|((parameters, return_type), body)| {
Expression::Value(ValueNode::Function {
parameters,
return_type,
body,
})
})
.boxed();
let function_expression = choice((identifier_expression.clone(), function.clone()));
let function_call = function_expression
.then(
expression
.clone()
.separated_by(just(Token::Control(Control::Comma)))
.collect()
.delimited_by(
just(Token::Control(Control::ParenOpen)),
just(Token::Control(Control::ParenClose)),
),
)
.map(|(function, arguments)| {
Expression::FunctionCall(FunctionCall::new(function, arguments))
})
.boxed();
let atom = choice(( let atom = choice((
function_call,
identifier_expression.clone(), identifier_expression.clone(),
basic_value.clone(), basic_value.clone(),
list.clone(), list.clone(),
@ -176,7 +228,9 @@ pub fn parser<'src>() -> DustParser<'src> {
infix( infix(
left(1), left(1),
just(Token::Operator(GreaterOrEqual)), just(Token::Operator(GreaterOrEqual)),
|left, right| Expression::Logic(Box::new(Logic::GreaterOrEqual(left, right))), |left, right| {
Expression::Logic(Box::new(Logic::GreaterOrEqual(left, right)))
},
), ),
infix( infix(
left(1), left(1),
@ -213,6 +267,7 @@ pub fn parser<'src>() -> DustParser<'src> {
.boxed(); .boxed();
choice(( choice((
function,
range, range,
r#enum, r#enum,
logic_math_and_index, logic_math_and_index,
@ -221,9 +276,9 @@ pub fn parser<'src>() -> DustParser<'src> {
map, map,
basic_value, basic_value,
)) ))
.boxed()
}); });
let statement = recursive(|statement| {
let expression_statement = expression let expression_statement = expression
.clone() .clone()
.map(|expression| Statement::Expression(expression)) .map(|expression| Statement::Expression(expression))
@ -247,16 +302,7 @@ pub fn parser<'src>() -> DustParser<'src> {
}) })
.boxed(); .boxed();
let block = statement let block_statement = block.clone().map(|block| Statement::Block(block));
.clone()
.repeated()
.collect()
.delimited_by(
just(Token::Control(Control::CurlyOpen)),
just(Token::Control(Control::CurlyClose)),
)
.map(|statements| Statement::Block(Block::new(statements)))
.boxed();
let r#loop = statement let r#loop = statement
.clone() .clone()
@ -283,54 +329,16 @@ pub fn parser<'src>() -> DustParser<'src> {
}) })
.boxed(); .boxed();
let function = identifier
.clone()
.then(type_specification.clone())
.separated_by(just(Token::Control(Control::Comma)))
.collect::<Vec<(Identifier, Type)>>()
.delimited_by(
just(Token::Control(Control::ParenOpen)),
just(Token::Control(Control::ParenClose)),
)
.then(type_specification)
.then(statement.clone())
.map(|((parameters, return_type), body)| {
Statement::Expression(Expression::Value(ValueNode::Function {
parameters,
return_type,
body: Box::new(body),
}))
});
let function_call = expression
.clone()
.then(
expression
.clone()
.separated_by(just(Token::Control(Control::Comma)))
.collect()
.delimited_by(
just(Token::Control(Control::ParenOpen)),
just(Token::Control(Control::ParenClose)),
),
)
.map(|(function, arguments)| {
Statement::Expression(Expression::FunctionCall(FunctionCall::new(
function, arguments,
)))
});
choice(( choice((
function_call,
assignment, assignment,
expression_statement, expression_statement,
r#break, r#break,
block, block_statement,
r#loop, r#loop,
if_else, if_else,
function,
)) ))
.then_ignore(just(Token::Control(Control::Semicolon)).or_not()) .then_ignore(just(Token::Control(Control::Semicolon)).or_not())
.boxed()
}); });
statement statement
@ -368,13 +376,13 @@ mod tests {
#[test] #[test]
fn function() { fn function() {
assert_eq!( assert_eq!(
parse(&lex("(x: int): int x ").unwrap()).unwrap()[0].0, parse(&lex("(x: int): int { x }").unwrap()).unwrap()[0].0,
Statement::Expression(Expression::Value(ValueNode::Function { Statement::Expression(Expression::Value(ValueNode::Function {
parameters: vec![(Identifier::new("x"), Type::Integer)], parameters: vec![(Identifier::new("x"), Type::Integer)],
return_type: Type::Integer, return_type: Type::Integer,
body: Box::new(Statement::Expression(Expression::Identifier( body: Block::new(vec![Statement::Expression(Expression::Identifier(
Identifier::new("x") Identifier::new("x")
))) ))])
})) }))
) )
} }

View File

@ -13,7 +13,7 @@ use stanza::{
}; };
use crate::{ use crate::{
abstract_tree::{AbstractTree, Action, Identifier, Statement, Type}, abstract_tree::{AbstractTree, Action, Block, Identifier, Type},
context::Context, context::Context,
error::{RuntimeError, ValidationError}, error::{RuntimeError, ValidationError},
}; };
@ -73,11 +73,7 @@ impl Value {
Value(Arc::new(ValueInner::Enum(name, variant))) Value(Arc::new(ValueInner::Enum(name, variant)))
} }
pub fn function( pub fn function(parameters: Vec<(Identifier, Type)>, return_type: Type, body: Block) -> Self {
parameters: Vec<(Identifier, Type)>,
return_type: Type,
body: Statement,
) -> Self {
Value(Arc::new(ValueInner::Function(Function::Parsed( Value(Arc::new(ValueInner::Function(Function::Parsed(
ParsedFunction { ParsedFunction {
parameters, parameters,
@ -369,7 +365,7 @@ impl Function {
pub struct ParsedFunction { pub struct ParsedFunction {
parameters: Vec<(Identifier, Type)>, parameters: Vec<(Identifier, Type)>,
return_type: Type, return_type: Type,
body: Statement, body: Block,
} }
#[derive(Clone, Debug, Eq, PartialEq, PartialOrd, Ord)] #[derive(Clone, Debug, Eq, PartialEq, PartialOrd, Ord)]

View File

@ -38,7 +38,7 @@ fn callback() {
foobar = (cb : () -> str) : str { foobar = (cb : () -> str) : str {
cb() cb()
} }
foobar(() : str { 'Hiya' }) foobar(() : str 'Hiya')
", ",
), ),
Ok(Some(Value::string("Hiya".to_string()))) Ok(Some(Value::string("Hiya".to_string())))
@ -62,7 +62,7 @@ fn function_context_does_not_capture_values() {
), ),
Err(vec![Error::Validation { Err(vec![Error::Validation {
error: ValidationError::VariableNotFound(Identifier::new("x")), error: ValidationError::VariableNotFound(Identifier::new("x")),
span: (0..0).into() span: (32..66).into()
}]) }])
); );

View File

@ -1,5 +1,5 @@
use dust_lang::{ use dust_lang::{
abstract_tree::{Expression, Identifier, Statement, Type}, abstract_tree::{Block, Expression, Identifier, Statement, Type},
error::{Error, TypeCheckError, ValidationError}, error::{Error, TypeCheckError, ValidationError},
*, *,
}; };
@ -41,7 +41,9 @@ fn function_variable() {
Ok(Some(Value::function( Ok(Some(Value::function(
vec![(Identifier::new("x"), Type::Integer)], vec![(Identifier::new("x"), Type::Integer)],
Type::Integer, Type::Integer,
Statement::Expression(Expression::Identifier(Identifier::new("x"))) Block::new(vec![Statement::Expression(Expression::Identifier(
Identifier::new("x")
))])
))) )))
); );
} }