Skip to main content

pw_stream/
lib.rs

1// Copyright 2023 The Pigweed Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may not
4// use this file except in compliance with the License. You may obtain a copy of
5// the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12// License for the specific language governing permissions and limitations under
13// the License.
14
15//! `pw_stream` provides `no_std` versions of Rust's [`std::io::Read`],
16//! [`std::io::Write`], and [`std::io::Seek`] traits as well as a simplified
17//! version of [`std::io::Cursor`].  One notable difference is that
18//! [`pw_status::Error`] is used to avoid needing to do error conversion or
19//! encapsulation.
20#![deny(missing_docs)]
21// Allows docs to reference `std`
22#![cfg_attr(feature = "no_std", no_std)]
23
24use pw_status::{Error, Result};
25
26#[doc(hidden)]
27mod cursor;
28mod integer;
29
30pub use cursor::Cursor;
31pub use integer::{ReadInteger, ReadVarint, WriteInteger, WriteVarint};
32
33/// A trait for objects that provide streaming read capability.
34pub trait Read {
35    /// Read from a stream into a buffer.
36    ///
37    /// Semantics match [`std::io::Read::read()`].
38    fn read(&mut self, buf: &mut [u8]) -> Result<usize>;
39
40    /// Read exactly enough bytes to fill the buffer.
41    ///
42    /// Semantics match [`std::io::Read::read_exact()`].
43    fn read_exact(&mut self, mut buf: &mut [u8]) -> Result<()> {
44        while !buf.is_empty() {
45            let len = self.read(buf)?;
46
47            // End of stream
48            if len == 0 {
49                break;
50            }
51
52            buf = &mut buf[len..];
53        }
54
55        if !buf.is_empty() {
56            Err(Error::OutOfRange)
57        } else {
58            Ok(())
59        }
60    }
61}
62
63/// A trait for objects that provide streaming write capability.
64pub trait Write {
65    /// Write a buffer to a stream.
66    ///
67    /// Semantics match [`std::io::Write::write()`].
68    fn write(&mut self, buf: &[u8]) -> Result<usize>;
69
70    /// Commit any outstanding buffered writes to underlying storage.
71    ///
72    /// Semantics match [`std::io::Write::flush()`].
73    fn flush(&mut self) -> Result<()>;
74
75    /// Writes entire buffer to stream.
76    ///
77    /// Semantics match [`std::io::Write::write_all()`].
78    fn write_all(&mut self, mut buf: &[u8]) -> Result<()> {
79        while !buf.is_empty() {
80            let len = self.write(buf)?;
81
82            // End of stream
83            if len == 0 {
84                break;
85            }
86
87            buf = buf.get(len..).ok_or(Error::OutOfRange)?;
88        }
89
90        if !buf.is_empty() {
91            Err(Error::OutOfRange)
92        } else {
93            Ok(())
94        }
95    }
96}
97
98/// A description of a seek operation in a stream.
99///
100/// While `pw_stream` targets embedded platforms which are often natively
101/// 32 bit, we believe that seek operation are relatively rare and the added
102/// overhead of using 64 bit values for seeks is balanced by the ability
103/// to support objects and operations over 4 GiB.
104pub enum SeekFrom {
105    /// Seek from the start of the stream.
106    Start(u64),
107
108    /// Seek from the end of the stream.
109    End(i64),
110
111    /// Seek from the current position of the stream.
112    Current(i64),
113}
114
115/// A trait for objects that provide the ability to seek withing a stream.
116pub trait Seek {
117    /// Adjust the current position of the stream.
118    ///
119    /// Semantics match [`std::io::Seek::seek()`].
120    fn seek(&mut self, pos: SeekFrom) -> Result<u64>;
121
122    /// Set the current position of the stream to its beginning.
123    ///
124    /// Semantics match [`std::io::Seek::rewind()`].
125    fn rewind(&mut self) -> Result<()> {
126        self.seek(SeekFrom::Start(0)).map(|_| ())
127    }
128
129    /// Returns the length of the stream.
130    ///
131    /// Semantics match [`std::io::Seek::stream_len()`].
132    fn stream_len(&mut self) -> Result<u64> {
133        // Save original position.
134        let orig_pos = self.seek(SeekFrom::Current(0))?;
135
136        // Seed to the end to discover stream length.
137        let end_pos = self.seek(SeekFrom::End(0))?;
138
139        // Go back to original position.
140        self.seek(SeekFrom::Start(orig_pos))?;
141
142        Ok(end_pos)
143    }
144
145    /// Returns the current position of the stream.
146    ///
147    /// Semantics match [`std::io::Seek::stream_position()`].
148    fn stream_position(&mut self) -> Result<u64> {
149        self.seek(SeekFrom::Current(0))
150    }
151}
152
153#[cfg(test)]
154pub(crate) mod test_utils {
155    use super::{Seek, SeekFrom};
156
157    pub(crate) fn test_rewind_resets_position_to_zero<const LEN: u64, T: Seek>(mut seeker: T) {
158        seeker
159            .seek(SeekFrom::Current((LEN / 2).cast_signed()))
160            .unwrap();
161        assert_eq!(seeker.stream_position().unwrap(), LEN / 2);
162        seeker.rewind().unwrap();
163        assert_eq!(seeker.stream_position().unwrap(), 0);
164    }
165
166    pub(crate) fn test_stream_pos_reports_correct_position<const LEN: u64, T: Seek>(mut seeker: T) {
167        assert_eq!(seeker.stream_position().unwrap(), 0);
168        seeker.seek(SeekFrom::Current(1)).unwrap();
169        assert_eq!(seeker.stream_position().unwrap(), 1);
170        seeker
171            .seek(SeekFrom::Current((LEN / 2).cast_signed() - 1))
172            .unwrap();
173        assert_eq!(seeker.stream_position().unwrap(), LEN / 2);
174        seeker.seek(SeekFrom::Current(0)).unwrap();
175        assert_eq!(seeker.stream_position().unwrap(), LEN / 2);
176        seeker.seek(SeekFrom::End(0)).unwrap();
177        assert_eq!(seeker.stream_position().unwrap(), LEN);
178    }
179
180    pub(crate) fn test_stream_len_reports_correct_length<const LEN: u64, T: Seek>(mut seeker: T) {
181        assert_eq!(seeker.stream_len().unwrap(), LEN);
182    }
183}
184
185#[cfg(test)]
186mod tests {
187    use core::cmp::min;
188
189    use super::test_utils::*;
190    use super::*;
191
192    struct TestSeeker {
193        len: u64,
194        pos: u64,
195    }
196
197    impl Seek for TestSeeker {
198        fn seek(&mut self, pos: SeekFrom) -> Result<u64> {
199            let new_pos = match pos {
200                SeekFrom::Start(pos) => pos,
201                SeekFrom::Current(pos) => {
202                    self.pos.checked_add_signed(pos).ok_or(Error::OutOfRange)?
203                }
204                SeekFrom::End(pos) => self.len.checked_add_signed(-pos).ok_or(Error::OutOfRange)?,
205            };
206
207            if new_pos > self.len {
208                Err(Error::OutOfRange)
209            } else {
210                self.pos = new_pos;
211                Ok(new_pos)
212            }
213        }
214    }
215
216    // A stream wrapper that limits reads and writes to a maximum chunk size.
217    struct ChunkedStreamAdapter<S: Read + Write + Seek> {
218        inner: S,
219        chunk_size: usize,
220        num_reads: u32,
221        num_writes: u32,
222    }
223
224    impl<S: Read + Write + Seek> ChunkedStreamAdapter<S> {
225        fn new(inner: S, chunk_size: usize) -> Self {
226            Self {
227                inner,
228                chunk_size,
229                num_reads: 0,
230                num_writes: 0,
231            }
232        }
233    }
234
235    impl<S: Read + Write + Seek> Read for ChunkedStreamAdapter<S> {
236        fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
237            let read_len = min(self.chunk_size, buf.len());
238            self.num_reads += 1;
239            self.inner.read(&mut buf[..read_len])
240        }
241    }
242
243    impl<S: Read + Write + Seek> Write for ChunkedStreamAdapter<S> {
244        fn write(&mut self, buf: &[u8]) -> Result<usize> {
245            let write_len = min(self.chunk_size, buf.len());
246            self.num_writes += 1;
247            self.inner.write(&buf[..write_len])
248        }
249
250        fn flush(&mut self) -> Result<()> {
251            self.inner.flush()
252        }
253    }
254
255    struct ErrorStream {
256        error: Error,
257    }
258
259    impl Read for ErrorStream {
260        fn read(&mut self, _buf: &mut [u8]) -> Result<usize> {
261            Err(self.error)
262        }
263    }
264
265    impl Write for ErrorStream {
266        fn write(&mut self, _buf: &[u8]) -> Result<usize> {
267            Err(self.error)
268        }
269
270        fn flush(&mut self) -> Result<()> {
271            Err(self.error)
272        }
273    }
274
275    #[test]
276    fn default_rewind_impl_resets_position_to_zero() {
277        test_rewind_resets_position_to_zero::<64, _>(TestSeeker { len: 64, pos: 0 });
278    }
279
280    #[test]
281    fn default_stream_pos_impl_reports_correct_position() {
282        test_stream_pos_reports_correct_position::<64, _>(TestSeeker { len: 64, pos: 0 });
283    }
284
285    #[test]
286    fn default_stream_len_impl_reports_correct_length() {
287        test_stream_len_reports_correct_length::<64, _>(TestSeeker { len: 64, pos: 32 });
288    }
289
290    #[test]
291    fn read_exact_reads_full_buffer_on_short_reads() {
292        let cursor = Cursor::new((0x0..=0xff).collect::<Vec<u8>>());
293        // Limit reads to 10 bytes per read.
294        let mut wrapper = ChunkedStreamAdapter::new(cursor, 10);
295        let mut read_buffer = vec![0u8; 256];
296
297        wrapper.read_exact(&mut read_buffer).unwrap();
298
299        // Ensure that the correct bytes were read.
300        assert_eq!(wrapper.inner.into_inner(), read_buffer);
301
302        // Verify that the read was broken up into the correct number of reads.
303        assert_eq!(wrapper.num_reads, 26);
304    }
305
306    #[test]
307    fn read_exact_returns_error_on_too_little_data() {
308        let cursor = Cursor::new((0x0..=0x7f).collect::<Vec<u8>>());
309        // Limit reads to 10 bytes per read.
310        let mut wrapper = ChunkedStreamAdapter::new(cursor, 10);
311        let mut read_buffer = vec![0u8; 256];
312
313        assert_eq!(wrapper.read_exact(&mut read_buffer), Err(Error::OutOfRange));
314    }
315
316    #[test]
317    fn read_exact_propagates_read_errors() {
318        let mut error_stream = ErrorStream {
319            error: Error::Internal,
320        };
321        let mut read_buffer = vec![0u8; 256];
322        assert_eq!(
323            error_stream.read_exact(&mut read_buffer),
324            Err(Error::Internal)
325        );
326    }
327
328    #[test]
329    fn write_all_writes_full_buffer_on_short_writes() {
330        let cursor = Cursor::new(vec![0u8; 256]);
331        // Limit writes to 10 bytes per write.
332        let mut wrapper = ChunkedStreamAdapter::new(cursor, 10);
333        let write_buffer = (0x0..=0xff).collect::<Vec<u8>>();
334
335        wrapper.write_all(&write_buffer).unwrap();
336
337        // Ensure that the correct bytes were written.
338        assert_eq!(wrapper.inner.into_inner(), write_buffer);
339
340        // Verify that the write was broken up into the correct number of writes.
341        assert_eq!(wrapper.num_writes, 26);
342    }
343
344    #[test]
345    fn write_all_returns_error_on_too_little_data() {
346        let cursor = Cursor::new(vec![0u8; 128]);
347        // Limit writes to 10 bytes per write.
348        let mut wrapper = ChunkedStreamAdapter::new(cursor, 10);
349        let write_buffer = (0x0..=0xff).collect::<Vec<u8>>();
350
351        assert_eq!(wrapper.write_all(&write_buffer), Err(Error::OutOfRange));
352    }
353
354    #[test]
355    fn write_all_propagates_write_errors() {
356        let mut error_stream = ErrorStream {
357            error: Error::Internal,
358        };
359        let write_buffer = (0x0..=0xff).collect::<Vec<u8>>();
360        assert_eq!(error_stream.write_all(&write_buffer), Err(Error::Internal));
361    }
362}