1
0

Implement block returns

This commit is contained in:
Jeff 2024-02-16 11:21:36 -05:00
parent 122d81f252
commit 8c4b2c9eef
2 changed files with 58 additions and 17 deletions

View File

@ -22,15 +22,23 @@ use crate::{
#[derive(Clone, Serialize, Deserialize, Eq, PartialEq, PartialOrd, Ord)]
pub struct Block {
is_async: bool,
contains_return: bool,
statements: Vec<Statement>,
}
impl Block {
pub fn contains_return(&self) -> bool {
self.contains_return
}
}
impl AbstractTree for Block {
fn from_syntax(node: SyntaxNode, source: &str, context: &Context) -> Result<Self, SyntaxError> {
SyntaxError::expect_syntax_node(source, "block", node)?;
let first_child = node.child(0).unwrap();
let is_async = first_child.kind() == "async";
let mut contains_return = false;
let statement_count = if is_async {
node.child_count() - 3
@ -46,12 +54,17 @@ impl AbstractTree for Block {
if child_node.kind() == "statement" {
let statement = Statement::from_syntax(child_node, source, &block_context)?;
if statement.is_return() {
contains_return = true;
}
statements.push(statement);
}
}
Ok(Block {
is_async,
contains_return,
statements,
})
}
@ -74,9 +87,13 @@ impl AbstractTree for Block {
.enumerate()
.find_map_first(|(index, statement)| {
let result = statement.run(_source, _context);
let is_last_statement = index == statements.len() - 1;
let should_return = if self.contains_return {
statement.is_return()
} else {
index == statements.len() - 1
};
if is_last_statement {
if should_return {
let get_write_lock = final_result.write();
match get_write_lock {
@ -92,22 +109,32 @@ impl AbstractTree for Block {
})
.unwrap_or(final_result.into_inner().map_err(|_| RwLockError)?)
} else {
let mut prev_result = None;
for statement in &self.statements {
prev_result = Some(statement.run(_source, _context));
for (index, statement) in self.statements.iter().enumerate() {
if statement.is_return() {
return statement.run(_source, _context);
}
prev_result.unwrap_or(Ok(Value::none()))
if index == self.statements.len() - 1 {
return statement.run(_source, _context);
}
}
Ok(Value::none())
}
}
fn expected_type(&self, _context: &Context) -> Result<Type, ValidationError> {
if let Some(statement) = self.statements.last() {
statement.expected_type(_context)
} else {
Ok(Type::None)
for (index, statement) in self.statements.iter().enumerate() {
if statement.is_return() {
return statement.expected_type(_context);
}
if index == self.statements.len() - 1 {
return statement.expected_type(_context);
}
}
Ok(Type::None)
}
}

View File

@ -10,7 +10,13 @@ use crate::{
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, PartialOrd, Ord)]
pub struct Statement {
is_return: bool,
statement_inner: StatementKind,
statement_kind: StatementKind,
}
impl Statement {
pub fn is_return(&self) -> bool {
self.is_return
}
}
impl AbstractTree for Statement {
@ -22,35 +28,43 @@ impl AbstractTree for Statement {
SyntaxError::expect_syntax_node(source, "statement", node)?;
let first_child = node.child(0).unwrap();
let is_return = first_child.kind() == "return";
let mut is_return = first_child.kind() == "return";
let child = if is_return {
node.child(1).unwrap()
} else {
first_child
};
let statement_kind = StatementKind::from_syntax(child, source, _context)?;
if let StatementKind::Block(block) = &statement_kind {
if block.contains_return() {
is_return = true;
}
};
Ok(Statement {
is_return,
statement_inner: StatementKind::from_syntax(child, source, _context)?,
statement_kind,
})
}
fn expected_type(&self, _context: &Context) -> Result<Type, ValidationError> {
self.statement_inner.expected_type(_context)
self.statement_kind.expected_type(_context)
}
fn validate(&self, _source: &str, _context: &Context) -> Result<(), ValidationError> {
self.statement_inner.validate(_source, _context)
self.statement_kind.validate(_source, _context)
}
fn run(&self, _source: &str, _context: &Context) -> Result<Value, RuntimeError> {
self.statement_inner.run(_source, _context)
self.statement_kind.run(_source, _context)
}
}
impl Format for Statement {
fn format(&self, _output: &mut String, _indent_level: u8) {
self.statement_inner.format(_output, _indent_level)
self.statement_kind.format(_output, _indent_level)
}
}