diff --git a/src/abstract_tree/assignment.rs b/src/abstract_tree/assignment.rs index 5b0b8b9..81b2803 100644 --- a/src/abstract_tree/assignment.rs +++ b/src/abstract_tree/assignment.rs @@ -42,8 +42,6 @@ impl AbstractTree for Assignment { } fn validate(&self, context: &Context) -> Result<(), ValidationError> { - self.statement.validate(context)?; - let statement_type = self.statement.expected_type(context)?; if let Some(expected) = &self.r#type { @@ -54,6 +52,8 @@ impl AbstractTree for Assignment { context.set_type(self.identifier.clone(), statement_type)?; } + self.statement.validate(context)?; + Ok(()) } diff --git a/src/abstract_tree/value_node.rs b/src/abstract_tree/value_node.rs index acd9656..8ddb0b6 100644 --- a/src/abstract_tree/value_node.rs +++ b/src/abstract_tree/value_node.rs @@ -84,6 +84,8 @@ impl AbstractTree for ValueNode { function_context.set_type(identifier.clone(), r#type.clone())?; } + body.validate(&function_context)?; + let actual_return_type = body.expected_type(&function_context)?; return_type.check(&actual_return_type)?; diff --git a/src/context.rs b/src/context.rs index c8f0a25..3bbccb3 100644 --- a/src/context.rs +++ b/src/context.rs @@ -6,7 +6,7 @@ use std::{ use crate::{ abstract_tree::{Identifier, Type}, error::RwLockPoisonError, - value::BuiltInFunction, + value::{BuiltInFunction, ValueInner}, Value, }; @@ -37,8 +37,10 @@ impl Context { let mut new_data = BTreeMap::new(); for (identifier, value_data) in other.inner.read()?.iter() { - if let ValueData::Type(_) = value_data { - new_data.insert(identifier.clone(), value_data.clone()); + if let ValueData::Type(r#type) = value_data { + if let Type::Function { .. } = r#type { + new_data.insert(identifier.clone(), value_data.clone()); + } } } @@ -49,7 +51,16 @@ impl Context { let mut new_data = BTreeMap::new(); for (identifier, value_data) in other.inner.read()?.iter() { - new_data.insert(identifier.clone(), value_data.clone()); + if let ValueData::Type(r#type) = value_data { + if let Type::Function { .. } = r#type { + new_data.insert(identifier.clone(), value_data.clone()); + } + } + if let ValueData::Value(value) = value_data { + if let ValueInner::Function { .. } = value.inner().as_ref() { + new_data.insert(identifier.clone(), value_data.clone()); + } + } } Ok(Self::with_data(new_data)) diff --git a/tests/functions.rs b/tests/functions.rs index 7cbda53..b0060e4 100644 --- a/tests/functions.rs +++ b/tests/functions.rs @@ -94,8 +94,6 @@ fn function_context_captures_functions() { #[test] fn recursion() { - env_logger::builder().is_test(true).try_init().unwrap(); - assert_eq!( interpret( "