diff --git a/Cargo.toml b/Cargo.toml index 97b871f..6cd7cc3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,6 +5,7 @@ name = "mini-redis" version = "0.1.0" [dependencies] +anyhow = "1.0.27" atoi = "0.3.2" bytes = "0.5.4" clap = { git = "https://github.com/clap-rs/clap/" } @@ -12,4 +13,3 @@ tokio = { git = "https://github.com/tokio-rs/tokio", features = ["full"] } tracing = "0.1.13" tracing-futures = { version = "0.2.3", features = ["tokio"] } tracing-subscriber = "0.2.2" -anyhow = "1.0.26" diff --git a/dump.rdb b/dump.rdb new file mode 100644 index 0000000..2d6c078 Binary files /dev/null and b/dump.rdb differ diff --git a/src/bin/client.rs b/src/bin/client.rs index 855b4d0..38b7b00 100644 --- a/src/bin/client.rs +++ b/src/bin/client.rs @@ -1,10 +1,9 @@ -use bytes::Bytes; use clap::Clap; -use mini_redis::{client, DEFAULT_PORT}; -use std::{io, str}; +use mini_redis::{client, cmd::Set, DEFAULT_PORT}; +use std::str; #[tokio::main] -async fn main() -> io::Result<()> { +async fn main() -> Result<(), Box> { let cli = Cli::parse(); let port = cli.port.unwrap_or(DEFAULT_PORT.to_string()); let mut client = client::connect(&format!("127.0.0.1:{}", port)).await?; @@ -18,7 +17,16 @@ async fn main() -> io::Result<()> { } Ok(()) } - Client::Set { key, value } => client.set(&key, Bytes::from(value)).await, + Client::Set(opts) => match client.set_with_opts(opts).await { + Ok(_) => { + println!("OK"); + Ok(()) + } + Err(e) => { + eprintln!("{}", e); + Err(e) + } + }, } } @@ -33,8 +41,9 @@ struct Cli { #[derive(Clap, Debug)] enum Client { - #[clap(about = "Gets a value associated with a key")] + /// Gets a value associated with a key Get { key: String }, - #[clap(about = "Associates a value with a key")] - Set { key: String, value: String }, + + /// Associates a value with a key + Set(Set), } diff --git a/src/bin/server.rs b/src/bin/server.rs index 58865f3..df18fdf 100644 --- a/src/bin/server.rs +++ b/src/bin/server.rs @@ -1,6 +1,6 @@ +use anyhow::{anyhow, Result}; use clap::Clap; use mini_redis::{server, DEFAULT_PORT}; -use anyhow::{anyhow, Result}; #[tokio::main] pub async fn main() -> Result<()> { diff --git a/src/client.rs b/src/client.rs index 57860ea..e87fd6a 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,7 +1,14 @@ -use crate::Connection; +use crate::{ + cmd::{ + utils::{bytes_from_str, duration_from_ms_str}, + Set, + }, + frame::Frame, + Command, Connection, +}; use bytes::Bytes; -use std::io; +use std::io::{Error, ErrorKind}; use tokio::net::{TcpStream, ToSocketAddrs}; /// Mini asynchronous Redis client @@ -9,7 +16,7 @@ pub struct Client { conn: Connection, } -pub async fn connect(addr: T) -> io::Result { +pub async fn connect(addr: T) -> Result> { let socket = TcpStream::connect(addr).await?; let conn = Connection::new(socket); @@ -17,11 +24,53 @@ pub async fn connect(addr: T) -> io::Result { } impl Client { - pub async fn get(&mut self, key: &str) -> io::Result> { + pub async fn get(&mut self, key: &str) -> Result, Box> { unimplemented!(); } - pub async fn set(&mut self, key: &str, val: Bytes) -> io::Result<()> { - unimplemented!(); + pub async fn set(&mut self, key: &str, value: &str) -> Result<(), Box> { + let opts = Set { + key: key.to_string(), + value: bytes_from_str(value), + expire: None, + }; + self.set_with_opts(opts).await + } + + pub async fn set_with_expiration( + &mut self, + key: &str, + value: &str, + expiration: &str, + ) -> Result<(), Box> { + let opts = Set { + key: key.to_string(), + value: bytes_from_str(value), + expire: Some(duration_from_ms_str(expiration)?), + }; + self.set_with_opts(opts).await + } + + pub async fn set_with_opts(&mut self, opts: Set) -> Result<(), Box> { + let frame = Command::Set(opts).into_frame()?; + self.conn.write_frame(&frame).await?; + let response = self.conn.read_frame().await?; + if let Some(response) = response { + match response { + Frame::Simple(response) => { + if response == "OK" { + Ok(()) + } else { + Err("unexpected response from server".into()) + } + } + _ => Err("unexpected response from server".into()), + } + } else { + Err(Box::new(Error::new( + ErrorKind::ConnectionReset, + "connection reset by server", + ))) + } } } diff --git a/src/cmd/get.rs b/src/cmd/get.rs index 3a36d10..f16c0f6 100644 --- a/src/cmd/get.rs +++ b/src/cmd/get.rs @@ -1,4 +1,4 @@ -use crate::{Connection, Frame, Db, Parse, ParseError}; +use crate::{Connection, Db, Frame, Parse, ParseError}; use std::io; use tracing::{debug, instrument}; @@ -13,13 +13,13 @@ impl Get { // with their debug implementations // see https://docs.rs/tracing/0.1.13/tracing/attr.instrument.html #[instrument] - pub(crate) fn parse(parse: &mut Parse) -> Result { + pub(crate) fn parse_frames(parse: &mut Parse) -> Result { let key = parse.next_string()?; // adding this debug event allows us to see what key is parsed // the ? sigil tells `tracing` to use the `Debug` implementation // get parse events can be filtered by running - // RUST_LOG=mini_redis::cmd::get[parse]=debug cargo run --bin server + // RUST_LOG=mini_redis::cmd::get[parse_frames]=debug cargo run --bin server // see https://docs.rs/tracing/0.1.13/tracing/#recording-fields debug!(?key); diff --git a/src/cmd/mod.rs b/src/cmd/mod.rs index 8bfaf32..647e55e 100644 --- a/src/cmd/mod.rs +++ b/src/cmd/mod.rs @@ -10,9 +10,12 @@ pub use set::Set; mod subscribe; pub use subscribe::{Subscribe, Unsubscribe}; -use crate::{Connection, Frame, Db, Parse, ParseError, Shutdown}; +pub(crate) mod utils; + +use crate::{Connection, Db, Frame, Parse, ParseError, Shutdown}; use std::io; +use tracing::instrument; #[derive(Debug)] pub(crate) enum Command { @@ -24,17 +27,18 @@ pub(crate) enum Command { } impl Command { + #[instrument] pub(crate) fn from_frame(frame: Frame) -> Result { let mut parse = Parse::new(frame)?; let command_name = parse.next_string()?.to_lowercase(); let command = match &command_name[..] { - "get" => Command::Get(Get::parse(&mut parse)?), - "publish" => Command::Publish(Publish::parse(&mut parse)?), - "set" => Command::Set(Set::parse(&mut parse)?), - "subscribe" => Command::Subscribe(Subscribe::parse(&mut parse)?), - "unsubscribe" => Command::Unsubscribe(Unsubscribe::parse(&mut parse)?), + "get" => Command::Get(Get::parse_frames(&mut parse)?), + "publish" => Command::Publish(Publish::parse_frames(&mut parse)?), + "set" => Command::Set(Set::parse_frames(&mut parse)?), + "subscribe" => Command::Subscribe(Subscribe::parse_frames(&mut parse)?), + "unsubscribe" => Command::Unsubscribe(Unsubscribe::parse_frames(&mut parse)?), _ => return Err(ParseError::UnknownCommand(command_name)), }; @@ -42,6 +46,14 @@ impl Command { Ok(command) } + pub(crate) fn into_frame(self) -> Result { + let frame = match self { + Command::Set(set) => set.into_frame(), + _ => unimplemented!(), + }; + Ok(frame) + } + pub(crate) async fn apply( self, db: &Db, diff --git a/src/cmd/publish.rs b/src/cmd/publish.rs index 4c72377..a7a760b 100644 --- a/src/cmd/publish.rs +++ b/src/cmd/publish.rs @@ -1,4 +1,4 @@ -use crate::{Connection, Frame, Db, Parse, ParseError}; +use crate::{Connection, Db, Frame, Parse, ParseError}; use bytes::Bytes; use std::io; @@ -10,7 +10,7 @@ pub struct Publish { } impl Publish { - pub(crate) fn parse(parse: &mut Parse) -> Result { + pub(crate) fn parse_frames(parse: &mut Parse) -> Result { let channel = parse.next_string()?; let message = parse.next_bytes()?; diff --git a/src/cmd/set.rs b/src/cmd/set.rs index 38e1806..ec7c13d 100644 --- a/src/cmd/set.rs +++ b/src/cmd/set.rs @@ -1,21 +1,32 @@ -use crate::cmd::{Parse, ParseError}; -use crate::{Connection, Frame, Db}; +use crate::cmd::{ + utils::{bytes_from_str, duration_from_ms_str}, + Parse, ParseError, +}; +use crate::{Connection, Db, Frame}; +use clap::Clap; use bytes::Bytes; use std::io; use std::time::Duration; use tracing::{debug, instrument}; -#[derive(Debug)] +#[derive(Clap, Debug)] pub struct Set { - key: String, - value: Bytes, - expire: Option, + /// the lookup key + pub(crate) key: String, + + /// the value to be stored + #[clap(parse(from_str = bytes_from_str))] + pub(crate) value: Bytes, + + /// duration in milliseconds + #[clap(parse(try_from_str = duration_from_ms_str))] + pub(crate) expire: Option, } impl Set { #[instrument] - pub(crate) fn parse(parse: &mut Parse) -> Result { + pub(crate) fn parse_frames(parse: &mut Parse) -> Result { use ParseError::EndOfStream; let key = parse.next_string()?; @@ -50,4 +61,12 @@ impl Set { debug!(?response); dst.write_frame(&response).await } + + pub(crate) fn into_frame(self) -> Frame { + let mut frame = Frame::array(); + frame.push_bulk(Bytes::from("set".as_bytes())); + frame.push_bulk(Bytes::from(self.key.into_bytes())); + frame.push_bulk(self.value); + frame + } } diff --git a/src/cmd/subscribe.rs b/src/cmd/subscribe.rs index 2ef8fad..9eed976 100644 --- a/src/cmd/subscribe.rs +++ b/src/cmd/subscribe.rs @@ -1,5 +1,5 @@ use crate::cmd::{Parse, ParseError}; -use crate::{Command, Connection, Frame, Db, Shutdown}; +use crate::{Command, Connection, Db, Frame, Shutdown}; use bytes::Bytes; use std::io; @@ -17,7 +17,7 @@ pub struct Unsubscribe { } impl Subscribe { - pub(crate) fn parse(parse: &mut Parse) -> Result { + pub(crate) fn parse_frames(parse: &mut Parse) -> Result { use ParseError::EndOfStream; // There must be at least one channel @@ -151,7 +151,7 @@ impl Subscribe { } impl Unsubscribe { - pub(crate) fn parse(parse: &mut Parse) -> Result { + pub(crate) fn parse_frames(parse: &mut Parse) -> Result { use ParseError::EndOfStream; // There may be no channels listed. diff --git a/src/cmd/utils.rs b/src/cmd/utils.rs new file mode 100644 index 0000000..8c1ccd3 --- /dev/null +++ b/src/cmd/utils.rs @@ -0,0 +1,11 @@ +use bytes::Bytes; +use std::time::Duration; + +pub(crate) fn duration_from_ms_str(src: &str) -> Result { + let millis = src.parse::()?; + Ok(Duration::from_millis(millis)) +} + +pub(crate) fn bytes_from_str(src: &str) -> Bytes { + Bytes::from(src.to_string()) +} diff --git a/src/conn.rs b/src/conn.rs index 00a730a..a8893ec 100644 --- a/src/conn.rs +++ b/src/conn.rs @@ -4,7 +4,7 @@ use bytes::{Buf, BytesMut}; use std::io::{self, Cursor}; use tokio::io::{AsyncReadExt, AsyncWriteExt, BufStream}; use tokio::net::TcpStream; - +use tracing::debug; #[derive(Debug)] pub(crate) struct Connection { stream: BufStream, @@ -23,6 +23,7 @@ impl Connection { use frame::Error::Incomplete; loop { + debug!(?self.buffer); let mut buf = Cursor::new(&self.buffer[..]); match Frame::check(&mut buf) { diff --git a/src/db.rs b/src/db.rs index 5404289..5a4858a 100644 --- a/src/db.rs +++ b/src/db.rs @@ -86,7 +86,8 @@ impl Db { // Only notify the worker task if the newly inserted expiration is the // **next** key to evict. In this case, the worker needs to be woken up // to update its state. - notify = state.next_expiration() + notify = state + .next_expiration() .map(|expiration| expiration > when) .unwrap_or(true); @@ -95,11 +96,14 @@ impl Db { }); // Insert the entry. - let prev = state.entries.insert(key, Entry { - id, - data: value, - expires_at, - }); + let prev = state.entries.insert( + key, + Entry { + id, + data: value, + expires_at, + }, + ); if let Some(prev) = prev { if let Some(when) = prev.expires_at { @@ -180,7 +184,10 @@ impl Shared { impl State { fn next_expiration(&self) -> Option { - self.expirations.keys().next().map(|expiration| expiration.0) + self.expirations + .keys() + .next() + .map(|expiration| expiration.0) } } diff --git a/src/frame.rs b/src/frame.rs index 21735cc..1d3b06d 100644 --- a/src/frame.rs +++ b/src/frame.rs @@ -6,7 +6,7 @@ use std::io::Cursor; use std::num::TryFromIntError; use std::string::FromUtf8Error; -#[derive(Debug)] +#[derive(Clone, Debug)] pub(crate) enum Frame { Simple(String), Error(String), @@ -148,6 +148,23 @@ impl Frame { _ => unimplemented!(), } } + + pub(crate) fn try_as_str(&self) -> Result { + match &self { + Frame::Simple(response) => Ok(response.to_string()), + Frame::Error(response) => Err(response.to_string()), + Frame::Integer(response) => Ok(format!("{}", response)), + Frame::Bulk(response) => Ok(format!("{:?}", response)), + Frame::Null => Ok("(nil)".to_string()), + Frame::Array(response) => { + let mut msg = "".to_string(); + for item in response { + msg.push_str(&item.try_as_str()?) + } + Ok(msg) + } + } + } } fn peek_u8(src: &mut Cursor<&[u8]>) -> Result { diff --git a/src/lib.rs b/src/lib.rs index 9e7dc87..4d5a979 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,7 +2,7 @@ pub const DEFAULT_PORT: &str = "6379"; pub mod client; -mod cmd; +pub mod cmd; use cmd::Command; mod conn; diff --git a/src/parse.rs b/src/parse.rs index 5cfb571..d6c8125 100644 --- a/src/parse.rs +++ b/src/parse.rs @@ -1,7 +1,7 @@ use crate::Frame; use bytes::Bytes; -use std::{io, str, vec}; +use std::{error, fmt, io, str, vec}; /// Utility for parsing a command #[derive(Debug)] @@ -76,15 +76,23 @@ impl Parse { impl From for io::Error { fn from(src: ParseError) -> io::Error { - use ParseError::*; - - io::Error::new( - io::ErrorKind::Other, - match src { - EndOfStream => "end of stream".to_string(), - Invalid => "invalid".to_string(), - UnknownCommand(cmd) => format!("unknown command `{}`", cmd), - }, - ) + io::Error::new(io::ErrorKind::Other, format!("{}", src)) + } +} + +impl fmt::Display for ParseError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let msg = match self { + ParseError::EndOfStream => "end of stream".to_string(), + ParseError::Invalid => "invalid".to_string(), + ParseError::UnknownCommand(cmd) => format!("unknown command `{}`", cmd), + }; + write!(f, "{}", &msg) + } +} + +impl std::error::Error for ParseError { + fn source(&self) -> Option<&(dyn error::Error + 'static)> { + None } }