Fix race condition in shutdown of background task (#81)
This commit is contained in:
@@ -24,7 +24,9 @@ pub async fn main() -> mini_redis::Result<()> {
|
||||
// Bind a TCP listener
|
||||
let listener = TcpListener::bind(&format!("127.0.0.1:{}", port)).await?;
|
||||
|
||||
server::run(listener, signal::ctrl_c()).await
|
||||
server::run(listener, signal::ctrl_c()).await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(StructOpt, Debug)]
|
||||
|
||||
66
src/db.rs
66
src/db.rs
@@ -4,6 +4,17 @@ use tokio::time::{self, Duration, Instant};
|
||||
use bytes::Bytes;
|
||||
use std::collections::{BTreeMap, HashMap};
|
||||
use std::sync::{Arc, Mutex};
|
||||
use tracing::debug;
|
||||
|
||||
/// A wrapper around a `Db` instance. This exists to allow orderly cleanup
|
||||
/// of the `Db` by signalling the background purge task to shut down when
|
||||
/// this struct is dropped.
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct DbDropGuard {
|
||||
/// The `Db` instance that will be shut down when this `DbHolder` struct
|
||||
/// is dropped.
|
||||
db: Db,
|
||||
}
|
||||
|
||||
/// Server state shared across all connections.
|
||||
///
|
||||
@@ -92,6 +103,27 @@ struct Entry {
|
||||
expires_at: Option<Instant>,
|
||||
}
|
||||
|
||||
impl DbDropGuard {
|
||||
/// Create a new `DbHolder`, wrapping a `Db` instance. When this is dropped
|
||||
/// the `Db`'s purge task will be shut down.
|
||||
pub(crate) fn new() -> DbDropGuard {
|
||||
DbDropGuard { db: Db::new() }
|
||||
}
|
||||
|
||||
/// Get the shared database. Internally, this is an
|
||||
/// `Arc`, so a clone only increments the ref count.
|
||||
pub(crate) fn db(&self) -> Db {
|
||||
self.db.clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for DbDropGuard {
|
||||
fn drop(&mut self) {
|
||||
// Signal the 'Db' instance to shut down the task that purges expired keys
|
||||
self.db.shutdown_purge_task();
|
||||
}
|
||||
}
|
||||
|
||||
impl Db {
|
||||
/// Create a new, empty, `Db` instance. Allocates shared state and spawns a
|
||||
/// background task to manage key expiration.
|
||||
@@ -244,28 +276,20 @@ impl Db {
|
||||
// subscribers. In this case, return `0`.
|
||||
.unwrap_or(0)
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for Db {
|
||||
fn drop(&mut self) {
|
||||
// If this is the last active `Db` instance, the background task must be
|
||||
// notified to shut down.
|
||||
//
|
||||
// First, determine if this is the last `Db` instance. This is done by
|
||||
// checking `strong_count`. The count will be 2. One for this `Db`
|
||||
// instance and one for the handle held by the background task.
|
||||
if Arc::strong_count(&self.shared) == 2 {
|
||||
// The background task must be signaled to shutdown. This is done by
|
||||
// setting `State::shutdown` to `true` and signalling the task.
|
||||
let mut state = self.shared.state.lock().unwrap();
|
||||
state.shutdown = true;
|
||||
/// Signals the purge background task to shut down. This is called by the
|
||||
/// `DbShutdown`s `Drop` implementation.
|
||||
fn shutdown_purge_task(&self) {
|
||||
// The background task must be signaled to shut down. This is done by
|
||||
// setting `State::shutdown` to `true` and signalling the task.
|
||||
let mut state = self.shared.state.lock().unwrap();
|
||||
state.shutdown = true;
|
||||
|
||||
// Drop the lock before signalling the background task. This helps
|
||||
// reduce lock contention by ensuring the background task doesn't
|
||||
// wake up only to be unable to acquire the mutex.
|
||||
drop(state);
|
||||
self.shared.background_task.notify_one();
|
||||
}
|
||||
// Drop the lock before signalling the background task. This helps
|
||||
// reduce lock contention by ensuring the background task doesn't
|
||||
// wake up only to be unable to acquire the mutex.
|
||||
drop(state);
|
||||
self.shared.background_task.notify_one();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -349,4 +373,6 @@ async fn purge_expired_tasks(shared: Arc<Shared>) {
|
||||
shared.background_task.notified().await;
|
||||
}
|
||||
}
|
||||
|
||||
debug!("Purge background task shut down")
|
||||
}
|
||||
|
||||
@@ -39,6 +39,7 @@ pub use frame::Frame;
|
||||
|
||||
mod db;
|
||||
use db::Db;
|
||||
use db::DbDropGuard;
|
||||
|
||||
mod parse;
|
||||
use parse::{Parse, ParseError};
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
//! Provides an async `run` function that listens for inbound connections,
|
||||
//! spawning a task per connection.
|
||||
|
||||
use crate::{Command, Connection, Db, Shutdown};
|
||||
use crate::{Command, Connection, Db, DbDropGuard, Shutdown};
|
||||
|
||||
use std::future::Future;
|
||||
use std::sync::Arc;
|
||||
@@ -21,9 +21,9 @@ struct Listener {
|
||||
/// Contains the key / value store as well as the broadcast channels for
|
||||
/// pub/sub.
|
||||
///
|
||||
/// This is a wrapper around an `Arc`. This enables `db` to be cloned and
|
||||
/// passed into the per connection state (`Handler`).
|
||||
db: Db,
|
||||
/// This holds a wrapper around an `Arc`. The internal `Db` can be
|
||||
/// retrieved and passed into the per connection state (`Handler`).
|
||||
db_holder: DbDropGuard,
|
||||
|
||||
/// TCP listener supplied by the `run` caller.
|
||||
listener: TcpListener,
|
||||
@@ -128,7 +128,7 @@ const MAX_CONNECTIONS: usize = 250;
|
||||
///
|
||||
/// `tokio::signal::ctrl_c()` can be used as the `shutdown` argument. This will
|
||||
/// listen for a SIGINT signal.
|
||||
pub async fn run(listener: TcpListener, shutdown: impl Future) -> crate::Result<()> {
|
||||
pub async fn run(listener: TcpListener, shutdown: impl Future) {
|
||||
// When the provided `shutdown` future completes, we must send a shutdown
|
||||
// message to all active connections. We use a broadcast channel for this
|
||||
// purpose. The call below ignores the receiver of the broadcast pair, and when
|
||||
@@ -140,7 +140,7 @@ pub async fn run(listener: TcpListener, shutdown: impl Future) -> crate::Result<
|
||||
// Initialize the listener state
|
||||
let mut server = Listener {
|
||||
listener,
|
||||
db: Db::new(),
|
||||
db_holder: DbDropGuard::new(),
|
||||
limit_connections: Arc::new(Semaphore::new(MAX_CONNECTIONS)),
|
||||
notify_shutdown,
|
||||
shutdown_complete_tx,
|
||||
@@ -193,6 +193,7 @@ pub async fn run(listener: TcpListener, shutdown: impl Future) -> crate::Result<
|
||||
notify_shutdown,
|
||||
..
|
||||
} = server;
|
||||
|
||||
// When `notify_shutdown` is dropped, all tasks which have `subscribe`d will
|
||||
// receive the shutdown signal and can exit
|
||||
drop(notify_shutdown);
|
||||
@@ -204,8 +205,6 @@ pub async fn run(listener: TcpListener, shutdown: impl Future) -> crate::Result<
|
||||
// `Sender` instances are held by connection handler tasks. When those drop,
|
||||
// the `mpsc` channel will close and `recv()` will return `None`.
|
||||
let _ = shutdown_complete_rx.recv().await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
impl Listener {
|
||||
@@ -250,9 +249,8 @@ impl Listener {
|
||||
|
||||
// Create the necessary per-connection handler state.
|
||||
let mut handler = Handler {
|
||||
// Get a handle to the shared database. Internally, this is an
|
||||
// `Arc`, so a clone only increments the ref count.
|
||||
db: self.db.clone(),
|
||||
// Get a handle to the shared database.
|
||||
db: self.db_holder.db(),
|
||||
|
||||
// Initialize the connection state. This allocates read/write
|
||||
// buffers to perform redis protocol frame parsing.
|
||||
|
||||
@@ -20,7 +20,7 @@ async fn pool_key_value_get_set() {
|
||||
assert_eq!(b"world", &value[..])
|
||||
}
|
||||
|
||||
async fn start_server() -> (SocketAddr, JoinHandle<mini_redis::Result<()>>) {
|
||||
async fn start_server() -> (SocketAddr, JoinHandle<()>) {
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let addr = listener.local_addr().unwrap();
|
||||
|
||||
|
||||
@@ -82,7 +82,7 @@ async fn unsubscribes_from_channels() {
|
||||
assert_eq!(subscriber.get_subscribed().len(), 0);
|
||||
}
|
||||
|
||||
async fn start_server() -> (SocketAddr, JoinHandle<mini_redis::Result<()>>) {
|
||||
async fn start_server() -> (SocketAddr, JoinHandle<()>) {
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let addr = listener.local_addr().unwrap();
|
||||
|
||||
|
||||
Reference in New Issue
Block a user