1use std::{
2 borrow::Borrow,
3 cmp::{Ordering, Reverse},
4 collections::{BinaryHeap, HashMap, HashSet},
5 hash::Hash,
6 sync::OnceLock,
7};
8
9use js_int::Int;
10use ruma_common::{
11 room_version_rules::AuthorizationRules, EventId, MilliSecondsSinceUnixEpoch, OwnedUserId,
12};
13use ruma_events::{room::member::MembershipState, StateEventType, TimelineEventType};
14use tracing::{debug, info, instrument, trace, warn};
15
16mod error;
17pub mod event_auth;
18pub mod events;
19#[cfg(test)]
20mod test_utils;
21
22use self::events::{
23 member::RoomMemberEvent, power_levels::RoomPowerLevelsEventOptionExt, RoomCreateEvent,
24 RoomPowerLevelsEvent,
25};
26pub use self::{
27 error::{Error, Result},
28 event_auth::{auth_check, auth_types_for_event},
29 events::Event,
30};
31
32pub type StateMap<T> = HashMap<(StateEventType, String), T>;
34
35#[instrument(skip(rules, state_sets, auth_chain_sets, fetch_event))]
58pub fn resolve<'a, E, SetIter>(
59 rules: &AuthorizationRules,
60 state_sets: impl IntoIterator<IntoIter = SetIter>,
61 auth_chain_sets: Vec<HashSet<E::Id>>,
62 fetch_event: impl Fn(&EventId) -> Option<E>,
63) -> Result<StateMap<E::Id>>
64where
65 E: Event + Clone,
66 E::Id: 'a,
67 SetIter: Iterator<Item = &'a StateMap<E::Id>> + Clone,
68{
69 info!("state resolution starting");
70
71 let (clean, conflicting) = separate(state_sets.into_iter());
73
74 info!(count = clean.len(), "non-conflicting events");
75 trace!(map = ?clean, "non-conflicting events");
76
77 if conflicting.is_empty() {
78 info!("no conflicting state found");
79 return Ok(clean);
80 }
81
82 info!(count = conflicting.len(), "conflicting events");
83 trace!(map = ?conflicting, "conflicting events");
84
85 let all_conflicted: HashSet<_> = get_auth_chain_diff(auth_chain_sets)
88 .chain(conflicting.into_values().flatten())
89 .filter(|id| fetch_event(id.borrow()).is_some())
91 .collect();
92
93 info!(count = all_conflicted.len(), "full conflicted set");
94 trace!(set = ?all_conflicted, "full conflicted set");
95
96 let control_events = all_conflicted
101 .iter()
102 .filter(|&id| is_power_event_id(id.borrow(), &fetch_event))
103 .cloned()
104 .collect::<Vec<_>>();
105
106 let sorted_control_levels =
108 reverse_topological_power_sort(control_events, &all_conflicted, rules, &fetch_event)?;
109
110 debug!(count = sorted_control_levels.len(), "power events");
111 trace!(list = ?sorted_control_levels, "sorted power events");
112
113 let resolved_control =
115 iterative_auth_check(rules, &sorted_control_levels, clean.clone(), &fetch_event)?;
116
117 debug!(count = resolved_control.len(), "resolved power events");
118 trace!(map = ?resolved_control, "resolved power events");
119
120 let deduped_power_ev = sorted_control_levels.into_iter().collect::<HashSet<_>>();
123
124 let events_to_resolve = all_conflicted
127 .iter()
128 .filter(|&id| !deduped_power_ev.contains(id.borrow()))
129 .cloned()
130 .collect::<Vec<_>>();
131
132 debug!(count = events_to_resolve.len(), "events left to resolve");
133 trace!(list = ?events_to_resolve, "events left to resolve");
134
135 let power_event = resolved_control.get(&(StateEventType::RoomPowerLevels, "".into()));
137
138 debug!(event_id = ?power_event, "power event");
139
140 let sorted_left_events = mainline_sort(&events_to_resolve, power_event.cloned(), &fetch_event)?;
141
142 trace!(list = ?sorted_left_events, "events left, sorted");
143
144 let mut resolved_state = iterative_auth_check(
145 rules,
146 &sorted_left_events,
147 resolved_control, &fetch_event,
149 )?;
150
151 resolved_state.extend(clean);
154
155 info!("state resolution finished");
156
157 Ok(resolved_state)
158}
159
160fn separate<'a, Id>(
168 state_sets_iter: impl Iterator<Item = &'a StateMap<Id>>,
169) -> (StateMap<Id>, StateMap<Vec<Id>>)
170where
171 Id: Clone + Eq + Hash + 'a,
172{
173 let mut state_set_count = 0_usize;
174 let mut occurrences = HashMap::<_, HashMap<_, _>>::new();
175
176 let state_sets_iter = state_sets_iter.inspect(|_| state_set_count += 1);
177 for (k, v) in state_sets_iter.flatten() {
178 occurrences.entry(k).or_default().entry(v).and_modify(|x| *x += 1).or_insert(1);
179 }
180
181 let mut unconflicted_state = StateMap::new();
182 let mut conflicted_state = StateMap::new();
183
184 for (k, v) in occurrences {
185 for (id, occurrence_count) in v {
186 if occurrence_count == state_set_count {
187 unconflicted_state.insert((k.0.clone(), k.1.clone()), id.clone());
188 } else {
189 conflicted_state
190 .entry((k.0.clone(), k.1.clone()))
191 .and_modify(|x: &mut Vec<_>| x.push(id.clone()))
192 .or_insert(vec![id.clone()]);
193 }
194 }
195 }
196
197 (unconflicted_state, conflicted_state)
198}
199
200fn get_auth_chain_diff<Id>(auth_chain_sets: Vec<HashSet<Id>>) -> impl Iterator<Item = Id>
202where
203 Id: Eq + Hash,
204{
205 let num_sets = auth_chain_sets.len();
206
207 let mut id_counts: HashMap<Id, usize> = HashMap::new();
208 for id in auth_chain_sets.into_iter().flatten() {
209 *id_counts.entry(id).or_default() += 1;
210 }
211
212 id_counts.into_iter().filter_map(move |(id, count)| (count < num_sets).then_some(id))
213}
214
215#[instrument(skip_all)]
223fn reverse_topological_power_sort<E: Event>(
224 events_to_sort: Vec<E::Id>,
225 auth_diff: &HashSet<E::Id>,
226 rules: &AuthorizationRules,
227 fetch_event: impl Fn(&EventId) -> Option<E>,
228) -> Result<Vec<E::Id>> {
229 debug!("reverse topological sort of power events");
230
231 let mut graph = HashMap::new();
232 for event_id in events_to_sort {
233 add_event_and_auth_chain_to_graph(&mut graph, event_id, auth_diff, &fetch_event);
234
235 }
239
240 let mut event_to_pl = HashMap::new();
242 let creator_lock = OnceLock::new();
245
246 for event_id in graph.keys() {
247 let pl = get_power_level_for_sender(event_id.borrow(), rules, &creator_lock, &fetch_event)
248 .map_err(Error::AuthEvent)?;
249 debug!(
250 event_id = event_id.borrow().as_str(),
251 power_level = i64::from(pl),
252 "found the power level of an event's sender",
253 );
254
255 event_to_pl.insert(event_id.clone(), pl);
256
257 }
261
262 lexicographical_topological_sort(&graph, |event_id| {
263 let ev = fetch_event(event_id).ok_or_else(|| Error::NotFound(event_id.to_owned()))?;
264 let pl = *event_to_pl.get(event_id).ok_or_else(|| Error::NotFound(event_id.to_owned()))?;
265 Ok((pl, ev.origin_server_ts()))
266 })
267}
268
269#[instrument(skip_all)]
274pub fn lexicographical_topological_sort<Id, F>(
275 graph: &HashMap<Id, HashSet<Id>>,
276 key_fn: F,
277) -> Result<Vec<Id>>
278where
279 F: Fn(&EventId) -> Result<(Int, MilliSecondsSinceUnixEpoch)>,
280 Id: Clone + Eq + Ord + Hash + Borrow<EventId>,
281{
282 #[derive(PartialEq, Eq)]
283 struct TieBreaker<'a, Id> {
284 power_level: Int,
285 origin_server_ts: MilliSecondsSinceUnixEpoch,
286 event_id: &'a Id,
287 }
288
289 impl<Id> Ord for TieBreaker<'_, Id>
290 where
291 Id: Ord,
292 {
293 fn cmp(&self, other: &Self) -> Ordering {
294 other
302 .power_level
303 .cmp(&self.power_level)
304 .then(self.origin_server_ts.cmp(&other.origin_server_ts))
305 .then(self.event_id.cmp(other.event_id))
306 }
307 }
308
309 impl<Id> PartialOrd for TieBreaker<'_, Id>
310 where
311 Id: Ord,
312 {
313 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
314 Some(self.cmp(other))
315 }
316 }
317
318 let mut outdegree_map = graph.clone();
328
329 let mut reverse_graph: HashMap<_, HashSet<_>> = HashMap::new();
332
333 let mut zero_outdegree = Vec::new();
335
336 for (node, edges) in graph {
337 if edges.is_empty() {
338 let (power_level, origin_server_ts) = key_fn(node.borrow())?;
339 zero_outdegree.push(Reverse(TieBreaker {
342 power_level,
343 origin_server_ts,
344 event_id: node,
345 }));
346 }
347
348 reverse_graph.entry(node).or_default();
349 for edge in edges {
350 reverse_graph.entry(edge).or_default().insert(node);
351 }
352 }
353
354 let mut heap = BinaryHeap::from(zero_outdegree);
355
356 let mut sorted = vec![];
358 while let Some(Reverse(item)) = heap.pop() {
360 let node = item.event_id;
361
362 for &parent in reverse_graph.get(node).expect("EventId in heap is also in reverse_graph") {
363 let out = outdegree_map
365 .get_mut(parent.borrow())
366 .expect("outdegree_map knows of all referenced EventIds");
367
368 out.remove(node.borrow());
370 if out.is_empty() {
371 let (power_level, origin_server_ts) = key_fn(parent.borrow())?;
372 heap.push(Reverse(TieBreaker { power_level, origin_server_ts, event_id: parent }));
373 }
374 }
375
376 sorted.push(node.clone());
378 }
379
380 Ok(sorted)
381}
382
383fn get_power_level_for_sender<E: Event>(
389 event_id: &EventId,
390 rules: &AuthorizationRules,
391 creator_lock: &OnceLock<OwnedUserId>,
392 fetch_event: impl Fn(&EventId) -> Option<E>,
393) -> std::result::Result<Int, String> {
394 let event = fetch_event(event_id);
395 let mut room_create_event = None;
396 let mut room_power_levels_event = None;
397
398 for aid in event.as_ref().map(|pdu| pdu.auth_events()).into_iter().flatten() {
399 if let Some(aev) = fetch_event(aid.borrow()) {
400 if is_type_and_key(&aev, &TimelineEventType::RoomPowerLevels, "") {
401 room_power_levels_event = Some(RoomPowerLevelsEvent::new(aev));
402 } else if creator_lock.get().is_none()
403 && is_type_and_key(&aev, &TimelineEventType::RoomCreate, "")
404 {
405 room_create_event = Some(RoomCreateEvent::new(aev));
406 }
407
408 if room_power_levels_event.is_some()
409 && (creator_lock.get().is_some() || room_create_event.is_some())
410 {
411 break;
412 }
413 }
414 }
415
416 let creator = if let Some(creator) = creator_lock.get() {
418 Some(creator)
419 } else if let Some(room_create_event) = room_create_event {
420 let creator = room_create_event.creator(rules)?;
421 Some(creator_lock.get_or_init(|| creator.into_owned()))
422 } else {
423 None
424 };
425
426 if let Some((event, creator)) = event.zip(creator) {
427 room_power_levels_event.user_power_level(event.sender(), creator, rules)
428 } else {
429 room_power_levels_event
430 .get_as_int_or_default(events::RoomPowerLevelsIntField::UsersDefault, rules)
431 }
432}
433
434fn iterative_auth_check<E: Event + Clone>(
444 rules: &AuthorizationRules,
445 events_to_check: &[E::Id],
446 unconflicted_state: StateMap<E::Id>,
447 fetch_event: impl Fn(&EventId) -> Option<E>,
448) -> Result<StateMap<E::Id>> {
449 debug!("starting iterative auth check");
450
451 trace!(list = ?events_to_check, "events to check");
452
453 let mut resolved_state = unconflicted_state;
454
455 for event_id in events_to_check {
456 let event = fetch_event(event_id.borrow())
457 .ok_or_else(|| Error::NotFound(event_id.borrow().to_owned()))?;
458 let state_key = event.state_key().ok_or(Error::MissingStateKey)?;
459
460 let mut auth_events = StateMap::new();
461 for aid in event.auth_events() {
462 if let Some(ev) = fetch_event(aid.borrow()) {
463 auth_events.insert(
466 ev.event_type().with_state_key(ev.state_key().ok_or(Error::MissingStateKey)?),
467 ev,
468 );
469 } else {
470 warn!(event_id = aid.borrow().as_str(), "missing auth event");
471 }
472 }
473
474 let auth_types = match auth_types_for_event(
475 event.event_type(),
476 event.sender(),
477 Some(state_key),
478 event.content(),
479 rules,
480 ) {
481 Ok(auth_types) => auth_types,
482 Err(error) => {
483 warn!("failed to get list of required auth events for malformed event: {error}");
484 continue;
485 }
486 };
487
488 for key in auth_types {
489 if let Some(ev_id) = resolved_state.get(&key) {
490 if let Some(event) = fetch_event(ev_id.borrow()) {
491 auth_events.insert(key.to_owned(), event);
493 }
494 }
495 }
496
497 match auth_check(rules, &event, |ty, key| auth_events.get(&ty.with_state_key(key))) {
498 Ok(()) => {
499 resolved_state
501 .insert(event.event_type().with_state_key(state_key), event_id.clone());
502 }
503 Err(error) => {
504 warn!("event failed the authentication check: {error}");
506 }
507 }
508
509 }
513 Ok(resolved_state)
514}
515
516fn mainline_sort<E: Event>(
524 to_sort: &[E::Id],
525 resolved_power_level: Option<E::Id>,
526 fetch_event: impl Fn(&EventId) -> Option<E>,
527) -> Result<Vec<E::Id>> {
528 debug!("mainline sort of events");
529
530 if to_sort.is_empty() {
532 return Ok(vec![]);
533 }
534
535 let mut mainline = vec![];
536 let mut pl = resolved_power_level;
537 while let Some(p) = pl {
538 mainline.push(p.clone());
539
540 let event =
541 fetch_event(p.borrow()).ok_or_else(|| Error::NotFound(p.borrow().to_owned()))?;
542 pl = None;
543 for aid in event.auth_events() {
544 let ev =
545 fetch_event(aid.borrow()).ok_or_else(|| Error::NotFound(p.borrow().to_owned()))?;
546 if is_type_and_key(&ev, &TimelineEventType::RoomPowerLevels, "") {
547 pl = Some(aid.to_owned());
548 break;
549 }
550 }
551 }
555
556 let mainline_map = mainline
557 .iter()
558 .rev()
559 .enumerate()
560 .map(|(idx, eid)| ((*eid).clone(), idx))
561 .collect::<HashMap<_, _>>();
562
563 let mut order_map = HashMap::new();
564 for ev_id in to_sort.iter() {
565 if let Some(event) = fetch_event(ev_id.borrow()) {
566 if let Ok(depth) = get_mainline_depth(Some(event), &mainline_map, &fetch_event) {
567 order_map.insert(
568 ev_id,
569 (depth, fetch_event(ev_id.borrow()).map(|ev| ev.origin_server_ts()), ev_id),
570 );
571 }
572 }
573
574 }
578
579 let mut sort_event_ids = order_map.keys().map(|&k| k.clone()).collect::<Vec<_>>();
582 sort_event_ids.sort_by_key(|sort_id| order_map.get(sort_id).unwrap());
583
584 Ok(sort_event_ids)
585}
586
587fn get_mainline_depth<E: Event>(
590 mut event: Option<E>,
591 mainline_map: &HashMap<E::Id, usize>,
592 fetch_event: impl Fn(&EventId) -> Option<E>,
593) -> Result<usize> {
594 while let Some(sort_ev) = event {
595 debug!(event_id = sort_ev.event_id().borrow().as_str(), "mainline");
596 let id = sort_ev.event_id();
597 if let Some(depth) = mainline_map.get(id.borrow()) {
598 return Ok(*depth);
599 }
600
601 event = None;
602 for aid in sort_ev.auth_events() {
603 let aev = fetch_event(aid.borrow())
604 .ok_or_else(|| Error::NotFound(aid.borrow().to_owned()))?;
605 if is_type_and_key(&aev, &TimelineEventType::RoomPowerLevels, "") {
606 event = Some(aev);
607 break;
608 }
609 }
610 }
611 Ok(0)
613}
614
615fn add_event_and_auth_chain_to_graph<E: Event>(
616 graph: &mut HashMap<E::Id, HashSet<E::Id>>,
617 event_id: E::Id,
618 auth_diff: &HashSet<E::Id>,
619 fetch_event: impl Fn(&EventId) -> Option<E>,
620) {
621 let mut state = vec![event_id];
622 while let Some(eid) = state.pop() {
623 graph.entry(eid.clone()).or_default();
624 for aid in
626 fetch_event(eid.borrow()).as_ref().map(|ev| ev.auth_events()).into_iter().flatten()
627 {
628 if auth_diff.contains(aid.borrow()) {
629 if !graph.contains_key(aid.borrow()) {
630 state.push(aid.to_owned());
631 }
632
633 graph.get_mut(eid.borrow()).unwrap().insert(aid.to_owned());
635 }
636 }
637 }
638}
639
640fn is_power_event_id<E: Event>(event_id: &EventId, fetch: impl Fn(&EventId) -> Option<E>) -> bool {
641 match fetch(event_id).as_ref() {
642 Some(state) => is_power_event(state),
643 _ => false,
644 }
645}
646
647fn is_type_and_key(ev: impl Event, ev_type: &TimelineEventType, state_key: &str) -> bool {
648 ev.event_type() == ev_type && ev.state_key() == Some(state_key)
649}
650
651fn is_power_event(event: impl Event) -> bool {
652 match event.event_type() {
653 TimelineEventType::RoomPowerLevels
654 | TimelineEventType::RoomJoinRules
655 | TimelineEventType::RoomCreate => event.state_key() == Some(""),
656 TimelineEventType::RoomMember => {
657 let room_member_event = RoomMemberEvent::new(event);
658 if room_member_event.membership().is_ok_and(|membership| {
659 matches!(membership, MembershipState::Leave | MembershipState::Ban)
660 }) {
661 return Some(room_member_event.sender().as_str()) != room_member_event.state_key();
662 }
663
664 false
665 }
666 _ => false,
667 }
668}
669
670trait EventTypeExt {
672 fn with_state_key(self, state_key: impl Into<String>) -> (StateEventType, String);
673}
674
675impl EventTypeExt for StateEventType {
676 fn with_state_key(self, state_key: impl Into<String>) -> (StateEventType, String) {
677 (self, state_key.into())
678 }
679}
680
681impl EventTypeExt for TimelineEventType {
682 fn with_state_key(self, state_key: impl Into<String>) -> (StateEventType, String) {
683 (self.to_string().into(), state_key.into())
684 }
685}
686
687impl<T> EventTypeExt for &T
688where
689 T: EventTypeExt + Clone,
690{
691 fn with_state_key(self, state_key: impl Into<String>) -> (StateEventType, String) {
692 self.to_owned().with_state_key(state_key)
693 }
694}
695
696#[cfg(test)]
697mod tests {
698 use std::{
699 collections::{HashMap, HashSet},
700 sync::Arc,
701 };
702
703 use js_int::{int, uint};
704 use maplit::{hashmap, hashset};
705 use rand::seq::SliceRandom;
706 use ruma_common::{
707 room_version_rules::AuthorizationRules, MilliSecondsSinceUnixEpoch, OwnedEventId,
708 };
709 use ruma_events::{
710 room::join_rules::{JoinRule, RoomJoinRulesEventContent},
711 StateEventType, TimelineEventType,
712 };
713 use serde_json::{json, value::to_raw_value as to_raw_json_value};
714 use tracing::debug;
715
716 use crate::{
717 is_power_event,
718 test_utils::{
719 alice, bob, charlie, do_check, ella, event_id, member_content_ban, member_content_join,
720 room_id, to_init_pdu_event, to_pdu_event, zara, PduEvent, TestStore, INITIAL_EVENTS,
721 },
722 Event, EventTypeExt, StateMap,
723 };
724
725 fn test_event_sort() {
726 let _ =
727 tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish());
728 let events = INITIAL_EVENTS();
729
730 let event_map = events
731 .values()
732 .map(|ev| (ev.event_type().with_state_key(ev.state_key().unwrap()), ev.clone()))
733 .collect::<StateMap<_>>();
734
735 let auth_chain: HashSet<OwnedEventId> = HashSet::new();
736
737 let power_events = event_map
738 .values()
739 .filter(|&pdu| is_power_event(&**pdu))
740 .map(|pdu| pdu.event_id.clone())
741 .collect::<Vec<_>>();
742
743 let sorted_power_events = crate::reverse_topological_power_sort(
744 power_events,
745 &auth_chain,
746 &AuthorizationRules::V6,
747 |id| events.get(id).cloned(),
748 )
749 .unwrap();
750
751 let resolved_power = crate::iterative_auth_check(
752 &AuthorizationRules::V6,
753 &sorted_power_events,
754 HashMap::new(), |id| events.get(id).cloned(),
756 )
757 .expect("iterative auth check failed on resolved events");
758
759 let mut events_to_sort = events.keys().cloned().collect::<Vec<_>>();
761
762 events_to_sort.shuffle(&mut rand::thread_rng());
763
764 let power_level =
765 resolved_power.get(&(StateEventType::RoomPowerLevels, "".to_owned())).cloned();
766
767 let sorted_event_ids =
768 crate::mainline_sort(&events_to_sort, power_level, |id| events.get(id).cloned())
769 .unwrap();
770
771 assert_eq!(
772 vec![
773 "$CREATE:foo",
774 "$IMA:foo",
775 "$IPOWER:foo",
776 "$IJR:foo",
777 "$IMB:foo",
778 "$IMC:foo",
779 "$START:foo",
780 "$END:foo"
781 ],
782 sorted_event_ids.iter().map(|id| id.to_string()).collect::<Vec<_>>()
783 );
784 }
785
786 #[test]
787 fn test_sort() {
788 for _ in 0..20 {
789 test_event_sort();
792 }
793 }
794
795 #[test]
796 fn ban_vs_power_level() {
797 let _ =
798 tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish());
799
800 let events = &[
801 to_init_pdu_event(
802 "PA",
803 alice(),
804 TimelineEventType::RoomPowerLevels,
805 Some(""),
806 to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 50 } })).unwrap(),
807 ),
808 to_init_pdu_event(
809 "MA",
810 alice(),
811 TimelineEventType::RoomMember,
812 Some(alice().to_string().as_str()),
813 member_content_join(),
814 ),
815 to_init_pdu_event(
816 "MB",
817 alice(),
818 TimelineEventType::RoomMember,
819 Some(bob().to_string().as_str()),
820 member_content_ban(),
821 ),
822 to_init_pdu_event(
823 "PB",
824 bob(),
825 TimelineEventType::RoomPowerLevels,
826 Some(""),
827 to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 50 } })).unwrap(),
828 ),
829 ];
830
831 let edges = vec![vec!["END", "MB", "MA", "PA", "START"], vec!["END", "PA", "PB"]]
832 .into_iter()
833 .map(|list| list.into_iter().map(event_id).collect::<Vec<_>>())
834 .collect::<Vec<_>>();
835
836 let expected_state_ids =
837 vec!["PA", "MA", "MB"].into_iter().map(event_id).collect::<Vec<_>>();
838
839 do_check(events, edges, expected_state_ids);
840 }
841
842 #[test]
843 fn topic_basic() {
844 let _ =
845 tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish());
846
847 let events = &[
848 to_init_pdu_event(
849 "T1",
850 alice(),
851 TimelineEventType::RoomTopic,
852 Some(""),
853 to_raw_json_value(&json!({})).unwrap(),
854 ),
855 to_init_pdu_event(
856 "PA1",
857 alice(),
858 TimelineEventType::RoomPowerLevels,
859 Some(""),
860 to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 50 } })).unwrap(),
861 ),
862 to_init_pdu_event(
863 "T2",
864 alice(),
865 TimelineEventType::RoomTopic,
866 Some(""),
867 to_raw_json_value(&json!({})).unwrap(),
868 ),
869 to_init_pdu_event(
870 "PA2",
871 alice(),
872 TimelineEventType::RoomPowerLevels,
873 Some(""),
874 to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 0 } })).unwrap(),
875 ),
876 to_init_pdu_event(
877 "PB",
878 bob(),
879 TimelineEventType::RoomPowerLevels,
880 Some(""),
881 to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 50 } })).unwrap(),
882 ),
883 to_init_pdu_event(
884 "T3",
885 bob(),
886 TimelineEventType::RoomTopic,
887 Some(""),
888 to_raw_json_value(&json!({})).unwrap(),
889 ),
890 ];
891
892 let edges =
893 vec![vec!["END", "PA2", "T2", "PA1", "T1", "START"], vec!["END", "T3", "PB", "PA1"]]
894 .into_iter()
895 .map(|list| list.into_iter().map(event_id).collect::<Vec<_>>())
896 .collect::<Vec<_>>();
897
898 let expected_state_ids = vec!["PA2", "T2"].into_iter().map(event_id).collect::<Vec<_>>();
899
900 do_check(events, edges, expected_state_ids);
901 }
902
903 #[test]
904 fn topic_reset() {
905 let _ =
906 tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish());
907
908 let events = &[
909 to_init_pdu_event(
910 "T1",
911 alice(),
912 TimelineEventType::RoomTopic,
913 Some(""),
914 to_raw_json_value(&json!({})).unwrap(),
915 ),
916 to_init_pdu_event(
917 "PA",
918 alice(),
919 TimelineEventType::RoomPowerLevels,
920 Some(""),
921 to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 50 } })).unwrap(),
922 ),
923 to_init_pdu_event(
924 "T2",
925 bob(),
926 TimelineEventType::RoomTopic,
927 Some(""),
928 to_raw_json_value(&json!({})).unwrap(),
929 ),
930 to_init_pdu_event(
931 "MB",
932 alice(),
933 TimelineEventType::RoomMember,
934 Some(bob().to_string().as_str()),
935 member_content_ban(),
936 ),
937 ];
938
939 let edges = vec![vec!["END", "MB", "T2", "PA", "T1", "START"], vec!["END", "T1"]]
940 .into_iter()
941 .map(|list| list.into_iter().map(event_id).collect::<Vec<_>>())
942 .collect::<Vec<_>>();
943
944 let expected_state_ids =
945 vec!["T1", "MB", "PA"].into_iter().map(event_id).collect::<Vec<_>>();
946
947 do_check(events, edges, expected_state_ids);
948 }
949
950 #[test]
951 fn join_rule_evasion() {
952 let _ =
953 tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish());
954
955 let events = &[
956 to_init_pdu_event(
957 "JR",
958 alice(),
959 TimelineEventType::RoomJoinRules,
960 Some(""),
961 to_raw_json_value(&RoomJoinRulesEventContent::new(JoinRule::Private)).unwrap(),
962 ),
963 to_init_pdu_event(
964 "ME",
965 ella(),
966 TimelineEventType::RoomMember,
967 Some(ella().to_string().as_str()),
968 member_content_join(),
969 ),
970 ];
971
972 let edges = vec![vec!["END", "JR", "START"], vec!["END", "ME", "START"]]
973 .into_iter()
974 .map(|list| list.into_iter().map(event_id).collect::<Vec<_>>())
975 .collect::<Vec<_>>();
976
977 let expected_state_ids = vec![event_id("JR")];
978
979 do_check(events, edges, expected_state_ids);
980 }
981
982 #[test]
983 fn offtopic_power_level() {
984 let _ =
985 tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish());
986
987 let events = &[
988 to_init_pdu_event(
989 "PA",
990 alice(),
991 TimelineEventType::RoomPowerLevels,
992 Some(""),
993 to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 50 } })).unwrap(),
994 ),
995 to_init_pdu_event(
996 "PB",
997 bob(),
998 TimelineEventType::RoomPowerLevels,
999 Some(""),
1000 to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 50, charlie(): 50 } }))
1001 .unwrap(),
1002 ),
1003 to_init_pdu_event(
1004 "PC",
1005 charlie(),
1006 TimelineEventType::RoomPowerLevels,
1007 Some(""),
1008 to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 50, charlie(): 0 } }))
1009 .unwrap(),
1010 ),
1011 ];
1012
1013 let edges = vec![vec!["END", "PC", "PB", "PA", "START"], vec!["END", "PA"]]
1014 .into_iter()
1015 .map(|list| list.into_iter().map(event_id).collect::<Vec<_>>())
1016 .collect::<Vec<_>>();
1017
1018 let expected_state_ids = vec!["PC"].into_iter().map(event_id).collect::<Vec<_>>();
1019
1020 do_check(events, edges, expected_state_ids);
1021 }
1022
1023 #[test]
1024 fn topic_setting() {
1025 let _ =
1026 tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish());
1027
1028 let events = &[
1029 to_init_pdu_event(
1030 "T1",
1031 alice(),
1032 TimelineEventType::RoomTopic,
1033 Some(""),
1034 to_raw_json_value(&json!({})).unwrap(),
1035 ),
1036 to_init_pdu_event(
1037 "PA1",
1038 alice(),
1039 TimelineEventType::RoomPowerLevels,
1040 Some(""),
1041 to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 50 } })).unwrap(),
1042 ),
1043 to_init_pdu_event(
1044 "T2",
1045 alice(),
1046 TimelineEventType::RoomTopic,
1047 Some(""),
1048 to_raw_json_value(&json!({})).unwrap(),
1049 ),
1050 to_init_pdu_event(
1051 "PA2",
1052 alice(),
1053 TimelineEventType::RoomPowerLevels,
1054 Some(""),
1055 to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 0 } })).unwrap(),
1056 ),
1057 to_init_pdu_event(
1058 "PB",
1059 bob(),
1060 TimelineEventType::RoomPowerLevels,
1061 Some(""),
1062 to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 50 } })).unwrap(),
1063 ),
1064 to_init_pdu_event(
1065 "T3",
1066 bob(),
1067 TimelineEventType::RoomTopic,
1068 Some(""),
1069 to_raw_json_value(&json!({})).unwrap(),
1070 ),
1071 to_init_pdu_event(
1072 "MZ1",
1073 zara(),
1074 TimelineEventType::RoomTopic,
1075 Some(""),
1076 to_raw_json_value(&json!({})).unwrap(),
1077 ),
1078 to_init_pdu_event(
1079 "T4",
1080 alice(),
1081 TimelineEventType::RoomTopic,
1082 Some(""),
1083 to_raw_json_value(&json!({})).unwrap(),
1084 ),
1085 ];
1086
1087 let edges = vec![
1088 vec!["END", "T4", "MZ1", "PA2", "T2", "PA1", "T1", "START"],
1089 vec!["END", "MZ1", "T3", "PB", "PA1"],
1090 ]
1091 .into_iter()
1092 .map(|list| list.into_iter().map(event_id).collect::<Vec<_>>())
1093 .collect::<Vec<_>>();
1094
1095 let expected_state_ids = vec!["T4", "PA2"].into_iter().map(event_id).collect::<Vec<_>>();
1096
1097 do_check(events, edges, expected_state_ids);
1098 }
1099
1100 #[test]
1101 fn test_event_map_none() {
1102 let _ =
1103 tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish());
1104
1105 let mut store = TestStore::<PduEvent>(hashmap! {});
1106
1107 let (state_at_bob, state_at_charlie, expected) = store.set_up();
1109
1110 let ev_map = store.0.clone();
1111 let state_sets = [state_at_bob, state_at_charlie];
1112 let resolved = match crate::resolve(
1113 &AuthorizationRules::V1,
1114 &state_sets,
1115 state_sets
1116 .iter()
1117 .map(|map| {
1118 store.auth_event_ids(room_id(), map.values().cloned().collect()).unwrap()
1119 })
1120 .collect(),
1121 |id| ev_map.get(id).cloned(),
1122 ) {
1123 Ok(state) => state,
1124 Err(e) => panic!("{e}"),
1125 };
1126
1127 assert_eq!(expected, resolved);
1128 }
1129
1130 #[test]
1131 fn test_lexicographical_sort() {
1132 let _ =
1133 tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish());
1134
1135 let graph = hashmap! {
1136 event_id("l") => hashset![event_id("o")],
1137 event_id("m") => hashset![event_id("n"), event_id("o")],
1138 event_id("n") => hashset![event_id("o")],
1139 event_id("o") => hashset![], event_id("p") => hashset![event_id("o")],
1141 };
1142
1143 let res = crate::lexicographical_topological_sort(&graph, |_id| {
1144 Ok((int!(0), MilliSecondsSinceUnixEpoch(uint!(0))))
1145 })
1146 .unwrap();
1147
1148 assert_eq!(
1149 vec!["o", "l", "n", "m", "p"],
1150 res.iter()
1151 .map(ToString::to_string)
1152 .map(|s| s.replace('$', "").replace(":foo", ""))
1153 .collect::<Vec<_>>()
1154 );
1155 }
1156
1157 #[test]
1158 fn ban_with_auth_chains() {
1159 let _ =
1160 tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish());
1161 let ban = BAN_STATE_SET();
1162
1163 let edges = vec![vec!["END", "MB", "PA", "START"], vec!["END", "IME", "MB"]]
1164 .into_iter()
1165 .map(|list| list.into_iter().map(event_id).collect::<Vec<_>>())
1166 .collect::<Vec<_>>();
1167
1168 let expected_state_ids = vec!["PA", "MB"].into_iter().map(event_id).collect::<Vec<_>>();
1169
1170 do_check(&ban.values().cloned().collect::<Vec<_>>(), edges, expected_state_ids);
1171 }
1172
1173 #[test]
1174 fn ban_with_auth_chains2() {
1175 let _ =
1176 tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish());
1177 let init = INITIAL_EVENTS();
1178 let ban = BAN_STATE_SET();
1179
1180 let mut inner = init.clone();
1181 inner.extend(ban);
1182 let store = TestStore(inner.clone());
1183
1184 let state_set_a = [
1185 inner.get(&event_id("CREATE")).unwrap(),
1186 inner.get(&event_id("IJR")).unwrap(),
1187 inner.get(&event_id("IMA")).unwrap(),
1188 inner.get(&event_id("IMB")).unwrap(),
1189 inner.get(&event_id("IMC")).unwrap(),
1190 inner.get(&event_id("MB")).unwrap(),
1191 inner.get(&event_id("PA")).unwrap(),
1192 ]
1193 .iter()
1194 .map(|ev| (ev.event_type().with_state_key(ev.state_key().unwrap()), ev.event_id.clone()))
1195 .collect::<StateMap<_>>();
1196
1197 let state_set_b = [
1198 inner.get(&event_id("CREATE")).unwrap(),
1199 inner.get(&event_id("IJR")).unwrap(),
1200 inner.get(&event_id("IMA")).unwrap(),
1201 inner.get(&event_id("IMB")).unwrap(),
1202 inner.get(&event_id("IMC")).unwrap(),
1203 inner.get(&event_id("IME")).unwrap(),
1204 inner.get(&event_id("PA")).unwrap(),
1205 ]
1206 .iter()
1207 .map(|ev| (ev.event_type().with_state_key(ev.state_key().unwrap()), ev.event_id.clone()))
1208 .collect::<StateMap<_>>();
1209
1210 let ev_map = &store.0;
1211 let state_sets = [state_set_a, state_set_b];
1212 let resolved = match crate::resolve(
1213 &AuthorizationRules::V6,
1214 &state_sets,
1215 state_sets
1216 .iter()
1217 .map(|map| {
1218 store.auth_event_ids(room_id(), map.values().cloned().collect()).unwrap()
1219 })
1220 .collect(),
1221 |id| ev_map.get(id).cloned(),
1222 ) {
1223 Ok(state) => state,
1224 Err(e) => panic!("{e}"),
1225 };
1226
1227 debug!(
1228 resolved = ?resolved
1229 .iter()
1230 .map(|((ty, key), id)| format!("(({ty}{key:?}), {id})"))
1231 .collect::<Vec<_>>(),
1232 "resolved state",
1233 );
1234
1235 let expected =
1236 ["$CREATE:foo", "$IJR:foo", "$PA:foo", "$IMA:foo", "$IMB:foo", "$IMC:foo", "$MB:foo"];
1237
1238 for id in expected.iter().map(|i| event_id(i)) {
1239 assert!(resolved.values().any(|eid| eid == &id) || init.contains_key(&id), "{id}");
1241 }
1242 assert_eq!(expected.len(), resolved.len());
1243 }
1244
1245 #[test]
1246 fn join_rule_with_auth_chain() {
1247 let join_rule = JOIN_RULE();
1248
1249 let edges = vec![vec!["END", "JR", "START"], vec!["END", "IMZ", "START"]]
1250 .into_iter()
1251 .map(|list| list.into_iter().map(event_id).collect::<Vec<_>>())
1252 .collect::<Vec<_>>();
1253
1254 let expected_state_ids = vec!["JR"].into_iter().map(event_id).collect::<Vec<_>>();
1255
1256 do_check(&join_rule.values().cloned().collect::<Vec<_>>(), edges, expected_state_ids);
1257 }
1258
1259 #[allow(non_snake_case)]
1260 fn BAN_STATE_SET() -> HashMap<OwnedEventId, Arc<PduEvent>> {
1261 vec![
1262 to_pdu_event(
1263 "PA",
1264 alice(),
1265 TimelineEventType::RoomPowerLevels,
1266 Some(""),
1267 to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 50 } })).unwrap(),
1268 &["CREATE", "IMA", "IPOWER"], &["START"], ),
1271 to_pdu_event(
1272 "PB",
1273 alice(),
1274 TimelineEventType::RoomPowerLevels,
1275 Some(""),
1276 to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 50 } })).unwrap(),
1277 &["CREATE", "IMA", "IPOWER"],
1278 &["END"],
1279 ),
1280 to_pdu_event(
1281 "MB",
1282 alice(),
1283 TimelineEventType::RoomMember,
1284 Some(ella().as_str()),
1285 member_content_ban(),
1286 &["CREATE", "IMA", "PB"],
1287 &["PA"],
1288 ),
1289 to_pdu_event(
1290 "IME",
1291 ella(),
1292 TimelineEventType::RoomMember,
1293 Some(ella().as_str()),
1294 member_content_join(),
1295 &["CREATE", "IJR", "PA"],
1296 &["MB"],
1297 ),
1298 ]
1299 .into_iter()
1300 .map(|ev| (ev.event_id.clone(), ev))
1301 .collect()
1302 }
1303
1304 #[allow(non_snake_case)]
1305 fn JOIN_RULE() -> HashMap<OwnedEventId, Arc<PduEvent>> {
1306 vec![
1307 to_pdu_event(
1308 "JR",
1309 alice(),
1310 TimelineEventType::RoomJoinRules,
1311 Some(""),
1312 to_raw_json_value(&json!({ "join_rule": "invite" })).unwrap(),
1313 &["CREATE", "IMA", "IPOWER"],
1314 &["START"],
1315 ),
1316 to_pdu_event(
1317 "IMZ",
1318 zara(),
1319 TimelineEventType::RoomPowerLevels,
1320 Some(zara().as_str()),
1321 member_content_join(),
1322 &["CREATE", "JR", "IPOWER"],
1323 &["START"],
1324 ),
1325 ]
1326 .into_iter()
1327 .map(|ev| (ev.event_id.clone(), ev))
1328 .collect()
1329 }
1330
1331 macro_rules! state_set {
1332 ($($kind:expr => $key:expr => $id:expr),* $(,)?) => {{
1333 #[allow(unused_mut)]
1334 let mut x = StateMap::new();
1335 $(
1336 x.insert(($kind, $key.to_owned()), $id);
1337 )*
1338 x
1339 }};
1340 }
1341
1342 #[test]
1343 fn separate_unique_conflicted() {
1344 let (unconflicted, conflicted) = super::separate(
1345 [
1346 state_set![StateEventType::RoomMember => "@a:hs1" => 0],
1347 state_set![StateEventType::RoomMember => "@b:hs1" => 1],
1348 state_set![StateEventType::RoomMember => "@c:hs1" => 2],
1349 ]
1350 .iter(),
1351 );
1352
1353 assert_eq!(unconflicted, StateMap::new());
1354 assert_eq!(
1355 conflicted,
1356 state_set![
1357 StateEventType::RoomMember => "@a:hs1" => vec![0],
1358 StateEventType::RoomMember => "@b:hs1" => vec![1],
1359 StateEventType::RoomMember => "@c:hs1" => vec![2],
1360 ],
1361 );
1362 }
1363
1364 #[test]
1365 fn separate_conflicted() {
1366 let (unconflicted, mut conflicted) = super::separate(
1367 [
1368 state_set![StateEventType::RoomMember => "@a:hs1" => 0],
1369 state_set![StateEventType::RoomMember => "@a:hs1" => 1],
1370 state_set![StateEventType::RoomMember => "@a:hs1" => 2],
1371 ]
1372 .iter(),
1373 );
1374
1375 for v in conflicted.values_mut() {
1377 v.sort_unstable();
1378 }
1379
1380 assert_eq!(unconflicted, StateMap::new());
1381 assert_eq!(
1382 conflicted,
1383 state_set![
1384 StateEventType::RoomMember => "@a:hs1" => vec![0, 1, 2],
1385 ],
1386 );
1387 }
1388
1389 #[test]
1390 fn separate_unconflicted() {
1391 let (unconflicted, conflicted) = super::separate(
1392 [
1393 state_set![StateEventType::RoomMember => "@a:hs1" => 0],
1394 state_set![StateEventType::RoomMember => "@a:hs1" => 0],
1395 state_set![StateEventType::RoomMember => "@a:hs1" => 0],
1396 ]
1397 .iter(),
1398 );
1399
1400 assert_eq!(
1401 unconflicted,
1402 state_set![
1403 StateEventType::RoomMember => "@a:hs1" => 0,
1404 ],
1405 );
1406 assert_eq!(conflicted, StateMap::new());
1407 }
1408
1409 #[test]
1410 fn separate_mixed() {
1411 let (unconflicted, conflicted) = super::separate(
1412 [
1413 state_set![StateEventType::RoomMember => "@a:hs1" => 0],
1414 state_set![
1415 StateEventType::RoomMember => "@a:hs1" => 0,
1416 StateEventType::RoomMember => "@b:hs1" => 1,
1417 ],
1418 state_set![
1419 StateEventType::RoomMember => "@a:hs1" => 0,
1420 StateEventType::RoomMember => "@c:hs1" => 2,
1421 ],
1422 ]
1423 .iter(),
1424 );
1425
1426 assert_eq!(
1427 unconflicted,
1428 state_set![
1429 StateEventType::RoomMember => "@a:hs1" => 0,
1430 ],
1431 );
1432 assert_eq!(
1433 conflicted,
1434 state_set![
1435 StateEventType::RoomMember => "@b:hs1" => vec![1],
1436 StateEventType::RoomMember => "@c:hs1" => vec![2],
1437 ],
1438 );
1439 }
1440}