diff --git a/Cargo.toml b/Cargo.toml index 5906cc8..98f56e8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,4 +14,4 @@ opt-level = 3 [dependencies] ariadne = "0.4.0" -chumsky = "1.0.0-alpha.6" +chumsky = { version = "1.0.0-alpha.6", features = ["pratt"] } diff --git a/src/lib.rs b/src/lib.rs index fa90e03..1bbc23d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,7 +4,7 @@ use std::{ ops::Range, }; -use chumsky::{prelude::*, Parser}; +use chumsky::{pratt::*, prelude::*, Parser}; #[derive(Clone, Debug, PartialEq)] pub enum Statement { @@ -32,22 +32,15 @@ pub enum Expression { } #[derive(Clone, Debug, PartialEq)] -pub struct Logic { - left: Expression, - operator: LogicOperator, - right: Expression, -} - -#[derive(Clone, Debug, PartialEq)] -pub enum LogicOperator { - Equal, - NotEqual, - Greater, - Less, - GreaterOrEqual, - LessOrEqual, - And, - Or, +pub enum Logic { + Equal(Expression, Expression), + NotEqual(Expression, Expression), + Greater(Expression, Expression), + Less(Expression, Expression), + GreaterOrEqual(Expression, Expression), + LessOrEqual(Expression, Expression), + And(Expression, Expression), + Or(Expression, Expression), } #[derive(Clone, Debug, PartialEq)] @@ -84,7 +77,9 @@ impl Display for Value { } } -pub fn parser<'src>() -> impl Parser<'src, &'src str, Statement> { +pub fn parser<'src>() -> impl Parser<'src, &'src str, Expression> { + let operator = |text| just(text).padded(); + let value = recursive(|value| { let boolean = just("true") .or(just("false")) @@ -102,24 +97,22 @@ pub fn parser<'src>() -> impl Parser<'src, &'src str, Statement> { let float = choice((float_numeric, float_other)); - let integer = just('-').or_not().then(text::int(10).padded()).map( - |(negative, integer_text): (Option, &str)| { - let integer = integer_text.parse::().unwrap(); + let integer = just('-') + .or_not() + .then(text::int(10).padded()) + .to_slice() + .map(|text: &str| { + let integer = text.parse::().unwrap(); - if negative.is_some() { - Value::Integer(-integer) - } else { - Value::Integer(integer) - } - }, - ); + Value::Integer(integer) + }); let delimited_string = |delimiter| { just(delimiter) .ignore_then(none_of(delimiter).repeated()) .then_ignore(just(delimiter)) .to_slice() - .map(|text: &str| Value::String(text.to_string())) + .map(|text: &str| Value::String(text[1..text.len() - 1].to_string())) }; let string = choice(( @@ -140,52 +133,36 @@ pub fn parser<'src>() -> impl Parser<'src, &'src str, Statement> { choice((boolean, float, integer, string, list)) }); - let expression = recursive(|expression| { - let logic = expression - .clone() - .then(choice(( - just("==").to(LogicOperator::Equal), - just("!=").to(LogicOperator::NotEqual), - just(">").to(LogicOperator::Greater), - just("<").to(LogicOperator::Less), - just(">=").to(LogicOperator::GreaterOrEqual), - just("<=").to(LogicOperator::LessOrEqual), - just("&&").to(LogicOperator::And), - just("||").to(LogicOperator::Or), - ))) - .padded() - .then(expression) - .map(|((left, operator), right)| { - Expression::Logic(Box::new(Logic { - left, - operator, - right, - })) - }); + let value_expression = value.map(|value| Expression::Value(value)); - let value = value.map(|value| Expression::Value(value)); + let logic_expression = value_expression.pratt(( + infix(left(1), operator("=="), |left, right| { + Expression::Logic(Box::new(Logic::Equal(left, right))) + }), + infix(left(1), operator("!="), |left, right| { + Expression::Logic(Box::new(Logic::NotEqual(left, right))) + }), + infix(left(1), operator(">"), |left, right| { + Expression::Logic(Box::new(Logic::Greater(left, right))) + }), + infix(left(1), operator("<"), |left, right| { + Expression::Logic(Box::new(Logic::Less(left, right))) + }), + infix(left(1), operator(">="), |left, right| { + Expression::Logic(Box::new(Logic::GreaterOrEqual(left, right))) + }), + infix(left(1), operator("<="), |left, right| { + Expression::Logic(Box::new(Logic::LessOrEqual(left, right))) + }), + infix(left(1), operator("&&"), |left, right| { + Expression::Logic(Box::new(Logic::And(left, right))) + }), + infix(left(1), operator("||"), |left, right| { + Expression::Logic(Box::new(Logic::Or(left, right))) + }), + )); - choice((logic, value)) - }); - - let statement = recursive(|statement| { - let assignment = text::ident() - .map(|text| Identifier::new(text)) - .then(just("=").padded()) - .then(statement) - .map(|((identifier, _), statement)| { - Statement::Assignment(Box::new(Assignment { - identifier, - statement, - })) - }); - - let expression = expression.map(|expression| Statement::Expression(expression)); - - choice((assignment, expression)) - }); - - statement.then_ignore(end()) + logic_expression.then_ignore(end()) } #[cfg(test)] @@ -196,15 +173,15 @@ mod tests { fn parse_list() { assert_eq!( parser().parse("[]").unwrap(), - Statement::value(Value::List(vec![])) + Expression::Value(Value::List(vec![])) ); assert_eq!( parser().parse("[42]").unwrap(), - Statement::value(Value::List(vec![Value::Integer(42)])) + Expression::Value(Value::List(vec![Value::Integer(42)])) ); assert_eq!( parser().parse("[42, 'foo', \"bar\", [1, 2, 3,]]").unwrap(), - Statement::value(Value::List(vec![ + Expression::Value(Value::List(vec![ Value::Integer(42), Value::String("foo".to_string()), Value::String("bar".to_string()), @@ -221,7 +198,7 @@ mod tests { fn parse_true() { assert_eq!( parser().parse("true").unwrap(), - Statement::value(Value::Boolean(true)) + Expression::Value(Value::Boolean(true)) ); } @@ -229,7 +206,7 @@ mod tests { fn parse_false() { assert_eq!( parser().parse("false").unwrap(), - Statement::value(Value::Boolean(false)) + Expression::Value(Value::Boolean(false)) ); } @@ -237,25 +214,25 @@ mod tests { fn parse_positive_float() { assert_eq!( parser().parse("0.0").unwrap(), - Statement::value(Value::Float(0.0)) + Expression::Value(Value::Float(0.0)) ); assert_eq!( parser().parse("42.0").unwrap(), - Statement::value(Value::Float(42.0)) + Expression::Value(Value::Float(42.0)) ); let max_float = f64::MAX.to_string() + ".0"; assert_eq!( parser().parse(&max_float).unwrap(), - Statement::value(Value::Float(f64::MAX)) + Expression::Value(Value::Float(f64::MAX)) ); let min_positive_float = f64::MIN_POSITIVE.to_string(); assert_eq!( parser().parse(&min_positive_float).unwrap(), - Statement::value(Value::Float(f64::MIN_POSITIVE)) + Expression::Value(Value::Float(f64::MIN_POSITIVE)) ); } @@ -263,25 +240,25 @@ mod tests { fn parse_negative_float() { assert_eq!( parser().parse("-0.0").unwrap(), - Statement::value(Value::Float(-0.0)) + Expression::Value(Value::Float(-0.0)) ); assert_eq!( parser().parse("-42.0").unwrap(), - Statement::value(Value::Float(-42.0)) + Expression::Value(Value::Float(-42.0)) ); let min_float = f64::MIN.to_string() + ".0"; assert_eq!( parser().parse(&min_float).unwrap(), - Statement::value(Value::Float(f64::MIN)) + Expression::Value(Value::Float(f64::MIN)) ); - let max_negative_float = f64::MIN_POSITIVE.to_string(); + let max_negative_float = format!("-{}", f64::MIN_POSITIVE); assert_eq!( parser().parse(&max_negative_float).unwrap(), - Statement::value(Value::Float(-f64::MIN_POSITIVE)) + Expression::Value(Value::Float(-f64::MIN_POSITIVE)) ); } @@ -289,16 +266,14 @@ mod tests { fn parse_other_float() { assert_eq!( parser().parse("Infinity").unwrap(), - Statement::value(Value::Float(f64::INFINITY)) + Expression::Value(Value::Float(f64::INFINITY)) ); assert_eq!( parser().parse("-Infinity").unwrap(), - Statement::value(Value::Float(f64::NEG_INFINITY)) + Expression::Value(Value::Float(f64::NEG_INFINITY)) ); - if let Statement::Expression(Expression::Value(Value::Float(float))) = - parser().parse("NaN").unwrap() - { + if let Expression::Value(Value::Float(float)) = parser().parse("NaN").unwrap() { assert!(float.is_nan()) } else { panic!("Expected a float.") @@ -309,54 +284,54 @@ mod tests { fn parse_positive_integer() { assert_eq!( parser().parse("0").unwrap(), - Statement::value(Value::Integer(0)) + Expression::Value(Value::Integer(0)) ); assert_eq!( parser().parse("1").unwrap(), - Statement::value(Value::Integer(1)) + Expression::Value(Value::Integer(1)) ); assert_eq!( parser().parse("2").unwrap(), - Statement::value(Value::Integer(2)) + Expression::Value(Value::Integer(2)) ); assert_eq!( parser().parse("3").unwrap(), - Statement::value(Value::Integer(3)) + Expression::Value(Value::Integer(3)) ); assert_eq!( parser().parse("4").unwrap(), - Statement::value(Value::Integer(4)) + Expression::Value(Value::Integer(4)) ); assert_eq!( parser().parse("5").unwrap(), - Statement::value(Value::Integer(5)) + Expression::Value(Value::Integer(5)) ); assert_eq!( parser().parse("6").unwrap(), - Statement::value(Value::Integer(6)) + Expression::Value(Value::Integer(6)) ); assert_eq!( parser().parse("7").unwrap(), - Statement::value(Value::Integer(7)) + Expression::Value(Value::Integer(7)) ); assert_eq!( parser().parse("8").unwrap(), - Statement::value(Value::Integer(8)) + Expression::Value(Value::Integer(8)) ); assert_eq!( parser().parse("9").unwrap(), - Statement::value(Value::Integer(9)) + Expression::Value(Value::Integer(9)) ); assert_eq!( parser().parse("42").unwrap(), - Statement::value(Value::Integer(42)) + Expression::Value(Value::Integer(42)) ); let maximum_integer = i64::MAX.to_string(); assert_eq!( parser().parse(&maximum_integer).unwrap(), - Statement::value(Value::Integer(i64::MAX)) + Expression::Value(Value::Integer(i64::MAX)) ); } @@ -364,54 +339,54 @@ mod tests { fn parse_negative_integer() { assert_eq!( parser().parse("-0").unwrap(), - Statement::value(Value::Integer(-0)) + Expression::Value(Value::Integer(-0)) ); assert_eq!( parser().parse("-1").unwrap(), - Statement::value(Value::Integer(-1)) + Expression::Value(Value::Integer(-1)) ); assert_eq!( parser().parse("-2").unwrap(), - Statement::value(Value::Integer(-2)) + Expression::Value(Value::Integer(-2)) ); assert_eq!( parser().parse("-3").unwrap(), - Statement::value(Value::Integer(-3)) + Expression::Value(Value::Integer(-3)) ); assert_eq!( parser().parse("-4").unwrap(), - Statement::value(Value::Integer(-4)) + Expression::Value(Value::Integer(-4)) ); assert_eq!( parser().parse("-5").unwrap(), - Statement::value(Value::Integer(-5)) + Expression::Value(Value::Integer(-5)) ); assert_eq!( parser().parse("-6").unwrap(), - Statement::value(Value::Integer(-6)) + Expression::Value(Value::Integer(-6)) ); assert_eq!( parser().parse("-7").unwrap(), - Statement::value(Value::Integer(-7)) + Expression::Value(Value::Integer(-7)) ); assert_eq!( parser().parse("-8").unwrap(), - Statement::value(Value::Integer(-8)) + Expression::Value(Value::Integer(-8)) ); assert_eq!( parser().parse("-9").unwrap(), - Statement::value(Value::Integer(-9)) + Expression::Value(Value::Integer(-9)) ); assert_eq!( parser().parse("-42").unwrap(), - Statement::value(Value::Integer(-42)) + Expression::Value(Value::Integer(-42)) ); let minimum_integer = i64::MIN.to_string(); assert_eq!( parser().parse(&minimum_integer).unwrap(), - Statement::value(Value::Integer(i64::MIN)) + Expression::Value(Value::Integer(i64::MIN)) ); } @@ -419,19 +394,19 @@ mod tests { fn double_quoted_string() { assert_eq!( parser().parse("\"\"").unwrap(), - Statement::value(Value::String("".to_string())) + Expression::Value(Value::String("".to_string())) ); assert_eq!( parser().parse("\"1\"").unwrap(), - Statement::value(Value::String("1".to_string())) + Expression::Value(Value::String("1".to_string())) ); assert_eq!( parser().parse("\"42\"").unwrap(), - Statement::value(Value::String("42".to_string())) + Expression::Value(Value::String("42".to_string())) ); assert_eq!( parser().parse("\"foobar\"").unwrap(), - Statement::value(Value::String("foobar".to_string())) + Expression::Value(Value::String("foobar".to_string())) ); } @@ -439,19 +414,19 @@ mod tests { fn single_quoted_string() { assert_eq!( parser().parse("''").unwrap(), - Statement::value(Value::String("".to_string())) + Expression::Value(Value::String("".to_string())) ); assert_eq!( parser().parse("'1'").unwrap(), - Statement::value(Value::String("1".to_string())) + Expression::Value(Value::String("1".to_string())) ); assert_eq!( parser().parse("'42'").unwrap(), - Statement::value(Value::String("42".to_string())) + Expression::Value(Value::String("42".to_string())) ); assert_eq!( parser().parse("'foobar'").unwrap(), - Statement::value(Value::String("foobar".to_string())) + Expression::Value(Value::String("foobar".to_string())) ); } @@ -459,19 +434,19 @@ mod tests { fn grave_quoted_string() { assert_eq!( parser().parse("``").unwrap(), - Statement::value(Value::String("".to_string())) + Expression::Value(Value::String("".to_string())) ); assert_eq!( parser().parse("`1`").unwrap(), - Statement::value(Value::String("1".to_string())) + Expression::Value(Value::String("1".to_string())) ); assert_eq!( parser().parse("`42`").unwrap(), - Statement::value(Value::String("42".to_string())) + Expression::Value(Value::String("42".to_string())) ); assert_eq!( parser().parse("`foobar`").unwrap(), - Statement::value(Value::String("foobar".to_string())) + Expression::Value(Value::String("foobar".to_string())) ); } }