ruma_state_res/
lib.rs

1use std::{
2    borrow::Borrow,
3    cmp::{Ordering, Reverse},
4    collections::{BinaryHeap, HashMap, HashSet},
5    hash::Hash,
6    sync::OnceLock,
7};
8
9use js_int::Int;
10use ruma_common::{
11    room_version_rules::AuthorizationRules, EventId, MilliSecondsSinceUnixEpoch, OwnedUserId,
12};
13use ruma_events::{room::member::MembershipState, StateEventType, TimelineEventType};
14use tracing::{debug, info, instrument, trace, warn};
15
16mod error;
17pub mod event_auth;
18pub mod events;
19#[cfg(test)]
20mod test_utils;
21
22use self::events::{
23    member::RoomMemberEvent, power_levels::RoomPowerLevelsEventOptionExt, RoomCreateEvent,
24    RoomPowerLevelsEvent,
25};
26pub use self::{
27    error::{Error, Result},
28    event_auth::{auth_check, auth_types_for_event},
29    events::Event,
30};
31
32/// A mapping of event type and state_key to some value `T`, usually an `EventId`.
33pub type StateMap<T> = HashMap<(StateEventType, String), T>;
34
35/// Resolve sets of state events as they come in.
36///
37/// Internally `StateResolution` builds a graph and an auth chain to allow for state conflict
38/// resolution.
39///
40/// ## Arguments
41///
42/// * `rules` - The rules to apply for the version of the current room.
43///
44/// * `state_sets` - The incoming state to resolve. Each `StateMap` represents a possible fork in
45///   the state of a room.
46///
47/// * `auth_chain_sets` - The full recursive set of `auth_events` for each event in the
48///   `state_sets`.
49///
50/// * `fetch_event` - Any event not found in the `event_map` will defer to this closure to find the
51///   event.
52///
53/// ## Invariants
54///
55/// The caller of `resolve` must ensure that all the events are from the same room. Although this
56/// function takes a `RoomId` it does not check that each event is part of the same room.
57#[instrument(skip(rules, state_sets, auth_chain_sets, fetch_event))]
58pub fn resolve<'a, E, SetIter>(
59    rules: &AuthorizationRules,
60    state_sets: impl IntoIterator<IntoIter = SetIter>,
61    auth_chain_sets: Vec<HashSet<E::Id>>,
62    fetch_event: impl Fn(&EventId) -> Option<E>,
63) -> Result<StateMap<E::Id>>
64where
65    E: Event + Clone,
66    E::Id: 'a,
67    SetIter: Iterator<Item = &'a StateMap<E::Id>> + Clone,
68{
69    info!("state resolution starting");
70
71    // Split non-conflicting and conflicting state
72    let (clean, conflicting) = separate(state_sets.into_iter());
73
74    info!(count = clean.len(), "non-conflicting events");
75    trace!(map = ?clean, "non-conflicting events");
76
77    if conflicting.is_empty() {
78        info!("no conflicting state found");
79        return Ok(clean);
80    }
81
82    info!(count = conflicting.len(), "conflicting events");
83    trace!(map = ?conflicting, "conflicting events");
84
85    // `all_conflicted` contains unique items
86    // synapse says `full_set = {eid for eid in full_conflicted_set if eid in event_map}`
87    let all_conflicted: HashSet<_> = get_auth_chain_diff(auth_chain_sets)
88        .chain(conflicting.into_values().flatten())
89        // Don't honor events we cannot "verify"
90        .filter(|id| fetch_event(id.borrow()).is_some())
91        .collect();
92
93    info!(count = all_conflicted.len(), "full conflicted set");
94    trace!(set = ?all_conflicted, "full conflicted set");
95
96    // We used to check that all events are events from the correct room
97    // this is now a check the caller of `resolve` must make.
98
99    // Get only the control events with a state_key: "" or ban/kick event (sender != state_key)
100    let control_events = all_conflicted
101        .iter()
102        .filter(|&id| is_power_event_id(id.borrow(), &fetch_event))
103        .cloned()
104        .collect::<Vec<_>>();
105
106    // Sort the control events based on power_level/clock/event_id and outgoing/incoming edges
107    let sorted_control_levels =
108        reverse_topological_power_sort(control_events, &all_conflicted, rules, &fetch_event)?;
109
110    debug!(count = sorted_control_levels.len(), "power events");
111    trace!(list = ?sorted_control_levels, "sorted power events");
112
113    // Sequentially auth check each control event.
114    let resolved_control =
115        iterative_auth_check(rules, &sorted_control_levels, clean.clone(), &fetch_event)?;
116
117    debug!(count = resolved_control.len(), "resolved power events");
118    trace!(map = ?resolved_control, "resolved power events");
119
120    // At this point the control_events have been resolved we now have to
121    // sort the remaining events using the mainline of the resolved power level.
122    let deduped_power_ev = sorted_control_levels.into_iter().collect::<HashSet<_>>();
123
124    // This removes the control events that passed auth and more importantly those that failed
125    // auth
126    let events_to_resolve = all_conflicted
127        .iter()
128        .filter(|&id| !deduped_power_ev.contains(id.borrow()))
129        .cloned()
130        .collect::<Vec<_>>();
131
132    debug!(count = events_to_resolve.len(), "events left to resolve");
133    trace!(list = ?events_to_resolve, "events left to resolve");
134
135    // This "epochs" power level event
136    let power_event = resolved_control.get(&(StateEventType::RoomPowerLevels, "".into()));
137
138    debug!(event_id = ?power_event, "power event");
139
140    let sorted_left_events = mainline_sort(&events_to_resolve, power_event.cloned(), &fetch_event)?;
141
142    trace!(list = ?sorted_left_events, "events left, sorted");
143
144    let mut resolved_state = iterative_auth_check(
145        rules,
146        &sorted_left_events,
147        resolved_control, // The control events are added to the final resolved state
148        &fetch_event,
149    )?;
150
151    // Add unconflicted state to the resolved state
152    // We priorities the unconflicting state
153    resolved_state.extend(clean);
154
155    info!("state resolution finished");
156
157    Ok(resolved_state)
158}
159
160/// Split the events that have no conflicts from those that are conflicting.
161///
162/// The return tuple looks like `(unconflicted, conflicted)`.
163///
164/// State is determined to be conflicting if for the given key (StateEventType, StateKey) there is
165/// not exactly one event ID. This includes missing events, if one state_set includes an event that
166/// none of the other have this is a conflicting event.
167fn separate<'a, Id>(
168    state_sets_iter: impl Iterator<Item = &'a StateMap<Id>>,
169) -> (StateMap<Id>, StateMap<Vec<Id>>)
170where
171    Id: Clone + Eq + Hash + 'a,
172{
173    let mut state_set_count = 0_usize;
174    let mut occurrences = HashMap::<_, HashMap<_, _>>::new();
175
176    let state_sets_iter = state_sets_iter.inspect(|_| state_set_count += 1);
177    for (k, v) in state_sets_iter.flatten() {
178        occurrences.entry(k).or_default().entry(v).and_modify(|x| *x += 1).or_insert(1);
179    }
180
181    let mut unconflicted_state = StateMap::new();
182    let mut conflicted_state = StateMap::new();
183
184    for (k, v) in occurrences {
185        for (id, occurrence_count) in v {
186            if occurrence_count == state_set_count {
187                unconflicted_state.insert((k.0.clone(), k.1.clone()), id.clone());
188            } else {
189                conflicted_state
190                    .entry((k.0.clone(), k.1.clone()))
191                    .and_modify(|x: &mut Vec<_>| x.push(id.clone()))
192                    .or_insert(vec![id.clone()]);
193            }
194        }
195    }
196
197    (unconflicted_state, conflicted_state)
198}
199
200/// Returns a Vec of deduped EventIds that appear in some chains but not others.
201fn get_auth_chain_diff<Id>(auth_chain_sets: Vec<HashSet<Id>>) -> impl Iterator<Item = Id>
202where
203    Id: Eq + Hash,
204{
205    let num_sets = auth_chain_sets.len();
206
207    let mut id_counts: HashMap<Id, usize> = HashMap::new();
208    for id in auth_chain_sets.into_iter().flatten() {
209        *id_counts.entry(id).or_default() += 1;
210    }
211
212    id_counts.into_iter().filter_map(move |(id, count)| (count < num_sets).then_some(id))
213}
214
215/// Events are sorted from "earliest" to "latest".
216///
217/// They are compared using the negative power level (reverse topological ordering), the origin
218/// server timestamp and in case of a tie the `EventId`s are compared lexicographically.
219///
220/// The power level is negative because a higher power level is equated to an earlier (further back
221/// in time) origin server timestamp.
222#[instrument(skip_all)]
223fn reverse_topological_power_sort<E: Event>(
224    events_to_sort: Vec<E::Id>,
225    auth_diff: &HashSet<E::Id>,
226    rules: &AuthorizationRules,
227    fetch_event: impl Fn(&EventId) -> Option<E>,
228) -> Result<Vec<E::Id>> {
229    debug!("reverse topological sort of power events");
230
231    let mut graph = HashMap::new();
232    for event_id in events_to_sort {
233        add_event_and_auth_chain_to_graph(&mut graph, event_id, auth_diff, &fetch_event);
234
235        // TODO: if these functions are ever made async here
236        // is a good place to yield every once in a while so other
237        // tasks can make progress
238    }
239
240    // This is used in the `key_fn` passed to the lexico_topo_sort fn
241    let mut event_to_pl = HashMap::new();
242    // We need to know the creator in case of missing power levels. Given that it's the same for all
243    // the events in the room, we will just load it for the first event and reuse it.
244    let creator_lock = OnceLock::new();
245
246    for event_id in graph.keys() {
247        let pl = get_power_level_for_sender(event_id.borrow(), rules, &creator_lock, &fetch_event)
248            .map_err(Error::AuthEvent)?;
249        debug!(
250            event_id = event_id.borrow().as_str(),
251            power_level = i64::from(pl),
252            "found the power level of an event's sender",
253        );
254
255        event_to_pl.insert(event_id.clone(), pl);
256
257        // TODO: if these functions are ever made async here
258        // is a good place to yield every once in a while so other
259        // tasks can make progress
260    }
261
262    lexicographical_topological_sort(&graph, |event_id| {
263        let ev = fetch_event(event_id).ok_or_else(|| Error::NotFound(event_id.to_owned()))?;
264        let pl = *event_to_pl.get(event_id).ok_or_else(|| Error::NotFound(event_id.to_owned()))?;
265        Ok((pl, ev.origin_server_ts()))
266    })
267}
268
269/// Sorts the event graph based on number of outgoing/incoming edges.
270///
271/// `key_fn` is used as to obtain the power level and age of an event for breaking ties (together
272/// with the event ID).
273#[instrument(skip_all)]
274pub fn lexicographical_topological_sort<Id, F>(
275    graph: &HashMap<Id, HashSet<Id>>,
276    key_fn: F,
277) -> Result<Vec<Id>>
278where
279    F: Fn(&EventId) -> Result<(Int, MilliSecondsSinceUnixEpoch)>,
280    Id: Clone + Eq + Ord + Hash + Borrow<EventId>,
281{
282    #[derive(PartialEq, Eq)]
283    struct TieBreaker<'a, Id> {
284        power_level: Int,
285        origin_server_ts: MilliSecondsSinceUnixEpoch,
286        event_id: &'a Id,
287    }
288
289    impl<Id> Ord for TieBreaker<'_, Id>
290    where
291        Id: Ord,
292    {
293        fn cmp(&self, other: &Self) -> Ordering {
294            // NOTE: the power level comparison is "backwards" intentionally.
295            // See the "Mainline ordering" section of the Matrix specification
296            // around where it says the following:
297            //
298            // > for events `x` and `y`, `x < y` if [...]
299            //
300            // <https://spec.matrix.org/latest/rooms/v11/#definitions>
301            other
302                .power_level
303                .cmp(&self.power_level)
304                .then(self.origin_server_ts.cmp(&other.origin_server_ts))
305                .then(self.event_id.cmp(other.event_id))
306        }
307    }
308
309    impl<Id> PartialOrd for TieBreaker<'_, Id>
310    where
311        Id: Ord,
312    {
313        fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
314            Some(self.cmp(other))
315        }
316    }
317
318    // NOTE: an event that has no incoming edges happened most recently,
319    // and an event that has no outgoing edges happened least recently.
320
321    // NOTE: this is basically Kahn's algorithm except we look at nodes with no
322    // outgoing edges, c.f.
323    // https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm
324
325    // outdegree_map is an event referring to the events before it, the
326    // more outdegree's the more recent the event.
327    let mut outdegree_map = graph.clone();
328
329    // The number of events that depend on the given event (the EventId key)
330    // How many events reference this event in the DAG as a parent
331    let mut reverse_graph: HashMap<_, HashSet<_>> = HashMap::new();
332
333    // Vec of nodes that have zero out degree, least recent events.
334    let mut zero_outdegree = Vec::new();
335
336    for (node, edges) in graph {
337        if edges.is_empty() {
338            let (power_level, origin_server_ts) = key_fn(node.borrow())?;
339            // The `Reverse` is because rusts `BinaryHeap` sorts largest -> smallest we need
340            // smallest -> largest
341            zero_outdegree.push(Reverse(TieBreaker {
342                power_level,
343                origin_server_ts,
344                event_id: node,
345            }));
346        }
347
348        reverse_graph.entry(node).or_default();
349        for edge in edges {
350            reverse_graph.entry(edge).or_default().insert(node);
351        }
352    }
353
354    let mut heap = BinaryHeap::from(zero_outdegree);
355
356    // We remove the oldest node (most incoming edges) and check against all other
357    let mut sorted = vec![];
358    // Destructure the `Reverse` and take the smallest `node` each time
359    while let Some(Reverse(item)) = heap.pop() {
360        let node = item.event_id;
361
362        for &parent in reverse_graph.get(node).expect("EventId in heap is also in reverse_graph") {
363            // The number of outgoing edges this node has
364            let out = outdegree_map
365                .get_mut(parent.borrow())
366                .expect("outdegree_map knows of all referenced EventIds");
367
368            // Only push on the heap once older events have been cleared
369            out.remove(node.borrow());
370            if out.is_empty() {
371                let (power_level, origin_server_ts) = key_fn(parent.borrow())?;
372                heap.push(Reverse(TieBreaker { power_level, origin_server_ts, event_id: parent }));
373            }
374        }
375
376        // synapse yields we push then return the vec
377        sorted.push(node.clone());
378    }
379
380    Ok(sorted)
381}
382
383/// Find the power level for the sender of `event_id` or return a default value of zero.
384///
385/// Do NOT use this any where but topological sort, we find the power level for the eventId
386/// at the eventId's generation (we walk backwards to `EventId`s most recent previous power level
387/// event).
388fn get_power_level_for_sender<E: Event>(
389    event_id: &EventId,
390    rules: &AuthorizationRules,
391    creator_lock: &OnceLock<OwnedUserId>,
392    fetch_event: impl Fn(&EventId) -> Option<E>,
393) -> std::result::Result<Int, String> {
394    let event = fetch_event(event_id);
395    let mut room_create_event = None;
396    let mut room_power_levels_event = None;
397
398    for aid in event.as_ref().map(|pdu| pdu.auth_events()).into_iter().flatten() {
399        if let Some(aev) = fetch_event(aid.borrow()) {
400            if is_type_and_key(&aev, &TimelineEventType::RoomPowerLevels, "") {
401                room_power_levels_event = Some(RoomPowerLevelsEvent::new(aev));
402            } else if creator_lock.get().is_none()
403                && is_type_and_key(&aev, &TimelineEventType::RoomCreate, "")
404            {
405                room_create_event = Some(RoomCreateEvent::new(aev));
406            }
407
408            if room_power_levels_event.is_some()
409                && (creator_lock.get().is_some() || room_create_event.is_some())
410            {
411                break;
412            }
413        }
414    }
415
416    // TODO: Use OnceLock::try_or_get_init when it is stabilized.
417    let creator = if let Some(creator) = creator_lock.get() {
418        Some(creator)
419    } else if let Some(room_create_event) = room_create_event {
420        let creator = room_create_event.creator(rules)?;
421        Some(creator_lock.get_or_init(|| creator.into_owned()))
422    } else {
423        None
424    };
425
426    if let Some((event, creator)) = event.zip(creator) {
427        room_power_levels_event.user_power_level(event.sender(), creator, rules)
428    } else {
429        room_power_levels_event
430            .get_as_int_or_default(events::RoomPowerLevelsIntField::UsersDefault, rules)
431    }
432}
433
434/// Check the that each event is authenticated based on the events before it.
435///
436/// ## Returns
437///
438/// The `unconflicted_state` combined with the newly auth'ed events. So any event that fails the
439/// `event_auth::auth_check` will be excluded from the returned state map.
440///
441/// For each `events_to_check` event we gather the events needed to auth it from the the
442/// `fetch_event` closure and verify each event using the `event_auth::auth_check` function.
443fn iterative_auth_check<E: Event + Clone>(
444    rules: &AuthorizationRules,
445    events_to_check: &[E::Id],
446    unconflicted_state: StateMap<E::Id>,
447    fetch_event: impl Fn(&EventId) -> Option<E>,
448) -> Result<StateMap<E::Id>> {
449    debug!("starting iterative auth check");
450
451    trace!(list = ?events_to_check, "events to check");
452
453    let mut resolved_state = unconflicted_state;
454
455    for event_id in events_to_check {
456        let event = fetch_event(event_id.borrow())
457            .ok_or_else(|| Error::NotFound(event_id.borrow().to_owned()))?;
458        let state_key = event.state_key().ok_or(Error::MissingStateKey)?;
459
460        let mut auth_events = StateMap::new();
461        for aid in event.auth_events() {
462            if let Some(ev) = fetch_event(aid.borrow()) {
463                // TODO synapse check "rejected_reason" which is most likely
464                // related to soft-failing
465                auth_events.insert(
466                    ev.event_type().with_state_key(ev.state_key().ok_or(Error::MissingStateKey)?),
467                    ev,
468                );
469            } else {
470                warn!(event_id = aid.borrow().as_str(), "missing auth event");
471            }
472        }
473
474        let auth_types = match auth_types_for_event(
475            event.event_type(),
476            event.sender(),
477            Some(state_key),
478            event.content(),
479            rules,
480        ) {
481            Ok(auth_types) => auth_types,
482            Err(error) => {
483                warn!("failed to get list of required auth events for malformed event: {error}");
484                continue;
485            }
486        };
487
488        for key in auth_types {
489            if let Some(ev_id) = resolved_state.get(&key) {
490                if let Some(event) = fetch_event(ev_id.borrow()) {
491                    // TODO synapse checks `rejected_reason` is None here
492                    auth_events.insert(key.to_owned(), event);
493                }
494            }
495        }
496
497        match auth_check(rules, &event, |ty, key| auth_events.get(&ty.with_state_key(key))) {
498            Ok(()) => {
499                // Add event to resolved state.
500                resolved_state
501                    .insert(event.event_type().with_state_key(state_key), event_id.clone());
502            }
503            Err(error) => {
504                // Don't add this event to resolved_state.
505                warn!("event failed the authentication check: {error}");
506            }
507        }
508
509        // TODO: if these functions are ever made async here
510        // is a good place to yield every once in a while so other
511        // tasks can make progress
512    }
513    Ok(resolved_state)
514}
515
516/// Returns the sorted `to_sort` list of `EventId`s based on a mainline sort using the depth of
517/// `resolved_power_level`, the server timestamp, and the eventId.
518///
519/// The depth of the given event is calculated based on the depth of it's closest "parent"
520/// power_level event. If there have been two power events the after the most recent are depth 0,
521/// the events before (with the first power level as a parent) will be marked as depth 1. depth 1 is
522/// "older" than depth 0.
523fn mainline_sort<E: Event>(
524    to_sort: &[E::Id],
525    resolved_power_level: Option<E::Id>,
526    fetch_event: impl Fn(&EventId) -> Option<E>,
527) -> Result<Vec<E::Id>> {
528    debug!("mainline sort of events");
529
530    // There are no EventId's to sort, bail.
531    if to_sort.is_empty() {
532        return Ok(vec![]);
533    }
534
535    let mut mainline = vec![];
536    let mut pl = resolved_power_level;
537    while let Some(p) = pl {
538        mainline.push(p.clone());
539
540        let event =
541            fetch_event(p.borrow()).ok_or_else(|| Error::NotFound(p.borrow().to_owned()))?;
542        pl = None;
543        for aid in event.auth_events() {
544            let ev =
545                fetch_event(aid.borrow()).ok_or_else(|| Error::NotFound(p.borrow().to_owned()))?;
546            if is_type_and_key(&ev, &TimelineEventType::RoomPowerLevels, "") {
547                pl = Some(aid.to_owned());
548                break;
549            }
550        }
551        // TODO: if these functions are ever made async here
552        // is a good place to yield every once in a while so other
553        // tasks can make progress
554    }
555
556    let mainline_map = mainline
557        .iter()
558        .rev()
559        .enumerate()
560        .map(|(idx, eid)| ((*eid).clone(), idx))
561        .collect::<HashMap<_, _>>();
562
563    let mut order_map = HashMap::new();
564    for ev_id in to_sort.iter() {
565        if let Some(event) = fetch_event(ev_id.borrow()) {
566            if let Ok(depth) = get_mainline_depth(Some(event), &mainline_map, &fetch_event) {
567                order_map.insert(
568                    ev_id,
569                    (depth, fetch_event(ev_id.borrow()).map(|ev| ev.origin_server_ts()), ev_id),
570                );
571            }
572        }
573
574        // TODO: if these functions are ever made async here
575        // is a good place to yield every once in a while so other
576        // tasks can make progress
577    }
578
579    // Sort the event_ids by their depth, timestamp and EventId
580    // unwrap is OK order map and sort_event_ids are from to_sort (the same Vec)
581    let mut sort_event_ids = order_map.keys().map(|&k| k.clone()).collect::<Vec<_>>();
582    sort_event_ids.sort_by_key(|sort_id| order_map.get(sort_id).unwrap());
583
584    Ok(sort_event_ids)
585}
586
587/// Get the mainline depth from the `mainline_map` or finds a power_level event that has an
588/// associated mainline depth.
589fn get_mainline_depth<E: Event>(
590    mut event: Option<E>,
591    mainline_map: &HashMap<E::Id, usize>,
592    fetch_event: impl Fn(&EventId) -> Option<E>,
593) -> Result<usize> {
594    while let Some(sort_ev) = event {
595        debug!(event_id = sort_ev.event_id().borrow().as_str(), "mainline");
596        let id = sort_ev.event_id();
597        if let Some(depth) = mainline_map.get(id.borrow()) {
598            return Ok(*depth);
599        }
600
601        event = None;
602        for aid in sort_ev.auth_events() {
603            let aev = fetch_event(aid.borrow())
604                .ok_or_else(|| Error::NotFound(aid.borrow().to_owned()))?;
605            if is_type_and_key(&aev, &TimelineEventType::RoomPowerLevels, "") {
606                event = Some(aev);
607                break;
608            }
609        }
610    }
611    // Did not find a power level event so we default to zero
612    Ok(0)
613}
614
615fn add_event_and_auth_chain_to_graph<E: Event>(
616    graph: &mut HashMap<E::Id, HashSet<E::Id>>,
617    event_id: E::Id,
618    auth_diff: &HashSet<E::Id>,
619    fetch_event: impl Fn(&EventId) -> Option<E>,
620) {
621    let mut state = vec![event_id];
622    while let Some(eid) = state.pop() {
623        graph.entry(eid.clone()).or_default();
624        // Prefer the store to event as the store filters dedups the events
625        for aid in
626            fetch_event(eid.borrow()).as_ref().map(|ev| ev.auth_events()).into_iter().flatten()
627        {
628            if auth_diff.contains(aid.borrow()) {
629                if !graph.contains_key(aid.borrow()) {
630                    state.push(aid.to_owned());
631                }
632
633                // We just inserted this at the start of the while loop
634                graph.get_mut(eid.borrow()).unwrap().insert(aid.to_owned());
635            }
636        }
637    }
638}
639
640fn is_power_event_id<E: Event>(event_id: &EventId, fetch: impl Fn(&EventId) -> Option<E>) -> bool {
641    match fetch(event_id).as_ref() {
642        Some(state) => is_power_event(state),
643        _ => false,
644    }
645}
646
647fn is_type_and_key(ev: impl Event, ev_type: &TimelineEventType, state_key: &str) -> bool {
648    ev.event_type() == ev_type && ev.state_key() == Some(state_key)
649}
650
651fn is_power_event(event: impl Event) -> bool {
652    match event.event_type() {
653        TimelineEventType::RoomPowerLevels
654        | TimelineEventType::RoomJoinRules
655        | TimelineEventType::RoomCreate => event.state_key() == Some(""),
656        TimelineEventType::RoomMember => {
657            let room_member_event = RoomMemberEvent::new(event);
658            if room_member_event.membership().is_ok_and(|membership| {
659                matches!(membership, MembershipState::Leave | MembershipState::Ban)
660            }) {
661                return Some(room_member_event.sender().as_str()) != room_member_event.state_key();
662            }
663
664            false
665        }
666        _ => false,
667    }
668}
669
670/// Convenience trait for adding event type plus state key to state maps.
671trait EventTypeExt {
672    fn with_state_key(self, state_key: impl Into<String>) -> (StateEventType, String);
673}
674
675impl EventTypeExt for StateEventType {
676    fn with_state_key(self, state_key: impl Into<String>) -> (StateEventType, String) {
677        (self, state_key.into())
678    }
679}
680
681impl EventTypeExt for TimelineEventType {
682    fn with_state_key(self, state_key: impl Into<String>) -> (StateEventType, String) {
683        (self.to_string().into(), state_key.into())
684    }
685}
686
687impl<T> EventTypeExt for &T
688where
689    T: EventTypeExt + Clone,
690{
691    fn with_state_key(self, state_key: impl Into<String>) -> (StateEventType, String) {
692        self.to_owned().with_state_key(state_key)
693    }
694}
695
696#[cfg(test)]
697mod tests {
698    use std::{
699        collections::{HashMap, HashSet},
700        sync::Arc,
701    };
702
703    use js_int::{int, uint};
704    use maplit::{hashmap, hashset};
705    use rand::seq::SliceRandom;
706    use ruma_common::{
707        room_version_rules::AuthorizationRules, MilliSecondsSinceUnixEpoch, OwnedEventId,
708    };
709    use ruma_events::{
710        room::join_rules::{JoinRule, RoomJoinRulesEventContent},
711        StateEventType, TimelineEventType,
712    };
713    use serde_json::{json, value::to_raw_value as to_raw_json_value};
714    use tracing::debug;
715
716    use crate::{
717        is_power_event,
718        test_utils::{
719            alice, bob, charlie, do_check, ella, event_id, member_content_ban, member_content_join,
720            room_id, to_init_pdu_event, to_pdu_event, zara, PduEvent, TestStore, INITIAL_EVENTS,
721        },
722        Event, EventTypeExt, StateMap,
723    };
724
725    fn test_event_sort() {
726        let _ =
727            tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish());
728        let events = INITIAL_EVENTS();
729
730        let event_map = events
731            .values()
732            .map(|ev| (ev.event_type().with_state_key(ev.state_key().unwrap()), ev.clone()))
733            .collect::<StateMap<_>>();
734
735        let auth_chain: HashSet<OwnedEventId> = HashSet::new();
736
737        let power_events = event_map
738            .values()
739            .filter(|&pdu| is_power_event(&**pdu))
740            .map(|pdu| pdu.event_id.clone())
741            .collect::<Vec<_>>();
742
743        let sorted_power_events = crate::reverse_topological_power_sort(
744            power_events,
745            &auth_chain,
746            &AuthorizationRules::V6,
747            |id| events.get(id).cloned(),
748        )
749        .unwrap();
750
751        let resolved_power = crate::iterative_auth_check(
752            &AuthorizationRules::V6,
753            &sorted_power_events,
754            HashMap::new(), // unconflicted events
755            |id| events.get(id).cloned(),
756        )
757        .expect("iterative auth check failed on resolved events");
758
759        // don't remove any events so we know it sorts them all correctly
760        let mut events_to_sort = events.keys().cloned().collect::<Vec<_>>();
761
762        events_to_sort.shuffle(&mut rand::thread_rng());
763
764        let power_level =
765            resolved_power.get(&(StateEventType::RoomPowerLevels, "".to_owned())).cloned();
766
767        let sorted_event_ids =
768            crate::mainline_sort(&events_to_sort, power_level, |id| events.get(id).cloned())
769                .unwrap();
770
771        assert_eq!(
772            vec![
773                "$CREATE:foo",
774                "$IMA:foo",
775                "$IPOWER:foo",
776                "$IJR:foo",
777                "$IMB:foo",
778                "$IMC:foo",
779                "$START:foo",
780                "$END:foo"
781            ],
782            sorted_event_ids.iter().map(|id| id.to_string()).collect::<Vec<_>>()
783        );
784    }
785
786    #[test]
787    fn test_sort() {
788        for _ in 0..20 {
789            // since we shuffle the eventIds before we sort them introducing randomness
790            // seems like we should test this a few times
791            test_event_sort();
792        }
793    }
794
795    #[test]
796    fn ban_vs_power_level() {
797        let _ =
798            tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish());
799
800        let events = &[
801            to_init_pdu_event(
802                "PA",
803                alice(),
804                TimelineEventType::RoomPowerLevels,
805                Some(""),
806                to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 50 } })).unwrap(),
807            ),
808            to_init_pdu_event(
809                "MA",
810                alice(),
811                TimelineEventType::RoomMember,
812                Some(alice().to_string().as_str()),
813                member_content_join(),
814            ),
815            to_init_pdu_event(
816                "MB",
817                alice(),
818                TimelineEventType::RoomMember,
819                Some(bob().to_string().as_str()),
820                member_content_ban(),
821            ),
822            to_init_pdu_event(
823                "PB",
824                bob(),
825                TimelineEventType::RoomPowerLevels,
826                Some(""),
827                to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 50 } })).unwrap(),
828            ),
829        ];
830
831        let edges = vec![vec!["END", "MB", "MA", "PA", "START"], vec!["END", "PA", "PB"]]
832            .into_iter()
833            .map(|list| list.into_iter().map(event_id).collect::<Vec<_>>())
834            .collect::<Vec<_>>();
835
836        let expected_state_ids =
837            vec!["PA", "MA", "MB"].into_iter().map(event_id).collect::<Vec<_>>();
838
839        do_check(events, edges, expected_state_ids);
840    }
841
842    #[test]
843    fn topic_basic() {
844        let _ =
845            tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish());
846
847        let events = &[
848            to_init_pdu_event(
849                "T1",
850                alice(),
851                TimelineEventType::RoomTopic,
852                Some(""),
853                to_raw_json_value(&json!({})).unwrap(),
854            ),
855            to_init_pdu_event(
856                "PA1",
857                alice(),
858                TimelineEventType::RoomPowerLevels,
859                Some(""),
860                to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 50 } })).unwrap(),
861            ),
862            to_init_pdu_event(
863                "T2",
864                alice(),
865                TimelineEventType::RoomTopic,
866                Some(""),
867                to_raw_json_value(&json!({})).unwrap(),
868            ),
869            to_init_pdu_event(
870                "PA2",
871                alice(),
872                TimelineEventType::RoomPowerLevels,
873                Some(""),
874                to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 0 } })).unwrap(),
875            ),
876            to_init_pdu_event(
877                "PB",
878                bob(),
879                TimelineEventType::RoomPowerLevels,
880                Some(""),
881                to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 50 } })).unwrap(),
882            ),
883            to_init_pdu_event(
884                "T3",
885                bob(),
886                TimelineEventType::RoomTopic,
887                Some(""),
888                to_raw_json_value(&json!({})).unwrap(),
889            ),
890        ];
891
892        let edges =
893            vec![vec!["END", "PA2", "T2", "PA1", "T1", "START"], vec!["END", "T3", "PB", "PA1"]]
894                .into_iter()
895                .map(|list| list.into_iter().map(event_id).collect::<Vec<_>>())
896                .collect::<Vec<_>>();
897
898        let expected_state_ids = vec!["PA2", "T2"].into_iter().map(event_id).collect::<Vec<_>>();
899
900        do_check(events, edges, expected_state_ids);
901    }
902
903    #[test]
904    fn topic_reset() {
905        let _ =
906            tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish());
907
908        let events = &[
909            to_init_pdu_event(
910                "T1",
911                alice(),
912                TimelineEventType::RoomTopic,
913                Some(""),
914                to_raw_json_value(&json!({})).unwrap(),
915            ),
916            to_init_pdu_event(
917                "PA",
918                alice(),
919                TimelineEventType::RoomPowerLevels,
920                Some(""),
921                to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 50 } })).unwrap(),
922            ),
923            to_init_pdu_event(
924                "T2",
925                bob(),
926                TimelineEventType::RoomTopic,
927                Some(""),
928                to_raw_json_value(&json!({})).unwrap(),
929            ),
930            to_init_pdu_event(
931                "MB",
932                alice(),
933                TimelineEventType::RoomMember,
934                Some(bob().to_string().as_str()),
935                member_content_ban(),
936            ),
937        ];
938
939        let edges = vec![vec!["END", "MB", "T2", "PA", "T1", "START"], vec!["END", "T1"]]
940            .into_iter()
941            .map(|list| list.into_iter().map(event_id).collect::<Vec<_>>())
942            .collect::<Vec<_>>();
943
944        let expected_state_ids =
945            vec!["T1", "MB", "PA"].into_iter().map(event_id).collect::<Vec<_>>();
946
947        do_check(events, edges, expected_state_ids);
948    }
949
950    #[test]
951    fn join_rule_evasion() {
952        let _ =
953            tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish());
954
955        let events = &[
956            to_init_pdu_event(
957                "JR",
958                alice(),
959                TimelineEventType::RoomJoinRules,
960                Some(""),
961                to_raw_json_value(&RoomJoinRulesEventContent::new(JoinRule::Private)).unwrap(),
962            ),
963            to_init_pdu_event(
964                "ME",
965                ella(),
966                TimelineEventType::RoomMember,
967                Some(ella().to_string().as_str()),
968                member_content_join(),
969            ),
970        ];
971
972        let edges = vec![vec!["END", "JR", "START"], vec!["END", "ME", "START"]]
973            .into_iter()
974            .map(|list| list.into_iter().map(event_id).collect::<Vec<_>>())
975            .collect::<Vec<_>>();
976
977        let expected_state_ids = vec![event_id("JR")];
978
979        do_check(events, edges, expected_state_ids);
980    }
981
982    #[test]
983    fn offtopic_power_level() {
984        let _ =
985            tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish());
986
987        let events = &[
988            to_init_pdu_event(
989                "PA",
990                alice(),
991                TimelineEventType::RoomPowerLevels,
992                Some(""),
993                to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 50 } })).unwrap(),
994            ),
995            to_init_pdu_event(
996                "PB",
997                bob(),
998                TimelineEventType::RoomPowerLevels,
999                Some(""),
1000                to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 50, charlie(): 50 } }))
1001                    .unwrap(),
1002            ),
1003            to_init_pdu_event(
1004                "PC",
1005                charlie(),
1006                TimelineEventType::RoomPowerLevels,
1007                Some(""),
1008                to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 50, charlie(): 0 } }))
1009                    .unwrap(),
1010            ),
1011        ];
1012
1013        let edges = vec![vec!["END", "PC", "PB", "PA", "START"], vec!["END", "PA"]]
1014            .into_iter()
1015            .map(|list| list.into_iter().map(event_id).collect::<Vec<_>>())
1016            .collect::<Vec<_>>();
1017
1018        let expected_state_ids = vec!["PC"].into_iter().map(event_id).collect::<Vec<_>>();
1019
1020        do_check(events, edges, expected_state_ids);
1021    }
1022
1023    #[test]
1024    fn topic_setting() {
1025        let _ =
1026            tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish());
1027
1028        let events = &[
1029            to_init_pdu_event(
1030                "T1",
1031                alice(),
1032                TimelineEventType::RoomTopic,
1033                Some(""),
1034                to_raw_json_value(&json!({})).unwrap(),
1035            ),
1036            to_init_pdu_event(
1037                "PA1",
1038                alice(),
1039                TimelineEventType::RoomPowerLevels,
1040                Some(""),
1041                to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 50 } })).unwrap(),
1042            ),
1043            to_init_pdu_event(
1044                "T2",
1045                alice(),
1046                TimelineEventType::RoomTopic,
1047                Some(""),
1048                to_raw_json_value(&json!({})).unwrap(),
1049            ),
1050            to_init_pdu_event(
1051                "PA2",
1052                alice(),
1053                TimelineEventType::RoomPowerLevels,
1054                Some(""),
1055                to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 0 } })).unwrap(),
1056            ),
1057            to_init_pdu_event(
1058                "PB",
1059                bob(),
1060                TimelineEventType::RoomPowerLevels,
1061                Some(""),
1062                to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 50 } })).unwrap(),
1063            ),
1064            to_init_pdu_event(
1065                "T3",
1066                bob(),
1067                TimelineEventType::RoomTopic,
1068                Some(""),
1069                to_raw_json_value(&json!({})).unwrap(),
1070            ),
1071            to_init_pdu_event(
1072                "MZ1",
1073                zara(),
1074                TimelineEventType::RoomTopic,
1075                Some(""),
1076                to_raw_json_value(&json!({})).unwrap(),
1077            ),
1078            to_init_pdu_event(
1079                "T4",
1080                alice(),
1081                TimelineEventType::RoomTopic,
1082                Some(""),
1083                to_raw_json_value(&json!({})).unwrap(),
1084            ),
1085        ];
1086
1087        let edges = vec![
1088            vec!["END", "T4", "MZ1", "PA2", "T2", "PA1", "T1", "START"],
1089            vec!["END", "MZ1", "T3", "PB", "PA1"],
1090        ]
1091        .into_iter()
1092        .map(|list| list.into_iter().map(event_id).collect::<Vec<_>>())
1093        .collect::<Vec<_>>();
1094
1095        let expected_state_ids = vec!["T4", "PA2"].into_iter().map(event_id).collect::<Vec<_>>();
1096
1097        do_check(events, edges, expected_state_ids);
1098    }
1099
1100    #[test]
1101    fn test_event_map_none() {
1102        let _ =
1103            tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish());
1104
1105        let mut store = TestStore::<PduEvent>(hashmap! {});
1106
1107        // build up the DAG
1108        let (state_at_bob, state_at_charlie, expected) = store.set_up();
1109
1110        let ev_map = store.0.clone();
1111        let state_sets = [state_at_bob, state_at_charlie];
1112        let resolved = match crate::resolve(
1113            &AuthorizationRules::V1,
1114            &state_sets,
1115            state_sets
1116                .iter()
1117                .map(|map| {
1118                    store.auth_event_ids(room_id(), map.values().cloned().collect()).unwrap()
1119                })
1120                .collect(),
1121            |id| ev_map.get(id).cloned(),
1122        ) {
1123            Ok(state) => state,
1124            Err(e) => panic!("{e}"),
1125        };
1126
1127        assert_eq!(expected, resolved);
1128    }
1129
1130    #[test]
1131    fn test_lexicographical_sort() {
1132        let _ =
1133            tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish());
1134
1135        let graph = hashmap! {
1136            event_id("l") => hashset![event_id("o")],
1137            event_id("m") => hashset![event_id("n"), event_id("o")],
1138            event_id("n") => hashset![event_id("o")],
1139            event_id("o") => hashset![], // "o" has zero outgoing edges but 4 incoming edges
1140            event_id("p") => hashset![event_id("o")],
1141        };
1142
1143        let res = crate::lexicographical_topological_sort(&graph, |_id| {
1144            Ok((int!(0), MilliSecondsSinceUnixEpoch(uint!(0))))
1145        })
1146        .unwrap();
1147
1148        assert_eq!(
1149            vec!["o", "l", "n", "m", "p"],
1150            res.iter()
1151                .map(ToString::to_string)
1152                .map(|s| s.replace('$', "").replace(":foo", ""))
1153                .collect::<Vec<_>>()
1154        );
1155    }
1156
1157    #[test]
1158    fn ban_with_auth_chains() {
1159        let _ =
1160            tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish());
1161        let ban = BAN_STATE_SET();
1162
1163        let edges = vec![vec!["END", "MB", "PA", "START"], vec!["END", "IME", "MB"]]
1164            .into_iter()
1165            .map(|list| list.into_iter().map(event_id).collect::<Vec<_>>())
1166            .collect::<Vec<_>>();
1167
1168        let expected_state_ids = vec!["PA", "MB"].into_iter().map(event_id).collect::<Vec<_>>();
1169
1170        do_check(&ban.values().cloned().collect::<Vec<_>>(), edges, expected_state_ids);
1171    }
1172
1173    #[test]
1174    fn ban_with_auth_chains2() {
1175        let _ =
1176            tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish());
1177        let init = INITIAL_EVENTS();
1178        let ban = BAN_STATE_SET();
1179
1180        let mut inner = init.clone();
1181        inner.extend(ban);
1182        let store = TestStore(inner.clone());
1183
1184        let state_set_a = [
1185            inner.get(&event_id("CREATE")).unwrap(),
1186            inner.get(&event_id("IJR")).unwrap(),
1187            inner.get(&event_id("IMA")).unwrap(),
1188            inner.get(&event_id("IMB")).unwrap(),
1189            inner.get(&event_id("IMC")).unwrap(),
1190            inner.get(&event_id("MB")).unwrap(),
1191            inner.get(&event_id("PA")).unwrap(),
1192        ]
1193        .iter()
1194        .map(|ev| (ev.event_type().with_state_key(ev.state_key().unwrap()), ev.event_id.clone()))
1195        .collect::<StateMap<_>>();
1196
1197        let state_set_b = [
1198            inner.get(&event_id("CREATE")).unwrap(),
1199            inner.get(&event_id("IJR")).unwrap(),
1200            inner.get(&event_id("IMA")).unwrap(),
1201            inner.get(&event_id("IMB")).unwrap(),
1202            inner.get(&event_id("IMC")).unwrap(),
1203            inner.get(&event_id("IME")).unwrap(),
1204            inner.get(&event_id("PA")).unwrap(),
1205        ]
1206        .iter()
1207        .map(|ev| (ev.event_type().with_state_key(ev.state_key().unwrap()), ev.event_id.clone()))
1208        .collect::<StateMap<_>>();
1209
1210        let ev_map = &store.0;
1211        let state_sets = [state_set_a, state_set_b];
1212        let resolved = match crate::resolve(
1213            &AuthorizationRules::V6,
1214            &state_sets,
1215            state_sets
1216                .iter()
1217                .map(|map| {
1218                    store.auth_event_ids(room_id(), map.values().cloned().collect()).unwrap()
1219                })
1220                .collect(),
1221            |id| ev_map.get(id).cloned(),
1222        ) {
1223            Ok(state) => state,
1224            Err(e) => panic!("{e}"),
1225        };
1226
1227        debug!(
1228            resolved = ?resolved
1229                .iter()
1230                .map(|((ty, key), id)| format!("(({ty}{key:?}), {id})"))
1231                .collect::<Vec<_>>(),
1232            "resolved state",
1233        );
1234
1235        let expected =
1236            ["$CREATE:foo", "$IJR:foo", "$PA:foo", "$IMA:foo", "$IMB:foo", "$IMC:foo", "$MB:foo"];
1237
1238        for id in expected.iter().map(|i| event_id(i)) {
1239            // make sure our resolved events are equal to the expected list
1240            assert!(resolved.values().any(|eid| eid == &id) || init.contains_key(&id), "{id}");
1241        }
1242        assert_eq!(expected.len(), resolved.len());
1243    }
1244
1245    #[test]
1246    fn join_rule_with_auth_chain() {
1247        let join_rule = JOIN_RULE();
1248
1249        let edges = vec![vec!["END", "JR", "START"], vec!["END", "IMZ", "START"]]
1250            .into_iter()
1251            .map(|list| list.into_iter().map(event_id).collect::<Vec<_>>())
1252            .collect::<Vec<_>>();
1253
1254        let expected_state_ids = vec!["JR"].into_iter().map(event_id).collect::<Vec<_>>();
1255
1256        do_check(&join_rule.values().cloned().collect::<Vec<_>>(), edges, expected_state_ids);
1257    }
1258
1259    #[allow(non_snake_case)]
1260    fn BAN_STATE_SET() -> HashMap<OwnedEventId, Arc<PduEvent>> {
1261        vec![
1262            to_pdu_event(
1263                "PA",
1264                alice(),
1265                TimelineEventType::RoomPowerLevels,
1266                Some(""),
1267                to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 50 } })).unwrap(),
1268                &["CREATE", "IMA", "IPOWER"], // auth_events
1269                &["START"],                   // prev_events
1270            ),
1271            to_pdu_event(
1272                "PB",
1273                alice(),
1274                TimelineEventType::RoomPowerLevels,
1275                Some(""),
1276                to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 50 } })).unwrap(),
1277                &["CREATE", "IMA", "IPOWER"],
1278                &["END"],
1279            ),
1280            to_pdu_event(
1281                "MB",
1282                alice(),
1283                TimelineEventType::RoomMember,
1284                Some(ella().as_str()),
1285                member_content_ban(),
1286                &["CREATE", "IMA", "PB"],
1287                &["PA"],
1288            ),
1289            to_pdu_event(
1290                "IME",
1291                ella(),
1292                TimelineEventType::RoomMember,
1293                Some(ella().as_str()),
1294                member_content_join(),
1295                &["CREATE", "IJR", "PA"],
1296                &["MB"],
1297            ),
1298        ]
1299        .into_iter()
1300        .map(|ev| (ev.event_id.clone(), ev))
1301        .collect()
1302    }
1303
1304    #[allow(non_snake_case)]
1305    fn JOIN_RULE() -> HashMap<OwnedEventId, Arc<PduEvent>> {
1306        vec![
1307            to_pdu_event(
1308                "JR",
1309                alice(),
1310                TimelineEventType::RoomJoinRules,
1311                Some(""),
1312                to_raw_json_value(&json!({ "join_rule": "invite" })).unwrap(),
1313                &["CREATE", "IMA", "IPOWER"],
1314                &["START"],
1315            ),
1316            to_pdu_event(
1317                "IMZ",
1318                zara(),
1319                TimelineEventType::RoomPowerLevels,
1320                Some(zara().as_str()),
1321                member_content_join(),
1322                &["CREATE", "JR", "IPOWER"],
1323                &["START"],
1324            ),
1325        ]
1326        .into_iter()
1327        .map(|ev| (ev.event_id.clone(), ev))
1328        .collect()
1329    }
1330
1331    macro_rules! state_set {
1332        ($($kind:expr => $key:expr => $id:expr),* $(,)?) => {{
1333            #[allow(unused_mut)]
1334            let mut x = StateMap::new();
1335            $(
1336                x.insert(($kind, $key.to_owned()), $id);
1337            )*
1338            x
1339        }};
1340    }
1341
1342    #[test]
1343    fn separate_unique_conflicted() {
1344        let (unconflicted, conflicted) = super::separate(
1345            [
1346                state_set![StateEventType::RoomMember => "@a:hs1" => 0],
1347                state_set![StateEventType::RoomMember => "@b:hs1" => 1],
1348                state_set![StateEventType::RoomMember => "@c:hs1" => 2],
1349            ]
1350            .iter(),
1351        );
1352
1353        assert_eq!(unconflicted, StateMap::new());
1354        assert_eq!(
1355            conflicted,
1356            state_set![
1357                StateEventType::RoomMember => "@a:hs1" => vec![0],
1358                StateEventType::RoomMember => "@b:hs1" => vec![1],
1359                StateEventType::RoomMember => "@c:hs1" => vec![2],
1360            ],
1361        );
1362    }
1363
1364    #[test]
1365    fn separate_conflicted() {
1366        let (unconflicted, mut conflicted) = super::separate(
1367            [
1368                state_set![StateEventType::RoomMember => "@a:hs1" => 0],
1369                state_set![StateEventType::RoomMember => "@a:hs1" => 1],
1370                state_set![StateEventType::RoomMember => "@a:hs1" => 2],
1371            ]
1372            .iter(),
1373        );
1374
1375        // HashMap iteration order is random, so sort this before asserting on it
1376        for v in conflicted.values_mut() {
1377            v.sort_unstable();
1378        }
1379
1380        assert_eq!(unconflicted, StateMap::new());
1381        assert_eq!(
1382            conflicted,
1383            state_set![
1384                StateEventType::RoomMember => "@a:hs1" => vec![0, 1, 2],
1385            ],
1386        );
1387    }
1388
1389    #[test]
1390    fn separate_unconflicted() {
1391        let (unconflicted, conflicted) = super::separate(
1392            [
1393                state_set![StateEventType::RoomMember => "@a:hs1" => 0],
1394                state_set![StateEventType::RoomMember => "@a:hs1" => 0],
1395                state_set![StateEventType::RoomMember => "@a:hs1" => 0],
1396            ]
1397            .iter(),
1398        );
1399
1400        assert_eq!(
1401            unconflicted,
1402            state_set![
1403                StateEventType::RoomMember => "@a:hs1" => 0,
1404            ],
1405        );
1406        assert_eq!(conflicted, StateMap::new());
1407    }
1408
1409    #[test]
1410    fn separate_mixed() {
1411        let (unconflicted, conflicted) = super::separate(
1412            [
1413                state_set![StateEventType::RoomMember => "@a:hs1" => 0],
1414                state_set![
1415                    StateEventType::RoomMember => "@a:hs1" => 0,
1416                    StateEventType::RoomMember => "@b:hs1" => 1,
1417                ],
1418                state_set![
1419                    StateEventType::RoomMember => "@a:hs1" => 0,
1420                    StateEventType::RoomMember => "@c:hs1" => 2,
1421                ],
1422            ]
1423            .iter(),
1424        );
1425
1426        assert_eq!(
1427            unconflicted,
1428            state_set![
1429                StateEventType::RoomMember => "@a:hs1" => 0,
1430            ],
1431        );
1432        assert_eq!(
1433            conflicted,
1434            state_set![
1435                StateEventType::RoomMember => "@b:hs1" => vec![1],
1436                StateEventType::RoomMember => "@c:hs1" => vec![2],
1437            ],
1438        );
1439    }
1440}