diff --git a/src/abstract_tree/assignment.rs b/src/abstract_tree/assignment.rs index fee6b18..b46bf56 100644 --- a/src/abstract_tree/assignment.rs +++ b/src/abstract_tree/assignment.rs @@ -29,10 +29,14 @@ impl<'src> AbstractTree for Assignment<'src> { } fn validate(&self, context: &Context) -> Result<(), ValidationError> { - if let Some(expected) = &self.r#type { - let statement_type = self.statement.expected_type(context)?; + let statement_type = self.statement.expected_type(context)?; + if let Some(expected) = &self.r#type { expected.check(&statement_type)?; + + context.set_type(self.identifier.clone(), expected.clone())?; + } else { + context.set_type(self.identifier.clone(), statement_type)?; } Ok(()) @@ -41,7 +45,7 @@ impl<'src> AbstractTree for Assignment<'src> { fn run(self, context: &Context) -> Result { let value = self.statement.run(context)?; - context.set(self.identifier, value)?; + context.set_value(self.identifier, value)?; Ok(Value::none()) } @@ -69,7 +73,7 @@ mod tests { .unwrap(); assert_eq!( - context.get(&Identifier::new("foobar")), + context.get_value(&Identifier::new("foobar")), Ok(Some(Value::integer(42))) ) } diff --git a/src/abstract_tree/identifier.rs b/src/abstract_tree/identifier.rs index 9a9b653..466ab4f 100644 --- a/src/abstract_tree/identifier.rs +++ b/src/abstract_tree/identifier.rs @@ -26,15 +26,15 @@ impl Identifier { impl AbstractTree for Identifier { fn expected_type(&self, context: &Context) -> Result { - if let Some(value) = context.get(self)? { - Ok(value.r#type()) + if let Some(r#type) = context.get_type(self)? { + Ok(r#type) } else { Err(ValidationError::VariableNotFound(self.clone())) } } fn validate(&self, context: &Context) -> Result<(), ValidationError> { - if let Some(_) = context.get(self)? { + if let Some(_) = context.get_data(self)? { Ok(()) } else { Err(ValidationError::VariableNotFound(self.clone())) @@ -42,7 +42,10 @@ impl AbstractTree for Identifier { } fn run(self, context: &Context) -> Result { - let value = context.get(&self)?.unwrap_or_else(Value::none).clone(); + let value = context + .get_value(&self)? + .unwrap_or_else(Value::none) + .clone(); Ok(value) } diff --git a/src/context.rs b/src/context.rs index d3f7349..e658073 100644 --- a/src/context.rs +++ b/src/context.rs @@ -3,10 +3,20 @@ use std::{ sync::{Arc, RwLock}, }; -use crate::{abstract_tree::Identifier, error::RwLockPoisonError, Value}; +use crate::{ + abstract_tree::{Identifier, Type}, + error::RwLockPoisonError, + Value, +}; pub struct Context { - inner: Arc>>, + inner: Arc>>, +} + +#[derive(Clone, Debug)] +pub enum ValueData { + Type(Type), + Value(Value), } impl Context { @@ -16,20 +26,47 @@ impl Context { } } - pub fn with_values(values: BTreeMap) -> Self { + pub fn with_data(data: BTreeMap) -> Self { Self { - inner: Arc::new(RwLock::new(values)), + inner: Arc::new(RwLock::new(data)), } } - pub fn get(&self, identifier: &Identifier) -> Result, RwLockPoisonError> { - let value = self.inner.read()?.get(identifier).cloned(); - - Ok(value) + pub fn get_data( + &self, + identifier: &Identifier, + ) -> Result, RwLockPoisonError> { + Ok(self.inner.read()?.get(identifier).cloned()) } - pub fn set(&self, identifier: Identifier, value: Value) -> Result<(), RwLockPoisonError> { - self.inner.write()?.insert(identifier, value); + pub fn get_type(&self, identifier: &Identifier) -> Result, RwLockPoisonError> { + if let Some(ValueData::Type(r#type)) = self.inner.read()?.get(identifier) { + Ok(Some(r#type.clone())) + } else { + Ok(None) + } + } + + pub fn get_value(&self, identifier: &Identifier) -> Result, RwLockPoisonError> { + if let Some(ValueData::Value(value)) = self.inner.read()?.get(identifier) { + Ok(Some(value.clone())) + } else { + Ok(None) + } + } + + pub fn set_type(&self, identifier: Identifier, r#type: Type) -> Result<(), RwLockPoisonError> { + self.inner + .write()? + .insert(identifier, ValueData::Type(r#type)); + + Ok(()) + } + + pub fn set_value(&self, identifier: Identifier, value: Value) -> Result<(), RwLockPoisonError> { + self.inner + .write()? + .insert(identifier, ValueData::Value(value)); Ok(()) } diff --git a/src/lib.rs b/src/lib.rs index 3ada3c1..178d41e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -12,6 +12,13 @@ use lexer::lex; pub use parser::{parse, parser, DustParser}; pub use value::Value; +pub fn interpret(source: &str) -> Result> { + let context = Context::new(); + let mut interpreter = Interpreter::new(context); + + interpreter.run(source) +} + pub struct Interpreter { context: Context, } diff --git a/tests/variables.rs b/tests/variables.rs new file mode 100644 index 0000000..3e0e1cc --- /dev/null +++ b/tests/variables.rs @@ -0,0 +1,9 @@ +use dust_lang::*; + +#[test] +fn set_and_get_variable() { + assert_eq!( + interpret("foobar = true foobar"), + Ok(Value::boolean(true)) + ); +}