aboutsummaryrefslogtreecommitdiff
path: root/src/common/sized_buffer.rs
blob: ebfcab5e2ef7a1457e91ced7c0672e62e283195d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
use memchr::memchr;
use std::io;
use std::io::Write;

pub struct SizedBuffer<T> {
    len: usize,
    pos: usize,
    pub inner: T,
}

#[derive(Debug)]
pub enum Error {
    IO(io::Error),
    ReadPastEnd,
    WrongByte { expected: u8, got: u8 },
}

impl<const N: usize> SizedBuffer<[u8; N]> {
    pub fn new() -> Self {
        Self {
            len: 0,
            pos: 0,
            inner: [0; N],
        }
    }

    pub fn append(&mut self, bytes: &mut [u8]) -> Result<(), Error> {
        self.write_all(bytes).map_err(Error::IO)
    }

    pub fn peek(&self) -> u8 {
        self.inner[self.pos]
    }

    pub fn read(&mut self) -> Result<u8, Error> {
        self.seek(1)?;
        Ok(self.inner[self.pos - 1])
    }

    /* Reads n bytes from the buffer, and moves the cursor right after the last byte read.
     *
     * Example with n=3:
     *
     *     pos = 2
     *      v
     * [0,1,2,3,4,5,6,7,8,9]
     *     |-----|^
     *      read  new pos
     */
    pub fn read_count(&mut self, n: usize) -> Result<&[u8], Error> {
        self.seek(n)?;
        Ok(&self.inner[self.pos - n..self.pos])
    }

    pub fn read_count_unchecked(&mut self, n: usize) -> &[u8] {
        self.seek_unchecked(n);
        &self.inner[self.pos - n..self.pos]
    }

    pub fn skip_byte(&mut self, byte: u8) -> Result<(), Error> {
        let x = self.read()?;
        if x != byte {
            return Err(Error::WrongByte {
                expected: byte,
                got: x,
            });
        }
        Ok(())
    }

    pub fn seek(&mut self, n: usize) -> Result<(), Error> {
        /* Allows to seek 1 past the end of the buffer. Subsequent reads will still fail but this
         * allows to read the last byte. Otherwise trying to read the last byte would fail after
         * reading it, when trying to advance the position of the cursor. */
        if self.pos + n > self.len {
            return Err(Error::ReadPastEnd);
        }
        self.pos += n;

        Ok(())
    }

    pub fn seek_unchecked(&mut self, n: usize) {
        self.pos += n;
    }

    pub fn find_offset(&mut self, x: u8) -> Option<usize> {
        memchr(x, &self.inner[self.pos..self.len])
    }

    pub fn empty(&self) -> bool {
        self.len == self.pos
    }

    pub fn reset(&mut self) {
        self.len = 0;
        self.pos = 0;
    }
}

impl<const N: usize> Write for SizedBuffer<[u8; N]> {
    #[inline]
    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
        let amt = (&mut self.inner[self.len..]).write(buf)?;
        self.len += amt;
        Ok(amt)
    }

    #[inline]
    fn flush(&mut self) -> io::Result<()> {
        Ok(())
    }
}