use super::parser::Error as ParserError; use super::parser::Parser; use super::parser::State as ParserState; use crate::common::db::Database; use crate::common::expr::Expr; use crate::common::query::Error as QueryError; use crate::common::query::Query; use crate::common::sized_buffer::SizedBuffer; use crate::server::CONN_BUFLEN; use log::debug; use log::error; use mio::net::TcpStream; use mio::Token; use std::fmt; use std::fmt::Display; use std::fmt::Formatter; use std::io::ErrorKind; use std::io::Read; use std::io::Write; #[derive(Debug)] pub enum Error { IO(std::io::Error), Parser(ParserError), Query(QueryError), Fatal, } #[derive(Eq, PartialEq, Debug)] #[repr(u8)] pub enum State { PendingExec = 0b100, ReadReady = 0b010, WriteReady = 0b001, } pub struct Conn { pub flags: u8, socket: TcpStream, token: Token, parser: Parser, buf_in: SizedBuffer<[u8; CONN_BUFLEN]>, buf_out: Vec, } impl Conn { pub fn new(socket: TcpStream, token: Token) -> Self { Self { flags: State::ReadReady as u8, socket, token, parser: Parser::new(), buf_in: SizedBuffer::new(), buf_out: Vec::new(), } } pub fn handle_pending(&mut self, db: &mut Database) -> Result<(), Error> { self.handle_msg(db, &[]) } pub fn handle_msg(&mut self, db: &mut Database, msg: &[u8]) -> Result<(), Error> { if let Err(err) = self.buf_in.append(&mut msg.to_vec()) { error!("[{self}] ! Error while writing to the buffer: {err:?}"); return Err(Error::Fatal); } match self.parser.parse(&mut self.buf_in) { Ok(ParserState::Wait) => Ok(()), Ok(ParserState::Done(argv)) => { Query::from_expr(Expr::Array( argv.iter().map(|x| Expr::Blob(x.to_vec())).collect(), )) .map(|query| { debug!("[{self}] < {query}"); let resp = query.exec(db); self.draft(&resp.encode()); if self.buf_in.empty() { self.buf_in.reset(); self.flags &= !(State::PendingExec as u8); } else { self.flags &= !(State::ReadReady as u8); self.flags |= State::PendingExec as u8; } debug!("[{self}] > {resp}"); }) .map_err(|e| { debug!("[{self}] ! {e:?}"); let wrapped_err = Error::Query(e); let resp = Expr::Error { code: error_code(&wrapped_err), msg: format!("{}", wrapped_err).into_bytes(), }; self.draft(&resp.encode()); if self.buf_in.empty() { self.buf_in.reset(); self.flags &= !(State::PendingExec as u8); } else { self.flags &= !(State::ReadReady as u8); self.flags |= State::PendingExec as u8; } wrapped_err })?; Ok(()) } Err(err) => { let wrapped_err = Error::Parser(err); debug!("[{self}] ! {wrapped_err}"); let resp = Expr::Error { code: error_code(&wrapped_err), msg: format!("{}", wrapped_err).into_bytes(), }; self.draft(&resp.encode()); /* If there's a parsing error, we throw away the entire buffer. This means * potentially throwing away valid commands appearing later in the buffer */ self.buf_in.reset(); self.flags &= !(State::PendingExec as u8); Err(wrapped_err) } } } fn draft(&mut self, msg: &[u8]) { self.buf_out = msg.to_vec(); self.flags = State::WriteReady as u8; } pub fn reply(&mut self) -> Result<(), Error> { assert!(self.write_ready()); match self.socket.write(&self.buf_out) { Ok(wrote_len) if wrote_len == self.buf_out.len() => { self.flags &= !(State::WriteReady as u8); /* Full write, we're ready to accept more data. If there are still queries waiting * to be processed in the buffer, wait until these are processed first. */ if !self.pending_exec() { self.flags |= State::ReadReady as u8; } if let Err(e) = self.socket.flush() { return Err(Error::IO(e)); } Ok(()) } Ok(_wrote_len) => { /* Partial write, track the remaining length */ self.flags |= State::WriteReady as u8; Ok(()) } Err(e) if e.kind() == ErrorKind::BrokenPipe => Err(Error::IO(e)), Err(e) if e.kind() == ErrorKind::ConnectionReset => Err(Error::IO(e)), Err(e) => panic!("FATAL | Unexpected error: {e:?}"), } } pub fn pending_exec(&self) -> bool { self.flags & State::PendingExec as u8 > 0 } pub fn read_ready(&self) -> bool { self.flags & State::ReadReady as u8 > 0 } pub fn write_ready(&self) -> bool { self.flags & State::WriteReady as u8 > 0 } } impl Read for Conn { fn read(&mut self, buf: &mut [u8]) -> Result { assert!(self.read_ready()); self.socket.read(buf) } } impl Write for Conn { fn write(&mut self, buf: &[u8]) -> Result { self.socket.write(buf) } fn flush(&mut self) -> Result<(), std::io::Error> { self.socket.flush() } } impl Display for Conn { fn fmt(&self, f: &mut Formatter) -> fmt::Result { match self.socket.peer_addr() { Ok(addr) => write!(f, "{addr} #{}", self.token.0), Err(_) => write!(f, "unknown-addr #{}", self.token.0), } } } // Errors impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { let res = match self { Error::IO(err) => format!("IO error: {err:?}"), Error::Parser(err) => err.to_string(), Error::Query(err) => err.to_string(), Error::Fatal => "Fatal error".to_string(), }; write!(f, "{}", res) } } fn error_code(error: &Error) -> Vec { let code = match error { Error::IO(_) => "SERVER", Error::Parser(_) => "PROTOCOL", Error::Query(_) => "QUERY", Error::Fatal => "SERVER", }; code.as_bytes().to_vec() }