Skip to main content

ruma_events/room/
encrypted_file_serde.rs

1use std::{borrow::Cow, collections::BTreeMap};
2
3use as_variant::as_variant;
4use ruma_common::serde::{Base64, JsonObject};
5use serde::{Deserialize, Deserializer, Serialize, Serializer, de, ser::SerializeMap};
6use serde_json::{Value as JsonValue, from_value as from_json_value};
7
8use super::{
9    CustomEncryptedFileHash, CustomEncryptedFileInfo, EncryptedFileHash,
10    EncryptedFileHashAlgorithm, EncryptedFileHashes, EncryptedFileInfo, V2EncryptedFileInfo,
11};
12
13impl<'de> Deserialize<'de> for EncryptedFileInfo {
14    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
15    where
16        D: Deserializer<'de>,
17    {
18        let mut data = JsonObject::deserialize(deserializer)?;
19
20        let v = data
21            .remove("v")
22            .and_then(|value| as_variant!(value, JsonValue::String))
23            .ok_or_else(|| de::Error::missing_field("v"))?;
24
25        match v.as_ref() {
26            "v2" => from_json_value(data.into()).map(Self::V2),
27            _ => Ok(Self::_Custom(CustomEncryptedFileInfo { v, data })),
28        }
29        .map_err(de::Error::custom)
30    }
31}
32
33impl<'de> Deserialize<'de> for V2EncryptedFileInfo {
34    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
35    where
36        D: Deserializer<'de>,
37    {
38        let V2EncryptedFileInfoSerdeHelper { key: JsonWebKey { kty, key_ops, alg, k, ext }, iv } =
39            V2EncryptedFileInfoSerdeHelper::deserialize(deserializer)?;
40
41        if kty != "oct" {
42            return Err(de::Error::custom(format!(
43                "invalid value in `kty` field: `{kty}` , expected `oct`"
44            )));
45        }
46
47        if alg != "A256CTR" {
48            return Err(de::Error::custom(format!(
49                "invalid value in `alg` field: `{alg}` , expected `A256CTR`"
50            )));
51        }
52
53        if !key_ops.iter().any(|key_op| key_op == "encrypt") {
54            return Err(de::Error::custom("missing value `encrypt` in `key_ops` field"));
55        }
56
57        if !key_ops.iter().any(|key_op| key_op == "decrypt") {
58            return Err(de::Error::custom("missing value `decrypt` in `key_ops` field"));
59        }
60
61        if !ext {
62            return Err(de::Error::custom(
63                "invalid value in `ext` field: `false` , expected `true`",
64            ));
65        }
66
67        let k = Base64::parse(k.as_ref())
68            .map_err(|error| de::Error::custom(format!("invalid value in `k` field: {error}")))?;
69        let iv = Base64::parse(iv.as_ref())
70            .map_err(|error| de::Error::custom(format!("invalid value in `iv` field: {error}")))?;
71
72        Ok(Self { k, iv })
73    }
74}
75
76impl Serialize for V2EncryptedFileInfo {
77    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
78    where
79        S: Serializer,
80    {
81        let Self { k, iv } = self;
82
83        let info = V2EncryptedFileInfoSerdeHelper {
84            key: JsonWebKey {
85                kty: Cow::Borrowed("oct"),
86                key_ops: vec![Cow::Borrowed("decrypt"), Cow::Borrowed("encrypt")],
87                alg: Cow::Borrowed("A256CTR"),
88                k: Cow::Owned(k.encode()),
89                ext: true,
90            },
91            iv: Cow::Owned(iv.encode()),
92        };
93
94        info.serialize(serializer)
95    }
96}
97
98#[derive(Deserialize, Serialize)]
99struct V2EncryptedFileInfoSerdeHelper<'a> {
100    /// The key.
101    #[serde(borrow)]
102    key: JsonWebKey<'a>,
103
104    /// The 128-bit unique counter block used by AES-CTR, encoded as unpadded base64.
105    #[serde(borrow)]
106    iv: Cow<'a, str>,
107}
108
109/// A [JSON Web Key](https://tools.ietf.org/html/rfc7517#appendix-A.3) object.
110#[derive(Deserialize, Serialize)]
111struct JsonWebKey<'a> {
112    /// Key type.
113    ///
114    /// Must be `oct`.
115    #[serde(borrow)]
116    kty: Cow<'a, str>,
117
118    /// Key operations.
119    ///
120    /// Must at least contain `encrypt` and `decrypt`.
121    #[serde(borrow)]
122    key_ops: Vec<Cow<'a, str>>,
123
124    /// Algorithm.
125    ///
126    /// Must be `A256CTR`.
127    #[serde(borrow)]
128    alg: Cow<'a, str>,
129
130    /// The key, encoded as url-safe unpadded base64.
131    #[serde(borrow)]
132    k: Cow<'a, str>,
133
134    /// Extractable.
135    ///
136    /// Must be `true`. This is a
137    /// [W3C extension](https://w3c.github.io/webcrypto/#iana-section-jwk).
138    ext: bool,
139}
140
141impl Serialize for EncryptedFileHashes {
142    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
143    where
144        S: Serializer,
145    {
146        let mut s = serializer.serialize_map(Some(self.len()))?;
147
148        for hash in self.values() {
149            match hash {
150                EncryptedFileHash::Sha256(hash) => {
151                    s.serialize_entry(&EncryptedFileHashAlgorithm::Sha256, hash)?;
152                }
153                EncryptedFileHash::_Custom(CustomEncryptedFileHash { algorithm, hash }) => {
154                    s.serialize_entry(algorithm, hash)?;
155                }
156            }
157        }
158
159        s.end()
160    }
161}
162
163impl<'de> Deserialize<'de> for EncryptedFileHashes {
164    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
165    where
166        D: Deserializer<'de>,
167    {
168        let map_by_key = BTreeMap::<EncryptedFileHashAlgorithm, String>::deserialize(deserializer)?;
169
170        map_by_key
171            .into_iter()
172            .map(|(algorithm, hash)| {
173                Ok(match algorithm {
174                    EncryptedFileHashAlgorithm::Sha256 => {
175                        EncryptedFileHash::Sha256(Base64::parse(hash).map_err(de::Error::custom)?)
176                    }
177                    EncryptedFileHashAlgorithm::_Custom(s) => {
178                        EncryptedFileHash::_Custom(CustomEncryptedFileHash {
179                            algorithm: s.0.into(),
180                            hash: Base64::parse(hash).map_err(de::Error::custom)?,
181                        })
182                    }
183                })
184            })
185            .collect()
186    }
187}