From 86098797034cbc7eb6db0cee54e17f8dcaedbc5d Mon Sep 17 00:00:00 2001 From: evuez Date: Sat, 26 Nov 2022 15:38:06 -0500 Subject: Initial commit --- src/server/conn.rs | 230 +++++++++++++++++++++++++++++++++++++++++++++++++++ src/server/parser.rs | 118 ++++++++++++++++++++++++++ 2 files changed, 348 insertions(+) create mode 100644 src/server/conn.rs create mode 100644 src/server/parser.rs (limited to 'src/server') diff --git a/src/server/conn.rs b/src/server/conn.rs new file mode 100644 index 0000000..7ea5b0d --- /dev/null +++ b/src/server/conn.rs @@ -0,0 +1,230 @@ +use super::parser::Error as ParserError; +use super::parser::Parser; +use super::parser::State as ParserState; +use crate::common::db::Database; +use crate::common::expr::Expr; +use crate::common::query::Error as QueryError; +use crate::common::query::Query; +use crate::common::sized_buffer::SizedBuffer; +use crate::server::CONN_BUFLEN; +use log::debug; +use log::error; + +use mio::net::TcpStream; +use mio::Token; +use std::fmt; +use std::fmt::Display; +use std::fmt::Formatter; +use std::io::ErrorKind; +use std::io::Read; +use std::io::Write; + +#[derive(Debug)] +pub enum Error { + IO(std::io::Error), + Parser(ParserError), + Query(QueryError), + Fatal, +} + +#[derive(Eq, PartialEq, Debug)] +#[repr(u8)] +pub enum State { + PendingExec = 0b100, + ReadReady = 0b010, + WriteReady = 0b001, +} + +pub struct Conn { + pub flags: u8, + socket: TcpStream, + token: Token, + parser: Parser, + buf_in: SizedBuffer<[u8; CONN_BUFLEN]>, + buf_out: Vec, +} + +impl Conn { + pub fn new(socket: TcpStream, token: Token) -> Self { + Self { + flags: State::ReadReady as u8, + socket, + token, + parser: Parser::new(), + buf_in: SizedBuffer::new(), + buf_out: Vec::new(), + } + } + + pub fn handle_pending(&mut self, db: &mut Database) -> Result<(), Error> { + self.handle_msg(db, &[]) + } + + pub fn handle_msg(&mut self, db: &mut Database, msg: &[u8]) -> Result<(), Error> { + if let Err(err) = self.buf_in.append(&mut msg.to_vec()) { + error!("[{self}] ! Error while writing to the buffer: {err:?}"); + return Err(Error::Fatal); + } + + match self.parser.parse(&mut self.buf_in) { + Ok(ParserState::Wait) => Ok(()), + Ok(ParserState::Done(argv)) => { + Query::from_expr(Expr::Array( + argv.iter().map(|x| Expr::Blob(x.to_vec())).collect(), + )) + .map(|query| { + debug!("[{self}] < {query}"); + let resp = query.exec(db); + + self.draft(&resp.encode()); + if self.buf_in.empty() { + self.buf_in.reset(); + self.flags &= !(State::PendingExec as u8); + } else { + self.flags &= !(State::ReadReady as u8); + self.flags |= State::PendingExec as u8; + } + + debug!("[{self}] > {resp}"); + }) + .map_err(|e| { + debug!("[{self}] ! {e:?}"); + let wrapped_err = Error::Query(e); + let resp = Expr::Error { + code: error_code(&wrapped_err), + msg: format!("{}", wrapped_err).into_bytes(), + }; + + self.draft(&resp.encode()); + + if self.buf_in.empty() { + self.buf_in.reset(); + self.flags &= !(State::PendingExec as u8); + } else { + self.flags &= !(State::ReadReady as u8); + self.flags |= State::PendingExec as u8; + } + wrapped_err + })?; + + Ok(()) + } + Err(err) => { + let wrapped_err = Error::Parser(err); + + debug!("[{self}] ! {wrapped_err}"); + + let resp = Expr::Error { + code: error_code(&wrapped_err), + msg: format!("{}", wrapped_err).into_bytes(), + }; + + self.draft(&resp.encode()); + + /* If there's a parsing error, we throw away the entire buffer. This means + * potentially throwing away valid commands appearing later in the buffer */ + self.buf_in.reset(); + self.flags &= !(State::PendingExec as u8); + + Err(wrapped_err) + } + } + } + + fn draft(&mut self, msg: &[u8]) { + self.buf_out = msg.to_vec(); + self.flags = State::WriteReady as u8; + } + + pub fn reply(&mut self) -> Result<(), Error> { + assert!(self.write_ready()); + + match self.socket.write(&self.buf_out) { + Ok(wrote_len) if wrote_len == self.buf_out.len() => { + self.flags &= !(State::WriteReady as u8); + + /* Full write, we're ready to accept more data. If there are still queries waiting + * to be processed in the buffer, wait until these are processed first. */ + if !self.pending_exec() { + self.flags |= State::ReadReady as u8; + } + + if let Err(e) = self.socket.flush() { + return Err(Error::IO(e)); + } + + Ok(()) + } + Ok(_wrote_len) => { + /* Partial write, track the remaining length */ + self.flags |= State::WriteReady as u8; + Ok(()) + } + Err(e) if e.kind() == ErrorKind::BrokenPipe => Err(Error::IO(e)), + Err(e) if e.kind() == ErrorKind::ConnectionReset => Err(Error::IO(e)), + Err(e) => panic!("FATAL | Unexpected error: {e:?}"), + } + } + + pub fn pending_exec(&self) -> bool { + self.flags & State::PendingExec as u8 > 0 + } + pub fn read_ready(&self) -> bool { + self.flags & State::ReadReady as u8 > 0 + } + pub fn write_ready(&self) -> bool { + self.flags & State::WriteReady as u8 > 0 + } +} + +impl Read for Conn { + fn read(&mut self, buf: &mut [u8]) -> Result { + assert!(self.read_ready()); + self.socket.read(buf) + } +} + +impl Write for Conn { + fn write(&mut self, buf: &[u8]) -> Result { + self.socket.write(buf) + } + + fn flush(&mut self) -> Result<(), std::io::Error> { + self.socket.flush() + } +} + +impl Display for Conn { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + match self.socket.peer_addr() { + Ok(addr) => write!(f, "{addr} #{}", self.token.0), + Err(_) => write!(f, "unknown-addr #{}", self.token.0), + } + } +} + +// Errors + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let res = match self { + Error::IO(err) => format!("IO error: {err:?}"), + Error::Parser(err) => err.to_string(), + Error::Query(err) => err.to_string(), + Error::Fatal => "Fatal error".to_string(), + }; + + write!(f, "{}", res) + } +} + +fn error_code(error: &Error) -> Vec { + let code = match error { + Error::IO(_) => "SERVER", + Error::Parser(_) => "PROTOCOL", + Error::Query(_) => "QUERY", + Error::Fatal => "SERVER", + }; + + code.as_bytes().to_vec() +} diff --git a/src/server/parser.rs b/src/server/parser.rs new file mode 100644 index 0000000..e9c2aad --- /dev/null +++ b/src/server/parser.rs @@ -0,0 +1,118 @@ +use crate::common::sized_buffer::Error as SizedBufferError; +use crate::common::sized_buffer::SizedBuffer; +use std::fmt; + +#[derive(Debug)] +pub enum Error { + ExpectedArray(char), + ExpectedString(char), + InvalidLength, + Buffer(SizedBufferError), +} + +#[derive(Debug)] +pub enum State { + Wait, + Done(Vec>), +} + +pub struct Parser { + pub argv: Vec>, + argc_rem: u64, + arg_len: i64, +} + +impl Parser { + pub fn new() -> Self { + Self { + argv: Vec::new(), + argc_rem: 0, + arg_len: -1, + } + } + + pub fn parse( + &mut self, + buffer: &mut SizedBuffer<[u8; N]>, + ) -> Result { + if self.argc_rem == 0 { + assert_eq!(self.arg_len, -1); + + /* argc_rem wasn't initialized yet. The current buffer was never read, we need to + * figure out the number of args. */ + if buffer.peek() != b'*' { + return Err(Error::ExpectedArray(buffer.peek() as char)); + } + let arr_lf_offset = match buffer.find_offset(b'\n') { + Some(offset) => offset, + /* Couldn't find \n, wait for more data */ + None => return Ok(State::Wait), + }; + self.argc_rem = read_data_len(buffer, arr_lf_offset)? as u64; + buffer.skip_byte(b'\n').map_err(Error::Buffer)?; + } + + /* We already started reading a command. We can process the arguments. */ + while self.argc_rem > 0 { + /* arg_len wasn't initialized yet. We need to find the size of the string argument. */ + if self.arg_len == -1 { + if buffer.peek() != b'$' { + return Err(Error::ExpectedString(buffer.peek() as char)); + } + + let str_lf_offset = match buffer.find_offset(b'\n') { + Some(offset) => offset, + /* Couldn't find \n, wait for more data */ + None => return Ok(State::Wait), + }; + self.arg_len = read_data_len(buffer, str_lf_offset)? as i64; + buffer.skip_byte(b'\n').map_err(Error::Buffer)?; + } else { + match buffer.read_count(self.arg_len as usize) { + /* Not enough data, wait for more */ + Err(_) => return Ok(State::Wait), + /* Push the new string onto the args list */ + Ok(arg) => self.argv.push(arg.to_vec()), + } + buffer.skip_byte(b'\n').map_err(Error::Buffer)?; + + self.argc_rem -= 1; + self.arg_len = -1; + } + } + + let argv = self.argv.clone(); + self.argv.clear(); + Ok(State::Done(argv)) + } +} + +fn read_data_len( + buffer: &mut SizedBuffer<[u8; N]>, + n: usize, +) -> Result { + /* Read from 1 to skip the data tag */ + let mut len: usize = 0; + + for (i, x) in buffer.read_count_unchecked(n)[1..].iter().rev().enumerate() { + if *x < 48 || *x > 57 { + return Err(Error::InvalidLength); + } + len += (x - 48) as usize * 10usize.pow(i as u32); + } + + Ok(len) +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let res = match self { + Error::ExpectedArray(got) => format!("Expected array, got {:?}.", *got as char), + Error::ExpectedString(got) => format!("Expected string, got {:?}.", *got as char), + Error::InvalidLength => "Invalid header size.".to_string(), + Error::Buffer(err) => format!("Error read or writing to/from the buffer: {err:?}"), + }; + + write!(f, "{}", res) + } +} -- cgit v1.2.3