diff options
Diffstat (limited to 'src/common/expr.rs')
-rw-r--r-- | src/common/expr.rs | 364 |
1 files changed, 364 insertions, 0 deletions
diff --git a/src/common/expr.rs b/src/common/expr.rs new file mode 100644 index 0000000..0a7880b --- /dev/null +++ b/src/common/expr.rs @@ -0,0 +1,364 @@ +#![allow(clippy::from_over_into)] +use crate::common::ip; +use std::convert::Into; +use std::fmt; +use std::io::BufRead; +use std::io::BufReader; +use std::io::Read; +use std::net::TcpStream; +use std::num::ParseIntError; + +#[derive(Debug, Clone, Hash)] +pub enum Expr { + Array(Vec<Expr>), + Blob(Vec<u8>), + Bool(bool), + Error { code: Vec<u8>, msg: Vec<u8> }, + Line(Vec<u8>), + Number(u128), + Table(Vec<(Expr, Expr)>), + Null, +} + +#[derive(Debug)] +pub enum Error { + ConnectionClosed, + InvalidChar(u8), + NotANumber(Vec<u8>), + InvalidTag(u8), + InvalidHeaderSize, + NotEnoughBytes, + ExpectedLF(u8), +} + +impl Expr { + pub fn from_bytes(reader: &mut BufReader<&TcpStream>) -> Result<Self, Error> { + let mut header = vec![]; + + reader + .read_until(b'\n', &mut header) + .map_err(|_| Error::ConnectionClosed)?; + + if header.is_empty() { + return Err(Error::ConnectionClosed); + } + + match header[0] { + b'!' => parse_error(header, reader), + b'#' => parse_bool(header, reader), + b'$' => parse_blob(header, reader), + b'*' => parse_array(header, reader), + b'+' => parse_line(header, reader), + b':' => parse_number(header, reader), + b'_' => Ok(Expr::Null), + _ => Err(Error::InvalidTag(header[0])), + } + } + + pub fn from_query(query: &str) -> Self { + let args = query + .split(' ') + .map(|x| Expr::Blob(x.as_bytes().to_vec())) + .collect(); + Expr::Array(args) + } + + pub fn encode(&self) -> Vec<u8> { + match self { + Expr::Array(arr) => encode_array(arr.clone()), + Expr::Blob(blob) => encode_blob(blob.to_vec()), + Expr::Bool(val) => encode_bool(*val), + Expr::Error { code, msg } => encode_error(code.to_vec(), msg.to_vec()), + Expr::Line(line) => encode_line(line.to_vec()), + Expr::Number(number) => encode_number(*number), + Expr::Null => b"_\n".to_vec(), + Expr::Table(table) => encode_table(table), + } + } +} + +// Display + +impl fmt::Display for Expr { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let res = match self { + Expr::Array(array) => { + format!( + "[{}]", + array + .iter() + .map(|x| x.to_string()) + .collect::<Vec<String>>() + .join(", ") + ) + } + Expr::Bool(value) => { + if *value { + "true".to_string() + } else { + "false".to_string() + } + } + Expr::Blob(blob) => format!("{:?}", String::from_utf8_lossy(blob).into_owned()), + Expr::Number(number) => format!("{:?}", number), + Expr::Error { code, msg } => format!( + "!{} {:?}", + String::from_utf8_lossy(code).into_owned(), + String::from_utf8_lossy(msg).into_owned() + ), + Expr::Line(line) => format!("{:?}", String::from_utf8_lossy(line).into_owned()), + Expr::Null => "(null)".to_string(), + Expr::Table(table) => { + format!( + "[{}]", + table + .iter() + .map(|(k, v)| format!("{}: {}", k, v)) + .collect::<Vec<String>>() + .join(", ") + ) + } + }; + + write!(f, "{}", res) + } +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let res = match self { + Error::ConnectionClosed => "Connection closed.".to_string(), + Error::InvalidChar(got) => format!("Invalid char: {:?}.", *got as char), + Error::NotANumber(got) => format!("Not a number: {:?}.", got), + Error::ExpectedLF(got) => format!("Expected LF, got: {:?}.", *got as char), + Error::InvalidHeaderSize => "Invalid header size.".to_string(), + Error::InvalidTag(got) => format!("Invalid tag {:?}.", *got as char), + Error::NotEnoughBytes => "Not enough bytes.".to_string(), + }; + + write!(f, "{}", res) + } +} + +// Encoders + +fn encode_array(array: Vec<Expr>) -> Vec<u8> { + let mut res = array.len().to_string().as_bytes().to_vec(); + + res.insert(0, b'*'); + res.push(b'\n'); + + for expr in array { + res.append(&mut expr.encode()); + } + + res +} +fn encode_bool(val: bool) -> Vec<u8> { + if val { + b"#t\n".to_vec() + } else { + b"#f\n".to_vec() + } +} +fn encode_blob(mut blob: Vec<u8>) -> Vec<u8> { + let mut res = blob.len().to_string().as_bytes().to_vec(); + + res.insert(0, b'$'); + res.push(b'\n'); + res.append(&mut blob); + res.push(b'\n'); + + res +} + +fn encode_line(mut line: Vec<u8>) -> Vec<u8> { + line.insert(0, b'+'); + line.push(b'\n'); + + line +} + +fn encode_number(number: u128) -> Vec<u8> { + let mut res = vec![b':']; + res.append(&mut number.to_string().as_bytes().to_vec()); + res.push(b'\n'); + + res +} + +fn encode_error(mut code: Vec<u8>, mut msg: Vec<u8>) -> Vec<u8> { + let mut res = (code.len() + msg.len()).to_string().as_bytes().to_vec(); + + res.insert(0, b'!'); + res.push(b'\n'); + res.append(&mut code); + res.push(b' '); + res.append(&mut msg); + res.push(b'\n'); + + res +} + +fn encode_table(table: &Vec<(Expr, Expr)>) -> Vec<u8> { + let mut res = table.len().to_string().as_bytes().to_vec(); + + res.insert(0, b'%'); + res.push(b'\n'); + + for (key, val) in table { + res.append(&mut key.encode()); + res.append(&mut val.encode()); + } + + res +} + +// Parsers + +fn parse_line(header: Vec<u8>, _reader: &mut BufReader<&TcpStream>) -> Result<Expr, Error> { + if header.len() < 2 { + return Err(Error::NotEnoughBytes); + } + + Ok(Expr::Line(header[1..header.len() - 1].to_vec())) +} + +fn parse_number(header: Vec<u8>, _reader: &mut BufReader<&TcpStream>) -> Result<Expr, Error> { + if header.len() < 2 { + return Err(Error::NotEnoughBytes); + } + + let digits = &header[1..header.len() - 1]; + + Ok(Expr::Number( + as_u128(digits).map_err(|_| Error::NotANumber(digits.to_vec()))?, + )) +} + +fn parse_blob(header: Vec<u8>, reader: &mut BufReader<&TcpStream>) -> Result<Expr, Error> { + let size = as_usize(&header[1..header.len() - 1]).map_err(|_| Error::InvalidHeaderSize)?; + + let mut body = vec![0u8; size]; + reader + .read_exact(&mut body) + .map_err(|_| Error::NotEnoughBytes)?; + + skip_lf(reader)?; + + Ok(Expr::Blob(body)) +} + +fn parse_error(header: Vec<u8>, reader: &mut BufReader<&TcpStream>) -> Result<Expr, Error> { + let size = as_usize(&header[1..header.len() - 1]).map_err(|_| Error::InvalidHeaderSize)?; + + let mut body = vec![0u8; size]; + reader + .read_exact(&mut body) + .map_err(|_| Error::NotEnoughBytes)?; + + skip_lf(reader)?; + + let parts: Vec<&[u8]> = body.splitn(2, |x| *x == b' ').collect(); + if parts.len() < 2 { + return Err(Error::NotEnoughBytes); + } + + Ok(Expr::Error { + code: parts[0].to_vec(), + msg: parts[1].to_vec(), + }) +} + +fn parse_array(header: Vec<u8>, reader: &mut BufReader<&TcpStream>) -> Result<Expr, Error> { + let size = as_usize(&header[1..header.len() - 1]).map_err(|_| Error::InvalidHeaderSize)?; + + let body: Vec<Expr> = (0..size) + .map(|_| Expr::from_bytes(reader)) + .collect::<Result<Vec<Expr>, _>>()?; + + Ok(Expr::Array(body)) +} + +fn parse_bool(header: Vec<u8>, _reader: &mut BufReader<&TcpStream>) -> Result<Expr, Error> { + if header.len() < 2 { + return Err(Error::NotEnoughBytes); + } + + match header[1] { + b't' => Ok(Expr::Bool(true)), + b'f' => Ok(Expr::Bool(false)), + other => Err(Error::InvalidChar(other)), + } +} + +// Into + +impl Into<Expr> for ip::Block { + fn into(self) -> Expr { + Expr::Line(self.into()) + } +} + +impl Into<Expr> for &ip::Block { + fn into(self) -> Expr { + Expr::Line(self.into()) + } +} + +impl Into<Expr> for &str { + fn into(self) -> Expr { + Expr::Blob(self.into()) + } +} + +impl Into<Expr> for String { + fn into(self) -> Expr { + Expr::Blob(self.into()) + } +} + +impl Into<Expr> for u8 { + fn into(self) -> Expr { + Expr::Number(self as u128) + } +} + +impl Into<Expr> for Vec<(&str, String)> { + fn into(self) -> Expr { + Expr::Table( + self.into_iter() + .map(|(x, y)| (x.into(), y.into())) + .collect(), + ) + } +} + +impl<T: Into<Expr>> Into<Expr> for Vec<T> { + fn into(self) -> Expr { + Expr::Array(self.into_iter().map(|x| x.into()).collect()) + } +} + +// Helpers + +fn skip_lf(reader: &mut BufReader<&TcpStream>) -> Result<(), Error> { + let mut lf = vec![0u8; 1]; + + reader + .read_exact(&mut lf) + .map_err(|_| Error::NotEnoughBytes)?; + + if lf != vec![b'\n'] { + return Err(Error::ExpectedLF(lf[0])); + } + Ok(()) +} + +fn as_usize(array: &[u8]) -> Result<usize, ParseIntError> { + std::str::from_utf8(array).unwrap().parse::<usize>() +} + +fn as_u128(array: &[u8]) -> Result<u128, ParseIntError> { + std::str::from_utf8(array).unwrap().parse::<u128>() +} |