Pass all unit tests

This commit is contained in:
Jeff 2024-08-23 16:33:38 -04:00
parent 8b14d74eba
commit ab53df56bc
8 changed files with 316 additions and 52 deletions

View File

@ -18,7 +18,7 @@ use crate::{
StructExpression, TupleAccessExpression, StructExpression, TupleAccessExpression,
}, },
core_library, parse, Context, ContextError, DustError, Expression, Identifier, RangeableType, core_library, parse, Context, ContextError, DustError, Expression, Identifier, RangeableType,
StructType, Type, StructType, Type, TypeConflict, TypeEvaluation,
}; };
/// Analyzes the abstract syntax tree for errors. /// Analyzes the abstract syntax tree for errors.
@ -97,6 +97,24 @@ impl<'a> Analyzer<'a> {
return; return;
} }
Ok(TypeEvaluation::Constructor(StructType::Unit { name })) => {
let set_type = self.context.set_variable_type(
identifier.inner.clone(),
Type::Struct(StructType::Unit { name }),
statement.position(),
);
if let Err(context_error) = set_type {
self.errors.push(AnalysisError::ContextError {
error: context_error,
position: identifier.position,
});
}
self.analyze_expression(value, statement.position());
return;
}
Ok(evaluation) => evaluation.r#type(), Ok(evaluation) => evaluation.r#type(),
}; };
@ -104,7 +122,7 @@ impl<'a> Analyzer<'a> {
let set_type = self.context.set_variable_type( let set_type = self.context.set_variable_type(
identifier.inner.clone(), identifier.inner.clone(),
r#type.clone(), r#type.clone(),
identifier.position, statement.position(),
); );
if let Err(context_error) = set_type { if let Err(context_error) = set_type {
@ -188,6 +206,110 @@ impl<'a> Analyzer<'a> {
self.analyze_expression(invoker, statement_position); self.analyze_expression(invoker, statement_position);
let invoker_evaluation = match invoker.type_evaluation(&self.context) {
Ok(evaluation) => evaluation,
Err(ast_error) => {
self.errors.push(AnalysisError::AstError(ast_error));
return;
}
};
if let TypeEvaluation::Constructor(StructType::Tuple { fields, .. }) =
invoker_evaluation
{
for (expected_type, argument) in fields.iter().zip(arguments.iter()) {
let actual_type = match argument.type_evaluation(&self.context) {
Ok(evaluation) => evaluation.r#type(),
Err(ast_error) => {
self.errors.push(AnalysisError::AstError(ast_error));
return;
}
};
if let Some(r#type) = actual_type {
let check = expected_type.check(&r#type);
if let Err(type_conflict) = check {
self.errors.push(AnalysisError::TypeConflict {
actual_expression: argument.clone(),
type_conflict,
});
}
}
}
return;
}
let invoked_type = if let Some(r#type) = invoker_evaluation.r#type() {
r#type
} else {
self.errors
.push(AnalysisError::ExpectedValueFromExpression {
expression: invoker.clone(),
});
return;
};
let function_type = if let Type::Function(function_type) = invoked_type {
function_type
} else {
self.errors.push(AnalysisError::ExpectedFunction {
actual: invoked_type,
actual_expression: invoker.clone(),
});
return;
};
let value_parameters =
if let Some(value_parameters) = &function_type.value_parameters {
value_parameters
} else {
if !arguments.is_empty() {
self.errors.push(AnalysisError::ExpectedValueArgumentCount {
expected: 0,
actual: arguments.len(),
position: invoker.position(),
});
}
return;
};
for ((_, expected_type), argument) in value_parameters.iter().zip(arguments) {
self.analyze_expression(argument, statement_position);
let argument_evaluation = match argument.type_evaluation(&self.context) {
Ok(evaluation) => evaluation,
Err(error) => {
self.errors.push(AnalysisError::AstError(error));
continue;
}
};
let actual_type = if let Some(r#type) = argument_evaluation.r#type() {
r#type
} else {
self.errors
.push(AnalysisError::ExpectedValueFromExpression {
expression: argument.clone(),
});
continue;
};
if let Err(type_conflict) = expected_type.check(&actual_type) {
self.errors.push(AnalysisError::TypeConflict {
type_conflict,
actual_expression: argument.clone(),
});
}
}
for argument in arguments { for argument in arguments {
self.analyze_expression(argument, statement_position); self.analyze_expression(argument, statement_position);
} }
@ -766,6 +888,10 @@ pub enum AnalysisError {
error: ContextError, error: ContextError,
position: Span, position: Span,
}, },
ExpectedFunction {
actual: Type,
actual_expression: Expression,
},
ExpectedType { ExpectedType {
expected: Type, expected: Type,
actual: Type, actual: Type,
@ -812,8 +938,11 @@ pub enum AnalysisError {
}, },
TypeConflict { TypeConflict {
actual_expression: Expression, actual_expression: Expression,
actual_type: Type, type_conflict: TypeConflict,
expected: Type, },
UnexpectedArguments {
expected: Option<Vec<Type>>,
actual: Vec<Expression>,
}, },
UndefinedFieldIdentifier { UndefinedFieldIdentifier {
identifier: Node<Identifier>, identifier: Node<Identifier>,
@ -844,6 +973,9 @@ impl AnalysisError {
match self { match self {
AnalysisError::AstError(ast_error) => ast_error.position(), AnalysisError::AstError(ast_error) => ast_error.position(),
AnalysisError::ContextError { position, .. } => *position, AnalysisError::ContextError { position, .. } => *position,
AnalysisError::ExpectedFunction {
actual_expression, ..
} => actual_expression.position(),
AnalysisError::ExpectedType { AnalysisError::ExpectedType {
actual_expression, .. actual_expression, ..
} => actual_expression.position(), } => actual_expression.position(),
@ -864,6 +996,7 @@ impl AnalysisError {
AnalysisError::UndefinedFieldIdentifier { identifier, .. } => identifier.position, AnalysisError::UndefinedFieldIdentifier { identifier, .. } => identifier.position,
AnalysisError::UndefinedType { identifier } => identifier.position, AnalysisError::UndefinedType { identifier } => identifier.position,
AnalysisError::UndefinedVariable { identifier } => identifier.position, AnalysisError::UndefinedVariable { identifier } => identifier.position,
AnalysisError::UnexpectedArguments { actual, .. } => actual[0].position(),
AnalysisError::UnexpectedIdentifier { identifier } => identifier.position, AnalysisError::UnexpectedIdentifier { identifier } => identifier.position,
AnalysisError::UnexectedString { actual } => actual.position(), AnalysisError::UnexectedString { actual } => actual.position(),
} }
@ -877,6 +1010,16 @@ impl Display for AnalysisError {
match self { match self {
AnalysisError::AstError(ast_error) => write!(f, "{}", ast_error), AnalysisError::AstError(ast_error) => write!(f, "{}", ast_error),
AnalysisError::ContextError { error, .. } => write!(f, "{}", error), AnalysisError::ContextError { error, .. } => write!(f, "{}", error),
AnalysisError::ExpectedFunction {
actual,
actual_expression,
} => {
write!(
f,
"Expected function, found {} in {}",
actual, actual_expression
)
}
AnalysisError::ExpectedType { AnalysisError::ExpectedType {
expected, expected,
actual, actual,
@ -952,16 +1095,23 @@ impl Display for AnalysisError {
), ),
AnalysisError::TypeConflict { AnalysisError::TypeConflict {
actual_expression: actual_statement, actual_expression: actual_statement,
actual_type, type_conflict: TypeConflict { expected, actual },
expected,
} => { } => {
write!( write!(
f, f,
"Expected type {}, found {}, which has type {}", "Expected type {}, found {}, which has type {}",
expected, actual_statement, actual_type expected, actual_statement, actual
)
}
AnalysisError::UnexpectedArguments {
actual, expected, ..
} => {
write!(
f,
"Unexpected arguments {:?}, expected {:?}",
actual, expected
) )
} }
AnalysisError::UndefinedFieldIdentifier { AnalysisError::UndefinedFieldIdentifier {
identifier, identifier,
container, container,
@ -1070,10 +1220,12 @@ mod tests {
assert_eq!( assert_eq!(
analyze(source), analyze(source),
Err(DustError::Analysis { Err(DustError::Analysis {
analysis_errors: vec![AnalysisError::ExpectedType { analysis_errors: vec![AnalysisError::TypeConflict {
expected: Type::Float, actual_expression: Expression::literal(2, (56, 57)),
actual: Type::Integer, type_conflict: TypeConflict {
actual_expression: Expression::literal(2, (52, 53)), expected: Type::Float,
actual: Type::Integer,
},
}], }],
source, source,
}) })

View File

@ -8,7 +8,7 @@ use serde::{Deserialize, Serialize};
use crate::{ use crate::{
BuiltInFunction, Context, FunctionType, Identifier, RangeableType, StructType, Type, BuiltInFunction, Context, FunctionType, Identifier, RangeableType, StructType, Type,
TypeEvaluation, TypeEvaluation, ValueData,
}; };
use super::{AstError, Node, Span, Statement}; use super::{AstError, Node, Span, Statement};
@ -272,7 +272,7 @@ impl Expression {
} }
pub fn type_evaluation(&self, context: &Context) -> Result<TypeEvaluation, AstError> { pub fn type_evaluation(&self, context: &Context) -> Result<TypeEvaluation, AstError> {
let return_type = match self { let evaluation = match self {
Expression::Block(block_expression) => { Expression::Block(block_expression) => {
block_expression.inner.type_evaluation(context)? block_expression.inner.type_evaluation(context)?
} }
@ -287,49 +287,72 @@ impl Expression {
} }
Expression::Call(call_expression) => { Expression::Call(call_expression) => {
let CallExpression { invoker, .. } = call_expression.inner.as_ref(); let CallExpression { invoker, .. } = call_expression.inner.as_ref();
let invoker_evaluation = invoker.type_evaluation(context)?;
let invoker_type = invoker.type_evaluation(context)?.r#type(); match invoker_evaluation {
TypeEvaluation::Return(Some(Type::Function(FunctionType {
let return_type = return_type,
if let Some(Type::Function(FunctionType { return_type, .. })) = invoker_type { ..
return_type.map(|r#type| *r#type) }))) => TypeEvaluation::Return(return_type.map(|boxed| *boxed)),
} else if let Some(Type::Struct(_)) = invoker_type { TypeEvaluation::Constructor(struct_type) => {
invoker_type TypeEvaluation::Return(Some(Type::Struct(struct_type)))
} else { }
None _ => {
}; return Err(AstError::ExpectedFunctionOrConstructor {
position: invoker.position(),
TypeEvaluation::Return(return_type) });
}
}
} }
Expression::FieldAccess(field_access_expression) => { Expression::FieldAccess(field_access_expression) => {
let FieldAccessExpression { container, field } = let FieldAccessExpression { container, field } =
field_access_expression.inner.as_ref(); field_access_expression.inner.as_ref();
let container_type = container.type_evaluation(context)?.r#type(); let container_type = match container.type_evaluation(context)?.r#type() {
Some(r#type) => r#type,
None => {
return Err(AstError::ExpectedNonEmptyEvaluation {
position: container.position(),
})
}
};
if let Some(Type::Struct(StructType::Fields { fields, .. })) = container_type { if let Type::Struct(StructType::Fields { fields, .. }) = container_type {
let found_type = fields let found_type = fields
.into_iter() .into_iter()
.find(|(name, _)| name == &field.inner) .find(|(name, _)| name == &field.inner)
.map(|(_, r#type)| r#type); .map(|(_, r#type)| r#type);
TypeEvaluation::Return(found_type) TypeEvaluation::Return(found_type)
} else if let Some(field_type) = container_type.get_field_type(&field.inner) {
TypeEvaluation::Return(Some(field_type))
} else { } else {
return Err(AstError::ExpectedStructFieldsType { return Err(AstError::ExpectedNonEmptyEvaluation {
position: container.position(), position: container.position(),
}); });
} }
} }
Expression::Grouped(expression) => expression.inner.type_evaluation(context)?, Expression::Grouped(expression) => expression.inner.type_evaluation(context)?,
Expression::Identifier(identifier) => { Expression::Identifier(identifier) => {
let type_option = context.get_type(&identifier.inner).map_err(|error| { if let Some(struct_type) =
AstError::ContextError { context
error, .get_constructor_type(&identifier.inner)
position: identifier.position, .map_err(|error| AstError::ContextError {
} error,
})?; position: identifier.position,
})?
{
TypeEvaluation::Constructor(struct_type)
} else {
let type_option = context.get_type(&identifier.inner).map_err(|error| {
AstError::ContextError {
error,
position: identifier.position,
}
})?;
TypeEvaluation::Return(type_option) TypeEvaluation::Return(type_option)
}
} }
Expression::If(if_expression) => match if_expression.inner.as_ref() { Expression::If(if_expression) => match if_expression.inner.as_ref() {
IfExpression::If { .. } => TypeEvaluation::Return(None), IfExpression::If { .. } => TypeEvaluation::Return(None),
@ -540,7 +563,7 @@ impl Expression {
} }
}; };
Ok(return_type) Ok(evaluation)
} }
pub fn position(&self) -> Span { pub fn position(&self) -> Span {
@ -734,10 +757,12 @@ pub enum PrimitiveValueExpression {
impl Display for PrimitiveValueExpression { impl Display for PrimitiveValueExpression {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
match self { match self {
PrimitiveValueExpression::Boolean(boolean) => write!(f, "{boolean}"), PrimitiveValueExpression::Boolean(boolean) => ValueData::Boolean(*boolean).fmt(f),
PrimitiveValueExpression::Character(character) => write!(f, "'{character}'"), PrimitiveValueExpression::Character(character) => {
PrimitiveValueExpression::Float(float) => write!(f, "{float}"), ValueData::Character(*character).fmt(f)
PrimitiveValueExpression::Integer(integer) => write!(f, "{integer}"), }
PrimitiveValueExpression::Float(float) => ValueData::Float(*float).fmt(f),
PrimitiveValueExpression::Integer(integer) => ValueData::Integer(*integer).fmt(f),
} }
} }
} }

View File

@ -67,6 +67,9 @@ pub enum AstError {
error: ContextError, error: ContextError,
position: Span, position: Span,
}, },
ExpectedFunctionOrConstructor {
position: Span,
},
ExpectedInteger { ExpectedInteger {
position: Span, position: Span,
}, },
@ -98,6 +101,7 @@ impl AstError {
pub fn position(&self) -> Span { pub fn position(&self) -> Span {
match self { match self {
AstError::ContextError { position, .. } => *position, AstError::ContextError { position, .. } => *position,
AstError::ExpectedFunctionOrConstructor { position } => *position,
AstError::ExpectedInteger { position } => *position, AstError::ExpectedInteger { position } => *position,
AstError::ExpectedListType { position } => *position, AstError::ExpectedListType { position } => *position,
AstError::ExpectedNonEmptyEvaluation { position } => *position, AstError::ExpectedNonEmptyEvaluation { position } => *position,
@ -116,6 +120,9 @@ impl Display for AstError {
AstError::ContextError { error, position } => { AstError::ContextError { error, position } => {
write!(f, "Context error at {:?}: {}", position, error) write!(f, "Context error at {:?}: {}", position, error)
} }
AstError::ExpectedFunctionOrConstructor { position } => {
write!(f, "Expected a function or constructor at {:?}", position)
}
AstError::ExpectedInteger { position } => { AstError::ExpectedInteger { position } => {
write!(f, "Expected an integer at {:?}", position) write!(f, "Expected an integer at {:?}", position)
} }

View File

@ -7,7 +7,7 @@ use std::{
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::{Identifier, Type, Value, ValueData, ValueError}; use crate::{FunctionType, Identifier, Type, Value, ValueData, ValueError};
/// Integrated function that can be called from Dust code. /// Integrated function that can be called from Dust code.
#[derive(Debug, Clone, Eq, PartialEq, PartialOrd, Ord, Serialize, Deserialize)] #[derive(Debug, Clone, Eq, PartialEq, PartialOrd, Ord, Serialize, Deserialize)]
@ -65,6 +65,15 @@ impl BuiltInFunction {
} }
} }
pub fn r#type(&self) -> Type {
Type::Function(FunctionType {
name: Identifier::new(self.name()),
type_parameters: self.type_parameters(),
value_parameters: self.value_parameters(),
return_type: self.return_type().map(Box::new),
})
}
pub fn call( pub fn call(
&self, &self,
_type_arguments: Option<Vec<Type>>, _type_arguments: Option<Vec<Type>>,

View File

@ -127,6 +127,26 @@ impl Context {
} }
} }
/// Returns the constructor type associated with the identifier.
pub fn get_constructor_type(
&self,
identifier: &Identifier,
) -> Result<Option<StructType>, ContextError> {
let read_associations = self.associations.read()?;
if let Some((context_data, _)) = read_associations.get(identifier) {
match context_data {
ContextData::Constructor(constructor) => Ok(Some(constructor.struct_type.clone())),
ContextData::ConstructorType(struct_type) => Ok(Some(struct_type.clone())),
_ => Ok(None),
}
} else if let Some(parent) = &self.parent {
parent.get_constructor_type(identifier)
} else {
Ok(None)
}
}
/// Associates an identifier with a variable type, with a position given for garbage collection. /// Associates an identifier with a variable type, with a position given for garbage collection.
pub fn set_variable_type( pub fn set_variable_type(
&self, &self,
@ -136,9 +156,22 @@ impl Context {
) -> Result<(), ContextError> { ) -> Result<(), ContextError> {
log::trace!("Setting {identifier} to type {type} at {position:?}"); log::trace!("Setting {identifier} to type {type} at {position:?}");
self.associations let mut associations = self.associations.write()?;
.write()? let last_position = associations
.insert(identifier, (ContextData::VariableType(r#type), position)); .get(&identifier)
.map(|(_, last_position)| {
if last_position.1 > position.1 {
*last_position
} else {
position
}
})
.unwrap_or_default();
associations.insert(
identifier,
(ContextData::VariableType(r#type), last_position),
);
Ok(()) Ok(())
} }
@ -152,7 +185,6 @@ impl Context {
log::trace!("Setting {identifier} to value {value}"); log::trace!("Setting {identifier} to value {value}");
let mut associations = self.associations.write()?; let mut associations = self.associations.write()?;
let last_position = associations let last_position = associations
.get(&identifier) .get(&identifier)
.map(|(_, last_position)| *last_position) .map(|(_, last_position)| *last_position)
@ -175,7 +207,6 @@ impl Context {
log::trace!("Setting {identifier} to constructor {constructor:?}"); log::trace!("Setting {identifier} to constructor {constructor:?}");
let mut associations = self.associations.write()?; let mut associations = self.associations.write()?;
let last_position = associations let last_position = associations
.get(&identifier) .get(&identifier)
.map(|(_, last_position)| *last_position) .map(|(_, last_position)| *last_position)
@ -200,10 +231,20 @@ impl Context {
log::trace!("Setting {identifier} to constructor of type {struct_type}"); log::trace!("Setting {identifier} to constructor of type {struct_type}");
let mut variables = self.associations.write()?; let mut variables = self.associations.write()?;
let last_position = variables
.get(&identifier)
.map(|(_, last_position)| {
if last_position.1 > position.1 {
*last_position
} else {
position
}
})
.unwrap_or_default();
variables.insert( variables.insert(
identifier, identifier,
(ContextData::ConstructorType(struct_type), position), (ContextData::ConstructorType(struct_type), last_position),
); );
Ok(()) Ok(())

View File

@ -1,5 +1,6 @@
use crate::{Constructor, RuntimeError, Span, Type, Value}; use crate::{Constructor, RuntimeError, Span, StructType, Type, Value};
#[derive(Debug, Clone, PartialEq)]
pub enum Evaluation { pub enum Evaluation {
Break(Option<Value>), Break(Option<Value>),
Constructor(Constructor), Constructor(Constructor),
@ -23,9 +24,10 @@ impl Evaluation {
} }
} }
#[derive(Debug, Clone, PartialEq)]
pub enum TypeEvaluation { pub enum TypeEvaluation {
Break(Option<Type>), Break(Option<Type>),
Constructor(Type), Constructor(StructType),
Return(Option<Type>), Return(Option<Type>),
} }

View File

@ -17,7 +17,7 @@ use std::{
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::{constructor::Constructor, Identifier}; use crate::{constructor::Constructor, BuiltInFunction, Identifier};
/// Description of a kind of value. /// Description of a kind of value.
/// ///
@ -255,6 +255,26 @@ impl Type {
}, },
} }
} }
pub fn get_field_type(&self, field: &Identifier) -> Option<Type> {
match field.as_str() {
"to_string" => Some(BuiltInFunction::ToString.r#type()),
"length" => match self {
Type::List { .. } => Some(Type::Integer),
Type::ListOf { .. } => Some(Type::Integer),
Type::ListEmpty => Some(Type::Integer),
Type::Map { .. } => Some(Type::Integer),
Type::String { .. } => Some(Type::Integer),
_ => None,
},
"is_even" | "is_odd" => Some(Type::Boolean),
_ => match self {
Type::Struct(StructType::Fields { fields, .. }) => fields.get(field).cloned(),
Type::Map { pairs } => pairs.get(field).cloned(),
_ => None,
},
}
}
} }
impl Display for Type { impl Display for Type {

View File

@ -1173,7 +1173,15 @@ impl Display for ValueData {
ValueData::Byte(byte) => write!(f, "{byte}"), ValueData::Byte(byte) => write!(f, "{byte}"),
ValueData::Character(character) => write!(f, "{character}"), ValueData::Character(character) => write!(f, "{character}"),
ValueData::Enum(r#enum) => write!(f, "{enum}"), ValueData::Enum(r#enum) => write!(f, "{enum}"),
ValueData::Float(float) => write!(f, "{float}"), ValueData::Float(float) => {
write!(f, "{float}")?;
if float.fract() == 0.0 {
write!(f, ".0")?;
}
Ok(())
}
ValueData::Function(function) => write!(f, "{function}"), ValueData::Function(function) => write!(f, "{function}"),
ValueData::Integer(integer) => write!(f, "{integer}"), ValueData::Integer(integer) => write!(f, "{integer}"),
ValueData::Map(pairs) => { ValueData::Map(pairs) => {