1use core::cmp::min;
16use core::ptr;
17
18use paste::paste;
19use pw_status::{Error, Result};
20use pw_varint::{VarintDecode, VarintEncode};
21
22use super::{Read, Seek, SeekFrom, Write};
23
24pub struct Cursor<T>
30where
31 T: AsRef<[u8]>,
32{
33 inner: T,
34 pos: usize,
35}
36
37impl<T: AsRef<[u8]>> Cursor<T> {
38 pub fn new(inner: T) -> Self {
42 Self { inner, pos: 0 }
43 }
44
45 pub fn into_inner(self) -> T {
47 self.inner
48 }
49
50 pub fn remaining(&self) -> usize {
52 self.len().saturating_sub(self.pos)
53 }
54
55 #[allow(clippy::len_without_is_empty)]
59 pub fn len(&self) -> usize {
60 self.inner.as_ref().len()
61 }
62
63 pub fn position(&self) -> usize {
65 self.pos
66 }
67
68 fn remaining_slice(&mut self) -> &[u8] {
69 &self.inner.as_ref()[self.pos..]
70 }
71}
72
73impl<T: AsRef<[u8]> + AsMut<[u8]>> Cursor<T> {
74 fn remaining_mut(&mut self) -> &mut [u8] {
75 &mut self.inner.as_mut()[self.pos..]
76 }
77}
78
79fn read_impl(inner: &[u8], pos: &mut usize, buf: &mut [u8]) -> Result<usize> {
82 let remaining = inner.len() - *pos;
83 let read_len = min(remaining, buf.len());
84 buf[..read_len].copy_from_slice(&inner[*pos..(*pos + read_len)]);
85 *pos += read_len;
86 Ok(read_len)
87}
88
89impl<T: AsRef<[u8]>> Read for Cursor<T> {
90 fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
91 read_impl(self.inner.as_ref(), &mut self.pos, buf)
92 }
93}
94
95fn write_impl(inner: &mut [u8], pos: &mut usize, buf: &[u8]) -> Result<usize> {
98 let remaining = inner.len().checked_sub(*pos).ok_or(Error::OutOfRange)?;
99 let write_len = min(remaining, buf.len());
100 unsafe {
103 let src_ptr = buf.as_ptr();
104 let dst_ptr = inner.as_mut_ptr().add(*pos);
105 ptr::copy_nonoverlapping(src_ptr, dst_ptr, write_len);
106 }
107 *pos = pos.saturating_add(write_len);
110 Ok(write_len)
111}
112
113impl<T: AsRef<[u8]> + AsMut<[u8]>> Write for Cursor<T> {
114 fn write(&mut self, buf: &[u8]) -> Result<usize> {
115 write_impl(self.inner.as_mut(), &mut self.pos, buf)
116 }
117
118 fn flush(&mut self) -> Result<()> {
119 Ok(())
121 }
122}
123
124impl<T: AsRef<[u8]>> Seek for Cursor<T> {
125 fn seek(&mut self, pos: SeekFrom) -> Result<u64> {
126 let new_pos = match pos {
127 SeekFrom::Start(pos) => pos,
128 SeekFrom::Current(pos) => (self.pos as u64)
129 .checked_add_signed(pos)
130 .ok_or(Error::OutOfRange)?,
131 SeekFrom::End(pos) => (self.len() as u64)
132 .checked_add_signed(-pos)
133 .ok_or(Error::OutOfRange)?,
134 };
135
136 let new_pos: usize = new_pos.try_into().map_err(|_| Error::OutOfRange)?;
139
140 if new_pos > self.len() {
141 Err(Error::OutOfRange)
142 } else {
143 self.pos = new_pos;
144 Ok(new_pos as u64)
145 }
146 }
147
148 fn rewind(&mut self) -> Result<()> {
150 self.pos = 0;
151 Ok(())
152 }
153
154 fn stream_len(&mut self) -> Result<u64> {
155 Ok(self.len() as u64)
156 }
157
158 fn stream_position(&mut self) -> Result<u64> {
159 Ok(self.pos as u64)
160 }
161}
162
163macro_rules! cursor_read_type_impl {
164 ($ty:ident, $endian:ident) => {
165 paste! {
166 fn [<read_ $ty _ $endian>](&mut self) -> Result<$ty> {
167 const NUM_BYTES: usize = $ty::BITS as usize / 8;
168 if NUM_BYTES > self.remaining() {
169 return Err(Error::OutOfRange);
170 }
171 let sub_slice = self
172 .inner
173 .as_ref()
174 .get(self.pos..self.pos + NUM_BYTES)
175 .ok_or_else(|| Error::InvalidArgument)?;
176 let sub_array: &[u8; NUM_BYTES] = unsafe { &*(sub_slice.as_ptr() as *const [u8; NUM_BYTES]) };
183 let value = $ty::[<from_ $endian _bytes>](*sub_array);
184
185 self.pos += NUM_BYTES;
186 Ok(value)
187 }
188 }
189 };
190}
191
192macro_rules! cursor_read_bits_impl {
193 ($bits:literal) => {
194 paste! {
195 cursor_read_type_impl!([<i $bits>], le);
196 cursor_read_type_impl!([<u $bits>], le);
197 cursor_read_type_impl!([<i $bits>], be);
198 cursor_read_type_impl!([<u $bits>], be);
199 }
200 };
201}
202
203macro_rules! cursor_write_type_impl {
204 ($ty:ident, $endian:ident) => {
205 paste! {
206 fn [<write_ $ty _ $endian>](&mut self, value: &$ty) -> Result<()> {
207 const NUM_BYTES: usize = $ty::BITS as usize / 8;
208 if NUM_BYTES > self.remaining() {
209 return Err(Error::OutOfRange);
210 }
211 let value_bytes = $ty::[<to_ $endian _bytes>](*value);
212 let sub_slice = self
213 .inner
214 .as_mut()
215 .get_mut(self.pos..self.pos + NUM_BYTES)
216 .ok_or_else(|| Error::InvalidArgument)?;
217
218 sub_slice.copy_from_slice(&value_bytes[..]);
219
220 self.pos += NUM_BYTES;
221 Ok(())
222 }
223 }
224 };
225}
226
227macro_rules! cursor_write_bits_impl {
228 ($bits:literal) => {
229 paste! {
230 cursor_write_type_impl!([<i $bits>], le);
231 cursor_write_type_impl!([<u $bits>], le);
232 cursor_write_type_impl!([<i $bits>], be);
233 cursor_write_type_impl!([<u $bits>], be);
234 }
235 };
236}
237
238impl<T: AsRef<[u8]>> crate::ReadInteger for Cursor<T> {
239 cursor_read_bits_impl!(8);
240 cursor_read_bits_impl!(16);
241 cursor_read_bits_impl!(32);
242 cursor_read_bits_impl!(64);
243 cursor_read_bits_impl!(128);
244}
245
246impl<T: AsRef<[u8]> + AsMut<[u8]>> crate::WriteInteger for Cursor<T> {
247 cursor_write_bits_impl!(8);
248 cursor_write_bits_impl!(16);
249 cursor_write_bits_impl!(32);
250 cursor_write_bits_impl!(64);
251 cursor_write_bits_impl!(128);
252}
253
254impl<T: AsRef<[u8]>> crate::ReadVarint for Cursor<T> {
255 fn read_varint(&mut self) -> Result<u64> {
256 let (len, value) = u64::varint_decode(self.remaining_slice())?;
257 self.pos += len;
258 Ok(value)
259 }
260
261 fn read_signed_varint(&mut self) -> Result<i64> {
262 let (len, value) = i64::varint_decode(self.remaining_slice())?;
263 self.pos += len;
264 Ok(value)
265 }
266}
267
268impl<T: AsRef<[u8]> + AsMut<[u8]>> crate::WriteVarint for Cursor<T> {
269 fn write_varint(&mut self, value: u64) -> Result<()> {
270 let encoded_len = value.varint_encode(self.remaining_mut())?;
271 self.pos += encoded_len;
272 Ok(())
273 }
274
275 fn write_signed_varint(&mut self, value: i64) -> Result<()> {
276 let encoded_len = value.varint_encode(self.remaining_mut())?;
277 self.pos += encoded_len;
278 Ok(())
279 }
280}
281
282#[cfg(test)]
283mod tests {
284 use super::*;
285 use crate::test_utils::*;
286 use crate::{ReadInteger, ReadVarint, WriteInteger, WriteVarint};
287
288 #[test]
289 fn cursor_len_returns_total_bytes() {
290 let cursor = Cursor {
291 inner: &[0u8; 64],
292 pos: 31,
293 };
294 assert_eq!(cursor.len(), 64);
295 }
296
297 #[test]
298 fn cursor_remaining_returns_remaining_bytes() {
299 let cursor = Cursor {
300 inner: &[0u8; 64],
301 pos: 31,
302 };
303 assert_eq!(cursor.remaining(), 33);
304 }
305
306 #[test]
307 fn cursor_position_returns_current_position() {
308 let cursor = Cursor {
309 inner: &[0u8; 64],
310 pos: 31,
311 };
312 assert_eq!(cursor.position(), 31);
313 }
314
315 #[test]
316 fn cursor_read_of_partial_buffer_reads_correct_data() {
317 let mut cursor = Cursor {
318 inner: &[1, 2, 3, 4, 5, 6, 7, 8],
319 pos: 4,
320 };
321 let mut buf = [0u8; 8];
322 assert_eq!(cursor.read(&mut buf), Ok(4));
323 assert_eq!(buf, [5, 6, 7, 8, 0, 0, 0, 0]);
324 }
325
326 #[test]
327 fn cursor_write_of_partial_buffer_writes_correct_data() {
328 let mut cursor = Cursor {
329 inner: &mut [0, 0, 0, 0, 0, 0, 0, 0],
330 pos: 4,
331 };
332 let buf = [1, 2, 3, 4, 5, 6, 7, 8];
333 assert_eq!(cursor.write(&buf), Ok(4));
334 assert_eq!(cursor.inner, &[0, 0, 0, 0, 1, 2, 3, 4]);
335 }
336
337 #[test]
338 fn cursor_rewind_resets_position_to_zero() {
339 test_rewind_resets_position_to_zero::<64, _>(Cursor::new(&[0u8; 64]));
340 }
341
342 #[test]
343 fn cursor_stream_pos_reports_correct_position() {
344 test_stream_pos_reports_correct_position::<64, _>(Cursor::new(&[0u8; 64]));
345 }
346
347 #[test]
348 fn cursor_stream_len_reports_correct_length() {
349 test_stream_len_reports_correct_length::<64, _>(Cursor::new(&[0u8; 64]));
350 }
351
352 macro_rules! cursor_read_n_bit_integers_unpacks_data_correctly {
353 ($bits:literal) => {
354 paste! {
355 #[test]
356 fn [<cursor_read_ $bits _bit_integers_unpacks_data_correctly>]() {
357 let (bytes, values) = [<integer_ $bits _bit_test_cases>]();
358 let mut cursor = Cursor::new(&bytes);
359
360 assert_eq!(cursor.[<read_i $bits _le>](), Ok(values.0));
361 assert_eq!(cursor.[<read_u $bits _le>](), Ok(values.1));
362 assert_eq!(cursor.[<read_i $bits _be>](), Ok(values.2));
363 assert_eq!(cursor.[<read_u $bits _be>](), Ok(values.3));
364 }
365 }
366 };
367 }
368
369 macro_rules! cursor_write_n_bit_integers_packs_data_correctly {
370 ($bits:literal) => {
371 paste! {
372 #[test]
373 fn [<cursor_write_ $bits _bit_integers_packs_data_correctly>]() {
374 let (expected_bytes, values) = [<integer_ $bits _bit_test_cases>]();
375 let mut cursor = Cursor::new(vec![0u8; expected_bytes.len()]);
376 cursor.[<write_i $bits _le>](&values.0).unwrap();
377 cursor.[<write_u $bits _le>](&values.1).unwrap();
378 cursor.[<write_i $bits _be>](&values.2).unwrap();
379 cursor.[<write_u $bits _be>](&values.3).unwrap();
380
381 let result_bytes: Vec<u8> = cursor.into_inner().into();
382
383 assert_eq!(result_bytes, expected_bytes);
384 }
385 }
386 };
387 }
388
389 fn integer_8_bit_test_cases() -> (Vec<u8>, (i8, u8, i8, u8)) {
390 (
391 vec![
392 0x0, 0x1, 0x2, 0x3, ],
397 (0, 1, 2, 3),
398 )
399 }
400
401 cursor_read_n_bit_integers_unpacks_data_correctly!(8);
402 cursor_write_n_bit_integers_packs_data_correctly!(8);
403
404 fn integer_16_bit_test_cases() -> (Vec<u8>, (i16, u16, i16, u16)) {
405 (
406 vec![
407 0x0, 0x80, 0x1, 0x80, 0x80, 0x2, 0x80, 0x3, ],
412 (
413 i16::from_le_bytes([0x0, 0x80]),
414 0x8001,
415 i16::from_be_bytes([0x80, 0x2]),
416 0x8003,
417 ),
418 )
419 }
420
421 cursor_read_n_bit_integers_unpacks_data_correctly!(16);
422 cursor_write_n_bit_integers_packs_data_correctly!(16);
423
424 fn integer_32_bit_test_cases() -> (Vec<u8>, (i32, u32, i32, u32)) {
425 (
426 vec![
427 0x0, 0x1, 0x2, 0x80, 0x3, 0x4, 0x5, 0x80, 0x80, 0x6, 0x7, 0x8, 0x80, 0x9, 0xa, 0xb, ],
432 (
433 i32::from_le_bytes([0x0, 0x1, 0x2, 0x80]),
434 0x8005_0403,
435 i32::from_be_bytes([0x80, 0x6, 0x7, 0x8]),
436 0x8009_0a0b,
437 ),
438 )
439 }
440
441 cursor_read_n_bit_integers_unpacks_data_correctly!(32);
442 cursor_write_n_bit_integers_packs_data_correctly!(32);
443
444 fn integer_64_bit_test_cases() -> (Vec<u8>, (i64, u64, i64, u64)) {
445 (
446 vec![
447 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x80, 0x7, 0x8, 0x9, 0xa, 0xb, 0xc, 0xd, 0x80, 0x80, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x80, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, ],
452 (
453 i64::from_le_bytes([0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x80]),
454 0x800d_0c0b_0a09_0807,
455 i64::from_be_bytes([0x80, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16]),
456 0x8017_1819_1a1b_1c1d,
457 ),
458 )
459 }
460
461 cursor_read_n_bit_integers_unpacks_data_correctly!(64);
462 cursor_write_n_bit_integers_packs_data_correctly!(64);
463
464 fn integer_128_bit_test_cases() -> (Vec<u8>, (i128, u128, i128, u128)) {
465 #[rustfmt::skip]
466 let val = (
467 vec![
468 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
470 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x8f,
471 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17,
473 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x8f,
474 0x80, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27,
476 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f,
477 0x80, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37,
479 0x38, 0x39, 0x3a, 0x3b, 0x3c, 0x3d, 0x3e, 0x3f,
480 ],
481 (
482 i128::from_le_bytes([
483 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
484 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x8f,
485 ]),
486 0x8f1e_1d1c_1b1a_1918_1716_1514_1312_1110,
487 i128::from_be_bytes([
488 0x80, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27,
489 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f,
490 ]),
491 0x8031_3233_3435_3637_3839_3a3b_3c3d_3e3f,
492 ),
493 );
494 val
495 }
496
497 cursor_read_n_bit_integers_unpacks_data_correctly!(128);
498 cursor_write_n_bit_integers_packs_data_correctly!(128);
499
500 #[test]
501 pub fn read_varint_unpacks_data_correctly() {
502 let mut cursor = Cursor::new(vec![0xfe, 0xff, 0xff, 0xff, 0x0f, 0x0, 0x0, 0x0]);
503 let value = cursor.read_varint().unwrap();
504 assert_eq!(value, 0xffff_fffe);
505
506 let mut cursor = Cursor::new(vec![0xff, 0xff, 0xff, 0xff, 0x0f, 0x0, 0x0, 0x0]);
507 let value = cursor.read_varint().unwrap();
508 assert_eq!(value, 0xffff_ffff);
509 }
510
511 #[test]
512 pub fn read_signed_varint_unpacks_data_correctly() {
513 let mut cursor = Cursor::new(vec![0xfe, 0xff, 0xff, 0xff, 0x0f, 0x0, 0x0, 0x0]);
514 let value = cursor.read_signed_varint().unwrap();
515 assert_eq!(value, i32::MAX.into());
516
517 let mut cursor = Cursor::new(vec![0xff, 0xff, 0xff, 0xff, 0x0f, 0x0, 0x0, 0x0]);
518 let value = cursor.read_signed_varint().unwrap();
519 assert_eq!(value, i32::MIN.into());
520 }
521
522 #[test]
523 pub fn write_varint_packs_data_correctly() {
524 let mut cursor = Cursor::new(vec![0u8; 8]);
525 cursor.write_varint(0xffff_fffe).unwrap();
526 let buf = cursor.into_inner();
527 assert_eq!(buf, vec![0xfe, 0xff, 0xff, 0xff, 0x0f, 0x0, 0x0, 0x0]);
528
529 let mut cursor = Cursor::new(vec![0u8; 8]);
530 cursor.write_varint(0xffff_ffff).unwrap();
531 let buf = cursor.into_inner();
532 assert_eq!(buf, vec![0xff, 0xff, 0xff, 0xff, 0x0f, 0x0, 0x0, 0x0]);
533 }
534
535 #[test]
536 pub fn write_signed_varint_packs_data_correctly() {
537 let mut cursor = Cursor::new(vec![0u8; 8]);
538 cursor.write_signed_varint(i32::MAX.into()).unwrap();
539 let buf = cursor.into_inner();
540 assert_eq!(buf, vec![0xfe, 0xff, 0xff, 0xff, 0x0f, 0x0, 0x0, 0x0]);
541
542 let mut cursor = Cursor::new(vec![0u8; 8]);
543 cursor.write_signed_varint(i32::MIN.into()).unwrap();
544 let buf = cursor.into_inner();
545 assert_eq!(buf, vec![0xff, 0xff, 0xff, 0xff, 0x0f, 0x0, 0x0, 0x0]);
546 }
547}