1use cfg_if::cfg_if;
2use proc_macro2::TokenStream;
3use quote::{quote, ToTokens};
4use syn::{
5 parse::{Parse, ParseStream},
6 punctuated::Punctuated,
7 Field, Generics, Ident, ItemStruct, Token, Type,
8};
9
10use super::{
11 attribute::{DeriveRequestMeta, RequestMeta},
12 ensure_feature_presence,
13};
14use crate::util::{field_has_serde_flatten_attribute, import_ruma_common, PrivateField};
15
16mod incoming;
17mod outgoing;
18
19pub fn expand_request(attr: RequestAttr, item: ItemStruct) -> TokenStream {
20 let ruma_common = import_ruma_common();
21 let ruma_macros = quote! { #ruma_common::exports::ruma_macros };
22
23 let maybe_feature_error = ensure_feature_presence().map(syn::Error::to_compile_error);
24
25 let error_ty = attr.0.first().map_or_else(
26 || quote! { #ruma_common::api::error::MatrixError },
27 |DeriveRequestMeta::Error(ty)| quote! { #ty },
28 );
29
30 cfg_if! {
31 if #[cfg(feature = "__internal_macro_expand")] {
32 use syn::parse_quote;
33
34 let mut derive_input = item.clone();
35 derive_input.attrs.push(parse_quote! { #[ruma_api(error = #error_ty)] });
36 crate::util::cfg_expand_struct(&mut derive_input);
37
38 let extra_derive = quote! { #ruma_macros::_FakeDeriveRumaApi };
39 let ruma_api_attribute = quote! {};
40 let request_impls =
41 expand_derive_request(derive_input).unwrap_or_else(syn::Error::into_compile_error);
42 } else {
43 let extra_derive = quote! { #ruma_macros::Request };
44 let ruma_api_attribute = quote! { #[ruma_api(error = #error_ty)] };
45 let request_impls = quote! {};
46 }
47 }
48
49 quote! {
50 #maybe_feature_error
51
52 #[derive(Clone, Debug, #ruma_common::serde::_FakeDeriveSerde, #extra_derive)]
53 #[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
54 #ruma_api_attribute
55 #item
56
57 #request_impls
58 }
59}
60
61pub struct RequestAttr(Punctuated<DeriveRequestMeta, Token![,]>);
62
63impl Parse for RequestAttr {
64 fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
65 Punctuated::<DeriveRequestMeta, Token![,]>::parse_terminated(input).map(Self)
66 }
67}
68
69pub fn expand_derive_request(input: ItemStruct) -> syn::Result<TokenStream> {
70 let fields =
71 input.fields.into_iter().map(RequestField::try_from).collect::<syn::Result<_>>()?;
72
73 let mut error_ty = None;
74
75 for attr in input.attrs {
76 if !attr.path().is_ident("ruma_api") {
77 continue;
78 }
79
80 let metas =
81 attr.parse_args_with(Punctuated::<DeriveRequestMeta, Token![,]>::parse_terminated)?;
82 for meta in metas {
83 match meta {
84 DeriveRequestMeta::Error(t) => error_ty = Some(t),
85 }
86 }
87 }
88
89 let request = Request {
90 ident: input.ident,
91 generics: input.generics,
92 fields,
93 error_ty: error_ty.expect("missing error_ty attribute"),
94 };
95
96 let ruma_common = import_ruma_common();
97 let test = request.check(&ruma_common)?;
98 let types_impls = request.expand_all(&ruma_common);
99
100 Ok(quote! {
101 #types_impls
102
103 #[allow(deprecated)]
104 #[cfg(test)]
105 mod __request {
106 #test
107 }
108 })
109}
110
111struct Request {
112 ident: Ident,
113 generics: Generics,
114 fields: Vec<RequestField>,
115
116 error_ty: Type,
117}
118
119impl Request {
120 fn body_fields(&self) -> impl Iterator<Item = &Field> {
121 self.fields.iter().filter_map(RequestField::as_body_field)
122 }
123
124 fn has_body_fields(&self) -> bool {
125 self.fields
126 .iter()
127 .any(|f| matches!(&f.kind, RequestFieldKind::Body | RequestFieldKind::NewtypeBody))
128 }
129
130 fn has_newtype_body(&self) -> bool {
131 self.fields.iter().any(|f| matches!(&f.kind, RequestFieldKind::NewtypeBody))
132 }
133
134 fn has_header_fields(&self) -> bool {
135 self.fields.iter().any(|f| matches!(&f.kind, RequestFieldKind::Header(_)))
136 }
137
138 fn has_path_fields(&self) -> bool {
139 self.fields.iter().any(|f| matches!(&f.kind, RequestFieldKind::Path))
140 }
141
142 fn has_query_fields(&self) -> bool {
143 self.fields.iter().any(|f| matches!(&f.kind, RequestFieldKind::Query))
144 }
145
146 fn header_fields(&self) -> impl Iterator<Item = (&Field, &Ident)> {
147 self.fields.iter().filter_map(RequestField::as_header_field)
148 }
149
150 fn path_fields(&self) -> impl Iterator<Item = &Field> {
151 self.fields.iter().filter_map(RequestField::as_path_field)
152 }
153
154 fn raw_body_field(&self) -> Option<&Field> {
155 self.fields.iter().find_map(RequestField::as_raw_body_field)
156 }
157
158 fn query_all_field(&self) -> Option<&Field> {
159 self.fields.iter().find_map(RequestField::as_query_all_field)
160 }
161
162 fn expand_all(&self, ruma_common: &TokenStream) -> TokenStream {
163 let ruma_macros = quote! { #ruma_common::exports::ruma_macros };
164 let serde = quote! { #ruma_common::exports::serde };
165
166 let request_body_struct = self.has_body_fields().then(|| {
167 let serde_attr = self.has_newtype_body().then(|| quote! { #[serde(transparent)] });
168 let fields =
169 self.fields.iter().filter_map(RequestField::as_body_field).map(PrivateField);
170
171 quote! {
172 #[cfg(any(feature = "client", feature = "server"))]
174 #[derive(Debug, #ruma_macros::_FakeDeriveRumaApi, #ruma_macros::_FakeDeriveSerde)]
175 #[cfg_attr(feature = "client", derive(#serde::Serialize))]
176 #[cfg_attr(feature = "server", derive(#serde::Deserialize))]
177 #serde_attr
178 struct RequestBody { #(#fields),* }
179 }
180 });
181
182 let request_query_def = if let Some(f) = self.query_all_field() {
183 let field = Field { ident: None, colon_token: None, ..f.clone() };
184 let field = PrivateField(&field);
185 Some(quote! { (#field); })
186 } else if self.has_query_fields() {
187 let fields =
188 self.fields.iter().filter_map(RequestField::as_query_field).map(PrivateField);
189 Some(quote! { { #(#fields),* } })
190 } else {
191 None
192 };
193
194 let request_query_struct = request_query_def.map(|def| {
195 quote! {
196 #[cfg(any(feature = "client", feature = "server"))]
198 #[derive(Debug, #ruma_macros::_FakeDeriveRumaApi, #ruma_macros::_FakeDeriveSerde)]
199 #[cfg_attr(feature = "client", derive(#serde::Serialize))]
200 #[cfg_attr(feature = "server", derive(#serde::Deserialize))]
201 struct RequestQuery #def
202 }
203 });
204
205 let outgoing_request_impl = self.expand_outgoing(ruma_common);
206 let incoming_request_impl = self.expand_incoming(ruma_common);
207
208 quote! {
209 #request_body_struct
210 #request_query_struct
211
212 #[allow(deprecated)]
213 mod __request_impls {
214 use super::*;
215 #outgoing_request_impl
216 #incoming_request_impl
217 }
218 }
219 }
220
221 pub(super) fn check(&self, ruma_common: &TokenStream) -> syn::Result<TokenStream> {
222 let http = quote! { #ruma_common::exports::http };
223
224 let newtype_body_fields = self.fields.iter().filter(|f| {
227 matches!(&f.kind, RequestFieldKind::NewtypeBody | RequestFieldKind::RawBody)
228 });
229
230 let has_newtype_body_field = match newtype_body_fields.count() {
231 0 => false,
232 1 => true,
233 _ => {
234 return Err(syn::Error::new_spanned(
235 &self.ident,
236 "Can't have more than one newtype body field",
237 ))
238 }
239 };
240
241 let query_all_fields =
242 self.fields.iter().filter(|f| matches!(&f.kind, RequestFieldKind::QueryAll));
243 let has_query_all_field = match query_all_fields.count() {
244 0 => false,
245 1 => true,
246 _ => {
247 return Err(syn::Error::new_spanned(
248 &self.ident,
249 "Can't have more than one query_all field",
250 ))
251 }
252 };
253
254 let mut body_fields =
255 self.fields.iter().filter(|f| matches!(f.kind, RequestFieldKind::Body));
256 let first_body_field = body_fields.next();
257 let has_body_fields = first_body_field.is_some();
258
259 if has_newtype_body_field && has_body_fields {
260 return Err(syn::Error::new_spanned(
261 &self.ident,
262 "Can't have both a newtype body field and regular body fields",
263 ));
264 }
265
266 if let Some(first_body_field) = first_body_field {
267 let is_single_body_field = body_fields.next().is_none();
268
269 if is_single_body_field && field_has_serde_flatten_attribute(&first_body_field.inner) {
270 return Err(syn::Error::new_spanned(
271 first_body_field,
272 "Use `#[ruma_api(body)]` to represent the JSON body as a single field",
273 ));
274 }
275 }
276
277 let has_query_fields = self.has_query_fields();
278 if has_query_all_field && has_query_fields {
279 return Err(syn::Error::new_spanned(
280 &self.ident,
281 "Can't have both a query_all field and regular query fields",
282 ));
283 }
284
285 let path_fields = self.path_fields().map(|f| f.ident.as_ref().unwrap().to_string());
286 let mut tests = quote! {
287 #[::std::prelude::v1::test]
288 fn path_parameters() {
289 let path_params = super::METADATA._path_parameters();
290 let request_path_fields: &[&::std::primitive::str] = &[#(#path_fields),*];
291 ::std::assert_eq!(
292 path_params, request_path_fields,
293 "Path parameters must match the `Request`'s `#[ruma_api(path)]` fields"
294 );
295 }
296 };
297
298 if has_body_fields || has_newtype_body_field {
299 tests.extend(quote! {
300 #[::std::prelude::v1::test]
301 fn request_is_not_get() {
302 ::std::assert_ne!(
303 super::METADATA.method, #http::Method::GET,
304 "GET endpoints can't have body fields",
305 );
306 }
307 });
308 }
309
310 Ok(tests)
311 }
312}
313
314pub(super) struct RequestField {
316 pub(super) inner: Field,
317 pub(super) kind: RequestFieldKind,
318}
319
320pub(super) enum RequestFieldKind {
322 Body,
324
325 Header(Ident),
327
328 NewtypeBody,
330
331 RawBody,
333
334 Path,
336
337 Query,
339
340 QueryAll,
342}
343
344impl RequestField {
345 fn new(inner: Field, kind_attr: Option<RequestMeta>) -> Self {
347 let kind = match kind_attr {
348 Some(RequestMeta::NewtypeBody) => RequestFieldKind::NewtypeBody,
349 Some(RequestMeta::RawBody) => RequestFieldKind::RawBody,
350 Some(RequestMeta::Path) => RequestFieldKind::Path,
351 Some(RequestMeta::Query) => RequestFieldKind::Query,
352 Some(RequestMeta::QueryAll) => RequestFieldKind::QueryAll,
353 Some(RequestMeta::Header(header)) => RequestFieldKind::Header(header),
354 None => RequestFieldKind::Body,
355 };
356
357 Self { inner, kind }
358 }
359
360 pub fn as_body_field(&self) -> Option<&Field> {
362 match &self.kind {
363 RequestFieldKind::Body | RequestFieldKind::NewtypeBody => Some(&self.inner),
364 _ => None,
365 }
366 }
367
368 pub fn as_raw_body_field(&self) -> Option<&Field> {
370 match &self.kind {
371 RequestFieldKind::RawBody => Some(&self.inner),
372 _ => None,
373 }
374 }
375
376 pub fn as_path_field(&self) -> Option<&Field> {
378 match &self.kind {
379 RequestFieldKind::Path => Some(&self.inner),
380 _ => None,
381 }
382 }
383
384 pub fn as_query_field(&self) -> Option<&Field> {
386 match &self.kind {
387 RequestFieldKind::Query => Some(&self.inner),
388 _ => None,
389 }
390 }
391
392 pub fn as_query_all_field(&self) -> Option<&Field> {
394 match &self.kind {
395 RequestFieldKind::QueryAll => Some(&self.inner),
396 _ => None,
397 }
398 }
399
400 pub fn as_header_field(&self) -> Option<(&Field, &Ident)> {
402 match &self.kind {
403 RequestFieldKind::Header(header_name) => Some((&self.inner, header_name)),
404 _ => None,
405 }
406 }
407}
408
409impl TryFrom<Field> for RequestField {
410 type Error = syn::Error;
411
412 fn try_from(mut field: Field) -> syn::Result<Self> {
413 let (mut api_attrs, attrs) =
414 field.attrs.into_iter().partition::<Vec<_>, _>(|attr| attr.path().is_ident("ruma_api"));
415 field.attrs = attrs;
416
417 let kind_attr = match api_attrs.as_slice() {
418 [] => None,
419 [_] => Some(api_attrs.pop().unwrap().parse_args::<RequestMeta>()?),
420 _ => {
421 return Err(syn::Error::new_spanned(
422 &api_attrs[1],
423 "multiple field kind attribute found, there can only be one",
424 ));
425 }
426 };
427
428 Ok(RequestField::new(field, kind_attr))
429 }
430}
431
432impl Parse for RequestField {
433 fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
434 input.call(Field::parse_named)?.try_into()
435 }
436}
437
438impl ToTokens for RequestField {
439 fn to_tokens(&self, tokens: &mut TokenStream) {
440 self.inner.to_tokens(tokens);
441 }
442}