diff --git a/README.md b/README.md index 41e3356..ad60786 100644 --- a/README.md +++ b/README.md @@ -10,6 +10,20 @@ application. **Disclaimer** Don't even think about trying to use this in production... just don't. +## Why Redis + +The primary goal of this project is teaching Tokio. Doing this requires a +project with a wide range of features with a focus on implementation simplicity. +Redis, an in-memory database, provides a wide range of features and uses a +simple wire protocol. The wide range of features allows demonstrating many Tokio +patterns in a "real world" context. + +The Redis wire protocol documentation can be found [here](https://redis.io/topics/protocol). + +The set of commands Redis provides can be found +[here](https://redis.io/commands). + + ## Running The repository provides a server, client library, and some client executables @@ -117,6 +131,12 @@ the server to update the active subscriptions. The server uses a `std::sync::Mutex` and **not** a Tokio mutex to synchronize access to shared state. See [`db.rs`](src/db.rs) for more details. +### Testing asynchronous code that relies on time + +In [`tests/server.rs`](tests/server.rs), there are tests for key expiration. +These tests depend on time passing. In order to make the tests deterministic, +time is mocked out using Tokio's testing utilities. + ## Contributing Contributions to `mini-redis` are welcome. Keep in mind, the goal of the project @@ -128,6 +148,9 @@ demonstrate a new pattern. Contributions should come with extensive comments targetted to new Tokio users. +Contributions that only focus on clarifying and improving comments are very +welcome. + ## FAQ #### Should I use this in production? diff --git a/examples/sub.rs b/examples/sub.rs index ccd3cf9..761b106 100644 --- a/examples/sub.rs +++ b/examples/sub.rs @@ -29,9 +29,9 @@ pub async fn main() -> Result<()> { let mut subscriber = client.subscribe(vec!["foo".into()]).await?; // await messages on channel foo - let msg = subscriber.next_message().await? ; - println!("got message from the channel: {}; message = {:?}", msg.channel, msg.content); - + if let Some(msg) = subscriber.next_message().await? { + println!("got message from the channel: {}; message = {:?}", msg.channel, msg.content); + } Ok(()) } diff --git a/src/bin/cli.rs b/src/bin/cli.rs index dc5379a..a970584 100644 --- a/src/bin/cli.rs +++ b/src/bin/cli.rs @@ -63,6 +63,7 @@ async fn main() -> mini_redis::Result<()> { // Establish a connection let mut client = client::connect(&addr).await?; + // Process the requested command match cli.command { Command::Get { key } => { if let Some(value) = client.get(&key).await? { diff --git a/src/bin/server.rs b/src/bin/server.rs index fde506c..352d36f 100644 --- a/src/bin/server.rs +++ b/src/bin/server.rs @@ -1,3 +1,11 @@ +//! mini-redis server. +//! +//! This file is the entry point for the server implemented in the library. It +//! performs command line parsing and passes the arguments on to +//! `mini_redis::server`. +//! +//! The `clap` crate is used for parsing arguments. + use mini_redis::{server, DEFAULT_PORT}; use clap::Clap; diff --git a/src/client.rs b/src/client.rs index cbc61a0..4764676 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,30 +1,111 @@ +//! Minimal Redis client implementation +//! +//! Provides an async connect and methods for issuing the supported commands. + use crate::cmd::{Get, Publish, Set, Subscribe, Unsubscribe}; use crate::{Connection, Frame}; use bytes::Bytes; use std::io::{Error, ErrorKind}; -use std::iter::FromIterator; -use std::collections::HashSet; use std::time::Duration; use tokio::net::{TcpStream, ToSocketAddrs}; use tokio::stream::Stream; use tracing::{debug, instrument}; -use async_stream::stream; +use async_stream::try_stream; -/// Mini asynchronous Redis client +/// Established connection with a Redis server. +/// +/// Backed by a single `TcpStream`, `Client` provides basic network client +/// functionality (no pooling, retrying, ...). Connections are established using +/// the [`connect`](fn@connect) function. +/// +/// Requests are issued using the various methods of `Client`. pub struct Client { - conn: Connection, + /// The TCP connection decorated with the redis protocol encoder / decoder + /// implemented using a buffered `TcpStream`. + /// + /// When `Listener` receives an inbound connection, the `TcpStream` is + /// passed to `Connection::new`, which initializes the associated buffers. + /// `Connection` allows the handler to operate at the "frame" level and keep + /// the byte level protocol parsing details encapsulated in `Connection`. + connection: Connection, } -pub async fn connect(addr: T) -> crate::Result { - let socket = TcpStream::connect(addr).await?; - let conn = Connection::new(socket); +/// A client that has entered pub/sub mode. +/// +/// Once clients subscribe to a channel, they may only perform pub/sub related +/// commands. The `Client` type is transitioned to a `Subscriber` type in order +/// to prevent non-pub/sub methods from being called. +pub struct Subscriber { + /// The subscribed client. + client: Client, - Ok(Client { conn }) + /// The set of channels to which the `Subscriber` is currently subscribed. + subscribed_channels: Vec, +} + +/// A message received on a subscribed channel. +#[derive(Debug, Clone)] +pub struct Message { + pub channel: String, + pub content: Bytes, +} + +/// Establish a connection with the Redis server located at `addr`. +/// +/// `addr` may be any type that can be asynchronously converted to a +/// `SocketAddr`. This includes `SocketAddr` and strings. The `ToSocketAddrs` +/// trait is the Tokio version and not the `std` version. +/// +/// # Examples +/// +/// ```no_run +/// use mini_redis::client; +/// +/// #[tokio::main] +/// async fn main() { +/// let client = match client::connect("localhost:6379").await { +/// Ok(client) => client, +/// Err(_) => panic!("failed to establish connection"), +/// }; +/// # drop(client); +/// } +/// ``` +/// +pub async fn connect(addr: T) -> crate::Result { + // The `addr` argument is passed directly to `TcpStream::connect`. This + // performs any asynchronous DNS lookup and attempts to establish the TCP + // connection. An error at either step returns an error, which is then + // bubbled up to the caller of `mini_redis` connect. + let socket = TcpStream::connect(addr).await?; + + // Initialize the connection state. This allocates read/write buffers to + // perform redis protocol frame parsing. + let connection = Connection::new(socket); + + Ok(Client { connection }) } impl Client { - /// Get the value of a key + /// Get the value of key. + /// + /// If the key does not exist the special value `None` is returned. + /// + /// # Examples + /// + /// Demonstrates basic usage. + /// + /// ```no_run + /// use mini_redis::client; + /// + /// #[tokio::main] + /// async fn main() { + /// let mut client = client::connect("localhost:6379").await.unwrap(); + /// + /// let val = client.get("foo").await.unwrap(); + /// println!("Got = {:?}", val); + /// } + /// ``` #[instrument(skip(self))] pub async fn get(&mut self, key: &str) -> crate::Result> { // Create a `Get` command for the `key` and convert it to a frame. @@ -32,10 +113,14 @@ impl Client { debug!(request = ?frame); - // Write the frame to the socket. - self.conn.write_frame(&frame).await?; + // Write the frame to the socket. This writes the full frame to the + // socket, waiting if necessary. + self.connection.write_frame(&frame).await?; - // Wait for the response. + // Wait for the response from the server + // + // Both `Simple` and `Bulk` frames are accepted. `Null` represents the + // key not being present and `None` is returned. match self.read_response().await? { Frame::Simple(value) => Ok(Some(value.into())), Frame::Bulk(value) => Ok(Some(value)), @@ -44,38 +129,80 @@ impl Client { } } - /// Set the value of a key to `value`. + /// Set `key` to hold the given `value`. + /// + /// The `value` is associated with `key` until it is overwritten by the next + /// call to `set` or it is removed. + /// + /// If key already holds a value, it is overwritten. Any previous time to + /// live associated with the key is discarded on successful SET operation. + /// + /// # Examples + /// + /// Demonstrates basic usage. + /// + /// ```no_run + /// use mini_redis::client; + /// + /// #[tokio::main] + /// async fn main() { + /// let mut client = client::connect("localhost:6379").await.unwrap(); + /// + /// client.set("foo", "bar".into()).await.unwrap(); + /// + /// // Getting the value immediately works + /// let val = client.get("foo").await.unwrap().unwrap(); + /// assert_eq!(val, "bar"); + /// } + /// ``` #[instrument(skip(self))] pub async fn set(&mut self, key: &str, value: Bytes) -> crate::Result<()> { + // Create a `Set` command and pass it to `set_cmd`. A separate method is + // used to set a value with an expiration. The common parts of both + // functions are implemented by `set_cmd`. self.set_cmd(Set::new(key, value, None)).await } - /// publish `message` on the `channel` - #[instrument(skip(self))] - pub async fn publish(&mut self, channel: &str, message: Bytes) -> crate::Result { - self.publish_cmd(Publish { - channel: channel.to_string(), - message: message, - }) - .await - } - - /// subscribe to the list of channels - /// when client sends a `SUBSCRIBE` command, server's handle for client enters a mode where only - /// `SUBSCRIBE` and `UNSUBSCRIBE` commands are allowed, so we consume client and return Subscribe type - /// which only allows `SUBSCRIBE` and `UNSUBSCRIBE` commands - #[instrument(skip(self))] - pub async fn subscribe(mut self, channels: Vec) -> crate::Result { - let channels = self.subscribe_cmd(Subscribe { channels: channels }).await?; - let subscribed_channels = HashSet::from_iter(channels); - - Ok(Subscriber { - conn: self.conn, - subscribed_channels, - }) - } - - /// Set the value of a key to `value`. The value expires after `expiration`. + /// Set `key` to hold the given `value`. The value expires after `expiration` + /// + /// The `value` is associated with `key` until one of the following: + /// - it expires. + /// - it is overwritten by the next call to `set`. + /// - it is removed. + /// + /// If key already holds a value, it is overwritten. Any previous time to + /// live associated with the key is discarded on a successful SET operation. + /// + /// # Examples + /// + /// Demonstrates basic usage. This example is not **guaranteed** to always + /// work as it relies on time based logic and assumes the client and server + /// stay relatively synchronized in time. The real world tends to not be so + /// favorable. + /// + /// ```no_run + /// use mini_redis::client; + /// use tokio::time; + /// use std::time::Duration; + /// + /// #[tokio::main] + /// async fn main() { + /// let ttl = Duration::from_millis(500); + /// let mut client = client::connect("localhost:6379").await.unwrap(); + /// + /// client.set_expires("foo", "bar".into(), ttl).await.unwrap(); + /// + /// // Getting the value immediately works + /// let val = client.get("foo").await.unwrap().unwrap(); + /// assert_eq!(val, "bar"); + /// + /// // Wait for the TTL to expire + /// time::delay_for(ttl).await; + /// + /// let val = client.get("foo").await.unwrap(); + /// assert!(val.is_some()); + /// } + /// ``` #[instrument(skip(self))] pub async fn set_expires( &mut self, @@ -83,33 +210,61 @@ impl Client { value: Bytes, expiration: Duration, ) -> crate::Result<()> { + // Create a `Set` command and pass it to `set_cmd`. A separate method is + // used to set a value with an expiration. The common parts of both + // functions are implemented by `set_cmd`. self.set_cmd(Set::new(key, value, Some(expiration))).await } + /// The core `SET` logic, used by both `set` and `set_expires. async fn set_cmd(&mut self, cmd: Set) -> crate::Result<()> { // Convert the `Set` command into a frame let frame = cmd.into_frame(); debug!(request = ?frame); - // Write the frame to the socket - self.conn.write_frame(&frame).await?; + // Write the frame to the socket. This writes the full frame to the + // socket, waiting if necessary. + self.connection.write_frame(&frame).await?; - // Read the response + // Wait for the response from the server. On success, the server + // responds simply with `OK`. Any other response indicates an error. match self.read_response().await? { Frame::Simple(response) if response == "OK" => Ok(()), frame => Err(frame.to_error()), } } - async fn publish_cmd(&mut self, cmd: Publish) -> crate::Result { + /// Posts `message` to the given `channel`. + /// + /// Returns the number of subscribers currently listening on the channel. + /// There is no guarantee that these subscribers receive the message as they + /// may disconnect at any time. + /// + /// # Examples + /// + /// Demonstrates basic usage. + /// + /// ```no_run + /// use mini_redis::client; + /// + /// #[tokio::main] + /// async fn main() { + /// let mut client = client::connect("localhost:6379").await.unwrap(); + /// + /// let val = client.publish("foo", "bar".into()).await.unwrap(); + /// println!("Got = {:?}", val); + /// } + /// ``` + #[instrument(skip(self))] + pub async fn publish(&mut self, channel: &str, message: Bytes) -> crate::Result { // Convert the `Publish` command into a frame - let frame = cmd.into_frame(); + let frame = Publish::new(channel, message).into_frame(); debug!(request = ?frame); // Write the frame to the socket - self.conn.write_frame(&frame).await?; + self.connection.write_frame(&frame).await?; // Read the response match self.read_response().await? { @@ -118,44 +273,76 @@ impl Client { } } - async fn subscribe_cmd(&mut self, cmd: Subscribe) -> crate::Result> { + /// Subscribes the client to the specified channels. + /// + /// Once a client issues a subscribe command, it may no longer issue any + /// non-pub/sub commands. The function consumes `self` and returns a `Subscriber`. + /// + /// The `Subscriber` value is used to receive messages as well as manage the + /// list of channels the client is subscribed to. + #[instrument(skip(self))] + pub async fn subscribe(mut self, channels: Vec) -> crate::Result { + // Issue the subscribe command to the server and wait for confirmation. + // The client will then have been transitioned into the "subscriber" + // state and may only issue pub/sub commands from that point on. + self.subscribe_cmd(&channels).await?; + + // Return the `Subscriber` type + Ok(Subscriber { + client: self, + subscribed_channels: channels, + }) + } + + /// The core `SUBSCRIBE` logic, used by misc subscribe fns + async fn subscribe_cmd(&mut self, channels: &[String]) -> crate::Result<()> { // Convert the `Subscribe` command into a frame - let channels = cmd.channels.clone(); - let frame = cmd.into_frame(); + let frame = Subscribe::new(&channels).into_frame(); debug!(request = ?frame); // Write the frame to the socket - self.conn.write_frame(&frame).await?; + self.connection.write_frame(&frame).await?; - // Read the response - for channel in &channels { + // For each channel being subscribed to, the server responds with a + // message confirming subscription to that channel. + for channel in channels { + // Read the response let response = self.read_response().await?; + + // Verify it is confirmation of subscription. match response { Frame::Array(ref frame) => match frame.as_slice() { + // The server responds with an array frame in the form of: + // + // ``` + // [ "subscribe", channel, num-subscribed ] + // ``` + // + // where channel is the name of the channel and + // num-subscribed is the number of channels that the client + // is currently subscribed to. [subscribe, schannel, ..] - if subscribe.to_string() == "subscribe" - && &schannel.to_string() == channel => - { - () - } + if **subscribe == "subscribe" && **schannel == channel => {} _ => return Err(response.to_error()), }, frame => return Err(frame.to_error()), }; } - Ok(channels) + Ok(()) } - /// Reads a response frame from the socket. If an `Error` frame is read, it - /// is converted to `Err`. + /// Reads a response frame from the socket. + /// + /// If an `Error` frame is received, it is converted to `Err`. async fn read_response(&mut self) -> crate::Result { - let response = self.conn.read_frame().await?; + let response = self.connection.read_frame().await?; debug!(?response); match response { + // Error frames are converted to `Err` Some(Frame::Error(msg)) => Err(msg.into()), Some(frame) => Ok(frame), None => { @@ -170,38 +357,53 @@ impl Client { } } -pub struct Subscriber { - conn: Connection, - subscribed_channels: HashSet, -} - impl Subscriber { - - /// get the list of subscribed channels - pub fn get_subscribed(&self) -> &HashSet { + /// Returns the set of channels currently subscribed to. + pub fn get_subscribed(&self) -> &[String] { &self.subscribed_channels } - /// await for next message published on the subscribed channels - pub async fn next_message(&mut self) -> crate::Result { - match self.receive_message().await { - Some(message) => message, - None => { - // Receiving `None` here indicates the server has closed the - // connection without sending a frame. This is unexpected and is - // represented as a "connection reset by peer" error. - let err = Error::new(ErrorKind::ConnectionReset, "connection reset by server"); + /// Receive the next message published on a subscribed channel, waiting if + /// necessary. + /// + /// `None` indicates the subscription has been terminated. + pub async fn next_message(&mut self) -> crate::Result> { + match self.client.connection.read_frame().await? { + Some(mframe) => { + debug!(?mframe); - Err(err.into()) + match mframe { + Frame::Array(ref frame) => match frame.as_slice() { + [message, channel, content] if **message == "message" => { + Ok(Some(Message { + channel: channel.to_string(), + content: Bytes::from(content.to_string()), + })) + } + _ => Err(mframe.to_error()), + }, + frame => Err(frame.to_error()), + } } + None => Ok(None), } } - /// Convert the subscriber into a Stream - /// yielding new messages published on subscribed channels + /// Convert the subscriber into a `Stream` yielding new messages published + /// on subscribed channels. + /// + /// `Subscriber` does not implement stream itself as doing so with safe code + /// is non trivial. The usage of async/await would require a manual Stream + /// implementation to use `unsafe` code. Instead, a conversion function is + /// provided and the returned stream is implemented with the help of the + /// `async-stream` crate. pub fn into_stream(mut self) -> impl Stream> { - stream! { - while let Some(message) = self.receive_message().await { + // Uses the `try_stream` macro from the `async-stream` crate. Generators + // are not stable in Rust. The crate uses a macro to simulate generators + // on top of async/await. There are limitations, so read the + // documentation there. + try_stream! { + while let Some(message) = self.next_message().await? { yield message; } } @@ -209,67 +411,55 @@ impl Subscriber { /// Subscribe to a list of new channels #[instrument(skip(self))] - pub async fn subscribe(&mut self, channels: Vec) -> crate::Result<()> { - let cmd = Subscribe { channels: channels }; + pub async fn subscribe(&mut self, channels: &[String]) -> crate::Result<()> { + // Issue the subscribe command + self.client.subscribe_cmd(channels).await?; - let channels = cmd.channels.clone(); - let frame = cmd.into_frame(); - - debug!(request = ?frame); - - // Write the frame to the socket - self.conn.write_frame(&frame).await?; - - // Read the response - for channel in &channels { - let response = self.read_response().await?; - match response { - Frame::Array(ref frame) => match frame.as_slice() { - [subscribe, schannel, ..] - if &subscribe.to_string() == "subscribe" - && &schannel.to_string() == channel => - { - () - } - _ => return Err(response.to_error()), - }, - frame => return Err(frame.to_error()), - }; - } - - self.subscribed_channels.extend(channels); + // Update the set of subscribed channels. + self.subscribed_channels.extend(channels.iter().map(Clone::clone)); Ok(()) } /// Unsubscribe to a list of new channels #[instrument(skip(self))] - pub async fn unsubscribe(&mut self, channels: Vec) -> crate::Result<()> { - let cmd = Unsubscribe { channels: channels }; - - let mut channels = cmd.channels.clone(); - let frame = cmd.into_frame(); + pub async fn unsubscribe(&mut self, channels: &[String]) -> crate::Result<()> { + let frame = Unsubscribe::new(&channels).into_frame(); debug!(request = ?frame); // Write the frame to the socket - self.conn.write_frame(&frame).await?; + self.client.connection.write_frame(&frame).await?; // if the input channel list is empty, server acknowledges as unsubscribing // from all subscribed channels, so we assert that the unsubscribe list received // matches the client subscribed one - if channels.is_empty() { - channels = Vec::from_iter(self.subscribed_channels.clone()); - } + let num = if channels.is_empty() { + self.subscribed_channels.len() + } else { + channels.len() + }; // Read the response - for _channel in &channels { - let response = self.read_response().await?; + for _ in 0..num { + let response = self.client.read_response().await?; + match response { Frame::Array(ref frame) => match frame.as_slice() { - [unsubscribe, uchannel, ..] if &unsubscribe.to_string() == "unsubscribe" => { - //unsubscribed channel should exist in the subscribed list at this point - if self.subscribed_channels.remove(&uchannel.to_string()) == false { + [unsubscribe, channel, ..] if **unsubscribe == "unsubscribe" => { + let len = self.subscribed_channels.len(); + + if len == 0 { + // There must be at least one channel + return Err(response.to_error()); + } + + // unsubscribed channel should exist in the subscribed list at this point + self.subscribed_channels.retain(|c| **channel != &c[..]); + + // Only a single channel should be removed from the + // liste of subscribed channels. + if self.subscribed_channels.len() != len - 1 { return Err(response.to_error()); } }, @@ -281,56 +471,4 @@ impl Subscriber { Ok(()) } - - /// Receives a frame published from server on socket and convert it to a `Message` - /// if frame is not `Frame::Array` with proper message structure return Err - async fn receive_message(&mut self) -> Option> { - match self.conn.read_frame().await { - Ok(None) => None, - Err(err) => Some(Err(err.into())), - Ok(Some(mframe)) => { - debug!(?mframe); - match mframe { - Frame::Array(ref frame) => match frame.as_slice() { - [message, channel, content] if &message.to_string() == "message" => { - Some(Ok(Message { - channel: channel.to_string(), - content: Bytes::from(content.to_string()), - })) - } - _ => Some(Err(mframe.to_error())), - }, - frame => Some(Err(frame.to_error())), - } - } - } - } - - /// Reads a response frame to a command from the socket. If an `Error` frame is read, it - /// is converted to `Err`. - async fn read_response(&mut self) -> crate::Result { - let response = self.conn.read_frame().await?; - - debug!(?response); - - match response { - Some(Frame::Error(msg)) => Err(msg.into()), - Some(frame) => Ok(frame), - None => { - // Receiving `None` here indicates the server has closed the - // connection without sending a frame. This is unexpected and is - // represented as a "connection reset by peer" error. - let err = Error::new(ErrorKind::ConnectionReset, "connection reset by server"); - - Err(err.into()) - } - } - } -} - -/// A message received on a subscribed channel -#[derive(Debug, Clone)] -pub struct Message { - pub channel: String, - pub content: Bytes, } diff --git a/src/cmd/get.rs b/src/cmd/get.rs index c5d1e7b..573d68e 100644 --- a/src/cmd/get.rs +++ b/src/cmd/get.rs @@ -20,10 +20,10 @@ impl Get { Get { key: key.to_string() } } - /// Parse a `Get` instance from received data. + /// Parse a `Get` instance from a received frame. /// - /// The `Parse` argument provides a cursor like API to read fields from a - /// received `Frame`. At this point, the data has already been received from + /// The `Parse` argument provides a cursor-like API to read fields from the + /// `Frame`. At this point, the entire frame has already been received from /// the socket. /// /// The `GET` string has already been consumed. diff --git a/src/cmd/mod.rs b/src/cmd/mod.rs index c38a076..a99363d 100644 --- a/src/cmd/mod.rs +++ b/src/cmd/mod.rs @@ -110,7 +110,7 @@ impl Command { Command::Set(_) => "set", Command::Subscribe(_) => "subscribe", Command::Unsubscribe(_) => "unsubscribe", - Command::Unknown(cmd) => &cmd.command_name, + Command::Unknown(cmd) => cmd.get_name(), } } } diff --git a/src/cmd/publish.rs b/src/cmd/publish.rs index a6ccc6e..3c28b1c 100644 --- a/src/cmd/publish.rs +++ b/src/cmd/publish.rs @@ -2,15 +2,59 @@ use crate::{Connection, Db, Frame, Parse}; use bytes::Bytes; +/// Posts a message to the given channel. +/// +/// Send a message into a channel without any knowledge of individual consumers. +/// Consumers may subscribe to channels in order to receive the messages. +/// +/// Channel names have no relation to the key-value namespace. Publishing on a +/// channel named "foo" has no relation to setting the "foo" key. #[derive(Debug)] pub struct Publish { - pub(crate) channel: String, - pub(crate) message: Bytes, + /// Name of the channel on which the message should be published. + channel: String, + + /// The message to publish. + message: Bytes, } impl Publish { + /// Create a new `Publish` command which sends `message` on `channel`. + pub(crate) fn new(channel: impl ToString, message: Bytes) -> Publish { + Publish { + channel: channel.to_string(), + message, + } + } + + /// Parse a `Publish` instance from a received frame. + /// + /// The `Parse` argument provides a cursor-like API to read fields from the + /// `Frame`. At this point, the entire frame has already been received from + /// the socket. + /// + /// The `PUBLISH` string has already been consumed. + /// + /// # Returns + /// + /// On success, the `Publish` value is returned. If the frame is malformed, + /// `Err` is returned. + /// + /// # Format + /// + /// Expects an array frame containing three entries. + /// + /// ```text + /// PUBLISH channel message + /// ``` pub(crate) fn parse_frames(parse: &mut Parse) -> crate::Result { + // The `PUBLISH` string has already been consumed. Extract the `channel` + // and `message` values from the frame. + // + // The `channel` must be a valid string. let channel = parse.next_string()?; + + // The `message` is arbitrary bytes. let message = parse.next_bytes()?; Ok(Publish { channel, message }) @@ -21,14 +65,31 @@ impl Publish { /// The response is written to `dst`. This is called by the server in order /// to execute a received command. pub(crate) async fn apply(self, db: &Db, dst: &mut Connection) -> crate::Result<()> { - // Set the value + // The shared state contains the `tokio::sync::broadcast::Sender` for + // all active channels. Calling `db.publish` dispatches the message into + // the appropriate channel. + // + // The number of subscribers currently listening on the channel is + // returned. This does not mean that `num_subscriber` channels will + // receive the message. Subscribers may drop before receiving the + // message. Given this, `num_subscribers` should only be used as a + // "hint". let num_subscribers = db.publish(&self.channel, self.message); + // The number of subscribers is returned as the response to the publish + // request. let response = Frame::Integer(num_subscribers as u64); + + // Write the frame to the client. dst.write_frame(&response).await?; + Ok(()) } + /// Converts the command into an equivalent `Frame`. + /// + /// This is called by the client when encoding a `Publish` command to send + /// to the server. pub(crate) fn into_frame(self) -> Frame { let mut frame = Frame::array(); frame.push_bulk(Bytes::from("publish".as_bytes())); diff --git a/src/cmd/set.rs b/src/cmd/set.rs index 0ddfc3a..e417ce6 100644 --- a/src/cmd/set.rs +++ b/src/cmd/set.rs @@ -42,10 +42,10 @@ impl Set { } } - /// Parse a `Set` instance from received data. + /// Parse a `Set` instance from a received frame. /// - /// The `Parse` argument provides a cursor like API to read fields from a - /// received `Frame`. At this point, the data has already been received from + /// The `Parse` argument provides a cursor-like API to read fields from the + /// `Frame`. At this point, the entire frame has already been received from /// the socket. /// /// The `SET` string has already been consumed. diff --git a/src/cmd/subscribe.rs b/src/cmd/subscribe.rs index 9c6c431..d93db90 100644 --- a/src/cmd/subscribe.rs +++ b/src/cmd/subscribe.rs @@ -5,27 +5,75 @@ use bytes::Bytes; use tokio::select; use tokio::stream::{StreamExt, StreamMap}; +/// Subscribes the client to one or more channels. +/// +/// Once the client enters the subscribed state, it is not supposed to issue any +/// other commands, except for additional SUBSCRIBE, PSUBSCRIBE, UNSUBSCRIBE, +/// PUNSUBSCRIBE, PING and QUIT commands. #[derive(Debug)] pub struct Subscribe { - pub(crate) channels: Vec, + channels: Vec, } +/// Unsubscribes the client from one or more channels. +/// +/// When no channels are specified, the client is unsubscribed from all the +/// previously subscribed channels. #[derive(Clone, Debug)] pub struct Unsubscribe { - pub(crate) channels: Vec, + channels: Vec, } impl Subscribe { + /// Creates a new `Subscribe` command to listen on the specified channels. + pub(crate) fn new(channels: &[String]) -> Subscribe { + Subscribe { channels: channels.to_vec() } + } + + /// Parse a `Subscribe` instance from a received frame. + /// + /// The `Parse` argument provides a cursor-like API to read fields from the + /// `Frame`. At this point, the entire frame has already been received from + /// the socket. + /// + /// The `SUBSCRIBE` string has already been consumed. + /// + /// # Returns + /// + /// On success, the `Subscribe` value is returned. If the frame is + /// malformed, `Err` is returned. + /// + /// # Format + /// + /// Expects an array frame containing two or more entries. + /// + /// ```text + /// SUBSCRIBE channel [channel ...] + /// ``` pub(crate) fn parse_frames(parse: &mut Parse) -> crate::Result { use ParseError::EndOfStream; - // There must be at least one channel + // The `SUBSCRIBE` string has already been consumed. At this point, + // there is one or more strings remaining in `parse`. These represent + // the channels to subscribe to. + // + // Extract the first string. If there is none, the the frame is + // malformed and the error is bubbled up. let mut channels = vec![parse.next_string()?]; + // Now, the remainder of the frame is consumed. Each value must be a + // string or the frame is malformed. Once all values in the frame have + // been consumed, the command is fully parsed. loop { match parse.next_string() { + // A string has been consumed from the `parse`, push it into the + // list of channels to subscribe to. Ok(s) => channels.push(s), + // The `EndOfStream` error indicates there is no further data to + // parse. Err(EndOfStream) => break, + // All other errors are bubbled up, resulting in the connection + // being terminated. Err(err) => return Err(err.into()), } } @@ -87,11 +135,18 @@ impl Subscribe { select! { // Receive messages from subscribed channels Some((channel, msg)) = subscriptions.next() => { + use tokio::sync::broadcast::RecvError; + + let msg = match msg { + Ok(msg) => msg, + Err(RecvError::Lagged(_)) => continue, + Err(RecvError::Closed) => unreachable!(), + }; + let mut response = Frame::array(); response.push_bulk(Bytes::from_static(b"message")); response.push_bulk(Bytes::copy_from_slice(channel.as_bytes())); - // TODO: handle lag error - response.push_bulk(msg.unwrap()); + response.push_bulk(msg); dst.write_frame(&response).await?; } @@ -149,6 +204,10 @@ impl Subscribe { } } + /// Converts the command into an equivalent `Frame`. + /// + /// This is called by the client when encoding a `Subscribe` command to send + /// to the server. pub(crate) fn into_frame(self) -> Frame { let mut frame = Frame::array(); frame.push_bulk(Bytes::from("subscribe".as_bytes())); @@ -160,16 +219,50 @@ impl Subscribe { } impl Unsubscribe { + /// Create a new `Unsubscribe` command with the given `channels`. + pub(crate) fn new(channels: &[String]) -> Unsubscribe { + Unsubscribe { channels: channels.to_vec() } + } + + /// Parse a `Unsubscribe` instance from a received frame. + /// + /// The `Parse` argument provides a cursor-like API to read fields from the + /// `Frame`. At this point, the entire frame has already been received from + /// the socket. + /// + /// The `UNSUBSCRIBE` string has already been consumed. + /// + /// # Returns + /// + /// On success, the `Unsubscribe` value is returned. If the frame is + /// malformed, `Err` is returned. + /// + /// # Format + /// + /// Expects an array frame containing at least one entry. + /// + /// ```text + /// UNSUBSCRIBE [channel [channel ...]] + /// ``` pub(crate) fn parse_frames(parse: &mut Parse) -> Result { use ParseError::EndOfStream; - // There may be no channels listed. + // There may be no channels listed, so start with an empty vec. let mut channels = vec![]; + // Each entry in the frame must be a string or the frame is malformed. + // Once all values in the frame have been consumed, the command is fully + // parsed. loop { match parse.next_string() { + // A string has been consumed from the `parse`, push it into the + // list of channels to unsubscribe from. Ok(s) => channels.push(s), + // The `EndOfStream` error indicates there is no further data to + // parse. Err(EndOfStream) => break, + // All other errors are bubbled up, resulting in the connection + // being terminated. Err(err) => return Err(err), } } @@ -177,12 +270,18 @@ impl Unsubscribe { Ok(Unsubscribe { channels }) } + /// Converts the command into an equivalent `Frame`. + /// + /// This is called by the client when encoding an `Unsubscribe` command to + /// send to the server. pub(crate) fn into_frame(self) -> Frame { let mut frame = Frame::array(); frame.push_bulk(Bytes::from("unsubscribe".as_bytes())); + for channel in self.channels { frame.push_bulk(Bytes::from(channel.into_bytes())); } + frame } } diff --git a/src/cmd/unknown.rs b/src/cmd/unknown.rs index 20437b2..fc7c5ea 100644 --- a/src/cmd/unknown.rs +++ b/src/cmd/unknown.rs @@ -2,9 +2,10 @@ use crate::{Connection, Frame}; use tracing::{debug, instrument}; +/// Represents an "unknown" command. This is not a real `Redis` command. #[derive(Debug)] pub struct Unknown { - pub command_name: String, + command_name: String, } impl Unknown { @@ -14,6 +15,14 @@ impl Unknown { Unknown { command_name: key.to_string() } } + /// Returns the command name + pub(crate) fn get_name(&self) -> &str { + &self.command_name + } + + /// Responds to the client, indicating the command is not recognized. + /// + /// This usually means the command is not yet implemented by `mini-redis`. #[instrument(skip(self, dst))] pub(crate) async fn apply(self, dst: &mut Connection) -> crate::Result<()> { let response = Frame::Error(format!("ERR unknown command '{}'", self.command_name)); diff --git a/src/frame.rs b/src/frame.rs index 3f273a9..1419940 100644 --- a/src/frame.rs +++ b/src/frame.rs @@ -169,6 +169,16 @@ impl Frame { } } +impl PartialEq<&str> for Frame { + fn eq(&self, other: &&str) -> bool { + match self { + Frame::Simple(s) => s.eq(other), + Frame::Bulk(s) => s.eq(other), + _ => false, + } + } +} + impl fmt::Display for Frame { fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { use std::str; diff --git a/src/parse.rs b/src/parse.rs index 53f3356..74b26b1 100644 --- a/src/parse.rs +++ b/src/parse.rs @@ -4,18 +4,35 @@ use bytes::Bytes; use std::{fmt, str, vec}; /// Utility for parsing a command +/// +/// Commands are represented as array frames. Each entry in the frame is a +/// "token". A `Parse` is initialized with the array frame and provides a +/// cursor-like API. Each command struct includes a `parse_frame` method that +/// uses a `Parse` to extract its fields. #[derive(Debug)] pub(crate) struct Parse { + /// Array frame iterator. parts: vec::IntoIter>, } +/// Error encountered while parsing a frame. +/// +/// Only `EndOfStream` errors are handled at runtime. All other errors result in +/// the connection being terminated. #[derive(Debug)] pub(crate) enum ParseError { + /// Attempting to extract a value failed due to the frame being fully + /// consumed. EndOfStream, + + /// All other errors Other(crate::Error), } impl Parse { + /// Create a new `Parse` to parse the contents of `frame`. + /// + /// Returns `Err` if `frame` is not an array frame. pub(crate) fn new(frame: Frame) -> Result { let array = match frame { Frame::Array(array) => array, @@ -27,6 +44,8 @@ impl Parse { }) } + /// Return the next entry. Array frames are arrays of frames, so the next + /// entry is a frame. fn next(&mut self) -> Result { self.parts .next() @@ -34,8 +53,16 @@ impl Parse { .ok_or(ParseError::EndOfStream) } + /// Return the next entry as a string. + /// + /// If the next entry cannot be represented as a String, then an error is returned. pub(crate) fn next_string(&mut self) -> Result { match self.next()? { + // Both `Simple` and `Bulk` representation may be strings. Strings + // are parsed to UTF-8. + // + // While errors are stored as strings, they are considered separate + // types. Frame::Simple(s) => Ok(s), Frame::Bulk(data) => str::from_utf8(&data[..]) .map(|s| s.to_string()) @@ -44,21 +71,39 @@ impl Parse { } } + /// Return the next entry as raw bytes. + /// + /// If the next entry cannot be represented as raw bytes, an error is + /// returned. pub(crate) fn next_bytes(&mut self) -> Result { match self.next()? { + // Both `Simple` and `Bulk` representation may be raw bytes. + // + // Although errors are stored as strings and could be represented as + // raw bytes, they are considered separate types. Frame::Simple(s) => Ok(Bytes::from(s.into_bytes())), Frame::Bulk(data) => Ok(data), frame => Err(format!("protocol error; expected simple frame or bulk frame, got {:?}", frame).into()), } } + /// Return the next entry as an integer. + /// + /// This includes `Simple`, `Bulk`, and `Integer` frame types. `Simple` and + /// `Bulk` frame types are parsed. + /// + /// If the next entry cannot be represented as an integer, then an error is + /// returned. pub(crate) fn next_int(&mut self) -> Result { use atoi::atoi; const MSG: &str = "protocol error; invalid number"; match self.next()? { + // An integer frame type is already stored as an integer. Frame::Integer(v) => Ok(v), + // Simple and bulk frames must be parsed as integers. If the parsing + // fails, an error is returned. Frame::Simple(data) => atoi::(data.as_bytes()).ok_or_else(|| MSG.into()), Frame::Bulk(data) => atoi::(&data).ok_or_else(|| MSG.into()), frame => Err(format!("protocol error; expected int frame but got {:?}", frame).into()), diff --git a/src/shutdown.rs b/src/shutdown.rs index 03c9e34..bf1b1c3 100644 --- a/src/shutdown.rs +++ b/src/shutdown.rs @@ -1,12 +1,25 @@ use tokio::sync::broadcast; +/// Listens for the server shutdown signal. +/// +/// Shutdown is signalled using a `broadcast::Receiver`. Only a single value is +/// ever sent. Once a value has been sent via the broadcast channel, the server +/// should shutdown. +/// +/// The `Shutdown` struct listens for the signal and tracks that the signal has +/// been received. Callers may query for whether the shutdown signal has been +/// received or not. #[derive(Debug)] pub(crate) struct Shutdown { + /// `true` if the shutdown signal has been received shutdown: bool, + + /// The receive half of the channel used to listen for shutdown. notify: broadcast::Receiver<()>, } impl Shutdown { + /// Create a new `Shutdown` backed by the given `broadcast::Receiver`. pub(crate) fn new(notify: broadcast::Receiver<()>) -> Shutdown { Shutdown { shutdown: false, @@ -14,18 +27,23 @@ impl Shutdown { } } + /// Returns `true` if the shutdown signal has been received. pub(crate) fn is_shutdown(&self) -> bool { self.shutdown } - /// Receive the shutdown notice + /// Receive the shutdown notice, waiting if necessary. pub(crate) async fn recv(&mut self) { + // If the shutdown signal has already been received, then return + // immediately. if self.shutdown { return; } // Cannot receive a "lag error" as only one value is ever sent. let _ = self.notify.recv().await; + + // Remember that the signal has been received. self.shutdown = true; } } diff --git a/tests/client.rs b/tests/client.rs index 648590a..9b1a837 100644 --- a/tests/client.rs +++ b/tests/client.rs @@ -31,7 +31,7 @@ async fn receive_message_subscribed_channel() { client.publish("hello", "world".into()).await.unwrap() }); - let message = subscriber.next_message().await.unwrap(); + let message = subscriber.next_message().await.unwrap().unwrap(); assert_eq!("hello", &message.channel); assert_eq!(b"world", &message.content[..]) } @@ -49,7 +49,7 @@ async fn receive_message_multiple_subscribed_channels() { client.publish("hello", "world".into()).await.unwrap() }); - let message1 = subscriber.next_message().await.unwrap(); + let message1 = subscriber.next_message().await.unwrap().unwrap(); assert_eq!("hello", &message1.channel); assert_eq!(b"world", &message1.content[..]); @@ -59,7 +59,7 @@ async fn receive_message_multiple_subscribed_channels() { }); - let message2 = subscriber.next_message().await.unwrap(); + let message2 = subscriber.next_message().await.unwrap().unwrap(); assert_eq!("world", &message2.channel); assert_eq!(b"howdy?", &message2.content[..]) } @@ -73,7 +73,7 @@ async fn unsubscribes_from_channels() { let client = client::connect(addr.clone()).await.unwrap(); let mut subscriber = client.subscribe(vec!["hello".into(), "world".into()]).await.unwrap(); - subscriber.unsubscribe(vec![]).await.unwrap(); + subscriber.unsubscribe(&[]).await.unwrap(); assert_eq!(subscriber.get_subscribed().len(), 0); }