diff --git a/examples/fibonacci.ds b/examples/fibonacci.ds index 17ab762..7d47b8f 100644 --- a/examples/fibonacci.ds +++ b/examples/fibonacci.ds @@ -1,4 +1,4 @@ -fib = (i , fib <(int) -> int>) { +fib = (i ) { if i <= 1 { 1 } else { @@ -6,4 +6,4 @@ fib = (i , fib <(int) -> int>) { } } -fib(8, fib) +fib(8) diff --git a/src/abstract_tree/assignment.rs b/src/abstract_tree/assignment.rs index 586546b..c082f58 100644 --- a/src/abstract_tree/assignment.rs +++ b/src/abstract_tree/assignment.rs @@ -97,9 +97,7 @@ impl AbstractTree for Assignment { } } - self.statement - .check_type(source, context) - .map_err(|error| error.at_source_position(source, self.syntax_position))?; + self.statement.check_type(source, context)?; Ok(()) } diff --git a/src/abstract_tree/function_node.rs b/src/abstract_tree/function_node.rs index 913d7e1..91609c3 100644 --- a/src/abstract_tree/function_node.rs +++ b/src/abstract_tree/function_node.rs @@ -61,6 +61,7 @@ impl FunctionNode { pub fn call(&self, arguments: &[Value], source: &str, outer_context: &Map) -> Result { let function_context = Map::new(); + let parameter_argument_pairs = self.parameters.iter().zip(arguments.iter()); for (key, (value, r#type)) in outer_context.variables()?.iter() { if r#type.is_function() { @@ -68,8 +69,6 @@ impl FunctionNode { } } - let parameter_argument_pairs = self.parameters.iter().zip(arguments.iter()); - for (identifier, value) in parameter_argument_pairs { let key = identifier.inner().clone(); @@ -116,6 +115,12 @@ impl AbstractTree for FunctionNode { function_context.set_type(parameter.inner().clone(), parameter_type.clone())?; } + for (key, (value, r#type)) in outer_context.variables()?.iter() { + if r#type.is_function() { + function_context.set(key.clone(), value.clone())?; + } + } + let body_node = node.child(child_count - 1).unwrap(); let body = Block::from_syntax(body_node, source, &function_context)?; @@ -130,32 +135,35 @@ impl AbstractTree for FunctionNode { }) } - fn check_type(&self, source: &str, _context: &Map) -> Result<()> { + fn check_type(&self, source: &str, context: &Map) -> Result<()> { let function_context = Map::new(); + for (key, (_value, r#type)) in context.variables()?.iter() { + if r#type.is_function() { + function_context.set_type(key.clone(), r#type.clone())?; + } + } + if let Type::Function { parameter_types, - return_type: _, + return_type, } = &self.r#type { for (parameter, parameter_type) in self.parameters.iter().zip(parameter_types.iter()) { function_context.set_type(parameter.inner().clone(), parameter_type.clone())?; } - self.return_type() + return_type .check(&self.body.expected_type(&function_context)?) .map_err(|error| error.at_source_position(source, self.syntax_position))?; + self.body.check_type(source, &function_context)?; + + Ok(()) } else { - return Err(Error::TypeCheckExpectedFunction { + Err(Error::TypeCheckExpectedFunction { actual: self.r#type.clone(), - }); - }; - - self.body - .check_type(source, &function_context) - .map_err(|error| error.at_source_position(source, self.syntax_position))?; - - Ok(()) + }) + } } fn run(&self, _source: &str, _context: &Map) -> Result { diff --git a/tests/functions.rs b/tests/functions.rs index e2263c3..25b8067 100644 --- a/tests/functions.rs +++ b/tests/functions.rs @@ -137,13 +137,13 @@ fn recursion() { if i <= 1 { 1 } else { - self(i - 1) + self(i - 2) + fib(i - 1) + fib(i - 2) } } - fib(3) + fib(8) " ), - Ok(Value::Integer(3)) + Ok(Value::Integer(34)) ); }