ruma_events/
poll.rs

1//! Modules for events in the `m.poll` namespace ([MSC3381]).
2//!
3//! This module also contains types shared by events in its child namespaces.
4//!
5//! [MSC3381]: https://github.com/matrix-org/matrix-spec-proposals/pull/3381
6
7use std::{
8    collections::{BTreeMap, BTreeSet},
9    ops::Deref,
10};
11
12use indexmap::IndexMap;
13use js_int::{UInt, uint};
14use ruma_common::{MilliSecondsSinceUnixEpoch, UserId};
15
16use self::{start::PollContentBlock, unstable_start::UnstablePollStartContentBlock};
17
18pub mod end;
19pub mod response;
20pub mod start;
21pub mod unstable_end;
22pub mod unstable_response;
23pub mod unstable_start;
24
25/// The data from a poll response necessary to compile poll results.
26#[derive(Debug, Clone, Copy)]
27#[allow(clippy::exhaustive_structs)]
28pub struct PollResponseData<'a> {
29    /// The sender of the response.
30    pub sender: &'a UserId,
31
32    /// The time of creation of the response on the originating server.
33    pub origin_server_ts: MilliSecondsSinceUnixEpoch,
34
35    /// The selections/answers of the response.
36    pub selections: &'a [String],
37}
38
39/// Generate the current results with the given poll and responses.
40///
41/// If the `end_timestamp` is provided, any response with an `origin_server_ts` after that timestamp
42/// is ignored. If it is not provided, `MilliSecondsSinceUnixEpoch::now()` will be used instead.
43///
44/// This method will handle invalid responses, or several response from the same user so all
45/// responses to the poll should be provided.
46///
47/// Returns a map of answer ID to a set of user IDs that voted for them. When using `.iter()` or
48/// `.into_iter()` on the map, the results are sorted from the highest number of votes to the
49/// lowest.
50pub fn compile_poll_results<'a>(
51    poll: &'a PollContentBlock,
52    responses: impl IntoIterator<Item = PollResponseData<'a>>,
53    end_timestamp: Option<MilliSecondsSinceUnixEpoch>,
54) -> IndexMap<&'a str, BTreeSet<&'a UserId>> {
55    let answer_ids = poll.answers.iter().map(|a| a.id.as_str()).collect();
56    let users_selections =
57        filter_selections(answer_ids, poll.max_selections, responses, end_timestamp);
58
59    aggregate_results(poll.answers.iter().map(|a| a.id.as_str()), users_selections)
60}
61
62/// Generate the current results with the given unstable poll and responses.
63///
64/// If the `end_timestamp` is provided, any response with an `origin_server_ts` after that timestamp
65/// is ignored. If it is not provided, `MilliSecondsSinceUnixEpoch::now()` will be used instead.
66///
67/// This method will handle invalid responses, or several response from the same user so all
68/// responses to the poll should be provided.
69///
70/// Returns a map of answer ID to a set of user IDs that voted for them. When using `.iter()` or
71/// `.into_iter()` on the map, the results are sorted from the highest number of votes to the
72/// lowest.
73pub fn compile_unstable_poll_results<'a>(
74    poll: &'a UnstablePollStartContentBlock,
75    responses: impl IntoIterator<Item = PollResponseData<'a>>,
76    end_timestamp: Option<MilliSecondsSinceUnixEpoch>,
77) -> IndexMap<&'a str, BTreeSet<&'a UserId>> {
78    let answer_ids = poll.answers.iter().map(|a| a.id.as_str()).collect();
79    let users_selections =
80        filter_selections(answer_ids, poll.max_selections, responses, end_timestamp);
81
82    aggregate_results(poll.answers.iter().map(|a| a.id.as_str()), users_selections)
83}
84
85/// Validate the selections of a response.
86fn validate_selections<'a>(
87    answer_ids: &BTreeSet<&str>,
88    max_selections: UInt,
89    selections: &'a [String],
90) -> Option<impl Iterator<Item = &'a str> + use<'a>> {
91    // Vote is spoiled if any answer is unknown.
92    if selections.iter().any(|s| !answer_ids.contains(s.as_str())) {
93        return None;
94    }
95
96    // Fallback to the maximum value for usize because we can't have more selections than that
97    // in memory.
98    let max_selections: usize = max_selections.try_into().unwrap_or(usize::MAX);
99
100    Some(selections.iter().take(max_selections).map(Deref::deref))
101}
102
103fn filter_selections<'a, R>(
104    answer_ids: BTreeSet<&str>,
105    max_selections: UInt,
106    responses: R,
107    end_timestamp: Option<MilliSecondsSinceUnixEpoch>,
108) -> BTreeMap<
109    &'a UserId,
110    (MilliSecondsSinceUnixEpoch, Option<impl Iterator<Item = &'a str> + use<'a, R>>),
111>
112where
113    R: IntoIterator<Item = PollResponseData<'a>>,
114{
115    responses
116        .into_iter()
117        .filter(|ev| {
118            // Filter out responses after the end_timestamp.
119            end_timestamp.is_none_or(|end_ts| ev.origin_server_ts <= end_ts)
120        })
121        .fold(BTreeMap::new(), |mut acc, data| {
122            let response =
123                acc.entry(data.sender).or_insert((MilliSecondsSinceUnixEpoch(uint!(0)), None));
124
125            // Only keep the latest selections for each user.
126            if response.0 < data.origin_server_ts {
127                *response = (
128                    data.origin_server_ts,
129                    validate_selections(&answer_ids, max_selections, data.selections),
130                );
131            }
132
133            acc
134        })
135}
136
137/// Aggregate the given selections by answer.
138fn aggregate_results<'a>(
139    answers: impl Iterator<Item = &'a str>,
140    users_selections: BTreeMap<
141        &'a UserId,
142        (MilliSecondsSinceUnixEpoch, Option<impl Iterator<Item = &'a str>>),
143    >,
144) -> IndexMap<&'a str, BTreeSet<&'a UserId>> {
145    let mut results = IndexMap::from_iter(answers.into_iter().map(|a| (a, BTreeSet::new())));
146
147    for (user, (_, selections)) in users_selections {
148        if let Some(selections) = selections {
149            for selection in selections {
150                results
151                    .get_mut(selection)
152                    .expect("validated selections should only match possible answers")
153                    .insert(user);
154            }
155        }
156    }
157
158    results.sort_by(|_, a, _, b| b.len().cmp(&a.len()));
159
160    results
161}
162
163/// Generate the fallback text representation of a poll end event.
164///
165/// This is a sentence that lists the top answers for the given results, in english. It is used to
166/// generate a valid poll end event when using
167/// `OriginalSync(Unstable)PollStartEvent::compile_results()`.
168///
169/// `answers` is an iterator of `(answer ID, answer plain text representation)` and `results` is an
170/// iterator of `(answer ID, count)` ordered in descending order.
171fn generate_poll_end_fallback_text<'a>(
172    answers: &[(&'a str, &'a str)],
173    results: impl Iterator<Item = (&'a str, usize)>,
174) -> String {
175    let mut top_answers = Vec::new();
176    let mut top_count = 0;
177
178    for (id, count) in results {
179        if count >= top_count {
180            top_answers.push(id);
181            top_count = count;
182        } else {
183            break;
184        }
185    }
186
187    let top_answers_text = top_answers
188        .into_iter()
189        .map(|id| {
190            answers
191                .iter()
192                .find(|(a_id, _)| *a_id == id)
193                .expect("top answer ID should be a valid answer ID")
194                .1
195        })
196        .collect::<Vec<_>>();
197
198    // Construct the plain text representation.
199    match top_answers_text.len() {
200        0 => "The poll has closed with no top answer".to_owned(),
201        1 => {
202            format!("The poll has closed. Top answer: {}", top_answers_text[0])
203        }
204        _ => {
205            let answers = top_answers_text.join(", ");
206            format!("The poll has closed. Top answers: {answers}")
207        }
208    }
209}