diff --git a/dust-lang/src/abstract_tree/function_call.rs b/dust-lang/src/abstract_tree/function_call.rs index 20c564d..fd84d10 100644 --- a/dust-lang/src/abstract_tree/function_call.rs +++ b/dust-lang/src/abstract_tree/function_call.rs @@ -128,10 +128,38 @@ impl AbstractNode for FunctionCall { } impl ExpectedType for FunctionCall { - fn expected_type(&self, _context: &mut Context) -> Result { - let function_node_type = self.function.expected_type(_context)?; + fn expected_type(&self, context: &mut Context) -> Result { + let function_node_type = self.function.expected_type(context)?; + + if let Type::Function { + return_type, + type_parameters, + .. + } = function_node_type + { + for (constructor, identifier) in self + .type_arguments + .as_ref() + .unwrap() + .into_iter() + .zip(type_parameters.unwrap().into_iter()) + { + if let Type::Generic { + identifier: return_identifier, + .. + } = *return_type.clone() + { + if return_identifier == identifier { + let concrete_type = constructor.clone().construct(&context)?; + + return Ok(Type::Generic { + identifier, + concrete_type: Some(Box::new(concrete_type)), + }); + } + } + } - if let Type::Function { return_type, .. } = function_node_type { Ok(*return_type) } else { Err(ValidationError::ExpectedFunction { diff --git a/dust-lang/src/abstract_tree/type.rs b/dust-lang/src/abstract_tree/type.rs index 3ec4b66..35d961b 100644 --- a/dust-lang/src/abstract_tree/type.rs +++ b/dust-lang/src/abstract_tree/type.rs @@ -21,7 +21,10 @@ pub enum Type { value_parameters: Vec, return_type: Box, }, - Generic(Option>), + Generic { + identifier: Identifier, + concrete_type: Option>, + }, Integer, List { length: usize, @@ -50,7 +53,16 @@ impl Type { | (Type::None, Type::None) | (Type::Range, Type::Range) | (Type::String, Type::String) => return Ok(()), - (Type::Generic(left), Type::Generic(right)) => match (left, right) { + ( + Type::Generic { + concrete_type: left, + .. + }, + Type::Generic { + concrete_type: right, + .. + }, + ) => match (left, right) { (Some(left), Some(right)) => { if left.check(&right).is_ok() { return Ok(()); @@ -61,6 +73,14 @@ impl Type { } _ => {} }, + (Type::Generic { concrete_type, .. }, other) + | (other, Type::Generic { concrete_type, .. }) => { + if let Some(concrete_type) = concrete_type { + if other == concrete_type.as_ref() { + return Ok(()); + } + } + } (Type::ListOf(left), Type::ListOf(right)) => { if left.check(&right).is_ok() { return Ok(()); @@ -207,11 +227,11 @@ impl Display for Type { Type::Any => write!(f, "any"), Type::Boolean => write!(f, "bool"), Type::Float => write!(f, "float"), - Type::Generic(type_option) => { - if let Some(concrete_type) = type_option { - write!(f, "implied to be {concrete_type}") + Type::Generic { concrete_type, .. } => { + if let Some(r#type) = concrete_type { + write!(f, "implied to be {}", r#type) } else { - todo!() + write!(f, "unknown") } } Type::Integer => write!(f, "int"), diff --git a/dust-lang/src/abstract_tree/type_constructor.rs b/dust-lang/src/abstract_tree/type_constructor.rs index fde1174..49764c0 100644 --- a/dust-lang/src/abstract_tree/type_constructor.rs +++ b/dust-lang/src/abstract_tree/type_constructor.rs @@ -82,9 +82,15 @@ impl TypeConstructor { node: identifier, .. }) => { if let Some(r#type) = context.get_type(&identifier)? { - Type::Generic(Some(Box::new(r#type))) + Type::Generic { + identifier, + concrete_type: Some(Box::new(r#type)), + } } else { - Type::Generic(None) + Type::Generic { + identifier, + concrete_type: None, + } } } TypeConstructor::List(positioned_constructor) => { diff --git a/dust-lang/src/abstract_tree/value_node.rs b/dust-lang/src/abstract_tree/value_node.rs index ad8d64a..0b11883 100644 --- a/dust-lang/src/abstract_tree/value_node.rs +++ b/dust-lang/src/abstract_tree/value_node.rs @@ -68,7 +68,13 @@ impl AbstractNode for ValueNode { if let Some(type_parameters) = type_parameters { for identifier in type_parameters { - function_context.set_type(identifier.clone(), Type::Generic(None))?; + function_context.set_type( + identifier.clone(), + Type::Generic { + identifier: identifier.clone(), + concrete_type: None, + }, + )?; } } diff --git a/examples/fizzbuzz.ds b/examples/fizzbuzz.ds index cdbde9d..7c4a956 100644 --- a/examples/fizzbuzz.ds +++ b/examples/fizzbuzz.ds @@ -18,3 +18,4 @@ while count <= 15 { count += 1 } + diff --git a/examples/type_inference.ds b/examples/type_inference.ds new file mode 100644 index 0000000..681f7e2 --- /dev/null +++ b/examples/type_inference.ds @@ -0,0 +1,2 @@ +foo = fn |T| (x: T) -> T { x } +bar: str = foo::(str)::("hi")