1#![allow(clippy::exhaustive_structs)]
5
6use as_variant::as_variant;
7use http::{HeaderMap, header};
8use serde::Deserialize;
9
10pub trait AuthScheme: Sized {
12 type Input<'a>;
14
15 type AddAuthenticationError: Into<Box<dyn std::error::Error + Send + Sync + 'static>>;
17
18 type Output;
20
21 type ExtractAuthenticationError: Into<Box<dyn std::error::Error + Send + Sync + 'static>>;
23
24 fn add_authentication<T: AsRef<[u8]>>(
29 request: &mut http::Request<T>,
30 input: Self::Input<'_>,
31 ) -> Result<(), Self::AddAuthenticationError>;
32
33 fn extract_authentication<T: AsRef<[u8]>>(
38 request: &http::Request<T>,
39 ) -> Result<Self::Output, Self::ExtractAuthenticationError>;
40}
41
42#[derive(Debug, Clone, Copy, Default)]
44pub struct NoAuthentication;
45
46impl AuthScheme for NoAuthentication {
47 type Input<'a> = ();
48 type AddAuthenticationError = std::convert::Infallible;
49 type Output = ();
50 type ExtractAuthenticationError = std::convert::Infallible;
51
52 fn add_authentication<T: AsRef<[u8]>>(
53 _request: &mut http::Request<T>,
54 _input: (),
55 ) -> Result<(), Self::AddAuthenticationError> {
56 Ok(())
57 }
58
59 fn extract_authentication<T: AsRef<[u8]>>(
61 _request: &http::Request<T>,
62 ) -> Result<(), Self::ExtractAuthenticationError> {
63 Ok(())
64 }
65}
66
67#[derive(Debug, Clone, Copy, Default)]
72pub struct NoAccessToken;
73
74impl AuthScheme for NoAccessToken {
75 type Input<'a> = SendAccessToken<'a>;
76 type AddAuthenticationError = header::InvalidHeaderValue;
77 type Output = ();
78 type ExtractAuthenticationError = std::convert::Infallible;
79
80 fn add_authentication<T: AsRef<[u8]>>(
81 request: &mut http::Request<T>,
82 access_token: SendAccessToken<'_>,
83 ) -> Result<(), Self::AddAuthenticationError> {
84 if let Some(access_token) = access_token.get_not_required_for_endpoint() {
85 add_access_token_as_authorization_header(request.headers_mut(), access_token)?;
86 }
87
88 Ok(())
89 }
90
91 fn extract_authentication<T: AsRef<[u8]>>(
93 _request: &http::Request<T>,
94 ) -> Result<(), Self::ExtractAuthenticationError> {
95 Ok(())
96 }
97}
98
99#[derive(Debug, Clone, Copy, Default)]
104pub struct AccessToken;
105
106impl AuthScheme for AccessToken {
107 type Input<'a> = SendAccessToken<'a>;
108 type AddAuthenticationError = AddRequiredTokenError;
109 type Output = String;
111 type ExtractAuthenticationError = ExtractTokenError;
112
113 fn add_authentication<T: AsRef<[u8]>>(
114 request: &mut http::Request<T>,
115 access_token: SendAccessToken<'_>,
116 ) -> Result<(), Self::AddAuthenticationError> {
117 let token = access_token
118 .get_required_for_endpoint()
119 .ok_or(AddRequiredTokenError::MissingAccessToken)?;
120 Ok(add_access_token_as_authorization_header(request.headers_mut(), token)?)
121 }
122
123 fn extract_authentication<T: AsRef<[u8]>>(
124 request: &http::Request<T>,
125 ) -> Result<String, Self::ExtractAuthenticationError> {
126 extract_bearer_or_query_token(request)?.ok_or(ExtractTokenError::MissingAccessToken)
127 }
128}
129
130#[derive(Debug, Clone, Copy, Default)]
135pub struct AccessTokenOptional;
136
137impl AuthScheme for AccessTokenOptional {
138 type Input<'a> = SendAccessToken<'a>;
139 type AddAuthenticationError = header::InvalidHeaderValue;
140 type Output = Option<String>;
142 type ExtractAuthenticationError = ExtractTokenError;
143
144 fn add_authentication<T: AsRef<[u8]>>(
145 request: &mut http::Request<T>,
146 access_token: SendAccessToken<'_>,
147 ) -> Result<(), Self::AddAuthenticationError> {
148 if let Some(access_token) = access_token.get_required_for_endpoint() {
149 add_access_token_as_authorization_header(request.headers_mut(), access_token)?;
150 }
151
152 Ok(())
153 }
154
155 fn extract_authentication<T: AsRef<[u8]>>(
156 request: &http::Request<T>,
157 ) -> Result<Option<String>, Self::ExtractAuthenticationError> {
158 extract_bearer_or_query_token(request)
159 }
160}
161
162#[derive(Debug, Clone, Copy, Default)]
168pub struct AppserviceToken;
169
170impl AuthScheme for AppserviceToken {
171 type Input<'a> = SendAccessToken<'a>;
172 type AddAuthenticationError = AddRequiredTokenError;
173 type Output = String;
175 type ExtractAuthenticationError = ExtractTokenError;
176
177 fn add_authentication<T: AsRef<[u8]>>(
178 request: &mut http::Request<T>,
179 access_token: SendAccessToken<'_>,
180 ) -> Result<(), Self::AddAuthenticationError> {
181 let token = access_token
182 .get_required_for_appservice()
183 .ok_or(AddRequiredTokenError::MissingAccessToken)?;
184 Ok(add_access_token_as_authorization_header(request.headers_mut(), token)?)
185 }
186
187 fn extract_authentication<T: AsRef<[u8]>>(
188 request: &http::Request<T>,
189 ) -> Result<String, Self::ExtractAuthenticationError> {
190 extract_bearer_or_query_token(request)?.ok_or(ExtractTokenError::MissingAccessToken)
191 }
192}
193
194#[derive(Debug, Clone, Copy, Default)]
200pub struct AppserviceTokenOptional;
201
202impl AuthScheme for AppserviceTokenOptional {
203 type Input<'a> = SendAccessToken<'a>;
204 type AddAuthenticationError = header::InvalidHeaderValue;
205 type Output = Option<String>;
207 type ExtractAuthenticationError = ExtractTokenError;
208
209 fn add_authentication<T: AsRef<[u8]>>(
210 request: &mut http::Request<T>,
211 access_token: SendAccessToken<'_>,
212 ) -> Result<(), Self::AddAuthenticationError> {
213 if let Some(access_token) = access_token.get_required_for_appservice() {
214 add_access_token_as_authorization_header(request.headers_mut(), access_token)?;
215 }
216
217 Ok(())
218 }
219
220 fn extract_authentication<T: AsRef<[u8]>>(
221 request: &http::Request<T>,
222 ) -> Result<Option<String>, Self::ExtractAuthenticationError> {
223 extract_bearer_or_query_token(request)
224 }
225}
226
227fn add_access_token_as_authorization_header(
229 headers: &mut HeaderMap,
230 token: &str,
231) -> Result<(), header::InvalidHeaderValue> {
232 headers.insert(header::AUTHORIZATION, format!("Bearer {token}").try_into()?);
233 Ok(())
234}
235
236fn extract_bearer_or_query_token<T>(
239 request: &http::Request<T>,
240) -> Result<Option<String>, ExtractTokenError> {
241 if let Some(token) = extract_bearer_token_from_authorization_header(request.headers())? {
242 return Ok(Some(token));
243 }
244
245 if let Some(query) = request.uri().query() {
246 Ok(extract_access_token_from_query(query)?)
247 } else {
248 Ok(None)
249 }
250}
251
252fn extract_bearer_token_from_authorization_header(
254 headers: &HeaderMap,
255) -> Result<Option<String>, ExtractTokenError> {
256 const EXPECTED_START: &str = "bearer ";
257
258 let Some(value) = headers.get(header::AUTHORIZATION) else {
259 return Ok(None);
260 };
261
262 let value = value.to_str()?;
263
264 if value.len() < EXPECTED_START.len()
265 || !value[..EXPECTED_START.len()].eq_ignore_ascii_case(EXPECTED_START)
266 {
267 return Err(ExtractTokenError::InvalidAuthorizationScheme);
268 }
269
270 Ok(Some(value[EXPECTED_START.len()..].to_owned()))
271}
272
273fn extract_access_token_from_query(
275 query: &str,
276) -> Result<Option<String>, serde_html_form::de::Error> {
277 #[derive(Deserialize)]
278 struct AccessTokenDeHelper {
279 access_token: Option<String>,
280 }
281
282 serde_html_form::from_str::<AccessTokenDeHelper>(query).map(|helper| helper.access_token)
283}
284
285#[derive(Clone, Copy, Debug)]
287#[allow(clippy::exhaustive_enums)]
288pub enum SendAccessToken<'a> {
289 IfRequired(&'a str),
292
293 Always(&'a str),
295
296 Appservice(&'a str),
299
300 None,
304}
305
306impl<'a> SendAccessToken<'a> {
307 pub fn get_required_for_endpoint(self) -> Option<&'a str> {
311 as_variant!(self, Self::IfRequired | Self::Appservice | Self::Always)
312 }
313
314 pub fn get_not_required_for_endpoint(self) -> Option<&'a str> {
318 as_variant!(self, Self::Always)
319 }
320
321 pub fn get_required_for_appservice(self) -> Option<&'a str> {
326 as_variant!(self, Self::Appservice | Self::Always)
327 }
328}
329
330#[derive(Debug, thiserror::Error)]
332#[non_exhaustive]
333pub enum AddRequiredTokenError {
334 #[error("no access token provided, but this endpoint requires one")]
336 MissingAccessToken,
337
338 #[error(transparent)]
340 IntoHeader(#[from] header::InvalidHeaderValue),
341}
342
343#[derive(Debug, thiserror::Error)]
345#[non_exhaustive]
346pub enum ExtractTokenError {
347 #[error("no access token found, but this endpoint requires one")]
349 MissingAccessToken,
350
351 #[error(transparent)]
353 FromHeader(#[from] header::ToStrError),
354
355 #[error("invalid authorization header scheme")]
357 InvalidAuthorizationScheme,
358
359 #[error("failed to deserialize query string: {0}")]
361 FromQuery(#[from] serde_html_form::de::Error),
362}