diff --git a/src/abstract_tree/function_call.rs b/src/abstract_tree/function_call.rs
index 815c627..621ff2e 100644
--- a/src/abstract_tree/function_call.rs
+++ b/src/abstract_tree/function_call.rs
@@ -3,8 +3,8 @@ use serde::{Deserialize, Serialize};
use crate::{
built_in_functions::Callable,
error::{RuntimeError, SyntaxError, ValidationError},
- AbstractTree, Context, Expression, Format, FunctionExpression, SourcePosition, SyntaxNode,
- Type, Value,
+ AbstractTree, Context, Expression, Format, Function, FunctionExpression, SourcePosition,
+ SyntaxNode, Type, Value,
};
/// A function being invoked and the arguments it is being passed.
@@ -94,17 +94,16 @@ impl AbstractTree for FunctionCall {
fn validate(&self, _source: &str, context: &Context) -> Result<(), ValidationError> {
let function_expression_type = self.function_expression.expected_type(context)?;
- let parameter_types = match function_expression_type {
- Type::Function {
- parameter_types, ..
- } => parameter_types,
- Type::Any => return Ok(()),
- _ => {
- return Err(ValidationError::TypeCheckExpectedFunction {
- actual: function_expression_type,
- position: self.syntax_position,
- });
- }
+ let parameter_types = if let Type::Function {
+ parameter_types, ..
+ } = function_expression_type
+ {
+ parameter_types
+ } else {
+ return Err(ValidationError::TypeCheckExpectedFunction {
+ actual: function_expression_type,
+ position: self.syntax_position,
+ });
};
if self.arguments.len() != parameter_types.len() {
@@ -156,7 +155,7 @@ impl AbstractTree for FunctionCall {
let function = value.as_function()?;
match function {
- crate::Function::BuiltIn(built_in_function) => {
+ Function::BuiltIn(built_in_function) => {
let mut arguments = Vec::with_capacity(self.arguments.len());
for expression in &self.arguments {
@@ -167,7 +166,7 @@ impl AbstractTree for FunctionCall {
built_in_function.call(&arguments, source, &self.context)
}
- crate::Function::ContextDefined(function_node) => {
+ Function::ContextDefined(function_node) => {
let parameter_expression_pairs =
function_node.parameters().iter().zip(self.arguments.iter());
diff --git a/src/abstract_tree/function_node.rs b/src/abstract_tree/function_node.rs
index b498e63..04b5197 100644
--- a/src/abstract_tree/function_node.rs
+++ b/src/abstract_tree/function_node.rs
@@ -109,13 +109,19 @@ impl AbstractTree for FunctionNode {
Ok(self.r#type().clone())
}
- fn validate(&self, source: &str, context: &Context) -> Result<(), ValidationError> {
+ fn validate(&self, source: &str, _context: &Context) -> Result<(), ValidationError> {
if let Type::Function {
- parameter_types: _,
+ parameter_types,
return_type,
} = &self.r#type
{
- let actual = self.body.expected_type(context)?;
+ let validation_context = Context::new();
+
+ for (parameter, r#type) in self.parameters.iter().zip(parameter_types.iter()) {
+ validation_context.set_type(parameter.inner().clone(), r#type.clone())?;
+ }
+
+ let actual = self.body.expected_type(&validation_context)?;
if !return_type.accepts(&actual) {
return Err(ValidationError::TypeCheck {
@@ -125,7 +131,7 @@ impl AbstractTree for FunctionNode {
});
}
- self.body.validate(source, context)?;
+ self.body.validate(source, &validation_context)?;
Ok(())
} else {
diff --git a/src/context.rs b/src/context.rs
index 25418d0..7c362e6 100644
--- a/src/context.rs
+++ b/src/context.rs
@@ -53,13 +53,11 @@ impl Context {
pub fn get_value(&self, key: &str) -> Result