Fix type checking bug

This commit is contained in:
Jeff 2024-02-13 08:10:34 -05:00
parent 3c72e4f988
commit 3f4c4ff464
5 changed files with 82 additions and 67 deletions

View File

@ -3,8 +3,8 @@ use serde::{Deserialize, Serialize};
use crate::{ use crate::{
built_in_functions::Callable, built_in_functions::Callable,
error::{RuntimeError, SyntaxError, ValidationError}, error::{RuntimeError, SyntaxError, ValidationError},
AbstractTree, Context, Expression, Format, FunctionExpression, SourcePosition, SyntaxNode, AbstractTree, Context, Expression, Format, Function, FunctionExpression, SourcePosition,
Type, Value, SyntaxNode, Type, Value,
}; };
/// A function being invoked and the arguments it is being passed. /// A function being invoked and the arguments it is being passed.
@ -94,17 +94,16 @@ impl AbstractTree for FunctionCall {
fn validate(&self, _source: &str, context: &Context) -> Result<(), ValidationError> { fn validate(&self, _source: &str, context: &Context) -> Result<(), ValidationError> {
let function_expression_type = self.function_expression.expected_type(context)?; let function_expression_type = self.function_expression.expected_type(context)?;
let parameter_types = match function_expression_type { let parameter_types = if let Type::Function {
Type::Function {
parameter_types, .. parameter_types, ..
} => parameter_types, } = function_expression_type
Type::Any => return Ok(()), {
_ => { parameter_types
} else {
return Err(ValidationError::TypeCheckExpectedFunction { return Err(ValidationError::TypeCheckExpectedFunction {
actual: function_expression_type, actual: function_expression_type,
position: self.syntax_position, position: self.syntax_position,
}); });
}
}; };
if self.arguments.len() != parameter_types.len() { if self.arguments.len() != parameter_types.len() {
@ -156,7 +155,7 @@ impl AbstractTree for FunctionCall {
let function = value.as_function()?; let function = value.as_function()?;
match function { match function {
crate::Function::BuiltIn(built_in_function) => { Function::BuiltIn(built_in_function) => {
let mut arguments = Vec::with_capacity(self.arguments.len()); let mut arguments = Vec::with_capacity(self.arguments.len());
for expression in &self.arguments { for expression in &self.arguments {
@ -167,7 +166,7 @@ impl AbstractTree for FunctionCall {
built_in_function.call(&arguments, source, &self.context) built_in_function.call(&arguments, source, &self.context)
} }
crate::Function::ContextDefined(function_node) => { Function::ContextDefined(function_node) => {
let parameter_expression_pairs = let parameter_expression_pairs =
function_node.parameters().iter().zip(self.arguments.iter()); function_node.parameters().iter().zip(self.arguments.iter());

View File

@ -109,13 +109,19 @@ impl AbstractTree for FunctionNode {
Ok(self.r#type().clone()) Ok(self.r#type().clone())
} }
fn validate(&self, source: &str, context: &Context) -> Result<(), ValidationError> { fn validate(&self, source: &str, _context: &Context) -> Result<(), ValidationError> {
if let Type::Function { if let Type::Function {
parameter_types: _, parameter_types,
return_type, return_type,
} = &self.r#type } = &self.r#type
{ {
let actual = self.body.expected_type(context)?; let validation_context = Context::new();
for (parameter, r#type) in self.parameters.iter().zip(parameter_types.iter()) {
validation_context.set_type(parameter.inner().clone(), r#type.clone())?;
}
let actual = self.body.expected_type(&validation_context)?;
if !return_type.accepts(&actual) { if !return_type.accepts(&actual) {
return Err(ValidationError::TypeCheck { return Err(ValidationError::TypeCheck {
@ -125,7 +131,7 @@ impl AbstractTree for FunctionNode {
}); });
} }
self.body.validate(source, context)?; self.body.validate(source, &validation_context)?;
Ok(()) Ok(())
} else { } else {

View File

@ -53,13 +53,11 @@ impl Context {
pub fn get_value(&self, key: &str) -> Result<Option<Value>, RwLockError> { pub fn get_value(&self, key: &str) -> Result<Option<Value>, RwLockError> {
if let Some(value_data) = self.inner.read()?.get(key) { if let Some(value_data) = self.inner.read()?.get(key) {
if let ValueData::Value { inner, .. } = value_data { if let ValueData::Value { inner, .. } = value_data {
Ok(Some(inner.clone())) return Ok(Some(inner.clone()));
} else {
Ok(None)
} }
} else {
Ok(None)
} }
Ok(None)
} }
pub fn get_type(&self, key: &str) -> Result<Option<Type>, RwLockError> { pub fn get_type(&self, key: &str) -> Result<Option<Type>, RwLockError> {

View File

@ -116,59 +116,59 @@ fn main() {
} }
} }
struct DustHighlighter { // struct DustHighlighter {
context: Context, // context: Context,
} // }
impl DustHighlighter { // impl DustHighlighter {
fn new(context: Context) -> Self { // fn new(context: Context) -> Self {
Self { context } // Self { context }
} // }
} // }
const HIGHLIGHT_TERMINATORS: [char; 8] = [' ', ':', '(', ')', '{', '}', '[', ']']; // const HIGHLIGHT_TERMINATORS: [char; 8] = [' ', ':', '(', ')', '{', '}', '[', ']'];
impl Highlighter for DustHighlighter { // impl Highlighter for DustHighlighter {
fn highlight(&self, line: &str, _cursor: usize) -> reedline::StyledText { // fn highlight(&self, line: &str, _cursor: usize) -> reedline::StyledText {
let mut styled = StyledText::new(); // let mut styled = StyledText::new();
for word in line.split_inclusive(&HIGHLIGHT_TERMINATORS) { // for word in line.split_inclusive(&HIGHLIGHT_TERMINATORS) {
let mut word_is_highlighted = false; // let mut word_is_highlighted = false;
for key in self.context.inner().unwrap().keys() { // for key in self.context.inner().unwrap().keys() {
if key == &word { // if key == &word {
styled.push((Style::new().bold(), word.to_string())); // styled.push((Style::new().bold(), word.to_string()));
} // }
word_is_highlighted = true; // word_is_highlighted = true;
} // }
for built_in_value in built_in_values() { // for built_in_value in built_in_values() {
if built_in_value.name() == word { // if built_in_value.name() == word {
styled.push((Style::new().bold(), word.to_string())); // styled.push((Style::new().bold(), word.to_string()));
} // }
word_is_highlighted = true; // word_is_highlighted = true;
} // }
if word_is_highlighted { // if word_is_highlighted {
let final_char = word.chars().last().unwrap(); // let final_char = word.chars().last().unwrap();
if HIGHLIGHT_TERMINATORS.contains(&final_char) { // if HIGHLIGHT_TERMINATORS.contains(&final_char) {
let mut terminator_style = Style::new(); // let mut terminator_style = Style::new();
terminator_style.foreground = Some(Color::Cyan); // terminator_style.foreground = Some(Color::Cyan);
styled.push((terminator_style, final_char.to_string())); // styled.push((terminator_style, final_char.to_string()));
} // }
} else { // } else {
styled.push((Style::new(), word.to_string())); // styled.push((Style::new(), word.to_string()));
} // }
} // }
styled // styled
} // }
} // }
struct StarshipPrompt { struct StarshipPrompt {
left: String, left: String,
@ -367,7 +367,6 @@ fn run_shell(context: Context) -> Result<(), Error> {
let mut line_editor = Reedline::create() let mut line_editor = Reedline::create()
.with_edit_mode(edit_mode) .with_edit_mode(edit_mode)
.with_history(history) .with_history(history)
.with_highlighter(Box::new(DustHighlighter::new(context.clone())))
.with_hinter(hinter) .with_hinter(hinter)
.use_kitty_keyboard_enhancement(true) .use_kitty_keyboard_enhancement(true)
.with_completer(Box::new(completer)) .with_completer(Box::new(completer))

View File

@ -32,8 +32,21 @@ fn argument_count_check() {
let result = interpret(source); let result = interpret(source);
assert_eq!( assert_eq!(
"Expected 1 arguments, but got 0. Occured at (5, 12) to (5, 17). Source: foo()", Err(Error::Validation(
result.unwrap_err().to_string() ValidationError::ExpectedFunctionArgumentAmount {
expected: 1,
actual: 0,
position: SourcePosition {
start_byte: 81,
end_byte: 86,
start_row: 5,
start_column: 12,
end_row: 5,
end_column: 17
}
}
)),
result
) )
} }