diff --git a/src/client.rs b/src/client.rs index 45b7c16..93b7efb 100644 --- a/src/client.rs +++ b/src/client.rs @@ -273,16 +273,16 @@ impl Subscriber { } // Read the response - for channel in &channels { + for _channel in &channels { let response = self.read_response().await?; match response { Frame::Array(ref frame) => match frame.as_slice() { - [unsubscribe, uchannel] - if &unsubscribe.to_string() == "unsubscribe" - && &uchannel.to_string() == channel => - { - self.subscribed_channels.remove(&uchannel.to_string()); - } + [unsubscribe, uchannel] if &unsubscribe.to_string() == "unsubscribe" => { + //unsubscribed channel should exist in the subscribed list at this point + if self.subscribed_channels.remove(&uchannel.to_string()) == false { + return Err(response.to_error()); + } + }, _ => return Err(response.to_error()), }, frame => return Err(frame.to_error()), diff --git a/src/cmd/mod.rs b/src/cmd/mod.rs index e83c15f..a437a78 100644 --- a/src/cmd/mod.rs +++ b/src/cmd/mod.rs @@ -10,6 +10,9 @@ pub use set::Set; mod subscribe; pub use subscribe::{Subscribe, Unsubscribe}; +mod unknown; +pub use unknown::Unknown; + use crate::{Connection, Db, Frame, Parse, ParseError, Shutdown}; #[derive(Debug)] @@ -19,6 +22,7 @@ pub(crate) enum Command { Set(Set), Subscribe(Subscribe), Unsubscribe(Unsubscribe), + Unknown(Unknown) } impl Command { @@ -33,7 +37,10 @@ impl Command { "set" => Command::Set(Set::parse_frames(&mut parse)?), "subscribe" => Command::Subscribe(Subscribe::parse_frames(&mut parse)?), "unsubscribe" => Command::Unsubscribe(Unsubscribe::parse_frames(&mut parse)?), - _ => return Err(ParseError::UnknownCommand(command_name)), + _ => { + parse.next_string()?; + Command::Unknown(Unknown::new(command_name)) + }, }; parse.finish()?; @@ -53,9 +60,21 @@ impl Command { Publish(cmd) => cmd.apply(db, dst).await, Set(cmd) => cmd.apply(db, dst).await, Subscribe(cmd) => cmd.apply(db, dst, shutdown).await, + Unknown(cmd) => cmd.apply(dst).await, // `Unsubscribe` cannot be applied. It may only be received from the // context of a `Subscribe` command. Unsubscribe(_) => unimplemented!(), } } + + pub(crate) fn get_name(&self) -> &str { + match self { + Command::Get(_) => "get", + Command::Publish(_) => "pub", + Command::Set(_) => "set", + Command::Subscribe(_) => "subscribe", + Command::Unsubscribe(_) => "unsubscribe", + Command::Unknown(cmd) => &cmd.command_name, + } + } } diff --git a/src/cmd/subscribe.rs b/src/cmd/subscribe.rs index 7e6cdef..a2c924b 100644 --- a/src/cmd/subscribe.rs +++ b/src/cmd/subscribe.rs @@ -1,4 +1,4 @@ -use crate::cmd::{Parse, ParseError}; +use crate::cmd::{Parse, ParseError, Unknown}; use crate::{Command, Connection, Db, Frame, Shutdown}; use bytes::Bytes; @@ -134,9 +134,9 @@ impl Subscribe { dst.write_frame(&response).await?; } } - _ => { - // TODO: received invalid command - unimplemented!(); + command => { + let cmd = Unknown::new(command.get_name()); + cmd.apply(dst).await?; } } } diff --git a/src/cmd/unknown.rs b/src/cmd/unknown.rs new file mode 100644 index 0000000..20437b2 --- /dev/null +++ b/src/cmd/unknown.rs @@ -0,0 +1,26 @@ +use crate::{Connection, Frame}; + +use tracing::{debug, instrument}; + +#[derive(Debug)] +pub struct Unknown { + pub command_name: String, +} + +impl Unknown { + /// Create a new `Unknown` command which responds to unknown commands + /// issued by clients + pub(crate) fn new(key: impl ToString) -> Unknown { + Unknown { command_name: key.to_string() } + } + + #[instrument(skip(self, dst))] + pub(crate) async fn apply(self, dst: &mut Connection) -> crate::Result<()> { + let response = Frame::Error(format!("ERR unknown command '{}'", self.command_name)); + + debug!(?response); + + dst.write_frame(&response).await?; + Ok(()) + } +} diff --git a/src/parse.rs b/src/parse.rs index d6c8125..e77ab86 100644 --- a/src/parse.rs +++ b/src/parse.rs @@ -13,7 +13,6 @@ pub(crate) struct Parse { pub(crate) enum ParseError { EndOfStream, Invalid, - UnknownCommand(String), } impl Parse { @@ -85,7 +84,6 @@ impl fmt::Display for ParseError { let msg = match self { ParseError::EndOfStream => "end of stream".to_string(), ParseError::Invalid => "invalid".to_string(), - ParseError::UnknownCommand(cmd) => format!("unknown command `{}`", cmd), }; write!(f, "{}", &msg) } diff --git a/tests/server.rs b/tests/server.rs index bb94f31..f5ffa04 100644 --- a/tests/server.rs +++ b/tests/server.rs @@ -98,6 +98,86 @@ async fn key_value_timeout() { assert_eq!(b"$-1\r\n", &response); } +// In this case we test that server responds acurately to +// SUBSCRIBE and UNSUBSCRIBE commands +#[tokio::test] +async fn subscribe_unsubscribe() { + let (addr, _handle) = start_server().await; + + let mut stream = TcpStream::connect(addr).await.unwrap(); + + // send SUBSCRIBE command + stream.write_all(b"*2\r\n$9\r\nsubscribe\r\n$5\r\nhello\r\n").await.unwrap(); + + // Read response + let mut response = [0; 30]; + + stream.read_exact(&mut response).await.unwrap(); + + assert_eq!(b"*2\r\n$9\r\nsubscribe\r\n$5\r\nhello\r\n", &response); + + // send UNSUBSCRIBE command + stream.write_all(b"*2\r\n$11\r\nunsubscribe\r\n$5\r\nhello\r\n").await.unwrap(); + + let mut response = [0; 33]; + + stream.read_exact(&mut response).await.unwrap(); + + assert_eq!(b"*2\r\n$11\r\nunsubscribe\r\n", &response[0..22]); + assert_eq!(b"$5\r\nhello\r\n", &response[22..33]); +} + +// In this case we test that server Responds with an Error message if a client +// sends an unknown command +#[tokio::test] +async fn send_error_unknown_command() { + let (addr, _handle) = start_server().await; + + // Establish a connection to the server + let mut stream = TcpStream::connect(addr).await.unwrap(); + + // Get a key, data is missing + stream.write_all(b"*2\r\n$3\r\nFOO\r\n$5\r\nhello\r\n").await.unwrap(); + + let mut response = [0; 28]; + + stream.read_exact(&mut response).await.unwrap(); + + assert_eq!(b"-ERR unknown command \'foo\'\r\n", &response); +} + +// In this case we test that server Responds with an Error message if a client +// sends an GET or SET command after a SUBSCRIBE +#[tokio::test] +async fn send_error_get_set_after_subscribe() { + let (addr, _handle) = start_server().await; + + let mut stream = TcpStream::connect(addr).await.unwrap(); + + // send SUBSCRIBE command + stream.write_all(b"*2\r\n$9\r\nsubscribe\r\n$5\r\nhello\r\n").await.unwrap(); + + let mut response = [0; 30]; + + stream.read_exact(&mut response).await.unwrap(); + + assert_eq!(b"*2\r\n$9\r\nsubscribe\r\n$5\r\nhello\r\n", &response); + + stream.write_all(b"*3\r\n$3\r\nSET\r\n$5\r\nhello\r\n$5\r\nworld\r\n").await.unwrap(); + + let mut response = [0; 28]; + + stream.read_exact(&mut response).await.unwrap(); + assert_eq!(b"-ERR unknown command \'set\'\r\n", &response); + + stream.write_all(b"*2\r\n$3\r\nGET\r\n$5\r\nhello\r\n").await.unwrap(); + + let mut response = [0; 28]; + + stream.read_exact(&mut response).await.unwrap(); + assert_eq!(b"-ERR unknown command \'get\'\r\n", &response); +} + async fn start_server() -> (SocketAddr, JoinHandle>) { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = listener.local_addr().unwrap();