aboutsummaryrefslogtreecommitdiff
path: root/src/server
diff options
context:
space:
mode:
authorevuez <julien@mulga.net>2022-11-26 15:38:06 -0500
committerevuez <julien@mulga.net>2024-04-03 22:44:12 +0200
commit86098797034cbc7eb6db0cee54e17f8dcaedbc5d (patch)
tree29b6225ead843eb9022296a54657bbadfa1c4da0 /src/server
downloadblom-86098797034cbc7eb6db0cee54e17f8dcaedbc5d.tar.gz
Initial commitHEADmain
Diffstat (limited to 'src/server')
-rw-r--r--src/server/conn.rs230
-rw-r--r--src/server/parser.rs118
2 files changed, 348 insertions, 0 deletions
diff --git a/src/server/conn.rs b/src/server/conn.rs
new file mode 100644
index 0000000..7ea5b0d
--- /dev/null
+++ b/src/server/conn.rs
@@ -0,0 +1,230 @@
+use super::parser::Error as ParserError;
+use super::parser::Parser;
+use super::parser::State as ParserState;
+use crate::common::db::Database;
+use crate::common::expr::Expr;
+use crate::common::query::Error as QueryError;
+use crate::common::query::Query;
+use crate::common::sized_buffer::SizedBuffer;
+use crate::server::CONN_BUFLEN;
+use log::debug;
+use log::error;
+
+use mio::net::TcpStream;
+use mio::Token;
+use std::fmt;
+use std::fmt::Display;
+use std::fmt::Formatter;
+use std::io::ErrorKind;
+use std::io::Read;
+use std::io::Write;
+
+#[derive(Debug)]
+pub enum Error {
+ IO(std::io::Error),
+ Parser(ParserError),
+ Query(QueryError),
+ Fatal,
+}
+
+#[derive(Eq, PartialEq, Debug)]
+#[repr(u8)]
+pub enum State {
+ PendingExec = 0b100,
+ ReadReady = 0b010,
+ WriteReady = 0b001,
+}
+
+pub struct Conn {
+ pub flags: u8,
+ socket: TcpStream,
+ token: Token,
+ parser: Parser,
+ buf_in: SizedBuffer<[u8; CONN_BUFLEN]>,
+ buf_out: Vec<u8>,
+}
+
+impl Conn {
+ pub fn new(socket: TcpStream, token: Token) -> Self {
+ Self {
+ flags: State::ReadReady as u8,
+ socket,
+ token,
+ parser: Parser::new(),
+ buf_in: SizedBuffer::new(),
+ buf_out: Vec::new(),
+ }
+ }
+
+ pub fn handle_pending(&mut self, db: &mut Database) -> Result<(), Error> {
+ self.handle_msg(db, &[])
+ }
+
+ pub fn handle_msg(&mut self, db: &mut Database, msg: &[u8]) -> Result<(), Error> {
+ if let Err(err) = self.buf_in.append(&mut msg.to_vec()) {
+ error!("[{self}] ! Error while writing to the buffer: {err:?}");
+ return Err(Error::Fatal);
+ }
+
+ match self.parser.parse(&mut self.buf_in) {
+ Ok(ParserState::Wait) => Ok(()),
+ Ok(ParserState::Done(argv)) => {
+ Query::from_expr(Expr::Array(
+ argv.iter().map(|x| Expr::Blob(x.to_vec())).collect(),
+ ))
+ .map(|query| {
+ debug!("[{self}] < {query}");
+ let resp = query.exec(db);
+
+ self.draft(&resp.encode());
+ if self.buf_in.empty() {
+ self.buf_in.reset();
+ self.flags &= !(State::PendingExec as u8);
+ } else {
+ self.flags &= !(State::ReadReady as u8);
+ self.flags |= State::PendingExec as u8;
+ }
+
+ debug!("[{self}] > {resp}");
+ })
+ .map_err(|e| {
+ debug!("[{self}] ! {e:?}");
+ let wrapped_err = Error::Query(e);
+ let resp = Expr::Error {
+ code: error_code(&wrapped_err),
+ msg: format!("{}", wrapped_err).into_bytes(),
+ };
+
+ self.draft(&resp.encode());
+
+ if self.buf_in.empty() {
+ self.buf_in.reset();
+ self.flags &= !(State::PendingExec as u8);
+ } else {
+ self.flags &= !(State::ReadReady as u8);
+ self.flags |= State::PendingExec as u8;
+ }
+ wrapped_err
+ })?;
+
+ Ok(())
+ }
+ Err(err) => {
+ let wrapped_err = Error::Parser(err);
+
+ debug!("[{self}] ! {wrapped_err}");
+
+ let resp = Expr::Error {
+ code: error_code(&wrapped_err),
+ msg: format!("{}", wrapped_err).into_bytes(),
+ };
+
+ self.draft(&resp.encode());
+
+ /* If there's a parsing error, we throw away the entire buffer. This means
+ * potentially throwing away valid commands appearing later in the buffer */
+ self.buf_in.reset();
+ self.flags &= !(State::PendingExec as u8);
+
+ Err(wrapped_err)
+ }
+ }
+ }
+
+ fn draft(&mut self, msg: &[u8]) {
+ self.buf_out = msg.to_vec();
+ self.flags = State::WriteReady as u8;
+ }
+
+ pub fn reply(&mut self) -> Result<(), Error> {
+ assert!(self.write_ready());
+
+ match self.socket.write(&self.buf_out) {
+ Ok(wrote_len) if wrote_len == self.buf_out.len() => {
+ self.flags &= !(State::WriteReady as u8);
+
+ /* Full write, we're ready to accept more data. If there are still queries waiting
+ * to be processed in the buffer, wait until these are processed first. */
+ if !self.pending_exec() {
+ self.flags |= State::ReadReady as u8;
+ }
+
+ if let Err(e) = self.socket.flush() {
+ return Err(Error::IO(e));
+ }
+
+ Ok(())
+ }
+ Ok(_wrote_len) => {
+ /* Partial write, track the remaining length */
+ self.flags |= State::WriteReady as u8;
+ Ok(())
+ }
+ Err(e) if e.kind() == ErrorKind::BrokenPipe => Err(Error::IO(e)),
+ Err(e) if e.kind() == ErrorKind::ConnectionReset => Err(Error::IO(e)),
+ Err(e) => panic!("FATAL | Unexpected error: {e:?}"),
+ }
+ }
+
+ pub fn pending_exec(&self) -> bool {
+ self.flags & State::PendingExec as u8 > 0
+ }
+ pub fn read_ready(&self) -> bool {
+ self.flags & State::ReadReady as u8 > 0
+ }
+ pub fn write_ready(&self) -> bool {
+ self.flags & State::WriteReady as u8 > 0
+ }
+}
+
+impl Read for Conn {
+ fn read(&mut self, buf: &mut [u8]) -> Result<usize, std::io::Error> {
+ assert!(self.read_ready());
+ self.socket.read(buf)
+ }
+}
+
+impl Write for Conn {
+ fn write(&mut self, buf: &[u8]) -> Result<usize, std::io::Error> {
+ self.socket.write(buf)
+ }
+
+ fn flush(&mut self) -> Result<(), std::io::Error> {
+ self.socket.flush()
+ }
+}
+
+impl Display for Conn {
+ fn fmt(&self, f: &mut Formatter) -> fmt::Result {
+ match self.socket.peer_addr() {
+ Ok(addr) => write!(f, "{addr} #{}", self.token.0),
+ Err(_) => write!(f, "unknown-addr #{}", self.token.0),
+ }
+ }
+}
+
+// Errors
+
+impl fmt::Display for Error {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ let res = match self {
+ Error::IO(err) => format!("IO error: {err:?}"),
+ Error::Parser(err) => err.to_string(),
+ Error::Query(err) => err.to_string(),
+ Error::Fatal => "Fatal error".to_string(),
+ };
+
+ write!(f, "{}", res)
+ }
+}
+
+fn error_code(error: &Error) -> Vec<u8> {
+ let code = match error {
+ Error::IO(_) => "SERVER",
+ Error::Parser(_) => "PROTOCOL",
+ Error::Query(_) => "QUERY",
+ Error::Fatal => "SERVER",
+ };
+
+ code.as_bytes().to_vec()
+}
diff --git a/src/server/parser.rs b/src/server/parser.rs
new file mode 100644
index 0000000..e9c2aad
--- /dev/null
+++ b/src/server/parser.rs
@@ -0,0 +1,118 @@
+use crate::common::sized_buffer::Error as SizedBufferError;
+use crate::common::sized_buffer::SizedBuffer;
+use std::fmt;
+
+#[derive(Debug)]
+pub enum Error {
+ ExpectedArray(char),
+ ExpectedString(char),
+ InvalidLength,
+ Buffer(SizedBufferError),
+}
+
+#[derive(Debug)]
+pub enum State {
+ Wait,
+ Done(Vec<Vec<u8>>),
+}
+
+pub struct Parser {
+ pub argv: Vec<Vec<u8>>,
+ argc_rem: u64,
+ arg_len: i64,
+}
+
+impl Parser {
+ pub fn new() -> Self {
+ Self {
+ argv: Vec::new(),
+ argc_rem: 0,
+ arg_len: -1,
+ }
+ }
+
+ pub fn parse<const N: usize>(
+ &mut self,
+ buffer: &mut SizedBuffer<[u8; N]>,
+ ) -> Result<State, Error> {
+ if self.argc_rem == 0 {
+ assert_eq!(self.arg_len, -1);
+
+ /* argc_rem wasn't initialized yet. The current buffer was never read, we need to
+ * figure out the number of args. */
+ if buffer.peek() != b'*' {
+ return Err(Error::ExpectedArray(buffer.peek() as char));
+ }
+ let arr_lf_offset = match buffer.find_offset(b'\n') {
+ Some(offset) => offset,
+ /* Couldn't find \n, wait for more data */
+ None => return Ok(State::Wait),
+ };
+ self.argc_rem = read_data_len(buffer, arr_lf_offset)? as u64;
+ buffer.skip_byte(b'\n').map_err(Error::Buffer)?;
+ }
+
+ /* We already started reading a command. We can process the arguments. */
+ while self.argc_rem > 0 {
+ /* arg_len wasn't initialized yet. We need to find the size of the string argument. */
+ if self.arg_len == -1 {
+ if buffer.peek() != b'$' {
+ return Err(Error::ExpectedString(buffer.peek() as char));
+ }
+
+ let str_lf_offset = match buffer.find_offset(b'\n') {
+ Some(offset) => offset,
+ /* Couldn't find \n, wait for more data */
+ None => return Ok(State::Wait),
+ };
+ self.arg_len = read_data_len(buffer, str_lf_offset)? as i64;
+ buffer.skip_byte(b'\n').map_err(Error::Buffer)?;
+ } else {
+ match buffer.read_count(self.arg_len as usize) {
+ /* Not enough data, wait for more */
+ Err(_) => return Ok(State::Wait),
+ /* Push the new string onto the args list */
+ Ok(arg) => self.argv.push(arg.to_vec()),
+ }
+ buffer.skip_byte(b'\n').map_err(Error::Buffer)?;
+
+ self.argc_rem -= 1;
+ self.arg_len = -1;
+ }
+ }
+
+ let argv = self.argv.clone();
+ self.argv.clear();
+ Ok(State::Done(argv))
+ }
+}
+
+fn read_data_len<const N: usize>(
+ buffer: &mut SizedBuffer<[u8; N]>,
+ n: usize,
+) -> Result<usize, Error> {
+ /* Read from 1 to skip the data tag */
+ let mut len: usize = 0;
+
+ for (i, x) in buffer.read_count_unchecked(n)[1..].iter().rev().enumerate() {
+ if *x < 48 || *x > 57 {
+ return Err(Error::InvalidLength);
+ }
+ len += (x - 48) as usize * 10usize.pow(i as u32);
+ }
+
+ Ok(len)
+}
+
+impl fmt::Display for Error {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ let res = match self {
+ Error::ExpectedArray(got) => format!("Expected array, got {:?}.", *got as char),
+ Error::ExpectedString(got) => format!("Expected string, got {:?}.", *got as char),
+ Error::InvalidLength => "Invalid header size.".to_string(),
+ Error::Buffer(err) => format!("Error read or writing to/from the buffer: {err:?}"),
+ };
+
+ write!(f, "{}", res)
+ }
+}