ruma_events/
poll.rs

1//! Modules for events in the `m.poll` namespace ([MSC3381]).
2//!
3//! This module also contains types shared by events in its child namespaces.
4//!
5//! [MSC3381]: https://github.com/matrix-org/matrix-spec-proposals/pull/3381
6
7use std::{
8    collections::{BTreeMap, BTreeSet},
9    ops::Deref,
10};
11
12use indexmap::IndexMap;
13use js_int::{uint, UInt};
14use ruma_common::{MilliSecondsSinceUnixEpoch, UserId};
15
16use self::{start::PollContentBlock, unstable_start::UnstablePollStartContentBlock};
17
18pub mod end;
19pub mod response;
20pub mod start;
21pub mod unstable_end;
22pub mod unstable_response;
23pub mod unstable_start;
24
25/// The data from a poll response necessary to compile poll results.
26#[derive(Debug, Clone, Copy)]
27#[allow(clippy::exhaustive_structs)]
28pub struct PollResponseData<'a> {
29    /// The sender of the response.
30    pub sender: &'a UserId,
31
32    /// The time of creation of the response on the originating server.
33    pub origin_server_ts: MilliSecondsSinceUnixEpoch,
34
35    /// The selections/answers of the response.
36    pub selections: &'a [String],
37}
38
39/// Generate the current results with the given poll and responses.
40///
41/// If the `end_timestamp` is provided, any response with an `origin_server_ts` after that timestamp
42/// is ignored. If it is not provided, `MilliSecondsSinceUnixEpoch::now()` will be used instead.
43///
44/// This method will handle invalid responses, or several response from the same user so all
45/// responses to the poll should be provided.
46///
47/// Returns a map of answer ID to a set of user IDs that voted for them. When using `.iter()` or
48/// `.into_iter()` on the map, the results are sorted from the highest number of votes to the
49/// lowest.
50pub fn compile_poll_results<'a>(
51    poll: &'a PollContentBlock,
52    responses: impl IntoIterator<Item = PollResponseData<'a>>,
53    end_timestamp: Option<MilliSecondsSinceUnixEpoch>,
54) -> IndexMap<&'a str, BTreeSet<&'a UserId>> {
55    let answer_ids = poll.answers.iter().map(|a| a.id.as_str()).collect();
56    let users_selections =
57        filter_selections(answer_ids, poll.max_selections, responses, end_timestamp);
58
59    aggregate_results(poll.answers.iter().map(|a| a.id.as_str()), users_selections)
60}
61
62/// Generate the current results with the given unstable poll and responses.
63///
64/// If the `end_timestamp` is provided, any response with an `origin_server_ts` after that timestamp
65/// is ignored. If it is not provided, `MilliSecondsSinceUnixEpoch::now()` will be used instead.
66///
67/// This method will handle invalid responses, or several response from the same user so all
68/// responses to the poll should be provided.
69///
70/// Returns a map of answer ID to a set of user IDs that voted for them. When using `.iter()` or
71/// `.into_iter()` on the map, the results are sorted from the highest number of votes to the
72/// lowest.
73pub fn compile_unstable_poll_results<'a>(
74    poll: &'a UnstablePollStartContentBlock,
75    responses: impl IntoIterator<Item = PollResponseData<'a>>,
76    end_timestamp: Option<MilliSecondsSinceUnixEpoch>,
77) -> IndexMap<&'a str, BTreeSet<&'a UserId>> {
78    let answer_ids = poll.answers.iter().map(|a| a.id.as_str()).collect();
79    let users_selections =
80        filter_selections(answer_ids, poll.max_selections, responses, end_timestamp);
81
82    aggregate_results(poll.answers.iter().map(|a| a.id.as_str()), users_selections)
83}
84
85/// Validate the selections of a response.
86fn validate_selections<'a>(
87    answer_ids: &BTreeSet<&str>,
88    max_selections: UInt,
89    selections: &'a [String],
90) -> Option<impl Iterator<Item = &'a str>> {
91    // Vote is spoiled if any answer is unknown.
92    if selections.iter().any(|s| !answer_ids.contains(s.as_str())) {
93        return None;
94    }
95
96    // Fallback to the maximum value for usize because we can't have more selections than that
97    // in memory.
98    let max_selections: usize = max_selections.try_into().unwrap_or(usize::MAX);
99
100    Some(selections.iter().take(max_selections).map(Deref::deref))
101}
102
103fn filter_selections<'a>(
104    answer_ids: BTreeSet<&str>,
105    max_selections: UInt,
106    responses: impl IntoIterator<Item = PollResponseData<'a>>,
107    end_timestamp: Option<MilliSecondsSinceUnixEpoch>,
108) -> BTreeMap<&'a UserId, (MilliSecondsSinceUnixEpoch, Option<impl Iterator<Item = &'a str>>)> {
109    responses
110        .into_iter()
111        .filter(|ev| {
112            // Filter out responses after the end_timestamp.
113            end_timestamp.map_or(true, |end_ts| ev.origin_server_ts <= end_ts)
114        })
115        .fold(BTreeMap::new(), |mut acc, data| {
116            let response =
117                acc.entry(data.sender).or_insert((MilliSecondsSinceUnixEpoch(uint!(0)), None));
118
119            // Only keep the latest selections for each user.
120            if response.0 < data.origin_server_ts {
121                *response = (
122                    data.origin_server_ts,
123                    validate_selections(&answer_ids, max_selections, data.selections),
124                );
125            }
126
127            acc
128        })
129}
130
131/// Aggregate the given selections by answer.
132fn aggregate_results<'a>(
133    answers: impl Iterator<Item = &'a str>,
134    users_selections: BTreeMap<
135        &'a UserId,
136        (MilliSecondsSinceUnixEpoch, Option<impl Iterator<Item = &'a str>>),
137    >,
138) -> IndexMap<&'a str, BTreeSet<&'a UserId>> {
139    let mut results = IndexMap::from_iter(answers.into_iter().map(|a| (a, BTreeSet::new())));
140
141    for (user, (_, selections)) in users_selections {
142        if let Some(selections) = selections {
143            for selection in selections {
144                results
145                    .get_mut(selection)
146                    .expect("validated selections should only match possible answers")
147                    .insert(user);
148            }
149        }
150    }
151
152    results.sort_by(|_, a, _, b| b.len().cmp(&a.len()));
153
154    results
155}
156
157/// Generate the fallback text representation of a poll end event.
158///
159/// This is a sentence that lists the top answers for the given results, in english. It is used to
160/// generate a valid poll end event when using
161/// `OriginalSync(Unstable)PollStartEvent::compile_results()`.
162///
163/// `answers` is an iterator of `(answer ID, answer plain text representation)` and `results` is an
164/// iterator of `(answer ID, count)` ordered in descending order.
165fn generate_poll_end_fallback_text<'a>(
166    answers: &[(&'a str, &'a str)],
167    results: impl Iterator<Item = (&'a str, usize)>,
168) -> String {
169    let mut top_answers = Vec::new();
170    let mut top_count = 0;
171
172    for (id, count) in results {
173        if count >= top_count {
174            top_answers.push(id);
175            top_count = count;
176        } else {
177            break;
178        }
179    }
180
181    let top_answers_text = top_answers
182        .into_iter()
183        .map(|id| {
184            answers
185                .iter()
186                .find(|(a_id, _)| *a_id == id)
187                .expect("top answer ID should be a valid answer ID")
188                .1
189        })
190        .collect::<Vec<_>>();
191
192    // Construct the plain text representation.
193    match top_answers_text.len() {
194        0 => "The poll has closed with no top answer".to_owned(),
195        1 => {
196            format!("The poll has closed. Top answer: {}", top_answers_text[0])
197        }
198        _ => {
199            let answers = top_answers_text.join(", ");
200            format!("The poll has closed. Top answers: {answers}")
201        }
202    }
203}