Add enum type validation

This commit is contained in:
Jeff 2024-07-12 16:04:47 -04:00
parent 790438d1e3
commit c47d09fd1d
6 changed files with 118 additions and 43 deletions

View File

@ -73,6 +73,7 @@ impl AbstractNode for IfElse {
.map_err(|conflict| ValidationError::TypeCheck {
conflict,
actual_position: self.if_block.node.last_statement().position(),
expected_position: Some(self.if_expression.position()),
})?;
}

View File

@ -111,35 +111,52 @@ impl AbstractNode for ValueNode {
expression.define_and_validate(context, _manage_memory, scope)?;
}
if let Some(Type::Enum { name, variants, .. }) = context.get_type(&type_name.node)? {
let mut found = false;
for (identifier, content) in &variants {
if identifier == &variant.node {
found = true;
if let Some(Type::Enum {
variants,
type_parameters,
..
}) = context.get_type(&type_name.node)?
{
if let (Some(parameters), Some(arguments)) = (type_parameters, type_arguments) {
if arguments.len() != parameters.len() {
return Err(ValidationError::WrongTypeArgumentsCount {
expected: parameters.len(),
actual: arguments.len(),
});
}
if let Some(content) = content {
for r#type in content {
if let Type::Generic {
concrete_type: None,
..
} = r#type
{
return Err(ValidationError::FullTypeNotKnown {
identifier: name,
position: variant.position,
});
let mut arguments = arguments.iter();
let mut found = false;
for (identifier, content) in variants {
if identifier == variant.node {
found = true;
}
if let Some(content) = content {
for expected_type in &content {
if let Type::Generic {
concrete_type: None,
..
} = expected_type
{
arguments.next().ok_or_else(|| {
ValidationError::WrongTypeArgumentsCount {
expected: content.len(),
actual: arguments.len(),
}
})?;
}
}
}
}
}
if !found {
return Err(ValidationError::EnumVariantNotFound {
identifier: variant.node.clone(),
position: variant.position,
});
if !found {
return Err(ValidationError::EnumVariantNotFound {
identifier: variant.node.clone(),
position: variant.position,
});
}
}
} else {
return Err(ValidationError::EnumDefinitionNotFound {
@ -691,10 +708,26 @@ impl Display for ValueNode {
variant,
content,
} => {
write!(f, "{}::{}", type_name.node, variant.node)?;
write!(f, "{}", type_name.node)?;
if let Some(types) = type_arguments {
write!(f, "::<")?;
for (index, r#type) in types.iter().enumerate() {
if index > 0 {
write!(f, ", ")?;
}
write!(f, "{type}")?;
}
write!(f, ">")?;
}
write!(f, "::{}", variant.node)?;
if let Some(expression) = content {
write!(f, "{expression}")?;
write!(f, "({expression})")?;
}
Ok(())

View File

@ -181,6 +181,10 @@ pub enum ValidationError {
identifier: Identifier,
position: SourcePosition,
},
WrongTypeArgumentsCount {
expected: usize,
actual: usize,
},
}
impl From<PoisonError> for ValidationError {

View File

@ -8,9 +8,9 @@ pub mod standard_library;
pub mod value;
use std::{
collections::{hash_map, HashMap},
ops::Range,
sync::{Arc, RwLock},
vec,
};
pub use abstract_tree::Type;
@ -29,21 +29,24 @@ pub fn interpret(source_id: &str, source: &str) -> Result<Option<Value>, Interpr
interpreter.run(Arc::from(source_id), Arc::from(source))
}
type Source = (Arc<str>, Arc<str>);
/// Interpreter, lexer and parser for the Dust programming language.
///
/// You must provide the interpreter with an ID for each piece of code you pass to it. These are
/// used to identify the source of errors and to provide more detailed error messages.
pub struct Interpreter {
context: Context,
sources: Arc<RwLock<Vec<Source>>>,
sources: Arc<RwLock<HashMap<Arc<str>, Arc<str>>>>,
}
impl Interpreter {
pub fn new(context: Context) -> Self {
Interpreter {
context,
sources: Arc::new(RwLock::new(Vec::new())),
sources: Arc::new(RwLock::new(HashMap::new())),
}
}
/// Lexes the source code and returns a list of tokens.
pub fn lex<'src>(
&self,
source_id: Arc<str>,
@ -51,14 +54,14 @@ impl Interpreter {
) -> Result<Vec<Token<'src>>, InterpreterError> {
let mut sources = self.sources.write().unwrap();
sources.clear();
sources.push((source_id.clone(), Arc::from(source)));
sources.insert(source_id.clone(), Arc::from(source));
lex(source)
.map(|tokens| tokens.into_iter().map(|(token, _)| token).collect())
.map_err(|errors| InterpreterError { source_id, errors })
}
/// Parses the source code and returns an abstract syntax tree.
pub fn parse(
&self,
source_id: Arc<str>,
@ -66,8 +69,7 @@ impl Interpreter {
) -> Result<AbstractTree, InterpreterError> {
let mut sources = self.sources.write().unwrap();
sources.clear();
sources.push((source_id.clone(), Arc::from(source)));
sources.insert(source_id.clone(), Arc::from(source));
parse(&lex(source).map_err(|errors| InterpreterError {
source_id: source_id.clone(),
@ -76,6 +78,7 @@ impl Interpreter {
.map_err(|errors| InterpreterError { source_id, errors })
}
/// Runs the source code and returns the result.
pub fn run(
&self,
source_id: Arc<str>,
@ -83,8 +86,7 @@ impl Interpreter {
) -> Result<Option<Value>, InterpreterError> {
let mut sources = self.sources.write().unwrap();
sources.clear();
sources.push((source_id.clone(), source.clone()));
sources.insert(source_id.clone(), source.clone());
let tokens = lex(source.as_ref()).map_err(|errors| InterpreterError {
source_id: source_id.clone(),
@ -101,12 +103,16 @@ impl Interpreter {
Ok(value_option)
}
pub fn sources(&self) -> vec::IntoIter<(Arc<str>, Arc<str>)> {
pub fn sources(&self) -> hash_map::IntoIter<Arc<str>, Arc<str>> {
self.sources.read().unwrap().clone().into_iter()
}
}
#[derive(Debug, PartialEq)]
/// An error that occurred during the interpretation of a piece of code.
///
/// Each error has a source ID that identifies the piece of code that caused the error, and a list
/// of errors that occurred during the interpretation of that code.
pub struct InterpreterError {
source_id: Arc<str>,
errors: Vec<DustError>,
@ -119,6 +125,7 @@ impl InterpreterError {
}
impl InterpreterError {
/// Converts the error into a list of user-friendly reports that can be printed to the console.
pub fn build_reports<'a>(self) -> Vec<Report<'a, (Arc<str>, Range<usize>)>> {
let token_color = Color::Yellow;
let type_color = Color::Green;
@ -281,23 +288,31 @@ impl InterpreterError {
builder = builder.with_help("Try specifying the type using turbofish.");
}
if let Some(position) = expected_position {
let actual_type_message = if let Some(position) = expected_position {
builder.add_label(
Label::new((self.source_id.clone(), position.0..position.1))
.with_message(format!(
"Type {} established here.",
expected.fg(type_color)
)),
);
format!("Got type {} here.", actual.fg(type_color))
} else {
format!(
"Got type {} but expected {}.",
actual.fg(type_color),
expected.fg(type_color)
)
}
};
builder.add_label(
Label::new((
self.source_id.clone(),
actual_position.0..actual_position.1,
))
.with_message(format!("Got type {} here.", actual.fg(type_color))),
);
.with_message(actual_type_message),
)
}
ValidationError::VariableNotFound {
identifier,
@ -430,6 +445,7 @@ impl InterpreterError {
.add_label(Label::new((self.source_id.clone(), 0..0)).with_message(reason)),
ValidationError::CannotUsePath(_) => todo!(),
ValidationError::Uninitialized => todo!(),
ValidationError::WrongTypeArgumentsCount { expected, actual } => todo!(),
}
}

View File

@ -63,6 +63,27 @@ fn type_invokation() {
);
}
#[test]
fn enum_instance_with_type_arguments() {
assert_eq!(
parse(&lex("Foo::<int, str>::Bar(42)").unwrap()).unwrap()[0],
Statement::Expression(Expression::Value(
ValueNode::EnumInstance {
type_name: Identifier::new("Foo").with_position((0, 3)),
type_arguments: Some(vec![
TypeConstructor::Raw(RawTypeConstructor::Integer.with_position((6, 9))),
TypeConstructor::Raw(RawTypeConstructor::String.with_position((11, 14)))
]),
variant: Identifier::new("Bar").with_position((17, 20)),
content: Some(Box::new(Expression::Value(
ValueNode::Integer(42).with_position((21, 23))
)))
}
.with_position((0, 24))
))
);
}
#[test]
fn enum_instance() {
assert_eq!(

View File

@ -8,10 +8,10 @@ use colored::Colorize;
use log::Level;
use std::{
collections::hash_map,
fs::read_to_string,
io::{stderr, Write},
sync::Arc,
vec,
};
use dust_lang::{context::Context, Interpreter};
@ -81,7 +81,7 @@ fn main() {
for report in error.build_reports() {
report
.write_for_stdout(
sources::<Arc<str>, Arc<str>, vec::IntoIter<(Arc<str>, Arc<str>)>>(
sources::<Arc<str>, Arc<str>, hash_map::IntoIter<Arc<str>, Arc<str>>>(
interpreter.sources(),
),
stderr(),