diff --git a/src/function/mod.rs b/src/function/mod.rs index 21f27e5..75d5a4c 100644 --- a/src/function/mod.rs +++ b/src/function/mod.rs @@ -4,6 +4,26 @@ use crate::{error::EvalexprResult, value::Value}; pub(crate) mod builtin; +/// A helper trait to enable cloning through `Fn` trait objects. +trait ClonableFn +where + Self: Fn(&Value) -> EvalexprResult, + Self: Send + Sync + 'static, +{ + fn dyn_clone(&self) -> Box; +} + +impl ClonableFn for F +where + F: Fn(&Value) -> EvalexprResult, + F: Send + Sync + 'static, + F: Clone, +{ + fn dyn_clone(&self) -> Box { + Box::new(self.clone()) as _ + } +} + /// A user-defined function. /// Functions can be used in expressions by storing them in a `Context`. /// @@ -18,17 +38,31 @@ pub(crate) mod builtin; /// })).unwrap(); // Do proper error handling here /// assert_eq!(eval_with_context("id(4)", &context), Ok(Value::from(4))); /// ``` -#[derive(Clone)] pub struct Function { - function: fn(&Value) -> EvalexprResult, + function: Box, +} + +impl Clone for Function { + fn clone(&self) -> Self { + Self { + function: self.function.dyn_clone(), + } + } } impl Function { /// Creates a user-defined function. /// - /// The `function` is a boxed function that takes a `Value` and returns a `EvalexprResult`. - pub fn new(function: fn(&Value) -> EvalexprResult) -> Self { - Self { function } + /// The `function` is boxed for storage. + pub fn new(function: F) -> Self + where + F: Fn(&Value) -> EvalexprResult, + F: Send + Sync + 'static, + F: Clone, + { + Self { + function: Box::new(function) as _, + } } pub(crate) fn call(&self, argument: &Value) -> EvalexprResult { diff --git a/tests/integration.rs b/tests/integration.rs index ab3af88..aa40aff 100644 --- a/tests/integration.rs +++ b/tests/integration.rs @@ -276,6 +276,36 @@ fn test_n_ary_functions() { ); } +#[test] +fn test_capturing_functions() { + let mut context = HashMapContext::new(); + // this variable is captured by the function + let three = 3; + context + .set_function( + "mult_3".into(), + Function::new(move |argument| { + if let Value::Int(int) = argument { + Ok(Value::Int(int * three)) + } else if let Value::Float(float) = argument { + Ok(Value::Float(float * three as f64)) + } else { + Err(EvalexprError::expected_number(argument.clone())) + } + }), + ) + .unwrap(); + + let four = 4; + context + .set_function("function_four".into(), Function::new(move |_| Ok(Value::Int(four)))) + .unwrap(); + + assert_eq!(eval_with_context("mult_3 2", &context), Ok(Value::Int(6))); + assert_eq!(eval_with_context("mult_3(3)", &context), Ok(Value::Int(9))); + assert_eq!(eval_with_context("mult_3(function_four())", &context), Ok(Value::Int(12))); +} + #[test] fn test_builtin_functions() { // Log