diff --git a/dust-lang/src/abstract_tree/function_call.rs b/dust-lang/src/abstract_tree/function_call.rs index 75538d8..d6971b3 100644 --- a/dust-lang/src/abstract_tree/function_call.rs +++ b/dust-lang/src/abstract_tree/function_call.rs @@ -1,3 +1,5 @@ +use std::cmp::Ordering; + use serde::{Deserialize, Serialize}; use crate::{ @@ -6,13 +8,19 @@ use crate::{ value::ValueInner, }; -use super::{AbstractNode, Evaluation, Expression, Type, TypeConstructor, ValueNode, WithPosition}; +use super::{ + expression, AbstractNode, Evaluation, Expression, Type, TypeConstructor, ValueNode, + WithPosition, +}; -#[derive(Debug, Clone, Eq, PartialEq, PartialOrd, Ord, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct FunctionCall { function_expression: Box, type_arguments: Option>, value_arguments: Option>, + + #[serde(skip)] + context: Context, } impl FunctionCall { @@ -25,6 +33,7 @@ impl FunctionCall { function_expression: Box::new(function_expression), type_arguments, value_arguments, + context: Context::new(None), } } @@ -43,6 +52,8 @@ impl AbstractNode for FunctionCall { } } + self.context.set_parent(context.clone())?; + Ok(()) } @@ -70,20 +81,35 @@ impl AbstractNode for FunctionCall { return_type: _, } = function_node_type { - match (type_parameters, &self.type_arguments) { - (Some(parameters), Some(type_arguments)) => { - if parameters.len() != type_arguments.len() { - return Err(ValidationError::WrongTypeArguments { - arguments: type_arguments.clone(), - parameters: parameters.clone(), - }); - } + if let (Some(parameters), Some(arguments)) = (type_parameters, &self.type_arguments) { + if parameters.len() != arguments.len() { + return Err(ValidationError::WrongTypeArguments { + arguments: arguments.clone(), + parameters: parameters.clone(), + }); + } + + for (identifier, constructor) in parameters.into_iter().zip(arguments.into_iter()) { + let r#type = constructor.construct(context)?; + + self.context.set_type(identifier, r#type)?; } - _ => {} } match (value_parameters, &self.value_arguments) { (Some(parameters), Some(arguments)) => { + for ((identifier, _), expression) in + parameters.iter().zip(arguments.into_iter()) + { + let r#type = if let Some(r#type) = expression.expected_type(context)? { + r#type + } else { + return Err(ValidationError::ExpectedExpression(expression.position())); + }; + + self.context.set_type(identifier.clone(), r#type)?; + } + if parameters.len() != arguments.len() { return Err(ValidationError::WrongValueArguments { parameters, @@ -150,8 +176,6 @@ impl AbstractNode for FunctionCall { )); }; - let function_context = Context::new(Some(context.clone())); - if let (Some(type_parameters), Some(type_arguments)) = (function.type_parameters(), self.type_arguments) { @@ -160,7 +184,7 @@ impl AbstractNode for FunctionCall { { let r#type = constructor.construct(context)?; - function_context.set_type(identifier.clone(), r#type)?; + self.context.set_type(identifier.clone(), r#type)?; } } @@ -178,31 +202,71 @@ impl AbstractNode for FunctionCall { )); }; - function_context.set_value(identifier.clone(), value)?; + self.context.set_value(identifier.clone(), value)?; } } - function.call(&function_context, manage_memory) + function.call(&self.context, manage_memory) } fn expected_type(&self, context: &Context) -> Result, ValidationError> { - let return_type = if let Some(r#type) = self.function_expression.expected_type(context)? { - if let Type::Function { return_type, .. } = r#type { - return_type - } else { - return Err(ValidationError::ExpectedFunction { - actual: r#type, - position: self.function_expression.position(), - }); - } + let expression_type = self.function_expression.expected_type(context)?.ok_or( + ValidationError::ExpectedExpression(self.function_expression.position()), + )?; + + let (type_parameters, value_parameters, return_type) = if let Type::Function { + type_parameters, + value_parameters, + return_type, + } = expression_type + { + (type_parameters, value_parameters, return_type) } else { - return Err(ValidationError::ExpectedExpression( - self.function_expression.position(), - )); + return Err(ValidationError::ExpectedFunction { + actual: expression_type, + position: self.function_expression.position(), + }); }; - let return_type = return_type.map(|r#box| *r#box); + if let Some(Type::Generic { + identifier: return_identifier, + concrete_type: None, + }) = return_type.clone().map(|r#box| *r#box) + { + if let (Some(parameters), Some(arguments)) = (type_parameters, &self.type_arguments) { + for (identifier, constructor) in parameters.into_iter().zip(arguments.into_iter()) { + if identifier == return_identifier { + let r#type = constructor.construct(context)?; - Ok(return_type) + return Ok(Some(Type::Generic { + identifier, + concrete_type: Some(Box::new(r#type)), + })); + } + } + } + } + + Ok(return_type.map(|r#box| *r#box)) + } +} + +impl Eq for FunctionCall {} + +impl PartialEq for FunctionCall { + fn eq(&self, other: &Self) -> bool { + todo!() + } +} + +impl PartialOrd for FunctionCall { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for FunctionCall { + fn cmp(&self, other: &Self) -> Ordering { + todo!() } } diff --git a/dust-lang/src/value.rs b/dust-lang/src/value.rs index d35d936..f36f7c1 100644 --- a/dust-lang/src/value.rs +++ b/dust-lang/src/value.rs @@ -142,7 +142,15 @@ impl Display for Value { write!(f, "{type_name}::{variant}") } } - ValueInner::Float(float) => write!(f, "{float}"), + ValueInner::Float(float) => { + write!(f, "{float}")?; + + if &float.floor() == float { + write!(f, ".0")?; + } + + Ok(()) + } ValueInner::Integer(integer) => write!(f, "{integer}"), ValueInner::List(list) => { write!(f, "[")?;