Skip to main content

ruma_events/room/
encrypted_file_serde.rs

1use std::{borrow::Cow, collections::BTreeMap};
2
3use ruma_common::serde::Base64;
4use serde::{Deserialize, Deserializer, Serialize, Serializer, de, ser::SerializeMap};
5
6use super::{
7    CustomEncryptedFileHash, EncryptedFileHash, EncryptedFileHashAlgorithm, EncryptedFileHashes,
8    V2EncryptedFileInfo,
9};
10
11impl<'de> Deserialize<'de> for V2EncryptedFileInfo {
12    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
13    where
14        D: Deserializer<'de>,
15    {
16        let V2EncryptedFileInfoSerdeHelper { key: JsonWebKey { kty, key_ops, alg, k, ext }, iv } =
17            V2EncryptedFileInfoSerdeHelper::deserialize(deserializer)?;
18
19        if kty != "oct" {
20            return Err(de::Error::custom(format!(
21                "invalid value in `kty` field: `{kty}` , expected `oct`"
22            )));
23        }
24
25        if alg != "A256CTR" {
26            return Err(de::Error::custom(format!(
27                "invalid value in `alg` field: `{alg}` , expected `A256CTR`"
28            )));
29        }
30
31        if !key_ops.iter().any(|key_op| key_op == "encrypt") {
32            return Err(de::Error::custom("missing value `encrypt` in `key_ops` field"));
33        }
34
35        if !key_ops.iter().any(|key_op| key_op == "decrypt") {
36            return Err(de::Error::custom("missing value `decrypt` in `key_ops` field"));
37        }
38
39        if !ext {
40            return Err(de::Error::custom(
41                "invalid value in `ext` field: `false` , expected `true`",
42            ));
43        }
44
45        let k = Base64::parse(k.as_ref())
46            .map_err(|error| de::Error::custom(format!("invalid value in `k` field: {error}")))?;
47        let iv = Base64::parse(iv.as_ref())
48            .map_err(|error| de::Error::custom(format!("invalid value in `iv` field: {error}")))?;
49
50        Ok(Self { k, iv })
51    }
52}
53
54impl Serialize for V2EncryptedFileInfo {
55    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
56    where
57        S: Serializer,
58    {
59        let Self { k, iv } = self;
60
61        let info = V2EncryptedFileInfoSerdeHelper {
62            key: JsonWebKey {
63                kty: Cow::Borrowed("oct"),
64                key_ops: vec![Cow::Borrowed("decrypt"), Cow::Borrowed("encrypt")],
65                alg: Cow::Borrowed("A256CTR"),
66                k: Cow::Owned(k.encode()),
67                ext: true,
68            },
69            iv: Cow::Owned(iv.encode()),
70        };
71
72        info.serialize(serializer)
73    }
74}
75
76#[derive(Deserialize, Serialize)]
77struct V2EncryptedFileInfoSerdeHelper<'a> {
78    /// The key.
79    #[serde(borrow)]
80    key: JsonWebKey<'a>,
81
82    /// The 128-bit unique counter block used by AES-CTR, encoded as unpadded base64.
83    #[serde(borrow)]
84    iv: Cow<'a, str>,
85}
86
87/// A [JSON Web Key](https://tools.ietf.org/html/rfc7517#appendix-A.3) object.
88#[derive(Deserialize, Serialize)]
89struct JsonWebKey<'a> {
90    /// Key type.
91    ///
92    /// Must be `oct`.
93    #[serde(borrow)]
94    kty: Cow<'a, str>,
95
96    /// Key operations.
97    ///
98    /// Must at least contain `encrypt` and `decrypt`.
99    #[serde(borrow)]
100    key_ops: Vec<Cow<'a, str>>,
101
102    /// Algorithm.
103    ///
104    /// Must be `A256CTR`.
105    #[serde(borrow)]
106    alg: Cow<'a, str>,
107
108    /// The key, encoded as url-safe unpadded base64.
109    #[serde(borrow)]
110    k: Cow<'a, str>,
111
112    /// Extractable.
113    ///
114    /// Must be `true`. This is a
115    /// [W3C extension](https://w3c.github.io/webcrypto/#iana-section-jwk).
116    ext: bool,
117}
118
119impl Serialize for EncryptedFileHashes {
120    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
121    where
122        S: Serializer,
123    {
124        let mut s = serializer.serialize_map(Some(self.len()))?;
125
126        for hash in self.values() {
127            match hash {
128                EncryptedFileHash::Sha256(hash) => {
129                    s.serialize_entry(&EncryptedFileHashAlgorithm::Sha256, hash)?;
130                }
131                EncryptedFileHash::_Custom(CustomEncryptedFileHash { algorithm, hash }) => {
132                    s.serialize_entry(algorithm, hash)?;
133                }
134            }
135        }
136
137        s.end()
138    }
139}
140
141impl<'de> Deserialize<'de> for EncryptedFileHashes {
142    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
143    where
144        D: Deserializer<'de>,
145    {
146        let map_by_key = BTreeMap::<EncryptedFileHashAlgorithm, String>::deserialize(deserializer)?;
147
148        map_by_key
149            .into_iter()
150            .map(|(algorithm, hash)| {
151                Ok(match algorithm {
152                    EncryptedFileHashAlgorithm::Sha256 => {
153                        EncryptedFileHash::Sha256(Base64::parse(hash).map_err(de::Error::custom)?)
154                    }
155                    EncryptedFileHashAlgorithm::_Custom(s) => {
156                        EncryptedFileHash::_Custom(CustomEncryptedFileHash {
157                            algorithm: s.0.into(),
158                            hash: Base64::parse(hash).map_err(de::Error::custom)?,
159                        })
160                    }
161                })
162            })
163            .collect()
164    }
165}