ruma_common/serde/
base64.rs

1//! Transparent base64 encoding / decoding as part of (de)serialization.
2
3use std::{fmt, marker::PhantomData};
4
5use base64::{
6    engine::{general_purpose, DecodePaddingMode, GeneralPurpose, GeneralPurposeConfig},
7    Engine,
8};
9use serde::{de, Deserialize, Deserializer, Serialize, Serializer};
10
11/// A wrapper around `B` (usually `Vec<u8>`) that (de)serializes from / to a base64 string.
12///
13/// The base64 character set (and miscellaneous other encoding / decoding options) can be customized
14/// through the generic parameter `C`.
15#[derive(Clone, PartialEq, Eq, PartialOrd, Ord)]
16pub struct Base64<C = Standard, B = Vec<u8>> {
17    bytes: B,
18    // Invariant PhantomData, Send + Sync
19    _phantom_conf: PhantomData<fn(C) -> C>,
20}
21
22/// Config used for the [`Base64`] type.
23pub trait Base64Config {
24    /// The config as a constant.
25    ///
26    /// Opaque so our interface is not tied to the base64 crate version.
27    #[doc(hidden)]
28    const CONF: Conf;
29}
30
31#[doc(hidden)]
32pub struct Conf(base64::alphabet::Alphabet);
33
34/// Standard base64 character set without padding.
35///
36/// Allows trailing bits in decoding for maximum compatibility.
37#[non_exhaustive]
38// Easier than implementing these all for Base64 manually to avoid the `C: Trait` bounds.
39#[derive(Clone, PartialEq, Eq, PartialOrd, Ord)]
40pub struct Standard;
41
42impl Base64Config for Standard {
43    const CONF: Conf = Conf(base64::alphabet::STANDARD);
44}
45
46/// Url-safe base64 character set without padding.
47///
48/// Allows trailing bits in decoding for maximum compatibility.
49#[non_exhaustive]
50// Easier than implementing these all for Base64 manually to avoid the `C: Trait` bounds.
51#[derive(Clone, PartialEq, Eq, PartialOrd, Ord)]
52pub struct UrlSafe;
53
54impl Base64Config for UrlSafe {
55    const CONF: Conf = Conf(base64::alphabet::URL_SAFE);
56}
57
58impl<C: Base64Config, B> Base64<C, B> {
59    const CONFIG: GeneralPurposeConfig = general_purpose::NO_PAD
60        // See https://github.com/matrix-org/matrix-spec/issues/838
61        .with_decode_allow_trailing_bits(true)
62        .with_decode_padding_mode(DecodePaddingMode::Indifferent);
63    const ENGINE: GeneralPurpose = GeneralPurpose::new(&C::CONF.0, Self::CONFIG);
64}
65
66impl<C: Base64Config, B: AsRef<[u8]>> Base64<C, B> {
67    /// Create a `Base64` instance from raw bytes, to be base64-encoded in serialization.
68    pub fn new(bytes: B) -> Self {
69        Self { bytes, _phantom_conf: PhantomData }
70    }
71
72    /// Get a reference to the raw bytes held by this `Base64` instance.
73    pub fn as_bytes(&self) -> &[u8] {
74        self.bytes.as_ref()
75    }
76
77    /// Encode the bytes contained in this `Base64` instance to unpadded base64.
78    pub fn encode(&self) -> String {
79        Self::ENGINE.encode(self.as_bytes())
80    }
81}
82
83impl<C, B> Base64<C, B> {
84    /// Get the raw bytes held by this `Base64` instance.
85    pub fn into_inner(self) -> B {
86        self.bytes
87    }
88}
89
90impl<C: Base64Config> Base64<C> {
91    /// Create a `Base64` instance containing an empty `Vec<u8>`.
92    pub fn empty() -> Self {
93        Self::new(Vec::new())
94    }
95
96    /// Parse some base64-encoded data to create a `Base64` instance.
97    pub fn parse(encoded: impl AsRef<[u8]>) -> Result<Self, Base64DecodeError> {
98        Self::ENGINE.decode(encoded).map(Self::new).map_err(Base64DecodeError)
99    }
100}
101
102impl<C: Base64Config, B: AsRef<[u8]>> fmt::Debug for Base64<C, B> {
103    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
104        self.encode().fmt(f)
105    }
106}
107
108impl<C: Base64Config, B: AsRef<[u8]>> fmt::Display for Base64<C, B> {
109    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
110        self.encode().fmt(f)
111    }
112}
113
114impl<'de, C: Base64Config> Deserialize<'de> for Base64<C> {
115    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
116    where
117        D: Deserializer<'de>,
118    {
119        let encoded = super::deserialize_cow_str(deserializer)?;
120        Self::parse(&*encoded).map_err(de::Error::custom)
121    }
122}
123
124impl<C: Base64Config, B: AsRef<[u8]>> Serialize for Base64<C, B> {
125    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
126    where
127        S: Serializer,
128    {
129        serializer.serialize_str(&self.encode())
130    }
131}
132
133/// An error that occurred while decoding a base64 string.
134#[derive(Clone)]
135pub struct Base64DecodeError(base64::DecodeError);
136
137impl fmt::Debug for Base64DecodeError {
138    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
139        self.0.fmt(f)
140    }
141}
142
143impl fmt::Display for Base64DecodeError {
144    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
145        self.0.fmt(f)
146    }
147}
148
149impl std::error::Error for Base64DecodeError {}
150
151#[cfg(test)]
152mod tests {
153    use super::{Base64, Standard};
154
155    #[test]
156    fn slightly_malformed_base64() {
157        const INPUT: &str = "3UmJnEIzUr2xWyaUnJg5fXwRybwG5FVC6Gq\
158            MHverEUn0ztuIsvVxX89JXX2pvdTsOBbLQx+4TVL02l4Cp5wPCm";
159        const INPUT_WITH_PADDING: &str = "im9+knCkMNQNh9o6sbdcZw==";
160
161        Base64::<Standard>::parse(INPUT).unwrap();
162        Base64::<Standard>::parse(INPUT_WITH_PADDING)
163            .expect("We should be able to decode padded Base64");
164    }
165}