1use std::borrow::Cow;
6
7use ruma_common::{
8 serde::{from_raw_json_value, Base64, JsonObject},
9 EventEncryptionAlgorithm, OwnedRoomId,
10};
11use ruma_macros::{EventContent, StringEnum};
12use serde::{de, Deserialize, Serialize};
13use serde_json::value::RawValue as RawJsonValue;
14
15use crate::PrivOwnedStr;
16
17#[derive(Clone, Debug, Serialize, EventContent)]
23#[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
24#[ruma_event(type = "m.room_key.withheld", kind = ToDevice)]
25pub struct ToDeviceRoomKeyWithheldEventContent {
26 pub algorithm: EventEncryptionAlgorithm,
30
31 #[serde(flatten)]
33 pub code: RoomKeyWithheldCodeInfo,
34
35 #[serde(skip_serializing_if = "Option::is_none")]
39 pub reason: Option<String>,
40
41 pub sender_key: Base64,
43}
44
45impl ToDeviceRoomKeyWithheldEventContent {
46 pub fn new(
49 algorithm: EventEncryptionAlgorithm,
50 code: RoomKeyWithheldCodeInfo,
51 sender_key: Base64,
52 ) -> Self {
53 Self { algorithm, code, reason: None, sender_key }
54 }
55}
56
57impl<'de> Deserialize<'de> for ToDeviceRoomKeyWithheldEventContent {
58 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
59 where
60 D: de::Deserializer<'de>,
61 {
62 #[derive(Deserialize)]
63 struct ToDeviceRoomKeyWithheldEventContentDeHelper {
64 algorithm: EventEncryptionAlgorithm,
65 reason: Option<String>,
66 sender_key: Base64,
67 }
68
69 let json = Box::<RawJsonValue>::deserialize(deserializer)?;
70
71 let ToDeviceRoomKeyWithheldEventContentDeHelper { algorithm, reason, sender_key } =
72 from_raw_json_value(&json)?;
73 let code = from_raw_json_value(&json)?;
74
75 Ok(Self { algorithm, code, reason, sender_key })
76 }
77}
78
79#[derive(Debug, Clone, Serialize)]
81#[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
82#[serde(tag = "code")]
83pub enum RoomKeyWithheldCodeInfo {
84 #[serde(rename = "m.blacklisted")]
88 Blacklisted(Box<RoomKeyWithheldSessionData>),
89
90 #[serde(rename = "m.unverified")]
95 Unverified(Box<RoomKeyWithheldSessionData>),
96
97 #[serde(rename = "m.unauthorised")]
103 Unauthorized(Box<RoomKeyWithheldSessionData>),
104
105 #[serde(rename = "m.unavailable")]
110 Unavailable(Box<RoomKeyWithheldSessionData>),
111
112 #[serde(rename = "m.no_olm")]
116 NoOlm,
117
118 #[doc(hidden)]
119 #[serde(untagged)]
120 _Custom(Box<CustomRoomKeyWithheldCodeInfo>),
121}
122
123impl RoomKeyWithheldCodeInfo {
124 pub fn code(&self) -> RoomKeyWithheldCode {
126 match self {
127 Self::Blacklisted(_) => RoomKeyWithheldCode::Blacklisted,
128 Self::Unverified(_) => RoomKeyWithheldCode::Unverified,
129 Self::Unauthorized(_) => RoomKeyWithheldCode::Unauthorized,
130 Self::Unavailable(_) => RoomKeyWithheldCode::Unavailable,
131 Self::NoOlm => RoomKeyWithheldCode::NoOlm,
132 Self::_Custom(info) => info.code.as_str().into(),
133 }
134 }
135}
136
137impl<'de> Deserialize<'de> for RoomKeyWithheldCodeInfo {
138 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
139 where
140 D: de::Deserializer<'de>,
141 {
142 #[derive(Debug, Deserialize)]
143 struct ExtractCode<'a> {
144 #[serde(borrow)]
145 code: Cow<'a, str>,
146 }
147
148 let json = Box::<RawJsonValue>::deserialize(deserializer)?;
149 let ExtractCode { code } = from_raw_json_value(&json)?;
150
151 Ok(match code.as_ref() {
152 "m.blacklisted" => Self::Blacklisted(from_raw_json_value(&json)?),
153 "m.unverified" => Self::Unverified(from_raw_json_value(&json)?),
154 "m.unauthorised" => Self::Unauthorized(from_raw_json_value(&json)?),
155 "m.unavailable" => Self::Unavailable(from_raw_json_value(&json)?),
156 "m.no_olm" => Self::NoOlm,
157 _ => Self::_Custom(from_raw_json_value(&json)?),
158 })
159 }
160}
161
162#[derive(Debug, Clone, Serialize, Deserialize)]
164#[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
165pub struct RoomKeyWithheldSessionData {
166 pub room_id: OwnedRoomId,
168
169 pub session_id: String,
171}
172
173impl RoomKeyWithheldSessionData {
174 pub fn new(room_id: OwnedRoomId, session_id: String) -> Self {
176 Self { room_id, session_id }
177 }
178}
179
180#[doc(hidden)]
182#[derive(Clone, Debug, Deserialize, Serialize)]
183pub struct CustomRoomKeyWithheldCodeInfo {
184 code: String,
186
187 #[serde(flatten)]
189 data: JsonObject,
190}
191
192#[doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/src/doc/string_enum.md"))]
194#[derive(Clone, PartialEq, Eq, StringEnum)]
195#[ruma_enum(rename_all = "m.snake_case")]
196#[non_exhaustive]
197pub enum RoomKeyWithheldCode {
198 Blacklisted,
202
203 Unverified,
208
209 Unauthorized,
215
216 Unavailable,
221
222 NoOlm,
226
227 #[doc(hidden)]
228 _Custom(PrivOwnedStr),
229}
230
231#[cfg(test)]
232mod tests {
233 use assert_matches2::assert_matches;
234 use ruma_common::{owned_room_id, serde::Base64, EventEncryptionAlgorithm};
235 use serde_json::{from_value as from_json_value, json, to_value as to_json_value};
236
237 use super::{
238 RoomKeyWithheldCodeInfo, RoomKeyWithheldSessionData, ToDeviceRoomKeyWithheldEventContent,
239 };
240
241 const PUBLIC_KEY: &[u8] = b"key";
242 const BASE64_ENCODED_PUBLIC_KEY: &str = "a2V5";
243
244 #[test]
245 fn serialization_no_olm() {
246 let content = ToDeviceRoomKeyWithheldEventContent::new(
247 EventEncryptionAlgorithm::MegolmV1AesSha2,
248 RoomKeyWithheldCodeInfo::NoOlm,
249 Base64::new(PUBLIC_KEY.to_owned()),
250 );
251
252 assert_eq!(
253 to_json_value(content).unwrap(),
254 json!({
255 "algorithm": "m.megolm.v1.aes-sha2",
256 "code": "m.no_olm",
257 "sender_key": BASE64_ENCODED_PUBLIC_KEY,
258 })
259 );
260 }
261
262 #[test]
263 fn serialization_blacklisted() {
264 let room_id = owned_room_id!("!roomid:localhost");
265 let content = ToDeviceRoomKeyWithheldEventContent::new(
266 EventEncryptionAlgorithm::MegolmV1AesSha2,
267 RoomKeyWithheldCodeInfo::Blacklisted(
268 RoomKeyWithheldSessionData::new(room_id.clone(), "unique_id".to_owned()).into(),
269 ),
270 Base64::new(PUBLIC_KEY.to_owned()),
271 );
272
273 assert_eq!(
274 to_json_value(content).unwrap(),
275 json!({
276 "algorithm": "m.megolm.v1.aes-sha2",
277 "code": "m.blacklisted",
278 "sender_key": BASE64_ENCODED_PUBLIC_KEY,
279 "room_id": room_id,
280 "session_id": "unique_id",
281 })
282 );
283 }
284
285 #[test]
286 fn deserialization_no_olm() {
287 let json = json!({
288 "algorithm": "m.megolm.v1.aes-sha2",
289 "code": "m.no_olm",
290 "sender_key": BASE64_ENCODED_PUBLIC_KEY,
291 "reason": "Could not find an olm session",
292 });
293
294 let content = from_json_value::<ToDeviceRoomKeyWithheldEventContent>(json).unwrap();
295 assert_eq!(content.algorithm, EventEncryptionAlgorithm::MegolmV1AesSha2);
296 assert_eq!(content.sender_key, Base64::new(PUBLIC_KEY.to_owned()));
297 assert_eq!(content.reason.as_deref(), Some("Could not find an olm session"));
298 assert_matches!(content.code, RoomKeyWithheldCodeInfo::NoOlm);
299 }
300
301 #[test]
302 fn deserialization_blacklisted() {
303 let room_id = owned_room_id!("!roomid:localhost");
304 let json = json!({
305 "algorithm": "m.megolm.v1.aes-sha2",
306 "code": "m.blacklisted",
307 "sender_key": BASE64_ENCODED_PUBLIC_KEY,
308 "room_id": room_id,
309 "session_id": "unique_id",
310 });
311
312 let content = from_json_value::<ToDeviceRoomKeyWithheldEventContent>(json).unwrap();
313 assert_eq!(content.algorithm, EventEncryptionAlgorithm::MegolmV1AesSha2);
314 assert_eq!(content.sender_key, Base64::new(PUBLIC_KEY.to_owned()));
315 assert_eq!(content.reason, None);
316 assert_matches!(content.code, RoomKeyWithheldCodeInfo::Blacklisted(session_data));
317 assert_eq!(session_data.room_id, room_id);
318 assert_eq!(session_data.session_id, "unique_id");
319 }
320
321 #[test]
322 fn custom_room_key_withheld_code_info_round_trip() {
323 let room_id = owned_room_id!("!roomid:localhost");
324 let json = json!({
325 "algorithm": "m.megolm.v1.aes-sha2",
326 "code": "dev.ruma.custom_code",
327 "sender_key": BASE64_ENCODED_PUBLIC_KEY,
328 "room_id": room_id,
329 "key": "value",
330 });
331
332 let content = from_json_value::<ToDeviceRoomKeyWithheldEventContent>(json.clone()).unwrap();
333 assert_eq!(content.code.code().as_str(), "dev.ruma.custom_code");
334
335 assert_eq!(to_json_value(&content).unwrap(), json);
336 }
337}