1use std::{borrow::Cow, fmt};
6
7use as_variant::as_variant;
8use ruma_common::{
9 OwnedDeviceId, OwnedTransactionId,
10 serde::{Base64, JsonObject},
11};
12use ruma_macros::EventContent;
13use serde::{Deserialize, Deserializer, Serialize, de};
14use serde_json::{Value as JsonValue, from_value as from_json_value};
15
16use super::{
17 HashAlgorithm, KeyAgreementProtocol, MessageAuthenticationCode, ShortAuthenticationString,
18};
19use crate::relation::Reference;
20
21#[derive(Clone, Debug, Deserialize, Serialize, EventContent)]
25#[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
26#[ruma_event(type = "m.key.verification.start", kind = ToDevice)]
27pub struct ToDeviceKeyVerificationStartEventContent {
28 pub from_device: OwnedDeviceId,
30
31 pub transaction_id: OwnedTransactionId,
37
38 #[serde(flatten)]
40 pub method: StartMethod,
41}
42
43impl ToDeviceKeyVerificationStartEventContent {
44 pub fn new(
47 from_device: OwnedDeviceId,
48 transaction_id: OwnedTransactionId,
49 method: StartMethod,
50 ) -> Self {
51 Self { from_device, transaction_id, method }
52 }
53}
54
55#[derive(Clone, Debug, Deserialize, Serialize, EventContent)]
59#[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
60#[ruma_event(type = "m.key.verification.start", kind = MessageLike)]
61pub struct KeyVerificationStartEventContent {
62 pub from_device: OwnedDeviceId,
64
65 #[serde(flatten)]
67 pub method: StartMethod,
68
69 #[serde(rename = "m.relates_to")]
71 pub relates_to: Reference,
72}
73
74impl KeyVerificationStartEventContent {
75 pub fn new(from_device: OwnedDeviceId, method: StartMethod, relates_to: Reference) -> Self {
78 Self { from_device, method, relates_to }
79 }
80}
81
82#[derive(Clone, Debug, Serialize)]
84#[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
85#[serde(untagged)]
86pub enum StartMethod {
87 SasV1(SasV1Content),
89
90 ReciprocateV1(ReciprocateV1Content),
96
97 #[doc(hidden)]
99 _Custom(_CustomStartMethodContent),
100}
101
102impl StartMethod {
103 pub fn method(&self) -> &str {
105 match self {
106 Self::SasV1(_) => "m.sas.v1",
107 Self::ReciprocateV1(_) => "m.reciprocate.v1",
108 Self::_Custom(c) => &c.method,
109 }
110 }
111
112 pub fn data(&self) -> Cow<'_, JsonObject> {
120 fn serialize<T: Serialize>(obj: T) -> JsonObject {
121 match serde_json::to_value(obj).expect("start method serialization to succeed") {
122 JsonValue::Object(mut obj) => {
123 obj.remove("method");
124 obj
125 }
126 _ => panic!("all start method variants must serialize to objects"),
127 }
128 }
129
130 match self {
131 Self::SasV1(c) => Cow::Owned(serialize(c)),
132 Self::ReciprocateV1(c) => Cow::Owned(serialize(c)),
133 Self::_Custom(c) => Cow::Borrowed(&c.data),
134 }
135 }
136}
137
138impl<'de> Deserialize<'de> for StartMethod {
139 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
140 where
141 D: Deserializer<'de>,
142 {
143 let mut data = JsonObject::deserialize(deserializer)?;
144
145 let method = data
146 .get("method")
147 .and_then(|value| as_variant!(value, JsonValue::String))
148 .ok_or_else(|| de::Error::missing_field("method"))?;
149
150 match method.as_ref() {
151 "m.sas.v1" => from_json_value(data.into()).map(Self::SasV1),
152 "m.reciprocate.v1" => from_json_value(data.into()).map(Self::ReciprocateV1),
153 _ => {
154 let method = as_variant!(
155 data.remove("method")
156 .expect("we already checked that the method field is present"),
157 JsonValue::String
158 )
159 .expect("we already checked that the method is a string");
160
161 Ok(Self::_Custom(_CustomStartMethodContent { method, data }))
162 }
163 }
164 .map_err(de::Error::custom)
165 }
166}
167
168#[doc(hidden)]
170#[derive(Clone, Debug, Serialize)]
171pub struct _CustomStartMethodContent {
172 method: String,
174
175 #[serde(flatten)]
177 data: JsonObject,
178}
179
180#[derive(Clone, Deserialize, Serialize)]
182#[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
183#[serde(rename = "m.reciprocate.v1", tag = "method")]
184pub struct ReciprocateV1Content {
185 pub secret: Base64,
187}
188
189impl ReciprocateV1Content {
190 pub fn new(secret: Base64) -> Self {
194 Self { secret }
195 }
196}
197
198impl fmt::Debug for ReciprocateV1Content {
199 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
200 f.debug_struct("ReciprocateV1Content").finish_non_exhaustive()
201 }
202}
203
204#[derive(Clone, Debug, Deserialize, Serialize)]
209#[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
210#[serde(rename = "m.sas.v1", tag = "method")]
211pub struct SasV1Content {
212 pub key_agreement_protocols: Vec<KeyAgreementProtocol>,
216
217 pub hashes: Vec<HashAlgorithm>,
221
222 pub message_authentication_codes: Vec<MessageAuthenticationCode>,
228
229 pub short_authentication_string: Vec<ShortAuthenticationString>,
233}
234
235#[derive(Debug)]
240#[allow(clippy::exhaustive_structs)]
241pub struct SasV1ContentInit {
242 pub key_agreement_protocols: Vec<KeyAgreementProtocol>,
246
247 pub hashes: Vec<HashAlgorithm>,
251
252 pub message_authentication_codes: Vec<MessageAuthenticationCode>,
258
259 pub short_authentication_string: Vec<ShortAuthenticationString>,
263}
264
265impl From<SasV1ContentInit> for SasV1Content {
266 fn from(init: SasV1ContentInit) -> Self {
268 Self {
269 key_agreement_protocols: init.key_agreement_protocols,
270 hashes: init.hashes,
271 message_authentication_codes: init.message_authentication_codes,
272 short_authentication_string: init.short_authentication_string,
273 }
274 }
275}
276
277#[cfg(test)]
278mod tests {
279 use assert_matches2::{assert_let, assert_matches};
280 use ruma_common::{canonical_json::assert_to_canonical_json_eq, event_id, serde::Base64};
281 use serde_json::{Value as JsonValue, from_value as from_json_value, json};
282
283 use super::{
284 HashAlgorithm, KeyAgreementProtocol, KeyVerificationStartEventContent,
285 MessageAuthenticationCode, ReciprocateV1Content, SasV1ContentInit,
286 ShortAuthenticationString, StartMethod, ToDeviceKeyVerificationStartEventContent,
287 };
288 use crate::{ToDeviceEvent, relation::Reference};
289
290 #[test]
291 fn to_device_serialization() {
292 let key_verification_start_content = ToDeviceKeyVerificationStartEventContent {
293 from_device: "123".into(),
294 transaction_id: "456".into(),
295 method: StartMethod::SasV1(
296 SasV1ContentInit {
297 hashes: vec![HashAlgorithm::Sha256],
298 key_agreement_protocols: vec![KeyAgreementProtocol::Curve25519],
299 message_authentication_codes: vec![MessageAuthenticationCode::HkdfHmacSha256V2],
300 short_authentication_string: vec![ShortAuthenticationString::Decimal],
301 }
302 .into(),
303 ),
304 };
305
306 assert_to_canonical_json_eq!(
307 key_verification_start_content,
308 json!({
309 "from_device": "123",
310 "transaction_id": "456",
311 "method": "m.sas.v1",
312 "key_agreement_protocols": ["curve25519"],
313 "hashes": ["sha256"],
314 "message_authentication_codes": ["hkdf-hmac-sha256.v2"],
315 "short_authentication_string": ["decimal"],
316 }),
317 );
318
319 let secret = Base64::new(b"This is a secret to everybody".to_vec());
320
321 let key_verification_start_content = ToDeviceKeyVerificationStartEventContent {
322 from_device: "123".into(),
323 transaction_id: "456".into(),
324 method: StartMethod::ReciprocateV1(ReciprocateV1Content::new(secret.clone())),
325 };
326
327 assert_to_canonical_json_eq!(
328 key_verification_start_content,
329 json!({
330 "from_device": "123",
331 "method": "m.reciprocate.v1",
332 "secret": secret,
333 "transaction_id": "456",
334 }),
335 );
336 }
337
338 #[test]
339 fn in_room_serialization() {
340 let event_id = event_id!("$1598361704261elfgc:localhost");
341
342 let key_verification_start_content = KeyVerificationStartEventContent {
343 from_device: "123".into(),
344 relates_to: Reference { event_id: event_id.to_owned() },
345 method: StartMethod::SasV1(
346 SasV1ContentInit {
347 hashes: vec![HashAlgorithm::Sha256],
348 key_agreement_protocols: vec![KeyAgreementProtocol::Curve25519],
349 message_authentication_codes: vec![MessageAuthenticationCode::HkdfHmacSha256V2],
350 short_authentication_string: vec![ShortAuthenticationString::Decimal],
351 }
352 .into(),
353 ),
354 };
355
356 assert_to_canonical_json_eq!(
357 key_verification_start_content,
358 json!({
359 "from_device": "123",
360 "method": "m.sas.v1",
361 "key_agreement_protocols": ["curve25519"],
362 "hashes": ["sha256"],
363 "message_authentication_codes": ["hkdf-hmac-sha256.v2"],
364 "short_authentication_string": ["decimal"],
365 "m.relates_to": {
366 "rel_type": "m.reference",
367 "event_id": event_id,
368 },
369 }),
370 );
371
372 let secret = Base64::new(b"This is a secret to everybody".to_vec());
373
374 let key_verification_start_content = KeyVerificationStartEventContent {
375 from_device: "123".into(),
376 relates_to: Reference { event_id: event_id.to_owned() },
377 method: StartMethod::ReciprocateV1(ReciprocateV1Content::new(secret.clone())),
378 };
379
380 assert_to_canonical_json_eq!(
381 key_verification_start_content,
382 json!({
383 "from_device": "123",
384 "method": "m.reciprocate.v1",
385 "secret": secret,
386 "m.relates_to": {
387 "rel_type": "m.reference",
388 "event_id": event_id,
389 },
390 }),
391 );
392 }
393
394 #[test]
395 fn to_device_deserialization() {
396 let json = json!({
397 "from_device": "123",
398 "transaction_id": "456",
399 "method": "m.sas.v1",
400 "hashes": ["sha256"],
401 "key_agreement_protocols": ["curve25519"],
402 "message_authentication_codes": ["hkdf-hmac-sha256.v2"],
403 "short_authentication_string": ["decimal"]
404 });
405
406 let content = from_json_value::<ToDeviceKeyVerificationStartEventContent>(json).unwrap();
408 assert_eq!(content.from_device, "123");
409 assert_eq!(content.transaction_id, "456");
410
411 assert_matches!(content.method, StartMethod::SasV1(sas));
412 assert_eq!(sas.hashes, vec![HashAlgorithm::Sha256]);
413 assert_eq!(sas.key_agreement_protocols, vec![KeyAgreementProtocol::Curve25519]);
414 assert_eq!(
415 sas.message_authentication_codes,
416 vec![MessageAuthenticationCode::HkdfHmacSha256V2]
417 );
418 assert_eq!(sas.short_authentication_string, vec![ShortAuthenticationString::Decimal]);
419
420 let json = json!({
421 "content": {
422 "from_device": "123",
423 "transaction_id": "456",
424 "method": "m.sas.v1",
425 "key_agreement_protocols": ["curve25519"],
426 "hashes": ["sha256"],
427 "message_authentication_codes": ["hkdf-hmac-sha256.v2"],
428 "short_authentication_string": ["decimal"]
429 },
430 "type": "m.key.verification.start",
431 "sender": "@example:localhost",
432 });
433
434 let ev = from_json_value::<ToDeviceEvent<ToDeviceKeyVerificationStartEventContent>>(json)
435 .unwrap();
436 assert_eq!(ev.sender, "@example:localhost");
437 assert_eq!(ev.content.from_device, "123");
438 assert_eq!(ev.content.transaction_id, "456");
439
440 assert_matches!(ev.content.method, StartMethod::SasV1(sas));
441 assert_eq!(sas.hashes, vec![HashAlgorithm::Sha256]);
442 assert_eq!(sas.key_agreement_protocols, vec![KeyAgreementProtocol::Curve25519]);
443 assert_eq!(
444 sas.message_authentication_codes,
445 vec![MessageAuthenticationCode::HkdfHmacSha256V2]
446 );
447 assert_eq!(sas.short_authentication_string, vec![ShortAuthenticationString::Decimal]);
448
449 let json = json!({
450 "content": {
451 "from_device": "123",
452 "method": "m.reciprocate.v1",
453 "secret": "c2VjcmV0Cg",
454 "transaction_id": "456",
455 },
456 "type": "m.key.verification.start",
457 "sender": "@example:localhost",
458 });
459
460 let ev = from_json_value::<ToDeviceEvent<ToDeviceKeyVerificationStartEventContent>>(json)
461 .unwrap();
462 assert_eq!(ev.sender, "@example:localhost");
463 assert_eq!(ev.content.from_device, "123");
464 assert_eq!(ev.content.transaction_id, "456");
465
466 assert_matches!(ev.content.method, StartMethod::ReciprocateV1(reciprocate));
467 assert_eq!(reciprocate.secret.encode(), "c2VjcmV0Cg");
468 }
469
470 #[test]
471 fn in_room_deserialization() {
472 let json = json!({
473 "from_device": "123",
474 "method": "m.sas.v1",
475 "hashes": ["sha256"],
476 "key_agreement_protocols": ["curve25519"],
477 "message_authentication_codes": ["hkdf-hmac-sha256.v2"],
478 "short_authentication_string": ["decimal"],
479 "m.relates_to": {
480 "rel_type": "m.reference",
481 "event_id": "$1598361704261elfgc:localhost",
482 }
483 });
484
485 let content = from_json_value::<KeyVerificationStartEventContent>(json).unwrap();
487 assert_eq!(content.from_device, "123");
488 assert_eq!(content.relates_to.event_id, "$1598361704261elfgc:localhost");
489
490 assert_matches!(content.method, StartMethod::SasV1(sas));
491 assert_eq!(sas.hashes, vec![HashAlgorithm::Sha256]);
492 assert_eq!(sas.key_agreement_protocols, vec![KeyAgreementProtocol::Curve25519]);
493 assert_eq!(
494 sas.message_authentication_codes,
495 vec![MessageAuthenticationCode::HkdfHmacSha256V2]
496 );
497 assert_eq!(sas.short_authentication_string, vec![ShortAuthenticationString::Decimal]);
498
499 let json = json!({
500 "from_device": "123",
501 "method": "m.reciprocate.v1",
502 "secret": "c2VjcmV0Cg",
503 "m.relates_to": {
504 "rel_type": "m.reference",
505 "event_id": "$1598361704261elfgc:localhost",
506 }
507 });
508
509 let content = from_json_value::<KeyVerificationStartEventContent>(json).unwrap();
510 assert_eq!(content.from_device, "123");
511 assert_eq!(content.relates_to.event_id, "$1598361704261elfgc:localhost");
512
513 assert_matches!(content.method, StartMethod::ReciprocateV1(reciprocate));
514 assert_eq!(reciprocate.secret.encode(), "c2VjcmV0Cg");
515 }
516
517 #[test]
518 fn custom_to_device_serialization_roundtrip() {
519 let json = json!({
520 "from_device": "123",
521 "transaction_id": "456",
522 "method": "m.sas.custom",
523 "test": "field",
524 });
525
526 let content =
527 from_json_value::<ToDeviceKeyVerificationStartEventContent>(json.clone()).unwrap();
528
529 assert_eq!(content.from_device, "123");
530 assert_eq!(content.transaction_id, "456");
531 assert_eq!(content.method.method(), "m.sas.custom");
532 let data = &*content.method.data();
533 assert_eq!(data.len(), 1);
534 assert_let!(Some(JsonValue::String(value)) = data.get("test"));
535 assert_eq!(value, "field");
536
537 assert_to_canonical_json_eq!(content, json);
538 }
539}