diff --git a/src/cmd/subscribe.rs b/src/cmd/subscribe.rs index aa20170..1b53e4d 100644 --- a/src/cmd/subscribe.rs +++ b/src/cmd/subscribe.rs @@ -4,6 +4,7 @@ use crate::{Command, Connection, Db, Frame, Shutdown}; use bytes::Bytes; use tokio::select; use tokio::stream::{StreamExt, StreamMap}; +use tokio::sync::broadcast; /// Subscribes the client to one or more channels. /// @@ -112,21 +113,8 @@ impl Subscribe { // `self.channels` is used to track additional channels to subscribe // to. When new `SUBSCRIBE` commands are received during the // execution of `apply`, the new channels are pushed onto this vec. - for channel in self.channels.drain(..) { - // Build response frame to respond to the client with. - let mut response = Frame::array(); - response.push_bulk(Bytes::from_static(b"subscribe")); - response.push_bulk(Bytes::copy_from_slice(channel.as_bytes())); - response.push_int(subscriptions.len().saturating_add(1) as u64); - - // Subscribe to channel - let rx = db.subscribe(channel.clone()); - - // Track subscription in this client's subscription set. - subscriptions.insert(channel, rx); - - // Respond with the successful subscription - dst.write_frame(&response).await?; + for channel_name in self.channels.drain(..) { + subscribe_to_channel(channel_name, &mut subscriptions, db, dst).await?; } // Wait for one of the following to happen: @@ -136,7 +124,7 @@ impl Subscribe { // - A server shutdown signal. select! { // Receive messages from subscribed channels - Some((channel, msg)) = subscriptions.next() => { + Some((channel_name, msg)) = subscriptions.next() => { use tokio::sync::broadcast::RecvError; let msg = match msg { @@ -145,60 +133,22 @@ impl Subscribe { Err(RecvError::Closed) => unreachable!(), }; - let mut response = Frame::array(); - response.push_bulk(Bytes::from_static(b"message")); - response.push_bulk(Bytes::copy_from_slice(channel.as_bytes())); - response.push_bulk(msg); - - dst.write_frame(&response).await?; + dst.write_frame(&make_message_frame(channel_name, msg)).await?; } res = dst.read_frame() => { let frame = match res? { Some(frame) => frame, - // How to handle remote client closing write half? + // This happens if the remote client has disconnected. None => return Ok(()) }; - // A command has been received from the client. - // - // Only `SUBSCRIBE` and `UNSUBSCRIBE` commands are permitted - // in this context. - match Command::from_frame(frame)? { - Command::Subscribe(subscribe) => { - // Subscribe to the channels on next iteration - self.channels.extend(subscribe.channels.into_iter()); - } - Command::Unsubscribe(mut unsubscribe) => { - // If no channels are specified, this requests - // unsubscribing from **all** channels. To implement - // this, the `unsubscribe.channels` vec is populated - // with the list of channels currently subscribed - // to. - if unsubscribe.channels.is_empty() { - unsubscribe.channels = subscriptions - .keys() - .map(|channel| channel.to_string()) - .collect(); - } - - for channel in unsubscribe.channels.drain(..) { - subscriptions.remove(&channel); - - let mut response = Frame::array(); - response.push_bulk(Bytes::from_static(b"unsubscribe")); - response.push_bulk(Bytes::copy_from_slice(channel.as_bytes())); - response.push_int(subscriptions.len() as u64); - - dst.write_frame(&response).await?; - } - } - command => { - let cmd = Unknown::new(command.get_name()); - cmd.apply(dst).await?; - } - } + handle_command( + frame, + &mut self.channels, + &mut subscriptions, + dst, + ).await?; } - // Receive additional commands from the client _ = shutdown.recv() => { return Ok(()); } @@ -220,6 +170,106 @@ impl Subscribe { } } +async fn subscribe_to_channel( + channel_name: String, + subscriptions: &mut StreamMap>, + db: &Db, + dst: &mut Connection, +) -> crate::Result<()> { + // Subscribe to the channel. + let rx = db.subscribe(channel_name.clone()); + + // Track subscription in this client's subscription set. + subscriptions.insert(channel_name.clone(), rx); + + // Respond with the successful subscription + let response = make_subscribe_frame(channel_name, subscriptions.len()); + dst.write_frame(&response).await?; + + Ok(()) +} + +/// Handle a command received while inside `Subscribe::apply`. Only subscribe +/// and unsubscribe commands are permitted in this context. +/// +/// Any new subscriptions are appended to `subscribe_to` instead of modifying +/// `subscriptions`. +async fn handle_command( + frame: Frame, + subscribe_to: &mut Vec, + subscriptions: &mut StreamMap>, + dst: &mut Connection, +) -> crate::Result<()> { + // A command has been received from the client. + // + // Only `SUBSCRIBE` and `UNSUBSCRIBE` commands are permitted + // in this context. + match Command::from_frame(frame)? { + Command::Subscribe(subscribe) => { + // The `apply` method will subscribe to the channels we add to this + // vector. + subscribe_to.extend(subscribe.channels.into_iter()); + } + Command::Unsubscribe(mut unsubscribe) => { + // If no channels are specified, this requests unsubscribing from + // **all** channels. To implement this, the `unsubscribe.channels` + // vec is populated with the list of channels currently subscribed + // to. + if unsubscribe.channels.is_empty() { + unsubscribe.channels = subscriptions + .keys() + .map(|channel_name| channel_name.to_string()) + .collect(); + } + + for channel_name in unsubscribe.channels { + subscriptions.remove(&channel_name); + + let response = make_unsubscribe_frame(channel_name, subscriptions.len()); + dst.write_frame(&response).await?; + } + } + command => { + let cmd = Unknown::new(command.get_name()); + cmd.apply(dst).await?; + } + } + Ok(()) +} + +/// Creates the response to a subcribe request. +/// +/// All of these functions take the `channel_name` as a `String` instead of +/// a `&str` since `Bytes::from` can reuse the allocation in the `String`, and +/// taking a `&str` would require copying the data. This allows the caller to +/// decide whether to clone the channel name or not. +fn make_subscribe_frame(channel_name: String, num_subs: usize) -> Frame { + let mut response = Frame::array(); + response.push_bulk(Bytes::from_static(b"subscribe")); + response.push_bulk(Bytes::from(channel_name)); + response.push_int(num_subs as u64); + response +} + +/// Creates the response to an unsubcribe request. +fn make_unsubscribe_frame(channel_name: String, num_subs: usize) -> Frame { + let mut response = Frame::array(); + response.push_bulk(Bytes::from_static(b"unsubscribe")); + response.push_bulk(Bytes::from(channel_name)); + response.push_int(num_subs as u64); + response +} + +/// Creates a message informing the client about a new message on a channel that +/// the client subscribes to. +fn make_message_frame(channel_name: String, msg: Bytes) -> Frame { + let mut response = Frame::array(); + response.push_bulk(Bytes::from_static(b"message")); + response.push_bulk(Bytes::from(channel_name)); + response.push_bulk(msg); + response +} + impl Unsubscribe { /// Create a new `Unsubscribe` command with the given `channels`. pub(crate) fn new(channels: &[String]) -> Unsubscribe {