diff --git a/src/lib.rs b/src/lib.rs index 1bbc23d..3a07342 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,46 +1,41 @@ -use std::{ - collections::BTreeMap, - fmt::{self, Display, Formatter}, - ops::Range, -}; +use std::{collections::BTreeMap, ops::Range}; -use chumsky::{pratt::*, prelude::*, Parser}; +use chumsky::{prelude::*, Parser}; #[derive(Clone, Debug, PartialEq)] pub enum Statement { - Assignment(Box), - Expression(Expression), - Sequence(Vec), -} - -impl Statement { - pub fn value(value: Value) -> Statement { - Statement::Expression(Expression::Value(value)) - } -} - -#[derive(Clone, Debug, PartialEq)] -pub struct Assignment { - identifier: Identifier, - statement: Statement, -} - -#[derive(Clone, Debug, PartialEq)] -pub enum Expression { - Logic(Box), + Assignment(Assignment), + Identifier(Identifier), + Logic(Logic), Value(Value), } #[derive(Clone, Debug, PartialEq)] -pub enum Logic { - Equal(Expression, Expression), - NotEqual(Expression, Expression), - Greater(Expression, Expression), - Less(Expression, Expression), - GreaterOrEqual(Expression, Expression), - LessOrEqual(Expression, Expression), - And(Expression, Expression), - Or(Expression, Expression), +pub struct Identifier(String); + +#[derive(Clone, Debug, PartialEq)] +pub struct Assignment { + identifier: Identifier, + value: Value, +} + +#[derive(Clone, Debug, PartialEq)] +pub struct Logic { + left: LogicExpression, + operator: LogicOperator, + right: LogicExpression, +} + +#[derive(Clone, Debug, PartialEq)] +pub enum LogicOperator { + Equal, +} + +#[derive(Clone, Debug, PartialEq)] +pub enum LogicExpression { + Identifier(Identifier), + Logic(Box), + Value(Value), } #[derive(Clone, Debug, PartialEq)] @@ -49,36 +44,13 @@ pub enum Value { Float(f64), Integer(i64), List(Vec), - Map(BTreeMap), + Map(BTreeMap), Range(Range), String(String), } -#[derive(Clone, Debug, PartialEq)] -pub struct Identifier(String); - -impl Identifier { - pub fn new(text: impl ToString) -> Self { - Identifier(text.to_string()) - } -} - -impl Display for Value { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - match self { - Value::Boolean(boolean) => write!(f, "{boolean}"), - Value::Float(float) => write!(f, "{float}"), - Value::Integer(integer) => write!(f, "{integer}"), - Value::List(_list) => todo!(), - Value::Map(_map) => todo!(), - Value::Range(range) => write!(f, "{}..{}", range.start, range.end), - Value::String(string) => write!(f, "{string}"), - } - } -} - -pub fn parser<'src>() -> impl Parser<'src, &'src str, Expression> { - let operator = |text| just(text).padded(); +pub fn parser<'src>() -> impl Parser<'src, &'src str, Statement, extra::Err>> { + let operator = |text: &'src str| just(text).padded(); let value = recursive(|value| { let boolean = just("true") @@ -133,55 +105,98 @@ pub fn parser<'src>() -> impl Parser<'src, &'src str, Expression> { choice((boolean, float, integer, string, list)) }); - let value_expression = value.map(|value| Expression::Value(value)); + let identifier = text::ident().map(|text: &str| Identifier(text.to_string())); - let logic_expression = value_expression.pratt(( - infix(left(1), operator("=="), |left, right| { - Expression::Logic(Box::new(Logic::Equal(left, right))) - }), - infix(left(1), operator("!="), |left, right| { - Expression::Logic(Box::new(Logic::NotEqual(left, right))) - }), - infix(left(1), operator(">"), |left, right| { - Expression::Logic(Box::new(Logic::Greater(left, right))) - }), - infix(left(1), operator("<"), |left, right| { - Expression::Logic(Box::new(Logic::Less(left, right))) - }), - infix(left(1), operator(">="), |left, right| { - Expression::Logic(Box::new(Logic::GreaterOrEqual(left, right))) - }), - infix(left(1), operator("<="), |left, right| { - Expression::Logic(Box::new(Logic::LessOrEqual(left, right))) - }), - infix(left(1), operator("&&"), |left, right| { - Expression::Logic(Box::new(Logic::And(left, right))) - }), - infix(left(1), operator("||"), |left, right| { - Expression::Logic(Box::new(Logic::Or(left, right))) - }), - )); + let assignment = identifier + .then_ignore(operator("=")) + .then(value.clone()) + .map(|(identifier, value)| Assignment { identifier, value }); - logic_expression.then_ignore(end()) + let logic = recursive(|logic| { + choice(( + value.clone().map(|value| LogicExpression::Value(value)), + identifier.map(|identifier| LogicExpression::Identifier(identifier)), + logic + .clone() + .map(|logic| LogicExpression::Logic(Box::new(logic))), + )) + .then(operator("==").map(|_| LogicOperator::Equal)) + .then(choice(( + value.clone().map(|value| LogicExpression::Value(value)), + identifier.map(|identifier| LogicExpression::Identifier(identifier)), + logic.map(|logic| LogicExpression::Logic(Box::new(logic))), + ))) + .map(|((left, operator), right)| Logic { + left, + operator, + right, + }) + }); + + choice(( + logic.map(|logic| Statement::Logic(logic)), + assignment.map(|assignment| Statement::Assignment(assignment)), + value.map(|value| Statement::Value(value)), + identifier.map(|identifier| Statement::Identifier(identifier)), + )) } #[cfg(test)] mod tests { use super::*; + #[test] + fn parse_identifier() { + assert_eq!( + parser().parse("x").unwrap(), + Statement::Identifier(Identifier("x".to_string())) + ); + assert_eq!( + parser().parse("foobar").unwrap(), + Statement::Identifier(Identifier("foobar".to_string())), + ); + assert_eq!( + parser().parse("HELLO").unwrap(), + Statement::Identifier(Identifier("HELLO".to_string())), + ); + } + + #[test] + fn parse_assignment() { + assert_eq!( + parser().parse("foobar=1").unwrap(), + Statement::Assignment(Assignment { + identifier: Identifier("foobar".to_string()), + value: Value::Integer(1) + }) + ); + } + + #[test] + fn parse_logic() { + assert_eq!( + parser().parse("x == 1").unwrap(), + Statement::Logic(Logic { + left: LogicExpression::Identifier(Identifier("x".to_string())), + operator: LogicOperator::Equal, + right: LogicExpression::Value(Value::Integer(1)) + }) + ); + } + #[test] fn parse_list() { assert_eq!( parser().parse("[]").unwrap(), - Expression::Value(Value::List(vec![])) + Statement::Value(Value::List(vec![])) ); assert_eq!( parser().parse("[42]").unwrap(), - Expression::Value(Value::List(vec![Value::Integer(42)])) + Statement::Value(Value::List(vec![Value::Integer(42)])) ); assert_eq!( parser().parse("[42, 'foo', \"bar\", [1, 2, 3,]]").unwrap(), - Expression::Value(Value::List(vec![ + Statement::Value(Value::List(vec![ Value::Integer(42), Value::String("foo".to_string()), Value::String("bar".to_string()), @@ -198,7 +213,7 @@ mod tests { fn parse_true() { assert_eq!( parser().parse("true").unwrap(), - Expression::Value(Value::Boolean(true)) + Statement::Value(Value::Boolean(true)) ); } @@ -206,7 +221,7 @@ mod tests { fn parse_false() { assert_eq!( parser().parse("false").unwrap(), - Expression::Value(Value::Boolean(false)) + Statement::Value(Value::Boolean(false)) ); } @@ -214,25 +229,25 @@ mod tests { fn parse_positive_float() { assert_eq!( parser().parse("0.0").unwrap(), - Expression::Value(Value::Float(0.0)) + Statement::Value(Value::Float(0.0)) ); assert_eq!( parser().parse("42.0").unwrap(), - Expression::Value(Value::Float(42.0)) + Statement::Value(Value::Float(42.0)) ); let max_float = f64::MAX.to_string() + ".0"; assert_eq!( parser().parse(&max_float).unwrap(), - Expression::Value(Value::Float(f64::MAX)) + Statement::Value(Value::Float(f64::MAX)) ); let min_positive_float = f64::MIN_POSITIVE.to_string(); assert_eq!( parser().parse(&min_positive_float).unwrap(), - Expression::Value(Value::Float(f64::MIN_POSITIVE)) + Statement::Value(Value::Float(f64::MIN_POSITIVE)) ); } @@ -240,25 +255,25 @@ mod tests { fn parse_negative_float() { assert_eq!( parser().parse("-0.0").unwrap(), - Expression::Value(Value::Float(-0.0)) + Statement::Value(Value::Float(-0.0)) ); assert_eq!( parser().parse("-42.0").unwrap(), - Expression::Value(Value::Float(-42.0)) + Statement::Value(Value::Float(-42.0)) ); let min_float = f64::MIN.to_string() + ".0"; assert_eq!( parser().parse(&min_float).unwrap(), - Expression::Value(Value::Float(f64::MIN)) + Statement::Value(Value::Float(f64::MIN)) ); let max_negative_float = format!("-{}", f64::MIN_POSITIVE); assert_eq!( parser().parse(&max_negative_float).unwrap(), - Expression::Value(Value::Float(-f64::MIN_POSITIVE)) + Statement::Value(Value::Float(-f64::MIN_POSITIVE)) ); } @@ -266,14 +281,14 @@ mod tests { fn parse_other_float() { assert_eq!( parser().parse("Infinity").unwrap(), - Expression::Value(Value::Float(f64::INFINITY)) + Statement::Value(Value::Float(f64::INFINITY)) ); assert_eq!( parser().parse("-Infinity").unwrap(), - Expression::Value(Value::Float(f64::NEG_INFINITY)) + Statement::Value(Value::Float(f64::NEG_INFINITY)) ); - if let Expression::Value(Value::Float(float)) = parser().parse("NaN").unwrap() { + if let Statement::Value(Value::Float(float)) = parser().parse("NaN").unwrap() { assert!(float.is_nan()) } else { panic!("Expected a float.") @@ -284,54 +299,54 @@ mod tests { fn parse_positive_integer() { assert_eq!( parser().parse("0").unwrap(), - Expression::Value(Value::Integer(0)) + Statement::Value(Value::Integer(0)) ); assert_eq!( parser().parse("1").unwrap(), - Expression::Value(Value::Integer(1)) + Statement::Value(Value::Integer(1)) ); assert_eq!( parser().parse("2").unwrap(), - Expression::Value(Value::Integer(2)) + Statement::Value(Value::Integer(2)) ); assert_eq!( parser().parse("3").unwrap(), - Expression::Value(Value::Integer(3)) + Statement::Value(Value::Integer(3)) ); assert_eq!( parser().parse("4").unwrap(), - Expression::Value(Value::Integer(4)) + Statement::Value(Value::Integer(4)) ); assert_eq!( parser().parse("5").unwrap(), - Expression::Value(Value::Integer(5)) + Statement::Value(Value::Integer(5)) ); assert_eq!( parser().parse("6").unwrap(), - Expression::Value(Value::Integer(6)) + Statement::Value(Value::Integer(6)) ); assert_eq!( parser().parse("7").unwrap(), - Expression::Value(Value::Integer(7)) + Statement::Value(Value::Integer(7)) ); assert_eq!( parser().parse("8").unwrap(), - Expression::Value(Value::Integer(8)) + Statement::Value(Value::Integer(8)) ); assert_eq!( parser().parse("9").unwrap(), - Expression::Value(Value::Integer(9)) + Statement::Value(Value::Integer(9)) ); assert_eq!( parser().parse("42").unwrap(), - Expression::Value(Value::Integer(42)) + Statement::Value(Value::Integer(42)) ); let maximum_integer = i64::MAX.to_string(); assert_eq!( parser().parse(&maximum_integer).unwrap(), - Expression::Value(Value::Integer(i64::MAX)) + Statement::Value(Value::Integer(i64::MAX)) ); } @@ -339,54 +354,54 @@ mod tests { fn parse_negative_integer() { assert_eq!( parser().parse("-0").unwrap(), - Expression::Value(Value::Integer(-0)) + Statement::Value(Value::Integer(-0)) ); assert_eq!( parser().parse("-1").unwrap(), - Expression::Value(Value::Integer(-1)) + Statement::Value(Value::Integer(-1)) ); assert_eq!( parser().parse("-2").unwrap(), - Expression::Value(Value::Integer(-2)) + Statement::Value(Value::Integer(-2)) ); assert_eq!( parser().parse("-3").unwrap(), - Expression::Value(Value::Integer(-3)) + Statement::Value(Value::Integer(-3)) ); assert_eq!( parser().parse("-4").unwrap(), - Expression::Value(Value::Integer(-4)) + Statement::Value(Value::Integer(-4)) ); assert_eq!( parser().parse("-5").unwrap(), - Expression::Value(Value::Integer(-5)) + Statement::Value(Value::Integer(-5)) ); assert_eq!( parser().parse("-6").unwrap(), - Expression::Value(Value::Integer(-6)) + Statement::Value(Value::Integer(-6)) ); assert_eq!( parser().parse("-7").unwrap(), - Expression::Value(Value::Integer(-7)) + Statement::Value(Value::Integer(-7)) ); assert_eq!( parser().parse("-8").unwrap(), - Expression::Value(Value::Integer(-8)) + Statement::Value(Value::Integer(-8)) ); assert_eq!( parser().parse("-9").unwrap(), - Expression::Value(Value::Integer(-9)) + Statement::Value(Value::Integer(-9)) ); assert_eq!( parser().parse("-42").unwrap(), - Expression::Value(Value::Integer(-42)) + Statement::Value(Value::Integer(-42)) ); let minimum_integer = i64::MIN.to_string(); assert_eq!( parser().parse(&minimum_integer).unwrap(), - Expression::Value(Value::Integer(i64::MIN)) + Statement::Value(Value::Integer(i64::MIN)) ); } @@ -394,19 +409,19 @@ mod tests { fn double_quoted_string() { assert_eq!( parser().parse("\"\"").unwrap(), - Expression::Value(Value::String("".to_string())) + Statement::Value(Value::String("".to_string())) ); assert_eq!( parser().parse("\"1\"").unwrap(), - Expression::Value(Value::String("1".to_string())) + Statement::Value(Value::String("1".to_string())) ); assert_eq!( parser().parse("\"42\"").unwrap(), - Expression::Value(Value::String("42".to_string())) + Statement::Value(Value::String("42".to_string())) ); assert_eq!( parser().parse("\"foobar\"").unwrap(), - Expression::Value(Value::String("foobar".to_string())) + Statement::Value(Value::String("foobar".to_string())) ); } @@ -414,19 +429,19 @@ mod tests { fn single_quoted_string() { assert_eq!( parser().parse("''").unwrap(), - Expression::Value(Value::String("".to_string())) + Statement::Value(Value::String("".to_string())) ); assert_eq!( parser().parse("'1'").unwrap(), - Expression::Value(Value::String("1".to_string())) + Statement::Value(Value::String("1".to_string())) ); assert_eq!( parser().parse("'42'").unwrap(), - Expression::Value(Value::String("42".to_string())) + Statement::Value(Value::String("42".to_string())) ); assert_eq!( parser().parse("'foobar'").unwrap(), - Expression::Value(Value::String("foobar".to_string())) + Statement::Value(Value::String("foobar".to_string())) ); } @@ -434,19 +449,19 @@ mod tests { fn grave_quoted_string() { assert_eq!( parser().parse("``").unwrap(), - Expression::Value(Value::String("".to_string())) + Statement::Value(Value::String("".to_string())) ); assert_eq!( parser().parse("`1`").unwrap(), - Expression::Value(Value::String("1".to_string())) + Statement::Value(Value::String("1".to_string())) ); assert_eq!( parser().parse("`42`").unwrap(), - Expression::Value(Value::String("42".to_string())) + Statement::Value(Value::String("42".to_string())) ); assert_eq!( parser().parse("`foobar`").unwrap(), - Expression::Value(Value::String("foobar".to_string())) + Statement::Value(Value::String("foobar".to_string())) ); } }