Skip to main content

ruma_state_res/
state_res.rs

1use std::{
2    borrow::Borrow,
3    cmp::{Ordering, Reverse},
4    collections::{BinaryHeap, HashMap, HashSet},
5    hash::Hash,
6    sync::OnceLock,
7};
8
9use ruma_common::{
10    EventId, MilliSecondsSinceUnixEpoch, OwnedUserId,
11    room_version_rules::{AuthorizationRules, StateResolutionV2Rules},
12};
13use ruma_events::{
14    StateEventType, TimelineEventType,
15    room::{member::MembershipState, power_levels::UserPowerLevel},
16};
17use tracing::{debug, info, instrument, trace, warn};
18
19#[cfg(test)]
20mod tests;
21
22use crate::{
23    Error, Event, Result, auth_types_for_event, check_state_dependent_auth_rules,
24    events::{
25        RoomCreateEvent, RoomMemberEvent, RoomPowerLevelsEvent, RoomPowerLevelsIntField,
26        power_levels::RoomPowerLevelsEventOptionExt,
27    },
28    utils::{RoomIdExt, event_id_map::EventIdMap, event_id_set::EventIdSet},
29};
30
31/// A mapping of event type and state_key to some value `T`, usually an `EventId`.
32///
33/// This is the representation of what the Matrix specification calls a "room state" or a "state
34/// map" during [state resolution].
35///
36/// [state resolution]: https://spec.matrix.org/v1.18/rooms/v2/#state-resolution
37pub type StateMap<T> = HashMap<(StateEventType, String), T>;
38
39/// Apply the [state resolution] algorithm introduced in room version 2 to resolve the state of a
40/// room.
41///
42/// ## Arguments
43///
44/// * `auth_rules` - The authorization rules to apply for the version of the current room.
45///
46/// * `state_res_rules` - The state resolution rules to apply for the version of the current room.
47///
48/// * `state_maps` - The incoming states to resolve. Each `StateMap` represents a possible fork in
49///   the state of a room.
50///
51/// * `auth_chains` - The list of full recursive sets of `auth_events` for each event in the
52///   `state_maps`.
53///
54/// * `fetch_event` - Function to fetch an event in the room given its event ID.
55///
56/// * `fetch_conflicted_state_subgraph` - Function to fetch the conflicted state subgraph for the
57///   given conflicted state set, for state resolution rules that use it. If it is called and
58///   returns `None`, this function will return an error.
59///
60/// ## Invariants
61///
62/// The caller of `resolve` must ensure that all the events are from the same room.
63///
64/// ## Returns
65///
66/// The resolved room state.
67///
68/// [state resolution]: https://spec.matrix.org/v1.18/rooms/v2/#state-resolution
69#[instrument(skip_all)]
70pub fn resolve<'a, E, MapsIter>(
71    auth_rules: &AuthorizationRules,
72    state_res_rules: &StateResolutionV2Rules,
73    state_maps: impl IntoIterator<IntoIter = MapsIter>,
74    auth_chains: Vec<EventIdSet<E::Id>>,
75    fetch_event: impl Fn(&EventId) -> Option<E>,
76    fetch_conflicted_state_subgraph: impl Fn(&StateMap<Vec<E::Id>>) -> Option<EventIdSet<E::Id>>,
77) -> Result<StateMap<E::Id>>
78where
79    E: Event + Clone,
80    E::Id: 'a,
81    MapsIter: Iterator<Item = &'a StateMap<E::Id>> + Clone,
82{
83    info!("state resolution starting");
84
85    // Split the unconflicted state map and the conflicted state set.
86    let (unconflicted_state_map, conflicted_state_set) =
87        split_conflicted_state_set(state_maps.into_iter());
88
89    info!(count = unconflicted_state_map.len(), "unconflicted events");
90    trace!(map = ?unconflicted_state_map, "unconflicted events");
91
92    if conflicted_state_set.is_empty() {
93        info!("no conflicted state found");
94        return Ok(unconflicted_state_map);
95    }
96
97    info!(count = conflicted_state_set.len(), "conflicted events");
98    trace!(map = ?conflicted_state_set, "conflicted events");
99
100    // Since v12, fetch the conflicted state subgraph.
101    let conflicted_state_subgraph = if state_res_rules.consider_conflicted_state_subgraph {
102        let conflicted_state_subgraph = fetch_conflicted_state_subgraph(&conflicted_state_set)
103            .ok_or(Error::FetchConflictedStateSubgraphFailed)?;
104
105        info!(count = conflicted_state_subgraph.len(), "events in conflicted state subgraph");
106        trace!(set = ?conflicted_state_subgraph, "conflicted state subgraph");
107
108        conflicted_state_subgraph
109    } else {
110        EventIdSet::new()
111    };
112
113    // The full conflicted set is the union of the conflicted state set and the auth difference,
114    // and since v12, the conflicted state subgraph.
115    let full_conflicted_set: EventIdSet<_> = auth_difference(auth_chains)
116        .chain(conflicted_state_set.into_values().flatten())
117        .chain(conflicted_state_subgraph)
118        // Don't honor events we cannot "verify"
119        .filter(|id| fetch_event(id.borrow()).is_some())
120        .collect();
121
122    info!(count = full_conflicted_set.len(), "full conflicted set");
123    trace!(set = ?full_conflicted_set, "full conflicted set");
124
125    // 1. Select the set X of all power events that appear in the full conflicted set. For each such
126    //    power event P, enlarge X by adding the events in the auth chain of P which also belong to
127    //    the full conflicted set. Sort X into a list using the reverse topological power ordering.
128    let conflicted_power_events = full_conflicted_set
129        .iter()
130        .filter(|&id| is_power_event_id(id.borrow(), &fetch_event))
131        .cloned()
132        .collect::<Vec<_>>();
133
134    let sorted_power_events =
135        sort_power_events(conflicted_power_events, &full_conflicted_set, auth_rules, &fetch_event)?;
136
137    debug!(count = sorted_power_events.len(), "power events");
138    trace!(list = ?sorted_power_events, "sorted power events");
139
140    // 2. Apply the iterative auth checks algorithm, starting from the unconflicted state map, to
141    //    the list of events from the previous step to get a partially resolved state.
142
143    // Since v12, begin the first phase of iterative auth checks with an empty state map.
144    let initial_state_map = if state_res_rules.begin_iterative_auth_checks_with_empty_state_map {
145        HashMap::new()
146    } else {
147        unconflicted_state_map.clone()
148    };
149
150    let partially_resolved_state =
151        iterative_auth_checks(auth_rules, &sorted_power_events, initial_state_map, &fetch_event)?;
152
153    debug!(count = partially_resolved_state.len(), "resolved power events");
154    trace!(map = ?partially_resolved_state, "resolved power events");
155
156    // 3. Take all remaining events that weren’t picked in step 1 and order them by the mainline
157    //    ordering based on the power level in the partially resolved state obtained in step 2.
158    let sorted_power_events_set = sorted_power_events.into_iter().collect::<EventIdSet<_>>();
159    let remaining_events = full_conflicted_set
160        .iter()
161        .filter(|&id| !sorted_power_events_set.contains(id.borrow()))
162        .cloned()
163        .collect::<Vec<_>>();
164
165    debug!(count = remaining_events.len(), "events left to resolve");
166    trace!(list = ?remaining_events, "events left to resolve");
167
168    // This "epochs" power level event
169    let power_event = partially_resolved_state.get(&(StateEventType::RoomPowerLevels, "".into()));
170
171    debug!(event_id = ?power_event, "power event");
172
173    let sorted_remaining_events =
174        mainline_sort(&remaining_events, power_event.cloned(), &fetch_event)?;
175
176    trace!(list = ?sorted_remaining_events, "events left, sorted");
177
178    // 4. Apply the iterative auth checks algorithm on the partial resolved state and the list of
179    //    events from the previous step.
180    let mut resolved_state = iterative_auth_checks(
181        auth_rules,
182        &sorted_remaining_events,
183        partially_resolved_state,
184        &fetch_event,
185    )?;
186
187    // 5. Update the result by replacing any event with the event with the same key from the
188    //    unconflicted state map, if such an event exists, to get the final resolved state.
189    resolved_state.extend(unconflicted_state_map);
190
191    info!("state resolution finished");
192
193    Ok(resolved_state)
194}
195
196/// Split the unconflicted state map and the conflicted state set.
197///
198/// Definition in the specification:
199///
200/// > If a given key _K_ is present in every _Si_ with the same value _V_ in each state map, then
201/// > the pair (_K_, _V_) belongs to the unconflicted state map. Otherwise, _V_ belongs to the
202/// > conflicted state set.
203///
204/// It means that, for a given (event type, state key) tuple, if all state maps have the same event
205/// ID, it lands in the unconflicted state map, otherwise the event IDs land in the conflicted state
206/// set.
207///
208/// ## Arguments
209///
210/// * `state_maps` - The incoming states to resolve. Each `StateMap` represents a possible fork in
211///   the state of a room.
212///
213/// ## Returns
214///
215/// Returns an `(unconflicted_state_map, conflicted_state_set)` tuple.
216fn split_conflicted_state_set<'a, Id>(
217    state_maps: impl Iterator<Item = &'a StateMap<Id>>,
218) -> (StateMap<Id>, StateMap<Vec<Id>>)
219where
220    Id: Clone + Eq + Hash + 'a,
221{
222    let mut state_set_count = 0_usize;
223    let mut occurrences = HashMap::<_, HashMap<_, _>>::new();
224
225    let state_maps = state_maps.inspect(|_| state_set_count += 1);
226    for (k, v) in state_maps.flatten() {
227        occurrences.entry(k).or_default().entry(v).and_modify(|x| *x += 1).or_insert(1);
228    }
229
230    let mut unconflicted_state_map = StateMap::new();
231    let mut conflicted_state_set = StateMap::new();
232
233    for (k, v) in occurrences {
234        for (id, occurrence_count) in v {
235            if occurrence_count == state_set_count {
236                unconflicted_state_map.insert((k.0.clone(), k.1.clone()), id.clone());
237            } else {
238                conflicted_state_set
239                    .entry((k.0.clone(), k.1.clone()))
240                    .and_modify(|x: &mut Vec<_>| x.push(id.clone()))
241                    .or_insert(vec![id.clone()]);
242            }
243        }
244    }
245
246    (unconflicted_state_map, conflicted_state_set)
247}
248
249/// Get the auth difference for the given auth chains.
250///
251/// Definition in the specification:
252///
253/// > The auth difference is calculated by first calculating the full auth chain for each state
254/// > _Si_, that is the union of the auth chains for each event in _Si_, and then taking every event
255/// > that doesn’t appear in every auth chain. If _Ci_ is the full auth chain of _Si_, then the auth
256/// > difference is ∪_Ci_ − ∩_Ci_.
257///
258/// ## Arguments
259///
260/// * `auth_chains` - The list of full recursive sets of `auth_events`.
261///
262/// ## Returns
263///
264/// Returns an iterator over all the event IDs that are not present in all the auth chains.
265fn auth_difference<Id>(auth_chains: Vec<EventIdSet<Id>>) -> impl Iterator<Item = Id>
266where
267    Id: Eq + Hash + Borrow<EventId>,
268{
269    let num_sets = auth_chains.len();
270
271    let mut id_counts: EventIdMap<Id, usize> = EventIdMap::new();
272    for id in auth_chains.into_iter().flatten() {
273        *id_counts.entry(id).or_default() += 1;
274    }
275
276    id_counts.into_iter().filter_map(move |(id, count)| (count < num_sets).then_some(id))
277}
278
279/// Enlarge the given list of conflicted power events by adding the events in their auth chain that
280/// are in the full conflicted set, and sort it using reverse topological power ordering.
281///
282/// ## Arguments
283///
284/// * `conflicted_power_events` - The list of power events in the full conflicted set.
285///
286/// * `full_conflicted_set` - The full conflicted set.
287///
288/// * `rules` - The authorization rules for the current room version.
289///
290/// * `fetch_event` - Function to fetch an event in the room given its event ID.
291///
292/// ## Returns
293///
294/// Returns the ordered list of event IDs from earliest to latest.
295#[instrument(skip_all)]
296fn sort_power_events<E: Event>(
297    conflicted_power_events: Vec<E::Id>,
298    full_conflicted_set: &EventIdSet<E::Id>,
299    rules: &AuthorizationRules,
300    fetch_event: impl Fn(&EventId) -> Option<E>,
301) -> Result<Vec<E::Id>> {
302    debug!("reverse topological sort of power events");
303
304    // A representation of the DAG, a map of event ID to its list of auth events that are in the
305    // full conflicted set.
306    let mut graph = EventIdMap::new();
307
308    // Fill the graph.
309    for event_id in conflicted_power_events {
310        add_event_and_auth_chain_to_graph(&mut graph, event_id, full_conflicted_set, &fetch_event);
311
312        // TODO: if these functions are ever made async here
313        // is a good place to yield every once in a while so other
314        // tasks can make progress
315    }
316
317    // The map of event ID to the power level of the sender of the event.
318    let mut event_to_power_level = EventIdMap::new();
319    // We need to know the creator in case of missing power levels. Given that it's the same for all
320    // the events in the room, we will just load it for the first event and reuse it.
321    let creators_lock = OnceLock::new();
322
323    // Get the power level of the sender of each event in the graph.
324    for event_id in graph.keys() {
325        let sender_power_level =
326            power_level_for_sender(event_id.borrow(), rules, &creators_lock, &fetch_event)
327                .map_err(Error::AuthEvent)?;
328        debug!(
329            event_id = event_id.borrow().as_str(),
330            power_level = ?sender_power_level,
331            "found the power level of an event's sender",
332        );
333
334        event_to_power_level.insert(event_id.clone(), sender_power_level);
335
336        // TODO: if these functions are ever made async here
337        // is a good place to yield every once in a while so other
338        // tasks can make progress
339    }
340
341    reverse_topological_power_sort(&graph, |event_id| {
342        let event = fetch_event(event_id).ok_or_else(|| Error::NotFound(event_id.to_owned()))?;
343        let power_level = *event_to_power_level
344            .get(event_id)
345            .ok_or_else(|| Error::NotFound(event_id.to_owned()))?;
346        Ok((power_level, event.origin_server_ts()))
347    })
348}
349
350/// Sorts the given event graph using reverse topological power ordering.
351///
352/// Definition in the specification:
353///
354/// > The reverse topological power ordering of a set of events is the lexicographically smallest
355/// > topological ordering based on the DAG formed by auth events. The reverse topological power
356/// > ordering is ordered from earliest event to latest. For comparing two topological orderings to
357/// > determine which is the lexicographically smallest, the following comparison relation on events
358/// > is used: for events x and y, x < y if
359/// >
360/// > 1. x’s sender has greater power level than y’s sender, when looking at their respective
361/// > auth_events; or
362/// > 2. the senders have the same power level, but x’s origin_server_ts is less than y’s
363/// > origin_server_ts; or
364/// > 3. the senders have the same power level and the events have the same origin_server_ts, but
365/// > x’s event_id is less than y’s event_id.
366/// >
367/// > The reverse topological power ordering can be found by sorting the events using Kahn’s
368/// > algorithm for topological sorting, and at each step selecting, among all the candidate
369/// > vertices, the smallest vertex using the above comparison relation.
370///
371/// ## Arguments
372///
373/// * `graph` - The graph to sort. A map of event ID to its auth events that are in the full
374///   conflicted set.
375///
376/// * `event_details_fn` - Function to obtain a (power level, origin_server_ts) of an event for
377///   breaking ties.
378///
379/// ## Returns
380///
381/// Returns the ordered list of event IDs from earliest to latest.
382#[instrument(skip_all)]
383pub fn reverse_topological_power_sort<Id, F>(
384    graph: &EventIdMap<Id, EventIdSet<Id>>,
385    event_details_fn: F,
386) -> Result<Vec<Id>>
387where
388    F: Fn(&EventId) -> Result<(UserPowerLevel, MilliSecondsSinceUnixEpoch)>,
389    Id: Clone + Eq + Ord + Hash + Borrow<EventId>,
390{
391    #[derive(PartialEq, Eq)]
392    struct TieBreaker<Id> {
393        power_level: UserPowerLevel,
394        origin_server_ts: MilliSecondsSinceUnixEpoch,
395        event_id: Id,
396    }
397
398    impl<Id> Ord for TieBreaker<Id>
399    where
400        Id: Ord,
401    {
402        fn cmp(&self, other: &Self) -> Ordering {
403            // NOTE: the power level comparison is "backwards" intentionally.
404            other
405                .power_level
406                .cmp(&self.power_level)
407                .then(self.origin_server_ts.cmp(&other.origin_server_ts))
408                .then(self.event_id.cmp(&other.event_id))
409        }
410    }
411
412    impl<Id> PartialOrd for TieBreaker<Id>
413    where
414        Id: Ord,
415    {
416        fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
417            Some(self.cmp(other))
418        }
419    }
420
421    // We consider that the DAG is directed from most recent events to oldest events, so an event is
422    // an incoming edge to its auth events.
423
424    // Map of event to the list of events in its auth events.
425    let mut outgoing_edges_map: EventIdMap<_, EventIdSet<_>> = EventIdMap::new();
426
427    // Map of event to the list of events that reference it in its auth events.
428    let mut incoming_edges_map: EventIdMap<_, EventIdSet<_>> = EventIdMap::new();
429
430    // Vec of events that have an outdegree of zero (no outgoing edges), i.e. the oldest events.
431    // Use a BinaryHeap to keep the events sorted.
432    let mut heap = BinaryHeap::new();
433
434    // Populate the list of events with an outdegree of zero, and the maps of incoming and outgoing
435    // edges with the graph.
436    for (event_id, outgoing_edges) in graph {
437        if outgoing_edges.is_empty() {
438            let (power_level, origin_server_ts) = event_details_fn(event_id.borrow())?;
439
440            // `Reverse` because `BinaryHeap` sorts largest -> smallest and we need
441            // smallest -> largest.
442            heap.push(Reverse(TieBreaker {
443                power_level,
444                origin_server_ts,
445                event_id: event_id.clone(),
446            }));
447        } else {
448            for auth_event_id in outgoing_edges {
449                incoming_edges_map
450                    .entry(auth_event_id.borrow())
451                    .or_default()
452                    .insert(event_id.borrow());
453            }
454
455            outgoing_edges_map
456                .insert(event_id.clone(), outgoing_edges.iter().map(Borrow::borrow).collect());
457        }
458    }
459
460    let mut sorted = vec![];
461
462    // Apply Kahn's algorithm.
463    // https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm
464    while let Some(Reverse(TieBreaker { event_id, .. })) = heap.pop() {
465        for &parent_id in incoming_edges_map.get(event_id.borrow()).into_iter().flatten() {
466            let parent_has_zero_outdegrees = {
467                let outgoing_edges = outgoing_edges_map.get_mut(parent_id).expect(
468                    "outgoing edges map should have a key for all event IDs with outgoing edges",
469                );
470
471                outgoing_edges.remove(event_id.borrow());
472                outgoing_edges.is_empty()
473            };
474
475            // Push on the heap once all the outgoing edges have been removed.
476            if parent_has_zero_outdegrees {
477                let (power_level, origin_server_ts) = event_details_fn(parent_id)?;
478                // Because the parent has no more outgoing edges, we can remove its entry from the
479                // outgoing edges map to get the owned event ID used for the key.
480                let (parent_id, _) = outgoing_edges_map
481                    .remove_entry(parent_id)
482                    .expect("outgoing edges map should have a key for all event IDs");
483
484                heap.push(Reverse(TieBreaker {
485                    power_level,
486                    origin_server_ts,
487                    event_id: parent_id,
488                }));
489            }
490        }
491
492        sorted.push(event_id);
493    }
494
495    Ok(sorted)
496}
497
498/// Find the power level for the sender of the event of the given event ID or return a default value
499/// of zero.
500///
501/// We find the most recent `m.room.power_levels` by walking backwards in the auth chain of the
502/// event.
503///
504/// Do NOT use this anywhere but topological sort.
505///
506/// ## Arguments
507///
508/// * `event_id` - The event ID of the event to get the power level of the sender of.
509///
510/// * `rules` - The authorization rules for the current room version.
511///
512/// * `creator_lock` - A lock used to cache the user ID of the creator of the room. If it is empty
513///   the creator will be fetched in the auth chain and used to populate the lock.
514///
515/// * `fetch_event` - Function to fetch an event in the room given its event ID.
516///
517/// ## Returns
518///
519/// Returns the power level of the sender of the event or an `Err(_)` if one of the auth events if
520/// malformed.
521fn power_level_for_sender<E: Event>(
522    event_id: &EventId,
523    rules: &AuthorizationRules,
524    creators_lock: &OnceLock<HashSet<OwnedUserId>>,
525    fetch_event: impl Fn(&EventId) -> Option<E>,
526) -> std::result::Result<UserPowerLevel, String> {
527    let event = fetch_event(event_id);
528    let mut room_create_event = None;
529    let mut room_power_levels_event = None;
530
531    if let Some(event) = &event
532        && rules.room_create_event_id_as_room_id
533        && creators_lock.get().is_none()
534    {
535        // The m.room.create event is not in the auth events, we can get its ID via the room ID.
536        room_create_event = event
537            .room_id()
538            .and_then(|room_id| room_id.room_create_event_id().ok())
539            .and_then(|room_create_event_id| fetch_event(&room_create_event_id));
540    }
541
542    for auth_event_id in event.as_ref().map(|pdu| pdu.auth_events()).into_iter().flatten() {
543        if let Some(auth_event) = fetch_event(auth_event_id.borrow()) {
544            if is_type_and_key(&auth_event, &TimelineEventType::RoomPowerLevels, "") {
545                room_power_levels_event = Some(RoomPowerLevelsEvent::new(auth_event));
546            } else if !rules.room_create_event_id_as_room_id
547                && creators_lock.get().is_none()
548                && is_type_and_key(&auth_event, &TimelineEventType::RoomCreate, "")
549            {
550                room_create_event = Some(auth_event);
551            }
552
553            if room_power_levels_event.is_some()
554                && (rules.room_create_event_id_as_room_id
555                    || creators_lock.get().is_some()
556                    || room_create_event.is_some())
557            {
558                break;
559            }
560        }
561    }
562
563    // TODO: Use OnceLock::try_or_get_init when it is stabilized.
564    let creators = if let Some(creators) = creators_lock.get() {
565        Some(creators)
566    } else if let Some(room_create_event) = room_create_event {
567        let room_create_event = RoomCreateEvent::new(room_create_event);
568        let creators = room_create_event.creators(rules)?;
569        Some(creators_lock.get_or_init(|| creators))
570    } else {
571        None
572    };
573
574    if let Some((event, creators)) = event.zip(creators) {
575        room_power_levels_event.user_power_level(event.sender(), creators, rules)
576    } else {
577        room_power_levels_event
578            .get_as_int_or_default(RoomPowerLevelsIntField::UsersDefault, rules)
579            .map(Into::into)
580    }
581}
582
583/// Perform the iterative auth checks to the given list of events.
584///
585/// Definition in the specification:
586///
587/// > The iterative auth checks algorithm takes as input an initial room state and a sorted list of
588/// > state events, and constructs a new room state by iterating through the event list and applying
589/// > the state event to the room state if the state event is allowed by the authorization rules. If
590/// > the state event is not allowed by the authorization rules, then the event is ignored. If a
591/// > (event_type, state_key) key that is required for checking the authorization rules is not
592/// > present in the state, then the appropriate state event from the event’s auth_events is used if
593/// > the auth event is not rejected.
594///
595/// ## Arguments
596///
597/// * `rules` - The authorization rules for the current room version.
598///
599/// * `events` - The sorted state events to apply to the `partial_state`.
600///
601/// * `state` - The current state that was partially resolved for the room.
602///
603/// * `fetch_event` - Function to fetch an event in the room given its event ID.
604///
605/// ## Returns
606///
607/// Returns the partially resolved state, or an `Err(_)` if one of the state events in the room has
608/// an unexpected format.
609fn iterative_auth_checks<E: Event + Clone>(
610    rules: &AuthorizationRules,
611    events: &[E::Id],
612    mut state: StateMap<E::Id>,
613    fetch_event: impl Fn(&EventId) -> Option<E>,
614) -> Result<StateMap<E::Id>> {
615    debug!("starting iterative auth checks");
616
617    trace!(list = ?events, "events to check");
618
619    for event_id in events {
620        let event = fetch_event(event_id.borrow())
621            .ok_or_else(|| Error::NotFound(event_id.borrow().to_owned()))?;
622        let state_key = event.state_key().ok_or(Error::MissingStateKey)?;
623
624        let mut auth_events = StateMap::new();
625        for auth_event_id in event.auth_events() {
626            if let Some(auth_event) = fetch_event(auth_event_id.borrow()) {
627                if !auth_event.rejected() {
628                    auth_events.insert(
629                        auth_event
630                            .event_type()
631                            .with_state_key(auth_event.state_key().ok_or(Error::MissingStateKey)?),
632                        auth_event,
633                    );
634                }
635            } else {
636                warn!(event_id = %auth_event_id.borrow(), "missing auth event");
637            }
638        }
639
640        // If the `m.room.create` event is not in the auth events, we need to add it, because it's
641        // always part of the state and required in the auth rules.
642        if rules.room_create_event_id_as_room_id
643            && *event.event_type() != TimelineEventType::RoomCreate
644        {
645            if let Some(room_create_event) = event
646                .room_id()
647                .and_then(|room_id| room_id.room_create_event_id().ok())
648                .and_then(|room_create_event_id| fetch_event(&room_create_event_id))
649            {
650                auth_events.insert((StateEventType::RoomCreate, String::new()), room_create_event);
651            } else {
652                warn!("missing m.room.create event");
653            }
654        }
655
656        let auth_types = match auth_types_for_event(
657            event.event_type(),
658            event.sender(),
659            Some(state_key),
660            event.content(),
661            rules,
662        ) {
663            Ok(auth_types) => auth_types,
664            Err(error) => {
665                warn!("failed to get list of required auth events for malformed event: {error}");
666                continue;
667            }
668        };
669
670        for key in auth_types {
671            if let Some(auth_event_id) = state.get(&key) {
672                if let Some(auth_event) = fetch_event(auth_event_id.borrow()) {
673                    if !auth_event.rejected() {
674                        auth_events.insert(key.to_owned(), auth_event);
675                    }
676                } else {
677                    warn!(event_id = %auth_event_id.borrow(), "missing auth event");
678                }
679            }
680        }
681
682        match check_state_dependent_auth_rules(rules, &event, |ty, key| {
683            auth_events.get(&ty.with_state_key(key))
684        }) {
685            Ok(()) => {
686                // Add event to the partially resolved state.
687                state.insert(event.event_type().with_state_key(state_key), event_id.clone());
688            }
689            Err(error) => {
690                // Don't add this event to the state.
691                warn!(event_id = ?event.event_id(), "event failed the authentication check: {error}");
692            }
693        }
694
695        // TODO: if these functions are ever made async here
696        // is a good place to yield every once in a while so other
697        // tasks can make progress
698    }
699
700    Ok(state)
701}
702
703/// Perform mainline ordering of the given events.
704///
705/// Definition in the spec:
706///
707/// > Given mainline positions calculated from P, the mainline ordering based on P of a set of
708/// > events is the ordering, from smallest to largest, using the following comparison relation on
709/// > events: for events x and y, x < y if
710/// >
711/// > 1. the mainline position of x is greater than the mainline position of y (i.e. the auth chain
712/// > of x is based on an earlier event in the mainline than y); or
713/// > 2. the mainline positions of the events are the same, but x’s origin_server_ts is less than
714/// > y’s origin_server_ts; or
715/// > 3. the mainline positions of the events are the same and the events have the same
716/// > origin_server_ts, but x’s event_id is less than y’s event_id.
717///
718/// ## Arguments
719///
720/// * `events` - The list of event IDs to sort.
721///
722/// * `power_level` - The power level event in the current state.
723///
724/// * `fetch_event` - Function to fetch an event in the room given its event ID.
725///
726/// ## Returns
727///
728/// Returns the sorted list of event IDs, or an `Err(_)` if one the event in the room has an
729/// unexpected format.
730fn mainline_sort<E: Event>(
731    events: &[E::Id],
732    mut power_level: Option<E::Id>,
733    fetch_event: impl Fn(&EventId) -> Option<E>,
734) -> Result<Vec<E::Id>> {
735    debug!("mainline sort of events");
736
737    // There are no events to sort, bail.
738    if events.is_empty() {
739        return Ok(vec![]);
740    }
741
742    // Populate the mainline of the power level.
743    let mut mainline = vec![];
744
745    while let Some(power_level_event_id) = power_level {
746        mainline.push(power_level_event_id.clone());
747
748        let power_level_event = fetch_event(power_level_event_id.borrow())
749            .ok_or_else(|| Error::NotFound(power_level_event_id.borrow().to_owned()))?;
750
751        power_level = None;
752
753        for auth_event_id in power_level_event.auth_events() {
754            let auth_event = fetch_event(auth_event_id.borrow())
755                .ok_or_else(|| Error::NotFound(power_level_event_id.borrow().to_owned()))?;
756            if is_type_and_key(&auth_event, &TimelineEventType::RoomPowerLevels, "") {
757                power_level = Some(auth_event_id.to_owned());
758                break;
759            }
760        }
761
762        // TODO: if these functions are ever made async here
763        // is a good place to yield every once in a while so other
764        // tasks can make progress
765    }
766
767    let mainline_map = mainline
768        .iter()
769        .rev()
770        .enumerate()
771        .map(|(idx, event_id)| ((*event_id).clone(), idx))
772        .collect::<EventIdMap<_, _>>();
773
774    let mut order_map = HashMap::new();
775    for event_id in events.iter() {
776        if let Some(event) = fetch_event(event_id.borrow())
777            && let Ok(position) = mainline_position(event, &mainline_map, &fetch_event)
778        {
779            order_map.insert(
780                event_id,
781                (
782                    position,
783                    fetch_event(event_id.borrow()).map(|event| event.origin_server_ts()),
784                    event_id,
785                ),
786            );
787        }
788
789        // TODO: if these functions are ever made async here
790        // is a good place to yield every once in a while so other
791        // tasks can make progress
792    }
793
794    let mut sorted_event_ids = order_map.keys().map(|&k| k.clone()).collect::<Vec<_>>();
795    sorted_event_ids.sort_by_key(|event_id| order_map.get(event_id).unwrap());
796
797    Ok(sorted_event_ids)
798}
799
800/// Get the mainline position of the given event from the given mainline map.
801///
802/// Definition in the spec:
803///
804/// > Let P = P0 be an m.room.power_levels event. Starting with i = 0, repeatedly fetch Pi+1, the
805/// > m.room.power_levels event in the auth_events of Pi. Increment i and repeat until Pi has no
806/// > m.room.power_levels event in its auth_events. The mainline of P0 is the list of events [P0 ,
807/// > P1, … , Pn], fetched in this way.
808/// >
809/// > Let e = e0 be another event (possibly another m.room.power_levels event). We can compute a
810/// > similar list of events [e1, …, em], where ej+1 is the m.room.power_levels event in the
811/// > auth_events of ej and where em has no m.room.power_levels event in its auth_events. (Note that
812/// > the event we started with, e0, is not included in this list. Also note that it may be empty,
813/// > because e may not cite an m.room.power_levels event in its auth_events at all.)
814/// >
815/// > Now compare these two lists as follows.
816/// >
817/// > * Find the smallest index j ≥ 1 for which ej belongs to the mainline of P.
818/// > * If such a j exists, then ej = Pi for some unique index i ≥ 0. Otherwise set i = ∞, where ∞
819/// > is a sentinel value greater than any integer.
820/// > * In both cases, the mainline position of e is i.
821///
822/// ## Arguments
823///
824/// * `event` - The event to compute the mainline position of.
825///
826/// * `mainline_map` - The mainline map of the m.room.power_levels event.
827///
828/// * `fetch_event` - Function to fetch an event in the room given its event ID.
829///
830/// ## Returns
831///
832/// Returns the mainline position of the event, or an `Err(_)` if one of the events in the auth
833/// chain of the event was not found.
834fn mainline_position<E: Event>(
835    event: E,
836    mainline_map: &EventIdMap<E::Id, usize>,
837    fetch_event: impl Fn(&EventId) -> Option<E>,
838) -> Result<usize> {
839    let mut current_event = Some(event);
840
841    while let Some(event) = current_event {
842        let event_id = event.event_id();
843        debug!(event_id = event_id.borrow().as_str(), "mainline");
844
845        // If the current event is in the mainline map, return its position.
846        if let Some(position) = mainline_map.get(event_id.borrow()) {
847            return Ok(*position);
848        }
849
850        current_event = None;
851
852        // Look for the power levels event in the auth events.
853        for auth_event_id in event.auth_events() {
854            let auth_event = fetch_event(auth_event_id.borrow())
855                .ok_or_else(|| Error::NotFound(auth_event_id.borrow().to_owned()))?;
856
857            if is_type_and_key(&auth_event, &TimelineEventType::RoomPowerLevels, "") {
858                current_event = Some(auth_event);
859                break;
860            }
861        }
862    }
863
864    // Did not find a power level event so we default to zero.
865    Ok(0)
866}
867
868/// Add the event with the given event ID and all the events in its auth chain that are in the full
869/// conflicted set to the graph.
870fn add_event_and_auth_chain_to_graph<E: Event>(
871    graph: &mut EventIdMap<E::Id, EventIdSet<E::Id>>,
872    event_id: E::Id,
873    full_conflicted_set: &EventIdSet<E::Id>,
874    fetch_event: impl Fn(&EventId) -> Option<E>,
875) {
876    let mut state = vec![event_id];
877
878    // Iterate through the auth chain of the event.
879    while let Some(event_id) = state.pop() {
880        // Add the current event to the graph.
881        graph.entry(event_id.clone()).or_default();
882
883        // Iterate through the auth events of this event.
884        for auth_event_id in fetch_event(event_id.borrow())
885            .as_ref()
886            .map(|event| event.auth_events())
887            .into_iter()
888            .flatten()
889        {
890            // If the auth event ID is in the full conflicted set…
891            if full_conflicted_set.contains(auth_event_id.borrow()) {
892                // If the auth event ID is not in the graph, we need to check its auth events later.
893                if !graph.contains_event_id(auth_event_id.borrow()) {
894                    state.push(auth_event_id.to_owned());
895                }
896
897                // Add the auth event ID to the list of incoming edges.
898                graph.get_mut(event_id.borrow()).unwrap().insert(auth_event_id.to_owned());
899            }
900        }
901    }
902}
903
904/// Whether the given event ID belongs to a power event.
905///
906/// See the docs of `is_power_event()` for the definition of a power event.
907fn is_power_event_id<E: Event>(event_id: &EventId, fetch: impl Fn(&EventId) -> Option<E>) -> bool {
908    match fetch(event_id).as_ref() {
909        Some(state) => is_power_event(state),
910        _ => false,
911    }
912}
913
914fn is_type_and_key(event: impl Event, event_type: &TimelineEventType, state_key: &str) -> bool {
915    event.event_type() == event_type && event.state_key() == Some(state_key)
916}
917
918/// Whether the given event is a power event.
919///
920/// Definition in the spec:
921///
922/// > A power event is a state event with type `m.room.power_levels` or `m.room.join_rules`, or a
923/// > state event with type `m.room.member` where the `membership` is `leave` or `ban` and the
924/// > `sender` does not match the `state_key`. The idea behind this is that power events are events
925/// > that might remove someone’s ability to do something in the room.
926fn is_power_event(event: impl Event) -> bool {
927    match event.event_type() {
928        TimelineEventType::RoomPowerLevels
929        | TimelineEventType::RoomJoinRules
930        | TimelineEventType::RoomCreate => event.state_key() == Some(""),
931        TimelineEventType::RoomMember => {
932            let room_member_event = RoomMemberEvent::new(event);
933            if room_member_event.membership().is_ok_and(|membership| {
934                matches!(membership, MembershipState::Leave | MembershipState::Ban)
935            }) {
936                return Some(room_member_event.sender().as_str()) != room_member_event.state_key();
937            }
938
939            false
940        }
941        _ => false,
942    }
943}
944
945/// Convenience trait for adding event type plus state key to state maps.
946pub(crate) trait EventTypeExt {
947    fn with_state_key(self, state_key: impl Into<String>) -> (StateEventType, String);
948}
949
950impl EventTypeExt for StateEventType {
951    fn with_state_key(self, state_key: impl Into<String>) -> (StateEventType, String) {
952        (self, state_key.into())
953    }
954}
955
956impl EventTypeExt for TimelineEventType {
957    fn with_state_key(self, state_key: impl Into<String>) -> (StateEventType, String) {
958        (self.to_string().into(), state_key.into())
959    }
960}
961
962impl<T> EventTypeExt for &T
963where
964    T: EventTypeExt + Clone,
965{
966    fn with_state_key(self, state_key: impl Into<String>) -> (StateEventType, String) {
967        self.to_owned().with_state_key(state_key)
968    }
969}