pw_stream/
cursor.rs

1// Copyright 2023 The Pigweed Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may not
4// use this file except in compliance with the License. You may obtain a copy of
5// the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12// License for the specific language governing permissions and limitations under
13// the License.
14
15use core::cmp::min;
16use core::ptr;
17
18use paste::paste;
19use pw_status::{Error, Result};
20use pw_varint::{VarintDecode, VarintEncode};
21
22use super::{Read, Seek, SeekFrom, Write};
23
24/// Wraps an <code>[AsRef]<[u8]></code> in a container implementing
25/// [`Read`], [`Write`], and [`Seek`].
26///
27/// [`Write`] support requires the inner type also implement
28/// <code>[AsMut]<[u8]></code>.
29pub struct Cursor<T>
30where
31    T: AsRef<[u8]>,
32{
33    inner: T,
34    pos: usize,
35}
36
37impl<T: AsRef<[u8]>> Cursor<T> {
38    /// Create a new Cursor wrapping `inner` with an initial position of 0.
39    ///
40    /// Semantics match [`std::io::Cursor::new()`].
41    pub fn new(inner: T) -> Self {
42        Self { inner, pos: 0 }
43    }
44
45    /// Consumes the cursor and returns the inner wrapped data.
46    pub fn into_inner(self) -> T {
47        self.inner
48    }
49
50    /// Returns the number of remaining bytes in the Cursor.
51    pub fn remaining(&self) -> usize {
52        self.len().saturating_sub(self.pos)
53    }
54
55    /// Returns the total length of the Cursor.
56    // Empty is ambiguous whether it should refer to len() or remaining() so
57    // we don't provide it.
58    #[allow(clippy::len_without_is_empty)]
59    pub fn len(&self) -> usize {
60        self.inner.as_ref().len()
61    }
62
63    /// Returns current IO position of the Cursor.
64    pub fn position(&self) -> usize {
65        self.pos
66    }
67
68    fn remaining_slice(&mut self) -> &[u8] {
69        &self.inner.as_ref()[self.pos..]
70    }
71}
72
73impl<T: AsRef<[u8]> + AsMut<[u8]>> Cursor<T> {
74    fn remaining_mut(&mut self) -> &mut [u8] {
75        &mut self.inner.as_mut()[self.pos..]
76    }
77}
78
79// Implement `read()` as a concrete function to avoid extra monomorphization
80// overhead.
81fn read_impl(inner: &[u8], pos: &mut usize, buf: &mut [u8]) -> Result<usize> {
82    let remaining = inner.len() - *pos;
83    let read_len = min(remaining, buf.len());
84    buf[..read_len].copy_from_slice(&inner[*pos..(*pos + read_len)]);
85    *pos += read_len;
86    Ok(read_len)
87}
88
89impl<T: AsRef<[u8]>> Read for Cursor<T> {
90    fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
91        read_impl(self.inner.as_ref(), &mut self.pos, buf)
92    }
93}
94
95// Implement `write()` as a concrete function to avoid extra monomorphization
96// overhead.
97fn write_impl(inner: &mut [u8], pos: &mut usize, buf: &[u8]) -> Result<usize> {
98    let remaining = inner.len().checked_sub(*pos).ok_or(Error::OutOfRange)?;
99    let write_len = min(remaining, buf.len());
100    // Safety: write_len has been bounds checked on buf and inner.
101    // There can't be any overlap as inner is a &mut and buf is a &.
102    unsafe {
103        let src_ptr = buf.as_ptr();
104        let dst_ptr = inner.as_mut_ptr().add(*pos);
105        ptr::copy_nonoverlapping(src_ptr, dst_ptr, write_len);
106    }
107    // This will never saturate as pos is a private field which will
108    // never be larger than inner.len().
109    *pos = pos.saturating_add(write_len);
110    Ok(write_len)
111}
112
113impl<T: AsRef<[u8]> + AsMut<[u8]>> Write for Cursor<T> {
114    fn write(&mut self, buf: &[u8]) -> Result<usize> {
115        write_impl(self.inner.as_mut(), &mut self.pos, buf)
116    }
117
118    fn flush(&mut self) -> Result<()> {
119        // Cursor does not provide any buffering so flush() is a noop.
120        Ok(())
121    }
122}
123
124impl<T: AsRef<[u8]>> Seek for Cursor<T> {
125    fn seek(&mut self, pos: SeekFrom) -> Result<u64> {
126        let new_pos = match pos {
127            SeekFrom::Start(pos) => pos,
128            SeekFrom::Current(pos) => (self.pos as u64)
129                .checked_add_signed(pos)
130                .ok_or(Error::OutOfRange)?,
131            SeekFrom::End(pos) => (self.len() as u64)
132                .checked_add_signed(-pos)
133                .ok_or(Error::OutOfRange)?,
134        };
135
136        // Since Cursor operates on in memory buffers, it's limited by usize.
137        // Return an error if we are asked to seek beyond that limit.
138        let new_pos: usize = new_pos.try_into().map_err(|_| Error::OutOfRange)?;
139
140        if new_pos > self.len() {
141            Err(Error::OutOfRange)
142        } else {
143            self.pos = new_pos;
144            Ok(new_pos as u64)
145        }
146    }
147
148    // Implement more efficient versions of rewind, stream_len, stream_position.
149    fn rewind(&mut self) -> Result<()> {
150        self.pos = 0;
151        Ok(())
152    }
153
154    fn stream_len(&mut self) -> Result<u64> {
155        Ok(self.len() as u64)
156    }
157
158    fn stream_position(&mut self) -> Result<u64> {
159        Ok(self.pos as u64)
160    }
161}
162
163macro_rules! cursor_read_type_impl {
164    ($ty:ident, $endian:ident) => {
165        paste! {
166          fn [<read_ $ty _ $endian>](&mut self) -> Result<$ty> {
167            const NUM_BYTES: usize = $ty::BITS as usize / 8;
168            if NUM_BYTES > self.remaining() {
169                return Err(Error::OutOfRange);
170            }
171            let sub_slice = self
172                .inner
173                .as_ref()
174                .get(self.pos..self.pos + NUM_BYTES)
175                .ok_or_else(|| Error::InvalidArgument)?;
176            // Because we are code size conscious we want an infallible way to
177            // turn `sub_slice` into a fixed sized array as opposed to using
178            // something like `.try_into()?`.
179            //
180            // Safety:  We are both bounds checking and size constraining the
181            // slice in the above lines of code.
182            let sub_array: &[u8; NUM_BYTES] = unsafe { &*(sub_slice.as_ptr() as *const [u8; NUM_BYTES]) };
183            let value = $ty::[<from_ $endian _bytes>](*sub_array);
184
185            self.pos += NUM_BYTES;
186            Ok(value)
187          }
188        }
189    };
190}
191
192macro_rules! cursor_read_bits_impl {
193    ($bits:literal) => {
194        paste! {
195          cursor_read_type_impl!([<i $bits>], le);
196          cursor_read_type_impl!([<u $bits>], le);
197          cursor_read_type_impl!([<i $bits>], be);
198          cursor_read_type_impl!([<u $bits>], be);
199        }
200    };
201}
202
203macro_rules! cursor_write_type_impl {
204    ($ty:ident, $endian:ident) => {
205        paste! {
206          fn [<write_ $ty _ $endian>](&mut self, value: &$ty) -> Result<()> {
207            const NUM_BYTES: usize = $ty::BITS as usize / 8;
208            if NUM_BYTES > self.remaining() {
209                return Err(Error::OutOfRange);
210            }
211            let value_bytes = $ty::[<to_ $endian _bytes>](*value);
212            let sub_slice = self
213                .inner
214                .as_mut()
215                .get_mut(self.pos..self.pos + NUM_BYTES)
216                .ok_or_else(|| Error::InvalidArgument)?;
217
218            sub_slice.copy_from_slice(&value_bytes[..]);
219
220            self.pos += NUM_BYTES;
221            Ok(())
222          }
223        }
224    };
225}
226
227macro_rules! cursor_write_bits_impl {
228    ($bits:literal) => {
229        paste! {
230          cursor_write_type_impl!([<i $bits>], le);
231          cursor_write_type_impl!([<u $bits>], le);
232          cursor_write_type_impl!([<i $bits>], be);
233          cursor_write_type_impl!([<u $bits>], be);
234        }
235    };
236}
237
238impl<T: AsRef<[u8]>> crate::ReadInteger for Cursor<T> {
239    cursor_read_bits_impl!(8);
240    cursor_read_bits_impl!(16);
241    cursor_read_bits_impl!(32);
242    cursor_read_bits_impl!(64);
243    cursor_read_bits_impl!(128);
244}
245
246impl<T: AsRef<[u8]> + AsMut<[u8]>> crate::WriteInteger for Cursor<T> {
247    cursor_write_bits_impl!(8);
248    cursor_write_bits_impl!(16);
249    cursor_write_bits_impl!(32);
250    cursor_write_bits_impl!(64);
251    cursor_write_bits_impl!(128);
252}
253
254impl<T: AsRef<[u8]>> crate::ReadVarint for Cursor<T> {
255    fn read_varint(&mut self) -> Result<u64> {
256        let (len, value) = u64::varint_decode(self.remaining_slice())?;
257        self.pos += len;
258        Ok(value)
259    }
260
261    fn read_signed_varint(&mut self) -> Result<i64> {
262        let (len, value) = i64::varint_decode(self.remaining_slice())?;
263        self.pos += len;
264        Ok(value)
265    }
266}
267
268impl<T: AsRef<[u8]> + AsMut<[u8]>> crate::WriteVarint for Cursor<T> {
269    fn write_varint(&mut self, value: u64) -> Result<()> {
270        let encoded_len = value.varint_encode(self.remaining_mut())?;
271        self.pos += encoded_len;
272        Ok(())
273    }
274
275    fn write_signed_varint(&mut self, value: i64) -> Result<()> {
276        let encoded_len = value.varint_encode(self.remaining_mut())?;
277        self.pos += encoded_len;
278        Ok(())
279    }
280}
281
282#[cfg(test)]
283mod tests {
284    use super::*;
285    use crate::{test_utils::*, ReadInteger, ReadVarint, WriteInteger, WriteVarint};
286
287    #[test]
288    fn cursor_len_returns_total_bytes() {
289        let cursor = Cursor {
290            inner: &[0u8; 64],
291            pos: 31,
292        };
293        assert_eq!(cursor.len(), 64);
294    }
295
296    #[test]
297    fn cursor_remaining_returns_remaining_bytes() {
298        let cursor = Cursor {
299            inner: &[0u8; 64],
300            pos: 31,
301        };
302        assert_eq!(cursor.remaining(), 33);
303    }
304
305    #[test]
306    fn cursor_position_returns_current_position() {
307        let cursor = Cursor {
308            inner: &[0u8; 64],
309            pos: 31,
310        };
311        assert_eq!(cursor.position(), 31);
312    }
313
314    #[test]
315    fn cursor_read_of_partial_buffer_reads_correct_data() {
316        let mut cursor = Cursor {
317            inner: &[1, 2, 3, 4, 5, 6, 7, 8],
318            pos: 4,
319        };
320        let mut buf = [0u8; 8];
321        assert_eq!(cursor.read(&mut buf), Ok(4));
322        assert_eq!(buf, [5, 6, 7, 8, 0, 0, 0, 0]);
323    }
324
325    #[test]
326    fn cursor_write_of_partial_buffer_writes_correct_data() {
327        let mut cursor = Cursor {
328            inner: &mut [0, 0, 0, 0, 0, 0, 0, 0],
329            pos: 4,
330        };
331        let buf = [1, 2, 3, 4, 5, 6, 7, 8];
332        assert_eq!(cursor.write(&buf), Ok(4));
333        assert_eq!(cursor.inner, &[0, 0, 0, 0, 1, 2, 3, 4]);
334    }
335
336    #[test]
337    fn cursor_rewind_resets_position_to_zero() {
338        test_rewind_resets_position_to_zero::<64, _>(Cursor::new(&[0u8; 64]));
339    }
340
341    #[test]
342    fn cursor_stream_pos_reports_correct_position() {
343        test_stream_pos_reports_correct_position::<64, _>(Cursor::new(&[0u8; 64]));
344    }
345
346    #[test]
347    fn cursor_stream_len_reports_correct_length() {
348        test_stream_len_reports_correct_length::<64, _>(Cursor::new(&[0u8; 64]));
349    }
350
351    macro_rules! cursor_read_n_bit_integers_unpacks_data_correctly {
352        ($bits:literal) => {
353            paste! {
354              #[test]
355              fn [<cursor_read_ $bits _bit_integers_unpacks_data_correctly>]() {
356                  let (bytes, values) = [<integer_ $bits _bit_test_cases>]();
357                  let mut cursor = Cursor::new(&bytes);
358
359                  assert_eq!(cursor.[<read_i $bits _le>](), Ok(values.0));
360                  assert_eq!(cursor.[<read_u $bits _le>](), Ok(values.1));
361                  assert_eq!(cursor.[<read_i $bits _be>](), Ok(values.2));
362                  assert_eq!(cursor.[<read_u $bits _be>](), Ok(values.3));
363              }
364            }
365        };
366    }
367
368    macro_rules! cursor_write_n_bit_integers_packs_data_correctly {
369        ($bits:literal) => {
370            paste! {
371              #[test]
372              fn [<cursor_write_ $bits _bit_integers_packs_data_correctly>]() {
373                  let (expected_bytes, values) = [<integer_ $bits _bit_test_cases>]();
374                  let mut cursor = Cursor::new(vec![0u8; expected_bytes.len()]);
375                  cursor.[<write_i $bits _le>](&values.0).unwrap();
376                  cursor.[<write_u $bits _le>](&values.1).unwrap();
377                  cursor.[<write_i $bits _be>](&values.2).unwrap();
378                  cursor.[<write_u $bits _be>](&values.3).unwrap();
379
380                  let result_bytes: Vec<u8> = cursor.into_inner().into();
381
382                  assert_eq!(result_bytes, expected_bytes);
383              }
384            }
385        };
386    }
387
388    fn integer_8_bit_test_cases() -> (Vec<u8>, (i8, u8, i8, u8)) {
389        (
390            vec![
391                0x0, // le i8
392                0x1, // le u8
393                0x2, // be i8
394                0x3, // be u8
395            ],
396            (0, 1, 2, 3),
397        )
398    }
399
400    cursor_read_n_bit_integers_unpacks_data_correctly!(8);
401    cursor_write_n_bit_integers_packs_data_correctly!(8);
402
403    fn integer_16_bit_test_cases() -> (Vec<u8>, (i16, u16, i16, u16)) {
404        (
405            vec![
406                0x0, 0x80, // le i16
407                0x1, 0x80, // le u16
408                0x80, 0x2, // be i16
409                0x80, 0x3, // be u16
410            ],
411            (
412                i16::from_le_bytes([0x0, 0x80]),
413                0x8001,
414                i16::from_be_bytes([0x80, 0x2]),
415                0x8003,
416            ),
417        )
418    }
419
420    cursor_read_n_bit_integers_unpacks_data_correctly!(16);
421    cursor_write_n_bit_integers_packs_data_correctly!(16);
422
423    fn integer_32_bit_test_cases() -> (Vec<u8>, (i32, u32, i32, u32)) {
424        (
425            vec![
426                0x0, 0x1, 0x2, 0x80, // le i32
427                0x3, 0x4, 0x5, 0x80, // le u32
428                0x80, 0x6, 0x7, 0x8, // be i32
429                0x80, 0x9, 0xa, 0xb, // be u32
430            ],
431            (
432                i32::from_le_bytes([0x0, 0x1, 0x2, 0x80]),
433                0x8005_0403,
434                i32::from_be_bytes([0x80, 0x6, 0x7, 0x8]),
435                0x8009_0a0b,
436            ),
437        )
438    }
439
440    cursor_read_n_bit_integers_unpacks_data_correctly!(32);
441    cursor_write_n_bit_integers_packs_data_correctly!(32);
442
443    fn integer_64_bit_test_cases() -> (Vec<u8>, (i64, u64, i64, u64)) {
444        (
445            vec![
446                0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x80, // le i64
447                0x7, 0x8, 0x9, 0xa, 0xb, 0xc, 0xd, 0x80, // le u64
448                0x80, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, // be i64
449                0x80, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, // be u64
450            ],
451            (
452                i64::from_le_bytes([0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x80]),
453                0x800d_0c0b_0a09_0807,
454                i64::from_be_bytes([0x80, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16]),
455                0x8017_1819_1a1b_1c1d,
456            ),
457        )
458    }
459
460    cursor_read_n_bit_integers_unpacks_data_correctly!(64);
461    cursor_write_n_bit_integers_packs_data_correctly!(64);
462
463    fn integer_128_bit_test_cases() -> (Vec<u8>, (i128, u128, i128, u128)) {
464        #[rustfmt::skip]
465        let val = (
466            vec![
467                // le i128
468                0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
469                0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x8f,
470                // le u128
471                0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17,
472                0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x8f,
473                // be i128
474                0x80, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27,
475                0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f,
476                // be u128
477                0x80, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37,
478                0x38, 0x39, 0x3a, 0x3b, 0x3c, 0x3d, 0x3e, 0x3f,
479            ],
480            (
481                i128::from_le_bytes([
482                    0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
483                    0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x8f,
484                ]),
485                0x8f1e_1d1c_1b1a_1918_1716_1514_1312_1110,
486                i128::from_be_bytes([
487                    0x80, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27,
488                    0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f,
489                ]),
490                0x8031_3233_3435_3637_3839_3a3b_3c3d_3e3f,
491            ),
492        );
493        val
494    }
495
496    cursor_read_n_bit_integers_unpacks_data_correctly!(128);
497    cursor_write_n_bit_integers_packs_data_correctly!(128);
498
499    #[test]
500    pub fn read_varint_unpacks_data_correctly() {
501        let mut cursor = Cursor::new(vec![0xfe, 0xff, 0xff, 0xff, 0x0f, 0x0, 0x0, 0x0]);
502        let value = cursor.read_varint().unwrap();
503        assert_eq!(value, 0xffff_fffe);
504
505        let mut cursor = Cursor::new(vec![0xff, 0xff, 0xff, 0xff, 0x0f, 0x0, 0x0, 0x0]);
506        let value = cursor.read_varint().unwrap();
507        assert_eq!(value, 0xffff_ffff);
508    }
509
510    #[test]
511    pub fn read_signed_varint_unpacks_data_correctly() {
512        let mut cursor = Cursor::new(vec![0xfe, 0xff, 0xff, 0xff, 0x0f, 0x0, 0x0, 0x0]);
513        let value = cursor.read_signed_varint().unwrap();
514        assert_eq!(value, i32::MAX.into());
515
516        let mut cursor = Cursor::new(vec![0xff, 0xff, 0xff, 0xff, 0x0f, 0x0, 0x0, 0x0]);
517        let value = cursor.read_signed_varint().unwrap();
518        assert_eq!(value, i32::MIN.into());
519    }
520
521    #[test]
522    pub fn write_varint_packs_data_correctly() {
523        let mut cursor = Cursor::new(vec![0u8; 8]);
524        cursor.write_varint(0xffff_fffe).unwrap();
525        let buf = cursor.into_inner();
526        assert_eq!(buf, vec![0xfe, 0xff, 0xff, 0xff, 0x0f, 0x0, 0x0, 0x0]);
527
528        let mut cursor = Cursor::new(vec![0u8; 8]);
529        cursor.write_varint(0xffff_ffff).unwrap();
530        let buf = cursor.into_inner();
531        assert_eq!(buf, vec![0xff, 0xff, 0xff, 0xff, 0x0f, 0x0, 0x0, 0x0]);
532    }
533
534    #[test]
535    pub fn write_signed_varint_packs_data_correctly() {
536        let mut cursor = Cursor::new(vec![0u8; 8]);
537        cursor.write_signed_varint(i32::MAX.into()).unwrap();
538        let buf = cursor.into_inner();
539        assert_eq!(buf, vec![0xfe, 0xff, 0xff, 0xff, 0x0f, 0x0, 0x0, 0x0]);
540
541        let mut cursor = Cursor::new(vec![0u8; 8]);
542        cursor.write_signed_varint(i32::MIN.into()).unwrap();
543        let buf = cursor.into_inner();
544        assert_eq!(buf, vec![0xff, 0xff, 0xff, 0xff, 0x0f, 0x0, 0x0, 0x0]);
545    }
546}