1use std::{error::Error as StdError, fmt, num::ParseIntError, sync::Arc};
6
7use as_variant::as_variant;
8use bytes::{BufMut, Bytes};
9use serde::{Deserialize, Serialize};
10use serde_json::{Value as JsonValue, from_slice as from_json_slice};
11use thiserror::Error;
12
13mod kind;
14mod kind_serde;
15#[cfg(test)]
16mod tests;
17
18pub use self::kind::*;
19use super::{EndpointError, MatrixVersion, OutgoingResponse};
20
21#[derive(Clone, Debug)]
23#[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
24pub struct Error {
25 pub status_code: http::StatusCode,
27
28 pub body: ErrorBody,
30}
31
32impl Error {
33 pub fn new(status_code: http::StatusCode, body: ErrorBody) -> Self {
37 Self { status_code, body }
38 }
39
40 pub fn error_kind(&self) -> Option<&ErrorKind> {
43 as_variant!(&self.body, ErrorBody::Standard(StandardErrorBody { kind, .. }) => kind)
44 }
45}
46
47impl fmt::Display for Error {
48 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
49 let status_code = self.status_code.as_u16();
50 match &self.body {
51 ErrorBody::Standard(StandardErrorBody { kind, message }) => {
52 let errcode = kind.errcode();
53 write!(f, "[{status_code} / {errcode}] {message}")
54 }
55 ErrorBody::Json(json) => write!(f, "[{status_code}] {json}"),
56 ErrorBody::NotJson { .. } => write!(f, "[{status_code}] <non-json bytes>"),
57 }
58 }
59}
60
61impl StdError for Error {}
62
63impl OutgoingResponse for Error {
64 fn try_into_http_response<T: Default + BufMut>(
65 self,
66 ) -> Result<http::Response<T>, IntoHttpError> {
67 let mut builder = http::Response::builder()
68 .header(http::header::CONTENT_TYPE, ruma_common::http_headers::APPLICATION_JSON)
69 .status(self.status_code);
70
71 if let Some(ErrorKind::LimitExceeded(LimitExceededErrorData {
73 retry_after: Some(retry_after),
74 })) = self.error_kind()
75 {
76 let header_value = http::HeaderValue::try_from(retry_after)?;
77 builder = builder.header(http::header::RETRY_AFTER, header_value);
78 }
79
80 builder
81 .body(match self.body {
82 ErrorBody::Standard(standard_body) => {
83 ruma_common::serde::json_to_buf(&standard_body)?
84 }
85 ErrorBody::Json(json) => ruma_common::serde::json_to_buf(&json)?,
86 ErrorBody::NotJson { .. } => {
87 return Err(IntoHttpError::Json(serde::ser::Error::custom(
88 "attempted to serialize ErrorBody::NotJson",
89 )));
90 }
91 })
92 .map_err(Into::into)
93 }
94}
95
96impl EndpointError for Error {
97 fn from_http_response<T: AsRef<[u8]>>(response: http::Response<T>) -> Self {
98 let status = response.status();
99
100 let body_bytes = &response.body().as_ref();
101 let error_body: ErrorBody = match from_json_slice::<StandardErrorBody>(body_bytes) {
102 Ok(mut standard_body) => {
103 let headers = response.headers();
104
105 if let ErrorKind::LimitExceeded(LimitExceededErrorData { retry_after }) =
106 &mut standard_body.kind
107 {
108 if let Some(Ok(retry_after_header)) =
111 headers.get(http::header::RETRY_AFTER).map(RetryAfter::try_from)
112 {
113 *retry_after = Some(retry_after_header);
114 }
115 }
116
117 ErrorBody::Standard(standard_body)
118 }
119 Err(_) => match from_json_slice(body_bytes) {
120 Ok(json) => ErrorBody::Json(json),
121 Err(error) => ErrorBody::NotJson {
122 bytes: Bytes::copy_from_slice(body_bytes),
123 deserialization_error: Arc::new(error),
124 },
125 },
126 };
127
128 error_body.into_error(status)
129 }
130}
131
132#[derive(Debug, Clone)]
134#[allow(clippy::exhaustive_enums)]
135pub enum ErrorBody {
136 Standard(StandardErrorBody),
138
139 Json(JsonValue),
141
142 NotJson {
144 bytes: Bytes,
146
147 deserialization_error: Arc<serde_json::Error>,
149 },
150}
151
152impl ErrorBody {
153 pub fn into_error(self, status_code: http::StatusCode) -> Error {
157 Error { status_code, body: self }
158 }
159}
160
161#[derive(Clone, Debug, Deserialize, Serialize)]
163#[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
164pub struct StandardErrorBody {
165 #[serde(flatten)]
167 pub kind: ErrorKind,
168
169 #[serde(rename = "error")]
171 pub message: String,
172}
173
174impl StandardErrorBody {
175 pub fn new(kind: ErrorKind, message: String) -> Self {
177 Self { kind, message }
178 }
179}
180
181#[derive(Debug, Error)]
184#[non_exhaustive]
185pub enum IntoHttpError {
186 #[error("failed to add authentication scheme: {0}")]
188 Authentication(Box<dyn std::error::Error + Send + Sync + 'static>),
189
190 #[error(
195 "endpoint was not supported by server-reported versions, \
196 but no unstable path to fall back to was defined"
197 )]
198 NoUnstablePath,
199
200 #[error(
203 "could not create any path variant for endpoint, as it was removed in version {}",
204 .0.as_str().expect("no endpoint was removed in Matrix 1.0")
205 )]
206 EndpointRemoved(MatrixVersion),
207
208 #[error("JSON serialization failed: {0}")]
210 Json(#[from] serde_json::Error),
211
212 #[error("query parameter serialization failed: {0}")]
214 Query(#[from] serde_html_form::ser::Error),
215
216 #[error("header serialization failed: {0}")]
218 Header(#[from] HeaderSerializationError),
219
220 #[error("HTTP request construction failed: {0}")]
222 Http(#[from] http::Error),
223}
224
225impl From<http::header::InvalidHeaderValue> for IntoHttpError {
226 fn from(value: http::header::InvalidHeaderValue) -> Self {
227 Self::Header(value.into())
228 }
229}
230
231#[derive(Debug, Error)]
233#[non_exhaustive]
234pub enum FromHttpRequestError {
235 #[error("deserialization failed: {0}")]
237 Deserialization(DeserializationError),
238
239 #[error("http method mismatch: expected {expected}, received: {received}")]
241 MethodMismatch {
242 expected: http::method::Method,
244 received: http::method::Method,
246 },
247}
248
249impl<T> From<T> for FromHttpRequestError
250where
251 T: Into<DeserializationError>,
252{
253 fn from(err: T) -> Self {
254 Self::Deserialization(err.into())
255 }
256}
257
258#[derive(Debug)]
260#[non_exhaustive]
261pub enum FromHttpResponseError<E> {
262 Deserialization(DeserializationError),
264
265 Server(E),
267}
268
269impl<E> FromHttpResponseError<E> {
270 pub fn map<F>(self, f: impl FnOnce(E) -> F) -> FromHttpResponseError<F> {
273 match self {
274 Self::Deserialization(d) => FromHttpResponseError::Deserialization(d),
275 Self::Server(s) => FromHttpResponseError::Server(f(s)),
276 }
277 }
278}
279
280impl<E, F> FromHttpResponseError<Result<E, F>> {
281 pub fn transpose(self) -> Result<FromHttpResponseError<E>, F> {
283 match self {
284 Self::Deserialization(d) => Ok(FromHttpResponseError::Deserialization(d)),
285 Self::Server(s) => s.map(FromHttpResponseError::Server),
286 }
287 }
288}
289
290impl<E: fmt::Display> fmt::Display for FromHttpResponseError<E> {
291 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
292 match self {
293 Self::Deserialization(err) => write!(f, "deserialization failed: {err}"),
294 Self::Server(err) => write!(f, "the server returned an error: {err}"),
295 }
296 }
297}
298
299impl<E, T> From<T> for FromHttpResponseError<E>
300where
301 T: Into<DeserializationError>,
302{
303 fn from(err: T) -> Self {
304 Self::Deserialization(err.into())
305 }
306}
307
308impl<E: StdError> StdError for FromHttpResponseError<E> {}
309
310pub trait FromHttpResponseErrorExt {
312 fn error_kind(&self) -> Option<&ErrorKind>;
315}
316
317impl FromHttpResponseErrorExt for FromHttpResponseError<Error> {
318 fn error_kind(&self) -> Option<&ErrorKind> {
319 as_variant!(self, Self::Server)?.error_kind()
320 }
321}
322
323#[derive(Debug, Error)]
326#[non_exhaustive]
327pub enum DeserializationError {
328 #[error(transparent)]
330 Utf8(#[from] std::str::Utf8Error),
331
332 #[error(transparent)]
334 Json(#[from] serde_json::Error),
335
336 #[error(transparent)]
338 Query(#[from] serde_html_form::de::Error),
339
340 #[error(transparent)]
342 Ident(#[from] crate::IdParseError),
343
344 #[error(transparent)]
346 Header(#[from] HeaderDeserializationError),
347
348 #[error(transparent)]
350 MultipartMixed(#[from] MultipartMixedDeserializationError),
351}
352
353impl From<std::convert::Infallible> for DeserializationError {
354 fn from(err: std::convert::Infallible) -> Self {
355 match err {}
356 }
357}
358
359impl From<http::header::ToStrError> for DeserializationError {
360 fn from(err: http::header::ToStrError) -> Self {
361 Self::Header(HeaderDeserializationError::ToStrError(err))
362 }
363}
364
365#[derive(Debug, Error)]
367#[non_exhaustive]
368pub enum HeaderDeserializationError {
369 #[error("{0}")]
371 ToStrError(#[from] http::header::ToStrError),
372
373 #[error("{0}")]
375 ParseIntError(#[from] ParseIntError),
376
377 #[error("failed to parse HTTP date")]
379 InvalidHttpDate,
380
381 #[error("missing header `{0}`")]
383 MissingHeader(String),
384
385 #[error("invalid header: {0}")]
387 InvalidHeader(Box<dyn std::error::Error + Send + Sync + 'static>),
388
389 #[error(
391 "The {header} header was received with an unexpected value, \
392 expected {expected}, received {unexpected}"
393 )]
394 InvalidHeaderValue {
395 header: String,
397 expected: String,
399 unexpected: String,
401 },
402
403 #[error(
406 "The `Content-Type` header for a `multipart/mixed` response is missing the `boundary` attribute"
407 )]
408 MissingMultipartBoundary,
409}
410
411#[derive(Debug, Error)]
413#[non_exhaustive]
414pub enum MultipartMixedDeserializationError {
415 #[error(
417 "multipart/mixed response does not have enough body parts, \
418 expected {expected}, found {found}"
419 )]
420 MissingBodyParts {
421 expected: usize,
423 found: usize,
425 },
426
427 #[error("multipart/mixed body part is missing separator between headers and content")]
429 MissingBodyPartInnerSeparator,
430
431 #[error("multipart/mixed body part header is missing separator between name and value")]
433 MissingHeaderSeparator,
434
435 #[error("invalid multipart/mixed header: {0}")]
437 InvalidHeader(Box<dyn std::error::Error + Send + Sync + 'static>),
438}
439
440#[derive(Debug)]
442#[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
443pub struct UnknownVersionError;
444
445impl fmt::Display for UnknownVersionError {
446 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
447 write!(f, "version string was unknown")
448 }
449}
450
451impl StdError for UnknownVersionError {}
452
453#[derive(Debug)]
458#[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
459pub struct IncorrectArgumentCount {
460 pub expected: usize,
462
463 pub got: usize,
465}
466
467impl fmt::Display for IncorrectArgumentCount {
468 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
469 write!(f, "incorrect path argument count, expected {}, got {}", self.expected, self.got)
470 }
471}
472
473impl StdError for IncorrectArgumentCount {}
474
475#[derive(Debug, Error)]
477#[non_exhaustive]
478pub enum HeaderSerializationError {
479 #[error(transparent)]
481 ToHeaderValue(#[from] http::header::InvalidHeaderValue),
482
483 #[error("invalid HTTP date")]
488 InvalidHttpDate,
489}