ruma_macros/api/response/
incoming.rs
1use proc_macro2::TokenStream;
2use quote::quote;
3use syn::Type;
4
5use super::{Response, ResponseFieldKind};
6
7impl Response {
8 pub fn expand_incoming(&self, error_ty: &Type, ruma_common: &TokenStream) -> TokenStream {
9 let http = quote! { #ruma_common::exports::http };
10 let serde_json = quote! { #ruma_common::exports::serde_json };
11
12 let extract_response_headers = self.has_header_fields().then(|| {
13 quote! {
14 let mut headers = response.headers().clone();
15 }
16 });
17
18 let typed_response_body_decl = self.has_body_fields().then(|| {
19 quote! {
20 let response_body: ResponseBody = {
21 let body = ::std::convert::AsRef::<[::std::primitive::u8]>::as_ref(
22 response.body(),
23 );
24
25 #serde_json::from_slice(match body {
26 [] => b"{}",
30 b => b,
31 })?
32 };
33 }
34 });
35
36 let response_init_fields = {
37 let mut fields = vec![];
38 let mut raw_body = None;
39
40 for response_field in &self.fields {
41 let field = &response_field.inner;
42 let field_name =
43 field.ident.as_ref().expect("expected field to have an identifier");
44 let cfg_attrs =
45 field.attrs.iter().filter(|a| a.path().is_ident("cfg")).collect::<Vec<_>>();
46
47 fields.push(match &response_field.kind {
48 ResponseFieldKind::Body | ResponseFieldKind::NewtypeBody => {
49 quote! {
50 #( #cfg_attrs )*
51 #field_name: response_body.#field_name
52 }
53 }
54 ResponseFieldKind::Header(header_name) => {
55 let optional_header = match &field.ty {
56 Type::Path(syn::TypePath {
57 path: syn::Path { segments, .. }, ..
58 }) if segments.last().unwrap().ident == "Option" => {
59 let syn::PathArguments::AngleBracketed(syn::AngleBracketedGenericArguments {
60 args: option_args, ..
61 }) = &segments.last().unwrap().arguments else {
62 panic!("Option should use angle brackets");
63 };
64 let syn::GenericArgument::Type(field_type) = option_args.first().unwrap() else {
65 panic!("Option brackets should contain type");
66 };
67 quote! {
68 #( #cfg_attrs )*
69 #field_name: {
70 headers.remove(#header_name)
71 .and_then(|h| { h.to_str().ok()?.parse::<#field_type>().ok() })
72 }
73 }
74 }
75 _ => {
76 let field_type = &field.ty;
77 quote! {
78 #( #cfg_attrs )*
79 #field_name: {
80 headers.remove(#header_name)
81 .ok_or_else(|| #ruma_common::api::error::HeaderDeserializationError::MissingHeader(
82 #header_name.to_string()
83 ))?
84 .to_str()?
85 .parse::<#field_type>()
86 .map_err(|e| #ruma_common::api::error::HeaderDeserializationError::InvalidHeader(e.into()))?
87 }
88 }
89 }
90 };
91 quote! { #optional_header }
92 }
93 ResponseFieldKind::RawBody => {
97 raw_body = Some(quote! {
98 #( #cfg_attrs )*
99 #field_name: {
100 ::std::convert::AsRef::<[::std::primitive::u8]>::as_ref(
101 response.body(),
102 )
103 .to_vec()
104 }
105 });
106 continue;
108 }
109 });
110 }
111
112 fields.extend(raw_body);
113
114 quote! {
115 #(#fields,)*
116 }
117 };
118
119 quote! {
120 #[automatically_derived]
121 #[cfg(feature = "client")]
122 impl #ruma_common::api::IncomingResponse for Response {
123 type EndpointError = #error_ty;
124
125 fn try_from_http_response<T: ::std::convert::AsRef<[::std::primitive::u8]>>(
126 response: #http::Response<T>,
127 ) -> ::std::result::Result<
128 Self,
129 #ruma_common::api::error::FromHttpResponseError<#error_ty>,
130 > {
131 if response.status().as_u16() < 400 {
132 #extract_response_headers
133 #typed_response_body_decl
134
135 ::std::result::Result::Ok(Self {
136 #response_init_fields
137 })
138 } else {
139 Err(#ruma_common::api::error::FromHttpResponseError::Server(
140 <#error_ty as #ruma_common::api::EndpointError>::from_http_response(
141 response,
142 )
143 ))
144 }
145 }
146 }
147 }
148 }
149}