Skip to main content

pw_base64/
lib.rs

1// Copyright 2023 The Pigweed Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may not
4// use this file except in compliance with the License. You may obtain a copy of
5// the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12// License for the specific language governing permissions and limitations under
13// the License.
14#![cfg_attr(not(feature = "std"), no_std)]
15#![deny(missing_docs)]
16
17//! `pw_base64` provides simple encoding of data into base64.
18//!
19//! ```
20//! const INPUT: &'static [u8] = "I 💖 Pigweed".as_bytes();
21//!
22//! // [`encoded_size`] can be used to calculate the size of the output buffer.
23//! let mut output = [0u8; pw_base64::encoded_size(INPUT.len())];
24//!
25//! // Data can be encoded to a `&mut [u8]`.
26//! let output_size = pw_base64::encode(INPUT, &mut output).unwrap();
27//! assert_eq!(&output[0..output_size], b"SSDwn5KWIFBpZ3dlZWQ=");
28//!
29//! // The output buffer can also be automatically converted to a `&str`.
30//! let output_str = pw_base64::encode_str(INPUT, &mut output).unwrap();
31//! assert_eq!(output_str, "SSDwn5KWIFBpZ3dlZWQ=");
32//! ```
33
34use pw_status::{Error, Result};
35use pw_stream::{Cursor, ReadInteger, Seek, Write};
36
37// Helper macro to make declaring the base 64 encode table more concise.
38macro_rules! b {
39    ($char:tt) => {
40        stringify!($char).as_bytes()[0]
41    };
42}
43
44// We use `u8`s in our encoding table instead of `char`s in order to avoid the
45// overhead of 1) storing each entry as 4 bytes and 2) overhead of converting
46// from `char` to `u8` while building the output.
47//
48// When constructing this table, the `b!` macro makes the assumption that
49// all the characters are a single byte in utf8.  This is true as base64
50// only outputs ASCII characters.
51#[rustfmt::skip]
52const BASE64_ENCODE_TABLE: [u8; 64] = [
53    b!(A), b!(B), b!(C), b!(D), b!(E), b!(F), b!(G), b!(H),
54    b!(I), b!(J), b!(K), b!(L), b!(M), b!(N), b!(O), b!(P),
55    b!(Q), b!(R), b!(S), b!(T), b!(U), b!(V), b!(W), b!(X),
56    b!(Y), b!(Z), b!(a), b!(b), b!(c), b!(d), b!(e), b!(f),
57    b!(g), b!(h), b!(i), b!(j), b!(k), b!(l), b!(m), b!(n),
58    b!(o), b!(p), b!(q), b!(r), b!(s), b!(t), b!(u), b!(v),
59    b!(w), b!(x), b!(y), b!(z), b!(0), b!(1), b!(2), b!(3),
60    b!(4), b!(5), b!(6), b!(7), b!(8), b!(9), b!(+), b!(/),
61];
62const BASE64_PADDING: u8 = b!(=);
63
64const MIN_VALID_CHAR: u8 = b'+';
65const MAX_VALID_CHAR: u8 = b'z';
66const INVALID_CHAR: u8 = 0xff;
67const IVLD: u8 = INVALID_CHAR;
68
69#[rustfmt::skip]
70const BASE64_DECODE_TABLE: [u8; 80] = [
71    0x3e, IVLD, 0x3e, IVLD, 0x3f, 0x34, 0x35, 0x36, 0x37, 0x38,
72    0x39, 0x3a, 0x3b, 0x3c, 0x3d, IVLD, IVLD, IVLD, IVLD, IVLD,
73    IVLD, IVLD, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
74    0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11,
75    0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, IVLD, IVLD,
76    IVLD, IVLD, 0x3f, IVLD, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f,
77    0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29,
78    0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, 0x30, 0x31, 0x32, 0x33,
79
80];
81
82const fn char_to_bits(c: u8) -> u8 {
83    BASE64_DECODE_TABLE[(c - MIN_VALID_CHAR) as usize]
84}
85
86const fn bits_0_1(char0: u8, char1: u8) -> u8 {
87    (char0 << 2) | ((char1 & 0b110000) >> 4)
88}
89
90const fn bits_1_2(char1: u8, char2: u8) -> u8 {
91    ((char1 & 0b001111) << 4) | ((char2 & 0b111100) >> 2)
92}
93
94const fn bits_2_3(char2: u8, char3: u8) -> u8 {
95    ((char2 & 0b000011) << 6) | char3
96}
97
98/// Returns the size of the output buffer needed to encode an input buffer of
99/// size `input_size`.
100pub const fn encoded_size(input_size: usize) -> usize {
101    input_size.div_ceil(3) * 4 // round up to a 3-byte group
102}
103
104/// Returns the maximum size of the decoded output for a given encoded size.
105pub const fn max_decoded_size(encoded_size: usize) -> usize {
106    if encoded_size.is_multiple_of(4) {
107        encoded_size / 4 * 3
108    } else {
109        0
110    }
111}
112
113/// Returns the exact size of the decoded output for a valid Base64 string.
114pub fn decoded_size(encoded: &[u8]) -> usize {
115    if !encoded.len().is_multiple_of(4) || encoded.is_empty() {
116        return 0;
117    }
118    let max_bytes = max_decoded_size(encoded.len());
119    let mut padding = 0;
120    if encoded[encoded.len() - 2] == BASE64_PADDING {
121        padding = 2;
122    } else if encoded[encoded.len() - 1] == BASE64_PADDING {
123        padding = 1;
124    }
125    max_bytes - padding
126}
127
128/// Returns true if the provided character is a valid Base64 character.
129pub fn is_valid_char(c: char) -> bool {
130    if !c.is_ascii() {
131        return false;
132    }
133    let val = c as u8;
134    (MIN_VALID_CHAR..=MAX_VALID_CHAR).contains(&val) && char_to_bits(val) != INVALID_CHAR
135}
136
137/// Returns true if the provided data is valid Base64.
138pub fn is_valid(encoded: &[u8]) -> bool {
139    if encoded.is_empty() {
140        return true;
141    }
142    if !encoded.len().is_multiple_of(4) {
143        return false;
144    }
145    // Check all characters except the last two.
146    for &byte in encoded.iter().take(encoded.len() - 2) {
147        if !is_valid_char(byte as char) {
148            return false;
149        }
150    }
151    // Check the last two characters.
152    let penultimate = encoded[encoded.len() - 2];
153    let last = encoded[encoded.len() - 1];
154
155    if penultimate == BASE64_PADDING {
156        last == BASE64_PADDING
157    } else if is_valid_char(penultimate as char) {
158        is_valid_char(last as char) || last == BASE64_PADDING
159    } else {
160        false
161    }
162}
163
164/// Decodes the provided Base64 data into raw binary.
165///
166/// Returns the number of bytes written to `output` on success, or an error.
167pub fn decode(encoded: &[u8], output: &mut [u8]) -> Result<usize> {
168    if encoded.is_empty() {
169        return Ok(0);
170    }
171    if !is_valid(encoded) {
172        return Err(Error::InvalidArgument);
173    }
174    if output.len() < max_decoded_size(encoded.len()) {
175        return Err(Error::OutOfRange);
176    }
177
178    let mut binary_len = 0;
179    let mut ch = 0;
180    while ch < encoded.len() - 4 {
181        let char0 = char_to_bits(encoded[ch]);
182        let char1 = char_to_bits(encoded[ch + 1]);
183        let char2 = char_to_bits(encoded[ch + 2]);
184        let char3 = char_to_bits(encoded[ch + 3]);
185
186        output[binary_len] = bits_0_1(char0, char1);
187        output[binary_len + 1] = bits_1_2(char1, char2);
188        output[binary_len + 2] = bits_2_3(char2, char3);
189
190        binary_len += 3;
191        ch += 4;
192    }
193
194    // Decode the final group, which may include padding.
195    let char0 = char_to_bits(encoded[ch]);
196    let char1 = char_to_bits(encoded[ch + 1]);
197    let char2 = char_to_bits(encoded[ch + 2]);
198    let char3 = char_to_bits(encoded[ch + 3]);
199
200    output[binary_len] = bits_0_1(char0, char1);
201    binary_len += 1;
202
203    if encoded[ch + 2] != BASE64_PADDING {
204        output[binary_len] = bits_1_2(char1, char2);
205        binary_len += 1;
206        if encoded[ch + 3] != BASE64_PADDING {
207            output[binary_len] = bits_2_3(char2, char3);
208            binary_len += 1;
209        }
210    }
211
212    Ok(binary_len)
213}
214
215// Base 64 encoding represents every 3 bytes with 4 ascii characters.  Each
216// of these 4 ascii characters represents 6 bits of data from the 3 bytes of
217// input.  The below helpers calculate each of the 4 characters form the 3 bytes
218// of input.
219const fn char_0(b: &[u8; 3]) -> u8 {
220    BASE64_ENCODE_TABLE[((b[0] & 0b11111100) >> 2) as usize]
221}
222
223const fn char_1(b: &[u8; 3]) -> u8 {
224    BASE64_ENCODE_TABLE[(((b[0] & 0b00000011) << 4) | ((b[1] & 0b11110000) >> 4)) as usize]
225}
226
227const fn char_2(b: &[u8; 3]) -> u8 {
228    BASE64_ENCODE_TABLE[(((b[1] & 0b00001111) << 2) | ((b[2] & 0b11000000) >> 6)) as usize]
229}
230
231const fn char_3(b: &[u8; 3]) -> u8 {
232    BASE64_ENCODE_TABLE[(b[2] & 0b00111111) as usize]
233}
234
235/// Encode `input` as base64 into the `output_buffer`.
236///
237/// Returns the number of bytes written to `output_buffer` on success or
238/// `Error::OutOfRange` if `output_buffer` is not large enough.
239pub fn encode(input: &[u8], output: &mut [u8]) -> Result<usize> {
240    if output.len() < encoded_size(input.len()) {
241        return Err(Error::OutOfRange);
242    }
243    let mut input = Cursor::new(input);
244    let mut output = Cursor::new(output);
245
246    let mut remaining_bytes = input.len();
247    while remaining_bytes > 0 {
248        let bytes = [
249            input.read_u8_le().unwrap_or(0),
250            input.read_u8_le().unwrap_or(0),
251            input.read_u8_le().unwrap_or(0),
252        ];
253
254        output.write(&[
255            char_0(&bytes),
256            char_1(&bytes),
257            if remaining_bytes > 1 {
258                char_2(&bytes)
259            } else {
260                BASE64_PADDING
261            },
262            if remaining_bytes > 2 {
263                char_3(&bytes)
264            } else {
265                BASE64_PADDING
266            },
267        ])?;
268        remaining_bytes = remaining_bytes.saturating_add_signed(-3);
269    }
270
271    output.stream_position().map(|len| len as usize)
272}
273
274/// Encode `input` as base64 into `output_buffer` and interprets it as a
275/// string.
276///
277/// Returns a `&str` referencing the `output_buffer` buffer on success or
278/// `Error::OutOfRange` if `output_buffer` is not large enough.
279///
280/// Using this method avoids having to do unicode checking as it can guarantee
281/// that the data written to `output_buffer` is only valid ASCII bytes.
282pub fn encode_str<'a>(input: &[u8], output_buffer: &'a mut [u8]) -> Result<&'a str> {
283    let encode_len = encode(input, output_buffer)?;
284    // Safety: Since we are building the output buffer strictly from ASCII
285    // characters, it is guaranteed to be valid UTF-8.
286    // encode_len has already been checked to be less than output_buffer
287    // in the encode() call.
288    unsafe {
289        Ok(core::str::from_utf8_unchecked(
290            output_buffer.get(0..encode_len).unwrap_unchecked(),
291        ))
292    }
293}
294
295#[cfg(test)]
296mod tests;