1use core::cmp::min;
16use core::ptr;
17
18use paste::paste;
19use pw_status::{Error, Result};
20use pw_varint::{VarintDecode, VarintEncode};
21
22use super::{Read, Seek, SeekFrom, Write};
23
24pub struct Cursor<T>
30where
31 T: AsRef<[u8]>,
32{
33 inner: T,
34 pos: usize,
35}
36
37impl<T: AsRef<[u8]>> Cursor<T> {
38 pub fn new(inner: T) -> Self {
42 Self { inner, pos: 0 }
43 }
44
45 pub fn into_inner(self) -> T {
47 self.inner
48 }
49
50 pub fn remaining(&self) -> usize {
52 self.len().saturating_sub(self.pos)
53 }
54
55 #[allow(clippy::len_without_is_empty)]
59 pub fn len(&self) -> usize {
60 self.inner.as_ref().len()
61 }
62
63 pub fn position(&self) -> usize {
65 self.pos
66 }
67
68 fn remaining_slice(&mut self) -> &[u8] {
69 &self.inner.as_ref()[self.pos..]
70 }
71}
72
73impl<T: AsRef<[u8]> + AsMut<[u8]>> Cursor<T> {
74 fn remaining_mut(&mut self) -> &mut [u8] {
75 &mut self.inner.as_mut()[self.pos..]
76 }
77}
78
79fn read_impl(inner: &[u8], pos: &mut usize, buf: &mut [u8]) -> Result<usize> {
82 let remaining = inner.len() - *pos;
83 let read_len = min(remaining, buf.len());
84 buf[..read_len].copy_from_slice(&inner[*pos..(*pos + read_len)]);
85 *pos += read_len;
86 Ok(read_len)
87}
88
89impl<T: AsRef<[u8]>> Read for Cursor<T> {
90 fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
91 read_impl(self.inner.as_ref(), &mut self.pos, buf)
92 }
93}
94
95fn write_impl(inner: &mut [u8], pos: &mut usize, buf: &[u8]) -> Result<usize> {
98 let remaining = inner.len().checked_sub(*pos).ok_or(Error::OutOfRange)?;
99 let write_len = min(remaining, buf.len());
100 unsafe {
103 let src_ptr = buf.as_ptr();
104 let dst_ptr = inner.as_mut_ptr().add(*pos);
105 ptr::copy_nonoverlapping(src_ptr, dst_ptr, write_len);
106 }
107 *pos = pos.saturating_add(write_len);
110 Ok(write_len)
111}
112
113impl<T: AsRef<[u8]> + AsMut<[u8]>> Write for Cursor<T> {
114 fn write(&mut self, buf: &[u8]) -> Result<usize> {
115 write_impl(self.inner.as_mut(), &mut self.pos, buf)
116 }
117
118 fn flush(&mut self) -> Result<()> {
119 Ok(())
121 }
122}
123
124impl<T: AsRef<[u8]>> Seek for Cursor<T> {
125 fn seek(&mut self, pos: SeekFrom) -> Result<u64> {
126 let new_pos = match pos {
127 SeekFrom::Start(pos) => pos,
128 SeekFrom::Current(pos) => (self.pos as u64)
129 .checked_add_signed(pos)
130 .ok_or(Error::OutOfRange)?,
131 SeekFrom::End(pos) => (self.len() as u64)
132 .checked_add_signed(-pos)
133 .ok_or(Error::OutOfRange)?,
134 };
135
136 let new_pos: usize = new_pos.try_into().map_err(|_| Error::OutOfRange)?;
139
140 if new_pos > self.len() {
141 Err(Error::OutOfRange)
142 } else {
143 self.pos = new_pos;
144 Ok(new_pos as u64)
145 }
146 }
147
148 fn rewind(&mut self) -> Result<()> {
150 self.pos = 0;
151 Ok(())
152 }
153
154 fn stream_len(&mut self) -> Result<u64> {
155 Ok(self.len() as u64)
156 }
157
158 fn stream_position(&mut self) -> Result<u64> {
159 Ok(self.pos as u64)
160 }
161}
162
163macro_rules! cursor_read_type_impl {
164 ($ty:ident, $endian:ident) => {
165 paste! {
166 fn [<read_ $ty _ $endian>](&mut self) -> Result<$ty> {
167 const NUM_BYTES: usize = $ty::BITS as usize / 8;
168 if NUM_BYTES > self.remaining() {
169 return Err(Error::OutOfRange);
170 }
171 let sub_slice = self
172 .inner
173 .as_ref()
174 .get(self.pos..self.pos + NUM_BYTES)
175 .ok_or_else(|| Error::InvalidArgument)?;
176 let sub_array: &[u8; NUM_BYTES] = unsafe { &*(sub_slice.as_ptr() as *const [u8; NUM_BYTES]) };
183 let value = $ty::[<from_ $endian _bytes>](*sub_array);
184
185 self.pos += NUM_BYTES;
186 Ok(value)
187 }
188 }
189 };
190}
191
192macro_rules! cursor_read_bits_impl {
193 ($bits:literal) => {
194 paste! {
195 cursor_read_type_impl!([<i $bits>], le);
196 cursor_read_type_impl!([<u $bits>], le);
197 cursor_read_type_impl!([<i $bits>], be);
198 cursor_read_type_impl!([<u $bits>], be);
199 }
200 };
201}
202
203macro_rules! cursor_write_type_impl {
204 ($ty:ident, $endian:ident) => {
205 paste! {
206 fn [<write_ $ty _ $endian>](&mut self, value: &$ty) -> Result<()> {
207 const NUM_BYTES: usize = $ty::BITS as usize / 8;
208 if NUM_BYTES > self.remaining() {
209 return Err(Error::OutOfRange);
210 }
211 let value_bytes = $ty::[<to_ $endian _bytes>](*value);
212 let sub_slice = self
213 .inner
214 .as_mut()
215 .get_mut(self.pos..self.pos + NUM_BYTES)
216 .ok_or_else(|| Error::InvalidArgument)?;
217
218 sub_slice.copy_from_slice(&value_bytes[..]);
219
220 self.pos += NUM_BYTES;
221 Ok(())
222 }
223 }
224 };
225}
226
227macro_rules! cursor_write_bits_impl {
228 ($bits:literal) => {
229 paste! {
230 cursor_write_type_impl!([<i $bits>], le);
231 cursor_write_type_impl!([<u $bits>], le);
232 cursor_write_type_impl!([<i $bits>], be);
233 cursor_write_type_impl!([<u $bits>], be);
234 }
235 };
236}
237
238impl<T: AsRef<[u8]>> crate::ReadInteger for Cursor<T> {
239 cursor_read_bits_impl!(8);
240 cursor_read_bits_impl!(16);
241 cursor_read_bits_impl!(32);
242 cursor_read_bits_impl!(64);
243 cursor_read_bits_impl!(128);
244}
245
246impl<T: AsRef<[u8]> + AsMut<[u8]>> crate::WriteInteger for Cursor<T> {
247 cursor_write_bits_impl!(8);
248 cursor_write_bits_impl!(16);
249 cursor_write_bits_impl!(32);
250 cursor_write_bits_impl!(64);
251 cursor_write_bits_impl!(128);
252}
253
254impl<T: AsRef<[u8]>> crate::ReadVarint for Cursor<T> {
255 fn read_varint(&mut self) -> Result<u64> {
256 let (len, value) = u64::varint_decode(self.remaining_slice())?;
257 self.pos += len;
258 Ok(value)
259 }
260
261 fn read_signed_varint(&mut self) -> Result<i64> {
262 let (len, value) = i64::varint_decode(self.remaining_slice())?;
263 self.pos += len;
264 Ok(value)
265 }
266}
267
268impl<T: AsRef<[u8]> + AsMut<[u8]>> crate::WriteVarint for Cursor<T> {
269 fn write_varint(&mut self, value: u64) -> Result<()> {
270 let encoded_len = value.varint_encode(self.remaining_mut())?;
271 self.pos += encoded_len;
272 Ok(())
273 }
274
275 fn write_signed_varint(&mut self, value: i64) -> Result<()> {
276 let encoded_len = value.varint_encode(self.remaining_mut())?;
277 self.pos += encoded_len;
278 Ok(())
279 }
280}
281
282#[cfg(test)]
283mod tests {
284 use super::*;
285 use crate::{test_utils::*, ReadInteger, ReadVarint, WriteInteger, WriteVarint};
286
287 #[test]
288 fn cursor_len_returns_total_bytes() {
289 let cursor = Cursor {
290 inner: &[0u8; 64],
291 pos: 31,
292 };
293 assert_eq!(cursor.len(), 64);
294 }
295
296 #[test]
297 fn cursor_remaining_returns_remaining_bytes() {
298 let cursor = Cursor {
299 inner: &[0u8; 64],
300 pos: 31,
301 };
302 assert_eq!(cursor.remaining(), 33);
303 }
304
305 #[test]
306 fn cursor_position_returns_current_position() {
307 let cursor = Cursor {
308 inner: &[0u8; 64],
309 pos: 31,
310 };
311 assert_eq!(cursor.position(), 31);
312 }
313
314 #[test]
315 fn cursor_read_of_partial_buffer_reads_correct_data() {
316 let mut cursor = Cursor {
317 inner: &[1, 2, 3, 4, 5, 6, 7, 8],
318 pos: 4,
319 };
320 let mut buf = [0u8; 8];
321 assert_eq!(cursor.read(&mut buf), Ok(4));
322 assert_eq!(buf, [5, 6, 7, 8, 0, 0, 0, 0]);
323 }
324
325 #[test]
326 fn cursor_write_of_partial_buffer_writes_correct_data() {
327 let mut cursor = Cursor {
328 inner: &mut [0, 0, 0, 0, 0, 0, 0, 0],
329 pos: 4,
330 };
331 let buf = [1, 2, 3, 4, 5, 6, 7, 8];
332 assert_eq!(cursor.write(&buf), Ok(4));
333 assert_eq!(cursor.inner, &[0, 0, 0, 0, 1, 2, 3, 4]);
334 }
335
336 #[test]
337 fn cursor_rewind_resets_position_to_zero() {
338 test_rewind_resets_position_to_zero::<64, _>(Cursor::new(&[0u8; 64]));
339 }
340
341 #[test]
342 fn cursor_stream_pos_reports_correct_position() {
343 test_stream_pos_reports_correct_position::<64, _>(Cursor::new(&[0u8; 64]));
344 }
345
346 #[test]
347 fn cursor_stream_len_reports_correct_length() {
348 test_stream_len_reports_correct_length::<64, _>(Cursor::new(&[0u8; 64]));
349 }
350
351 macro_rules! cursor_read_n_bit_integers_unpacks_data_correctly {
352 ($bits:literal) => {
353 paste! {
354 #[test]
355 fn [<cursor_read_ $bits _bit_integers_unpacks_data_correctly>]() {
356 let (bytes, values) = [<integer_ $bits _bit_test_cases>]();
357 let mut cursor = Cursor::new(&bytes);
358
359 assert_eq!(cursor.[<read_i $bits _le>](), Ok(values.0));
360 assert_eq!(cursor.[<read_u $bits _le>](), Ok(values.1));
361 assert_eq!(cursor.[<read_i $bits _be>](), Ok(values.2));
362 assert_eq!(cursor.[<read_u $bits _be>](), Ok(values.3));
363 }
364 }
365 };
366 }
367
368 macro_rules! cursor_write_n_bit_integers_packs_data_correctly {
369 ($bits:literal) => {
370 paste! {
371 #[test]
372 fn [<cursor_write_ $bits _bit_integers_packs_data_correctly>]() {
373 let (expected_bytes, values) = [<integer_ $bits _bit_test_cases>]();
374 let mut cursor = Cursor::new(vec![0u8; expected_bytes.len()]);
375 cursor.[<write_i $bits _le>](&values.0).unwrap();
376 cursor.[<write_u $bits _le>](&values.1).unwrap();
377 cursor.[<write_i $bits _be>](&values.2).unwrap();
378 cursor.[<write_u $bits _be>](&values.3).unwrap();
379
380 let result_bytes: Vec<u8> = cursor.into_inner().into();
381
382 assert_eq!(result_bytes, expected_bytes);
383 }
384 }
385 };
386 }
387
388 fn integer_8_bit_test_cases() -> (Vec<u8>, (i8, u8, i8, u8)) {
389 (
390 vec![
391 0x0, 0x1, 0x2, 0x3, ],
396 (0, 1, 2, 3),
397 )
398 }
399
400 cursor_read_n_bit_integers_unpacks_data_correctly!(8);
401 cursor_write_n_bit_integers_packs_data_correctly!(8);
402
403 fn integer_16_bit_test_cases() -> (Vec<u8>, (i16, u16, i16, u16)) {
404 (
405 vec![
406 0x0, 0x80, 0x1, 0x80, 0x80, 0x2, 0x80, 0x3, ],
411 (
412 i16::from_le_bytes([0x0, 0x80]),
413 0x8001,
414 i16::from_be_bytes([0x80, 0x2]),
415 0x8003,
416 ),
417 )
418 }
419
420 cursor_read_n_bit_integers_unpacks_data_correctly!(16);
421 cursor_write_n_bit_integers_packs_data_correctly!(16);
422
423 fn integer_32_bit_test_cases() -> (Vec<u8>, (i32, u32, i32, u32)) {
424 (
425 vec![
426 0x0, 0x1, 0x2, 0x80, 0x3, 0x4, 0x5, 0x80, 0x80, 0x6, 0x7, 0x8, 0x80, 0x9, 0xa, 0xb, ],
431 (
432 i32::from_le_bytes([0x0, 0x1, 0x2, 0x80]),
433 0x8005_0403,
434 i32::from_be_bytes([0x80, 0x6, 0x7, 0x8]),
435 0x8009_0a0b,
436 ),
437 )
438 }
439
440 cursor_read_n_bit_integers_unpacks_data_correctly!(32);
441 cursor_write_n_bit_integers_packs_data_correctly!(32);
442
443 fn integer_64_bit_test_cases() -> (Vec<u8>, (i64, u64, i64, u64)) {
444 (
445 vec![
446 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x80, 0x7, 0x8, 0x9, 0xa, 0xb, 0xc, 0xd, 0x80, 0x80, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x80, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, ],
451 (
452 i64::from_le_bytes([0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x80]),
453 0x800d_0c0b_0a09_0807,
454 i64::from_be_bytes([0x80, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16]),
455 0x8017_1819_1a1b_1c1d,
456 ),
457 )
458 }
459
460 cursor_read_n_bit_integers_unpacks_data_correctly!(64);
461 cursor_write_n_bit_integers_packs_data_correctly!(64);
462
463 fn integer_128_bit_test_cases() -> (Vec<u8>, (i128, u128, i128, u128)) {
464 #[rustfmt::skip]
465 let val = (
466 vec![
467 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
469 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x8f,
470 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17,
472 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x8f,
473 0x80, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27,
475 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f,
476 0x80, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37,
478 0x38, 0x39, 0x3a, 0x3b, 0x3c, 0x3d, 0x3e, 0x3f,
479 ],
480 (
481 i128::from_le_bytes([
482 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
483 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x8f,
484 ]),
485 0x8f1e_1d1c_1b1a_1918_1716_1514_1312_1110,
486 i128::from_be_bytes([
487 0x80, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27,
488 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f,
489 ]),
490 0x8031_3233_3435_3637_3839_3a3b_3c3d_3e3f,
491 ),
492 );
493 val
494 }
495
496 cursor_read_n_bit_integers_unpacks_data_correctly!(128);
497 cursor_write_n_bit_integers_packs_data_correctly!(128);
498
499 #[test]
500 pub fn read_varint_unpacks_data_correctly() {
501 let mut cursor = Cursor::new(vec![0xfe, 0xff, 0xff, 0xff, 0x0f, 0x0, 0x0, 0x0]);
502 let value = cursor.read_varint().unwrap();
503 assert_eq!(value, 0xffff_fffe);
504
505 let mut cursor = Cursor::new(vec![0xff, 0xff, 0xff, 0xff, 0x0f, 0x0, 0x0, 0x0]);
506 let value = cursor.read_varint().unwrap();
507 assert_eq!(value, 0xffff_ffff);
508 }
509
510 #[test]
511 pub fn read_signed_varint_unpacks_data_correctly() {
512 let mut cursor = Cursor::new(vec![0xfe, 0xff, 0xff, 0xff, 0x0f, 0x0, 0x0, 0x0]);
513 let value = cursor.read_signed_varint().unwrap();
514 assert_eq!(value, i32::MAX.into());
515
516 let mut cursor = Cursor::new(vec![0xff, 0xff, 0xff, 0xff, 0x0f, 0x0, 0x0, 0x0]);
517 let value = cursor.read_signed_varint().unwrap();
518 assert_eq!(value, i32::MIN.into());
519 }
520
521 #[test]
522 pub fn write_varint_packs_data_correctly() {
523 let mut cursor = Cursor::new(vec![0u8; 8]);
524 cursor.write_varint(0xffff_fffe).unwrap();
525 let buf = cursor.into_inner();
526 assert_eq!(buf, vec![0xfe, 0xff, 0xff, 0xff, 0x0f, 0x0, 0x0, 0x0]);
527
528 let mut cursor = Cursor::new(vec![0u8; 8]);
529 cursor.write_varint(0xffff_ffff).unwrap();
530 let buf = cursor.into_inner();
531 assert_eq!(buf, vec![0xff, 0xff, 0xff, 0xff, 0x0f, 0x0, 0x0, 0x0]);
532 }
533
534 #[test]
535 pub fn write_signed_varint_packs_data_correctly() {
536 let mut cursor = Cursor::new(vec![0u8; 8]);
537 cursor.write_signed_varint(i32::MAX.into()).unwrap();
538 let buf = cursor.into_inner();
539 assert_eq!(buf, vec![0xfe, 0xff, 0xff, 0xff, 0x0f, 0x0, 0x0, 0x0]);
540
541 let mut cursor = Cursor::new(vec![0u8; 8]);
542 cursor.write_signed_varint(i32::MIN.into()).unwrap();
543 let buf = cursor.into_inner();
544 assert_eq!(buf, vec![0xff, 0xff, 0xff, 0xff, 0x0f, 0x0, 0x0, 0x0]);
545 }
546}