ruma_client_api/
uiaa.rs

1//! Module for [User-Interactive Authentication API][uiaa] types.
2//!
3//! [uiaa]: https://spec.matrix.org/latest/client-server-api/#user-interactive-authentication-api
4
5use std::{borrow::Cow, fmt, marker::PhantomData};
6
7use bytes::BufMut;
8use ruma_common::{
9    api::{EndpointError, OutgoingResponse, error::IntoHttpError},
10    serde::StringEnum,
11};
12use serde::{Deserialize, Deserializer, Serialize, de};
13use serde_json::{from_slice as from_json_slice, value::RawValue as RawJsonValue};
14
15use crate::{
16    PrivOwnedStr,
17    error::{Error as MatrixError, StandardErrorBody},
18};
19
20mod auth_data;
21mod auth_params;
22pub mod get_uiaa_fallback_page;
23
24pub use self::{auth_data::*, auth_params::*};
25
26/// The type of an authentication stage.
27#[doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/src/doc/string_enum.md"))]
28#[derive(Clone, StringEnum)]
29#[non_exhaustive]
30pub enum AuthType {
31    /// Password-based authentication (`m.login.password`).
32    #[ruma_enum(rename = "m.login.password")]
33    Password,
34
35    /// Google ReCaptcha 2.0 authentication (`m.login.recaptcha`).
36    #[ruma_enum(rename = "m.login.recaptcha")]
37    ReCaptcha,
38
39    /// Email-based authentication (`m.login.email.identity`).
40    #[ruma_enum(rename = "m.login.email.identity")]
41    EmailIdentity,
42
43    /// Phone number-based authentication (`m.login.msisdn`).
44    #[ruma_enum(rename = "m.login.msisdn")]
45    Msisdn,
46
47    /// SSO-based authentication (`m.login.sso`).
48    #[ruma_enum(rename = "m.login.sso")]
49    Sso,
50
51    /// Dummy authentication (`m.login.dummy`).
52    #[ruma_enum(rename = "m.login.dummy")]
53    Dummy,
54
55    /// Registration token-based authentication (`m.login.registration_token`).
56    #[ruma_enum(rename = "m.login.registration_token")]
57    RegistrationToken,
58
59    /// Terms of service (`m.login.terms`).
60    ///
61    /// This type is only valid during account registration.
62    #[ruma_enum(rename = "m.login.terms")]
63    Terms,
64
65    /// OAuth 2.0 (`m.oauth`).
66    ///
67    /// This type is only valid with the cross-signing keys upload endpoint, after logging in with
68    /// the OAuth 2.0 API.
69    #[ruma_enum(rename = "m.oauth", alias = "org.matrix.cross_signing_reset")]
70    OAuth,
71
72    #[doc(hidden)]
73    _Custom(PrivOwnedStr),
74}
75
76/// Information about available authentication flows and status for User-Interactive Authenticiation
77/// API.
78#[derive(Clone, Debug, Deserialize, Serialize)]
79#[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
80pub struct UiaaInfo {
81    /// List of authentication flows available for this endpoint.
82    pub flows: Vec<AuthFlow>,
83
84    /// List of stages in the current flow completed by the client.
85    #[serde(default, skip_serializing_if = "Vec::is_empty")]
86    pub completed: Vec<AuthType>,
87
88    /// Authentication parameters required for the client to complete authentication.
89    ///
90    /// To create a `Box<RawJsonValue>`, use `serde_json::value::to_raw_value`.
91    #[serde(skip_serializing_if = "Option::is_none")]
92    pub params: Option<Box<RawJsonValue>>,
93
94    /// Session key for client to use to complete authentication.
95    #[serde(skip_serializing_if = "Option::is_none")]
96    pub session: Option<String>,
97
98    /// Authentication-related errors for previous request returned by homeserver.
99    #[serde(flatten, skip_serializing_if = "Option::is_none")]
100    pub auth_error: Option<StandardErrorBody>,
101}
102
103impl UiaaInfo {
104    /// Creates a new `UiaaInfo` with the given flows.
105    pub fn new(flows: Vec<AuthFlow>) -> Self {
106        Self { flows, completed: Vec::new(), params: None, session: None, auth_error: None }
107    }
108
109    /// Get the parameters for the given [`AuthType`], if they are available in the `params` object.
110    ///
111    /// Returns `Ok(Some(_))` if the parameters for the authentication type were found and the
112    /// deserialization worked, `Ok(None)` if the parameters for the authentication type were not
113    /// found, and `Err(_)` if the parameters for the authentication type were found but their
114    /// deserialization failed.
115    ///
116    /// # Example
117    ///
118    /// ```
119    /// # use ruma_client_api::uiaa::UiaaInfo;
120    /// use ruma_client_api::uiaa::{AuthType, LoginTermsParams};
121    ///
122    /// # let uiaa_info = UiaaInfo::new(Vec::new());
123    /// let login_terms_params = uiaa_info.params::<LoginTermsParams>(&AuthType::Terms)?;
124    /// # Ok::<(), serde_json::Error>(())
125    /// ```
126    pub fn params<'a, T: Deserialize<'a>>(
127        &'a self,
128        auth_type: &AuthType,
129    ) -> Result<Option<T>, serde_json::Error> {
130        struct AuthTypeVisitor<'b, T> {
131            auth_type: &'b AuthType,
132            _phantom: PhantomData<T>,
133        }
134
135        impl<'de, T> de::Visitor<'de> for AuthTypeVisitor<'_, T>
136        where
137            T: Deserialize<'de>,
138        {
139            type Value = Option<T>;
140
141            fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
142                formatter.write_str("a key-value map")
143            }
144
145            fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
146            where
147                A: de::MapAccess<'de>,
148            {
149                let mut params = None;
150
151                while let Some(key) = map.next_key::<Cow<'de, str>>()? {
152                    if AuthType::from(key) == *self.auth_type {
153                        params = Some(map.next_value()?);
154                    } else {
155                        map.next_value::<de::IgnoredAny>()?;
156                    }
157                }
158
159                Ok(params)
160            }
161        }
162
163        let Some(params) = &self.params else {
164            return Ok(None);
165        };
166
167        let mut deserializer = serde_json::Deserializer::from_str(params.get());
168        deserializer.deserialize_map(AuthTypeVisitor { auth_type, _phantom: PhantomData })
169    }
170}
171
172/// Description of steps required to authenticate via the User-Interactive Authentication API.
173#[derive(Clone, Debug, Default, Deserialize, Serialize)]
174#[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
175pub struct AuthFlow {
176    /// Ordered list of stages required to complete authentication.
177    #[serde(default, skip_serializing_if = "Vec::is_empty")]
178    pub stages: Vec<AuthType>,
179}
180
181impl AuthFlow {
182    /// Creates a new `AuthFlow` with the given stages.
183    ///
184    /// To create an empty `AuthFlow`, use `AuthFlow::default()`.
185    pub fn new(stages: Vec<AuthType>) -> Self {
186        Self { stages }
187    }
188}
189
190/// Contains either a User-Interactive Authentication API response body or a Matrix error.
191#[derive(Clone, Debug)]
192#[allow(clippy::exhaustive_enums)]
193pub enum UiaaResponse {
194    /// User-Interactive Authentication API response
195    AuthResponse(UiaaInfo),
196
197    /// Matrix error response
198    MatrixError(MatrixError),
199}
200
201impl fmt::Display for UiaaResponse {
202    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
203        match self {
204            Self::AuthResponse(_) => write!(f, "User-Interactive Authentication required."),
205            Self::MatrixError(err) => write!(f, "{err}"),
206        }
207    }
208}
209
210impl From<MatrixError> for UiaaResponse {
211    fn from(error: MatrixError) -> Self {
212        Self::MatrixError(error)
213    }
214}
215
216impl EndpointError for UiaaResponse {
217    fn from_http_response<T: AsRef<[u8]>>(response: http::Response<T>) -> Self {
218        if response.status() == http::StatusCode::UNAUTHORIZED
219            && let Ok(uiaa_info) = from_json_slice(response.body().as_ref())
220        {
221            return Self::AuthResponse(uiaa_info);
222        }
223
224        Self::MatrixError(MatrixError::from_http_response(response))
225    }
226}
227
228impl std::error::Error for UiaaResponse {}
229
230impl OutgoingResponse for UiaaResponse {
231    fn try_into_http_response<T: Default + BufMut>(
232        self,
233    ) -> Result<http::Response<T>, IntoHttpError> {
234        match self {
235            UiaaResponse::AuthResponse(authentication_info) => http::Response::builder()
236                .header(http::header::CONTENT_TYPE, ruma_common::http_headers::APPLICATION_JSON)
237                .status(http::StatusCode::UNAUTHORIZED)
238                .body(ruma_common::serde::json_to_buf(&authentication_info)?)
239                .map_err(Into::into),
240            UiaaResponse::MatrixError(error) => error.try_into_http_response(),
241        }
242    }
243}
244
245#[cfg(test)]
246mod tests {
247    use assert_matches2::assert_matches;
248    use ruma_common::serde::JsonObject;
249    use serde_json::{from_value as from_json_value, json};
250
251    use super::{AuthType, LoginTermsParams, OAuthParams, UiaaInfo};
252
253    #[test]
254    fn uiaa_info_params() {
255        let json = json!({
256            "flows": [{
257                "stages": ["m.login.terms", "m.login.email.identity", "local.custom.stage"],
258            }],
259            "params": {
260                "local.custom.stage": {
261                    "foo": "bar",
262                },
263                "m.login.terms": {
264                    "policies": {
265                        "privacy": {
266                            "en-US": {
267                                "name": "Privacy Policy",
268                                "url": "http://matrix.local/en-US/privacy",
269                            },
270                            "fr-FR": {
271                                "name": "Politique de confidentialité",
272                                "url": "http://matrix.local/fr-FR/privacy",
273                            },
274                            "version": "1",
275                        },
276                    },
277                }
278            },
279            "session": "abcdef",
280        });
281
282        let info = from_json_value::<UiaaInfo>(json).unwrap();
283
284        assert_matches!(info.params::<JsonObject>(&AuthType::EmailIdentity), Ok(None));
285        assert_matches!(
286            info.params::<JsonObject>(&AuthType::from("local.custom.stage")),
287            Ok(Some(_))
288        );
289
290        assert_matches!(info.params::<LoginTermsParams>(&AuthType::Terms), Ok(Some(params)));
291        assert_eq!(params.policies.len(), 1);
292
293        let policy = params.policies.get("privacy").unwrap();
294        assert_eq!(policy.version, "1");
295        assert_eq!(policy.translations.len(), 2);
296        let translation = policy.translations.get("en-US").unwrap();
297        assert_eq!(translation.name, "Privacy Policy");
298        assert_eq!(translation.url, "http://matrix.local/en-US/privacy");
299        let translation = policy.translations.get("fr-FR").unwrap();
300        assert_eq!(translation.name, "Politique de confidentialité");
301        assert_eq!(translation.url, "http://matrix.local/fr-FR/privacy");
302    }
303
304    #[test]
305    fn uiaa_info_oauth_params() {
306        let url = "http://auth.matrix.local/reset";
307        let stable_json = json!({
308            "flows": [{
309                "stages": ["m.oauth"],
310            }],
311            "params": {
312                "m.oauth": {
313                    "url": url,
314                }
315            },
316            "session": "abcdef",
317        });
318        let unstable_json = json!({
319            "flows": [{
320                "stages": ["org.matrix.cross_signing_reset"],
321            }],
322            "params": {
323                "org.matrix.cross_signing_reset": {
324                    "url": url,
325                }
326            },
327            "session": "abcdef",
328        });
329
330        let info = from_json_value::<UiaaInfo>(stable_json).unwrap();
331        assert_matches!(info.params::<OAuthParams>(&AuthType::OAuth), Ok(Some(params)));
332        assert_eq!(params.url, url);
333
334        let info = from_json_value::<UiaaInfo>(unstable_json).unwrap();
335        assert_matches!(info.params::<OAuthParams>(&AuthType::OAuth), Ok(Some(params)));
336        assert_eq!(params.url, url);
337    }
338}