ruma_common/api/
auth_scheme.rs

1//! The `AuthScheme` trait used to specify the authentication scheme used by endpoints and the types
2//! that implement it.
3
4#![allow(clippy::exhaustive_structs)]
5
6use as_variant::as_variant;
7use http::{header, HeaderMap};
8use serde::Deserialize;
9
10/// Trait implemented by types representing an authentication scheme used by an endpoint.
11pub trait AuthScheme: Sized {
12    /// The input necessary to generate the authentication.
13    type Input<'a>;
14
15    /// The error type returned from [`add_authentication()`](Self::add_authentication).
16    type AddAuthenticationError: Into<Box<dyn std::error::Error + Send + Sync + 'static>>;
17
18    /// The authentication data that can be extracted from a request.
19    type Output;
20
21    /// The error type returned from [`extract_authentication()`](Self::extract_authentication).
22    type ExtractAuthenticationError: Into<Box<dyn std::error::Error + Send + Sync + 'static>>;
23
24    /// Add this authentication scheme to the given outgoing request, if necessary.
25    ///
26    /// Returns an error if the endpoint requires authentication but the input doesn't provide it,
27    /// or if the input fails to serialize to the proper format.
28    fn add_authentication<T: AsRef<[u8]>>(
29        request: &mut http::Request<T>,
30        input: Self::Input<'_>,
31    ) -> Result<(), Self::AddAuthenticationError>;
32
33    /// Extract the data of this authentication scheme from the given incoming request.
34    ///
35    /// Returns an error if the endpoint requires authentication but the request doesn't provide it,
36    /// or if the output fails to deserialize to the proper format.
37    fn extract_authentication<T: AsRef<[u8]>>(
38        request: &http::Request<T>,
39    ) -> Result<Self::Output, Self::ExtractAuthenticationError>;
40}
41
42/// No authentication is performed.
43///
44/// This type accepts a [`SendAccessToken`] as input to be able to send it regardless of whether it
45/// is required.
46#[derive(Debug, Clone, Copy, Default)]
47pub struct NoAuthentication;
48
49impl AuthScheme for NoAuthentication {
50    type Input<'a> = SendAccessToken<'a>;
51    type AddAuthenticationError = header::InvalidHeaderValue;
52    type Output = ();
53    type ExtractAuthenticationError = std::convert::Infallible;
54
55    fn add_authentication<T: AsRef<[u8]>>(
56        request: &mut http::Request<T>,
57        access_token: SendAccessToken<'_>,
58    ) -> Result<(), Self::AddAuthenticationError> {
59        if let Some(access_token) = access_token.get_not_required_for_endpoint() {
60            add_access_token_as_authorization_header(request.headers_mut(), access_token)?;
61        }
62
63        Ok(())
64    }
65
66    /// Since this endpoint doesn't expect any authentication, this is a noop.
67    fn extract_authentication<T: AsRef<[u8]>>(
68        _request: &http::Request<T>,
69    ) -> Result<(), Self::ExtractAuthenticationError> {
70        Ok(())
71    }
72}
73
74/// Authentication is performed by including an access token in the `Authentication` http
75/// header, or an `access_token` query parameter.
76///
77/// Using the query parameter is deprecated since Matrix 1.11.
78#[derive(Debug, Clone, Copy, Default)]
79pub struct AccessToken;
80
81impl AuthScheme for AccessToken {
82    type Input<'a> = SendAccessToken<'a>;
83    type AddAuthenticationError = AddRequiredTokenError;
84    /// The access token.
85    type Output = String;
86    type ExtractAuthenticationError = ExtractTokenError;
87
88    fn add_authentication<T: AsRef<[u8]>>(
89        request: &mut http::Request<T>,
90        access_token: SendAccessToken<'_>,
91    ) -> Result<(), Self::AddAuthenticationError> {
92        let token = access_token
93            .get_required_for_endpoint()
94            .ok_or(AddRequiredTokenError::MissingAccessToken)?;
95        Ok(add_access_token_as_authorization_header(request.headers_mut(), token)?)
96    }
97
98    fn extract_authentication<T: AsRef<[u8]>>(
99        request: &http::Request<T>,
100    ) -> Result<String, Self::ExtractAuthenticationError> {
101        extract_bearer_or_query_token(request)?.ok_or(ExtractTokenError::MissingAccessToken)
102    }
103}
104
105/// Authentication is optional, and it is performed by including an access token in the
106/// `Authentication` http header, or an `access_token` query parameter.
107///
108/// Using the query parameter is deprecated since Matrix 1.11.
109#[derive(Debug, Clone, Copy, Default)]
110pub struct AccessTokenOptional;
111
112impl AuthScheme for AccessTokenOptional {
113    type Input<'a> = SendAccessToken<'a>;
114    type AddAuthenticationError = header::InvalidHeaderValue;
115    /// The access token, if any.
116    type Output = Option<String>;
117    type ExtractAuthenticationError = ExtractTokenError;
118
119    fn add_authentication<T: AsRef<[u8]>>(
120        request: &mut http::Request<T>,
121        access_token: SendAccessToken<'_>,
122    ) -> Result<(), Self::AddAuthenticationError> {
123        if let Some(access_token) = access_token.get_required_for_endpoint() {
124            add_access_token_as_authorization_header(request.headers_mut(), access_token)?;
125        }
126
127        Ok(())
128    }
129
130    fn extract_authentication<T: AsRef<[u8]>>(
131        request: &http::Request<T>,
132    ) -> Result<Option<String>, Self::ExtractAuthenticationError> {
133        extract_bearer_or_query_token(request)
134    }
135}
136
137/// Authentication is required, and can only be performed for appservices, by including an
138/// appservice access token in the `Authentication` http header, or `access_token` query
139/// parameter.
140///
141/// Using the query parameter is deprecated since Matrix 1.11.
142#[derive(Debug, Clone, Copy, Default)]
143pub struct AppserviceToken;
144
145impl AuthScheme for AppserviceToken {
146    type Input<'a> = SendAccessToken<'a>;
147    type AddAuthenticationError = AddRequiredTokenError;
148    /// The appservice token.
149    type Output = String;
150    type ExtractAuthenticationError = ExtractTokenError;
151
152    fn add_authentication<T: AsRef<[u8]>>(
153        request: &mut http::Request<T>,
154        access_token: SendAccessToken<'_>,
155    ) -> Result<(), Self::AddAuthenticationError> {
156        let token = access_token
157            .get_required_for_appservice()
158            .ok_or(AddRequiredTokenError::MissingAccessToken)?;
159        Ok(add_access_token_as_authorization_header(request.headers_mut(), token)?)
160    }
161
162    fn extract_authentication<T: AsRef<[u8]>>(
163        request: &http::Request<T>,
164    ) -> Result<String, Self::ExtractAuthenticationError> {
165        extract_bearer_or_query_token(request)?.ok_or(ExtractTokenError::MissingAccessToken)
166    }
167}
168
169/// No authentication is performed for clients, but it can be performed for appservices, by
170/// including an appservice access token in the `Authentication` http header, or an
171/// `access_token` query parameter.
172///
173/// Using the query parameter is deprecated since Matrix 1.11.
174#[derive(Debug, Clone, Copy, Default)]
175pub struct AppserviceTokenOptional;
176
177impl AuthScheme for AppserviceTokenOptional {
178    type Input<'a> = SendAccessToken<'a>;
179    type AddAuthenticationError = header::InvalidHeaderValue;
180    /// The appservice token, if any.
181    type Output = Option<String>;
182    type ExtractAuthenticationError = ExtractTokenError;
183
184    fn add_authentication<T: AsRef<[u8]>>(
185        request: &mut http::Request<T>,
186        access_token: SendAccessToken<'_>,
187    ) -> Result<(), Self::AddAuthenticationError> {
188        if let Some(access_token) = access_token.get_required_for_appservice() {
189            add_access_token_as_authorization_header(request.headers_mut(), access_token)?;
190        }
191
192        Ok(())
193    }
194
195    fn extract_authentication<T: AsRef<[u8]>>(
196        request: &http::Request<T>,
197    ) -> Result<Option<String>, Self::ExtractAuthenticationError> {
198        extract_bearer_or_query_token(request)
199    }
200}
201
202/// Add the given access token as an `Authorization` HTTP header to the given map.
203fn add_access_token_as_authorization_header(
204    headers: &mut HeaderMap,
205    token: &str,
206) -> Result<(), header::InvalidHeaderValue> {
207    headers.insert(header::AUTHORIZATION, format!("Bearer {token}").try_into()?);
208    Ok(())
209}
210
211/// Extract the access token from the `Authorization` HTTP header or the query string of the given
212/// request.
213fn extract_bearer_or_query_token<T>(
214    request: &http::Request<T>,
215) -> Result<Option<String>, ExtractTokenError> {
216    if let Some(token) = extract_bearer_token_from_authorization_header(request.headers())? {
217        return Ok(Some(token));
218    }
219
220    if let Some(query) = request.uri().query() {
221        Ok(extract_access_token_from_query(query)?)
222    } else {
223        Ok(None)
224    }
225}
226
227/// Extract the value of the `Authorization` HTTP header with a `Bearer` scheme.
228fn extract_bearer_token_from_authorization_header(
229    headers: &HeaderMap,
230) -> Result<Option<String>, ExtractTokenError> {
231    const EXPECTED_START: &str = "bearer ";
232
233    let Some(value) = headers.get(header::AUTHORIZATION) else {
234        return Ok(None);
235    };
236
237    let value = value.to_str()?;
238
239    if value.len() < EXPECTED_START.len()
240        || !value[..EXPECTED_START.len()].eq_ignore_ascii_case(EXPECTED_START)
241    {
242        return Err(ExtractTokenError::InvalidAuthorizationScheme);
243    }
244
245    Ok(Some(value[EXPECTED_START.len()..].to_owned()))
246}
247
248/// Extract the `access_token` from the given query string.
249fn extract_access_token_from_query(
250    query: &str,
251) -> Result<Option<String>, serde_html_form::de::Error> {
252    #[derive(Deserialize)]
253    struct AccessTokenDeHelper {
254        access_token: Option<String>,
255    }
256
257    serde_html_form::from_str::<AccessTokenDeHelper>(query).map(|helper| helper.access_token)
258}
259
260/// An enum to control whether an access token should be added to outgoing requests
261#[derive(Clone, Copy, Debug)]
262#[allow(clippy::exhaustive_enums)]
263pub enum SendAccessToken<'a> {
264    /// Add the given access token to the request only if the `METADATA` on the request requires
265    /// it.
266    IfRequired(&'a str),
267
268    /// Always add the access token.
269    Always(&'a str),
270
271    /// Add the given appservice token to the request only if the `METADATA` on the request
272    /// requires it.
273    Appservice(&'a str),
274
275    /// Don't add an access token.
276    ///
277    /// This will lead to an error if the request endpoint requires authentication
278    None,
279}
280
281impl<'a> SendAccessToken<'a> {
282    /// Get the access token for an endpoint that requires one.
283    ///
284    /// Returns `Some(_)` if `self` contains an access token.
285    pub fn get_required_for_endpoint(self) -> Option<&'a str> {
286        as_variant!(self, Self::IfRequired | Self::Appservice | Self::Always)
287    }
288
289    /// Get the access token for an endpoint that should not require one.
290    ///
291    /// Returns `Some(_)` only if `self` is `SendAccessToken::Always(_)`.
292    pub fn get_not_required_for_endpoint(self) -> Option<&'a str> {
293        as_variant!(self, Self::Always)
294    }
295
296    /// Gets the access token for an endpoint that requires one for appservices.
297    ///
298    /// Returns `Some(_)` if `self` is either `SendAccessToken::Appservice(_)`
299    /// or `SendAccessToken::Always(_)`
300    pub fn get_required_for_appservice(self) -> Option<&'a str> {
301        as_variant!(self, Self::Appservice | Self::Always)
302    }
303}
304
305/// An error that can occur when adding an [`AuthScheme`] that requires an access token.
306#[derive(Debug, thiserror::Error)]
307#[non_exhaustive]
308pub enum AddRequiredTokenError {
309    /// No access token was provided, but the endpoint requires one.
310    #[error("no access token provided, but this endpoint requires one")]
311    MissingAccessToken,
312
313    /// Failed to convert the authentication to a header value.
314    #[error(transparent)]
315    IntoHeader(#[from] header::InvalidHeaderValue),
316}
317
318/// An error that can occur when extracting an [`AuthScheme`] that expects an access token.
319#[derive(Debug, thiserror::Error)]
320#[non_exhaustive]
321pub enum ExtractTokenError {
322    /// No access token was found, but the endpoint requires one.
323    #[error("no access token found, but this endpoint requires one")]
324    MissingAccessToken,
325
326    /// Failed to convert the header value to a UTF-8 string.
327    #[error(transparent)]
328    FromHeader(#[from] header::ToStrError),
329
330    /// The scheme of the Authorization HTTP header is invalid.
331    #[error("invalid authorization header scheme")]
332    InvalidAuthorizationScheme,
333
334    /// Failed to deserialize the query string.
335    #[error("failed to deserialize query string: {0}")]
336    FromQuery(#[from] serde_html_form::de::Error),
337}