[Rust Server] Fix panic handling headers (#4877)

[Rust Server] Fix panic handling headers

If we have an API which has multiple auth types, we may panic. This is because
in Hyper 0.11, the following code will panic:

```
use hyper::header::{Authorization, Basic, Bearer, Headers};
fn main() {
    let mut headers = Headers::default();
    let basic = Basic { username: "richard".to_string(), password: None };
    headers.set::<Authorization<Basic>>(Authorization(basic));
    println!("Auth: {:?}", headers.get::<Authorization<Bearer>>());
}
```

as it mixes up an `Authorization<Basic>` and `Authorization<Bearer>` as both
have `Authorization:` as the header name.

This is fixed by using `swagger::SafeHeaders` added in
https://github.com/Metaswitch/swagger-rs/pull/90
This commit is contained in:
Richard Whitehouse 2020-01-05 14:46:09 +00:00 committed by GitHub
parent c2ee4aefe1
commit 79d11d7129
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 62 additions and 38 deletions

View File

@ -17,7 +17,7 @@ conversion = ["frunk", "frunk_derives", "frunk_core", "frunk-enum-core", "frunk-
# Common
chrono = { version = "0.4", features = ["serde"] }
futures = "0.1"
swagger = "2"
swagger = "2.2"
lazy_static = "1.4"
log = "0.3.0"
mime = "0.2.6"

View File

@ -33,6 +33,8 @@ use std::sync::Arc;
use std::str;
use std::str::FromStr;
use std::string::ToString;
use swagger::headers::SafeHeaders;
{{#apiUsesMultipart}}
use hyper::mime::Mime;
use std::io::Cursor;
@ -478,7 +480,7 @@ impl<F, C> Api<C> for Client<F> where
{{#responses}}
{{{code}}} => {
{{#headers}} header! { (Response{{{nameInCamelCase}}}, "{{{baseName}}}") => [{{{datatype}}}] }
let response_{{{name}}} = match response.headers().get::<Response{{{nameInCamelCase}}}>() {
let response_{{{name}}} = match response.headers().safe_get::<Response{{{nameInCamelCase}}}>() {
Some(response_{{{name}}}) => response_{{{name}}}.0.clone(),
None => return Box::new(future::err(ApiError(String::from("Required response header {{{baseName}}} for response {{{code}}} was not found.")))) as Box<dyn Future<Item=_, Error=_>>,
};

View File

@ -6,6 +6,7 @@ use hyper::{Request, Response, Error, StatusCode};
use server::url::form_urlencoded;
use swagger::auth::{Authorization, AuthData, Scopes};
use swagger::{Has, Pop, Push, XSpanIdString};
use swagger::headers::SafeHeaders;
use Api;
pub struct NewAddContext<T, A>
@ -88,7 +89,7 @@ impl<T, A, B, C, D> hyper::server::Service for AddContext<T, A>
{
use hyper::header::{Authorization as HyperAuth, Basic, Bearer};
use std::ops::Deref;
if let Some(basic) = req.headers().get::<HyperAuth<Basic>>().cloned() {
if let Some(basic) = req.headers().safe_get::<HyperAuth<Basic>>() {
let auth_data = AuthData::Basic(basic.deref().clone());
let context = context.push(Some(auth_data));
let context = context.push(None::<Authorization>);
@ -100,7 +101,7 @@ impl<T, A, B, C, D> hyper::server::Service for AddContext<T, A>
{
use hyper::header::{Authorization as HyperAuth, Basic, Bearer};
use std::ops::Deref;
if let Some(bearer) = req.headers().get::<HyperAuth<Bearer>>().cloned() {
if let Some(bearer) = req.headers().safe_get::<HyperAuth<Bearer>>() {
let auth_data = AuthData::Bearer(bearer.deref().clone());
let context = context.push(Some(auth_data));
let context = context.push(None::<Authorization>);
@ -112,7 +113,7 @@ impl<T, A, B, C, D> hyper::server::Service for AddContext<T, A>
{{#isKeyInHeader}}
{
header! { (ApiKey{{-index}}, "{{{keyParamName}}}") => [String] }
if let Some(header) = req.headers().get::<ApiKey{{-index}}>().cloned() {
if let Some(header) = req.headers().safe_get::<ApiKey{{-index}}>() {
let auth_data = AuthData::ApiKey(header.0);
let context = context.push(Some(auth_data));
let context = context.push(None::<Authorization>);

View File

@ -45,6 +45,7 @@ use std::collections::BTreeSet;
pub use swagger::auth::Authorization;
use swagger::{ApiError, XSpanId, XSpanIdString, Has, RequestParser};
use swagger::auth::Scopes;
use swagger::headers::SafeHeaders;
use {Api{{#apiInfo}}{{#apis}}{{#operations}}{{#operation}},
{{{operationId}}}Response{{/operation}}{{/operations}}{{/apis}}{{/apiInfo}}
@ -178,7 +179,7 @@ where
{{#vendorExtensions}}
{{#consumesMultipart}}
let boundary = match multipart_boundary(&headers) {
Some(boundary) => boundary.to_string(),
Some(boundary) => boundary,
None => return Box::new(future::ok(Response::new().with_status(StatusCode::BadRequest).with_body("Couldn't find valid multipart body"))),
};
{{/consumesMultipart}}
@ -214,7 +215,7 @@ where
};
{{/required}}
{{^required}}
let param_{{{paramName}}} = headers.get::<Request{{vendorExtensions.typeName}}>().map(|header| header.0.clone());
let param_{{{paramName}}} = headers.safe_get::<Request{{vendorExtensions.typeName}}>().map(|header| header.0.clone());
{{/required}}
{{/headerParams}}
{{#queryParams}}
@ -530,11 +531,11 @@ impl<T, C> Clone for Service<T, C>
{{#apiUsesMultipart}}
/// Utility function to get the multipart boundary marker (if any) from the Headers.
fn multipart_boundary<'a>(headers: &'a Headers) -> Option<&'a str> {
headers.get::<ContentType>().and_then(|content_type| {
let ContentType(ref mime) = *content_type;
fn multipart_boundary(headers: &Headers) -> Option<String> {
headers.safe_get::<ContentType>().and_then(|content_type| {
let ContentType(mime) = content_type;
if mime.type_() == hyper::mime::MULTIPART && mime.subtype() == hyper::mime::FORM_DATA {
mime.get_param(hyper::mime::BOUNDARY).map(|x| x.as_str())
mime.get_param(hyper::mime::BOUNDARY).map(|x| x.as_str().to_string())
} else {
None
}

View File

@ -15,7 +15,7 @@ conversion = ["frunk", "frunk_derives", "frunk_core", "frunk-enum-core", "frunk-
# Common
chrono = { version = "0.4", features = ["serde"] }
futures = "0.1"
swagger = "2"
swagger = "2.2"
lazy_static = "1.4"
log = "0.3.0"
mime = "0.2.6"

View File

@ -25,6 +25,8 @@ use std::sync::Arc;
use std::str;
use std::str::FromStr;
use std::string::ToString;
use swagger::headers::SafeHeaders;
use hyper::mime::Mime;
use std::io::Cursor;
use client::multipart::client::lazy::Multipart;

View File

@ -6,6 +6,7 @@ use hyper::{Request, Response, Error, StatusCode};
use server::url::form_urlencoded;
use swagger::auth::{Authorization, AuthData, Scopes};
use swagger::{Has, Pop, Push, XSpanIdString};
use swagger::headers::SafeHeaders;
use Api;
pub struct NewAddContext<T, A>

View File

@ -35,6 +35,7 @@ use std::collections::BTreeSet;
pub use swagger::auth::Authorization;
use swagger::{ApiError, XSpanId, XSpanIdString, Has, RequestParser};
use swagger::auth::Scopes;
use swagger::headers::SafeHeaders;
use {Api,
MultipartRequestPostResponse
@ -123,7 +124,7 @@ where
// MultipartRequestPost - POST /multipart_request
&hyper::Method::Post if path.matched(paths::ID_MULTIPART_REQUEST) => {
let boundary = match multipart_boundary(&headers) {
Some(boundary) => boundary.to_string(),
Some(boundary) => boundary,
None => return Box::new(future::ok(Response::new().with_status(StatusCode::BadRequest).with_body("Couldn't find valid multipart body"))),
};
// Form Body parameters (note that non-required body parameters will ignore garbage
@ -250,11 +251,11 @@ impl<T, C> Clone for Service<T, C>
}
/// Utility function to get the multipart boundary marker (if any) from the Headers.
fn multipart_boundary<'a>(headers: &'a Headers) -> Option<&'a str> {
headers.get::<ContentType>().and_then(|content_type| {
let ContentType(ref mime) = *content_type;
fn multipart_boundary(headers: &Headers) -> Option<String> {
headers.safe_get::<ContentType>().and_then(|content_type| {
let ContentType(mime) = content_type;
if mime.type_() == hyper::mime::MULTIPART && mime.subtype() == hyper::mime::FORM_DATA {
mime.get_param(hyper::mime::BOUNDARY).map(|x| x.as_str())
mime.get_param(hyper::mime::BOUNDARY).map(|x| x.as_str().to_string())
} else {
None
}

View File

@ -15,7 +15,7 @@ conversion = ["frunk", "frunk_derives", "frunk_core", "frunk-enum-core", "frunk-
# Common
chrono = { version = "0.4", features = ["serde"] }
futures = "0.1"
swagger = "2"
swagger = "2.2"
lazy_static = "1.4"
log = "0.3.0"
mime = "0.2.6"

View File

@ -25,6 +25,8 @@ use std::sync::Arc;
use std::str;
use std::str::FromStr;
use std::string::ToString;
use swagger::headers::SafeHeaders;
use mimetypes;
use serde_json;
use serde_xml_rs;
@ -675,7 +677,7 @@ impl<F, C> Api<C> for Client<F> where
match response.status().as_u16() {
200 => {
header! { (ResponseSuccessInfo, "Success-Info") => [String] }
let response_success_info = match response.headers().get::<ResponseSuccessInfo>() {
let response_success_info = match response.headers().safe_get::<ResponseSuccessInfo>() {
Some(response_success_info) => response_success_info.0.clone(),
None => return Box::new(future::err(ApiError(String::from("Required response header Success-Info for response 200 was not found.")))) as Box<dyn Future<Item=_, Error=_>>,
};
@ -699,12 +701,12 @@ impl<F, C> Api<C> for Client<F> where
},
412 => {
header! { (ResponseFurtherInfo, "Further-Info") => [String] }
let response_further_info = match response.headers().get::<ResponseFurtherInfo>() {
let response_further_info = match response.headers().safe_get::<ResponseFurtherInfo>() {
Some(response_further_info) => response_further_info.0.clone(),
None => return Box::new(future::err(ApiError(String::from("Required response header Further-Info for response 412 was not found.")))) as Box<dyn Future<Item=_, Error=_>>,
};
header! { (ResponseFailureInfo, "Failure-Info") => [String] }
let response_failure_info = match response.headers().get::<ResponseFailureInfo>() {
let response_failure_info = match response.headers().safe_get::<ResponseFailureInfo>() {
Some(response_failure_info) => response_failure_info.0.clone(),
None => return Box::new(future::err(ApiError(String::from("Required response header Failure-Info for response 412 was not found.")))) as Box<dyn Future<Item=_, Error=_>>,
};

View File

@ -6,6 +6,7 @@ use hyper::{Request, Response, Error, StatusCode};
use server::url::form_urlencoded;
use swagger::auth::{Authorization, AuthData, Scopes};
use swagger::{Has, Pop, Push, XSpanIdString};
use swagger::headers::SafeHeaders;
use Api;
pub struct NewAddContext<T, A>
@ -86,7 +87,7 @@ impl<T, A, B, C, D> hyper::server::Service for AddContext<T, A>
{
use hyper::header::{Authorization as HyperAuth, Basic, Bearer};
use std::ops::Deref;
if let Some(bearer) = req.headers().get::<HyperAuth<Bearer>>().cloned() {
if let Some(bearer) = req.headers().safe_get::<HyperAuth<Bearer>>() {
let auth_data = AuthData::Bearer(bearer.deref().clone());
let context = context.push(Some(auth_data));
let context = context.push(None::<Authorization>);

View File

@ -33,6 +33,7 @@ use std::collections::BTreeSet;
pub use swagger::auth::Authorization;
use swagger::{ApiError, XSpanId, XSpanIdString, Has, RequestParser};
use swagger::auth::Scopes;
use swagger::headers::SafeHeaders;
use {Api,
MultigetGetResponse,

View File

@ -15,7 +15,7 @@ conversion = ["frunk", "frunk_derives", "frunk_core", "frunk-enum-core", "frunk-
# Common
chrono = { version = "0.4", features = ["serde"] }
futures = "0.1"
swagger = "2"
swagger = "2.2"
lazy_static = "1.4"
log = "0.3.0"
mime = "0.2.6"

View File

@ -24,6 +24,8 @@ use std::sync::Arc;
use std::str;
use std::str::FromStr;
use std::string::ToString;
use swagger::headers::SafeHeaders;
use mimetypes;
use serde_json;

View File

@ -6,6 +6,7 @@ use hyper::{Request, Response, Error, StatusCode};
use server::url::form_urlencoded;
use swagger::auth::{Authorization, AuthData, Scopes};
use swagger::{Has, Pop, Push, XSpanIdString};
use swagger::headers::SafeHeaders;
use Api;
pub struct NewAddContext<T, A>

View File

@ -31,6 +31,7 @@ use std::collections::BTreeSet;
pub use swagger::auth::Authorization;
use swagger::{ApiError, XSpanId, XSpanIdString, Has, RequestParser};
use swagger::auth::Scopes;
use swagger::headers::SafeHeaders;
use {Api,
Op10GetResponse,

View File

@ -15,7 +15,7 @@ conversion = ["frunk", "frunk_derives", "frunk_core", "frunk-enum-core", "frunk-
# Common
chrono = { version = "0.4", features = ["serde"] }
futures = "0.1"
swagger = "2"
swagger = "2.2"
lazy_static = "1.4"
log = "0.3.0"
mime = "0.2.6"

View File

@ -27,6 +27,8 @@ use std::sync::Arc;
use std::str;
use std::str::FromStr;
use std::string::ToString;
use swagger::headers::SafeHeaders;
use hyper::mime::Mime;
use std::io::Cursor;
use client::multipart::client::lazy::Multipart;
@ -2705,12 +2707,12 @@ impl<F, C> Api<C> for Client<F> where
match response.status().as_u16() {
200 => {
header! { (ResponseXRateLimit, "X-Rate-Limit") => [i32] }
let response_x_rate_limit = match response.headers().get::<ResponseXRateLimit>() {
let response_x_rate_limit = match response.headers().safe_get::<ResponseXRateLimit>() {
Some(response_x_rate_limit) => response_x_rate_limit.0.clone(),
None => return Box::new(future::err(ApiError(String::from("Required response header X-Rate-Limit for response 200 was not found.")))) as Box<dyn Future<Item=_, Error=_>>,
};
header! { (ResponseXExpiresAfter, "X-Expires-After") => [chrono::DateTime<chrono::Utc>] }
let response_x_expires_after = match response.headers().get::<ResponseXExpiresAfter>() {
let response_x_expires_after = match response.headers().safe_get::<ResponseXExpiresAfter>() {
Some(response_x_expires_after) => response_x_expires_after.0.clone(),
None => return Box::new(future::err(ApiError(String::from("Required response header X-Expires-After for response 200 was not found.")))) as Box<dyn Future<Item=_, Error=_>>,
};

View File

@ -6,6 +6,7 @@ use hyper::{Request, Response, Error, StatusCode};
use server::url::form_urlencoded;
use swagger::auth::{Authorization, AuthData, Scopes};
use swagger::{Has, Pop, Push, XSpanIdString};
use swagger::headers::SafeHeaders;
use Api;
pub struct NewAddContext<T, A>
@ -85,7 +86,7 @@ impl<T, A, B, C, D> hyper::server::Service for AddContext<T, A>
{
header! { (ApiKey1, "api_key") => [String] }
if let Some(header) = req.headers().get::<ApiKey1>().cloned() {
if let Some(header) = req.headers().safe_get::<ApiKey1>() {
let auth_data = AuthData::ApiKey(header.0);
let context = context.push(Some(auth_data));
let context = context.push(None::<Authorization>);
@ -107,7 +108,7 @@ impl<T, A, B, C, D> hyper::server::Service for AddContext<T, A>
{
use hyper::header::{Authorization as HyperAuth, Basic, Bearer};
use std::ops::Deref;
if let Some(basic) = req.headers().get::<HyperAuth<Basic>>().cloned() {
if let Some(basic) = req.headers().safe_get::<HyperAuth<Basic>>() {
let auth_data = AuthData::Basic(basic.deref().clone());
let context = context.push(Some(auth_data));
let context = context.push(None::<Authorization>);
@ -117,7 +118,7 @@ impl<T, A, B, C, D> hyper::server::Service for AddContext<T, A>
{
use hyper::header::{Authorization as HyperAuth, Basic, Bearer};
use std::ops::Deref;
if let Some(bearer) = req.headers().get::<HyperAuth<Bearer>>().cloned() {
if let Some(bearer) = req.headers().safe_get::<HyperAuth<Bearer>>() {
let auth_data = AuthData::Bearer(bearer.deref().clone());
let context = context.push(Some(auth_data));
let context = context.push(None::<Authorization>);

View File

@ -37,6 +37,7 @@ use std::collections::BTreeSet;
pub use swagger::auth::Authorization;
use swagger::{ApiError, XSpanId, XSpanIdString, Has, RequestParser};
use swagger::auth::Scopes;
use swagger::headers::SafeHeaders;
use {Api,
TestSpecialTagsResponse,
@ -752,9 +753,9 @@ where
&hyper::Method::Get if path.matched(paths::ID_FAKE) => {
// Header parameters
header! { (RequestEnumHeaderStringArray, "enum_header_string_array") => (String)* }
let param_enum_header_string_array = headers.get::<RequestEnumHeaderStringArray>().map(|header| header.0.clone());
let param_enum_header_string_array = headers.safe_get::<RequestEnumHeaderStringArray>().map(|header| header.0.clone());
header! { (RequestEnumHeaderString, "enum_header_string") => [String] }
let param_enum_header_string = headers.get::<RequestEnumHeaderString>().map(|header| header.0.clone());
let param_enum_header_string = headers.safe_get::<RequestEnumHeaderString>().map(|header| header.0.clone());
// Query parameters (note that non-required or collection query parameters will ignore garbage values, rather than causing a 400 response)
let query_params = form_urlencoded::parse(uri.query().unwrap_or_default().as_bytes()).collect::<Vec<_>>();
let param_enum_query_string_array = query_params.iter().filter(|e| e.0 == "enum_query_string_array").map(|e| e.1.to_owned())
@ -1132,7 +1133,7 @@ where
};
// Header parameters
header! { (RequestApiKey, "api_key") => [String] }
let param_api_key = headers.get::<RequestApiKey>().map(|header| header.0.clone());
let param_api_key = headers.safe_get::<RequestApiKey>().map(|header| header.0.clone());
Box::new({
{{
Box::new(api_impl.delete_pet(param_pet_id, param_api_key, &context)
@ -1614,7 +1615,7 @@ where
}
}
let boundary = match multipart_boundary(&headers) {
Some(boundary) => boundary.to_string(),
Some(boundary) => boundary,
None => return Box::new(future::ok(Response::new().with_status(StatusCode::BadRequest).with_body("Couldn't find valid multipart body"))),
};
// Path parameters
@ -2489,11 +2490,11 @@ impl<T, C> Clone for Service<T, C>
}
/// Utility function to get the multipart boundary marker (if any) from the Headers.
fn multipart_boundary<'a>(headers: &'a Headers) -> Option<&'a str> {
headers.get::<ContentType>().and_then(|content_type| {
let ContentType(ref mime) = *content_type;
fn multipart_boundary(headers: &Headers) -> Option<String> {
headers.safe_get::<ContentType>().and_then(|content_type| {
let ContentType(mime) = content_type;
if mime.type_() == hyper::mime::MULTIPART && mime.subtype() == hyper::mime::FORM_DATA {
mime.get_param(hyper::mime::BOUNDARY).map(|x| x.as_str())
mime.get_param(hyper::mime::BOUNDARY).map(|x| x.as_str().to_string())
} else {
None
}

View File

@ -15,7 +15,7 @@ conversion = ["frunk", "frunk_derives", "frunk_core", "frunk-enum-core", "frunk-
# Common
chrono = { version = "0.4", features = ["serde"] }
futures = "0.1"
swagger = "2"
swagger = "2.2"
lazy_static = "1.4"
log = "0.3.0"
mime = "0.2.6"

View File

@ -24,6 +24,8 @@ use std::sync::Arc;
use std::str;
use std::str::FromStr;
use std::string::ToString;
use swagger::headers::SafeHeaders;
use mimetypes;
use serde_json;

View File

@ -6,6 +6,7 @@ use hyper::{Request, Response, Error, StatusCode};
use server::url::form_urlencoded;
use swagger::auth::{Authorization, AuthData, Scopes};
use swagger::{Has, Pop, Push, XSpanIdString};
use swagger::headers::SafeHeaders;
use Api;
pub struct NewAddContext<T, A>

View File

@ -31,6 +31,7 @@ use std::collections::BTreeSet;
pub use swagger::auth::Authorization;
use swagger::{ApiError, XSpanId, XSpanIdString, Has, RequestParser};
use swagger::auth::Scopes;
use swagger::headers::SafeHeaders;
use {Api,
DummyGetResponse,