1use std::borrow::Cow;
6
7use as_variant::as_variant;
8use ruma_common::{
9 OwnedTransactionId,
10 serde::{Base64, JsonObject},
11};
12use ruma_macros::EventContent;
13use serde::{Deserialize, Deserializer, Serialize, de};
14use serde_json::{Value as JsonValue, from_value as from_json_value};
15
16use super::{
17 HashAlgorithm, KeyAgreementProtocol, MessageAuthenticationCode, ShortAuthenticationString,
18};
19use crate::relation::Reference;
20
21#[derive(Clone, Debug, Deserialize, Serialize, EventContent)]
25#[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
26#[ruma_event(type = "m.key.verification.accept", kind = ToDevice)]
27pub struct ToDeviceKeyVerificationAcceptEventContent {
28 pub transaction_id: OwnedTransactionId,
32
33 #[serde(flatten)]
35 pub method: AcceptMethod,
36}
37
38impl ToDeviceKeyVerificationAcceptEventContent {
39 pub fn new(transaction_id: OwnedTransactionId, method: AcceptMethod) -> Self {
42 Self { transaction_id, method }
43 }
44}
45
46#[derive(Clone, Debug, Deserialize, Serialize, EventContent)]
50#[ruma_event(type = "m.key.verification.accept", kind = MessageLike)]
51#[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
52pub struct KeyVerificationAcceptEventContent {
53 #[serde(flatten)]
55 pub method: AcceptMethod,
56
57 #[serde(rename = "m.relates_to")]
59 pub relates_to: Reference,
60}
61
62impl KeyVerificationAcceptEventContent {
63 pub fn new(method: AcceptMethod, relates_to: Reference) -> Self {
66 Self { method, relates_to }
67 }
68}
69
70#[derive(Clone, Debug, Serialize)]
72#[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
73#[serde(untagged)]
74pub enum AcceptMethod {
75 SasV1(SasV1Content),
77
78 #[doc(hidden)]
80 _Custom(_CustomAcceptMethodContent),
81}
82
83impl AcceptMethod {
84 pub fn method(&self) -> &str {
86 match self {
87 Self::SasV1(_) => "m.sas.v1",
88 Self::_Custom(c) => &c.method,
89 }
90 }
91
92 pub fn data(&self) -> Cow<'_, JsonObject> {
100 fn serialize<T: Serialize>(obj: T) -> JsonObject {
101 match serde_json::to_value(obj).expect("accept method serialization to succeed") {
102 JsonValue::Object(mut obj) => {
103 obj.remove("method");
104 obj
105 }
106 _ => panic!("all accept method variants must serialize to objects"),
107 }
108 }
109
110 match self {
111 Self::SasV1(c) => Cow::Owned(serialize(c)),
112 Self::_Custom(c) => Cow::Borrowed(&c.data),
113 }
114 }
115}
116
117impl<'de> Deserialize<'de> for AcceptMethod {
118 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
119 where
120 D: Deserializer<'de>,
121 {
122 let mut data = JsonObject::deserialize(deserializer)?;
123
124 let method = data
125 .get("method")
126 .and_then(|value| as_variant!(value, JsonValue::String))
127 .ok_or_else(|| de::Error::missing_field("method"))?;
128
129 match method.as_ref() {
130 "m.sas.v1" => from_json_value(data.into()).map(Self::SasV1),
131 _ => {
132 let method = as_variant!(
133 data.remove("method")
134 .expect("we already checked that the method field is present"),
135 JsonValue::String
136 )
137 .expect("we already checked that the method is a string");
138
139 Ok(Self::_Custom(_CustomAcceptMethodContent { method, data }))
140 }
141 }
142 .map_err(de::Error::custom)
143 }
144}
145
146#[doc(hidden)]
148#[derive(Clone, Debug, Serialize)]
149pub struct _CustomAcceptMethodContent {
150 method: String,
152
153 #[serde(flatten)]
155 data: JsonObject,
156}
157
158#[derive(Clone, Debug, Deserialize, Serialize)]
160#[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
161#[serde(rename = "m.sas.v1", tag = "method")]
162pub struct SasV1Content {
163 pub key_agreement_protocol: KeyAgreementProtocol,
166
167 pub hash: HashAlgorithm,
170
171 pub message_authentication_code: MessageAuthenticationCode,
174
175 pub short_authentication_string: Vec<ShortAuthenticationString>,
181
182 pub commitment: Base64,
186}
187
188#[derive(Debug)]
190#[allow(clippy::exhaustive_structs)]
191pub struct SasV1ContentInit {
192 pub key_agreement_protocol: KeyAgreementProtocol,
195
196 pub hash: HashAlgorithm,
199
200 pub message_authentication_code: MessageAuthenticationCode,
202
203 pub short_authentication_string: Vec<ShortAuthenticationString>,
209
210 pub commitment: Base64,
214}
215
216impl From<SasV1ContentInit> for SasV1Content {
217 fn from(init: SasV1ContentInit) -> Self {
219 SasV1Content {
220 hash: init.hash,
221 key_agreement_protocol: init.key_agreement_protocol,
222 message_authentication_code: init.message_authentication_code,
223 short_authentication_string: init.short_authentication_string,
224 commitment: init.commitment,
225 }
226 }
227}
228
229#[cfg(test)]
230mod tests {
231 use assert_matches2::{assert_let, assert_matches};
232 use ruma_common::{
233 canonical_json::assert_to_canonical_json_eq,
234 event_id,
235 serde::{Base64, Raw},
236 };
237 use serde_json::{Value as JsonValue, from_value as from_json_value, json};
238
239 use super::{
240 AcceptMethod, HashAlgorithm, KeyAgreementProtocol, KeyVerificationAcceptEventContent,
241 MessageAuthenticationCode, SasV1Content, ShortAuthenticationString,
242 ToDeviceKeyVerificationAcceptEventContent,
243 };
244 use crate::{ToDeviceEvent, relation::Reference};
245
246 #[test]
247 fn to_device_serialization() {
248 let key_verification_accept_content = ToDeviceKeyVerificationAcceptEventContent {
249 transaction_id: "456".into(),
250 method: AcceptMethod::SasV1(SasV1Content {
251 hash: HashAlgorithm::Sha256,
252 key_agreement_protocol: KeyAgreementProtocol::Curve25519,
253 message_authentication_code: MessageAuthenticationCode::HkdfHmacSha256V2,
254 short_authentication_string: vec![ShortAuthenticationString::Decimal],
255 commitment: Base64::new(b"hello".to_vec()),
256 }),
257 };
258
259 assert_to_canonical_json_eq!(
260 key_verification_accept_content,
261 json!({
262 "transaction_id": "456",
263 "method": "m.sas.v1",
264 "commitment": "aGVsbG8",
265 "key_agreement_protocol": "curve25519",
266 "hash": "sha256",
267 "message_authentication_code": "hkdf-hmac-sha256.v2",
268 "short_authentication_string": ["decimal"],
269 }),
270 );
271 }
272
273 #[test]
274 fn in_room_serialization() {
275 let event_id = event_id!("$1598361704261elfgc:localhost");
276
277 let key_verification_accept_content = KeyVerificationAcceptEventContent {
278 relates_to: Reference { event_id: event_id.to_owned() },
279 method: AcceptMethod::SasV1(SasV1Content {
280 hash: HashAlgorithm::Sha256,
281 key_agreement_protocol: KeyAgreementProtocol::Curve25519,
282 message_authentication_code: MessageAuthenticationCode::HkdfHmacSha256V2,
283 short_authentication_string: vec![ShortAuthenticationString::Decimal],
284 commitment: Base64::new(b"hello".to_vec()),
285 }),
286 };
287
288 assert_to_canonical_json_eq!(
289 key_verification_accept_content,
290 json!({
291 "method": "m.sas.v1",
292 "commitment": "aGVsbG8",
293 "key_agreement_protocol": "curve25519",
294 "hash": "sha256",
295 "message_authentication_code": "hkdf-hmac-sha256.v2",
296 "short_authentication_string": ["decimal"],
297 "m.relates_to": {
298 "rel_type": "m.reference",
299 "event_id": event_id,
300 },
301 }),
302 );
303 }
304
305 #[test]
306 fn to_device_deserialization() {
307 let json = json!({
308 "transaction_id": "456",
309 "commitment": "aGVsbG8",
310 "method": "m.sas.v1",
311 "hash": "sha256",
312 "key_agreement_protocol": "curve25519",
313 "message_authentication_code": "hkdf-hmac-sha256.v2",
314 "short_authentication_string": ["decimal"]
315 });
316
317 let content = from_json_value::<ToDeviceKeyVerificationAcceptEventContent>(json).unwrap();
319 assert_eq!(content.transaction_id, "456");
320
321 assert_matches!(content.method, AcceptMethod::SasV1(sas));
322 assert_eq!(sas.commitment.encode(), "aGVsbG8");
323 assert_eq!(sas.hash, HashAlgorithm::Sha256);
324 assert_eq!(sas.key_agreement_protocol, KeyAgreementProtocol::Curve25519);
325 assert_eq!(sas.message_authentication_code, MessageAuthenticationCode::HkdfHmacSha256V2);
326 assert_eq!(sas.short_authentication_string, vec![ShortAuthenticationString::Decimal]);
327
328 let json = json!({
329 "content": {
330 "commitment": "aGVsbG8",
331 "transaction_id": "456",
332 "method": "m.sas.v1",
333 "key_agreement_protocol": "curve25519",
334 "hash": "sha256",
335 "message_authentication_code": "hkdf-hmac-sha256.v2",
336 "short_authentication_string": ["decimal"]
337 },
338 "type": "m.key.verification.accept",
339 "sender": "@example:localhost",
340 });
341
342 let ev = from_json_value::<ToDeviceEvent<ToDeviceKeyVerificationAcceptEventContent>>(json)
343 .unwrap();
344 assert_eq!(ev.content.transaction_id, "456");
345 assert_eq!(ev.sender, "@example:localhost");
346
347 assert_matches!(ev.content.method, AcceptMethod::SasV1(sas));
348 assert_eq!(sas.commitment.encode(), "aGVsbG8");
349 assert_eq!(sas.hash, HashAlgorithm::Sha256);
350 assert_eq!(sas.key_agreement_protocol, KeyAgreementProtocol::Curve25519);
351 assert_eq!(sas.message_authentication_code, MessageAuthenticationCode::HkdfHmacSha256V2);
352 assert_eq!(sas.short_authentication_string, vec![ShortAuthenticationString::Decimal]);
353 }
354
355 #[test]
356 fn in_room_deserialization() {
357 let json = json!({
358 "commitment": "aGVsbG8",
359 "method": "m.sas.v1",
360 "hash": "sha256",
361 "key_agreement_protocol": "curve25519",
362 "message_authentication_code": "hkdf-hmac-sha256.v2",
363 "short_authentication_string": ["decimal"],
364 "m.relates_to": {
365 "rel_type": "m.reference",
366 "event_id": "$1598361704261elfgc:localhost",
367 }
368 });
369
370 let content = from_json_value::<KeyVerificationAcceptEventContent>(json).unwrap();
372 assert_eq!(content.relates_to.event_id, "$1598361704261elfgc:localhost");
373
374 assert_matches!(content.method, AcceptMethod::SasV1(sas));
375 assert_eq!(sas.commitment.encode(), "aGVsbG8");
376 assert_eq!(sas.hash, HashAlgorithm::Sha256);
377 assert_eq!(sas.key_agreement_protocol, KeyAgreementProtocol::Curve25519);
378 assert_eq!(sas.message_authentication_code, MessageAuthenticationCode::HkdfHmacSha256V2);
379 assert_eq!(sas.short_authentication_string, vec![ShortAuthenticationString::Decimal]);
380 }
381
382 #[test]
383 fn in_room_serialization_roundtrip() {
384 let event_id = event_id!("$1598361704261elfgc:localhost");
385
386 let content = KeyVerificationAcceptEventContent {
387 relates_to: Reference { event_id: event_id.to_owned() },
388 method: AcceptMethod::SasV1(SasV1Content {
389 hash: HashAlgorithm::Sha256,
390 key_agreement_protocol: KeyAgreementProtocol::Curve25519,
391 message_authentication_code: MessageAuthenticationCode::HkdfHmacSha256V2,
392 short_authentication_string: vec![ShortAuthenticationString::Decimal],
393 commitment: Base64::new(b"hello".to_vec()),
394 }),
395 };
396
397 let json_content = Raw::new(&content).unwrap();
398 let deser_content = json_content.deserialize().unwrap();
399
400 assert_matches!(deser_content.method, AcceptMethod::SasV1(_));
401 assert_eq!(deser_content.relates_to.event_id, event_id);
402 }
403
404 #[test]
405 fn custom_to_device_serialization_roundtrip() {
406 let json = json!({
407 "transaction_id": "456",
408 "method": "m.sas.custom",
409 "test": "field",
410 });
411
412 let content =
413 from_json_value::<ToDeviceKeyVerificationAcceptEventContent>(json.clone()).unwrap();
414
415 assert_eq!(content.transaction_id, "456");
416 assert_eq!(content.method.method(), "m.sas.custom");
417 let data = &*content.method.data();
418 assert_eq!(data.len(), 1);
419 assert_let!(Some(JsonValue::String(value)) = data.get("test"));
420 assert_eq!(value, "field");
421
422 assert_to_canonical_json_eq!(content, json);
423 }
424}