diff options
author | evuez <julien@mulga.net> | 2022-11-26 15:38:06 -0500 |
---|---|---|
committer | evuez <julien@mulga.net> | 2024-04-03 22:44:12 +0200 |
commit | 86098797034cbc7eb6db0cee54e17f8dcaedbc5d (patch) | |
tree | 29b6225ead843eb9022296a54657bbadfa1c4da0 /src | |
download | blom-main.tar.gz |
Diffstat (limited to 'src')
-rw-r--r-- | src/common/db.rs | 89 | ||||
-rw-r--r-- | src/common/expr.rs | 364 | ||||
-rw-r--r-- | src/common/ip.rs | 195 | ||||
-rw-r--r-- | src/common/mod.rs | 7 | ||||
-rw-r--r-- | src/common/query.rs | 252 | ||||
-rw-r--r-- | src/common/sized_buffer.rs | 113 | ||||
-rw-r--r-- | src/common/term.rs | 517 | ||||
-rw-r--r-- | src/common/trie.rs | 151 | ||||
-rw-r--r-- | src/help.rs | 43 | ||||
-rw-r--r-- | src/main.rs | 37 | ||||
-rw-r--r-- | src/query.rs | 97 | ||||
-rw-r--r-- | src/repl.rs | 134 | ||||
-rw-r--r-- | src/server.rs | 291 | ||||
-rw-r--r-- | src/server/conn.rs | 230 | ||||
-rw-r--r-- | src/server/parser.rs | 118 |
15 files changed, 2638 insertions, 0 deletions
diff --git a/src/common/db.rs b/src/common/db.rs new file mode 100644 index 0000000..77a1752 --- /dev/null +++ b/src/common/db.rs @@ -0,0 +1,89 @@ +#![allow(clippy::from_over_into)] +use crate::common::ip::Addr; +use crate::common::ip::Block; +use crate::common::ip::Meta; +use crate::common::trie; +use std::collections::HashMap; + +pub struct Database { + pub trie: trie::Trie, + pub meta: HashMap<String, String>, +} + +pub struct BlockInfo { + exists: bool, + parent: Option<Block>, + child_count: usize, + start: Addr, + end: Addr, + size: u128, +} + +impl Database { + pub fn new() -> Self { + Database { + trie: trie::Trie::new(), + meta: HashMap::new(), + } + } + + pub fn get(&self, block: &Block) -> Option<Block> { + self.trie.find(block) + } + + /// Sets `meta` on the given `block`. Replaces the metadata if `block` already exists. + /// Creates the block if it doesn't exist. + pub fn set(&mut self, block: Block, meta: Meta) { + self.trie.insert(block, meta); + } + + /// Returns the parent of the given block. + /// + /// Returns `None` if the block has no parents. + pub fn parent(&self, block: &Block) -> Option<Block> { + self.trie.parent(block) + } + + /// Returns the children of the given block. + /// + /// Returns an empty `Vec` if the given block has no children. + pub fn children(&self, block: &Block) -> Vec<Block> { + self.trie.children(block) + } + + pub fn info(&self, block: &Block) -> BlockInfo { + BlockInfo { + exists: self.exists(block), + parent: self.trie.parent(block), + child_count: self.trie.children(block).len(), + start: block.start(), + end: block.end(), + size: block.size(), + } + } + + pub fn del(&self, _block: &Block) -> Option<()> { + None + } + + pub fn exists(&self, block: &Block) -> bool { + self.trie.find(block).is_some() + } +} + +// Into + +impl Into<Vec<(&str, String)>> for BlockInfo { + fn into(self) -> Vec<(&'static str, String)> { + vec![ + ("EXISTS", self.exists.to_string()), + ( + "PARENT", + self.parent.map(|x| x.to_string()).unwrap_or_default(), + ), + ("CHILD-COUNT", self.child_count.to_string()), + ("RANGE", format!("{} - {}", self.start, self.end)), + ("SIZE", self.size.to_string()), + ] + } +} diff --git a/src/common/expr.rs b/src/common/expr.rs new file mode 100644 index 0000000..0a7880b --- /dev/null +++ b/src/common/expr.rs @@ -0,0 +1,364 @@ +#![allow(clippy::from_over_into)] +use crate::common::ip; +use std::convert::Into; +use std::fmt; +use std::io::BufRead; +use std::io::BufReader; +use std::io::Read; +use std::net::TcpStream; +use std::num::ParseIntError; + +#[derive(Debug, Clone, Hash)] +pub enum Expr { + Array(Vec<Expr>), + Blob(Vec<u8>), + Bool(bool), + Error { code: Vec<u8>, msg: Vec<u8> }, + Line(Vec<u8>), + Number(u128), + Table(Vec<(Expr, Expr)>), + Null, +} + +#[derive(Debug)] +pub enum Error { + ConnectionClosed, + InvalidChar(u8), + NotANumber(Vec<u8>), + InvalidTag(u8), + InvalidHeaderSize, + NotEnoughBytes, + ExpectedLF(u8), +} + +impl Expr { + pub fn from_bytes(reader: &mut BufReader<&TcpStream>) -> Result<Self, Error> { + let mut header = vec![]; + + reader + .read_until(b'\n', &mut header) + .map_err(|_| Error::ConnectionClosed)?; + + if header.is_empty() { + return Err(Error::ConnectionClosed); + } + + match header[0] { + b'!' => parse_error(header, reader), + b'#' => parse_bool(header, reader), + b'$' => parse_blob(header, reader), + b'*' => parse_array(header, reader), + b'+' => parse_line(header, reader), + b':' => parse_number(header, reader), + b'_' => Ok(Expr::Null), + _ => Err(Error::InvalidTag(header[0])), + } + } + + pub fn from_query(query: &str) -> Self { + let args = query + .split(' ') + .map(|x| Expr::Blob(x.as_bytes().to_vec())) + .collect(); + Expr::Array(args) + } + + pub fn encode(&self) -> Vec<u8> { + match self { + Expr::Array(arr) => encode_array(arr.clone()), + Expr::Blob(blob) => encode_blob(blob.to_vec()), + Expr::Bool(val) => encode_bool(*val), + Expr::Error { code, msg } => encode_error(code.to_vec(), msg.to_vec()), + Expr::Line(line) => encode_line(line.to_vec()), + Expr::Number(number) => encode_number(*number), + Expr::Null => b"_\n".to_vec(), + Expr::Table(table) => encode_table(table), + } + } +} + +// Display + +impl fmt::Display for Expr { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let res = match self { + Expr::Array(array) => { + format!( + "[{}]", + array + .iter() + .map(|x| x.to_string()) + .collect::<Vec<String>>() + .join(", ") + ) + } + Expr::Bool(value) => { + if *value { + "true".to_string() + } else { + "false".to_string() + } + } + Expr::Blob(blob) => format!("{:?}", String::from_utf8_lossy(blob).into_owned()), + Expr::Number(number) => format!("{:?}", number), + Expr::Error { code, msg } => format!( + "!{} {:?}", + String::from_utf8_lossy(code).into_owned(), + String::from_utf8_lossy(msg).into_owned() + ), + Expr::Line(line) => format!("{:?}", String::from_utf8_lossy(line).into_owned()), + Expr::Null => "(null)".to_string(), + Expr::Table(table) => { + format!( + "[{}]", + table + .iter() + .map(|(k, v)| format!("{}: {}", k, v)) + .collect::<Vec<String>>() + .join(", ") + ) + } + }; + + write!(f, "{}", res) + } +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let res = match self { + Error::ConnectionClosed => "Connection closed.".to_string(), + Error::InvalidChar(got) => format!("Invalid char: {:?}.", *got as char), + Error::NotANumber(got) => format!("Not a number: {:?}.", got), + Error::ExpectedLF(got) => format!("Expected LF, got: {:?}.", *got as char), + Error::InvalidHeaderSize => "Invalid header size.".to_string(), + Error::InvalidTag(got) => format!("Invalid tag {:?}.", *got as char), + Error::NotEnoughBytes => "Not enough bytes.".to_string(), + }; + + write!(f, "{}", res) + } +} + +// Encoders + +fn encode_array(array: Vec<Expr>) -> Vec<u8> { + let mut res = array.len().to_string().as_bytes().to_vec(); + + res.insert(0, b'*'); + res.push(b'\n'); + + for expr in array { + res.append(&mut expr.encode()); + } + + res +} +fn encode_bool(val: bool) -> Vec<u8> { + if val { + b"#t\n".to_vec() + } else { + b"#f\n".to_vec() + } +} +fn encode_blob(mut blob: Vec<u8>) -> Vec<u8> { + let mut res = blob.len().to_string().as_bytes().to_vec(); + + res.insert(0, b'$'); + res.push(b'\n'); + res.append(&mut blob); + res.push(b'\n'); + + res +} + +fn encode_line(mut line: Vec<u8>) -> Vec<u8> { + line.insert(0, b'+'); + line.push(b'\n'); + + line +} + +fn encode_number(number: u128) -> Vec<u8> { + let mut res = vec![b':']; + res.append(&mut number.to_string().as_bytes().to_vec()); + res.push(b'\n'); + + res +} + +fn encode_error(mut code: Vec<u8>, mut msg: Vec<u8>) -> Vec<u8> { + let mut res = (code.len() + msg.len()).to_string().as_bytes().to_vec(); + + res.insert(0, b'!'); + res.push(b'\n'); + res.append(&mut code); + res.push(b' '); + res.append(&mut msg); + res.push(b'\n'); + + res +} + +fn encode_table(table: &Vec<(Expr, Expr)>) -> Vec<u8> { + let mut res = table.len().to_string().as_bytes().to_vec(); + + res.insert(0, b'%'); + res.push(b'\n'); + + for (key, val) in table { + res.append(&mut key.encode()); + res.append(&mut val.encode()); + } + + res +} + +// Parsers + +fn parse_line(header: Vec<u8>, _reader: &mut BufReader<&TcpStream>) -> Result<Expr, Error> { + if header.len() < 2 { + return Err(Error::NotEnoughBytes); + } + + Ok(Expr::Line(header[1..header.len() - 1].to_vec())) +} + +fn parse_number(header: Vec<u8>, _reader: &mut BufReader<&TcpStream>) -> Result<Expr, Error> { + if header.len() < 2 { + return Err(Error::NotEnoughBytes); + } + + let digits = &header[1..header.len() - 1]; + + Ok(Expr::Number( + as_u128(digits).map_err(|_| Error::NotANumber(digits.to_vec()))?, + )) +} + +fn parse_blob(header: Vec<u8>, reader: &mut BufReader<&TcpStream>) -> Result<Expr, Error> { + let size = as_usize(&header[1..header.len() - 1]).map_err(|_| Error::InvalidHeaderSize)?; + + let mut body = vec![0u8; size]; + reader + .read_exact(&mut body) + .map_err(|_| Error::NotEnoughBytes)?; + + skip_lf(reader)?; + + Ok(Expr::Blob(body)) +} + +fn parse_error(header: Vec<u8>, reader: &mut BufReader<&TcpStream>) -> Result<Expr, Error> { + let size = as_usize(&header[1..header.len() - 1]).map_err(|_| Error::InvalidHeaderSize)?; + + let mut body = vec![0u8; size]; + reader + .read_exact(&mut body) + .map_err(|_| Error::NotEnoughBytes)?; + + skip_lf(reader)?; + + let parts: Vec<&[u8]> = body.splitn(2, |x| *x == b' ').collect(); + if parts.len() < 2 { + return Err(Error::NotEnoughBytes); + } + + Ok(Expr::Error { + code: parts[0].to_vec(), + msg: parts[1].to_vec(), + }) +} + +fn parse_array(header: Vec<u8>, reader: &mut BufReader<&TcpStream>) -> Result<Expr, Error> { + let size = as_usize(&header[1..header.len() - 1]).map_err(|_| Error::InvalidHeaderSize)?; + + let body: Vec<Expr> = (0..size) + .map(|_| Expr::from_bytes(reader)) + .collect::<Result<Vec<Expr>, _>>()?; + + Ok(Expr::Array(body)) +} + +fn parse_bool(header: Vec<u8>, _reader: &mut BufReader<&TcpStream>) -> Result<Expr, Error> { + if header.len() < 2 { + return Err(Error::NotEnoughBytes); + } + + match header[1] { + b't' => Ok(Expr::Bool(true)), + b'f' => Ok(Expr::Bool(false)), + other => Err(Error::InvalidChar(other)), + } +} + +// Into + +impl Into<Expr> for ip::Block { + fn into(self) -> Expr { + Expr::Line(self.into()) + } +} + +impl Into<Expr> for &ip::Block { + fn into(self) -> Expr { + Expr::Line(self.into()) + } +} + +impl Into<Expr> for &str { + fn into(self) -> Expr { + Expr::Blob(self.into()) + } +} + +impl Into<Expr> for String { + fn into(self) -> Expr { + Expr::Blob(self.into()) + } +} + +impl Into<Expr> for u8 { + fn into(self) -> Expr { + Expr::Number(self as u128) + } +} + +impl Into<Expr> for Vec<(&str, String)> { + fn into(self) -> Expr { + Expr::Table( + self.into_iter() + .map(|(x, y)| (x.into(), y.into())) + .collect(), + ) + } +} + +impl<T: Into<Expr>> Into<Expr> for Vec<T> { + fn into(self) -> Expr { + Expr::Array(self.into_iter().map(|x| x.into()).collect()) + } +} + +// Helpers + +fn skip_lf(reader: &mut BufReader<&TcpStream>) -> Result<(), Error> { + let mut lf = vec![0u8; 1]; + + reader + .read_exact(&mut lf) + .map_err(|_| Error::NotEnoughBytes)?; + + if lf != vec![b'\n'] { + return Err(Error::ExpectedLF(lf[0])); + } + Ok(()) +} + +fn as_usize(array: &[u8]) -> Result<usize, ParseIntError> { + std::str::from_utf8(array).unwrap().parse::<usize>() +} + +fn as_u128(array: &[u8]) -> Result<u128, ParseIntError> { + std::str::from_utf8(array).unwrap().parse::<u128>() +} diff --git a/src/common/ip.rs b/src/common/ip.rs new file mode 100644 index 0000000..8c1e6b6 --- /dev/null +++ b/src/common/ip.rs @@ -0,0 +1,195 @@ +#![allow(clippy::from_over_into)] +use crate::common::expr::Expr; +use memchr::memchr; +use std::cmp::Ordering; +use std::collections::HashMap; +use std::fmt; +use std::fmt::Debug; +use std::net::IpAddr; +use std::net::Ipv4Addr; +use std::net::Ipv6Addr; +use std::str::FromStr; + +pub const BITMASK: u128 = 0x8000_0000_0000_0000_0000_0000_0000_0000; +pub const BITCOUNT: u8 = 128; + +pub type Addr = IpAddr; + +#[derive(Debug)] +pub enum Error { + InvalidAddr(String), + InvalidBlock(String), + InvalidMaskLength(String), + InvalidNibblesCount(usize), + MaskLengthTooLarge(u8), +} + +#[derive(Copy, Clone, Eq)] +pub struct Block { + pub addr: u128, + pub mlen: u8, +} + +pub type Meta = HashMap<String, Expr>; + +impl Block { + fn new(addr: IpAddr, mlen: u8) -> Self { + let (ipv6_addr, ipv6_mlen) = match addr { + IpAddr::V4(addr) => { + assert!(mlen <= 32); + (addr.to_ipv6_mapped(), mlen + 96) + } + IpAddr::V6(addr) => { + assert!(mlen <= 128); + (addr, mlen) + } + }; + + let segments = ipv6_addr.segments(); + let addr_bits = ((segments[0] as u128) << 112) + | ((segments[1] as u128) << 96) + | ((segments[2] as u128) << 80) + | ((segments[3] as u128) << 64) + | ((segments[4] as u128) << 48) + | ((segments[5] as u128) << 32) + | ((segments[6] as u128) << 16) + | (segments[7] as u128); + + Self { + addr: addr_bits, + mlen: ipv6_mlen, + } + } + + pub fn from_bytestring(bytestring: &[u8]) -> Result<Self, Error> { + match memchr(b'/', bytestring) { + Some(pos) => { + let addr = IpAddr::parse_ascii(&bytestring[0..pos]).map_err(|_| { + Error::InvalidAddr( + std::str::from_utf8(&bytestring[0..pos]) + .unwrap() + .to_string(), + ) + })?; + + let mut mlen = 0; + for (i, byte) in bytestring[pos + 1..bytestring.len()] + .iter() + .rev() + .enumerate() + { + if *byte < 48 || *byte > 57 { + return Err(Error::InvalidBlock("Invalid1".to_string())); + } + mlen += (byte - 48) * 10u8.pow(i as u32); + } + + Ok(Self::new(addr, mlen)) + } + None => Err(Error::InvalidBlock( + std::str::from_utf8(bytestring).unwrap().to_string(), + )), + } + } + + fn addr(&self) -> IpAddr { + addr_from_bits(self.addr) + } + + pub fn start(&self) -> IpAddr { + let mask = 0xffff_ffff_ffff_ffff_ffff_ffff_ffff_ffff ^ ((1 << (128 - self.mlen)) - 1); + addr_from_bits(self.addr & mask) + } + pub fn end(&self) -> IpAddr { + let mask = !(0xffff_ffff_ffff_ffff_ffff_ffff_ffff_ffff ^ ((1 << (128 - self.mlen)) - 1)); + addr_from_bits(self.addr | mask) + } + + pub fn size(&self) -> u128 { + 2u128.pow(128 - self.mlen as u32) + } +} + +fn addr_from_bits(bits: u128) -> IpAddr { + if bits & 0xffff_ffff_ffff_ffff_ffff_ffff_0000_0000 == 0xffff_0000_0000 { + IpAddr::V4(Ipv4Addr::from((bits & 0xff_ff_ff_ff) as u32)) + } else { + IpAddr::V6(Ipv6Addr::from(bits)) + } +} + +impl FromStr for Block { + type Err = Error; + fn from_str(s: &str) -> Result<Block, Error> { + let parts: Vec<&str> = s.splitn(2, '/').collect(); + if parts.len() < 2 { + return Err(Error::InvalidBlock(s.to_string())); + } + + let addr = + IpAddr::from_str(parts[0]).map_err(|_| Error::InvalidAddr(parts[0].to_string()))?; + let mlen = parts[1] + .parse::<u8>() + .map_err(|_| Error::InvalidMaskLength(parts[1].to_string()))?; + + Ok(Self::new(addr, mlen)) + } +} + +impl fmt::Display for Block { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let addr = self.addr(); + let mlen = if addr.is_ipv4() { + self.mlen - 96 + } else { + self.mlen + }; + write!(f, "{}/{}", addr, mlen) + } +} + +impl fmt::Debug for Block { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let addr = self.addr(); + let mlen = if addr.is_ipv4() { + self.mlen - 96 + } else { + self.mlen + }; + write!(f, "{}/{}", addr, mlen) + } +} + +impl Ord for Block { + fn cmp(&self, other: &Self) -> Ordering { + (self.addr >> (BITCOUNT - self.mlen) as usize) + .cmp(&(other.addr >> ((BITCOUNT - other.mlen) % 32) as usize)) + } +} + +impl PartialEq for Block { + fn eq(&self, other: &Self) -> bool { + self.addr == other.addr && self.mlen == other.mlen + } +} + +impl PartialOrd for Block { + fn partial_cmp(&self, other: &Self) -> Option<Ordering> { + Some( + (self.addr >> (BITCOUNT - self.mlen) as usize) + .cmp(&(other.addr >> ((BITCOUNT - other.mlen) % 32) as usize)), + ) + } +} + +impl Into<Vec<u8>> for Block { + fn into(self) -> Vec<u8> { + (&self).into() + } +} + +impl Into<Vec<u8>> for &Block { + fn into(self) -> Vec<u8> { + self.to_string().into() + } +} diff --git a/src/common/mod.rs b/src/common/mod.rs new file mode 100644 index 0000000..d5ad193 --- /dev/null +++ b/src/common/mod.rs @@ -0,0 +1,7 @@ +pub mod db; +pub mod expr; +pub mod ip; +pub mod query; +pub mod sized_buffer; +pub mod term; +pub mod trie; diff --git a/src/common/query.rs b/src/common/query.rs new file mode 100644 index 0000000..00d79b2 --- /dev/null +++ b/src/common/query.rs @@ -0,0 +1,252 @@ +use crate::common::db::Database; +use crate::common::expr::Expr; +use crate::common::ip; +use std::fmt; + +#[derive(Debug)] +pub enum Error { + Empty, + Invalid, + IP(ip::Error), + ExpectedArray(Expr), + ExpectedBlob(Expr), + InvalidArgsCount { + cmd: String, + expected: usize, + actual: usize, + }, +} + +#[derive(Debug)] +pub enum Query { + Children { ip_block: ip::Block }, + Del { ip_block: ip::Block }, + Echo { msg: Vec<u8> }, + Exists { ip_block: ip::Block }, + Get { ip_block: ip::Block }, + Info { ip_block: ip::Block }, + Parent { ip_block: ip::Block }, + Set { ip_block: ip::Block, value: Expr }, +} + +impl Query { + pub fn from_expr(expr: Expr) -> Result<Self, Error> { + let mut xs = if let Expr::Array(xs) = expr { + xs + } else { + return Err(Error::ExpectedArray(expr)); + }; + + if xs.is_empty() { + return Err(Error::Empty); + } + let cmd_expr = xs.remove(0); + let cmd = if let Expr::Blob(cmd) = cmd_expr { + cmd + } else { + return Err(Error::ExpectedBlob(cmd_expr)); + }; + + match &*cmd { + b"ECHO" => build_echo(xs), // ECHO str + b"SET" => build_set(xs), // SET ::/32 Expr + b"GET" => build_get(xs), // GET ::/32 + b"DEL" => build_del(xs), // DEL ::/32 + b"EXISTS" => build_exists(xs), // EXISTS ::/32 + b"INFO" => build_info(xs), // EXISTS ::/32 + b"PARENT" => build_parent(xs), // SUP ::/32 + b"CHILDREN" => build_children(xs), // SUB ::/32 + _ => Err(Error::Invalid), + } + } + + pub fn exec(&self, db: &mut Database) -> Expr { + match self { + Self::Echo { msg } => Expr::Blob(msg.clone()), + Self::Set { ip_block, value: _ } => { + db.set(*ip_block, std::collections::HashMap::new()); + Expr::Blob(b"OK".to_vec()) + } + Self::Get { ip_block } => db.get(ip_block).map(|x| x.into()).unwrap_or(Expr::Null), + Self::Del { ip_block } => db + .del(ip_block) + .map_or(Expr::Null, |_| Expr::Blob(b"OK".to_vec())), + Self::Exists { ip_block } => Expr::Bool(db.exists(ip_block)), + Self::Info { ip_block } => { + let info: Vec<(&str, String)> = db.info(ip_block).into(); + info.into() + } + Self::Parent { ip_block } => { + db.parent(ip_block).map(|x| x.into()).unwrap_or(Expr::Null) + } + Self::Children { ip_block } => db.children(ip_block).into(), + } + } +} + +// Display + +impl fmt::Display for Query { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let res = match self { + Query::Set { ip_block, value } => { + format!("SET {} {}", ip_block, value) + } + Query::Get { ip_block } => { + format!("GET {}", ip_block) + } + Query::Echo { msg } => { + format!("ECHO {:?}", String::from_utf8_lossy(msg).into_owned()) + } + Query::Del { ip_block } => { + format!("DEL {}", ip_block) + } + Query::Exists { ip_block } => { + format!("EXISTS {}", ip_block) + } + Query::Info { ip_block } => { + format!("INFO {}", ip_block) + } + Query::Parent { ip_block } => { + format!("PARENT {}", ip_block) + } + Query::Children { ip_block } => { + format!("CHILDREN {}", ip_block) + } + }; + + write!(f, "{}", res) + } +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let res = match self { + Error::Empty => "Received an empty query.".to_string(), + Error::ExpectedArray(got) => format!("Expected an ARRAY, got: {:?}.", got), + Error::ExpectedBlob(got) => format!("Expected a BLOB, got: {:?}.", got), + Error::IP(e) => format!("IP error: {e:?}"), + Error::Invalid => "Received an invalid query.".to_string(), + Error::InvalidArgsCount { + cmd, + expected, + actual, + } => format!( + "Invalid number of arguments for {:?}. Expected {} but got {}.", + cmd, expected, actual + ), + }; + + write!(f, "{}", res) + } +} + +// Builders + +fn build_echo(mut args: Vec<Expr>) -> Result<Query, Error> { + count_args("ECHO", &args, 1)?; + + let expr = args.remove(0); + let msg = if let Expr::Blob(k) = expr { + k + } else { + return Err(Error::ExpectedBlob(expr)); + }; + Ok(Query::Echo { msg }) +} + +fn build_get(mut args: Vec<Expr>) -> Result<Query, Error> { + count_args("GET", &args, 1)?; + + let expr = args.remove(0); + let ip_block = if let Expr::Blob(ref addr) = expr { + ip::Block::from_bytestring(addr).map_err(Error::IP)? + } else { + return Err(Error::ExpectedBlob(expr)); + }; + Ok(Query::Get { ip_block }) +} +fn build_set(mut args: Vec<Expr>) -> Result<Query, Error> { + count_args("SET", &args, 2)?; + + let expr = args.remove(0); + let ip_block = if let Expr::Blob(ref addr) = expr { + ip::Block::from_bytestring(addr).map_err(Error::IP)? + } else { + return Err(Error::ExpectedBlob(expr)); + }; + Ok(Query::Set { + ip_block, + value: args.remove(0), + }) +} +fn build_del(mut args: Vec<Expr>) -> Result<Query, Error> { + count_args("DEL", &args, 1)?; + + let expr = args.remove(0); + let ip_block = if let Expr::Blob(ref addr) = expr { + ip::Block::from_bytestring(addr).map_err(Error::IP)? + } else { + return Err(Error::ExpectedBlob(expr)); + }; + Ok(Query::Del { ip_block }) +} +fn build_exists(mut args: Vec<Expr>) -> Result<Query, Error> { + count_args("EXISTS", &args, 1)?; + + let expr = args.remove(0); + let ip_block = if let Expr::Blob(ref addr) = expr { + ip::Block::from_bytestring(addr).map_err(Error::IP)? + } else { + return Err(Error::ExpectedBlob(expr)); + }; + Ok(Query::Exists { ip_block }) +} +fn build_info(mut args: Vec<Expr>) -> Result<Query, Error> { + count_args("INFO", &args, 1)?; + + let expr = args.remove(0); + let ip_block = if let Expr::Blob(ref addr) = expr { + ip::Block::from_bytestring(addr).map_err(Error::IP)? + } else { + return Err(Error::ExpectedBlob(expr)); + }; + Ok(Query::Info { ip_block }) +} +fn build_parent(mut args: Vec<Expr>) -> Result<Query, Error> { + count_args("PARENT", &args, 1)?; + + let expr = args.remove(0); + let ip_block = if let Expr::Blob(ref addr) = expr { + ip::Block::from_bytestring(addr).map_err(Error::IP)? + } else { + return Err(Error::ExpectedBlob(expr)); + }; + Ok(Query::Parent { ip_block }) +} +fn build_children(mut args: Vec<Expr>) -> Result<Query, Error> { + count_args("CHILDREN", &args, 1)?; + + let expr = args.remove(0); + let ip_block = if let Expr::Blob(ref addr) = expr { + ip::Block::from_bytestring(addr).map_err(Error::IP)? + } else { + return Err(Error::ExpectedBlob(expr)); + }; + Ok(Query::Children { ip_block }) +} + +// Helpers + +fn count_args<T>(cmd: &str, args: &Vec<T>, expected_count: usize) -> Result<(), Error> { + let args_count = args.len(); + if args_count != expected_count { + return Err(Error::InvalidArgsCount { + cmd: cmd.to_string(), + expected: expected_count, + actual: args_count, + }); + } + + Ok(()) +} diff --git a/src/common/sized_buffer.rs b/src/common/sized_buffer.rs new file mode 100644 index 0000000..ebfcab5 --- /dev/null +++ b/src/common/sized_buffer.rs @@ -0,0 +1,113 @@ +use memchr::memchr; +use std::io; +use std::io::Write; + +pub struct SizedBuffer<T> { + len: usize, + pos: usize, + pub inner: T, +} + +#[derive(Debug)] +pub enum Error { + IO(io::Error), + ReadPastEnd, + WrongByte { expected: u8, got: u8 }, +} + +impl<const N: usize> SizedBuffer<[u8; N]> { + pub fn new() -> Self { + Self { + len: 0, + pos: 0, + inner: [0; N], + } + } + + pub fn append(&mut self, bytes: &mut [u8]) -> Result<(), Error> { + self.write_all(bytes).map_err(Error::IO) + } + + pub fn peek(&self) -> u8 { + self.inner[self.pos] + } + + pub fn read(&mut self) -> Result<u8, Error> { + self.seek(1)?; + Ok(self.inner[self.pos - 1]) + } + + /* Reads n bytes from the buffer, and moves the cursor right after the last byte read. + * + * Example with n=3: + * + * pos = 2 + * v + * [0,1,2,3,4,5,6,7,8,9] + * |-----|^ + * read new pos + */ + pub fn read_count(&mut self, n: usize) -> Result<&[u8], Error> { + self.seek(n)?; + Ok(&self.inner[self.pos - n..self.pos]) + } + + pub fn read_count_unchecked(&mut self, n: usize) -> &[u8] { + self.seek_unchecked(n); + &self.inner[self.pos - n..self.pos] + } + + pub fn skip_byte(&mut self, byte: u8) -> Result<(), Error> { + let x = self.read()?; + if x != byte { + return Err(Error::WrongByte { + expected: byte, + got: x, + }); + } + Ok(()) + } + + pub fn seek(&mut self, n: usize) -> Result<(), Error> { + /* Allows to seek 1 past the end of the buffer. Subsequent reads will still fail but this + * allows to read the last byte. Otherwise trying to read the last byte would fail after + * reading it, when trying to advance the position of the cursor. */ + if self.pos + n > self.len { + return Err(Error::ReadPastEnd); + } + self.pos += n; + + Ok(()) + } + + pub fn seek_unchecked(&mut self, n: usize) { + self.pos += n; + } + + pub fn find_offset(&mut self, x: u8) -> Option<usize> { + memchr(x, &self.inner[self.pos..self.len]) + } + + pub fn empty(&self) -> bool { + self.len == self.pos + } + + pub fn reset(&mut self) { + self.len = 0; + self.pos = 0; + } +} + +impl<const N: usize> Write for SizedBuffer<[u8; N]> { + #[inline] + fn write(&mut self, buf: &[u8]) -> io::Result<usize> { + let amt = (&mut self.inner[self.len..]).write(buf)?; + self.len += amt; + Ok(amt) + } + + #[inline] + fn flush(&mut self) -> io::Result<()> { + Ok(()) + } +} diff --git a/src/common/term.rs b/src/common/term.rs new file mode 100644 index 0000000..44fd819 --- /dev/null +++ b/src/common/term.rs @@ -0,0 +1,517 @@ +#![allow(clippy::manual_range_contains)] +#![allow(clippy::type_complexity)] +use libc::BRKINT; +use libc::CS8; +use libc::ECHO; +use libc::ICANON; +use libc::ICRNL; +use libc::IEXTEN; +use libc::INPCK; +use libc::ISIG; +use libc::ISTRIP; +use libc::IXON; +use libc::OPOST; +use libc::TCSAFLUSH; +use libc::TIOCGWINSZ; +use libc::VMIN; +use libc::VTIME; +use log::trace; +use std::cmp::max; +use std::cmp::min; +use std::fs::File; +use std::io::Read; +use std::io::Write; +use std::mem; +use std::os::unix::io::AsRawFd; + +enum Key { + Null = 0, + CtrlA = 1, + CtrlB = 2, + CtrlC = 3, + CtrlD = 4, + CtrlE = 5, + CtrlF = 6, + CtrlH = 8, + Tab = 9, + CtrlK = 11, + CtrlL = 12, + Enter = 13, + CtrlN = 14, + CtrlP = 16, + CtrlT = 20, + CtrlU = 21, + CtrlW = 23, + Esc = 27, + Backspace = 127, +} + +pub struct Term { + initial_termios: Option<libc::termios>, + tty_in: File, + tty_out: File, + history: Vec<Vec<u8>>, + history_cursor: usize, + column: usize, + line: Vec<u8>, + line_cursor: usize, + prompt: Option<Vec<u8>>, + hints_callback: Option<fn(&[u8]) -> &[u8]>, + columns: usize, + max_rows: usize, +} + +#[derive(Debug)] +struct Position { + column: usize, + #[allow(dead_code)] + row: usize, +} + +impl Term { + pub fn new() -> Self { + let tty_out = File::options() + .read(true) + .write(true) + .open("/dev/tty") + .unwrap(); + let tty_in = tty_out.try_clone().unwrap(); + + Self { + initial_termios: None, + tty_out, + tty_in, + history: Vec::new(), + history_cursor: 0, + column: 0, + line: Vec::new(), + line_cursor: 0, + prompt: None, + hints_callback: None, + columns: 0, + max_rows: 0, + } + } + + pub fn setup_hints(&mut self, hints_callback: fn(&[u8]) -> &[u8]) { + self.hints_callback = Some(hints_callback); + } + + pub fn setup_prompt(&mut self, prompt: &[u8]) { + self.prompt = Some(prompt.to_vec()); + } + + pub fn edit(&mut self) -> Option<Vec<u8>> { + self.enable_raw_mode(); + + self.columns = self.get_columns(); + self.column = self.get_cursor_position().unwrap().column; + + self.refresh_line(); + + loop { + let mut buffer: [u8; 1] = [0]; + + if self.tty_in.read(&mut buffer).unwrap_or(0) != 1 { + break; + } + + match buffer[0] { + x if (x == Key::Tab as u8) => { + todo!("PRESSED TAB"); + } + x if (x == Key::CtrlC as u8) => { + self.tty_out.write_all(b"\r\n").unwrap(); + self.disable_raw_mode(); + std::process::exit(1); + } + x if (x == Key::Enter as u8) => { + let line = self.line.clone(); + + self.history_cursor = self.history.len(); + self.history.push(self.line.clone()); + + self.line.clear(); + self.line_cursor = 0; + + self.tty_out.write_all(b"\r\n").unwrap(); + self.disable_raw_mode(); + + return Some(line); + } + x if (x == Key::Backspace as u8) => { + if self.line_cursor < 1 { + continue; + } + + self.line_cursor -= 1; + self.line.remove(self.line_cursor); + + self.refresh_line(); + } + x if (x == Key::Null as u8) => {} + x if (x == Key::CtrlD as u8) => { + if !self.line.is_empty() { + self.delete_right(); + } else { + self.tty_out.write_all(b"\r\n").unwrap(); + self.disable_raw_mode(); + std::process::exit(1); + } + } + x if (x == Key::CtrlA as u8) => self.move_to_start(), + x if (x == Key::CtrlE as u8) => self.move_to_end(), + x if (x == Key::CtrlB as u8) => self.move_left(), + x if (x == Key::CtrlF as u8) => self.move_right(), + x if (x == Key::CtrlH as u8) => self.delete_left(), + x if (x == Key::CtrlK as u8) => self.delete_to_end(), + x if (x == Key::CtrlL as u8) => { + self.clear_screen(); + self.refresh_line(); + } + // Next line in history + x if (x == Key::CtrlN as u8) => self.history_next(), + // Previous line in history + x if (x == Key::CtrlP as u8) => self.history_prev(), + x if (x == Key::CtrlT as u8) => self.swap_one(), + x if (x == Key::CtrlU as u8) => self.delete_to_start(), + x if (x == Key::CtrlW as u8) => self.delete_word_left(), + x if (x == Key::Esc as u8) => { + let mut seq: [u8; 1] = [0]; + if self.tty_in.read(&mut seq).unwrap_or(0) != 1 { + continue; + } + if seq[0] != b'[' { + // FIXME: Handle Esc+Key / "Esc 0" sequences + continue; + } + if self.tty_in.read(&mut seq).unwrap_or(0) != 1 { + break; + } + if seq[0] >= b'0' && seq[0] <= b'9' { + // FIXME: Handle extended escape sequence + continue; + } + match seq[0] { + // Up + b'A' => self.history_prev(), + // Down + b'B' => self.history_next(), + // Right + b'C' => self.move_right(), + // Left + b'D' => self.move_left(), + // Home + b'H' => self.move_to_start(), + // End + b'F' => self.move_to_end(), + _ => break, + } + } + x => { + self.line.insert(self.line_cursor, x); + self.line_cursor += 1; + self.refresh_line(); + } + } + + self.tty_out.flush().unwrap(); + } + + self.disable_raw_mode(); + + None + } + + fn history_next(&mut self) { + if self.history.is_empty() { + return; + } + + self.line = self.history[self.history_cursor].clone(); + self.line_cursor = self.line.len(); + + self.refresh_line(); + + self.history_cursor = min(self.history_cursor + 1, self.history.len() - 1); + } + + fn history_prev(&mut self) { + if self.history.is_empty() { + return; + } + + // FIXME: Save current line before replacing it with line for the history. + + self.line = self.history[self.history_cursor].clone(); + self.line_cursor = self.line.len(); + + self.refresh_line(); + + self.history_cursor = self.history_cursor.saturating_sub(1); + } + + fn move_right(&mut self) { + self.line_cursor = min(self.line_cursor + 1, self.line.len()); + self.refresh_line(); + } + + fn move_left(&mut self) { + self.line_cursor = self.line_cursor.saturating_sub(1); + self.refresh_line(); + } + + fn move_to_start(&mut self) { + self.line_cursor = 0; + self.refresh_line(); + } + fn move_to_end(&mut self) { + self.line_cursor = self.line.len(); + self.refresh_line(); + } + + fn delete_word_left(&mut self) { + if self.line_cursor == 0 { + return; + } + + let cursor = self.line_cursor; + + while self.line_cursor > 0 && self.line[self.line_cursor - 1] == b' ' { + self.line_cursor -= 1; + } + while self.line_cursor > 0 && self.line[self.line_cursor - 1] != b' ' { + self.line_cursor -= 1; + } + + self.line.drain(self.line_cursor..cursor); + self.refresh_line(); + } + + fn delete_left(&mut self) { + if self.line_cursor == 0 { + return; + } + + self.line.remove(self.line_cursor - 1); + self.line_cursor = self.line_cursor.saturating_sub(1); + self.refresh_line(); + } + + fn delete_right(&mut self) { + if self.line_cursor >= self.line.len() { + return; + } + + self.line.remove(self.line_cursor); + self.refresh_line(); + } + + fn delete_to_start(&mut self) { + if self.line_cursor == 0 { + return; + } + + self.line.drain(0..self.line_cursor); + self.line_cursor = 0; + self.refresh_line(); + } + + fn delete_to_end(&mut self) { + if self.line_cursor >= self.line.len() { + return; + } + + self.line.truncate(self.line_cursor); + self.refresh_line(); + } + + fn swap_one(&mut self) { + if self.line_cursor == 0 { + return; + } + + self.line_cursor = min(self.line_cursor + 1, self.line.len()); + self.line.swap(self.line_cursor - 2, self.line_cursor - 1); + self.refresh_line(); + } + + fn refresh_line(&mut self) { + let prompt_len = self.prompt.as_ref().map(|x| x.len()).unwrap_or(0); + + let rows = (self.line.len() + prompt_len + self.columns - 1) / self.columns; + let prev_rows = self.max_rows; + let curr_row = (self.line_cursor + prompt_len + self.columns - 1) / self.columns; + self.column = (prompt_len + self.line_cursor) % self.columns; + + self.max_rows = max(self.max_rows, rows); + + // + // Clear rows + // + + // Go down to the last row if the cursor is not already there. + if curr_row < prev_rows { + trace!("Go down {}", prev_rows - curr_row); + write!(self.tty_out, "\x1b[{}B", prev_rows - curr_row).unwrap(); + } + + // Clear each row and go up. + for _ in 1..prev_rows { + trace!("Clear row, go up"); + self.tty_out.write_all(b"\r\x1b[0K\x1b[1A").unwrap(); + } + + // Clear the top row. + trace!("Clear top row"); + self.tty_out.write_all(b"\r\x1b[0K").unwrap(); + + // + // Rewrite prompt and line + // + + if let Some(prompt) = &self.prompt { + self.tty_out.write_all(prompt).unwrap(); + } + + self.tty_out.write_all(&self.line).unwrap(); + + // + // Display hints if any + // + + if self.hints_callback.is_some() && !self.line.is_empty() { + let hints = self.hints_callback.unwrap()(&self.line); + + if !hints.is_empty() { + // Column for the last character on the current row. + let max_column = (prompt_len + self.line.len()) % (self.columns + 1); + let max_len = min(self.columns - max_column, hints.len()); + + self.tty_out.write_all(b"\x1b[2m").unwrap(); + self.tty_out.write_all(&hints[..max_len]).unwrap(); + self.tty_out.write_all(b"\x1b[22m").unwrap(); + } + } + + // + // Move cursor to the right position. + // + + if self.column == 0 + && self.line.len() == self.line_cursor + && (prompt_len + self.line_cursor) > 0 + { + // We reached the end of a row, insert a new line and go back to the first row. + trace!("Insert linefeed"); + self.max_rows = max(self.max_rows, rows + 1); + self.tty_out.write_all(b"\r\n").unwrap(); + } else if self.column == 0 { + self.tty_out.write_all(b"\r").unwrap(); + } else { + // Move to the right column. + write!(self.tty_out, "\r\x1b[{}C", self.column).unwrap(); + + if rows > curr_row { + // Move up to the current row. + write!(self.tty_out, "\x1b[{}A", rows - curr_row).unwrap(); + } + } + } + + fn enable_raw_mode(&mut self) { + let mut termios = mem::MaybeUninit::uninit(); + + if unsafe { libc::tcgetattr(self.tty_out.as_raw_fd(), termios.as_mut_ptr()) } == -1 { + panic!("Failed to enable raw mode"); + } + + self.initial_termios = Some(unsafe { termios.assume_init() }); + let mut raw = self.initial_termios.unwrap(); + + raw.c_iflag &= !(BRKINT | ICRNL | INPCK | ISTRIP | IXON); + raw.c_oflag &= !OPOST; + raw.c_cflag |= CS8; + raw.c_lflag &= !(ECHO | ICANON | IEXTEN | ISIG); + raw.c_cc[VMIN] = 1; + raw.c_cc[VTIME] = 0; + + if unsafe { libc::tcsetattr(self.tty_out.as_raw_fd(), TCSAFLUSH, &raw) } < 0 { + panic!("Failed to enable raw mode"); + } + } + + fn disable_raw_mode(&mut self) { + if self.initial_termios.is_some() + && unsafe { + libc::tcsetattr( + self.tty_out.as_raw_fd(), + TCSAFLUSH, + &self.initial_termios.unwrap(), + ) + } != -1 + { + self.initial_termios = None; + } + } + + fn get_columns(&mut self) -> usize { + let ws = libc::winsize { + ws_row: 0, + ws_col: 0, + ws_xpixel: 0, + ws_ypixel: 0, + }; + + if unsafe { libc::ioctl(self.tty_out.as_raw_fd(), TIOCGWINSZ, &ws) } < 0 || ws.ws_col == 0 { + todo!("Query the terminal for columns count."); + } + + ws.ws_col as usize + } + + fn get_cursor_position(&mut self) -> Option<Position> { + self.enable_raw_mode(); + + self.tty_out.write_all(b"\x1b[6n").unwrap(); + self.tty_out.flush().unwrap(); + + let mut buffer: [u8; 32] = [0; 32]; + + let mut sep_index = None; + let mut index = 0; + + let mut bytes = (&self.tty_in).bytes(); + while let Some(Ok(byte)) = bytes.next() { + if byte == b'R' || index >= buffer.len() { + break; + } + if byte == b';' { + sep_index = Some(index); + } + + buffer[index] = byte; + index += 1; + } + + if buffer[0] != Key::Esc as u8 || buffer[1] != b'[' { + return None; + } + + sep_index.map(|sep_index| Position { + row: as_usize(&buffer[2..sep_index]), + column: as_usize(&buffer[sep_index + 1..index]), + }) + } + + fn clear_screen(&mut self) { + self.tty_out.write_all(b"\x1b[H\x1b[2J").unwrap(); + } +} + +fn as_usize(array: &[u8]) -> usize { + std::str::from_utf8(array) + .unwrap() + .parse::<usize>() + .unwrap() +} diff --git a/src/common/trie.rs b/src/common/trie.rs new file mode 100644 index 0000000..381d2b5 --- /dev/null +++ b/src/common/trie.rs @@ -0,0 +1,151 @@ +#![allow(dead_code)] +use crate::common::ip::Block; +use crate::common::ip::Meta; +use crate::common::ip::BITMASK; + +pub struct Trie { + root: Node, +} + +#[derive(Debug)] +struct Node { + block: Option<Block>, + left: Option<Box<Node>>, + right: Option<Box<Node>>, +} + +impl Node { + pub const NULL: Node = Self { + block: None, + left: None, + right: None, + }; + + fn new(block: Option<Block>) -> Self { + Node { + block, + left: None, + right: None, + } + } + + fn is_leaf(&self) -> bool { + self.left.is_none() && self.right.is_none() + } +} + +impl Trie { + pub fn new() -> Self { + Trie { root: Node::NULL } + } + + pub fn insert(&mut self, block: Block, _meta: Meta) { + let mut node = &mut self.root; + let mut addr = block.addr; + + for _ in 0..block.mlen { + if addr & BITMASK == 0 { + if node.left.is_none() { + node.left = Some(Box::new(Node::NULL)); + } + node = node.left.as_deref_mut().unwrap(); + } else { + if node.right.is_none() { + node.right = Some(Box::new(Node::NULL)); + } + node = node.right.as_deref_mut().unwrap(); + } + addr <<= 1; + } + + node.block = Some(block); + } + + pub fn find(&self, needle: &Block) -> Option<Block> { + self.find_block_node(needle)?.block + } + + pub fn parent(&self, needle: &Block) -> Option<Block> { + let mut node = &self.root; + let mut addr = needle.addr; + let mut parent = None; + + for _ in 0..needle.mlen { + if node.block.is_some() { + parent = node.block; + } + + if addr & BITMASK == 0 { + if node.left.is_some() { + node = node.left.as_deref().unwrap(); + } else { + break; + } + } else if node.right.is_some() { + node = node.right.as_deref().unwrap(); + } else { + break; + } + addr <<= 1; + } + + parent + } + + pub fn children(&self, parent: &Block) -> Vec<Block> { + let root: &Node = match self.find_block_node(parent) { + Some(root) => root, + None => return vec![], + }; + + let mut children: Vec<Block> = Vec::new(); + let mut stack: Vec<&Node> = vec![root]; + + while let Some(node) = stack.pop() { + if node.is_leaf() && node.block.is_some() && node.block.unwrap() != root.block.unwrap() + { + children.push(node.block.unwrap()); + continue; + } + + if node.left.is_some() { + stack.push(node.left.as_deref().unwrap()); + } + if node.right.is_some() { + stack.push(node.right.as_deref().unwrap()); + } + } + + children + } + + fn find_block_node(&self, needle: &Block) -> Option<&Node> { + let mut node = &self.root; + let mut addr = needle.addr; + let mut match_ = None; + + for _ in 0..(needle.mlen + 1) { + if node.block.is_some() { + match_ = Some(node); + } + + if addr & BITMASK == 0 { + if node.left.is_some() { + node = node.left.as_deref().unwrap(); + } else { + break; + } + } else if node.right.is_some() { + node = node.right.as_deref().unwrap(); + } else { + break; + } + addr <<= 1; + } + + match match_?.block { + Some(block) if block == *needle => match_, + _ => None, + } + } +} diff --git a/src/help.rs b/src/help.rs new file mode 100644 index 0000000..188f475 --- /dev/null +++ b/src/help.rs @@ -0,0 +1,43 @@ +use std::process::ExitCode; + +use crate::query; +use crate::repl; +use crate::server; + +pub const USAGE: &str = " +Usage: blom [command] [options] + +Commands: + + server Start the blom server. + + repl Start the blom repl. + query <query> Send <query> to the blom server. + + help [<command>] Print usage. + +General options: + + -h, --help Print usage. +"; + +pub fn cmd(args: &[String]) -> ExitCode { + if args.is_empty() { + println!("{}", USAGE); + return ExitCode::from(0); + } + + match args[0].as_str() { + "server" => println!("{}", server::USAGE), + "repl" => println!("{}", repl::USAGE), + "query" => println!("{}", query::USAGE), + "help" => println!("{}", USAGE), + _ => { + println!("{}", USAGE); + eprintln!("Unknown command {}.", args[0]); + return ExitCode::from(1); + } + } + + ExitCode::SUCCESS +} diff --git a/src/main.rs b/src/main.rs new file mode 100644 index 0000000..f39c5fc --- /dev/null +++ b/src/main.rs @@ -0,0 +1,37 @@ +#![feature(addr_parse_ascii)] + +use std::env; +use std::process::ExitCode; + +mod common; +mod help; +mod query; +mod repl; +mod server; + +fn main() -> ExitCode { + env_logger::init(); + + let args: Vec<String> = env::args().collect(); + + if args.len() <= 1 { + println!("{}", help::USAGE); + eprintln!("Expected command argument."); + return ExitCode::from(1); + } + + let cmd = &args[1]; + let cmd_args = &args[2..]; + + match cmd.as_str() { + "server" => return server::cmd(cmd_args), + "repl" => return repl::cmd(cmd_args), + "query" => return query::cmd(cmd_args), + "help" | "-h" | "--help" => return help::cmd(cmd_args), + _ => { + println!("{}", help::USAGE); + eprintln!("Unknown command {}.", cmd); + return ExitCode::from(1); + } + } +} diff --git a/src/query.rs b/src/query.rs new file mode 100644 index 0000000..1f3b4c8 --- /dev/null +++ b/src/query.rs @@ -0,0 +1,97 @@ +use crate::common::expr::Expr; +use std::io::Read; +use std::io::Write; +use std::net::TcpStream; +use std::process::ExitCode; +use std::str; + +pub const USAGE: &str = " +Usage: blom query <query> + + Sends <query> to the blom server and prints the response + to stdout. + +Examples: + + $ blom query 'CHILDREN 1.2.3.0/24' + 1.2.3.4/26 + 1.2.3.4/32 + + $ blom query --raw 'CHILDREN 1.2.3.0/24' + +1.2.3.4/26\n + +1.2.3.4/32\n + + $ blom query 'INFO 1.2.3.0/24' + Start: 1.2.3.0 + End: 1.2.3.255 + Size: 256 + Parents: 3 + Children: 364 + Meta: \"{\\\"owner\\\": \\\"acme\\\"}\" + +Options: + + -r, --raw Do not parse the response and prints the raw + output from the server instead. + + -v, --verbose Print the request send to the server in addition + to the response. + + -b, --bind Bind to the given ADDRESS:PORT. Default is + 0.0.0.0:4902 + +"; + +pub fn cmd(args: &[String]) -> ExitCode { + let mut addr = "0.0.0.0:4902"; + let mut query: Option<&str> = None; + + let mut arg_index: usize = 0; + while arg_index < args.len() { + match args[arg_index].as_str() { + "--bind" | "-b" => { + addr = args + .get(arg_index + 1) + .expect("Missing address after --bind.") + .as_str(); + arg_index += 1; + } + q => { + if query.is_some() { + println!("Too many arguments."); + return ExitCode::FAILURE; + } + query = Some(q); + } + } + + arg_index += 1; + } + + if let Some(q) = query { + handle_query(addr, q); + return ExitCode::SUCCESS; + } + + println!("Missing query."); + ExitCode::FAILURE +} + +fn handle_query(addr: &str, query: &str) { + let expr = Expr::from_query(query).encode(); + + match TcpStream::connect(addr) { + Ok(mut stream) => { + let _ = stream.write(&expr).unwrap(); + let mut data = [0u8; 64]; + match stream.read(&mut data) { + Ok(_) => { + let resp = str::from_utf8(&data).unwrap(); + println!("{resp:?}"); + } + Err(e) => println!("Read failure: {e:?}"), + } + } + Err(e) => println!("Connection failure: {e:?}"), + } +} diff --git a/src/repl.rs b/src/repl.rs new file mode 100644 index 0000000..339405e --- /dev/null +++ b/src/repl.rs @@ -0,0 +1,134 @@ +use crate::common::expr::Error as ProtoError; +use crate::common::expr::Expr; +use crate::common::term::Term; +use std::io::BufReader; +use std::io::Write; +use std::net::TcpStream; +use std::process::ExitCode; + +pub const USAGE: &str = " +Usage: blom start + + Starts the blom repl. + + -b, --bind Bind to the given ADDRESS:PORT. Default is + 0.0.0.0:4902 + +"; + +enum Error { + Connect, +} + +pub fn cmd(args: &[String]) -> ExitCode { + let mut addr = "0.0.0.0:4902"; + + let mut arg_index: usize = 0; + while arg_index < args.len() { + match args[arg_index].as_str() { + "--addr" | "-a" => { + addr = args + .get(arg_index + 1) + .expect("Missing address after --addr.") + .as_str(); + arg_index += 1; + } + _ => { + println!("Too many arguments."); + return ExitCode::FAILURE; + } + } + + arg_index += 1; + } + + match start(addr) { + Ok(()) => ExitCode::SUCCESS, + Err(_) => ExitCode::FAILURE, + } +} + +fn start(addr: &str) -> Result<(), Error> { + let mut term = Term::new(); + + term.setup_prompt(format!("{addr}> ").as_bytes()); + term.setup_hints(hints); + + while let Some(query) = input(&mut term) { + if query.is_empty() { + continue; + } + + let mut stream = match try_connect(addr) { + Ok(stream) => stream, + Err(_) => continue, + }; + + let expr = Expr::from_query(&query); + + if stream.write(&expr.encode()).is_err() { + println!("Couldn't connect to blom at {addr}"); + continue; + } + + let mut resp = BufReader::new(&stream); + + match Expr::from_bytes(&mut resp) { + Ok(resp) => println!("{resp}"), + Err(ProtoError::ConnectionClosed) => { + println!("Couldn't connect to blom at {addr}"); + continue; + } + Err(e) => { + println!("{e}"); + break; + } + } + } + + Ok(()) +} + +fn try_connect(addr: &str) -> Result<TcpStream, Error> { + match TcpStream::connect(addr) { + Ok(stream) => Ok(stream), + Err(_) => { + println!("Couldn't connect to blom at {addr}."); + Err(Error::Connect) + } + } +} + +fn input(term: &mut Term) -> Option<String> { + let line = term.edit()?; + let query = std::str::from_utf8(&line).unwrap().trim(); + + if query == "exit" { + return None; + } + + Some(query.to_string()) +} + +// TODO: Refactor +fn hints(line: &[u8]) -> &[u8] { + match line { + b"ECHO" | b"echo" => b" MESSAGE", + b"SET" | b"set" => b" BLOCK META", + b"GET" | b"get" => b" BLOCK", + b"DEL" | b"del" => b" BLOCK", + b"EXISTS" | b"exists" => b" BLOCK", + b"INFO" | b"info" => b" BLOCK", + b"PARENT" | b"parent" => b" BLOCK", + b"CHILDREN" | b"children" => b" BLOCK", + b"ECHO " | b"echo " => b"MESSAGE", + b"SET " | b"set " => b"BLOCK META", + b"GET " | b"get " => b"BLOCK", + b"DEL " | b"del " => b"BLOCK", + b"EXISTS " | b"exists " => b"BLOCK", + b"INFO " | b"info " => b"BLOCK", + b"PARENT " | b"parent " => b"BLOCK", + b"CHILDREN " | b"children " => b"BLOCK", + _ => b"", + } +} diff --git a/src/server.rs b/src/server.rs new file mode 100644 index 0000000..c23d74a --- /dev/null +++ b/src/server.rs @@ -0,0 +1,291 @@ +mod conn; +mod parser; + +use crate::common::db::Database; +use crate::common::expr::Error as ExprError; +use crate::common::query::Error as QueryError; +use conn::Conn; +use conn::Error as ConnError; +use log::debug; +use log::error; +use log::info; +use log::trace; +use log::warn; +use mio::event::Event; +use mio::event::Events; +use mio::net::TcpListener; +use mio::Interest; +use mio::Poll; +use mio::Token; +use std::collections::HashMap; +use std::fmt; +use std::io::ErrorKind; +use std::io::Read; +use std::net::SocketAddr; +use std::process; +use std::process::ExitCode; +use std::time::Duration; + +pub const USAGE: &str = " +Usage: blom start + + Starts the blom server. + +Options + + -b, --bind Bind to the given ADDRESS:PORT. Default is + 0.0.0.0:4902 + + +"; + +pub const HEADER: &str = " + ___ ___ ___ ___ + /\\ \\ /\\__\\ /\\ \\ /\\__\\ + /..\\ \\ /./ / /..\\ \\ /..L_L_ + /..\\.\\__\\ /./__/ /./\\.\\__\\ /./L.\\__\\ + \\.\\../ / \\.\\ \\ \\.\\/./ / \\/_/./ / + \\../ / \\.\\__\\ \\../ / /./ / + \\/__/ \\/__/ \\/__/ \\/__/ +"; + +const SERVER_TOKEN: Token = Token(0); +const SERVER_BUFLEN: usize = 512; +const CONN_BUFLEN: usize = 2048; +const MAX_CONN: usize = 1000; + +#[derive(Debug)] +enum Error { + IO(std::io::Error), + Query(QueryError), + Expr(ExprError), +} + +struct Server { + addr: SocketAddr, + pid: u32, + poll: Poll, + pool: Vec<Token>, + listener: TcpListener, + conns: HashMap<Token, Conn>, + buffer: [u8; SERVER_BUFLEN], + db: Database, +} + +pub fn cmd(args: &[String]) -> ExitCode { + println!("{}\n", HEADER); + + let addr = match args.get(0).map(|x| x.as_str()) { + Some("--bind") | Some("-b") => { + Some(args.get(1).expect("Missing address for --bind.").as_str()) + } + _ => None, + } + .unwrap_or("0.0.0.0:4902"); + + let mut server = match Server::bind(addr.parse().unwrap(), MAX_CONN) { + Some(server) => server, + None => return ExitCode::FAILURE, + }; + println!( + "PID: {} | Listening on {}; press Ctrl-C to exit...\n\n", + server.pid, server.addr + ); + + match server.start() { + Ok(()) => { + info!("Received signal X, terminating."); + ExitCode::SUCCESS + } + Err(e) => { + error!("Error: {e:?}"); + ExitCode::FAILURE + } + } +} + +impl Server { + fn bind(addr: SocketAddr, max_conns: usize) -> Option<Self> { + assert!(max_conns > 0); + + let mut listener = match TcpListener::bind(addr) { + Ok(listener) => listener, + Err(ref e) if e.kind() == ErrorKind::AddrInUse => { + println!("{addr} already in use."); + return None; + } + Err(_) => { + println!("Couldn't bind to {addr}."); + return None; + } + }; + + let poll = Poll::new().unwrap(); + + poll.registry() + .register(&mut listener, SERVER_TOKEN, Interest::READABLE) + .unwrap(); + + Some(Server { + addr, + pid: process::id(), + poll, + listener, + conns: HashMap::new(), + buffer: [0; SERVER_BUFLEN], + pool: (1..max_conns + 1).map(Token).collect(), + db: Database::new(), + }) + } + + fn start(&mut self) -> Result<(), Error> { + let mut events = Events::with_capacity(512); + + loop { + match self.poll.poll(&mut events, Some(Duration::from_secs(30))) { + Ok(()) => (), + Err(ref e) if e.kind() == ErrorKind::Interrupted => break, + Err(e) => return Err(Error::IO(e)), + }; + + if events.is_empty() { + trace!("Pulse"); + } + + for event in &events { + match event.token() { + SERVER_TOKEN => self.handle_new(), + token => self.handle_conn(token, event), + } + } + } + + Ok(()) + } + + fn handle_new(&mut self) { + loop { + match self.listener.accept() { + Ok((mut socket, _)) => match self.pool.pop() { + Some(token) => { + self.poll + .registry() + .register(&mut socket, token, Interest::READABLE | Interest::WRITABLE) + .unwrap(); + + let conn = Conn::new(socket, token); + + debug!( + "[{}] Connected. {} clients connected.", + conn, + MAX_CONN - self.pool.len() + ); + + self.conns.insert(token, conn); + } + None => { + warn!( + "[{}] Connection refused: Too many clients.", + socket.peer_addr().unwrap() + ); + break; + } + }, + Err(ref e) if e.kind() == ErrorKind::WouldBlock => break, + Err(_) => panic!("FATAL | Unexpected error"), + } + } + } + + fn handle_conn(&mut self, token: Token, event: &Event) { + let mut did_work = true; + while did_work { + did_work = false; + + let conn = self.conns.get_mut(&token).unwrap(); + + if conn.pending_exec() { + did_work = true; + if let Err(ConnError::Fatal) = conn.handle_pending(&mut self.db) { + self.close_conn(token); + break; + } + } + + if event.is_readable() && conn.read_ready() { + did_work = true; + + match conn.read(&mut self.buffer) { + Ok(0) => { + self.close_conn(token); + break; + } + Ok(len) => { + if let Err(ConnError::Fatal) = + conn.handle_msg(&mut self.db, &self.buffer[0..len]) + { + self.close_conn(token); + break; + } + } + Err(ref e) if e.kind() == ErrorKind::WouldBlock => break, + Err(ref e) if e.kind() == ErrorKind::ConnectionReset => { + self.close_conn(token); + break; + } + Err(e) => { + error!("!!! [{conn}] Fatal error: {e:?}."); + self.close_conn(token); + break; + } + } + } + + if event.is_writable() && conn.write_ready() { + did_work = true; + if let Err(e) = conn.reply() { + error!("!!! [{conn}] Fatal error: {e:?}."); + self.close_conn(token); + break; + }; + } + } + } + + fn close_conn(&mut self, token: Token) { + let conn = self.conns.remove(&token).unwrap(); + self.pool.push(token); + debug!( + "[{conn}] Closed. {} clients connected.", + MAX_CONN - self.pool.len() + ); + } +} + +// Display + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let res = match self { + Error::IO(err) => err.to_string(), + Error::Query(err) => err.to_string(), + Error::Expr(err) => err.to_string(), + }; + + write!(f, "{}", res) + } +} + +// Errors + +impl From<QueryError> for Error { + fn from(err: QueryError) -> Self { + Error::Query(err) + } +} + +impl From<ExprError> for Error { + fn from(err: ExprError) -> Self { + Error::Expr(err) + } +} diff --git a/src/server/conn.rs b/src/server/conn.rs new file mode 100644 index 0000000..7ea5b0d --- /dev/null +++ b/src/server/conn.rs @@ -0,0 +1,230 @@ +use super::parser::Error as ParserError; +use super::parser::Parser; +use super::parser::State as ParserState; +use crate::common::db::Database; +use crate::common::expr::Expr; +use crate::common::query::Error as QueryError; +use crate::common::query::Query; +use crate::common::sized_buffer::SizedBuffer; +use crate::server::CONN_BUFLEN; +use log::debug; +use log::error; + +use mio::net::TcpStream; +use mio::Token; +use std::fmt; +use std::fmt::Display; +use std::fmt::Formatter; +use std::io::ErrorKind; +use std::io::Read; +use std::io::Write; + +#[derive(Debug)] +pub enum Error { + IO(std::io::Error), + Parser(ParserError), + Query(QueryError), + Fatal, +} + +#[derive(Eq, PartialEq, Debug)] +#[repr(u8)] +pub enum State { + PendingExec = 0b100, + ReadReady = 0b010, + WriteReady = 0b001, +} + +pub struct Conn { + pub flags: u8, + socket: TcpStream, + token: Token, + parser: Parser, + buf_in: SizedBuffer<[u8; CONN_BUFLEN]>, + buf_out: Vec<u8>, +} + +impl Conn { + pub fn new(socket: TcpStream, token: Token) -> Self { + Self { + flags: State::ReadReady as u8, + socket, + token, + parser: Parser::new(), + buf_in: SizedBuffer::new(), + buf_out: Vec::new(), + } + } + + pub fn handle_pending(&mut self, db: &mut Database) -> Result<(), Error> { + self.handle_msg(db, &[]) + } + + pub fn handle_msg(&mut self, db: &mut Database, msg: &[u8]) -> Result<(), Error> { + if let Err(err) = self.buf_in.append(&mut msg.to_vec()) { + error!("[{self}] ! Error while writing to the buffer: {err:?}"); + return Err(Error::Fatal); + } + + match self.parser.parse(&mut self.buf_in) { + Ok(ParserState::Wait) => Ok(()), + Ok(ParserState::Done(argv)) => { + Query::from_expr(Expr::Array( + argv.iter().map(|x| Expr::Blob(x.to_vec())).collect(), + )) + .map(|query| { + debug!("[{self}] < {query}"); + let resp = query.exec(db); + + self.draft(&resp.encode()); + if self.buf_in.empty() { + self.buf_in.reset(); + self.flags &= !(State::PendingExec as u8); + } else { + self.flags &= !(State::ReadReady as u8); + self.flags |= State::PendingExec as u8; + } + + debug!("[{self}] > {resp}"); + }) + .map_err(|e| { + debug!("[{self}] ! {e:?}"); + let wrapped_err = Error::Query(e); + let resp = Expr::Error { + code: error_code(&wrapped_err), + msg: format!("{}", wrapped_err).into_bytes(), + }; + + self.draft(&resp.encode()); + + if self.buf_in.empty() { + self.buf_in.reset(); + self.flags &= !(State::PendingExec as u8); + } else { + self.flags &= !(State::ReadReady as u8); + self.flags |= State::PendingExec as u8; + } + wrapped_err + })?; + + Ok(()) + } + Err(err) => { + let wrapped_err = Error::Parser(err); + + debug!("[{self}] ! {wrapped_err}"); + + let resp = Expr::Error { + code: error_code(&wrapped_err), + msg: format!("{}", wrapped_err).into_bytes(), + }; + + self.draft(&resp.encode()); + + /* If there's a parsing error, we throw away the entire buffer. This means + * potentially throwing away valid commands appearing later in the buffer */ + self.buf_in.reset(); + self.flags &= !(State::PendingExec as u8); + + Err(wrapped_err) + } + } + } + + fn draft(&mut self, msg: &[u8]) { + self.buf_out = msg.to_vec(); + self.flags = State::WriteReady as u8; + } + + pub fn reply(&mut self) -> Result<(), Error> { + assert!(self.write_ready()); + + match self.socket.write(&self.buf_out) { + Ok(wrote_len) if wrote_len == self.buf_out.len() => { + self.flags &= !(State::WriteReady as u8); + + /* Full write, we're ready to accept more data. If there are still queries waiting + * to be processed in the buffer, wait until these are processed first. */ + if !self.pending_exec() { + self.flags |= State::ReadReady as u8; + } + + if let Err(e) = self.socket.flush() { + return Err(Error::IO(e)); + } + + Ok(()) + } + Ok(_wrote_len) => { + /* Partial write, track the remaining length */ + self.flags |= State::WriteReady as u8; + Ok(()) + } + Err(e) if e.kind() == ErrorKind::BrokenPipe => Err(Error::IO(e)), + Err(e) if e.kind() == ErrorKind::ConnectionReset => Err(Error::IO(e)), + Err(e) => panic!("FATAL | Unexpected error: {e:?}"), + } + } + + pub fn pending_exec(&self) -> bool { + self.flags & State::PendingExec as u8 > 0 + } + pub fn read_ready(&self) -> bool { + self.flags & State::ReadReady as u8 > 0 + } + pub fn write_ready(&self) -> bool { + self.flags & State::WriteReady as u8 > 0 + } +} + +impl Read for Conn { + fn read(&mut self, buf: &mut [u8]) -> Result<usize, std::io::Error> { + assert!(self.read_ready()); + self.socket.read(buf) + } +} + +impl Write for Conn { + fn write(&mut self, buf: &[u8]) -> Result<usize, std::io::Error> { + self.socket.write(buf) + } + + fn flush(&mut self) -> Result<(), std::io::Error> { + self.socket.flush() + } +} + +impl Display for Conn { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + match self.socket.peer_addr() { + Ok(addr) => write!(f, "{addr} #{}", self.token.0), + Err(_) => write!(f, "unknown-addr #{}", self.token.0), + } + } +} + +// Errors + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let res = match self { + Error::IO(err) => format!("IO error: {err:?}"), + Error::Parser(err) => err.to_string(), + Error::Query(err) => err.to_string(), + Error::Fatal => "Fatal error".to_string(), + }; + + write!(f, "{}", res) + } +} + +fn error_code(error: &Error) -> Vec<u8> { + let code = match error { + Error::IO(_) => "SERVER", + Error::Parser(_) => "PROTOCOL", + Error::Query(_) => "QUERY", + Error::Fatal => "SERVER", + }; + + code.as_bytes().to_vec() +} diff --git a/src/server/parser.rs b/src/server/parser.rs new file mode 100644 index 0000000..e9c2aad --- /dev/null +++ b/src/server/parser.rs @@ -0,0 +1,118 @@ +use crate::common::sized_buffer::Error as SizedBufferError; +use crate::common::sized_buffer::SizedBuffer; +use std::fmt; + +#[derive(Debug)] +pub enum Error { + ExpectedArray(char), + ExpectedString(char), + InvalidLength, + Buffer(SizedBufferError), +} + +#[derive(Debug)] +pub enum State { + Wait, + Done(Vec<Vec<u8>>), +} + +pub struct Parser { + pub argv: Vec<Vec<u8>>, + argc_rem: u64, + arg_len: i64, +} + +impl Parser { + pub fn new() -> Self { + Self { + argv: Vec::new(), + argc_rem: 0, + arg_len: -1, + } + } + + pub fn parse<const N: usize>( + &mut self, + buffer: &mut SizedBuffer<[u8; N]>, + ) -> Result<State, Error> { + if self.argc_rem == 0 { + assert_eq!(self.arg_len, -1); + + /* argc_rem wasn't initialized yet. The current buffer was never read, we need to + * figure out the number of args. */ + if buffer.peek() != b'*' { + return Err(Error::ExpectedArray(buffer.peek() as char)); + } + let arr_lf_offset = match buffer.find_offset(b'\n') { + Some(offset) => offset, + /* Couldn't find \n, wait for more data */ + None => return Ok(State::Wait), + }; + self.argc_rem = read_data_len(buffer, arr_lf_offset)? as u64; + buffer.skip_byte(b'\n').map_err(Error::Buffer)?; + } + + /* We already started reading a command. We can process the arguments. */ + while self.argc_rem > 0 { + /* arg_len wasn't initialized yet. We need to find the size of the string argument. */ + if self.arg_len == -1 { + if buffer.peek() != b'$' { + return Err(Error::ExpectedString(buffer.peek() as char)); + } + + let str_lf_offset = match buffer.find_offset(b'\n') { + Some(offset) => offset, + /* Couldn't find \n, wait for more data */ + None => return Ok(State::Wait), + }; + self.arg_len = read_data_len(buffer, str_lf_offset)? as i64; + buffer.skip_byte(b'\n').map_err(Error::Buffer)?; + } else { + match buffer.read_count(self.arg_len as usize) { + /* Not enough data, wait for more */ + Err(_) => return Ok(State::Wait), + /* Push the new string onto the args list */ + Ok(arg) => self.argv.push(arg.to_vec()), + } + buffer.skip_byte(b'\n').map_err(Error::Buffer)?; + + self.argc_rem -= 1; + self.arg_len = -1; + } + } + + let argv = self.argv.clone(); + self.argv.clear(); + Ok(State::Done(argv)) + } +} + +fn read_data_len<const N: usize>( + buffer: &mut SizedBuffer<[u8; N]>, + n: usize, +) -> Result<usize, Error> { + /* Read from 1 to skip the data tag */ + let mut len: usize = 0; + + for (i, x) in buffer.read_count_unchecked(n)[1..].iter().rev().enumerate() { + if *x < 48 || *x > 57 { + return Err(Error::InvalidLength); + } + len += (x - 48) as usize * 10usize.pow(i as u32); + } + + Ok(len) +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let res = match self { + Error::ExpectedArray(got) => format!("Expected array, got {:?}.", *got as char), + Error::ExpectedString(got) => format!("Expected string, got {:?}.", *got as char), + Error::InvalidLength => "Invalid header size.".to_string(), + Error::Buffer(err) => format!("Error read or writing to/from the buffer: {err:?}"), + }; + + write!(f, "{}", res) + } +} |