ruma_macros/api/
request.rs

1use cfg_if::cfg_if;
2use proc_macro2::TokenStream;
3use quote::{quote, ToTokens};
4use syn::{
5    parse::{Parse, ParseStream},
6    punctuated::Punctuated,
7    Field, Generics, Ident, ItemStruct, Token, Type,
8};
9
10use super::{
11    attribute::{DeriveRequestMeta, RequestMeta},
12    ensure_feature_presence,
13};
14use crate::util::{field_has_serde_flatten_attribute, import_ruma_common, PrivateField};
15
16mod incoming;
17mod outgoing;
18
19pub fn expand_request(attr: RequestAttr, item: ItemStruct) -> TokenStream {
20    let ruma_common = import_ruma_common();
21    let ruma_macros = quote! { #ruma_common::exports::ruma_macros };
22
23    let maybe_feature_error = ensure_feature_presence().map(syn::Error::to_compile_error);
24
25    let error_ty = attr.0.first().map_or_else(
26        || quote! { #ruma_common::api::error::MatrixError },
27        |DeriveRequestMeta::Error(ty)| quote! { #ty },
28    );
29
30    cfg_if! {
31        if #[cfg(feature = "__internal_macro_expand")] {
32            use syn::parse_quote;
33
34            let mut derive_input = item.clone();
35            derive_input.attrs.push(parse_quote! { #[ruma_api(error = #error_ty)] });
36            crate::util::cfg_expand_struct(&mut derive_input);
37
38            let extra_derive = quote! { #ruma_macros::_FakeDeriveRumaApi };
39            let ruma_api_attribute = quote! {};
40            let request_impls =
41                expand_derive_request(derive_input).unwrap_or_else(syn::Error::into_compile_error);
42        } else {
43            let extra_derive = quote! { #ruma_macros::Request };
44            let ruma_api_attribute = quote! { #[ruma_api(error = #error_ty)] };
45            let request_impls = quote! {};
46        }
47    }
48
49    quote! {
50        #maybe_feature_error
51
52        #[derive(Clone, Debug, #ruma_common::serde::_FakeDeriveSerde, #extra_derive)]
53        #[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
54        #ruma_api_attribute
55        #item
56
57        #request_impls
58    }
59}
60
61pub struct RequestAttr(Punctuated<DeriveRequestMeta, Token![,]>);
62
63impl Parse for RequestAttr {
64    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
65        Punctuated::<DeriveRequestMeta, Token![,]>::parse_terminated(input).map(Self)
66    }
67}
68
69pub fn expand_derive_request(input: ItemStruct) -> syn::Result<TokenStream> {
70    let fields =
71        input.fields.into_iter().map(RequestField::try_from).collect::<syn::Result<_>>()?;
72
73    let mut error_ty = None;
74
75    for attr in input.attrs {
76        if !attr.path().is_ident("ruma_api") {
77            continue;
78        }
79
80        let metas =
81            attr.parse_args_with(Punctuated::<DeriveRequestMeta, Token![,]>::parse_terminated)?;
82        for meta in metas {
83            match meta {
84                DeriveRequestMeta::Error(t) => error_ty = Some(t),
85            }
86        }
87    }
88
89    let request = Request {
90        ident: input.ident,
91        generics: input.generics,
92        fields,
93        error_ty: error_ty.expect("missing error_ty attribute"),
94    };
95
96    let ruma_common = import_ruma_common();
97    let test = request.check(&ruma_common)?;
98    let types_impls = request.expand_all(&ruma_common);
99
100    Ok(quote! {
101        #types_impls
102
103        #[allow(deprecated)]
104        #[cfg(test)]
105        mod __request {
106            #test
107        }
108    })
109}
110
111struct Request {
112    ident: Ident,
113    generics: Generics,
114    fields: Vec<RequestField>,
115
116    error_ty: Type,
117}
118
119impl Request {
120    fn body_fields(&self) -> impl Iterator<Item = &Field> {
121        self.fields.iter().filter_map(RequestField::as_body_field)
122    }
123
124    fn has_body_fields(&self) -> bool {
125        self.fields
126            .iter()
127            .any(|f| matches!(&f.kind, RequestFieldKind::Body | RequestFieldKind::NewtypeBody))
128    }
129
130    fn has_newtype_body(&self) -> bool {
131        self.fields.iter().any(|f| matches!(&f.kind, RequestFieldKind::NewtypeBody))
132    }
133
134    fn has_header_fields(&self) -> bool {
135        self.fields.iter().any(|f| matches!(&f.kind, RequestFieldKind::Header(_)))
136    }
137
138    fn has_path_fields(&self) -> bool {
139        self.fields.iter().any(|f| matches!(&f.kind, RequestFieldKind::Path))
140    }
141
142    fn has_query_fields(&self) -> bool {
143        self.fields.iter().any(|f| matches!(&f.kind, RequestFieldKind::Query))
144    }
145
146    fn header_fields(&self) -> impl Iterator<Item = (&Field, &Ident)> {
147        self.fields.iter().filter_map(RequestField::as_header_field)
148    }
149
150    fn path_fields(&self) -> impl Iterator<Item = &Field> {
151        self.fields.iter().filter_map(RequestField::as_path_field)
152    }
153
154    fn raw_body_field(&self) -> Option<&Field> {
155        self.fields.iter().find_map(RequestField::as_raw_body_field)
156    }
157
158    fn query_all_field(&self) -> Option<&Field> {
159        self.fields.iter().find_map(RequestField::as_query_all_field)
160    }
161
162    fn expand_all(&self, ruma_common: &TokenStream) -> TokenStream {
163        let ruma_macros = quote! { #ruma_common::exports::ruma_macros };
164        let serde = quote! { #ruma_common::exports::serde };
165
166        let request_body_struct = self.has_body_fields().then(|| {
167            let serde_attr = self.has_newtype_body().then(|| quote! { #[serde(transparent)] });
168            let fields =
169                self.fields.iter().filter_map(RequestField::as_body_field).map(PrivateField);
170
171            quote! {
172                /// Data in the request body.
173                #[cfg(any(feature = "client", feature = "server"))]
174                #[derive(Debug, #ruma_macros::_FakeDeriveRumaApi, #ruma_macros::_FakeDeriveSerde)]
175                #[cfg_attr(feature = "client", derive(#serde::Serialize))]
176                #[cfg_attr(feature = "server", derive(#serde::Deserialize))]
177                #serde_attr
178                struct RequestBody { #(#fields),* }
179            }
180        });
181
182        let request_query_def = if let Some(f) = self.query_all_field() {
183            let field = Field { ident: None, colon_token: None, ..f.clone() };
184            let field = PrivateField(&field);
185            Some(quote! { (#field); })
186        } else if self.has_query_fields() {
187            let fields =
188                self.fields.iter().filter_map(RequestField::as_query_field).map(PrivateField);
189            Some(quote! { { #(#fields),* } })
190        } else {
191            None
192        };
193
194        let request_query_struct = request_query_def.map(|def| {
195            quote! {
196                /// Data in the request's query string.
197                #[cfg(any(feature = "client", feature = "server"))]
198                #[derive(Debug, #ruma_macros::_FakeDeriveRumaApi, #ruma_macros::_FakeDeriveSerde)]
199                #[cfg_attr(feature = "client", derive(#serde::Serialize))]
200                #[cfg_attr(feature = "server", derive(#serde::Deserialize))]
201                struct RequestQuery #def
202            }
203        });
204
205        let outgoing_request_impl = self.expand_outgoing(ruma_common);
206        let incoming_request_impl = self.expand_incoming(ruma_common);
207
208        quote! {
209            #request_body_struct
210            #request_query_struct
211
212            #[allow(deprecated)]
213            mod __request_impls {
214                use super::*;
215                #outgoing_request_impl
216                #incoming_request_impl
217            }
218        }
219    }
220
221    pub(super) fn check(&self, ruma_common: &TokenStream) -> syn::Result<TokenStream> {
222        let http = quote! { #ruma_common::exports::http };
223
224        // TODO: highlight problematic fields
225
226        let newtype_body_fields = self.fields.iter().filter(|f| {
227            matches!(&f.kind, RequestFieldKind::NewtypeBody | RequestFieldKind::RawBody)
228        });
229
230        let has_newtype_body_field = match newtype_body_fields.count() {
231            0 => false,
232            1 => true,
233            _ => {
234                return Err(syn::Error::new_spanned(
235                    &self.ident,
236                    "Can't have more than one newtype body field",
237                ))
238            }
239        };
240
241        let query_all_fields =
242            self.fields.iter().filter(|f| matches!(&f.kind, RequestFieldKind::QueryAll));
243        let has_query_all_field = match query_all_fields.count() {
244            0 => false,
245            1 => true,
246            _ => {
247                return Err(syn::Error::new_spanned(
248                    &self.ident,
249                    "Can't have more than one query_all field",
250                ))
251            }
252        };
253
254        let mut body_fields =
255            self.fields.iter().filter(|f| matches!(f.kind, RequestFieldKind::Body));
256        let first_body_field = body_fields.next();
257        let has_body_fields = first_body_field.is_some();
258
259        if has_newtype_body_field && has_body_fields {
260            return Err(syn::Error::new_spanned(
261                &self.ident,
262                "Can't have both a newtype body field and regular body fields",
263            ));
264        }
265
266        if let Some(first_body_field) = first_body_field {
267            let is_single_body_field = body_fields.next().is_none();
268
269            if is_single_body_field && field_has_serde_flatten_attribute(&first_body_field.inner) {
270                return Err(syn::Error::new_spanned(
271                    first_body_field,
272                    "Use `#[ruma_api(body)]` to represent the JSON body as a single field",
273                ));
274            }
275        }
276
277        let has_query_fields = self.has_query_fields();
278        if has_query_all_field && has_query_fields {
279            return Err(syn::Error::new_spanned(
280                &self.ident,
281                "Can't have both a query_all field and regular query fields",
282            ));
283        }
284
285        let path_fields = self.path_fields().map(|f| f.ident.as_ref().unwrap().to_string());
286        let mut tests = quote! {
287            #[::std::prelude::v1::test]
288            fn path_parameters() {
289                let path_params = super::METADATA._path_parameters();
290                let request_path_fields: &[&::std::primitive::str] = &[#(#path_fields),*];
291                ::std::assert_eq!(
292                    path_params, request_path_fields,
293                    "Path parameters must match the `Request`'s `#[ruma_api(path)]` fields"
294                );
295            }
296        };
297
298        if has_body_fields || has_newtype_body_field {
299            tests.extend(quote! {
300                #[::std::prelude::v1::test]
301                fn request_is_not_get() {
302                    ::std::assert_ne!(
303                        super::METADATA.method, #http::Method::GET,
304                        "GET endpoints can't have body fields",
305                    );
306                }
307            });
308        }
309
310        Ok(tests)
311    }
312}
313
314/// A field of the request struct.
315pub(super) struct RequestField {
316    pub(super) inner: Field,
317    pub(super) kind: RequestFieldKind,
318}
319
320/// The kind of a request field.
321pub(super) enum RequestFieldKind {
322    /// JSON data in the body of the request.
323    Body,
324
325    /// Data in an HTTP header.
326    Header(Ident),
327
328    /// A specific data type in the body of the request.
329    NewtypeBody,
330
331    /// Arbitrary bytes in the body of the request.
332    RawBody,
333
334    /// Data that appears in the URL path.
335    Path,
336
337    /// Data that appears in the query string.
338    Query,
339
340    /// Data that represents all the query string as a single type.
341    QueryAll,
342}
343
344impl RequestField {
345    /// Creates a new `RequestField`.
346    fn new(inner: Field, kind_attr: Option<RequestMeta>) -> Self {
347        let kind = match kind_attr {
348            Some(RequestMeta::NewtypeBody) => RequestFieldKind::NewtypeBody,
349            Some(RequestMeta::RawBody) => RequestFieldKind::RawBody,
350            Some(RequestMeta::Path) => RequestFieldKind::Path,
351            Some(RequestMeta::Query) => RequestFieldKind::Query,
352            Some(RequestMeta::QueryAll) => RequestFieldKind::QueryAll,
353            Some(RequestMeta::Header(header)) => RequestFieldKind::Header(header),
354            None => RequestFieldKind::Body,
355        };
356
357        Self { inner, kind }
358    }
359
360    /// Return the contained field if this request field is a body kind.
361    pub fn as_body_field(&self) -> Option<&Field> {
362        match &self.kind {
363            RequestFieldKind::Body | RequestFieldKind::NewtypeBody => Some(&self.inner),
364            _ => None,
365        }
366    }
367
368    /// Return the contained field if this request field is a raw body kind.
369    pub fn as_raw_body_field(&self) -> Option<&Field> {
370        match &self.kind {
371            RequestFieldKind::RawBody => Some(&self.inner),
372            _ => None,
373        }
374    }
375
376    /// Return the contained field if this request field is a path kind.
377    pub fn as_path_field(&self) -> Option<&Field> {
378        match &self.kind {
379            RequestFieldKind::Path => Some(&self.inner),
380            _ => None,
381        }
382    }
383
384    /// Return the contained field if this request field is a query kind.
385    pub fn as_query_field(&self) -> Option<&Field> {
386        match &self.kind {
387            RequestFieldKind::Query => Some(&self.inner),
388            _ => None,
389        }
390    }
391
392    /// Return the contained field if this request field is a query all kind.
393    pub fn as_query_all_field(&self) -> Option<&Field> {
394        match &self.kind {
395            RequestFieldKind::QueryAll => Some(&self.inner),
396            _ => None,
397        }
398    }
399
400    /// Return the contained field and header ident if this request field is a header kind.
401    pub fn as_header_field(&self) -> Option<(&Field, &Ident)> {
402        match &self.kind {
403            RequestFieldKind::Header(header_name) => Some((&self.inner, header_name)),
404            _ => None,
405        }
406    }
407}
408
409impl TryFrom<Field> for RequestField {
410    type Error = syn::Error;
411
412    fn try_from(mut field: Field) -> syn::Result<Self> {
413        let (mut api_attrs, attrs) =
414            field.attrs.into_iter().partition::<Vec<_>, _>(|attr| attr.path().is_ident("ruma_api"));
415        field.attrs = attrs;
416
417        let kind_attr = match api_attrs.as_slice() {
418            [] => None,
419            [_] => Some(api_attrs.pop().unwrap().parse_args::<RequestMeta>()?),
420            _ => {
421                return Err(syn::Error::new_spanned(
422                    &api_attrs[1],
423                    "multiple field kind attribute found, there can only be one",
424                ));
425            }
426        };
427
428        Ok(RequestField::new(field, kind_attr))
429    }
430}
431
432impl Parse for RequestField {
433    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
434        input.call(Field::parse_named)?.try_into()
435    }
436}
437
438impl ToTokens for RequestField {
439    fn to_tokens(&self, tokens: &mut TokenStream) {
440        self.inner.to_tokens(tokens);
441    }
442}