From 3f4c4ff464f4a3e7157df0c1ed86b49de40d41b8 Mon Sep 17 00:00:00 2001 From: Jeff Date: Tue, 13 Feb 2024 08:10:34 -0500 Subject: [PATCH] Fix type checking bug --- src/abstract_tree/function_call.rs | 29 ++++++----- src/abstract_tree/function_node.rs | 14 ++++-- src/context.rs | 8 ++- src/main.rs | 81 +++++++++++++++--------------- tests/types.rs | 17 ++++++- 5 files changed, 82 insertions(+), 67 deletions(-) diff --git a/src/abstract_tree/function_call.rs b/src/abstract_tree/function_call.rs index 815c627..621ff2e 100644 --- a/src/abstract_tree/function_call.rs +++ b/src/abstract_tree/function_call.rs @@ -3,8 +3,8 @@ use serde::{Deserialize, Serialize}; use crate::{ built_in_functions::Callable, error::{RuntimeError, SyntaxError, ValidationError}, - AbstractTree, Context, Expression, Format, FunctionExpression, SourcePosition, SyntaxNode, - Type, Value, + AbstractTree, Context, Expression, Format, Function, FunctionExpression, SourcePosition, + SyntaxNode, Type, Value, }; /// A function being invoked and the arguments it is being passed. @@ -94,17 +94,16 @@ impl AbstractTree for FunctionCall { fn validate(&self, _source: &str, context: &Context) -> Result<(), ValidationError> { let function_expression_type = self.function_expression.expected_type(context)?; - let parameter_types = match function_expression_type { - Type::Function { - parameter_types, .. - } => parameter_types, - Type::Any => return Ok(()), - _ => { - return Err(ValidationError::TypeCheckExpectedFunction { - actual: function_expression_type, - position: self.syntax_position, - }); - } + let parameter_types = if let Type::Function { + parameter_types, .. + } = function_expression_type + { + parameter_types + } else { + return Err(ValidationError::TypeCheckExpectedFunction { + actual: function_expression_type, + position: self.syntax_position, + }); }; if self.arguments.len() != parameter_types.len() { @@ -156,7 +155,7 @@ impl AbstractTree for FunctionCall { let function = value.as_function()?; match function { - crate::Function::BuiltIn(built_in_function) => { + Function::BuiltIn(built_in_function) => { let mut arguments = Vec::with_capacity(self.arguments.len()); for expression in &self.arguments { @@ -167,7 +166,7 @@ impl AbstractTree for FunctionCall { built_in_function.call(&arguments, source, &self.context) } - crate::Function::ContextDefined(function_node) => { + Function::ContextDefined(function_node) => { let parameter_expression_pairs = function_node.parameters().iter().zip(self.arguments.iter()); diff --git a/src/abstract_tree/function_node.rs b/src/abstract_tree/function_node.rs index b498e63..04b5197 100644 --- a/src/abstract_tree/function_node.rs +++ b/src/abstract_tree/function_node.rs @@ -109,13 +109,19 @@ impl AbstractTree for FunctionNode { Ok(self.r#type().clone()) } - fn validate(&self, source: &str, context: &Context) -> Result<(), ValidationError> { + fn validate(&self, source: &str, _context: &Context) -> Result<(), ValidationError> { if let Type::Function { - parameter_types: _, + parameter_types, return_type, } = &self.r#type { - let actual = self.body.expected_type(context)?; + let validation_context = Context::new(); + + for (parameter, r#type) in self.parameters.iter().zip(parameter_types.iter()) { + validation_context.set_type(parameter.inner().clone(), r#type.clone())?; + } + + let actual = self.body.expected_type(&validation_context)?; if !return_type.accepts(&actual) { return Err(ValidationError::TypeCheck { @@ -125,7 +131,7 @@ impl AbstractTree for FunctionNode { }); } - self.body.validate(source, context)?; + self.body.validate(source, &validation_context)?; Ok(()) } else { diff --git a/src/context.rs b/src/context.rs index 25418d0..7c362e6 100644 --- a/src/context.rs +++ b/src/context.rs @@ -53,13 +53,11 @@ impl Context { pub fn get_value(&self, key: &str) -> Result, RwLockError> { if let Some(value_data) = self.inner.read()?.get(key) { if let ValueData::Value { inner, .. } = value_data { - Ok(Some(inner.clone())) - } else { - Ok(None) + return Ok(Some(inner.clone())); } - } else { - Ok(None) } + + Ok(None) } pub fn get_type(&self, key: &str) -> Result, RwLockError> { diff --git a/src/main.rs b/src/main.rs index 2bafb38..112dfbd 100644 --- a/src/main.rs +++ b/src/main.rs @@ -116,59 +116,59 @@ fn main() { } } -struct DustHighlighter { - context: Context, -} +// struct DustHighlighter { +// context: Context, +// } -impl DustHighlighter { - fn new(context: Context) -> Self { - Self { context } - } -} +// impl DustHighlighter { +// fn new(context: Context) -> Self { +// Self { context } +// } +// } -const HIGHLIGHT_TERMINATORS: [char; 8] = [' ', ':', '(', ')', '{', '}', '[', ']']; +// const HIGHLIGHT_TERMINATORS: [char; 8] = [' ', ':', '(', ')', '{', '}', '[', ']']; -impl Highlighter for DustHighlighter { - fn highlight(&self, line: &str, _cursor: usize) -> reedline::StyledText { - let mut styled = StyledText::new(); +// impl Highlighter for DustHighlighter { +// fn highlight(&self, line: &str, _cursor: usize) -> reedline::StyledText { +// let mut styled = StyledText::new(); - for word in line.split_inclusive(&HIGHLIGHT_TERMINATORS) { - let mut word_is_highlighted = false; +// for word in line.split_inclusive(&HIGHLIGHT_TERMINATORS) { +// let mut word_is_highlighted = false; - for key in self.context.inner().unwrap().keys() { - if key == &word { - styled.push((Style::new().bold(), word.to_string())); - } +// for key in self.context.inner().unwrap().keys() { +// if key == &word { +// styled.push((Style::new().bold(), word.to_string())); +// } - word_is_highlighted = true; - } +// word_is_highlighted = true; +// } - for built_in_value in built_in_values() { - if built_in_value.name() == word { - styled.push((Style::new().bold(), word.to_string())); - } +// for built_in_value in built_in_values() { +// if built_in_value.name() == word { +// styled.push((Style::new().bold(), word.to_string())); +// } - word_is_highlighted = true; - } +// word_is_highlighted = true; +// } - if word_is_highlighted { - let final_char = word.chars().last().unwrap(); +// if word_is_highlighted { +// let final_char = word.chars().last().unwrap(); - if HIGHLIGHT_TERMINATORS.contains(&final_char) { - let mut terminator_style = Style::new(); +// if HIGHLIGHT_TERMINATORS.contains(&final_char) { +// let mut terminator_style = Style::new(); - terminator_style.foreground = Some(Color::Cyan); +// terminator_style.foreground = Some(Color::Cyan); - styled.push((terminator_style, final_char.to_string())); - } - } else { - styled.push((Style::new(), word.to_string())); - } - } +// styled.push((terminator_style, final_char.to_string())); +// } +// } else { +// styled.push((Style::new(), word.to_string())); +// } +// } - styled - } -} +// styled +// } +// } struct StarshipPrompt { left: String, @@ -367,7 +367,6 @@ fn run_shell(context: Context) -> Result<(), Error> { let mut line_editor = Reedline::create() .with_edit_mode(edit_mode) .with_history(history) - .with_highlighter(Box::new(DustHighlighter::new(context.clone()))) .with_hinter(hinter) .use_kitty_keyboard_enhancement(true) .with_completer(Box::new(completer)) diff --git a/tests/types.rs b/tests/types.rs index 5076b2a..c0081d8 100644 --- a/tests/types.rs +++ b/tests/types.rs @@ -32,8 +32,21 @@ fn argument_count_check() { let result = interpret(source); assert_eq!( - "Expected 1 arguments, but got 0. Occured at (5, 12) to (5, 17). Source: foo()", - result.unwrap_err().to_string() + Err(Error::Validation( + ValidationError::ExpectedFunctionArgumentAmount { + expected: 1, + actual: 0, + position: SourcePosition { + start_byte: 81, + end_byte: 86, + start_row: 5, + start_column: 12, + end_row: 5, + end_column: 17 + } + } + )), + result ) }