1use std::{error::Error as StdError, fmt, num::ParseIntError, sync::Arc};
6
7use bytes::{BufMut, Bytes};
8use serde_json::{from_slice as from_json_slice, Value as JsonValue};
9use thiserror::Error;
10
11use super::{EndpointError, MatrixVersion, OutgoingResponse};
12
13#[allow(clippy::exhaustive_structs)]
17#[derive(Clone, Debug)]
18pub struct MatrixError {
19    pub status_code: http::StatusCode,
21
22    pub body: MatrixErrorBody,
24}
25
26impl fmt::Display for MatrixError {
27    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
28        let status_code = self.status_code.as_u16();
29        match &self.body {
30            MatrixErrorBody::Json(json) => write!(f, "[{status_code}] {json}"),
31            MatrixErrorBody::NotJson { .. } => write!(f, "[{status_code}] <non-json bytes>"),
32        }
33    }
34}
35
36impl StdError for MatrixError {}
37
38impl OutgoingResponse for MatrixError {
39    fn try_into_http_response<T: Default + BufMut>(
40        self,
41    ) -> Result<http::Response<T>, IntoHttpError> {
42        http::Response::builder()
43            .header(http::header::CONTENT_TYPE, crate::http_headers::APPLICATION_JSON)
44            .status(self.status_code)
45            .body(match self.body {
46                MatrixErrorBody::Json(json) => crate::serde::json_to_buf(&json)?,
47                MatrixErrorBody::NotJson { .. } => {
48                    return Err(IntoHttpError::Json(serde::ser::Error::custom(
49                        "attempted to serialize MatrixErrorBody::NotJson",
50                    )));
51                }
52            })
53            .map_err(Into::into)
54    }
55}
56
57impl EndpointError for MatrixError {
58    fn from_http_response<T: AsRef<[u8]>>(response: http::Response<T>) -> Self {
59        let status_code = response.status();
60        let body = MatrixErrorBody::from_bytes(response.body().as_ref());
61        Self { status_code, body }
62    }
63}
64
65#[derive(Clone, Debug)]
67#[allow(clippy::exhaustive_enums)]
68pub enum MatrixErrorBody {
69    Json(JsonValue),
71
72    NotJson {
74        bytes: Bytes,
76
77        deserialization_error: Arc<serde_json::Error>,
79    },
80}
81
82impl MatrixErrorBody {
83    pub fn from_bytes(body_bytes: &[u8]) -> Self {
85        match from_json_slice(body_bytes) {
86            Ok(json) => MatrixErrorBody::Json(json),
87            Err(e) => MatrixErrorBody::NotJson {
88                bytes: Bytes::copy_from_slice(body_bytes),
89                deserialization_error: Arc::new(e),
90            },
91        }
92    }
93}
94
95#[derive(Debug, Error)]
98#[non_exhaustive]
99pub enum IntoHttpError {
100    #[error("failed to add authentication scheme: {0}")]
102    Authentication(Box<dyn std::error::Error + Send + Sync + 'static>),
103
104    #[error(
109        "endpoint was not supported by server-reported versions, \
110         but no unstable path to fall back to was defined"
111    )]
112    NoUnstablePath,
113
114    #[error(
117        "could not create any path variant for endpoint, as it was removed in version {}",
118        .0.as_str().expect("no endpoint was removed in Matrix 1.0")
119    )]
120    EndpointRemoved(MatrixVersion),
121
122    #[error("JSON serialization failed: {0}")]
124    Json(#[from] serde_json::Error),
125
126    #[error("query parameter serialization failed: {0}")]
128    Query(#[from] serde_html_form::ser::Error),
129
130    #[error("header serialization failed: {0}")]
132    Header(#[from] HeaderSerializationError),
133
134    #[error("HTTP request construction failed: {0}")]
136    Http(#[from] http::Error),
137}
138
139impl From<http::header::InvalidHeaderValue> for IntoHttpError {
140    fn from(value: http::header::InvalidHeaderValue) -> Self {
141        Self::Header(value.into())
142    }
143}
144
145#[derive(Debug, Error)]
147#[non_exhaustive]
148pub enum FromHttpRequestError {
149    #[error("deserialization failed: {0}")]
151    Deserialization(DeserializationError),
152
153    #[error("http method mismatch: expected {expected}, received: {received}")]
155    MethodMismatch {
156        expected: http::method::Method,
158        received: http::method::Method,
160    },
161}
162
163impl<T> From<T> for FromHttpRequestError
164where
165    T: Into<DeserializationError>,
166{
167    fn from(err: T) -> Self {
168        Self::Deserialization(err.into())
169    }
170}
171
172#[derive(Debug)]
174#[non_exhaustive]
175pub enum FromHttpResponseError<E> {
176    Deserialization(DeserializationError),
178
179    Server(E),
181}
182
183impl<E> FromHttpResponseError<E> {
184    pub fn map<F>(self, f: impl FnOnce(E) -> F) -> FromHttpResponseError<F> {
187        match self {
188            Self::Deserialization(d) => FromHttpResponseError::Deserialization(d),
189            Self::Server(s) => FromHttpResponseError::Server(f(s)),
190        }
191    }
192}
193
194impl<E, F> FromHttpResponseError<Result<E, F>> {
195    pub fn transpose(self) -> Result<FromHttpResponseError<E>, F> {
197        match self {
198            Self::Deserialization(d) => Ok(FromHttpResponseError::Deserialization(d)),
199            Self::Server(s) => s.map(FromHttpResponseError::Server),
200        }
201    }
202}
203
204impl<E: fmt::Display> fmt::Display for FromHttpResponseError<E> {
205    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
206        match self {
207            Self::Deserialization(err) => write!(f, "deserialization failed: {err}"),
208            Self::Server(err) => write!(f, "the server returned an error: {err}"),
209        }
210    }
211}
212
213impl<E, T> From<T> for FromHttpResponseError<E>
214where
215    T: Into<DeserializationError>,
216{
217    fn from(err: T) -> Self {
218        Self::Deserialization(err.into())
219    }
220}
221
222impl<E: StdError> StdError for FromHttpResponseError<E> {}
223
224#[derive(Debug, Error)]
227#[non_exhaustive]
228pub enum DeserializationError {
229    #[error(transparent)]
231    Utf8(#[from] std::str::Utf8Error),
232
233    #[error(transparent)]
235    Json(#[from] serde_json::Error),
236
237    #[error(transparent)]
239    Query(#[from] serde_html_form::de::Error),
240
241    #[error(transparent)]
243    Ident(#[from] crate::IdParseError),
244
245    #[error(transparent)]
247    Header(#[from] HeaderDeserializationError),
248
249    #[error(transparent)]
251    MultipartMixed(#[from] MultipartMixedDeserializationError),
252}
253
254impl From<std::convert::Infallible> for DeserializationError {
255    fn from(err: std::convert::Infallible) -> Self {
256        match err {}
257    }
258}
259
260impl From<http::header::ToStrError> for DeserializationError {
261    fn from(err: http::header::ToStrError) -> Self {
262        Self::Header(HeaderDeserializationError::ToStrError(err))
263    }
264}
265
266#[derive(Debug, Error)]
268#[non_exhaustive]
269pub enum HeaderDeserializationError {
270    #[error("{0}")]
272    ToStrError(#[from] http::header::ToStrError),
273
274    #[error("{0}")]
276    ParseIntError(#[from] ParseIntError),
277
278    #[error("failed to parse HTTP date")]
280    InvalidHttpDate,
281
282    #[error("missing header `{0}`")]
284    MissingHeader(String),
285
286    #[error("invalid header: {0}")]
288    InvalidHeader(Box<dyn std::error::Error + Send + Sync + 'static>),
289
290    #[error(
292        "The {header} header was received with an unexpected value, \
293         expected {expected}, received {unexpected}"
294    )]
295    InvalidHeaderValue {
296        header: String,
298        expected: String,
300        unexpected: String,
302    },
303
304    #[error(
307        "The `Content-Type` header for a `multipart/mixed` response is missing the `boundary` attribute"
308    )]
309    MissingMultipartBoundary,
310}
311
312#[derive(Debug, Error)]
314#[non_exhaustive]
315pub enum MultipartMixedDeserializationError {
316    #[error(
318        "multipart/mixed response does not have enough body parts, \
319         expected {expected}, found {found}"
320    )]
321    MissingBodyParts {
322        expected: usize,
324        found: usize,
326    },
327
328    #[error("multipart/mixed body part is missing separator between headers and content")]
330    MissingBodyPartInnerSeparator,
331
332    #[error("multipart/mixed body part header is missing separator between name and value")]
334    MissingHeaderSeparator,
335
336    #[error("invalid multipart/mixed header: {0}")]
338    InvalidHeader(Box<dyn std::error::Error + Send + Sync + 'static>),
339}
340
341#[derive(Debug)]
343#[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
344pub struct UnknownVersionError;
345
346impl fmt::Display for UnknownVersionError {
347    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
348        write!(f, "version string was unknown")
349    }
350}
351
352impl StdError for UnknownVersionError {}
353
354#[derive(Debug)]
357#[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
358pub struct IncorrectArgumentCount {
359    pub expected: usize,
361
362    pub got: usize,
364}
365
366impl fmt::Display for IncorrectArgumentCount {
367    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
368        write!(f, "incorrect path argument count, expected {}, got {}", self.expected, self.got)
369    }
370}
371
372impl StdError for IncorrectArgumentCount {}
373
374#[derive(Debug, Error)]
376#[non_exhaustive]
377pub enum HeaderSerializationError {
378    #[error(transparent)]
380    ToHeaderValue(#[from] http::header::InvalidHeaderValue),
381
382    #[error("invalid HTTP date")]
387    InvalidHttpDate,
388}