aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorevuez <julien@mulga.net>2022-11-26 15:38:06 -0500
committerevuez <julien@mulga.net>2024-04-03 22:44:12 +0200
commit86098797034cbc7eb6db0cee54e17f8dcaedbc5d (patch)
tree29b6225ead843eb9022296a54657bbadfa1c4da0 /src
downloadblom-main.tar.gz
Initial commitHEADmain
Diffstat (limited to 'src')
-rw-r--r--src/common/db.rs89
-rw-r--r--src/common/expr.rs364
-rw-r--r--src/common/ip.rs195
-rw-r--r--src/common/mod.rs7
-rw-r--r--src/common/query.rs252
-rw-r--r--src/common/sized_buffer.rs113
-rw-r--r--src/common/term.rs517
-rw-r--r--src/common/trie.rs151
-rw-r--r--src/help.rs43
-rw-r--r--src/main.rs37
-rw-r--r--src/query.rs97
-rw-r--r--src/repl.rs134
-rw-r--r--src/server.rs291
-rw-r--r--src/server/conn.rs230
-rw-r--r--src/server/parser.rs118
15 files changed, 2638 insertions, 0 deletions
diff --git a/src/common/db.rs b/src/common/db.rs
new file mode 100644
index 0000000..77a1752
--- /dev/null
+++ b/src/common/db.rs
@@ -0,0 +1,89 @@
+#![allow(clippy::from_over_into)]
+use crate::common::ip::Addr;
+use crate::common::ip::Block;
+use crate::common::ip::Meta;
+use crate::common::trie;
+use std::collections::HashMap;
+
+pub struct Database {
+ pub trie: trie::Trie,
+ pub meta: HashMap<String, String>,
+}
+
+pub struct BlockInfo {
+ exists: bool,
+ parent: Option<Block>,
+ child_count: usize,
+ start: Addr,
+ end: Addr,
+ size: u128,
+}
+
+impl Database {
+ pub fn new() -> Self {
+ Database {
+ trie: trie::Trie::new(),
+ meta: HashMap::new(),
+ }
+ }
+
+ pub fn get(&self, block: &Block) -> Option<Block> {
+ self.trie.find(block)
+ }
+
+ /// Sets `meta` on the given `block`. Replaces the metadata if `block` already exists.
+ /// Creates the block if it doesn't exist.
+ pub fn set(&mut self, block: Block, meta: Meta) {
+ self.trie.insert(block, meta);
+ }
+
+ /// Returns the parent of the given block.
+ ///
+ /// Returns `None` if the block has no parents.
+ pub fn parent(&self, block: &Block) -> Option<Block> {
+ self.trie.parent(block)
+ }
+
+ /// Returns the children of the given block.
+ ///
+ /// Returns an empty `Vec` if the given block has no children.
+ pub fn children(&self, block: &Block) -> Vec<Block> {
+ self.trie.children(block)
+ }
+
+ pub fn info(&self, block: &Block) -> BlockInfo {
+ BlockInfo {
+ exists: self.exists(block),
+ parent: self.trie.parent(block),
+ child_count: self.trie.children(block).len(),
+ start: block.start(),
+ end: block.end(),
+ size: block.size(),
+ }
+ }
+
+ pub fn del(&self, _block: &Block) -> Option<()> {
+ None
+ }
+
+ pub fn exists(&self, block: &Block) -> bool {
+ self.trie.find(block).is_some()
+ }
+}
+
+// Into
+
+impl Into<Vec<(&str, String)>> for BlockInfo {
+ fn into(self) -> Vec<(&'static str, String)> {
+ vec![
+ ("EXISTS", self.exists.to_string()),
+ (
+ "PARENT",
+ self.parent.map(|x| x.to_string()).unwrap_or_default(),
+ ),
+ ("CHILD-COUNT", self.child_count.to_string()),
+ ("RANGE", format!("{} - {}", self.start, self.end)),
+ ("SIZE", self.size.to_string()),
+ ]
+ }
+}
diff --git a/src/common/expr.rs b/src/common/expr.rs
new file mode 100644
index 0000000..0a7880b
--- /dev/null
+++ b/src/common/expr.rs
@@ -0,0 +1,364 @@
+#![allow(clippy::from_over_into)]
+use crate::common::ip;
+use std::convert::Into;
+use std::fmt;
+use std::io::BufRead;
+use std::io::BufReader;
+use std::io::Read;
+use std::net::TcpStream;
+use std::num::ParseIntError;
+
+#[derive(Debug, Clone, Hash)]
+pub enum Expr {
+ Array(Vec<Expr>),
+ Blob(Vec<u8>),
+ Bool(bool),
+ Error { code: Vec<u8>, msg: Vec<u8> },
+ Line(Vec<u8>),
+ Number(u128),
+ Table(Vec<(Expr, Expr)>),
+ Null,
+}
+
+#[derive(Debug)]
+pub enum Error {
+ ConnectionClosed,
+ InvalidChar(u8),
+ NotANumber(Vec<u8>),
+ InvalidTag(u8),
+ InvalidHeaderSize,
+ NotEnoughBytes,
+ ExpectedLF(u8),
+}
+
+impl Expr {
+ pub fn from_bytes(reader: &mut BufReader<&TcpStream>) -> Result<Self, Error> {
+ let mut header = vec![];
+
+ reader
+ .read_until(b'\n', &mut header)
+ .map_err(|_| Error::ConnectionClosed)?;
+
+ if header.is_empty() {
+ return Err(Error::ConnectionClosed);
+ }
+
+ match header[0] {
+ b'!' => parse_error(header, reader),
+ b'#' => parse_bool(header, reader),
+ b'$' => parse_blob(header, reader),
+ b'*' => parse_array(header, reader),
+ b'+' => parse_line(header, reader),
+ b':' => parse_number(header, reader),
+ b'_' => Ok(Expr::Null),
+ _ => Err(Error::InvalidTag(header[0])),
+ }
+ }
+
+ pub fn from_query(query: &str) -> Self {
+ let args = query
+ .split(' ')
+ .map(|x| Expr::Blob(x.as_bytes().to_vec()))
+ .collect();
+ Expr::Array(args)
+ }
+
+ pub fn encode(&self) -> Vec<u8> {
+ match self {
+ Expr::Array(arr) => encode_array(arr.clone()),
+ Expr::Blob(blob) => encode_blob(blob.to_vec()),
+ Expr::Bool(val) => encode_bool(*val),
+ Expr::Error { code, msg } => encode_error(code.to_vec(), msg.to_vec()),
+ Expr::Line(line) => encode_line(line.to_vec()),
+ Expr::Number(number) => encode_number(*number),
+ Expr::Null => b"_\n".to_vec(),
+ Expr::Table(table) => encode_table(table),
+ }
+ }
+}
+
+// Display
+
+impl fmt::Display for Expr {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ let res = match self {
+ Expr::Array(array) => {
+ format!(
+ "[{}]",
+ array
+ .iter()
+ .map(|x| x.to_string())
+ .collect::<Vec<String>>()
+ .join(", ")
+ )
+ }
+ Expr::Bool(value) => {
+ if *value {
+ "true".to_string()
+ } else {
+ "false".to_string()
+ }
+ }
+ Expr::Blob(blob) => format!("{:?}", String::from_utf8_lossy(blob).into_owned()),
+ Expr::Number(number) => format!("{:?}", number),
+ Expr::Error { code, msg } => format!(
+ "!{} {:?}",
+ String::from_utf8_lossy(code).into_owned(),
+ String::from_utf8_lossy(msg).into_owned()
+ ),
+ Expr::Line(line) => format!("{:?}", String::from_utf8_lossy(line).into_owned()),
+ Expr::Null => "(null)".to_string(),
+ Expr::Table(table) => {
+ format!(
+ "[{}]",
+ table
+ .iter()
+ .map(|(k, v)| format!("{}: {}", k, v))
+ .collect::<Vec<String>>()
+ .join(", ")
+ )
+ }
+ };
+
+ write!(f, "{}", res)
+ }
+}
+
+impl fmt::Display for Error {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ let res = match self {
+ Error::ConnectionClosed => "Connection closed.".to_string(),
+ Error::InvalidChar(got) => format!("Invalid char: {:?}.", *got as char),
+ Error::NotANumber(got) => format!("Not a number: {:?}.", got),
+ Error::ExpectedLF(got) => format!("Expected LF, got: {:?}.", *got as char),
+ Error::InvalidHeaderSize => "Invalid header size.".to_string(),
+ Error::InvalidTag(got) => format!("Invalid tag {:?}.", *got as char),
+ Error::NotEnoughBytes => "Not enough bytes.".to_string(),
+ };
+
+ write!(f, "{}", res)
+ }
+}
+
+// Encoders
+
+fn encode_array(array: Vec<Expr>) -> Vec<u8> {
+ let mut res = array.len().to_string().as_bytes().to_vec();
+
+ res.insert(0, b'*');
+ res.push(b'\n');
+
+ for expr in array {
+ res.append(&mut expr.encode());
+ }
+
+ res
+}
+fn encode_bool(val: bool) -> Vec<u8> {
+ if val {
+ b"#t\n".to_vec()
+ } else {
+ b"#f\n".to_vec()
+ }
+}
+fn encode_blob(mut blob: Vec<u8>) -> Vec<u8> {
+ let mut res = blob.len().to_string().as_bytes().to_vec();
+
+ res.insert(0, b'$');
+ res.push(b'\n');
+ res.append(&mut blob);
+ res.push(b'\n');
+
+ res
+}
+
+fn encode_line(mut line: Vec<u8>) -> Vec<u8> {
+ line.insert(0, b'+');
+ line.push(b'\n');
+
+ line
+}
+
+fn encode_number(number: u128) -> Vec<u8> {
+ let mut res = vec![b':'];
+ res.append(&mut number.to_string().as_bytes().to_vec());
+ res.push(b'\n');
+
+ res
+}
+
+fn encode_error(mut code: Vec<u8>, mut msg: Vec<u8>) -> Vec<u8> {
+ let mut res = (code.len() + msg.len()).to_string().as_bytes().to_vec();
+
+ res.insert(0, b'!');
+ res.push(b'\n');
+ res.append(&mut code);
+ res.push(b' ');
+ res.append(&mut msg);
+ res.push(b'\n');
+
+ res
+}
+
+fn encode_table(table: &Vec<(Expr, Expr)>) -> Vec<u8> {
+ let mut res = table.len().to_string().as_bytes().to_vec();
+
+ res.insert(0, b'%');
+ res.push(b'\n');
+
+ for (key, val) in table {
+ res.append(&mut key.encode());
+ res.append(&mut val.encode());
+ }
+
+ res
+}
+
+// Parsers
+
+fn parse_line(header: Vec<u8>, _reader: &mut BufReader<&TcpStream>) -> Result<Expr, Error> {
+ if header.len() < 2 {
+ return Err(Error::NotEnoughBytes);
+ }
+
+ Ok(Expr::Line(header[1..header.len() - 1].to_vec()))
+}
+
+fn parse_number(header: Vec<u8>, _reader: &mut BufReader<&TcpStream>) -> Result<Expr, Error> {
+ if header.len() < 2 {
+ return Err(Error::NotEnoughBytes);
+ }
+
+ let digits = &header[1..header.len() - 1];
+
+ Ok(Expr::Number(
+ as_u128(digits).map_err(|_| Error::NotANumber(digits.to_vec()))?,
+ ))
+}
+
+fn parse_blob(header: Vec<u8>, reader: &mut BufReader<&TcpStream>) -> Result<Expr, Error> {
+ let size = as_usize(&header[1..header.len() - 1]).map_err(|_| Error::InvalidHeaderSize)?;
+
+ let mut body = vec![0u8; size];
+ reader
+ .read_exact(&mut body)
+ .map_err(|_| Error::NotEnoughBytes)?;
+
+ skip_lf(reader)?;
+
+ Ok(Expr::Blob(body))
+}
+
+fn parse_error(header: Vec<u8>, reader: &mut BufReader<&TcpStream>) -> Result<Expr, Error> {
+ let size = as_usize(&header[1..header.len() - 1]).map_err(|_| Error::InvalidHeaderSize)?;
+
+ let mut body = vec![0u8; size];
+ reader
+ .read_exact(&mut body)
+ .map_err(|_| Error::NotEnoughBytes)?;
+
+ skip_lf(reader)?;
+
+ let parts: Vec<&[u8]> = body.splitn(2, |x| *x == b' ').collect();
+ if parts.len() < 2 {
+ return Err(Error::NotEnoughBytes);
+ }
+
+ Ok(Expr::Error {
+ code: parts[0].to_vec(),
+ msg: parts[1].to_vec(),
+ })
+}
+
+fn parse_array(header: Vec<u8>, reader: &mut BufReader<&TcpStream>) -> Result<Expr, Error> {
+ let size = as_usize(&header[1..header.len() - 1]).map_err(|_| Error::InvalidHeaderSize)?;
+
+ let body: Vec<Expr> = (0..size)
+ .map(|_| Expr::from_bytes(reader))
+ .collect::<Result<Vec<Expr>, _>>()?;
+
+ Ok(Expr::Array(body))
+}
+
+fn parse_bool(header: Vec<u8>, _reader: &mut BufReader<&TcpStream>) -> Result<Expr, Error> {
+ if header.len() < 2 {
+ return Err(Error::NotEnoughBytes);
+ }
+
+ match header[1] {
+ b't' => Ok(Expr::Bool(true)),
+ b'f' => Ok(Expr::Bool(false)),
+ other => Err(Error::InvalidChar(other)),
+ }
+}
+
+// Into
+
+impl Into<Expr> for ip::Block {
+ fn into(self) -> Expr {
+ Expr::Line(self.into())
+ }
+}
+
+impl Into<Expr> for &ip::Block {
+ fn into(self) -> Expr {
+ Expr::Line(self.into())
+ }
+}
+
+impl Into<Expr> for &str {
+ fn into(self) -> Expr {
+ Expr::Blob(self.into())
+ }
+}
+
+impl Into<Expr> for String {
+ fn into(self) -> Expr {
+ Expr::Blob(self.into())
+ }
+}
+
+impl Into<Expr> for u8 {
+ fn into(self) -> Expr {
+ Expr::Number(self as u128)
+ }
+}
+
+impl Into<Expr> for Vec<(&str, String)> {
+ fn into(self) -> Expr {
+ Expr::Table(
+ self.into_iter()
+ .map(|(x, y)| (x.into(), y.into()))
+ .collect(),
+ )
+ }
+}
+
+impl<T: Into<Expr>> Into<Expr> for Vec<T> {
+ fn into(self) -> Expr {
+ Expr::Array(self.into_iter().map(|x| x.into()).collect())
+ }
+}
+
+// Helpers
+
+fn skip_lf(reader: &mut BufReader<&TcpStream>) -> Result<(), Error> {
+ let mut lf = vec![0u8; 1];
+
+ reader
+ .read_exact(&mut lf)
+ .map_err(|_| Error::NotEnoughBytes)?;
+
+ if lf != vec![b'\n'] {
+ return Err(Error::ExpectedLF(lf[0]));
+ }
+ Ok(())
+}
+
+fn as_usize(array: &[u8]) -> Result<usize, ParseIntError> {
+ std::str::from_utf8(array).unwrap().parse::<usize>()
+}
+
+fn as_u128(array: &[u8]) -> Result<u128, ParseIntError> {
+ std::str::from_utf8(array).unwrap().parse::<u128>()
+}
diff --git a/src/common/ip.rs b/src/common/ip.rs
new file mode 100644
index 0000000..8c1e6b6
--- /dev/null
+++ b/src/common/ip.rs
@@ -0,0 +1,195 @@
+#![allow(clippy::from_over_into)]
+use crate::common::expr::Expr;
+use memchr::memchr;
+use std::cmp::Ordering;
+use std::collections::HashMap;
+use std::fmt;
+use std::fmt::Debug;
+use std::net::IpAddr;
+use std::net::Ipv4Addr;
+use std::net::Ipv6Addr;
+use std::str::FromStr;
+
+pub const BITMASK: u128 = 0x8000_0000_0000_0000_0000_0000_0000_0000;
+pub const BITCOUNT: u8 = 128;
+
+pub type Addr = IpAddr;
+
+#[derive(Debug)]
+pub enum Error {
+ InvalidAddr(String),
+ InvalidBlock(String),
+ InvalidMaskLength(String),
+ InvalidNibblesCount(usize),
+ MaskLengthTooLarge(u8),
+}
+
+#[derive(Copy, Clone, Eq)]
+pub struct Block {
+ pub addr: u128,
+ pub mlen: u8,
+}
+
+pub type Meta = HashMap<String, Expr>;
+
+impl Block {
+ fn new(addr: IpAddr, mlen: u8) -> Self {
+ let (ipv6_addr, ipv6_mlen) = match addr {
+ IpAddr::V4(addr) => {
+ assert!(mlen <= 32);
+ (addr.to_ipv6_mapped(), mlen + 96)
+ }
+ IpAddr::V6(addr) => {
+ assert!(mlen <= 128);
+ (addr, mlen)
+ }
+ };
+
+ let segments = ipv6_addr.segments();
+ let addr_bits = ((segments[0] as u128) << 112)
+ | ((segments[1] as u128) << 96)
+ | ((segments[2] as u128) << 80)
+ | ((segments[3] as u128) << 64)
+ | ((segments[4] as u128) << 48)
+ | ((segments[5] as u128) << 32)
+ | ((segments[6] as u128) << 16)
+ | (segments[7] as u128);
+
+ Self {
+ addr: addr_bits,
+ mlen: ipv6_mlen,
+ }
+ }
+
+ pub fn from_bytestring(bytestring: &[u8]) -> Result<Self, Error> {
+ match memchr(b'/', bytestring) {
+ Some(pos) => {
+ let addr = IpAddr::parse_ascii(&bytestring[0..pos]).map_err(|_| {
+ Error::InvalidAddr(
+ std::str::from_utf8(&bytestring[0..pos])
+ .unwrap()
+ .to_string(),
+ )
+ })?;
+
+ let mut mlen = 0;
+ for (i, byte) in bytestring[pos + 1..bytestring.len()]
+ .iter()
+ .rev()
+ .enumerate()
+ {
+ if *byte < 48 || *byte > 57 {
+ return Err(Error::InvalidBlock("Invalid1".to_string()));
+ }
+ mlen += (byte - 48) * 10u8.pow(i as u32);
+ }
+
+ Ok(Self::new(addr, mlen))
+ }
+ None => Err(Error::InvalidBlock(
+ std::str::from_utf8(bytestring).unwrap().to_string(),
+ )),
+ }
+ }
+
+ fn addr(&self) -> IpAddr {
+ addr_from_bits(self.addr)
+ }
+
+ pub fn start(&self) -> IpAddr {
+ let mask = 0xffff_ffff_ffff_ffff_ffff_ffff_ffff_ffff ^ ((1 << (128 - self.mlen)) - 1);
+ addr_from_bits(self.addr & mask)
+ }
+ pub fn end(&self) -> IpAddr {
+ let mask = !(0xffff_ffff_ffff_ffff_ffff_ffff_ffff_ffff ^ ((1 << (128 - self.mlen)) - 1));
+ addr_from_bits(self.addr | mask)
+ }
+
+ pub fn size(&self) -> u128 {
+ 2u128.pow(128 - self.mlen as u32)
+ }
+}
+
+fn addr_from_bits(bits: u128) -> IpAddr {
+ if bits & 0xffff_ffff_ffff_ffff_ffff_ffff_0000_0000 == 0xffff_0000_0000 {
+ IpAddr::V4(Ipv4Addr::from((bits & 0xff_ff_ff_ff) as u32))
+ } else {
+ IpAddr::V6(Ipv6Addr::from(bits))
+ }
+}
+
+impl FromStr for Block {
+ type Err = Error;
+ fn from_str(s: &str) -> Result<Block, Error> {
+ let parts: Vec<&str> = s.splitn(2, '/').collect();
+ if parts.len() < 2 {
+ return Err(Error::InvalidBlock(s.to_string()));
+ }
+
+ let addr =
+ IpAddr::from_str(parts[0]).map_err(|_| Error::InvalidAddr(parts[0].to_string()))?;
+ let mlen = parts[1]
+ .parse::<u8>()
+ .map_err(|_| Error::InvalidMaskLength(parts[1].to_string()))?;
+
+ Ok(Self::new(addr, mlen))
+ }
+}
+
+impl fmt::Display for Block {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ let addr = self.addr();
+ let mlen = if addr.is_ipv4() {
+ self.mlen - 96
+ } else {
+ self.mlen
+ };
+ write!(f, "{}/{}", addr, mlen)
+ }
+}
+
+impl fmt::Debug for Block {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ let addr = self.addr();
+ let mlen = if addr.is_ipv4() {
+ self.mlen - 96
+ } else {
+ self.mlen
+ };
+ write!(f, "{}/{}", addr, mlen)
+ }
+}
+
+impl Ord for Block {
+ fn cmp(&self, other: &Self) -> Ordering {
+ (self.addr >> (BITCOUNT - self.mlen) as usize)
+ .cmp(&(other.addr >> ((BITCOUNT - other.mlen) % 32) as usize))
+ }
+}
+
+impl PartialEq for Block {
+ fn eq(&self, other: &Self) -> bool {
+ self.addr == other.addr && self.mlen == other.mlen
+ }
+}
+
+impl PartialOrd for Block {
+ fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
+ Some(
+ (self.addr >> (BITCOUNT - self.mlen) as usize)
+ .cmp(&(other.addr >> ((BITCOUNT - other.mlen) % 32) as usize)),
+ )
+ }
+}
+
+impl Into<Vec<u8>> for Block {
+ fn into(self) -> Vec<u8> {
+ (&self).into()
+ }
+}
+
+impl Into<Vec<u8>> for &Block {
+ fn into(self) -> Vec<u8> {
+ self.to_string().into()
+ }
+}
diff --git a/src/common/mod.rs b/src/common/mod.rs
new file mode 100644
index 0000000..d5ad193
--- /dev/null
+++ b/src/common/mod.rs
@@ -0,0 +1,7 @@
+pub mod db;
+pub mod expr;
+pub mod ip;
+pub mod query;
+pub mod sized_buffer;
+pub mod term;
+pub mod trie;
diff --git a/src/common/query.rs b/src/common/query.rs
new file mode 100644
index 0000000..00d79b2
--- /dev/null
+++ b/src/common/query.rs
@@ -0,0 +1,252 @@
+use crate::common::db::Database;
+use crate::common::expr::Expr;
+use crate::common::ip;
+use std::fmt;
+
+#[derive(Debug)]
+pub enum Error {
+ Empty,
+ Invalid,
+ IP(ip::Error),
+ ExpectedArray(Expr),
+ ExpectedBlob(Expr),
+ InvalidArgsCount {
+ cmd: String,
+ expected: usize,
+ actual: usize,
+ },
+}
+
+#[derive(Debug)]
+pub enum Query {
+ Children { ip_block: ip::Block },
+ Del { ip_block: ip::Block },
+ Echo { msg: Vec<u8> },
+ Exists { ip_block: ip::Block },
+ Get { ip_block: ip::Block },
+ Info { ip_block: ip::Block },
+ Parent { ip_block: ip::Block },
+ Set { ip_block: ip::Block, value: Expr },
+}
+
+impl Query {
+ pub fn from_expr(expr: Expr) -> Result<Self, Error> {
+ let mut xs = if let Expr::Array(xs) = expr {
+ xs
+ } else {
+ return Err(Error::ExpectedArray(expr));
+ };
+
+ if xs.is_empty() {
+ return Err(Error::Empty);
+ }
+ let cmd_expr = xs.remove(0);
+ let cmd = if let Expr::Blob(cmd) = cmd_expr {
+ cmd
+ } else {
+ return Err(Error::ExpectedBlob(cmd_expr));
+ };
+
+ match &*cmd {
+ b"ECHO" => build_echo(xs), // ECHO str
+ b"SET" => build_set(xs), // SET ::/32 Expr
+ b"GET" => build_get(xs), // GET ::/32
+ b"DEL" => build_del(xs), // DEL ::/32
+ b"EXISTS" => build_exists(xs), // EXISTS ::/32
+ b"INFO" => build_info(xs), // EXISTS ::/32
+ b"PARENT" => build_parent(xs), // SUP ::/32
+ b"CHILDREN" => build_children(xs), // SUB ::/32
+ _ => Err(Error::Invalid),
+ }
+ }
+
+ pub fn exec(&self, db: &mut Database) -> Expr {
+ match self {
+ Self::Echo { msg } => Expr::Blob(msg.clone()),
+ Self::Set { ip_block, value: _ } => {
+ db.set(*ip_block, std::collections::HashMap::new());
+ Expr::Blob(b"OK".to_vec())
+ }
+ Self::Get { ip_block } => db.get(ip_block).map(|x| x.into()).unwrap_or(Expr::Null),
+ Self::Del { ip_block } => db
+ .del(ip_block)
+ .map_or(Expr::Null, |_| Expr::Blob(b"OK".to_vec())),
+ Self::Exists { ip_block } => Expr::Bool(db.exists(ip_block)),
+ Self::Info { ip_block } => {
+ let info: Vec<(&str, String)> = db.info(ip_block).into();
+ info.into()
+ }
+ Self::Parent { ip_block } => {
+ db.parent(ip_block).map(|x| x.into()).unwrap_or(Expr::Null)
+ }
+ Self::Children { ip_block } => db.children(ip_block).into(),
+ }
+ }
+}
+
+// Display
+
+impl fmt::Display for Query {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ let res = match self {
+ Query::Set { ip_block, value } => {
+ format!("SET {} {}", ip_block, value)
+ }
+ Query::Get { ip_block } => {
+ format!("GET {}", ip_block)
+ }
+ Query::Echo { msg } => {
+ format!("ECHO {:?}", String::from_utf8_lossy(msg).into_owned())
+ }
+ Query::Del { ip_block } => {
+ format!("DEL {}", ip_block)
+ }
+ Query::Exists { ip_block } => {
+ format!("EXISTS {}", ip_block)
+ }
+ Query::Info { ip_block } => {
+ format!("INFO {}", ip_block)
+ }
+ Query::Parent { ip_block } => {
+ format!("PARENT {}", ip_block)
+ }
+ Query::Children { ip_block } => {
+ format!("CHILDREN {}", ip_block)
+ }
+ };
+
+ write!(f, "{}", res)
+ }
+}
+
+impl fmt::Display for Error {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ let res = match self {
+ Error::Empty => "Received an empty query.".to_string(),
+ Error::ExpectedArray(got) => format!("Expected an ARRAY, got: {:?}.", got),
+ Error::ExpectedBlob(got) => format!("Expected a BLOB, got: {:?}.", got),
+ Error::IP(e) => format!("IP error: {e:?}"),
+ Error::Invalid => "Received an invalid query.".to_string(),
+ Error::InvalidArgsCount {
+ cmd,
+ expected,
+ actual,
+ } => format!(
+ "Invalid number of arguments for {:?}. Expected {} but got {}.",
+ cmd, expected, actual
+ ),
+ };
+
+ write!(f, "{}", res)
+ }
+}
+
+// Builders
+
+fn build_echo(mut args: Vec<Expr>) -> Result<Query, Error> {
+ count_args("ECHO", &args, 1)?;
+
+ let expr = args.remove(0);
+ let msg = if let Expr::Blob(k) = expr {
+ k
+ } else {
+ return Err(Error::ExpectedBlob(expr));
+ };
+ Ok(Query::Echo { msg })
+}
+
+fn build_get(mut args: Vec<Expr>) -> Result<Query, Error> {
+ count_args("GET", &args, 1)?;
+
+ let expr = args.remove(0);
+ let ip_block = if let Expr::Blob(ref addr) = expr {
+ ip::Block::from_bytestring(addr).map_err(Error::IP)?
+ } else {
+ return Err(Error::ExpectedBlob(expr));
+ };
+ Ok(Query::Get { ip_block })
+}
+fn build_set(mut args: Vec<Expr>) -> Result<Query, Error> {
+ count_args("SET", &args, 2)?;
+
+ let expr = args.remove(0);
+ let ip_block = if let Expr::Blob(ref addr) = expr {
+ ip::Block::from_bytestring(addr).map_err(Error::IP)?
+ } else {
+ return Err(Error::ExpectedBlob(expr));
+ };
+ Ok(Query::Set {
+ ip_block,
+ value: args.remove(0),
+ })
+}
+fn build_del(mut args: Vec<Expr>) -> Result<Query, Error> {
+ count_args("DEL", &args, 1)?;
+
+ let expr = args.remove(0);
+ let ip_block = if let Expr::Blob(ref addr) = expr {
+ ip::Block::from_bytestring(addr).map_err(Error::IP)?
+ } else {
+ return Err(Error::ExpectedBlob(expr));
+ };
+ Ok(Query::Del { ip_block })
+}
+fn build_exists(mut args: Vec<Expr>) -> Result<Query, Error> {
+ count_args("EXISTS", &args, 1)?;
+
+ let expr = args.remove(0);
+ let ip_block = if let Expr::Blob(ref addr) = expr {
+ ip::Block::from_bytestring(addr).map_err(Error::IP)?
+ } else {
+ return Err(Error::ExpectedBlob(expr));
+ };
+ Ok(Query::Exists { ip_block })
+}
+fn build_info(mut args: Vec<Expr>) -> Result<Query, Error> {
+ count_args("INFO", &args, 1)?;
+
+ let expr = args.remove(0);
+ let ip_block = if let Expr::Blob(ref addr) = expr {
+ ip::Block::from_bytestring(addr).map_err(Error::IP)?
+ } else {
+ return Err(Error::ExpectedBlob(expr));
+ };
+ Ok(Query::Info { ip_block })
+}
+fn build_parent(mut args: Vec<Expr>) -> Result<Query, Error> {
+ count_args("PARENT", &args, 1)?;
+
+ let expr = args.remove(0);
+ let ip_block = if let Expr::Blob(ref addr) = expr {
+ ip::Block::from_bytestring(addr).map_err(Error::IP)?
+ } else {
+ return Err(Error::ExpectedBlob(expr));
+ };
+ Ok(Query::Parent { ip_block })
+}
+fn build_children(mut args: Vec<Expr>) -> Result<Query, Error> {
+ count_args("CHILDREN", &args, 1)?;
+
+ let expr = args.remove(0);
+ let ip_block = if let Expr::Blob(ref addr) = expr {
+ ip::Block::from_bytestring(addr).map_err(Error::IP)?
+ } else {
+ return Err(Error::ExpectedBlob(expr));
+ };
+ Ok(Query::Children { ip_block })
+}
+
+// Helpers
+
+fn count_args<T>(cmd: &str, args: &Vec<T>, expected_count: usize) -> Result<(), Error> {
+ let args_count = args.len();
+ if args_count != expected_count {
+ return Err(Error::InvalidArgsCount {
+ cmd: cmd.to_string(),
+ expected: expected_count,
+ actual: args_count,
+ });
+ }
+
+ Ok(())
+}
diff --git a/src/common/sized_buffer.rs b/src/common/sized_buffer.rs
new file mode 100644
index 0000000..ebfcab5
--- /dev/null
+++ b/src/common/sized_buffer.rs
@@ -0,0 +1,113 @@
+use memchr::memchr;
+use std::io;
+use std::io::Write;
+
+pub struct SizedBuffer<T> {
+ len: usize,
+ pos: usize,
+ pub inner: T,
+}
+
+#[derive(Debug)]
+pub enum Error {
+ IO(io::Error),
+ ReadPastEnd,
+ WrongByte { expected: u8, got: u8 },
+}
+
+impl<const N: usize> SizedBuffer<[u8; N]> {
+ pub fn new() -> Self {
+ Self {
+ len: 0,
+ pos: 0,
+ inner: [0; N],
+ }
+ }
+
+ pub fn append(&mut self, bytes: &mut [u8]) -> Result<(), Error> {
+ self.write_all(bytes).map_err(Error::IO)
+ }
+
+ pub fn peek(&self) -> u8 {
+ self.inner[self.pos]
+ }
+
+ pub fn read(&mut self) -> Result<u8, Error> {
+ self.seek(1)?;
+ Ok(self.inner[self.pos - 1])
+ }
+
+ /* Reads n bytes from the buffer, and moves the cursor right after the last byte read.
+ *
+ * Example with n=3:
+ *
+ * pos = 2
+ * v
+ * [0,1,2,3,4,5,6,7,8,9]
+ * |-----|^
+ * read new pos
+ */
+ pub fn read_count(&mut self, n: usize) -> Result<&[u8], Error> {
+ self.seek(n)?;
+ Ok(&self.inner[self.pos - n..self.pos])
+ }
+
+ pub fn read_count_unchecked(&mut self, n: usize) -> &[u8] {
+ self.seek_unchecked(n);
+ &self.inner[self.pos - n..self.pos]
+ }
+
+ pub fn skip_byte(&mut self, byte: u8) -> Result<(), Error> {
+ let x = self.read()?;
+ if x != byte {
+ return Err(Error::WrongByte {
+ expected: byte,
+ got: x,
+ });
+ }
+ Ok(())
+ }
+
+ pub fn seek(&mut self, n: usize) -> Result<(), Error> {
+ /* Allows to seek 1 past the end of the buffer. Subsequent reads will still fail but this
+ * allows to read the last byte. Otherwise trying to read the last byte would fail after
+ * reading it, when trying to advance the position of the cursor. */
+ if self.pos + n > self.len {
+ return Err(Error::ReadPastEnd);
+ }
+ self.pos += n;
+
+ Ok(())
+ }
+
+ pub fn seek_unchecked(&mut self, n: usize) {
+ self.pos += n;
+ }
+
+ pub fn find_offset(&mut self, x: u8) -> Option<usize> {
+ memchr(x, &self.inner[self.pos..self.len])
+ }
+
+ pub fn empty(&self) -> bool {
+ self.len == self.pos
+ }
+
+ pub fn reset(&mut self) {
+ self.len = 0;
+ self.pos = 0;
+ }
+}
+
+impl<const N: usize> Write for SizedBuffer<[u8; N]> {
+ #[inline]
+ fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
+ let amt = (&mut self.inner[self.len..]).write(buf)?;
+ self.len += amt;
+ Ok(amt)
+ }
+
+ #[inline]
+ fn flush(&mut self) -> io::Result<()> {
+ Ok(())
+ }
+}
diff --git a/src/common/term.rs b/src/common/term.rs
new file mode 100644
index 0000000..44fd819
--- /dev/null
+++ b/src/common/term.rs
@@ -0,0 +1,517 @@
+#![allow(clippy::manual_range_contains)]
+#![allow(clippy::type_complexity)]
+use libc::BRKINT;
+use libc::CS8;
+use libc::ECHO;
+use libc::ICANON;
+use libc::ICRNL;
+use libc::IEXTEN;
+use libc::INPCK;
+use libc::ISIG;
+use libc::ISTRIP;
+use libc::IXON;
+use libc::OPOST;
+use libc::TCSAFLUSH;
+use libc::TIOCGWINSZ;
+use libc::VMIN;
+use libc::VTIME;
+use log::trace;
+use std::cmp::max;
+use std::cmp::min;
+use std::fs::File;
+use std::io::Read;
+use std::io::Write;
+use std::mem;
+use std::os::unix::io::AsRawFd;
+
+enum Key {
+ Null = 0,
+ CtrlA = 1,
+ CtrlB = 2,
+ CtrlC = 3,
+ CtrlD = 4,
+ CtrlE = 5,
+ CtrlF = 6,
+ CtrlH = 8,
+ Tab = 9,
+ CtrlK = 11,
+ CtrlL = 12,
+ Enter = 13,
+ CtrlN = 14,
+ CtrlP = 16,
+ CtrlT = 20,
+ CtrlU = 21,
+ CtrlW = 23,
+ Esc = 27,
+ Backspace = 127,
+}
+
+pub struct Term {
+ initial_termios: Option<libc::termios>,
+ tty_in: File,
+ tty_out: File,
+ history: Vec<Vec<u8>>,
+ history_cursor: usize,
+ column: usize,
+ line: Vec<u8>,
+ line_cursor: usize,
+ prompt: Option<Vec<u8>>,
+ hints_callback: Option<fn(&[u8]) -> &[u8]>,
+ columns: usize,
+ max_rows: usize,
+}
+
+#[derive(Debug)]
+struct Position {
+ column: usize,
+ #[allow(dead_code)]
+ row: usize,
+}
+
+impl Term {
+ pub fn new() -> Self {
+ let tty_out = File::options()
+ .read(true)
+ .write(true)
+ .open("/dev/tty")
+ .unwrap();
+ let tty_in = tty_out.try_clone().unwrap();
+
+ Self {
+ initial_termios: None,
+ tty_out,
+ tty_in,
+ history: Vec::new(),
+ history_cursor: 0,
+ column: 0,
+ line: Vec::new(),
+ line_cursor: 0,
+ prompt: None,
+ hints_callback: None,
+ columns: 0,
+ max_rows: 0,
+ }
+ }
+
+ pub fn setup_hints(&mut self, hints_callback: fn(&[u8]) -> &[u8]) {
+ self.hints_callback = Some(hints_callback);
+ }
+
+ pub fn setup_prompt(&mut self, prompt: &[u8]) {
+ self.prompt = Some(prompt.to_vec());
+ }
+
+ pub fn edit(&mut self) -> Option<Vec<u8>> {
+ self.enable_raw_mode();
+
+ self.columns = self.get_columns();
+ self.column = self.get_cursor_position().unwrap().column;
+
+ self.refresh_line();
+
+ loop {
+ let mut buffer: [u8; 1] = [0];
+
+ if self.tty_in.read(&mut buffer).unwrap_or(0) != 1 {
+ break;
+ }
+
+ match buffer[0] {
+ x if (x == Key::Tab as u8) => {
+ todo!("PRESSED TAB");
+ }
+ x if (x == Key::CtrlC as u8) => {
+ self.tty_out.write_all(b"\r\n").unwrap();
+ self.disable_raw_mode();
+ std::process::exit(1);
+ }
+ x if (x == Key::Enter as u8) => {
+ let line = self.line.clone();
+
+ self.history_cursor = self.history.len();
+ self.history.push(self.line.clone());
+
+ self.line.clear();
+ self.line_cursor = 0;
+
+ self.tty_out.write_all(b"\r\n").unwrap();
+ self.disable_raw_mode();
+
+ return Some(line);
+ }
+ x if (x == Key::Backspace as u8) => {
+ if self.line_cursor < 1 {
+ continue;
+ }
+
+ self.line_cursor -= 1;
+ self.line.remove(self.line_cursor);
+
+ self.refresh_line();
+ }
+ x if (x == Key::Null as u8) => {}
+ x if (x == Key::CtrlD as u8) => {
+ if !self.line.is_empty() {
+ self.delete_right();
+ } else {
+ self.tty_out.write_all(b"\r\n").unwrap();
+ self.disable_raw_mode();
+ std::process::exit(1);
+ }
+ }
+ x if (x == Key::CtrlA as u8) => self.move_to_start(),
+ x if (x == Key::CtrlE as u8) => self.move_to_end(),
+ x if (x == Key::CtrlB as u8) => self.move_left(),
+ x if (x == Key::CtrlF as u8) => self.move_right(),
+ x if (x == Key::CtrlH as u8) => self.delete_left(),
+ x if (x == Key::CtrlK as u8) => self.delete_to_end(),
+ x if (x == Key::CtrlL as u8) => {
+ self.clear_screen();
+ self.refresh_line();
+ }
+ // Next line in history
+ x if (x == Key::CtrlN as u8) => self.history_next(),
+ // Previous line in history
+ x if (x == Key::CtrlP as u8) => self.history_prev(),
+ x if (x == Key::CtrlT as u8) => self.swap_one(),
+ x if (x == Key::CtrlU as u8) => self.delete_to_start(),
+ x if (x == Key::CtrlW as u8) => self.delete_word_left(),
+ x if (x == Key::Esc as u8) => {
+ let mut seq: [u8; 1] = [0];
+ if self.tty_in.read(&mut seq).unwrap_or(0) != 1 {
+ continue;
+ }
+ if seq[0] != b'[' {
+ // FIXME: Handle Esc+Key / "Esc 0" sequences
+ continue;
+ }
+ if self.tty_in.read(&mut seq).unwrap_or(0) != 1 {
+ break;
+ }
+ if seq[0] >= b'0' && seq[0] <= b'9' {
+ // FIXME: Handle extended escape sequence
+ continue;
+ }
+ match seq[0] {
+ // Up
+ b'A' => self.history_prev(),
+ // Down
+ b'B' => self.history_next(),
+ // Right
+ b'C' => self.move_right(),
+ // Left
+ b'D' => self.move_left(),
+ // Home
+ b'H' => self.move_to_start(),
+ // End
+ b'F' => self.move_to_end(),
+ _ => break,
+ }
+ }
+ x => {
+ self.line.insert(self.line_cursor, x);
+ self.line_cursor += 1;
+ self.refresh_line();
+ }
+ }
+
+ self.tty_out.flush().unwrap();
+ }
+
+ self.disable_raw_mode();
+
+ None
+ }
+
+ fn history_next(&mut self) {
+ if self.history.is_empty() {
+ return;
+ }
+
+ self.line = self.history[self.history_cursor].clone();
+ self.line_cursor = self.line.len();
+
+ self.refresh_line();
+
+ self.history_cursor = min(self.history_cursor + 1, self.history.len() - 1);
+ }
+
+ fn history_prev(&mut self) {
+ if self.history.is_empty() {
+ return;
+ }
+
+ // FIXME: Save current line before replacing it with line for the history.
+
+ self.line = self.history[self.history_cursor].clone();
+ self.line_cursor = self.line.len();
+
+ self.refresh_line();
+
+ self.history_cursor = self.history_cursor.saturating_sub(1);
+ }
+
+ fn move_right(&mut self) {
+ self.line_cursor = min(self.line_cursor + 1, self.line.len());
+ self.refresh_line();
+ }
+
+ fn move_left(&mut self) {
+ self.line_cursor = self.line_cursor.saturating_sub(1);
+ self.refresh_line();
+ }
+
+ fn move_to_start(&mut self) {
+ self.line_cursor = 0;
+ self.refresh_line();
+ }
+ fn move_to_end(&mut self) {
+ self.line_cursor = self.line.len();
+ self.refresh_line();
+ }
+
+ fn delete_word_left(&mut self) {
+ if self.line_cursor == 0 {
+ return;
+ }
+
+ let cursor = self.line_cursor;
+
+ while self.line_cursor > 0 && self.line[self.line_cursor - 1] == b' ' {
+ self.line_cursor -= 1;
+ }
+ while self.line_cursor > 0 && self.line[self.line_cursor - 1] != b' ' {
+ self.line_cursor -= 1;
+ }
+
+ self.line.drain(self.line_cursor..cursor);
+ self.refresh_line();
+ }
+
+ fn delete_left(&mut self) {
+ if self.line_cursor == 0 {
+ return;
+ }
+
+ self.line.remove(self.line_cursor - 1);
+ self.line_cursor = self.line_cursor.saturating_sub(1);
+ self.refresh_line();
+ }
+
+ fn delete_right(&mut self) {
+ if self.line_cursor >= self.line.len() {
+ return;
+ }
+
+ self.line.remove(self.line_cursor);
+ self.refresh_line();
+ }
+
+ fn delete_to_start(&mut self) {
+ if self.line_cursor == 0 {
+ return;
+ }
+
+ self.line.drain(0..self.line_cursor);
+ self.line_cursor = 0;
+ self.refresh_line();
+ }
+
+ fn delete_to_end(&mut self) {
+ if self.line_cursor >= self.line.len() {
+ return;
+ }
+
+ self.line.truncate(self.line_cursor);
+ self.refresh_line();
+ }
+
+ fn swap_one(&mut self) {
+ if self.line_cursor == 0 {
+ return;
+ }
+
+ self.line_cursor = min(self.line_cursor + 1, self.line.len());
+ self.line.swap(self.line_cursor - 2, self.line_cursor - 1);
+ self.refresh_line();
+ }
+
+ fn refresh_line(&mut self) {
+ let prompt_len = self.prompt.as_ref().map(|x| x.len()).unwrap_or(0);
+
+ let rows = (self.line.len() + prompt_len + self.columns - 1) / self.columns;
+ let prev_rows = self.max_rows;
+ let curr_row = (self.line_cursor + prompt_len + self.columns - 1) / self.columns;
+ self.column = (prompt_len + self.line_cursor) % self.columns;
+
+ self.max_rows = max(self.max_rows, rows);
+
+ //
+ // Clear rows
+ //
+
+ // Go down to the last row if the cursor is not already there.
+ if curr_row < prev_rows {
+ trace!("Go down {}", prev_rows - curr_row);
+ write!(self.tty_out, "\x1b[{}B", prev_rows - curr_row).unwrap();
+ }
+
+ // Clear each row and go up.
+ for _ in 1..prev_rows {
+ trace!("Clear row, go up");
+ self.tty_out.write_all(b"\r\x1b[0K\x1b[1A").unwrap();
+ }
+
+ // Clear the top row.
+ trace!("Clear top row");
+ self.tty_out.write_all(b"\r\x1b[0K").unwrap();
+
+ //
+ // Rewrite prompt and line
+ //
+
+ if let Some(prompt) = &self.prompt {
+ self.tty_out.write_all(prompt).unwrap();
+ }
+
+ self.tty_out.write_all(&self.line).unwrap();
+
+ //
+ // Display hints if any
+ //
+
+ if self.hints_callback.is_some() && !self.line.is_empty() {
+ let hints = self.hints_callback.unwrap()(&self.line);
+
+ if !hints.is_empty() {
+ // Column for the last character on the current row.
+ let max_column = (prompt_len + self.line.len()) % (self.columns + 1);
+ let max_len = min(self.columns - max_column, hints.len());
+
+ self.tty_out.write_all(b"\x1b[2m").unwrap();
+ self.tty_out.write_all(&hints[..max_len]).unwrap();
+ self.tty_out.write_all(b"\x1b[22m").unwrap();
+ }
+ }
+
+ //
+ // Move cursor to the right position.
+ //
+
+ if self.column == 0
+ && self.line.len() == self.line_cursor
+ && (prompt_len + self.line_cursor) > 0
+ {
+ // We reached the end of a row, insert a new line and go back to the first row.
+ trace!("Insert linefeed");
+ self.max_rows = max(self.max_rows, rows + 1);
+ self.tty_out.write_all(b"\r\n").unwrap();
+ } else if self.column == 0 {
+ self.tty_out.write_all(b"\r").unwrap();
+ } else {
+ // Move to the right column.
+ write!(self.tty_out, "\r\x1b[{}C", self.column).unwrap();
+
+ if rows > curr_row {
+ // Move up to the current row.
+ write!(self.tty_out, "\x1b[{}A", rows - curr_row).unwrap();
+ }
+ }
+ }
+
+ fn enable_raw_mode(&mut self) {
+ let mut termios = mem::MaybeUninit::uninit();
+
+ if unsafe { libc::tcgetattr(self.tty_out.as_raw_fd(), termios.as_mut_ptr()) } == -1 {
+ panic!("Failed to enable raw mode");
+ }
+
+ self.initial_termios = Some(unsafe { termios.assume_init() });
+ let mut raw = self.initial_termios.unwrap();
+
+ raw.c_iflag &= !(BRKINT | ICRNL | INPCK | ISTRIP | IXON);
+ raw.c_oflag &= !OPOST;
+ raw.c_cflag |= CS8;
+ raw.c_lflag &= !(ECHO | ICANON | IEXTEN | ISIG);
+ raw.c_cc[VMIN] = 1;
+ raw.c_cc[VTIME] = 0;
+
+ if unsafe { libc::tcsetattr(self.tty_out.as_raw_fd(), TCSAFLUSH, &raw) } < 0 {
+ panic!("Failed to enable raw mode");
+ }
+ }
+
+ fn disable_raw_mode(&mut self) {
+ if self.initial_termios.is_some()
+ && unsafe {
+ libc::tcsetattr(
+ self.tty_out.as_raw_fd(),
+ TCSAFLUSH,
+ &self.initial_termios.unwrap(),
+ )
+ } != -1
+ {
+ self.initial_termios = None;
+ }
+ }
+
+ fn get_columns(&mut self) -> usize {
+ let ws = libc::winsize {
+ ws_row: 0,
+ ws_col: 0,
+ ws_xpixel: 0,
+ ws_ypixel: 0,
+ };
+
+ if unsafe { libc::ioctl(self.tty_out.as_raw_fd(), TIOCGWINSZ, &ws) } < 0 || ws.ws_col == 0 {
+ todo!("Query the terminal for columns count.");
+ }
+
+ ws.ws_col as usize
+ }
+
+ fn get_cursor_position(&mut self) -> Option<Position> {
+ self.enable_raw_mode();
+
+ self.tty_out.write_all(b"\x1b[6n").unwrap();
+ self.tty_out.flush().unwrap();
+
+ let mut buffer: [u8; 32] = [0; 32];
+
+ let mut sep_index = None;
+ let mut index = 0;
+
+ let mut bytes = (&self.tty_in).bytes();
+ while let Some(Ok(byte)) = bytes.next() {
+ if byte == b'R' || index >= buffer.len() {
+ break;
+ }
+ if byte == b';' {
+ sep_index = Some(index);
+ }
+
+ buffer[index] = byte;
+ index += 1;
+ }
+
+ if buffer[0] != Key::Esc as u8 || buffer[1] != b'[' {
+ return None;
+ }
+
+ sep_index.map(|sep_index| Position {
+ row: as_usize(&buffer[2..sep_index]),
+ column: as_usize(&buffer[sep_index + 1..index]),
+ })
+ }
+
+ fn clear_screen(&mut self) {
+ self.tty_out.write_all(b"\x1b[H\x1b[2J").unwrap();
+ }
+}
+
+fn as_usize(array: &[u8]) -> usize {
+ std::str::from_utf8(array)
+ .unwrap()
+ .parse::<usize>()
+ .unwrap()
+}
diff --git a/src/common/trie.rs b/src/common/trie.rs
new file mode 100644
index 0000000..381d2b5
--- /dev/null
+++ b/src/common/trie.rs
@@ -0,0 +1,151 @@
+#![allow(dead_code)]
+use crate::common::ip::Block;
+use crate::common::ip::Meta;
+use crate::common::ip::BITMASK;
+
+pub struct Trie {
+ root: Node,
+}
+
+#[derive(Debug)]
+struct Node {
+ block: Option<Block>,
+ left: Option<Box<Node>>,
+ right: Option<Box<Node>>,
+}
+
+impl Node {
+ pub const NULL: Node = Self {
+ block: None,
+ left: None,
+ right: None,
+ };
+
+ fn new(block: Option<Block>) -> Self {
+ Node {
+ block,
+ left: None,
+ right: None,
+ }
+ }
+
+ fn is_leaf(&self) -> bool {
+ self.left.is_none() && self.right.is_none()
+ }
+}
+
+impl Trie {
+ pub fn new() -> Self {
+ Trie { root: Node::NULL }
+ }
+
+ pub fn insert(&mut self, block: Block, _meta: Meta) {
+ let mut node = &mut self.root;
+ let mut addr = block.addr;
+
+ for _ in 0..block.mlen {
+ if addr & BITMASK == 0 {
+ if node.left.is_none() {
+ node.left = Some(Box::new(Node::NULL));
+ }
+ node = node.left.as_deref_mut().unwrap();
+ } else {
+ if node.right.is_none() {
+ node.right = Some(Box::new(Node::NULL));
+ }
+ node = node.right.as_deref_mut().unwrap();
+ }
+ addr <<= 1;
+ }
+
+ node.block = Some(block);
+ }
+
+ pub fn find(&self, needle: &Block) -> Option<Block> {
+ self.find_block_node(needle)?.block
+ }
+
+ pub fn parent(&self, needle: &Block) -> Option<Block> {
+ let mut node = &self.root;
+ let mut addr = needle.addr;
+ let mut parent = None;
+
+ for _ in 0..needle.mlen {
+ if node.block.is_some() {
+ parent = node.block;
+ }
+
+ if addr & BITMASK == 0 {
+ if node.left.is_some() {
+ node = node.left.as_deref().unwrap();
+ } else {
+ break;
+ }
+ } else if node.right.is_some() {
+ node = node.right.as_deref().unwrap();
+ } else {
+ break;
+ }
+ addr <<= 1;
+ }
+
+ parent
+ }
+
+ pub fn children(&self, parent: &Block) -> Vec<Block> {
+ let root: &Node = match self.find_block_node(parent) {
+ Some(root) => root,
+ None => return vec![],
+ };
+
+ let mut children: Vec<Block> = Vec::new();
+ let mut stack: Vec<&Node> = vec![root];
+
+ while let Some(node) = stack.pop() {
+ if node.is_leaf() && node.block.is_some() && node.block.unwrap() != root.block.unwrap()
+ {
+ children.push(node.block.unwrap());
+ continue;
+ }
+
+ if node.left.is_some() {
+ stack.push(node.left.as_deref().unwrap());
+ }
+ if node.right.is_some() {
+ stack.push(node.right.as_deref().unwrap());
+ }
+ }
+
+ children
+ }
+
+ fn find_block_node(&self, needle: &Block) -> Option<&Node> {
+ let mut node = &self.root;
+ let mut addr = needle.addr;
+ let mut match_ = None;
+
+ for _ in 0..(needle.mlen + 1) {
+ if node.block.is_some() {
+ match_ = Some(node);
+ }
+
+ if addr & BITMASK == 0 {
+ if node.left.is_some() {
+ node = node.left.as_deref().unwrap();
+ } else {
+ break;
+ }
+ } else if node.right.is_some() {
+ node = node.right.as_deref().unwrap();
+ } else {
+ break;
+ }
+ addr <<= 1;
+ }
+
+ match match_?.block {
+ Some(block) if block == *needle => match_,
+ _ => None,
+ }
+ }
+}
diff --git a/src/help.rs b/src/help.rs
new file mode 100644
index 0000000..188f475
--- /dev/null
+++ b/src/help.rs
@@ -0,0 +1,43 @@
+use std::process::ExitCode;
+
+use crate::query;
+use crate::repl;
+use crate::server;
+
+pub const USAGE: &str = "
+Usage: blom [command] [options]
+
+Commands:
+
+ server Start the blom server.
+
+ repl Start the blom repl.
+ query <query> Send <query> to the blom server.
+
+ help [<command>] Print usage.
+
+General options:
+
+ -h, --help Print usage.
+";
+
+pub fn cmd(args: &[String]) -> ExitCode {
+ if args.is_empty() {
+ println!("{}", USAGE);
+ return ExitCode::from(0);
+ }
+
+ match args[0].as_str() {
+ "server" => println!("{}", server::USAGE),
+ "repl" => println!("{}", repl::USAGE),
+ "query" => println!("{}", query::USAGE),
+ "help" => println!("{}", USAGE),
+ _ => {
+ println!("{}", USAGE);
+ eprintln!("Unknown command {}.", args[0]);
+ return ExitCode::from(1);
+ }
+ }
+
+ ExitCode::SUCCESS
+}
diff --git a/src/main.rs b/src/main.rs
new file mode 100644
index 0000000..f39c5fc
--- /dev/null
+++ b/src/main.rs
@@ -0,0 +1,37 @@
+#![feature(addr_parse_ascii)]
+
+use std::env;
+use std::process::ExitCode;
+
+mod common;
+mod help;
+mod query;
+mod repl;
+mod server;
+
+fn main() -> ExitCode {
+ env_logger::init();
+
+ let args: Vec<String> = env::args().collect();
+
+ if args.len() <= 1 {
+ println!("{}", help::USAGE);
+ eprintln!("Expected command argument.");
+ return ExitCode::from(1);
+ }
+
+ let cmd = &args[1];
+ let cmd_args = &args[2..];
+
+ match cmd.as_str() {
+ "server" => return server::cmd(cmd_args),
+ "repl" => return repl::cmd(cmd_args),
+ "query" => return query::cmd(cmd_args),
+ "help" | "-h" | "--help" => return help::cmd(cmd_args),
+ _ => {
+ println!("{}", help::USAGE);
+ eprintln!("Unknown command {}.", cmd);
+ return ExitCode::from(1);
+ }
+ }
+}
diff --git a/src/query.rs b/src/query.rs
new file mode 100644
index 0000000..1f3b4c8
--- /dev/null
+++ b/src/query.rs
@@ -0,0 +1,97 @@
+use crate::common::expr::Expr;
+use std::io::Read;
+use std::io::Write;
+use std::net::TcpStream;
+use std::process::ExitCode;
+use std::str;
+
+pub const USAGE: &str = "
+Usage: blom query <query>
+
+ Sends <query> to the blom server and prints the response
+ to stdout.
+
+Examples:
+
+ $ blom query 'CHILDREN 1.2.3.0/24'
+ 1.2.3.4/26
+ 1.2.3.4/32
+
+ $ blom query --raw 'CHILDREN 1.2.3.0/24'
+ +1.2.3.4/26\n
+ +1.2.3.4/32\n
+
+ $ blom query 'INFO 1.2.3.0/24'
+ Start: 1.2.3.0
+ End: 1.2.3.255
+ Size: 256
+ Parents: 3
+ Children: 364
+ Meta: \"{\\\"owner\\\": \\\"acme\\\"}\"
+
+Options:
+
+ -r, --raw Do not parse the response and prints the raw
+ output from the server instead.
+
+ -v, --verbose Print the request send to the server in addition
+ to the response.
+
+ -b, --bind Bind to the given ADDRESS:PORT. Default is
+ 0.0.0.0:4902
+
+";
+
+pub fn cmd(args: &[String]) -> ExitCode {
+ let mut addr = "0.0.0.0:4902";
+ let mut query: Option<&str> = None;
+
+ let mut arg_index: usize = 0;
+ while arg_index < args.len() {
+ match args[arg_index].as_str() {
+ "--bind" | "-b" => {
+ addr = args
+ .get(arg_index + 1)
+ .expect("Missing address after --bind.")
+ .as_str();
+ arg_index += 1;
+ }
+ q => {
+ if query.is_some() {
+ println!("Too many arguments.");
+ return ExitCode::FAILURE;
+ }
+ query = Some(q);
+ }
+ }
+
+ arg_index += 1;
+ }
+
+ if let Some(q) = query {
+ handle_query(addr, q);
+ return ExitCode::SUCCESS;
+ }
+
+ println!("Missing query.");
+ ExitCode::FAILURE
+}
+
+fn handle_query(addr: &str, query: &str) {
+ let expr = Expr::from_query(query).encode();
+
+ match TcpStream::connect(addr) {
+ Ok(mut stream) => {
+ let _ = stream.write(&expr).unwrap();
+ let mut data = [0u8; 64];
+ match stream.read(&mut data) {
+ Ok(_) => {
+ let resp = str::from_utf8(&data).unwrap();
+ println!("{resp:?}");
+ }
+ Err(e) => println!("Read failure: {e:?}"),
+ }
+ }
+ Err(e) => println!("Connection failure: {e:?}"),
+ }
+}
diff --git a/src/repl.rs b/src/repl.rs
new file mode 100644
index 0000000..339405e
--- /dev/null
+++ b/src/repl.rs
@@ -0,0 +1,134 @@
+use crate::common::expr::Error as ProtoError;
+use crate::common::expr::Expr;
+use crate::common::term::Term;
+use std::io::BufReader;
+use std::io::Write;
+use std::net::TcpStream;
+use std::process::ExitCode;
+
+pub const USAGE: &str = "
+Usage: blom start
+
+ Starts the blom repl.
+
+ -b, --bind Bind to the given ADDRESS:PORT. Default is
+ 0.0.0.0:4902
+
+";
+
+enum Error {
+ Connect,
+}
+
+pub fn cmd(args: &[String]) -> ExitCode {
+ let mut addr = "0.0.0.0:4902";
+
+ let mut arg_index: usize = 0;
+ while arg_index < args.len() {
+ match args[arg_index].as_str() {
+ "--addr" | "-a" => {
+ addr = args
+ .get(arg_index + 1)
+ .expect("Missing address after --addr.")
+ .as_str();
+ arg_index += 1;
+ }
+ _ => {
+ println!("Too many arguments.");
+ return ExitCode::FAILURE;
+ }
+ }
+
+ arg_index += 1;
+ }
+
+ match start(addr) {
+ Ok(()) => ExitCode::SUCCESS,
+ Err(_) => ExitCode::FAILURE,
+ }
+}
+
+fn start(addr: &str) -> Result<(), Error> {
+ let mut term = Term::new();
+
+ term.setup_prompt(format!("{addr}> ").as_bytes());
+ term.setup_hints(hints);
+
+ while let Some(query) = input(&mut term) {
+ if query.is_empty() {
+ continue;
+ }
+
+ let mut stream = match try_connect(addr) {
+ Ok(stream) => stream,
+ Err(_) => continue,
+ };
+
+ let expr = Expr::from_query(&query);
+
+ if stream.write(&expr.encode()).is_err() {
+ println!("Couldn't connect to blom at {addr}");
+ continue;
+ }
+
+ let mut resp = BufReader::new(&stream);
+
+ match Expr::from_bytes(&mut resp) {
+ Ok(resp) => println!("{resp}"),
+ Err(ProtoError::ConnectionClosed) => {
+ println!("Couldn't connect to blom at {addr}");
+ continue;
+ }
+ Err(e) => {
+ println!("{e}");
+ break;
+ }
+ }
+ }
+
+ Ok(())
+}
+
+fn try_connect(addr: &str) -> Result<TcpStream, Error> {
+ match TcpStream::connect(addr) {
+ Ok(stream) => Ok(stream),
+ Err(_) => {
+ println!("Couldn't connect to blom at {addr}.");
+ Err(Error::Connect)
+ }
+ }
+}
+
+fn input(term: &mut Term) -> Option<String> {
+ let line = term.edit()?;
+ let query = std::str::from_utf8(&line).unwrap().trim();
+
+ if query == "exit" {
+ return None;
+ }
+
+ Some(query.to_string())
+}
+
+// TODO: Refactor
+fn hints(line: &[u8]) -> &[u8] {
+ match line {
+ b"ECHO" | b"echo" => b" MESSAGE",
+ b"SET" | b"set" => b" BLOCK META",
+ b"GET" | b"get" => b" BLOCK",
+ b"DEL" | b"del" => b" BLOCK",
+ b"EXISTS" | b"exists" => b" BLOCK",
+ b"INFO" | b"info" => b" BLOCK",
+ b"PARENT" | b"parent" => b" BLOCK",
+ b"CHILDREN" | b"children" => b" BLOCK",
+ b"ECHO " | b"echo " => b"MESSAGE",
+ b"SET " | b"set " => b"BLOCK META",
+ b"GET " | b"get " => b"BLOCK",
+ b"DEL " | b"del " => b"BLOCK",
+ b"EXISTS " | b"exists " => b"BLOCK",
+ b"INFO " | b"info " => b"BLOCK",
+ b"PARENT " | b"parent " => b"BLOCK",
+ b"CHILDREN " | b"children " => b"BLOCK",
+ _ => b"",
+ }
+}
diff --git a/src/server.rs b/src/server.rs
new file mode 100644
index 0000000..c23d74a
--- /dev/null
+++ b/src/server.rs
@@ -0,0 +1,291 @@
+mod conn;
+mod parser;
+
+use crate::common::db::Database;
+use crate::common::expr::Error as ExprError;
+use crate::common::query::Error as QueryError;
+use conn::Conn;
+use conn::Error as ConnError;
+use log::debug;
+use log::error;
+use log::info;
+use log::trace;
+use log::warn;
+use mio::event::Event;
+use mio::event::Events;
+use mio::net::TcpListener;
+use mio::Interest;
+use mio::Poll;
+use mio::Token;
+use std::collections::HashMap;
+use std::fmt;
+use std::io::ErrorKind;
+use std::io::Read;
+use std::net::SocketAddr;
+use std::process;
+use std::process::ExitCode;
+use std::time::Duration;
+
+pub const USAGE: &str = "
+Usage: blom start
+
+ Starts the blom server.
+
+Options
+
+ -b, --bind Bind to the given ADDRESS:PORT. Default is
+ 0.0.0.0:4902
+
+
+";
+
+pub const HEADER: &str = "
+ ___ ___ ___ ___
+ /\\ \\ /\\__\\ /\\ \\ /\\__\\
+ /..\\ \\ /./ / /..\\ \\ /..L_L_
+ /..\\.\\__\\ /./__/ /./\\.\\__\\ /./L.\\__\\
+ \\.\\../ / \\.\\ \\ \\.\\/./ / \\/_/./ /
+ \\../ / \\.\\__\\ \\../ / /./ /
+ \\/__/ \\/__/ \\/__/ \\/__/
+";
+
+const SERVER_TOKEN: Token = Token(0);
+const SERVER_BUFLEN: usize = 512;
+const CONN_BUFLEN: usize = 2048;
+const MAX_CONN: usize = 1000;
+
+#[derive(Debug)]
+enum Error {
+ IO(std::io::Error),
+ Query(QueryError),
+ Expr(ExprError),
+}
+
+struct Server {
+ addr: SocketAddr,
+ pid: u32,
+ poll: Poll,
+ pool: Vec<Token>,
+ listener: TcpListener,
+ conns: HashMap<Token, Conn>,
+ buffer: [u8; SERVER_BUFLEN],
+ db: Database,
+}
+
+pub fn cmd(args: &[String]) -> ExitCode {
+ println!("{}\n", HEADER);
+
+ let addr = match args.get(0).map(|x| x.as_str()) {
+ Some("--bind") | Some("-b") => {
+ Some(args.get(1).expect("Missing address for --bind.").as_str())
+ }
+ _ => None,
+ }
+ .unwrap_or("0.0.0.0:4902");
+
+ let mut server = match Server::bind(addr.parse().unwrap(), MAX_CONN) {
+ Some(server) => server,
+ None => return ExitCode::FAILURE,
+ };
+ println!(
+ "PID: {} | Listening on {}; press Ctrl-C to exit...\n\n",
+ server.pid, server.addr
+ );
+
+ match server.start() {
+ Ok(()) => {
+ info!("Received signal X, terminating.");
+ ExitCode::SUCCESS
+ }
+ Err(e) => {
+ error!("Error: {e:?}");
+ ExitCode::FAILURE
+ }
+ }
+}
+
+impl Server {
+ fn bind(addr: SocketAddr, max_conns: usize) -> Option<Self> {
+ assert!(max_conns > 0);
+
+ let mut listener = match TcpListener::bind(addr) {
+ Ok(listener) => listener,
+ Err(ref e) if e.kind() == ErrorKind::AddrInUse => {
+ println!("{addr} already in use.");
+ return None;
+ }
+ Err(_) => {
+ println!("Couldn't bind to {addr}.");
+ return None;
+ }
+ };
+
+ let poll = Poll::new().unwrap();
+
+ poll.registry()
+ .register(&mut listener, SERVER_TOKEN, Interest::READABLE)
+ .unwrap();
+
+ Some(Server {
+ addr,
+ pid: process::id(),
+ poll,
+ listener,
+ conns: HashMap::new(),
+ buffer: [0; SERVER_BUFLEN],
+ pool: (1..max_conns + 1).map(Token).collect(),
+ db: Database::new(),
+ })
+ }
+
+ fn start(&mut self) -> Result<(), Error> {
+ let mut events = Events::with_capacity(512);
+
+ loop {
+ match self.poll.poll(&mut events, Some(Duration::from_secs(30))) {
+ Ok(()) => (),
+ Err(ref e) if e.kind() == ErrorKind::Interrupted => break,
+ Err(e) => return Err(Error::IO(e)),
+ };
+
+ if events.is_empty() {
+ trace!("Pulse");
+ }
+
+ for event in &events {
+ match event.token() {
+ SERVER_TOKEN => self.handle_new(),
+ token => self.handle_conn(token, event),
+ }
+ }
+ }
+
+ Ok(())
+ }
+
+ fn handle_new(&mut self) {
+ loop {
+ match self.listener.accept() {
+ Ok((mut socket, _)) => match self.pool.pop() {
+ Some(token) => {
+ self.poll
+ .registry()
+ .register(&mut socket, token, Interest::READABLE | Interest::WRITABLE)
+ .unwrap();
+
+ let conn = Conn::new(socket, token);
+
+ debug!(
+ "[{}] Connected. {} clients connected.",
+ conn,
+ MAX_CONN - self.pool.len()
+ );
+
+ self.conns.insert(token, conn);
+ }
+ None => {
+ warn!(
+ "[{}] Connection refused: Too many clients.",
+ socket.peer_addr().unwrap()
+ );
+ break;
+ }
+ },
+ Err(ref e) if e.kind() == ErrorKind::WouldBlock => break,
+ Err(_) => panic!("FATAL | Unexpected error"),
+ }
+ }
+ }
+
+ fn handle_conn(&mut self, token: Token, event: &Event) {
+ let mut did_work = true;
+ while did_work {
+ did_work = false;
+
+ let conn = self.conns.get_mut(&token).unwrap();
+
+ if conn.pending_exec() {
+ did_work = true;
+ if let Err(ConnError::Fatal) = conn.handle_pending(&mut self.db) {
+ self.close_conn(token);
+ break;
+ }
+ }
+
+ if event.is_readable() && conn.read_ready() {
+ did_work = true;
+
+ match conn.read(&mut self.buffer) {
+ Ok(0) => {
+ self.close_conn(token);
+ break;
+ }
+ Ok(len) => {
+ if let Err(ConnError::Fatal) =
+ conn.handle_msg(&mut self.db, &self.buffer[0..len])
+ {
+ self.close_conn(token);
+ break;
+ }
+ }
+ Err(ref e) if e.kind() == ErrorKind::WouldBlock => break,
+ Err(ref e) if e.kind() == ErrorKind::ConnectionReset => {
+ self.close_conn(token);
+ break;
+ }
+ Err(e) => {
+ error!("!!! [{conn}] Fatal error: {e:?}.");
+ self.close_conn(token);
+ break;
+ }
+ }
+ }
+
+ if event.is_writable() && conn.write_ready() {
+ did_work = true;
+ if let Err(e) = conn.reply() {
+ error!("!!! [{conn}] Fatal error: {e:?}.");
+ self.close_conn(token);
+ break;
+ };
+ }
+ }
+ }
+
+ fn close_conn(&mut self, token: Token) {
+ let conn = self.conns.remove(&token).unwrap();
+ self.pool.push(token);
+ debug!(
+ "[{conn}] Closed. {} clients connected.",
+ MAX_CONN - self.pool.len()
+ );
+ }
+}
+
+// Display
+
+impl fmt::Display for Error {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ let res = match self {
+ Error::IO(err) => err.to_string(),
+ Error::Query(err) => err.to_string(),
+ Error::Expr(err) => err.to_string(),
+ };
+
+ write!(f, "{}", res)
+ }
+}
+
+// Errors
+
+impl From<QueryError> for Error {
+ fn from(err: QueryError) -> Self {
+ Error::Query(err)
+ }
+}
+
+impl From<ExprError> for Error {
+ fn from(err: ExprError) -> Self {
+ Error::Expr(err)
+ }
+}
diff --git a/src/server/conn.rs b/src/server/conn.rs
new file mode 100644
index 0000000..7ea5b0d
--- /dev/null
+++ b/src/server/conn.rs
@@ -0,0 +1,230 @@
+use super::parser::Error as ParserError;
+use super::parser::Parser;
+use super::parser::State as ParserState;
+use crate::common::db::Database;
+use crate::common::expr::Expr;
+use crate::common::query::Error as QueryError;
+use crate::common::query::Query;
+use crate::common::sized_buffer::SizedBuffer;
+use crate::server::CONN_BUFLEN;
+use log::debug;
+use log::error;
+
+use mio::net::TcpStream;
+use mio::Token;
+use std::fmt;
+use std::fmt::Display;
+use std::fmt::Formatter;
+use std::io::ErrorKind;
+use std::io::Read;
+use std::io::Write;
+
+#[derive(Debug)]
+pub enum Error {
+ IO(std::io::Error),
+ Parser(ParserError),
+ Query(QueryError),
+ Fatal,
+}
+
+#[derive(Eq, PartialEq, Debug)]
+#[repr(u8)]
+pub enum State {
+ PendingExec = 0b100,
+ ReadReady = 0b010,
+ WriteReady = 0b001,
+}
+
+pub struct Conn {
+ pub flags: u8,
+ socket: TcpStream,
+ token: Token,
+ parser: Parser,
+ buf_in: SizedBuffer<[u8; CONN_BUFLEN]>,
+ buf_out: Vec<u8>,
+}
+
+impl Conn {
+ pub fn new(socket: TcpStream, token: Token) -> Self {
+ Self {
+ flags: State::ReadReady as u8,
+ socket,
+ token,
+ parser: Parser::new(),
+ buf_in: SizedBuffer::new(),
+ buf_out: Vec::new(),
+ }
+ }
+
+ pub fn handle_pending(&mut self, db: &mut Database) -> Result<(), Error> {
+ self.handle_msg(db, &[])
+ }
+
+ pub fn handle_msg(&mut self, db: &mut Database, msg: &[u8]) -> Result<(), Error> {
+ if let Err(err) = self.buf_in.append(&mut msg.to_vec()) {
+ error!("[{self}] ! Error while writing to the buffer: {err:?}");
+ return Err(Error::Fatal);
+ }
+
+ match self.parser.parse(&mut self.buf_in) {
+ Ok(ParserState::Wait) => Ok(()),
+ Ok(ParserState::Done(argv)) => {
+ Query::from_expr(Expr::Array(
+ argv.iter().map(|x| Expr::Blob(x.to_vec())).collect(),
+ ))
+ .map(|query| {
+ debug!("[{self}] < {query}");
+ let resp = query.exec(db);
+
+ self.draft(&resp.encode());
+ if self.buf_in.empty() {
+ self.buf_in.reset();
+ self.flags &= !(State::PendingExec as u8);
+ } else {
+ self.flags &= !(State::ReadReady as u8);
+ self.flags |= State::PendingExec as u8;
+ }
+
+ debug!("[{self}] > {resp}");
+ })
+ .map_err(|e| {
+ debug!("[{self}] ! {e:?}");
+ let wrapped_err = Error::Query(e);
+ let resp = Expr::Error {
+ code: error_code(&wrapped_err),
+ msg: format!("{}", wrapped_err).into_bytes(),
+ };
+
+ self.draft(&resp.encode());
+
+ if self.buf_in.empty() {
+ self.buf_in.reset();
+ self.flags &= !(State::PendingExec as u8);
+ } else {
+ self.flags &= !(State::ReadReady as u8);
+ self.flags |= State::PendingExec as u8;
+ }
+ wrapped_err
+ })?;
+
+ Ok(())
+ }
+ Err(err) => {
+ let wrapped_err = Error::Parser(err);
+
+ debug!("[{self}] ! {wrapped_err}");
+
+ let resp = Expr::Error {
+ code: error_code(&wrapped_err),
+ msg: format!("{}", wrapped_err).into_bytes(),
+ };
+
+ self.draft(&resp.encode());
+
+ /* If there's a parsing error, we throw away the entire buffer. This means
+ * potentially throwing away valid commands appearing later in the buffer */
+ self.buf_in.reset();
+ self.flags &= !(State::PendingExec as u8);
+
+ Err(wrapped_err)
+ }
+ }
+ }
+
+ fn draft(&mut self, msg: &[u8]) {
+ self.buf_out = msg.to_vec();
+ self.flags = State::WriteReady as u8;
+ }
+
+ pub fn reply(&mut self) -> Result<(), Error> {
+ assert!(self.write_ready());
+
+ match self.socket.write(&self.buf_out) {
+ Ok(wrote_len) if wrote_len == self.buf_out.len() => {
+ self.flags &= !(State::WriteReady as u8);
+
+ /* Full write, we're ready to accept more data. If there are still queries waiting
+ * to be processed in the buffer, wait until these are processed first. */
+ if !self.pending_exec() {
+ self.flags |= State::ReadReady as u8;
+ }
+
+ if let Err(e) = self.socket.flush() {
+ return Err(Error::IO(e));
+ }
+
+ Ok(())
+ }
+ Ok(_wrote_len) => {
+ /* Partial write, track the remaining length */
+ self.flags |= State::WriteReady as u8;
+ Ok(())
+ }
+ Err(e) if e.kind() == ErrorKind::BrokenPipe => Err(Error::IO(e)),
+ Err(e) if e.kind() == ErrorKind::ConnectionReset => Err(Error::IO(e)),
+ Err(e) => panic!("FATAL | Unexpected error: {e:?}"),
+ }
+ }
+
+ pub fn pending_exec(&self) -> bool {
+ self.flags & State::PendingExec as u8 > 0
+ }
+ pub fn read_ready(&self) -> bool {
+ self.flags & State::ReadReady as u8 > 0
+ }
+ pub fn write_ready(&self) -> bool {
+ self.flags & State::WriteReady as u8 > 0
+ }
+}
+
+impl Read for Conn {
+ fn read(&mut self, buf: &mut [u8]) -> Result<usize, std::io::Error> {
+ assert!(self.read_ready());
+ self.socket.read(buf)
+ }
+}
+
+impl Write for Conn {
+ fn write(&mut self, buf: &[u8]) -> Result<usize, std::io::Error> {
+ self.socket.write(buf)
+ }
+
+ fn flush(&mut self) -> Result<(), std::io::Error> {
+ self.socket.flush()
+ }
+}
+
+impl Display for Conn {
+ fn fmt(&self, f: &mut Formatter) -> fmt::Result {
+ match self.socket.peer_addr() {
+ Ok(addr) => write!(f, "{addr} #{}", self.token.0),
+ Err(_) => write!(f, "unknown-addr #{}", self.token.0),
+ }
+ }
+}
+
+// Errors
+
+impl fmt::Display for Error {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ let res = match self {
+ Error::IO(err) => format!("IO error: {err:?}"),
+ Error::Parser(err) => err.to_string(),
+ Error::Query(err) => err.to_string(),
+ Error::Fatal => "Fatal error".to_string(),
+ };
+
+ write!(f, "{}", res)
+ }
+}
+
+fn error_code(error: &Error) -> Vec<u8> {
+ let code = match error {
+ Error::IO(_) => "SERVER",
+ Error::Parser(_) => "PROTOCOL",
+ Error::Query(_) => "QUERY",
+ Error::Fatal => "SERVER",
+ };
+
+ code.as_bytes().to_vec()
+}
diff --git a/src/server/parser.rs b/src/server/parser.rs
new file mode 100644
index 0000000..e9c2aad
--- /dev/null
+++ b/src/server/parser.rs
@@ -0,0 +1,118 @@
+use crate::common::sized_buffer::Error as SizedBufferError;
+use crate::common::sized_buffer::SizedBuffer;
+use std::fmt;
+
+#[derive(Debug)]
+pub enum Error {
+ ExpectedArray(char),
+ ExpectedString(char),
+ InvalidLength,
+ Buffer(SizedBufferError),
+}
+
+#[derive(Debug)]
+pub enum State {
+ Wait,
+ Done(Vec<Vec<u8>>),
+}
+
+pub struct Parser {
+ pub argv: Vec<Vec<u8>>,
+ argc_rem: u64,
+ arg_len: i64,
+}
+
+impl Parser {
+ pub fn new() -> Self {
+ Self {
+ argv: Vec::new(),
+ argc_rem: 0,
+ arg_len: -1,
+ }
+ }
+
+ pub fn parse<const N: usize>(
+ &mut self,
+ buffer: &mut SizedBuffer<[u8; N]>,
+ ) -> Result<State, Error> {
+ if self.argc_rem == 0 {
+ assert_eq!(self.arg_len, -1);
+
+ /* argc_rem wasn't initialized yet. The current buffer was never read, we need to
+ * figure out the number of args. */
+ if buffer.peek() != b'*' {
+ return Err(Error::ExpectedArray(buffer.peek() as char));
+ }
+ let arr_lf_offset = match buffer.find_offset(b'\n') {
+ Some(offset) => offset,
+ /* Couldn't find \n, wait for more data */
+ None => return Ok(State::Wait),
+ };
+ self.argc_rem = read_data_len(buffer, arr_lf_offset)? as u64;
+ buffer.skip_byte(b'\n').map_err(Error::Buffer)?;
+ }
+
+ /* We already started reading a command. We can process the arguments. */
+ while self.argc_rem > 0 {
+ /* arg_len wasn't initialized yet. We need to find the size of the string argument. */
+ if self.arg_len == -1 {
+ if buffer.peek() != b'$' {
+ return Err(Error::ExpectedString(buffer.peek() as char));
+ }
+
+ let str_lf_offset = match buffer.find_offset(b'\n') {
+ Some(offset) => offset,
+ /* Couldn't find \n, wait for more data */
+ None => return Ok(State::Wait),
+ };
+ self.arg_len = read_data_len(buffer, str_lf_offset)? as i64;
+ buffer.skip_byte(b'\n').map_err(Error::Buffer)?;
+ } else {
+ match buffer.read_count(self.arg_len as usize) {
+ /* Not enough data, wait for more */
+ Err(_) => return Ok(State::Wait),
+ /* Push the new string onto the args list */
+ Ok(arg) => self.argv.push(arg.to_vec()),
+ }
+ buffer.skip_byte(b'\n').map_err(Error::Buffer)?;
+
+ self.argc_rem -= 1;
+ self.arg_len = -1;
+ }
+ }
+
+ let argv = self.argv.clone();
+ self.argv.clear();
+ Ok(State::Done(argv))
+ }
+}
+
+fn read_data_len<const N: usize>(
+ buffer: &mut SizedBuffer<[u8; N]>,
+ n: usize,
+) -> Result<usize, Error> {
+ /* Read from 1 to skip the data tag */
+ let mut len: usize = 0;
+
+ for (i, x) in buffer.read_count_unchecked(n)[1..].iter().rev().enumerate() {
+ if *x < 48 || *x > 57 {
+ return Err(Error::InvalidLength);
+ }
+ len += (x - 48) as usize * 10usize.pow(i as u32);
+ }
+
+ Ok(len)
+}
+
+impl fmt::Display for Error {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ let res = match self {
+ Error::ExpectedArray(got) => format!("Expected array, got {:?}.", *got as char),
+ Error::ExpectedString(got) => format!("Expected string, got {:?}.", *got as char),
+ Error::InvalidLength => "Invalid header size.".to_string(),
+ Error::Buffer(err) => format!("Error read or writing to/from the buffer: {err:?}"),
+ };
+
+ write!(f, "{}", res)
+ }
+}