Skip to main content

pw_base64/
lib.rs

1// Copyright 2023 The Pigweed Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may not
4// use this file except in compliance with the License. You may obtain a copy of
5// the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12// License for the specific language governing permissions and limitations under
13// the License.
14#![cfg_attr(not(feature = "std"), no_std)]
15#![deny(missing_docs)]
16
17//! `pw_base64` provides simple encoding of data into base64.
18//!
19//! ```
20//! const INPUT: &'static [u8] = "I 💖 Pigweed".as_bytes();
21//!
22//! // [`encoded_size`] can be used to calculate the size of the output buffer.
23//! let mut output = [0u8; pw_base64::encoded_size(INPUT.len())];
24//!
25//! // Data can be encoded to a `&mut [u8]`.
26//! let output_size = pw_base64::encode(INPUT, &mut output).unwrap();
27//! assert_eq!(&output[0..output_size], b"SSDwn5KWIFBpZ3dlZWQ=");
28//!
29//! // The output buffer can also be automatically converted to a `&str`.
30//! let output_str = pw_base64::encode_str(INPUT, &mut output).unwrap();
31//! assert_eq!(output_str, "SSDwn5KWIFBpZ3dlZWQ=");
32//! ```
33
34use pw_status::{Error, Result};
35use pw_stream::{Cursor, ReadInteger, Seek, Write};
36
37// Helper macro to make declaring the base 64 encode table more concise.
38macro_rules! b {
39    ($char:tt) => {
40        stringify!($char).as_bytes()[0]
41    };
42}
43
44// We use `u8`s in our encoding table instead of `char`s in order to avoid the
45// overhead of 1) storing each entry as 4 bytes and 2) overhead of converting
46// from `char` to `u8` while building the output.
47//
48// When constructing this table, the `b!` macro makes the assumption that
49// all the characters are a single byte in utf8.  This is true as base64
50// only outputs ASCII characters.
51#[rustfmt::skip]
52const BASE64_ENCODE_TABLE: [u8; 64] = [
53    b!(A), b!(B), b!(C), b!(D), b!(E), b!(F), b!(G), b!(H),
54    b!(I), b!(J), b!(K), b!(L), b!(M), b!(N), b!(O), b!(P),
55    b!(Q), b!(R), b!(S), b!(T), b!(U), b!(V), b!(W), b!(X),
56    b!(Y), b!(Z), b!(a), b!(b), b!(c), b!(d), b!(e), b!(f),
57    b!(g), b!(h), b!(i), b!(j), b!(k), b!(l), b!(m), b!(n),
58    b!(o), b!(p), b!(q), b!(r), b!(s), b!(t), b!(u), b!(v),
59    b!(w), b!(x), b!(y), b!(z), b!(0), b!(1), b!(2), b!(3),
60    b!(4), b!(5), b!(6), b!(7), b!(8), b!(9), b!(+), b!(/),
61];
62const BASE64_PADDING: u8 = b!(=);
63
64const MIN_VALID_CHAR: u8 = b'+';
65const MAX_VALID_CHAR: u8 = b'z';
66const INVALID_CHAR: u8 = 0xff;
67const IVLD: u8 = INVALID_CHAR;
68
69#[rustfmt::skip]
70const BASE64_DECODE_TABLE: [u8; 80] = [
71    0x3e, IVLD, 0x3e, IVLD, 0x3f, 0x34, 0x35, 0x36, 0x37, 0x38,
72    0x39, 0x3a, 0x3b, 0x3c, 0x3d, IVLD, IVLD, IVLD, IVLD, IVLD,
73    IVLD, IVLD, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
74    0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11,
75    0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, IVLD, IVLD,
76    IVLD, IVLD, 0x3f, IVLD, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f,
77    0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29,
78    0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, 0x30, 0x31, 0x32, 0x33,
79
80];
81
82const fn char_to_bits(c: u8) -> u8 {
83    BASE64_DECODE_TABLE[(c - MIN_VALID_CHAR) as usize]
84}
85
86const fn bits_0_1(char0: u8, char1: u8) -> u8 {
87    (char0 << 2) | ((char1 & 0b110000) >> 4)
88}
89
90const fn bits_1_2(char1: u8, char2: u8) -> u8 {
91    ((char1 & 0b001111) << 4) | ((char2 & 0b111100) >> 2)
92}
93
94const fn bits_2_3(char2: u8, char3: u8) -> u8 {
95    ((char2 & 0b000011) << 6) | char3
96}
97
98/// Returns the size of the output buffer needed to encode an input buffer of
99/// size `input_size`.
100#[must_use]
101pub const fn encoded_size(input_size: usize) -> usize {
102    input_size.div_ceil(3) * 4 // round up to a 3-byte group
103}
104
105/// Returns the maximum size of the decoded output for a given encoded size.
106#[must_use]
107pub const fn max_decoded_size(encoded_size: usize) -> usize {
108    if encoded_size.is_multiple_of(4) {
109        encoded_size / 4 * 3
110    } else {
111        0
112    }
113}
114
115/// Returns the exact size of the decoded output for a valid Base64 string.
116#[must_use]
117pub fn decoded_size(encoded: &[u8]) -> usize {
118    if !encoded.len().is_multiple_of(4) || encoded.is_empty() {
119        return 0;
120    }
121    let max_bytes = max_decoded_size(encoded.len());
122    let mut padding = 0;
123    if encoded[encoded.len() - 2] == BASE64_PADDING {
124        padding = 2;
125    } else if encoded[encoded.len() - 1] == BASE64_PADDING {
126        padding = 1;
127    }
128    max_bytes - padding
129}
130
131/// Returns true if the provided character is a valid Base64 character.
132#[must_use]
133pub fn is_valid_char(c: char) -> bool {
134    if !c.is_ascii() {
135        return false;
136    }
137    let val = c as u8;
138    (MIN_VALID_CHAR..=MAX_VALID_CHAR).contains(&val) && char_to_bits(val) != INVALID_CHAR
139}
140
141/// Returns true if the provided data is valid Base64.
142#[must_use]
143pub fn is_valid(encoded: &[u8]) -> bool {
144    if encoded.is_empty() {
145        return true;
146    }
147    if !encoded.len().is_multiple_of(4) {
148        return false;
149    }
150    // Check all characters except the last two.
151    for &byte in encoded.iter().take(encoded.len() - 2) {
152        if !is_valid_char(byte as char) {
153            return false;
154        }
155    }
156    // Check the last two characters.
157    let penultimate = encoded[encoded.len() - 2];
158    let last = encoded[encoded.len() - 1];
159
160    if penultimate == BASE64_PADDING {
161        last == BASE64_PADDING
162    } else if is_valid_char(penultimate as char) {
163        is_valid_char(last as char) || last == BASE64_PADDING
164    } else {
165        false
166    }
167}
168
169/// Decodes the provided Base64 data into raw binary.
170///
171/// Returns the number of bytes written to `output` on success, or an error.
172pub fn decode(encoded: &[u8], output: &mut [u8]) -> Result<usize> {
173    if encoded.is_empty() {
174        return Ok(0);
175    }
176    if !is_valid(encoded) {
177        return Err(Error::InvalidArgument);
178    }
179    if output.len() < max_decoded_size(encoded.len()) {
180        return Err(Error::OutOfRange);
181    }
182
183    let mut binary_len = 0;
184    let mut ch = 0;
185    while ch < encoded.len() - 4 {
186        let char0 = char_to_bits(encoded[ch]);
187        let char1 = char_to_bits(encoded[ch + 1]);
188        let char2 = char_to_bits(encoded[ch + 2]);
189        let char3 = char_to_bits(encoded[ch + 3]);
190
191        output[binary_len] = bits_0_1(char0, char1);
192        output[binary_len + 1] = bits_1_2(char1, char2);
193        output[binary_len + 2] = bits_2_3(char2, char3);
194
195        binary_len += 3;
196        ch += 4;
197    }
198
199    // Decode the final group, which may include padding.
200    let char0 = char_to_bits(encoded[ch]);
201    let char1 = char_to_bits(encoded[ch + 1]);
202    let char2 = char_to_bits(encoded[ch + 2]);
203    let char3 = char_to_bits(encoded[ch + 3]);
204
205    output[binary_len] = bits_0_1(char0, char1);
206    binary_len += 1;
207
208    if encoded[ch + 2] != BASE64_PADDING {
209        output[binary_len] = bits_1_2(char1, char2);
210        binary_len += 1;
211        if encoded[ch + 3] != BASE64_PADDING {
212            output[binary_len] = bits_2_3(char2, char3);
213            binary_len += 1;
214        }
215    }
216
217    Ok(binary_len)
218}
219
220// Base 64 encoding represents every 3 bytes with 4 ascii characters.  Each
221// of these 4 ascii characters represents 6 bits of data from the 3 bytes of
222// input.  The below helpers calculate each of the 4 characters form the 3 bytes
223// of input.
224const fn char_0(b: &[u8; 3]) -> u8 {
225    BASE64_ENCODE_TABLE[((b[0] & 0b11111100) >> 2) as usize]
226}
227
228const fn char_1(b: &[u8; 3]) -> u8 {
229    BASE64_ENCODE_TABLE[(((b[0] & 0b00000011) << 4) | ((b[1] & 0b11110000) >> 4)) as usize]
230}
231
232const fn char_2(b: &[u8; 3]) -> u8 {
233    BASE64_ENCODE_TABLE[(((b[1] & 0b00001111) << 2) | ((b[2] & 0b11000000) >> 6)) as usize]
234}
235
236const fn char_3(b: &[u8; 3]) -> u8 {
237    BASE64_ENCODE_TABLE[(b[2] & 0b00111111) as usize]
238}
239
240/// Encode `input` as base64 into the `output_buffer`.
241///
242/// Returns the number of bytes written to `output_buffer` on success or
243/// `Error::OutOfRange` if `output_buffer` is not large enough.
244pub fn encode(input: &[u8], output: &mut [u8]) -> Result<usize> {
245    if output.len() < encoded_size(input.len()) {
246        return Err(Error::OutOfRange);
247    }
248    let mut input = Cursor::new(input);
249    let mut output = Cursor::new(output);
250
251    let mut remaining_bytes = input.len();
252    while remaining_bytes > 0 {
253        let bytes = [
254            input.read_u8_le().unwrap_or(0),
255            input.read_u8_le().unwrap_or(0),
256            input.read_u8_le().unwrap_or(0),
257        ];
258
259        output.write(&[
260            char_0(&bytes),
261            char_1(&bytes),
262            if remaining_bytes > 1 {
263                char_2(&bytes)
264            } else {
265                BASE64_PADDING
266            },
267            if remaining_bytes > 2 {
268                char_3(&bytes)
269            } else {
270                BASE64_PADDING
271            },
272        ])?;
273        remaining_bytes = remaining_bytes.saturating_add_signed(-3);
274    }
275
276    let len = output.stream_position()?;
277    usize::try_from(len).map_err(|_| Error::OutOfRange)
278}
279
280/// Encode `input` as base64 into `output_buffer` and interprets it as a
281/// string.
282///
283/// Returns a `&str` referencing the `output_buffer` buffer on success or
284/// `Error::OutOfRange` if `output_buffer` is not large enough.
285///
286/// Using this method avoids having to do unicode checking as it can guarantee
287/// that the data written to `output_buffer` is only valid ASCII bytes.
288pub fn encode_str<'a>(input: &[u8], output_buffer: &'a mut [u8]) -> Result<&'a str> {
289    let encode_len = encode(input, output_buffer)?;
290    // Safety: Since we are building the output buffer strictly from ASCII
291    // characters, it is guaranteed to be valid UTF-8.
292    // encode_len has already been checked to be less than output_buffer
293    // in the encode() call.
294    unsafe {
295        Ok(core::str::from_utf8_unchecked(
296            output_buffer.get(0..encode_len).unwrap_unchecked(),
297        ))
298    }
299}
300
301#[cfg(test)]
302mod tests;