ruma_client_api/uiaa/auth_data/
data_serde.rs

1//! Custom Serialize / Deserialize implementations for the authentication data types.
2
3use std::borrow::Cow;
4
5use ruma_common::{serde::from_raw_json_value, thirdparty::Medium};
6use serde::{Deserialize, Deserializer, Serialize, de, ser::SerializeStruct};
7use serde_json::value::RawValue as RawJsonValue;
8
9use super::{AuthData, CustomThirdPartyId, UserIdentifier};
10
11impl<'de> Deserialize<'de> for AuthData {
12    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
13    where
14        D: Deserializer<'de>,
15    {
16        let json = Box::<RawJsonValue>::deserialize(deserializer)?;
17
18        #[derive(Deserialize)]
19        struct ExtractType<'a> {
20            #[serde(borrow, rename = "type")]
21            auth_type: Option<Cow<'a, str>>,
22        }
23
24        let auth_type = serde_json::from_str::<ExtractType<'_>>(json.get())
25            .map_err(de::Error::custom)?
26            .auth_type;
27
28        match auth_type.as_deref() {
29            Some("m.login.password") => from_raw_json_value(&json).map(Self::Password),
30            Some("m.login.recaptcha") => from_raw_json_value(&json).map(Self::ReCaptcha),
31            Some("m.login.email.identity") => from_raw_json_value(&json).map(Self::EmailIdentity),
32            Some("m.login.msisdn") => from_raw_json_value(&json).map(Self::Msisdn),
33            Some("m.login.dummy") => from_raw_json_value(&json).map(Self::Dummy),
34            Some("m.login.registration_token") => {
35                from_raw_json_value(&json).map(Self::RegistrationToken)
36            }
37            Some("m.login.terms") => from_raw_json_value(&json).map(Self::Terms),
38            Some("m.oauth" | "org.matrix.cross_signing_reset") => {
39                from_raw_json_value(&json).map(Self::OAuth)
40            }
41            None => from_raw_json_value(&json).map(Self::FallbackAcknowledgement),
42            Some(_) => from_raw_json_value(&json).map(Self::_Custom),
43        }
44    }
45}
46
47impl Serialize for UserIdentifier {
48    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
49    where
50        S: serde::Serializer,
51    {
52        let mut id;
53        match self {
54            Self::UserIdOrLocalpart(user) => {
55                id = serializer.serialize_struct("UserIdentifier", 2)?;
56                id.serialize_field("type", "m.id.user")?;
57                id.serialize_field("user", user)?;
58            }
59            Self::PhoneNumber { country, phone } => {
60                id = serializer.serialize_struct("UserIdentifier", 3)?;
61                id.serialize_field("type", "m.id.phone")?;
62                id.serialize_field("country", country)?;
63                id.serialize_field("phone", phone)?;
64            }
65            Self::Email { address } => {
66                id = serializer.serialize_struct("UserIdentifier", 3)?;
67                id.serialize_field("type", "m.id.thirdparty")?;
68                id.serialize_field("medium", &Medium::Email)?;
69                id.serialize_field("address", address)?;
70            }
71            Self::Msisdn { number } => {
72                id = serializer.serialize_struct("UserIdentifier", 3)?;
73                id.serialize_field("type", "m.id.thirdparty")?;
74                id.serialize_field("medium", &Medium::Msisdn)?;
75                id.serialize_field("address", number)?;
76            }
77            Self::_CustomThirdParty(CustomThirdPartyId { medium, address }) => {
78                id = serializer.serialize_struct("UserIdentifier", 3)?;
79                id.serialize_field("type", "m.id.thirdparty")?;
80                id.serialize_field("medium", &medium)?;
81                id.serialize_field("address", address)?;
82            }
83        }
84        id.end()
85    }
86}
87
88impl<'de> Deserialize<'de> for UserIdentifier {
89    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
90    where
91        D: Deserializer<'de>,
92    {
93        let json = Box::<RawJsonValue>::deserialize(deserializer)?;
94
95        #[derive(Deserialize)]
96        #[serde(tag = "type")]
97        enum ExtractType {
98            #[serde(rename = "m.id.user")]
99            User,
100            #[serde(rename = "m.id.phone")]
101            Phone,
102            #[serde(rename = "m.id.thirdparty")]
103            ThirdParty,
104        }
105
106        #[derive(Deserialize)]
107        struct UserIdOrLocalpart {
108            user: String,
109        }
110
111        #[derive(Deserialize)]
112        struct ThirdPartyId {
113            medium: Medium,
114            address: String,
115        }
116
117        #[derive(Deserialize)]
118        struct PhoneNumber {
119            country: String,
120            phone: String,
121        }
122
123        let id_type = serde_json::from_str::<ExtractType>(json.get()).map_err(de::Error::custom)?;
124
125        match id_type {
126            ExtractType::User => from_raw_json_value(&json)
127                .map(|user_id: UserIdOrLocalpart| Self::UserIdOrLocalpart(user_id.user)),
128            ExtractType::Phone => from_raw_json_value(&json)
129                .map(|nb: PhoneNumber| Self::PhoneNumber { country: nb.country, phone: nb.phone }),
130            ExtractType::ThirdParty => {
131                let ThirdPartyId { medium, address } = from_raw_json_value(&json)?;
132                match medium {
133                    Medium::Email => Ok(Self::Email { address }),
134                    Medium::Msisdn => Ok(Self::Msisdn { number: address }),
135                    _ => Ok(Self::_CustomThirdParty(CustomThirdPartyId { medium, address })),
136                }
137            }
138        }
139    }
140}
141
142#[cfg(test)]
143mod tests {
144    use assert_matches2::assert_matches;
145    use serde_json::{from_value as from_json_value, json, to_value as to_json_value};
146
147    use crate::uiaa::UserIdentifier;
148
149    #[test]
150    fn serialize() {
151        assert_eq!(
152            to_json_value(UserIdentifier::UserIdOrLocalpart("@user:notareal.hs".to_owned()))
153                .unwrap(),
154            json!({
155                "type": "m.id.user",
156                "user": "@user:notareal.hs",
157            })
158        );
159
160        assert_eq!(
161            to_json_value(UserIdentifier::PhoneNumber {
162                country: "33".to_owned(),
163                phone: "0102030405".to_owned()
164            })
165            .unwrap(),
166            json!({
167                "type": "m.id.phone",
168                "country": "33",
169                "phone": "0102030405",
170            })
171        );
172
173        assert_eq!(
174            to_json_value(UserIdentifier::Email { address: "me@myprovider.net".to_owned() })
175                .unwrap(),
176            json!({
177                "type": "m.id.thirdparty",
178                "medium": "email",
179                "address": "me@myprovider.net",
180            })
181        );
182
183        assert_eq!(
184            to_json_value(UserIdentifier::Msisdn { number: "330102030405".to_owned() }).unwrap(),
185            json!({
186                "type": "m.id.thirdparty",
187                "medium": "msisdn",
188                "address": "330102030405",
189            })
190        );
191
192        assert_eq!(
193            to_json_value(UserIdentifier::third_party_id("robot".into(), "01001110".to_owned()))
194                .unwrap(),
195            json!({
196                "type": "m.id.thirdparty",
197                "medium": "robot",
198                "address": "01001110",
199            })
200        );
201    }
202
203    #[test]
204    fn deserialize() {
205        let json = json!({
206            "type": "m.id.user",
207            "user": "@user:notareal.hs",
208        });
209        assert_matches!(from_json_value(json), Ok(UserIdentifier::UserIdOrLocalpart(user)));
210        assert_eq!(user, "@user:notareal.hs");
211
212        let json = json!({
213            "type": "m.id.phone",
214            "country": "33",
215            "phone": "0102030405",
216        });
217        assert_matches!(from_json_value(json), Ok(UserIdentifier::PhoneNumber { country, phone }));
218        assert_eq!(country, "33");
219        assert_eq!(phone, "0102030405");
220
221        let json = json!({
222            "type": "m.id.thirdparty",
223            "medium": "email",
224            "address": "me@myprovider.net",
225        });
226        assert_matches!(from_json_value(json), Ok(UserIdentifier::Email { address }));
227        assert_eq!(address, "me@myprovider.net");
228
229        let json = json!({
230            "type": "m.id.thirdparty",
231            "medium": "msisdn",
232            "address": "330102030405",
233        });
234        assert_matches!(from_json_value(json), Ok(UserIdentifier::Msisdn { number }));
235        assert_eq!(number, "330102030405");
236
237        let json = json!({
238            "type": "m.id.thirdparty",
239            "medium": "robot",
240            "address": "01110010",
241        });
242        let id = from_json_value::<UserIdentifier>(json).unwrap();
243        let (medium, address) = id.as_third_party_id().unwrap();
244        assert_eq!(medium.as_str(), "robot");
245        assert_eq!(address, "01110010");
246    }
247}