1use std::{borrow::Cow, collections::BTreeMap};
6
7use ruma_common::{
8 serde::{from_raw_json_value, ignore_invalid_vec_items},
9 space::SpaceRoomJoinRule,
10 OwnedRoomId,
11};
12use ruma_macros::EventContent;
13use serde::{
14 de::{Deserializer, Error},
15 Deserialize, Serialize,
16};
17use serde_json::{value::RawValue as RawJsonValue, Value as JsonValue};
18
19use crate::{EmptyStateKey, PrivOwnedStr};
20
21#[derive(Clone, Debug, Serialize, EventContent)]
25#[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
26#[ruma_event(type = "m.room.join_rules", kind = State, state_key_type = EmptyStateKey)]
27pub struct RoomJoinRulesEventContent {
28 #[ruma_event(skip_redaction)]
30 #[serde(flatten)]
31 pub join_rule: JoinRule,
32}
33
34impl RoomJoinRulesEventContent {
35 pub fn new(join_rule: JoinRule) -> Self {
37 Self { join_rule }
38 }
39
40 pub fn restricted(allow: Vec<AllowRule>) -> Self {
43 Self { join_rule: JoinRule::Restricted(Restricted::new(allow)) }
44 }
45
46 pub fn knock_restricted(allow: Vec<AllowRule>) -> Self {
49 Self { join_rule: JoinRule::KnockRestricted(Restricted::new(allow)) }
50 }
51}
52
53impl<'de> Deserialize<'de> for RoomJoinRulesEventContent {
54 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
55 where
56 D: Deserializer<'de>,
57 {
58 let join_rule = JoinRule::deserialize(deserializer)?;
59 Ok(RoomJoinRulesEventContent { join_rule })
60 }
61}
62
63impl RoomJoinRulesEvent {
64 pub fn join_rule(&self) -> &JoinRule {
66 match self {
67 Self::Original(ev) => &ev.content.join_rule,
68 Self::Redacted(ev) => &ev.content.join_rule,
69 }
70 }
71}
72
73impl SyncRoomJoinRulesEvent {
74 pub fn join_rule(&self) -> &JoinRule {
76 match self {
77 Self::Original(ev) => &ev.content.join_rule,
78 Self::Redacted(ev) => &ev.content.join_rule,
79 }
80 }
81}
82
83#[derive(Clone, Debug, PartialEq, Eq, Serialize)]
88#[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
89#[serde(tag = "join_rule", rename_all = "snake_case")]
90pub enum JoinRule {
91 Invite,
94
95 Knock,
99
100 Private,
102
103 Restricted(Restricted),
106
107 KnockRestricted(Restricted),
110
111 Public,
113
114 #[doc(hidden)]
115 #[serde(skip_serializing)]
116 _Custom(PrivOwnedStr),
117}
118
119impl JoinRule {
120 pub fn as_str(&self) -> &str {
122 match self {
123 JoinRule::Invite => "invite",
124 JoinRule::Knock => "knock",
125 JoinRule::Private => "private",
126 JoinRule::Restricted(_) => "restricted",
127 JoinRule::KnockRestricted(_) => "knock_restricted",
128 JoinRule::Public => "public",
129 JoinRule::_Custom(rule) => &rule.0,
130 }
131 }
132}
133
134impl<'de> Deserialize<'de> for JoinRule {
135 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
136 where
137 D: Deserializer<'de>,
138 {
139 let json: Box<RawJsonValue> = Box::deserialize(deserializer)?;
140
141 #[derive(Deserialize)]
142 struct ExtractType<'a> {
143 #[serde(borrow)]
144 join_rule: Option<Cow<'a, str>>,
145 }
146
147 let join_rule = serde_json::from_str::<ExtractType<'_>>(json.get())
148 .map_err(Error::custom)?
149 .join_rule
150 .ok_or_else(|| D::Error::missing_field("join_rule"))?;
151
152 match join_rule.as_ref() {
153 "invite" => Ok(Self::Invite),
154 "knock" => Ok(Self::Knock),
155 "private" => Ok(Self::Private),
156 "restricted" => from_raw_json_value(&json).map(Self::Restricted),
157 "knock_restricted" => from_raw_json_value(&json).map(Self::KnockRestricted),
158 "public" => Ok(Self::Public),
159 _ => Ok(Self::_Custom(PrivOwnedStr(join_rule.into()))),
160 }
161 }
162}
163
164impl From<JoinRule> for SpaceRoomJoinRule {
165 fn from(value: JoinRule) -> Self {
166 value.as_str().into()
167 }
168}
169
170#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
172#[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
173pub struct Restricted {
174 #[serde(default, deserialize_with = "ignore_invalid_vec_items")]
176 pub allow: Vec<AllowRule>,
177}
178
179impl Restricted {
180 pub fn new(allow: Vec<AllowRule>) -> Self {
182 Self { allow }
183 }
184}
185
186#[derive(Clone, Debug, PartialEq, Eq, Serialize)]
188#[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
189#[serde(untagged)]
190pub enum AllowRule {
191 RoomMembership(RoomMembership),
193
194 #[doc(hidden)]
195 _Custom(Box<CustomAllowRule>),
196}
197
198impl AllowRule {
199 pub fn room_membership(room_id: OwnedRoomId) -> Self {
201 Self::RoomMembership(RoomMembership::new(room_id))
202 }
203}
204
205#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
207#[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
208#[serde(tag = "type", rename = "m.room_membership")]
209pub struct RoomMembership {
210 pub room_id: OwnedRoomId,
212}
213
214impl RoomMembership {
215 pub fn new(room_id: OwnedRoomId) -> Self {
217 Self { room_id }
218 }
219}
220
221#[doc(hidden)]
222#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
223#[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
224pub struct CustomAllowRule {
225 #[serde(rename = "type")]
226 rule_type: String,
227 #[serde(flatten)]
228 extra: BTreeMap<String, JsonValue>,
229}
230
231impl<'de> Deserialize<'de> for AllowRule {
232 fn deserialize<D>(deserializer: D) -> Result<AllowRule, D::Error>
233 where
234 D: Deserializer<'de>,
235 {
236 let json: Box<RawJsonValue> = Box::deserialize(deserializer)?;
237
238 #[derive(Deserialize)]
240 struct ExtractType<'a> {
241 #[serde(borrow, rename = "type")]
242 rule_type: Option<Cow<'a, str>>,
243 }
244
245 let rule_type =
247 serde_json::from_str::<ExtractType<'_>>(json.get()).map_err(Error::custom)?.rule_type;
248
249 match rule_type.as_deref() {
250 Some("m.room_membership") => from_raw_json_value(&json).map(Self::RoomMembership),
251 Some(_) => from_raw_json_value(&json).map(Self::_Custom),
252 None => Err(D::Error::missing_field("type")),
253 }
254 }
255}
256
257#[cfg(test)]
258mod tests {
259 use std::collections::BTreeMap;
260
261 use assert_matches2::assert_matches;
262 use ruma_common::owned_room_id;
263
264 use super::{
265 AllowRule, CustomAllowRule, JoinRule, OriginalSyncRoomJoinRulesEvent, Restricted,
266 RoomJoinRulesEventContent, SpaceRoomJoinRule,
267 };
268
269 #[test]
270 fn deserialize() {
271 let json = r#"{"join_rule": "public"}"#;
272 let event: RoomJoinRulesEventContent = serde_json::from_str(json).unwrap();
273 assert_matches!(event, RoomJoinRulesEventContent { join_rule: JoinRule::Public });
274 }
275
276 #[test]
277 fn deserialize_restricted() {
278 let json = r#"{
279 "join_rule": "restricted",
280 "allow": [
281 {
282 "type": "m.room_membership",
283 "room_id": "!mods:example.org"
284 },
285 {
286 "type": "m.room_membership",
287 "room_id": "!users:example.org"
288 }
289 ]
290 }"#;
291 let event: RoomJoinRulesEventContent = serde_json::from_str(json).unwrap();
292 match event.join_rule {
293 JoinRule::Restricted(restricted) => assert_eq!(
294 restricted.allow,
295 &[
296 AllowRule::room_membership(owned_room_id!("!mods:example.org")),
297 AllowRule::room_membership(owned_room_id!("!users:example.org"))
298 ]
299 ),
300 rule => panic!("Deserialized to wrong variant: {rule:?}"),
301 }
302 }
303
304 #[test]
305 fn deserialize_restricted_event() {
306 let json = r#"{
307 "type": "m.room.join_rules",
308 "sender": "@admin:community.rs",
309 "content": {
310 "join_rule": "restricted",
311 "allow": [
312 { "type": "m.room_membership","room_id": "!KqeUnzmXPIhHRaWMTs:mccarty.io" }
313 ]
314 },
315 "state_key": "",
316 "origin_server_ts":1630508835342,
317 "unsigned": {
318 "age":4165521871
319 },
320 "event_id": "$0ACb9KSPlT3al3kikyRYvFhMqXPP9ZcQOBrsdIuh58U"
321 }"#;
322
323 assert_matches!(serde_json::from_str::<OriginalSyncRoomJoinRulesEvent>(json), Ok(_));
324 }
325
326 #[test]
327 fn roundtrip_custom_allow_rule() {
328 let json = r#"{"type":"org.msc9000.something","foo":"bar"}"#;
329 let allow_rule: AllowRule = serde_json::from_str(json).unwrap();
330 assert_matches!(&allow_rule, AllowRule::_Custom(_));
331 assert_eq!(serde_json::to_string(&allow_rule).unwrap(), json);
332 }
333
334 #[test]
335 fn restricted_room_no_allow_field() {
336 let json = r#"{"join_rule":"restricted"}"#;
337 let join_rules: RoomJoinRulesEventContent = serde_json::from_str(json).unwrap();
338 assert_matches!(
339 join_rules,
340 RoomJoinRulesEventContent { join_rule: JoinRule::Restricted(_) }
341 );
342 }
343
344 #[test]
345 fn invalid_allow_items() {
346 let json = r#"{
347 "join_rule": "restricted",
348 "allow": [
349 {
350 "type": "m.room_membership",
351 "room_id": "!mods:example.org"
352 },
353 {
354 "type": "m.room_membership",
355 "room_id": ""
356 },
357 {
358 "type": "m.room_membership",
359 "room_id": "not a room id"
360 },
361 {
362 "type": "org.example.custom",
363 "org.example.minimum_role": "developer"
364 },
365 {
366 "not even close": "to being correct",
367 "any object": "passes this test",
368 "only non-objects in this array": "cause deserialization to fail"
369 }
370 ]
371 }"#;
372 let event: RoomJoinRulesEventContent = serde_json::from_str(json).unwrap();
373
374 assert_matches!(event.join_rule, JoinRule::Restricted(restricted));
375 assert_eq!(
376 restricted.allow,
377 &[
378 AllowRule::room_membership(owned_room_id!("!mods:example.org")),
379 AllowRule::_Custom(Box::new(CustomAllowRule {
380 rule_type: "org.example.custom".into(),
381 extra: BTreeMap::from([(
382 "org.example.minimum_role".into(),
383 "developer".into()
384 )])
385 }))
386 ]
387 );
388 }
389
390 #[test]
391 fn join_rule_to_space_room_join_rule() {
392 assert_eq!(SpaceRoomJoinRule::Invite, JoinRule::Invite.into());
393 assert_eq!(SpaceRoomJoinRule::Knock, JoinRule::Knock.into());
394 assert_eq!(
395 SpaceRoomJoinRule::KnockRestricted,
396 JoinRule::KnockRestricted(Restricted::default()).into()
397 );
398 assert_eq!(SpaceRoomJoinRule::Public, JoinRule::Public.into());
399 assert_eq!(SpaceRoomJoinRule::Private, JoinRule::Private.into());
400 assert_eq!(
401 SpaceRoomJoinRule::Restricted,
402 JoinRule::Restricted(Restricted::default()).into()
403 );
404 }
405}