ruma_macros/api/request/
incoming.rs
1use proc_macro2::TokenStream;
2use quote::quote;
3use syn::Field;
4
5use super::{Request, RequestField};
6
7impl Request {
8 pub fn expand_incoming(&self, ruma_common: &TokenStream) -> TokenStream {
9 let http = quote! { #ruma_common::exports::http };
10 let serde = quote! { #ruma_common::exports::serde };
11 let serde_html_form = quote! { #ruma_common::exports::serde_html_form };
12 let serde_json = quote! { #ruma_common::exports::serde_json };
13
14 let error_ty = &self.error_ty;
15
16 let (parse_request_path, path_vars) = if self.has_path_fields() {
19 let path_vars: Vec<_> = self.path_fields().filter_map(|f| f.ident.as_ref()).collect();
20
21 let parse_request_path = quote! {
22 let (#(#path_vars,)*) = #serde::Deserialize::deserialize(
23 #serde::de::value::SeqDeserializer::<_, #serde::de::value::Error>::new(
24 path_args.iter().map(::std::convert::AsRef::as_ref)
25 )
26 )?;
27 };
28
29 (parse_request_path, quote! { #(#path_vars,)* })
30 } else {
31 (TokenStream::new(), TokenStream::new())
32 };
33
34 let (parse_query, query_vars) = if let Some(field) = self.query_all_field() {
35 let cfg_attrs =
36 field.attrs.iter().filter(|a| a.path().is_ident("cfg")).collect::<Vec<_>>();
37 let field_name = field.ident.as_ref().expect("expected field to have an identifier");
38 let parse = quote! {
39 #( #cfg_attrs )*
40 let #field_name =
41 #serde_html_form::from_str(&request.uri().query().unwrap_or(""))?;
42 };
43
44 (
45 parse,
46 quote! {
47 #( #cfg_attrs )*
48 #field_name,
49 },
50 )
51 } else if self.has_query_fields() {
52 let (decls, names) = vars(
53 self.fields.iter().filter_map(RequestField::as_query_field),
54 quote! { request_query },
55 );
56
57 let parse = quote! {
58 let request_query: RequestQuery =
59 #serde_html_form::from_str(&request.uri().query().unwrap_or(""))?;
60
61 #decls
62 };
63
64 (parse, names)
65 } else {
66 (TokenStream::new(), TokenStream::new())
67 };
68
69 let (parse_headers, header_vars) = if self.has_header_fields() {
70 let (decls, names): (TokenStream, Vec<_>) = self
71 .header_fields()
72 .map(|(field, header_name)| {
73 let cfg_attrs =
74 field.attrs.iter().filter(|a| a.path().is_ident("cfg")).collect::<Vec<_>>();
75
76 let field_name = &field.ident;
77 let header_name_string = header_name.to_string();
78
79 let (some_case, none_case) = match &field.ty {
80 syn::Type::Path(syn::TypePath {
81 path: syn::Path { segments, .. }, ..
82 }) if segments.last().unwrap().ident == "Option" => {
83 let syn::PathArguments::AngleBracketed(syn::AngleBracketedGenericArguments {
84 args: option_args, ..
85 }) = &segments.last().unwrap().arguments else {
86 panic!("Option should use angle brackets");
87 };
88 let syn::GenericArgument::Type(field_type) = option_args.first().unwrap() else {
89 panic!("Option brackets should contain type");
90 };
91 (
92 quote! {
93 str_value.parse::<#field_type>().ok()
94 },
95 quote! { None }
96 )
97 }
98 _ => {
99 let field_type = &field.ty;
100 (
101 quote! {
102 str_value
103 .parse::<#field_type>()
104 .map_err(|e| #ruma_common::api::error::HeaderDeserializationError::InvalidHeader(e.into()))?
105 },
106 quote! {
107 return Err(
108 #ruma_common::api::error::HeaderDeserializationError::MissingHeader(
109 #header_name_string.into()
110 ).into(),
111 )
112 },
113 )
114 }
115 };
116
117 let decl = quote! {
118 #( #cfg_attrs )*
119 let #field_name = match headers.get(#header_name) {
120 Some(header_value) => {
121 let str_value = header_value.to_str()?;
122 #some_case
123 }
124 None => #none_case,
125 };
126 };
127
128 (
129 decl,
130 quote! {
131 #( #cfg_attrs )*
132 #field_name
133 },
134 )
135 })
136 .unzip();
137
138 let parse = quote! {
139 let headers = request.headers();
140
141 #decls
142 };
143
144 (parse, quote! { #(#names,)* })
145 } else {
146 (TokenStream::new(), TokenStream::new())
147 };
148
149 let extract_body = self.has_body_fields().then(|| {
150 quote! {
151 let request_body: RequestBody = {
152 let body = ::std::convert::AsRef::<[::std::primitive::u8]>::as_ref(
153 request.body(),
154 );
155
156 #serde_json::from_slice(match body {
157 [] => b"{}",
161 b => b,
162 })?
163 };
164 }
165 });
166
167 let (parse_body, body_vars) = if let Some(field) = self.raw_body_field() {
168 let field_name = field.ident.as_ref().expect("expected field to have an identifier");
169 let parse = quote! {
170 let #field_name =
171 ::std::convert::AsRef::<[u8]>::as_ref(request.body()).to_vec();
172 };
173
174 (parse, quote! { #field_name, })
175 } else {
176 vars(self.body_fields(), quote! { request_body })
177 };
178
179 quote! {
180 #[automatically_derived]
181 #[cfg(feature = "server")]
182 impl #ruma_common::api::IncomingRequest for Request {
183 type EndpointError = #error_ty;
184 type OutgoingResponse = Response;
185
186 const METADATA: #ruma_common::api::Metadata = METADATA;
187
188 fn try_from_http_request<B, S>(
189 request: #http::Request<B>,
190 path_args: &[S],
191 ) -> ::std::result::Result<Self, #ruma_common::api::error::FromHttpRequestError>
192 where
193 B: ::std::convert::AsRef<[::std::primitive::u8]>,
194 S: ::std::convert::AsRef<::std::primitive::str>,
195 {
196 if !(request.method() == METADATA.method
197 || request.method() == #http::Method::HEAD
198 && METADATA.method == #http::Method::GET)
199 {
200 return Err(#ruma_common::api::error::FromHttpRequestError::MethodMismatch {
201 expected: METADATA.method,
202 received: request.method().clone(),
203 });
204 }
205
206 #parse_request_path
207 #parse_query
208 #parse_headers
209
210 #extract_body
211 #parse_body
212
213 ::std::result::Result::Ok(Self {
214 #path_vars
215 #query_vars
216 #header_vars
217 #body_vars
218 })
219 }
220 }
221 }
222 }
223}
224
225fn vars<'a>(
226 fields: impl IntoIterator<Item = &'a Field>,
227 src: TokenStream,
228) -> (TokenStream, TokenStream) {
229 fields
230 .into_iter()
231 .map(|field| {
232 let field_name = field.ident.as_ref().expect("expected field to have an identifier");
233 let cfg_attrs =
234 field.attrs.iter().filter(|a| a.path().is_ident("cfg")).collect::<Vec<_>>();
235
236 let decl = quote! {
237 #( #cfg_attrs )*
238 let #field_name = #src.#field_name;
239 };
240
241 (
242 decl,
243 quote! {
244 #( #cfg_attrs )*
245 #field_name,
246 },
247 )
248 })
249 .unzip()
250}