ruma_common/serde/
base64.rs
1use std::{fmt, marker::PhantomData};
4
5use base64::{
6 engine::{general_purpose, DecodePaddingMode, GeneralPurpose, GeneralPurposeConfig},
7 Engine,
8};
9use serde::{de, Deserialize, Deserializer, Serialize, Serializer};
10
11#[derive(Clone, PartialEq, Eq, PartialOrd, Ord)]
16pub struct Base64<C = Standard, B = Vec<u8>> {
17 bytes: B,
18 _phantom_conf: PhantomData<fn(C) -> C>,
20}
21
22pub trait Base64Config {
24 #[doc(hidden)]
28 const CONF: Conf;
29}
30
31#[doc(hidden)]
32pub struct Conf(base64::alphabet::Alphabet);
33
34#[non_exhaustive]
38#[derive(Clone, PartialEq, Eq, PartialOrd, Ord)]
40pub struct Standard;
41
42impl Base64Config for Standard {
43 const CONF: Conf = Conf(base64::alphabet::STANDARD);
44}
45
46#[non_exhaustive]
50#[derive(Clone, PartialEq, Eq, PartialOrd, Ord)]
52pub struct UrlSafe;
53
54impl Base64Config for UrlSafe {
55 const CONF: Conf = Conf(base64::alphabet::URL_SAFE);
56}
57
58impl<C: Base64Config, B> Base64<C, B> {
59 const CONFIG: GeneralPurposeConfig = general_purpose::NO_PAD
60 .with_decode_allow_trailing_bits(true)
62 .with_decode_padding_mode(DecodePaddingMode::Indifferent);
63 const ENGINE: GeneralPurpose = GeneralPurpose::new(&C::CONF.0, Self::CONFIG);
64}
65
66impl<C: Base64Config, B: AsRef<[u8]>> Base64<C, B> {
67 pub fn new(bytes: B) -> Self {
69 Self { bytes, _phantom_conf: PhantomData }
70 }
71
72 pub fn as_bytes(&self) -> &[u8] {
74 self.bytes.as_ref()
75 }
76
77 pub fn encode(&self) -> String {
79 Self::ENGINE.encode(self.as_bytes())
80 }
81}
82
83impl<C, B> Base64<C, B> {
84 pub fn into_inner(self) -> B {
86 self.bytes
87 }
88}
89
90impl<C: Base64Config> Base64<C> {
91 pub fn empty() -> Self {
93 Self::new(Vec::new())
94 }
95
96 pub fn parse(encoded: impl AsRef<[u8]>) -> Result<Self, Base64DecodeError> {
98 Self::ENGINE.decode(encoded).map(Self::new).map_err(Base64DecodeError)
99 }
100}
101
102impl<C: Base64Config, B: AsRef<[u8]>> fmt::Debug for Base64<C, B> {
103 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
104 self.encode().fmt(f)
105 }
106}
107
108impl<C: Base64Config, B: AsRef<[u8]>> fmt::Display for Base64<C, B> {
109 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
110 self.encode().fmt(f)
111 }
112}
113
114impl<'de, C: Base64Config> Deserialize<'de> for Base64<C> {
115 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
116 where
117 D: Deserializer<'de>,
118 {
119 let encoded = super::deserialize_cow_str(deserializer)?;
120 Self::parse(&*encoded).map_err(de::Error::custom)
121 }
122}
123
124impl<C: Base64Config, B: AsRef<[u8]>> Serialize for Base64<C, B> {
125 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
126 where
127 S: Serializer,
128 {
129 serializer.serialize_str(&self.encode())
130 }
131}
132
133#[derive(Clone)]
135pub struct Base64DecodeError(base64::DecodeError);
136
137impl fmt::Debug for Base64DecodeError {
138 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
139 self.0.fmt(f)
140 }
141}
142
143impl fmt::Display for Base64DecodeError {
144 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
145 self.0.fmt(f)
146 }
147}
148
149impl std::error::Error for Base64DecodeError {}
150
151#[cfg(test)]
152mod tests {
153 use super::{Base64, Standard};
154
155 #[test]
156 fn slightly_malformed_base64() {
157 const INPUT: &str = "3UmJnEIzUr2xWyaUnJg5fXwRybwG5FVC6Gq\
158 MHverEUn0ztuIsvVxX89JXX2pvdTsOBbLQx+4TVL02l4Cp5wPCm";
159 const INPUT_WITH_PADDING: &str = "im9+knCkMNQNh9o6sbdcZw==";
160
161 Base64::<Standard>::parse(INPUT).unwrap();
162 Base64::<Standard>::parse(INPUT_WITH_PADDING)
163 .expect("We should be able to decode padded Base64");
164 }
165}