Skip to main content

ruma_common/serde/
base64.rs

1//! Transparent base64 encoding / decoding as part of (de)serialization.
2
3use std::{fmt, marker::PhantomData};
4
5use base64::{
6    Engine,
7    engine::{DecodePaddingMode, GeneralPurpose, GeneralPurposeConfig, general_purpose},
8};
9use serde::{Deserialize, Deserializer, Serialize, Serializer, de};
10use zeroize::Zeroize;
11
12/// A wrapper around `B` (usually `Vec<u8>`) that (de)serializes from / to a base64 string.
13///
14/// The base64 character set (and miscellaneous other encoding / decoding options) can be customized
15/// through the generic parameter `C`.
16#[derive(Clone, PartialEq, Eq, PartialOrd, Ord)]
17pub struct Base64<C = Standard, B = Vec<u8>> {
18    bytes: B,
19    // Invariant PhantomData, Send + Sync
20    _phantom_conf: PhantomData<fn(C) -> C>,
21}
22
23impl<C, B> Zeroize for Base64<C, B>
24where
25    B: Zeroize,
26{
27    fn zeroize(&mut self) {
28        self.bytes.zeroize();
29    }
30}
31
32/// Config used for the [`Base64`] type.
33pub trait Base64Config {
34    /// The config as a constant.
35    ///
36    /// Opaque so our interface is not tied to the base64 crate version.
37    #[doc(hidden)]
38    const CONF: Conf;
39}
40
41#[doc(hidden)]
42pub struct Conf(base64::alphabet::Alphabet);
43
44/// Standard base64 character set without padding.
45///
46/// Allows trailing bits in decoding for maximum compatibility.
47#[non_exhaustive]
48// Easier than implementing these all for Base64 manually to avoid the `C: Trait` bounds.
49#[derive(Clone, PartialEq, Eq, PartialOrd, Ord)]
50pub struct Standard;
51
52impl Base64Config for Standard {
53    const CONF: Conf = Conf(base64::alphabet::STANDARD);
54}
55
56/// Url-safe base64 character set without padding.
57///
58/// Allows trailing bits in decoding for maximum compatibility.
59#[non_exhaustive]
60// Easier than implementing these all for Base64 manually to avoid the `C: Trait` bounds.
61#[derive(Clone, PartialEq, Eq, PartialOrd, Ord)]
62pub struct UrlSafe;
63
64impl Base64Config for UrlSafe {
65    const CONF: Conf = Conf(base64::alphabet::URL_SAFE);
66}
67
68impl<C: Base64Config, B> Base64<C, B> {
69    const CONFIG: GeneralPurposeConfig = general_purpose::NO_PAD
70        // See https://github.com/matrix-org/matrix-spec/issues/838
71        .with_decode_allow_trailing_bits(true)
72        .with_decode_padding_mode(DecodePaddingMode::Indifferent);
73    const ENGINE: GeneralPurpose = GeneralPurpose::new(&C::CONF.0, Self::CONFIG);
74}
75
76impl<C: Base64Config, B: AsRef<[u8]>> Base64<C, B> {
77    /// Create a `Base64` instance from raw bytes, to be base64-encoded in serialization.
78    pub fn new(bytes: B) -> Self {
79        Self { bytes, _phantom_conf: PhantomData }
80    }
81
82    /// Get a reference to the raw bytes held by this `Base64` instance.
83    pub fn as_bytes(&self) -> &[u8] {
84        self.bytes.as_ref()
85    }
86
87    /// Encode the bytes contained in this `Base64` instance to unpadded base64.
88    pub fn encode(&self) -> String {
89        Self::ENGINE.encode(self.as_bytes())
90    }
91}
92
93impl<C, B> Base64<C, B> {
94    /// Get a reference to the raw bytes held by this `Base64` instance.
95    pub fn as_inner(&self) -> &B {
96        &self.bytes
97    }
98
99    /// Get the raw bytes held by this `Base64` instance.
100    pub fn into_inner(self) -> B {
101        self.bytes
102    }
103}
104
105impl<C: Base64Config> Base64<C> {
106    /// Create a `Base64` instance containing an empty `Vec<u8>`.
107    pub fn empty() -> Self {
108        Self::new(Vec::new())
109    }
110}
111
112impl<C: Base64Config, B: TryFromBase64DecodedBytes> Base64<C, B> {
113    /// Parse some base64-encoded data to create a `Base64` instance.
114    pub fn parse(encoded: impl AsRef<[u8]>) -> Result<Self, Base64DecodeError> {
115        let decoded = Self::ENGINE.decode(encoded).map_err(Base64DecodeError::base64)?;
116        B::try_from_bytes(decoded).map(Self::new)
117    }
118}
119
120impl<C: Base64Config, B: AsRef<[u8]>> fmt::Debug for Base64<C, B> {
121    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
122        self.encode().fmt(f)
123    }
124}
125
126impl<C: Base64Config, B: AsRef<[u8]>> fmt::Display for Base64<C, B> {
127    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
128        self.encode().fmt(f)
129    }
130}
131
132impl<'de, C: Base64Config, B: TryFromBase64DecodedBytes> Deserialize<'de> for Base64<C, B> {
133    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
134    where
135        D: Deserializer<'de>,
136    {
137        let encoded = super::deserialize_cow_str(deserializer)?;
138        Self::parse(&*encoded).map_err(de::Error::custom)
139    }
140}
141
142impl<C: Base64Config, B: AsRef<[u8]>> Serialize for Base64<C, B> {
143    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
144    where
145        S: Serializer,
146    {
147        serializer.serialize_str(&self.encode())
148    }
149}
150
151/// Marker trait for indicating which inner `B` "bytes" type can be converted from decoded base64
152/// bytes.
153///
154/// This is used as a bound in [`Base64::parse()`] to provide a more helpful error message than
155/// using a `TryFrom<Vec<u8>>` implementation.
156pub trait TryFromBase64DecodedBytes: Sized + AsRef<[u8]> {
157    /// Convert the given bytes to this type.
158    #[doc(hidden)]
159    fn try_from_bytes(bytes: Vec<u8>) -> Result<Self, Base64DecodeError>;
160}
161
162impl TryFromBase64DecodedBytes for Vec<u8> {
163    fn try_from_bytes(bytes: Vec<u8>) -> Result<Self, Base64DecodeError> {
164        Ok(bytes)
165    }
166}
167
168impl<const N: usize> TryFromBase64DecodedBytes for [u8; N] {
169    fn try_from_bytes(bytes: Vec<u8>) -> Result<Self, Base64DecodeError> {
170        Self::try_from(bytes)
171            .map_err(|bytes| Base64DecodeError::invalid_decoded_length(bytes.len(), N))
172    }
173}
174
175/// An error that occurred while decoding a base64 string.
176#[derive(Clone)]
177pub struct Base64DecodeError(Base64DecodeErrorInner);
178
179impl Base64DecodeError {
180    /// Construct a `Base64DecodeError` from an invalid base64 encoding error.
181    fn base64(error: base64::DecodeError) -> Self {
182        Self(Base64DecodeErrorInner::Base64(error))
183    }
184
185    /// Construct a `Base64DecodeError` from an invalid decoded bytes length error.
186    fn invalid_decoded_length(len: usize, expected: usize) -> Self {
187        Self(Base64DecodeErrorInner::InvalidDecodedLength { len, expected })
188    }
189}
190
191impl fmt::Debug for Base64DecodeError {
192    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
193        self.0.fmt(f)
194    }
195}
196
197impl fmt::Display for Base64DecodeError {
198    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
199        match &self.0 {
200            Base64DecodeErrorInner::Base64(error) => write!(f, "invalid base64 encoding: {error}"),
201            Base64DecodeErrorInner::InvalidDecodedLength { len, expected } => {
202                write!(f, "invalid decoded base64 bytes length: {len}, expected {expected}")
203            }
204        }
205    }
206}
207
208impl std::error::Error for Base64DecodeError {}
209
210/// An error that occurred while decoding a base64 string.
211#[derive(Debug, Clone)]
212enum Base64DecodeErrorInner {
213    /// The base64 encoding is invalid.
214    Base64(base64::DecodeError),
215
216    /// The decoded bytes have the wrong length to fit into an array of fixed length.
217    InvalidDecodedLength {
218        /// The length of the input.
219        len: usize,
220        /// The expected length.
221        expected: usize,
222    },
223}
224
225#[cfg(test)]
226mod tests {
227    use super::{Base64, Standard};
228
229    #[test]
230    fn parse_base64() {
231        const INPUT: &str = "3UmJnEIzUr2xWyaUnJg5fXwRybwG5FVC6Gq\
232            MHverEUn0ztuIsvVxX89JXX2pvdTsOBbLQx+4TVL02l4Cp5wPCm";
233        const INPUT_WITH_PADDING: &str = "im9+knCkMNQNh9o6sbdcZw==";
234
235        Base64::<Standard>::parse(INPUT).unwrap();
236        Base64::<Standard>::parse(INPUT_WITH_PADDING)
237            .expect("We should be able to decode padded Base64");
238
239        // Check that we can parse with the correct length.
240        Base64::<Standard, [u8; 32]>::parse(INPUT).unwrap_err();
241        Base64::<Standard, [u8; 64]>::parse(INPUT).unwrap();
242        Base64::<Standard, [u8; 32]>::parse(INPUT_WITH_PADDING).unwrap_err();
243        Base64::<Standard, [u8; 16]>::parse(INPUT_WITH_PADDING).unwrap();
244    }
245}