1use std::{borrow::Cow, fmt};
6
7use bytes::BufMut;
8use ruma_common::{
9 api::{error::IntoHttpError, EndpointError, OutgoingResponse},
10 serde::{from_raw_json_value, JsonObject, StringEnum},
11 thirdparty::Medium,
12 OwnedClientSecret, OwnedSessionId, OwnedUserId,
13};
14use serde::{
15 de::{self, DeserializeOwned},
16 Deserialize, Deserializer, Serialize,
17};
18use serde_json::{
19 from_slice as from_json_slice, value::RawValue as RawJsonValue, Value as JsonValue,
20};
21
22use crate::{
23 error::{Error as MatrixError, StandardErrorBody},
24 PrivOwnedStr,
25};
26
27pub mod get_uiaa_fallback_page;
28mod user_serde;
29
30#[derive(Clone, Serialize)]
32#[non_exhaustive]
33#[serde(untagged)]
34pub enum AuthData {
35 Password(Password),
37
38 ReCaptcha(ReCaptcha),
40
41 EmailIdentity(EmailIdentity),
43
44 Msisdn(Msisdn),
46
47 Dummy(Dummy),
49
50 RegistrationToken(RegistrationToken),
52
53 FallbackAcknowledgement(FallbackAcknowledgement),
55
56 Terms(Terms),
60
61 #[doc(hidden)]
62 _Custom(CustomAuthData),
63}
64
65impl AuthData {
66 pub fn new(
77 auth_type: &str,
78 session: Option<String>,
79 data: JsonObject,
80 ) -> serde_json::Result<Self> {
81 fn deserialize_variant<T: DeserializeOwned>(
82 session: Option<String>,
83 mut obj: JsonObject,
84 ) -> serde_json::Result<T> {
85 if let Some(session) = session {
86 obj.insert("session".into(), session.into());
87 }
88 serde_json::from_value(JsonValue::Object(obj))
89 }
90
91 Ok(match auth_type {
92 "m.login.password" => Self::Password(deserialize_variant(session, data)?),
93 "m.login.recaptcha" => Self::ReCaptcha(deserialize_variant(session, data)?),
94 "m.login.email.identity" => Self::EmailIdentity(deserialize_variant(session, data)?),
95 "m.login.msisdn" => Self::Msisdn(deserialize_variant(session, data)?),
96 "m.login.dummy" => Self::Dummy(deserialize_variant(session, data)?),
97 "m.registration_token" => Self::RegistrationToken(deserialize_variant(session, data)?),
98 "m.login.terms" => Self::Terms(deserialize_variant(session, data)?),
99 _ => {
100 Self::_Custom(CustomAuthData { auth_type: auth_type.into(), session, extra: data })
101 }
102 })
103 }
104
105 pub fn fallback_acknowledgement(session: String) -> Self {
107 Self::FallbackAcknowledgement(FallbackAcknowledgement::new(session))
108 }
109
110 pub fn auth_type(&self) -> Option<AuthType> {
112 match self {
113 Self::Password(_) => Some(AuthType::Password),
114 Self::ReCaptcha(_) => Some(AuthType::ReCaptcha),
115 Self::EmailIdentity(_) => Some(AuthType::EmailIdentity),
116 Self::Msisdn(_) => Some(AuthType::Msisdn),
117 Self::Dummy(_) => Some(AuthType::Dummy),
118 Self::RegistrationToken(_) => Some(AuthType::RegistrationToken),
119 Self::FallbackAcknowledgement(_) => None,
120 Self::Terms(_) => Some(AuthType::Terms),
121 Self::_Custom(c) => Some(AuthType::_Custom(PrivOwnedStr(c.auth_type.as_str().into()))),
122 }
123 }
124
125 pub fn session(&self) -> Option<&str> {
127 match self {
128 Self::Password(x) => x.session.as_deref(),
129 Self::ReCaptcha(x) => x.session.as_deref(),
130 Self::EmailIdentity(x) => x.session.as_deref(),
131 Self::Msisdn(x) => x.session.as_deref(),
132 Self::Dummy(x) => x.session.as_deref(),
133 Self::RegistrationToken(x) => x.session.as_deref(),
134 Self::FallbackAcknowledgement(x) => Some(&x.session),
135 Self::Terms(x) => x.session.as_deref(),
136 Self::_Custom(x) => x.session.as_deref(),
137 }
138 }
139
140 pub fn data(&self) -> Cow<'_, JsonObject> {
148 fn serialize<T: Serialize>(obj: T) -> JsonObject {
149 match serde_json::to_value(obj).expect("auth data serialization to succeed") {
150 JsonValue::Object(obj) => obj,
151 _ => panic!("all auth data variants must serialize to objects"),
152 }
153 }
154
155 match self {
156 Self::Password(x) => Cow::Owned(serialize(Password {
157 identifier: x.identifier.clone(),
158 password: x.password.clone(),
159 session: None,
160 })),
161 Self::ReCaptcha(x) => {
162 Cow::Owned(serialize(ReCaptcha { response: x.response.clone(), session: None }))
163 }
164 Self::EmailIdentity(x) => Cow::Owned(serialize(EmailIdentity {
165 thirdparty_id_creds: x.thirdparty_id_creds.clone(),
166 session: None,
167 })),
168 Self::Msisdn(x) => Cow::Owned(serialize(Msisdn {
169 thirdparty_id_creds: x.thirdparty_id_creds.clone(),
170 session: None,
171 })),
172 Self::RegistrationToken(x) => {
173 Cow::Owned(serialize(RegistrationToken { token: x.token.clone(), session: None }))
174 }
175 Self::Dummy(_) | Self::FallbackAcknowledgement(_) | Self::Terms(_) => {
177 Cow::Owned(JsonObject::default())
178 }
179 Self::_Custom(c) => Cow::Borrowed(&c.extra),
180 }
181 }
182}
183
184impl fmt::Debug for AuthData {
185 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
186 match self {
188 Self::Password(inner) => inner.fmt(f),
189 Self::ReCaptcha(inner) => inner.fmt(f),
190 Self::EmailIdentity(inner) => inner.fmt(f),
191 Self::Msisdn(inner) => inner.fmt(f),
192 Self::Dummy(inner) => inner.fmt(f),
193 Self::RegistrationToken(inner) => inner.fmt(f),
194 Self::FallbackAcknowledgement(inner) => inner.fmt(f),
195 Self::Terms(inner) => inner.fmt(f),
196 Self::_Custom(inner) => inner.fmt(f),
197 }
198 }
199}
200
201impl<'de> Deserialize<'de> for AuthData {
202 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
203 where
204 D: Deserializer<'de>,
205 {
206 let json = Box::<RawJsonValue>::deserialize(deserializer)?;
207
208 #[derive(Deserialize)]
209 struct ExtractType<'a> {
210 #[serde(borrow, rename = "type")]
211 auth_type: Option<Cow<'a, str>>,
212 }
213
214 let auth_type = serde_json::from_str::<ExtractType<'_>>(json.get())
215 .map_err(de::Error::custom)?
216 .auth_type;
217
218 match auth_type.as_deref() {
219 Some("m.login.password") => from_raw_json_value(&json).map(Self::Password),
220 Some("m.login.recaptcha") => from_raw_json_value(&json).map(Self::ReCaptcha),
221 Some("m.login.email.identity") => from_raw_json_value(&json).map(Self::EmailIdentity),
222 Some("m.login.msisdn") => from_raw_json_value(&json).map(Self::Msisdn),
223 Some("m.login.dummy") => from_raw_json_value(&json).map(Self::Dummy),
224 Some("m.login.registration_token") => {
225 from_raw_json_value(&json).map(Self::RegistrationToken)
226 }
227 Some("m.login.terms") => from_raw_json_value(&json).map(Self::Terms),
228 None => from_raw_json_value(&json).map(Self::FallbackAcknowledgement),
229 Some(_) => from_raw_json_value(&json).map(Self::_Custom),
230 }
231 }
232}
233
234#[doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/src/doc/string_enum.md"))]
236#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, StringEnum)]
237#[non_exhaustive]
238pub enum AuthType {
239 #[ruma_enum(rename = "m.login.password")]
241 Password,
242
243 #[ruma_enum(rename = "m.login.recaptcha")]
245 ReCaptcha,
246
247 #[ruma_enum(rename = "m.login.email.identity")]
249 EmailIdentity,
250
251 #[ruma_enum(rename = "m.login.msisdn")]
253 Msisdn,
254
255 #[ruma_enum(rename = "m.login.sso")]
257 Sso,
258
259 #[ruma_enum(rename = "m.login.dummy")]
261 Dummy,
262
263 #[ruma_enum(rename = "m.login.registration_token")]
265 RegistrationToken,
266
267 #[ruma_enum(rename = "m.login.terms")]
271 Terms,
272
273 #[doc(hidden)]
274 _Custom(PrivOwnedStr),
275}
276
277#[derive(Clone, Deserialize, Serialize)]
283#[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
284#[serde(tag = "type", rename = "m.login.password")]
285pub struct Password {
286 pub identifier: UserIdentifier,
288
289 pub password: String,
291
292 pub session: Option<String>,
294}
295
296impl Password {
297 pub fn new(identifier: UserIdentifier, password: String) -> Self {
299 Self { identifier, password, session: None }
300 }
301}
302
303impl fmt::Debug for Password {
304 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
305 let Self { identifier, password: _, session } = self;
306 f.debug_struct("Password")
307 .field("identifier", identifier)
308 .field("session", session)
309 .finish_non_exhaustive()
310 }
311}
312
313#[derive(Clone, Deserialize, Serialize)]
319#[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
320#[serde(tag = "type", rename = "m.login.recaptcha")]
321pub struct ReCaptcha {
322 pub response: String,
324
325 pub session: Option<String>,
327}
328
329impl ReCaptcha {
330 pub fn new(response: String) -> Self {
332 Self { response, session: None }
333 }
334}
335
336impl fmt::Debug for ReCaptcha {
337 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
338 let Self { response: _, session } = self;
339 f.debug_struct("ReCaptcha").field("session", session).finish_non_exhaustive()
340 }
341}
342
343#[derive(Clone, Debug, Deserialize, Serialize)]
349#[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
350#[serde(tag = "type", rename = "m.login.email.identity")]
351pub struct EmailIdentity {
352 #[serde(rename = "threepid_creds")]
354 pub thirdparty_id_creds: ThirdpartyIdCredentials,
355
356 pub session: Option<String>,
358}
359
360#[derive(Clone, Debug, Deserialize, Serialize)]
366#[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
367#[serde(tag = "type", rename = "m.login.msisdn")]
368pub struct Msisdn {
369 #[serde(rename = "threepid_creds")]
371 pub thirdparty_id_creds: ThirdpartyIdCredentials,
372
373 pub session: Option<String>,
375}
376
377#[derive(Clone, Debug, Default, Deserialize, Serialize)]
383#[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
384#[serde(tag = "type", rename = "m.login.dummy")]
385pub struct Dummy {
386 pub session: Option<String>,
388}
389
390impl Dummy {
391 pub fn new() -> Self {
393 Self::default()
394 }
395}
396
397#[derive(Clone, Deserialize, Serialize)]
403#[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
404#[serde(tag = "type", rename = "m.login.registration_token")]
405pub struct RegistrationToken {
406 pub token: String,
408
409 pub session: Option<String>,
411}
412
413impl RegistrationToken {
414 pub fn new(token: String) -> Self {
416 Self { token, session: None }
417 }
418}
419
420impl fmt::Debug for RegistrationToken {
421 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
422 let Self { token: _, session } = self;
423 f.debug_struct("RegistrationToken").field("session", session).finish_non_exhaustive()
424 }
425}
426
427#[derive(Clone, Debug, Deserialize, Serialize)]
433#[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
434pub struct FallbackAcknowledgement {
435 pub session: String,
437}
438
439impl FallbackAcknowledgement {
440 pub fn new(session: String) -> Self {
442 Self { session }
443 }
444}
445
446#[derive(Clone, Debug, Default, Deserialize, Serialize)]
454#[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
455#[serde(tag = "type", rename = "m.login.terms")]
456pub struct Terms {
457 pub session: Option<String>,
459}
460
461impl Terms {
462 pub fn new() -> Self {
464 Self::default()
465 }
466}
467
468#[doc(hidden)]
469#[derive(Clone, Deserialize, Serialize)]
470#[non_exhaustive]
471pub struct CustomAuthData {
472 #[serde(rename = "type")]
473 auth_type: String,
474 session: Option<String>,
475 #[serde(flatten)]
476 extra: JsonObject,
477}
478
479impl fmt::Debug for CustomAuthData {
480 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
481 let Self { auth_type, session, extra: _ } = self;
482 f.debug_struct("CustomAuthData")
483 .field("auth_type", auth_type)
484 .field("session", session)
485 .finish_non_exhaustive()
486 }
487}
488
489#[derive(Clone, Debug, PartialEq, Eq)]
491#[allow(clippy::exhaustive_enums)]
492pub enum UserIdentifier {
493 UserIdOrLocalpart(String),
496
497 Email {
499 address: String,
501 },
502
503 Msisdn {
505 number: String,
509 },
510
511 PhoneNumber {
515 country: String,
521
522 phone: String,
524 },
525
526 #[doc(hidden)]
527 _CustomThirdParty(CustomThirdPartyId),
528}
529
530impl UserIdentifier {
531 pub fn third_party_id(medium: Medium, address: String) -> Self {
533 match medium {
534 Medium::Email => Self::Email { address },
535 Medium::Msisdn => Self::Msisdn { number: address },
536 _ => Self::_CustomThirdParty(CustomThirdPartyId { medium, address }),
537 }
538 }
539
540 pub fn as_third_party_id(&self) -> Option<(&Medium, &str)> {
542 match self {
543 Self::Email { address } => Some((&Medium::Email, address)),
544 Self::Msisdn { number } => Some((&Medium::Msisdn, number)),
545 Self::_CustomThirdParty(CustomThirdPartyId { medium, address }) => {
546 Some((medium, address))
547 }
548 _ => None,
549 }
550 }
551}
552
553impl From<OwnedUserId> for UserIdentifier {
554 fn from(id: OwnedUserId) -> Self {
555 Self::UserIdOrLocalpart(id.into())
556 }
557}
558
559impl From<&OwnedUserId> for UserIdentifier {
560 fn from(id: &OwnedUserId) -> Self {
561 Self::UserIdOrLocalpart(id.as_str().to_owned())
562 }
563}
564
565#[doc(hidden)]
566#[derive(Clone, Debug, PartialEq, Eq, Serialize)]
567#[non_exhaustive]
568pub struct CustomThirdPartyId {
569 medium: Medium,
570 address: String,
571}
572
573#[doc(hidden)]
574#[derive(Clone, Debug, PartialEq, Eq, Deserialize)]
575#[non_exhaustive]
576pub struct IncomingCustomThirdPartyId {
577 medium: Medium,
578 address: String,
579}
580
581#[derive(Clone, Deserialize, Serialize)]
583#[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
584pub struct ThirdpartyIdCredentials {
585 pub sid: OwnedSessionId,
587
588 pub client_secret: OwnedClientSecret,
590
591 #[serde(skip_serializing_if = "Option::is_none")]
593 pub id_server: Option<String>,
594
595 #[serde(skip_serializing_if = "Option::is_none")]
597 pub id_access_token: Option<String>,
598}
599
600impl ThirdpartyIdCredentials {
601 pub fn new(sid: OwnedSessionId, client_secret: OwnedClientSecret) -> Self {
603 Self { sid, client_secret, id_server: None, id_access_token: None }
604 }
605}
606
607impl fmt::Debug for ThirdpartyIdCredentials {
608 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
609 let Self { sid, client_secret: _, id_server, id_access_token } = self;
610 f.debug_struct("ThirdpartyIdCredentials")
611 .field("sid", sid)
612 .field("id_server", id_server)
613 .field("id_access_token", id_access_token)
614 .finish_non_exhaustive()
615 }
616}
617
618#[derive(Clone, Debug, Deserialize, Serialize)]
621#[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
622pub struct UiaaInfo {
623 pub flows: Vec<AuthFlow>,
625
626 #[serde(default, skip_serializing_if = "Vec::is_empty")]
628 pub completed: Vec<AuthType>,
629
630 #[serde(skip_serializing_if = "Option::is_none")]
634 pub params: Option<Box<RawJsonValue>>,
635
636 #[serde(skip_serializing_if = "Option::is_none")]
638 pub session: Option<String>,
639
640 #[serde(flatten, skip_serializing_if = "Option::is_none")]
642 pub auth_error: Option<StandardErrorBody>,
643}
644
645impl UiaaInfo {
646 pub fn new(flows: Vec<AuthFlow>) -> Self {
648 Self { flows, completed: Vec::new(), params: None, session: None, auth_error: None }
649 }
650}
651
652#[derive(Clone, Debug, Default, Deserialize, Serialize)]
654#[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
655pub struct AuthFlow {
656 #[serde(default, skip_serializing_if = "Vec::is_empty")]
658 pub stages: Vec<AuthType>,
659}
660
661impl AuthFlow {
662 pub fn new(stages: Vec<AuthType>) -> Self {
666 Self { stages }
667 }
668}
669
670#[derive(Clone, Debug)]
672#[allow(clippy::exhaustive_enums)]
673pub enum UiaaResponse {
674 AuthResponse(UiaaInfo),
676
677 MatrixError(MatrixError),
679}
680
681impl fmt::Display for UiaaResponse {
682 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
683 match self {
684 Self::AuthResponse(_) => write!(f, "User-Interactive Authentication required."),
685 Self::MatrixError(err) => write!(f, "{err}"),
686 }
687 }
688}
689
690impl From<MatrixError> for UiaaResponse {
691 fn from(error: MatrixError) -> Self {
692 Self::MatrixError(error)
693 }
694}
695
696impl EndpointError for UiaaResponse {
697 fn from_http_response<T: AsRef<[u8]>>(response: http::Response<T>) -> Self {
698 if response.status() == http::StatusCode::UNAUTHORIZED {
699 if let Ok(uiaa_info) = from_json_slice(response.body().as_ref()) {
700 return Self::AuthResponse(uiaa_info);
701 }
702 }
703
704 Self::MatrixError(MatrixError::from_http_response(response))
705 }
706}
707
708impl std::error::Error for UiaaResponse {}
709
710impl OutgoingResponse for UiaaResponse {
711 fn try_into_http_response<T: Default + BufMut>(
712 self,
713 ) -> Result<http::Response<T>, IntoHttpError> {
714 match self {
715 UiaaResponse::AuthResponse(authentication_info) => http::Response::builder()
716 .header(http::header::CONTENT_TYPE, "application/json")
717 .status(http::StatusCode::UNAUTHORIZED)
718 .body(ruma_common::serde::json_to_buf(&authentication_info)?)
719 .map_err(Into::into),
720 UiaaResponse::MatrixError(error) => error.try_into_http_response(),
721 }
722 }
723}