diff --git a/src/abstract_tree/assignment.rs b/src/abstract_tree/assignment.rs index d5caeeb..29b040a 100644 --- a/src/abstract_tree/assignment.rs +++ b/src/abstract_tree/assignment.rs @@ -66,10 +66,6 @@ impl AbstractTree for Assignment { let key = self.identifier.inner(); let value = self.statement.run(source, context)?; - if let Some(type_definition) = &self.type_definition { - type_definition.check(&value, context)?; - } - let new_value = match self.operator { AssignmentOperator::PlusEqual => { if let Some(mut previous_value) = context.variables()?.get(key).cloned() { @@ -90,6 +86,10 @@ impl AbstractTree for Assignment { AssignmentOperator::Equal => value, }; + if let Some(type_definition) = &self.type_definition { + type_definition.check(&new_value, context)?; + } + context.variables_mut()?.insert(key.clone(), new_value); Ok(Value::Empty) diff --git a/src/abstract_tree/math.rs b/src/abstract_tree/math.rs index b8f131f..4a413f3 100644 --- a/src/abstract_tree/math.rs +++ b/src/abstract_tree/math.rs @@ -56,8 +56,8 @@ impl AbstractTree for Math { Ok(value) } - fn expected_type(&self, _context: &Map) -> Result { - Ok(Type::Number) + fn expected_type(&self, context: &Map) -> Result { + self.left.expected_type(context) } } diff --git a/src/abstract_tree/type_defintion.rs b/src/abstract_tree/type_defintion.rs index 3c614ae..df867ce 100644 --- a/src/abstract_tree/type_defintion.rs +++ b/src/abstract_tree/type_defintion.rs @@ -29,7 +29,7 @@ impl AbstractTree for TypeDefintion { } } -#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, PartialOrd, Ord)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialOrd, Ord)] pub enum Type { Any, Boolean, @@ -49,85 +49,13 @@ pub enum Type { impl TypeDefintion { pub fn check(&self, value: &Value, context: &Map) -> Result<()> { - match (&self.r#type, value) { - (Type::Any, _) - | (Type::Boolean, Value::Boolean(_)) - | (Type::Empty, Value::Empty) - | (Type::Float, Value::Float(_)) - | (Type::Integer, Value::Integer(_)) - | (Type::Map, Value::Map(_)) - | (Type::Number, Value::Integer(_)) - | (Type::Number, Value::Float(_)) - | (Type::String, Value::String(_)) - | (Type::Table, Value::Table(_)) => Ok(()), - (Type::List(_), Value::List(list)) => { - if let Some(first) = list.items().first() { - self.check(first, context) - } else { - Ok(()) - } - } - ( - Type::Function { - parameter_types, - return_type, - }, - Value::Function(function), - ) => { - let parameter_type_count = parameter_types.len(); - let parameter_count = function.parameters().len(); - - if parameter_type_count != parameter_count - || return_type.as_ref() != &function.body().expected_type(context)? - { - return Err(Error::TypeCheck { - expected: self.r#type.clone(), - actual: value.clone(), - }); - } - - Ok(()) - } - (Type::Boolean, _) => Err(Error::TypeCheck { - expected: Type::Boolean, - actual: value.clone(), - }), - (Type::Empty, _) => Err(Error::TypeCheck { - expected: Type::Empty, - actual: value.clone(), - }), - (Type::Float, _) => Err(Error::TypeCheck { - expected: Type::Float, - actual: value.clone(), - }), - (Type::Function { .. }, _) => Err(Error::TypeCheck { + if self.r#type == value.r#type(context)? { + Ok(()) + } else { + Err(Error::TypeCheck { expected: self.r#type.clone(), actual: value.clone(), - }), - (Type::Integer, _) => Err(Error::TypeCheck { - expected: Type::Integer, - actual: value.clone(), - }), - (Type::List(_), _) => Err(Error::TypeCheck { - expected: self.r#type.clone(), - actual: value.clone(), - }), - (Type::Map, _) => Err(Error::TypeCheck { - expected: Type::Map, - actual: value.clone(), - }), - (Type::Number, _) => Err(Error::TypeCheck { - expected: Type::Number, - actual: value.clone(), - }), - (Type::String, _) => Err(Error::TypeCheck { - expected: Type::String, - actual: value.clone(), - }), - (Type::Table, _) => Err(Error::TypeCheck { - expected: Type::Table, - actual: value.clone(), - }), + }) } } } @@ -169,11 +97,12 @@ impl AbstractTree for Type { Type::List(Box::new(item_type)) } "map" => Type::Map, + "num" => Type::Number, "str" => Type::String, "table" => Type::Table, _ => { return Err(Error::UnexpectedSyntaxNode { - expected: "any, bool, float, fn, int, list, map, str or table", + expected: "any, bool, float, fn, int, list, map, num, str or table", actual: type_node.kind(), location: type_node.start_position(), relevant_source: source[type_node.byte_range()].to_string(), @@ -193,6 +122,30 @@ impl AbstractTree for Type { } } +impl Eq for Type {} + +impl PartialEq for Type { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (Type::Any, _) + | (_, Type::Any) + | (Type::Boolean, Type::Boolean) + | (Type::Empty, Type::Empty) + | (Type::Float, Type::Float) + | (Type::Integer, Type::Integer) + | (Type::Map, Type::Map) + | (Type::Number, Type::Number) + | (Type::Number, Type::Integer) + | (Type::Number, Type::Float) + | (Type::Integer, Type::Number) + | (Type::Float, Type::Number) + | (Type::String, Type::String) + | (Type::Table, Type::Table) => true, + _ => false, + } + } +} + impl Display for Type { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { match self { @@ -200,7 +153,18 @@ impl Display for Type { Type::Boolean => write!(f, "bool"), Type::Empty => write!(f, "empty"), Type::Float => write!(f, "float"), - Type::Function { .. } => write!(f, "function"), + Type::Function { + parameter_types, + return_type, + } => { + write!(f, "fn ")?; + + for parameter_type in parameter_types { + write!(f, "{parameter_type} ")?; + } + + write!(f, "-> {return_type}") + } Type::Integer => write!(f, "integer"), Type::List(_) => write!(f, "list"), Type::Map => write!(f, "map"),