Fix race condition in shutdown of background task (#81)
This commit is contained in:
@@ -24,7 +24,9 @@ pub async fn main() -> mini_redis::Result<()> {
|
|||||||
// Bind a TCP listener
|
// Bind a TCP listener
|
||||||
let listener = TcpListener::bind(&format!("127.0.0.1:{}", port)).await?;
|
let listener = TcpListener::bind(&format!("127.0.0.1:{}", port)).await?;
|
||||||
|
|
||||||
server::run(listener, signal::ctrl_c()).await
|
server::run(listener, signal::ctrl_c()).await;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(StructOpt, Debug)]
|
#[derive(StructOpt, Debug)]
|
||||||
|
|||||||
66
src/db.rs
66
src/db.rs
@@ -4,6 +4,17 @@ use tokio::time::{self, Duration, Instant};
|
|||||||
use bytes::Bytes;
|
use bytes::Bytes;
|
||||||
use std::collections::{BTreeMap, HashMap};
|
use std::collections::{BTreeMap, HashMap};
|
||||||
use std::sync::{Arc, Mutex};
|
use std::sync::{Arc, Mutex};
|
||||||
|
use tracing::debug;
|
||||||
|
|
||||||
|
/// A wrapper around a `Db` instance. This exists to allow orderly cleanup
|
||||||
|
/// of the `Db` by signalling the background purge task to shut down when
|
||||||
|
/// this struct is dropped.
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub(crate) struct DbDropGuard {
|
||||||
|
/// The `Db` instance that will be shut down when this `DbHolder` struct
|
||||||
|
/// is dropped.
|
||||||
|
db: Db,
|
||||||
|
}
|
||||||
|
|
||||||
/// Server state shared across all connections.
|
/// Server state shared across all connections.
|
||||||
///
|
///
|
||||||
@@ -92,6 +103,27 @@ struct Entry {
|
|||||||
expires_at: Option<Instant>,
|
expires_at: Option<Instant>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl DbDropGuard {
|
||||||
|
/// Create a new `DbHolder`, wrapping a `Db` instance. When this is dropped
|
||||||
|
/// the `Db`'s purge task will be shut down.
|
||||||
|
pub(crate) fn new() -> DbDropGuard {
|
||||||
|
DbDropGuard { db: Db::new() }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the shared database. Internally, this is an
|
||||||
|
/// `Arc`, so a clone only increments the ref count.
|
||||||
|
pub(crate) fn db(&self) -> Db {
|
||||||
|
self.db.clone()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for DbDropGuard {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
// Signal the 'Db' instance to shut down the task that purges expired keys
|
||||||
|
self.db.shutdown_purge_task();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl Db {
|
impl Db {
|
||||||
/// Create a new, empty, `Db` instance. Allocates shared state and spawns a
|
/// Create a new, empty, `Db` instance. Allocates shared state and spawns a
|
||||||
/// background task to manage key expiration.
|
/// background task to manage key expiration.
|
||||||
@@ -244,28 +276,20 @@ impl Db {
|
|||||||
// subscribers. In this case, return `0`.
|
// subscribers. In this case, return `0`.
|
||||||
.unwrap_or(0)
|
.unwrap_or(0)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
impl Drop for Db {
|
/// Signals the purge background task to shut down. This is called by the
|
||||||
fn drop(&mut self) {
|
/// `DbShutdown`s `Drop` implementation.
|
||||||
// If this is the last active `Db` instance, the background task must be
|
fn shutdown_purge_task(&self) {
|
||||||
// notified to shut down.
|
// The background task must be signaled to shut down. This is done by
|
||||||
//
|
// setting `State::shutdown` to `true` and signalling the task.
|
||||||
// First, determine if this is the last `Db` instance. This is done by
|
let mut state = self.shared.state.lock().unwrap();
|
||||||
// checking `strong_count`. The count will be 2. One for this `Db`
|
state.shutdown = true;
|
||||||
// instance and one for the handle held by the background task.
|
|
||||||
if Arc::strong_count(&self.shared) == 2 {
|
|
||||||
// The background task must be signaled to shutdown. This is done by
|
|
||||||
// setting `State::shutdown` to `true` and signalling the task.
|
|
||||||
let mut state = self.shared.state.lock().unwrap();
|
|
||||||
state.shutdown = true;
|
|
||||||
|
|
||||||
// Drop the lock before signalling the background task. This helps
|
// Drop the lock before signalling the background task. This helps
|
||||||
// reduce lock contention by ensuring the background task doesn't
|
// reduce lock contention by ensuring the background task doesn't
|
||||||
// wake up only to be unable to acquire the mutex.
|
// wake up only to be unable to acquire the mutex.
|
||||||
drop(state);
|
drop(state);
|
||||||
self.shared.background_task.notify_one();
|
self.shared.background_task.notify_one();
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -349,4 +373,6 @@ async fn purge_expired_tasks(shared: Arc<Shared>) {
|
|||||||
shared.background_task.notified().await;
|
shared.background_task.notified().await;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
debug!("Purge background task shut down")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -39,6 +39,7 @@ pub use frame::Frame;
|
|||||||
|
|
||||||
mod db;
|
mod db;
|
||||||
use db::Db;
|
use db::Db;
|
||||||
|
use db::DbDropGuard;
|
||||||
|
|
||||||
mod parse;
|
mod parse;
|
||||||
use parse::{Parse, ParseError};
|
use parse::{Parse, ParseError};
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
//! Provides an async `run` function that listens for inbound connections,
|
//! Provides an async `run` function that listens for inbound connections,
|
||||||
//! spawning a task per connection.
|
//! spawning a task per connection.
|
||||||
|
|
||||||
use crate::{Command, Connection, Db, Shutdown};
|
use crate::{Command, Connection, Db, DbDropGuard, Shutdown};
|
||||||
|
|
||||||
use std::future::Future;
|
use std::future::Future;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
@@ -21,9 +21,9 @@ struct Listener {
|
|||||||
/// Contains the key / value store as well as the broadcast channels for
|
/// Contains the key / value store as well as the broadcast channels for
|
||||||
/// pub/sub.
|
/// pub/sub.
|
||||||
///
|
///
|
||||||
/// This is a wrapper around an `Arc`. This enables `db` to be cloned and
|
/// This holds a wrapper around an `Arc`. The internal `Db` can be
|
||||||
/// passed into the per connection state (`Handler`).
|
/// retrieved and passed into the per connection state (`Handler`).
|
||||||
db: Db,
|
db_holder: DbDropGuard,
|
||||||
|
|
||||||
/// TCP listener supplied by the `run` caller.
|
/// TCP listener supplied by the `run` caller.
|
||||||
listener: TcpListener,
|
listener: TcpListener,
|
||||||
@@ -128,7 +128,7 @@ const MAX_CONNECTIONS: usize = 250;
|
|||||||
///
|
///
|
||||||
/// `tokio::signal::ctrl_c()` can be used as the `shutdown` argument. This will
|
/// `tokio::signal::ctrl_c()` can be used as the `shutdown` argument. This will
|
||||||
/// listen for a SIGINT signal.
|
/// listen for a SIGINT signal.
|
||||||
pub async fn run(listener: TcpListener, shutdown: impl Future) -> crate::Result<()> {
|
pub async fn run(listener: TcpListener, shutdown: impl Future) {
|
||||||
// When the provided `shutdown` future completes, we must send a shutdown
|
// When the provided `shutdown` future completes, we must send a shutdown
|
||||||
// message to all active connections. We use a broadcast channel for this
|
// message to all active connections. We use a broadcast channel for this
|
||||||
// purpose. The call below ignores the receiver of the broadcast pair, and when
|
// purpose. The call below ignores the receiver of the broadcast pair, and when
|
||||||
@@ -140,7 +140,7 @@ pub async fn run(listener: TcpListener, shutdown: impl Future) -> crate::Result<
|
|||||||
// Initialize the listener state
|
// Initialize the listener state
|
||||||
let mut server = Listener {
|
let mut server = Listener {
|
||||||
listener,
|
listener,
|
||||||
db: Db::new(),
|
db_holder: DbDropGuard::new(),
|
||||||
limit_connections: Arc::new(Semaphore::new(MAX_CONNECTIONS)),
|
limit_connections: Arc::new(Semaphore::new(MAX_CONNECTIONS)),
|
||||||
notify_shutdown,
|
notify_shutdown,
|
||||||
shutdown_complete_tx,
|
shutdown_complete_tx,
|
||||||
@@ -193,6 +193,7 @@ pub async fn run(listener: TcpListener, shutdown: impl Future) -> crate::Result<
|
|||||||
notify_shutdown,
|
notify_shutdown,
|
||||||
..
|
..
|
||||||
} = server;
|
} = server;
|
||||||
|
|
||||||
// When `notify_shutdown` is dropped, all tasks which have `subscribe`d will
|
// When `notify_shutdown` is dropped, all tasks which have `subscribe`d will
|
||||||
// receive the shutdown signal and can exit
|
// receive the shutdown signal and can exit
|
||||||
drop(notify_shutdown);
|
drop(notify_shutdown);
|
||||||
@@ -204,8 +205,6 @@ pub async fn run(listener: TcpListener, shutdown: impl Future) -> crate::Result<
|
|||||||
// `Sender` instances are held by connection handler tasks. When those drop,
|
// `Sender` instances are held by connection handler tasks. When those drop,
|
||||||
// the `mpsc` channel will close and `recv()` will return `None`.
|
// the `mpsc` channel will close and `recv()` will return `None`.
|
||||||
let _ = shutdown_complete_rx.recv().await;
|
let _ = shutdown_complete_rx.recv().await;
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Listener {
|
impl Listener {
|
||||||
@@ -250,9 +249,8 @@ impl Listener {
|
|||||||
|
|
||||||
// Create the necessary per-connection handler state.
|
// Create the necessary per-connection handler state.
|
||||||
let mut handler = Handler {
|
let mut handler = Handler {
|
||||||
// Get a handle to the shared database. Internally, this is an
|
// Get a handle to the shared database.
|
||||||
// `Arc`, so a clone only increments the ref count.
|
db: self.db_holder.db(),
|
||||||
db: self.db.clone(),
|
|
||||||
|
|
||||||
// Initialize the connection state. This allocates read/write
|
// Initialize the connection state. This allocates read/write
|
||||||
// buffers to perform redis protocol frame parsing.
|
// buffers to perform redis protocol frame parsing.
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ async fn pool_key_value_get_set() {
|
|||||||
assert_eq!(b"world", &value[..])
|
assert_eq!(b"world", &value[..])
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn start_server() -> (SocketAddr, JoinHandle<mini_redis::Result<()>>) {
|
async fn start_server() -> (SocketAddr, JoinHandle<()>) {
|
||||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
let addr = listener.local_addr().unwrap();
|
let addr = listener.local_addr().unwrap();
|
||||||
|
|
||||||
|
|||||||
@@ -82,7 +82,7 @@ async fn unsubscribes_from_channels() {
|
|||||||
assert_eq!(subscriber.get_subscribed().len(), 0);
|
assert_eq!(subscriber.get_subscribed().len(), 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn start_server() -> (SocketAddr, JoinHandle<mini_redis::Result<()>>) {
|
async fn start_server() -> (SocketAddr, JoinHandle<()>) {
|
||||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
let addr = listener.local_addr().unwrap();
|
let addr = listener.local_addr().unwrap();
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user