Skip to content

Commit 5cf4c23

Browse files
committed
Add Cast operator
1 parent 56a39c0 commit 5cf4c23

File tree

4 files changed

+199
-1
lines changed

4 files changed

+199
-1
lines changed

diesel/src/expression/cast.rs

+152
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
//! SQL `CAST(expr AS sql_type)` expression support
2+
3+
use crate::expression::{AppearsOnTable, Expression, SelectableExpression, ValidGrouping};
4+
use crate::query_source::aliasing::{AliasSource, FieldAliasMapper};
5+
use crate::result::QueryResult;
6+
use crate::{query_builder, query_source, sql_types};
7+
8+
use std::marker::PhantomData;
9+
10+
pub(crate) mod private {
11+
use super::*;
12+
13+
#[derive(Debug, Clone, Copy, diesel::query_builder::QueryId, sql_types::DieselNumericOps)]
14+
pub struct Cast<E, ST> {
15+
pub(super) expr: E,
16+
pub(super) sql_type: PhantomData<ST>,
17+
}
18+
}
19+
pub(crate) use private::Cast;
20+
21+
impl<E, ST> Cast<E, ST> {
22+
pub(crate) fn new(expr: E) -> Self {
23+
Self {
24+
expr,
25+
sql_type: PhantomData,
26+
}
27+
}
28+
}
29+
30+
impl<E, ST, GroupByClause> ValidGrouping<GroupByClause> for Cast<E, ST>
31+
where
32+
E: ValidGrouping<GroupByClause>,
33+
{
34+
type IsAggregate = E::IsAggregate;
35+
}
36+
37+
impl<E, ST, QS> SelectableExpression<QS> for Cast<E, ST>
38+
where
39+
Cast<E, ST>: AppearsOnTable<QS>,
40+
E: SelectableExpression<QS>,
41+
{
42+
}
43+
44+
impl<E, ST, QS> AppearsOnTable<QS> for Cast<E, ST>
45+
where
46+
Cast<E, ST>: Expression,
47+
E: AppearsOnTable<QS>,
48+
{
49+
}
50+
51+
impl<E, ST> Expression for Cast<E, ST>
52+
where
53+
E: Expression,
54+
ST: sql_types::SingleValue,
55+
{
56+
type SqlType = ST;
57+
}
58+
impl<E, ST, DB> query_builder::QueryFragment<DB> for Cast<E, ST>
59+
where
60+
E: query_builder::QueryFragment<DB>,
61+
DB: diesel::backend::Backend,
62+
ST: KnownCastSqlTypeName<DB>,
63+
{
64+
fn walk_ast<'b>(&'b self, mut out: query_builder::AstPass<'_, 'b, DB>) -> QueryResult<()> {
65+
out.push_sql("CAST(");
66+
self.expr.walk_ast(out.reborrow())?;
67+
out.push_sql(" AS ");
68+
out.push_sql(ST::sql_type_name());
69+
out.push_sql(")");
70+
Ok(())
71+
}
72+
}
73+
74+
/// We know what to write as `sql_type` in the `CAST(expr AS sql_type)` SQL for
75+
/// `Self`
76+
///
77+
/// That is what is returned by `Self::sql_type_name()`
78+
#[diagnostic::on_unimplemented(
79+
note = "In order to use `CAST`, it is necessary that Diesel knows how to express the name \
80+
of this type in the given backend.",
81+
note = "This can be PRed into Diesel if the type is a standard SQL type."
82+
)]
83+
pub trait KnownCastSqlTypeName<DB> {
84+
/// What to write as `sql_type` in the `CAST(expr AS sql_type)` SQL for
85+
/// `Self`
86+
fn sql_type_name() -> &'static str;
87+
}
88+
impl<ST, DB> KnownCastSqlTypeName<DB> for sql_types::Nullable<ST>
89+
where
90+
ST: KnownCastSqlTypeName<DB>,
91+
{
92+
fn sql_type_name() -> &'static str {
93+
<ST as KnownCastSqlTypeName<DB>>::sql_type_name()
94+
}
95+
}
96+
97+
macro_rules! type_name {
98+
($($backend: ty: $backend_feature: literal { $($type: ident => $val: literal,)+ })*) => {
99+
$(
100+
$(
101+
#[cfg(feature = $backend_feature)]
102+
impl KnownCastSqlTypeName<$backend> for sql_types::$type {
103+
fn sql_type_name() -> &'static str {
104+
$val
105+
}
106+
}
107+
)*
108+
)*
109+
};
110+
}
111+
type_name! {
112+
diesel::pg::Pg: "postgres_backend" {
113+
Int4 => "int4",
114+
Int8 => "int8",
115+
Text => "text",
116+
}
117+
diesel::mysql::Mysql: "mysql_backend" {
118+
Int4 => "integer",
119+
Int8 => "integer",
120+
Text => "text",
121+
}
122+
diesel::sqlite::Sqlite: "sqlite" {
123+
Int4 => "integer",
124+
Int8 => "bigint",
125+
Text => "text",
126+
}
127+
}
128+
129+
impl<S, E, ST> FieldAliasMapper<S> for Cast<E, ST>
130+
where
131+
S: AliasSource,
132+
E: FieldAliasMapper<S>,
133+
{
134+
type Out = Cast<<E as FieldAliasMapper<S>>::Out, ST>;
135+
fn map(self, alias: &query_source::Alias<S>) -> Self::Out {
136+
Cast {
137+
expr: self.expr.map(alias),
138+
sql_type: self.sql_type,
139+
}
140+
}
141+
}
142+
143+
/// Marker trait: this SQL type (`Self`) can be casted to the target SQL type
144+
/// (`ST`) using `CAST(expr AS target_sql_type)`
145+
pub trait CastsTo<ST> {}
146+
impl<ST1, ST2> CastsTo<sql_types::Nullable<ST2>> for sql_types::Nullable<ST1> where ST1: CastsTo<ST2>
147+
{}
148+
149+
impl CastsTo<sql_types::Int8> for sql_types::Int4 {}
150+
impl CastsTo<sql_types::Int4> for sql_types::Int8 {}
151+
impl CastsTo<sql_types::Text> for sql_types::Int4 {}
152+
impl CastsTo<sql_types::Text> for sql_types::Int8 {}

