From 9f98617f6ab9b18d212fb6b91534ba6eed759bc1 Mon Sep 17 00:00:00 2001 From: Jeff Date: Mon, 18 Mar 2024 16:00:04 -0400 Subject: [PATCH] Refine enums --- src/abstract_tree/enum_definition.rs | 20 ++++++-- src/abstract_tree/type.rs | 36 +++++++++++--- src/abstract_tree/value_node.rs | 67 ++++++++++++++++++-------- src/context.rs | 29 ++++++++++- src/parser.rs | 48 ++++++++----------- src/value.rs | 20 ++------ tests/enums.rs | 72 ++++++++++++++++++++++++++-- 7 files changed, 214 insertions(+), 78 deletions(-) diff --git a/src/abstract_tree/enum_definition.rs b/src/abstract_tree/enum_definition.rs index e1b0640..a546f4e 100644 --- a/src/abstract_tree/enum_definition.rs +++ b/src/abstract_tree/enum_definition.rs @@ -8,15 +8,15 @@ use super::{AbstractTree, Action, Identifier, Type, WithPosition}; #[derive(Clone, Debug, Eq, PartialEq, PartialOrd, Ord)] pub struct EnumDefinition { name: Identifier, - type_parameters: Vec, - variants: Vec<(Identifier, Option>>)>, + type_parameters: Option>, + variants: Vec<(Identifier, Option>)>, } impl EnumDefinition { pub fn new( name: Identifier, - type_parameters: Vec, - variants: Vec<(Identifier, Option>>)>, + type_parameters: Option>, + variants: Vec<(Identifier, Option>)>, ) -> Self { Self { name, @@ -24,6 +24,18 @@ impl EnumDefinition { variants, } } + + pub fn name(&self) -> &Identifier { + &self.name + } + + pub fn type_parameters(&self) -> &Option> { + &self.type_parameters + } + + pub fn variants(&self) -> &Vec<(Identifier, Option>)> { + &self.variants + } } impl AbstractTree for EnumDefinition { diff --git a/src/abstract_tree/type.rs b/src/abstract_tree/type.rs index d8d6181..e1bd1dc 100644 --- a/src/abstract_tree/type.rs +++ b/src/abstract_tree/type.rs @@ -12,7 +12,12 @@ use super::{AbstractTree, Action}; pub enum Type { Any, Boolean, - Custom(Identifier), + Parameter(Identifier), + Enum { + name: Identifier, + type_arguments: Option>, + variants: Vec<(Identifier, Option)>, + }, Float, Function { parameter_types: Vec, @@ -45,7 +50,7 @@ impl Type { | (Type::None, Type::None) | (Type::Range, Type::Range) | (Type::String, Type::String) => Ok(()), - (Type::Custom(left), Type::Custom(right)) => { + (Type::Parameter(left), Type::Parameter(right)) => { if left == right { Ok(()) } else { @@ -113,7 +118,26 @@ impl Display for Type { match self { Type::Any => write!(f, "any"), Type::Boolean => write!(f, "boolean"), - Type::Custom(name) => write!(f, "{name}"), + Type::Parameter(name) => write!(f, "{name}"), + Type::Enum { + name, + type_arguments, + variants: _, + } => { + write!(f, "{name}(")?; + + if let Some(type_arguments) = type_arguments { + for (index, r#type) in type_arguments.into_iter().enumerate() { + if index == type_arguments.len() - 1 { + write!(f, "{}", r#type)?; + } else { + write!(f, "{}, ", r#type)?; + } + } + } + + write!(f, ")") + } Type::Float => write!(f, "float"), Type::Integer => write!(f, "integer"), Type::List => write!(f, "list"), @@ -160,7 +184,7 @@ mod tests { assert_eq!(Type::Any.check(&Type::Any), Ok(())); assert_eq!(Type::Boolean.check(&Type::Boolean), Ok(())); assert_eq!( - Type::Custom(Identifier::new("foo")).check(&Type::Custom(Identifier::new("foo"))), + Type::Parameter(Identifier::new("foo")).check(&Type::Parameter(Identifier::new("foo"))), Ok(()) ); assert_eq!(Type::Float.check(&Type::Float), Ok(())); @@ -183,8 +207,8 @@ mod tests { #[test] fn errors() { - let foo = Type::Custom(Identifier::new("foo")); - let bar = Type::Custom(Identifier::new("bar")); + let foo = Type::Parameter(Identifier::new("foo")); + let bar = Type::Parameter(Identifier::new("bar")); assert_eq!( foo.check(&bar), diff --git a/src/abstract_tree/value_node.rs b/src/abstract_tree/value_node.rs index ea306d8..741b78c 100644 --- a/src/abstract_tree/value_node.rs +++ b/src/abstract_tree/value_node.rs @@ -15,7 +15,8 @@ pub enum ValueNode { EnumInstance { name: Identifier, variant: Identifier, - expressions: Vec>, + type_arguments: Option>>, + expression: Box>, }, Float(f64), Integer(i64), @@ -37,18 +38,29 @@ pub enum ValueNode { } impl AbstractTree for ValueNode { - fn expected_type(&self, _context: &Context) -> Result { + fn expected_type(&self, context: &Context) -> Result { let r#type = match self { ValueNode::Boolean(_) => Type::Boolean, - ValueNode::EnumInstance { name, .. } => Type::Custom(name.clone()), + ValueNode::EnumInstance { + name, + variant: _, + type_arguments, + expression: _, + } => { + if let Some(r#type) = context.get_type(name)? { + r#type + } else { + Type::None + } + } ValueNode::Float(_) => Type::Float, ValueNode::Integer(_) => Type::Integer, ValueNode::List(items) => { let mut item_types = Vec::with_capacity(items.len()); for expression in items { - item_types.push(expression.node.expected_type(_context)?); + item_types.push(expression.node.expected_type(context)?); } Type::ListExact(item_types) @@ -117,6 +129,16 @@ impl AbstractTree for ValueNode { })?; } + if let ValueNode::EnumInstance { + name, + variant, + type_arguments, + expression, + } = self + { + let r#type = self.expected_type(context)?; + } + Ok(()) } @@ -126,22 +148,17 @@ impl AbstractTree for ValueNode { ValueNode::EnumInstance { name, variant, - expressions, + type_arguments: _, + expression, } => { - let mut values = Vec::with_capacity(expressions.len()); + let action = expression.node.run(_context)?; + let value = if let Action::Return(value) = action { + value + } else { + todo!() + }; - for expression in expressions { - let action = expression.node.run(_context)?; - let value = if let Action::Return(value) = action { - value - } else { - todo!() - }; - - values.push(value); - } - - Value::enum_instance(EnumInstance::new(name, variant, values)) + Value::enum_instance(EnumInstance::new(name, variant, value)) } ValueNode::Float(float) => Value::float(float), ValueNode::Integer(integer) => Value::integer(integer), @@ -213,12 +230,14 @@ impl Ord for ValueNode { EnumInstance { name: left_name, variant: left_variant, - expressions: left_expressions, + type_arguments: left_types, + expression: left_expression, }, EnumInstance { name: right_name, variant: right_variant, - expressions: right_expressions, + type_arguments: right_types, + expression: right_expression, }, ) => { let name_cmp = left_name.cmp(right_name); @@ -227,7 +246,13 @@ impl Ord for ValueNode { let variant_cmp = left_variant.cmp(right_variant); if variant_cmp.is_eq() { - left_expressions.cmp(right_expressions) + let type_cmp = left_types.cmp(right_types); + + if type_cmp.is_eq() { + left_expression.cmp(right_expression) + } else { + type_cmp + } } else { variant_cmp } diff --git a/src/context.rs b/src/context.rs index 27f320b..be0cde5 100644 --- a/src/context.rs +++ b/src/context.rs @@ -77,7 +77,34 @@ impl Context { let r#type = match value_data { ValueData::Type(r#type) => r#type.clone(), ValueData::Value(value) => value.r#type(), - ValueData::EnumDefinition(_) => return Ok(None), + ValueData::EnumDefinition(enum_definition) => { + let type_arguments = + enum_definition + .type_parameters() + .as_ref() + .map(|identifier_list| { + identifier_list + .into_iter() + .map(|identifier| Type::Parameter(identifier.clone())) + .collect() + }); + let variants = enum_definition + .variants() + .into_iter() + .map(|(identifier, type_option)| { + ( + identifier.clone(), + type_option.clone().map(|r#type| r#type.node), + ) + }) + .collect(); + + Type::Enum { + name: enum_definition.name().clone(), + type_arguments, + variants, + } + } }; return Ok(Some(r#type.clone())); diff --git a/src/parser.rs b/src/parser.rs index a44ce41..08ea1d4 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -107,7 +107,7 @@ pub fn parser<'src>() -> DustParser<'src> { just(Token::Keyword("list")).to(Type::List), identifier .clone() - .map(|identifier| Type::Custom(identifier)), + .map(|identifier| Type::Parameter(identifier)), )) }) .map_with(|r#type, state| r#type.with_position(state.span())); @@ -345,10 +345,8 @@ pub fn parser<'src>() -> DustParser<'src> { let enum_instance = identifier .clone() - .then_ignore(just(Token::Control(Control::DoubleColon))) - .then(identifier.clone()) .then( - positioned_expression + r#type .clone() .separated_by(just(Token::Control(Control::Comma))) .allow_trailing() @@ -356,13 +354,21 @@ pub fn parser<'src>() -> DustParser<'src> { .delimited_by( just(Token::Control(Control::ParenOpen)), just(Token::Control(Control::ParenClose)), - ), + ) + .or_not(), ) - .map_with(|((name, variant), expressions), state| { + .then_ignore(just(Token::Control(Control::DoubleColon))) + .then(identifier.clone()) + .then(positioned_expression.delimited_by( + just(Token::Control(Control::ParenOpen)), + just(Token::Control(Control::ParenClose)), + )) + .map_with(|(((name, type_arguments), variant), expression), state| { Expression::Value(ValueNode::EnumInstance { name, + type_arguments, variant, - expressions, + expression: Box::new(expression), }) .with_position(state.span()) }); @@ -452,9 +458,6 @@ pub fn parser<'src>() -> DustParser<'src> { let enum_variant = identifier.clone().then( r#type .clone() - .separated_by(just(Token::Control(Control::Comma))) - .allow_trailing() - .collect() .delimited_by( just(Token::Control(Control::ParenOpen)), just(Token::Control(Control::ParenClose)), @@ -473,7 +476,8 @@ pub fn parser<'src>() -> DustParser<'src> { .delimited_by( just(Token::Control(Control::ParenOpen)), just(Token::Control(Control::ParenClose)), - ), + ) + .or_not(), ) .then( enum_variant @@ -517,9 +521,9 @@ mod tests { assert_eq!( parse( &lex(" - enum FooBar (F, B) { - Foo(F), - Bar(B), + enum FooBar { + Foo, + Bar, } ") .unwrap() @@ -528,20 +532,10 @@ mod tests { .node, Statement::EnumDefinition(EnumDefinition::new( Identifier::new("FooBar"), - vec![Identifier::new("F"), Identifier::new("B")], + None, vec![ - ( - Identifier::new("Foo"), - Some(vec![ - Type::Custom(Identifier::new("F")).with_position((62, 63)) - ]), - ), - ( - Identifier::new("Bar"), - Some(vec![ - Type::Custom(Identifier::new("B")).with_position((90, 91)) - ]) - ) + (Identifier::new("Foo"), None), + (Identifier::new("Bar"), None) ] )) ); diff --git a/src/value.rs b/src/value.rs index 8f95187..9377a94 100644 --- a/src/value.rs +++ b/src/value.rs @@ -81,7 +81,7 @@ impl Value { pub fn r#type(&self) -> Type { match self.0.as_ref() { ValueInner::Boolean(_) => Type::Boolean, - ValueInner::EnumInstance(EnumInstance { name, .. }) => Type::Custom(name.clone()), + ValueInner::EnumInstance(EnumInstance { name, .. }) => Type::Parameter(name.clone()), ValueInner::Float(_) => Type::Float, ValueInner::Integer(_) => Type::Integer, ValueInner::List(values) => { @@ -146,19 +146,9 @@ impl Display for Value { ValueInner::EnumInstance(EnumInstance { name, variant, - value: content, + value, }) => { - write!(f, "{name}::{variant}(")?; - - for (index, value) in content.into_iter().enumerate() { - if index == content.len() - 1 { - write!(f, "{value}")?; - } else { - write!(f, "{value} ")?; - } - } - - write!(f, ")") + write!(f, "{name}::{variant}({value})") } ValueInner::Float(float) => write!(f, "{float}"), ValueInner::Integer(integer) => write!(f, "{integer}"), @@ -276,11 +266,11 @@ impl Ord for ValueInner { pub struct EnumInstance { name: Identifier, variant: Identifier, - value: Vec, + value: Value, } impl EnumInstance { - pub fn new(name: Identifier, variant: Identifier, value: Vec) -> Self { + pub fn new(name: Identifier, variant: Identifier, value: Value) -> Self { Self { name, variant, diff --git a/tests/enums.rs b/tests/enums.rs index e2740cd..14f4ebe 100644 --- a/tests/enums.rs +++ b/tests/enums.rs @@ -1,11 +1,43 @@ -use dust_lang::*; +use dust_lang::{ + abstract_tree::Type, + error::{Error, TypeConflict, ValidationError}, + *, +}; #[test] -fn define_enum() { +fn simple_enum_type_check() { + assert_eq!( + interpret( + " + enum FooBar { + Foo(int), + Bar, + } + + foo = FooBar::Foo('yo') + foo + ", + ), + Err(vec![Error::Validation { + error: ValidationError::TypeCheck { + conflict: TypeConflict { + actual: Type::String, + expected: Type::Integer, + }, + actual_position: (0, 0).into(), + expected_position: (0, 0).into() + }, + position: (0, 0).into() + }]) + ) +} + +#[test] +fn simple_enum() { interpret( " - enum FooBar(F) { - Foo(F), + enum FooBar { + Foo(int), Bar, } @@ -15,3 +47,35 @@ fn define_enum() { ) .unwrap(); } + +#[test] +fn simple_enum_with_type_argument() { + interpret( + " + enum FooBar(F) { + Foo(F), + Bar, + } + + foo = FooBar(int)::Foo(1) + foo + ", + ) + .unwrap(); +} + +#[test] +fn complex_enum_with_type_arguments() { + interpret( + " + enum FooBar(F, B) { + Foo(F), + Bar(B), + } + + bar = FooBar(int, str)::Bar('bar') + bar + ", + ) + .unwrap(); +}