ruma_common/serde/
base64.rs1use std::{fmt, marker::PhantomData};
4
5use base64::{
6 Engine,
7 engine::{DecodePaddingMode, GeneralPurpose, GeneralPurposeConfig, general_purpose},
8};
9use serde::{Deserialize, Deserializer, Serialize, Serializer, de};
10use zeroize::Zeroize;
11
12#[derive(Clone, PartialEq, Eq, PartialOrd, Ord)]
17pub struct Base64<C = Standard, B = Vec<u8>> {
18 bytes: B,
19 _phantom_conf: PhantomData<fn(C) -> C>,
21}
22
23impl<C, B> Zeroize for Base64<C, B>
24where
25 B: Zeroize,
26{
27 fn zeroize(&mut self) {
28 self.bytes.zeroize();
29 }
30}
31
32pub trait Base64Config {
34 #[doc(hidden)]
38 const CONF: Conf;
39}
40
41#[doc(hidden)]
42pub struct Conf(base64::alphabet::Alphabet);
43
44#[non_exhaustive]
48#[derive(Clone, PartialEq, Eq, PartialOrd, Ord)]
50pub struct Standard;
51
52impl Base64Config for Standard {
53 const CONF: Conf = Conf(base64::alphabet::STANDARD);
54}
55
56#[non_exhaustive]
60#[derive(Clone, PartialEq, Eq, PartialOrd, Ord)]
62pub struct UrlSafe;
63
64impl Base64Config for UrlSafe {
65 const CONF: Conf = Conf(base64::alphabet::URL_SAFE);
66}
67
68impl<C: Base64Config, B> Base64<C, B> {
69 const CONFIG: GeneralPurposeConfig = general_purpose::NO_PAD
70 .with_decode_allow_trailing_bits(true)
72 .with_decode_padding_mode(DecodePaddingMode::Indifferent);
73 const ENGINE: GeneralPurpose = GeneralPurpose::new(&C::CONF.0, Self::CONFIG);
74}
75
76impl<C: Base64Config, B: AsRef<[u8]>> Base64<C, B> {
77 pub fn new(bytes: B) -> Self {
79 Self { bytes, _phantom_conf: PhantomData }
80 }
81
82 pub fn as_bytes(&self) -> &[u8] {
84 self.bytes.as_ref()
85 }
86
87 pub fn encode(&self) -> String {
89 Self::ENGINE.encode(self.as_bytes())
90 }
91}
92
93impl<C, B> Base64<C, B> {
94 pub fn as_inner(&self) -> &B {
96 &self.bytes
97 }
98
99 pub fn into_inner(self) -> B {
101 self.bytes
102 }
103}
104
105impl<C: Base64Config> Base64<C> {
106 pub fn empty() -> Self {
108 Self::new(Vec::new())
109 }
110}
111
112impl<C: Base64Config, B: TryFromBase64DecodedBytes> Base64<C, B> {
113 pub fn parse(encoded: impl AsRef<[u8]>) -> Result<Self, Base64DecodeError> {
115 let decoded = Self::ENGINE.decode(encoded).map_err(Base64DecodeError::base64)?;
116 B::try_from_bytes(decoded).map(Self::new)
117 }
118}
119
120impl<C: Base64Config, B: AsRef<[u8]>> fmt::Debug for Base64<C, B> {
121 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
122 self.encode().fmt(f)
123 }
124}
125
126impl<C: Base64Config, B: AsRef<[u8]>> fmt::Display for Base64<C, B> {
127 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
128 self.encode().fmt(f)
129 }
130}
131
132impl<'de, C: Base64Config, B: TryFromBase64DecodedBytes> Deserialize<'de> for Base64<C, B> {
133 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
134 where
135 D: Deserializer<'de>,
136 {
137 let encoded = super::deserialize_cow_str(deserializer)?;
138 Self::parse(&*encoded).map_err(de::Error::custom)
139 }
140}
141
142impl<C: Base64Config, B: AsRef<[u8]>> Serialize for Base64<C, B> {
143 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
144 where
145 S: Serializer,
146 {
147 serializer.serialize_str(&self.encode())
148 }
149}
150
151pub trait TryFromBase64DecodedBytes: Sized + AsRef<[u8]> {
157 #[doc(hidden)]
159 fn try_from_bytes(bytes: Vec<u8>) -> Result<Self, Base64DecodeError>;
160}
161
162impl TryFromBase64DecodedBytes for Vec<u8> {
163 fn try_from_bytes(bytes: Vec<u8>) -> Result<Self, Base64DecodeError> {
164 Ok(bytes)
165 }
166}
167
168impl<const N: usize> TryFromBase64DecodedBytes for [u8; N] {
169 fn try_from_bytes(bytes: Vec<u8>) -> Result<Self, Base64DecodeError> {
170 Self::try_from(bytes)
171 .map_err(|bytes| Base64DecodeError::invalid_decoded_length(bytes.len(), N))
172 }
173}
174
175#[derive(Clone)]
177pub struct Base64DecodeError(Base64DecodeErrorInner);
178
179impl Base64DecodeError {
180 fn base64(error: base64::DecodeError) -> Self {
182 Self(Base64DecodeErrorInner::Base64(error))
183 }
184
185 fn invalid_decoded_length(len: usize, expected: usize) -> Self {
187 Self(Base64DecodeErrorInner::InvalidDecodedLength { len, expected })
188 }
189}
190
191impl fmt::Debug for Base64DecodeError {
192 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
193 self.0.fmt(f)
194 }
195}
196
197impl fmt::Display for Base64DecodeError {
198 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
199 match &self.0 {
200 Base64DecodeErrorInner::Base64(error) => write!(f, "invalid base64 encoding: {error}"),
201 Base64DecodeErrorInner::InvalidDecodedLength { len, expected } => {
202 write!(f, "invalid decoded base64 bytes length: {len}, expected {expected}")
203 }
204 }
205 }
206}
207
208impl std::error::Error for Base64DecodeError {}
209
210#[derive(Debug, Clone)]
212enum Base64DecodeErrorInner {
213 Base64(base64::DecodeError),
215
216 InvalidDecodedLength {
218 len: usize,
220 expected: usize,
222 },
223}
224
225#[cfg(test)]
226mod tests {
227 use super::{Base64, Standard};
228
229 #[test]
230 fn parse_base64() {
231 const INPUT: &str = "3UmJnEIzUr2xWyaUnJg5fXwRybwG5FVC6Gq\
232 MHverEUn0ztuIsvVxX89JXX2pvdTsOBbLQx+4TVL02l4Cp5wPCm";
233 const INPUT_WITH_PADDING: &str = "im9+knCkMNQNh9o6sbdcZw==";
234
235 Base64::<Standard>::parse(INPUT).unwrap();
236 Base64::<Standard>::parse(INPUT_WITH_PADDING)
237 .expect("We should be able to decode padded Base64");
238
239 Base64::<Standard, [u8; 32]>::parse(INPUT).unwrap_err();
241 Base64::<Standard, [u8; 64]>::parse(INPUT).unwrap();
242 Base64::<Standard, [u8; 32]>::parse(INPUT_WITH_PADDING).unwrap_err();
243 Base64::<Standard, [u8; 16]>::parse(INPUT_WITH_PADDING).unwrap();
244 }
245}