ruma_common/serde/
raw.rs

1use std::{
2    clone::Clone,
3    fmt::{self, Debug},
4    marker::PhantomData,
5    mem,
6};
7
8use serde::{
9    de::{self, Deserialize, DeserializeSeed, Deserializer, IgnoredAny, MapAccess, Visitor},
10    ser::{Serialize, Serializer},
11};
12use serde_json::value::{
13    to_raw_value as to_raw_json_value, RawValue as RawJsonValue, Value as JsonValue,
14};
15
16/// A wrapper around `Box<RawValue>` with a generic parameter for the expected Rust type.
17///
18/// Ruma offers the `Raw` wrapper to enable passing around JSON text that is only partially
19/// validated. This is useful when a client receives events that do not follow the spec perfectly
20/// or a server needs to generate reference hashes with the original canonical JSON string.
21/// All structs and enums representing event types implement `Deserialize`, therefore they can be
22/// used with `Raw`. Since `Raw` does not change the JSON string, it should be used to pass around
23/// events in a lossless way.
24///
25/// ```no_run
26/// # use serde::Deserialize;
27/// # use ruma_common::serde::Raw;
28/// # #[derive(Deserialize)]
29/// # struct AnyTimelineEvent;
30///
31/// let json = r#"{ "type": "imagine a full event", "content": {...} }"#;
32///
33/// let deser = serde_json::from_str::<Raw<AnyTimelineEvent>>(json)
34///     .unwrap() // the first Result from serde_json::from_str, will not fail
35///     .deserialize() // deserialize to the inner type
36///     .unwrap(); // finally get to the AnyTimelineEvent
37/// ```
38#[repr(transparent)]
39pub struct Raw<T> {
40    json: Box<RawJsonValue>,
41    _ev: PhantomData<T>,
42}
43
44impl<T> Raw<T> {
45    /// Create a `Raw` by serializing the given `T`.
46    ///
47    /// Shorthand for `serde_json::value::to_raw_value(val).map(Raw::from_json)`, but specialized to
48    /// `T`.
49    ///
50    /// # Errors
51    ///
52    /// Fails if `T`s [`Serialize`] implementation fails.
53    pub fn new(val: &T) -> serde_json::Result<Self>
54    where
55        T: Serialize,
56    {
57        to_raw_json_value(val).map(Self::from_json)
58    }
59
60    /// Create a `Raw` from a boxed `RawValue`.
61    pub fn from_json(json: Box<RawJsonValue>) -> Self {
62        Self { json, _ev: PhantomData }
63    }
64
65    /// Convert an owned `String` of JSON data to `Raw<T>`.
66    ///
67    /// This function is equivalent to `serde_json::from_str::<Raw<T>>` except that an allocation
68    /// and copy is avoided if both of the following are true:
69    ///
70    /// * the input has no leading or trailing whitespace, and
71    /// * the input has capacity equal to its length.
72    pub fn from_json_string(json: String) -> serde_json::Result<Self> {
73        RawJsonValue::from_string(json).map(Self::from_json)
74    }
75
76    /// Access the underlying json value.
77    pub fn json(&self) -> &RawJsonValue {
78        &self.json
79    }
80
81    /// Convert `self` into the underlying json value.
82    pub fn into_json(self) -> Box<RawJsonValue> {
83        self.json
84    }
85
86    /// Try to access a given field inside this `Raw`, assuming it contains an object.
87    ///
88    /// Returns `Err(_)` when the contained value is not an object, or the field exists but is fails
89    /// to deserialize to the expected type.
90    ///
91    /// Returns `Ok(None)` when the field doesn't exist or is `null`.
92    ///
93    /// # Example
94    ///
95    /// ```no_run
96    /// # type CustomMatrixEvent = ();
97    /// # fn foo() -> serde_json::Result<()> {
98    /// # let raw_event: ruma_common::serde::Raw<()> = todo!();
99    /// if raw_event.get_field::<String>("type")?.as_deref() == Some("org.custom.matrix.event") {
100    ///     let event = raw_event.deserialize_as_unchecked::<CustomMatrixEvent>()?;
101    ///     // ...
102    /// }
103    /// # Ok(())
104    /// # }
105    /// ```
106    pub fn get_field<'a, U>(&'a self, field_name: &str) -> serde_json::Result<Option<U>>
107    where
108        U: Deserialize<'a>,
109    {
110        struct FieldVisitor<'b>(&'b str);
111
112        impl Visitor<'_> for FieldVisitor<'_> {
113            type Value = bool;
114
115            fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
116                write!(formatter, "`{}`", self.0)
117            }
118
119            fn visit_str<E>(self, value: &str) -> Result<bool, E>
120            where
121                E: de::Error,
122            {
123                Ok(value == self.0)
124            }
125        }
126
127        struct Field<'b>(&'b str);
128
129        impl<'de> DeserializeSeed<'de> for Field<'_> {
130            type Value = bool;
131
132            fn deserialize<D>(self, deserializer: D) -> Result<bool, D::Error>
133            where
134                D: Deserializer<'de>,
135            {
136                deserializer.deserialize_identifier(FieldVisitor(self.0))
137            }
138        }
139
140        struct SingleFieldVisitor<'b, T> {
141            field_name: &'b str,
142            _phantom: PhantomData<T>,
143        }
144
145        impl<'b, T> SingleFieldVisitor<'b, T> {
146            fn new(field_name: &'b str) -> Self {
147                Self { field_name, _phantom: PhantomData }
148            }
149        }
150
151        impl<'de, T> Visitor<'de> for SingleFieldVisitor<'_, T>
152        where
153            T: Deserialize<'de>,
154        {
155            type Value = Option<T>;
156
157            fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
158                formatter.write_str("a string")
159            }
160
161            fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
162            where
163                A: MapAccess<'de>,
164            {
165                let mut res = None;
166                while let Some(is_right_field) = map.next_key_seed(Field(self.field_name))? {
167                    if is_right_field {
168                        res = Some(map.next_value()?);
169                    } else {
170                        map.next_value::<IgnoredAny>()?;
171                    }
172                }
173
174                Ok(res)
175            }
176        }
177
178        let mut deserializer = serde_json::Deserializer::from_str(self.json().get());
179        deserializer.deserialize_map(SingleFieldVisitor::new(field_name))
180    }
181
182    /// Try to deserialize the JSON as the expected type.
183    pub fn deserialize<'a>(&'a self) -> serde_json::Result<T>
184    where
185        T: Deserialize<'a>,
186    {
187        serde_json::from_str(self.json.get())
188    }
189
190    /// Try to deserialize the JSON as a custom type.
191    pub fn deserialize_as<'a, U>(&'a self) -> serde_json::Result<U>
192    where
193        T: JsonCastable<U>,
194        U: Deserialize<'a>,
195    {
196        self.deserialize_as_unchecked()
197    }
198
199    /// Same as [`deserialize_as`][Self::deserialize_as], but without the trait restriction.
200    pub fn deserialize_as_unchecked<'a, U>(&'a self) -> serde_json::Result<U>
201    where
202        U: Deserialize<'a>,
203    {
204        serde_json::from_str(self.json.get())
205    }
206
207    /// Turns `Raw<T>` into `Raw<U>` without changing the underlying JSON.
208    ///
209    /// This is useful for turning raw specific event types into raw event enum types.
210    pub fn cast<U>(self) -> Raw<U>
211    where
212        T: JsonCastable<U>,
213    {
214        self.cast_unchecked()
215    }
216
217    /// Turns `&Raw<T>` into `&Raw<U>` without changing the underlying JSON.
218    ///
219    /// This is useful for turning raw specific event types into raw event enum types.
220    pub fn cast_ref<U>(&self) -> &Raw<U>
221    where
222        T: JsonCastable<U>,
223    {
224        self.cast_ref_unchecked()
225    }
226
227    /// Same as [`cast`][Self::cast], but without the trait restriction.
228    pub fn cast_unchecked<U>(self) -> Raw<U> {
229        Raw::from_json(self.into_json())
230    }
231
232    /// Same as [`cast_ref`][Self::cast_ref], but without the trait restriction.
233    pub fn cast_ref_unchecked<U>(&self) -> &Raw<U> {
234        unsafe { mem::transmute(self) }
235    }
236}
237
238impl<T> Clone for Raw<T> {
239    fn clone(&self) -> Self {
240        Self::from_json(self.json.clone())
241    }
242}
243
244impl<T> Debug for Raw<T> {
245    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
246        use std::any::type_name;
247        f.debug_struct(&format!("Raw::<{}>", type_name::<T>())).field("json", &self.json).finish()
248    }
249}
250
251impl<'de, T> Deserialize<'de> for Raw<T> {
252    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
253    where
254        D: Deserializer<'de>,
255    {
256        Box::<RawJsonValue>::deserialize(deserializer).map(Self::from_json)
257    }
258}
259
260impl<T> Serialize for Raw<T> {
261    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
262    where
263        S: Serializer,
264    {
265        self.json.serialize(serializer)
266    }
267}
268
269/// Marker trait for restricting the types [`Raw::deserialize_as`], [`Raw::cast`] and
270/// [`Raw::cast_ref`] can be called with.
271///
272/// Implementing this trait for a type `U` means that it is safe to cast from `U` to `T` because `T`
273/// can be deserialized from the same JSON as `U`.
274pub trait JsonCastable<T> {}
275
276impl<T> JsonCastable<JsonValue> for T {}
277
278#[cfg(test)]
279mod tests {
280    use serde::Deserialize;
281    use serde_json::{from_str as from_json_str, value::RawValue as RawJsonValue};
282
283    use super::Raw;
284
285    #[test]
286    fn get_field() -> serde_json::Result<()> {
287        #[derive(Debug, PartialEq, Deserialize)]
288        struct A<'a> {
289            #[serde(borrow)]
290            b: Vec<&'a str>,
291        }
292
293        const OBJ: &str = r#"{ "a": { "b": [  "c"] }, "z": 5 }"#;
294        let raw: Raw<()> = from_json_str(OBJ)?;
295
296        assert_eq!(raw.get_field::<u8>("z")?, Some(5));
297        assert_eq!(raw.get_field::<&RawJsonValue>("a")?.unwrap().get(), r#"{ "b": [  "c"] }"#);
298        assert_eq!(raw.get_field::<A<'_>>("a")?, Some(A { b: vec!["c"] }));
299
300        assert_eq!(raw.get_field::<u8>("b")?, None);
301        raw.get_field::<u8>("a").unwrap_err();
302
303        Ok(())
304    }
305}