@@ 0,0 1,129 @@
+use sqlparser::ast::Expr;
+
+#[derive(Debug, thiserror::Error)]
+enum Error<'sql> {
+ #[error("subquery detected: {0}")]
+ SubqueryDetected(&'sql sqlparser::ast::Query),
+ #[error("function call detected: {0}")]
+ FunctionCallDetected(&'sql sqlparser::ast::Function),
+ #[error("sql parser error: {0}")]
+ Sql(#[from] sqlparser::parser::ParserError),
+}
+
+fn sanitize(expr: &Expr) -> Result<(), Error> {
+ match expr {
+ Expr::Identifier(_) => Ok(()),
+ Expr::CompoundIdentifier(_) => Ok(()),
+ Expr::JsonAccess { left, operator: _, right } => sanitize(left)
+ .and(sanitize(right)),
+ Expr::CompositeAccess { expr, key: _ } => sanitize(expr),
+ Expr::IsFalse(subexpr) => sanitize(subexpr),
+ Expr::IsNotFalse(subexpr) => sanitize(subexpr),
+ Expr::IsTrue(subexpr) => sanitize(subexpr),
+ Expr::IsNotTrue(subexpr) => sanitize(subexpr),
+ Expr::IsNull(subexpr) => sanitize(subexpr),
+ Expr::IsNotNull(subexpr) => sanitize(subexpr),
+ Expr::IsUnknown(subexpr) => sanitize(subexpr),
+ Expr::IsNotUnknown(subexpr) => sanitize(subexpr),
+ Expr::IsDistinctFrom(left, right) => sanitize(left)
+ .and(sanitize(right)),
+ Expr::IsNotDistinctFrom(left, right) => sanitize(left)
+ .and(sanitize(right)),
+ Expr::InList { expr, list, negated: _ } => sanitize(expr)
+ .and(list.iter().try_for_each(sanitize)),
+ Expr::InSubquery { expr: _, subquery, negated: _ } => Err(Error::SubqueryDetected(subquery.as_ref())),
+ Expr::InUnnest { expr, array_expr, negated: _ } => sanitize(expr).and(sanitize(array_expr)),
+ Expr::Between { expr, negated: _, low, high } => sanitize(expr)
+ .and(sanitize(low))
+ .and(sanitize(high)),
+ Expr::BinaryOp { left, op: _, right } => sanitize(left)
+ .and(sanitize(right)),
+ Expr::Like { negated: _, expr, pattern, escape_char: _ } => sanitize(expr)
+ .and(sanitize(pattern)),
+ Expr::ILike { negated: _, expr, pattern, escape_char: _ } => sanitize(expr).and(sanitize(pattern)),
+ Expr::SimilarTo { negated: _, expr, pattern, escape_char: _ } => sanitize(expr).and(sanitize(pattern)),
+ Expr::RLike { negated: _, expr, pattern, regexp: _ } => sanitize(expr).and(sanitize(pattern)),
+ Expr::AnyOp { left, compare_op: _, right } => sanitize(left).and(sanitize(right)),
+ Expr::AllOp { left, compare_op: _, right } => sanitize(left).and(sanitize(right)),
+ Expr::UnaryOp { op: _, expr } => sanitize(expr),
+ Expr::Convert { expr, data_type: _, charset: _, target_before_value: _ } => sanitize(expr),
+ Expr::Cast { expr, data_type: _, format: _ } => sanitize(expr),
+ Expr::TryCast { expr, data_type: _, format: _ } => sanitize(expr),
+ Expr::SafeCast { expr, data_type: _, format: _ } => sanitize(expr),
+ Expr::AtTimeZone { timestamp, time_zone: _ } => sanitize(timestamp),
+ Expr::Extract { field: _, expr } => sanitize(expr),
+ Expr::Ceil { expr, field: _ } => sanitize(expr),
+ Expr::Floor { expr, field: _ } => sanitize(expr),
+ Expr::Position { expr, r#in } => sanitize(expr).and(sanitize(r#in)),
+ Expr::Substring { expr, substring_from, substring_for, special: _ } => sanitize(expr)
+ .and(substring_from.as_deref().map(sanitize).unwrap_or(Ok(())))
+ .and(substring_for.as_deref().map(sanitize).unwrap_or(Ok(()))),
+ Expr::Trim { expr, trim_where: _, trim_what, trim_characters } => sanitize(expr)
+ .and(trim_what.as_deref().map(sanitize).unwrap_or(Ok(())))
+ .and(
+ trim_characters
+ .as_ref()
+ .map(|v| v.iter())
+ .map(|mut iter| iter.try_for_each(sanitize))
+ .unwrap_or(Ok(()))
+ ),
+ Expr::Overlay { expr, overlay_what, overlay_from, overlay_for } => sanitize(expr)
+ .and(sanitize(overlay_what))
+ .and(sanitize(overlay_from))
+ .and(overlay_for.as_deref().map(sanitize).unwrap_or(Ok(()))),
+ Expr::Collate { expr, collation: _ } => sanitize(expr),
+ Expr::Nested(subexpr) => sanitize(subexpr),
+ Expr::Value(_) => Ok(()),
+ Expr::IntroducedString { introducer: _, value: _ } => Ok(()),
+ Expr::TypedString { data_type: _, value: _ } => Ok(()),
+ Expr::MapAccess { column, keys } => sanitize(column).and(keys.iter().try_for_each(sanitize)),
+ Expr::Function(func) => Err(Error::FunctionCallDetected(func)),
+ Expr::AggregateExpressionWithFilter { expr, filter } => sanitize(expr).and(sanitize(filter)),
+ Expr::Case { operand, conditions, results, else_result } => conditions.iter()
+ .chain(results)
+ .chain(operand.iter().map(std::borrow::Borrow::borrow))
+ .chain(else_result.iter().map(std::borrow::Borrow::borrow))
+ .try_for_each(sanitize),
+ Expr::Exists { subquery, negated: _ } => Err(Error::SubqueryDetected(subquery)),
+ Expr::Subquery(subquery) => Err(Error::SubqueryDetected(subquery.as_ref())),
+ Expr::ArraySubquery(subquery) => Err(Error::SubqueryDetected(subquery.as_ref())),
+ Expr::ListAgg(agg) => sanitize(&agg.expr),
+ Expr::ArrayAgg(agg) => sanitize(&agg.expr),
+ Expr::GroupingSets(sets) => sets.iter()
+ .map(|i| i.iter())
+ .try_for_each(|mut si| si.try_for_each(sanitize)),
+ Expr::Cube(cube) => cube.iter()
+ .map(|i| i.iter())
+ .try_for_each(|mut si| si.try_for_each(sanitize)),
+ Expr::Rollup(rollup) => rollup.iter()
+ .map(|i| i.iter())
+ .try_for_each(|mut si| si.try_for_each(sanitize)),
+ Expr::Tuple(tuple) => tuple.iter().try_for_each(sanitize),
+ Expr::Struct { values, fields: _ } => values.iter().try_for_each(sanitize),
+ Expr::Named { expr, name: _ } => sanitize(expr),
+ Expr::ArrayIndex { obj, indexes } => sanitize(obj)
+ .and(indexes.iter().try_for_each(sanitize)),
+ Expr::Array(array) => array.elem.iter().try_for_each(sanitize),
+ Expr::Interval(interval) => sanitize(&interval.value),
+ Expr::MatchAgainst { .. } => Ok(()),
+ Expr::Wildcard => Ok(()),
+ Expr::QualifiedWildcard(_) => Ok(()),
+ Expr::OuterJoin(expr) => sanitize(expr),
+ }
+}
+
+fn main() -> Result<(), sqlparser::parser::ParserError> {
+ let query = std::env::args().skip(1).take(1).next().unwrap();
+ static DIALECT: sqlparser::dialect::PostgreSqlDialect = sqlparser::dialect::PostgreSqlDialect {};
+
+ let parser: sqlparser::parser::Parser<'static> = sqlparser::parser::Parser::new(&DIALECT);
+ let expr = parser.try_with_sql(&query)?.parse_expr()?;
+ match sanitize(&expr) {
+ Ok(_) => eprintln!("{0:#?}\n\n{0}", expr),
+ Err(err) => {
+ eprintln!("{}", err);
+ }
+ }
+
+ Ok(())
+}