Skip to main content

ruma_common/api/
auth_scheme.rs

1//! The `AuthScheme` trait used to specify the authentication scheme used by endpoints and the types
2//! that implement it.
3
4#![allow(clippy::exhaustive_structs)]
5
6use as_variant::as_variant;
7use http::{HeaderMap, header};
8use serde::Deserialize;
9
10/// Trait implemented by types representing an authentication scheme used by an endpoint.
11pub trait AuthScheme: Sized {
12    /// The input necessary to generate the authentication.
13    type Input<'a>;
14
15    /// The error type returned from [`add_authentication()`](Self::add_authentication).
16    type AddAuthenticationError: Into<Box<dyn std::error::Error + Send + Sync + 'static>>;
17
18    /// The authentication data that can be extracted from a request.
19    type Output;
20
21    /// The error type returned from [`extract_authentication()`](Self::extract_authentication).
22    type ExtractAuthenticationError: Into<Box<dyn std::error::Error + Send + Sync + 'static>>;
23
24    /// Add this authentication scheme to the given outgoing request, if necessary.
25    ///
26    /// Returns an error if the endpoint requires authentication but the input doesn't provide it,
27    /// or if the input fails to serialize to the proper format.
28    fn add_authentication<T: AsRef<[u8]>>(
29        request: &mut http::Request<T>,
30        input: Self::Input<'_>,
31    ) -> Result<(), Self::AddAuthenticationError>;
32
33    /// Extract the data of this authentication scheme from the given incoming request.
34    ///
35    /// Returns an error if the endpoint requires authentication but the request doesn't provide it,
36    /// or if the output fails to deserialize to the proper format.
37    fn extract_authentication<T: AsRef<[u8]>>(
38        request: &http::Request<T>,
39    ) -> Result<Self::Output, Self::ExtractAuthenticationError>;
40}
41
42/// No authentication is performed.
43#[derive(Debug, Clone, Copy, Default)]
44pub struct NoAuthentication;
45
46impl AuthScheme for NoAuthentication {
47    type Input<'a> = ();
48    type AddAuthenticationError = std::convert::Infallible;
49    type Output = ();
50    type ExtractAuthenticationError = std::convert::Infallible;
51
52    fn add_authentication<T: AsRef<[u8]>>(
53        _request: &mut http::Request<T>,
54        _input: (),
55    ) -> Result<(), Self::AddAuthenticationError> {
56        Ok(())
57    }
58
59    /// Since this endpoint doesn't expect any authentication, this is a noop.
60    fn extract_authentication<T: AsRef<[u8]>>(
61        _request: &http::Request<T>,
62    ) -> Result<(), Self::ExtractAuthenticationError> {
63        Ok(())
64    }
65}
66
67/// No authentication is performed on an API that usually relies on access tokens.
68///
69/// Contrary to [`NoAuthentication`], this type accepts a [`SendAccessToken`] as input to be able to
70/// send it regardless of whether it is required.
71#[derive(Debug, Clone, Copy, Default)]
72pub struct NoAccessToken;
73
74impl AuthScheme for NoAccessToken {
75    type Input<'a> = SendAccessToken<'a>;
76    type AddAuthenticationError = header::InvalidHeaderValue;
77    type Output = ();
78    type ExtractAuthenticationError = std::convert::Infallible;
79
80    fn add_authentication<T: AsRef<[u8]>>(
81        request: &mut http::Request<T>,
82        access_token: SendAccessToken<'_>,
83    ) -> Result<(), Self::AddAuthenticationError> {
84        if let Some(access_token) = access_token.get_not_required_for_endpoint() {
85            add_access_token_as_authorization_header(request.headers_mut(), access_token)?;
86        }
87
88        Ok(())
89    }
90
91    /// Since this endpoint doesn't expect any authentication, this is a noop.
92    fn extract_authentication<T: AsRef<[u8]>>(
93        _request: &http::Request<T>,
94    ) -> Result<(), Self::ExtractAuthenticationError> {
95        Ok(())
96    }
97}
98
99/// Authentication is performed by including an access token in the `Authentication` http
100/// header, or an `access_token` query parameter.
101///
102/// Using the query parameter is deprecated since Matrix 1.11.
103#[derive(Debug, Clone, Copy, Default)]
104pub struct AccessToken;
105
106impl AuthScheme for AccessToken {
107    type Input<'a> = SendAccessToken<'a>;
108    type AddAuthenticationError = AddRequiredTokenError;
109    /// The access token.
110    type Output = String;
111    type ExtractAuthenticationError = ExtractTokenError;
112
113    fn add_authentication<T: AsRef<[u8]>>(
114        request: &mut http::Request<T>,
115        access_token: SendAccessToken<'_>,
116    ) -> Result<(), Self::AddAuthenticationError> {
117        let token = access_token
118            .get_required_for_endpoint()
119            .ok_or(AddRequiredTokenError::MissingAccessToken)?;
120        Ok(add_access_token_as_authorization_header(request.headers_mut(), token)?)
121    }
122
123    fn extract_authentication<T: AsRef<[u8]>>(
124        request: &http::Request<T>,
125    ) -> Result<String, Self::ExtractAuthenticationError> {
126        extract_bearer_or_query_token(request)?.ok_or(ExtractTokenError::MissingAccessToken)
127    }
128}
129
130/// Authentication is optional, and it is performed by including an access token in the
131/// `Authentication` http header, or an `access_token` query parameter.
132///
133/// Using the query parameter is deprecated since Matrix 1.11.
134#[derive(Debug, Clone, Copy, Default)]
135pub struct AccessTokenOptional;
136
137impl AuthScheme for AccessTokenOptional {
138    type Input<'a> = SendAccessToken<'a>;
139    type AddAuthenticationError = header::InvalidHeaderValue;
140    /// The access token, if any.
141    type Output = Option<String>;
142    type ExtractAuthenticationError = ExtractTokenError;
143
144    fn add_authentication<T: AsRef<[u8]>>(
145        request: &mut http::Request<T>,
146        access_token: SendAccessToken<'_>,
147    ) -> Result<(), Self::AddAuthenticationError> {
148        if let Some(access_token) = access_token.get_required_for_endpoint() {
149            add_access_token_as_authorization_header(request.headers_mut(), access_token)?;
150        }
151
152        Ok(())
153    }
154
155    fn extract_authentication<T: AsRef<[u8]>>(
156        request: &http::Request<T>,
157    ) -> Result<Option<String>, Self::ExtractAuthenticationError> {
158        extract_bearer_or_query_token(request)
159    }
160}
161
162/// Authentication is required, and can only be performed for appservices, by including an
163/// appservice access token in the `Authentication` http header, or `access_token` query
164/// parameter.
165///
166/// Using the query parameter is deprecated since Matrix 1.11.
167#[derive(Debug, Clone, Copy, Default)]
168pub struct AppserviceToken;
169
170impl AuthScheme for AppserviceToken {
171    type Input<'a> = SendAccessToken<'a>;
172    type AddAuthenticationError = AddRequiredTokenError;
173    /// The appservice token.
174    type Output = String;
175    type ExtractAuthenticationError = ExtractTokenError;
176
177    fn add_authentication<T: AsRef<[u8]>>(
178        request: &mut http::Request<T>,
179        access_token: SendAccessToken<'_>,
180    ) -> Result<(), Self::AddAuthenticationError> {
181        let token = access_token
182            .get_required_for_appservice()
183            .ok_or(AddRequiredTokenError::MissingAccessToken)?;
184        Ok(add_access_token_as_authorization_header(request.headers_mut(), token)?)
185    }
186
187    fn extract_authentication<T: AsRef<[u8]>>(
188        request: &http::Request<T>,
189    ) -> Result<String, Self::ExtractAuthenticationError> {
190        extract_bearer_or_query_token(request)?.ok_or(ExtractTokenError::MissingAccessToken)
191    }
192}
193
194/// No authentication is performed for clients, but it can be performed for appservices, by
195/// including an appservice access token in the `Authentication` http header, or an
196/// `access_token` query parameter.
197///
198/// Using the query parameter is deprecated since Matrix 1.11.
199#[derive(Debug, Clone, Copy, Default)]
200pub struct AppserviceTokenOptional;
201
202impl AuthScheme for AppserviceTokenOptional {
203    type Input<'a> = SendAccessToken<'a>;
204    type AddAuthenticationError = header::InvalidHeaderValue;
205    /// The appservice token, if any.
206    type Output = Option<String>;
207    type ExtractAuthenticationError = ExtractTokenError;
208
209    fn add_authentication<T: AsRef<[u8]>>(
210        request: &mut http::Request<T>,
211        access_token: SendAccessToken<'_>,
212    ) -> Result<(), Self::AddAuthenticationError> {
213        if let Some(access_token) = access_token.get_required_for_appservice() {
214            add_access_token_as_authorization_header(request.headers_mut(), access_token)?;
215        }
216
217        Ok(())
218    }
219
220    fn extract_authentication<T: AsRef<[u8]>>(
221        request: &http::Request<T>,
222    ) -> Result<Option<String>, Self::ExtractAuthenticationError> {
223        extract_bearer_or_query_token(request)
224    }
225}
226
227/// Add the given access token as an `Authorization` HTTP header to the given map.
228fn add_access_token_as_authorization_header(
229    headers: &mut HeaderMap,
230    token: &str,
231) -> Result<(), header::InvalidHeaderValue> {
232    headers.insert(header::AUTHORIZATION, format!("Bearer {token}").try_into()?);
233    Ok(())
234}
235
236/// Extract the access token from the `Authorization` HTTP header or the query string of the given
237/// request.
238fn extract_bearer_or_query_token<T>(
239    request: &http::Request<T>,
240) -> Result<Option<String>, ExtractTokenError> {
241    if let Some(token) = extract_bearer_token_from_authorization_header(request.headers())? {
242        return Ok(Some(token));
243    }
244
245    if let Some(query) = request.uri().query() {
246        Ok(extract_access_token_from_query(query)?)
247    } else {
248        Ok(None)
249    }
250}
251
252/// Extract the value of the `Authorization` HTTP header with a `Bearer` scheme.
253fn extract_bearer_token_from_authorization_header(
254    headers: &HeaderMap,
255) -> Result<Option<String>, ExtractTokenError> {
256    const EXPECTED_START: &str = "bearer ";
257
258    let Some(value) = headers.get(header::AUTHORIZATION) else {
259        return Ok(None);
260    };
261
262    let value = value.to_str()?;
263
264    if value.len() < EXPECTED_START.len()
265        || !value[..EXPECTED_START.len()].eq_ignore_ascii_case(EXPECTED_START)
266    {
267        return Err(ExtractTokenError::InvalidAuthorizationScheme);
268    }
269
270    Ok(Some(value[EXPECTED_START.len()..].to_owned()))
271}
272
273/// Extract the `access_token` from the given query string.
274fn extract_access_token_from_query(
275    query: &str,
276) -> Result<Option<String>, serde_html_form::de::Error> {
277    #[derive(Deserialize)]
278    struct AccessTokenDeHelper {
279        access_token: Option<String>,
280    }
281
282    serde_html_form::from_str::<AccessTokenDeHelper>(query).map(|helper| helper.access_token)
283}
284
285/// An enum to control whether an access token should be added to outgoing requests
286#[derive(Clone, Copy, Debug)]
287#[allow(clippy::exhaustive_enums)]
288pub enum SendAccessToken<'a> {
289    /// Add the given access token to the request only if the `METADATA` on the request requires
290    /// it.
291    IfRequired(&'a str),
292
293    /// Always add the access token.
294    Always(&'a str),
295
296    /// Add the given appservice token to the request only if the `METADATA` on the request
297    /// requires it.
298    Appservice(&'a str),
299
300    /// Don't add an access token.
301    ///
302    /// This will lead to an error if the request endpoint requires authentication
303    None,
304}
305
306impl<'a> SendAccessToken<'a> {
307    /// Get the access token for an endpoint that requires one.
308    ///
309    /// Returns `Some(_)` if `self` contains an access token.
310    pub fn get_required_for_endpoint(self) -> Option<&'a str> {
311        as_variant!(self, Self::IfRequired | Self::Appservice | Self::Always)
312    }
313
314    /// Get the access token for an endpoint that should not require one.
315    ///
316    /// Returns `Some(_)` only if `self` is `SendAccessToken::Always(_)`.
317    pub fn get_not_required_for_endpoint(self) -> Option<&'a str> {
318        as_variant!(self, Self::Always)
319    }
320
321    /// Gets the access token for an endpoint that requires one for appservices.
322    ///
323    /// Returns `Some(_)` if `self` is either `SendAccessToken::Appservice(_)`
324    /// or `SendAccessToken::Always(_)`
325    pub fn get_required_for_appservice(self) -> Option<&'a str> {
326        as_variant!(self, Self::Appservice | Self::Always)
327    }
328}
329
330/// An error that can occur when adding an [`AuthScheme`] that requires an access token.
331#[derive(Debug, thiserror::Error)]
332#[non_exhaustive]
333pub enum AddRequiredTokenError {
334    /// No access token was provided, but the endpoint requires one.
335    #[error("no access token provided, but this endpoint requires one")]
336    MissingAccessToken,
337
338    /// Failed to convert the authentication to a header value.
339    #[error(transparent)]
340    IntoHeader(#[from] header::InvalidHeaderValue),
341}
342
343/// An error that can occur when extracting an [`AuthScheme`] that expects an access token.
344#[derive(Debug, thiserror::Error)]
345#[non_exhaustive]
346pub enum ExtractTokenError {
347    /// No access token was found, but the endpoint requires one.
348    #[error("no access token found, but this endpoint requires one")]
349    MissingAccessToken,
350
351    /// Failed to convert the header value to a UTF-8 string.
352    #[error(transparent)]
353    FromHeader(#[from] header::ToStrError),
354
355    /// The scheme of the Authorization HTTP header is invalid.
356    #[error("invalid authorization header scheme")]
357    InvalidAuthorizationScheme,
358
359    /// Failed to deserialize the query string.
360    #[error("failed to deserialize query string: {0}")]
361    FromQuery(#[from] serde_html_form::de::Error),
362}