1use std::{fmt, str::FromStr};
4
5use headers::authorization::Credentials;
6use http::HeaderValue;
7use http_auth::ChallengeParser;
8use ruma_common::{
9 http_headers::quote_ascii_string_if_required,
10 serde::{Base64, Base64DecodeError},
11 IdParseError, OwnedServerName, OwnedServerSigningKeyId,
12};
13use thiserror::Error;
14use tracing::debug;
15
16#[derive(Clone)]
21#[non_exhaustive]
22pub struct XMatrix {
23 pub origin: OwnedServerName,
25 pub destination: Option<OwnedServerName>,
32 pub key: OwnedServerSigningKeyId,
35 pub sig: Base64,
37}
38
39impl XMatrix {
40 pub fn new(
42 origin: OwnedServerName,
43 destination: OwnedServerName,
44 key: OwnedServerSigningKeyId,
45 sig: Base64,
46 ) -> Self {
47 Self { origin, destination: Some(destination), key, sig }
48 }
49
50 pub fn parse(s: impl AsRef<str>) -> Result<Self, XMatrixParseError> {
52 let parser = ChallengeParser::new(s.as_ref());
53 let mut xmatrix = None;
54
55 for challenge in parser {
56 let challenge = challenge?;
57
58 if challenge.scheme.eq_ignore_ascii_case(XMatrix::SCHEME) {
59 xmatrix = Some(challenge);
60 break;
61 }
62 }
63
64 let Some(xmatrix) = xmatrix else {
65 return Err(XMatrixParseError::NotFound);
66 };
67
68 let mut origin = None;
69 let mut destination = None;
70 let mut key = None;
71 let mut sig = None;
72
73 for (name, value) in xmatrix.params {
74 if name.eq_ignore_ascii_case("origin") {
75 if origin.is_some() {
76 return Err(XMatrixParseError::DuplicateParameter("origin".to_owned()));
77 } else {
78 origin = Some(OwnedServerName::try_from(value.to_unescaped())?);
79 }
80 } else if name.eq_ignore_ascii_case("destination") {
81 if destination.is_some() {
82 return Err(XMatrixParseError::DuplicateParameter("destination".to_owned()));
83 } else {
84 destination = Some(OwnedServerName::try_from(value.to_unescaped())?);
85 }
86 } else if name.eq_ignore_ascii_case("key") {
87 if key.is_some() {
88 return Err(XMatrixParseError::DuplicateParameter("key".to_owned()));
89 } else {
90 key = Some(OwnedServerSigningKeyId::try_from(value.to_unescaped())?);
91 }
92 } else if name.eq_ignore_ascii_case("sig") {
93 if sig.is_some() {
94 return Err(XMatrixParseError::DuplicateParameter("sig".to_owned()));
95 } else {
96 sig = Some(Base64::parse(value.to_unescaped())?);
97 }
98 } else {
99 debug!("Unknown parameter {name} in X-Matrix Authorization header");
100 }
101 }
102
103 Ok(Self {
104 origin: origin
105 .ok_or_else(|| XMatrixParseError::MissingParameter("origin".to_owned()))?,
106 destination,
107 key: key.ok_or_else(|| XMatrixParseError::MissingParameter("key".to_owned()))?,
108 sig: sig.ok_or_else(|| XMatrixParseError::MissingParameter("sig".to_owned()))?,
109 })
110 }
111}
112
113impl fmt::Debug for XMatrix {
114 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
115 f.debug_struct("XMatrix")
116 .field("origin", &self.origin)
117 .field("destination", &self.destination)
118 .field("key", &self.key)
119 .finish_non_exhaustive()
120 }
121}
122
123impl fmt::Display for XMatrix {
124 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
125 let Self { origin, destination, key, sig } = self;
126
127 let origin = quote_ascii_string_if_required(origin.as_str());
128 let key = quote_ascii_string_if_required(key.as_str());
129 let sig = sig.encode();
130 let sig = quote_ascii_string_if_required(&sig);
131
132 write!(f, r#"{} "#, Self::SCHEME)?;
133
134 if let Some(destination) = destination {
135 let destination = quote_ascii_string_if_required(destination.as_str());
136 write!(f, r#"destination={destination},"#)?;
137 }
138
139 write!(f, "key={key},origin={origin},sig={sig}")
140 }
141}
142
143impl FromStr for XMatrix {
144 type Err = XMatrixParseError;
145
146 fn from_str(s: &str) -> Result<Self, Self::Err> {
147 Self::parse(s)
148 }
149}
150
151impl TryFrom<&HeaderValue> for XMatrix {
152 type Error = XMatrixParseError;
153
154 fn try_from(value: &HeaderValue) -> Result<Self, Self::Error> {
155 Self::parse(value.to_str()?)
156 }
157}
158
159impl From<&XMatrix> for HeaderValue {
160 fn from(value: &XMatrix) -> Self {
161 value.to_string().try_into().expect("header format is static")
162 }
163}
164
165impl Credentials for XMatrix {
166 const SCHEME: &'static str = "X-Matrix";
167
168 fn decode(value: &HeaderValue) -> Option<Self> {
169 value.try_into().ok()
170 }
171
172 fn encode(&self) -> HeaderValue {
173 self.into()
174 }
175}
176
177#[derive(Debug, Error)]
179#[non_exhaustive]
180pub enum XMatrixParseError {
181 #[error(transparent)]
183 ToStr(#[from] http::header::ToStrError),
184
185 #[error("{0}")]
187 ParseStr(String),
188
189 #[error("X-Matrix credentials not found")]
191 NotFound,
192
193 #[error(transparent)]
195 ParseId(#[from] IdParseError),
196
197 #[error(transparent)]
199 ParseBase64(#[from] Base64DecodeError),
200
201 #[error("missing parameter '{0}'")]
203 MissingParameter(String),
204
205 #[error("duplicate parameter '{0}'")]
207 DuplicateParameter(String),
208}
209
210impl<'a> From<http_auth::parser::Error<'a>> for XMatrixParseError {
211 fn from(value: http_auth::parser::Error<'a>) -> Self {
212 Self::ParseStr(value.to_string())
213 }
214}
215
216#[cfg(test)]
217mod tests {
218 use headers::{authorization::Credentials, HeaderValue};
219 use ruma_common::{serde::Base64, OwnedServerName};
220
221 use super::XMatrix;
222
223 #[test]
224 fn xmatrix_auth_pre_1_3() {
225 let header = HeaderValue::from_static(
226 "X-Matrix origin=\"origin.hs.example.com\",key=\"ed25519:key1\",sig=\"dGVzdA==\"",
227 );
228 let origin = "origin.hs.example.com".try_into().unwrap();
229 let key = "ed25519:key1".try_into().unwrap();
230 let sig = Base64::new(b"test".to_vec());
231 let credentials = XMatrix::try_from(&header).unwrap();
232 assert_eq!(credentials.origin, origin);
233 assert_eq!(credentials.destination, None);
234 assert_eq!(credentials.key, key);
235 assert_eq!(credentials.sig, sig);
236
237 let credentials = XMatrix { origin, destination: None, key, sig };
238
239 assert_eq!(
240 credentials.encode(),
241 "X-Matrix key=\"ed25519:key1\",origin=origin.hs.example.com,sig=dGVzdA"
242 );
243 }
244
245 #[test]
246 fn xmatrix_auth_1_3() {
247 let header = HeaderValue::from_static("X-Matrix origin=\"origin.hs.example.com\",destination=\"destination.hs.example.com\",key=\"ed25519:key1\",sig=\"dGVzdA==\"");
248 let origin: OwnedServerName = "origin.hs.example.com".try_into().unwrap();
249 let destination: OwnedServerName = "destination.hs.example.com".try_into().unwrap();
250 let key = "ed25519:key1".try_into().unwrap();
251 let sig = Base64::new(b"test".to_vec());
252 let credentials = XMatrix::try_from(&header).unwrap();
253 assert_eq!(credentials.origin, origin);
254 assert_eq!(credentials.destination, Some(destination.clone()));
255 assert_eq!(credentials.key, key);
256 assert_eq!(credentials.sig, sig);
257
258 let credentials = XMatrix::new(origin, destination, key, sig);
259
260 assert_eq!(credentials.encode(), "X-Matrix destination=destination.hs.example.com,key=\"ed25519:key1\",origin=origin.hs.example.com,sig=dGVzdA");
261 }
262
263 #[test]
264 fn xmatrix_quoting() {
265 let header = HeaderValue::from_static(
266 r#"X-Matrix origin="example.com:1234",key="abc\"def\\:ghi",sig=dGVzdA,"#,
267 );
268
269 let origin: OwnedServerName = "example.com:1234".try_into().unwrap();
270 let key = r#"abc"def\:ghi"#.try_into().unwrap();
271 let sig = Base64::new(b"test".to_vec());
272 let credentials = XMatrix::try_from(&header).unwrap();
273 assert_eq!(credentials.origin, origin);
274 assert_eq!(credentials.destination, None);
275 assert_eq!(credentials.key, key);
276 assert_eq!(credentials.sig, sig);
277
278 let credentials = XMatrix { origin, destination: None, key, sig };
279
280 assert_eq!(
281 credentials.encode(),
282 r#"X-Matrix key="abc\"def\\:ghi",origin="example.com:1234",sig=dGVzdA"#
283 );
284 }
285
286 #[test]
287 fn xmatrix_auth_1_3_with_extra_spaces() {
288 let header = HeaderValue::from_static("X-Matrix origin=\"origin.hs.example.com\" , destination=\"destination.hs.example.com\",key=\"ed25519:key1\", sig=\"dGVzdA\"");
289 let credentials = XMatrix::try_from(&header).unwrap();
290 let sig = Base64::new(b"test".to_vec());
291
292 assert_eq!(credentials.origin, "origin.hs.example.com");
293 assert_eq!(credentials.destination.unwrap(), "destination.hs.example.com");
294 assert_eq!(credentials.key, "ed25519:key1");
295 assert_eq!(credentials.sig, sig);
296 }
297}