diesel/src/expression/helper_types.rs

+4
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,10 @@ pub type NotBetween<Lhs, Lower, Upper> = Grouped<
9090
/// [`lhs.concat(rhs)`](crate::expression_methods::TextExpressionMethods::concat())
9191
pub type Concat<Lhs, Rhs> = Grouped<super::operators::Concat<Lhs, AsExpr<Rhs, Lhs>>>;
9292

93+
/// The return type of
94+
/// [`expr.cast<ST>()`](crate::expression_methods::ExpressionMethods::cast())
95+
pub type Cast<Expr, ST> = super::cast::Cast<Expr, ST>;
96+
9397
/// The return type of
9498
/// [`expr.desc()`](crate::expression_methods::ExpressionMethods::desc())
9599
pub type Desc<Expr> = super::operators::Desc<Expr>;

diesel/src/expression/mod.rs

+1
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ pub(crate) mod nullable;
3737
#[macro_use]
3838
pub(crate) mod operators;
3939
mod case_when;
40+
pub mod cast;
4041
pub(crate) mod select_by;
4142
mod sql_literal;
4243
pub(crate) mod subselect;

diesel/src/expression_methods/global_expression_methods.rs

+42-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use crate::dsl;
22
use crate::expression::array_comparison::{AsInExpression, In, NotIn};
33
use crate::expression::grouped::Grouped;
44
use crate::expression::operators::*;
5-
use crate::expression::{assume_not_null, nullable, AsExpression, Expression};
5+
use crate::expression::{assume_not_null, cast, nullable, AsExpression, Expression};
66
use crate::sql_types::{SingleValue, SqlType};
77

88
/// Methods present on all expressions, except tuples
@@ -437,6 +437,47 @@ pub trait ExpressionMethods: Expression + Sized {
437437
))
438438
}
439439

440+
/// Generates a `CAST(expr AS sql_type)` expression
441+
///
442+
/// It is necessary that the expression's SQL type can be casted to the
443+
/// target SQL type (represented by the [`CastsTo`](cast::CastsTo) trait),
444+
/// and that we know how the corresponding SQL type is named for the
445+
/// specific backend (represented by the
446+
/// [`KnownCastSqlTypeName`](cast::KnownCastSqlTypeName) trait).
447+
///
448+
/// # Example
449+
///
450+
/// ```rust
451+
/// # include!("../doctest_setup.rs");
452+
/// #
453+
/// # fn main() {
454+
/// # run_test().unwrap();
455+
/// # }
456+
/// #
457+
/// # fn run_test() -> QueryResult<()> {
458+
/// # use schema::animals::dsl::*;
459+
/// # let connection = &mut establish_connection();
460+
/// #
461+
/// use diesel::sql_types;
462+
///
463+
/// let data = diesel::select(
464+
/// 12_i32
465+
/// .into_sql::<sql_types::Int4>()
466+
/// .cast::<sql_types::Text>(),
467+
/// )
468+
/// .first::<String>(connection)?;
469+
/// assert_eq!("12", data);
470+
/// # Ok(())
471+
/// # }
472+
/// ```
473+
fn cast<ST>(self) -> dsl::Cast<Self, ST>
474+
where
475+
ST: SingleValue,
476+
Self::SqlType: cast::CastsTo<ST>,
477+
{
478+
cast::Cast::new(self)
479+
}
480+
440481
/// Creates a SQL `DESC` expression, representing this expression in
441482
/// descending order.
442483
///

0 commit comments

Comments
 (0)