Skip to content

Commit e70b79f

Browse files
committed
Implement #2871 by matching on outcome
1 parent 3bf9ef0 commit e70b79f

File tree

3 files changed

+88
-4
lines changed

3 files changed

+88
-4
lines changed

core/lib/src/request/from_request.rs

+4-3
Original file line numberDiff line numberDiff line change
@@ -521,12 +521,13 @@ impl<'r, T: FromRequest<'r>> FromRequest<'r> for Result<T, T::Error> {
521521

522522
#[crate::async_trait]
523523
impl<'r, T: FromRequest<'r>> FromRequest<'r> for Option<T> {
524-
type Error = Infallible;
524+
type Error = T::Error;
525525

526-
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Infallible> {
526+
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
527527
match T::from_request(request).await {
528528
Success(val) => Success(Some(val)),
529-
Error(_) | Forward(_) => Success(None),
529+
Forward(_) => Success(None),
530+
Error((status, error)) => Error((status, error)),
530531
}
531532
}
532533
}

core/lib/src/response/flash.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ impl<'r> FromRequest<'r> for FlashMessage<'r> {
256256
Ok(i) if i <= kv.len() => Ok(Flash::named(&kv[..i], &kv[i..], req)),
257257
_ => Err(())
258258
}
259-
}).or_error(Status::BadRequest)
259+
}).or_forward(Status::BadRequest)
260260
}
261261
}
262262

+83
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
#[macro_use]
2+
extern crate rocket;
3+
4+
use std::num::ParseIntError;
5+
6+
use rocket::{outcome::IntoOutcome, request::{FromRequest, Outcome}, Request};
7+
use rocket_http::{Header, Status};
8+
9+
pub struct SessionId {
10+
session_id: u64,
11+
}
12+
13+
#[rocket::async_trait]
14+
impl<'r> FromRequest<'r> for SessionId {
15+
type Error = ParseIntError;
16+
17+
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, ParseIntError> {
18+
let session_id_string = request.headers().get("Session-Id").next()
19+
.or_forward(Status::BadRequest);
20+
session_id_string.and_then(|v| v.parse()
21+
.map(|id| SessionId { session_id: id })
22+
.or_error(Status::BadRequest))
23+
}
24+
}
25+
26+
#[get("/mandatory")]
27+
fn get_data_with_mandatory_header(header: SessionId) -> String {
28+
format!("GET for session {:}", header.session_id)
29+
}
30+
31+
#[get("/optional")]
32+
fn get_data_with_opt_header(opt_header: Option<SessionId>) -> String {
33+
if let Some(id) = opt_header {
34+
format!("GET for session {:}", id.session_id)
35+
} else {
36+
format!("GET for new session")
37+
}
38+
}
39+
40+
#[test]
41+
fn read_optional_header() {
42+
let rocket = rocket::build().mount(
43+
"/",
44+
routes![get_data_with_opt_header, get_data_with_mandatory_header]);
45+
let client = rocket::local::blocking::Client::debug(rocket).unwrap();
46+
47+
// If we supply the header, the handler sees it
48+
let response = client.get("/optional")
49+
.header(Header::new("session-id", "1234567")).dispatch();
50+
assert_eq!(response.into_string().unwrap(), "GET for session 1234567".to_string());
51+
52+
// If no header, means that the handler sees a None
53+
let response = client.get("/optional").dispatch();
54+
assert_eq!(response.into_string().unwrap(), "GET for new session".to_string());
55+
56+
// If we supply a malformed header, the handler will not be called, but the request will fail
57+
let response = client.get("/optional")
58+
.header(Header::new("session-id", "Xw23")).dispatch();
59+
assert_eq!(response.status(), Status::BadRequest);
60+
}
61+
62+
#[test]
63+
fn read_mandatory_header() {
64+
let rocket = rocket::build().mount(
65+
"/",
66+
routes![get_data_with_opt_header, get_data_with_mandatory_header]);
67+
let client = rocket::local::blocking::Client::debug(rocket).unwrap();
68+
69+
// If the header is missing, it's a bad request (extra info would be nice, though)
70+
let response = client.get("/mandatory").dispatch();
71+
assert_eq!(response.status(), Status::BadRequest);
72+
73+
// If the header is malformed, it's a bad request too (extra info would be nice, though)
74+
let response = client.get("/mandatory")
75+
.header(Header::new("session-id", "Xw23")).dispatch();
76+
assert_eq!(response.status(), Status::BadRequest);
77+
78+
// If the header is fine, just do the stuff
79+
let response = client.get("/mandatory")
80+
.header(Header::new("session-id", "64535")).dispatch();
81+
assert_eq!(response.status(), Status::Ok);
82+
assert_eq!(response.into_string().unwrap(), "GET for session 64535".to_string());
83+
}

0 commit comments

Comments
 (0)