pw_stream/
cursor.rs

1// Copyright 2023 The Pigweed Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may not
4// use this file except in compliance with the License. You may obtain a copy of
5// the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12// License for the specific language governing permissions and limitations under
13// the License.
14
15use core::cmp::min;
16use core::ptr;
17
18use paste::paste;
19use pw_status::{Error, Result};
20use pw_varint::{VarintDecode, VarintEncode};
21
22use super::{Read, Seek, SeekFrom, Write};
23
24/// Wraps an <code>[AsRef]<[u8]></code> in a container implementing
25/// [`Read`], [`Write`], and [`Seek`].
26///
27/// [`Write`] support requires the inner type also implement
28/// <code>[AsMut]<[u8]></code>.
29pub struct Cursor<T>
30where
31    T: AsRef<[u8]>,
32{
33    inner: T,
34    pos: usize,
35}
36
37impl<T: AsRef<[u8]>> Cursor<T> {
38    /// Create a new Cursor wrapping `inner` with an initial position of 0.
39    ///
40    /// Semantics match [`std::io::Cursor::new()`].
41    pub fn new(inner: T) -> Self {
42        Self { inner, pos: 0 }
43    }
44
45    /// Consumes the cursor and returns the inner wrapped data.
46    pub fn into_inner(self) -> T {
47        self.inner
48    }
49
50    /// Returns the number of remaining bytes in the Cursor.
51    pub fn remaining(&self) -> usize {
52        self.len().saturating_sub(self.pos)
53    }
54
55    /// Returns the total length of the Cursor.
56    // Empty is ambiguous whether it should refer to len() or remaining() so
57    // we don't provide it.
58    #[allow(clippy::len_without_is_empty)]
59    pub fn len(&self) -> usize {
60        self.inner.as_ref().len()
61    }
62
63    /// Returns current IO position of the Cursor.
64    pub fn position(&self) -> usize {
65        self.pos
66    }
67
68    fn remaining_slice(&mut self) -> &[u8] {
69        &self.inner.as_ref()[self.pos..]
70    }
71}
72
73impl<T: AsRef<[u8]> + AsMut<[u8]>> Cursor<T> {
74    fn remaining_mut(&mut self) -> &mut [u8] {
75        &mut self.inner.as_mut()[self.pos..]
76    }
77}
78
79// Implement `read()` as a concrete function to avoid extra monomorphization
80// overhead.
81fn read_impl(inner: &[u8], pos: &mut usize, buf: &mut [u8]) -> Result<usize> {
82    let remaining = inner.len() - *pos;
83    let read_len = min(remaining, buf.len());
84    buf[..read_len].copy_from_slice(&inner[*pos..(*pos + read_len)]);
85    *pos += read_len;
86    Ok(read_len)
87}
88
89impl<T: AsRef<[u8]>> Read for Cursor<T> {
90    fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
91        read_impl(self.inner.as_ref(), &mut self.pos, buf)
92    }
93}
94
95// Implement `write()` as a concrete function to avoid extra monomorphization
96// overhead.
97fn write_impl(inner: &mut [u8], pos: &mut usize, buf: &[u8]) -> Result<usize> {
98    let remaining = inner.len().checked_sub(*pos).ok_or(Error::OutOfRange)?;
99    let write_len = min(remaining, buf.len());
100    // Safety: write_len has been bounds checked on buf and inner.
101    // There can't be any overlap as inner is a &mut and buf is a &.
102    unsafe {
103        let src_ptr = buf.as_ptr();
104        let dst_ptr = inner.as_mut_ptr().add(*pos);
105        ptr::copy_nonoverlapping(src_ptr, dst_ptr, write_len);
106    }
107    // This will never saturate as pos is a private field which will
108    // never be larger than inner.len().
109    *pos = pos.saturating_add(write_len);
110    Ok(write_len)
111}
112
113impl<T: AsRef<[u8]> + AsMut<[u8]>> Write for Cursor<T> {
114    fn write(&mut self, buf: &[u8]) -> Result<usize> {
115        write_impl(self.inner.as_mut(), &mut self.pos, buf)
116    }
117
118    fn flush(&mut self) -> Result<()> {
119        // Cursor does not provide any buffering so flush() is a noop.
120        Ok(())
121    }
122}
123
124impl<T: AsRef<[u8]>> Seek for Cursor<T> {
125    fn seek(&mut self, pos: SeekFrom) -> Result<u64> {
126        let new_pos = match pos {
127            SeekFrom::Start(pos) => pos,
128            SeekFrom::Current(pos) => (self.pos as u64)
129                .checked_add_signed(pos)
130                .ok_or(Error::OutOfRange)?,
131            SeekFrom::End(pos) => (self.len() as u64)
132                .checked_add_signed(-pos)
133                .ok_or(Error::OutOfRange)?,
134        };
135
136        // Since Cursor operates on in memory buffers, it's limited by usize.
137        // Return an error if we are asked to seek beyond that limit.
138        let new_pos: usize = new_pos.try_into().map_err(|_| Error::OutOfRange)?;
139
140        if new_pos > self.len() {
141            Err(Error::OutOfRange)
142        } else {
143            self.pos = new_pos;
144            Ok(new_pos as u64)
145        }
146    }
147
148    // Implement more efficient versions of rewind, stream_len, stream_position.
149    fn rewind(&mut self) -> Result<()> {
150        self.pos = 0;
151        Ok(())
152    }
153
154    fn stream_len(&mut self) -> Result<u64> {
155        Ok(self.len() as u64)
156    }
157
158    fn stream_position(&mut self) -> Result<u64> {
159        Ok(self.pos as u64)
160    }
161}
162
163macro_rules! cursor_read_type_impl {
164    ($ty:ident, $endian:ident) => {
165        paste! {
166          fn [<read_ $ty _ $endian>](&mut self) -> Result<$ty> {
167            const NUM_BYTES: usize = $ty::BITS as usize / 8;
168            if NUM_BYTES > self.remaining() {
169                return Err(Error::OutOfRange);
170            }
171            let sub_slice = self
172                .inner
173                .as_ref()
174                .get(self.pos..self.pos + NUM_BYTES)
175                .ok_or_else(|| Error::InvalidArgument)?;
176            // Because we are code size conscious we want an infallible way to
177            // turn `sub_slice` into a fixed sized array as opposed to using
178            // something like `.try_into()?`.
179            //
180            // Safety:  We are both bounds checking and size constraining the
181            // slice in the above lines of code.
182            let sub_array: &[u8; NUM_BYTES] = unsafe { &*(sub_slice.as_ptr() as *const [u8; NUM_BYTES]) };
183            let value = $ty::[<from_ $endian _bytes>](*sub_array);
184
185            self.pos += NUM_BYTES;
186            Ok(value)
187          }
188        }
189    };
190}
191
192macro_rules! cursor_read_bits_impl {
193    ($bits:literal) => {
194        paste! {
195          cursor_read_type_impl!([<i $bits>], le);
196          cursor_read_type_impl!([<u $bits>], le);
197          cursor_read_type_impl!([<i $bits>], be);
198          cursor_read_type_impl!([<u $bits>], be);
199        }
200    };
201}
202
203macro_rules! cursor_write_type_impl {
204    ($ty:ident, $endian:ident) => {
205        paste! {
206          fn [<write_ $ty _ $endian>](&mut self, value: &$ty) -> Result<()> {
207            const NUM_BYTES: usize = $ty::BITS as usize / 8;
208            if NUM_BYTES > self.remaining() {
209                return Err(Error::OutOfRange);
210            }
211            let value_bytes = $ty::[<to_ $endian _bytes>](*value);
212            let sub_slice = self
213                .inner
214                .as_mut()
215                .get_mut(self.pos..self.pos + NUM_BYTES)
216                .ok_or_else(|| Error::InvalidArgument)?;
217
218            sub_slice.copy_from_slice(&value_bytes[..]);
219
220            self.pos += NUM_BYTES;
221            Ok(())
222          }
223        }
224    };
225}
226
227macro_rules! cursor_write_bits_impl {
228    ($bits:literal) => {
229        paste! {
230          cursor_write_type_impl!([<i $bits>], le);
231          cursor_write_type_impl!([<u $bits>], le);
232          cursor_write_type_impl!([<i $bits>], be);
233          cursor_write_type_impl!([<u $bits>], be);
234        }
235    };
236}
237
238impl<T: AsRef<[u8]>> crate::ReadInteger for Cursor<T> {
239    cursor_read_bits_impl!(8);
240    cursor_read_bits_impl!(16);
241    cursor_read_bits_impl!(32);
242    cursor_read_bits_impl!(64);
243    cursor_read_bits_impl!(128);
244}
245
246impl<T: AsRef<[u8]> + AsMut<[u8]>> crate::WriteInteger for Cursor<T> {
247    cursor_write_bits_impl!(8);
248    cursor_write_bits_impl!(16);
249    cursor_write_bits_impl!(32);
250    cursor_write_bits_impl!(64);
251    cursor_write_bits_impl!(128);
252}
253
254impl<T: AsRef<[u8]>> crate::ReadVarint for Cursor<T> {
255    fn read_varint(&mut self) -> Result<u64> {
256        let (len, value) = u64::varint_decode(self.remaining_slice())?;
257        self.pos += len;
258        Ok(value)
259    }
260
261    fn read_signed_varint(&mut self) -> Result<i64> {
262        let (len, value) = i64::varint_decode(self.remaining_slice())?;
263        self.pos += len;
264        Ok(value)
265    }
266}
267
268impl<T: AsRef<[u8]> + AsMut<[u8]>> crate::WriteVarint for Cursor<T> {
269    fn write_varint(&mut self, value: u64) -> Result<()> {
270        let encoded_len = value.varint_encode(self.remaining_mut())?;
271        self.pos += encoded_len;
272        Ok(())
273    }
274
275    fn write_signed_varint(&mut self, value: i64) -> Result<()> {
276        let encoded_len = value.varint_encode(self.remaining_mut())?;
277        self.pos += encoded_len;
278        Ok(())
279    }
280}
281
282#[cfg(test)]
283mod tests {
284    use super::*;
285    use crate::test_utils::*;
286    use crate::{ReadInteger, ReadVarint, WriteInteger, WriteVarint};
287
288    #[test]
289    fn cursor_len_returns_total_bytes() {
290        let cursor = Cursor {
291            inner: &[0u8; 64],
292            pos: 31,
293        };
294        assert_eq!(cursor.len(), 64);
295    }
296
297    #[test]
298    fn cursor_remaining_returns_remaining_bytes() {
299        let cursor = Cursor {
300            inner: &[0u8; 64],
301            pos: 31,
302        };
303        assert_eq!(cursor.remaining(), 33);
304    }
305
306    #[test]
307    fn cursor_position_returns_current_position() {
308        let cursor = Cursor {
309            inner: &[0u8; 64],
310            pos: 31,
311        };
312        assert_eq!(cursor.position(), 31);
313    }
314
315    #[test]
316    fn cursor_read_of_partial_buffer_reads_correct_data() {
317        let mut cursor = Cursor {
318            inner: &[1, 2, 3, 4, 5, 6, 7, 8],
319            pos: 4,
320        };
321        let mut buf = [0u8; 8];
322        assert_eq!(cursor.read(&mut buf), Ok(4));
323        assert_eq!(buf, [5, 6, 7, 8, 0, 0, 0, 0]);
324    }
325
326    #[test]
327    fn cursor_write_of_partial_buffer_writes_correct_data() {
328        let mut cursor = Cursor {
329            inner: &mut [0, 0, 0, 0, 0, 0, 0, 0],
330            pos: 4,
331        };
332        let buf = [1, 2, 3, 4, 5, 6, 7, 8];
333        assert_eq!(cursor.write(&buf), Ok(4));
334        assert_eq!(cursor.inner, &[0, 0, 0, 0, 1, 2, 3, 4]);
335    }
336
337    #[test]
338    fn cursor_rewind_resets_position_to_zero() {
339        test_rewind_resets_position_to_zero::<64, _>(Cursor::new(&[0u8; 64]));
340    }
341
342    #[test]
343    fn cursor_stream_pos_reports_correct_position() {
344        test_stream_pos_reports_correct_position::<64, _>(Cursor::new(&[0u8; 64]));
345    }
346
347    #[test]
348    fn cursor_stream_len_reports_correct_length() {
349        test_stream_len_reports_correct_length::<64, _>(Cursor::new(&[0u8; 64]));
350    }
351
352    macro_rules! cursor_read_n_bit_integers_unpacks_data_correctly {
353        ($bits:literal) => {
354            paste! {
355              #[test]
356              fn [<cursor_read_ $bits _bit_integers_unpacks_data_correctly>]() {
357                  let (bytes, values) = [<integer_ $bits _bit_test_cases>]();
358                  let mut cursor = Cursor::new(&bytes);
359
360                  assert_eq!(cursor.[<read_i $bits _le>](), Ok(values.0));
361                  assert_eq!(cursor.[<read_u $bits _le>](), Ok(values.1));
362                  assert_eq!(cursor.[<read_i $bits _be>](), Ok(values.2));
363                  assert_eq!(cursor.[<read_u $bits _be>](), Ok(values.3));
364              }
365            }
366        };
367    }
368
369    macro_rules! cursor_write_n_bit_integers_packs_data_correctly {
370        ($bits:literal) => {
371            paste! {
372              #[test]
373              fn [<cursor_write_ $bits _bit_integers_packs_data_correctly>]() {
374                  let (expected_bytes, values) = [<integer_ $bits _bit_test_cases>]();
375                  let mut cursor = Cursor::new(vec![0u8; expected_bytes.len()]);
376                  cursor.[<write_i $bits _le>](&values.0).unwrap();
377                  cursor.[<write_u $bits _le>](&values.1).unwrap();
378                  cursor.[<write_i $bits _be>](&values.2).unwrap();
379                  cursor.[<write_u $bits _be>](&values.3).unwrap();
380
381                  let result_bytes: Vec<u8> = cursor.into_inner().into();
382
383                  assert_eq!(result_bytes, expected_bytes);
384              }
385            }
386        };
387    }
388
389    fn integer_8_bit_test_cases() -> (Vec<u8>, (i8, u8, i8, u8)) {
390        (
391            vec![
392                0x0, // le i8
393                0x1, // le u8
394                0x2, // be i8
395                0x3, // be u8
396            ],
397            (0, 1, 2, 3),
398        )
399    }
400
401    cursor_read_n_bit_integers_unpacks_data_correctly!(8);
402    cursor_write_n_bit_integers_packs_data_correctly!(8);
403
404    fn integer_16_bit_test_cases() -> (Vec<u8>, (i16, u16, i16, u16)) {
405        (
406            vec![
407                0x0, 0x80, // le i16
408                0x1, 0x80, // le u16
409                0x80, 0x2, // be i16
410                0x80, 0x3, // be u16
411            ],
412            (
413                i16::from_le_bytes([0x0, 0x80]),
414                0x8001,
415                i16::from_be_bytes([0x80, 0x2]),
416                0x8003,
417            ),
418        )
419    }
420
421    cursor_read_n_bit_integers_unpacks_data_correctly!(16);
422    cursor_write_n_bit_integers_packs_data_correctly!(16);
423
424    fn integer_32_bit_test_cases() -> (Vec<u8>, (i32, u32, i32, u32)) {
425        (
426            vec![
427                0x0, 0x1, 0x2, 0x80, // le i32
428                0x3, 0x4, 0x5, 0x80, // le u32
429                0x80, 0x6, 0x7, 0x8, // be i32
430                0x80, 0x9, 0xa, 0xb, // be u32
431            ],
432            (
433                i32::from_le_bytes([0x0, 0x1, 0x2, 0x80]),
434                0x8005_0403,
435                i32::from_be_bytes([0x80, 0x6, 0x7, 0x8]),
436                0x8009_0a0b,
437            ),
438        )
439    }
440
441    cursor_read_n_bit_integers_unpacks_data_correctly!(32);
442    cursor_write_n_bit_integers_packs_data_correctly!(32);
443
444    fn integer_64_bit_test_cases() -> (Vec<u8>, (i64, u64, i64, u64)) {
445        (
446            vec![
447                0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x80, // le i64
448                0x7, 0x8, 0x9, 0xa, 0xb, 0xc, 0xd, 0x80, // le u64
449                0x80, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, // be i64
450                0x80, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, // be u64
451            ],
452            (
453                i64::from_le_bytes([0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x80]),
454                0x800d_0c0b_0a09_0807,
455                i64::from_be_bytes([0x80, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16]),
456                0x8017_1819_1a1b_1c1d,
457            ),
458        )
459    }
460
461    cursor_read_n_bit_integers_unpacks_data_correctly!(64);
462    cursor_write_n_bit_integers_packs_data_correctly!(64);
463
464    fn integer_128_bit_test_cases() -> (Vec<u8>, (i128, u128, i128, u128)) {
465        #[rustfmt::skip]
466        let val = (
467            vec![
468                // le i128
469                0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
470                0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x8f,
471                // le u128
472                0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17,
473                0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x8f,
474                // be i128
475                0x80, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27,
476                0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f,
477                // be u128
478                0x80, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37,
479                0x38, 0x39, 0x3a, 0x3b, 0x3c, 0x3d, 0x3e, 0x3f,
480            ],
481            (
482                i128::from_le_bytes([
483                    0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
484                    0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x8f,
485                ]),
486                0x8f1e_1d1c_1b1a_1918_1716_1514_1312_1110,
487                i128::from_be_bytes([
488                    0x80, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27,
489                    0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f,
490                ]),
491                0x8031_3233_3435_3637_3839_3a3b_3c3d_3e3f,
492            ),
493        );
494        val
495    }
496
497    cursor_read_n_bit_integers_unpacks_data_correctly!(128);
498    cursor_write_n_bit_integers_packs_data_correctly!(128);
499
500    #[test]
501    pub fn read_varint_unpacks_data_correctly() {
502        let mut cursor = Cursor::new(vec![0xfe, 0xff, 0xff, 0xff, 0x0f, 0x0, 0x0, 0x0]);
503        let value = cursor.read_varint().unwrap();
504        assert_eq!(value, 0xffff_fffe);
505
506        let mut cursor = Cursor::new(vec![0xff, 0xff, 0xff, 0xff, 0x0f, 0x0, 0x0, 0x0]);
507        let value = cursor.read_varint().unwrap();
508        assert_eq!(value, 0xffff_ffff);
509    }
510
511    #[test]
512    pub fn read_signed_varint_unpacks_data_correctly() {
513        let mut cursor = Cursor::new(vec![0xfe, 0xff, 0xff, 0xff, 0x0f, 0x0, 0x0, 0x0]);
514        let value = cursor.read_signed_varint().unwrap();
515        assert_eq!(value, i32::MAX.into());
516
517        let mut cursor = Cursor::new(vec![0xff, 0xff, 0xff, 0xff, 0x0f, 0x0, 0x0, 0x0]);
518        let value = cursor.read_signed_varint().unwrap();
519        assert_eq!(value, i32::MIN.into());
520    }
521
522    #[test]
523    pub fn write_varint_packs_data_correctly() {
524        let mut cursor = Cursor::new(vec![0u8; 8]);
525        cursor.write_varint(0xffff_fffe).unwrap();
526        let buf = cursor.into_inner();
527        assert_eq!(buf, vec![0xfe, 0xff, 0xff, 0xff, 0x0f, 0x0, 0x0, 0x0]);
528
529        let mut cursor = Cursor::new(vec![0u8; 8]);
530        cursor.write_varint(0xffff_ffff).unwrap();
531        let buf = cursor.into_inner();
532        assert_eq!(buf, vec![0xff, 0xff, 0xff, 0xff, 0x0f, 0x0, 0x0, 0x0]);
533    }
534
535    #[test]
536    pub fn write_signed_varint_packs_data_correctly() {
537        let mut cursor = Cursor::new(vec![0u8; 8]);
538        cursor.write_signed_varint(i32::MAX.into()).unwrap();
539        let buf = cursor.into_inner();
540        assert_eq!(buf, vec![0xfe, 0xff, 0xff, 0xff, 0x0f, 0x0, 0x0, 0x0]);
541
542        let mut cursor = Cursor::new(vec![0u8; 8]);
543        cursor.write_signed_varint(i32::MIN.into()).unwrap();
544        let buf = cursor.into_inner();
545        assert_eq!(buf, vec![0xff, 0xff, 0xff, 0xff, 0x0f, 0x0, 0x0, 0x0]);
546    }
547}