aboutsummaryrefslogtreecommitdiff
path: root/src/server/conn.rs
diff options
context:
space:
mode:
authorevuez <julien@mulga.net>2022-11-26 15:38:06 -0500
committerevuez <julien@mulga.net>2024-04-03 22:44:12 +0200
commit86098797034cbc7eb6db0cee54e17f8dcaedbc5d (patch)
tree29b6225ead843eb9022296a54657bbadfa1c4da0 /src/server/conn.rs
downloadblom-main.tar.gz
Initial commitHEADmain
Diffstat (limited to 'src/server/conn.rs')
-rw-r--r--src/server/conn.rs230
1 files changed, 230 insertions, 0 deletions
diff --git a/src/server/conn.rs b/src/server/conn.rs
new file mode 100644
index 0000000..7ea5b0d
--- /dev/null
+++ b/src/server/conn.rs
@@ -0,0 +1,230 @@
+use super::parser::Error as ParserError;
+use super::parser::Parser;
+use super::parser::State as ParserState;
+use crate::common::db::Database;
+use crate::common::expr::Expr;
+use crate::common::query::Error as QueryError;
+use crate::common::query::Query;
+use crate::common::sized_buffer::SizedBuffer;
+use crate::server::CONN_BUFLEN;
+use log::debug;
+use log::error;
+
+use mio::net::TcpStream;
+use mio::Token;
+use std::fmt;
+use std::fmt::Display;
+use std::fmt::Formatter;
+use std::io::ErrorKind;
+use std::io::Read;
+use std::io::Write;
+
+#[derive(Debug)]
+pub enum Error {
+ IO(std::io::Error),
+ Parser(ParserError),
+ Query(QueryError),
+ Fatal,
+}
+
+#[derive(Eq, PartialEq, Debug)]
+#[repr(u8)]
+pub enum State {
+ PendingExec = 0b100,
+ ReadReady = 0b010,
+ WriteReady = 0b001,
+}
+
+pub struct Conn {
+ pub flags: u8,
+ socket: TcpStream,
+ token: Token,
+ parser: Parser,
+ buf_in: SizedBuffer<[u8; CONN_BUFLEN]>,
+ buf_out: Vec<u8>,
+}
+
+impl Conn {
+ pub fn new(socket: TcpStream, token: Token) -> Self {
+ Self {
+ flags: State::ReadReady as u8,
+ socket,
+ token,
+ parser: Parser::new(),
+ buf_in: SizedBuffer::new(),
+ buf_out: Vec::new(),
+ }
+ }
+
+ pub fn handle_pending(&mut self, db: &mut Database) -> Result<(), Error> {
+ self.handle_msg(db, &[])
+ }
+
+ pub fn handle_msg(&mut self, db: &mut Database, msg: &[u8]) -> Result<(), Error> {
+ if let Err(err) = self.buf_in.append(&mut msg.to_vec()) {
+ error!("[{self}] ! Error while writing to the buffer: {err:?}");
+ return Err(Error::Fatal);
+ }
+
+ match self.parser.parse(&mut self.buf_in) {
+ Ok(ParserState::Wait) => Ok(()),
+ Ok(ParserState::Done(argv)) => {
+ Query::from_expr(Expr::Array(
+ argv.iter().map(|x| Expr::Blob(x.to_vec())).collect(),
+ ))
+ .map(|query| {
+ debug!("[{self}] < {query}");
+ let resp = query.exec(db);
+
+ self.draft(&resp.encode());
+ if self.buf_in.empty() {
+ self.buf_in.reset();
+ self.flags &= !(State::PendingExec as u8);
+ } else {
+ self.flags &= !(State::ReadReady as u8);
+ self.flags |= State::PendingExec as u8;
+ }
+
+ debug!("[{self}] > {resp}");
+ })
+ .map_err(|e| {
+ debug!("[{self}] ! {e:?}");
+ let wrapped_err = Error::Query(e);
+ let resp = Expr::Error {
+ code: error_code(&wrapped_err),
+ msg: format!("{}", wrapped_err).into_bytes(),
+ };
+
+ self.draft(&resp.encode());
+
+ if self.buf_in.empty() {
+ self.buf_in.reset();
+ self.flags &= !(State::PendingExec as u8);
+ } else {
+ self.flags &= !(State::ReadReady as u8);
+ self.flags |= State::PendingExec as u8;
+ }
+ wrapped_err
+ })?;
+
+ Ok(())
+ }
+ Err(err) => {
+ let wrapped_err = Error::Parser(err);
+
+ debug!("[{self}] ! {wrapped_err}");
+
+ let resp = Expr::Error {
+ code: error_code(&wrapped_err),
+ msg: format!("{}", wrapped_err).into_bytes(),
+ };
+
+ self.draft(&resp.encode());
+
+ /* If there's a parsing error, we throw away the entire buffer. This means
+ * potentially throwing away valid commands appearing later in the buffer */
+ self.buf_in.reset();
+ self.flags &= !(State::PendingExec as u8);
+
+ Err(wrapped_err)
+ }
+ }
+ }
+
+ fn draft(&mut self, msg: &[u8]) {
+ self.buf_out = msg.to_vec();
+ self.flags = State::WriteReady as u8;
+ }
+
+ pub fn reply(&mut self) -> Result<(), Error> {
+ assert!(self.write_ready());
+
+ match self.socket.write(&self.buf_out) {
+ Ok(wrote_len) if wrote_len == self.buf_out.len() => {
+ self.flags &= !(State::WriteReady as u8);
+
+ /* Full write, we're ready to accept more data. If there are still queries waiting
+ * to be processed in the buffer, wait until these are processed first. */
+ if !self.pending_exec() {
+ self.flags |= State::ReadReady as u8;
+ }
+
+ if let Err(e) = self.socket.flush() {
+ return Err(Error::IO(e));
+ }
+
+ Ok(())
+ }
+ Ok(_wrote_len) => {
+ /* Partial write, track the remaining length */
+ self.flags |= State::WriteReady as u8;
+ Ok(())
+ }
+ Err(e) if e.kind() == ErrorKind::BrokenPipe => Err(Error::IO(e)),
+ Err(e) if e.kind() == ErrorKind::ConnectionReset => Err(Error::IO(e)),
+ Err(e) => panic!("FATAL | Unexpected error: {e:?}"),
+ }
+ }
+
+ pub fn pending_exec(&self) -> bool {
+ self.flags & State::PendingExec as u8 > 0
+ }
+ pub fn read_ready(&self) -> bool {
+ self.flags & State::ReadReady as u8 > 0
+ }
+ pub fn write_ready(&self) -> bool {
+ self.flags & State::WriteReady as u8 > 0
+ }
+}
+
+impl Read for Conn {
+ fn read(&mut self, buf: &mut [u8]) -> Result<usize, std::io::Error> {
+ assert!(self.read_ready());
+ self.socket.read(buf)
+ }
+}
+
+impl Write for Conn {
+ fn write(&mut self, buf: &[u8]) -> Result<usize, std::io::Error> {
+ self.socket.write(buf)
+ }
+
+ fn flush(&mut self) -> Result<(), std::io::Error> {
+ self.socket.flush()
+ }
+}
+
+impl Display for Conn {
+ fn fmt(&self, f: &mut Formatter) -> fmt::Result {
+ match self.socket.peer_addr() {
+ Ok(addr) => write!(f, "{addr} #{}", self.token.0),
+ Err(_) => write!(f, "unknown-addr #{}", self.token.0),
+ }
+ }
+}
+
+// Errors
+
+impl fmt::Display for Error {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ let res = match self {
+ Error::IO(err) => format!("IO error: {err:?}"),
+ Error::Parser(err) => err.to_string(),
+ Error::Query(err) => err.to_string(),
+ Error::Fatal => "Fatal error".to_string(),
+ };
+
+ write!(f, "{}", res)
+ }
+}
+
+fn error_code(error: &Error) -> Vec<u8> {
+ let code = match error {
+ Error::IO(_) => "SERVER",
+ Error::Parser(_) => "PROTOCOL",
+ Error::Query(_) => "QUERY",
+ Error::Fatal => "SERVER",
+ };
+
+ code.as_bytes().to_vec()
+}