Add enum type validation

This commit is contained in:
Jeff 2024-07-12 16:04:47 -04:00
parent 790438d1e3
commit c47d09fd1d
6 changed files with 118 additions and 43 deletions

View File

@ -73,6 +73,7 @@ impl AbstractNode for IfElse {
.map_err(|conflict| ValidationError::TypeCheck { .map_err(|conflict| ValidationError::TypeCheck {
conflict, conflict,
actual_position: self.if_block.node.last_statement().position(), actual_position: self.if_block.node.last_statement().position(),
expected_position: Some(self.if_expression.position()), expected_position: Some(self.if_expression.position()),
})?; })?;
} }

View File

@ -111,35 +111,52 @@ impl AbstractNode for ValueNode {
expression.define_and_validate(context, _manage_memory, scope)?; expression.define_and_validate(context, _manage_memory, scope)?;
} }
if let Some(Type::Enum { name, variants, .. }) = context.get_type(&type_name.node)? { if let Some(Type::Enum {
let mut found = false; variants,
type_parameters,
for (identifier, content) in &variants { ..
if identifier == &variant.node { }) = context.get_type(&type_name.node)?
found = true; {
if let (Some(parameters), Some(arguments)) = (type_parameters, type_arguments) {
if arguments.len() != parameters.len() {
return Err(ValidationError::WrongTypeArgumentsCount {
expected: parameters.len(),
actual: arguments.len(),
});
} }
if let Some(content) = content { let mut arguments = arguments.iter();
for r#type in content { let mut found = false;
if let Type::Generic {
concrete_type: None, for (identifier, content) in variants {
.. if identifier == variant.node {
} = r#type found = true;
{ }
return Err(ValidationError::FullTypeNotKnown {
identifier: name, if let Some(content) = content {
position: variant.position, for expected_type in &content {
}); if let Type::Generic {
concrete_type: None,
..
} = expected_type
{
arguments.next().ok_or_else(|| {
ValidationError::WrongTypeArgumentsCount {
expected: content.len(),
actual: arguments.len(),
}
})?;
}
} }
} }
} }
}
if !found { if !found {
return Err(ValidationError::EnumVariantNotFound { return Err(ValidationError::EnumVariantNotFound {
identifier: variant.node.clone(), identifier: variant.node.clone(),
position: variant.position, position: variant.position,
}); });
}
} }
} else { } else {
return Err(ValidationError::EnumDefinitionNotFound { return Err(ValidationError::EnumDefinitionNotFound {
@ -691,10 +708,26 @@ impl Display for ValueNode {
variant, variant,
content, content,
} => { } => {
write!(f, "{}::{}", type_name.node, variant.node)?; write!(f, "{}", type_name.node)?;
if let Some(types) = type_arguments {
write!(f, "::<")?;
for (index, r#type) in types.iter().enumerate() {
if index > 0 {
write!(f, ", ")?;
}
write!(f, "{type}")?;
}
write!(f, ">")?;
}
write!(f, "::{}", variant.node)?;
if let Some(expression) = content { if let Some(expression) = content {
write!(f, "{expression}")?; write!(f, "({expression})")?;
} }
Ok(()) Ok(())

View File

@ -181,6 +181,10 @@ pub enum ValidationError {
identifier: Identifier, identifier: Identifier,
position: SourcePosition, position: SourcePosition,
}, },
WrongTypeArgumentsCount {
expected: usize,
actual: usize,
},
} }
impl From<PoisonError> for ValidationError { impl From<PoisonError> for ValidationError {

View File

@ -8,9 +8,9 @@ pub mod standard_library;
pub mod value; pub mod value;
use std::{ use std::{
collections::{hash_map, HashMap},
ops::Range, ops::Range,
sync::{Arc, RwLock}, sync::{Arc, RwLock},
vec,
}; };
pub use abstract_tree::Type; pub use abstract_tree::Type;
@ -29,21 +29,24 @@ pub fn interpret(source_id: &str, source: &str) -> Result<Option<Value>, Interpr
interpreter.run(Arc::from(source_id), Arc::from(source)) interpreter.run(Arc::from(source_id), Arc::from(source))
} }
type Source = (Arc<str>, Arc<str>); /// Interpreter, lexer and parser for the Dust programming language.
///
/// You must provide the interpreter with an ID for each piece of code you pass to it. These are
/// used to identify the source of errors and to provide more detailed error messages.
pub struct Interpreter { pub struct Interpreter {
context: Context, context: Context,
sources: Arc<RwLock<Vec<Source>>>, sources: Arc<RwLock<HashMap<Arc<str>, Arc<str>>>>,
} }
impl Interpreter { impl Interpreter {
pub fn new(context: Context) -> Self { pub fn new(context: Context) -> Self {
Interpreter { Interpreter {
context, context,
sources: Arc::new(RwLock::new(Vec::new())), sources: Arc::new(RwLock::new(HashMap::new())),
} }
} }
/// Lexes the source code and returns a list of tokens.
pub fn lex<'src>( pub fn lex<'src>(
&self, &self,
source_id: Arc<str>, source_id: Arc<str>,
@ -51,14 +54,14 @@ impl Interpreter {
) -> Result<Vec<Token<'src>>, InterpreterError> { ) -> Result<Vec<Token<'src>>, InterpreterError> {
let mut sources = self.sources.write().unwrap(); let mut sources = self.sources.write().unwrap();
sources.clear(); sources.insert(source_id.clone(), Arc::from(source));
sources.push((source_id.clone(), Arc::from(source)));
lex(source) lex(source)
.map(|tokens| tokens.into_iter().map(|(token, _)| token).collect()) .map(|tokens| tokens.into_iter().map(|(token, _)| token).collect())
.map_err(|errors| InterpreterError { source_id, errors }) .map_err(|errors| InterpreterError { source_id, errors })
} }
/// Parses the source code and returns an abstract syntax tree.
pub fn parse( pub fn parse(
&self, &self,
source_id: Arc<str>, source_id: Arc<str>,
@ -66,8 +69,7 @@ impl Interpreter {
) -> Result<AbstractTree, InterpreterError> { ) -> Result<AbstractTree, InterpreterError> {
let mut sources = self.sources.write().unwrap(); let mut sources = self.sources.write().unwrap();
sources.clear(); sources.insert(source_id.clone(), Arc::from(source));
sources.push((source_id.clone(), Arc::from(source)));
parse(&lex(source).map_err(|errors| InterpreterError { parse(&lex(source).map_err(|errors| InterpreterError {
source_id: source_id.clone(), source_id: source_id.clone(),
@ -76,6 +78,7 @@ impl Interpreter {
.map_err(|errors| InterpreterError { source_id, errors }) .map_err(|errors| InterpreterError { source_id, errors })
} }
/// Runs the source code and returns the result.
pub fn run( pub fn run(
&self, &self,
source_id: Arc<str>, source_id: Arc<str>,
@ -83,8 +86,7 @@ impl Interpreter {
) -> Result<Option<Value>, InterpreterError> { ) -> Result<Option<Value>, InterpreterError> {
let mut sources = self.sources.write().unwrap(); let mut sources = self.sources.write().unwrap();
sources.clear(); sources.insert(source_id.clone(), source.clone());
sources.push((source_id.clone(), source.clone()));
let tokens = lex(source.as_ref()).map_err(|errors| InterpreterError { let tokens = lex(source.as_ref()).map_err(|errors| InterpreterError {
source_id: source_id.clone(), source_id: source_id.clone(),
@ -101,12 +103,16 @@ impl Interpreter {
Ok(value_option) Ok(value_option)
} }
pub fn sources(&self) -> vec::IntoIter<(Arc<str>, Arc<str>)> { pub fn sources(&self) -> hash_map::IntoIter<Arc<str>, Arc<str>> {
self.sources.read().unwrap().clone().into_iter() self.sources.read().unwrap().clone().into_iter()
} }
} }
#[derive(Debug, PartialEq)] #[derive(Debug, PartialEq)]
/// An error that occurred during the interpretation of a piece of code.
///
/// Each error has a source ID that identifies the piece of code that caused the error, and a list
/// of errors that occurred during the interpretation of that code.
pub struct InterpreterError { pub struct InterpreterError {
source_id: Arc<str>, source_id: Arc<str>,
errors: Vec<DustError>, errors: Vec<DustError>,
@ -119,6 +125,7 @@ impl InterpreterError {
} }
impl InterpreterError { impl InterpreterError {
/// Converts the error into a list of user-friendly reports that can be printed to the console.
pub fn build_reports<'a>(self) -> Vec<Report<'a, (Arc<str>, Range<usize>)>> { pub fn build_reports<'a>(self) -> Vec<Report<'a, (Arc<str>, Range<usize>)>> {
let token_color = Color::Yellow; let token_color = Color::Yellow;
let type_color = Color::Green; let type_color = Color::Green;
@ -281,23 +288,31 @@ impl InterpreterError {
builder = builder.with_help("Try specifying the type using turbofish."); builder = builder.with_help("Try specifying the type using turbofish.");
} }
if let Some(position) = expected_position { let actual_type_message = if let Some(position) = expected_position {
builder.add_label( builder.add_label(
Label::new((self.source_id.clone(), position.0..position.1)) Label::new((self.source_id.clone(), position.0..position.1))
.with_message(format!( .with_message(format!(
"Type {} established here.", "Type {} established here.",
expected.fg(type_color) expected.fg(type_color)
)), )),
);
format!("Got type {} here.", actual.fg(type_color))
} else {
format!(
"Got type {} but expected {}.",
actual.fg(type_color),
expected.fg(type_color)
) )
} };
builder.add_label( builder.add_label(
Label::new(( Label::new((
self.source_id.clone(), self.source_id.clone(),
actual_position.0..actual_position.1, actual_position.0..actual_position.1,
)) ))
.with_message(format!("Got type {} here.", actual.fg(type_color))), .with_message(actual_type_message),
); )
} }
ValidationError::VariableNotFound { ValidationError::VariableNotFound {
identifier, identifier,
@ -430,6 +445,7 @@ impl InterpreterError {
.add_label(Label::new((self.source_id.clone(), 0..0)).with_message(reason)), .add_label(Label::new((self.source_id.clone(), 0..0)).with_message(reason)),
ValidationError::CannotUsePath(_) => todo!(), ValidationError::CannotUsePath(_) => todo!(),
ValidationError::Uninitialized => todo!(), ValidationError::Uninitialized => todo!(),
ValidationError::WrongTypeArgumentsCount { expected, actual } => todo!(),
} }
} }

View File

@ -63,6 +63,27 @@ fn type_invokation() {
); );
} }
#[test]
fn enum_instance_with_type_arguments() {
assert_eq!(
parse(&lex("Foo::<int, str>::Bar(42)").unwrap()).unwrap()[0],
Statement::Expression(Expression::Value(
ValueNode::EnumInstance {
type_name: Identifier::new("Foo").with_position((0, 3)),
type_arguments: Some(vec![
TypeConstructor::Raw(RawTypeConstructor::Integer.with_position((6, 9))),
TypeConstructor::Raw(RawTypeConstructor::String.with_position((11, 14)))
]),
variant: Identifier::new("Bar").with_position((17, 20)),
content: Some(Box::new(Expression::Value(
ValueNode::Integer(42).with_position((21, 23))
)))
}
.with_position((0, 24))
))
);
}
#[test] #[test]
fn enum_instance() { fn enum_instance() {
assert_eq!( assert_eq!(

View File

@ -8,10 +8,10 @@ use colored::Colorize;
use log::Level; use log::Level;
use std::{ use std::{
collections::hash_map,
fs::read_to_string, fs::read_to_string,
io::{stderr, Write}, io::{stderr, Write},
sync::Arc, sync::Arc,
vec,
}; };
use dust_lang::{context::Context, Interpreter}; use dust_lang::{context::Context, Interpreter};
@ -81,7 +81,7 @@ fn main() {
for report in error.build_reports() { for report in error.build_reports() {
report report
.write_for_stdout( .write_for_stdout(
sources::<Arc<str>, Arc<str>, vec::IntoIter<(Arc<str>, Arc<str>)>>( sources::<Arc<str>, Arc<str>, hash_map::IntoIter<Arc<str>, Arc<str>>>(
interpreter.sources(), interpreter.sources(),
), ),
stderr(), stderr(),