diff --git a/src/bin/server.rs b/src/bin/server.rs index ca85288..2f76ad7 100644 --- a/src/bin/server.rs +++ b/src/bin/server.rs @@ -24,7 +24,9 @@ pub async fn main() -> mini_redis::Result<()> { // Bind a TCP listener let listener = TcpListener::bind(&format!("127.0.0.1:{}", port)).await?; - server::run(listener, signal::ctrl_c()).await + server::run(listener, signal::ctrl_c()).await; + + Ok(()) } #[derive(StructOpt, Debug)] diff --git a/src/db.rs b/src/db.rs index 2e0a0c7..07e33a2 100644 --- a/src/db.rs +++ b/src/db.rs @@ -4,6 +4,17 @@ use tokio::time::{self, Duration, Instant}; use bytes::Bytes; use std::collections::{BTreeMap, HashMap}; use std::sync::{Arc, Mutex}; +use tracing::debug; + +/// A wrapper around a `Db` instance. This exists to allow orderly cleanup +/// of the `Db` by signalling the background purge task to shut down when +/// this struct is dropped. +#[derive(Debug)] +pub(crate) struct DbDropGuard { + /// The `Db` instance that will be shut down when this `DbHolder` struct + /// is dropped. + db: Db, +} /// Server state shared across all connections. /// @@ -92,6 +103,27 @@ struct Entry { expires_at: Option, } +impl DbDropGuard { + /// Create a new `DbHolder`, wrapping a `Db` instance. When this is dropped + /// the `Db`'s purge task will be shut down. + pub(crate) fn new() -> DbDropGuard { + DbDropGuard { db: Db::new() } + } + + /// Get the shared database. Internally, this is an + /// `Arc`, so a clone only increments the ref count. + pub(crate) fn db(&self) -> Db { + self.db.clone() + } +} + +impl Drop for DbDropGuard { + fn drop(&mut self) { + // Signal the 'Db' instance to shut down the task that purges expired keys + self.db.shutdown_purge_task(); + } +} + impl Db { /// Create a new, empty, `Db` instance. Allocates shared state and spawns a /// background task to manage key expiration. @@ -244,28 +276,20 @@ impl Db { // subscribers. In this case, return `0`. .unwrap_or(0) } -} -impl Drop for Db { - fn drop(&mut self) { - // If this is the last active `Db` instance, the background task must be - // notified to shut down. - // - // First, determine if this is the last `Db` instance. This is done by - // checking `strong_count`. The count will be 2. One for this `Db` - // instance and one for the handle held by the background task. - if Arc::strong_count(&self.shared) == 2 { - // The background task must be signaled to shutdown. This is done by - // setting `State::shutdown` to `true` and signalling the task. - let mut state = self.shared.state.lock().unwrap(); - state.shutdown = true; + /// Signals the purge background task to shut down. This is called by the + /// `DbShutdown`s `Drop` implementation. + fn shutdown_purge_task(&self) { + // The background task must be signaled to shut down. This is done by + // setting `State::shutdown` to `true` and signalling the task. + let mut state = self.shared.state.lock().unwrap(); + state.shutdown = true; - // Drop the lock before signalling the background task. This helps - // reduce lock contention by ensuring the background task doesn't - // wake up only to be unable to acquire the mutex. - drop(state); - self.shared.background_task.notify_one(); - } + // Drop the lock before signalling the background task. This helps + // reduce lock contention by ensuring the background task doesn't + // wake up only to be unable to acquire the mutex. + drop(state); + self.shared.background_task.notify_one(); } } @@ -349,4 +373,6 @@ async fn purge_expired_tasks(shared: Arc) { shared.background_task.notified().await; } } + + debug!("Purge background task shut down") } diff --git a/src/lib.rs b/src/lib.rs index e12588e..48c472b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -39,6 +39,7 @@ pub use frame::Frame; mod db; use db::Db; +use db::DbDropGuard; mod parse; use parse::{Parse, ParseError}; diff --git a/src/server.rs b/src/server.rs index e71dbcd..05dbed4 100644 --- a/src/server.rs +++ b/src/server.rs @@ -3,7 +3,7 @@ //! Provides an async `run` function that listens for inbound connections, //! spawning a task per connection. -use crate::{Command, Connection, Db, Shutdown}; +use crate::{Command, Connection, Db, DbDropGuard, Shutdown}; use std::future::Future; use std::sync::Arc; @@ -21,9 +21,9 @@ struct Listener { /// Contains the key / value store as well as the broadcast channels for /// pub/sub. /// - /// This is a wrapper around an `Arc`. This enables `db` to be cloned and - /// passed into the per connection state (`Handler`). - db: Db, + /// This holds a wrapper around an `Arc`. The internal `Db` can be + /// retrieved and passed into the per connection state (`Handler`). + db_holder: DbDropGuard, /// TCP listener supplied by the `run` caller. listener: TcpListener, @@ -128,7 +128,7 @@ const MAX_CONNECTIONS: usize = 250; /// /// `tokio::signal::ctrl_c()` can be used as the `shutdown` argument. This will /// listen for a SIGINT signal. -pub async fn run(listener: TcpListener, shutdown: impl Future) -> crate::Result<()> { +pub async fn run(listener: TcpListener, shutdown: impl Future) { // When the provided `shutdown` future completes, we must send a shutdown // message to all active connections. We use a broadcast channel for this // purpose. The call below ignores the receiver of the broadcast pair, and when @@ -140,7 +140,7 @@ pub async fn run(listener: TcpListener, shutdown: impl Future) -> crate::Result< // Initialize the listener state let mut server = Listener { listener, - db: Db::new(), + db_holder: DbDropGuard::new(), limit_connections: Arc::new(Semaphore::new(MAX_CONNECTIONS)), notify_shutdown, shutdown_complete_tx, @@ -193,6 +193,7 @@ pub async fn run(listener: TcpListener, shutdown: impl Future) -> crate::Result< notify_shutdown, .. } = server; + // When `notify_shutdown` is dropped, all tasks which have `subscribe`d will // receive the shutdown signal and can exit drop(notify_shutdown); @@ -204,8 +205,6 @@ pub async fn run(listener: TcpListener, shutdown: impl Future) -> crate::Result< // `Sender` instances are held by connection handler tasks. When those drop, // the `mpsc` channel will close and `recv()` will return `None`. let _ = shutdown_complete_rx.recv().await; - - Ok(()) } impl Listener { @@ -250,9 +249,8 @@ impl Listener { // Create the necessary per-connection handler state. let mut handler = Handler { - // Get a handle to the shared database. Internally, this is an - // `Arc`, so a clone only increments the ref count. - db: self.db.clone(), + // Get a handle to the shared database. + db: self.db_holder.db(), // Initialize the connection state. This allocates read/write // buffers to perform redis protocol frame parsing. diff --git a/tests/buffer.rs b/tests/buffer.rs index 7b0d852..823b720 100644 --- a/tests/buffer.rs +++ b/tests/buffer.rs @@ -20,7 +20,7 @@ async fn pool_key_value_get_set() { assert_eq!(b"world", &value[..]) } -async fn start_server() -> (SocketAddr, JoinHandle>) { +async fn start_server() -> (SocketAddr, JoinHandle<()>) { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = listener.local_addr().unwrap(); diff --git a/tests/client.rs b/tests/client.rs index fb19a5a..e2e7b42 100644 --- a/tests/client.rs +++ b/tests/client.rs @@ -82,7 +82,7 @@ async fn unsubscribes_from_channels() { assert_eq!(subscriber.get_subscribed().len(), 0); } -async fn start_server() -> (SocketAddr, JoinHandle>) { +async fn start_server() -> (SocketAddr, JoinHandle<()>) { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = listener.local_addr().unwrap();