1use std::{borrow::Cow, fmt, marker::PhantomData};
6
7use bytes::BufMut;
8use ruma_common::{
9 api::{EndpointError, OutgoingResponse, error::IntoHttpError},
10 serde::StringEnum,
11};
12use serde::{Deserialize, Deserializer, Serialize, de};
13use serde_json::{from_slice as from_json_slice, value::RawValue as RawJsonValue};
14
15use crate::{
16 PrivOwnedStr,
17 error::{Error as MatrixError, StandardErrorBody},
18};
19
20mod auth_data;
21mod auth_params;
22pub mod get_uiaa_fallback_page;
23
24pub use self::{auth_data::*, auth_params::*};
25
26#[doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/src/doc/string_enum.md"))]
28#[derive(Clone, StringEnum)]
29#[non_exhaustive]
30pub enum AuthType {
31 #[ruma_enum(rename = "m.login.password")]
33 Password,
34
35 #[ruma_enum(rename = "m.login.recaptcha")]
37 ReCaptcha,
38
39 #[ruma_enum(rename = "m.login.email.identity")]
41 EmailIdentity,
42
43 #[ruma_enum(rename = "m.login.msisdn")]
45 Msisdn,
46
47 #[ruma_enum(rename = "m.login.sso")]
49 Sso,
50
51 #[ruma_enum(rename = "m.login.dummy")]
53 Dummy,
54
55 #[ruma_enum(rename = "m.login.registration_token")]
57 RegistrationToken,
58
59 #[ruma_enum(rename = "m.login.terms")]
63 Terms,
64
65 #[ruma_enum(rename = "m.oauth", alias = "org.matrix.cross_signing_reset")]
70 OAuth,
71
72 #[doc(hidden)]
73 _Custom(PrivOwnedStr),
74}
75
76#[derive(Clone, Debug, Deserialize, Serialize)]
79#[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
80pub struct UiaaInfo {
81 pub flows: Vec<AuthFlow>,
83
84 #[serde(default, skip_serializing_if = "Vec::is_empty")]
86 pub completed: Vec<AuthType>,
87
88 #[serde(skip_serializing_if = "Option::is_none")]
92 pub params: Option<Box<RawJsonValue>>,
93
94 #[serde(skip_serializing_if = "Option::is_none")]
96 pub session: Option<String>,
97
98 #[serde(flatten, skip_serializing_if = "Option::is_none")]
100 pub auth_error: Option<StandardErrorBody>,
101}
102
103impl UiaaInfo {
104 pub fn new(flows: Vec<AuthFlow>) -> Self {
106 Self { flows, completed: Vec::new(), params: None, session: None, auth_error: None }
107 }
108
109 pub fn params<'a, T: Deserialize<'a>>(
127 &'a self,
128 auth_type: &AuthType,
129 ) -> Result<Option<T>, serde_json::Error> {
130 struct AuthTypeVisitor<'b, T> {
131 auth_type: &'b AuthType,
132 _phantom: PhantomData<T>,
133 }
134
135 impl<'de, T> de::Visitor<'de> for AuthTypeVisitor<'_, T>
136 where
137 T: Deserialize<'de>,
138 {
139 type Value = Option<T>;
140
141 fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
142 formatter.write_str("a key-value map")
143 }
144
145 fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
146 where
147 A: de::MapAccess<'de>,
148 {
149 let mut params = None;
150
151 while let Some(key) = map.next_key::<Cow<'de, str>>()? {
152 if AuthType::from(key) == *self.auth_type {
153 params = Some(map.next_value()?);
154 } else {
155 map.next_value::<de::IgnoredAny>()?;
156 }
157 }
158
159 Ok(params)
160 }
161 }
162
163 let Some(params) = &self.params else {
164 return Ok(None);
165 };
166
167 let mut deserializer = serde_json::Deserializer::from_str(params.get());
168 deserializer.deserialize_map(AuthTypeVisitor { auth_type, _phantom: PhantomData })
169 }
170}
171
172#[derive(Clone, Debug, Default, Deserialize, Serialize)]
174#[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
175pub struct AuthFlow {
176 #[serde(default, skip_serializing_if = "Vec::is_empty")]
178 pub stages: Vec<AuthType>,
179}
180
181impl AuthFlow {
182 pub fn new(stages: Vec<AuthType>) -> Self {
186 Self { stages }
187 }
188}
189
190#[derive(Clone, Debug)]
192#[allow(clippy::exhaustive_enums)]
193pub enum UiaaResponse {
194 AuthResponse(UiaaInfo),
196
197 MatrixError(MatrixError),
199}
200
201impl fmt::Display for UiaaResponse {
202 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
203 match self {
204 Self::AuthResponse(_) => write!(f, "User-Interactive Authentication required."),
205 Self::MatrixError(err) => write!(f, "{err}"),
206 }
207 }
208}
209
210impl From<MatrixError> for UiaaResponse {
211 fn from(error: MatrixError) -> Self {
212 Self::MatrixError(error)
213 }
214}
215
216impl EndpointError for UiaaResponse {
217 fn from_http_response<T: AsRef<[u8]>>(response: http::Response<T>) -> Self {
218 if response.status() == http::StatusCode::UNAUTHORIZED
219 && let Ok(uiaa_info) = from_json_slice(response.body().as_ref())
220 {
221 return Self::AuthResponse(uiaa_info);
222 }
223
224 Self::MatrixError(MatrixError::from_http_response(response))
225 }
226}
227
228impl std::error::Error for UiaaResponse {}
229
230impl OutgoingResponse for UiaaResponse {
231 fn try_into_http_response<T: Default + BufMut>(
232 self,
233 ) -> Result<http::Response<T>, IntoHttpError> {
234 match self {
235 UiaaResponse::AuthResponse(authentication_info) => http::Response::builder()
236 .header(http::header::CONTENT_TYPE, ruma_common::http_headers::APPLICATION_JSON)
237 .status(http::StatusCode::UNAUTHORIZED)
238 .body(ruma_common::serde::json_to_buf(&authentication_info)?)
239 .map_err(Into::into),
240 UiaaResponse::MatrixError(error) => error.try_into_http_response(),
241 }
242 }
243}
244
245#[cfg(test)]
246mod tests {
247 use assert_matches2::assert_matches;
248 use ruma_common::serde::JsonObject;
249 use serde_json::{from_value as from_json_value, json};
250
251 use super::{AuthType, LoginTermsParams, OAuthParams, UiaaInfo};
252
253 #[test]
254 fn uiaa_info_params() {
255 let json = json!({
256 "flows": [{
257 "stages": ["m.login.terms", "m.login.email.identity", "local.custom.stage"],
258 }],
259 "params": {
260 "local.custom.stage": {
261 "foo": "bar",
262 },
263 "m.login.terms": {
264 "policies": {
265 "privacy": {
266 "en-US": {
267 "name": "Privacy Policy",
268 "url": "http://matrix.local/en-US/privacy",
269 },
270 "fr-FR": {
271 "name": "Politique de confidentialité",
272 "url": "http://matrix.local/fr-FR/privacy",
273 },
274 "version": "1",
275 },
276 },
277 }
278 },
279 "session": "abcdef",
280 });
281
282 let info = from_json_value::<UiaaInfo>(json).unwrap();
283
284 assert_matches!(info.params::<JsonObject>(&AuthType::EmailIdentity), Ok(None));
285 assert_matches!(
286 info.params::<JsonObject>(&AuthType::from("local.custom.stage")),
287 Ok(Some(_))
288 );
289
290 assert_matches!(info.params::<LoginTermsParams>(&AuthType::Terms), Ok(Some(params)));
291 assert_eq!(params.policies.len(), 1);
292
293 let policy = params.policies.get("privacy").unwrap();
294 assert_eq!(policy.version, "1");
295 assert_eq!(policy.translations.len(), 2);
296 let translation = policy.translations.get("en-US").unwrap();
297 assert_eq!(translation.name, "Privacy Policy");
298 assert_eq!(translation.url, "http://matrix.local/en-US/privacy");
299 let translation = policy.translations.get("fr-FR").unwrap();
300 assert_eq!(translation.name, "Politique de confidentialité");
301 assert_eq!(translation.url, "http://matrix.local/fr-FR/privacy");
302 }
303
304 #[test]
305 fn uiaa_info_oauth_params() {
306 let url = "http://auth.matrix.local/reset";
307 let stable_json = json!({
308 "flows": [{
309 "stages": ["m.oauth"],
310 }],
311 "params": {
312 "m.oauth": {
313 "url": url,
314 }
315 },
316 "session": "abcdef",
317 });
318 let unstable_json = json!({
319 "flows": [{
320 "stages": ["org.matrix.cross_signing_reset"],
321 }],
322 "params": {
323 "org.matrix.cross_signing_reset": {
324 "url": url,
325 }
326 },
327 "session": "abcdef",
328 });
329
330 let info = from_json_value::<UiaaInfo>(stable_json).unwrap();
331 assert_matches!(info.params::<OAuthParams>(&AuthType::OAuth), Ok(Some(params)));
332 assert_eq!(params.url, url);
333
334 let info = from_json_value::<UiaaInfo>(unstable_json).unwrap();
335 assert_matches!(info.params::<OAuthParams>(&AuthType::OAuth), Ok(Some(params)));
336 assert_eq!(params.url, url);
337 }
338}