ruma_client_api/uiaa/auth_data/
data_serde.rs

1//! Custom Serialize / Deserialize implementations for the authentication data types.
2
3use std::borrow::Cow;
4
5use ruma_common::{serde::from_raw_json_value, thirdparty::Medium};
6use serde::{Deserialize, Deserializer, Serialize, de, ser::SerializeStruct};
7use serde_json::value::RawValue as RawJsonValue;
8
9use super::{AuthData, CustomThirdPartyId, UserIdentifier};
10
11impl<'de> Deserialize<'de> for AuthData {
12    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
13    where
14        D: Deserializer<'de>,
15    {
16        let json = Box::<RawJsonValue>::deserialize(deserializer)?;
17
18        #[derive(Deserialize)]
19        struct ExtractType<'a> {
20            #[serde(borrow, rename = "type")]
21            auth_type: Option<Cow<'a, str>>,
22        }
23
24        let auth_type = serde_json::from_str::<ExtractType<'_>>(json.get())
25            .map_err(de::Error::custom)?
26            .auth_type;
27
28        match auth_type.as_deref() {
29            Some("m.login.password") => from_raw_json_value(&json).map(Self::Password),
30            Some("m.login.recaptcha") => from_raw_json_value(&json).map(Self::ReCaptcha),
31            Some("m.login.email.identity") => from_raw_json_value(&json).map(Self::EmailIdentity),
32            Some("m.login.msisdn") => from_raw_json_value(&json).map(Self::Msisdn),
33            Some("m.login.dummy") => from_raw_json_value(&json).map(Self::Dummy),
34            Some("m.login.registration_token") => {
35                from_raw_json_value(&json).map(Self::RegistrationToken)
36            }
37            Some("m.login.terms") => from_raw_json_value(&json).map(Self::Terms),
38            Some("m.oauth" | "org.matrix.cross_signing_reset") => {
39                from_raw_json_value(&json).map(Self::OAuth)
40            }
41            None => from_raw_json_value(&json).map(Self::FallbackAcknowledgement),
42            Some(_) => from_raw_json_value(&json).map(Self::_Custom),
43        }
44    }
45}
46
47impl Serialize for UserIdentifier {
48    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
49    where
50        S: serde::Serializer,
51    {
52        let mut id;
53        match self {
54            Self::UserIdOrLocalpart(user) => {
55                id = serializer.serialize_struct("UserIdentifier", 2)?;
56                id.serialize_field("type", "m.id.user")?;
57                id.serialize_field("user", user)?;
58            }
59            Self::PhoneNumber { country, phone } => {
60                id = serializer.serialize_struct("UserIdentifier", 3)?;
61                id.serialize_field("type", "m.id.phone")?;
62                id.serialize_field("country", country)?;
63                id.serialize_field("phone", phone)?;
64            }
65            Self::Email { address } => {
66                id = serializer.serialize_struct("UserIdentifier", 3)?;
67                id.serialize_field("type", "m.id.thirdparty")?;
68                id.serialize_field("medium", &Medium::Email)?;
69                id.serialize_field("address", address)?;
70            }
71            Self::Msisdn { number } => {
72                id = serializer.serialize_struct("UserIdentifier", 3)?;
73                id.serialize_field("type", "m.id.thirdparty")?;
74                id.serialize_field("medium", &Medium::Msisdn)?;
75                id.serialize_field("address", number)?;
76            }
77            Self::_CustomThirdParty(CustomThirdPartyId { medium, address }) => {
78                id = serializer.serialize_struct("UserIdentifier", 3)?;
79                id.serialize_field("type", "m.id.thirdparty")?;
80                id.serialize_field("medium", &medium)?;
81                id.serialize_field("address", address)?;
82            }
83        }
84        id.end()
85    }
86}
87
88impl<'de> Deserialize<'de> for UserIdentifier {
89    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
90    where
91        D: Deserializer<'de>,
92    {
93        let json = Box::<RawJsonValue>::deserialize(deserializer)?;
94
95        #[derive(Deserialize)]
96        #[serde(tag = "type")]
97        enum ExtractType {
98            #[serde(rename = "m.id.user")]
99            User,
100            #[serde(rename = "m.id.phone")]
101            Phone,
102            #[serde(rename = "m.id.thirdparty")]
103            ThirdParty,
104        }
105
106        #[derive(Deserialize)]
107        struct UserIdOrLocalpart {
108            user: String,
109        }
110
111        #[derive(Deserialize)]
112        struct ThirdPartyId {
113            medium: Medium,
114            address: String,
115        }
116
117        #[derive(Deserialize)]
118        struct PhoneNumber {
119            country: String,
120            phone: String,
121        }
122
123        let id_type = serde_json::from_str::<ExtractType>(json.get()).map_err(de::Error::custom)?;
124
125        match id_type {
126            ExtractType::User => from_raw_json_value(&json)
127                .map(|user_id: UserIdOrLocalpart| Self::UserIdOrLocalpart(user_id.user)),
128            ExtractType::Phone => from_raw_json_value(&json)
129                .map(|nb: PhoneNumber| Self::PhoneNumber { country: nb.country, phone: nb.phone }),
130            ExtractType::ThirdParty => {
131                let ThirdPartyId { medium, address } = from_raw_json_value(&json)?;
132                match medium {
133                    Medium::Email => Ok(Self::Email { address }),
134                    Medium::Msisdn => Ok(Self::Msisdn { number: address }),
135                    _ => Ok(Self::_CustomThirdParty(CustomThirdPartyId { medium, address })),
136                }
137            }
138        }
139    }
140}
141
142#[cfg(test)]
143mod tests {
144    use assert_matches2::assert_matches;
145    use ruma_common::canonical_json::assert_to_canonical_json_eq;
146    use serde_json::{from_value as from_json_value, json};
147
148    use crate::uiaa::UserIdentifier;
149
150    #[test]
151    fn serialize() {
152        assert_to_canonical_json_eq!(
153            UserIdentifier::UserIdOrLocalpart("@user:notareal.hs".to_owned()),
154            json!({
155                "type": "m.id.user",
156                "user": "@user:notareal.hs",
157            })
158        );
159
160        assert_to_canonical_json_eq!(
161            UserIdentifier::PhoneNumber {
162                country: "33".to_owned(),
163                phone: "0102030405".to_owned()
164            },
165            json!({
166                "type": "m.id.phone",
167                "country": "33",
168                "phone": "0102030405",
169            })
170        );
171
172        assert_to_canonical_json_eq!(
173            UserIdentifier::Email { address: "me@myprovider.net".to_owned() },
174            json!({
175                "type": "m.id.thirdparty",
176                "medium": "email",
177                "address": "me@myprovider.net",
178            })
179        );
180
181        assert_to_canonical_json_eq!(
182            UserIdentifier::Msisdn { number: "330102030405".to_owned() },
183            json!({
184                "type": "m.id.thirdparty",
185                "medium": "msisdn",
186                "address": "330102030405",
187            })
188        );
189
190        assert_to_canonical_json_eq!(
191            UserIdentifier::third_party_id("robot".into(), "01001110".to_owned()),
192            json!({
193                "type": "m.id.thirdparty",
194                "medium": "robot",
195                "address": "01001110",
196            })
197        );
198    }
199
200    #[test]
201    fn deserialize() {
202        let json = json!({
203            "type": "m.id.user",
204            "user": "@user:notareal.hs",
205        });
206        assert_matches!(from_json_value(json), Ok(UserIdentifier::UserIdOrLocalpart(user)));
207        assert_eq!(user, "@user:notareal.hs");
208
209        let json = json!({
210            "type": "m.id.phone",
211            "country": "33",
212            "phone": "0102030405",
213        });
214        assert_matches!(from_json_value(json), Ok(UserIdentifier::PhoneNumber { country, phone }));
215        assert_eq!(country, "33");
216        assert_eq!(phone, "0102030405");
217
218        let json = json!({
219            "type": "m.id.thirdparty",
220            "medium": "email",
221            "address": "me@myprovider.net",
222        });
223        assert_matches!(from_json_value(json), Ok(UserIdentifier::Email { address }));
224        assert_eq!(address, "me@myprovider.net");
225
226        let json = json!({
227            "type": "m.id.thirdparty",
228            "medium": "msisdn",
229            "address": "330102030405",
230        });
231        assert_matches!(from_json_value(json), Ok(UserIdentifier::Msisdn { number }));
232        assert_eq!(number, "330102030405");
233
234        let json = json!({
235            "type": "m.id.thirdparty",
236            "medium": "robot",
237            "address": "01110010",
238        });
239        let id = from_json_value::<UserIdentifier>(json).unwrap();
240        let (medium, address) = id.as_third_party_id().unwrap();
241        assert_eq!(medium.as_str(), "robot");
242        assert_eq!(address, "01110010");
243    }
244}