1use std::borrow::Cow;
6
7use as_variant::as_variant;
8use ruma_common::{
9 EventEncryptionAlgorithm, OwnedRoomId,
10 serde::{Base64, JsonObject, from_raw_json_value},
11};
12use ruma_macros::{EventContent, StringEnum};
13use serde::{Deserialize, Serialize, de};
14use serde_json::{Value as JsonValue, value::RawValue as RawJsonValue};
15
16use crate::PrivOwnedStr;
17
18#[derive(Clone, Debug, Serialize, EventContent)]
24#[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
25#[ruma_event(type = "m.room_key.withheld", kind = ToDevice)]
26pub struct ToDeviceRoomKeyWithheldEventContent {
27 pub algorithm: EventEncryptionAlgorithm,
31
32 #[serde(flatten)]
34 pub code: RoomKeyWithheldCodeInfo,
35
36 #[serde(skip_serializing_if = "Option::is_none")]
40 pub reason: Option<String>,
41
42 pub sender_key: Base64,
44}
45
46impl ToDeviceRoomKeyWithheldEventContent {
47 pub fn new(
50 algorithm: EventEncryptionAlgorithm,
51 code: RoomKeyWithheldCodeInfo,
52 sender_key: Base64,
53 ) -> Self {
54 Self { algorithm, code, reason: None, sender_key }
55 }
56}
57
58impl<'de> Deserialize<'de> for ToDeviceRoomKeyWithheldEventContent {
59 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
60 where
61 D: de::Deserializer<'de>,
62 {
63 #[derive(Deserialize)]
64 struct ToDeviceRoomKeyWithheldEventContentDeHelper {
65 algorithm: EventEncryptionAlgorithm,
66 reason: Option<String>,
67 sender_key: Base64,
68 }
69
70 let json = Box::<RawJsonValue>::deserialize(deserializer)?;
71
72 let ToDeviceRoomKeyWithheldEventContentDeHelper { algorithm, reason, sender_key } =
73 from_raw_json_value(&json)?;
74 let code = from_raw_json_value(&json)?;
75
76 Ok(Self { algorithm, code, reason, sender_key })
77 }
78}
79
80#[derive(Debug, Clone, Serialize)]
82#[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
83#[serde(tag = "code")]
84pub enum RoomKeyWithheldCodeInfo {
85 #[serde(rename = "m.blacklisted")]
89 Blacklisted(Box<RoomKeyWithheldSessionData>),
90
91 #[serde(rename = "m.unverified")]
96 Unverified(Box<RoomKeyWithheldSessionData>),
97
98 #[serde(rename = "m.unauthorised")]
104 Unauthorized(Box<RoomKeyWithheldSessionData>),
105
106 #[serde(rename = "m.unavailable")]
111 Unavailable(Box<RoomKeyWithheldSessionData>),
112
113 #[serde(rename = "m.no_olm")]
117 NoOlm,
118
119 #[doc(hidden)]
120 #[serde(untagged)]
121 _Custom(Box<CustomRoomKeyWithheldCodeInfo>),
122}
123
124impl RoomKeyWithheldCodeInfo {
125 pub fn code(&self) -> RoomKeyWithheldCode {
127 match self {
128 Self::Blacklisted(_) => RoomKeyWithheldCode::Blacklisted,
129 Self::Unverified(_) => RoomKeyWithheldCode::Unverified,
130 Self::Unauthorized(_) => RoomKeyWithheldCode::Unauthorized,
131 Self::Unavailable(_) => RoomKeyWithheldCode::Unavailable,
132 Self::NoOlm => RoomKeyWithheldCode::NoOlm,
133 Self::_Custom(info) => info.code.as_str().into(),
134 }
135 }
136}
137
138impl<'de> Deserialize<'de> for RoomKeyWithheldCodeInfo {
139 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
140 where
141 D: de::Deserializer<'de>,
142 {
143 #[derive(Debug, Deserialize)]
144 struct ExtractCode<'a> {
145 #[serde(borrow)]
146 code: Cow<'a, str>,
147 }
148
149 let json = Box::<RawJsonValue>::deserialize(deserializer)?;
150 let ExtractCode { code } = from_raw_json_value(&json)?;
151
152 Ok(match code.as_ref() {
153 "m.blacklisted" => Self::Blacklisted(from_raw_json_value(&json)?),
154 "m.unverified" => Self::Unverified(from_raw_json_value(&json)?),
155 "m.unauthorised" => Self::Unauthorized(from_raw_json_value(&json)?),
156 "m.unavailable" => Self::Unavailable(from_raw_json_value(&json)?),
157 "m.no_olm" => Self::NoOlm,
158 _ => {
159 let mut data = from_raw_json_value::<JsonObject, _>(&json)?;
160
161 data.remove("algorithm");
165 data.remove("sender_key");
166 data.remove("reason");
167
168 let code = as_variant!(
169 data.remove("code").expect("we already checked that the code field is present"),
170 JsonValue::String
171 )
172 .expect("we already checked that the code is a string");
173
174 Self::_Custom(CustomRoomKeyWithheldCodeInfo { code, data }.into())
175 }
176 })
177 }
178}
179
180#[derive(Debug, Clone, Serialize, Deserialize)]
182#[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
183pub struct RoomKeyWithheldSessionData {
184 pub room_id: OwnedRoomId,
186
187 pub session_id: String,
189}
190
191impl RoomKeyWithheldSessionData {
192 pub fn new(room_id: OwnedRoomId, session_id: String) -> Self {
194 Self { room_id, session_id }
195 }
196}
197
198#[doc(hidden)]
200#[derive(Clone, Debug, Serialize)]
201pub struct CustomRoomKeyWithheldCodeInfo {
202 code: String,
204
205 #[serde(flatten)]
207 data: JsonObject,
208}
209
210#[doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/src/doc/string_enum.md"))]
212#[derive(Clone, StringEnum)]
213#[ruma_enum(rename_all(prefix = "m.", rule = "snake_case"))]
214#[non_exhaustive]
215pub enum RoomKeyWithheldCode {
216 Blacklisted,
220
221 Unverified,
226
227 Unauthorized,
233
234 Unavailable,
239
240 NoOlm,
244
245 #[doc(hidden)]
246 _Custom(PrivOwnedStr),
247}
248
249#[cfg(test)]
250mod tests {
251 use assert_matches2::assert_matches;
252 use ruma_common::{
253 EventEncryptionAlgorithm, canonical_json::assert_to_canonical_json_eq, owned_room_id,
254 serde::Base64,
255 };
256 use serde_json::{from_value as from_json_value, json};
257
258 use super::{
259 RoomKeyWithheldCodeInfo, RoomKeyWithheldSessionData, ToDeviceRoomKeyWithheldEventContent,
260 };
261
262 const PUBLIC_KEY: &[u8] = b"key";
263 const BASE64_ENCODED_PUBLIC_KEY: &str = "a2V5";
264
265 #[test]
266 fn serialization_no_olm() {
267 let content = ToDeviceRoomKeyWithheldEventContent::new(
268 EventEncryptionAlgorithm::MegolmV1AesSha2,
269 RoomKeyWithheldCodeInfo::NoOlm,
270 Base64::new(PUBLIC_KEY.to_owned()),
271 );
272
273 assert_to_canonical_json_eq!(
274 content,
275 json!({
276 "algorithm": "m.megolm.v1.aes-sha2",
277 "code": "m.no_olm",
278 "sender_key": BASE64_ENCODED_PUBLIC_KEY,
279 })
280 );
281 }
282
283 #[test]
284 fn serialization_blacklisted() {
285 let room_id = owned_room_id!("!roomid:localhost");
286 let content = ToDeviceRoomKeyWithheldEventContent::new(
287 EventEncryptionAlgorithm::MegolmV1AesSha2,
288 RoomKeyWithheldCodeInfo::Blacklisted(
289 RoomKeyWithheldSessionData::new(room_id.clone(), "unique_id".to_owned()).into(),
290 ),
291 Base64::new(PUBLIC_KEY.to_owned()),
292 );
293
294 assert_to_canonical_json_eq!(
295 content,
296 json!({
297 "algorithm": "m.megolm.v1.aes-sha2",
298 "code": "m.blacklisted",
299 "sender_key": BASE64_ENCODED_PUBLIC_KEY,
300 "room_id": room_id,
301 "session_id": "unique_id",
302 })
303 );
304 }
305
306 #[test]
307 fn deserialization_no_olm() {
308 let json = json!({
309 "algorithm": "m.megolm.v1.aes-sha2",
310 "code": "m.no_olm",
311 "sender_key": BASE64_ENCODED_PUBLIC_KEY,
312 "reason": "Could not find an olm session",
313 });
314
315 let content = from_json_value::<ToDeviceRoomKeyWithheldEventContent>(json).unwrap();
316 assert_eq!(content.algorithm, EventEncryptionAlgorithm::MegolmV1AesSha2);
317 assert_eq!(content.sender_key, Base64::new(PUBLIC_KEY.to_owned()));
318 assert_eq!(content.reason.as_deref(), Some("Could not find an olm session"));
319 assert_matches!(content.code, RoomKeyWithheldCodeInfo::NoOlm);
320 }
321
322 #[test]
323 fn deserialization_blacklisted() {
324 let room_id = owned_room_id!("!roomid:localhost");
325 let json = json!({
326 "algorithm": "m.megolm.v1.aes-sha2",
327 "code": "m.blacklisted",
328 "sender_key": BASE64_ENCODED_PUBLIC_KEY,
329 "room_id": room_id,
330 "session_id": "unique_id",
331 });
332
333 let content = from_json_value::<ToDeviceRoomKeyWithheldEventContent>(json).unwrap();
334 assert_eq!(content.algorithm, EventEncryptionAlgorithm::MegolmV1AesSha2);
335 assert_eq!(content.sender_key, Base64::new(PUBLIC_KEY.to_owned()));
336 assert_eq!(content.reason, None);
337 assert_matches!(content.code, RoomKeyWithheldCodeInfo::Blacklisted(session_data));
338 assert_eq!(session_data.room_id, room_id);
339 assert_eq!(session_data.session_id, "unique_id");
340 }
341
342 #[test]
343 fn custom_room_key_withheld_code_info_round_trip() {
344 let room_id = owned_room_id!("!roomid:localhost");
345 let json = json!({
346 "algorithm": "m.megolm.v1.aes-sha2",
347 "code": "dev.ruma.custom_code",
348 "sender_key": BASE64_ENCODED_PUBLIC_KEY,
349 "room_id": room_id,
350 "key": "value",
351 });
352
353 let content = from_json_value::<ToDeviceRoomKeyWithheldEventContent>(json.clone()).unwrap();
354 assert_eq!(content.code.code().as_str(), "dev.ruma.custom_code");
355
356 assert_to_canonical_json_eq!(content, json);
357 }
358}