diff options
author | evuez <julien@mulga.net> | 2022-11-26 15:38:06 -0500 |
---|---|---|
committer | evuez <julien@mulga.net> | 2024-04-03 22:44:12 +0200 |
commit | 86098797034cbc7eb6db0cee54e17f8dcaedbc5d (patch) | |
tree | 29b6225ead843eb9022296a54657bbadfa1c4da0 /src/common | |
download | blom-main.tar.gz |
Diffstat (limited to 'src/common')
-rw-r--r-- | src/common/db.rs | 89 | ||||
-rw-r--r-- | src/common/expr.rs | 364 | ||||
-rw-r--r-- | src/common/ip.rs | 195 | ||||
-rw-r--r-- | src/common/mod.rs | 7 | ||||
-rw-r--r-- | src/common/query.rs | 252 | ||||
-rw-r--r-- | src/common/sized_buffer.rs | 113 | ||||
-rw-r--r-- | src/common/term.rs | 517 | ||||
-rw-r--r-- | src/common/trie.rs | 151 |
8 files changed, 1688 insertions, 0 deletions
diff --git a/src/common/db.rs b/src/common/db.rs new file mode 100644 index 0000000..77a1752 --- /dev/null +++ b/src/common/db.rs @@ -0,0 +1,89 @@ +#![allow(clippy::from_over_into)] +use crate::common::ip::Addr; +use crate::common::ip::Block; +use crate::common::ip::Meta; +use crate::common::trie; +use std::collections::HashMap; + +pub struct Database { + pub trie: trie::Trie, + pub meta: HashMap<String, String>, +} + +pub struct BlockInfo { + exists: bool, + parent: Option<Block>, + child_count: usize, + start: Addr, + end: Addr, + size: u128, +} + +impl Database { + pub fn new() -> Self { + Database { + trie: trie::Trie::new(), + meta: HashMap::new(), + } + } + + pub fn get(&self, block: &Block) -> Option<Block> { + self.trie.find(block) + } + + /// Sets `meta` on the given `block`. Replaces the metadata if `block` already exists. + /// Creates the block if it doesn't exist. + pub fn set(&mut self, block: Block, meta: Meta) { + self.trie.insert(block, meta); + } + + /// Returns the parent of the given block. + /// + /// Returns `None` if the block has no parents. + pub fn parent(&self, block: &Block) -> Option<Block> { + self.trie.parent(block) + } + + /// Returns the children of the given block. + /// + /// Returns an empty `Vec` if the given block has no children. + pub fn children(&self, block: &Block) -> Vec<Block> { + self.trie.children(block) + } + + pub fn info(&self, block: &Block) -> BlockInfo { + BlockInfo { + exists: self.exists(block), + parent: self.trie.parent(block), + child_count: self.trie.children(block).len(), + start: block.start(), + end: block.end(), + size: block.size(), + } + } + + pub fn del(&self, _block: &Block) -> Option<()> { + None + } + + pub fn exists(&self, block: &Block) -> bool { + self.trie.find(block).is_some() + } +} + +// Into + +impl Into<Vec<(&str, String)>> for BlockInfo { + fn into(self) -> Vec<(&'static str, String)> { + vec![ + ("EXISTS", self.exists.to_string()), + ( + "PARENT", + self.parent.map(|x| x.to_string()).unwrap_or_default(), + ), + ("CHILD-COUNT", self.child_count.to_string()), + ("RANGE", format!("{} - {}", self.start, self.end)), + ("SIZE", self.size.to_string()), + ] + } +} diff --git a/src/common/expr.rs b/src/common/expr.rs new file mode 100644 index 0000000..0a7880b --- /dev/null +++ b/src/common/expr.rs @@ -0,0 +1,364 @@ +#![allow(clippy::from_over_into)] +use crate::common::ip; +use std::convert::Into; +use std::fmt; +use std::io::BufRead; +use std::io::BufReader; +use std::io::Read; +use std::net::TcpStream; +use std::num::ParseIntError; + +#[derive(Debug, Clone, Hash)] +pub enum Expr { + Array(Vec<Expr>), + Blob(Vec<u8>), + Bool(bool), + Error { code: Vec<u8>, msg: Vec<u8> }, + Line(Vec<u8>), + Number(u128), + Table(Vec<(Expr, Expr)>), + Null, +} + +#[derive(Debug)] +pub enum Error { + ConnectionClosed, + InvalidChar(u8), + NotANumber(Vec<u8>), + InvalidTag(u8), + InvalidHeaderSize, + NotEnoughBytes, + ExpectedLF(u8), +} + +impl Expr { + pub fn from_bytes(reader: &mut BufReader<&TcpStream>) -> Result<Self, Error> { + let mut header = vec![]; + + reader + .read_until(b'\n', &mut header) + .map_err(|_| Error::ConnectionClosed)?; + + if header.is_empty() { + return Err(Error::ConnectionClosed); + } + + match header[0] { + b'!' => parse_error(header, reader), + b'#' => parse_bool(header, reader), + b'$' => parse_blob(header, reader), + b'*' => parse_array(header, reader), + b'+' => parse_line(header, reader), + b':' => parse_number(header, reader), + b'_' => Ok(Expr::Null), + _ => Err(Error::InvalidTag(header[0])), + } + } + + pub fn from_query(query: &str) -> Self { + let args = query + .split(' ') + .map(|x| Expr::Blob(x.as_bytes().to_vec())) + .collect(); + Expr::Array(args) + } + + pub fn encode(&self) -> Vec<u8> { + match self { + Expr::Array(arr) => encode_array(arr.clone()), + Expr::Blob(blob) => encode_blob(blob.to_vec()), + Expr::Bool(val) => encode_bool(*val), + Expr::Error { code, msg } => encode_error(code.to_vec(), msg.to_vec()), + Expr::Line(line) => encode_line(line.to_vec()), + Expr::Number(number) => encode_number(*number), + Expr::Null => b"_\n".to_vec(), + Expr::Table(table) => encode_table(table), + } + } +} + +// Display + +impl fmt::Display for Expr { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let res = match self { + Expr::Array(array) => { + format!( + "[{}]", + array + .iter() + .map(|x| x.to_string()) + .collect::<Vec<String>>() + .join(", ") + ) + } + Expr::Bool(value) => { + if *value { + "true".to_string() + } else { + "false".to_string() + } + } + Expr::Blob(blob) => format!("{:?}", String::from_utf8_lossy(blob).into_owned()), + Expr::Number(number) => format!("{:?}", number), + Expr::Error { code, msg } => format!( + "!{} {:?}", + String::from_utf8_lossy(code).into_owned(), + String::from_utf8_lossy(msg).into_owned() + ), + Expr::Line(line) => format!("{:?}", String::from_utf8_lossy(line).into_owned()), + Expr::Null => "(null)".to_string(), + Expr::Table(table) => { + format!( + "[{}]", + table + .iter() + .map(|(k, v)| format!("{}: {}", k, v)) + .collect::<Vec<String>>() + .join(", ") + ) + } + }; + + write!(f, "{}", res) + } +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let res = match self { + Error::ConnectionClosed => "Connection closed.".to_string(), + Error::InvalidChar(got) => format!("Invalid char: {:?}.", *got as char), + Error::NotANumber(got) => format!("Not a number: {:?}.", got), + Error::ExpectedLF(got) => format!("Expected LF, got: {:?}.", *got as char), + Error::InvalidHeaderSize => "Invalid header size.".to_string(), + Error::InvalidTag(got) => format!("Invalid tag {:?}.", *got as char), + Error::NotEnoughBytes => "Not enough bytes.".to_string(), + }; + + write!(f, "{}", res) + } +} + +// Encoders + +fn encode_array(array: Vec<Expr>) -> Vec<u8> { + let mut res = array.len().to_string().as_bytes().to_vec(); + + res.insert(0, b'*'); + res.push(b'\n'); + + for expr in array { + res.append(&mut expr.encode()); + } + + res +} +fn encode_bool(val: bool) -> Vec<u8> { + if val { + b"#t\n".to_vec() + } else { + b"#f\n".to_vec() + } +} +fn encode_blob(mut blob: Vec<u8>) -> Vec<u8> { + let mut res = blob.len().to_string().as_bytes().to_vec(); + + res.insert(0, b'$'); + res.push(b'\n'); + res.append(&mut blob); + res.push(b'\n'); + + res +} + +fn encode_line(mut line: Vec<u8>) -> Vec<u8> { + line.insert(0, b'+'); + line.push(b'\n'); + + line +} + +fn encode_number(number: u128) -> Vec<u8> { + let mut res = vec![b':']; + res.append(&mut number.to_string().as_bytes().to_vec()); + res.push(b'\n'); + + res +} + +fn encode_error(mut code: Vec<u8>, mut msg: Vec<u8>) -> Vec<u8> { + let mut res = (code.len() + msg.len()).to_string().as_bytes().to_vec(); + + res.insert(0, b'!'); + res.push(b'\n'); + res.append(&mut code); + res.push(b' '); + res.append(&mut msg); + res.push(b'\n'); + + res +} + +fn encode_table(table: &Vec<(Expr, Expr)>) -> Vec<u8> { + let mut res = table.len().to_string().as_bytes().to_vec(); + + res.insert(0, b'%'); + res.push(b'\n'); + + for (key, val) in table { + res.append(&mut key.encode()); + res.append(&mut val.encode()); + } + + res +} + +// Parsers + +fn parse_line(header: Vec<u8>, _reader: &mut BufReader<&TcpStream>) -> Result<Expr, Error> { + if header.len() < 2 { + return Err(Error::NotEnoughBytes); + } + + Ok(Expr::Line(header[1..header.len() - 1].to_vec())) +} + +fn parse_number(header: Vec<u8>, _reader: &mut BufReader<&TcpStream>) -> Result<Expr, Error> { + if header.len() < 2 { + return Err(Error::NotEnoughBytes); + } + + let digits = &header[1..header.len() - 1]; + + Ok(Expr::Number( + as_u128(digits).map_err(|_| Error::NotANumber(digits.to_vec()))?, + )) +} + +fn parse_blob(header: Vec<u8>, reader: &mut BufReader<&TcpStream>) -> Result<Expr, Error> { + let size = as_usize(&header[1..header.len() - 1]).map_err(|_| Error::InvalidHeaderSize)?; + + let mut body = vec![0u8; size]; + reader + .read_exact(&mut body) + .map_err(|_| Error::NotEnoughBytes)?; + + skip_lf(reader)?; + + Ok(Expr::Blob(body)) +} + +fn parse_error(header: Vec<u8>, reader: &mut BufReader<&TcpStream>) -> Result<Expr, Error> { + let size = as_usize(&header[1..header.len() - 1]).map_err(|_| Error::InvalidHeaderSize)?; + + let mut body = vec![0u8; size]; + reader + .read_exact(&mut body) + .map_err(|_| Error::NotEnoughBytes)?; + + skip_lf(reader)?; + + let parts: Vec<&[u8]> = body.splitn(2, |x| *x == b' ').collect(); + if parts.len() < 2 { + return Err(Error::NotEnoughBytes); + } + + Ok(Expr::Error { + code: parts[0].to_vec(), + msg: parts[1].to_vec(), + }) +} + +fn parse_array(header: Vec<u8>, reader: &mut BufReader<&TcpStream>) -> Result<Expr, Error> { + let size = as_usize(&header[1..header.len() - 1]).map_err(|_| Error::InvalidHeaderSize)?; + + let body: Vec<Expr> = (0..size) + .map(|_| Expr::from_bytes(reader)) + .collect::<Result<Vec<Expr>, _>>()?; + + Ok(Expr::Array(body)) +} + +fn parse_bool(header: Vec<u8>, _reader: &mut BufReader<&TcpStream>) -> Result<Expr, Error> { + if header.len() < 2 { + return Err(Error::NotEnoughBytes); + } + + match header[1] { + b't' => Ok(Expr::Bool(true)), + b'f' => Ok(Expr::Bool(false)), + other => Err(Error::InvalidChar(other)), + } +} + +// Into + +impl Into<Expr> for ip::Block { + fn into(self) -> Expr { + Expr::Line(self.into()) + } +} + +impl Into<Expr> for &ip::Block { + fn into(self) -> Expr { + Expr::Line(self.into()) + } +} + +impl Into<Expr> for &str { + fn into(self) -> Expr { + Expr::Blob(self.into()) + } +} + +impl Into<Expr> for String { + fn into(self) -> Expr { + Expr::Blob(self.into()) + } +} + +impl Into<Expr> for u8 { + fn into(self) -> Expr { + Expr::Number(self as u128) + } +} + +impl Into<Expr> for Vec<(&str, String)> { + fn into(self) -> Expr { + Expr::Table( + self.into_iter() + .map(|(x, y)| (x.into(), y.into())) + .collect(), + ) + } +} + +impl<T: Into<Expr>> Into<Expr> for Vec<T> { + fn into(self) -> Expr { + Expr::Array(self.into_iter().map(|x| x.into()).collect()) + } +} + +// Helpers + +fn skip_lf(reader: &mut BufReader<&TcpStream>) -> Result<(), Error> { + let mut lf = vec![0u8; 1]; + + reader + .read_exact(&mut lf) + .map_err(|_| Error::NotEnoughBytes)?; + + if lf != vec![b'\n'] { + return Err(Error::ExpectedLF(lf[0])); + } + Ok(()) +} + +fn as_usize(array: &[u8]) -> Result<usize, ParseIntError> { + std::str::from_utf8(array).unwrap().parse::<usize>() +} + +fn as_u128(array: &[u8]) -> Result<u128, ParseIntError> { + std::str::from_utf8(array).unwrap().parse::<u128>() +} diff --git a/src/common/ip.rs b/src/common/ip.rs new file mode 100644 index 0000000..8c1e6b6 --- /dev/null +++ b/src/common/ip.rs @@ -0,0 +1,195 @@ +#![allow(clippy::from_over_into)] +use crate::common::expr::Expr; +use memchr::memchr; +use std::cmp::Ordering; +use std::collections::HashMap; +use std::fmt; +use std::fmt::Debug; +use std::net::IpAddr; +use std::net::Ipv4Addr; +use std::net::Ipv6Addr; +use std::str::FromStr; + +pub const BITMASK: u128 = 0x8000_0000_0000_0000_0000_0000_0000_0000; +pub const BITCOUNT: u8 = 128; + +pub type Addr = IpAddr; + +#[derive(Debug)] +pub enum Error { + InvalidAddr(String), + InvalidBlock(String), + InvalidMaskLength(String), + InvalidNibblesCount(usize), + MaskLengthTooLarge(u8), +} + +#[derive(Copy, Clone, Eq)] +pub struct Block { + pub addr: u128, + pub mlen: u8, +} + +pub type Meta = HashMap<String, Expr>; + +impl Block { + fn new(addr: IpAddr, mlen: u8) -> Self { + let (ipv6_addr, ipv6_mlen) = match addr { + IpAddr::V4(addr) => { + assert!(mlen <= 32); + (addr.to_ipv6_mapped(), mlen + 96) + } + IpAddr::V6(addr) => { + assert!(mlen <= 128); + (addr, mlen) + } + }; + + let segments = ipv6_addr.segments(); + let addr_bits = ((segments[0] as u128) << 112) + | ((segments[1] as u128) << 96) + | ((segments[2] as u128) << 80) + | ((segments[3] as u128) << 64) + | ((segments[4] as u128) << 48) + | ((segments[5] as u128) << 32) + | ((segments[6] as u128) << 16) + | (segments[7] as u128); + + Self { + addr: addr_bits, + mlen: ipv6_mlen, + } + } + + pub fn from_bytestring(bytestring: &[u8]) -> Result<Self, Error> { + match memchr(b'/', bytestring) { + Some(pos) => { + let addr = IpAddr::parse_ascii(&bytestring[0..pos]).map_err(|_| { + Error::InvalidAddr( + std::str::from_utf8(&bytestring[0..pos]) + .unwrap() + .to_string(), + ) + })?; + + let mut mlen = 0; + for (i, byte) in bytestring[pos + 1..bytestring.len()] + .iter() + .rev() + .enumerate() + { + if *byte < 48 || *byte > 57 { + return Err(Error::InvalidBlock("Invalid1".to_string())); + } + mlen += (byte - 48) * 10u8.pow(i as u32); + } + + Ok(Self::new(addr, mlen)) + } + None => Err(Error::InvalidBlock( + std::str::from_utf8(bytestring).unwrap().to_string(), + )), + } + } + + fn addr(&self) -> IpAddr { + addr_from_bits(self.addr) + } + + pub fn start(&self) -> IpAddr { + let mask = 0xffff_ffff_ffff_ffff_ffff_ffff_ffff_ffff ^ ((1 << (128 - self.mlen)) - 1); + addr_from_bits(self.addr & mask) + } + pub fn end(&self) -> IpAddr { + let mask = !(0xffff_ffff_ffff_ffff_ffff_ffff_ffff_ffff ^ ((1 << (128 - self.mlen)) - 1)); + addr_from_bits(self.addr | mask) + } + + pub fn size(&self) -> u128 { + 2u128.pow(128 - self.mlen as u32) + } +} + +fn addr_from_bits(bits: u128) -> IpAddr { + if bits & 0xffff_ffff_ffff_ffff_ffff_ffff_0000_0000 == 0xffff_0000_0000 { + IpAddr::V4(Ipv4Addr::from((bits & 0xff_ff_ff_ff) as u32)) + } else { + IpAddr::V6(Ipv6Addr::from(bits)) + } +} + +impl FromStr for Block { + type Err = Error; + fn from_str(s: &str) -> Result<Block, Error> { + let parts: Vec<&str> = s.splitn(2, '/').collect(); + if parts.len() < 2 { + return Err(Error::InvalidBlock(s.to_string())); + } + + let addr = + IpAddr::from_str(parts[0]).map_err(|_| Error::InvalidAddr(parts[0].to_string()))?; + let mlen = parts[1] + .parse::<u8>() + .map_err(|_| Error::InvalidMaskLength(parts[1].to_string()))?; + + Ok(Self::new(addr, mlen)) + } +} + +impl fmt::Display for Block { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let addr = self.addr(); + let mlen = if addr.is_ipv4() { + self.mlen - 96 + } else { + self.mlen + }; + write!(f, "{}/{}", addr, mlen) + } +} + +impl fmt::Debug for Block { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let addr = self.addr(); + let mlen = if addr.is_ipv4() { + self.mlen - 96 + } else { + self.mlen + }; + write!(f, "{}/{}", addr, mlen) + } +} + +impl Ord for Block { + fn cmp(&self, other: &Self) -> Ordering { + (self.addr >> (BITCOUNT - self.mlen) as usize) + .cmp(&(other.addr >> ((BITCOUNT - other.mlen) % 32) as usize)) + } +} + +impl PartialEq for Block { + fn eq(&self, other: &Self) -> bool { + self.addr == other.addr && self.mlen == other.mlen + } +} + +impl PartialOrd for Block { + fn partial_cmp(&self, other: &Self) -> Option<Ordering> { + Some( + (self.addr >> (BITCOUNT - self.mlen) as usize) + .cmp(&(other.addr >> ((BITCOUNT - other.mlen) % 32) as usize)), + ) + } +} + +impl Into<Vec<u8>> for Block { + fn into(self) -> Vec<u8> { + (&self).into() + } +} + +impl Into<Vec<u8>> for &Block { + fn into(self) -> Vec<u8> { + self.to_string().into() + } +} diff --git a/src/common/mod.rs b/src/common/mod.rs new file mode 100644 index 0000000..d5ad193 --- /dev/null +++ b/src/common/mod.rs @@ -0,0 +1,7 @@ +pub mod db; +pub mod expr; +pub mod ip; +pub mod query; +pub mod sized_buffer; +pub mod term; +pub mod trie; diff --git a/src/common/query.rs b/src/common/query.rs new file mode 100644 index 0000000..00d79b2 --- /dev/null +++ b/src/common/query.rs @@ -0,0 +1,252 @@ +use crate::common::db::Database; +use crate::common::expr::Expr; +use crate::common::ip; +use std::fmt; + +#[derive(Debug)] +pub enum Error { + Empty, + Invalid, + IP(ip::Error), + ExpectedArray(Expr), + ExpectedBlob(Expr), + InvalidArgsCount { + cmd: String, + expected: usize, + actual: usize, + }, +} + +#[derive(Debug)] +pub enum Query { + Children { ip_block: ip::Block }, + Del { ip_block: ip::Block }, + Echo { msg: Vec<u8> }, + Exists { ip_block: ip::Block }, + Get { ip_block: ip::Block }, + Info { ip_block: ip::Block }, + Parent { ip_block: ip::Block }, + Set { ip_block: ip::Block, value: Expr }, +} + +impl Query { + pub fn from_expr(expr: Expr) -> Result<Self, Error> { + let mut xs = if let Expr::Array(xs) = expr { + xs + } else { + return Err(Error::ExpectedArray(expr)); + }; + + if xs.is_empty() { + return Err(Error::Empty); + } + let cmd_expr = xs.remove(0); + let cmd = if let Expr::Blob(cmd) = cmd_expr { + cmd + } else { + return Err(Error::ExpectedBlob(cmd_expr)); + }; + + match &*cmd { + b"ECHO" => build_echo(xs), // ECHO str + b"SET" => build_set(xs), // SET ::/32 Expr + b"GET" => build_get(xs), // GET ::/32 + b"DEL" => build_del(xs), // DEL ::/32 + b"EXISTS" => build_exists(xs), // EXISTS ::/32 + b"INFO" => build_info(xs), // EXISTS ::/32 + b"PARENT" => build_parent(xs), // SUP ::/32 + b"CHILDREN" => build_children(xs), // SUB ::/32 + _ => Err(Error::Invalid), + } + } + + pub fn exec(&self, db: &mut Database) -> Expr { + match self { + Self::Echo { msg } => Expr::Blob(msg.clone()), + Self::Set { ip_block, value: _ } => { + db.set(*ip_block, std::collections::HashMap::new()); + Expr::Blob(b"OK".to_vec()) + } + Self::Get { ip_block } => db.get(ip_block).map(|x| x.into()).unwrap_or(Expr::Null), + Self::Del { ip_block } => db + .del(ip_block) + .map_or(Expr::Null, |_| Expr::Blob(b"OK".to_vec())), + Self::Exists { ip_block } => Expr::Bool(db.exists(ip_block)), + Self::Info { ip_block } => { + let info: Vec<(&str, String)> = db.info(ip_block).into(); + info.into() + } + Self::Parent { ip_block } => { + db.parent(ip_block).map(|x| x.into()).unwrap_or(Expr::Null) + } + Self::Children { ip_block } => db.children(ip_block).into(), + } + } +} + +// Display + +impl fmt::Display for Query { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let res = match self { + Query::Set { ip_block, value } => { + format!("SET {} {}", ip_block, value) + } + Query::Get { ip_block } => { + format!("GET {}", ip_block) + } + Query::Echo { msg } => { + format!("ECHO {:?}", String::from_utf8_lossy(msg).into_owned()) + } + Query::Del { ip_block } => { + format!("DEL {}", ip_block) + } + Query::Exists { ip_block } => { + format!("EXISTS {}", ip_block) + } + Query::Info { ip_block } => { + format!("INFO {}", ip_block) + } + Query::Parent { ip_block } => { + format!("PARENT {}", ip_block) + } + Query::Children { ip_block } => { + format!("CHILDREN {}", ip_block) + } + }; + + write!(f, "{}", res) + } +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let res = match self { + Error::Empty => "Received an empty query.".to_string(), + Error::ExpectedArray(got) => format!("Expected an ARRAY, got: {:?}.", got), + Error::ExpectedBlob(got) => format!("Expected a BLOB, got: {:?}.", got), + Error::IP(e) => format!("IP error: {e:?}"), + Error::Invalid => "Received an invalid query.".to_string(), + Error::InvalidArgsCount { + cmd, + expected, + actual, + } => format!( + "Invalid number of arguments for {:?}. Expected {} but got {}.", + cmd, expected, actual + ), + }; + + write!(f, "{}", res) + } +} + +// Builders + +fn build_echo(mut args: Vec<Expr>) -> Result<Query, Error> { + count_args("ECHO", &args, 1)?; + + let expr = args.remove(0); + let msg = if let Expr::Blob(k) = expr { + k + } else { + return Err(Error::ExpectedBlob(expr)); + }; + Ok(Query::Echo { msg }) +} + +fn build_get(mut args: Vec<Expr>) -> Result<Query, Error> { + count_args("GET", &args, 1)?; + + let expr = args.remove(0); + let ip_block = if let Expr::Blob(ref addr) = expr { + ip::Block::from_bytestring(addr).map_err(Error::IP)? + } else { + return Err(Error::ExpectedBlob(expr)); + }; + Ok(Query::Get { ip_block }) +} +fn build_set(mut args: Vec<Expr>) -> Result<Query, Error> { + count_args("SET", &args, 2)?; + + let expr = args.remove(0); + let ip_block = if let Expr::Blob(ref addr) = expr { + ip::Block::from_bytestring(addr).map_err(Error::IP)? + } else { + return Err(Error::ExpectedBlob(expr)); + }; + Ok(Query::Set { + ip_block, + value: args.remove(0), + }) +} +fn build_del(mut args: Vec<Expr>) -> Result<Query, Error> { + count_args("DEL", &args, 1)?; + + let expr = args.remove(0); + let ip_block = if let Expr::Blob(ref addr) = expr { + ip::Block::from_bytestring(addr).map_err(Error::IP)? + } else { + return Err(Error::ExpectedBlob(expr)); + }; + Ok(Query::Del { ip_block }) +} +fn build_exists(mut args: Vec<Expr>) -> Result<Query, Error> { + count_args("EXISTS", &args, 1)?; + + let expr = args.remove(0); + let ip_block = if let Expr::Blob(ref addr) = expr { + ip::Block::from_bytestring(addr).map_err(Error::IP)? + } else { + return Err(Error::ExpectedBlob(expr)); + }; + Ok(Query::Exists { ip_block }) +} +fn build_info(mut args: Vec<Expr>) -> Result<Query, Error> { + count_args("INFO", &args, 1)?; + + let expr = args.remove(0); + let ip_block = if let Expr::Blob(ref addr) = expr { + ip::Block::from_bytestring(addr).map_err(Error::IP)? + } else { + return Err(Error::ExpectedBlob(expr)); + }; + Ok(Query::Info { ip_block }) +} +fn build_parent(mut args: Vec<Expr>) -> Result<Query, Error> { + count_args("PARENT", &args, 1)?; + + let expr = args.remove(0); + let ip_block = if let Expr::Blob(ref addr) = expr { + ip::Block::from_bytestring(addr).map_err(Error::IP)? + } else { + return Err(Error::ExpectedBlob(expr)); + }; + Ok(Query::Parent { ip_block }) +} +fn build_children(mut args: Vec<Expr>) -> Result<Query, Error> { + count_args("CHILDREN", &args, 1)?; + + let expr = args.remove(0); + let ip_block = if let Expr::Blob(ref addr) = expr { + ip::Block::from_bytestring(addr).map_err(Error::IP)? + } else { + return Err(Error::ExpectedBlob(expr)); + }; + Ok(Query::Children { ip_block }) +} + +// Helpers + +fn count_args<T>(cmd: &str, args: &Vec<T>, expected_count: usize) -> Result<(), Error> { + let args_count = args.len(); + if args_count != expected_count { + return Err(Error::InvalidArgsCount { + cmd: cmd.to_string(), + expected: expected_count, + actual: args_count, + }); + } + + Ok(()) +} diff --git a/src/common/sized_buffer.rs b/src/common/sized_buffer.rs new file mode 100644 index 0000000..ebfcab5 --- /dev/null +++ b/src/common/sized_buffer.rs @@ -0,0 +1,113 @@ +use memchr::memchr; +use std::io; +use std::io::Write; + +pub struct SizedBuffer<T> { + len: usize, + pos: usize, + pub inner: T, +} + +#[derive(Debug)] +pub enum Error { + IO(io::Error), + ReadPastEnd, + WrongByte { expected: u8, got: u8 }, +} + +impl<const N: usize> SizedBuffer<[u8; N]> { + pub fn new() -> Self { + Self { + len: 0, + pos: 0, + inner: [0; N], + } + } + + pub fn append(&mut self, bytes: &mut [u8]) -> Result<(), Error> { + self.write_all(bytes).map_err(Error::IO) + } + + pub fn peek(&self) -> u8 { + self.inner[self.pos] + } + + pub fn read(&mut self) -> Result<u8, Error> { + self.seek(1)?; + Ok(self.inner[self.pos - 1]) + } + + /* Reads n bytes from the buffer, and moves the cursor right after the last byte read. + * + * Example with n=3: + * + * pos = 2 + * v + * [0,1,2,3,4,5,6,7,8,9] + * |-----|^ + * read new pos + */ + pub fn read_count(&mut self, n: usize) -> Result<&[u8], Error> { + self.seek(n)?; + Ok(&self.inner[self.pos - n..self.pos]) + } + + pub fn read_count_unchecked(&mut self, n: usize) -> &[u8] { + self.seek_unchecked(n); + &self.inner[self.pos - n..self.pos] + } + + pub fn skip_byte(&mut self, byte: u8) -> Result<(), Error> { + let x = self.read()?; + if x != byte { + return Err(Error::WrongByte { + expected: byte, + got: x, + }); + } + Ok(()) + } + + pub fn seek(&mut self, n: usize) -> Result<(), Error> { + /* Allows to seek 1 past the end of the buffer. Subsequent reads will still fail but this + * allows to read the last byte. Otherwise trying to read the last byte would fail after + * reading it, when trying to advance the position of the cursor. */ + if self.pos + n > self.len { + return Err(Error::ReadPastEnd); + } + self.pos += n; + + Ok(()) + } + + pub fn seek_unchecked(&mut self, n: usize) { + self.pos += n; + } + + pub fn find_offset(&mut self, x: u8) -> Option<usize> { + memchr(x, &self.inner[self.pos..self.len]) + } + + pub fn empty(&self) -> bool { + self.len == self.pos + } + + pub fn reset(&mut self) { + self.len = 0; + self.pos = 0; + } +} + +impl<const N: usize> Write for SizedBuffer<[u8; N]> { + #[inline] + fn write(&mut self, buf: &[u8]) -> io::Result<usize> { + let amt = (&mut self.inner[self.len..]).write(buf)?; + self.len += amt; + Ok(amt) + } + + #[inline] + fn flush(&mut self) -> io::Result<()> { + Ok(()) + } +} diff --git a/src/common/term.rs b/src/common/term.rs new file mode 100644 index 0000000..44fd819 --- /dev/null +++ b/src/common/term.rs @@ -0,0 +1,517 @@ +#![allow(clippy::manual_range_contains)] +#![allow(clippy::type_complexity)] +use libc::BRKINT; +use libc::CS8; +use libc::ECHO; +use libc::ICANON; +use libc::ICRNL; +use libc::IEXTEN; +use libc::INPCK; +use libc::ISIG; +use libc::ISTRIP; +use libc::IXON; +use libc::OPOST; +use libc::TCSAFLUSH; +use libc::TIOCGWINSZ; +use libc::VMIN; +use libc::VTIME; +use log::trace; +use std::cmp::max; +use std::cmp::min; +use std::fs::File; +use std::io::Read; +use std::io::Write; +use std::mem; +use std::os::unix::io::AsRawFd; + +enum Key { + Null = 0, + CtrlA = 1, + CtrlB = 2, + CtrlC = 3, + CtrlD = 4, + CtrlE = 5, + CtrlF = 6, + CtrlH = 8, + Tab = 9, + CtrlK = 11, + CtrlL = 12, + Enter = 13, + CtrlN = 14, + CtrlP = 16, + CtrlT = 20, + CtrlU = 21, + CtrlW = 23, + Esc = 27, + Backspace = 127, +} + +pub struct Term { + initial_termios: Option<libc::termios>, + tty_in: File, + tty_out: File, + history: Vec<Vec<u8>>, + history_cursor: usize, + column: usize, + line: Vec<u8>, + line_cursor: usize, + prompt: Option<Vec<u8>>, + hints_callback: Option<fn(&[u8]) -> &[u8]>, + columns: usize, + max_rows: usize, +} + +#[derive(Debug)] +struct Position { + column: usize, + #[allow(dead_code)] + row: usize, +} + +impl Term { + pub fn new() -> Self { + let tty_out = File::options() + .read(true) + .write(true) + .open("/dev/tty") + .unwrap(); + let tty_in = tty_out.try_clone().unwrap(); + + Self { + initial_termios: None, + tty_out, + tty_in, + history: Vec::new(), + history_cursor: 0, + column: 0, + line: Vec::new(), + line_cursor: 0, + prompt: None, + hints_callback: None, + columns: 0, + max_rows: 0, + } + } + + pub fn setup_hints(&mut self, hints_callback: fn(&[u8]) -> &[u8]) { + self.hints_callback = Some(hints_callback); + } + + pub fn setup_prompt(&mut self, prompt: &[u8]) { + self.prompt = Some(prompt.to_vec()); + } + + pub fn edit(&mut self) -> Option<Vec<u8>> { + self.enable_raw_mode(); + + self.columns = self.get_columns(); + self.column = self.get_cursor_position().unwrap().column; + + self.refresh_line(); + + loop { + let mut buffer: [u8; 1] = [0]; + + if self.tty_in.read(&mut buffer).unwrap_or(0) != 1 { + break; + } + + match buffer[0] { + x if (x == Key::Tab as u8) => { + todo!("PRESSED TAB"); + } + x if (x == Key::CtrlC as u8) => { + self.tty_out.write_all(b"\r\n").unwrap(); + self.disable_raw_mode(); + std::process::exit(1); + } + x if (x == Key::Enter as u8) => { + let line = self.line.clone(); + + self.history_cursor = self.history.len(); + self.history.push(self.line.clone()); + + self.line.clear(); + self.line_cursor = 0; + + self.tty_out.write_all(b"\r\n").unwrap(); + self.disable_raw_mode(); + + return Some(line); + } + x if (x == Key::Backspace as u8) => { + if self.line_cursor < 1 { + continue; + } + + self.line_cursor -= 1; + self.line.remove(self.line_cursor); + + self.refresh_line(); + } + x if (x == Key::Null as u8) => {} + x if (x == Key::CtrlD as u8) => { + if !self.line.is_empty() { + self.delete_right(); + } else { + self.tty_out.write_all(b"\r\n").unwrap(); + self.disable_raw_mode(); + std::process::exit(1); + } + } + x if (x == Key::CtrlA as u8) => self.move_to_start(), + x if (x == Key::CtrlE as u8) => self.move_to_end(), + x if (x == Key::CtrlB as u8) => self.move_left(), + x if (x == Key::CtrlF as u8) => self.move_right(), + x if (x == Key::CtrlH as u8) => self.delete_left(), + x if (x == Key::CtrlK as u8) => self.delete_to_end(), + x if (x == Key::CtrlL as u8) => { + self.clear_screen(); + self.refresh_line(); + } + // Next line in history + x if (x == Key::CtrlN as u8) => self.history_next(), + // Previous line in history + x if (x == Key::CtrlP as u8) => self.history_prev(), + x if (x == Key::CtrlT as u8) => self.swap_one(), + x if (x == Key::CtrlU as u8) => self.delete_to_start(), + x if (x == Key::CtrlW as u8) => self.delete_word_left(), + x if (x == Key::Esc as u8) => { + let mut seq: [u8; 1] = [0]; + if self.tty_in.read(&mut seq).unwrap_or(0) != 1 { + continue; + } + if seq[0] != b'[' { + // FIXME: Handle Esc+Key / "Esc 0" sequences + continue; + } + if self.tty_in.read(&mut seq).unwrap_or(0) != 1 { + break; + } + if seq[0] >= b'0' && seq[0] <= b'9' { + // FIXME: Handle extended escape sequence + continue; + } + match seq[0] { + // Up + b'A' => self.history_prev(), + // Down + b'B' => self.history_next(), + // Right + b'C' => self.move_right(), + // Left + b'D' => self.move_left(), + // Home + b'H' => self.move_to_start(), + // End + b'F' => self.move_to_end(), + _ => break, + } + } + x => { + self.line.insert(self.line_cursor, x); + self.line_cursor += 1; + self.refresh_line(); + } + } + + self.tty_out.flush().unwrap(); + } + + self.disable_raw_mode(); + + None + } + + fn history_next(&mut self) { + if self.history.is_empty() { + return; + } + + self.line = self.history[self.history_cursor].clone(); + self.line_cursor = self.line.len(); + + self.refresh_line(); + + self.history_cursor = min(self.history_cursor + 1, self.history.len() - 1); + } + + fn history_prev(&mut self) { + if self.history.is_empty() { + return; + } + + // FIXME: Save current line before replacing it with line for the history. + + self.line = self.history[self.history_cursor].clone(); + self.line_cursor = self.line.len(); + + self.refresh_line(); + + self.history_cursor = self.history_cursor.saturating_sub(1); + } + + fn move_right(&mut self) { + self.line_cursor = min(self.line_cursor + 1, self.line.len()); + self.refresh_line(); + } + + fn move_left(&mut self) { + self.line_cursor = self.line_cursor.saturating_sub(1); + self.refresh_line(); + } + + fn move_to_start(&mut self) { + self.line_cursor = 0; + self.refresh_line(); + } + fn move_to_end(&mut self) { + self.line_cursor = self.line.len(); + self.refresh_line(); + } + + fn delete_word_left(&mut self) { + if self.line_cursor == 0 { + return; + } + + let cursor = self.line_cursor; + + while self.line_cursor > 0 && self.line[self.line_cursor - 1] == b' ' { + self.line_cursor -= 1; + } + while self.line_cursor > 0 && self.line[self.line_cursor - 1] != b' ' { + self.line_cursor -= 1; + } + + self.line.drain(self.line_cursor..cursor); + self.refresh_line(); + } + + fn delete_left(&mut self) { + if self.line_cursor == 0 { + return; + } + + self.line.remove(self.line_cursor - 1); + self.line_cursor = self.line_cursor.saturating_sub(1); + self.refresh_line(); + } + + fn delete_right(&mut self) { + if self.line_cursor >= self.line.len() { + return; + } + + self.line.remove(self.line_cursor); + self.refresh_line(); + } + + fn delete_to_start(&mut self) { + if self.line_cursor == 0 { + return; + } + + self.line.drain(0..self.line_cursor); + self.line_cursor = 0; + self.refresh_line(); + } + + fn delete_to_end(&mut self) { + if self.line_cursor >= self.line.len() { + return; + } + + self.line.truncate(self.line_cursor); + self.refresh_line(); + } + + fn swap_one(&mut self) { + if self.line_cursor == 0 { + return; + } + + self.line_cursor = min(self.line_cursor + 1, self.line.len()); + self.line.swap(self.line_cursor - 2, self.line_cursor - 1); + self.refresh_line(); + } + + fn refresh_line(&mut self) { + let prompt_len = self.prompt.as_ref().map(|x| x.len()).unwrap_or(0); + + let rows = (self.line.len() + prompt_len + self.columns - 1) / self.columns; + let prev_rows = self.max_rows; + let curr_row = (self.line_cursor + prompt_len + self.columns - 1) / self.columns; + self.column = (prompt_len + self.line_cursor) % self.columns; + + self.max_rows = max(self.max_rows, rows); + + // + // Clear rows + // + + // Go down to the last row if the cursor is not already there. + if curr_row < prev_rows { + trace!("Go down {}", prev_rows - curr_row); + write!(self.tty_out, "\x1b[{}B", prev_rows - curr_row).unwrap(); + } + + // Clear each row and go up. + for _ in 1..prev_rows { + trace!("Clear row, go up"); + self.tty_out.write_all(b"\r\x1b[0K\x1b[1A").unwrap(); + } + + // Clear the top row. + trace!("Clear top row"); + self.tty_out.write_all(b"\r\x1b[0K").unwrap(); + + // + // Rewrite prompt and line + // + + if let Some(prompt) = &self.prompt { + self.tty_out.write_all(prompt).unwrap(); + } + + self.tty_out.write_all(&self.line).unwrap(); + + // + // Display hints if any + // + + if self.hints_callback.is_some() && !self.line.is_empty() { + let hints = self.hints_callback.unwrap()(&self.line); + + if !hints.is_empty() { + // Column for the last character on the current row. + let max_column = (prompt_len + self.line.len()) % (self.columns + 1); + let max_len = min(self.columns - max_column, hints.len()); + + self.tty_out.write_all(b"\x1b[2m").unwrap(); + self.tty_out.write_all(&hints[..max_len]).unwrap(); + self.tty_out.write_all(b"\x1b[22m").unwrap(); + } + } + + // + // Move cursor to the right position. + // + + if self.column == 0 + && self.line.len() == self.line_cursor + && (prompt_len + self.line_cursor) > 0 + { + // We reached the end of a row, insert a new line and go back to the first row. + trace!("Insert linefeed"); + self.max_rows = max(self.max_rows, rows + 1); + self.tty_out.write_all(b"\r\n").unwrap(); + } else if self.column == 0 { + self.tty_out.write_all(b"\r").unwrap(); + } else { + // Move to the right column. + write!(self.tty_out, "\r\x1b[{}C", self.column).unwrap(); + + if rows > curr_row { + // Move up to the current row. + write!(self.tty_out, "\x1b[{}A", rows - curr_row).unwrap(); + } + } + } + + fn enable_raw_mode(&mut self) { + let mut termios = mem::MaybeUninit::uninit(); + + if unsafe { libc::tcgetattr(self.tty_out.as_raw_fd(), termios.as_mut_ptr()) } == -1 { + panic!("Failed to enable raw mode"); + } + + self.initial_termios = Some(unsafe { termios.assume_init() }); + let mut raw = self.initial_termios.unwrap(); + + raw.c_iflag &= !(BRKINT | ICRNL | INPCK | ISTRIP | IXON); + raw.c_oflag &= !OPOST; + raw.c_cflag |= CS8; + raw.c_lflag &= !(ECHO | ICANON | IEXTEN | ISIG); + raw.c_cc[VMIN] = 1; + raw.c_cc[VTIME] = 0; + + if unsafe { libc::tcsetattr(self.tty_out.as_raw_fd(), TCSAFLUSH, &raw) } < 0 { + panic!("Failed to enable raw mode"); + } + } + + fn disable_raw_mode(&mut self) { + if self.initial_termios.is_some() + && unsafe { + libc::tcsetattr( + self.tty_out.as_raw_fd(), + TCSAFLUSH, + &self.initial_termios.unwrap(), + ) + } != -1 + { + self.initial_termios = None; + } + } + + fn get_columns(&mut self) -> usize { + let ws = libc::winsize { + ws_row: 0, + ws_col: 0, + ws_xpixel: 0, + ws_ypixel: 0, + }; + + if unsafe { libc::ioctl(self.tty_out.as_raw_fd(), TIOCGWINSZ, &ws) } < 0 || ws.ws_col == 0 { + todo!("Query the terminal for columns count."); + } + + ws.ws_col as usize + } + + fn get_cursor_position(&mut self) -> Option<Position> { + self.enable_raw_mode(); + + self.tty_out.write_all(b"\x1b[6n").unwrap(); + self.tty_out.flush().unwrap(); + + let mut buffer: [u8; 32] = [0; 32]; + + let mut sep_index = None; + let mut index = 0; + + let mut bytes = (&self.tty_in).bytes(); + while let Some(Ok(byte)) = bytes.next() { + if byte == b'R' || index >= buffer.len() { + break; + } + if byte == b';' { + sep_index = Some(index); + } + + buffer[index] = byte; + index += 1; + } + + if buffer[0] != Key::Esc as u8 || buffer[1] != b'[' { + return None; + } + + sep_index.map(|sep_index| Position { + row: as_usize(&buffer[2..sep_index]), + column: as_usize(&buffer[sep_index + 1..index]), + }) + } + + fn clear_screen(&mut self) { + self.tty_out.write_all(b"\x1b[H\x1b[2J").unwrap(); + } +} + +fn as_usize(array: &[u8]) -> usize { + std::str::from_utf8(array) + .unwrap() + .parse::<usize>() + .unwrap() +} diff --git a/src/common/trie.rs b/src/common/trie.rs new file mode 100644 index 0000000..381d2b5 --- /dev/null +++ b/src/common/trie.rs @@ -0,0 +1,151 @@ +#![allow(dead_code)] +use crate::common::ip::Block; +use crate::common::ip::Meta; +use crate::common::ip::BITMASK; + +pub struct Trie { + root: Node, +} + +#[derive(Debug)] +struct Node { + block: Option<Block>, + left: Option<Box<Node>>, + right: Option<Box<Node>>, +} + +impl Node { + pub const NULL: Node = Self { + block: None, + left: None, + right: None, + }; + + fn new(block: Option<Block>) -> Self { + Node { + block, + left: None, + right: None, + } + } + + fn is_leaf(&self) -> bool { + self.left.is_none() && self.right.is_none() + } +} + +impl Trie { + pub fn new() -> Self { + Trie { root: Node::NULL } + } + + pub fn insert(&mut self, block: Block, _meta: Meta) { + let mut node = &mut self.root; + let mut addr = block.addr; + + for _ in 0..block.mlen { + if addr & BITMASK == 0 { + if node.left.is_none() { + node.left = Some(Box::new(Node::NULL)); + } + node = node.left.as_deref_mut().unwrap(); + } else { + if node.right.is_none() { + node.right = Some(Box::new(Node::NULL)); + } + node = node.right.as_deref_mut().unwrap(); + } + addr <<= 1; + } + + node.block = Some(block); + } + + pub fn find(&self, needle: &Block) -> Option<Block> { + self.find_block_node(needle)?.block + } + + pub fn parent(&self, needle: &Block) -> Option<Block> { + let mut node = &self.root; + let mut addr = needle.addr; + let mut parent = None; + + for _ in 0..needle.mlen { + if node.block.is_some() { + parent = node.block; + } + + if addr & BITMASK == 0 { + if node.left.is_some() { + node = node.left.as_deref().unwrap(); + } else { + break; + } + } else if node.right.is_some() { + node = node.right.as_deref().unwrap(); + } else { + break; + } + addr <<= 1; + } + + parent + } + + pub fn children(&self, parent: &Block) -> Vec<Block> { + let root: &Node = match self.find_block_node(parent) { + Some(root) => root, + None => return vec![], + }; + + let mut children: Vec<Block> = Vec::new(); + let mut stack: Vec<&Node> = vec![root]; + + while let Some(node) = stack.pop() { + if node.is_leaf() && node.block.is_some() && node.block.unwrap() != root.block.unwrap() + { + children.push(node.block.unwrap()); + continue; + } + + if node.left.is_some() { + stack.push(node.left.as_deref().unwrap()); + } + if node.right.is_some() { + stack.push(node.right.as_deref().unwrap()); + } + } + + children + } + + fn find_block_node(&self, needle: &Block) -> Option<&Node> { + let mut node = &self.root; + let mut addr = needle.addr; + let mut match_ = None; + + for _ in 0..(needle.mlen + 1) { + if node.block.is_some() { + match_ = Some(node); + } + + if addr & BITMASK == 0 { + if node.left.is_some() { + node = node.left.as_deref().unwrap(); + } else { + break; + } + } else if node.right.is_some() { + node = node.right.as_deref().unwrap(); + } else { + break; + } + addr <<= 1; + } + + match match_?.block { + Some(block) if block == *needle => match_, + _ => None, + } + } +} |