1use std::borrow::Cow;
6
7use ruma_common::{
8 EventEncryptionAlgorithm, OwnedRoomId,
9 serde::{Base64, JsonObject, from_raw_json_value},
10};
11use ruma_macros::{EventContent, StringEnum};
12use serde::{Deserialize, Serialize, de};
13use serde_json::value::RawValue as RawJsonValue;
14
15use crate::PrivOwnedStr;
16
17#[derive(Clone, Debug, Serialize, EventContent)]
23#[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
24#[ruma_event(type = "m.room_key.withheld", kind = ToDevice)]
25pub struct ToDeviceRoomKeyWithheldEventContent {
26 pub algorithm: EventEncryptionAlgorithm,
30
31 #[serde(flatten)]
33 pub code: RoomKeyWithheldCodeInfo,
34
35 #[serde(skip_serializing_if = "Option::is_none")]
39 pub reason: Option<String>,
40
41 pub sender_key: Base64,
43}
44
45impl ToDeviceRoomKeyWithheldEventContent {
46 pub fn new(
49 algorithm: EventEncryptionAlgorithm,
50 code: RoomKeyWithheldCodeInfo,
51 sender_key: Base64,
52 ) -> Self {
53 Self { algorithm, code, reason: None, sender_key }
54 }
55}
56
57impl<'de> Deserialize<'de> for ToDeviceRoomKeyWithheldEventContent {
58 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
59 where
60 D: de::Deserializer<'de>,
61 {
62 #[derive(Deserialize)]
63 struct ToDeviceRoomKeyWithheldEventContentDeHelper {
64 algorithm: EventEncryptionAlgorithm,
65 reason: Option<String>,
66 sender_key: Base64,
67 }
68
69 let json = Box::<RawJsonValue>::deserialize(deserializer)?;
70
71 let ToDeviceRoomKeyWithheldEventContentDeHelper { algorithm, reason, sender_key } =
72 from_raw_json_value(&json)?;
73 let code = from_raw_json_value(&json)?;
74
75 Ok(Self { algorithm, code, reason, sender_key })
76 }
77}
78
79#[derive(Debug, Clone, Serialize)]
81#[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
82#[serde(tag = "code")]
83pub enum RoomKeyWithheldCodeInfo {
84 #[serde(rename = "m.blacklisted")]
88 Blacklisted(Box<RoomKeyWithheldSessionData>),
89
90 #[serde(rename = "m.unverified")]
95 Unverified(Box<RoomKeyWithheldSessionData>),
96
97 #[serde(rename = "m.unauthorised")]
103 Unauthorized(Box<RoomKeyWithheldSessionData>),
104
105 #[serde(rename = "m.unavailable")]
110 Unavailable(Box<RoomKeyWithheldSessionData>),
111
112 #[serde(rename = "m.no_olm")]
116 NoOlm,
117
118 #[doc(hidden)]
119 #[serde(untagged)]
120 _Custom(Box<CustomRoomKeyWithheldCodeInfo>),
121}
122
123impl RoomKeyWithheldCodeInfo {
124 pub fn code(&self) -> RoomKeyWithheldCode {
126 match self {
127 Self::Blacklisted(_) => RoomKeyWithheldCode::Blacklisted,
128 Self::Unverified(_) => RoomKeyWithheldCode::Unverified,
129 Self::Unauthorized(_) => RoomKeyWithheldCode::Unauthorized,
130 Self::Unavailable(_) => RoomKeyWithheldCode::Unavailable,
131 Self::NoOlm => RoomKeyWithheldCode::NoOlm,
132 Self::_Custom(info) => info.code.as_str().into(),
133 }
134 }
135}
136
137impl<'de> Deserialize<'de> for RoomKeyWithheldCodeInfo {
138 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
139 where
140 D: de::Deserializer<'de>,
141 {
142 #[derive(Debug, Deserialize)]
143 struct ExtractCode<'a> {
144 #[serde(borrow)]
145 code: Cow<'a, str>,
146 }
147
148 let json = Box::<RawJsonValue>::deserialize(deserializer)?;
149 let ExtractCode { code } = from_raw_json_value(&json)?;
150
151 Ok(match code.as_ref() {
152 "m.blacklisted" => Self::Blacklisted(from_raw_json_value(&json)?),
153 "m.unverified" => Self::Unverified(from_raw_json_value(&json)?),
154 "m.unauthorised" => Self::Unauthorized(from_raw_json_value(&json)?),
155 "m.unavailable" => Self::Unavailable(from_raw_json_value(&json)?),
156 "m.no_olm" => Self::NoOlm,
157 _ => Self::_Custom(from_raw_json_value(&json)?),
158 })
159 }
160}
161
162#[derive(Debug, Clone, Serialize, Deserialize)]
164#[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
165pub struct RoomKeyWithheldSessionData {
166 pub room_id: OwnedRoomId,
168
169 pub session_id: String,
171}
172
173impl RoomKeyWithheldSessionData {
174 pub fn new(room_id: OwnedRoomId, session_id: String) -> Self {
176 Self { room_id, session_id }
177 }
178}
179
180#[doc(hidden)]
182#[derive(Clone, Debug, Deserialize, Serialize)]
183pub struct CustomRoomKeyWithheldCodeInfo {
184 code: String,
186
187 #[serde(flatten)]
189 data: JsonObject,
190}
191
192#[doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/src/doc/string_enum.md"))]
194#[derive(Clone, StringEnum)]
195#[ruma_enum(rename_all(prefix = "m.", rule = "snake_case"))]
196#[non_exhaustive]
197pub enum RoomKeyWithheldCode {
198 Blacklisted,
202
203 Unverified,
208
209 Unauthorized,
215
216 Unavailable,
221
222 NoOlm,
226
227 #[doc(hidden)]
228 _Custom(PrivOwnedStr),
229}
230
231#[cfg(test)]
232mod tests {
233 use assert_matches2::assert_matches;
234 use ruma_common::{
235 EventEncryptionAlgorithm, canonical_json::assert_to_canonical_json_eq, owned_room_id,
236 serde::Base64,
237 };
238 use serde_json::{from_value as from_json_value, json, to_value as to_json_value};
239
240 use super::{
241 RoomKeyWithheldCodeInfo, RoomKeyWithheldSessionData, ToDeviceRoomKeyWithheldEventContent,
242 };
243
244 const PUBLIC_KEY: &[u8] = b"key";
245 const BASE64_ENCODED_PUBLIC_KEY: &str = "a2V5";
246
247 #[test]
248 fn serialization_no_olm() {
249 let content = ToDeviceRoomKeyWithheldEventContent::new(
250 EventEncryptionAlgorithm::MegolmV1AesSha2,
251 RoomKeyWithheldCodeInfo::NoOlm,
252 Base64::new(PUBLIC_KEY.to_owned()),
253 );
254
255 assert_to_canonical_json_eq!(
256 content,
257 json!({
258 "algorithm": "m.megolm.v1.aes-sha2",
259 "code": "m.no_olm",
260 "sender_key": BASE64_ENCODED_PUBLIC_KEY,
261 })
262 );
263 }
264
265 #[test]
266 fn serialization_blacklisted() {
267 let room_id = owned_room_id!("!roomid:localhost");
268 let content = ToDeviceRoomKeyWithheldEventContent::new(
269 EventEncryptionAlgorithm::MegolmV1AesSha2,
270 RoomKeyWithheldCodeInfo::Blacklisted(
271 RoomKeyWithheldSessionData::new(room_id.clone(), "unique_id".to_owned()).into(),
272 ),
273 Base64::new(PUBLIC_KEY.to_owned()),
274 );
275
276 assert_to_canonical_json_eq!(
277 content,
278 json!({
279 "algorithm": "m.megolm.v1.aes-sha2",
280 "code": "m.blacklisted",
281 "sender_key": BASE64_ENCODED_PUBLIC_KEY,
282 "room_id": room_id,
283 "session_id": "unique_id",
284 })
285 );
286 }
287
288 #[test]
289 fn deserialization_no_olm() {
290 let json = json!({
291 "algorithm": "m.megolm.v1.aes-sha2",
292 "code": "m.no_olm",
293 "sender_key": BASE64_ENCODED_PUBLIC_KEY,
294 "reason": "Could not find an olm session",
295 });
296
297 let content = from_json_value::<ToDeviceRoomKeyWithheldEventContent>(json).unwrap();
298 assert_eq!(content.algorithm, EventEncryptionAlgorithm::MegolmV1AesSha2);
299 assert_eq!(content.sender_key, Base64::new(PUBLIC_KEY.to_owned()));
300 assert_eq!(content.reason.as_deref(), Some("Could not find an olm session"));
301 assert_matches!(content.code, RoomKeyWithheldCodeInfo::NoOlm);
302 }
303
304 #[test]
305 fn deserialization_blacklisted() {
306 let room_id = owned_room_id!("!roomid:localhost");
307 let json = json!({
308 "algorithm": "m.megolm.v1.aes-sha2",
309 "code": "m.blacklisted",
310 "sender_key": BASE64_ENCODED_PUBLIC_KEY,
311 "room_id": room_id,
312 "session_id": "unique_id",
313 });
314
315 let content = from_json_value::<ToDeviceRoomKeyWithheldEventContent>(json).unwrap();
316 assert_eq!(content.algorithm, EventEncryptionAlgorithm::MegolmV1AesSha2);
317 assert_eq!(content.sender_key, Base64::new(PUBLIC_KEY.to_owned()));
318 assert_eq!(content.reason, None);
319 assert_matches!(content.code, RoomKeyWithheldCodeInfo::Blacklisted(session_data));
320 assert_eq!(session_data.room_id, room_id);
321 assert_eq!(session_data.session_id, "unique_id");
322 }
323
324 #[test]
325 fn custom_room_key_withheld_code_info_round_trip() {
326 let room_id = owned_room_id!("!roomid:localhost");
327 let json = json!({
328 "algorithm": "m.megolm.v1.aes-sha2",
329 "code": "dev.ruma.custom_code",
330 "sender_key": BASE64_ENCODED_PUBLIC_KEY,
331 "room_id": room_id,
332 "key": "value",
333 });
334
335 let content = from_json_value::<ToDeviceRoomKeyWithheldEventContent>(json.clone()).unwrap();
336 assert_eq!(content.code.code().as_str(), "dev.ruma.custom_code");
337
338 assert_eq!(to_json_value(&content).unwrap(), json);
339 }
340}