From f3fe03a95fc45e14937e95fcb58365ddf662d105 Mon Sep 17 00:00:00 2001 From: Jeff Date: Thu, 11 Jul 2024 17:22:30 -0400 Subject: [PATCH] Fix function context bug --- dust-lang/src/abstract_tree/value_node.rs | 36 +++++++++++++++++++---- dust-lang/src/value.rs | 16 ++++++++-- dust-lang/tests/functions.rs | 16 +++++----- 3 files changed, 50 insertions(+), 18 deletions(-) diff --git a/dust-lang/src/abstract_tree/value_node.rs b/dust-lang/src/abstract_tree/value_node.rs index afb20d5..417e91a 100644 --- a/dust-lang/src/abstract_tree/value_node.rs +++ b/dust-lang/src/abstract_tree/value_node.rs @@ -53,6 +53,7 @@ impl ValueNode { value_parameters, return_type, body, + context: Context::new(), }) } } @@ -157,17 +158,20 @@ impl AbstractNode for ValueNode { body, type_parameters, value_parameters, + context: function_context, }) = self { + function_context.inherit_variables_from(context)?; + if let Some(type_parameters) = type_parameters { for identifier in type_parameters { - context.set_type( + function_context.set_type( identifier.clone(), Type::Generic { identifier: identifier.clone(), concrete_type: None, }, - body.position, + (0, usize::MAX).into(), )?; } } @@ -176,15 +180,19 @@ impl AbstractNode for ValueNode { for (identifier, type_constructor) in value_parameters { let r#type = type_constructor.clone().construct(context)?; - context.set_type(identifier.clone(), r#type, body.position)?; + function_context.set_type( + identifier.clone(), + r#type, + (0, usize::MAX).into(), + )?; } } body.node - .define_and_validate(context, _manage_memory, scope)?; + .define_and_validate(function_context, _manage_memory, scope)?; let ((expected_return, expected_position), actual_return) = - match (return_type, body.node.expected_type(context)?) { + match (return_type, body.node.expected_type(function_context)?) { (Some(constructor), Some(r#type)) => ( (constructor.construct(context)?, constructor.position()), r#type, @@ -697,6 +705,7 @@ impl Display for ValueNode { value_parameters, return_type, body, + .. }) => { write!(f, "fn ")?; @@ -730,12 +739,27 @@ impl Display for ValueNode { } } -#[derive(Clone, Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize)] pub struct FunctionNode { type_parameters: Option>, value_parameters: Option>, return_type: Option, body: WithPosition, + + #[serde(skip)] + context: Context, +} + +impl Clone for FunctionNode { + fn clone(&self) -> Self { + FunctionNode { + type_parameters: self.type_parameters.clone(), + value_parameters: self.value_parameters.clone(), + return_type: self.return_type.clone(), + body: self.body.clone(), + context: Context::new(), + } + } } impl PartialEq for FunctionNode { diff --git a/dust-lang/src/value.rs b/dust-lang/src/value.rs index 01e305e..9828b57 100644 --- a/dust-lang/src/value.rs +++ b/dust-lang/src/value.rs @@ -743,7 +743,7 @@ impl Ord for ValueInner { } } -#[derive(Clone, Debug)] +#[derive(Debug)] pub struct Function { type_parameters: Option>, value_parameters: Option>, @@ -820,13 +820,23 @@ impl Function { debug!("Calling function"); - self.body - .define_and_validate(&self.context, false, SourcePosition(0, usize::MAX))?; self.body .evaluate(&self.context, false, SourcePosition(0, usize::MAX)) } } +impl Clone for Function { + fn clone(&self) -> Self { + Function { + type_parameters: self.type_parameters.clone(), + value_parameters: self.value_parameters.clone(), + return_type: self.return_type.clone(), + body: self.body.clone(), + context: Context::new(), + } + } +} + impl Eq for Function {} impl PartialEq for Function { diff --git a/dust-lang/tests/functions.rs b/dust-lang/tests/functions.rs index 11e4535..0b357dd 100644 --- a/dust-lang/tests/functions.rs +++ b/dust-lang/tests/functions.rs @@ -7,15 +7,15 @@ fn function_scope() { "test", " x = 2 - + foo = fn () -> int { x = 42 x } - + x = 1 - foo() + foo() " ), Ok(Some(Value::integer(42))) @@ -96,16 +96,14 @@ fn recursion() { "test", " fib = fn (i: int) -> int { - if i < 0 { - 0 - } else if i <= 1 { - i - } else { + if i <= 1 { + i + } else { fib(i - 1) + fib(i - 2) } } - fib(8) + fib(7) " ), Ok(Some(Value::integer(13)))