ruma_macros/api/request/
outgoing.rs1use proc_macro2::TokenStream;
2use quote::quote;
3use syn::Field;
4
5use super::{Request, RequestField};
6
7impl Request {
8 pub fn expand_outgoing(&self, ruma_common: &TokenStream) -> TokenStream {
9 let bytes = quote! { #ruma_common::exports::bytes };
10 let http = quote! { #ruma_common::exports::http };
11 let serde_html_form = quote! { #ruma_common::exports::serde_html_form };
12
13 let error_ty = &self.error_ty;
14
15 let path_fields =
16 self.path_fields().map(|f| f.ident.as_ref().expect("path fields have a name"));
17
18 let request_query_string = if let Some(field) = self.query_all_field() {
19 let field_name = field.ident.as_ref().expect("expected field to have identifier");
20
21 quote! {{
22 let request_query = RequestQuery(self.#field_name);
23
24 &#serde_html_form::to_string(request_query)?
25 }}
26 } else if self.has_query_fields() {
27 let request_query_init_fields = struct_init_fields(
28 self.fields.iter().filter_map(RequestField::as_query_field),
29 quote! { self },
30 );
31
32 quote! {{
33 let request_query = RequestQuery {
34 #request_query_init_fields
35 };
36
37 &#serde_html_form::to_string(request_query)?
38 }}
39 } else {
40 quote! { "" }
41 };
42
43 let mut header_kvs = if self.raw_body_field().is_some() || self.has_body_fields() {
48 quote! {
49 req_headers.insert(
50 #http::header::CONTENT_TYPE,
51 #http::header::HeaderValue::from_static("application/json"),
52 );
53 }
54 } else {
55 TokenStream::new()
56 };
57
58 header_kvs.extend(self.header_fields().map(|(field, header_name)| {
59 let field_name = &field.ident;
60
61 match &field.ty {
62 syn::Type::Path(syn::TypePath { path: syn::Path { segments, .. }, .. })
63 if segments.last().unwrap().ident == "Option" =>
64 {
65 quote! {
66 if let Some(header_val) = self.#field_name.as_ref() {
67 req_headers.insert(
68 #header_name,
69 #http::header::HeaderValue::from_str(&header_val.to_string())?,
70 );
71 }
72 }
73 }
74 _ => quote! {
75 req_headers.insert(
76 #header_name,
77 #http::header::HeaderValue::from_str(&self.#field_name.to_string())?,
78 );
79 },
80 }
81 }));
82
83 header_kvs.extend(quote! {
84 req_headers.extend(METADATA.authorization_header(access_token)?);
85 });
86
87 let request_body = if let Some(field) = self.raw_body_field() {
88 let field_name = field.ident.as_ref().expect("expected field to have an identifier");
89 quote! { #ruma_common::serde::slice_to_buf(&self.#field_name) }
90 } else if self.has_body_fields() {
91 let initializers = struct_init_fields(self.body_fields(), quote! { self });
92
93 quote! {
94 #ruma_common::serde::json_to_buf(&RequestBody { #initializers })?
95 }
96 } else {
97 quote! { METADATA.empty_request_body::<T>() }
98 };
99
100 let (impl_generics, ty_generics, where_clause) = self.generics.split_for_impl();
101
102 quote! {
103 #[automatically_derived]
104 #[cfg(feature = "client")]
105 impl #impl_generics #ruma_common::api::OutgoingRequest for Request #ty_generics #where_clause {
106 type EndpointError = #error_ty;
107 type IncomingResponse = Response;
108
109 const METADATA: #ruma_common::api::Metadata = METADATA;
110
111 fn try_into_http_request<T: ::std::default::Default + #bytes::BufMut>(
112 self,
113 base_url: &::std::primitive::str,
114 access_token: #ruma_common::api::SendAccessToken<'_>,
115 considering_versions: &'_ [#ruma_common::api::MatrixVersion],
116 ) -> ::std::result::Result<#http::Request<T>, #ruma_common::api::error::IntoHttpError> {
117 let mut req_builder = #http::Request::builder()
118 .method(METADATA.method)
119 .uri(METADATA.make_endpoint_url(
120 considering_versions,
121 base_url,
122 &[ #( &self.#path_fields ),* ],
123 #request_query_string,
124 )?);
125
126 if let Some(mut req_headers) = req_builder.headers_mut() {
127 #header_kvs
128 }
129
130 let http_request = req_builder.body(#request_body)?;
131
132 Ok(http_request)
133 }
134 }
135 }
136 }
137}
138
139fn struct_init_fields<'a>(
142 fields: impl IntoIterator<Item = &'a Field>,
143 src: TokenStream,
144) -> TokenStream {
145 fields
146 .into_iter()
147 .map(|field| {
148 let field_name = field.ident.as_ref().expect("expected field to have an identifier");
149 let cfg_attrs =
150 field.attrs.iter().filter(|a| a.path().is_ident("cfg")).collect::<Vec<_>>();
151
152 quote! {
153 #( #cfg_attrs )*
154 #field_name: #src.#field_name,
155 }
156 })
157 .collect()
158}