aboutsummaryrefslogtreecommitdiff
path: root/src/common/expr.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/common/expr.rs')
-rw-r--r--src/common/expr.rs364
1 files changed, 364 insertions, 0 deletions
diff --git a/src/common/expr.rs b/src/common/expr.rs
new file mode 100644
index 0000000..0a7880b
--- /dev/null
+++ b/src/common/expr.rs
@@ -0,0 +1,364 @@
+#![allow(clippy::from_over_into)]
+use crate::common::ip;
+use std::convert::Into;
+use std::fmt;
+use std::io::BufRead;
+use std::io::BufReader;
+use std::io::Read;
+use std::net::TcpStream;
+use std::num::ParseIntError;
+
+#[derive(Debug, Clone, Hash)]
+pub enum Expr {
+ Array(Vec<Expr>),
+ Blob(Vec<u8>),
+ Bool(bool),
+ Error { code: Vec<u8>, msg: Vec<u8> },
+ Line(Vec<u8>),
+ Number(u128),
+ Table(Vec<(Expr, Expr)>),
+ Null,
+}
+
+#[derive(Debug)]
+pub enum Error {
+ ConnectionClosed,
+ InvalidChar(u8),
+ NotANumber(Vec<u8>),
+ InvalidTag(u8),
+ InvalidHeaderSize,
+ NotEnoughBytes,
+ ExpectedLF(u8),
+}
+
+impl Expr {
+ pub fn from_bytes(reader: &mut BufReader<&TcpStream>) -> Result<Self, Error> {
+ let mut header = vec![];
+
+ reader
+ .read_until(b'\n', &mut header)
+ .map_err(|_| Error::ConnectionClosed)?;
+
+ if header.is_empty() {
+ return Err(Error::ConnectionClosed);
+ }
+
+ match header[0] {
+ b'!' => parse_error(header, reader),
+ b'#' => parse_bool(header, reader),
+ b'$' => parse_blob(header, reader),
+ b'*' => parse_array(header, reader),
+ b'+' => parse_line(header, reader),
+ b':' => parse_number(header, reader),
+ b'_' => Ok(Expr::Null),
+ _ => Err(Error::InvalidTag(header[0])),
+ }
+ }
+
+ pub fn from_query(query: &str) -> Self {
+ let args = query
+ .split(' ')
+ .map(|x| Expr::Blob(x.as_bytes().to_vec()))
+ .collect();
+ Expr::Array(args)
+ }
+
+ pub fn encode(&self) -> Vec<u8> {
+ match self {
+ Expr::Array(arr) => encode_array(arr.clone()),
+ Expr::Blob(blob) => encode_blob(blob.to_vec()),
+ Expr::Bool(val) => encode_bool(*val),
+ Expr::Error { code, msg } => encode_error(code.to_vec(), msg.to_vec()),
+ Expr::Line(line) => encode_line(line.to_vec()),
+ Expr::Number(number) => encode_number(*number),
+ Expr::Null => b"_\n".to_vec(),
+ Expr::Table(table) => encode_table(table),
+ }
+ }
+}
+
+// Display
+
+impl fmt::Display for Expr {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ let res = match self {
+ Expr::Array(array) => {
+ format!(
+ "[{}]",
+ array
+ .iter()
+ .map(|x| x.to_string())
+ .collect::<Vec<String>>()
+ .join(", ")
+ )
+ }
+ Expr::Bool(value) => {
+ if *value {
+ "true".to_string()
+ } else {
+ "false".to_string()
+ }
+ }
+ Expr::Blob(blob) => format!("{:?}", String::from_utf8_lossy(blob).into_owned()),
+ Expr::Number(number) => format!("{:?}", number),
+ Expr::Error { code, msg } => format!(
+ "!{} {:?}",
+ String::from_utf8_lossy(code).into_owned(),
+ String::from_utf8_lossy(msg).into_owned()
+ ),
+ Expr::Line(line) => format!("{:?}", String::from_utf8_lossy(line).into_owned()),
+ Expr::Null => "(null)".to_string(),
+ Expr::Table(table) => {
+ format!(
+ "[{}]",
+ table
+ .iter()
+ .map(|(k, v)| format!("{}: {}", k, v))
+ .collect::<Vec<String>>()
+ .join(", ")
+ )
+ }
+ };
+
+ write!(f, "{}", res)
+ }
+}
+
+impl fmt::Display for Error {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ let res = match self {
+ Error::ConnectionClosed => "Connection closed.".to_string(),
+ Error::InvalidChar(got) => format!("Invalid char: {:?}.", *got as char),
+ Error::NotANumber(got) => format!("Not a number: {:?}.", got),
+ Error::ExpectedLF(got) => format!("Expected LF, got: {:?}.", *got as char),
+ Error::InvalidHeaderSize => "Invalid header size.".to_string(),
+ Error::InvalidTag(got) => format!("Invalid tag {:?}.", *got as char),
+ Error::NotEnoughBytes => "Not enough bytes.".to_string(),
+ };
+
+ write!(f, "{}", res)
+ }
+}
+
+// Encoders
+
+fn encode_array(array: Vec<Expr>) -> Vec<u8> {
+ let mut res = array.len().to_string().as_bytes().to_vec();
+
+ res.insert(0, b'*');
+ res.push(b'\n');
+
+ for expr in array {
+ res.append(&mut expr.encode());
+ }
+
+ res
+}
+fn encode_bool(val: bool) -> Vec<u8> {
+ if val {
+ b"#t\n".to_vec()
+ } else {
+ b"#f\n".to_vec()
+ }
+}
+fn encode_blob(mut blob: Vec<u8>) -> Vec<u8> {
+ let mut res = blob.len().to_string().as_bytes().to_vec();
+
+ res.insert(0, b'$');
+ res.push(b'\n');
+ res.append(&mut blob);
+ res.push(b'\n');
+
+ res
+}
+
+fn encode_line(mut line: Vec<u8>) -> Vec<u8> {
+ line.insert(0, b'+');
+ line.push(b'\n');
+
+ line
+}
+
+fn encode_number(number: u128) -> Vec<u8> {
+ let mut res = vec![b':'];
+ res.append(&mut number.to_string().as_bytes().to_vec());
+ res.push(b'\n');
+
+ res
+}
+
+fn encode_error(mut code: Vec<u8>, mut msg: Vec<u8>) -> Vec<u8> {
+ let mut res = (code.len() + msg.len()).to_string().as_bytes().to_vec();
+
+ res.insert(0, b'!');
+ res.push(b'\n');
+ res.append(&mut code);
+ res.push(b' ');
+ res.append(&mut msg);
+ res.push(b'\n');
+
+ res
+}
+
+fn encode_table(table: &Vec<(Expr, Expr)>) -> Vec<u8> {
+ let mut res = table.len().to_string().as_bytes().to_vec();
+
+ res.insert(0, b'%');
+ res.push(b'\n');
+
+ for (key, val) in table {
+ res.append(&mut key.encode());
+ res.append(&mut val.encode());
+ }
+
+ res
+}
+
+// Parsers
+
+fn parse_line(header: Vec<u8>, _reader: &mut BufReader<&TcpStream>) -> Result<Expr, Error> {
+ if header.len() < 2 {
+ return Err(Error::NotEnoughBytes);
+ }
+
+ Ok(Expr::Line(header[1..header.len() - 1].to_vec()))
+}
+
+fn parse_number(header: Vec<u8>, _reader: &mut BufReader<&TcpStream>) -> Result<Expr, Error> {
+ if header.len() < 2 {
+ return Err(Error::NotEnoughBytes);
+ }
+
+ let digits = &header[1..header.len() - 1];
+
+ Ok(Expr::Number(
+ as_u128(digits).map_err(|_| Error::NotANumber(digits.to_vec()))?,
+ ))
+}
+
+fn parse_blob(header: Vec<u8>, reader: &mut BufReader<&TcpStream>) -> Result<Expr, Error> {
+ let size = as_usize(&header[1..header.len() - 1]).map_err(|_| Error::InvalidHeaderSize)?;
+
+ let mut body = vec![0u8; size];
+ reader
+ .read_exact(&mut body)
+ .map_err(|_| Error::NotEnoughBytes)?;
+
+ skip_lf(reader)?;
+
+ Ok(Expr::Blob(body))
+}
+
+fn parse_error(header: Vec<u8>, reader: &mut BufReader<&TcpStream>) -> Result<Expr, Error> {
+ let size = as_usize(&header[1..header.len() - 1]).map_err(|_| Error::InvalidHeaderSize)?;
+
+ let mut body = vec![0u8; size];
+ reader
+ .read_exact(&mut body)
+ .map_err(|_| Error::NotEnoughBytes)?;
+
+ skip_lf(reader)?;
+
+ let parts: Vec<&[u8]> = body.splitn(2, |x| *x == b' ').collect();
+ if parts.len() < 2 {
+ return Err(Error::NotEnoughBytes);
+ }
+
+ Ok(Expr::Error {
+ code: parts[0].to_vec(),
+ msg: parts[1].to_vec(),
+ })
+}
+
+fn parse_array(header: Vec<u8>, reader: &mut BufReader<&TcpStream>) -> Result<Expr, Error> {
+ let size = as_usize(&header[1..header.len() - 1]).map_err(|_| Error::InvalidHeaderSize)?;
+
+ let body: Vec<Expr> = (0..size)
+ .map(|_| Expr::from_bytes(reader))
+ .collect::<Result<Vec<Expr>, _>>()?;
+
+ Ok(Expr::Array(body))
+}
+
+fn parse_bool(header: Vec<u8>, _reader: &mut BufReader<&TcpStream>) -> Result<Expr, Error> {
+ if header.len() < 2 {
+ return Err(Error::NotEnoughBytes);
+ }
+
+ match header[1] {
+ b't' => Ok(Expr::Bool(true)),
+ b'f' => Ok(Expr::Bool(false)),
+ other => Err(Error::InvalidChar(other)),
+ }
+}
+
+// Into
+
+impl Into<Expr> for ip::Block {
+ fn into(self) -> Expr {
+ Expr::Line(self.into())
+ }
+}
+
+impl Into<Expr> for &ip::Block {
+ fn into(self) -> Expr {
+ Expr::Line(self.into())
+ }
+}
+
+impl Into<Expr> for &str {
+ fn into(self) -> Expr {
+ Expr::Blob(self.into())
+ }
+}
+
+impl Into<Expr> for String {
+ fn into(self) -> Expr {
+ Expr::Blob(self.into())
+ }
+}
+
+impl Into<Expr> for u8 {
+ fn into(self) -> Expr {
+ Expr::Number(self as u128)
+ }
+}
+
+impl Into<Expr> for Vec<(&str, String)> {
+ fn into(self) -> Expr {
+ Expr::Table(
+ self.into_iter()
+ .map(|(x, y)| (x.into(), y.into()))
+ .collect(),
+ )
+ }
+}
+
+impl<T: Into<Expr>> Into<Expr> for Vec<T> {
+ fn into(self) -> Expr {
+ Expr::Array(self.into_iter().map(|x| x.into()).collect())
+ }
+}
+
+// Helpers
+
+fn skip_lf(reader: &mut BufReader<&TcpStream>) -> Result<(), Error> {
+ let mut lf = vec![0u8; 1];
+
+ reader
+ .read_exact(&mut lf)
+ .map_err(|_| Error::NotEnoughBytes)?;
+
+ if lf != vec![b'\n'] {
+ return Err(Error::ExpectedLF(lf[0]));
+ }
+ Ok(())
+}
+
+fn as_usize(array: &[u8]) -> Result<usize, ParseIntError> {
+ std::str::from_utf8(array).unwrap().parse::<usize>()
+}
+
+fn as_u128(array: &[u8]) -> Result<u128, ParseIntError> {
+ std::str::from_utf8(array).unwrap().parse::<u128>()
+}