Skip to main content

ruma_events/key/verification/
accept.rs

1//! Types for the [`m.key.verification.accept`] event.
2//!
3//! [`m.key.verification.accept`]: https://spec.matrix.org/v1.18/client-server-api/#mkeyverificationaccept
4
5use std::borrow::Cow;
6
7use as_variant::as_variant;
8use ruma_common::{
9    OwnedTransactionId,
10    serde::{Base64, JsonObject},
11};
12use ruma_macros::EventContent;
13use serde::{Deserialize, Deserializer, Serialize, de};
14use serde_json::{Value as JsonValue, from_value as from_json_value};
15
16use super::{
17    HashAlgorithm, KeyAgreementProtocol, MessageAuthenticationCode, ShortAuthenticationString,
18};
19use crate::relation::Reference;
20
21/// The content of a to-device `m.key.verification.accept` event.
22///
23/// Accepts a previously sent `m.key.verification.start` message.
24#[derive(Clone, Debug, Deserialize, Serialize, EventContent)]
25#[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
26#[ruma_event(type = "m.key.verification.accept", kind = ToDevice)]
27pub struct ToDeviceKeyVerificationAcceptEventContent {
28    /// An opaque identifier for the verification process.
29    ///
30    /// Must be the same as the one used for the `m.key.verification.start` message.
31    pub transaction_id: OwnedTransactionId,
32
33    /// The method specific content.
34    #[serde(flatten)]
35    pub method: AcceptMethod,
36}
37
38impl ToDeviceKeyVerificationAcceptEventContent {
39    /// Creates a new `ToDeviceKeyVerificationAcceptEventContent` with the given transaction ID and
40    /// method-specific content.
41    pub fn new(transaction_id: OwnedTransactionId, method: AcceptMethod) -> Self {
42        Self { transaction_id, method }
43    }
44}
45
46/// The content of a in-room `m.key.verification.accept` event.
47///
48/// Accepts a previously sent `m.key.verification.start` message.
49#[derive(Clone, Debug, Deserialize, Serialize, EventContent)]
50#[ruma_event(type = "m.key.verification.accept", kind = MessageLike)]
51#[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
52pub struct KeyVerificationAcceptEventContent {
53    /// The method specific content.
54    #[serde(flatten)]
55    pub method: AcceptMethod,
56
57    /// Information about the related event.
58    #[serde(rename = "m.relates_to")]
59    pub relates_to: Reference,
60}
61
62impl KeyVerificationAcceptEventContent {
63    /// Creates a new `KeyVerificationAcceptEventContent` with the given method-specific
64    /// content and reference.
65    pub fn new(method: AcceptMethod, relates_to: Reference) -> Self {
66        Self { method, relates_to }
67    }
68}
69
70/// An enum representing the different method specific `m.key.verification.accept` content.
71#[derive(Clone, Debug, Serialize)]
72#[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
73#[serde(untagged)]
74pub enum AcceptMethod {
75    /// The `m.sas.v1` verification method.
76    SasV1(SasV1Content),
77
78    /// Any unknown accept method.
79    #[doc(hidden)]
80    _Custom(_CustomAcceptMethodContent),
81}
82
83impl AcceptMethod {
84    /// The value of the `method` field.
85    pub fn method(&self) -> &str {
86        match self {
87            Self::SasV1(_) => "m.sas.v1",
88            Self::_Custom(c) => &c.method,
89        }
90    }
91
92    /// The data of this `AcceptMethod`.
93    ///
94    /// The returned JSON object won't contain the `method` field, use [`.method()`][Self::method]
95    /// to access it.
96    ///
97    /// Prefer to use the public variants of `AcceptMethod` where possible; this method is meant to
98    /// be used for custom methods only.
99    pub fn data(&self) -> Cow<'_, JsonObject> {
100        fn serialize<T: Serialize>(obj: T) -> JsonObject {
101            match serde_json::to_value(obj).expect("accept method serialization to succeed") {
102                JsonValue::Object(mut obj) => {
103                    obj.remove("method");
104                    obj
105                }
106                _ => panic!("all accept method variants must serialize to objects"),
107            }
108        }
109
110        match self {
111            Self::SasV1(c) => Cow::Owned(serialize(c)),
112            Self::_Custom(c) => Cow::Borrowed(&c.data),
113        }
114    }
115}
116
117impl<'de> Deserialize<'de> for AcceptMethod {
118    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
119    where
120        D: Deserializer<'de>,
121    {
122        let mut data = JsonObject::deserialize(deserializer)?;
123
124        let method = data
125            .get("method")
126            .and_then(|value| as_variant!(value, JsonValue::String))
127            .ok_or_else(|| de::Error::missing_field("method"))?;
128
129        match method.as_ref() {
130            "m.sas.v1" => from_json_value(data.into()).map(Self::SasV1),
131            _ => {
132                let method = as_variant!(
133                    data.remove("method")
134                        .expect("we already checked that the method field is present"),
135                    JsonValue::String
136                )
137                .expect("we already checked that the method is a string");
138
139                Ok(Self::_Custom(_CustomAcceptMethodContent { method, data }))
140            }
141        }
142        .map_err(de::Error::custom)
143    }
144}
145
146/// Method specific content of a unknown key verification method.
147#[doc(hidden)]
148#[derive(Clone, Debug, Serialize)]
149pub struct _CustomAcceptMethodContent {
150    /// The name of the method.
151    method: String,
152
153    /// The additional fields that the method contains.
154    #[serde(flatten)]
155    data: JsonObject,
156}
157
158/// The payload of an `m.key.verification.accept` event using the `m.sas.v1` method.
159#[derive(Clone, Debug, Deserialize, Serialize)]
160#[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
161#[serde(rename = "m.sas.v1", tag = "method")]
162pub struct SasV1Content {
163    /// The key agreement protocol the device is choosing to use, out of the
164    /// options in the `m.key.verification.start` message.
165    pub key_agreement_protocol: KeyAgreementProtocol,
166
167    /// The hash method the device is choosing to use, out of the options in the
168    /// `m.key.verification.start` message.
169    pub hash: HashAlgorithm,
170
171    /// The message authentication code the device is choosing to use, out of
172    /// the options in the `m.key.verification.start` message.
173    pub message_authentication_code: MessageAuthenticationCode,
174
175    /// The SAS methods both devices involved in the verification process
176    /// understand.
177    ///
178    /// Must be a subset of the options in the `m.key.verification.start`
179    /// message.
180    pub short_authentication_string: Vec<ShortAuthenticationString>,
181
182    /// The hash (encoded as unpadded base64) of the concatenation of the
183    /// device's ephemeral public key (encoded as unpadded base64) and the
184    /// canonical JSON representation of the `m.key.verification.start` message.
185    pub commitment: Base64,
186}
187
188/// Mandatory initial set of fields for creating an accept `SasV1Content`.
189#[derive(Debug)]
190#[allow(clippy::exhaustive_structs)]
191pub struct SasV1ContentInit {
192    /// The key agreement protocol the device is choosing to use, out of the
193    /// options in the `m.key.verification.start` message.
194    pub key_agreement_protocol: KeyAgreementProtocol,
195
196    /// The hash method the device is choosing to use, out of the options in the
197    /// `m.key.verification.start` message.
198    pub hash: HashAlgorithm,
199
200    /// The message authentication codes that the accepting device understands.
201    pub message_authentication_code: MessageAuthenticationCode,
202
203    /// The SAS methods both devices involved in the verification process
204    /// understand.
205    ///
206    /// Must be a subset of the options in the `m.key.verification.start`
207    /// message.
208    pub short_authentication_string: Vec<ShortAuthenticationString>,
209
210    /// The hash (encoded as unpadded base64) of the concatenation of the
211    /// device's ephemeral public key (encoded as unpadded base64) and the
212    /// canonical JSON representation of the `m.key.verification.start` message.
213    pub commitment: Base64,
214}
215
216impl From<SasV1ContentInit> for SasV1Content {
217    /// Creates a new `SasV1Content` from the given init struct.
218    fn from(init: SasV1ContentInit) -> Self {
219        SasV1Content {
220            hash: init.hash,
221            key_agreement_protocol: init.key_agreement_protocol,
222            message_authentication_code: init.message_authentication_code,
223            short_authentication_string: init.short_authentication_string,
224            commitment: init.commitment,
225        }
226    }
227}
228
229#[cfg(test)]
230mod tests {
231    use assert_matches2::{assert_let, assert_matches};
232    use ruma_common::{
233        canonical_json::assert_to_canonical_json_eq,
234        event_id,
235        serde::{Base64, Raw},
236    };
237    use serde_json::{Value as JsonValue, from_value as from_json_value, json};
238
239    use super::{
240        AcceptMethod, HashAlgorithm, KeyAgreementProtocol, KeyVerificationAcceptEventContent,
241        MessageAuthenticationCode, SasV1Content, ShortAuthenticationString,
242        ToDeviceKeyVerificationAcceptEventContent,
243    };
244    use crate::{ToDeviceEvent, relation::Reference};
245
246    #[test]
247    fn to_device_serialization() {
248        let key_verification_accept_content = ToDeviceKeyVerificationAcceptEventContent {
249            transaction_id: "456".into(),
250            method: AcceptMethod::SasV1(SasV1Content {
251                hash: HashAlgorithm::Sha256,
252                key_agreement_protocol: KeyAgreementProtocol::Curve25519,
253                message_authentication_code: MessageAuthenticationCode::HkdfHmacSha256V2,
254                short_authentication_string: vec![ShortAuthenticationString::Decimal],
255                commitment: Base64::new(b"hello".to_vec()),
256            }),
257        };
258
259        assert_to_canonical_json_eq!(
260            key_verification_accept_content,
261            json!({
262                "transaction_id": "456",
263                "method": "m.sas.v1",
264                "commitment": "aGVsbG8",
265                "key_agreement_protocol": "curve25519",
266                "hash": "sha256",
267                "message_authentication_code": "hkdf-hmac-sha256.v2",
268                "short_authentication_string": ["decimal"],
269            }),
270        );
271    }
272
273    #[test]
274    fn in_room_serialization() {
275        let event_id = event_id!("$1598361704261elfgc:localhost");
276
277        let key_verification_accept_content = KeyVerificationAcceptEventContent {
278            relates_to: Reference { event_id: event_id.to_owned() },
279            method: AcceptMethod::SasV1(SasV1Content {
280                hash: HashAlgorithm::Sha256,
281                key_agreement_protocol: KeyAgreementProtocol::Curve25519,
282                message_authentication_code: MessageAuthenticationCode::HkdfHmacSha256V2,
283                short_authentication_string: vec![ShortAuthenticationString::Decimal],
284                commitment: Base64::new(b"hello".to_vec()),
285            }),
286        };
287
288        assert_to_canonical_json_eq!(
289            key_verification_accept_content,
290            json!({
291                "method": "m.sas.v1",
292                "commitment": "aGVsbG8",
293                "key_agreement_protocol": "curve25519",
294                "hash": "sha256",
295                "message_authentication_code": "hkdf-hmac-sha256.v2",
296                "short_authentication_string": ["decimal"],
297                "m.relates_to": {
298                    "rel_type": "m.reference",
299                    "event_id": event_id,
300                },
301            }),
302        );
303    }
304
305    #[test]
306    fn to_device_deserialization() {
307        let json = json!({
308            "transaction_id": "456",
309            "commitment": "aGVsbG8",
310            "method": "m.sas.v1",
311            "hash": "sha256",
312            "key_agreement_protocol": "curve25519",
313            "message_authentication_code": "hkdf-hmac-sha256.v2",
314            "short_authentication_string": ["decimal"]
315        });
316
317        // Deserialize the content struct separately to verify `TryFromRaw` is implemented for it.
318        let content = from_json_value::<ToDeviceKeyVerificationAcceptEventContent>(json).unwrap();
319        assert_eq!(content.transaction_id, "456");
320
321        assert_matches!(content.method, AcceptMethod::SasV1(sas));
322        assert_eq!(sas.commitment.encode(), "aGVsbG8");
323        assert_eq!(sas.hash, HashAlgorithm::Sha256);
324        assert_eq!(sas.key_agreement_protocol, KeyAgreementProtocol::Curve25519);
325        assert_eq!(sas.message_authentication_code, MessageAuthenticationCode::HkdfHmacSha256V2);
326        assert_eq!(sas.short_authentication_string, vec![ShortAuthenticationString::Decimal]);
327
328        let json = json!({
329            "content": {
330                "commitment": "aGVsbG8",
331                "transaction_id": "456",
332                "method": "m.sas.v1",
333                "key_agreement_protocol": "curve25519",
334                "hash": "sha256",
335                "message_authentication_code": "hkdf-hmac-sha256.v2",
336                "short_authentication_string": ["decimal"]
337            },
338            "type": "m.key.verification.accept",
339            "sender": "@example:localhost",
340        });
341
342        let ev = from_json_value::<ToDeviceEvent<ToDeviceKeyVerificationAcceptEventContent>>(json)
343            .unwrap();
344        assert_eq!(ev.content.transaction_id, "456");
345        assert_eq!(ev.sender, "@example:localhost");
346
347        assert_matches!(ev.content.method, AcceptMethod::SasV1(sas));
348        assert_eq!(sas.commitment.encode(), "aGVsbG8");
349        assert_eq!(sas.hash, HashAlgorithm::Sha256);
350        assert_eq!(sas.key_agreement_protocol, KeyAgreementProtocol::Curve25519);
351        assert_eq!(sas.message_authentication_code, MessageAuthenticationCode::HkdfHmacSha256V2);
352        assert_eq!(sas.short_authentication_string, vec![ShortAuthenticationString::Decimal]);
353    }
354
355    #[test]
356    fn in_room_deserialization() {
357        let json = json!({
358            "commitment": "aGVsbG8",
359            "method": "m.sas.v1",
360            "hash": "sha256",
361            "key_agreement_protocol": "curve25519",
362            "message_authentication_code": "hkdf-hmac-sha256.v2",
363            "short_authentication_string": ["decimal"],
364            "m.relates_to": {
365                "rel_type": "m.reference",
366                "event_id": "$1598361704261elfgc:localhost",
367            }
368        });
369
370        // Deserialize the content struct separately to verify `TryFromRaw` is implemented for it.
371        let content = from_json_value::<KeyVerificationAcceptEventContent>(json).unwrap();
372        assert_eq!(content.relates_to.event_id, "$1598361704261elfgc:localhost");
373
374        assert_matches!(content.method, AcceptMethod::SasV1(sas));
375        assert_eq!(sas.commitment.encode(), "aGVsbG8");
376        assert_eq!(sas.hash, HashAlgorithm::Sha256);
377        assert_eq!(sas.key_agreement_protocol, KeyAgreementProtocol::Curve25519);
378        assert_eq!(sas.message_authentication_code, MessageAuthenticationCode::HkdfHmacSha256V2);
379        assert_eq!(sas.short_authentication_string, vec![ShortAuthenticationString::Decimal]);
380    }
381
382    #[test]
383    fn in_room_serialization_roundtrip() {
384        let event_id = event_id!("$1598361704261elfgc:localhost");
385
386        let content = KeyVerificationAcceptEventContent {
387            relates_to: Reference { event_id: event_id.to_owned() },
388            method: AcceptMethod::SasV1(SasV1Content {
389                hash: HashAlgorithm::Sha256,
390                key_agreement_protocol: KeyAgreementProtocol::Curve25519,
391                message_authentication_code: MessageAuthenticationCode::HkdfHmacSha256V2,
392                short_authentication_string: vec![ShortAuthenticationString::Decimal],
393                commitment: Base64::new(b"hello".to_vec()),
394            }),
395        };
396
397        let json_content = Raw::new(&content).unwrap();
398        let deser_content = json_content.deserialize().unwrap();
399
400        assert_matches!(deser_content.method, AcceptMethod::SasV1(_));
401        assert_eq!(deser_content.relates_to.event_id, event_id);
402    }
403
404    #[test]
405    fn custom_to_device_serialization_roundtrip() {
406        let json = json!({
407            "transaction_id": "456",
408            "method": "m.sas.custom",
409            "test": "field",
410        });
411
412        let content =
413            from_json_value::<ToDeviceKeyVerificationAcceptEventContent>(json.clone()).unwrap();
414
415        assert_eq!(content.transaction_id, "456");
416        assert_eq!(content.method.method(), "m.sas.custom");
417        let data = &*content.method.data();
418        assert_eq!(data.len(), 1);
419        assert_let!(Some(JsonValue::String(value)) = data.get("test"));
420        assert_eq!(value, "field");
421
422        assert_to_canonical_json_eq!(content, json);
423    }
424}