diff --git a/src/abstract_tree/type.rs b/src/abstract_tree/type.rs index c1669e2..f0adb61 100644 --- a/src/abstract_tree/type.rs +++ b/src/abstract_tree/type.rs @@ -48,15 +48,10 @@ impl Type { | (Type::Map, Type::Map) | (Type::None, Type::None) | (Type::Range, Type::Range) - | (Type::String, Type::String) => Ok(()), + | (Type::String, Type::String) => return Ok(()), (Type::ListOf(left), Type::ListOf(right)) => { if let Ok(()) = left.check(right) { - Ok(()) - } else { - Err(TypeConflict { - actual: left.as_ref().clone(), - expected: right.as_ref().clone(), - }) + return Ok(()); } } (Type::ListOf(list_of), Type::ListExact(list_exact)) => { @@ -64,27 +59,50 @@ impl Type { list_of.check(r#type)?; } - Ok(()) + return Ok(()); } (Type::ListExact(list_exact), Type::ListOf(list_of)) => { for r#type in list_exact { r#type.check(&list_of)?; } - Ok(()) + return Ok(()); } (Type::ListExact(left), Type::ListExact(right)) => { for (left, right) in left.iter().zip(right.iter()) { left.check(right)?; } - Ok(()) + return Ok(()); } - _ => Err(TypeConflict { - actual: other.clone(), - expected: self.clone(), - }), + (Type::Named(left), Type::Named(right)) => { + if left == right { + return Ok(()); + } + } + ( + Type::Named(named), + Type::Structure { + name: struct_name, .. + }, + ) + | ( + Type::Structure { + name: struct_name, .. + }, + Type::Named(named), + ) => { + if named == struct_name { + return Ok(()); + } + } + _ => {} } + + Err(TypeConflict { + actual: other.clone(), + expected: self.clone(), + }) } } diff --git a/src/abstract_tree/value_node.rs b/src/abstract_tree/value_node.rs index 9efa603..fc93f3f 100644 --- a/src/abstract_tree/value_node.rs +++ b/src/abstract_tree/value_node.rs @@ -1,5 +1,7 @@ use std::{cmp::Ordering, collections::BTreeMap, ops::Range}; +use chumsky::container::Container; + use crate::{ context::Context, error::{RuntimeError, ValidationError}, @@ -63,7 +65,29 @@ impl AbstractNode for ValueNode { .collect(), return_type: Box::new(return_type.node.clone()), }, - ValueNode::Structure { name, fields } => todo!(), + ValueNode::Structure { + name, + fields: expressions, + } => { + let mut types = Vec::with_capacity(expressions.len()); + + for (identifier, expression) in expressions { + let r#type = expression.node.expected_type(_context)?; + + types.push(( + identifier.clone(), + WithPosition { + node: r#type, + position: expression.position, + }, + )); + } + + Type::Structure { + name: name.clone(), + fields: types, + } + } }; Ok(r#type) @@ -114,12 +138,36 @@ impl AbstractNode for ValueNode { })?; } - if let ValueNode::Structure { name, fields } = self { - let r#type = if let Some(r#type) = context.get_type(name)? { - r#type + if let ValueNode::Structure { + name, + fields: expressions, + } = self + { + let types = if let Some(r#type) = context.get_type(name)? { + if let Type::Structure { + name, + fields: types, + } = r#type + { + types + } else { + todo!() + } } else { return Err(ValidationError::TypeNotFound(name.clone())); }; + + for ((_, expression), (_, expected_type)) in expressions.iter().zip(types.iter()) { + let actual_type = expression.node.expected_type(context)?; + + expected_type.node.check(&actual_type).map_err(|conflict| { + ValidationError::TypeCheck { + conflict, + actual_position: expression.position, + expected_position: expected_type.position, + } + })? + } } Ok(()) diff --git a/src/context.rs b/src/context.rs index f41abf5..df9ccb3 100644 --- a/src/context.rs +++ b/src/context.rs @@ -73,7 +73,13 @@ impl Context { pub fn get_type(&self, identifier: &Identifier) -> Result, ValidationError> { if let Some(value_data) = self.inner.read()?.get(identifier) { let r#type = match value_data { - ValueData::Type(r#type) => r#type.clone(), + ValueData::Type(r#type) => { + if let Type::Named(name) = r#type { + return self.get_type(name); + } + + r#type.clone() + } ValueData::Value(value) => value.r#type(self)?, }; diff --git a/src/error.rs b/src/error.rs index 6c102bb..1dcabe1 100644 --- a/src/error.rs +++ b/src/error.rs @@ -33,7 +33,7 @@ impl Error { let (mut builder, validation_error, error_position) = match self { Error::Parse { expected, span } => { let message = if expected.is_empty() { - "Invalid token.".to_string() + "Invalid character.".to_string() } else { format!("Expected {expected}.") }; @@ -77,7 +77,7 @@ impl Error { } Error::Runtime { error, position } => ( Report::build( - ReportKind::Custom("Dust Error", Color::White), + ReportKind::Custom("Runtime Error", Color::White), "input", position.1, ), @@ -90,10 +90,11 @@ impl Error { ), Error::Validation { error, position } => ( Report::build( - ReportKind::Custom("Dust Error", Color::White), + ReportKind::Custom("Validation Error", Color::White), "input", position.1, - ), + ) + .with_note("This error was detected by the interpreter before running the code."), Some(error), position, ), @@ -130,6 +131,8 @@ impl Error { } => { let TypeConflict { actual, expected } = conflict; + builder = builder.with_message("A type conflict was found."); + builder.add_labels([ Label::new(("input", expected_postion.0..expected_postion.1)).with_message( format!("Type {} established here.", expected.fg(type_color)), diff --git a/src/parser.rs b/src/parser.rs index a00b094..234534f 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -106,7 +106,7 @@ pub fn parser<'src>() -> impl Parser< just(Token::Keyword("range")).to(Type::Range), just(Token::Keyword("str")).to(Type::String), just(Token::Keyword("list")).to(Type::List), - identifier.clone().map(|name| Type::Named(name)), + identifier.clone().map(|identifier| Type::Named(identifier)), )) }) .map_with(|r#type, state| r#type.with_position(state.span())); diff --git a/tests/structs.rs b/tests/structs.rs index db98905..83832d9 100644 --- a/tests/structs.rs +++ b/tests/structs.rs @@ -1,4 +1,8 @@ -use dust_lang::{abstract_tree::Identifier, error::Error, *}; +use dust_lang::{ + abstract_tree::{Identifier, Type}, + error::{Error, TypeConflict, ValidationError}, + *, +}; #[test] fn simple_structure() { @@ -26,6 +30,34 @@ fn simple_structure() { ) } +#[test] +fn field_type_error() { + assert_eq!( + interpret( + " + struct Foo { + bar : int, + } + + Foo { + bar = 'hiya', + } + " + ), + Err(vec![Error::Validation { + error: ValidationError::TypeCheck { + conflict: TypeConflict { + actual: Type::String, + expected: Type::Integer + }, + actual_position: (128, 134).into(), + expected_position: (56, 59).into() + }, + position: (96, 166).into() + }]) + ) +} + #[test] fn nested_structure() { assert_eq!(