1use std::{borrow::Cow, fmt, time::Duration};
2
3use js_int::UInt;
4use ruma_common::serde::JsonObject;
5use serde::{
6 de::{self, Deserialize, Deserializer, MapAccess, Visitor},
7 ser::{self, Serialize, SerializeMap, Serializer},
8};
9use serde_json::{from_value as from_json_value, map::Entry};
10
11use super::{
12 BadStatusErrorData, CustomErrorKind, ErrorCode, ErrorKind, IncompatibleRoomVersionErrorData,
13 LimitExceededErrorData, ResourceLimitExceededErrorData, RetryAfter, UnknownTokenErrorData,
14 WrongRoomKeysVersionErrorData,
15};
16
17enum Field<'de> {
18 ErrorCode,
19 SoftLogout,
20 RetryAfterMs,
21 RoomVersion,
22 AdminContact,
23 Status,
24 Body,
25 CurrentVersion,
26 Other(Cow<'de, str>),
27}
28
29impl<'de> Field<'de> {
30 fn new(s: Cow<'de, str>) -> Field<'de> {
31 match s.as_ref() {
32 "errcode" => Self::ErrorCode,
33 "soft_logout" => Self::SoftLogout,
34 "retry_after_ms" => Self::RetryAfterMs,
35 "room_version" => Self::RoomVersion,
36 "admin_contact" => Self::AdminContact,
37 "status" => Self::Status,
38 "body" => Self::Body,
39 "current_version" => Self::CurrentVersion,
40 _ => Self::Other(s),
41 }
42 }
43}
44
45impl<'de> Deserialize<'de> for Field<'de> {
46 fn deserialize<D>(deserializer: D) -> Result<Field<'de>, D::Error>
47 where
48 D: Deserializer<'de>,
49 {
50 struct FieldVisitor;
51
52 impl<'de> Visitor<'de> for FieldVisitor {
53 type Value = Field<'de>;
54
55 fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
56 formatter.write_str("any struct field")
57 }
58
59 fn visit_str<E>(self, value: &str) -> Result<Field<'de>, E>
60 where
61 E: de::Error,
62 {
63 Ok(Field::new(Cow::Owned(value.to_owned())))
64 }
65
66 fn visit_borrowed_str<E>(self, value: &'de str) -> Result<Field<'de>, E>
67 where
68 E: de::Error,
69 {
70 Ok(Field::new(Cow::Borrowed(value)))
71 }
72
73 fn visit_string<E>(self, value: String) -> Result<Field<'de>, E>
74 where
75 E: de::Error,
76 {
77 Ok(Field::new(Cow::Owned(value)))
78 }
79 }
80
81 deserializer.deserialize_identifier(FieldVisitor)
82 }
83}
84
85struct ErrorKindVisitor;
86
87impl<'de> Visitor<'de> for ErrorKindVisitor {
88 type Value = ErrorKind;
89
90 fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
91 formatter.write_str("enum ErrorKind")
92 }
93
94 fn visit_map<V>(self, mut map: V) -> Result<ErrorKind, V::Error>
95 where
96 V: MapAccess<'de>,
97 {
98 let mut errcode = None;
99 let mut soft_logout = None;
100 let mut retry_after_ms = None;
101 let mut room_version = None;
102 let mut admin_contact = None;
103 let mut status = None;
104 let mut body = None;
105 let mut current_version = None;
106 let mut data = JsonObject::new();
107
108 macro_rules! set_field {
109 (errcode) => {
110 set_field!(@inner errcode)
111 };
112 ($field:ident) => {
113 match errcode {
114 Some(set_field!(@variant_containing $field)) | None => {
115 set_field!(@inner $field)
116 }
117 Some(_) => {
120 let _ = map.next_value::<de::IgnoredAny>()?;
121 },
122 }
123 };
124 (@variant_containing soft_logout) => { ErrorCode::UnknownToken };
125 (@variant_containing retry_after_ms) => { ErrorCode::LimitExceeded };
126 (@variant_containing room_version) => { ErrorCode::IncompatibleRoomVersion };
127 (@variant_containing admin_contact) => { ErrorCode::ResourceLimitExceeded };
128 (@variant_containing status) => { ErrorCode::BadStatus };
129 (@variant_containing body) => { ErrorCode::BadStatus };
130 (@variant_containing current_version) => { ErrorCode::WrongRoomKeysVersion };
131 (@inner $field:ident) => {
132 {
133 if $field.is_some() {
134 return Err(de::Error::duplicate_field(stringify!($field)));
135 }
136 $field = Some(map.next_value()?);
137 }
138 };
139 }
140
141 while let Some(key) = map.next_key()? {
142 match key {
143 Field::ErrorCode => set_field!(errcode),
144 Field::SoftLogout => set_field!(soft_logout),
145 Field::RetryAfterMs => set_field!(retry_after_ms),
146 Field::RoomVersion => set_field!(room_version),
147 Field::AdminContact => set_field!(admin_contact),
148 Field::Status => set_field!(status),
149 Field::Body => set_field!(body),
150 Field::CurrentVersion => set_field!(current_version),
151 Field::Other(other) => match data.entry(other.into_owned()) {
152 Entry::Vacant(v) => {
153 v.insert(map.next_value()?);
154 }
155 Entry::Occupied(o) => {
156 return Err(de::Error::custom(format!("duplicate field `{}`", o.key())));
157 }
158 },
159 }
160 }
161
162 let errcode = errcode.ok_or_else(|| de::Error::missing_field("errcode"))?;
163
164 Ok(match errcode {
165 ErrorCode::AppserviceLoginUnsupported => ErrorKind::AppserviceLoginUnsupported,
166 ErrorCode::BadAlias => ErrorKind::BadAlias,
167 ErrorCode::BadJson => ErrorKind::BadJson,
168 ErrorCode::BadState => ErrorKind::BadState,
169 ErrorCode::BadStatus => ErrorKind::BadStatus(BadStatusErrorData {
170 status: status
171 .map(|s| {
172 from_json_value::<u16>(s)
173 .map_err(de::Error::custom)?
174 .try_into()
175 .map_err(de::Error::custom)
176 })
177 .transpose()?,
178 body: body.map(from_json_value).transpose().map_err(de::Error::custom)?,
179 }),
180 ErrorCode::CannotLeaveServerNoticeRoom => ErrorKind::CannotLeaveServerNoticeRoom,
181 ErrorCode::CannotOverwriteMedia => ErrorKind::CannotOverwriteMedia,
182 ErrorCode::CaptchaInvalid => ErrorKind::CaptchaInvalid,
183 ErrorCode::CaptchaNeeded => ErrorKind::CaptchaNeeded,
184 #[cfg(feature = "unstable-msc4306")]
185 ErrorCode::ConflictingUnsubscription => ErrorKind::ConflictingUnsubscription,
186 ErrorCode::ConnectionFailed => ErrorKind::ConnectionFailed,
187 ErrorCode::ConnectionTimeout => ErrorKind::ConnectionTimeout,
188 ErrorCode::DuplicateAnnotation => ErrorKind::DuplicateAnnotation,
189 ErrorCode::Exclusive => ErrorKind::Exclusive,
190 ErrorCode::Forbidden => ErrorKind::Forbidden,
191 ErrorCode::GuestAccessForbidden => ErrorKind::GuestAccessForbidden,
192 ErrorCode::IncompatibleRoomVersion => {
193 ErrorKind::IncompatibleRoomVersion(IncompatibleRoomVersionErrorData {
194 room_version: from_json_value(
195 room_version.ok_or_else(|| de::Error::missing_field("room_version"))?,
196 )
197 .map_err(de::Error::custom)?,
198 })
199 }
200 ErrorCode::InvalidParam => ErrorKind::InvalidParam,
201 ErrorCode::InvalidRoomState => ErrorKind::InvalidRoomState,
202 ErrorCode::InvalidUsername => ErrorKind::InvalidUsername,
203 ErrorCode::InviteBlocked => ErrorKind::InviteBlocked,
204 ErrorCode::LimitExceeded => ErrorKind::LimitExceeded(LimitExceededErrorData {
205 retry_after: retry_after_ms
206 .map(from_json_value::<UInt>)
207 .transpose()
208 .map_err(de::Error::custom)?
209 .map(Into::into)
210 .map(Duration::from_millis)
211 .map(RetryAfter::Delay),
212 }),
213 ErrorCode::MissingParam => ErrorKind::MissingParam,
214 ErrorCode::MissingToken => ErrorKind::MissingToken,
215 ErrorCode::NotFound => ErrorKind::NotFound,
216 #[cfg(feature = "unstable-msc4306")]
217 ErrorCode::NotInThread => ErrorKind::NotInThread,
218 ErrorCode::NotJson => ErrorKind::NotJson,
219 ErrorCode::NotYetUploaded => ErrorKind::NotYetUploaded,
220 ErrorCode::ResourceLimitExceeded => {
221 ErrorKind::ResourceLimitExceeded(ResourceLimitExceededErrorData {
222 admin_contact: from_json_value(
223 admin_contact.ok_or_else(|| de::Error::missing_field("admin_contact"))?,
224 )
225 .map_err(de::Error::custom)?,
226 })
227 }
228 ErrorCode::RoomInUse => ErrorKind::RoomInUse,
229 ErrorCode::ServerNotTrusted => ErrorKind::ServerNotTrusted,
230 ErrorCode::ThreepidAuthFailed => ErrorKind::ThreepidAuthFailed,
231 ErrorCode::ThreepidDenied => ErrorKind::ThreepidDenied,
232 ErrorCode::ThreepidInUse => ErrorKind::ThreepidInUse,
233 ErrorCode::ThreepidMediumNotSupported => ErrorKind::ThreepidMediumNotSupported,
234 ErrorCode::ThreepidNotFound => ErrorKind::ThreepidNotFound,
235 ErrorCode::TokenIncorrect => ErrorKind::TokenIncorrect,
236 ErrorCode::TooLarge => ErrorKind::TooLarge,
237 ErrorCode::UnableToAuthorizeJoin => ErrorKind::UnableToAuthorizeJoin,
238 ErrorCode::UnableToGrantJoin => ErrorKind::UnableToGrantJoin,
239 #[cfg(feature = "unstable-msc3843")]
240 ErrorCode::Unactionable => ErrorKind::Unactionable,
241 ErrorCode::Unauthorized => ErrorKind::Unauthorized,
242 ErrorCode::Unknown => ErrorKind::Unknown,
243 #[cfg(feature = "unstable-msc4186")]
244 ErrorCode::UnknownPos => ErrorKind::UnknownPos,
245 ErrorCode::UnknownToken => ErrorKind::UnknownToken(UnknownTokenErrorData {
246 soft_logout: soft_logout
247 .map(from_json_value)
248 .transpose()
249 .map_err(de::Error::custom)?
250 .unwrap_or_default(),
251 }),
252 ErrorCode::Unrecognized => ErrorKind::Unrecognized,
253 ErrorCode::UnsupportedRoomVersion => ErrorKind::UnsupportedRoomVersion,
254 ErrorCode::UrlNotSet => ErrorKind::UrlNotSet,
255 ErrorCode::UserDeactivated => ErrorKind::UserDeactivated,
256 ErrorCode::UserInUse => ErrorKind::UserInUse,
257 ErrorCode::UserLocked => ErrorKind::UserLocked,
258 ErrorCode::UserSuspended => ErrorKind::UserSuspended,
259 ErrorCode::WeakPassword => ErrorKind::WeakPassword,
260 ErrorCode::WrongRoomKeysVersion => {
261 ErrorKind::WrongRoomKeysVersion(WrongRoomKeysVersionErrorData {
262 current_version: from_json_value(
263 current_version
264 .ok_or_else(|| de::Error::missing_field("current_version"))?,
265 )
266 .map_err(de::Error::custom)?,
267 })
268 }
269 ErrorCode::_Custom(errcode) => {
270 ErrorKind::_Custom(CustomErrorKind { errcode: errcode.0.into(), data })
271 }
272 })
273 }
274}
275
276impl<'de> Deserialize<'de> for ErrorKind {
277 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
278 where
279 D: Deserializer<'de>,
280 {
281 deserializer.deserialize_map(ErrorKindVisitor)
282 }
283}
284
285impl Serialize for ErrorKind {
286 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
287 where
288 S: Serializer,
289 {
290 let mut st = serializer.serialize_map(None)?;
291 st.serialize_entry("errcode", &self.errcode())?;
292 match self {
293 Self::BadStatus(BadStatusErrorData { status, body }) => {
294 if let Some(status) = status {
295 st.serialize_entry("status", &status.as_u16())?;
296 }
297 if let Some(body) = body {
298 st.serialize_entry("body", body)?;
299 }
300 }
301 Self::IncompatibleRoomVersion(IncompatibleRoomVersionErrorData { room_version }) => {
302 st.serialize_entry("room_version", room_version)?;
303 }
304 Self::LimitExceeded(LimitExceededErrorData {
305 retry_after: Some(RetryAfter::Delay(duration)),
306 }) => {
307 st.serialize_entry(
308 "retry_after_ms",
309 &UInt::try_from(duration.as_millis()).map_err(ser::Error::custom)?,
310 )?;
311 }
312 Self::ResourceLimitExceeded(ResourceLimitExceededErrorData { admin_contact }) => {
313 st.serialize_entry("admin_contact", admin_contact)?;
314 }
315 Self::UnknownToken(UnknownTokenErrorData { soft_logout: true }) | Self::UserLocked => {
316 st.serialize_entry("soft_logout", &true)?;
317 }
318 Self::WrongRoomKeysVersion(WrongRoomKeysVersionErrorData { current_version }) => {
319 st.serialize_entry("current_version", current_version)?;
320 }
321 Self::_Custom(CustomErrorKind { data, .. }) => {
322 for (k, v) in data {
323 st.serialize_entry(k, v)?;
324 }
325 }
326 _ => {}
327 }
328 st.end()
329 }
330}
331
332#[cfg(test)]
333mod tests {
334 use ruma_common::room_version_id;
335 use serde_json::{from_value as from_json_value, json};
336
337 use super::{ErrorKind, IncompatibleRoomVersionErrorData};
338
339 #[test]
340 fn deserialize_forbidden() {
341 let deserialized: ErrorKind = from_json_value(json!({ "errcode": "M_FORBIDDEN" })).unwrap();
342 assert_eq!(deserialized, ErrorKind::Forbidden);
343 }
344
345 #[test]
346 fn deserialize_forbidden_with_extra_fields() {
347 let deserialized: ErrorKind = from_json_value(json!({
348 "errcode": "M_FORBIDDEN",
349 "error": "…",
350 }))
351 .unwrap();
352
353 assert_eq!(deserialized, ErrorKind::Forbidden);
354 }
355
356 #[test]
357 fn deserialize_incompatible_room_version() {
358 let deserialized: ErrorKind = from_json_value(json!({
359 "errcode": "M_INCOMPATIBLE_ROOM_VERSION",
360 "room_version": "7",
361 }))
362 .unwrap();
363
364 assert_eq!(
365 deserialized,
366 ErrorKind::IncompatibleRoomVersion(IncompatibleRoomVersionErrorData {
367 room_version: room_version_id!("7")
368 })
369 );
370 }
371}