use cfg_if::cfg_if;
use proc_macro2::TokenStream;
use quote::{quote, ToTokens};
use syn::{
parse::{Parse, ParseStream},
punctuated::Punctuated,
Field, Generics, Ident, ItemStruct, Token, Type,
};
use super::{
attribute::{DeriveRequestMeta, RequestMeta},
ensure_feature_presence,
};
use crate::util::{field_has_serde_flatten_attribute, import_ruma_common, PrivateField};
mod incoming;
mod outgoing;
pub fn expand_request(attr: RequestAttr, item: ItemStruct) -> TokenStream {
let ruma_common = import_ruma_common();
let ruma_macros = quote! { #ruma_common::exports::ruma_macros };
let maybe_feature_error = ensure_feature_presence().map(syn::Error::to_compile_error);
let error_ty = attr.0.first().map_or_else(
|| quote! { #ruma_common::api::error::MatrixError },
|DeriveRequestMeta::Error(ty)| quote! { #ty },
);
cfg_if! {
if #[cfg(feature = "__internal_macro_expand")] {
use syn::parse_quote;
let mut derive_input = item.clone();
derive_input.attrs.push(parse_quote! { #[ruma_api(error = #error_ty)] });
crate::util::cfg_expand_struct(&mut derive_input);
let extra_derive = quote! { #ruma_macros::_FakeDeriveRumaApi };
let ruma_api_attribute = quote! {};
let request_impls =
expand_derive_request(derive_input).unwrap_or_else(syn::Error::into_compile_error);
} else {
let extra_derive = quote! { #ruma_macros::Request };
let ruma_api_attribute = quote! { #[ruma_api(error = #error_ty)] };
let request_impls = quote! {};
}
}
quote! {
#maybe_feature_error
#[derive(Clone, Debug, #ruma_common::serde::_FakeDeriveSerde, #extra_derive)]
#[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
#ruma_api_attribute
#item
#request_impls
}
}
pub struct RequestAttr(Punctuated<DeriveRequestMeta, Token![,]>);
impl Parse for RequestAttr {
fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
Punctuated::<DeriveRequestMeta, Token![,]>::parse_terminated(input).map(Self)
}
}
pub fn expand_derive_request(input: ItemStruct) -> syn::Result<TokenStream> {
let fields =
input.fields.into_iter().map(RequestField::try_from).collect::<syn::Result<_>>()?;
let mut error_ty = None;
for attr in input.attrs {
if !attr.path().is_ident("ruma_api") {
continue;
}
let metas =
attr.parse_args_with(Punctuated::<DeriveRequestMeta, Token![,]>::parse_terminated)?;
for meta in metas {
match meta {
DeriveRequestMeta::Error(t) => error_ty = Some(t),
}
}
}
let request = Request {
ident: input.ident,
generics: input.generics,
fields,
error_ty: error_ty.expect("missing error_ty attribute"),
};
let ruma_common = import_ruma_common();
let test = request.check(&ruma_common)?;
let types_impls = request.expand_all(&ruma_common);
Ok(quote! {
#types_impls
#[allow(deprecated)]
#[cfg(test)]
mod __request {
#test
}
})
}
struct Request {
ident: Ident,
generics: Generics,
fields: Vec<RequestField>,
error_ty: Type,
}
impl Request {
fn body_fields(&self) -> impl Iterator<Item = &Field> {
self.fields.iter().filter_map(RequestField::as_body_field)
}
fn has_body_fields(&self) -> bool {
self.fields
.iter()
.any(|f| matches!(&f.kind, RequestFieldKind::Body | RequestFieldKind::NewtypeBody))
}
fn has_newtype_body(&self) -> bool {
self.fields.iter().any(|f| matches!(&f.kind, RequestFieldKind::NewtypeBody))
}
fn has_header_fields(&self) -> bool {
self.fields.iter().any(|f| matches!(&f.kind, RequestFieldKind::Header(_)))
}
fn has_path_fields(&self) -> bool {
self.fields.iter().any(|f| matches!(&f.kind, RequestFieldKind::Path))
}
fn has_query_fields(&self) -> bool {
self.fields.iter().any(|f| matches!(&f.kind, RequestFieldKind::Query))
}
fn header_fields(&self) -> impl Iterator<Item = (&Field, &Ident)> {
self.fields.iter().filter_map(RequestField::as_header_field)
}
fn path_fields(&self) -> impl Iterator<Item = &Field> {
self.fields.iter().filter_map(RequestField::as_path_field)
}
fn raw_body_field(&self) -> Option<&Field> {
self.fields.iter().find_map(RequestField::as_raw_body_field)
}
fn query_all_field(&self) -> Option<&Field> {
self.fields.iter().find_map(RequestField::as_query_all_field)
}
fn expand_all(&self, ruma_common: &TokenStream) -> TokenStream {
let ruma_macros = quote! { #ruma_common::exports::ruma_macros };
let serde = quote! { #ruma_common::exports::serde };
let request_body_struct = self.has_body_fields().then(|| {
let serde_attr = self.has_newtype_body().then(|| quote! { #[serde(transparent)] });
let fields =
self.fields.iter().filter_map(RequestField::as_body_field).map(PrivateField);
quote! {
#[cfg(any(feature = "client", feature = "server"))]
#[derive(Debug, #ruma_macros::_FakeDeriveRumaApi, #ruma_macros::_FakeDeriveSerde)]
#[cfg_attr(feature = "client", derive(#serde::Serialize))]
#[cfg_attr(feature = "server", derive(#serde::Deserialize))]
#serde_attr
struct RequestBody { #(#fields),* }
}
});
let request_query_def = if let Some(f) = self.query_all_field() {
let field = Field { ident: None, colon_token: None, ..f.clone() };
let field = PrivateField(&field);
Some(quote! { (#field); })
} else if self.has_query_fields() {
let fields =
self.fields.iter().filter_map(RequestField::as_query_field).map(PrivateField);
Some(quote! { { #(#fields),* } })
} else {
None
};
let request_query_struct = request_query_def.map(|def| {
quote! {
#[cfg(any(feature = "client", feature = "server"))]
#[derive(Debug, #ruma_macros::_FakeDeriveRumaApi, #ruma_macros::_FakeDeriveSerde)]
#[cfg_attr(feature = "client", derive(#serde::Serialize))]
#[cfg_attr(feature = "server", derive(#serde::Deserialize))]
struct RequestQuery #def
}
});
let outgoing_request_impl = self.expand_outgoing(ruma_common);
let incoming_request_impl = self.expand_incoming(ruma_common);
quote! {
#request_body_struct
#request_query_struct
#[allow(deprecated)]
mod __request_impls {
use super::*;
#outgoing_request_impl
#incoming_request_impl
}
}
}
pub(super) fn check(&self, ruma_common: &TokenStream) -> syn::Result<TokenStream> {
let http = quote! { #ruma_common::exports::http };
let newtype_body_fields = self.fields.iter().filter(|f| {
matches!(&f.kind, RequestFieldKind::NewtypeBody | RequestFieldKind::RawBody)
});
let has_newtype_body_field = match newtype_body_fields.count() {
0 => false,
1 => true,
_ => {
return Err(syn::Error::new_spanned(
&self.ident,
"Can't have more than one newtype body field",
))
}
};
let query_all_fields =
self.fields.iter().filter(|f| matches!(&f.kind, RequestFieldKind::QueryAll));
let has_query_all_field = match query_all_fields.count() {
0 => false,
1 => true,
_ => {
return Err(syn::Error::new_spanned(
&self.ident,
"Can't have more than one query_all field",
))
}
};
let mut body_fields =
self.fields.iter().filter(|f| matches!(f.kind, RequestFieldKind::Body));
let first_body_field = body_fields.next();
let has_body_fields = first_body_field.is_some();
if has_newtype_body_field && has_body_fields {
return Err(syn::Error::new_spanned(
&self.ident,
"Can't have both a newtype body field and regular body fields",
));
}
if let Some(first_body_field) = first_body_field {
let is_single_body_field = body_fields.next().is_none();
if is_single_body_field && field_has_serde_flatten_attribute(&first_body_field.inner) {
return Err(syn::Error::new_spanned(
first_body_field,
"Use `#[ruma_api(body)]` to represent the JSON body as a single field",
));
}
}
let has_query_fields = self.has_query_fields();
if has_query_all_field && has_query_fields {
return Err(syn::Error::new_spanned(
&self.ident,
"Can't have both a query_all field and regular query fields",
));
}
let path_fields = self.path_fields().map(|f| f.ident.as_ref().unwrap().to_string());
let mut tests = quote! {
#[::std::prelude::v1::test]
fn path_parameters() {
let path_params = super::METADATA._path_parameters();
let request_path_fields: &[&::std::primitive::str] = &[#(#path_fields),*];
::std::assert_eq!(
path_params, request_path_fields,
"Path parameters must match the `Request`'s `#[ruma_api(path)]` fields"
);
}
};
if has_body_fields || has_newtype_body_field {
tests.extend(quote! {
#[::std::prelude::v1::test]
fn request_is_not_get() {
::std::assert_ne!(
super::METADATA.method, #http::Method::GET,
"GET endpoints can't have body fields",
);
}
});
}
Ok(tests)
}
}
pub(super) struct RequestField {
pub(super) inner: Field,
pub(super) kind: RequestFieldKind,
}
pub(super) enum RequestFieldKind {
Body,
Header(Ident),
NewtypeBody,
RawBody,
Path,
Query,
QueryAll,
}
impl RequestField {
fn new(inner: Field, kind_attr: Option<RequestMeta>) -> Self {
let kind = match kind_attr {
Some(RequestMeta::NewtypeBody) => RequestFieldKind::NewtypeBody,
Some(RequestMeta::RawBody) => RequestFieldKind::RawBody,
Some(RequestMeta::Path) => RequestFieldKind::Path,
Some(RequestMeta::Query) => RequestFieldKind::Query,
Some(RequestMeta::QueryAll) => RequestFieldKind::QueryAll,
Some(RequestMeta::Header(header)) => RequestFieldKind::Header(header),
None => RequestFieldKind::Body,
};
Self { inner, kind }
}
pub fn as_body_field(&self) -> Option<&Field> {
match &self.kind {
RequestFieldKind::Body | RequestFieldKind::NewtypeBody => Some(&self.inner),
_ => None,
}
}
pub fn as_raw_body_field(&self) -> Option<&Field> {
match &self.kind {
RequestFieldKind::RawBody => Some(&self.inner),
_ => None,
}
}
pub fn as_path_field(&self) -> Option<&Field> {
match &self.kind {
RequestFieldKind::Path => Some(&self.inner),
_ => None,
}
}
pub fn as_query_field(&self) -> Option<&Field> {
match &self.kind {
RequestFieldKind::Query => Some(&self.inner),
_ => None,
}
}
pub fn as_query_all_field(&self) -> Option<&Field> {
match &self.kind {
RequestFieldKind::QueryAll => Some(&self.inner),
_ => None,
}
}
pub fn as_header_field(&self) -> Option<(&Field, &Ident)> {
match &self.kind {
RequestFieldKind::Header(header_name) => Some((&self.inner, header_name)),
_ => None,
}
}
}
impl TryFrom<Field> for RequestField {
type Error = syn::Error;
fn try_from(mut field: Field) -> syn::Result<Self> {
let (mut api_attrs, attrs) =
field.attrs.into_iter().partition::<Vec<_>, _>(|attr| attr.path().is_ident("ruma_api"));
field.attrs = attrs;
let kind_attr = match api_attrs.as_slice() {
[] => None,
[_] => Some(api_attrs.pop().unwrap().parse_args::<RequestMeta>()?),
_ => {
return Err(syn::Error::new_spanned(
&api_attrs[1],
"multiple field kind attribute found, there can only be one",
));
}
};
Ok(RequestField::new(field, kind_attr))
}
}
impl Parse for RequestField {
fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
input.call(Field::parse_named)?.try_into()
}
}
impl ToTokens for RequestField {
fn to_tokens(&self, tokens: &mut TokenStream) {
self.inner.to_tokens(tokens);
}
}