diff --git a/src/abstract_tree/assignment.rs b/src/abstract_tree/assignment.rs index 963eaf6..e9cfec9 100644 --- a/src/abstract_tree/assignment.rs +++ b/src/abstract_tree/assignment.rs @@ -3,8 +3,8 @@ use serde::{Deserialize, Serialize}; use crate::{ context::Context, error::{RuntimeError, SyntaxError, ValidationError}, - AbstractTree, AssignmentOperator, Format, Identifier, SourcePosition, Statement, SyntaxNode, - Type, TypeSpecification, Value, + AbstractTree, AssignmentOperator, Format, Function, Identifier, SourcePosition, Statement, + SyntaxNode, Type, TypeSpecification, Value, }; /// Variable assignment, including add-assign and subtract-assign operations. @@ -149,6 +149,12 @@ impl AbstractTree for Assignment { AssignmentOperator::Equal => right, }; + if let Value::Function(Function::ContextDefined(function_node)) = &new_value { + function_node + .context() + .set_value(self.identifier.clone(), new_value.clone())?; + } + context.set_value(self.identifier.clone(), new_value)?; Ok(Value::none()) diff --git a/src/abstract_tree/function_call.rs b/src/abstract_tree/function_call.rs index 74b897c..1e538fe 100644 --- a/src/abstract_tree/function_call.rs +++ b/src/abstract_tree/function_call.rs @@ -94,6 +94,8 @@ impl AbstractTree for FunctionCall { } fn validate(&self, _source: &str, context: &Context) -> Result<(), ValidationError> { + self.function_expression.validate(_source, context)?; + let function_expression_type = self.function_expression.expected_type(context)?; let parameter_types = if let Type::Function { @@ -117,6 +119,8 @@ impl AbstractTree for FunctionCall { } for (index, expression) in self.arguments.iter().enumerate() { + expression.validate(_source, context)?; + if let Some(expected) = parameter_types.get(index) { let actual = expression.expected_type(context)?; @@ -178,8 +182,6 @@ impl AbstractTree for FunctionCall { call_context.set_value(identifier.clone(), value)?; } - println!("{}", call_context); - function_node.body().run(source, &call_context) } } diff --git a/src/abstract_tree/if_else.rs b/src/abstract_tree/if_else.rs index aa9a807..d0ac50a 100644 --- a/src/abstract_tree/if_else.rs +++ b/src/abstract_tree/if_else.rs @@ -89,6 +89,8 @@ impl AbstractTree for IfElse { } if let Some(block) = &self.else_block { + block.validate(_source, context)?; + let actual = block.expected_type(context)?; if !expected.accepts(&actual) { diff --git a/src/context/mod.rs b/src/context/mod.rs index de14843..cd184a0 100644 --- a/src/context/mod.rs +++ b/src/context/mod.rs @@ -91,8 +91,11 @@ impl Context { pub fn with_variables_from(other: &Context) -> Result { let mut new_variables = BTreeMap::new(); - for (identifier, value_data) in other.inner.read()?.iter() { - new_variables.insert(identifier.clone(), value_data.clone()); + for (identifier, (value_data, counter)) in other.inner.read()?.iter() { + let (allowances, _runtime_uses) = counter.get_counts()?; + let new_counter = UsageCounter::with_counts(allowances, 0); + + new_variables.insert(identifier.clone(), (value_data.clone(), new_counter)); } Ok(Context { @@ -123,30 +126,20 @@ impl Context { pub fn inherit_from(&self, other: &Context) -> Result<(), RwLockError> { let mut self_variables = self.inner.write()?; - for (identifier, (value_data, _counter)) in other.inner.read()?.iter() { + for (identifier, (value_data, counter)) in other.inner.read()?.iter() { + let (allowances, _runtime_uses) = counter.get_counts()?; + let new_counter = UsageCounter::with_counts(allowances, 0); + if let ValueData::Value(value) = value_data { if value.is_function() { - self_variables.insert( - identifier.clone(), - (value_data.clone(), UsageCounter::new()), - ); + self_variables.insert(identifier.clone(), (value_data.clone(), new_counter)); } - } - - if let ValueData::TypeHint(r#type) = value_data { + } else if let ValueData::TypeHint(r#type) = value_data { if r#type.is_function() { - self_variables.insert( - identifier.clone(), - (value_data.clone(), UsageCounter::new()), - ); + self_variables.insert(identifier.clone(), (value_data.clone(), new_counter)); } - } - - if let ValueData::TypeDefinition(_) = value_data { - self_variables.insert( - identifier.clone(), - (value_data.clone(), UsageCounter::new()), - ); + } else if let ValueData::TypeDefinition(_) = value_data { + self_variables.insert(identifier.clone(), (value_data.clone(), new_counter)); } } diff --git a/src/context/usage_counter.rs b/src/context/usage_counter.rs index 53b268d..02a01f5 100644 --- a/src/context/usage_counter.rs +++ b/src/context/usage_counter.rs @@ -16,6 +16,13 @@ impl UsageCounter { }))) } + pub fn with_counts(allowances: usize, runtime_uses: usize) -> UsageCounter { + UsageCounter(Arc::new(RwLock::new(UsageCounterInner { + allowances, + runtime_uses, + }))) + } + pub fn get_counts(&self) -> Result<(usize, usize), RwLockError> { let inner = self.0.read()?; Ok((inner.allowances, inner.runtime_uses)) diff --git a/tests/functions.rs b/tests/functions.rs index 4bd427e..e52d5ad 100644 --- a/tests/functions.rs +++ b/tests/functions.rs @@ -115,6 +115,8 @@ fn function_context_captures_structure_definitions() { #[test] fn recursion() { + env_logger::builder().is_test(true).try_init().unwrap(); + assert_eq!( interpret( "