ruma_macros/events/
event_type.rs

1use proc_macro2::{Span, TokenStream};
2use quote::quote;
3use syn::{parse_quote, Ident, LitStr};
4
5use super::event_parse::{EventEnumEntry, EventEnumInput, EventKind};
6
7pub fn expand_event_type_enum(
8    input: EventEnumInput,
9    ruma_events: TokenStream,
10) -> syn::Result<TokenStream> {
11    let mut timeline: Vec<&Vec<EventEnumEntry>> = vec![];
12    let mut state: Vec<&Vec<EventEnumEntry>> = vec![];
13    let mut message: Vec<&Vec<EventEnumEntry>> = vec![];
14    let mut ephemeral: Vec<&Vec<EventEnumEntry>> = vec![];
15    let mut room_account: Vec<&Vec<EventEnumEntry>> = vec![];
16    let mut global_account: Vec<&Vec<EventEnumEntry>> = vec![];
17    let mut to_device: Vec<&Vec<EventEnumEntry>> = vec![];
18    for event in &input.enums {
19        match event.kind {
20            EventKind::GlobalAccountData => global_account.push(&event.events),
21            EventKind::RoomAccountData => room_account.push(&event.events),
22            EventKind::Ephemeral => ephemeral.push(&event.events),
23            EventKind::MessageLike => {
24                message.push(&event.events);
25                timeline.push(&event.events);
26            }
27            EventKind::State => {
28                state.push(&event.events);
29                timeline.push(&event.events);
30            }
31            EventKind::ToDevice => to_device.push(&event.events),
32            EventKind::RoomRedaction
33            | EventKind::Presence
34            | EventKind::Decrypted
35            | EventKind::HierarchySpaceChild => {}
36        }
37    }
38    let presence = vec![EventEnumEntry {
39        attrs: vec![],
40        aliases: vec![],
41        ev_type: LitStr::new("m.presence", Span::call_site()),
42        ev_path: parse_quote! { #ruma_events::presence },
43        ident: None,
44    }];
45    let mut all = input.enums.iter().map(|e| &e.events).collect::<Vec<_>>();
46    all.push(&presence);
47
48    let mut res = TokenStream::new();
49
50    res.extend(
51        generate_enum("TimelineEventType", &timeline, &ruma_events)
52            .unwrap_or_else(syn::Error::into_compile_error),
53    );
54    res.extend(
55        generate_enum("StateEventType", &state, &ruma_events)
56            .unwrap_or_else(syn::Error::into_compile_error),
57    );
58    res.extend(
59        generate_enum("MessageLikeEventType", &message, &ruma_events)
60            .unwrap_or_else(syn::Error::into_compile_error),
61    );
62    res.extend(
63        generate_enum("EphemeralRoomEventType", &ephemeral, &ruma_events)
64            .unwrap_or_else(syn::Error::into_compile_error),
65    );
66    res.extend(
67        generate_enum("RoomAccountDataEventType", &room_account, &ruma_events)
68            .unwrap_or_else(syn::Error::into_compile_error),
69    );
70    res.extend(
71        generate_enum("GlobalAccountDataEventType", &global_account, &ruma_events)
72            .unwrap_or_else(syn::Error::into_compile_error),
73    );
74    res.extend(
75        generate_enum("ToDeviceEventType", &to_device, &ruma_events)
76            .unwrap_or_else(syn::Error::into_compile_error),
77    );
78
79    Ok(res)
80}
81
82fn generate_enum(
83    ident: &str,
84    input: &[&Vec<EventEnumEntry>],
85    ruma_common: &TokenStream,
86) -> syn::Result<TokenStream> {
87    let serde = quote! { #ruma_common::exports::serde };
88    let enum_doc = format!("The type of `{}` this is.", ident.strip_suffix("Type").unwrap());
89
90    let ident = Ident::new(ident, Span::call_site());
91
92    let mut deduped: Vec<&EventEnumEntry> = vec![];
93    for item in input.iter().copied().flatten() {
94        if let Some(idx) = deduped.iter().position(|e| e.ev_type == item.ev_type) {
95            // If there is a variant without config attributes use that
96            if deduped[idx].attrs != item.attrs && item.attrs.is_empty() {
97                deduped[idx] = item;
98            }
99        } else {
100            deduped.push(item);
101        }
102    }
103
104    let event_types = deduped.iter().map(|e| &e.ev_type);
105
106    let variants: Vec<_> = deduped
107        .iter()
108        .map(|e| {
109            let start = e.to_variant()?.decl();
110            let data = e.has_type_fragment().then(|| quote! { (::std::string::String) });
111
112            Ok(quote! {
113                #start #data
114            })
115        })
116        .collect::<syn::Result<_>>()?;
117
118    let to_cow_str_match_arms: Vec<_> = deduped
119        .iter()
120        .map(|e| {
121            let v = e.to_variant()?;
122            let start = v.match_arm(quote! { Self });
123            let ev_type = &e.ev_type;
124
125            Ok(if let Some(prefix) = ev_type.value().strip_suffix(".*") {
126                let fstr = prefix.to_owned() + ".{}";
127                quote! { #start(_s) => ::std::borrow::Cow::Owned(::std::format!(#fstr, _s)) }
128            } else {
129                quote! { #start => ::std::borrow::Cow::Borrowed(#ev_type) }
130            })
131        })
132        .collect::<syn::Result<_>>()?;
133
134    let mut from_str_match_arms = TokenStream::new();
135    for event in &deduped {
136        let v = event.to_variant()?;
137        let ctor = v.ctor(quote! { Self });
138        let ev_types = event.aliases.iter().chain([&event.ev_type]);
139        let attrs = &event.attrs;
140
141        if event.ev_type.value().ends_with(".*") {
142            for ev_type in ev_types {
143                let name = ev_type.value();
144                let prefix = name
145                    .strip_suffix('*')
146                    .expect("aliases have already been checked to have the same suffix");
147
148                from_str_match_arms.extend(quote! {
149                    #(#attrs)*
150                    // Use if-let guard once available
151                    _s if _s.starts_with(#prefix) => {
152                        #ctor(::std::convert::From::from(_s.strip_prefix(#prefix).unwrap()))
153                    }
154                });
155            }
156        } else {
157            from_str_match_arms.extend(quote! { #(#attrs)* #(#ev_types)|* => #ctor, });
158        }
159    }
160
161    let from_ident_for_timeline = if ident == "StateEventType" || ident == "MessageLikeEventType" {
162        let match_arms: Vec<_> = deduped
163            .iter()
164            .map(|e| {
165                let v = e.to_variant()?;
166                let ident_var = v.match_arm(quote! { #ident });
167                let timeline_var = v.ctor(quote! { Self });
168
169                Ok(if e.has_type_fragment() {
170                    quote! { #ident_var (_s) => #timeline_var (_s) }
171                } else {
172                    quote! { #ident_var => #timeline_var }
173                })
174            })
175            .collect::<syn::Result<_>>()?;
176
177        Some(quote! {
178            #[allow(deprecated)]
179            impl ::std::convert::From<#ident> for TimelineEventType {
180                fn from(s: #ident) -> Self {
181                    match s {
182                        #(#match_arms,)*
183                        #ident ::_Custom(_s) => Self::_Custom(_s),
184                    }
185                }
186            }
187        })
188    } else {
189        None
190    };
191
192    Ok(quote! {
193        #[doc = #enum_doc]
194        ///
195        /// This type can hold an arbitrary string. To build events with a custom type, convert it
196        /// from a string with `::from()` / `.into()`. To check for events that are not available as a
197        /// documented variant here, use its string representation, obtained through `.to_string()`.
198        #[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
199        #[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
200        pub enum #ident {
201            #(
202                #[doc = #event_types]
203                #variants,
204            )*
205            #[doc(hidden)]
206            _Custom(crate::PrivOwnedStr),
207        }
208
209        #[allow(deprecated)]
210        impl #ident {
211            fn to_cow_str(&self) -> ::std::borrow::Cow<'_, ::std::primitive::str> {
212                match self {
213                    #(#to_cow_str_match_arms,)*
214                    Self::_Custom(crate::PrivOwnedStr(s)) => ::std::borrow::Cow::Borrowed(s),
215                }
216            }
217        }
218
219        #[allow(deprecated)]
220        impl ::std::fmt::Display for #ident {
221            fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
222                self.to_cow_str().fmt(f)
223            }
224        }
225
226        #[allow(deprecated)]
227        impl ::std::fmt::Debug for #ident {
228            fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
229                <str as ::std::fmt::Debug>::fmt(&self.to_cow_str(), f)
230            }
231        }
232
233        #[allow(deprecated)]
234        impl ::std::convert::From<&::std::primitive::str> for #ident {
235            fn from(s: &::std::primitive::str) -> Self {
236                match s {
237                    #from_str_match_arms
238                    _ => Self::_Custom(crate::PrivOwnedStr(::std::convert::From::from(s))),
239                }
240            }
241        }
242
243        #[allow(deprecated)]
244        impl ::std::convert::From<::std::string::String> for #ident {
245            fn from(s: ::std::string::String) -> Self {
246                ::std::convert::From::from(s.as_str())
247            }
248        }
249
250        #[allow(deprecated)]
251        impl<'de> #serde::Deserialize<'de> for #ident {
252            fn deserialize<D>(deserializer: D) -> ::std::result::Result<Self, D::Error>
253            where
254                D: #serde::Deserializer<'de>
255            {
256                let s = #ruma_common::serde::deserialize_cow_str(deserializer)?;
257                Ok(::std::convert::From::from(&s[..]))
258            }
259        }
260
261        #[allow(deprecated)]
262        impl #serde::Serialize for #ident {
263            fn serialize<S>(&self, serializer: S) -> ::std::result::Result<S::Ok, S::Error>
264            where
265                S: #serde::Serializer,
266            {
267                self.to_cow_str().serialize(serializer)
268            }
269        }
270
271        #from_ident_for_timeline
272    })
273}