From 5752d1e0fc15c101c0397a220b5cfcdcaffd7a8b Mon Sep 17 00:00:00 2001 From: Carl Lerche Date: Mon, 13 Apr 2020 21:02:32 -0700 Subject: [PATCH] mostly docs, some code tweaks as well (#31) Db background tasks never shutdown o_O --- README.md | 5 ++ src/client.rs | 14 +--- src/cmd/get.rs | 57 +++++++++++--- src/cmd/mod.rs | 42 ++++++++++- src/cmd/publish.rs | 8 +- src/cmd/set.rs | 92 ++++++++++++++++++++--- src/cmd/subscribe.rs | 7 +- src/db.rs | 172 +++++++++++++++++++++++++++++++++++++++---- 8 files changed, 339 insertions(+), 58 deletions(-) diff --git a/README.md b/README.md index d69455a..41e3356 100644 --- a/README.md +++ b/README.md @@ -112,6 +112,11 @@ the server to update the active subscriptions. [broadcast]: https://docs.rs/tokio/*/tokio/sync/broadcast/index.html [`StreamMap`]: https://docs.rs/tokio/*/tokio/stream/struct.StreamMap.html +### Using a `std::sync::Mutex` in an async application + +The server uses a `std::sync::Mutex` and **not** a Tokio mutex to synchronize +access to shared state. See [`db.rs`](src/db.rs) for more details. + ## Contributing Contributions to `mini-redis` are welcome. Keep in mind, the goal of the project diff --git a/src/client.rs b/src/client.rs index dbac8f2..cbc61a0 100644 --- a/src/client.rs +++ b/src/client.rs @@ -47,12 +47,7 @@ impl Client { /// Set the value of a key to `value`. #[instrument(skip(self))] pub async fn set(&mut self, key: &str, value: Bytes) -> crate::Result<()> { - self.set_cmd(Set { - key: key.to_string(), - value: value, - expire: None, - }) - .await + self.set_cmd(Set::new(key, value, None)).await } /// publish `message` on the `channel` @@ -88,12 +83,7 @@ impl Client { value: Bytes, expiration: Duration, ) -> crate::Result<()> { - self.set_cmd(Set { - key: key.to_string(), - value: value.into(), - expire: Some(expiration), - }) - .await + self.set_cmd(Set::new(key, value, Some(expiration))).await } async fn set_cmd(&mut self, cmd: Set) -> crate::Result<()> { diff --git a/src/cmd/get.rs b/src/cmd/get.rs index a39fb82..c5d1e7b 100644 --- a/src/cmd/get.rs +++ b/src/cmd/get.rs @@ -1,10 +1,16 @@ -use crate::{Connection, Db, Frame, Parse, ParseError}; +use crate::{Connection, Db, Frame, Parse}; use bytes::Bytes; use tracing::{debug, instrument}; +/// Get the value of key. +/// +/// If the key does not exist the special value nil is returned. An error is +/// returned if the value stored at key is not a string, because GET only +/// handles string values. #[derive(Debug)] pub struct Get { + /// Name of the key to get key: String, } @@ -14,36 +20,63 @@ impl Get { Get { key: key.to_string() } } - // instrumenting functions will log all of the arguments passed to the function - // with their debug implementations - // see https://docs.rs/tracing/0.1.13/tracing/attr.instrument.html - pub(crate) fn parse_frames(parse: &mut Parse) -> Result { + /// Parse a `Get` instance from received data. + /// + /// The `Parse` argument provides a cursor like API to read fields from a + /// received `Frame`. At this point, the data has already been received from + /// the socket. + /// + /// The `GET` string has already been consumed. + /// + /// # Returns + /// + /// Returns the `Get` value on success. If the frame is malformed, `Err` is + /// returned. + /// + /// # Format + /// + /// Expects an array frame containing two entries. + /// + /// ```text + /// GET key + /// ``` + pub(crate) fn parse_frames(parse: &mut Parse) -> crate::Result { + // The `GET` string has already been consumed. The next value is the + // name of the key to get. If the next value is not a string or the + // input is fully consumed, then an error is returned. let key = parse.next_string()?; - // adding this debug event allows us to see what key is parsed - // the ? sigil tells `tracing` to use the `Debug` implementation - // get parse events can be filtered by running - // RUST_LOG=mini_redis::cmd::get[parse_frames]=debug cargo run --bin server - // see https://docs.rs/tracing/0.1.13/tracing/#recording-fields - debug!(?key); - Ok(Get { key }) } + /// Apply the `Get` command to the specified `Db` instance. + /// + /// The response is written to `dst`. This is called by the server in order + /// to execute a received command. #[instrument(skip(self, db, dst))] pub(crate) async fn apply(self, db: &Db, dst: &mut Connection) -> crate::Result<()> { + // Get the value from the shared database state let response = if let Some(value) = db.get(&self.key) { + // If a value is present, it is written to the client in "bulk" + // format. Frame::Bulk(value) } else { + // If there is no value, `Null` is written. Frame::Null }; debug!(?response); + // Write the response back to the client dst.write_frame(&response).await?; + Ok(()) } + /// Converts the command into an equivalent `Frame`. + /// + /// This is called by the client when encoding a `Get` command to send to + /// the server. pub(crate) fn into_frame(self) -> Frame { let mut frame = Frame::array(); frame.push_bulk(Bytes::from("get".as_bytes())); diff --git a/src/cmd/mod.rs b/src/cmd/mod.rs index 5055e9f..c38a076 100644 --- a/src/cmd/mod.rs +++ b/src/cmd/mod.rs @@ -15,6 +15,9 @@ pub use unknown::Unknown; use crate::{Connection, Db, Frame, Parse, ParseError, Shutdown}; +/// Enumeration of supported Redis commands. +/// +/// Methods called on `Command` are delegated to the command implementation. #[derive(Debug)] pub(crate) enum Command { Get(Get), @@ -26,11 +29,29 @@ pub(crate) enum Command { } impl Command { + /// Parse a command from a received frame. + /// + /// The `Frame` must represent a Redis command supported by `mini-redis` and + /// be the array variant. + /// + /// # Returns + /// + /// On success, the command value is returned, otherwise, `Err` is returned. pub(crate) fn from_frame(frame: Frame) -> crate::Result { + // The frame value is decorated with `Parse`. `Parse` provides a + // "cursor" like API which makes parsing the command easier. + // + // The frame value must be an array variant. Any other frame variants + // result in an error being returned. let mut parse = Parse::new(frame)?; + // All redis commands begin with the command name as a string. The name + // is read and converted to lower casae in order to do case sensitive + // matching. let command_name = parse.next_string()?.to_lowercase(); + // Match the command name, delegating the rest of the parsing to the + // specific command. let command = match &command_name[..] { "get" => Command::Get(Get::parse_frames(&mut parse)?), "publish" => Command::Publish(Publish::parse_frames(&mut parse)?), @@ -38,15 +59,29 @@ impl Command { "subscribe" => Command::Subscribe(Subscribe::parse_frames(&mut parse)?), "unsubscribe" => Command::Unsubscribe(Unsubscribe::parse_frames(&mut parse)?), _ => { - parse.next_string()?; - Command::Unknown(Unknown::new(command_name)) + // The command is not recognized and an Unknown command is + // returned. + // + // `return` is called here to skip the `finish()` call below. As + // the command is not recognized, there is most likely + // unconsumed fields remaining in the `Parse` instance. + return Ok(Command::Unknown(Unknown::new(command_name))); }, }; + // Check if there is any remaining unconsumed fields in the `Parse` + // value. If fields remain, this indicates an unexpected frame format + // and an error is returned. parse.finish()?; + + // The command has been successfully parsed Ok(command) } + /// Apply the command to the specified `Db` instance. + /// + /// The response is written to `dst`. This is called by the server in order + /// to execute a received command. pub(crate) async fn apply( self, db: &Db, @@ -63,10 +98,11 @@ impl Command { Unknown(cmd) => cmd.apply(dst).await, // `Unsubscribe` cannot be applied. It may only be received from the // context of a `Subscribe` command. - Unsubscribe(_) => unimplemented!(), + Unsubscribe(_) => Err("`Unsubscribe` is unsupported in this context".into()), } } + /// Returns the command name pub(crate) fn get_name(&self) -> &str { match self { Command::Get(_) => "get", diff --git a/src/cmd/publish.rs b/src/cmd/publish.rs index dc13a7e..a6ccc6e 100644 --- a/src/cmd/publish.rs +++ b/src/cmd/publish.rs @@ -1,4 +1,4 @@ -use crate::{Connection, Db, Frame, Parse, ParseError}; +use crate::{Connection, Db, Frame, Parse}; use bytes::Bytes; @@ -9,13 +9,17 @@ pub struct Publish { } impl Publish { - pub(crate) fn parse_frames(parse: &mut Parse) -> Result { + pub(crate) fn parse_frames(parse: &mut Parse) -> crate::Result { let channel = parse.next_string()?; let message = parse.next_bytes()?; Ok(Publish { channel, message }) } + /// Apply the `Publish` command to the specified `Db` instance. + /// + /// The response is written to `dst`. This is called by the server in order + /// to execute a received command. pub(crate) async fn apply(self, db: &Db, dst: &mut Connection) -> crate::Result<()> { // Set the value let num_subscribers = db.publish(&self.channel, self.message); diff --git a/src/cmd/set.rs b/src/cmd/set.rs index ac6a231..0ddfc3a 100644 --- a/src/cmd/set.rs +++ b/src/cmd/set.rs @@ -5,57 +5,127 @@ use bytes::Bytes; use std::time::Duration; use tracing::{debug, instrument}; +/// Set `key` to hold the string `value`. +/// +/// If `key` already holds a value, it is overwritten, regardless of its type. +/// Any previous time to live associated with the key is discarded on successful +/// SET operation. +/// +/// # Options +/// +/// Currently, the following options are supported: +/// +/// * EX `seconds` -- Set the specified expire time, in seconds. +/// * PX `milliseconds` -- Set the specified expire time, in milliseconds. #[derive(Debug)] pub struct Set { /// the lookup key - pub(crate) key: String, + key: String, /// the value to be stored - pub(crate) value: Bytes, + value: Bytes, /// When to expire the key - pub(crate) expire: Option, + expire: Option, } impl Set { - #[instrument] - pub(crate) fn parse_frames(parse: &mut Parse) -> Result { + /// Create a new `Set` command which sets `key` to `value`. + /// + /// If `expire` is `Some`, the value should expire after the specified + /// duration. + pub(crate) fn new(key: impl ToString, value: Bytes, expire: Option) -> Set { + Set { + key: key.to_string(), + value, + expire, + } + } + + /// Parse a `Set` instance from received data. + /// + /// The `Parse` argument provides a cursor like API to read fields from a + /// received `Frame`. At this point, the data has already been received from + /// the socket. + /// + /// The `SET` string has already been consumed. + /// + /// # Returns + /// + /// Returns the `Set` value on success. If the frame is malformed, `Err` is + /// returned. + /// + /// # Format + /// + /// Expects an array frame containing at least 3 entries. + /// + /// ```text + /// SET key value [EX seconds|PX milliseconds] + /// ``` + pub(crate) fn parse_frames(parse: &mut Parse) -> crate::Result { use ParseError::EndOfStream; + // Read the key to set. This is a required field let key = parse.next_string()?; + + // Read the value to set. This is a required field. let value = parse.next_bytes()?; + + // The expiration is optional. If nothing else follows, then it is + // `None`. let mut expire = None; + // Attempt to parse another string. match parse.next_string() { Ok(s) if s == "EX" => { + // An expiration is specified in seconds. The next value is an + // integer. let secs = parse.next_int()?; expire = Some(Duration::from_secs(secs)); } Ok(s) if s == "PX" => { + // An expiration is specified in milliseconds. The next value is + // an integer. let ms = parse.next_int()?; expire = Some(Duration::from_millis(ms)); } - Ok(_) => unimplemented!(), + // Currently, mini-redis does not support any of the other SET + // options. An error here results in the connection being + // terminated. Other connections will continue to operate normally. + Ok(_) => return Err("currently `SET` only supports the expiration option".into()), + // The `EndOfStream` error indicates there is no further data to + // parse. In this case, it is a normal run time situation and + // indicates there are no specified `SET` options. Err(EndOfStream) => {} - Err(err) => return Err(err), + // All other errors are bubbled up, resulting in the connection + // being terminated. + Err(err) => return Err(err.into()), } - debug!(?key, ?value, ?expire); - Ok(Set { key, value, expire }) } - #[instrument(skip(db))] + /// Apply the `Get` command to the specified `Db` instace. + /// + /// The response is written to `dst`. This is called by the server in order + /// to execute a received command. + #[instrument(skip(self, db, dst))] pub(crate) async fn apply(self, db: &Db, dst: &mut Connection) -> crate::Result<()> { - // Set the value + // Set the value in the shared database state. db.set(self.key, self.value, self.expire); + // Create a success response and write it to `dst`. let response = Frame::Simple("OK".to_string()); debug!(?response); dst.write_frame(&response).await?; + Ok(()) } + /// Converts the command into an equivalent `Frame`. + /// + /// This is called by the client when encoding a `Set` command to send to + /// the server. pub(crate) fn into_frame(self) -> Frame { let mut frame = Frame::array(); frame.push_bulk(Bytes::from("set".as_bytes())); diff --git a/src/cmd/subscribe.rs b/src/cmd/subscribe.rs index de85b80..9c6c431 100644 --- a/src/cmd/subscribe.rs +++ b/src/cmd/subscribe.rs @@ -16,7 +16,7 @@ pub struct Unsubscribe { } impl Subscribe { - pub(crate) fn parse_frames(parse: &mut Parse) -> Result { + pub(crate) fn parse_frames(parse: &mut Parse) -> crate::Result { use ParseError::EndOfStream; // There must be at least one channel @@ -26,15 +26,14 @@ impl Subscribe { match parse.next_string() { Ok(s) => channels.push(s), Err(EndOfStream) => break, - Err(err) => return Err(err), + Err(err) => return Err(err.into()), } } Ok(Subscribe { channels }) } - /// Implements the "subscribe" half of Redis' Pub/Sub feature documented - /// [here]. + /// Apply the `Subscribe` command to the specified `Db` instance. /// /// This function is the entry point and includes the initial list of /// channels to subscribe to. Additional `subscribe` and `unsubscribe` diff --git a/src/db.rs b/src/db.rs index 5a4858a..102f288 100644 --- a/src/db.rs +++ b/src/db.rs @@ -5,32 +5,77 @@ use bytes::Bytes; use std::collections::{BTreeMap, HashMap}; use std::sync::{Arc, Mutex}; +/// Server state shared across all connections. +/// +/// `Db` contains a `HashMap` storing the key/value data and all +/// `broadcast::Sender` values for active pub/sub channels. +/// +/// A `Db` instance is a handle to shared state. Cloning `Db` is shallow and +/// only incurs an atomic ref count increment. +/// +/// When a `Db` value is created, a background task is spawned. This task is +/// used to expire values after the requested duration has elapsed. The task +/// runs until all instances of `Db` are dropped, at which point the task +/// terminates. #[derive(Debug, Clone)] pub(crate) struct Db { + /// Handle to shared state. The background task will also have an + /// `Arc`. shared: Arc, } #[derive(Debug)] struct Shared { + /// The shared state is guarded by a mutex. This is a `std::sync::Mutex` and + /// not a Tokio mutex. This is because there are no asynchronous operations + /// being performed while holding the mutex. Additionally, the critical + /// sections are very small. + /// + /// A Tokio mutex is mostly intended to be used when locks need to be held + /// across `.await` yield points. All other cases are **usually** best + /// served by a std mutex. If the critical section does not include any + /// async operations but is long (CPU intensive or performing blocking + /// operations), then the entire operation, including waiting for the mutex, + /// is considered a "blocking" operation and `tokio::task::spawn_blocking` + /// should be used. state: Mutex, - /// Notifies the task handling entry expiration - expire_task: Notify, + /// Notifies the background task handling entry expiration. The background + /// task waits on this to be notified, then checks for expired values or the + /// shutdown signal. + background_task: Notify, } #[derive(Debug)] struct State { - /// The key-value data + /// The key-value data. We are not trying to do anything fancy so a + /// `std::collections::HashMap` works fine. entries: HashMap, - /// The pub/sub key-space + /// The pub/sub key-space. Redis uses a **separate** key space for key-value + /// and pub/sub. `mini-redis` handles this by using a separate `HashMap`. pub_sub: HashMap>, /// Tracks key TTLs. + /// + /// A `BTreeMap` is used to maintain expirations sorted by when they expire. + /// This allows the background task to iterate this map to find the value + /// expiring next. + /// + /// While highly unlikely, it is possible for more than one expiration to be + /// created for the same instant. Because of this, the `Instant` is + /// insufficient for the key. A unique expiration identifier (`u64`) is used + /// to break these ties. expirations: BTreeMap<(Instant, u64), String>, - /// Identifier to use for the next expiration. + /// Identifier to use for the next expiration. Each expiration is associated + /// with a unique identifier. See above for why. next_id: u64, + + /// True when the Db instance is shutting down. This happens when all `Db` + /// values drop. Setting this to `true` signals to the background task to + /// exit. + shutdown: bool, } /// Entry in the key-value store @@ -48,6 +93,8 @@ struct Entry { } impl Db { + /// Create a new, empty, `Db` instance. Allocates shared state and spawns a + /// background task to manage key expiration. pub(crate) fn new() -> Db { let shared = Arc::new(Shared { state: Mutex::new(State { @@ -55,8 +102,9 @@ impl Db { pub_sub: HashMap::new(), expirations: BTreeMap::new(), next_id: 0, + shutdown: false, }), - expire_task: Notify::new(), + background_task: Notify::new(), }); // Start the background task. @@ -65,22 +113,41 @@ impl Db { Db { shared } } + /// Get the value associated with a key. + /// + /// Returns `None` if there is no value associated with the key. This may be + /// due to never having assigned a value to the key or a previously assigned + /// value expired. pub(crate) fn get(&self, key: &str) -> Option { + // Acquire the lock, get the entry and clone the value. + // + // Because data is stored using `Bytes`, a clone here is a shallow + // clone. Data is not copied. let state = self.shared.state.lock().unwrap(); state.entries.get(key).map(|entry| entry.data.clone()) } + /// Set the value associated with a key along with an optional expiration + /// Duration. + /// + /// If a value is already associated with the key, it is removed. pub(crate) fn set(&self, key: String, value: Bytes, expire: Option) { let mut state = self.shared.state.lock().unwrap(); - // Get and increment the next insertion ID. + // Get and increment the next insertion ID. Guarded by the lock, this + // ensures a unique identifier is associated with each `set` operation. let id = state.next_id; state.next_id += 1; - // By default, no notification is needed + // If this `set` becomes the key that expires **next**, the background + // task needs to be notified so it can update its state. + // + // Whether or not the task needs to be notified is computed during the + // `set` routine. let mut notify = false; let expires_at = expire.map(|duration| { + // `Instant` at which the key expires. let when = Instant::now() + duration; // Only notify the worker task if the newly inserted expiration is the @@ -91,11 +158,12 @@ impl Db { .map(|expiration| expiration > when) .unwrap_or(true); + // Track the expiration. state.expirations.insert((when, id), key.clone()); when }); - // Insert the entry. + // Insert the entry into the `HashMap`. let prev = state.entries.insert( key, Entry { @@ -105,6 +173,9 @@ impl Db { }, ); + // If there was a value previously associated with the key **and** it + // had an expiration time. The associated entry in the `expirations` map + // must also be removed. This avoids leaking data. if let Some(prev) = prev { if let Some(when) = prev.expires_at { // clear expiration @@ -112,22 +183,45 @@ impl Db { } } + // Release the mutex before notifying the background task. This helps + // reduce contention by avoiding the background task waking up only to + // be unable to acquire the mutex due to this function still holding it. drop(state); if notify { - self.shared.expire_task.notify(); + // Finally, only notify the background task if it needs to update + // its state to reflect a new expiration. + self.shared.background_task.notify(); } } + /// Returns a `Receiver` for the requested channel. + /// + /// The returned `Receiver` is used to receive values broadcast by `PUBLISH` + /// commands. pub(crate) fn subscribe(&self, key: String) -> broadcast::Receiver { use std::collections::hash_map::Entry; + // Acquire the mutex let mut state = self.shared.state.lock().unwrap(); + // If there is no entry for the requested channel, then create a new + // broadcast channel and associate it with the key. If one already + // exists, return an associated receiver. match state.pub_sub.entry(key) { Entry::Occupied(e) => e.get().subscribe(), Entry::Vacant(e) => { - let (tx, rx) = broadcast::channel(1028); + // No broadcast channel exists yet, so create one. + // + // The channel is created with a capacity of `1024` messages. A + // message is stored in the channel until **all** subscribers + // have seen it. This means that a slow subscriber could result + // in messages being held indefinitely. + // + // When the channel's capacity fills up, publishing will result + // in old messages being dropped. This prevents slow consumers + // from blocking the entire system. + let (tx, rx) = broadcast::channel(1024); e.insert(tx); rx } @@ -152,10 +246,41 @@ impl Db { } } +impl Drop for Db { + fn drop(&mut self) { + // If this is the last active `Db` instance, the background task must be + // notified to shut down. + // + // First, determine if this is the last `Db` instance. This is done by + // checking `strong_count`. The count will be 2. One for this `Db` + // intance and one for the handle held by the background task. + if Arc::strong_count(&self.shared) == 2 { + // The background task must be signaled to shutdown. This is done by + // setting `State::shutdown` to `true` and signalling the task. + let mut state = self.shared.state.lock().unwrap(); + state.shutdown = true; + + // Drop the lock before signalling the background task. This helps + // reduce lock contention by ensuring the background task doesn't + // wake up only to be unable to acquire the mutex. + drop(state); + self.shared.background_task.notify(); + } + } +} + impl Shared { + /// Purge all expired keys and return the `Instant` at which the **next** + /// key will expire. The background task will sleep until this instant. fn purge_expired_keys(&self) -> Option { let mut state = self.state.lock().unwrap(); + if state.shutdown { + // The database is shutting down. All handles to the shared state + // have dropped. The background task should exit. + return None; + } + // This is needed to make the borrow checker happy. In short, `lock()` // returns a `MutexGuard` and not a `&mut State`. The borrow checker is // not able to see "through" the mutex guard and determine that it is @@ -180,6 +305,14 @@ impl Shared { None } + + /// Returns `true` if the database is shutting down + /// + /// The `shutdown` flag is set when all `Db` values have dropped, indicating + /// that the shared state can no longer be accessed. + fn is_shutdown(&self) -> bool { + self.state.lock().unwrap().shutdown + } } impl State { @@ -191,18 +324,29 @@ impl State { } } +/// Routine executed by the background task. +/// +/// Wait to be notified. On notification, purge any expired keys from the shared +/// state handle. If `shutdown` is set, terminate the task. async fn purge_expired_tasks(shared: Arc) { - loop { + // If the shutdown flag is set, then the task should exit. + while !shared.is_shutdown() { // Purge all keys that are expired. The function returns the instant at // which the **next** key will expire. The worker should wait until the // instant has passed then purge again. if let Some(when) = shared.purge_expired_keys() { + // Wait until the next key expires **or** until the background task + // is notified. If the task is notified, then it must reload its + // state as new keys have been set to expire early. This is done by + // looping. tokio::select! { _ = time::delay_until(when) => {} - _ = shared.expire_task.notified() => {} + _ = shared.background_task.notified() => {} } } else { - shared.expire_task.notified().await; + // There are no keys expiring in the future. Wait until the task is + // notified. + shared.background_task.notified().await; } } }