pw_stream/
lib.rs

1// Copyright 2023 The Pigweed Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may not
4// use this file except in compliance with the License. You may obtain a copy of
5// the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12// License for the specific language governing permissions and limitations under
13// the License.
14
15//! `pw_stream` provides `no_std` versions of Rust's [`std::io::Read`],
16//! [`std::io::Write`], and [`std::io::Seek`] traits as well as a simplified
17//! version of [`std::io::Cursor`].  One notable difference is that
18//! [`pw_status::Error`] is used to avoid needing to do error conversion or
19//! encapsulation.
20#![deny(missing_docs)]
21// Allows docs to reference `std`
22#![cfg_attr(feature = "no_std", no_std)]
23
24use pw_status::{Error, Result};
25
26#[doc(hidden)]
27mod cursor;
28mod integer;
29
30pub use cursor::Cursor;
31pub use integer::{ReadInteger, ReadVarint, WriteInteger, WriteVarint};
32
33/// A trait for objects that provide streaming read capability.
34pub trait Read {
35    /// Read from a stream into a buffer.
36    ///
37    /// Semantics match [`std::io::Read::read()`].
38    fn read(&mut self, buf: &mut [u8]) -> Result<usize>;
39
40    /// Read exactly enough bytes to fill the buffer.
41    ///
42    /// Semantics match [`std::io::Read::read_exact()`].
43    fn read_exact(&mut self, mut buf: &mut [u8]) -> Result<()> {
44        while !buf.is_empty() {
45            let len = self.read(buf)?;
46
47            // End of stream
48            if len == 0 {
49                break;
50            }
51
52            buf = &mut buf[len..];
53        }
54
55        if !buf.is_empty() {
56            Err(Error::OutOfRange)
57        } else {
58            Ok(())
59        }
60    }
61}
62
63/// A trait for objects that provide streaming write capability.
64pub trait Write {
65    /// Write a buffer to a stream.
66    ///
67    /// Semantics match [`std::io::Write::write()`].
68    fn write(&mut self, buf: &[u8]) -> Result<usize>;
69
70    /// Commit any outstanding buffered writes to underlying storage.
71    ///
72    /// Semantics match [`std::io::Write::flush()`].
73    fn flush(&mut self) -> Result<()>;
74
75    /// Writes entire buffer to stream.
76    ///
77    /// Semantics match [`std::io::Write::write_all()`].
78    fn write_all(&mut self, mut buf: &[u8]) -> Result<()> {
79        while !buf.is_empty() {
80            let len = self.write(buf)?;
81
82            // End of stream
83            if len == 0 {
84                break;
85            }
86
87            buf = &buf[len..];
88        }
89
90        if !buf.is_empty() {
91            Err(Error::OutOfRange)
92        } else {
93            Ok(())
94        }
95    }
96}
97
98/// A description of a seek operation in a stream.
99///
100/// While `pw_stream` targets embedded platforms which are often natively
101/// 32 bit, we believe that seek operation are relatively rare and the added
102/// overhead of using 64 bit values for seeks is balanced by the ability
103/// to support objects and operations over 4 GiB.
104pub enum SeekFrom {
105    /// Seek from the start of the stream.
106    Start(u64),
107
108    /// Seek from the end of the stream.
109    End(i64),
110
111    /// Seek from the current position of the stream.
112    Current(i64),
113}
114
115/// A trait for objects that provide the ability to seek withing a stream.
116pub trait Seek {
117    /// Adjust the current position of the stream.
118    ///
119    /// Semantics match [`std::io::Seek::seek()`].
120    fn seek(&mut self, pos: SeekFrom) -> Result<u64>;
121
122    /// Set the current position of the stream to its beginning.
123    ///
124    /// Semantics match [`std::io::Seek::rewind()`].
125    fn rewind(&mut self) -> Result<()> {
126        self.seek(SeekFrom::Start(0)).map(|_| ())
127    }
128
129    /// Returns the length of the stream.
130    ///
131    /// Semantics match [`std::io::Seek::stream_len()`].
132    fn stream_len(&mut self) -> Result<u64> {
133        // Save original position.
134        let orig_pos = self.seek(SeekFrom::Current(0))?;
135
136        // Seed to the end to discover stream length.
137        let end_pos = self.seek(SeekFrom::End(0))?;
138
139        // Go back to original position.
140        self.seek(SeekFrom::Start(orig_pos))?;
141
142        Ok(end_pos)
143    }
144
145    /// Returns the current position of the stream.
146    ///
147    /// Semantics match [`std::io::Seek::stream_position()`].
148    fn stream_position(&mut self) -> Result<u64> {
149        self.seek(SeekFrom::Current(0))
150    }
151}
152
153#[cfg(test)]
154pub(crate) mod test_utils {
155    use super::{Seek, SeekFrom};
156
157    pub(crate) fn test_rewind_resets_position_to_zero<const LEN: u64, T: Seek>(mut seeker: T) {
158        seeker.seek(SeekFrom::Current(LEN as i64 / 2)).unwrap();
159        assert_eq!(seeker.stream_position().unwrap(), LEN / 2);
160        seeker.rewind().unwrap();
161        assert_eq!(seeker.stream_position().unwrap(), 0);
162    }
163
164    pub(crate) fn test_stream_pos_reports_correct_position<const LEN: u64, T: Seek>(mut seeker: T) {
165        assert_eq!(seeker.stream_position().unwrap(), 0);
166        seeker.seek(SeekFrom::Current(1)).unwrap();
167        assert_eq!(seeker.stream_position().unwrap(), 1);
168        seeker.seek(SeekFrom::Current(LEN as i64 / 2 - 1)).unwrap();
169        assert_eq!(seeker.stream_position().unwrap(), LEN / 2);
170        seeker.seek(SeekFrom::Current(0)).unwrap();
171        assert_eq!(seeker.stream_position().unwrap(), LEN / 2);
172        seeker.seek(SeekFrom::End(0)).unwrap();
173        assert_eq!(seeker.stream_position().unwrap(), LEN);
174    }
175
176    pub(crate) fn test_stream_len_reports_correct_length<const LEN: u64, T: Seek>(mut seeker: T) {
177        assert_eq!(seeker.stream_len().unwrap(), LEN);
178    }
179}
180
181#[cfg(test)]
182mod tests {
183    use core::cmp::min;
184
185    use super::test_utils::*;
186    use super::*;
187
188    struct TestSeeker {
189        len: u64,
190        pos: u64,
191    }
192
193    impl Seek for TestSeeker {
194        fn seek(&mut self, pos: SeekFrom) -> Result<u64> {
195            let new_pos = match pos {
196                SeekFrom::Start(pos) => pos,
197                SeekFrom::Current(pos) => {
198                    self.pos.checked_add_signed(pos).ok_or(Error::OutOfRange)?
199                }
200                SeekFrom::End(pos) => self.len.checked_add_signed(-pos).ok_or(Error::OutOfRange)?,
201            };
202
203            if new_pos > self.len {
204                Err(Error::OutOfRange)
205            } else {
206                self.pos = new_pos;
207                Ok(new_pos)
208            }
209        }
210    }
211
212    // A stream wrapper that limits reads and writes to a maximum chunk size.
213    struct ChunkedStreamAdapter<S: Read + Write + Seek> {
214        inner: S,
215        chunk_size: usize,
216        num_reads: u32,
217        num_writes: u32,
218    }
219
220    impl<S: Read + Write + Seek> ChunkedStreamAdapter<S> {
221        fn new(inner: S, chunk_size: usize) -> Self {
222            Self {
223                inner,
224                chunk_size,
225                num_reads: 0,
226                num_writes: 0,
227            }
228        }
229    }
230
231    impl<S: Read + Write + Seek> Read for ChunkedStreamAdapter<S> {
232        fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
233            let read_len = min(self.chunk_size, buf.len());
234            self.num_reads += 1;
235            self.inner.read(&mut buf[..read_len])
236        }
237    }
238
239    impl<S: Read + Write + Seek> Write for ChunkedStreamAdapter<S> {
240        fn write(&mut self, buf: &[u8]) -> Result<usize> {
241            let write_len = min(self.chunk_size, buf.len());
242            self.num_writes += 1;
243            self.inner.write(&buf[..write_len])
244        }
245
246        fn flush(&mut self) -> Result<()> {
247            self.inner.flush()
248        }
249    }
250
251    struct ErrorStream {
252        error: Error,
253    }
254
255    impl Read for ErrorStream {
256        fn read(&mut self, _buf: &mut [u8]) -> Result<usize> {
257            Err(self.error)
258        }
259    }
260
261    impl Write for ErrorStream {
262        fn write(&mut self, _buf: &[u8]) -> Result<usize> {
263            Err(self.error)
264        }
265
266        fn flush(&mut self) -> Result<()> {
267            Err(self.error)
268        }
269    }
270
271    #[test]
272    fn default_rewind_impl_resets_position_to_zero() {
273        test_rewind_resets_position_to_zero::<64, _>(TestSeeker { len: 64, pos: 0 });
274    }
275
276    #[test]
277    fn default_stream_pos_impl_reports_correct_position() {
278        test_stream_pos_reports_correct_position::<64, _>(TestSeeker { len: 64, pos: 0 });
279    }
280
281    #[test]
282    fn default_stream_len_impl_reports_correct_length() {
283        test_stream_len_reports_correct_length::<64, _>(TestSeeker { len: 64, pos: 32 });
284    }
285
286    #[test]
287    fn read_exact_reads_full_buffer_on_short_reads() {
288        let cursor = Cursor::new((0x0..=0xff).collect::<Vec<u8>>());
289        // Limit reads to 10 bytes per read.
290        let mut wrapper = ChunkedStreamAdapter::new(cursor, 10);
291        let mut read_buffer = vec![0u8; 256];
292
293        wrapper.read_exact(&mut read_buffer).unwrap();
294
295        // Ensure that the correct bytes were read.
296        assert_eq!(wrapper.inner.into_inner(), read_buffer);
297
298        // Verify that the read was broken up into the correct number of reads.
299        assert_eq!(wrapper.num_reads, 26);
300    }
301
302    #[test]
303    fn read_exact_returns_error_on_too_little_data() {
304        let cursor = Cursor::new((0x0..=0x7f).collect::<Vec<u8>>());
305        // Limit reads to 10 bytes per read.
306        let mut wrapper = ChunkedStreamAdapter::new(cursor, 10);
307        let mut read_buffer = vec![0u8; 256];
308
309        assert_eq!(wrapper.read_exact(&mut read_buffer), Err(Error::OutOfRange));
310    }
311
312    #[test]
313    fn read_exact_propagates_read_errors() {
314        let mut error_stream = ErrorStream {
315            error: Error::Internal,
316        };
317        let mut read_buffer = vec![0u8; 256];
318        assert_eq!(
319            error_stream.read_exact(&mut read_buffer),
320            Err(Error::Internal)
321        );
322    }
323
324    #[test]
325    fn write_all_writes_full_buffer_on_short_writes() {
326        let cursor = Cursor::new(vec![0u8; 256]);
327        // Limit writes to 10 bytes per write.
328        let mut wrapper = ChunkedStreamAdapter::new(cursor, 10);
329        let write_buffer = (0x0..=0xff).collect::<Vec<u8>>();
330
331        wrapper.write_all(&write_buffer).unwrap();
332
333        // Ensure that the correct bytes were written.
334        assert_eq!(wrapper.inner.into_inner(), write_buffer);
335
336        // Verify that the write was broken up into the correct number of writes.
337        assert_eq!(wrapper.num_writes, 26);
338    }
339
340    #[test]
341    fn write_all_returns_error_on_too_little_data() {
342        let cursor = Cursor::new(vec![0u8; 128]);
343        // Limit writes to 10 bytes per write.
344        let mut wrapper = ChunkedStreamAdapter::new(cursor, 10);
345        let write_buffer = (0x0..=0xff).collect::<Vec<u8>>();
346
347        assert_eq!(wrapper.write_all(&write_buffer), Err(Error::OutOfRange));
348    }
349
350    #[test]
351    fn write_all_propagates_write_errors() {
352        let mut error_stream = ErrorStream {
353            error: Error::Internal,
354        };
355        let write_buffer = (0x0..=0xff).collect::<Vec<u8>>();
356        assert_eq!(error_stream.write_all(&write_buffer), Err(Error::Internal));
357    }
358}