use std::{
borrow::Borrow,
cmp::{Ordering, Reverse},
collections::{BinaryHeap, HashMap, HashSet},
hash::Hash,
};
use js_int::{int, Int};
use ruma_common::{EventId, MilliSecondsSinceUnixEpoch, RoomVersionId};
use ruma_events::{
room::member::{MembershipState, RoomMemberEventContent},
StateEventType, TimelineEventType,
};
use serde_json::from_str as from_json_str;
use tracing::{debug, info, instrument, trace, warn};
mod error;
pub mod event_auth;
mod power_levels;
pub mod room_version;
mod state_event;
#[cfg(test)]
mod test_utils;
pub use error::{Error, Result};
pub use event_auth::{auth_check, auth_types_for_event};
use power_levels::PowerLevelsContentFields;
pub use room_version::RoomVersion;
pub use state_event::Event;
pub type StateMap<T> = HashMap<(StateEventType, String), T>;
#[instrument(skip(state_sets, auth_chain_sets, fetch_event))]
pub fn resolve<'a, E, SetIter>(
room_version: &RoomVersionId,
state_sets: impl IntoIterator<IntoIter = SetIter>,
auth_chain_sets: Vec<HashSet<E::Id>>,
fetch_event: impl Fn(&EventId) -> Option<E>,
) -> Result<StateMap<E::Id>>
where
E: Event + Clone,
E::Id: 'a,
SetIter: Iterator<Item = &'a StateMap<E::Id>> + Clone,
{
info!("state resolution starting");
let (clean, conflicting) = separate(state_sets.into_iter());
info!(count = clean.len(), "non-conflicting events");
trace!(map = ?clean, "non-conflicting events");
if conflicting.is_empty() {
info!("no conflicting state found");
return Ok(clean);
}
info!(count = conflicting.len(), "conflicting events");
trace!(map = ?conflicting, "conflicting events");
let all_conflicted: HashSet<_> = get_auth_chain_diff(auth_chain_sets)
.chain(conflicting.into_values().flatten())
.filter(|id| fetch_event(id.borrow()).is_some())
.collect();
info!(count = all_conflicted.len(), "full conflicted set");
trace!(set = ?all_conflicted, "full conflicted set");
let control_events = all_conflicted
.iter()
.filter(|&id| is_power_event_id(id.borrow(), &fetch_event))
.cloned()
.collect::<Vec<_>>();
let sorted_control_levels =
reverse_topological_power_sort(control_events, &all_conflicted, &fetch_event)?;
debug!(count = sorted_control_levels.len(), "power events");
trace!(list = ?sorted_control_levels, "sorted power events");
let room_version = RoomVersion::new(room_version)?;
let resolved_control =
iterative_auth_check(&room_version, &sorted_control_levels, clean.clone(), &fetch_event)?;
debug!(count = resolved_control.len(), "resolved power events");
trace!(map = ?resolved_control, "resolved power events");
let deduped_power_ev = sorted_control_levels.into_iter().collect::<HashSet<_>>();
let events_to_resolve = all_conflicted
.iter()
.filter(|&id| !deduped_power_ev.contains(id.borrow()))
.cloned()
.collect::<Vec<_>>();
debug!(count = events_to_resolve.len(), "events left to resolve");
trace!(list = ?events_to_resolve, "events left to resolve");
let power_event = resolved_control.get(&(StateEventType::RoomPowerLevels, "".into()));
debug!(event_id = ?power_event, "power event");
let sorted_left_events = mainline_sort(&events_to_resolve, power_event.cloned(), &fetch_event)?;
trace!(list = ?sorted_left_events, "events left, sorted");
let mut resolved_state = iterative_auth_check(
&room_version,
&sorted_left_events,
resolved_control, &fetch_event,
)?;
resolved_state.extend(clean);
info!("state resolution finished");
Ok(resolved_state)
}
fn separate<'a, Id>(
state_sets_iter: impl Iterator<Item = &'a StateMap<Id>>,
) -> (StateMap<Id>, StateMap<Vec<Id>>)
where
Id: Clone + Eq + Hash + 'a,
{
let mut state_set_count = 0_usize;
let mut occurrences = HashMap::<_, HashMap<_, _>>::new();
let state_sets_iter = state_sets_iter.inspect(|_| state_set_count += 1);
for (k, v) in state_sets_iter.flatten() {
occurrences.entry(k).or_default().entry(v).and_modify(|x| *x += 1).or_insert(1);
}
let mut unconflicted_state = StateMap::new();
let mut conflicted_state = StateMap::new();
for (k, v) in occurrences {
for (id, occurrence_count) in v {
if occurrence_count == state_set_count {
unconflicted_state.insert((k.0.clone(), k.1.clone()), id.clone());
} else {
conflicted_state
.entry((k.0.clone(), k.1.clone()))
.and_modify(|x: &mut Vec<_>| x.push(id.clone()))
.or_insert(vec![id.clone()]);
}
}
}
(unconflicted_state, conflicted_state)
}
fn get_auth_chain_diff<Id>(auth_chain_sets: Vec<HashSet<Id>>) -> impl Iterator<Item = Id>
where
Id: Eq + Hash,
{
let num_sets = auth_chain_sets.len();
let mut id_counts: HashMap<Id, usize> = HashMap::new();
for id in auth_chain_sets.into_iter().flatten() {
*id_counts.entry(id).or_default() += 1;
}
id_counts.into_iter().filter_map(move |(id, count)| (count < num_sets).then_some(id))
}
#[instrument(skip_all)]
fn reverse_topological_power_sort<E: Event>(
events_to_sort: Vec<E::Id>,
auth_diff: &HashSet<E::Id>,
fetch_event: impl Fn(&EventId) -> Option<E>,
) -> Result<Vec<E::Id>> {
debug!("reverse topological sort of power events");
let mut graph = HashMap::new();
for event_id in events_to_sort {
add_event_and_auth_chain_to_graph(&mut graph, event_id, auth_diff, &fetch_event);
}
let mut event_to_pl = HashMap::new();
for event_id in graph.keys() {
let pl = get_power_level_for_sender(event_id.borrow(), &fetch_event)?;
debug!(
event_id = event_id.borrow().as_str(),
power_level = i64::from(pl),
"found the power level of an event's sender",
);
event_to_pl.insert(event_id.clone(), pl);
}
lexicographical_topological_sort(&graph, |event_id| {
let ev = fetch_event(event_id).ok_or_else(|| Error::NotFound("".into()))?;
let pl = *event_to_pl.get(event_id).ok_or_else(|| Error::NotFound("".into()))?;
Ok((pl, ev.origin_server_ts()))
})
}
#[instrument(skip_all)]
pub fn lexicographical_topological_sort<Id, F>(
graph: &HashMap<Id, HashSet<Id>>,
key_fn: F,
) -> Result<Vec<Id>>
where
F: Fn(&EventId) -> Result<(Int, MilliSecondsSinceUnixEpoch)>,
Id: Clone + Eq + Ord + Hash + Borrow<EventId>,
{
#[derive(PartialEq, Eq)]
struct TieBreaker<'a, Id> {
power_level: Int,
origin_server_ts: MilliSecondsSinceUnixEpoch,
event_id: &'a Id,
}
impl<Id> Ord for TieBreaker<'_, Id>
where
Id: Ord,
{
fn cmp(&self, other: &Self) -> Ordering {
other
.power_level
.cmp(&self.power_level)
.then(self.origin_server_ts.cmp(&other.origin_server_ts))
.then(self.event_id.cmp(other.event_id))
}
}
impl<Id> PartialOrd for TieBreaker<'_, Id>
where
Id: Ord,
{
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
let mut outdegree_map = graph.clone();
let mut reverse_graph: HashMap<_, HashSet<_>> = HashMap::new();
let mut zero_outdegree = Vec::new();
for (node, edges) in graph {
if edges.is_empty() {
let (power_level, origin_server_ts) = key_fn(node.borrow())?;
zero_outdegree.push(Reverse(TieBreaker {
power_level,
origin_server_ts,
event_id: node,
}));
}
reverse_graph.entry(node).or_default();
for edge in edges {
reverse_graph.entry(edge).or_default().insert(node);
}
}
let mut heap = BinaryHeap::from(zero_outdegree);
let mut sorted = vec![];
while let Some(Reverse(item)) = heap.pop() {
let node = item.event_id;
for &parent in reverse_graph.get(node).expect("EventId in heap is also in reverse_graph") {
let out = outdegree_map
.get_mut(parent.borrow())
.expect("outdegree_map knows of all referenced EventIds");
out.remove(node.borrow());
if out.is_empty() {
let (power_level, origin_server_ts) = key_fn(parent.borrow())?;
heap.push(Reverse(TieBreaker { power_level, origin_server_ts, event_id: parent }));
}
}
sorted.push(node.clone());
}
Ok(sorted)
}
fn get_power_level_for_sender<E: Event>(
event_id: &EventId,
fetch_event: impl Fn(&EventId) -> Option<E>,
) -> serde_json::Result<Int> {
let event = fetch_event(event_id);
let mut pl = None;
for aid in event.as_ref().map(|pdu| pdu.auth_events()).into_iter().flatten() {
if let Some(aev) = fetch_event(aid.borrow()) {
if is_type_and_key(&aev, &TimelineEventType::RoomPowerLevels, "") {
pl = Some(aev);
break;
}
}
}
let content: PowerLevelsContentFields = match pl {
None => return Ok(int!(0)),
Some(ev) => from_json_str(ev.content().get())?,
};
if let Some(ev) = event {
if let Some(&user_level) = content.users.get(ev.sender()) {
return Ok(user_level);
}
}
Ok(content.users_default)
}
fn iterative_auth_check<E: Event + Clone>(
room_version: &RoomVersion,
events_to_check: &[E::Id],
unconflicted_state: StateMap<E::Id>,
fetch_event: impl Fn(&EventId) -> Option<E>,
) -> Result<StateMap<E::Id>> {
debug!("starting iterative auth check");
trace!(list = ?events_to_check, "events to check");
let mut resolved_state = unconflicted_state;
for event_id in events_to_check {
let event = fetch_event(event_id.borrow())
.ok_or_else(|| Error::NotFound(format!("Failed to find {event_id}")))?;
let state_key = event
.state_key()
.ok_or_else(|| Error::InvalidPdu("State event had no state key".to_owned()))?;
let mut auth_events = StateMap::new();
for aid in event.auth_events() {
if let Some(ev) = fetch_event(aid.borrow()) {
auth_events.insert(
ev.event_type().with_state_key(ev.state_key().ok_or_else(|| {
Error::InvalidPdu("State event had no state key".to_owned())
})?),
ev,
);
} else {
warn!(event_id = aid.borrow().as_str(), "missing auth event");
}
}
for key in auth_types_for_event(
event.event_type(),
event.sender(),
Some(state_key),
event.content(),
)? {
if let Some(ev_id) = resolved_state.get(&key) {
if let Some(event) = fetch_event(ev_id.borrow()) {
auth_events.insert(key.to_owned(), event);
}
}
}
let current_third_party = auth_events.iter().find_map(|(_, pdu)| {
(*pdu.event_type() == TimelineEventType::RoomThirdPartyInvite).then_some(pdu)
});
if auth_check(room_version, &event, current_third_party, |ty, key| {
auth_events.get(&ty.with_state_key(key))
})? {
resolved_state.insert(event.event_type().with_state_key(state_key), event_id.clone());
} else {
warn!("event failed the authentication check");
}
}
Ok(resolved_state)
}
fn mainline_sort<E: Event>(
to_sort: &[E::Id],
resolved_power_level: Option<E::Id>,
fetch_event: impl Fn(&EventId) -> Option<E>,
) -> Result<Vec<E::Id>> {
debug!("mainline sort of events");
if to_sort.is_empty() {
return Ok(vec![]);
}
let mut mainline = vec![];
let mut pl = resolved_power_level;
while let Some(p) = pl {
mainline.push(p.clone());
let event = fetch_event(p.borrow())
.ok_or_else(|| Error::NotFound(format!("Failed to find {p}")))?;
pl = None;
for aid in event.auth_events() {
let ev = fetch_event(aid.borrow())
.ok_or_else(|| Error::NotFound(format!("Failed to find {aid}")))?;
if is_type_and_key(&ev, &TimelineEventType::RoomPowerLevels, "") {
pl = Some(aid.to_owned());
break;
}
}
}
let mainline_map = mainline
.iter()
.rev()
.enumerate()
.map(|(idx, eid)| ((*eid).clone(), idx))
.collect::<HashMap<_, _>>();
let mut order_map = HashMap::new();
for ev_id in to_sort.iter() {
if let Some(event) = fetch_event(ev_id.borrow()) {
if let Ok(depth) = get_mainline_depth(Some(event), &mainline_map, &fetch_event) {
order_map.insert(
ev_id,
(depth, fetch_event(ev_id.borrow()).map(|ev| ev.origin_server_ts()), ev_id),
);
}
}
}
let mut sort_event_ids = order_map.keys().map(|&k| k.clone()).collect::<Vec<_>>();
sort_event_ids.sort_by_key(|sort_id| order_map.get(sort_id).unwrap());
Ok(sort_event_ids)
}
fn get_mainline_depth<E: Event>(
mut event: Option<E>,
mainline_map: &HashMap<E::Id, usize>,
fetch_event: impl Fn(&EventId) -> Option<E>,
) -> Result<usize> {
while let Some(sort_ev) = event {
debug!(event_id = sort_ev.event_id().borrow().as_str(), "mainline");
let id = sort_ev.event_id();
if let Some(depth) = mainline_map.get(id.borrow()) {
return Ok(*depth);
}
event = None;
for aid in sort_ev.auth_events() {
let aev = fetch_event(aid.borrow())
.ok_or_else(|| Error::NotFound(format!("Failed to find {aid}")))?;
if is_type_and_key(&aev, &TimelineEventType::RoomPowerLevels, "") {
event = Some(aev);
break;
}
}
}
Ok(0)
}
fn add_event_and_auth_chain_to_graph<E: Event>(
graph: &mut HashMap<E::Id, HashSet<E::Id>>,
event_id: E::Id,
auth_diff: &HashSet<E::Id>,
fetch_event: impl Fn(&EventId) -> Option<E>,
) {
let mut state = vec![event_id];
while let Some(eid) = state.pop() {
graph.entry(eid.clone()).or_default();
for aid in
fetch_event(eid.borrow()).as_ref().map(|ev| ev.auth_events()).into_iter().flatten()
{
if auth_diff.contains(aid.borrow()) {
if !graph.contains_key(aid.borrow()) {
state.push(aid.to_owned());
}
graph.get_mut(eid.borrow()).unwrap().insert(aid.to_owned());
}
}
}
}
fn is_power_event_id<E: Event>(event_id: &EventId, fetch: impl Fn(&EventId) -> Option<E>) -> bool {
match fetch(event_id).as_ref() {
Some(state) => is_power_event(state),
_ => false,
}
}
fn is_type_and_key(ev: impl Event, ev_type: &TimelineEventType, state_key: &str) -> bool {
ev.event_type() == ev_type && ev.state_key() == Some(state_key)
}
fn is_power_event(event: impl Event) -> bool {
match event.event_type() {
TimelineEventType::RoomPowerLevels
| TimelineEventType::RoomJoinRules
| TimelineEventType::RoomCreate => event.state_key() == Some(""),
TimelineEventType::RoomMember => {
if let Ok(content) = from_json_str::<RoomMemberEventContent>(event.content().get()) {
if [MembershipState::Leave, MembershipState::Ban].contains(&content.membership) {
return Some(event.sender().as_str()) != event.state_key();
}
}
false
}
_ => false,
}
}
trait EventTypeExt {
fn with_state_key(self, state_key: impl Into<String>) -> (StateEventType, String);
}
impl EventTypeExt for StateEventType {
fn with_state_key(self, state_key: impl Into<String>) -> (StateEventType, String) {
(self, state_key.into())
}
}
impl EventTypeExt for TimelineEventType {
fn with_state_key(self, state_key: impl Into<String>) -> (StateEventType, String) {
(self.to_string().into(), state_key.into())
}
}
impl<T> EventTypeExt for &T
where
T: EventTypeExt + Clone,
{
fn with_state_key(self, state_key: impl Into<String>) -> (StateEventType, String) {
self.to_owned().with_state_key(state_key)
}
}
#[cfg(test)]
mod tests {
use std::{
collections::{HashMap, HashSet},
sync::Arc,
};
use js_int::{int, uint};
use maplit::{hashmap, hashset};
use rand::seq::SliceRandom;
use ruma_common::{MilliSecondsSinceUnixEpoch, OwnedEventId, RoomVersionId};
use ruma_events::{
room::join_rules::{JoinRule, RoomJoinRulesEventContent},
StateEventType, TimelineEventType,
};
use serde_json::{json, value::to_raw_value as to_raw_json_value};
use tracing::debug;
use crate::{
is_power_event,
room_version::RoomVersion,
test_utils::{
alice, bob, charlie, do_check, ella, event_id, member_content_ban, member_content_join,
room_id, to_init_pdu_event, to_pdu_event, zara, PduEvent, TestStore, INITIAL_EVENTS,
},
Event, EventTypeExt, StateMap,
};
fn test_event_sort() {
let _ =
tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish());
let events = INITIAL_EVENTS();
let event_map = events
.values()
.map(|ev| (ev.event_type().with_state_key(ev.state_key().unwrap()), ev.clone()))
.collect::<StateMap<_>>();
let auth_chain: HashSet<OwnedEventId> = HashSet::new();
let power_events = event_map
.values()
.filter(|&pdu| is_power_event(&**pdu))
.map(|pdu| pdu.event_id.clone())
.collect::<Vec<_>>();
let sorted_power_events =
crate::reverse_topological_power_sort(power_events, &auth_chain, |id| {
events.get(id).cloned()
})
.unwrap();
let resolved_power = crate::iterative_auth_check(
&RoomVersion::V6,
&sorted_power_events,
HashMap::new(), |id| events.get(id).cloned(),
)
.expect("iterative auth check failed on resolved events");
let mut events_to_sort = events.keys().cloned().collect::<Vec<_>>();
events_to_sort.shuffle(&mut rand::thread_rng());
let power_level =
resolved_power.get(&(StateEventType::RoomPowerLevels, "".to_owned())).cloned();
let sorted_event_ids =
crate::mainline_sort(&events_to_sort, power_level, |id| events.get(id).cloned())
.unwrap();
assert_eq!(
vec![
"$CREATE:foo",
"$IMA:foo",
"$IPOWER:foo",
"$IJR:foo",
"$IMB:foo",
"$IMC:foo",
"$START:foo",
"$END:foo"
],
sorted_event_ids.iter().map(|id| id.to_string()).collect::<Vec<_>>()
);
}
#[test]
fn test_sort() {
for _ in 0..20 {
test_event_sort();
}
}
#[test]
fn ban_vs_power_level() {
let _ =
tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish());
let events = &[
to_init_pdu_event(
"PA",
alice(),
TimelineEventType::RoomPowerLevels,
Some(""),
to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 50 } })).unwrap(),
),
to_init_pdu_event(
"MA",
alice(),
TimelineEventType::RoomMember,
Some(alice().to_string().as_str()),
member_content_join(),
),
to_init_pdu_event(
"MB",
alice(),
TimelineEventType::RoomMember,
Some(bob().to_string().as_str()),
member_content_ban(),
),
to_init_pdu_event(
"PB",
bob(),
TimelineEventType::RoomPowerLevels,
Some(""),
to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 50 } })).unwrap(),
),
];
let edges = vec![vec!["END", "MB", "MA", "PA", "START"], vec!["END", "PA", "PB"]]
.into_iter()
.map(|list| list.into_iter().map(event_id).collect::<Vec<_>>())
.collect::<Vec<_>>();
let expected_state_ids =
vec!["PA", "MA", "MB"].into_iter().map(event_id).collect::<Vec<_>>();
do_check(events, edges, expected_state_ids);
}
#[test]
fn topic_basic() {
let _ =
tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish());
let events = &[
to_init_pdu_event(
"T1",
alice(),
TimelineEventType::RoomTopic,
Some(""),
to_raw_json_value(&json!({})).unwrap(),
),
to_init_pdu_event(
"PA1",
alice(),
TimelineEventType::RoomPowerLevels,
Some(""),
to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 50 } })).unwrap(),
),
to_init_pdu_event(
"T2",
alice(),
TimelineEventType::RoomTopic,
Some(""),
to_raw_json_value(&json!({})).unwrap(),
),
to_init_pdu_event(
"PA2",
alice(),
TimelineEventType::RoomPowerLevels,
Some(""),
to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 0 } })).unwrap(),
),
to_init_pdu_event(
"PB",
bob(),
TimelineEventType::RoomPowerLevels,
Some(""),
to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 50 } })).unwrap(),
),
to_init_pdu_event(
"T3",
bob(),
TimelineEventType::RoomTopic,
Some(""),
to_raw_json_value(&json!({})).unwrap(),
),
];
let edges =
vec![vec!["END", "PA2", "T2", "PA1", "T1", "START"], vec!["END", "T3", "PB", "PA1"]]
.into_iter()
.map(|list| list.into_iter().map(event_id).collect::<Vec<_>>())
.collect::<Vec<_>>();
let expected_state_ids = vec!["PA2", "T2"].into_iter().map(event_id).collect::<Vec<_>>();
do_check(events, edges, expected_state_ids);
}
#[test]
fn topic_reset() {
let _ =
tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish());
let events = &[
to_init_pdu_event(
"T1",
alice(),
TimelineEventType::RoomTopic,
Some(""),
to_raw_json_value(&json!({})).unwrap(),
),
to_init_pdu_event(
"PA",
alice(),
TimelineEventType::RoomPowerLevels,
Some(""),
to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 50 } })).unwrap(),
),
to_init_pdu_event(
"T2",
bob(),
TimelineEventType::RoomTopic,
Some(""),
to_raw_json_value(&json!({})).unwrap(),
),
to_init_pdu_event(
"MB",
alice(),
TimelineEventType::RoomMember,
Some(bob().to_string().as_str()),
member_content_ban(),
),
];
let edges = vec![vec!["END", "MB", "T2", "PA", "T1", "START"], vec!["END", "T1"]]
.into_iter()
.map(|list| list.into_iter().map(event_id).collect::<Vec<_>>())
.collect::<Vec<_>>();
let expected_state_ids =
vec!["T1", "MB", "PA"].into_iter().map(event_id).collect::<Vec<_>>();
do_check(events, edges, expected_state_ids);
}
#[test]
fn join_rule_evasion() {
let _ =
tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish());
let events = &[
to_init_pdu_event(
"JR",
alice(),
TimelineEventType::RoomJoinRules,
Some(""),
to_raw_json_value(&RoomJoinRulesEventContent::new(JoinRule::Private)).unwrap(),
),
to_init_pdu_event(
"ME",
ella(),
TimelineEventType::RoomMember,
Some(ella().to_string().as_str()),
member_content_join(),
),
];
let edges = vec![vec!["END", "JR", "START"], vec!["END", "ME", "START"]]
.into_iter()
.map(|list| list.into_iter().map(event_id).collect::<Vec<_>>())
.collect::<Vec<_>>();
let expected_state_ids = vec![event_id("JR")];
do_check(events, edges, expected_state_ids);
}
#[test]
fn offtopic_power_level() {
let _ =
tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish());
let events = &[
to_init_pdu_event(
"PA",
alice(),
TimelineEventType::RoomPowerLevels,
Some(""),
to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 50 } })).unwrap(),
),
to_init_pdu_event(
"PB",
bob(),
TimelineEventType::RoomPowerLevels,
Some(""),
to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 50, charlie(): 50 } }))
.unwrap(),
),
to_init_pdu_event(
"PC",
charlie(),
TimelineEventType::RoomPowerLevels,
Some(""),
to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 50, charlie(): 0 } }))
.unwrap(),
),
];
let edges = vec![vec!["END", "PC", "PB", "PA", "START"], vec!["END", "PA"]]
.into_iter()
.map(|list| list.into_iter().map(event_id).collect::<Vec<_>>())
.collect::<Vec<_>>();
let expected_state_ids = vec!["PC"].into_iter().map(event_id).collect::<Vec<_>>();
do_check(events, edges, expected_state_ids);
}
#[test]
fn topic_setting() {
let _ =
tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish());
let events = &[
to_init_pdu_event(
"T1",
alice(),
TimelineEventType::RoomTopic,
Some(""),
to_raw_json_value(&json!({})).unwrap(),
),
to_init_pdu_event(
"PA1",
alice(),
TimelineEventType::RoomPowerLevels,
Some(""),
to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 50 } })).unwrap(),
),
to_init_pdu_event(
"T2",
alice(),
TimelineEventType::RoomTopic,
Some(""),
to_raw_json_value(&json!({})).unwrap(),
),
to_init_pdu_event(
"PA2",
alice(),
TimelineEventType::RoomPowerLevels,
Some(""),
to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 0 } })).unwrap(),
),
to_init_pdu_event(
"PB",
bob(),
TimelineEventType::RoomPowerLevels,
Some(""),
to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 50 } })).unwrap(),
),
to_init_pdu_event(
"T3",
bob(),
TimelineEventType::RoomTopic,
Some(""),
to_raw_json_value(&json!({})).unwrap(),
),
to_init_pdu_event(
"MZ1",
zara(),
TimelineEventType::RoomTopic,
Some(""),
to_raw_json_value(&json!({})).unwrap(),
),
to_init_pdu_event(
"T4",
alice(),
TimelineEventType::RoomTopic,
Some(""),
to_raw_json_value(&json!({})).unwrap(),
),
];
let edges = vec![
vec!["END", "T4", "MZ1", "PA2", "T2", "PA1", "T1", "START"],
vec!["END", "MZ1", "T3", "PB", "PA1"],
]
.into_iter()
.map(|list| list.into_iter().map(event_id).collect::<Vec<_>>())
.collect::<Vec<_>>();
let expected_state_ids = vec!["T4", "PA2"].into_iter().map(event_id).collect::<Vec<_>>();
do_check(events, edges, expected_state_ids);
}
#[test]
fn test_event_map_none() {
let _ =
tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish());
let mut store = TestStore::<PduEvent>(hashmap! {});
let (state_at_bob, state_at_charlie, expected) = store.set_up();
let ev_map = store.0.clone();
let state_sets = [state_at_bob, state_at_charlie];
let resolved = match crate::resolve(
&RoomVersionId::V2,
&state_sets,
state_sets
.iter()
.map(|map| {
store.auth_event_ids(room_id(), map.values().cloned().collect()).unwrap()
})
.collect(),
|id| ev_map.get(id).cloned(),
) {
Ok(state) => state,
Err(e) => panic!("{e}"),
};
assert_eq!(expected, resolved);
}
#[test]
fn test_lexicographical_sort() {
let _ =
tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish());
let graph = hashmap! {
event_id("l") => hashset![event_id("o")],
event_id("m") => hashset![event_id("n"), event_id("o")],
event_id("n") => hashset![event_id("o")],
event_id("o") => hashset![], event_id("p") => hashset![event_id("o")],
};
let res = crate::lexicographical_topological_sort(&graph, |_id| {
Ok((int!(0), MilliSecondsSinceUnixEpoch(uint!(0))))
})
.unwrap();
assert_eq!(
vec!["o", "l", "n", "m", "p"],
res.iter()
.map(ToString::to_string)
.map(|s| s.replace('$', "").replace(":foo", ""))
.collect::<Vec<_>>()
);
}
#[test]
fn ban_with_auth_chains() {
let _ =
tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish());
let ban = BAN_STATE_SET();
let edges = vec![vec!["END", "MB", "PA", "START"], vec!["END", "IME", "MB"]]
.into_iter()
.map(|list| list.into_iter().map(event_id).collect::<Vec<_>>())
.collect::<Vec<_>>();
let expected_state_ids = vec!["PA", "MB"].into_iter().map(event_id).collect::<Vec<_>>();
do_check(&ban.values().cloned().collect::<Vec<_>>(), edges, expected_state_ids);
}
#[test]
fn ban_with_auth_chains2() {
let _ =
tracing::subscriber::set_default(tracing_subscriber::fmt().with_test_writer().finish());
let init = INITIAL_EVENTS();
let ban = BAN_STATE_SET();
let mut inner = init.clone();
inner.extend(ban);
let store = TestStore(inner.clone());
let state_set_a = [
inner.get(&event_id("CREATE")).unwrap(),
inner.get(&event_id("IJR")).unwrap(),
inner.get(&event_id("IMA")).unwrap(),
inner.get(&event_id("IMB")).unwrap(),
inner.get(&event_id("IMC")).unwrap(),
inner.get(&event_id("MB")).unwrap(),
inner.get(&event_id("PA")).unwrap(),
]
.iter()
.map(|ev| (ev.event_type().with_state_key(ev.state_key().unwrap()), ev.event_id.clone()))
.collect::<StateMap<_>>();
let state_set_b = [
inner.get(&event_id("CREATE")).unwrap(),
inner.get(&event_id("IJR")).unwrap(),
inner.get(&event_id("IMA")).unwrap(),
inner.get(&event_id("IMB")).unwrap(),
inner.get(&event_id("IMC")).unwrap(),
inner.get(&event_id("IME")).unwrap(),
inner.get(&event_id("PA")).unwrap(),
]
.iter()
.map(|ev| (ev.event_type().with_state_key(ev.state_key().unwrap()), ev.event_id.clone()))
.collect::<StateMap<_>>();
let ev_map = &store.0;
let state_sets = [state_set_a, state_set_b];
let resolved = match crate::resolve(
&RoomVersionId::V6,
&state_sets,
state_sets
.iter()
.map(|map| {
store.auth_event_ids(room_id(), map.values().cloned().collect()).unwrap()
})
.collect(),
|id| ev_map.get(id).cloned(),
) {
Ok(state) => state,
Err(e) => panic!("{e}"),
};
debug!(
resolved = ?resolved
.iter()
.map(|((ty, key), id)| format!("(({ty}{key:?}), {id})"))
.collect::<Vec<_>>(),
"resolved state",
);
let expected =
["$CREATE:foo", "$IJR:foo", "$PA:foo", "$IMA:foo", "$IMB:foo", "$IMC:foo", "$MB:foo"];
for id in expected.iter().map(|i| event_id(i)) {
assert!(resolved.values().any(|eid| eid == &id) || init.contains_key(&id), "{id}");
}
assert_eq!(expected.len(), resolved.len());
}
#[test]
fn join_rule_with_auth_chain() {
let join_rule = JOIN_RULE();
let edges = vec![vec!["END", "JR", "START"], vec!["END", "IMZ", "START"]]
.into_iter()
.map(|list| list.into_iter().map(event_id).collect::<Vec<_>>())
.collect::<Vec<_>>();
let expected_state_ids = vec!["JR"].into_iter().map(event_id).collect::<Vec<_>>();
do_check(&join_rule.values().cloned().collect::<Vec<_>>(), edges, expected_state_ids);
}
#[allow(non_snake_case)]
fn BAN_STATE_SET() -> HashMap<OwnedEventId, Arc<PduEvent>> {
vec![
to_pdu_event(
"PA",
alice(),
TimelineEventType::RoomPowerLevels,
Some(""),
to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 50 } })).unwrap(),
&["CREATE", "IMA", "IPOWER"], &["START"], ),
to_pdu_event(
"PB",
alice(),
TimelineEventType::RoomPowerLevels,
Some(""),
to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 50 } })).unwrap(),
&["CREATE", "IMA", "IPOWER"],
&["END"],
),
to_pdu_event(
"MB",
alice(),
TimelineEventType::RoomMember,
Some(ella().as_str()),
member_content_ban(),
&["CREATE", "IMA", "PB"],
&["PA"],
),
to_pdu_event(
"IME",
ella(),
TimelineEventType::RoomMember,
Some(ella().as_str()),
member_content_join(),
&["CREATE", "IJR", "PA"],
&["MB"],
),
]
.into_iter()
.map(|ev| (ev.event_id.clone(), ev))
.collect()
}
#[allow(non_snake_case)]
fn JOIN_RULE() -> HashMap<OwnedEventId, Arc<PduEvent>> {
vec![
to_pdu_event(
"JR",
alice(),
TimelineEventType::RoomJoinRules,
Some(""),
to_raw_json_value(&json!({ "join_rule": "invite" })).unwrap(),
&["CREATE", "IMA", "IPOWER"],
&["START"],
),
to_pdu_event(
"IMZ",
zara(),
TimelineEventType::RoomPowerLevels,
Some(zara().as_str()),
member_content_join(),
&["CREATE", "JR", "IPOWER"],
&["START"],
),
]
.into_iter()
.map(|ev| (ev.event_id.clone(), ev))
.collect()
}
macro_rules! state_set {
($($kind:expr => $key:expr => $id:expr),* $(,)?) => {{
#[allow(unused_mut)]
let mut x = StateMap::new();
$(
x.insert(($kind, $key.to_owned()), $id);
)*
x
}};
}
#[test]
fn separate_unique_conflicted() {
let (unconflicted, conflicted) = super::separate(
[
state_set![StateEventType::RoomMember => "@a:hs1" => 0],
state_set![StateEventType::RoomMember => "@b:hs1" => 1],
state_set![StateEventType::RoomMember => "@c:hs1" => 2],
]
.iter(),
);
assert_eq!(unconflicted, StateMap::new());
assert_eq!(
conflicted,
state_set![
StateEventType::RoomMember => "@a:hs1" => vec![0],
StateEventType::RoomMember => "@b:hs1" => vec![1],
StateEventType::RoomMember => "@c:hs1" => vec![2],
],
);
}
#[test]
fn separate_conflicted() {
let (unconflicted, mut conflicted) = super::separate(
[
state_set![StateEventType::RoomMember => "@a:hs1" => 0],
state_set![StateEventType::RoomMember => "@a:hs1" => 1],
state_set![StateEventType::RoomMember => "@a:hs1" => 2],
]
.iter(),
);
for v in conflicted.values_mut() {
v.sort_unstable();
}
assert_eq!(unconflicted, StateMap::new());
assert_eq!(
conflicted,
state_set![
StateEventType::RoomMember => "@a:hs1" => vec![0, 1, 2],
],
);
}
#[test]
fn separate_unconflicted() {
let (unconflicted, conflicted) = super::separate(
[
state_set![StateEventType::RoomMember => "@a:hs1" => 0],
state_set![StateEventType::RoomMember => "@a:hs1" => 0],
state_set![StateEventType::RoomMember => "@a:hs1" => 0],
]
.iter(),
);
assert_eq!(
unconflicted,
state_set![
StateEventType::RoomMember => "@a:hs1" => 0,
],
);
assert_eq!(conflicted, StateMap::new());
}
#[test]
fn separate_mixed() {
let (unconflicted, conflicted) = super::separate(
[
state_set![StateEventType::RoomMember => "@a:hs1" => 0],
state_set![
StateEventType::RoomMember => "@a:hs1" => 0,
StateEventType::RoomMember => "@b:hs1" => 1,
],
state_set![
StateEventType::RoomMember => "@a:hs1" => 0,
StateEventType::RoomMember => "@c:hs1" => 2,
],
]
.iter(),
);
assert_eq!(
unconflicted,
state_set![
StateEventType::RoomMember => "@a:hs1" => 0,
],
);
assert_eq!(
conflicted,
state_set![
StateEventType::RoomMember => "@b:hs1" => vec![1],
StateEventType::RoomMember => "@c:hs1" => vec![2],
],
);
}
}