server: implement key expiration (#13)
This commit is contained in:
201
src/db.rs
Normal file
201
src/db.rs
Normal file
@@ -0,0 +1,201 @@
|
||||
use tokio::sync::{broadcast, Notify};
|
||||
use tokio::time::{self, Duration, Instant};
|
||||
|
||||
use bytes::Bytes;
|
||||
use std::collections::{BTreeMap, HashMap};
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct Db {
|
||||
shared: Arc<Shared>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Shared {
|
||||
state: Mutex<State>,
|
||||
|
||||
/// Notifies the task handling entry expiration
|
||||
expire_task: Notify,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct State {
|
||||
/// The key-value data
|
||||
entries: HashMap<String, Entry>,
|
||||
|
||||
/// The pub/sub key-space
|
||||
pub_sub: HashMap<String, broadcast::Sender<Bytes>>,
|
||||
|
||||
/// Tracks key TTLs.
|
||||
expirations: BTreeMap<(Instant, u64), String>,
|
||||
|
||||
/// Identifier to use for the next expiration.
|
||||
next_id: u64,
|
||||
}
|
||||
|
||||
/// Entry in the key-value store
|
||||
#[derive(Debug)]
|
||||
struct Entry {
|
||||
/// Uniquely identifies this entry.
|
||||
id: u64,
|
||||
|
||||
/// Stored data
|
||||
data: Bytes,
|
||||
|
||||
/// Instant at which the entry expires and should be removed from the
|
||||
/// database.
|
||||
expires_at: Option<Instant>,
|
||||
}
|
||||
|
||||
impl Db {
|
||||
pub(crate) fn new() -> Db {
|
||||
let shared = Arc::new(Shared {
|
||||
state: Mutex::new(State {
|
||||
entries: HashMap::new(),
|
||||
pub_sub: HashMap::new(),
|
||||
expirations: BTreeMap::new(),
|
||||
next_id: 0,
|
||||
}),
|
||||
expire_task: Notify::new(),
|
||||
});
|
||||
|
||||
// Start the background task.
|
||||
tokio::spawn(purge_expired_tasks(shared.clone()));
|
||||
|
||||
Db { shared }
|
||||
}
|
||||
|
||||
pub(crate) fn get(&self, key: &str) -> Option<Bytes> {
|
||||
let state = self.shared.state.lock().unwrap();
|
||||
state.entries.get(key).map(|entry| entry.data.clone())
|
||||
}
|
||||
|
||||
pub(crate) fn set(&self, key: String, value: Bytes, expire: Option<Duration>) {
|
||||
let mut state = self.shared.state.lock().unwrap();
|
||||
|
||||
// Get and increment the next insertion ID.
|
||||
let id = state.next_id;
|
||||
state.next_id += 1;
|
||||
|
||||
// By default, no notification is needed
|
||||
let mut notify = false;
|
||||
|
||||
let expires_at = expire.map(|duration| {
|
||||
let when = Instant::now() + duration;
|
||||
|
||||
// Only notify the worker task if the newly inserted expiration is the
|
||||
// **next** key to evict. In this case, the worker needs to be woken up
|
||||
// to update its state.
|
||||
notify = state.next_expiration()
|
||||
.map(|expiration| expiration > when)
|
||||
.unwrap_or(true);
|
||||
|
||||
state.expirations.insert((when, id), key.clone());
|
||||
when
|
||||
});
|
||||
|
||||
// Insert the entry.
|
||||
let prev = state.entries.insert(key, Entry {
|
||||
id,
|
||||
data: value,
|
||||
expires_at,
|
||||
});
|
||||
|
||||
if let Some(prev) = prev {
|
||||
if let Some(when) = prev.expires_at {
|
||||
// clear expiration
|
||||
state.expirations.remove(&(when, prev.id));
|
||||
}
|
||||
}
|
||||
|
||||
drop(state);
|
||||
|
||||
if notify {
|
||||
self.shared.expire_task.notify();
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn subscribe(&self, key: String) -> broadcast::Receiver<Bytes> {
|
||||
use std::collections::hash_map::Entry;
|
||||
|
||||
let mut state = self.shared.state.lock().unwrap();
|
||||
|
||||
match state.pub_sub.entry(key) {
|
||||
Entry::Occupied(e) => e.get().subscribe(),
|
||||
Entry::Vacant(e) => {
|
||||
let (tx, rx) = broadcast::channel(1028);
|
||||
e.insert(tx);
|
||||
rx
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Publish a message to the channel. Returns the number of subscribers
|
||||
/// listening on the channel.
|
||||
pub(crate) fn publish(&self, key: &str, value: Bytes) -> usize {
|
||||
let state = self.shared.state.lock().unwrap();
|
||||
|
||||
state
|
||||
.pub_sub
|
||||
.get(key)
|
||||
// On a successful message send on the broadcast channel, the number
|
||||
// of subscribers is returned. An error indicates there are no
|
||||
// receivers, in which case, `0` should be returned.
|
||||
.map(|tx| tx.send(value).unwrap_or(0))
|
||||
// If there is no entry for the channel key, then there are no
|
||||
// subscribers. In this case, return `0`.
|
||||
.unwrap_or(0)
|
||||
}
|
||||
}
|
||||
|
||||
impl Shared {
|
||||
fn purge_expired_keys(&self) -> Option<Instant> {
|
||||
let mut state = self.state.lock().unwrap();
|
||||
|
||||
// This is needed to make the borrow checker happy. In short, `lock()`
|
||||
// returns a `MutexGuard` and not a `&mut State`. The borrow checker is
|
||||
// not able to see "through" the mutex guard and determine that it is
|
||||
// safe to access both `state.expirations` and `state.entries` mutably,
|
||||
// so we get a "real" mutable reference to `State` outside of the loop.
|
||||
let state = &mut *state;
|
||||
|
||||
// Find all keys scheduled to expire **before** now.
|
||||
let now = Instant::now();
|
||||
|
||||
while let Some((&(when, id), key)) = state.expirations.iter().next() {
|
||||
if when > now {
|
||||
// Done purging, `when` is the instant at which the next key
|
||||
// expires. The worker task will wait until this instant.
|
||||
return Some(when);
|
||||
}
|
||||
|
||||
// The key expired, remove it
|
||||
state.entries.remove(key);
|
||||
state.expirations.remove(&(when, id));
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
impl State {
|
||||
fn next_expiration(&self) -> Option<Instant> {
|
||||
self.expirations.keys().next().map(|expiration| expiration.0)
|
||||
}
|
||||
}
|
||||
|
||||
async fn purge_expired_tasks(shared: Arc<Shared>) {
|
||||
loop {
|
||||
// Purge all keys that are expired. The function returns the instant at
|
||||
// which the **next** key will expire. The worker should wait until the
|
||||
// instant has passed then purge again.
|
||||
if let Some(when) = shared.purge_expired_keys() {
|
||||
tokio::select! {
|
||||
_ = time::delay_until(when) => {}
|
||||
_ = shared.expire_task.notified() => {}
|
||||
}
|
||||
} else {
|
||||
shared.expire_task.notified().await;
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user