ruma_macros/api/
response.rs

1use std::ops::Not;
2
3use cfg_if::cfg_if;
4use proc_macro2::TokenStream;
5use quote::{quote, ToTokens};
6use syn::{
7    parse::{Parse, ParseStream},
8    punctuated::Punctuated,
9    visit::Visit,
10    Field, Generics, Ident, ItemStruct, Lifetime, Token, Type,
11};
12
13use super::{
14    attribute::{DeriveResponseMeta, ResponseMeta},
15    ensure_feature_presence,
16};
17use crate::util::{field_has_serde_flatten_attribute, import_ruma_common, PrivateField};
18
19mod incoming;
20mod outgoing;
21
22pub fn expand_response(attr: ResponseAttr, item: ItemStruct) -> TokenStream {
23    let ruma_common = import_ruma_common();
24    let ruma_macros = quote! { #ruma_common::exports::ruma_macros };
25
26    let maybe_feature_error = ensure_feature_presence().map(syn::Error::to_compile_error);
27
28    let error_ty = attr
29        .0
30        .iter()
31        .find_map(|a| match a {
32            DeriveResponseMeta::Error(ty) => Some(quote! { #ty }),
33            _ => None,
34        })
35        .unwrap_or_else(|| quote! { #ruma_common::api::error::MatrixError });
36    let status_ident = attr
37        .0
38        .iter()
39        .find_map(|a| match a {
40            DeriveResponseMeta::Status(ident) => Some(quote! { #ident }),
41            _ => None,
42        })
43        .unwrap_or_else(|| quote! { OK });
44
45    cfg_if! {
46        if #[cfg(feature = "__internal_macro_expand")] {
47            use syn::parse_quote;
48
49            let mut derive_input = item.clone();
50            derive_input.attrs.push(parse_quote! {
51                #[ruma_api(error = #error_ty, status = #status_ident)]
52            });
53            crate::util::cfg_expand_struct(&mut derive_input);
54
55            let extra_derive = quote! { #ruma_macros::_FakeDeriveRumaApi };
56            let ruma_api_attribute = quote! {};
57            let response_impls =
58                expand_derive_response(derive_input).unwrap_or_else(syn::Error::into_compile_error);
59        } else {
60            let extra_derive = quote! { #ruma_macros::Response };
61            let ruma_api_attribute = quote! {
62                #[ruma_api(error = #error_ty, status = #status_ident)]
63            };
64            let response_impls = quote! {};
65        }
66    }
67
68    quote! {
69        #maybe_feature_error
70
71        #[derive(Clone, Debug, #ruma_common::serde::_FakeDeriveSerde, #extra_derive)]
72        #[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
73        #ruma_api_attribute
74        #item
75
76        #response_impls
77    }
78}
79
80pub struct ResponseAttr(Punctuated<DeriveResponseMeta, Token![,]>);
81
82impl Parse for ResponseAttr {
83    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
84        Punctuated::<DeriveResponseMeta, Token![,]>::parse_terminated(input).map(Self)
85    }
86}
87
88pub fn expand_derive_response(input: ItemStruct) -> syn::Result<TokenStream> {
89    let fields =
90        input.fields.into_iter().map(ResponseField::try_from).collect::<syn::Result<_>>()?;
91    let mut manual_body_serde = false;
92    let mut error_ty = None;
93    let mut status_ident = None;
94    for attr in input.attrs {
95        if !attr.path().is_ident("ruma_api") {
96            continue;
97        }
98
99        let metas =
100            attr.parse_args_with(Punctuated::<DeriveResponseMeta, Token![,]>::parse_terminated)?;
101        for meta in metas {
102            match meta {
103                DeriveResponseMeta::ManualBodySerde => manual_body_serde = true,
104                DeriveResponseMeta::Error(t) => error_ty = Some(t),
105                DeriveResponseMeta::Status(t) => status_ident = Some(t),
106            }
107        }
108    }
109
110    let response = Response {
111        ident: input.ident,
112        generics: input.generics,
113        fields,
114        manual_body_serde,
115        error_ty: error_ty.expect("missing error_ty attribute"),
116        status_ident: status_ident.expect("missing status_ident attribute"),
117    };
118
119    response.check()?;
120    Ok(response.expand_all())
121}
122
123struct Response {
124    ident: Ident,
125    generics: Generics,
126    fields: Vec<ResponseField>,
127    manual_body_serde: bool,
128    error_ty: Type,
129    status_ident: Ident,
130}
131
132impl Response {
133    /// Whether or not this request has any data in the HTTP body.
134    fn has_body_fields(&self) -> bool {
135        self.fields
136            .iter()
137            .any(|f| matches!(&f.kind, ResponseFieldKind::Body | &ResponseFieldKind::NewtypeBody))
138    }
139
140    /// Whether or not this request has a single newtype body field.
141    fn has_newtype_body(&self) -> bool {
142        self.fields.iter().any(|f| matches!(&f.kind, ResponseFieldKind::NewtypeBody))
143    }
144
145    /// Whether or not this request has a single raw body field.
146    fn has_raw_body(&self) -> bool {
147        self.fields.iter().any(|f| matches!(&f.kind, ResponseFieldKind::RawBody))
148    }
149
150    /// Whether or not this request has any data in the URL path.
151    fn has_header_fields(&self) -> bool {
152        self.fields.iter().any(|f| matches!(&f.kind, &ResponseFieldKind::Header(_)))
153    }
154
155    fn expand_all(&self) -> TokenStream {
156        let ruma_common = import_ruma_common();
157        let ruma_macros = quote! { #ruma_common::exports::ruma_macros };
158        let serde = quote! { #ruma_common::exports::serde };
159
160        let response_body_struct = (!self.has_raw_body()).then(|| {
161            let serde_derives = self.manual_body_serde.not().then(|| {
162                quote! {
163                    #[cfg_attr(feature = "client", derive(#serde::Deserialize))]
164                    #[cfg_attr(feature = "server", derive(#serde::Serialize))]
165                }
166            });
167
168            let serde_attr = self.has_newtype_body().then(|| quote! { #[serde(transparent)] });
169            let fields =
170                self.fields.iter().filter_map(ResponseField::as_body_field).map(PrivateField);
171
172            quote! {
173                /// Data in the response body.
174                #[cfg(any(feature = "client", feature = "server"))]
175                #[derive(Debug, #ruma_macros::_FakeDeriveRumaApi, #ruma_macros::_FakeDeriveSerde)]
176                #serde_derives
177                #serde_attr
178                struct ResponseBody { #(#fields),* }
179            }
180        });
181
182        let outgoing_response_impl = self.expand_outgoing(&self.status_ident, &ruma_common);
183        let incoming_response_impl = self.expand_incoming(&self.error_ty, &ruma_common);
184
185        quote! {
186            #response_body_struct
187
188            #outgoing_response_impl
189            #incoming_response_impl
190        }
191    }
192
193    pub fn check(&self) -> syn::Result<()> {
194        // TODO: highlight problematic fields
195
196        assert!(
197            self.generics.params.is_empty() && self.generics.where_clause.is_none(),
198            "This macro doesn't support generic types"
199        );
200
201        let newtype_body_fields = self.fields.iter().filter(|f| {
202            matches!(&f.kind, ResponseFieldKind::NewtypeBody | ResponseFieldKind::RawBody)
203        });
204
205        let has_newtype_body_field = match newtype_body_fields.count() {
206            0 => false,
207            1 => true,
208            _ => {
209                return Err(syn::Error::new_spanned(
210                    &self.ident,
211                    "Can't have more than one newtype body field",
212                ))
213            }
214        };
215
216        let mut body_fields =
217            self.fields.iter().filter(|f| matches!(f.kind, ResponseFieldKind::Body));
218        let first_body_field = body_fields.next();
219        let has_body_fields = first_body_field.is_some();
220
221        if has_newtype_body_field && has_body_fields {
222            return Err(syn::Error::new_spanned(
223                &self.ident,
224                "Can't have both a newtype body field and regular body fields",
225            ));
226        }
227
228        if let Some(first_body_field) = first_body_field {
229            let is_single_body_field = body_fields.next().is_none();
230
231            if is_single_body_field && field_has_serde_flatten_attribute(&first_body_field.inner) {
232                return Err(syn::Error::new_spanned(
233                    first_body_field,
234                    "Use `#[ruma_api(body)]` to represent the JSON body as a single field",
235                ));
236            }
237        }
238
239        Ok(())
240    }
241}
242
243/// A field of the response struct.
244struct ResponseField {
245    inner: Field,
246    kind: ResponseFieldKind,
247}
248
249/// The kind of a response field.
250enum ResponseFieldKind {
251    /// JSON data in the body of the response.
252    Body,
253
254    /// Data in an HTTP header.
255    Header(Ident),
256
257    /// A specific data type in the body of the response.
258    NewtypeBody,
259
260    /// Arbitrary bytes in the body of the response.
261    RawBody,
262}
263
264impl ResponseField {
265    /// Creates a new `ResponseField`.
266    fn new(inner: Field, kind_attr: Option<ResponseMeta>) -> Self {
267        let kind = match kind_attr {
268            Some(ResponseMeta::NewtypeBody) => ResponseFieldKind::NewtypeBody,
269            Some(ResponseMeta::RawBody) => ResponseFieldKind::RawBody,
270            Some(ResponseMeta::Header(header)) => ResponseFieldKind::Header(header),
271            None => ResponseFieldKind::Body,
272        };
273
274        Self { inner, kind }
275    }
276
277    /// Return the contained field if this response field is a body kind.
278    fn as_body_field(&self) -> Option<&Field> {
279        match &self.kind {
280            ResponseFieldKind::Body | ResponseFieldKind::NewtypeBody => Some(&self.inner),
281            _ => None,
282        }
283    }
284
285    /// Return the contained field if this response field is a raw body kind.
286    fn as_raw_body_field(&self) -> Option<&Field> {
287        match &self.kind {
288            ResponseFieldKind::RawBody => Some(&self.inner),
289            _ => None,
290        }
291    }
292
293    /// Return the contained field and HTTP header ident if this response field is a header kind.
294    fn as_header_field(&self) -> Option<(&Field, &Ident)> {
295        match &self.kind {
296            ResponseFieldKind::Header(ident) => Some((&self.inner, ident)),
297            _ => None,
298        }
299    }
300}
301
302impl TryFrom<Field> for ResponseField {
303    type Error = syn::Error;
304
305    fn try_from(mut field: Field) -> syn::Result<Self> {
306        if has_lifetime(&field.ty) {
307            return Err(syn::Error::new_spanned(
308                field.ident,
309                "Lifetimes on Response fields cannot be supported until GAT are stable",
310            ));
311        }
312
313        let (mut api_attrs, attrs) =
314            field.attrs.into_iter().partition::<Vec<_>, _>(|attr| attr.path().is_ident("ruma_api"));
315        field.attrs = attrs;
316
317        let kind_attr = match api_attrs.as_slice() {
318            [] => None,
319            [_] => Some(api_attrs.pop().unwrap().parse_args::<ResponseMeta>()?),
320            _ => {
321                return Err(syn::Error::new_spanned(
322                    &api_attrs[1],
323                    "multiple field kind attribute found, there can only be one",
324                ));
325            }
326        };
327
328        Ok(ResponseField::new(field, kind_attr))
329    }
330}
331
332impl Parse for ResponseField {
333    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
334        input.call(Field::parse_named)?.try_into()
335    }
336}
337
338impl ToTokens for ResponseField {
339    fn to_tokens(&self, tokens: &mut TokenStream) {
340        self.inner.to_tokens(tokens);
341    }
342}
343
344fn has_lifetime(ty: &Type) -> bool {
345    struct Visitor {
346        found_lifetime: bool,
347    }
348
349    impl<'ast> Visit<'ast> for Visitor {
350        fn visit_lifetime(&mut self, _lt: &'ast Lifetime) {
351            self.found_lifetime = true;
352        }
353    }
354
355    let mut vis = Visitor { found_lifetime: false };
356    vis.visit_type(ty);
357    vis.found_lifetime
358}