Begin passing tests

This commit is contained in:
Jeff 2024-06-22 01:19:30 -04:00
parent 4b89ea0e96
commit 240c045a0c
2 changed files with 40 additions and 15 deletions

View File

@ -151,23 +151,25 @@ impl AbstractNode for ValueNode {
body.node.validate(&mut function_context, _manage_memory)?;
let (expected_return_type, expected_position) = if let Some(constructor) = return_type {
(constructor.construct(context)?, constructor.position())
} else {
return Err(ValidationError::ExpectedExpression(body.position));
};
let actual_return_type = if let Some(r#type) = body.node.expected_type(context)? {
r#type
} else {
return Err(ValidationError::ExpectedExpression(body.position));
let ((expected_return, expected_position), actual_return) =
match (return_type, body.node.expected_type(context)?) {
(Some(constructor), Some(r#type)) => (
(constructor.construct(context)?, constructor.position()),
r#type,
),
(None, Some(_)) => return Err(ValidationError::ExpectedValue(body.position)),
(Some(constructor), None) => {
return Err(ValidationError::ExpectedExpression(constructor.position()))
}
(None, None) => return Ok(()),
};
expected_return_type
.check(&actual_return_type)
.map_err(|conflict| ValidationError::TypeCheck {
expected_return.check(&actual_return).map_err(|conflict| {
ValidationError::TypeCheck {
conflict,
actual_position: body.position,
expected_position: Some(expected_position),
}
})?;
return Ok(());

View File

@ -427,6 +427,29 @@ fn list_of_type() {
#[test]
fn function_type() {
assert_eq!(
parse(&lex("type Foo = fn |T| (int)").unwrap()).unwrap()[0],
Statement::TypeAlias(
TypeAlias::new(
Identifier::new("Foo").with_position((5, 8)),
TypeConstructor::Function(
FunctionTypeConstructor {
type_parameters: Some(vec![Identifier::new("T").with_position((15, 16))]),
value_parameters: vec![TypeConstructor::Raw(
RawTypeConstructor::Integer.with_position((19, 22))
)],
return_type: None
}
.with_position((11, 23))
)
)
.with_position((0, 23))
)
);
}
#[test]
fn function_type_with_return() {
assert_eq!(
parse(&lex("type Foo = fn |T| (int) -> T").unwrap()).unwrap()[0],
Statement::TypeAlias(
@ -666,7 +689,7 @@ fn map() {
),
])
.with_position((0, 15))
),)
))
);
}