ruma_common/serde/
base64.rs

1//! Transparent base64 encoding / decoding as part of (de)serialization.
2
3use std::{fmt, marker::PhantomData};
4
5use base64::{
6    engine::{general_purpose, DecodePaddingMode, GeneralPurpose, GeneralPurposeConfig},
7    Engine,
8};
9use serde::{de, Deserialize, Deserializer, Serialize, Serializer};
10use zeroize::Zeroize;
11
12/// A wrapper around `B` (usually `Vec<u8>`) that (de)serializes from / to a base64 string.
13///
14/// The base64 character set (and miscellaneous other encoding / decoding options) can be customized
15/// through the generic parameter `C`.
16#[derive(Clone, PartialEq, Eq, PartialOrd, Ord)]
17pub struct Base64<C = Standard, B = Vec<u8>> {
18    bytes: B,
19    // Invariant PhantomData, Send + Sync
20    _phantom_conf: PhantomData<fn(C) -> C>,
21}
22
23impl<C, B> Zeroize for Base64<C, B>
24where
25    B: Zeroize,
26{
27    fn zeroize(&mut self) {
28        self.bytes.zeroize();
29    }
30}
31
32/// Config used for the [`Base64`] type.
33pub trait Base64Config {
34    /// The config as a constant.
35    ///
36    /// Opaque so our interface is not tied to the base64 crate version.
37    #[doc(hidden)]
38    const CONF: Conf;
39}
40
41#[doc(hidden)]
42pub struct Conf(base64::alphabet::Alphabet);
43
44/// Standard base64 character set without padding.
45///
46/// Allows trailing bits in decoding for maximum compatibility.
47#[non_exhaustive]
48// Easier than implementing these all for Base64 manually to avoid the `C: Trait` bounds.
49#[derive(Clone, PartialEq, Eq, PartialOrd, Ord)]
50pub struct Standard;
51
52impl Base64Config for Standard {
53    const CONF: Conf = Conf(base64::alphabet::STANDARD);
54}
55
56/// Url-safe base64 character set without padding.
57///
58/// Allows trailing bits in decoding for maximum compatibility.
59#[non_exhaustive]
60// Easier than implementing these all for Base64 manually to avoid the `C: Trait` bounds.
61#[derive(Clone, PartialEq, Eq, PartialOrd, Ord)]
62pub struct UrlSafe;
63
64impl Base64Config for UrlSafe {
65    const CONF: Conf = Conf(base64::alphabet::URL_SAFE);
66}
67
68impl<C: Base64Config, B> Base64<C, B> {
69    const CONFIG: GeneralPurposeConfig = general_purpose::NO_PAD
70        // See https://github.com/matrix-org/matrix-spec/issues/838
71        .with_decode_allow_trailing_bits(true)
72        .with_decode_padding_mode(DecodePaddingMode::Indifferent);
73    const ENGINE: GeneralPurpose = GeneralPurpose::new(&C::CONF.0, Self::CONFIG);
74}
75
76impl<C: Base64Config, B: AsRef<[u8]>> Base64<C, B> {
77    /// Create a `Base64` instance from raw bytes, to be base64-encoded in serialization.
78    pub fn new(bytes: B) -> Self {
79        Self { bytes, _phantom_conf: PhantomData }
80    }
81
82    /// Get a reference to the raw bytes held by this `Base64` instance.
83    pub fn as_bytes(&self) -> &[u8] {
84        self.bytes.as_ref()
85    }
86
87    /// Encode the bytes contained in this `Base64` instance to unpadded base64.
88    pub fn encode(&self) -> String {
89        Self::ENGINE.encode(self.as_bytes())
90    }
91}
92
93impl<C, B> Base64<C, B> {
94    /// Get the raw bytes held by this `Base64` instance.
95    pub fn into_inner(self) -> B {
96        self.bytes
97    }
98}
99
100impl<C: Base64Config> Base64<C> {
101    /// Create a `Base64` instance containing an empty `Vec<u8>`.
102    pub fn empty() -> Self {
103        Self::new(Vec::new())
104    }
105
106    /// Parse some base64-encoded data to create a `Base64` instance.
107    pub fn parse(encoded: impl AsRef<[u8]>) -> Result<Self, Base64DecodeError> {
108        Self::ENGINE.decode(encoded).map(Self::new).map_err(Base64DecodeError)
109    }
110}
111
112impl<C: Base64Config, B: AsRef<[u8]>> fmt::Debug for Base64<C, B> {
113    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
114        self.encode().fmt(f)
115    }
116}
117
118impl<C: Base64Config, B: AsRef<[u8]>> fmt::Display for Base64<C, B> {
119    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
120        self.encode().fmt(f)
121    }
122}
123
124impl<'de, C: Base64Config> Deserialize<'de> for Base64<C> {
125    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
126    where
127        D: Deserializer<'de>,
128    {
129        let encoded = super::deserialize_cow_str(deserializer)?;
130        Self::parse(&*encoded).map_err(de::Error::custom)
131    }
132}
133
134impl<C: Base64Config, B: AsRef<[u8]>> Serialize for Base64<C, B> {
135    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
136    where
137        S: Serializer,
138    {
139        serializer.serialize_str(&self.encode())
140    }
141}
142
143/// An error that occurred while decoding a base64 string.
144#[derive(Clone)]
145pub struct Base64DecodeError(base64::DecodeError);
146
147impl fmt::Debug for Base64DecodeError {
148    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
149        self.0.fmt(f)
150    }
151}
152
153impl fmt::Display for Base64DecodeError {
154    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
155        self.0.fmt(f)
156    }
157}
158
159impl std::error::Error for Base64DecodeError {}
160
161#[cfg(test)]
162mod tests {
163    use super::{Base64, Standard};
164
165    #[test]
166    fn slightly_malformed_base64() {
167        const INPUT: &str = "3UmJnEIzUr2xWyaUnJg5fXwRybwG5FVC6Gq\
168            MHverEUn0ztuIsvVxX89JXX2pvdTsOBbLQx+4TVL02l4Cp5wPCm";
169        const INPUT_WITH_PADDING: &str = "im9+knCkMNQNh9o6sbdcZw==";
170
171        Base64::<Standard>::parse(INPUT).unwrap();
172        Base64::<Standard>::parse(INPUT_WITH_PADDING)
173            .expect("We should be able to decode padded Base64");
174    }
175}