ruma_macros/api/request/
outgoing.rs

1use proc_macro2::TokenStream;
2use quote::quote;
3use syn::Field;
4
5use super::{Request, RequestField};
6
7impl Request {
8    pub fn expand_outgoing(&self, ruma_common: &TokenStream) -> TokenStream {
9        let bytes = quote! { #ruma_common::exports::bytes };
10        let http = quote! { #ruma_common::exports::http };
11        let serde_html_form = quote! { #ruma_common::exports::serde_html_form };
12
13        let error_ty = &self.error_ty;
14
15        let path_fields =
16            self.path_fields().map(|f| f.ident.as_ref().expect("path fields have a name"));
17
18        let request_query_string = if let Some(field) = self.query_all_field() {
19            let field_name = field.ident.as_ref().expect("expected field to have identifier");
20
21            quote! {{
22                let request_query = RequestQuery(self.#field_name);
23
24                &#serde_html_form::to_string(request_query)?
25            }}
26        } else if self.has_query_fields() {
27            let request_query_init_fields = struct_init_fields(
28                self.fields.iter().filter_map(RequestField::as_query_field),
29                quote! { self },
30            );
31
32            quote! {{
33                let request_query = RequestQuery {
34                    #request_query_init_fields
35                };
36
37                &#serde_html_form::to_string(request_query)?
38            }}
39        } else {
40            quote! { "" }
41        };
42
43        // If there are no body fields, the request body will be empty (not `{}`), so the
44        // `application/json` content-type would be wrong. It may also cause problems with CORS
45        // policies that don't allow the `Content-Type` header (for things such as `.well-known`
46        // that are commonly handled by something else than a homeserver).
47        let mut header_kvs = if self.raw_body_field().is_some() || self.has_body_fields() {
48            quote! {
49                req_headers.insert(
50                    #http::header::CONTENT_TYPE,
51                    #http::header::HeaderValue::from_static("application/json"),
52                );
53            }
54        } else {
55            TokenStream::new()
56        };
57
58        header_kvs.extend(self.header_fields().map(|(field, header_name)| {
59            let field_name = &field.ident;
60
61            match &field.ty {
62                syn::Type::Path(syn::TypePath { path: syn::Path { segments, .. }, .. })
63                    if segments.last().unwrap().ident == "Option" =>
64                {
65                    quote! {
66                        if let Some(header_val) = self.#field_name.as_ref() {
67                            req_headers.insert(
68                                #header_name,
69                                #http::header::HeaderValue::from_str(&header_val.to_string())?,
70                            );
71                        }
72                    }
73                }
74                _ => quote! {
75                    req_headers.insert(
76                        #header_name,
77                        #http::header::HeaderValue::from_str(&self.#field_name.to_string())?,
78                    );
79                },
80            }
81        }));
82
83        header_kvs.extend(quote! {
84            req_headers.extend(METADATA.authorization_header(access_token)?);
85        });
86
87        let request_body = if let Some(field) = self.raw_body_field() {
88            let field_name = field.ident.as_ref().expect("expected field to have an identifier");
89            quote! { #ruma_common::serde::slice_to_buf(&self.#field_name) }
90        } else if self.has_body_fields() {
91            let initializers = struct_init_fields(self.body_fields(), quote! { self });
92
93            quote! {
94                #ruma_common::serde::json_to_buf(&RequestBody { #initializers })?
95            }
96        } else {
97            quote! { METADATA.empty_request_body::<T>() }
98        };
99
100        let (impl_generics, ty_generics, where_clause) = self.generics.split_for_impl();
101
102        quote! {
103            #[automatically_derived]
104            #[cfg(feature = "client")]
105            impl #impl_generics #ruma_common::api::OutgoingRequest for Request #ty_generics #where_clause {
106                type EndpointError = #error_ty;
107                type IncomingResponse = Response;
108
109                const METADATA: #ruma_common::api::Metadata = METADATA;
110
111                fn try_into_http_request<T: ::std::default::Default + #bytes::BufMut>(
112                    self,
113                    base_url: &::std::primitive::str,
114                    access_token: #ruma_common::api::SendAccessToken<'_>,
115                    considering_versions: &'_ [#ruma_common::api::MatrixVersion],
116                ) -> ::std::result::Result<#http::Request<T>, #ruma_common::api::error::IntoHttpError> {
117                    let mut req_builder = #http::Request::builder()
118                        .method(METADATA.method)
119                        .uri(METADATA.make_endpoint_url(
120                            considering_versions,
121                            base_url,
122                            &[ #( &self.#path_fields ),* ],
123                            #request_query_string,
124                        )?);
125
126                    if let Some(mut req_headers) = req_builder.headers_mut() {
127                        #header_kvs
128                    }
129
130                    let http_request = req_builder.body(#request_body)?;
131
132                    Ok(http_request)
133                }
134            }
135        }
136    }
137}
138
139/// Produces code for a struct initializer for the given field kind to be accessed through the
140/// given variable name.
141fn struct_init_fields<'a>(
142    fields: impl IntoIterator<Item = &'a Field>,
143    src: TokenStream,
144) -> TokenStream {
145    fields
146        .into_iter()
147        .map(|field| {
148            let field_name = field.ident.as_ref().expect("expected field to have an identifier");
149            let cfg_attrs =
150                field.attrs.iter().filter(|a| a.path().is_ident("cfg")).collect::<Vec<_>>();
151
152            quote! {
153                #( #cfg_attrs )*
154                #field_name: #src.#field_name,
155            }
156        })
157        .collect()
158}