1use std::ops::Not;
2
3use cfg_if::cfg_if;
4use proc_macro2::TokenStream;
5use quote::{quote, ToTokens};
6use syn::{
7 parse::{Parse, ParseStream},
8 punctuated::Punctuated,
9 visit::Visit,
10 Field, Generics, Ident, ItemStruct, Lifetime, Token, Type,
11};
12
13use super::{
14 attribute::{DeriveResponseMeta, ResponseMeta},
15 ensure_feature_presence,
16};
17use crate::util::{field_has_serde_flatten_attribute, import_ruma_common, PrivateField};
18
19mod incoming;
20mod outgoing;
21
22pub fn expand_response(attr: ResponseAttr, item: ItemStruct) -> TokenStream {
23 let ruma_common = import_ruma_common();
24 let ruma_macros = quote! { #ruma_common::exports::ruma_macros };
25
26 let maybe_feature_error = ensure_feature_presence().map(syn::Error::to_compile_error);
27
28 let error_ty = attr
29 .0
30 .iter()
31 .find_map(|a| match a {
32 DeriveResponseMeta::Error(ty) => Some(quote! { #ty }),
33 _ => None,
34 })
35 .unwrap_or_else(|| quote! { #ruma_common::api::error::MatrixError });
36 let status_ident = attr
37 .0
38 .iter()
39 .find_map(|a| match a {
40 DeriveResponseMeta::Status(ident) => Some(quote! { #ident }),
41 _ => None,
42 })
43 .unwrap_or_else(|| quote! { OK });
44
45 cfg_if! {
46 if #[cfg(feature = "__internal_macro_expand")] {
47 use syn::parse_quote;
48
49 let mut derive_input = item.clone();
50 derive_input.attrs.push(parse_quote! {
51 #[ruma_api(error = #error_ty, status = #status_ident)]
52 });
53 crate::util::cfg_expand_struct(&mut derive_input);
54
55 let extra_derive = quote! { #ruma_macros::_FakeDeriveRumaApi };
56 let ruma_api_attribute = quote! {};
57 let response_impls =
58 expand_derive_response(derive_input).unwrap_or_else(syn::Error::into_compile_error);
59 } else {
60 let extra_derive = quote! { #ruma_macros::Response };
61 let ruma_api_attribute = quote! {
62 #[ruma_api(error = #error_ty, status = #status_ident)]
63 };
64 let response_impls = quote! {};
65 }
66 }
67
68 quote! {
69 #maybe_feature_error
70
71 #[derive(Clone, Debug, #ruma_common::serde::_FakeDeriveSerde, #extra_derive)]
72 #[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
73 #ruma_api_attribute
74 #item
75
76 #response_impls
77 }
78}
79
80pub struct ResponseAttr(Punctuated<DeriveResponseMeta, Token![,]>);
81
82impl Parse for ResponseAttr {
83 fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
84 Punctuated::<DeriveResponseMeta, Token![,]>::parse_terminated(input).map(Self)
85 }
86}
87
88pub fn expand_derive_response(input: ItemStruct) -> syn::Result<TokenStream> {
89 let fields =
90 input.fields.into_iter().map(ResponseField::try_from).collect::<syn::Result<_>>()?;
91 let mut manual_body_serde = false;
92 let mut error_ty = None;
93 let mut status_ident = None;
94 for attr in input.attrs {
95 if !attr.path().is_ident("ruma_api") {
96 continue;
97 }
98
99 let metas =
100 attr.parse_args_with(Punctuated::<DeriveResponseMeta, Token![,]>::parse_terminated)?;
101 for meta in metas {
102 match meta {
103 DeriveResponseMeta::ManualBodySerde => manual_body_serde = true,
104 DeriveResponseMeta::Error(t) => error_ty = Some(t),
105 DeriveResponseMeta::Status(t) => status_ident = Some(t),
106 }
107 }
108 }
109
110 let response = Response {
111 ident: input.ident,
112 generics: input.generics,
113 fields,
114 manual_body_serde,
115 error_ty: error_ty.expect("missing error_ty attribute"),
116 status_ident: status_ident.expect("missing status_ident attribute"),
117 };
118
119 response.check()?;
120 Ok(response.expand_all())
121}
122
123struct Response {
124 ident: Ident,
125 generics: Generics,
126 fields: Vec<ResponseField>,
127 manual_body_serde: bool,
128 error_ty: Type,
129 status_ident: Ident,
130}
131
132impl Response {
133 fn has_body_fields(&self) -> bool {
135 self.fields
136 .iter()
137 .any(|f| matches!(&f.kind, ResponseFieldKind::Body | &ResponseFieldKind::NewtypeBody))
138 }
139
140 fn has_newtype_body(&self) -> bool {
142 self.fields.iter().any(|f| matches!(&f.kind, ResponseFieldKind::NewtypeBody))
143 }
144
145 fn has_raw_body(&self) -> bool {
147 self.fields.iter().any(|f| matches!(&f.kind, ResponseFieldKind::RawBody))
148 }
149
150 fn has_header_fields(&self) -> bool {
152 self.fields.iter().any(|f| matches!(&f.kind, &ResponseFieldKind::Header(_)))
153 }
154
155 fn expand_all(&self) -> TokenStream {
156 let ruma_common = import_ruma_common();
157 let ruma_macros = quote! { #ruma_common::exports::ruma_macros };
158 let serde = quote! { #ruma_common::exports::serde };
159
160 let response_body_struct = (!self.has_raw_body()).then(|| {
161 let serde_derives = self.manual_body_serde.not().then(|| {
162 quote! {
163 #[cfg_attr(feature = "client", derive(#serde::Deserialize))]
164 #[cfg_attr(feature = "server", derive(#serde::Serialize))]
165 }
166 });
167
168 let serde_attr = self.has_newtype_body().then(|| quote! { #[serde(transparent)] });
169 let fields =
170 self.fields.iter().filter_map(ResponseField::as_body_field).map(PrivateField);
171
172 quote! {
173 #[cfg(any(feature = "client", feature = "server"))]
175 #[derive(Debug, #ruma_macros::_FakeDeriveRumaApi, #ruma_macros::_FakeDeriveSerde)]
176 #serde_derives
177 #serde_attr
178 struct ResponseBody { #(#fields),* }
179 }
180 });
181
182 let outgoing_response_impl = self.expand_outgoing(&self.status_ident, &ruma_common);
183 let incoming_response_impl = self.expand_incoming(&self.error_ty, &ruma_common);
184
185 quote! {
186 #response_body_struct
187
188 #outgoing_response_impl
189 #incoming_response_impl
190 }
191 }
192
193 pub fn check(&self) -> syn::Result<()> {
194 assert!(
197 self.generics.params.is_empty() && self.generics.where_clause.is_none(),
198 "This macro doesn't support generic types"
199 );
200
201 let newtype_body_fields = self.fields.iter().filter(|f| {
202 matches!(&f.kind, ResponseFieldKind::NewtypeBody | ResponseFieldKind::RawBody)
203 });
204
205 let has_newtype_body_field = match newtype_body_fields.count() {
206 0 => false,
207 1 => true,
208 _ => {
209 return Err(syn::Error::new_spanned(
210 &self.ident,
211 "Can't have more than one newtype body field",
212 ))
213 }
214 };
215
216 let mut body_fields =
217 self.fields.iter().filter(|f| matches!(f.kind, ResponseFieldKind::Body));
218 let first_body_field = body_fields.next();
219 let has_body_fields = first_body_field.is_some();
220
221 if has_newtype_body_field && has_body_fields {
222 return Err(syn::Error::new_spanned(
223 &self.ident,
224 "Can't have both a newtype body field and regular body fields",
225 ));
226 }
227
228 if let Some(first_body_field) = first_body_field {
229 let is_single_body_field = body_fields.next().is_none();
230
231 if is_single_body_field && field_has_serde_flatten_attribute(&first_body_field.inner) {
232 return Err(syn::Error::new_spanned(
233 first_body_field,
234 "Use `#[ruma_api(body)]` to represent the JSON body as a single field",
235 ));
236 }
237 }
238
239 Ok(())
240 }
241}
242
243struct ResponseField {
245 inner: Field,
246 kind: ResponseFieldKind,
247}
248
249enum ResponseFieldKind {
251 Body,
253
254 Header(Ident),
256
257 NewtypeBody,
259
260 RawBody,
262}
263
264impl ResponseField {
265 fn new(inner: Field, kind_attr: Option<ResponseMeta>) -> Self {
267 let kind = match kind_attr {
268 Some(ResponseMeta::NewtypeBody) => ResponseFieldKind::NewtypeBody,
269 Some(ResponseMeta::RawBody) => ResponseFieldKind::RawBody,
270 Some(ResponseMeta::Header(header)) => ResponseFieldKind::Header(header),
271 None => ResponseFieldKind::Body,
272 };
273
274 Self { inner, kind }
275 }
276
277 fn as_body_field(&self) -> Option<&Field> {
279 match &self.kind {
280 ResponseFieldKind::Body | ResponseFieldKind::NewtypeBody => Some(&self.inner),
281 _ => None,
282 }
283 }
284
285 fn as_raw_body_field(&self) -> Option<&Field> {
287 match &self.kind {
288 ResponseFieldKind::RawBody => Some(&self.inner),
289 _ => None,
290 }
291 }
292
293 fn as_header_field(&self) -> Option<(&Field, &Ident)> {
295 match &self.kind {
296 ResponseFieldKind::Header(ident) => Some((&self.inner, ident)),
297 _ => None,
298 }
299 }
300}
301
302impl TryFrom<Field> for ResponseField {
303 type Error = syn::Error;
304
305 fn try_from(mut field: Field) -> syn::Result<Self> {
306 if has_lifetime(&field.ty) {
307 return Err(syn::Error::new_spanned(
308 field.ident,
309 "Lifetimes on Response fields cannot be supported until GAT are stable",
310 ));
311 }
312
313 let (mut api_attrs, attrs) =
314 field.attrs.into_iter().partition::<Vec<_>, _>(|attr| attr.path().is_ident("ruma_api"));
315 field.attrs = attrs;
316
317 let kind_attr = match api_attrs.as_slice() {
318 [] => None,
319 [_] => Some(api_attrs.pop().unwrap().parse_args::<ResponseMeta>()?),
320 _ => {
321 return Err(syn::Error::new_spanned(
322 &api_attrs[1],
323 "multiple field kind attribute found, there can only be one",
324 ));
325 }
326 };
327
328 Ok(ResponseField::new(field, kind_attr))
329 }
330}
331
332impl Parse for ResponseField {
333 fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
334 input.call(Field::parse_named)?.try_into()
335 }
336}
337
338impl ToTokens for ResponseField {
339 fn to_tokens(&self, tokens: &mut TokenStream) {
340 self.inner.to_tokens(tokens);
341 }
342}
343
344fn has_lifetime(ty: &Type) -> bool {
345 struct Visitor {
346 found_lifetime: bool,
347 }
348
349 impl<'ast> Visit<'ast> for Visitor {
350 fn visit_lifetime(&mut self, _lt: &'ast Lifetime) {
351 self.found_lifetime = true;
352 }
353 }
354
355 let mut vis = Visitor { found_lifetime: false };
356 vis.visit_type(ty);
357 vis.found_lifetime
358}