ruma_common/api/
auth_scheme.rs1#![allow(clippy::exhaustive_structs)]
5
6use as_variant::as_variant;
7use http::{header, HeaderMap};
8use serde::Deserialize;
9
10pub trait AuthScheme: Sized {
12 type Input<'a>;
14
15 type AddAuthenticationError: Into<Box<dyn std::error::Error + Send + Sync + 'static>>;
17
18 type Output;
20
21 type ExtractAuthenticationError: Into<Box<dyn std::error::Error + Send + Sync + 'static>>;
23
24 fn add_authentication<T: AsRef<[u8]>>(
29 request: &mut http::Request<T>,
30 input: Self::Input<'_>,
31 ) -> Result<(), Self::AddAuthenticationError>;
32
33 fn extract_authentication<T: AsRef<[u8]>>(
38 request: &http::Request<T>,
39 ) -> Result<Self::Output, Self::ExtractAuthenticationError>;
40}
41
42#[derive(Debug, Clone, Copy, Default)]
47pub struct NoAuthentication;
48
49impl AuthScheme for NoAuthentication {
50 type Input<'a> = SendAccessToken<'a>;
51 type AddAuthenticationError = header::InvalidHeaderValue;
52 type Output = ();
53 type ExtractAuthenticationError = std::convert::Infallible;
54
55 fn add_authentication<T: AsRef<[u8]>>(
56 request: &mut http::Request<T>,
57 access_token: SendAccessToken<'_>,
58 ) -> Result<(), Self::AddAuthenticationError> {
59 if let Some(access_token) = access_token.get_not_required_for_endpoint() {
60 add_access_token_as_authorization_header(request.headers_mut(), access_token)?;
61 }
62
63 Ok(())
64 }
65
66 fn extract_authentication<T: AsRef<[u8]>>(
68 _request: &http::Request<T>,
69 ) -> Result<(), Self::ExtractAuthenticationError> {
70 Ok(())
71 }
72}
73
74#[derive(Debug, Clone, Copy, Default)]
79pub struct AccessToken;
80
81impl AuthScheme for AccessToken {
82 type Input<'a> = SendAccessToken<'a>;
83 type AddAuthenticationError = AddRequiredTokenError;
84 type Output = String;
86 type ExtractAuthenticationError = ExtractTokenError;
87
88 fn add_authentication<T: AsRef<[u8]>>(
89 request: &mut http::Request<T>,
90 access_token: SendAccessToken<'_>,
91 ) -> Result<(), Self::AddAuthenticationError> {
92 let token = access_token
93 .get_required_for_endpoint()
94 .ok_or(AddRequiredTokenError::MissingAccessToken)?;
95 Ok(add_access_token_as_authorization_header(request.headers_mut(), token)?)
96 }
97
98 fn extract_authentication<T: AsRef<[u8]>>(
99 request: &http::Request<T>,
100 ) -> Result<String, Self::ExtractAuthenticationError> {
101 extract_bearer_or_query_token(request)?.ok_or(ExtractTokenError::MissingAccessToken)
102 }
103}
104
105#[derive(Debug, Clone, Copy, Default)]
110pub struct AccessTokenOptional;
111
112impl AuthScheme for AccessTokenOptional {
113 type Input<'a> = SendAccessToken<'a>;
114 type AddAuthenticationError = header::InvalidHeaderValue;
115 type Output = Option<String>;
117 type ExtractAuthenticationError = ExtractTokenError;
118
119 fn add_authentication<T: AsRef<[u8]>>(
120 request: &mut http::Request<T>,
121 access_token: SendAccessToken<'_>,
122 ) -> Result<(), Self::AddAuthenticationError> {
123 if let Some(access_token) = access_token.get_required_for_endpoint() {
124 add_access_token_as_authorization_header(request.headers_mut(), access_token)?;
125 }
126
127 Ok(())
128 }
129
130 fn extract_authentication<T: AsRef<[u8]>>(
131 request: &http::Request<T>,
132 ) -> Result<Option<String>, Self::ExtractAuthenticationError> {
133 extract_bearer_or_query_token(request)
134 }
135}
136
137#[derive(Debug, Clone, Copy, Default)]
143pub struct AppserviceToken;
144
145impl AuthScheme for AppserviceToken {
146 type Input<'a> = SendAccessToken<'a>;
147 type AddAuthenticationError = AddRequiredTokenError;
148 type Output = String;
150 type ExtractAuthenticationError = ExtractTokenError;
151
152 fn add_authentication<T: AsRef<[u8]>>(
153 request: &mut http::Request<T>,
154 access_token: SendAccessToken<'_>,
155 ) -> Result<(), Self::AddAuthenticationError> {
156 let token = access_token
157 .get_required_for_appservice()
158 .ok_or(AddRequiredTokenError::MissingAccessToken)?;
159 Ok(add_access_token_as_authorization_header(request.headers_mut(), token)?)
160 }
161
162 fn extract_authentication<T: AsRef<[u8]>>(
163 request: &http::Request<T>,
164 ) -> Result<String, Self::ExtractAuthenticationError> {
165 extract_bearer_or_query_token(request)?.ok_or(ExtractTokenError::MissingAccessToken)
166 }
167}
168
169#[derive(Debug, Clone, Copy, Default)]
175pub struct AppserviceTokenOptional;
176
177impl AuthScheme for AppserviceTokenOptional {
178 type Input<'a> = SendAccessToken<'a>;
179 type AddAuthenticationError = header::InvalidHeaderValue;
180 type Output = Option<String>;
182 type ExtractAuthenticationError = ExtractTokenError;
183
184 fn add_authentication<T: AsRef<[u8]>>(
185 request: &mut http::Request<T>,
186 access_token: SendAccessToken<'_>,
187 ) -> Result<(), Self::AddAuthenticationError> {
188 if let Some(access_token) = access_token.get_required_for_appservice() {
189 add_access_token_as_authorization_header(request.headers_mut(), access_token)?;
190 }
191
192 Ok(())
193 }
194
195 fn extract_authentication<T: AsRef<[u8]>>(
196 request: &http::Request<T>,
197 ) -> Result<Option<String>, Self::ExtractAuthenticationError> {
198 extract_bearer_or_query_token(request)
199 }
200}
201
202fn add_access_token_as_authorization_header(
204 headers: &mut HeaderMap,
205 token: &str,
206) -> Result<(), header::InvalidHeaderValue> {
207 headers.insert(header::AUTHORIZATION, format!("Bearer {token}").try_into()?);
208 Ok(())
209}
210
211fn extract_bearer_or_query_token<T>(
214 request: &http::Request<T>,
215) -> Result<Option<String>, ExtractTokenError> {
216 if let Some(token) = extract_bearer_token_from_authorization_header(request.headers())? {
217 return Ok(Some(token));
218 }
219
220 if let Some(query) = request.uri().query() {
221 Ok(extract_access_token_from_query(query)?)
222 } else {
223 Ok(None)
224 }
225}
226
227fn extract_bearer_token_from_authorization_header(
229 headers: &HeaderMap,
230) -> Result<Option<String>, ExtractTokenError> {
231 const EXPECTED_START: &str = "bearer ";
232
233 let Some(value) = headers.get(header::AUTHORIZATION) else {
234 return Ok(None);
235 };
236
237 let value = value.to_str()?;
238
239 if value.len() < EXPECTED_START.len()
240 || !value[..EXPECTED_START.len()].eq_ignore_ascii_case(EXPECTED_START)
241 {
242 return Err(ExtractTokenError::InvalidAuthorizationScheme);
243 }
244
245 Ok(Some(value[EXPECTED_START.len()..].to_owned()))
246}
247
248fn extract_access_token_from_query(
250 query: &str,
251) -> Result<Option<String>, serde_html_form::de::Error> {
252 #[derive(Deserialize)]
253 struct AccessTokenDeHelper {
254 access_token: Option<String>,
255 }
256
257 serde_html_form::from_str::<AccessTokenDeHelper>(query).map(|helper| helper.access_token)
258}
259
260#[derive(Clone, Copy, Debug)]
262#[allow(clippy::exhaustive_enums)]
263pub enum SendAccessToken<'a> {
264 IfRequired(&'a str),
267
268 Always(&'a str),
270
271 Appservice(&'a str),
274
275 None,
279}
280
281impl<'a> SendAccessToken<'a> {
282 pub fn get_required_for_endpoint(self) -> Option<&'a str> {
286 as_variant!(self, Self::IfRequired | Self::Appservice | Self::Always)
287 }
288
289 pub fn get_not_required_for_endpoint(self) -> Option<&'a str> {
293 as_variant!(self, Self::Always)
294 }
295
296 pub fn get_required_for_appservice(self) -> Option<&'a str> {
301 as_variant!(self, Self::Appservice | Self::Always)
302 }
303}
304
305#[derive(Debug, thiserror::Error)]
307#[non_exhaustive]
308pub enum AddRequiredTokenError {
309 #[error("no access token provided, but this endpoint requires one")]
311 MissingAccessToken,
312
313 #[error(transparent)]
315 IntoHeader(#[from] header::InvalidHeaderValue),
316}
317
318#[derive(Debug, thiserror::Error)]
320#[non_exhaustive]
321pub enum ExtractTokenError {
322 #[error("no access token found, but this endpoint requires one")]
324 MissingAccessToken,
325
326 #[error(transparent)]
328 FromHeader(#[from] header::ToStrError),
329
330 #[error("invalid authorization header scheme")]
332 InvalidAuthorizationScheme,
333
334 #[error("failed to deserialize query string: {0}")]
336 FromQuery(#[from] serde_html_form::de::Error),
337}