mostly docs, some code tweaks as well (#31)
Db background tasks never shutdown o_O
This commit is contained in:
@@ -112,6 +112,11 @@ the server to update the active subscriptions.
|
||||
[broadcast]: https://docs.rs/tokio/*/tokio/sync/broadcast/index.html
|
||||
[`StreamMap`]: https://docs.rs/tokio/*/tokio/stream/struct.StreamMap.html
|
||||
|
||||
### Using a `std::sync::Mutex` in an async application
|
||||
|
||||
The server uses a `std::sync::Mutex` and **not** a Tokio mutex to synchronize
|
||||
access to shared state. See [`db.rs`](src/db.rs) for more details.
|
||||
|
||||
## Contributing
|
||||
|
||||
Contributions to `mini-redis` are welcome. Keep in mind, the goal of the project
|
||||
|
||||
@@ -47,12 +47,7 @@ impl Client {
|
||||
/// Set the value of a key to `value`.
|
||||
#[instrument(skip(self))]
|
||||
pub async fn set(&mut self, key: &str, value: Bytes) -> crate::Result<()> {
|
||||
self.set_cmd(Set {
|
||||
key: key.to_string(),
|
||||
value: value,
|
||||
expire: None,
|
||||
})
|
||||
.await
|
||||
self.set_cmd(Set::new(key, value, None)).await
|
||||
}
|
||||
|
||||
/// publish `message` on the `channel`
|
||||
@@ -88,12 +83,7 @@ impl Client {
|
||||
value: Bytes,
|
||||
expiration: Duration,
|
||||
) -> crate::Result<()> {
|
||||
self.set_cmd(Set {
|
||||
key: key.to_string(),
|
||||
value: value.into(),
|
||||
expire: Some(expiration),
|
||||
})
|
||||
.await
|
||||
self.set_cmd(Set::new(key, value, Some(expiration))).await
|
||||
}
|
||||
|
||||
async fn set_cmd(&mut self, cmd: Set) -> crate::Result<()> {
|
||||
|
||||
@@ -1,10 +1,16 @@
|
||||
use crate::{Connection, Db, Frame, Parse, ParseError};
|
||||
use crate::{Connection, Db, Frame, Parse};
|
||||
|
||||
use bytes::Bytes;
|
||||
use tracing::{debug, instrument};
|
||||
|
||||
/// Get the value of key.
|
||||
///
|
||||
/// If the key does not exist the special value nil is returned. An error is
|
||||
/// returned if the value stored at key is not a string, because GET only
|
||||
/// handles string values.
|
||||
#[derive(Debug)]
|
||||
pub struct Get {
|
||||
/// Name of the key to get
|
||||
key: String,
|
||||
}
|
||||
|
||||
@@ -14,36 +20,63 @@ impl Get {
|
||||
Get { key: key.to_string() }
|
||||
}
|
||||
|
||||
// instrumenting functions will log all of the arguments passed to the function
|
||||
// with their debug implementations
|
||||
// see https://docs.rs/tracing/0.1.13/tracing/attr.instrument.html
|
||||
pub(crate) fn parse_frames(parse: &mut Parse) -> Result<Get, ParseError> {
|
||||
/// Parse a `Get` instance from received data.
|
||||
///
|
||||
/// The `Parse` argument provides a cursor like API to read fields from a
|
||||
/// received `Frame`. At this point, the data has already been received from
|
||||
/// the socket.
|
||||
///
|
||||
/// The `GET` string has already been consumed.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Returns the `Get` value on success. If the frame is malformed, `Err` is
|
||||
/// returned.
|
||||
///
|
||||
/// # Format
|
||||
///
|
||||
/// Expects an array frame containing two entries.
|
||||
///
|
||||
/// ```text
|
||||
/// GET key
|
||||
/// ```
|
||||
pub(crate) fn parse_frames(parse: &mut Parse) -> crate::Result<Get> {
|
||||
// The `GET` string has already been consumed. The next value is the
|
||||
// name of the key to get. If the next value is not a string or the
|
||||
// input is fully consumed, then an error is returned.
|
||||
let key = parse.next_string()?;
|
||||
|
||||
// adding this debug event allows us to see what key is parsed
|
||||
// the ? sigil tells `tracing` to use the `Debug` implementation
|
||||
// get parse events can be filtered by running
|
||||
// RUST_LOG=mini_redis::cmd::get[parse_frames]=debug cargo run --bin server
|
||||
// see https://docs.rs/tracing/0.1.13/tracing/#recording-fields
|
||||
debug!(?key);
|
||||
|
||||
Ok(Get { key })
|
||||
}
|
||||
|
||||
/// Apply the `Get` command to the specified `Db` instance.
|
||||
///
|
||||
/// The response is written to `dst`. This is called by the server in order
|
||||
/// to execute a received command.
|
||||
#[instrument(skip(self, db, dst))]
|
||||
pub(crate) async fn apply(self, db: &Db, dst: &mut Connection) -> crate::Result<()> {
|
||||
// Get the value from the shared database state
|
||||
let response = if let Some(value) = db.get(&self.key) {
|
||||
// If a value is present, it is written to the client in "bulk"
|
||||
// format.
|
||||
Frame::Bulk(value)
|
||||
} else {
|
||||
// If there is no value, `Null` is written.
|
||||
Frame::Null
|
||||
};
|
||||
|
||||
debug!(?response);
|
||||
|
||||
// Write the response back to the client
|
||||
dst.write_frame(&response).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Converts the command into an equivalent `Frame`.
|
||||
///
|
||||
/// This is called by the client when encoding a `Get` command to send to
|
||||
/// the server.
|
||||
pub(crate) fn into_frame(self) -> Frame {
|
||||
let mut frame = Frame::array();
|
||||
frame.push_bulk(Bytes::from("get".as_bytes()));
|
||||
|
||||
@@ -15,6 +15,9 @@ pub use unknown::Unknown;
|
||||
|
||||
use crate::{Connection, Db, Frame, Parse, ParseError, Shutdown};
|
||||
|
||||
/// Enumeration of supported Redis commands.
|
||||
///
|
||||
/// Methods called on `Command` are delegated to the command implementation.
|
||||
#[derive(Debug)]
|
||||
pub(crate) enum Command {
|
||||
Get(Get),
|
||||
@@ -26,11 +29,29 @@ pub(crate) enum Command {
|
||||
}
|
||||
|
||||
impl Command {
|
||||
/// Parse a command from a received frame.
|
||||
///
|
||||
/// The `Frame` must represent a Redis command supported by `mini-redis` and
|
||||
/// be the array variant.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// On success, the command value is returned, otherwise, `Err` is returned.
|
||||
pub(crate) fn from_frame(frame: Frame) -> crate::Result<Command> {
|
||||
// The frame value is decorated with `Parse`. `Parse` provides a
|
||||
// "cursor" like API which makes parsing the command easier.
|
||||
//
|
||||
// The frame value must be an array variant. Any other frame variants
|
||||
// result in an error being returned.
|
||||
let mut parse = Parse::new(frame)?;
|
||||
|
||||
// All redis commands begin with the command name as a string. The name
|
||||
// is read and converted to lower casae in order to do case sensitive
|
||||
// matching.
|
||||
let command_name = parse.next_string()?.to_lowercase();
|
||||
|
||||
// Match the command name, delegating the rest of the parsing to the
|
||||
// specific command.
|
||||
let command = match &command_name[..] {
|
||||
"get" => Command::Get(Get::parse_frames(&mut parse)?),
|
||||
"publish" => Command::Publish(Publish::parse_frames(&mut parse)?),
|
||||
@@ -38,15 +59,29 @@ impl Command {
|
||||
"subscribe" => Command::Subscribe(Subscribe::parse_frames(&mut parse)?),
|
||||
"unsubscribe" => Command::Unsubscribe(Unsubscribe::parse_frames(&mut parse)?),
|
||||
_ => {
|
||||
parse.next_string()?;
|
||||
Command::Unknown(Unknown::new(command_name))
|
||||
// The command is not recognized and an Unknown command is
|
||||
// returned.
|
||||
//
|
||||
// `return` is called here to skip the `finish()` call below. As
|
||||
// the command is not recognized, there is most likely
|
||||
// unconsumed fields remaining in the `Parse` instance.
|
||||
return Ok(Command::Unknown(Unknown::new(command_name)));
|
||||
},
|
||||
};
|
||||
|
||||
// Check if there is any remaining unconsumed fields in the `Parse`
|
||||
// value. If fields remain, this indicates an unexpected frame format
|
||||
// and an error is returned.
|
||||
parse.finish()?;
|
||||
|
||||
// The command has been successfully parsed
|
||||
Ok(command)
|
||||
}
|
||||
|
||||
/// Apply the command to the specified `Db` instance.
|
||||
///
|
||||
/// The response is written to `dst`. This is called by the server in order
|
||||
/// to execute a received command.
|
||||
pub(crate) async fn apply(
|
||||
self,
|
||||
db: &Db,
|
||||
@@ -63,10 +98,11 @@ impl Command {
|
||||
Unknown(cmd) => cmd.apply(dst).await,
|
||||
// `Unsubscribe` cannot be applied. It may only be received from the
|
||||
// context of a `Subscribe` command.
|
||||
Unsubscribe(_) => unimplemented!(),
|
||||
Unsubscribe(_) => Err("`Unsubscribe` is unsupported in this context".into()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the command name
|
||||
pub(crate) fn get_name(&self) -> &str {
|
||||
match self {
|
||||
Command::Get(_) => "get",
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use crate::{Connection, Db, Frame, Parse, ParseError};
|
||||
use crate::{Connection, Db, Frame, Parse};
|
||||
|
||||
use bytes::Bytes;
|
||||
|
||||
@@ -9,13 +9,17 @@ pub struct Publish {
|
||||
}
|
||||
|
||||
impl Publish {
|
||||
pub(crate) fn parse_frames(parse: &mut Parse) -> Result<Publish, ParseError> {
|
||||
pub(crate) fn parse_frames(parse: &mut Parse) -> crate::Result<Publish> {
|
||||
let channel = parse.next_string()?;
|
||||
let message = parse.next_bytes()?;
|
||||
|
||||
Ok(Publish { channel, message })
|
||||
}
|
||||
|
||||
/// Apply the `Publish` command to the specified `Db` instance.
|
||||
///
|
||||
/// The response is written to `dst`. This is called by the server in order
|
||||
/// to execute a received command.
|
||||
pub(crate) async fn apply(self, db: &Db, dst: &mut Connection) -> crate::Result<()> {
|
||||
// Set the value
|
||||
let num_subscribers = db.publish(&self.channel, self.message);
|
||||
|
||||
@@ -5,57 +5,127 @@ use bytes::Bytes;
|
||||
use std::time::Duration;
|
||||
use tracing::{debug, instrument};
|
||||
|
||||
/// Set `key` to hold the string `value`.
|
||||
///
|
||||
/// If `key` already holds a value, it is overwritten, regardless of its type.
|
||||
/// Any previous time to live associated with the key is discarded on successful
|
||||
/// SET operation.
|
||||
///
|
||||
/// # Options
|
||||
///
|
||||
/// Currently, the following options are supported:
|
||||
///
|
||||
/// * EX `seconds` -- Set the specified expire time, in seconds.
|
||||
/// * PX `milliseconds` -- Set the specified expire time, in milliseconds.
|
||||
#[derive(Debug)]
|
||||
pub struct Set {
|
||||
/// the lookup key
|
||||
pub(crate) key: String,
|
||||
key: String,
|
||||
|
||||
/// the value to be stored
|
||||
pub(crate) value: Bytes,
|
||||
value: Bytes,
|
||||
|
||||
/// When to expire the key
|
||||
pub(crate) expire: Option<Duration>,
|
||||
expire: Option<Duration>,
|
||||
}
|
||||
|
||||
impl Set {
|
||||
#[instrument]
|
||||
pub(crate) fn parse_frames(parse: &mut Parse) -> Result<Set, ParseError> {
|
||||
/// Create a new `Set` command which sets `key` to `value`.
|
||||
///
|
||||
/// If `expire` is `Some`, the value should expire after the specified
|
||||
/// duration.
|
||||
pub(crate) fn new(key: impl ToString, value: Bytes, expire: Option<Duration>) -> Set {
|
||||
Set {
|
||||
key: key.to_string(),
|
||||
value,
|
||||
expire,
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse a `Set` instance from received data.
|
||||
///
|
||||
/// The `Parse` argument provides a cursor like API to read fields from a
|
||||
/// received `Frame`. At this point, the data has already been received from
|
||||
/// the socket.
|
||||
///
|
||||
/// The `SET` string has already been consumed.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Returns the `Set` value on success. If the frame is malformed, `Err` is
|
||||
/// returned.
|
||||
///
|
||||
/// # Format
|
||||
///
|
||||
/// Expects an array frame containing at least 3 entries.
|
||||
///
|
||||
/// ```text
|
||||
/// SET key value [EX seconds|PX milliseconds]
|
||||
/// ```
|
||||
pub(crate) fn parse_frames(parse: &mut Parse) -> crate::Result<Set> {
|
||||
use ParseError::EndOfStream;
|
||||
|
||||
// Read the key to set. This is a required field
|
||||
let key = parse.next_string()?;
|
||||
|
||||
// Read the value to set. This is a required field.
|
||||
let value = parse.next_bytes()?;
|
||||
|
||||
// The expiration is optional. If nothing else follows, then it is
|
||||
// `None`.
|
||||
let mut expire = None;
|
||||
|
||||
// Attempt to parse another string.
|
||||
match parse.next_string() {
|
||||
Ok(s) if s == "EX" => {
|
||||
// An expiration is specified in seconds. The next value is an
|
||||
// integer.
|
||||
let secs = parse.next_int()?;
|
||||
expire = Some(Duration::from_secs(secs));
|
||||
}
|
||||
Ok(s) if s == "PX" => {
|
||||
// An expiration is specified in milliseconds. The next value is
|
||||
// an integer.
|
||||
let ms = parse.next_int()?;
|
||||
expire = Some(Duration::from_millis(ms));
|
||||
}
|
||||
Ok(_) => unimplemented!(),
|
||||
// Currently, mini-redis does not support any of the other SET
|
||||
// options. An error here results in the connection being
|
||||
// terminated. Other connections will continue to operate normally.
|
||||
Ok(_) => return Err("currently `SET` only supports the expiration option".into()),
|
||||
// The `EndOfStream` error indicates there is no further data to
|
||||
// parse. In this case, it is a normal run time situation and
|
||||
// indicates there are no specified `SET` options.
|
||||
Err(EndOfStream) => {}
|
||||
Err(err) => return Err(err),
|
||||
// All other errors are bubbled up, resulting in the connection
|
||||
// being terminated.
|
||||
Err(err) => return Err(err.into()),
|
||||
}
|
||||
|
||||
debug!(?key, ?value, ?expire);
|
||||
|
||||
Ok(Set { key, value, expire })
|
||||
}
|
||||
|
||||
#[instrument(skip(db))]
|
||||
/// Apply the `Get` command to the specified `Db` instace.
|
||||
///
|
||||
/// The response is written to `dst`. This is called by the server in order
|
||||
/// to execute a received command.
|
||||
#[instrument(skip(self, db, dst))]
|
||||
pub(crate) async fn apply(self, db: &Db, dst: &mut Connection) -> crate::Result<()> {
|
||||
// Set the value
|
||||
// Set the value in the shared database state.
|
||||
db.set(self.key, self.value, self.expire);
|
||||
|
||||
// Create a success response and write it to `dst`.
|
||||
let response = Frame::Simple("OK".to_string());
|
||||
debug!(?response);
|
||||
dst.write_frame(&response).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Converts the command into an equivalent `Frame`.
|
||||
///
|
||||
/// This is called by the client when encoding a `Set` command to send to
|
||||
/// the server.
|
||||
pub(crate) fn into_frame(self) -> Frame {
|
||||
let mut frame = Frame::array();
|
||||
frame.push_bulk(Bytes::from("set".as_bytes()));
|
||||
|
||||
@@ -16,7 +16,7 @@ pub struct Unsubscribe {
|
||||
}
|
||||
|
||||
impl Subscribe {
|
||||
pub(crate) fn parse_frames(parse: &mut Parse) -> Result<Subscribe, ParseError> {
|
||||
pub(crate) fn parse_frames(parse: &mut Parse) -> crate::Result<Subscribe> {
|
||||
use ParseError::EndOfStream;
|
||||
|
||||
// There must be at least one channel
|
||||
@@ -26,15 +26,14 @@ impl Subscribe {
|
||||
match parse.next_string() {
|
||||
Ok(s) => channels.push(s),
|
||||
Err(EndOfStream) => break,
|
||||
Err(err) => return Err(err),
|
||||
Err(err) => return Err(err.into()),
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Subscribe { channels })
|
||||
}
|
||||
|
||||
/// Implements the "subscribe" half of Redis' Pub/Sub feature documented
|
||||
/// [here].
|
||||
/// Apply the `Subscribe` command to the specified `Db` instance.
|
||||
///
|
||||
/// This function is the entry point and includes the initial list of
|
||||
/// channels to subscribe to. Additional `subscribe` and `unsubscribe`
|
||||
|
||||
172
src/db.rs
172
src/db.rs
@@ -5,32 +5,77 @@ use bytes::Bytes;
|
||||
use std::collections::{BTreeMap, HashMap};
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
/// Server state shared across all connections.
|
||||
///
|
||||
/// `Db` contains a `HashMap` storing the key/value data and all
|
||||
/// `broadcast::Sender` values for active pub/sub channels.
|
||||
///
|
||||
/// A `Db` instance is a handle to shared state. Cloning `Db` is shallow and
|
||||
/// only incurs an atomic ref count increment.
|
||||
///
|
||||
/// When a `Db` value is created, a background task is spawned. This task is
|
||||
/// used to expire values after the requested duration has elapsed. The task
|
||||
/// runs until all instances of `Db` are dropped, at which point the task
|
||||
/// terminates.
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct Db {
|
||||
/// Handle to shared state. The background task will also have an
|
||||
/// `Arc<Shared>`.
|
||||
shared: Arc<Shared>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Shared {
|
||||
/// The shared state is guarded by a mutex. This is a `std::sync::Mutex` and
|
||||
/// not a Tokio mutex. This is because there are no asynchronous operations
|
||||
/// being performed while holding the mutex. Additionally, the critical
|
||||
/// sections are very small.
|
||||
///
|
||||
/// A Tokio mutex is mostly intended to be used when locks need to be held
|
||||
/// across `.await` yield points. All other cases are **usually** best
|
||||
/// served by a std mutex. If the critical section does not include any
|
||||
/// async operations but is long (CPU intensive or performing blocking
|
||||
/// operations), then the entire operation, including waiting for the mutex,
|
||||
/// is considered a "blocking" operation and `tokio::task::spawn_blocking`
|
||||
/// should be used.
|
||||
state: Mutex<State>,
|
||||
|
||||
/// Notifies the task handling entry expiration
|
||||
expire_task: Notify,
|
||||
/// Notifies the background task handling entry expiration. The background
|
||||
/// task waits on this to be notified, then checks for expired values or the
|
||||
/// shutdown signal.
|
||||
background_task: Notify,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct State {
|
||||
/// The key-value data
|
||||
/// The key-value data. We are not trying to do anything fancy so a
|
||||
/// `std::collections::HashMap` works fine.
|
||||
entries: HashMap<String, Entry>,
|
||||
|
||||
/// The pub/sub key-space
|
||||
/// The pub/sub key-space. Redis uses a **separate** key space for key-value
|
||||
/// and pub/sub. `mini-redis` handles this by using a separate `HashMap`.
|
||||
pub_sub: HashMap<String, broadcast::Sender<Bytes>>,
|
||||
|
||||
/// Tracks key TTLs.
|
||||
///
|
||||
/// A `BTreeMap` is used to maintain expirations sorted by when they expire.
|
||||
/// This allows the background task to iterate this map to find the value
|
||||
/// expiring next.
|
||||
///
|
||||
/// While highly unlikely, it is possible for more than one expiration to be
|
||||
/// created for the same instant. Because of this, the `Instant` is
|
||||
/// insufficient for the key. A unique expiration identifier (`u64`) is used
|
||||
/// to break these ties.
|
||||
expirations: BTreeMap<(Instant, u64), String>,
|
||||
|
||||
/// Identifier to use for the next expiration.
|
||||
/// Identifier to use for the next expiration. Each expiration is associated
|
||||
/// with a unique identifier. See above for why.
|
||||
next_id: u64,
|
||||
|
||||
/// True when the Db instance is shutting down. This happens when all `Db`
|
||||
/// values drop. Setting this to `true` signals to the background task to
|
||||
/// exit.
|
||||
shutdown: bool,
|
||||
}
|
||||
|
||||
/// Entry in the key-value store
|
||||
@@ -48,6 +93,8 @@ struct Entry {
|
||||
}
|
||||
|
||||
impl Db {
|
||||
/// Create a new, empty, `Db` instance. Allocates shared state and spawns a
|
||||
/// background task to manage key expiration.
|
||||
pub(crate) fn new() -> Db {
|
||||
let shared = Arc::new(Shared {
|
||||
state: Mutex::new(State {
|
||||
@@ -55,8 +102,9 @@ impl Db {
|
||||
pub_sub: HashMap::new(),
|
||||
expirations: BTreeMap::new(),
|
||||
next_id: 0,
|
||||
shutdown: false,
|
||||
}),
|
||||
expire_task: Notify::new(),
|
||||
background_task: Notify::new(),
|
||||
});
|
||||
|
||||
// Start the background task.
|
||||
@@ -65,22 +113,41 @@ impl Db {
|
||||
Db { shared }
|
||||
}
|
||||
|
||||
/// Get the value associated with a key.
|
||||
///
|
||||
/// Returns `None` if there is no value associated with the key. This may be
|
||||
/// due to never having assigned a value to the key or a previously assigned
|
||||
/// value expired.
|
||||
pub(crate) fn get(&self, key: &str) -> Option<Bytes> {
|
||||
// Acquire the lock, get the entry and clone the value.
|
||||
//
|
||||
// Because data is stored using `Bytes`, a clone here is a shallow
|
||||
// clone. Data is not copied.
|
||||
let state = self.shared.state.lock().unwrap();
|
||||
state.entries.get(key).map(|entry| entry.data.clone())
|
||||
}
|
||||
|
||||
/// Set the value associated with a key along with an optional expiration
|
||||
/// Duration.
|
||||
///
|
||||
/// If a value is already associated with the key, it is removed.
|
||||
pub(crate) fn set(&self, key: String, value: Bytes, expire: Option<Duration>) {
|
||||
let mut state = self.shared.state.lock().unwrap();
|
||||
|
||||
// Get and increment the next insertion ID.
|
||||
// Get and increment the next insertion ID. Guarded by the lock, this
|
||||
// ensures a unique identifier is associated with each `set` operation.
|
||||
let id = state.next_id;
|
||||
state.next_id += 1;
|
||||
|
||||
// By default, no notification is needed
|
||||
// If this `set` becomes the key that expires **next**, the background
|
||||
// task needs to be notified so it can update its state.
|
||||
//
|
||||
// Whether or not the task needs to be notified is computed during the
|
||||
// `set` routine.
|
||||
let mut notify = false;
|
||||
|
||||
let expires_at = expire.map(|duration| {
|
||||
// `Instant` at which the key expires.
|
||||
let when = Instant::now() + duration;
|
||||
|
||||
// Only notify the worker task if the newly inserted expiration is the
|
||||
@@ -91,11 +158,12 @@ impl Db {
|
||||
.map(|expiration| expiration > when)
|
||||
.unwrap_or(true);
|
||||
|
||||
// Track the expiration.
|
||||
state.expirations.insert((when, id), key.clone());
|
||||
when
|
||||
});
|
||||
|
||||
// Insert the entry.
|
||||
// Insert the entry into the `HashMap`.
|
||||
let prev = state.entries.insert(
|
||||
key,
|
||||
Entry {
|
||||
@@ -105,6 +173,9 @@ impl Db {
|
||||
},
|
||||
);
|
||||
|
||||
// If there was a value previously associated with the key **and** it
|
||||
// had an expiration time. The associated entry in the `expirations` map
|
||||
// must also be removed. This avoids leaking data.
|
||||
if let Some(prev) = prev {
|
||||
if let Some(when) = prev.expires_at {
|
||||
// clear expiration
|
||||
@@ -112,22 +183,45 @@ impl Db {
|
||||
}
|
||||
}
|
||||
|
||||
// Release the mutex before notifying the background task. This helps
|
||||
// reduce contention by avoiding the background task waking up only to
|
||||
// be unable to acquire the mutex due to this function still holding it.
|
||||
drop(state);
|
||||
|
||||
if notify {
|
||||
self.shared.expire_task.notify();
|
||||
// Finally, only notify the background task if it needs to update
|
||||
// its state to reflect a new expiration.
|
||||
self.shared.background_task.notify();
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a `Receiver` for the requested channel.
|
||||
///
|
||||
/// The returned `Receiver` is used to receive values broadcast by `PUBLISH`
|
||||
/// commands.
|
||||
pub(crate) fn subscribe(&self, key: String) -> broadcast::Receiver<Bytes> {
|
||||
use std::collections::hash_map::Entry;
|
||||
|
||||
// Acquire the mutex
|
||||
let mut state = self.shared.state.lock().unwrap();
|
||||
|
||||
// If there is no entry for the requested channel, then create a new
|
||||
// broadcast channel and associate it with the key. If one already
|
||||
// exists, return an associated receiver.
|
||||
match state.pub_sub.entry(key) {
|
||||
Entry::Occupied(e) => e.get().subscribe(),
|
||||
Entry::Vacant(e) => {
|
||||
let (tx, rx) = broadcast::channel(1028);
|
||||
// No broadcast channel exists yet, so create one.
|
||||
//
|
||||
// The channel is created with a capacity of `1024` messages. A
|
||||
// message is stored in the channel until **all** subscribers
|
||||
// have seen it. This means that a slow subscriber could result
|
||||
// in messages being held indefinitely.
|
||||
//
|
||||
// When the channel's capacity fills up, publishing will result
|
||||
// in old messages being dropped. This prevents slow consumers
|
||||
// from blocking the entire system.
|
||||
let (tx, rx) = broadcast::channel(1024);
|
||||
e.insert(tx);
|
||||
rx
|
||||
}
|
||||
@@ -152,10 +246,41 @@ impl Db {
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for Db {
|
||||
fn drop(&mut self) {
|
||||
// If this is the last active `Db` instance, the background task must be
|
||||
// notified to shut down.
|
||||
//
|
||||
// First, determine if this is the last `Db` instance. This is done by
|
||||
// checking `strong_count`. The count will be 2. One for this `Db`
|
||||
// intance and one for the handle held by the background task.
|
||||
if Arc::strong_count(&self.shared) == 2 {
|
||||
// The background task must be signaled to shutdown. This is done by
|
||||
// setting `State::shutdown` to `true` and signalling the task.
|
||||
let mut state = self.shared.state.lock().unwrap();
|
||||
state.shutdown = true;
|
||||
|
||||
// Drop the lock before signalling the background task. This helps
|
||||
// reduce lock contention by ensuring the background task doesn't
|
||||
// wake up only to be unable to acquire the mutex.
|
||||
drop(state);
|
||||
self.shared.background_task.notify();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Shared {
|
||||
/// Purge all expired keys and return the `Instant` at which the **next**
|
||||
/// key will expire. The background task will sleep until this instant.
|
||||
fn purge_expired_keys(&self) -> Option<Instant> {
|
||||
let mut state = self.state.lock().unwrap();
|
||||
|
||||
if state.shutdown {
|
||||
// The database is shutting down. All handles to the shared state
|
||||
// have dropped. The background task should exit.
|
||||
return None;
|
||||
}
|
||||
|
||||
// This is needed to make the borrow checker happy. In short, `lock()`
|
||||
// returns a `MutexGuard` and not a `&mut State`. The borrow checker is
|
||||
// not able to see "through" the mutex guard and determine that it is
|
||||
@@ -180,6 +305,14 @@ impl Shared {
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Returns `true` if the database is shutting down
|
||||
///
|
||||
/// The `shutdown` flag is set when all `Db` values have dropped, indicating
|
||||
/// that the shared state can no longer be accessed.
|
||||
fn is_shutdown(&self) -> bool {
|
||||
self.state.lock().unwrap().shutdown
|
||||
}
|
||||
}
|
||||
|
||||
impl State {
|
||||
@@ -191,18 +324,29 @@ impl State {
|
||||
}
|
||||
}
|
||||
|
||||
/// Routine executed by the background task.
|
||||
///
|
||||
/// Wait to be notified. On notification, purge any expired keys from the shared
|
||||
/// state handle. If `shutdown` is set, terminate the task.
|
||||
async fn purge_expired_tasks(shared: Arc<Shared>) {
|
||||
loop {
|
||||
// If the shutdown flag is set, then the task should exit.
|
||||
while !shared.is_shutdown() {
|
||||
// Purge all keys that are expired. The function returns the instant at
|
||||
// which the **next** key will expire. The worker should wait until the
|
||||
// instant has passed then purge again.
|
||||
if let Some(when) = shared.purge_expired_keys() {
|
||||
// Wait until the next key expires **or** until the background task
|
||||
// is notified. If the task is notified, then it must reload its
|
||||
// state as new keys have been set to expire early. This is done by
|
||||
// looping.
|
||||
tokio::select! {
|
||||
_ = time::delay_until(when) => {}
|
||||
_ = shared.expire_task.notified() => {}
|
||||
_ = shared.background_task.notified() => {}
|
||||
}
|
||||
} else {
|
||||
shared.expire_task.notified().await;
|
||||
// There are no keys expiring in the future. Wait until the task is
|
||||
// notified.
|
||||
shared.background_task.notified().await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user