ruma_federation_api/serde/
pdu_process_response.rs

1use std::{collections::BTreeMap, fmt};
2
3use ruma_common::OwnedEventId;
4use serde::{
5    de::{Deserializer, MapAccess, Visitor},
6    ser::{SerializeMap, Serializer},
7    Deserialize, Serialize,
8};
9
10#[derive(Deserialize, Serialize)]
11struct WrappedError {
12    #[serde(skip_serializing_if = "Option::is_none")]
13    error: Option<String>,
14}
15
16pub(crate) fn serialize<S>(
17    response: &BTreeMap<OwnedEventId, Result<(), String>>,
18    serializer: S,
19) -> Result<S::Ok, S::Error>
20where
21    S: Serializer,
22{
23    let mut map = serializer.serialize_map(Some(response.len()))?;
24    for (key, value) in response {
25        let wrapped_error = WrappedError { error: value.clone().err() };
26        map.serialize_entry(&key, &wrapped_error)?;
27    }
28    map.end()
29}
30
31#[allow(clippy::type_complexity)]
32pub(crate) fn deserialize<'de, D>(
33    deserializer: D,
34) -> Result<BTreeMap<OwnedEventId, Result<(), String>>, D::Error>
35where
36    D: Deserializer<'de>,
37{
38    struct PduProcessResponseVisitor;
39
40    impl<'de> Visitor<'de> for PduProcessResponseVisitor {
41        type Value = BTreeMap<OwnedEventId, Result<(), String>>;
42
43        fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
44            formatter.write_str("A map of EventIds to a map of optional errors")
45        }
46
47        fn visit_map<M>(self, mut access: M) -> Result<Self::Value, M::Error>
48        where
49            M: MapAccess<'de>,
50        {
51            let mut map = BTreeMap::new();
52
53            while let Some((key, value)) = access.next_entry::<OwnedEventId, WrappedError>()? {
54                let v = match value.error {
55                    None => Ok(()),
56                    Some(error) => Err(error),
57                };
58                map.insert(key, v);
59            }
60            Ok(map)
61        }
62    }
63
64    deserializer.deserialize_map(PduProcessResponseVisitor)
65}
66
67#[cfg(test)]
68mod tests {
69    use std::collections::BTreeMap;
70
71    use ruma_common::{event_id, owned_event_id, OwnedEventId};
72    use serde_json::{json, value::Serializer as JsonSerializer};
73
74    use super::{deserialize, serialize};
75
76    #[test]
77    fn serialize_error() {
78        let mut response: BTreeMap<OwnedEventId, Result<(), String>> = BTreeMap::new();
79        response
80            .insert(owned_event_id!("$someevent:matrix.org"), Err("Some processing error.".into()));
81
82        let serialized = serialize(&response, JsonSerializer).unwrap();
83        let json = json!({
84            "$someevent:matrix.org": { "error": "Some processing error." }
85        });
86        assert_eq!(serialized, json);
87    }
88
89    #[test]
90    fn serialize_ok() {
91        let mut response: BTreeMap<OwnedEventId, Result<(), String>> = BTreeMap::new();
92        response.insert(owned_event_id!("$someevent:matrix.org"), Ok(()));
93
94        let serialized = serialize(&response, serde_json::value::Serializer).unwrap();
95        let json = json!({
96            "$someevent:matrix.org": {}
97        });
98        assert_eq!(serialized, json);
99    }
100
101    #[test]
102    fn deserialize_error() {
103        let json = json!({
104            "$someevent:matrix.org": { "error": "Some processing error." }
105        });
106
107        let response = deserialize(json).unwrap();
108        let event_id = event_id!("$someevent:matrix.org");
109
110        let event_response = response.get(event_id).unwrap().clone().unwrap_err();
111        assert_eq!(event_response, "Some processing error.");
112    }
113
114    #[test]
115    fn deserialize_null_error_is_ok() {
116        let json = json!({
117            "$someevent:matrix.org": { "error": null }
118        });
119
120        let response = deserialize(json).unwrap();
121        let event_id = event_id!("$someevent:matrix.org");
122
123        response.get(event_id).unwrap().as_ref().unwrap();
124    }
125
126    #[test]
127    fn deserialize_empty_error_is_err() {
128        let json = json!({
129            "$someevent:matrix.org": { "error": "" }
130        });
131
132        let response = deserialize(json).unwrap();
133        let event_id = event_id!("$someevent:matrix.org");
134
135        let event_response = response.get(event_id).unwrap().clone().unwrap_err();
136        assert_eq!(event_response, "");
137    }
138
139    #[test]
140    fn deserialize_ok() {
141        let json = json!({
142            "$someevent:matrix.org": {}
143        });
144        let response = deserialize(json).unwrap();
145        response.get(event_id!("$someevent:matrix.org")).unwrap().as_ref().unwrap();
146    }
147}