From fecc62811dc0daa36b655856cbdfc09852f7452e Mon Sep 17 00:00:00 2001 From: Jeff Date: Wed, 19 Jun 2024 12:03:25 -0400 Subject: [PATCH] Improve type inference --- dust-lang/src/abstract_tree/assignment.rs | 19 ++++++------------- dust-lang/src/abstract_tree/loop.rs | 4 ++++ dust-lang/src/abstract_tree/statement.rs | 8 ++++++++ dust-lang/src/abstract_tree/value_node.rs | 10 +++++++++- dust-lang/src/lib.rs | 4 +++- dust-lang/tests/enums.rs | 23 +++++++++++++++++++++++ examples/type_inference.ds | 4 ++++ 7 files changed, 57 insertions(+), 15 deletions(-) diff --git a/dust-lang/src/abstract_tree/assignment.rs b/dust-lang/src/abstract_tree/assignment.rs index 2afe6e2..b4e1bf0 100644 --- a/dust-lang/src/abstract_tree/assignment.rs +++ b/dust-lang/src/abstract_tree/assignment.rs @@ -52,8 +52,13 @@ impl Evaluate for Assignment { )); } + let statement = self + .statement + .last_child_statement() + .unwrap_or(&self.statement); + if let (Some(constructor), Statement::Expression(Expression::FunctionCall(function_call))) = - (&self.constructor, self.statement.as_ref()) + (&self.constructor, statement) { let declared_type = constructor.clone().construct(context)?; let function_type = function_call.node.function().expected_type(context)?; @@ -79,18 +84,6 @@ impl Evaluate for Assignment { position: function_call.position, }); } - } else if let Some(constructor) = &self.constructor { - let r#type = constructor.clone().construct(&context)?; - - r#type - .check(&statement_type) - .map_err(|conflict| ValidationError::TypeCheck { - conflict, - actual_position: self.statement.position(), - expected_position: Some(constructor.position()), - })?; - - context.set_type(self.identifier.node.clone(), r#type.clone())?; } else { context.set_type(self.identifier.node.clone(), statement_type)?; } diff --git a/dust-lang/src/abstract_tree/loop.rs b/dust-lang/src/abstract_tree/loop.rs index 1a51b4e..72d04dc 100644 --- a/dust-lang/src/abstract_tree/loop.rs +++ b/dust-lang/src/abstract_tree/loop.rs @@ -16,6 +16,10 @@ impl Loop { pub fn new(statements: Vec) -> Self { Self { statements } } + + pub fn last_statement(&self) -> &Statement { + self.statements.last().unwrap() + } } impl Evaluate for Loop { diff --git a/dust-lang/src/abstract_tree/statement.rs b/dust-lang/src/abstract_tree/statement.rs index 4c8e34d..7ed2f24 100644 --- a/dust-lang/src/abstract_tree/statement.rs +++ b/dust-lang/src/abstract_tree/statement.rs @@ -39,6 +39,14 @@ impl Statement { Statement::While(inner) => inner.position, } } + + pub fn last_child_statement(&self) -> Option<&Self> { + match self { + Statement::Block(block) => Some(block.node.last_statement()), + Statement::Loop(r#loop) => Some(r#loop.node.last_statement()), + _ => None, + } + } } impl Evaluate for Statement { diff --git a/dust-lang/src/abstract_tree/value_node.rs b/dust-lang/src/abstract_tree/value_node.rs index b85febf..55d1788 100644 --- a/dust-lang/src/abstract_tree/value_node.rs +++ b/dust-lang/src/abstract_tree/value_node.rs @@ -41,6 +41,15 @@ pub enum ValueNode { impl Evaluate for ValueNode { fn validate(&self, context: &mut Context, _manage_memory: bool) -> Result<(), ValidationError> { + if let ValueNode::EnumInstance { type_name, .. } = self { + if let Some(_) = context.get_type(&type_name.node)? { + } else { + return Err(ValidationError::EnumDefinitionNotFound( + type_name.node.clone(), + )); + } + } + if let ValueNode::Map(map_assignments) = self { for (_identifier, constructor_option, expression) in map_assignments { expression.validate(context, _manage_memory)?; @@ -147,7 +156,6 @@ impl Evaluate for ValueNode { ) -> Result { let value = match self { ValueNode::Boolean(boolean) => Value::boolean(boolean), - ValueNode::EnumInstance { type_name, variant, diff --git a/dust-lang/src/lib.rs b/dust-lang/src/lib.rs index fec5c14..7cff0ba 100644 --- a/dust-lang/src/lib.rs +++ b/dust-lang/src/lib.rs @@ -13,7 +13,7 @@ use std::{ }; use abstract_tree::{AbstractTree, Type}; -use ariadne::{Color, Fmt, Label, Report, ReportKind}; +use ariadne::{Color, Config, Fmt, Label, Report, ReportKind}; use chumsky::prelude::*; use context::Context; use error::{DustError, RuntimeError, TypeConflict, ValidationError}; @@ -429,6 +429,8 @@ impl InterpreterError { } } + builder = builder.with_config(Config::default().with_multiline_arrows(false)); + let report = builder.finish(); reports.push(report); diff --git a/dust-lang/tests/enums.rs b/dust-lang/tests/enums.rs index a4c7caa..b453a3d 100644 --- a/dust-lang/tests/enums.rs +++ b/dust-lang/tests/enums.rs @@ -21,3 +21,26 @@ fn simple_enum() { ))) ); } + +#[test] +fn big_enum() { + assert_eq!( + interpret( + "test", + " + type FooBarBaz = enum |T, U, V| { + Foo(T), + Bar(U), + Baz(V), + } + + FooBarBaz::Baz(42.0) + " + ), + Ok(Some(Value::enum_instance( + Identifier::new("FooBarBaz"), + Identifier::new("Baz"), + Some(vec![Value::float(42.0)]), + ))) + ); +} diff --git a/examples/type_inference.ds b/examples/type_inference.ds index fdc907a..0ebc58e 100644 --- a/examples/type_inference.ds +++ b/examples/type_inference.ds @@ -14,3 +14,7 @@ x = json.parse::(int)::("1") // Use type annotation x: int = json.parse("1") + +x: int = { + json.parse("1") +}