From ee30ae55a8b49723c64d6d00ef5948b66f14903a Mon Sep 17 00:00:00 2001 From: Jeff Date: Wed, 3 Jan 2024 14:11:19 -0500 Subject: [PATCH] Move reference counted rw lock for list values --- src/abstract_tree/for.rs | 25 ++++++++++++++++--------- src/abstract_tree/identifier.rs | 8 ++++---- src/bin/gui/app.rs | 12 ++++++------ src/interpret.rs | 30 ++++++++++++++++-------------- src/main.rs | 6 +++--- src/value/list.rs | 27 ++++++++++----------------- tests/interpret.rs | 13 +++++++++++++ 7 files changed, 68 insertions(+), 53 deletions(-) diff --git a/src/abstract_tree/for.rs b/src/abstract_tree/for.rs index 9300524..5d2f5cd 100644 --- a/src/abstract_tree/for.rs +++ b/src/abstract_tree/for.rs @@ -1,3 +1,5 @@ +use std::sync::RwLock; + use rayon::prelude::*; use serde::{Deserialize, Serialize}; use tree_sitter::Node; @@ -54,21 +56,26 @@ impl AbstractTree for For { let key = self.item_id.inner(); if self.is_async { + let context = RwLock::new(context); + values.par_iter().try_for_each(|value| { - let mut iter_context = Map::clone_from(context)?; + let mut context = match context.write() { + Ok(map) => map, + Err(error) => return Err(error.into()), + }; - iter_context.set(key.clone(), value.clone(), None); - - self.block.run(source, &mut iter_context).map(|_value| ()) + context.set(key.clone(), value.clone(), None); + self.block.run(source, &mut context).map(|_value| ()) })?; + + context.write()?.set(key.clone(), Value::none(), None); } else { - let mut loop_context = Map::clone_from(context)?; - for value in values.iter() { - loop_context.set(key.clone(), value.clone(), None); - - self.block.run(source, &mut loop_context)?; + context.set(key.clone(), value.clone(), None); + self.block.run(source, context)?; } + + context.set(key.clone(), Value::none(), None); } Ok(Value::none()) diff --git a/src/abstract_tree/identifier.rs b/src/abstract_tree/identifier.rs index 934e214..b1c0036 100644 --- a/src/abstract_tree/identifier.rs +++ b/src/abstract_tree/identifier.rs @@ -35,6 +35,10 @@ impl AbstractTree for Identifier { Ok(Identifier(text.to_string())) } + fn check_type(&self, _source: &str, _context: &Map) -> Result<()> { + Ok(()) + } + fn run(&self, _source: &str, context: &mut Map) -> Result { if let Some((value, _)) = context.variables().get(&self.0) { Ok(value.clone()) @@ -43,10 +47,6 @@ impl AbstractTree for Identifier { } } - fn check_type(&self, _source: &str, _context: &Map) -> Result<()> { - Ok(()) - } - fn expected_type(&self, context: &Map) -> Result { if let Some((_value, r#type)) = context.variables().get(&self.0) { Ok(r#type.clone()) diff --git a/src/bin/gui/app.rs b/src/bin/gui/app.rs index 34fff55..af4e250 100644 --- a/src/bin/gui/app.rs +++ b/src/bin/gui/app.rs @@ -6,21 +6,21 @@ use egui_extras::{Column, TableBuilder}; use serde::{Deserialize, Serialize}; #[derive(Deserialize, Serialize)] -pub struct App { +pub struct App<'c> { path: String, source: String, context: Map, #[serde(skip)] - interpreter: Interpreter, + interpreter: Interpreter<'c>, output: Result, error: Option, } -impl App { +impl<'c> App<'c> { pub fn new(cc: &eframe::CreationContext<'_>, path: PathBuf) -> Self { - fn create_app(path: PathBuf) -> App { + fn create_app<'c>(path: PathBuf) -> App<'c> { let context = Map::new(); - let mut interpreter = Interpreter::new(context.clone()); + let mut interpreter = Interpreter::new(&mut context); let read_source = read_to_string(&path); let source = if let Ok(source) = read_source { source @@ -54,7 +54,7 @@ impl App { } } -impl eframe::App for App { +impl<'c> eframe::App for App<'c> { /// Called by the frame work to save state before shutdown. fn save(&mut self, storage: &mut dyn eframe::Storage) { eframe::set_value(storage, eframe::APP_KEY, self); diff --git a/src/interpret.rs b/src/interpret.rs index d9baeb6..e4fd79b 100644 --- a/src/interpret.rs +++ b/src/interpret.rs @@ -18,7 +18,7 @@ use crate::{language, AbstractTree, Error, Map, Result, Root, Value}; /// assert_eq!(interpret("1 + 2 + 3"), Ok(Value::Integer(6))); /// ``` pub fn interpret(source: &str) -> Result { - interpret_with_context(source, Map::new()) + interpret_with_context(source, &mut Map::new()) } /// Interpret the given source code with the given context. @@ -40,7 +40,7 @@ pub fn interpret(source: &str) -> Result { /// Ok(Value::Integer(10)) /// ); /// ``` -pub fn interpret_with_context(source: &str, context: Map) -> Result { +pub fn interpret_with_context(source: &str, context: &mut Map) -> Result { let mut interpreter = Interpreter::new(context); let value = interpreter.run(source)?; @@ -48,15 +48,15 @@ pub fn interpret_with_context(source: &str, context: Map) -> Result { } /// A source code interpreter for the Dust language. -pub struct Interpreter { +pub struct Interpreter<'c> { parser: Parser, - context: Map, + context: &'c mut Map, syntax_tree: Option, abstract_tree: Option, } -impl Interpreter { - pub fn new(context: Map) -> Self { +impl<'c> Interpreter<'c> { + pub fn new(context: &'c mut Map) -> Self { let mut parser = Parser::new(); parser @@ -71,6 +71,14 @@ impl Interpreter { } } + pub fn context(&self) -> &Map { + &self.context + } + + pub fn context_mut(&mut self) -> &mut Map { + &mut self.context + } + pub fn parse(&mut self, source: &str) -> Result<()> { fn check_for_error(source: &str, node: Node, cursor: &mut TreeCursor) -> Result<()> { if node.is_error() { @@ -101,14 +109,14 @@ impl Interpreter { Ok(()) } - pub fn run(&mut self, source: &str) -> Result { + pub fn run(&'c mut self, source: &str) -> Result { self.parse(source)?; self.abstract_tree = if let Some(syntax_tree) = &self.syntax_tree { Some(Root::from_syntax_node( source, syntax_tree.root_node(), - &mut self.context, + self.context, )?) } else { return Err(Error::ParserCancelled); @@ -130,9 +138,3 @@ impl Interpreter { } } } - -impl Default for Interpreter { - fn default() -> Self { - Interpreter::new(Map::new()) - } -} diff --git a/src/main.rs b/src/main.rs index 2c26db8..20a2a72 100644 --- a/src/main.rs +++ b/src/main.rs @@ -72,7 +72,7 @@ fn main() { let mut parser = TSParser::new(); parser.set_language(language()).unwrap(); - let mut interpreter = Interpreter::new(context); + let mut interpreter = Interpreter::new(&mut context); if args.show_syntax_tree { interpreter.parse(&source).unwrap(); @@ -168,7 +168,7 @@ impl Highlighter for DustReadline { } fn run_cli_shell() { - let context = Map::new(); + let mut context = Map::new(); let mut rl: Editor = Editor::new().unwrap(); rl.set_helper(Some(DustReadline::new())); @@ -185,7 +185,7 @@ fn run_cli_shell() { rl.add_history_entry(line).unwrap(); - let eval_result = interpret_with_context(line, context.clone()); + let eval_result = interpret_with_context(line, &mut context); match eval_result { Ok(value) => println!("{value}"), diff --git a/src/value/list.rs b/src/value/list.rs index 975bc0b..1bf5d2f 100644 --- a/src/value/list.rs +++ b/src/value/list.rs @@ -1,13 +1,12 @@ use std::{ cmp::Ordering, fmt::{self, Display, Formatter}, - sync::{Arc, RwLock, RwLockReadGuard, RwLockWriteGuard}, }; use crate::Value; #[derive(Debug, Clone)] -pub struct List(Arc>>); +pub struct List(Vec); impl Default for List { fn default() -> Self { @@ -17,23 +16,23 @@ impl Default for List { impl List { pub fn new() -> Self { - List(Arc::new(RwLock::new(Vec::new()))) + List(Vec::new()) } pub fn with_capacity(capacity: usize) -> Self { - List(Arc::new(RwLock::new(Vec::with_capacity(capacity)))) + List(Vec::with_capacity(capacity)) } pub fn with_items(items: Vec) -> Self { - List(Arc::new(RwLock::new(items))) + List(items) } - pub fn items(&self) -> RwLockReadGuard<'_, Vec> { - self.0.read().unwrap() + pub fn items(&self) -> &Vec { + &self.0 } - pub fn items_mut(&self) -> RwLockWriteGuard<'_, Vec> { - self.0.write().unwrap() + pub fn items_mut(&mut self) -> &mut Vec { + &mut self.0 } } @@ -41,19 +40,13 @@ impl Eq for List {} impl PartialEq for List { fn eq(&self, other: &Self) -> bool { - let left = self.0.read().unwrap().clone().into_iter(); - let right = other.0.read().unwrap().clone().into_iter(); - - left.eq(right) + self.0.eq(&other.0) } } impl Ord for List { fn cmp(&self, other: &Self) -> Ordering { - let left = self.0.read().unwrap().clone().into_iter(); - let right = other.0.read().unwrap().clone().into_iter(); - - left.cmp(right) + self.0.cmp(&other.0) } } diff --git a/tests/interpret.rs b/tests/interpret.rs index a9cfa3b..adc2288 100644 --- a/tests/interpret.rs +++ b/tests/interpret.rs @@ -76,6 +76,19 @@ mod for_loop { result ); } + + #[test] + fn async_modify_value() { + let result = interpret( + " + list = [] + async for i in [1 2 3] { list += i } + length(list) + ", + ); + + assert_eq!(Ok(Value::Integer(3)), result); + } } mod logic {