ruma_common/serde/
raw.rs

1use std::{
2    clone::Clone,
3    fmt::{self, Debug},
4    marker::PhantomData,
5    mem,
6};
7
8use serde::{
9    de::{self, Deserialize, DeserializeSeed, Deserializer, IgnoredAny, MapAccess, Visitor},
10    ser::{Serialize, Serializer},
11};
12use serde_json::value::{to_raw_value as to_raw_json_value, RawValue as RawJsonValue};
13
14/// A wrapper around `Box<RawValue>` with a generic parameter for the expected Rust type.
15///
16/// Ruma offers the `Raw` wrapper to enable passing around JSON text that is only partially
17/// validated. This is useful when a client receives events that do not follow the spec perfectly
18/// or a server needs to generate reference hashes with the original canonical JSON string.
19/// All structs and enums representing event types implement `Deserialize`, therefore they can be
20/// used with `Raw`. Since `Raw` does not change the JSON string, it should be used to pass around
21/// events in a lossless way.
22///
23/// ```no_run
24/// # use serde::Deserialize;
25/// # use ruma_common::serde::Raw;
26/// # #[derive(Deserialize)]
27/// # struct AnyTimelineEvent;
28///
29/// let json = r#"{ "type": "imagine a full event", "content": {...} }"#;
30///
31/// let deser = serde_json::from_str::<Raw<AnyTimelineEvent>>(json)
32///     .unwrap() // the first Result from serde_json::from_str, will not fail
33///     .deserialize() // deserialize to the inner type
34///     .unwrap(); // finally get to the AnyTimelineEvent
35/// ```
36#[repr(transparent)]
37pub struct Raw<T> {
38    json: Box<RawJsonValue>,
39    _ev: PhantomData<T>,
40}
41
42impl<T> Raw<T> {
43    /// Create a `Raw` by serializing the given `T`.
44    ///
45    /// Shorthand for `serde_json::value::to_raw_value(val).map(Raw::from_json)`, but specialized to
46    /// `T`.
47    ///
48    /// # Errors
49    ///
50    /// Fails if `T`s [`Serialize`] implementation fails.
51    pub fn new(val: &T) -> serde_json::Result<Self>
52    where
53        T: Serialize,
54    {
55        to_raw_json_value(val).map(Self::from_json)
56    }
57
58    /// Create a `Raw` from a boxed `RawValue`.
59    pub fn from_json(json: Box<RawJsonValue>) -> Self {
60        Self { json, _ev: PhantomData }
61    }
62
63    /// Convert an owned `String` of JSON data to `Raw<T>`.
64    ///
65    /// This function is equivalent to `serde_json::from_str::<Raw<T>>` except that an allocation
66    /// and copy is avoided if both of the following are true:
67    ///
68    /// * the input has no leading or trailing whitespace, and
69    /// * the input has capacity equal to its length.
70    pub fn from_json_string(json: String) -> serde_json::Result<Self> {
71        RawJsonValue::from_string(json).map(Self::from_json)
72    }
73
74    /// Access the underlying json value.
75    pub fn json(&self) -> &RawJsonValue {
76        &self.json
77    }
78
79    /// Convert `self` into the underlying json value.
80    pub fn into_json(self) -> Box<RawJsonValue> {
81        self.json
82    }
83
84    /// Try to access a given field inside this `Raw`, assuming it contains an object.
85    ///
86    /// Returns `Err(_)` when the contained value is not an object, or the field exists but is fails
87    /// to deserialize to the expected type.
88    ///
89    /// Returns `Ok(None)` when the field doesn't exist or is `null`.
90    ///
91    /// # Example
92    ///
93    /// ```no_run
94    /// # type CustomMatrixEvent = ();
95    /// # fn foo() -> serde_json::Result<()> {
96    /// # let raw_event: ruma_common::serde::Raw<()> = todo!();
97    /// if raw_event.get_field::<String>("type")?.as_deref() == Some("org.custom.matrix.event") {
98    ///     let event = raw_event.deserialize_as::<CustomMatrixEvent>()?;
99    ///     // ...
100    /// }
101    /// # Ok(())
102    /// # }
103    /// ```
104    pub fn get_field<'a, U>(&'a self, field_name: &str) -> serde_json::Result<Option<U>>
105    where
106        U: Deserialize<'a>,
107    {
108        struct FieldVisitor<'b>(&'b str);
109
110        impl Visitor<'_> for FieldVisitor<'_> {
111            type Value = bool;
112
113            fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
114                write!(formatter, "`{}`", self.0)
115            }
116
117            fn visit_str<E>(self, value: &str) -> Result<bool, E>
118            where
119                E: de::Error,
120            {
121                Ok(value == self.0)
122            }
123        }
124
125        struct Field<'b>(&'b str);
126
127        impl<'de> DeserializeSeed<'de> for Field<'_> {
128            type Value = bool;
129
130            fn deserialize<D>(self, deserializer: D) -> Result<bool, D::Error>
131            where
132                D: Deserializer<'de>,
133            {
134                deserializer.deserialize_identifier(FieldVisitor(self.0))
135            }
136        }
137
138        struct SingleFieldVisitor<'b, T> {
139            field_name: &'b str,
140            _phantom: PhantomData<T>,
141        }
142
143        impl<'b, T> SingleFieldVisitor<'b, T> {
144            fn new(field_name: &'b str) -> Self {
145                Self { field_name, _phantom: PhantomData }
146            }
147        }
148
149        impl<'de, T> Visitor<'de> for SingleFieldVisitor<'_, T>
150        where
151            T: Deserialize<'de>,
152        {
153            type Value = Option<T>;
154
155            fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
156                formatter.write_str("a string")
157            }
158
159            fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
160            where
161                A: MapAccess<'de>,
162            {
163                let mut res = None;
164                while let Some(is_right_field) = map.next_key_seed(Field(self.field_name))? {
165                    if is_right_field {
166                        res = Some(map.next_value()?);
167                    } else {
168                        map.next_value::<IgnoredAny>()?;
169                    }
170                }
171
172                Ok(res)
173            }
174        }
175
176        let mut deserializer = serde_json::Deserializer::from_str(self.json().get());
177        deserializer.deserialize_map(SingleFieldVisitor::new(field_name))
178    }
179
180    /// Try to deserialize the JSON as the expected type.
181    pub fn deserialize<'a>(&'a self) -> serde_json::Result<T>
182    where
183        T: Deserialize<'a>,
184    {
185        serde_json::from_str(self.json.get())
186    }
187
188    /// Try to deserialize the JSON as a custom type.
189    pub fn deserialize_as<'a, U>(&'a self) -> serde_json::Result<U>
190    where
191        U: Deserialize<'a>,
192    {
193        serde_json::from_str(self.json.get())
194    }
195
196    /// Turns `Raw<T>` into `Raw<U>` without changing the underlying JSON.
197    ///
198    /// This is useful for turning raw specific event types into raw event enum types.
199    pub fn cast<U>(self) -> Raw<U> {
200        Raw::from_json(self.into_json())
201    }
202
203    /// Turns `&Raw<T>` into `&Raw<U>` without changing the underlying JSON.
204    ///
205    /// This is useful for turning raw specific event types into raw event enum types.
206    pub fn cast_ref<U>(&self) -> &Raw<U> {
207        unsafe { mem::transmute(self) }
208    }
209}
210
211impl<T> Clone for Raw<T> {
212    fn clone(&self) -> Self {
213        Self::from_json(self.json.clone())
214    }
215}
216
217impl<T> Debug for Raw<T> {
218    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
219        use std::any::type_name;
220        f.debug_struct(&format!("Raw::<{}>", type_name::<T>())).field("json", &self.json).finish()
221    }
222}
223
224impl<'de, T> Deserialize<'de> for Raw<T> {
225    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
226    where
227        D: Deserializer<'de>,
228    {
229        Box::<RawJsonValue>::deserialize(deserializer).map(Self::from_json)
230    }
231}
232
233impl<T> Serialize for Raw<T> {
234    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
235    where
236        S: Serializer,
237    {
238        self.json.serialize(serializer)
239    }
240}
241
242#[cfg(test)]
243mod tests {
244    use serde::Deserialize;
245    use serde_json::{from_str as from_json_str, value::RawValue as RawJsonValue};
246
247    use super::Raw;
248
249    #[test]
250    fn get_field() -> serde_json::Result<()> {
251        #[derive(Debug, PartialEq, Deserialize)]
252        struct A<'a> {
253            #[serde(borrow)]
254            b: Vec<&'a str>,
255        }
256
257        const OBJ: &str = r#"{ "a": { "b": [  "c"] }, "z": 5 }"#;
258        let raw: Raw<()> = from_json_str(OBJ)?;
259
260        assert_eq!(raw.get_field::<u8>("z")?, Some(5));
261        assert_eq!(raw.get_field::<&RawJsonValue>("a")?.unwrap().get(), r#"{ "b": [  "c"] }"#);
262        assert_eq!(raw.get_field::<A<'_>>("a")?, Some(A { b: vec!["c"] }));
263
264        assert_eq!(raw.get_field::<u8>("b")?, None);
265        raw.get_field::<u8>("a").unwrap_err();
266
267        Ok(())
268    }
269}