diff --git a/src/abstract_tree/block.rs b/src/abstract_tree/block.rs index f8fb685..482ee8f 100644 --- a/src/abstract_tree/block.rs +++ b/src/abstract_tree/block.rs @@ -22,15 +22,23 @@ use crate::{ #[derive(Clone, Serialize, Deserialize, Eq, PartialEq, PartialOrd, Ord)] pub struct Block { is_async: bool, + contains_return: bool, statements: Vec, } +impl Block { + pub fn contains_return(&self) -> bool { + self.contains_return + } +} + impl AbstractTree for Block { fn from_syntax(node: SyntaxNode, source: &str, context: &Context) -> Result { SyntaxError::expect_syntax_node(source, "block", node)?; let first_child = node.child(0).unwrap(); let is_async = first_child.kind() == "async"; + let mut contains_return = false; let statement_count = if is_async { node.child_count() - 3 @@ -46,12 +54,17 @@ impl AbstractTree for Block { if child_node.kind() == "statement" { let statement = Statement::from_syntax(child_node, source, &block_context)?; + if statement.is_return() { + contains_return = true; + } + statements.push(statement); } } Ok(Block { is_async, + contains_return, statements, }) } @@ -74,9 +87,13 @@ impl AbstractTree for Block { .enumerate() .find_map_first(|(index, statement)| { let result = statement.run(_source, _context); - let is_last_statement = index == statements.len() - 1; + let should_return = if self.contains_return { + statement.is_return() + } else { + index == statements.len() - 1 + }; - if is_last_statement { + if should_return { let get_write_lock = final_result.write(); match get_write_lock { @@ -92,22 +109,32 @@ impl AbstractTree for Block { }) .unwrap_or(final_result.into_inner().map_err(|_| RwLockError)?) } else { - let mut prev_result = None; + for (index, statement) in self.statements.iter().enumerate() { + if statement.is_return() { + return statement.run(_source, _context); + } - for statement in &self.statements { - prev_result = Some(statement.run(_source, _context)); + if index == self.statements.len() - 1 { + return statement.run(_source, _context); + } } - prev_result.unwrap_or(Ok(Value::none())) + Ok(Value::none()) } } fn expected_type(&self, _context: &Context) -> Result { - if let Some(statement) = self.statements.last() { - statement.expected_type(_context) - } else { - Ok(Type::None) + for (index, statement) in self.statements.iter().enumerate() { + if statement.is_return() { + return statement.expected_type(_context); + } + + if index == self.statements.len() - 1 { + return statement.expected_type(_context); + } } + + Ok(Type::None) } } diff --git a/src/abstract_tree/statement.rs b/src/abstract_tree/statement.rs index 41e8db2..7f85146 100644 --- a/src/abstract_tree/statement.rs +++ b/src/abstract_tree/statement.rs @@ -10,7 +10,13 @@ use crate::{ #[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, PartialOrd, Ord)] pub struct Statement { is_return: bool, - statement_inner: StatementKind, + statement_kind: StatementKind, +} + +impl Statement { + pub fn is_return(&self) -> bool { + self.is_return + } } impl AbstractTree for Statement { @@ -22,35 +28,43 @@ impl AbstractTree for Statement { SyntaxError::expect_syntax_node(source, "statement", node)?; let first_child = node.child(0).unwrap(); - let is_return = first_child.kind() == "return"; + let mut is_return = first_child.kind() == "return"; let child = if is_return { node.child(1).unwrap() } else { first_child }; + let statement_kind = StatementKind::from_syntax(child, source, _context)?; + + if let StatementKind::Block(block) = &statement_kind { + if block.contains_return() { + is_return = true; + } + }; + Ok(Statement { is_return, - statement_inner: StatementKind::from_syntax(child, source, _context)?, + statement_kind, }) } fn expected_type(&self, _context: &Context) -> Result { - self.statement_inner.expected_type(_context) + self.statement_kind.expected_type(_context) } fn validate(&self, _source: &str, _context: &Context) -> Result<(), ValidationError> { - self.statement_inner.validate(_source, _context) + self.statement_kind.validate(_source, _context) } fn run(&self, _source: &str, _context: &Context) -> Result { - self.statement_inner.run(_source, _context) + self.statement_kind.run(_source, _context) } } impl Format for Statement { fn format(&self, _output: &mut String, _indent_level: u8) { - self.statement_inner.format(_output, _indent_level) + self.statement_kind.format(_output, _indent_level) } }