diff --git a/dust-lang/src/abstract_tree/value_node.rs b/dust-lang/src/abstract_tree/value_node.rs index 6b8be9e..597c0a6 100644 --- a/dust-lang/src/abstract_tree/value_node.rs +++ b/dust-lang/src/abstract_tree/value_node.rs @@ -151,24 +151,26 @@ impl AbstractNode for ValueNode { body.node.validate(&mut function_context, _manage_memory)?; - let (expected_return_type, expected_position) = if let Some(constructor) = return_type { - (constructor.construct(context)?, constructor.position()) - } else { - return Err(ValidationError::ExpectedExpression(body.position)); - }; - let actual_return_type = if let Some(r#type) = body.node.expected_type(context)? { - r#type - } else { - return Err(ValidationError::ExpectedExpression(body.position)); - }; + let ((expected_return, expected_position), actual_return) = + match (return_type, body.node.expected_type(context)?) { + (Some(constructor), Some(r#type)) => ( + (constructor.construct(context)?, constructor.position()), + r#type, + ), + (None, Some(_)) => return Err(ValidationError::ExpectedValue(body.position)), + (Some(constructor), None) => { + return Err(ValidationError::ExpectedExpression(constructor.position())) + } + (None, None) => return Ok(()), + }; - expected_return_type - .check(&actual_return_type) - .map_err(|conflict| ValidationError::TypeCheck { + expected_return.check(&actual_return).map_err(|conflict| { + ValidationError::TypeCheck { conflict, actual_position: body.position, expected_position: Some(expected_position), - })?; + } + })?; return Ok(()); } diff --git a/dust-lang/src/parser/tests.rs b/dust-lang/src/parser/tests.rs index 383b4c8..298d9f0 100644 --- a/dust-lang/src/parser/tests.rs +++ b/dust-lang/src/parser/tests.rs @@ -427,6 +427,29 @@ fn list_of_type() { #[test] fn function_type() { + assert_eq!( + parse(&lex("type Foo = fn |T| (int)").unwrap()).unwrap()[0], + Statement::TypeAlias( + TypeAlias::new( + Identifier::new("Foo").with_position((5, 8)), + TypeConstructor::Function( + FunctionTypeConstructor { + type_parameters: Some(vec![Identifier::new("T").with_position((15, 16))]), + value_parameters: vec![TypeConstructor::Raw( + RawTypeConstructor::Integer.with_position((19, 22)) + )], + return_type: None + } + .with_position((11, 23)) + ) + ) + .with_position((0, 23)) + ) + ); +} + +#[test] +fn function_type_with_return() { assert_eq!( parse(&lex("type Foo = fn |T| (int) -> T").unwrap()).unwrap()[0], Statement::TypeAlias( @@ -666,7 +689,7 @@ fn map() { ), ]) .with_position((0, 15)) - ),) + )) ); }