Skip to content

Commit d5b8a4c

Browse files
merlinio2000maggimerheaths
authored
Support xml error respsonse (Azure#1631)
* core: make HttpError use content-type header to deserialize body * core: respect Content-Type in ErrorKind::http_response_from_body * core: correctly de-serialize XML errors with different case-ing * cleanup if statements * core: only parse xml error bodies if 'xml' feature is present * core: refactor ErrorKind::HttpResponse to make use of headers * Update sdk/core/src/error/mod.rs Co-authored-by: Heath Stewart <heaths@outlook.com> * Update sdk/core/src/error/http_error.rs Co-authored-by: Heath Stewart <heaths@outlook.com> * Update sdk/core/src/error/mod.rs Co-authored-by: Heath Stewart <heaths@outlook.com> * Revert "Update sdk/core/src/error/http_error.rs" This reverts commit a2f3b4b. --------- Co-authored-by: Merlin Maggi <merlin.maggi@jls.ch> Co-authored-by: Heath Stewart <heaths@outlook.com>
1 parent bf92c43 commit d5b8a4c

File tree

8 files changed

+135
-52
lines changed

8 files changed

+135
-52
lines changed

sdk/core/src/error/http_error.rs

+73-31
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
1-
use crate::{from_json, headers, Response, StatusCode};
1+
use crate::{
2+
content_type, from_json,
3+
headers::{self, Headers},
4+
Response, StatusCode,
5+
};
26
use bytes::Bytes;
37
use serde::Deserialize;
4-
use std::collections::HashMap;
58

69
/// An unsuccessful HTTP response
710
#[derive(Debug)]
811
pub struct HttpError {
912
status: StatusCode,
1013
details: ErrorDetails,
11-
headers: std::collections::HashMap<String, String>,
14+
headers: Headers,
1215
body: Bytes,
1316
}
1417

@@ -17,14 +20,8 @@ impl HttpError {
1720
///
1821
/// This does not check whether the response was a success and should only be used with unsuccessful responses.
1922
pub async fn new(response: Response) -> Self {
20-
let status = response.status();
21-
let headers: HashMap<String, String> = response
22-
.headers()
23-
.iter()
24-
.map(|(name, value)| (name.as_str().to_owned(), value.as_str().to_owned()))
25-
.collect();
26-
let body = response
27-
.into_body()
23+
let (status, headers, body) = response.deconstruct();
24+
let body = body
2825
.collect()
2926
.await
3027
.unwrap_or_else(|_| Bytes::from_static(b"<ERROR COLLECTING BODY>"));
@@ -71,11 +68,15 @@ impl std::fmt::Display for HttpError {
7168
write!(f, "{tab}Body: \"{:?}\",{newline}", self.body)?;
7269
write!(f, "{tab}Headers: [{newline}")?;
7370
// TODO: sanitize headers
74-
for (k, v) in &self.headers {
75-
write!(f, "{tab}{tab}{k}:{v}{newline}")?;
71+
for (k, v) in self.headers.iter() {
72+
write!(
73+
f,
74+
"{tab}{tab}{k}:{v}{newline}",
75+
k = k.as_str(),
76+
v = v.as_str()
77+
)?;
7678
}
7779
write!(f, "{tab}],{newline}}}{newline}")?;
78-
7980
Ok(())
8081
}
8182
}
@@ -89,46 +90,87 @@ struct ErrorDetails {
8990
}
9091

9192
impl ErrorDetails {
92-
fn new(headers: &HashMap<String, String>, body: &[u8]) -> Self {
93-
let mut code = get_error_code_from_header(headers);
94-
code = code.or_else(|| get_error_code_from_body(body));
95-
let message = get_error_message_from_body(body);
96-
Self { code, message }
93+
fn new(headers: &Headers, body: &[u8]) -> Self {
94+
let header_err_code = get_error_code_from_header(headers);
95+
let content_type = headers.get_optional_str(&headers::CONTENT_TYPE);
96+
let (body_err_code, body_err_message) =
97+
get_error_code_message_from_body(body, content_type);
98+
99+
let code = header_err_code.or(body_err_code);
100+
Self {
101+
code,
102+
message: body_err_message,
103+
}
97104
}
98105
}
99106

100107
/// Gets the error code if it's present in the headers
101108
///
102109
/// For more info, see [here](https://github.com/microsoft/api-guidelines/blob/vNext/azure/Guidelines.md#handling-errors)
103-
fn get_error_code_from_header(headers: &HashMap<String, String>) -> Option<String> {
104-
headers.get(headers::ERROR_CODE.as_str()).cloned()
110+
pub(crate) fn get_error_code_from_header(headers: &Headers) -> Option<String> {
111+
headers.get_optional_string(&headers::ERROR_CODE)
105112
}
106113

107114
#[derive(Deserialize)]
108115
struct NestedError {
116+
#[serde(alias = "Message")]
109117
message: Option<String>,
118+
#[serde(alias = "Code")]
110119
code: Option<String>,
111120
}
112121

122+
/// Error from a response body, aliases are set because XML responses follow different case-ing
113123
#[derive(Deserialize)]
114124
struct ErrorBody {
125+
#[serde(alias = "Error")]
115126
error: Option<NestedError>,
127+
#[serde(alias = "Message")]
116128
message: Option<String>,
129+
#[serde(alias = "Code")]
117130
code: Option<String>,
118131
}
119132

120-
/// Gets the error code if it's present in the body
121-
///
122-
/// For more info, see [here](https://github.com/microsoft/api-guidelines/blob/vNext/azure/Guidelines.md#handling-errors)
123-
pub(crate) fn get_error_code_from_body(body: &[u8]) -> Option<String> {
124-
let decoded: ErrorBody = from_json(body).ok()?;
125-
decoded.error.and_then(|e| e.code).or(decoded.code)
133+
impl ErrorBody {
134+
/// Deconstructs self into error (code, message)
135+
///
136+
/// The nested errors fields take precedence over those in the root of the structure
137+
fn into_code_message(self) -> (Option<String>, Option<String>) {
138+
let (nested_code, nested_message) = self
139+
.error
140+
.map(|nested_error| (nested_error.code, nested_error.message))
141+
.unwrap_or((None, None));
142+
(nested_code.or(self.code), nested_message.or(self.message))
143+
}
126144
}
127145

128-
/// Gets the error message if it's present in the body
146+
/// Gets the error code and message from the body based on the specified content_type
147+
/// Support for xml decoding is dependent on the 'xml' feature flag
129148
///
149+
/// Assumes JSON if unspecified/inconclusive to maintain old behaviour
150+
/// [#1275](https://github.com/Azure/azure-sdk-for-rust/issues/1275)
130151
/// For more info, see [here](https://github.com/microsoft/api-guidelines/blob/vNext/azure/Guidelines.md#handling-errors)
131-
pub(crate) fn get_error_message_from_body(body: &[u8]) -> Option<String> {
132-
let decoded: ErrorBody = from_json(body).ok()?;
133-
decoded.error.and_then(|e| e.message).or(decoded.message)
152+
pub(crate) fn get_error_code_message_from_body(
153+
body: &[u8],
154+
content_type: Option<&str>,
155+
) -> (Option<String>, Option<String>) {
156+
let err_body: Option<ErrorBody> = if content_type
157+
.is_some_and(|ctype| ctype == content_type::APPLICATION_XML.as_str())
158+
{
159+
#[cfg(feature = "xml")]
160+
{
161+
crate::xml::read_xml(body).ok()
162+
}
163+
#[cfg(not(feature = "xml"))]
164+
{
165+
tracing::warn!("encountered XML response but the 'xml' feature flag was not specified");
166+
None
167+
}
168+
} else {
169+
// keep old default of assuming JSON
170+
from_json(body).ok()
171+
};
172+
173+
err_body
174+
.map(ErrorBody::into_code_message)
175+
.unwrap_or((None, None))
134176
}

sdk/core/src/error/mod.rs

+35-5
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,11 @@ use std::borrow::Cow;
33
use std::fmt::{Debug, Display};
44
mod http_error;
55
mod macros;
6+
use crate::headers::{self, Headers};
67
pub use http_error::HttpError;
78

9+
use self::http_error::get_error_code_from_header;
10+
811
/// A convenience alias for `Result` where the error type is hard coded to `Error`
912
pub type Result<T> = std::result::Result<T, Error>;
1013

@@ -41,9 +44,21 @@ impl ErrorKind {
4144
Self::HttpResponse { status, error_code }
4245
}
4346

44-
pub fn http_response_from_body(status: StatusCode, body: &[u8]) -> Self {
45-
let error_code = http_error::get_error_code_from_body(body);
46-
Self::HttpResponse { status, error_code }
47+
/// Constructs an [`ErrorKind::HttpResponse`] with given status code. The error code is taken from
48+
/// the header [`headers::ERROR_CODE`] if present; otherwise, it is taken from the body.
49+
pub fn http_response_from_parts(status: StatusCode, headers: &Headers, body: &[u8]) -> Self {
50+
if let Some(header_err_code) = get_error_code_from_header(headers) {
51+
Self::HttpResponse {
52+
status,
53+
error_code: Some(header_err_code),
54+
}
55+
} else {
56+
let (error_code, _) = http_error::get_error_code_message_from_body(
57+
body,
58+
headers.get_optional_str(&headers::CONTENT_TYPE),
59+
);
60+
Self::HttpResponse { status, error_code }
61+
}
4762
}
4863
}
4964

@@ -469,7 +484,8 @@ mod tests {
469484

470485
#[test]
471486
fn matching_against_http_error() {
472-
let kind = ErrorKind::http_response_from_body(StatusCode::ImATeapot, b"{}");
487+
let kind =
488+
ErrorKind::http_response_from_parts(StatusCode::ImATeapot, &Headers::new(), b"{}");
473489

474490
assert!(matches!(
475491
kind,
@@ -479,8 +495,9 @@ mod tests {
479495
}
480496
));
481497

482-
let kind = ErrorKind::http_response_from_body(
498+
let kind = ErrorKind::http_response_from_parts(
483499
StatusCode::ImATeapot,
500+
&Headers::new(),
484501
br#"{"error": {"code":"teepot"}}"#,
485502
);
486503

@@ -492,6 +509,19 @@ mod tests {
492509
}
493510
if error_code.as_deref() == Some("teepot")
494511
));
512+
513+
let mut headers = Headers::new();
514+
headers.insert(headers::ERROR_CODE, "teapot");
515+
let kind = ErrorKind::http_response_from_parts(StatusCode::ImATeapot, &headers, br#"{}"#);
516+
517+
assert!(matches!(
518+
kind,
519+
ErrorKind::HttpResponse {
520+
status: StatusCode::ImATeapot,
521+
error_code
522+
}
523+
if error_code.as_deref() == Some("teapot")
524+
));
495525
}
496526

497527
#[test]

sdk/core/src/http_client/mod.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ pub trait HttpClient: Send + Sync + std::fmt::Debug {
4949
if status.is_success() {
5050
Ok(crate::CollectedResponse::new(status, headers, body))
5151
} else {
52-
Err(ErrorKind::http_response_from_body(status, &body).into_error())
52+
Err(ErrorKind::http_response_from_parts(status, &headers, &body).into_error())
5353
}
5454
}
5555
}

sdk/identity/src/client_credentials_flow/mod.rs

+5-3
Original file line numberDiff line numberDiff line change
@@ -76,10 +76,12 @@ pub async fn perform(
7676
);
7777
req.set_body(encoded);
7878
let rsp = http_client.execute_request(&req).await?;
79-
let rsp_status = rsp.status();
80-
let rsp_body = rsp.into_body().collect().await?;
79+
let (rsp_status, rsp_headers, rsp_body) = rsp.deconstruct();
80+
let rsp_body = rsp_body.collect().await?;
8181
if !rsp_status.is_success() {
82-
return Err(ErrorKind::http_response_from_body(rsp_status, &rsp_body).into_error());
82+
return Err(
83+
ErrorKind::http_response_from_parts(rsp_status, &rsp_headers, &rsp_body).into_error(),
84+
);
8385
}
8486
from_json(&rsp_body)
8587
}

sdk/identity/src/device_code_flow/mod.rs

+5-3
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,12 @@ where
3636
.finish();
3737

3838
let rsp = post_form(http_client.clone(), url, encoded).await?;
39-
let rsp_status = rsp.status();
40-
let rsp_body = rsp.into_body().collect().await?;
39+
let (rsp_status, rsp_headers, rsp_body) = rsp.deconstruct();
40+
let rsp_body = rsp_body.collect().await?;
4141
if !rsp_status.is_success() {
42-
return Err(ErrorKind::http_response_from_body(rsp_status, &rsp_body).into_error());
42+
return Err(
43+
ErrorKind::http_response_from_parts(rsp_status, &rsp_headers, &rsp_body).into_error(),
44+
);
4345
}
4446
let device_code_response: DeviceCodePhaseOneResponse = from_json(&rsp_body)?;
4547

sdk/identity/src/federated_credentials_flow/mod.rs

+3-2
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,10 @@ pub async fn perform(
8383
if rsp_status.is_success() {
8484
rsp.json().await
8585
} else {
86-
let rsp_body = rsp.into_body().collect().await?;
86+
let (rsp_status, rsp_headers, rsp_body) = rsp.deconstruct();
87+
let rsp_body = rsp_body.collect().await?;
8788
let text = std::str::from_utf8(&rsp_body)?;
8889
error!("rsp_body == {:?}", text);
89-
Err(ErrorKind::http_response_from_body(rsp_status, &rsp_body).into_error())
90+
Err(ErrorKind::http_response_from_parts(rsp_status, &rsp_headers, &rsp_body).into_error())
9091
}
9192
}

sdk/identity/src/refresh_token.rs

+5-3
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,11 @@ pub async fn exchange(
5050
if rsp_status.is_success() {
5151
rsp.json().await.map_kind(ErrorKind::Credential)
5252
} else {
53-
let rsp_body = rsp.into_body().collect().await?;
54-
let token_error: RefreshTokenError = from_json(&rsp_body)
55-
.map_err(|_| ErrorKind::http_response_from_body(rsp_status, &rsp_body))?;
53+
let (rsp_status, rsp_headers, rsp_body) = rsp.deconstruct();
54+
let rsp_body = rsp_body.collect().await?;
55+
let token_error: RefreshTokenError = from_json(&rsp_body).map_err(|_| {
56+
ErrorKind::http_response_from_parts(rsp_status, &rsp_headers, &rsp_body)
57+
})?;
5658
Err(Error::new(ErrorKind::Credential, token_error))
5759
}
5860
}

sdk/identity/src/token_credentials/imds_managed_identity_credentials.rs

+8-4
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,9 @@ impl ImdsManagedIdentityCredential {
8989
};
9090

9191
let rsp = self.http_client.execute_request(&req).await?;
92-
let rsp_status = rsp.status();
93-
let rsp_body = rsp.into_body().collect().await?;
92+
93+
let (rsp_status, rsp_headers, rsp_body) = rsp.deconstruct();
94+
let rsp_body = rsp_body.collect().await?;
9495

9596
if !rsp_status.is_success() {
9697
match rsp_status {
@@ -107,9 +108,12 @@ impl ImdsManagedIdentityCredential {
107108
))
108109
}
109110
rsp_status => {
110-
return Err(
111-
ErrorKind::http_response_from_body(rsp_status, &rsp_body).into_error()
111+
return Err(ErrorKind::http_response_from_parts(
112+
rsp_status,
113+
&rsp_headers,
114+
&rsp_body,
112115
)
116+
.into_error())
113117
}
114118
}
115119
}

0 commit comments

Comments
 (0)