1use std::{error::Error as StdError, fmt, num::ParseIntError, sync::Arc};
6
7use as_variant::as_variant;
8use bytes::{BufMut, Bytes};
9use serde::{Deserialize, Serialize};
10use serde_json::{Value as JsonValue, from_slice as from_json_slice};
11use thiserror::Error;
12
13mod kind;
14mod kind_serde;
15#[cfg(test)]
16mod tests;
17
18pub use self::kind::*;
19use super::{EndpointError, MatrixVersion, OutgoingResponse};
20
21#[derive(Clone, Debug)]
23#[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
24pub struct Error {
25 pub status_code: http::StatusCode,
27
28 pub body: ErrorBody,
30}
31
32impl Error {
33 pub fn new(status_code: http::StatusCode, body: ErrorBody) -> Self {
37 Self { status_code, body }
38 }
39
40 pub fn error_kind(&self) -> Option<&ErrorKind> {
42 as_variant!(&self.body, ErrorBody::Standard(StandardErrorBody { kind, .. }) => kind)
43 }
44
45 pub fn is_endpoint_not_implemented(&self) -> bool {
53 self.status_code == http::StatusCode::NOT_FOUND
54 && self
55 .error_kind()
56 .is_some_and(|error_kind| matches!(error_kind, ErrorKind::Unrecognized))
57 }
58}
59
60impl fmt::Display for Error {
61 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
62 let status_code = self.status_code.as_u16();
63 match &self.body {
64 ErrorBody::Standard(StandardErrorBody { kind, message }) => {
65 let errcode = kind.errcode();
66 write!(f, "[{status_code} / {errcode}] {message}")
67 }
68 ErrorBody::Json(json) => write!(f, "[{status_code}] {json}"),
69 ErrorBody::NotJson { .. } => write!(f, "[{status_code}] <non-json bytes>"),
70 }
71 }
72}
73
74impl StdError for Error {}
75
76impl OutgoingResponse for Error {
77 fn try_into_http_response<T: Default + BufMut>(
78 self,
79 ) -> Result<http::Response<T>, IntoHttpError> {
80 let mut builder = http::Response::builder()
81 .header(http::header::CONTENT_TYPE, ruma_common::http_headers::APPLICATION_JSON)
82 .status(self.status_code);
83
84 if let Some(ErrorKind::LimitExceeded(LimitExceededErrorData {
86 retry_after: Some(retry_after),
87 })) = self.error_kind()
88 {
89 let header_value = http::HeaderValue::try_from(retry_after)?;
90 builder = builder.header(http::header::RETRY_AFTER, header_value);
91 }
92
93 builder
94 .body(match self.body {
95 ErrorBody::Standard(standard_body) => {
96 ruma_common::serde::json_to_buf(&standard_body)?
97 }
98 ErrorBody::Json(json) => ruma_common::serde::json_to_buf(&json)?,
99 ErrorBody::NotJson { .. } => {
100 return Err(IntoHttpError::Json(serde::ser::Error::custom(
101 "attempted to serialize ErrorBody::NotJson",
102 )));
103 }
104 })
105 .map_err(Into::into)
106 }
107}
108
109impl EndpointError for Error {
110 fn from_http_response<T: AsRef<[u8]>>(response: http::Response<T>) -> Self {
111 let status = response.status();
112
113 let body_bytes = &response.body().as_ref();
114 let error_body: ErrorBody = match from_json_slice::<StandardErrorBody>(body_bytes) {
115 Ok(mut standard_body) => {
116 let headers = response.headers();
117
118 if let ErrorKind::LimitExceeded(LimitExceededErrorData { retry_after }) =
119 &mut standard_body.kind
120 {
121 if let Some(Ok(retry_after_header)) =
124 headers.get(http::header::RETRY_AFTER).map(RetryAfter::try_from)
125 {
126 *retry_after = Some(retry_after_header);
127 }
128 }
129
130 ErrorBody::Standard(standard_body)
131 }
132 Err(_) => match from_json_slice(body_bytes) {
133 Ok(json) => ErrorBody::Json(json),
134 Err(error) => ErrorBody::NotJson {
135 bytes: Bytes::copy_from_slice(body_bytes),
136 deserialization_error: Arc::new(error),
137 },
138 },
139 };
140
141 error_body.into_error(status)
142 }
143}
144
145#[derive(Debug, Clone)]
147#[allow(clippy::exhaustive_enums)]
148pub enum ErrorBody {
149 Standard(StandardErrorBody),
151
152 Json(JsonValue),
154
155 NotJson {
157 bytes: Bytes,
159
160 deserialization_error: Arc<serde_json::Error>,
162 },
163}
164
165impl ErrorBody {
166 pub fn into_error(self, status_code: http::StatusCode) -> Error {
170 Error { status_code, body: self }
171 }
172}
173
174#[derive(Clone, Debug, Deserialize, Serialize)]
176#[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
177pub struct StandardErrorBody {
178 #[serde(flatten)]
180 pub kind: ErrorKind,
181
182 #[serde(rename = "error")]
184 pub message: String,
185}
186
187impl StandardErrorBody {
188 pub fn new(kind: ErrorKind, message: String) -> Self {
190 Self { kind, message }
191 }
192}
193
194#[derive(Debug, Error)]
197#[non_exhaustive]
198pub enum IntoHttpError {
199 #[error("failed to add authentication scheme: {0}")]
201 Authentication(Box<dyn std::error::Error + Send + Sync + 'static>),
202
203 #[error(
208 "endpoint was not supported by server-reported versions, \
209 but no unstable path to fall back to was defined"
210 )]
211 NoUnstablePath,
212
213 #[error(
216 "could not create any path variant for endpoint, as it was removed in version {}",
217 .0.as_str().expect("no endpoint was removed in Matrix 1.0")
218 )]
219 EndpointRemoved(MatrixVersion),
220
221 #[error("JSON serialization failed: {0}")]
223 Json(#[from] serde_json::Error),
224
225 #[error("query parameter serialization failed: {0}")]
227 Query(#[from] serde_html_form::ser::Error),
228
229 #[error("header serialization failed: {0}")]
231 Header(#[from] HeaderSerializationError),
232
233 #[error("HTTP request construction failed: {0}")]
235 Http(#[from] http::Error),
236}
237
238impl From<http::header::InvalidHeaderValue> for IntoHttpError {
239 fn from(value: http::header::InvalidHeaderValue) -> Self {
240 Self::Header(value.into())
241 }
242}
243
244#[derive(Debug, Error)]
246#[non_exhaustive]
247pub enum FromHttpRequestError {
248 #[error("deserialization failed: {0}")]
250 Deserialization(DeserializationError),
251
252 #[error("http method mismatch: expected {expected}, received: {received}")]
254 MethodMismatch {
255 expected: http::method::Method,
257 received: http::method::Method,
259 },
260}
261
262impl<T> From<T> for FromHttpRequestError
263where
264 T: Into<DeserializationError>,
265{
266 fn from(err: T) -> Self {
267 Self::Deserialization(err.into())
268 }
269}
270
271#[derive(Debug)]
273#[non_exhaustive]
274pub enum FromHttpResponseError<E> {
275 Deserialization(DeserializationError),
277
278 Server(E),
280}
281
282impl<E> FromHttpResponseError<E> {
283 pub fn map<F>(self, f: impl FnOnce(E) -> F) -> FromHttpResponseError<F> {
286 match self {
287 Self::Deserialization(d) => FromHttpResponseError::Deserialization(d),
288 Self::Server(s) => FromHttpResponseError::Server(f(s)),
289 }
290 }
291}
292
293impl<E, F> FromHttpResponseError<Result<E, F>> {
294 pub fn transpose(self) -> Result<FromHttpResponseError<E>, F> {
296 match self {
297 Self::Deserialization(d) => Ok(FromHttpResponseError::Deserialization(d)),
298 Self::Server(s) => s.map(FromHttpResponseError::Server),
299 }
300 }
301}
302
303impl<E: fmt::Display> fmt::Display for FromHttpResponseError<E> {
304 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
305 match self {
306 Self::Deserialization(err) => write!(f, "deserialization failed: {err}"),
307 Self::Server(err) => write!(f, "the server returned an error: {err}"),
308 }
309 }
310}
311
312impl<E, T> From<T> for FromHttpResponseError<E>
313where
314 T: Into<DeserializationError>,
315{
316 fn from(err: T) -> Self {
317 Self::Deserialization(err.into())
318 }
319}
320
321impl<E: StdError> StdError for FromHttpResponseError<E> {}
322
323pub trait FromHttpResponseErrorExt {
325 fn error_kind(&self) -> Option<&ErrorKind>;
328}
329
330impl FromHttpResponseErrorExt for FromHttpResponseError<Error> {
331 fn error_kind(&self) -> Option<&ErrorKind> {
332 as_variant!(self, Self::Server)?.error_kind()
333 }
334}
335
336#[derive(Debug, Error)]
339#[non_exhaustive]
340pub enum DeserializationError {
341 #[error(transparent)]
343 Utf8(#[from] std::str::Utf8Error),
344
345 #[error(transparent)]
347 Json(#[from] serde_json::Error),
348
349 #[error(transparent)]
351 Query(#[from] serde_html_form::de::Error),
352
353 #[error(transparent)]
355 Ident(#[from] crate::IdParseError),
356
357 #[error(transparent)]
359 Header(#[from] HeaderDeserializationError),
360
361 #[error(transparent)]
363 MultipartMixed(#[from] MultipartMixedDeserializationError),
364}
365
366impl From<std::convert::Infallible> for DeserializationError {
367 fn from(err: std::convert::Infallible) -> Self {
368 match err {}
369 }
370}
371
372impl From<http::header::ToStrError> for DeserializationError {
373 fn from(err: http::header::ToStrError) -> Self {
374 Self::Header(HeaderDeserializationError::ToStrError(err))
375 }
376}
377
378#[derive(Debug, Error)]
380#[non_exhaustive]
381pub enum HeaderDeserializationError {
382 #[error("{0}")]
384 ToStrError(#[from] http::header::ToStrError),
385
386 #[error("{0}")]
388 ParseIntError(#[from] ParseIntError),
389
390 #[error("failed to parse HTTP date")]
392 InvalidHttpDate,
393
394 #[error("missing header `{0}`")]
396 MissingHeader(String),
397
398 #[error("invalid header: {0}")]
400 InvalidHeader(Box<dyn std::error::Error + Send + Sync + 'static>),
401
402 #[error(
404 "The {header} header was received with an unexpected value, \
405 expected {expected}, received {unexpected}"
406 )]
407 InvalidHeaderValue {
408 header: String,
410 expected: String,
412 unexpected: String,
414 },
415
416 #[error(
419 "The `Content-Type` header for a `multipart/mixed` response is missing the `boundary` attribute"
420 )]
421 MissingMultipartBoundary,
422}
423
424#[derive(Debug, Error)]
426#[non_exhaustive]
427pub enum MultipartMixedDeserializationError {
428 #[error(
430 "multipart/mixed response does not have enough body parts, \
431 expected {expected}, found {found}"
432 )]
433 MissingBodyParts {
434 expected: usize,
436 found: usize,
438 },
439
440 #[error("multipart/mixed body part is missing separator between headers and content")]
442 MissingBodyPartInnerSeparator,
443
444 #[error("multipart/mixed body part header is missing separator between name and value")]
446 MissingHeaderSeparator,
447
448 #[error("invalid multipart/mixed header: {0}")]
450 InvalidHeader(Box<dyn std::error::Error + Send + Sync + 'static>),
451}
452
453#[derive(Debug)]
455#[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
456pub struct UnknownVersionError;
457
458impl fmt::Display for UnknownVersionError {
459 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
460 write!(f, "version string was unknown")
461 }
462}
463
464impl StdError for UnknownVersionError {}
465
466#[derive(Debug)]
471#[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
472pub struct IncorrectArgumentCount {
473 pub expected: usize,
475
476 pub got: usize,
478}
479
480impl fmt::Display for IncorrectArgumentCount {
481 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
482 write!(f, "incorrect path argument count, expected {}, got {}", self.expected, self.got)
483 }
484}
485
486impl StdError for IncorrectArgumentCount {}
487
488#[derive(Debug, Error)]
490#[non_exhaustive]
491pub enum HeaderSerializationError {
492 #[error(transparent)]
494 ToHeaderValue(#[from] http::header::InvalidHeaderValue),
495
496 #[error("invalid HTTP date")]
501 InvalidHttpDate,
502}