Skip to main content

ruma_client_api/uiaa/auth_data/
data_serde.rs

1//! Custom Serialize / Deserialize implementations for the authentication data types.
2
3use std::borrow::Cow;
4
5use ruma_common::{serde::from_raw_json_value, thirdparty::Medium};
6use serde::{Deserialize, Deserializer, Serialize, de};
7use serde_json::value::RawValue as RawJsonValue;
8
9use super::{
10    AuthData, CustomThirdPartyUserIdentifier, EmailUserIdentifier, MsisdnUserIdentifier,
11    UserIdentifier,
12};
13
14impl<'de> Deserialize<'de> for AuthData {
15    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
16    where
17        D: Deserializer<'de>,
18    {
19        let json = Box::<RawJsonValue>::deserialize(deserializer)?;
20
21        #[derive(Deserialize)]
22        struct ExtractType<'a> {
23            #[serde(borrow, rename = "type")]
24            auth_type: Option<Cow<'a, str>>,
25        }
26
27        let auth_type = serde_json::from_str::<ExtractType<'_>>(json.get())
28            .map_err(de::Error::custom)?
29            .auth_type;
30
31        match auth_type.as_deref() {
32            Some("m.login.password") => from_raw_json_value(&json).map(Self::Password),
33            Some("m.login.recaptcha") => from_raw_json_value(&json).map(Self::ReCaptcha),
34            Some("m.login.email.identity") => from_raw_json_value(&json).map(Self::EmailIdentity),
35            Some("m.login.msisdn") => from_raw_json_value(&json).map(Self::Msisdn),
36            Some("m.login.dummy") => from_raw_json_value(&json).map(Self::Dummy),
37            Some("m.login.registration_token") => {
38                from_raw_json_value(&json).map(Self::RegistrationToken)
39            }
40            Some("m.login.terms") => from_raw_json_value(&json).map(Self::Terms),
41            Some("m.oauth" | "org.matrix.cross_signing_reset") => {
42                from_raw_json_value(&json).map(Self::OAuth)
43            }
44            None => from_raw_json_value(&json).map(Self::FallbackAcknowledgement),
45            Some(_) => from_raw_json_value(&json).map(Self::_Custom),
46        }
47    }
48}
49
50impl<'de> Deserialize<'de> for UserIdentifier {
51    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
52    where
53        D: Deserializer<'de>,
54    {
55        #[derive(Deserialize)]
56        struct ExtractType<'a> {
57            #[serde(borrow, rename = "type")]
58            identifier_type: Cow<'a, str>,
59        }
60
61        let json = Box::<RawJsonValue>::deserialize(deserializer)?;
62        let ExtractType { identifier_type } =
63            serde_json::from_str(json.get()).map_err(de::Error::custom)?;
64
65        match identifier_type.as_ref() {
66            "m.id.user" => from_raw_json_value(&json).map(Self::Matrix),
67            "m.id.phone" => from_raw_json_value(&json).map(Self::PhoneNumber),
68            "m.id.thirdparty" => {
69                let id: CustomThirdPartyUserIdentifier = from_raw_json_value(&json)?;
70                match &id.medium {
71                    Medium::Email => Ok(Self::Email(EmailUserIdentifier { address: id.address })),
72                    Medium::Msisdn => Ok(Self::Msisdn(MsisdnUserIdentifier { number: id.address })),
73                    _ => Ok(Self::_CustomThirdParty(id)),
74                }
75            }
76            _ => from_raw_json_value(&json).map(Self::_Custom),
77        }
78    }
79}
80
81impl Serialize for EmailUserIdentifier {
82    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
83    where
84        S: serde::Serializer,
85    {
86        let Self { address } = self;
87
88        CustomThirdPartyUserIdentifier { medium: Medium::Email, address: address.clone() }
89            .serialize(serializer)
90    }
91}
92
93impl<'de> Deserialize<'de> for EmailUserIdentifier {
94    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
95    where
96        D: Deserializer<'de>,
97    {
98        let CustomThirdPartyUserIdentifier { medium, address } =
99            CustomThirdPartyUserIdentifier::deserialize(deserializer)?;
100
101        if medium != Medium::Email {
102            return Err(de::Error::invalid_value(
103                de::Unexpected::Str(medium.as_str()),
104                &Medium::Email.as_str(),
105            ));
106        }
107
108        Ok(Self { address })
109    }
110}
111
112impl Serialize for MsisdnUserIdentifier {
113    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
114    where
115        S: serde::Serializer,
116    {
117        let Self { number } = self;
118
119        CustomThirdPartyUserIdentifier { medium: Medium::Msisdn, address: number.clone() }
120            .serialize(serializer)
121    }
122}
123
124impl<'de> Deserialize<'de> for MsisdnUserIdentifier {
125    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
126    where
127        D: Deserializer<'de>,
128    {
129        let CustomThirdPartyUserIdentifier { medium, address } =
130            CustomThirdPartyUserIdentifier::deserialize(deserializer)?;
131
132        if medium != Medium::Msisdn {
133            return Err(de::Error::invalid_value(
134                de::Unexpected::Str(medium.as_str()),
135                &Medium::Msisdn.as_str(),
136            ));
137        }
138
139        Ok(Self { number: address })
140    }
141}
142
143#[cfg(test)]
144mod tests {
145    use assert_matches2::assert_let;
146    use ruma_common::canonical_json::assert_to_canonical_json_eq;
147    use serde_json::{Value as JsonValue, from_value as from_json_value, json};
148
149    use crate::uiaa::{
150        EmailUserIdentifier, MatrixUserIdentifier, MsisdnUserIdentifier, PhoneNumberUserIdentifier,
151        UserIdentifier,
152    };
153
154    #[test]
155    fn serialize() {
156        assert_to_canonical_json_eq!(
157            UserIdentifier::Matrix(MatrixUserIdentifier::new("@user:notareal.hs".to_owned())),
158            json!({
159                "type": "m.id.user",
160                "user": "@user:notareal.hs",
161            })
162        );
163
164        assert_to_canonical_json_eq!(
165            UserIdentifier::PhoneNumber(PhoneNumberUserIdentifier::new(
166                "33".to_owned(),
167                "0102030405".to_owned()
168            )),
169            json!({
170                "type": "m.id.phone",
171                "country": "33",
172                "phone": "0102030405",
173            })
174        );
175
176        assert_to_canonical_json_eq!(
177            UserIdentifier::Email(EmailUserIdentifier::new("me@myprovider.net".to_owned())),
178            json!({
179                "type": "m.id.thirdparty",
180                "medium": "email",
181                "address": "me@myprovider.net",
182            })
183        );
184
185        assert_to_canonical_json_eq!(
186            UserIdentifier::Msisdn(MsisdnUserIdentifier::new("330102030405".to_owned())),
187            json!({
188                "type": "m.id.thirdparty",
189                "medium": "msisdn",
190                "address": "330102030405",
191            })
192        );
193
194        assert_to_canonical_json_eq!(
195            UserIdentifier::third_party_id("robot".into(), "01001110".to_owned()),
196            json!({
197                "type": "m.id.thirdparty",
198                "medium": "robot",
199                "address": "01001110",
200            })
201        );
202    }
203
204    #[test]
205    fn deserialize() {
206        let json = json!({
207            "type": "m.id.user",
208            "user": "@user:notareal.hs",
209        });
210        assert_let!(Ok(UserIdentifier::Matrix(id)) = from_json_value(json));
211        assert_eq!(id.user, "@user:notareal.hs");
212
213        let json = json!({
214            "type": "m.id.phone",
215            "country": "33",
216            "phone": "0102030405",
217        });
218        assert_let!(
219            Ok(UserIdentifier::PhoneNumber(PhoneNumberUserIdentifier { country, phone })) =
220                from_json_value(json)
221        );
222        assert_eq!(country, "33");
223        assert_eq!(phone, "0102030405");
224
225        let json = json!({
226            "type": "m.id.thirdparty",
227            "medium": "email",
228            "address": "me@myprovider.net",
229        });
230        assert_let!(Ok(UserIdentifier::Email(id)) = from_json_value(json));
231        assert_eq!(id.address, "me@myprovider.net");
232
233        let json = json!({
234            "type": "m.id.thirdparty",
235            "medium": "msisdn",
236            "address": "330102030405",
237        });
238        assert_let!(Ok(UserIdentifier::Msisdn(id)) = from_json_value(json));
239        assert_eq!(id.number, "330102030405");
240
241        let json = json!({
242            "type": "m.id.thirdparty",
243            "medium": "robot",
244            "address": "01110010",
245        });
246        let id = from_json_value::<UserIdentifier>(json).unwrap();
247        let (medium, address) = id.as_third_party_id().unwrap();
248        assert_eq!(medium.as_str(), "robot");
249        assert_eq!(address, "01110010");
250    }
251
252    #[test]
253    fn custom_identifier_roundtrip() {
254        let json = json!({
255            "type": "local.dev.identifier",
256            "foo": "bar",
257        });
258
259        let id = from_json_value::<UserIdentifier>(json.clone()).unwrap();
260        assert_eq!(id.identifier_type(), "local.dev.identifier");
261        let data = id.custom_identifier_data().unwrap();
262        assert_let!(Some(JsonValue::String(foo)) = data.get("foo"));
263        assert_eq!(foo, "bar");
264
265        assert_to_canonical_json_eq!(id, json);
266    }
267}