1use core::cmp::min;
16
17use paste::paste;
18use pw_status::{Error, Result};
19use pw_varint::{VarintDecode, VarintEncode};
20
21use super::{Read, Seek, SeekFrom, Write};
22
23pub struct Cursor<T>
29where
30 T: AsRef<[u8]>,
31{
32 inner: T,
33 pos: usize,
34}
35
36impl<T: AsRef<[u8]>> Cursor<T> {
37 pub fn new(inner: T) -> Self {
41 Self { inner, pos: 0 }
42 }
43
44 pub fn into_inner(self) -> T {
46 self.inner
47 }
48
49 pub fn remaining(&self) -> usize {
51 self.len() - self.pos
52 }
53
54 #[allow(clippy::len_without_is_empty)]
58 pub fn len(&self) -> usize {
59 self.inner.as_ref().len()
60 }
61
62 pub fn position(&self) -> usize {
64 self.pos
65 }
66
67 fn remaining_slice(&mut self) -> &[u8] {
68 &self.inner.as_ref()[self.pos..]
69 }
70}
71
72impl<T: AsRef<[u8]> + AsMut<[u8]>> Cursor<T> {
73 fn remaining_mut(&mut self) -> &mut [u8] {
74 &mut self.inner.as_mut()[self.pos..]
75 }
76}
77
78fn read_impl(inner: &[u8], pos: &mut usize, buf: &mut [u8]) -> Result<usize> {
81 let remaining = inner.len() - *pos;
82 let read_len = min(remaining, buf.len());
83 buf[..read_len].copy_from_slice(&inner[*pos..(*pos + read_len)]);
84 *pos += read_len;
85 Ok(read_len)
86}
87
88impl<T: AsRef<[u8]>> Read for Cursor<T> {
89 fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
90 read_impl(self.inner.as_ref(), &mut self.pos, buf)
91 }
92}
93
94fn write_impl(inner: &mut [u8], pos: &mut usize, buf: &[u8]) -> Result<usize> {
97 let remaining = inner.len() - *pos;
98 let write_len = min(remaining, buf.len());
99 inner[*pos..(*pos + write_len)].copy_from_slice(&buf[0..write_len]);
100 *pos += write_len;
101 Ok(write_len)
102}
103
104impl<T: AsRef<[u8]> + AsMut<[u8]>> Write for Cursor<T> {
105 fn write(&mut self, buf: &[u8]) -> Result<usize> {
106 write_impl(self.inner.as_mut(), &mut self.pos, buf)
107 }
108
109 fn flush(&mut self) -> Result<()> {
110 Ok(())
112 }
113}
114
115impl<T: AsRef<[u8]>> Seek for Cursor<T> {
116 fn seek(&mut self, pos: SeekFrom) -> Result<u64> {
117 let new_pos = match pos {
118 SeekFrom::Start(pos) => pos,
119 SeekFrom::Current(pos) => (self.pos as u64)
120 .checked_add_signed(pos)
121 .ok_or(Error::OutOfRange)?,
122 SeekFrom::End(pos) => (self.len() as u64)
123 .checked_add_signed(-pos)
124 .ok_or(Error::OutOfRange)?,
125 };
126
127 let new_pos: usize = new_pos.try_into().map_err(|_| Error::OutOfRange)?;
130
131 if new_pos > self.len() {
132 Err(Error::OutOfRange)
133 } else {
134 self.pos = new_pos;
135 Ok(new_pos as u64)
136 }
137 }
138
139 fn rewind(&mut self) -> Result<()> {
141 self.pos = 0;
142 Ok(())
143 }
144
145 fn stream_len(&mut self) -> Result<u64> {
146 Ok(self.len() as u64)
147 }
148
149 fn stream_position(&mut self) -> Result<u64> {
150 Ok(self.pos as u64)
151 }
152}
153
154macro_rules! cursor_read_type_impl {
155 ($ty:ident, $endian:ident) => {
156 paste! {
157 fn [<read_ $ty _ $endian>](&mut self) -> Result<$ty> {
158 const NUM_BYTES: usize = $ty::BITS as usize / 8;
159 if NUM_BYTES > self.remaining() {
160 return Err(Error::OutOfRange);
161 }
162 let sub_slice = self
163 .inner
164 .as_ref()
165 .get(self.pos..self.pos + NUM_BYTES)
166 .ok_or_else(|| Error::InvalidArgument)?;
167 let sub_array: &[u8; NUM_BYTES] = unsafe { &*(sub_slice.as_ptr() as *const [u8; NUM_BYTES]) };
174 let value = $ty::[<from_ $endian _bytes>](*sub_array);
175
176 self.pos += NUM_BYTES;
177 Ok(value)
178 }
179 }
180 };
181}
182
183macro_rules! cursor_read_bits_impl {
184 ($bits:literal) => {
185 paste! {
186 cursor_read_type_impl!([<i $bits>], le);
187 cursor_read_type_impl!([<u $bits>], le);
188 cursor_read_type_impl!([<i $bits>], be);
189 cursor_read_type_impl!([<u $bits>], be);
190 }
191 };
192}
193
194macro_rules! cursor_write_type_impl {
195 ($ty:ident, $endian:ident) => {
196 paste! {
197 fn [<write_ $ty _ $endian>](&mut self, value: &$ty) -> Result<()> {
198 const NUM_BYTES: usize = $ty::BITS as usize / 8;
199 if NUM_BYTES > self.remaining() {
200 return Err(Error::OutOfRange);
201 }
202 let value_bytes = $ty::[<to_ $endian _bytes>](*value);
203 let sub_slice = self
204 .inner
205 .as_mut()
206 .get_mut(self.pos..self.pos + NUM_BYTES)
207 .ok_or_else(|| Error::InvalidArgument)?;
208
209 sub_slice.copy_from_slice(&value_bytes[..]);
210
211 self.pos += NUM_BYTES;
212 Ok(())
213 }
214 }
215 };
216}
217
218macro_rules! cursor_write_bits_impl {
219 ($bits:literal) => {
220 paste! {
221 cursor_write_type_impl!([<i $bits>], le);
222 cursor_write_type_impl!([<u $bits>], le);
223 cursor_write_type_impl!([<i $bits>], be);
224 cursor_write_type_impl!([<u $bits>], be);
225 }
226 };
227}
228
229impl<T: AsRef<[u8]>> crate::ReadInteger for Cursor<T> {
230 cursor_read_bits_impl!(8);
231 cursor_read_bits_impl!(16);
232 cursor_read_bits_impl!(32);
233 cursor_read_bits_impl!(64);
234 cursor_read_bits_impl!(128);
235}
236
237impl<T: AsRef<[u8]> + AsMut<[u8]>> crate::WriteInteger for Cursor<T> {
238 cursor_write_bits_impl!(8);
239 cursor_write_bits_impl!(16);
240 cursor_write_bits_impl!(32);
241 cursor_write_bits_impl!(64);
242 cursor_write_bits_impl!(128);
243}
244
245impl<T: AsRef<[u8]>> crate::ReadVarint for Cursor<T> {
246 fn read_varint(&mut self) -> Result<u64> {
247 let (len, value) = u64::varint_decode(self.remaining_slice())?;
248 self.pos += len;
249 Ok(value)
250 }
251
252 fn read_signed_varint(&mut self) -> Result<i64> {
253 let (len, value) = i64::varint_decode(self.remaining_slice())?;
254 self.pos += len;
255 Ok(value)
256 }
257}
258
259impl<T: AsRef<[u8]> + AsMut<[u8]>> crate::WriteVarint for Cursor<T> {
260 fn write_varint(&mut self, value: u64) -> Result<()> {
261 let encoded_len = value.varint_encode(self.remaining_mut())?;
262 self.pos += encoded_len;
263 Ok(())
264 }
265
266 fn write_signed_varint(&mut self, value: i64) -> Result<()> {
267 let encoded_len = value.varint_encode(self.remaining_mut())?;
268 self.pos += encoded_len;
269 Ok(())
270 }
271}
272
273#[cfg(test)]
274mod tests {
275 use super::*;
276 use crate::{test_utils::*, ReadInteger, ReadVarint, WriteInteger, WriteVarint};
277
278 #[test]
279 fn cursor_len_returns_total_bytes() {
280 let cursor = Cursor {
281 inner: &[0u8; 64],
282 pos: 31,
283 };
284 assert_eq!(cursor.len(), 64);
285 }
286
287 #[test]
288 fn cursor_remaining_returns_remaining_bytes() {
289 let cursor = Cursor {
290 inner: &[0u8; 64],
291 pos: 31,
292 };
293 assert_eq!(cursor.remaining(), 33);
294 }
295
296 #[test]
297 fn cursor_position_returns_current_position() {
298 let cursor = Cursor {
299 inner: &[0u8; 64],
300 pos: 31,
301 };
302 assert_eq!(cursor.position(), 31);
303 }
304
305 #[test]
306 fn cursor_read_of_partial_buffer_reads_correct_data() {
307 let mut cursor = Cursor {
308 inner: &[1, 2, 3, 4, 5, 6, 7, 8],
309 pos: 4,
310 };
311 let mut buf = [0u8; 8];
312 assert_eq!(cursor.read(&mut buf), Ok(4));
313 assert_eq!(buf, [5, 6, 7, 8, 0, 0, 0, 0]);
314 }
315
316 #[test]
317 fn cursor_write_of_partial_buffer_writes_correct_data() {
318 let mut cursor = Cursor {
319 inner: &mut [0, 0, 0, 0, 0, 0, 0, 0],
320 pos: 4,
321 };
322 let buf = [1, 2, 3, 4, 5, 6, 7, 8];
323 assert_eq!(cursor.write(&buf), Ok(4));
324 assert_eq!(cursor.inner, &[0, 0, 0, 0, 1, 2, 3, 4]);
325 }
326
327 #[test]
328 fn cursor_rewind_resets_position_to_zero() {
329 test_rewind_resets_position_to_zero::<64, _>(Cursor::new(&[0u8; 64]));
330 }
331
332 #[test]
333 fn cursor_stream_pos_reports_correct_position() {
334 test_stream_pos_reports_correct_position::<64, _>(Cursor::new(&[0u8; 64]));
335 }
336
337 #[test]
338 fn cursor_stream_len_reports_correct_length() {
339 test_stream_len_reports_correct_length::<64, _>(Cursor::new(&[0u8; 64]));
340 }
341
342 macro_rules! cursor_read_n_bit_integers_unpacks_data_correctly {
343 ($bits:literal) => {
344 paste! {
345 #[test]
346 fn [<cursor_read_ $bits _bit_integers_unpacks_data_correctly>]() {
347 let (bytes, values) = [<integer_ $bits _bit_test_cases>]();
348 let mut cursor = Cursor::new(&bytes);
349
350 assert_eq!(cursor.[<read_i $bits _le>](), Ok(values.0));
351 assert_eq!(cursor.[<read_u $bits _le>](), Ok(values.1));
352 assert_eq!(cursor.[<read_i $bits _be>](), Ok(values.2));
353 assert_eq!(cursor.[<read_u $bits _be>](), Ok(values.3));
354 }
355 }
356 };
357 }
358
359 macro_rules! cursor_write_n_bit_integers_packs_data_correctly {
360 ($bits:literal) => {
361 paste! {
362 #[test]
363 fn [<cursor_write_ $bits _bit_integers_packs_data_correctly>]() {
364 let (expected_bytes, values) = [<integer_ $bits _bit_test_cases>]();
365 let mut cursor = Cursor::new(vec![0u8; expected_bytes.len()]);
366 cursor.[<write_i $bits _le>](&values.0).unwrap();
367 cursor.[<write_u $bits _le>](&values.1).unwrap();
368 cursor.[<write_i $bits _be>](&values.2).unwrap();
369 cursor.[<write_u $bits _be>](&values.3).unwrap();
370
371 let result_bytes: Vec<u8> = cursor.into_inner().into();
372
373 assert_eq!(result_bytes, expected_bytes);
374 }
375 }
376 };
377 }
378
379 fn integer_8_bit_test_cases() -> (Vec<u8>, (i8, u8, i8, u8)) {
380 (
381 vec![
382 0x0, 0x1, 0x2, 0x3, ],
387 (0, 1, 2, 3),
388 )
389 }
390
391 cursor_read_n_bit_integers_unpacks_data_correctly!(8);
392 cursor_write_n_bit_integers_packs_data_correctly!(8);
393
394 fn integer_16_bit_test_cases() -> (Vec<u8>, (i16, u16, i16, u16)) {
395 (
396 vec![
397 0x0, 0x80, 0x1, 0x80, 0x80, 0x2, 0x80, 0x3, ],
402 (
403 i16::from_le_bytes([0x0, 0x80]),
404 0x8001,
405 i16::from_be_bytes([0x80, 0x2]),
406 0x8003,
407 ),
408 )
409 }
410
411 cursor_read_n_bit_integers_unpacks_data_correctly!(16);
412 cursor_write_n_bit_integers_packs_data_correctly!(16);
413
414 fn integer_32_bit_test_cases() -> (Vec<u8>, (i32, u32, i32, u32)) {
415 (
416 vec![
417 0x0, 0x1, 0x2, 0x80, 0x3, 0x4, 0x5, 0x80, 0x80, 0x6, 0x7, 0x8, 0x80, 0x9, 0xa, 0xb, ],
422 (
423 i32::from_le_bytes([0x0, 0x1, 0x2, 0x80]),
424 0x8005_0403,
425 i32::from_be_bytes([0x80, 0x6, 0x7, 0x8]),
426 0x8009_0a0b,
427 ),
428 )
429 }
430
431 cursor_read_n_bit_integers_unpacks_data_correctly!(32);
432 cursor_write_n_bit_integers_packs_data_correctly!(32);
433
434 fn integer_64_bit_test_cases() -> (Vec<u8>, (i64, u64, i64, u64)) {
435 (
436 vec![
437 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x80, 0x7, 0x8, 0x9, 0xa, 0xb, 0xc, 0xd, 0x80, 0x80, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x80, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, ],
442 (
443 i64::from_le_bytes([0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x80]),
444 0x800d_0c0b_0a09_0807,
445 i64::from_be_bytes([0x80, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16]),
446 0x8017_1819_1a1b_1c1d,
447 ),
448 )
449 }
450
451 cursor_read_n_bit_integers_unpacks_data_correctly!(64);
452 cursor_write_n_bit_integers_packs_data_correctly!(64);
453
454 fn integer_128_bit_test_cases() -> (Vec<u8>, (i128, u128, i128, u128)) {
455 #[rustfmt::skip]
456 let val = (
457 vec![
458 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
460 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x8f,
461 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17,
463 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x8f,
464 0x80, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27,
466 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f,
467 0x80, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37,
469 0x38, 0x39, 0x3a, 0x3b, 0x3c, 0x3d, 0x3e, 0x3f,
470 ],
471 (
472 i128::from_le_bytes([
473 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
474 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x8f,
475 ]),
476 0x8f1e_1d1c_1b1a_1918_1716_1514_1312_1110,
477 i128::from_be_bytes([
478 0x80, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27,
479 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f,
480 ]),
481 0x8031_3233_3435_3637_3839_3a3b_3c3d_3e3f,
482 ),
483 );
484 val
485 }
486
487 cursor_read_n_bit_integers_unpacks_data_correctly!(128);
488 cursor_write_n_bit_integers_packs_data_correctly!(128);
489
490 #[test]
491 pub fn read_varint_unpacks_data_correctly() {
492 let mut cursor = Cursor::new(vec![0xfe, 0xff, 0xff, 0xff, 0x0f, 0x0, 0x0, 0x0]);
493 let value = cursor.read_varint().unwrap();
494 assert_eq!(value, 0xffff_fffe);
495
496 let mut cursor = Cursor::new(vec![0xff, 0xff, 0xff, 0xff, 0x0f, 0x0, 0x0, 0x0]);
497 let value = cursor.read_varint().unwrap();
498 assert_eq!(value, 0xffff_ffff);
499 }
500
501 #[test]
502 pub fn read_signed_varint_unpacks_data_correctly() {
503 let mut cursor = Cursor::new(vec![0xfe, 0xff, 0xff, 0xff, 0x0f, 0x0, 0x0, 0x0]);
504 let value = cursor.read_signed_varint().unwrap();
505 assert_eq!(value, i32::MAX.into());
506
507 let mut cursor = Cursor::new(vec![0xff, 0xff, 0xff, 0xff, 0x0f, 0x0, 0x0, 0x0]);
508 let value = cursor.read_signed_varint().unwrap();
509 assert_eq!(value, i32::MIN.into());
510 }
511
512 #[test]
513 pub fn write_varint_packs_data_correctly() {
514 let mut cursor = Cursor::new(vec![0u8; 8]);
515 cursor.write_varint(0xffff_fffe).unwrap();
516 let buf = cursor.into_inner();
517 assert_eq!(buf, vec![0xfe, 0xff, 0xff, 0xff, 0x0f, 0x0, 0x0, 0x0]);
518
519 let mut cursor = Cursor::new(vec![0u8; 8]);
520 cursor.write_varint(0xffff_ffff).unwrap();
521 let buf = cursor.into_inner();
522 assert_eq!(buf, vec![0xff, 0xff, 0xff, 0xff, 0x0f, 0x0, 0x0, 0x0]);
523 }
524
525 #[test]
526 pub fn write_signed_varint_packs_data_correctly() {
527 let mut cursor = Cursor::new(vec![0u8; 8]);
528 cursor.write_signed_varint(i32::MAX.into()).unwrap();
529 let buf = cursor.into_inner();
530 assert_eq!(buf, vec![0xfe, 0xff, 0xff, 0xff, 0x0f, 0x0, 0x0, 0x0]);
531
532 let mut cursor = Cursor::new(vec![0u8; 8]);
533 cursor.write_signed_varint(i32::MIN.into()).unwrap();
534 let buf = cursor.into_inner();
535 assert_eq!(buf, vec![0xff, 0xff, 0xff, 0xff, 0x0f, 0x0, 0x0, 0x0]);
536 }
537}