ruma_common/serde/
base64.rs1use std::{fmt, marker::PhantomData};
4
5use base64::{
6 engine::{general_purpose, DecodePaddingMode, GeneralPurpose, GeneralPurposeConfig},
7 Engine,
8};
9use serde::{de, Deserialize, Deserializer, Serialize, Serializer};
10use zeroize::Zeroize;
11
12#[derive(Clone, PartialEq, Eq, PartialOrd, Ord)]
17pub struct Base64<C = Standard, B = Vec<u8>> {
18 bytes: B,
19 _phantom_conf: PhantomData<fn(C) -> C>,
21}
22
23impl<C, B> Zeroize for Base64<C, B>
24where
25 B: Zeroize,
26{
27 fn zeroize(&mut self) {
28 self.bytes.zeroize();
29 }
30}
31
32pub trait Base64Config {
34 #[doc(hidden)]
38 const CONF: Conf;
39}
40
41#[doc(hidden)]
42pub struct Conf(base64::alphabet::Alphabet);
43
44#[non_exhaustive]
48#[derive(Clone, PartialEq, Eq, PartialOrd, Ord)]
50pub struct Standard;
51
52impl Base64Config for Standard {
53 const CONF: Conf = Conf(base64::alphabet::STANDARD);
54}
55
56#[non_exhaustive]
60#[derive(Clone, PartialEq, Eq, PartialOrd, Ord)]
62pub struct UrlSafe;
63
64impl Base64Config for UrlSafe {
65 const CONF: Conf = Conf(base64::alphabet::URL_SAFE);
66}
67
68impl<C: Base64Config, B> Base64<C, B> {
69 const CONFIG: GeneralPurposeConfig = general_purpose::NO_PAD
70 .with_decode_allow_trailing_bits(true)
72 .with_decode_padding_mode(DecodePaddingMode::Indifferent);
73 const ENGINE: GeneralPurpose = GeneralPurpose::new(&C::CONF.0, Self::CONFIG);
74}
75
76impl<C: Base64Config, B: AsRef<[u8]>> Base64<C, B> {
77 pub fn new(bytes: B) -> Self {
79 Self { bytes, _phantom_conf: PhantomData }
80 }
81
82 pub fn as_bytes(&self) -> &[u8] {
84 self.bytes.as_ref()
85 }
86
87 pub fn encode(&self) -> String {
89 Self::ENGINE.encode(self.as_bytes())
90 }
91}
92
93impl<C, B> Base64<C, B> {
94 pub fn into_inner(self) -> B {
96 self.bytes
97 }
98}
99
100impl<C: Base64Config> Base64<C> {
101 pub fn empty() -> Self {
103 Self::new(Vec::new())
104 }
105
106 pub fn parse(encoded: impl AsRef<[u8]>) -> Result<Self, Base64DecodeError> {
108 Self::ENGINE.decode(encoded).map(Self::new).map_err(Base64DecodeError)
109 }
110}
111
112impl<C: Base64Config, B: AsRef<[u8]>> fmt::Debug for Base64<C, B> {
113 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
114 self.encode().fmt(f)
115 }
116}
117
118impl<C: Base64Config, B: AsRef<[u8]>> fmt::Display for Base64<C, B> {
119 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
120 self.encode().fmt(f)
121 }
122}
123
124impl<'de, C: Base64Config> Deserialize<'de> for Base64<C> {
125 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
126 where
127 D: Deserializer<'de>,
128 {
129 let encoded = super::deserialize_cow_str(deserializer)?;
130 Self::parse(&*encoded).map_err(de::Error::custom)
131 }
132}
133
134impl<C: Base64Config, B: AsRef<[u8]>> Serialize for Base64<C, B> {
135 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
136 where
137 S: Serializer,
138 {
139 serializer.serialize_str(&self.encode())
140 }
141}
142
143#[derive(Clone)]
145pub struct Base64DecodeError(base64::DecodeError);
146
147impl fmt::Debug for Base64DecodeError {
148 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
149 self.0.fmt(f)
150 }
151}
152
153impl fmt::Display for Base64DecodeError {
154 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
155 self.0.fmt(f)
156 }
157}
158
159impl std::error::Error for Base64DecodeError {}
160
161#[cfg(test)]
162mod tests {
163 use super::{Base64, Standard};
164
165 #[test]
166 fn slightly_malformed_base64() {
167 const INPUT: &str = "3UmJnEIzUr2xWyaUnJg5fXwRybwG5FVC6Gq\
168 MHverEUn0ztuIsvVxX89JXX2pvdTsOBbLQx+4TVL02l4Cp5wPCm";
169 const INPUT_WITH_PADDING: &str = "im9+knCkMNQNh9o6sbdcZw==";
170
171 Base64::<Standard>::parse(INPUT).unwrap();
172 Base64::<Standard>::parse(INPUT_WITH_PADDING)
173 .expect("We should be able to decode padded Base64");
174 }
175}