1use std::{borrow::Cow, fmt};
6
7use bytes::BufMut;
8use ruma_common::{
9    api::{error::IntoHttpError, EndpointError, OutgoingResponse},
10    serde::{from_raw_json_value, JsonObject, StringEnum},
11    thirdparty::Medium,
12    OwnedClientSecret, OwnedSessionId, OwnedUserId,
13};
14use serde::{
15    de::{self, DeserializeOwned},
16    Deserialize, Deserializer, Serialize,
17};
18use serde_json::{
19    from_slice as from_json_slice, value::RawValue as RawJsonValue, Value as JsonValue,
20};
21
22use crate::{
23    error::{Error as MatrixError, StandardErrorBody},
24    PrivOwnedStr,
25};
26
27pub mod get_uiaa_fallback_page;
28mod user_serde;
29
30#[derive(Clone, Serialize)]
32#[non_exhaustive]
33#[serde(untagged)]
34pub enum AuthData {
35    Password(Password),
37
38    ReCaptcha(ReCaptcha),
40
41    EmailIdentity(EmailIdentity),
43
44    Msisdn(Msisdn),
46
47    Dummy(Dummy),
49
50    RegistrationToken(RegistrationToken),
52
53    FallbackAcknowledgement(FallbackAcknowledgement),
55
56    Terms(Terms),
60
61    #[doc(hidden)]
62    _Custom(CustomAuthData),
63}
64
65impl AuthData {
66    pub fn new(
77        auth_type: &str,
78        session: Option<String>,
79        data: JsonObject,
80    ) -> serde_json::Result<Self> {
81        fn deserialize_variant<T: DeserializeOwned>(
82            session: Option<String>,
83            mut obj: JsonObject,
84        ) -> serde_json::Result<T> {
85            if let Some(session) = session {
86                obj.insert("session".into(), session.into());
87            }
88            serde_json::from_value(JsonValue::Object(obj))
89        }
90
91        Ok(match auth_type {
92            "m.login.password" => Self::Password(deserialize_variant(session, data)?),
93            "m.login.recaptcha" => Self::ReCaptcha(deserialize_variant(session, data)?),
94            "m.login.email.identity" => Self::EmailIdentity(deserialize_variant(session, data)?),
95            "m.login.msisdn" => Self::Msisdn(deserialize_variant(session, data)?),
96            "m.login.dummy" => Self::Dummy(deserialize_variant(session, data)?),
97            "m.registration_token" => Self::RegistrationToken(deserialize_variant(session, data)?),
98            "m.login.terms" => Self::Terms(deserialize_variant(session, data)?),
99            _ => {
100                Self::_Custom(CustomAuthData { auth_type: auth_type.into(), session, extra: data })
101            }
102        })
103    }
104
105    pub fn fallback_acknowledgement(session: String) -> Self {
107        Self::FallbackAcknowledgement(FallbackAcknowledgement::new(session))
108    }
109
110    pub fn auth_type(&self) -> Option<AuthType> {
112        match self {
113            Self::Password(_) => Some(AuthType::Password),
114            Self::ReCaptcha(_) => Some(AuthType::ReCaptcha),
115            Self::EmailIdentity(_) => Some(AuthType::EmailIdentity),
116            Self::Msisdn(_) => Some(AuthType::Msisdn),
117            Self::Dummy(_) => Some(AuthType::Dummy),
118            Self::RegistrationToken(_) => Some(AuthType::RegistrationToken),
119            Self::FallbackAcknowledgement(_) => None,
120            Self::Terms(_) => Some(AuthType::Terms),
121            Self::_Custom(c) => Some(AuthType::_Custom(PrivOwnedStr(c.auth_type.as_str().into()))),
122        }
123    }
124
125    pub fn session(&self) -> Option<&str> {
127        match self {
128            Self::Password(x) => x.session.as_deref(),
129            Self::ReCaptcha(x) => x.session.as_deref(),
130            Self::EmailIdentity(x) => x.session.as_deref(),
131            Self::Msisdn(x) => x.session.as_deref(),
132            Self::Dummy(x) => x.session.as_deref(),
133            Self::RegistrationToken(x) => x.session.as_deref(),
134            Self::FallbackAcknowledgement(x) => Some(&x.session),
135            Self::Terms(x) => x.session.as_deref(),
136            Self::_Custom(x) => x.session.as_deref(),
137        }
138    }
139
140    pub fn data(&self) -> Cow<'_, JsonObject> {
148        fn serialize<T: Serialize>(obj: T) -> JsonObject {
149            match serde_json::to_value(obj).expect("auth data serialization to succeed") {
150                JsonValue::Object(obj) => obj,
151                _ => panic!("all auth data variants must serialize to objects"),
152            }
153        }
154
155        match self {
156            Self::Password(x) => Cow::Owned(serialize(Password {
157                identifier: x.identifier.clone(),
158                password: x.password.clone(),
159                session: None,
160            })),
161            Self::ReCaptcha(x) => {
162                Cow::Owned(serialize(ReCaptcha { response: x.response.clone(), session: None }))
163            }
164            Self::EmailIdentity(x) => Cow::Owned(serialize(EmailIdentity {
165                thirdparty_id_creds: x.thirdparty_id_creds.clone(),
166                session: None,
167            })),
168            Self::Msisdn(x) => Cow::Owned(serialize(Msisdn {
169                thirdparty_id_creds: x.thirdparty_id_creds.clone(),
170                session: None,
171            })),
172            Self::RegistrationToken(x) => {
173                Cow::Owned(serialize(RegistrationToken { token: x.token.clone(), session: None }))
174            }
175            Self::Dummy(_) | Self::FallbackAcknowledgement(_) | Self::Terms(_) => {
177                Cow::Owned(JsonObject::default())
178            }
179            Self::_Custom(c) => Cow::Borrowed(&c.extra),
180        }
181    }
182}
183
184impl fmt::Debug for AuthData {
185    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
186        match self {
188            Self::Password(inner) => inner.fmt(f),
189            Self::ReCaptcha(inner) => inner.fmt(f),
190            Self::EmailIdentity(inner) => inner.fmt(f),
191            Self::Msisdn(inner) => inner.fmt(f),
192            Self::Dummy(inner) => inner.fmt(f),
193            Self::RegistrationToken(inner) => inner.fmt(f),
194            Self::FallbackAcknowledgement(inner) => inner.fmt(f),
195            Self::Terms(inner) => inner.fmt(f),
196            Self::_Custom(inner) => inner.fmt(f),
197        }
198    }
199}
200
201impl<'de> Deserialize<'de> for AuthData {
202    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
203    where
204        D: Deserializer<'de>,
205    {
206        let json = Box::<RawJsonValue>::deserialize(deserializer)?;
207
208        #[derive(Deserialize)]
209        struct ExtractType<'a> {
210            #[serde(borrow, rename = "type")]
211            auth_type: Option<Cow<'a, str>>,
212        }
213
214        let auth_type = serde_json::from_str::<ExtractType<'_>>(json.get())
215            .map_err(de::Error::custom)?
216            .auth_type;
217
218        match auth_type.as_deref() {
219            Some("m.login.password") => from_raw_json_value(&json).map(Self::Password),
220            Some("m.login.recaptcha") => from_raw_json_value(&json).map(Self::ReCaptcha),
221            Some("m.login.email.identity") => from_raw_json_value(&json).map(Self::EmailIdentity),
222            Some("m.login.msisdn") => from_raw_json_value(&json).map(Self::Msisdn),
223            Some("m.login.dummy") => from_raw_json_value(&json).map(Self::Dummy),
224            Some("m.login.registration_token") => {
225                from_raw_json_value(&json).map(Self::RegistrationToken)
226            }
227            Some("m.login.terms") => from_raw_json_value(&json).map(Self::Terms),
228            None => from_raw_json_value(&json).map(Self::FallbackAcknowledgement),
229            Some(_) => from_raw_json_value(&json).map(Self::_Custom),
230        }
231    }
232}
233
234#[doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/src/doc/string_enum.md"))]
236#[derive(Clone, StringEnum)]
237#[non_exhaustive]
238pub enum AuthType {
239    #[ruma_enum(rename = "m.login.password")]
241    Password,
242
243    #[ruma_enum(rename = "m.login.recaptcha")]
245    ReCaptcha,
246
247    #[ruma_enum(rename = "m.login.email.identity")]
249    EmailIdentity,
250
251    #[ruma_enum(rename = "m.login.msisdn")]
253    Msisdn,
254
255    #[ruma_enum(rename = "m.login.sso")]
257    Sso,
258
259    #[ruma_enum(rename = "m.login.dummy")]
261    Dummy,
262
263    #[ruma_enum(rename = "m.login.registration_token")]
265    RegistrationToken,
266
267    #[ruma_enum(rename = "m.login.terms")]
271    Terms,
272
273    #[doc(hidden)]
274    _Custom(PrivOwnedStr),
275}
276
277#[derive(Clone, Deserialize, Serialize)]
283#[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
284#[serde(tag = "type", rename = "m.login.password")]
285pub struct Password {
286    pub identifier: UserIdentifier,
288
289    pub password: String,
291
292    pub session: Option<String>,
294}
295
296impl Password {
297    pub fn new(identifier: UserIdentifier, password: String) -> Self {
299        Self { identifier, password, session: None }
300    }
301}
302
303impl fmt::Debug for Password {
304    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
305        let Self { identifier, password: _, session } = self;
306        f.debug_struct("Password")
307            .field("identifier", identifier)
308            .field("session", session)
309            .finish_non_exhaustive()
310    }
311}
312
313#[derive(Clone, Deserialize, Serialize)]
319#[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
320#[serde(tag = "type", rename = "m.login.recaptcha")]
321pub struct ReCaptcha {
322    pub response: String,
324
325    pub session: Option<String>,
327}
328
329impl ReCaptcha {
330    pub fn new(response: String) -> Self {
332        Self { response, session: None }
333    }
334}
335
336impl fmt::Debug for ReCaptcha {
337    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
338        let Self { response: _, session } = self;
339        f.debug_struct("ReCaptcha").field("session", session).finish_non_exhaustive()
340    }
341}
342
343#[derive(Clone, Debug, Deserialize, Serialize)]
349#[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
350#[serde(tag = "type", rename = "m.login.email.identity")]
351pub struct EmailIdentity {
352    #[serde(rename = "threepid_creds")]
354    pub thirdparty_id_creds: ThirdpartyIdCredentials,
355
356    pub session: Option<String>,
358}
359
360#[derive(Clone, Debug, Deserialize, Serialize)]
366#[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
367#[serde(tag = "type", rename = "m.login.msisdn")]
368pub struct Msisdn {
369    #[serde(rename = "threepid_creds")]
371    pub thirdparty_id_creds: ThirdpartyIdCredentials,
372
373    pub session: Option<String>,
375}
376
377#[derive(Clone, Debug, Default, Deserialize, Serialize)]
383#[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
384#[serde(tag = "type", rename = "m.login.dummy")]
385pub struct Dummy {
386    pub session: Option<String>,
388}
389
390impl Dummy {
391    pub fn new() -> Self {
393        Self::default()
394    }
395}
396
397#[derive(Clone, Deserialize, Serialize)]
403#[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
404#[serde(tag = "type", rename = "m.login.registration_token")]
405pub struct RegistrationToken {
406    pub token: String,
408
409    pub session: Option<String>,
411}
412
413impl RegistrationToken {
414    pub fn new(token: String) -> Self {
416        Self { token, session: None }
417    }
418}
419
420impl fmt::Debug for RegistrationToken {
421    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
422        let Self { token: _, session } = self;
423        f.debug_struct("RegistrationToken").field("session", session).finish_non_exhaustive()
424    }
425}
426
427#[derive(Clone, Debug, Deserialize, Serialize)]
433#[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
434pub struct FallbackAcknowledgement {
435    pub session: String,
437}
438
439impl FallbackAcknowledgement {
440    pub fn new(session: String) -> Self {
442        Self { session }
443    }
444}
445
446#[derive(Clone, Debug, Default, Deserialize, Serialize)]
454#[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
455#[serde(tag = "type", rename = "m.login.terms")]
456pub struct Terms {
457    pub session: Option<String>,
459}
460
461impl Terms {
462    pub fn new() -> Self {
464        Self::default()
465    }
466}
467
468#[doc(hidden)]
469#[derive(Clone, Deserialize, Serialize)]
470#[non_exhaustive]
471pub struct CustomAuthData {
472    #[serde(rename = "type")]
473    auth_type: String,
474    session: Option<String>,
475    #[serde(flatten)]
476    extra: JsonObject,
477}
478
479impl fmt::Debug for CustomAuthData {
480    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
481        let Self { auth_type, session, extra: _ } = self;
482        f.debug_struct("CustomAuthData")
483            .field("auth_type", auth_type)
484            .field("session", session)
485            .finish_non_exhaustive()
486    }
487}
488
489#[derive(Clone, Debug, PartialEq, Eq)]
491#[allow(clippy::exhaustive_enums)]
492pub enum UserIdentifier {
493    UserIdOrLocalpart(String),
496
497    Email {
499        address: String,
501    },
502
503    Msisdn {
505        number: String,
509    },
510
511    PhoneNumber {
515        country: String,
521
522        phone: String,
524    },
525
526    #[doc(hidden)]
527    _CustomThirdParty(CustomThirdPartyId),
528}
529
530impl UserIdentifier {
531    pub fn third_party_id(medium: Medium, address: String) -> Self {
533        match medium {
534            Medium::Email => Self::Email { address },
535            Medium::Msisdn => Self::Msisdn { number: address },
536            _ => Self::_CustomThirdParty(CustomThirdPartyId { medium, address }),
537        }
538    }
539
540    pub fn as_third_party_id(&self) -> Option<(&Medium, &str)> {
542        match self {
543            Self::Email { address } => Some((&Medium::Email, address)),
544            Self::Msisdn { number } => Some((&Medium::Msisdn, number)),
545            Self::_CustomThirdParty(CustomThirdPartyId { medium, address }) => {
546                Some((medium, address))
547            }
548            _ => None,
549        }
550    }
551}
552
553impl From<OwnedUserId> for UserIdentifier {
554    fn from(id: OwnedUserId) -> Self {
555        Self::UserIdOrLocalpart(id.into())
556    }
557}
558
559impl From<&OwnedUserId> for UserIdentifier {
560    fn from(id: &OwnedUserId) -> Self {
561        Self::UserIdOrLocalpart(id.as_str().to_owned())
562    }
563}
564
565#[doc(hidden)]
566#[derive(Clone, Debug, PartialEq, Eq, Serialize)]
567#[non_exhaustive]
568pub struct CustomThirdPartyId {
569    medium: Medium,
570    address: String,
571}
572
573#[doc(hidden)]
574#[derive(Clone, Debug, PartialEq, Eq, Deserialize)]
575#[non_exhaustive]
576pub struct IncomingCustomThirdPartyId {
577    medium: Medium,
578    address: String,
579}
580
581#[derive(Clone, Deserialize, Serialize)]
583#[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
584pub struct ThirdpartyIdCredentials {
585    pub sid: OwnedSessionId,
587
588    pub client_secret: OwnedClientSecret,
590
591    #[serde(skip_serializing_if = "Option::is_none")]
593    pub id_server: Option<String>,
594
595    #[serde(skip_serializing_if = "Option::is_none")]
597    pub id_access_token: Option<String>,
598}
599
600impl ThirdpartyIdCredentials {
601    pub fn new(sid: OwnedSessionId, client_secret: OwnedClientSecret) -> Self {
603        Self { sid, client_secret, id_server: None, id_access_token: None }
604    }
605}
606
607impl fmt::Debug for ThirdpartyIdCredentials {
608    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
609        let Self { sid, client_secret: _, id_server, id_access_token } = self;
610        f.debug_struct("ThirdpartyIdCredentials")
611            .field("sid", sid)
612            .field("id_server", id_server)
613            .field("id_access_token", id_access_token)
614            .finish_non_exhaustive()
615    }
616}
617
618#[derive(Clone, Debug, Deserialize, Serialize)]
621#[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
622pub struct UiaaInfo {
623    pub flows: Vec<AuthFlow>,
625
626    #[serde(default, skip_serializing_if = "Vec::is_empty")]
628    pub completed: Vec<AuthType>,
629
630    #[serde(skip_serializing_if = "Option::is_none")]
634    pub params: Option<Box<RawJsonValue>>,
635
636    #[serde(skip_serializing_if = "Option::is_none")]
638    pub session: Option<String>,
639
640    #[serde(flatten, skip_serializing_if = "Option::is_none")]
642    pub auth_error: Option<StandardErrorBody>,
643}
644
645impl UiaaInfo {
646    pub fn new(flows: Vec<AuthFlow>) -> Self {
648        Self { flows, completed: Vec::new(), params: None, session: None, auth_error: None }
649    }
650}
651
652#[derive(Clone, Debug, Default, Deserialize, Serialize)]
654#[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
655pub struct AuthFlow {
656    #[serde(default, skip_serializing_if = "Vec::is_empty")]
658    pub stages: Vec<AuthType>,
659}
660
661impl AuthFlow {
662    pub fn new(stages: Vec<AuthType>) -> Self {
666        Self { stages }
667    }
668}
669
670#[derive(Clone, Debug)]
672#[allow(clippy::exhaustive_enums)]
673pub enum UiaaResponse {
674    AuthResponse(UiaaInfo),
676
677    MatrixError(MatrixError),
679}
680
681impl fmt::Display for UiaaResponse {
682    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
683        match self {
684            Self::AuthResponse(_) => write!(f, "User-Interactive Authentication required."),
685            Self::MatrixError(err) => write!(f, "{err}"),
686        }
687    }
688}
689
690impl From<MatrixError> for UiaaResponse {
691    fn from(error: MatrixError) -> Self {
692        Self::MatrixError(error)
693    }
694}
695
696impl EndpointError for UiaaResponse {
697    fn from_http_response<T: AsRef<[u8]>>(response: http::Response<T>) -> Self {
698        if response.status() == http::StatusCode::UNAUTHORIZED {
699            if let Ok(uiaa_info) = from_json_slice(response.body().as_ref()) {
700                return Self::AuthResponse(uiaa_info);
701            }
702        }
703
704        Self::MatrixError(MatrixError::from_http_response(response))
705    }
706}
707
708impl std::error::Error for UiaaResponse {}
709
710impl OutgoingResponse for UiaaResponse {
711    fn try_into_http_response<T: Default + BufMut>(
712        self,
713    ) -> Result<http::Response<T>, IntoHttpError> {
714        match self {
715            UiaaResponse::AuthResponse(authentication_info) => http::Response::builder()
716                .header(http::header::CONTENT_TYPE, ruma_common::http_headers::APPLICATION_JSON)
717                .status(http::StatusCode::UNAUTHORIZED)
718                .body(ruma_common::serde::json_to_buf(&authentication_info)?)
719                .map_err(Into::into),
720            UiaaResponse::MatrixError(error) => error.try_into_http_response(),
721        }
722    }
723}