pw_stream/
cursor.rs

1// Copyright 2023 The Pigweed Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may not
4// use this file except in compliance with the License. You may obtain a copy of
5// the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12// License for the specific language governing permissions and limitations under
13// the License.
14
15use core::cmp::min;
16
17use paste::paste;
18use pw_status::{Error, Result};
19use pw_varint::{VarintDecode, VarintEncode};
20
21use super::{Read, Seek, SeekFrom, Write};
22
23/// Wraps an <code>[AsRef]<[u8]></code> in a container implementing
24/// [`Read`], [`Write`], and [`Seek`].
25///
26/// [`Write`] support requires the inner type also implement
27/// <code>[AsMut]<[u8]></code>.
28pub struct Cursor<T>
29where
30    T: AsRef<[u8]>,
31{
32    inner: T,
33    pos: usize,
34}
35
36impl<T: AsRef<[u8]>> Cursor<T> {
37    /// Create a new Cursor wrapping `inner` with an initial position of 0.
38    ///
39    /// Semantics match [`std::io::Cursor::new()`].
40    pub fn new(inner: T) -> Self {
41        Self { inner, pos: 0 }
42    }
43
44    /// Consumes the cursor and returns the inner wrapped data.
45    pub fn into_inner(self) -> T {
46        self.inner
47    }
48
49    /// Returns the number of remaining bytes in the Cursor.
50    pub fn remaining(&self) -> usize {
51        self.len() - self.pos
52    }
53
54    /// Returns the total length of the Cursor.
55    // Empty is ambiguous whether it should refer to len() or remaining() so
56    // we don't provide it.
57    #[allow(clippy::len_without_is_empty)]
58    pub fn len(&self) -> usize {
59        self.inner.as_ref().len()
60    }
61
62    /// Returns current IO position of the Cursor.
63    pub fn position(&self) -> usize {
64        self.pos
65    }
66
67    fn remaining_slice(&mut self) -> &[u8] {
68        &self.inner.as_ref()[self.pos..]
69    }
70}
71
72impl<T: AsRef<[u8]> + AsMut<[u8]>> Cursor<T> {
73    fn remaining_mut(&mut self) -> &mut [u8] {
74        &mut self.inner.as_mut()[self.pos..]
75    }
76}
77
78// Implement `read()` as a concrete function to avoid extra monomorphization
79// overhead.
80fn read_impl(inner: &[u8], pos: &mut usize, buf: &mut [u8]) -> Result<usize> {
81    let remaining = inner.len() - *pos;
82    let read_len = min(remaining, buf.len());
83    buf[..read_len].copy_from_slice(&inner[*pos..(*pos + read_len)]);
84    *pos += read_len;
85    Ok(read_len)
86}
87
88impl<T: AsRef<[u8]>> Read for Cursor<T> {
89    fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
90        read_impl(self.inner.as_ref(), &mut self.pos, buf)
91    }
92}
93
94// Implement `write()` as a concrete function to avoid extra monomorphization
95// overhead.
96fn write_impl(inner: &mut [u8], pos: &mut usize, buf: &[u8]) -> Result<usize> {
97    let remaining = inner.len() - *pos;
98    let write_len = min(remaining, buf.len());
99    inner[*pos..(*pos + write_len)].copy_from_slice(&buf[0..write_len]);
100    *pos += write_len;
101    Ok(write_len)
102}
103
104impl<T: AsRef<[u8]> + AsMut<[u8]>> Write for Cursor<T> {
105    fn write(&mut self, buf: &[u8]) -> Result<usize> {
106        write_impl(self.inner.as_mut(), &mut self.pos, buf)
107    }
108
109    fn flush(&mut self) -> Result<()> {
110        // Cursor does not provide any buffering so flush() is a noop.
111        Ok(())
112    }
113}
114
115impl<T: AsRef<[u8]>> Seek for Cursor<T> {
116    fn seek(&mut self, pos: SeekFrom) -> Result<u64> {
117        let new_pos = match pos {
118            SeekFrom::Start(pos) => pos,
119            SeekFrom::Current(pos) => (self.pos as u64)
120                .checked_add_signed(pos)
121                .ok_or(Error::OutOfRange)?,
122            SeekFrom::End(pos) => (self.len() as u64)
123                .checked_add_signed(-pos)
124                .ok_or(Error::OutOfRange)?,
125        };
126
127        // Since Cursor operates on in memory buffers, it's limited by usize.
128        // Return an error if we are asked to seek beyond that limit.
129        let new_pos: usize = new_pos.try_into().map_err(|_| Error::OutOfRange)?;
130
131        if new_pos > self.len() {
132            Err(Error::OutOfRange)
133        } else {
134            self.pos = new_pos;
135            Ok(new_pos as u64)
136        }
137    }
138
139    // Implement more efficient versions of rewind, stream_len, stream_position.
140    fn rewind(&mut self) -> Result<()> {
141        self.pos = 0;
142        Ok(())
143    }
144
145    fn stream_len(&mut self) -> Result<u64> {
146        Ok(self.len() as u64)
147    }
148
149    fn stream_position(&mut self) -> Result<u64> {
150        Ok(self.pos as u64)
151    }
152}
153
154macro_rules! cursor_read_type_impl {
155    ($ty:ident, $endian:ident) => {
156        paste! {
157          fn [<read_ $ty _ $endian>](&mut self) -> Result<$ty> {
158            const NUM_BYTES: usize = $ty::BITS as usize / 8;
159            if NUM_BYTES > self.remaining() {
160                return Err(Error::OutOfRange);
161            }
162            let sub_slice = self
163                .inner
164                .as_ref()
165                .get(self.pos..self.pos + NUM_BYTES)
166                .ok_or_else(|| Error::InvalidArgument)?;
167            // Because we are code size conscious we want an infallible way to
168            // turn `sub_slice` into a fixed sized array as opposed to using
169            // something like `.try_into()?`.
170            //
171            // Safety:  We are both bounds checking and size constraining the
172            // slice in the above lines of code.
173            let sub_array: &[u8; NUM_BYTES] = unsafe { &*(sub_slice.as_ptr() as *const [u8; NUM_BYTES]) };
174            let value = $ty::[<from_ $endian _bytes>](*sub_array);
175
176            self.pos += NUM_BYTES;
177            Ok(value)
178          }
179        }
180    };
181}
182
183macro_rules! cursor_read_bits_impl {
184    ($bits:literal) => {
185        paste! {
186          cursor_read_type_impl!([<i $bits>], le);
187          cursor_read_type_impl!([<u $bits>], le);
188          cursor_read_type_impl!([<i $bits>], be);
189          cursor_read_type_impl!([<u $bits>], be);
190        }
191    };
192}
193
194macro_rules! cursor_write_type_impl {
195    ($ty:ident, $endian:ident) => {
196        paste! {
197          fn [<write_ $ty _ $endian>](&mut self, value: &$ty) -> Result<()> {
198            const NUM_BYTES: usize = $ty::BITS as usize / 8;
199            if NUM_BYTES > self.remaining() {
200                return Err(Error::OutOfRange);
201            }
202            let value_bytes = $ty::[<to_ $endian _bytes>](*value);
203            let sub_slice = self
204                .inner
205                .as_mut()
206                .get_mut(self.pos..self.pos + NUM_BYTES)
207                .ok_or_else(|| Error::InvalidArgument)?;
208
209            sub_slice.copy_from_slice(&value_bytes[..]);
210
211            self.pos += NUM_BYTES;
212            Ok(())
213          }
214        }
215    };
216}
217
218macro_rules! cursor_write_bits_impl {
219    ($bits:literal) => {
220        paste! {
221          cursor_write_type_impl!([<i $bits>], le);
222          cursor_write_type_impl!([<u $bits>], le);
223          cursor_write_type_impl!([<i $bits>], be);
224          cursor_write_type_impl!([<u $bits>], be);
225        }
226    };
227}
228
229impl<T: AsRef<[u8]>> crate::ReadInteger for Cursor<T> {
230    cursor_read_bits_impl!(8);
231    cursor_read_bits_impl!(16);
232    cursor_read_bits_impl!(32);
233    cursor_read_bits_impl!(64);
234    cursor_read_bits_impl!(128);
235}
236
237impl<T: AsRef<[u8]> + AsMut<[u8]>> crate::WriteInteger for Cursor<T> {
238    cursor_write_bits_impl!(8);
239    cursor_write_bits_impl!(16);
240    cursor_write_bits_impl!(32);
241    cursor_write_bits_impl!(64);
242    cursor_write_bits_impl!(128);
243}
244
245impl<T: AsRef<[u8]>> crate::ReadVarint for Cursor<T> {
246    fn read_varint(&mut self) -> Result<u64> {
247        let (len, value) = u64::varint_decode(self.remaining_slice())?;
248        self.pos += len;
249        Ok(value)
250    }
251
252    fn read_signed_varint(&mut self) -> Result<i64> {
253        let (len, value) = i64::varint_decode(self.remaining_slice())?;
254        self.pos += len;
255        Ok(value)
256    }
257}
258
259impl<T: AsRef<[u8]> + AsMut<[u8]>> crate::WriteVarint for Cursor<T> {
260    fn write_varint(&mut self, value: u64) -> Result<()> {
261        let encoded_len = value.varint_encode(self.remaining_mut())?;
262        self.pos += encoded_len;
263        Ok(())
264    }
265
266    fn write_signed_varint(&mut self, value: i64) -> Result<()> {
267        let encoded_len = value.varint_encode(self.remaining_mut())?;
268        self.pos += encoded_len;
269        Ok(())
270    }
271}
272
273#[cfg(test)]
274mod tests {
275    use super::*;
276    use crate::{test_utils::*, ReadInteger, ReadVarint, WriteInteger, WriteVarint};
277
278    #[test]
279    fn cursor_len_returns_total_bytes() {
280        let cursor = Cursor {
281            inner: &[0u8; 64],
282            pos: 31,
283        };
284        assert_eq!(cursor.len(), 64);
285    }
286
287    #[test]
288    fn cursor_remaining_returns_remaining_bytes() {
289        let cursor = Cursor {
290            inner: &[0u8; 64],
291            pos: 31,
292        };
293        assert_eq!(cursor.remaining(), 33);
294    }
295
296    #[test]
297    fn cursor_position_returns_current_position() {
298        let cursor = Cursor {
299            inner: &[0u8; 64],
300            pos: 31,
301        };
302        assert_eq!(cursor.position(), 31);
303    }
304
305    #[test]
306    fn cursor_read_of_partial_buffer_reads_correct_data() {
307        let mut cursor = Cursor {
308            inner: &[1, 2, 3, 4, 5, 6, 7, 8],
309            pos: 4,
310        };
311        let mut buf = [0u8; 8];
312        assert_eq!(cursor.read(&mut buf), Ok(4));
313        assert_eq!(buf, [5, 6, 7, 8, 0, 0, 0, 0]);
314    }
315
316    #[test]
317    fn cursor_write_of_partial_buffer_writes_correct_data() {
318        let mut cursor = Cursor {
319            inner: &mut [0, 0, 0, 0, 0, 0, 0, 0],
320            pos: 4,
321        };
322        let buf = [1, 2, 3, 4, 5, 6, 7, 8];
323        assert_eq!(cursor.write(&buf), Ok(4));
324        assert_eq!(cursor.inner, &[0, 0, 0, 0, 1, 2, 3, 4]);
325    }
326
327    #[test]
328    fn cursor_rewind_resets_position_to_zero() {
329        test_rewind_resets_position_to_zero::<64, _>(Cursor::new(&[0u8; 64]));
330    }
331
332    #[test]
333    fn cursor_stream_pos_reports_correct_position() {
334        test_stream_pos_reports_correct_position::<64, _>(Cursor::new(&[0u8; 64]));
335    }
336
337    #[test]
338    fn cursor_stream_len_reports_correct_length() {
339        test_stream_len_reports_correct_length::<64, _>(Cursor::new(&[0u8; 64]));
340    }
341
342    macro_rules! cursor_read_n_bit_integers_unpacks_data_correctly {
343        ($bits:literal) => {
344            paste! {
345              #[test]
346              fn [<cursor_read_ $bits _bit_integers_unpacks_data_correctly>]() {
347                  let (bytes, values) = [<integer_ $bits _bit_test_cases>]();
348                  let mut cursor = Cursor::new(&bytes);
349
350                  assert_eq!(cursor.[<read_i $bits _le>](), Ok(values.0));
351                  assert_eq!(cursor.[<read_u $bits _le>](), Ok(values.1));
352                  assert_eq!(cursor.[<read_i $bits _be>](), Ok(values.2));
353                  assert_eq!(cursor.[<read_u $bits _be>](), Ok(values.3));
354              }
355            }
356        };
357    }
358
359    macro_rules! cursor_write_n_bit_integers_packs_data_correctly {
360        ($bits:literal) => {
361            paste! {
362              #[test]
363              fn [<cursor_write_ $bits _bit_integers_packs_data_correctly>]() {
364                  let (expected_bytes, values) = [<integer_ $bits _bit_test_cases>]();
365                  let mut cursor = Cursor::new(vec![0u8; expected_bytes.len()]);
366                  cursor.[<write_i $bits _le>](&values.0).unwrap();
367                  cursor.[<write_u $bits _le>](&values.1).unwrap();
368                  cursor.[<write_i $bits _be>](&values.2).unwrap();
369                  cursor.[<write_u $bits _be>](&values.3).unwrap();
370
371                  let result_bytes: Vec<u8> = cursor.into_inner().into();
372
373                  assert_eq!(result_bytes, expected_bytes);
374              }
375            }
376        };
377    }
378
379    fn integer_8_bit_test_cases() -> (Vec<u8>, (i8, u8, i8, u8)) {
380        (
381            vec![
382                0x0, // le i8
383                0x1, // le u8
384                0x2, // be i8
385                0x3, // be u8
386            ],
387            (0, 1, 2, 3),
388        )
389    }
390
391    cursor_read_n_bit_integers_unpacks_data_correctly!(8);
392    cursor_write_n_bit_integers_packs_data_correctly!(8);
393
394    fn integer_16_bit_test_cases() -> (Vec<u8>, (i16, u16, i16, u16)) {
395        (
396            vec![
397                0x0, 0x80, // le i16
398                0x1, 0x80, // le u16
399                0x80, 0x2, // be i16
400                0x80, 0x3, // be u16
401            ],
402            (
403                i16::from_le_bytes([0x0, 0x80]),
404                0x8001,
405                i16::from_be_bytes([0x80, 0x2]),
406                0x8003,
407            ),
408        )
409    }
410
411    cursor_read_n_bit_integers_unpacks_data_correctly!(16);
412    cursor_write_n_bit_integers_packs_data_correctly!(16);
413
414    fn integer_32_bit_test_cases() -> (Vec<u8>, (i32, u32, i32, u32)) {
415        (
416            vec![
417                0x0, 0x1, 0x2, 0x80, // le i32
418                0x3, 0x4, 0x5, 0x80, // le u32
419                0x80, 0x6, 0x7, 0x8, // be i32
420                0x80, 0x9, 0xa, 0xb, // be u32
421            ],
422            (
423                i32::from_le_bytes([0x0, 0x1, 0x2, 0x80]),
424                0x8005_0403,
425                i32::from_be_bytes([0x80, 0x6, 0x7, 0x8]),
426                0x8009_0a0b,
427            ),
428        )
429    }
430
431    cursor_read_n_bit_integers_unpacks_data_correctly!(32);
432    cursor_write_n_bit_integers_packs_data_correctly!(32);
433
434    fn integer_64_bit_test_cases() -> (Vec<u8>, (i64, u64, i64, u64)) {
435        (
436            vec![
437                0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x80, // le i64
438                0x7, 0x8, 0x9, 0xa, 0xb, 0xc, 0xd, 0x80, // le u64
439                0x80, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, // be i64
440                0x80, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, // be u64
441            ],
442            (
443                i64::from_le_bytes([0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x80]),
444                0x800d_0c0b_0a09_0807,
445                i64::from_be_bytes([0x80, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16]),
446                0x8017_1819_1a1b_1c1d,
447            ),
448        )
449    }
450
451    cursor_read_n_bit_integers_unpacks_data_correctly!(64);
452    cursor_write_n_bit_integers_packs_data_correctly!(64);
453
454    fn integer_128_bit_test_cases() -> (Vec<u8>, (i128, u128, i128, u128)) {
455        #[rustfmt::skip]
456        let val = (
457            vec![
458                // le i128
459                0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
460                0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x8f,
461                // le u128
462                0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17,
463                0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x8f,
464                // be i128
465                0x80, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27,
466                0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f,
467                // be u128
468                0x80, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37,
469                0x38, 0x39, 0x3a, 0x3b, 0x3c, 0x3d, 0x3e, 0x3f,
470            ],
471            (
472                i128::from_le_bytes([
473                    0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
474                    0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x8f,
475                ]),
476                0x8f1e_1d1c_1b1a_1918_1716_1514_1312_1110,
477                i128::from_be_bytes([
478                    0x80, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27,
479                    0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f,
480                ]),
481                0x8031_3233_3435_3637_3839_3a3b_3c3d_3e3f,
482            ),
483        );
484        val
485    }
486
487    cursor_read_n_bit_integers_unpacks_data_correctly!(128);
488    cursor_write_n_bit_integers_packs_data_correctly!(128);
489
490    #[test]
491    pub fn read_varint_unpacks_data_correctly() {
492        let mut cursor = Cursor::new(vec![0xfe, 0xff, 0xff, 0xff, 0x0f, 0x0, 0x0, 0x0]);
493        let value = cursor.read_varint().unwrap();
494        assert_eq!(value, 0xffff_fffe);
495
496        let mut cursor = Cursor::new(vec![0xff, 0xff, 0xff, 0xff, 0x0f, 0x0, 0x0, 0x0]);
497        let value = cursor.read_varint().unwrap();
498        assert_eq!(value, 0xffff_ffff);
499    }
500
501    #[test]
502    pub fn read_signed_varint_unpacks_data_correctly() {
503        let mut cursor = Cursor::new(vec![0xfe, 0xff, 0xff, 0xff, 0x0f, 0x0, 0x0, 0x0]);
504        let value = cursor.read_signed_varint().unwrap();
505        assert_eq!(value, i32::MAX.into());
506
507        let mut cursor = Cursor::new(vec![0xff, 0xff, 0xff, 0xff, 0x0f, 0x0, 0x0, 0x0]);
508        let value = cursor.read_signed_varint().unwrap();
509        assert_eq!(value, i32::MIN.into());
510    }
511
512    #[test]
513    pub fn write_varint_packs_data_correctly() {
514        let mut cursor = Cursor::new(vec![0u8; 8]);
515        cursor.write_varint(0xffff_fffe).unwrap();
516        let buf = cursor.into_inner();
517        assert_eq!(buf, vec![0xfe, 0xff, 0xff, 0xff, 0x0f, 0x0, 0x0, 0x0]);
518
519        let mut cursor = Cursor::new(vec![0u8; 8]);
520        cursor.write_varint(0xffff_ffff).unwrap();
521        let buf = cursor.into_inner();
522        assert_eq!(buf, vec![0xff, 0xff, 0xff, 0xff, 0x0f, 0x0, 0x0, 0x0]);
523    }
524
525    #[test]
526    pub fn write_signed_varint_packs_data_correctly() {
527        let mut cursor = Cursor::new(vec![0u8; 8]);
528        cursor.write_signed_varint(i32::MAX.into()).unwrap();
529        let buf = cursor.into_inner();
530        assert_eq!(buf, vec![0xfe, 0xff, 0xff, 0xff, 0x0f, 0x0, 0x0, 0x0]);
531
532        let mut cursor = Cursor::new(vec![0u8; 8]);
533        cursor.write_signed_varint(i32::MIN.into()).unwrap();
534        let buf = cursor.into_inner();
535        assert_eq!(buf, vec![0xff, 0xff, 0xff, 0xff, 0x0f, 0x0, 0x0, 0x0]);
536    }
537}