ruma_macros/events/
event.rs
1use proc_macro2::{Span, TokenStream};
4use quote::quote;
5use syn::{parse_quote, Data, DataStruct, DeriveInput, Field, Fields, FieldsNamed};
6
7use super::{
8 event_parse::{to_kind_variation, EventKind, EventKindVariation},
9 util::is_non_stripped_room_event,
10};
11use crate::{import_ruma_events, util::to_camel_case};
12
13pub fn expand_event(input: DeriveInput) -> syn::Result<TokenStream> {
15 let ruma_events = import_ruma_events();
16
17 let ident = &input.ident;
18 let (kind, var) = to_kind_variation(ident).ok_or_else(|| {
19 syn::Error::new_spanned(ident, "not a valid ruma event struct identifier")
20 })?;
21
22 let fields: Vec<_> = if let Data::Struct(DataStruct {
23 fields: Fields::Named(FieldsNamed { named, .. }),
24 ..
25 }) = &input.data
26 {
27 if !named.iter().any(|f| f.ident.as_ref().unwrap() == "content") {
28 return Err(syn::Error::new(
29 Span::call_site(),
30 "struct must contain a `content` field",
31 ));
32 }
33
34 named.iter().cloned().collect()
35 } else {
36 return Err(syn::Error::new_spanned(
37 input.ident,
38 "the `Event` derive only supports structs with named fields",
39 ));
40 };
41
42 let mut res = TokenStream::new();
43
44 res.extend(
45 expand_deserialize_event(&input, var, &fields, &ruma_events)
46 .unwrap_or_else(syn::Error::into_compile_error),
47 );
48
49 if var.is_sync() {
50 res.extend(expand_sync_from_into_full(&input, kind, var, &fields, &ruma_events));
51 }
52
53 if is_non_stripped_room_event(kind, var) {
54 res.extend(expand_eq_ord_event(&input));
55 }
56
57 Ok(res)
58}
59
60fn expand_deserialize_event(
61 input: &DeriveInput,
62 var: EventKindVariation,
63 fields: &[Field],
64 ruma_events: &TokenStream,
65) -> syn::Result<TokenStream> {
66 let serde = quote! { #ruma_events::exports::serde };
67 let serde_json = quote! { #ruma_events::exports::serde_json };
68
69 let ident = &input.ident;
70 let content_type = &fields
72 .iter()
73 .find(|f| f.ident.as_ref().unwrap() == "content")
75 .unwrap()
76 .ty;
77
78 let (impl_generics, ty_gen, where_clause) = input.generics.split_for_impl();
79 let is_generic = !input.generics.params.is_empty();
80
81 let enum_variants: Vec<_> = fields
82 .iter()
83 .map(|field| {
84 let name = field.ident.as_ref().unwrap();
85 to_camel_case(name)
86 })
87 .collect();
88
89 let deserialize_var_types: Vec<_> = fields
90 .iter()
91 .map(|field| {
92 let name = field.ident.as_ref().unwrap();
93 if name == "content" {
94 if is_generic {
95 quote! { ::std::boxed::Box<#serde_json::value::RawValue> }
96 } else {
97 quote! { #content_type }
98 }
99 } else if name == "state_key" && var == EventKindVariation::Initial {
100 quote! { ::std::string::String }
101 } else {
102 let ty = &field.ty;
103 quote! { #ty }
104 }
105 })
106 .collect();
107
108 let ok_or_else_fields: Vec<_> = fields
109 .iter()
110 .map(|field| {
111 let name = field.ident.as_ref().unwrap();
112 Ok(if name == "content" && is_generic {
113 quote! {
114 let content = {
115 let json = content
116 .ok_or_else(|| #serde::de::Error::missing_field("content"))?;
117 C::from_parts(&event_type, &json).map_err(#serde::de::Error::custom)?
118 };
119 }
120 } else if name == "unsigned" && !var.is_redacted() {
121 quote! {
122 let unsigned = unsigned.unwrap_or_default();
123 }
124 } else if name == "state_key" && var == EventKindVariation::Initial {
125 let ty = &field.ty;
126 quote! {
127 let state_key: ::std::string::String = state_key.unwrap_or_default();
128 let state_key: #ty = <#ty as #serde::de::Deserialize>::deserialize(
129 #serde::de::IntoDeserializer::<A::Error>::into_deserializer(state_key),
130 )?;
131 }
132 } else {
133 quote! {
134 let #name = #name.ok_or_else(|| {
135 #serde::de::Error::missing_field(stringify!(#name))
136 })?;
137 }
138 })
139 })
140 .collect::<syn::Result<_>>()?;
141
142 let field_names: Vec<_> = fields.iter().flat_map(|f| &f.ident).collect();
143
144 let deserialize_impl_gen = if is_generic {
145 let gen = &input.generics.params;
146 quote! { <'de, #gen> }
147 } else {
148 quote! { <'de> }
149 };
150 let deserialize_phantom_type = if is_generic {
151 quote! { ::std::marker::PhantomData }
152 } else {
153 quote! {}
154 };
155 let where_clause = if is_generic {
156 let predicate = parse_quote! { C: #ruma_events::EventContentFromType };
157 if let Some(mut where_clause) = where_clause.cloned() {
158 where_clause.predicates.push(predicate);
159 Some(where_clause)
160 } else {
161 Some(parse_quote! { where #predicate })
162 }
163 } else {
164 where_clause.cloned()
165 };
166
167 Ok(quote! {
168 #[automatically_derived]
169 impl #deserialize_impl_gen #serde::de::Deserialize<'de> for #ident #ty_gen #where_clause {
170 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
171 where
172 D: #serde::de::Deserializer<'de>,
173 {
174 #[derive(#serde::Deserialize)]
175 #[serde(field_identifier, rename_all = "snake_case")]
176 enum Field {
177 Type,
180 #( #enum_variants, )*
181 #[serde(other)]
182 Unknown,
183 }
184
185 struct EventVisitor #impl_generics (#deserialize_phantom_type #ty_gen);
188
189 #[automatically_derived]
190 impl #deserialize_impl_gen #serde::de::Visitor<'de>
191 for EventVisitor #ty_gen #where_clause
192 {
193 type Value = #ident #ty_gen;
194
195 fn expecting(
196 &self,
197 formatter: &mut ::std::fmt::Formatter<'_>,
198 ) -> ::std::fmt::Result {
199 write!(formatter, "struct implementing {}", stringify!(#content_type))
200 }
201
202 fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
203 where
204 A: #serde::de::MapAccess<'de>,
205 {
206 let mut event_type: Option<String> = None;
207 #( let mut #field_names: Option<#deserialize_var_types> = None; )*
208
209 while let Some(key) = map.next_key()? {
210 match key {
211 Field::Unknown => {
212 let _: #serde::de::IgnoredAny = map.next_value()?;
213 },
214 Field::Type => {
215 if event_type.is_some() {
216 return Err(#serde::de::Error::duplicate_field("type"));
217 }
218 event_type = Some(map.next_value()?);
219 }
220 #(
221 Field::#enum_variants => {
222 if #field_names.is_some() {
223 return Err(#serde::de::Error::duplicate_field(
224 stringify!(#field_names),
225 ));
226 }
227 #field_names = Some(map.next_value()?);
228 }
229 )*
230 }
231 }
232
233 let event_type =
234 event_type.ok_or_else(|| #serde::de::Error::missing_field("type"))?;
235 #( #ok_or_else_fields )*
236
237 Ok(#ident {
238 #( #field_names ),*
239 })
240 }
241 }
242
243 deserializer.deserialize_map(EventVisitor(#deserialize_phantom_type))
244 }
245 }
246 })
247}
248
249fn expand_sync_from_into_full(
250 input: &DeriveInput,
251 kind: EventKind,
252 var: EventKindVariation,
253 fields: &[Field],
254 ruma_events: &TokenStream,
255) -> syn::Result<TokenStream> {
256 let ident = &input.ident;
257 let full_struct = kind.to_event_ident(var.to_full())?;
258 let (impl_generics, ty_gen, where_clause) = input.generics.split_for_impl();
259 let fields: Vec<_> = fields.iter().flat_map(|f| &f.ident).collect();
260
261 Ok(quote! {
262 #[automatically_derived]
263 impl #impl_generics ::std::convert::From<#full_struct #ty_gen>
264 for #ident #ty_gen #where_clause
265 {
266 fn from(event: #full_struct #ty_gen) -> Self {
267 let #full_struct { #( #fields, )* .. } = event;
268 Self { #( #fields, )* }
269 }
270 }
271
272 #[automatically_derived]
273 impl #impl_generics #ident #ty_gen #where_clause {
274 pub fn into_full_event(
276 self,
277 room_id: #ruma_events::exports::ruma_common::OwnedRoomId,
278 ) -> #full_struct #ty_gen {
279 let Self { #( #fields, )* } = self;
280 #full_struct {
281 #( #fields, )*
282 room_id,
283 }
284 }
285 }
286 })
287}
288
289fn expand_eq_ord_event(input: &DeriveInput) -> TokenStream {
290 let ident = &input.ident;
291 let (impl_gen, ty_gen, where_clause) = input.generics.split_for_impl();
292
293 quote! {
294 #[automatically_derived]
295 impl #impl_gen ::std::cmp::PartialEq for #ident #ty_gen #where_clause {
296 fn eq(&self, other: &Self) -> ::std::primitive::bool {
298 self.event_id == other.event_id
299 }
300 }
301
302 #[automatically_derived]
303 impl #impl_gen ::std::cmp::Eq for #ident #ty_gen #where_clause {}
304
305 #[automatically_derived]
306 impl #impl_gen ::std::cmp::PartialOrd for #ident #ty_gen #where_clause {
307 fn partial_cmp(&self, other: &Self) -> ::std::option::Option<::std::cmp::Ordering> {
309 self.event_id.partial_cmp(&other.event_id)
310 }
311 }
312
313 #[automatically_derived]
314 impl #impl_gen ::std::cmp::Ord for #ident #ty_gen #where_clause {
315 fn cmp(&self, other: &Self) -> ::std::cmp::Ordering {
317 self.event_id.cmp(&other.event_id)
318 }
319 }
320 }
321}