ruma_macros/api/response/
incoming.rs

1use proc_macro2::TokenStream;
2use quote::quote;
3use syn::Type;
4
5use super::{Response, ResponseFieldKind};
6
7impl Response {
8    pub fn expand_incoming(&self, error_ty: &Type, ruma_common: &TokenStream) -> TokenStream {
9        let http = quote! { #ruma_common::exports::http };
10        let serde_json = quote! { #ruma_common::exports::serde_json };
11
12        let extract_response_headers = self.has_header_fields().then(|| {
13            quote! {
14                let mut headers = response.headers().clone();
15            }
16        });
17
18        let typed_response_body_decl = self.has_body_fields().then(|| {
19            quote! {
20                let response_body: ResponseBody = {
21                    let body = ::std::convert::AsRef::<[::std::primitive::u8]>::as_ref(
22                        response.body(),
23                    );
24
25                    #serde_json::from_slice(match body {
26                        // If the response body is completely empty, pretend it is an empty
27                        // JSON object instead. This allows responses with only optional body
28                        // parameters to be deserialized in that case.
29                        [] => b"{}",
30                        b => b,
31                    })?
32                };
33            }
34        });
35
36        let response_init_fields = {
37            let mut fields = vec![];
38            let mut raw_body = None;
39
40            for response_field in &self.fields {
41                let field = &response_field.inner;
42                let field_name =
43                    field.ident.as_ref().expect("expected field to have an identifier");
44                let cfg_attrs =
45                    field.attrs.iter().filter(|a| a.path().is_ident("cfg")).collect::<Vec<_>>();
46
47                fields.push(match &response_field.kind {
48                    ResponseFieldKind::Body | ResponseFieldKind::NewtypeBody => {
49                        quote! {
50                            #( #cfg_attrs )*
51                            #field_name: response_body.#field_name
52                        }
53                    }
54                    ResponseFieldKind::Header(header_name) => {
55                        let optional_header = match &field.ty {
56                            Type::Path(syn::TypePath {
57                                path: syn::Path { segments, .. }, ..
58                            }) if segments.last().unwrap().ident == "Option" => {
59                                let syn::PathArguments::AngleBracketed(syn::AngleBracketedGenericArguments {
60                                    args: option_args, ..
61                                }) = &segments.last().unwrap().arguments else {
62                                    panic!("Option should use angle brackets");
63                                };
64                                let syn::GenericArgument::Type(field_type) = option_args.first().unwrap() else {
65                                    panic!("Option brackets should contain type");
66                                };
67                                quote! {
68                                    #( #cfg_attrs )*
69                                    #field_name: {
70                                        headers.remove(#header_name)
71                                            .and_then(|h| { h.to_str().ok()?.parse::<#field_type>().ok() })
72                                    }
73                                }
74                            }
75                            _ => {
76                                let field_type = &field.ty;
77                                quote! {
78                                    #( #cfg_attrs )*
79                                    #field_name: {
80                                        headers.remove(#header_name)
81                                            .ok_or_else(|| #ruma_common::api::error::HeaderDeserializationError::MissingHeader(
82                                                #header_name.to_string()
83                                            ))?
84                                            .to_str()?
85                                            .parse::<#field_type>()
86                                            .map_err(|e| #ruma_common::api::error::HeaderDeserializationError::InvalidHeader(e.into()))?
87                                    }
88                                }
89                            }
90                        };
91                        quote! { #optional_header }
92                    }
93                    // This field must be instantiated last to avoid `use of move value` error.
94                    // We are guaranteed only one new body field because of a check in
95                    // `parse_response`.
96                    ResponseFieldKind::RawBody => {
97                        raw_body = Some(quote! {
98                            #( #cfg_attrs )*
99                            #field_name: {
100                                ::std::convert::AsRef::<[::std::primitive::u8]>::as_ref(
101                                    response.body(),
102                                )
103                                .to_vec()
104                            }
105                        });
106                        // skip adding to the vec
107                        continue;
108                    }
109                });
110            }
111
112            fields.extend(raw_body);
113
114            quote! {
115                #(#fields,)*
116            }
117        };
118
119        quote! {
120            #[automatically_derived]
121            #[cfg(feature = "client")]
122            impl #ruma_common::api::IncomingResponse for Response {
123                type EndpointError = #error_ty;
124
125                fn try_from_http_response<T: ::std::convert::AsRef<[::std::primitive::u8]>>(
126                    response: #http::Response<T>,
127                ) -> ::std::result::Result<
128                    Self,
129                    #ruma_common::api::error::FromHttpResponseError<#error_ty>,
130                > {
131                    if response.status().as_u16() < 400 {
132                        #extract_response_headers
133                        #typed_response_body_decl
134
135                        ::std::result::Result::Ok(Self {
136                            #response_init_fields
137                        })
138                    } else {
139                        Err(#ruma_common::api::error::FromHttpResponseError::Server(
140                            <#error_ty as #ruma_common::api::EndpointError>::from_http_response(
141                                response,
142                            )
143                        ))
144                    }
145                }
146            }
147        }
148    }
149}