ruma_macros/api/request/
incoming.rs

1use proc_macro2::TokenStream;
2use quote::quote;
3use syn::Field;
4
5use super::{Request, RequestField};
6
7impl Request {
8    pub fn expand_incoming(&self, ruma_common: &TokenStream) -> TokenStream {
9        let http = quote! { #ruma_common::exports::http };
10        let serde = quote! { #ruma_common::exports::serde };
11        let serde_html_form = quote! { #ruma_common::exports::serde_html_form };
12        let serde_json = quote! { #ruma_common::exports::serde_json };
13
14        let error_ty = &self.error_ty;
15
16        // FIXME: the rest of the field initializer expansions are gated `cfg(...)` except this one.
17        // If we get errors about missing fields in Request for a path field look here.
18        let (parse_request_path, path_vars) = if self.has_path_fields() {
19            let path_vars: Vec<_> = self.path_fields().filter_map(|f| f.ident.as_ref()).collect();
20
21            let parse_request_path = quote! {
22                let (#(#path_vars,)*) = #serde::Deserialize::deserialize(
23                    #serde::de::value::SeqDeserializer::<_, #serde::de::value::Error>::new(
24                        path_args.iter().map(::std::convert::AsRef::as_ref)
25                    )
26                )?;
27            };
28
29            (parse_request_path, quote! { #(#path_vars,)* })
30        } else {
31            (TokenStream::new(), TokenStream::new())
32        };
33
34        let (parse_query, query_vars) = if let Some(field) = self.query_all_field() {
35            let cfg_attrs =
36                field.attrs.iter().filter(|a| a.path().is_ident("cfg")).collect::<Vec<_>>();
37            let field_name = field.ident.as_ref().expect("expected field to have an identifier");
38            let parse = quote! {
39                #( #cfg_attrs )*
40                let #field_name =
41                    #serde_html_form::from_str(&request.uri().query().unwrap_or(""))?;
42            };
43
44            (
45                parse,
46                quote! {
47                    #( #cfg_attrs )*
48                    #field_name,
49                },
50            )
51        } else if self.has_query_fields() {
52            let (decls, names) = vars(
53                self.fields.iter().filter_map(RequestField::as_query_field),
54                quote! { request_query },
55            );
56
57            let parse = quote! {
58                let request_query: RequestQuery =
59                    #serde_html_form::from_str(&request.uri().query().unwrap_or(""))?;
60
61                #decls
62            };
63
64            (parse, names)
65        } else {
66            (TokenStream::new(), TokenStream::new())
67        };
68
69        let (parse_headers, header_vars) = if self.has_header_fields() {
70            let (decls, names): (TokenStream, Vec<_>) = self
71                .header_fields()
72                .map(|(field, header_name)| {
73                    let cfg_attrs =
74                        field.attrs.iter().filter(|a| a.path().is_ident("cfg")).collect::<Vec<_>>();
75
76                    let field_name = &field.ident;
77                    let header_name_string = header_name.to_string();
78
79                    let (some_case, none_case) = match &field.ty {
80                        syn::Type::Path(syn::TypePath {
81                            path: syn::Path { segments, .. }, ..
82                        }) if segments.last().unwrap().ident == "Option" => {
83                            let syn::PathArguments::AngleBracketed(syn::AngleBracketedGenericArguments {
84                                args: option_args, ..
85                            }) = &segments.last().unwrap().arguments else {
86                                panic!("Option should use angle brackets");
87                            };
88                            let syn::GenericArgument::Type(field_type) = option_args.first().unwrap() else {
89                                panic!("Option brackets should contain type");
90                            };
91                            (
92                                quote! {
93                                    str_value.parse::<#field_type>().ok()
94                                },
95                                quote! { None }
96                            )
97                        }
98                        _ => {
99                            let field_type = &field.ty;
100                            (
101                                quote! {
102                                    str_value
103                                        .parse::<#field_type>()
104                                        .map_err(|e| #ruma_common::api::error::HeaderDeserializationError::InvalidHeader(e.into()))?
105                                },
106                                quote! {
107                                    return Err(
108                                        #ruma_common::api::error::HeaderDeserializationError::MissingHeader(
109                                            #header_name_string.into()
110                                        ).into(),
111                                    )
112                                },
113                            )
114                        }
115                    };
116
117                    let decl = quote! {
118                        #( #cfg_attrs )*
119                        let #field_name = match headers.get(#header_name) {
120                            Some(header_value) => {
121                                let str_value = header_value.to_str()?;
122                                #some_case
123                            }
124                            None => #none_case,
125                        };
126                    };
127
128                    (
129                        decl,
130                        quote! {
131                            #( #cfg_attrs )*
132                            #field_name
133                        },
134                    )
135                })
136                .unzip();
137
138            let parse = quote! {
139                let headers = request.headers();
140
141                #decls
142            };
143
144            (parse, quote! { #(#names,)* })
145        } else {
146            (TokenStream::new(), TokenStream::new())
147        };
148
149        let extract_body = self.has_body_fields().then(|| {
150            quote! {
151                let request_body: RequestBody = {
152                    let body = ::std::convert::AsRef::<[::std::primitive::u8]>::as_ref(
153                        request.body(),
154                    );
155
156                    #serde_json::from_slice(match body {
157                        // If the request body is completely empty, pretend it is an empty JSON
158                        // object instead. This allows requests with only optional body parameters
159                        // to be deserialized in that case.
160                        [] => b"{}",
161                        b => b,
162                    })?
163                };
164            }
165        });
166
167        let (parse_body, body_vars) = if let Some(field) = self.raw_body_field() {
168            let field_name = field.ident.as_ref().expect("expected field to have an identifier");
169            let parse = quote! {
170                let #field_name =
171                    ::std::convert::AsRef::<[u8]>::as_ref(request.body()).to_vec();
172            };
173
174            (parse, quote! { #field_name, })
175        } else {
176            vars(self.body_fields(), quote! { request_body })
177        };
178
179        quote! {
180            #[automatically_derived]
181            #[cfg(feature = "server")]
182            impl #ruma_common::api::IncomingRequest for Request {
183                type EndpointError = #error_ty;
184                type OutgoingResponse = Response;
185
186                const METADATA: #ruma_common::api::Metadata = METADATA;
187
188                fn try_from_http_request<B, S>(
189                    request: #http::Request<B>,
190                    path_args: &[S],
191                ) -> ::std::result::Result<Self, #ruma_common::api::error::FromHttpRequestError>
192                where
193                    B: ::std::convert::AsRef<[::std::primitive::u8]>,
194                    S: ::std::convert::AsRef<::std::primitive::str>,
195                {
196                    if !(request.method() == METADATA.method
197                        || request.method() == #http::Method::HEAD
198                            && METADATA.method == #http::Method::GET)
199                    {
200                        return Err(#ruma_common::api::error::FromHttpRequestError::MethodMismatch {
201                            expected: METADATA.method,
202                            received: request.method().clone(),
203                        });
204                    }
205
206                    #parse_request_path
207                    #parse_query
208                    #parse_headers
209
210                    #extract_body
211                    #parse_body
212
213                    ::std::result::Result::Ok(Self {
214                        #path_vars
215                        #query_vars
216                        #header_vars
217                        #body_vars
218                    })
219                }
220            }
221        }
222    }
223}
224
225fn vars<'a>(
226    fields: impl IntoIterator<Item = &'a Field>,
227    src: TokenStream,
228) -> (TokenStream, TokenStream) {
229    fields
230        .into_iter()
231        .map(|field| {
232            let field_name = field.ident.as_ref().expect("expected field to have an identifier");
233            let cfg_attrs =
234                field.attrs.iter().filter(|a| a.path().is_ident("cfg")).collect::<Vec<_>>();
235
236            let decl = quote! {
237                #( #cfg_attrs )*
238                let #field_name = #src.#field_name;
239            };
240
241            (
242                decl,
243                quote! {
244                    #( #cfg_attrs )*
245                    #field_name,
246                },
247            )
248        })
249        .unzip()
250}