From c47d09fd1d4f20c6c3abf2c86069ffa988ac9c4a Mon Sep 17 00:00:00 2001 From: Jeff Date: Fri, 12 Jul 2024 16:04:47 -0400 Subject: [PATCH] Add enum type validation --- dust-lang/src/abstract_tree/if_else.rs | 1 + dust-lang/src/abstract_tree/value_node.rs | 83 ++++++++++++++++------- dust-lang/src/error.rs | 4 ++ dust-lang/src/lib.rs | 48 ++++++++----- dust-lang/src/parser/tests.rs | 21 ++++++ dust-shell/src/main.rs | 4 +- 6 files changed, 118 insertions(+), 43 deletions(-) diff --git a/dust-lang/src/abstract_tree/if_else.rs b/dust-lang/src/abstract_tree/if_else.rs index 79302cd..d6cbd6a 100644 --- a/dust-lang/src/abstract_tree/if_else.rs +++ b/dust-lang/src/abstract_tree/if_else.rs @@ -73,6 +73,7 @@ impl AbstractNode for IfElse { .map_err(|conflict| ValidationError::TypeCheck { conflict, actual_position: self.if_block.node.last_statement().position(), + expected_position: Some(self.if_expression.position()), })?; } diff --git a/dust-lang/src/abstract_tree/value_node.rs b/dust-lang/src/abstract_tree/value_node.rs index 6a1bde7..526161b 100644 --- a/dust-lang/src/abstract_tree/value_node.rs +++ b/dust-lang/src/abstract_tree/value_node.rs @@ -111,35 +111,52 @@ impl AbstractNode for ValueNode { expression.define_and_validate(context, _manage_memory, scope)?; } - if let Some(Type::Enum { name, variants, .. }) = context.get_type(&type_name.node)? { - let mut found = false; - - for (identifier, content) in &variants { - if identifier == &variant.node { - found = true; + if let Some(Type::Enum { + variants, + type_parameters, + .. + }) = context.get_type(&type_name.node)? + { + if let (Some(parameters), Some(arguments)) = (type_parameters, type_arguments) { + if arguments.len() != parameters.len() { + return Err(ValidationError::WrongTypeArgumentsCount { + expected: parameters.len(), + actual: arguments.len(), + }); } - if let Some(content) = content { - for r#type in content { - if let Type::Generic { - concrete_type: None, - .. - } = r#type - { - return Err(ValidationError::FullTypeNotKnown { - identifier: name, - position: variant.position, - }); + let mut arguments = arguments.iter(); + let mut found = false; + + for (identifier, content) in variants { + if identifier == variant.node { + found = true; + } + + if let Some(content) = content { + for expected_type in &content { + if let Type::Generic { + concrete_type: None, + .. + } = expected_type + { + arguments.next().ok_or_else(|| { + ValidationError::WrongTypeArgumentsCount { + expected: content.len(), + actual: arguments.len(), + } + })?; + } } } } - } - if !found { - return Err(ValidationError::EnumVariantNotFound { - identifier: variant.node.clone(), - position: variant.position, - }); + if !found { + return Err(ValidationError::EnumVariantNotFound { + identifier: variant.node.clone(), + position: variant.position, + }); + } } } else { return Err(ValidationError::EnumDefinitionNotFound { @@ -691,10 +708,26 @@ impl Display for ValueNode { variant, content, } => { - write!(f, "{}::{}", type_name.node, variant.node)?; + write!(f, "{}", type_name.node)?; + + if let Some(types) = type_arguments { + write!(f, "::<")?; + + for (index, r#type) in types.iter().enumerate() { + if index > 0 { + write!(f, ", ")?; + } + + write!(f, "{type}")?; + } + + write!(f, ">")?; + } + + write!(f, "::{}", variant.node)?; if let Some(expression) = content { - write!(f, "{expression}")?; + write!(f, "({expression})")?; } Ok(()) diff --git a/dust-lang/src/error.rs b/dust-lang/src/error.rs index 3e82e06..b47b0b1 100644 --- a/dust-lang/src/error.rs +++ b/dust-lang/src/error.rs @@ -181,6 +181,10 @@ pub enum ValidationError { identifier: Identifier, position: SourcePosition, }, + WrongTypeArgumentsCount { + expected: usize, + actual: usize, + }, } impl From for ValidationError { diff --git a/dust-lang/src/lib.rs b/dust-lang/src/lib.rs index c6b1318..12f8de2 100644 --- a/dust-lang/src/lib.rs +++ b/dust-lang/src/lib.rs @@ -8,9 +8,9 @@ pub mod standard_library; pub mod value; use std::{ + collections::{hash_map, HashMap}, ops::Range, sync::{Arc, RwLock}, - vec, }; pub use abstract_tree::Type; @@ -29,21 +29,24 @@ pub fn interpret(source_id: &str, source: &str) -> Result, Interpr interpreter.run(Arc::from(source_id), Arc::from(source)) } -type Source = (Arc, Arc); - +/// Interpreter, lexer and parser for the Dust programming language. +/// +/// You must provide the interpreter with an ID for each piece of code you pass to it. These are +/// used to identify the source of errors and to provide more detailed error messages. pub struct Interpreter { context: Context, - sources: Arc>>, + sources: Arc, Arc>>>, } impl Interpreter { pub fn new(context: Context) -> Self { Interpreter { context, - sources: Arc::new(RwLock::new(Vec::new())), + sources: Arc::new(RwLock::new(HashMap::new())), } } + /// Lexes the source code and returns a list of tokens. pub fn lex<'src>( &self, source_id: Arc, @@ -51,14 +54,14 @@ impl Interpreter { ) -> Result>, InterpreterError> { let mut sources = self.sources.write().unwrap(); - sources.clear(); - sources.push((source_id.clone(), Arc::from(source))); + sources.insert(source_id.clone(), Arc::from(source)); lex(source) .map(|tokens| tokens.into_iter().map(|(token, _)| token).collect()) .map_err(|errors| InterpreterError { source_id, errors }) } + /// Parses the source code and returns an abstract syntax tree. pub fn parse( &self, source_id: Arc, @@ -66,8 +69,7 @@ impl Interpreter { ) -> Result { let mut sources = self.sources.write().unwrap(); - sources.clear(); - sources.push((source_id.clone(), Arc::from(source))); + sources.insert(source_id.clone(), Arc::from(source)); parse(&lex(source).map_err(|errors| InterpreterError { source_id: source_id.clone(), @@ -76,6 +78,7 @@ impl Interpreter { .map_err(|errors| InterpreterError { source_id, errors }) } + /// Runs the source code and returns the result. pub fn run( &self, source_id: Arc, @@ -83,8 +86,7 @@ impl Interpreter { ) -> Result, InterpreterError> { let mut sources = self.sources.write().unwrap(); - sources.clear(); - sources.push((source_id.clone(), source.clone())); + sources.insert(source_id.clone(), source.clone()); let tokens = lex(source.as_ref()).map_err(|errors| InterpreterError { source_id: source_id.clone(), @@ -101,12 +103,16 @@ impl Interpreter { Ok(value_option) } - pub fn sources(&self) -> vec::IntoIter<(Arc, Arc)> { + pub fn sources(&self) -> hash_map::IntoIter, Arc> { self.sources.read().unwrap().clone().into_iter() } } #[derive(Debug, PartialEq)] +/// An error that occurred during the interpretation of a piece of code. +/// +/// Each error has a source ID that identifies the piece of code that caused the error, and a list +/// of errors that occurred during the interpretation of that code. pub struct InterpreterError { source_id: Arc, errors: Vec, @@ -119,6 +125,7 @@ impl InterpreterError { } impl InterpreterError { + /// Converts the error into a list of user-friendly reports that can be printed to the console. pub fn build_reports<'a>(self) -> Vec, Range)>> { let token_color = Color::Yellow; let type_color = Color::Green; @@ -281,23 +288,31 @@ impl InterpreterError { builder = builder.with_help("Try specifying the type using turbofish."); } - if let Some(position) = expected_position { + let actual_type_message = if let Some(position) = expected_position { builder.add_label( Label::new((self.source_id.clone(), position.0..position.1)) .with_message(format!( "Type {} established here.", expected.fg(type_color) )), + ); + + format!("Got type {} here.", actual.fg(type_color)) + } else { + format!( + "Got type {} but expected {}.", + actual.fg(type_color), + expected.fg(type_color) ) - } + }; builder.add_label( Label::new(( self.source_id.clone(), actual_position.0..actual_position.1, )) - .with_message(format!("Got type {} here.", actual.fg(type_color))), - ); + .with_message(actual_type_message), + ) } ValidationError::VariableNotFound { identifier, @@ -430,6 +445,7 @@ impl InterpreterError { .add_label(Label::new((self.source_id.clone(), 0..0)).with_message(reason)), ValidationError::CannotUsePath(_) => todo!(), ValidationError::Uninitialized => todo!(), + ValidationError::WrongTypeArgumentsCount { expected, actual } => todo!(), } } diff --git a/dust-lang/src/parser/tests.rs b/dust-lang/src/parser/tests.rs index 9267c62..9103cbb 100644 --- a/dust-lang/src/parser/tests.rs +++ b/dust-lang/src/parser/tests.rs @@ -63,6 +63,27 @@ fn type_invokation() { ); } +#[test] +fn enum_instance_with_type_arguments() { + assert_eq!( + parse(&lex("Foo::::Bar(42)").unwrap()).unwrap()[0], + Statement::Expression(Expression::Value( + ValueNode::EnumInstance { + type_name: Identifier::new("Foo").with_position((0, 3)), + type_arguments: Some(vec![ + TypeConstructor::Raw(RawTypeConstructor::Integer.with_position((6, 9))), + TypeConstructor::Raw(RawTypeConstructor::String.with_position((11, 14))) + ]), + variant: Identifier::new("Bar").with_position((17, 20)), + content: Some(Box::new(Expression::Value( + ValueNode::Integer(42).with_position((21, 23)) + ))) + } + .with_position((0, 24)) + )) + ); +} + #[test] fn enum_instance() { assert_eq!( diff --git a/dust-shell/src/main.rs b/dust-shell/src/main.rs index 7b27866..2e1a174 100644 --- a/dust-shell/src/main.rs +++ b/dust-shell/src/main.rs @@ -8,10 +8,10 @@ use colored::Colorize; use log::Level; use std::{ + collections::hash_map, fs::read_to_string, io::{stderr, Write}, sync::Arc, - vec, }; use dust_lang::{context::Context, Interpreter}; @@ -81,7 +81,7 @@ fn main() { for report in error.build_reports() { report .write_for_stdout( - sources::, Arc, vec::IntoIter<(Arc, Arc)>>( + sources::, Arc, hash_map::IntoIter, Arc>>( interpreter.sources(), ), stderr(),