ruma_macros/events/
event.rs

1//! Implementation of the top level `*Event` derive macro.
2
3use proc_macro2::{Span, TokenStream};
4use quote::quote;
5use syn::{parse_quote, Data, DataStruct, DeriveInput, Field, Fields, FieldsNamed};
6
7use super::{
8    event_parse::{to_kind_variation, EventKind, EventKindVariation},
9    util::is_non_stripped_room_event,
10};
11use crate::{import_ruma_events, util::to_camel_case};
12
13/// Derive `Event` macro code generation.
14pub fn expand_event(input: DeriveInput) -> syn::Result<TokenStream> {
15    let ruma_events = import_ruma_events();
16
17    let ident = &input.ident;
18    let (kind, var) = to_kind_variation(ident).ok_or_else(|| {
19        syn::Error::new_spanned(ident, "not a valid ruma event struct identifier")
20    })?;
21
22    let fields: Vec<_> = if let Data::Struct(DataStruct {
23        fields: Fields::Named(FieldsNamed { named, .. }),
24        ..
25    }) = &input.data
26    {
27        if !named.iter().any(|f| f.ident.as_ref().unwrap() == "content") {
28            return Err(syn::Error::new(
29                Span::call_site(),
30                "struct must contain a `content` field",
31            ));
32        }
33
34        named.iter().cloned().collect()
35    } else {
36        return Err(syn::Error::new_spanned(
37            input.ident,
38            "the `Event` derive only supports structs with named fields",
39        ));
40    };
41
42    let mut res = TokenStream::new();
43
44    res.extend(
45        expand_deserialize_event(&input, var, &fields, &ruma_events)
46            .unwrap_or_else(syn::Error::into_compile_error),
47    );
48
49    if var.is_sync() {
50        res.extend(expand_sync_from_into_full(&input, kind, var, &fields, &ruma_events));
51    }
52
53    if is_non_stripped_room_event(kind, var) {
54        res.extend(expand_eq_ord_event(&input));
55    }
56
57    Ok(res)
58}
59
60fn expand_deserialize_event(
61    input: &DeriveInput,
62    var: EventKindVariation,
63    fields: &[Field],
64    ruma_events: &TokenStream,
65) -> syn::Result<TokenStream> {
66    let serde = quote! { #ruma_events::exports::serde };
67    let serde_json = quote! { #ruma_events::exports::serde_json };
68
69    let ident = &input.ident;
70    // we know there is a content field already
71    let content_type = &fields
72        .iter()
73        // we also know that the fields are named and have an ident
74        .find(|f| f.ident.as_ref().unwrap() == "content")
75        .unwrap()
76        .ty;
77
78    let (impl_generics, ty_gen, where_clause) = input.generics.split_for_impl();
79    let is_generic = !input.generics.params.is_empty();
80
81    let enum_variants: Vec<_> = fields
82        .iter()
83        .map(|field| {
84            let name = field.ident.as_ref().unwrap();
85            to_camel_case(name)
86        })
87        .collect();
88
89    let deserialize_var_types: Vec<_> = fields
90        .iter()
91        .map(|field| {
92            let name = field.ident.as_ref().unwrap();
93            if name == "content" {
94                if is_generic {
95                    quote! { ::std::boxed::Box<#serde_json::value::RawValue> }
96                } else {
97                    quote! { #content_type }
98                }
99            } else if name == "state_key" && var == EventKindVariation::Initial {
100                quote! { ::std::string::String }
101            } else {
102                let ty = &field.ty;
103                quote! { #ty }
104            }
105        })
106        .collect();
107
108    let ok_or_else_fields: Vec<_> = fields
109        .iter()
110        .map(|field| {
111            let name = field.ident.as_ref().unwrap();
112            Ok(if name == "content" && is_generic {
113                quote! {
114                    let content = {
115                        let json = content
116                            .ok_or_else(|| #serde::de::Error::missing_field("content"))?;
117                        C::from_parts(&event_type, &json).map_err(#serde::de::Error::custom)?
118                    };
119                }
120            } else if name == "unsigned" && !var.is_redacted() {
121                quote! {
122                    let unsigned = unsigned.unwrap_or_default();
123                }
124            } else if name == "state_key" && var == EventKindVariation::Initial {
125                let ty = &field.ty;
126                quote! {
127                    let state_key: ::std::string::String = state_key.unwrap_or_default();
128                    let state_key: #ty = <#ty as #serde::de::Deserialize>::deserialize(
129                        #serde::de::IntoDeserializer::<A::Error>::into_deserializer(state_key),
130                    )?;
131                }
132            } else {
133                quote! {
134                    let #name = #name.ok_or_else(|| {
135                        #serde::de::Error::missing_field(stringify!(#name))
136                    })?;
137                }
138            })
139        })
140        .collect::<syn::Result<_>>()?;
141
142    let field_names: Vec<_> = fields.iter().flat_map(|f| &f.ident).collect();
143
144    let deserialize_impl_gen = if is_generic {
145        let gen = &input.generics.params;
146        quote! { <'de, #gen> }
147    } else {
148        quote! { <'de> }
149    };
150    let deserialize_phantom_type = if is_generic {
151        quote! { ::std::marker::PhantomData }
152    } else {
153        quote! {}
154    };
155    let where_clause = if is_generic {
156        let predicate = parse_quote! { C: #ruma_events::EventContentFromType };
157        if let Some(mut where_clause) = where_clause.cloned() {
158            where_clause.predicates.push(predicate);
159            Some(where_clause)
160        } else {
161            Some(parse_quote! { where #predicate })
162        }
163    } else {
164        where_clause.cloned()
165    };
166
167    Ok(quote! {
168        #[automatically_derived]
169        impl #deserialize_impl_gen #serde::de::Deserialize<'de> for #ident #ty_gen #where_clause {
170            fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
171            where
172                D: #serde::de::Deserializer<'de>,
173            {
174                #[derive(#serde::Deserialize)]
175                #[serde(field_identifier, rename_all = "snake_case")]
176                enum Field {
177                    // since this is represented as an enum we have to add it so the JSON picks it
178                    // up
179                    Type,
180                    #( #enum_variants, )*
181                    #[serde(other)]
182                    Unknown,
183                }
184
185                /// Visits the fields of an event struct to handle deserialization of
186                /// the `content` and `prev_content` fields.
187                struct EventVisitor #impl_generics (#deserialize_phantom_type #ty_gen);
188
189                #[automatically_derived]
190                impl #deserialize_impl_gen #serde::de::Visitor<'de>
191                    for EventVisitor #ty_gen #where_clause
192                {
193                    type Value = #ident #ty_gen;
194
195                    fn expecting(
196                        &self,
197                        formatter: &mut ::std::fmt::Formatter<'_>,
198                    ) -> ::std::fmt::Result {
199                        write!(formatter, "struct implementing {}", stringify!(#content_type))
200                    }
201
202                    fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
203                    where
204                        A: #serde::de::MapAccess<'de>,
205                    {
206                        let mut event_type: Option<String> = None;
207                        #( let mut #field_names: Option<#deserialize_var_types> = None; )*
208
209                        while let Some(key) = map.next_key()? {
210                            match key {
211                                Field::Unknown => {
212                                    let _: #serde::de::IgnoredAny = map.next_value()?;
213                                },
214                                Field::Type => {
215                                    if event_type.is_some() {
216                                        return Err(#serde::de::Error::duplicate_field("type"));
217                                    }
218                                    event_type = Some(map.next_value()?);
219                                }
220                                #(
221                                    Field::#enum_variants => {
222                                        if #field_names.is_some() {
223                                            return Err(#serde::de::Error::duplicate_field(
224                                                stringify!(#field_names),
225                                            ));
226                                        }
227                                        #field_names = Some(map.next_value()?);
228                                    }
229                                )*
230                            }
231                        }
232
233                        let event_type =
234                            event_type.ok_or_else(|| #serde::de::Error::missing_field("type"))?;
235                        #( #ok_or_else_fields )*
236
237                        Ok(#ident {
238                            #( #field_names ),*
239                        })
240                    }
241                }
242
243                deserializer.deserialize_map(EventVisitor(#deserialize_phantom_type))
244            }
245        }
246    })
247}
248
249fn expand_sync_from_into_full(
250    input: &DeriveInput,
251    kind: EventKind,
252    var: EventKindVariation,
253    fields: &[Field],
254    ruma_events: &TokenStream,
255) -> syn::Result<TokenStream> {
256    let ident = &input.ident;
257    let full_struct = kind.to_event_ident(var.to_full())?;
258    let (impl_generics, ty_gen, where_clause) = input.generics.split_for_impl();
259    let fields: Vec<_> = fields.iter().flat_map(|f| &f.ident).collect();
260
261    Ok(quote! {
262        #[automatically_derived]
263        impl #impl_generics ::std::convert::From<#full_struct #ty_gen>
264            for #ident #ty_gen #where_clause
265        {
266            fn from(event: #full_struct #ty_gen) -> Self {
267                let #full_struct { #( #fields, )* .. } = event;
268                Self { #( #fields, )* }
269            }
270        }
271
272        #[automatically_derived]
273        impl #impl_generics #ident #ty_gen #where_clause {
274            /// Convert this sync event into a full event, one with a room_id field.
275            pub fn into_full_event(
276                self,
277                room_id: #ruma_events::exports::ruma_common::OwnedRoomId,
278            ) -> #full_struct #ty_gen {
279                let Self { #( #fields, )* } = self;
280                #full_struct {
281                    #( #fields, )*
282                    room_id,
283                }
284            }
285        }
286    })
287}
288
289fn expand_eq_ord_event(input: &DeriveInput) -> TokenStream {
290    let ident = &input.ident;
291    let (impl_gen, ty_gen, where_clause) = input.generics.split_for_impl();
292
293    quote! {
294        #[automatically_derived]
295        impl #impl_gen ::std::cmp::PartialEq for #ident #ty_gen #where_clause {
296            /// Checks if two `EventId`s are equal.
297            fn eq(&self, other: &Self) -> ::std::primitive::bool {
298                self.event_id == other.event_id
299            }
300        }
301
302        #[automatically_derived]
303        impl #impl_gen ::std::cmp::Eq for #ident #ty_gen #where_clause {}
304
305        #[automatically_derived]
306        impl #impl_gen ::std::cmp::PartialOrd for #ident #ty_gen #where_clause {
307            /// Compares `EventId`s and orders them lexicographically.
308            fn partial_cmp(&self, other: &Self) -> ::std::option::Option<::std::cmp::Ordering> {
309                self.event_id.partial_cmp(&other.event_id)
310            }
311        }
312
313        #[automatically_derived]
314        impl #impl_gen ::std::cmp::Ord for #ident #ty_gen #where_clause {
315            /// Compares `EventId`s and orders them lexicographically.
316            fn cmp(&self, other: &Self) -> ::std::cmp::Ordering {
317                self.event_id.cmp(&other.event_id)
318            }
319        }
320    }
321}