Discord bots are cool, right? Haven't you ever wanted to make your own?
Well, today we're going to do exactly that, but this time we're using Rust! We're going to build a Discord bot that allows you to setup polls, and updates counts in real-time!
Depending on how well you know Rust from an intermediate to advanced level, this workshop could take anywhere from 40 minutes to an hour to complete. Don't let that scare you, though! You'll learn a lot of new concepts about how to structure complex applications in Rust along the way :)
Prerequisites
For this workshop, I do recommend an intermediate understanding of low-level concepts such as memory management, and some experience with Rust as well. Here are the concepts in particular that I recommend you have a good understanding of:
- Lifetimes, borrowing and move semantics
- How Rust structures code (i.e. Cargo projects)
- The general idea behind macros (not the exact syntax but a good idea of their general purpose)
- Basic Rust knowledge (I recommend the Rust book for this!)
Demo
Here's a demo of the polling bot in action:
The full code can be viewed here. Alternatively, you can open the below section for a full listing.
Full code
main.rs
:
fn main() { std::process::Command::new("cargo").arg("run").status().unwrap(); }
Cargo.toml
:
[package]
name = "polling-bot"
version = "0.1.0"
authors = ["runner"]
edition = "2018"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
serenity = "0.9.1"
tokio = { version = "^0.2.23", features = ["macros"] }
[[bin]]
name = "polling-bot"
path = "real_main.rs"
real_main.rs
:
use serenity::async_trait;
use serenity::framework::standard::{
macros::{command, group},
Args, CommandResult, StandardFramework,
};
use serenity::model::{
channel::{Message, Reaction},
gateway::Ready,
};
use serenity::prelude::*;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::Mutex;
use serenity::model::channel::ReactionType;
use serenity::model::id::{MessageId, ChannelId};
enum ReactionEvent<'a> {
Reaction(&'a Reaction),
RemoveAll(ChannelId, MessageId),
}
macro_rules! perform_reaction {
(($ctx:expr, $reaction_event:expr) $body:expr) => {
use ReactionEvent::*;
// Discard if it's our own reaction.
if let Reaction(r) = $reaction_event {
if r.user_id == Some($ctx.cache.current_user_id().await) {
println!("Reaction added by self, ignoring");
return;
}
}
let key = match $reaction_event {
Reaction(r) => (r.channel_id, r.message_id),
RemoveAll(c, m) => (c, m),
};
// Try to get poll for the given message otherwise return
{
let poll_data = $ctx.data.read().await;
let poll_map = poll_data
.get::<PollsKey>()
.expect("Failed to retrieve polls map!")
.lock()
.await;
if !poll_map.contains_key(&key) {
println!("Message not in polls map, ignoring");
return;
}
}
// reretrieve the map as writable
let mut poll_data = $ctx.data.write().await;
let mut poll_map = poll_data
.get_mut::<PollsKey>()
.expect("Failed to retrieve polls map!")
.lock()
.await;
let poll = match poll_map.get_mut(&key) {
None => {
println!("Failed to get poll for {:?}", key);
return;
}
Some(poll) => poll,
};
// nudges Rust towards the right type :)
fn get_f<F: FnOnce(&mut Poll, Option<usize>)>(f: F) -> F {
f
}
let f = get_f($body);
match $reaction_event {
Reaction(r) => match r.emoji {
ReactionType::Unicode(ref s) => {
let c = s.chars().nth(0).unwrap();
let end_char = std::char::from_u32('🇦' as u32 + poll.answers.len() as u32 - 1)
.expect("Failed to format emoji");
if c < '🇦' || c > end_char {
println!("Emoji is not regional indicator or is not in range, ignoring");
return;
}
let number = (c as u32 - '🇦' as u32) as usize;
f(poll, Some(number));
}
_ => {
println!("Unknown emoji in reaction, ignoring");
return;
}
},
RemoveAll(..) => f(poll, None),
}
let content = render_message(&poll);
key.0
.edit_message(&$ctx.http, key.1, |edit| edit.content(&content))
.await
.expect("Failed to edit message");
println!("Rerendered message");
};
}
struct Handler;
#[async_trait]
impl EventHandler for Handler {
async fn ready(&self, _: Context, ready: Ready) {
println!("Bot ready with username {}", ready.user.name);
}
async fn reaction_add(&self, ctx: Context, add_reaction: Reaction) {
println!("Reaction add");
perform_reaction! { (ctx, ReactionEvent::Reaction(&add_reaction)) |poll, number| {
poll.answerers[number.unwrap()] += 1;
}}
}
async fn reaction_remove(&self, ctx: Context, removed_reaction: Reaction) {
println!("Single reaction remove");
perform_reaction! { (ctx, ReactionEvent::Reaction(&removed_reaction)) |poll, number| {
poll.answerers[number.unwrap()] -= 1;
}}
}
async fn reaction_remove_all(&self, ctx: Context, channel_id: ChannelId, removed_from_message_id: MessageId) {
println!("All reactions removed");
perform_reaction! { (ctx, ReactionEvent::RemoveAll(channel_id, removed_from_message_id)) |poll, _| {
for answers in poll.answerers.iter_mut() {
*answers = 0;
}
}}
}
}
fn render_message(poll: &Poll) -> String {
let mut message_text = format!("**Poll:** {}\n", poll.question);
let total_answerers = poll.answerers.iter().sum::<usize>();
for (i, (answer, &num)) in poll.answers.iter().zip(poll.answerers.iter()).enumerate() {
let emoji = std::char::from_u32('🇦' as u32 + i as u32).expect("Failed to format emoji");
message_text.push(emoji);
if total_answerers > 0 {
let percent = num as f64 / total_answerers as f64 * 100.;
message_text.push_str(&format!(" {:.0}%", percent));
}
message_text.push(' ');
message_text.push_str(answer);
message_text.push_str(&format!(" ({} votes)", num));
message_text.push('\n');
}
message_text
}
struct PollsKey;
impl TypeMapKey for PollsKey {
type Value = Arc<Mutex<PollsMap>>;
}
type PollsMap = HashMap<(ChannelId, MessageId), Poll>;
struct Poll {
pub question: String,
pub answers: Vec<String>,
pub answerers: Vec<usize>,
}
#[group]
#[commands(poll)]
struct General;
#[command]
async fn poll(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult {
let question = args.single_quoted::<String>()?;
let answers = args
.quoted()
.iter::<String>()
.filter_map(|x| x.ok())
.collect::<Vec<_>>();
let answers_len = answers.len();
let poll = Poll {
question: question,
answerers: vec![0; answers_len],
answers: answers,
};
let message_text = render_message(&poll);
let emojis = (0..answers_len)
.map(|i| std::char::from_u32('🇦' as u32 + i as u32).expect("Failed to format emoji"))
.collect::<Vec<_>>();
let poll_msg = msg.channel_id.say(&ctx.http, &message_text).await?;
for &emoji in &emojis {
poll_msg
.react(&ctx.http, ReactionType::Unicode(emoji.to_string()))
.await?;
}
let mut poll_data = ctx.data.write().await;
let poll_map = poll_data
.get_mut::<PollsKey>()
.expect("Failed to retrieve polls map!");
poll_map
.lock()
.await
.insert((msg.channel_id, poll_msg.id), poll);
Ok(())
}
#[tokio::main]
async fn main() {
let token = std::env::var("DISCORD_TOKEN").expect("Expected DISCORD_TOKEN to be set!");
let framework = StandardFramework::new()
.configure(|c| c.case_insensitivity(true))
.group(&GENERAL_GROUP);
let mut client = Client::builder(&token)
.event_handler(Handler)
.framework(framework)
.type_map_insert::<PollsKey>(Arc::new(Mutex::new(PollsMap::new())))
.await
.expect("Failed to build client");
if let Err(why) = client.start().await {
println!("Client error: {:?}", why);
}
}
Getting started
We're going to host our Discord bot on repl.it.
To get started, create an account. I personally recommend you sign in with your GitHub account if you have one, but email is fine too.
Now let's create a new Rust project by going to https://repl.it/languages/rust:
Adding the library and setting up a basic bot
Right now, our Rust program isn't a Cargo project. Cargo is Rust's package manager. Without it, we wouldn't be able to easily depend on libraries. So since our program isn't a Cargo project, it won't be able to use the serenity
library!
Let's fix that by running cargo init --name polling-bot
in the terminal. This command initializes a new Cargo project for us. You can replace polling-bot
with whatever you want your program to be named. This should create two new files on the side: .gitignore
and Cargo.toml
, which we'll be using to add the Serenity library:
P.S. Here's a hint when working with Repl.it: If you ever get a disk quota exceeded
error just delete the target
directory and try again. Additionally, delete the target
directory when you're done playing around with your bot or else it'll take up a bunch of space!
Next, create a new file called real_main.rs
and put this code in it:
fn main() {
// todo
}
Replace the contents of main.rs
with this:
fn main() { std::process::Command::new("cargo").arg("run").status().unwrap(); }
Magic?!?!?!?
This line seems kinda magic but it's just running cargo run
from a Rust program. Repl.it is kind of weird in this way since it doesn't natively support Cargo projects but it mostly works if we do this. If you're running this code locally, you can skip this and the real_main.rs
and just write all your code in main.rs
.
This is kind of a hack but it's necessary to make sure that the environment variables are passed through correctly. You'll be doing all your coding in real_main.rs
.
One last thing: Update the path to the source file in Cargo.toml
, changing main.rs
to real_main.rs
. Your Cargo.toml should look like this:
What is Serenity?
Serenity is a Rust crate (or library) that helps you write Discord bots in Rust. If you've heard of Discord.py for Python, or Discord.js for JavaScript, you can kind of think of Serenity like that, except for Rust.
Adding the bot token
Now, we need to create a new bot in the Discord Developer Portal. Click "New Application" in the top right corner, highlighted in red here:
Give your bot a nice name (I'm using "Polling Bot" for this tutorial), then hit the create button!
Now, go to the "Bot" section of your application, highlighted in red here:
Click the "Add Bot" button to enable the bot for this application, highlighted in red here:
If you want to, feel free to rename the bot's username or give it an avatar. I'm skipping that here since it's up to you to add your own creative touch ✨
Alright, now that we've setup our bot, copy the bot's token by clicking the "Copy" button next to the token field, highlighted in light cyan here:
Great, you've copied your token!
Let's put it into Repl.it. Create a new .env
file in Repl.it. The .env
file is a special file that allows you to store secrets, such as your Discord bot's token, for example.
Inside the .env
file, add a new line that looks like DISCORD_TOKEN=<token>
. Replace <token>
with the token that you previously copied. It should look like this (I've redacted my token):
Now, let's add the Serenity library! In Rust, projects manage their dependencies using Cargo.toml
, so that's where we need to add Serenity. Head over to Cargo.toml
and add this line under your [dependencies]
section: serenity = "0.9.1"
. At the time of writing, the latest version of Serenity is 0.9.1, but you can replace it with the latest version which can be found here.
We'll also need another helper library, tokio
. Add tokio = { version = "^0.2.23", features = ["macros"] }
to your Cargo.toml
as well. This library just helps us out with some async stuff, but you don't need to worry too much about it for now. Don't use the latest version of Tokio (0.3 or later) as this will cause incompatibilities with Serenity!
Your Cargo.toml
should now look like this:
Inviting your bot to a server
To invite your bot to a server, you'll need to go to the OAuth2 tab in your application (highlighted in green here):
Next, select the "bot" scope for OAuth2, highlighted in pink here:
Then, scroll down and check these permissions we'll need (highlighted in orange in the picture):
- View Channels (under General Permissions)
- Send Messages (under Text Permissions)
- Read Message History (under Text Permissions)
- Add Reactions (under Text Permissions)
Finally, copy the OAuth2 link, highlighted in brown here:
Paste it into your browser and invite your bot to a server for testing! (Preferably with other people to test the polling.) You'll need the "Manage Server" permission in order to invite the bot to a server.
A basic template
Now, let's add a basic Discord bot template. At this point, you should have invited your bot to a server for testing.
Delete the contents of real_main.rs
and replace it with this:
use serenity::async_trait;
use serenity::framework::standard::{
macros::{command, group},
Args, CommandResult, StandardFramework,
};
use serenity::model::{
channel::{Message, Reaction},
gateway::Ready,
};
use serenity::prelude::*;
struct Handler;
#[async_trait]
impl EventHandler for Handler {
async fn ready(&self, _: Context, ready: Ready) {
println!("Bot ready with username {}", ready.user.name);
}
}
#[group]
#[commands(ping)]
struct General;
#[command]
async fn ping(ctx: &Context, msg: &Message, mut _args: Args) -> CommandResult {
msg.channel_id.say(&ctx.http, "Pong!").await?;
Ok(())
}
#[tokio::main]
async fn main() {
let token = std::env::var("DISCORD_TOKEN").expect("Expected DISCORD_TOKEN to be set!");
let framework = StandardFramework::new()
.configure(|c| c.case_insensitivity(true))
.group(&GENERAL_GROUP);
let mut client = Client::builder(&token)
.event_handler(Handler)
.framework(framework)
.await
.expect("Failed to build client");
if let Err(why) = client.start().await {
println!("Client error: {:?}", why);
}
}
This is just a super simple template that has a ~ping
command which makes the bot respond with Pong!
.
By the way, I highly recommend having the Serenity docs open on the side while going through this workshop! You can search for all of the functions and structs we use in there, with very detailed explanations.
Technical details
use serenity::async_trait;
use serenity::framework::standard::{
macros::{command, group},
Args, CommandResult, StandardFramework,
};
use serenity::model::{
channel::{Message, Reaction},
gateway::Ready,
};
use serenity::prelude::*;
These are just some imports that we need.
struct Handler;
This creates a new type called Handler
which has no data. We're going to be implementing the EventHandler
trait on it so that we can handle ready events.
#[async_trait]
impl EventHandler for Handler {
async fn ready(&self, _: Context, ready: Ready) {
println!("Bot ready with username {}", ready.user.name);
}
}
This implements the EventHandler
trait for Handler
. By default all of the event handlers don't do anything, but here we override the ready
method, so that we can print our bot's username once it's ready.
Since we're defining an async fn
in the implementation, we have to use the #[async_trait]
attribute (imported above) to allow it, because currently Rust does not natively support async traits.
#[group]
#[commands(ping)]
struct General;
This sets up a command group. In Serenity, commands can only be added through command groups, so we just setup a General
command group with the ping
command.
#[command]
async fn ping(ctx: &Context, msg: &Message, mut _args: Args) -> CommandResult {
msg.channel_id.say(&ctx.http, "Pong!").await?;
Ok(())
}
This is the ping
command. We just take the command's message, get the channel it was sent in, and send Pong!
in that channel. (So, we just reply with Pong!
.)
#[tokio::main]
async fn main() {
let token = std::env::var("DISCORD_TOKEN").expect("Expected DISCORD_TOKEN to be set!");
This is the start of the main function. We mark it with the #[tokio::main]
annotation to make it into an async fn
which it is not by default.
Then, we retrieve the token from the DISCORD_TOKEN
environment variable and panic if it's not set.
let framework = StandardFramework::new()
.configure(|c| c.case_insensitivity(true))
.group(&GENERAL_GROUP);
Here we setup the Serenity standard command framework. We configure it to allow case insensitivity, so that commands can be typed like ~poll
, ~pOlL
, ~poLL
or any other combination. Then we also add the General
group of commands which includes the ping
command we defined above.
let mut client = Client::builder(&token)
.event_handler(Handler)
.framework(framework)
.await
.expect("Failed to build client");
if let Err(why) = client.start().await {
println!("Client error: {:?}", why);
}
}
Now we setup our Discord bot client with the token we got earlier, our event handler (which prints the bot's username once the ready event is received), and our command framework.
Then, we start the bot, and if there's an error after running it, we print the error.
That's all of the details!
Run the bot to make sure that everything works. It may take a while to build at first but subsequent builds will be faster.
Now it's time to actually add polling to the bot!
Representing polls in code
Before we add the poll
command, we'll need to define some types to represent polls.
Add these lines to your imports:
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::Mutex;
This just imports some things we'll be using soon.
Now let's start adding the types for polls. Add this code before struct General
in real_main.rs
:
struct PollsKey;
impl TypeMapKey for PollsKey {
type Value = Arc<Mutex<PollsMap>>;
}
Here, we create a new type PollsKey
which we'll use to retrieve the current polls. We also implement the TypeMapKey
trait for PollsKey
, which lets us use it as a key in a type map. Serenity uses type maps for data storage so that's why we have to make this type. We set the type of the key's value to be Arc<Mutex<PollsMap>>
. Arc
is atomic reference-counting, which lets us share an object across threads (this is necessary for async). Mutex
is used for exclusive access to the map, when we are changing it. Arc<Mutex<...>>
is a common pattern used to share mutable data across threads.
Next, let's define the PollsMap
type:
type PollsMap = HashMap<(ChannelId, MessageId), Poll>;
PollsMap
is just a type alias to a HashMap
with key type (ChannelId, MessageId)
(a tuple) and value type Poll
.
Finally, let's define the Poll
type:
struct Poll {
pub question: String,
pub answers: Vec<String>,
pub answerers: Vec<usize>,
}
The Poll
type just has a question, list of answers and how many people answered for each response.
Now that we've defined our type key and the type value, let's actually add that to our global data map that Serenity provides. Add this in your main
function:
#[tokio::main]
async fn main() {
// -- snip --
let mut client = Client::builder(&token)
.event_handler(Handler)
.framework(framework)
.type_map_insert::<PollsKey>(Arc::new(Mutex::new(PollsMap::new()))) // new!
.await
.expect("Failed to build client");
// -- snip --
}
This just inserts an empty polls map with the type key PollsKey
we defined earlier.
Now, let's finally create the poll
command!
Creating the poll
command
Let's define the poll
command.
Before we do anything else, we'll need to add two new imports:
use serenity::model::channel::ReactionType;
use serenity::model::id::{MessageId, ChannelId};
We'll use this a little later.
Next, remove the ping
function and change the ping
in group General
to be our new poll
command. Removing it should look something like this:
#[group]
-#[commands(ping)]
+#[commands(poll)]
struct General;
-#[command]
-async fn ping(ctx: &Context, msg: &Message, mut _args: Args) -> CommandResult {
- msg.channel_id.say(&ctx.http, "Pong!").await?;
-
- Ok(())
-}
Now, let's create the poll
command. Right under group General
, let's add this:
#[command]
async fn poll(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult {
This is basic scaffolding used for all commands in Serenity. We take a context (which we can use to send messages, etc.), the message containing the command and an Args
object allowing us to easily retrieve arguments to our command. Then we return a CommandResult
which lets us handle errors in our command easily.
What is Result<T, E>?
Rust doesn't do error handling like most other languages, where you would throw
or raise
an Exception and then catch it later. In general it doesn't work that well as a error handling model, so Rust uses Result<T, E>
instead. It's very simply defined like this:
pub enum Result<T, E> {
Ok(T),
Err(E),
}
We have a success type T
and an error type E
. It's just an enum that contains an Ok
variant and an Err
variant. This makes it really easy to handle since we can use things like match
on it and such.
But Result<T, E>
has one more superpower: the ?
operator. The ?
operator, when used in a function that returns Result
, makes it super easy to handle errors and propagate them. As an example:
fn error1() -> std::io::Result<()> {
// returns the error if the function fails
do_some_fallible_io_operation()?;
Ok(()) // default case, everything ok
}
// is equivalent to:
fn error2() -> std::io::Result<()> {
match do_some_fallible_io_operation() {
Ok(_) => Ok(()) // ok, then ok
Err(e) => Err(e), // error? propagate the error
}
}
This makes it much easier to work with errors in Rust especially with Result<T, E>
.
Note: The ?
operator works with any Result<T, E2>
where E2: Into<E>
and E
is the error type of the outer function's Result
. Also see std::ops::Try
.
Next, we're going to get the question, which will be the first argument:
let question = args.single_quoted::<String>()?;
The single_quoted
function returns one argument (delimited by quotes, if there are spaces in it) of the given type. Since we give it type String
, this accepts anything. For what the ?
operator means, you can read the above section.
Now let's get all the answers that the user provided:
let answers = args
.quoted() // 1) Enable quoting for answers with spaces
.iter::<String>() // 2) Iterate over the rest of the arguments (as Strings)
.filter_map(|x| x.ok()) // 3) Filter out any arguments that failed to parse
.collect::<Vec<_>>(); // 4) Collect all the arguments into a Vec<String>
Let's count the total number of answers (we'll need this later):
let answers_len = answers.len();
Now we can create our Poll
struct with the data that we got:
let poll = Poll {
question: question,
// no responses yet
answerers: vec![0; answers_len],
answers: answers,
};
The vec![0; answers_len]
is a shorthand way to create a Vec with length answers_len
and fill it with zeros. Since we don't have any responses yet, they should all be zero.
Now we'll have to create a fancy message for the users to respond on:
// Build the message contents
let message_text = render_message(&poll);
We'll define this function later, but for now just know that it takes a Poll
reference and returns a String
of the message contents.
We have to accumulate all the emojis to react with, so that the user can easily click to respond. We're using the "regional indicator" section of Unicode, which looks like this in Discord:
This code creates a list of all the regional indicator characters we need. For example if we have 5 total answers, we'll need regional indicators A, B, C, D and E.
let emojis = (0..answers_len)
.map(|i| std::char::from_u32('🇦' as u32 + i as u32).expect("Failed to format emoji"))
.collect::<Vec<_>>();
We take a range of 0..answers_len
(which is an Iterator), and then we transform it using the map
function. We add the regional indicator 🇦
to it, which is like an offset. Then once we have all the characters, we collect them into a Vec
to be iterated over a little later.
Why answers_len and not answers.len()?
If you tried to use answers.len()
instead of answers_len
above, you'd get an error that looks something like this:
error[E0382]: borrow of moved value: `answers`
--> src\main.rs:203:22
|
187 | let answers = args
| ------- move occurs because `answers` has type `Vec<String>`, which does not implement the `Copy` trait
...
198 | answers: answers,
| ------- value moved here
...
203 | let emojis = (0..answers.len())
| ^^^^^^^ value borrowed here after move
What this means is that we're moving the data of answers
into the Poll
, so we can't use answers
anymore since its data is invalid. Therefore, we just get the length of answers
before moving it into the Poll
so we can use it later.
Now let's actually create the message with the contents we got before:
let poll_msg = msg.channel_id.say(&ctx.http, &message_text).await?;
So we're taking the channel ID of the command's message, and sending our own message in that same channel.
Now let's add reactions for each of the answers:
for &emoji in &emojis {
poll_msg
.react(&ctx.http, ReactionType::Unicode(emoji.to_string()))
.await?;
}
So for each emoji character we're going to convert it to a string so we can add that reaction (as a Unicode emoji, since we aren't using custom emojis) to our message.
Now that we've setup the message, we need to add our new poll to the polls map. First we need to retrieve the global data
as writable:
let mut poll_data = ctx.data.write().await;
Next, we need to get the polls map by retrieving key PollsKey
:
let poll_map = poll_data
.get_mut::<PollsKey>()
.expect("Failed to retrieve polls map!");
Now, we can finally insert our poll, which is keyed by the channel & message ID (that's all we need to be unique):
poll_map
.lock()
.await
.insert((msg.channel_id, poll_msg.id), poll);
We succeeded! Let's return with a successful value:
Ok(())
}
And that's the end of the poll
command!
But... do you remember the render_message
function which we were going to define later? Let's do that.
The render_message
function
The render_message
function is pretty simple: it just takes a Poll
reference and formats it to look nice in a message. Let's start defining that right above struct PollsKey
:
fn render_message(poll: &Poll) -> String {
We're taking a Poll
reference as input (we don't need to take ownership since we are just reading it) and returning a String
of the formatted message contents.
// Build the message contents
let mut message_text = format!("**Poll:** {}\n", poll.question);
let total_answerers = poll.answerers.iter().sum::<usize>();
We start the message with a bolded Poll: then we put the question after it. We mark it as mut
since we'll add to it. Also, we create a total_answerers variable which contains the total number of responses (used for percentage calculation.)
for (i, (answer, &num)) in poll.answers.iter().zip(poll.answerers.iter()).enumerate() {
We're iterating over each answer string, how many votes it got and the number of the answer we are iterating over (used for creating emoji).
We create the emoji similarly to the way we did in the poll
command, and then add it to the message:
let emoji = std::char::from_u32('🇦' as u32 + i as u32).expect("Failed to format emoji");
// add answerers and percent
message_text.push(emoji);
If we got at least one response in total we add a percentage (if we have zero responses, we get NaN therefore we don't show it in that case):
if total_answerers > 0 {
let percent = num as f64 / total_answerers as f64 * 100.;
message_text.push_str(&format!(" ({:.0}%)", percent));
}
Lastly, we add the answer string and how many votes it got (as well as a newline):
message_text.push(' ');
message_text.push_str(answer);
message_text.push_str(&format!(" ({} votes)", num));
message_text.push('\n');
}
Now we just return the message that we've built up!
message_text
}
That's the end of the render_message
function!
Give yourself a pat on the back, we're halfway done!!
If you want to, feel free to run your program now and try out the poll
command (prefix is ~
). It won't work yet but the message should be printed.
The full code at this point can be viewed here. Alternatively, you can open the below section for a full listing.
Full code at this point
main.rs
:
fn main() { std::process::Command::new("cargo").arg("run").status().unwrap(); }
Cargo.toml
:
[package]
name = "polling-bot"
version = "0.1.0"
authors = ["runner"]
edition = "2018"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
serenity = "0.9.1"
tokio = { version = "^0.2.23", features = ["macros"] }
[[bin]]
name = "polling-bot"
path = "real_main.rs"
real_main.rs
:
use serenity::async_trait;
use serenity::framework::standard::{
macros::{command, group},
Args, CommandResult, StandardFramework,
};
use serenity::model::{
channel::{Message, Reaction},
gateway::Ready,
};
use serenity::prelude::*;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::Mutex;
use serenity::model::channel::ReactionType;
use serenity::model::id::{MessageId, ChannelId};
struct Handler;
#[async_trait]
impl EventHandler for Handler {
async fn ready(&self, _: Context, ready: Ready) {
println!("Bot ready with username {}", ready.user.name);
}
}
fn render_message(poll: &Poll) -> String {
let mut message_text = format!("**Poll:** {}\n", poll.question);
let total_answerers = poll.answerers.iter().sum::<usize>();
for (i, (answer, &num)) in poll.answers.iter().zip(poll.answerers.iter()).enumerate() {
let emoji = std::char::from_u32('🇦' as u32 + i as u32).expect("Failed to format emoji");
message_text.push(emoji);
if total_answerers > 0 {
let percent = num as f64 / total_answerers as f64 * 100.;
message_text.push_str(&format!(" {:.0}%", percent));
}
message_text.push(' ');
message_text.push_str(answer);
message_text.push_str(&format!(" ({} votes)", num));
message_text.push('\n');
}
message_text
}
struct PollsKey;
impl TypeMapKey for PollsKey {
type Value = Arc<Mutex<PollsMap>>;
}
type PollsMap = HashMap<(ChannelId, MessageId), Poll>;
struct Poll {
pub question: String,
pub answers: Vec<String>,
pub answerers: Vec<usize>,
}
#[group]
#[commands(poll)]
struct General;
#[command]
async fn poll(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult {
let question = args.single_quoted::<String>()?;
let answers = args
.quoted()
.iter::<String>()
.filter_map(|x| x.ok())
.collect::<Vec<_>>();
let answers_len = answers.len();
let poll = Poll {
question: question,
answerers: vec![0; answers_len],
answers: answers,
};
let message_text = render_message(&poll);
let emojis = (0..answers_len)
.map(|i| std::char::from_u32('🇦' as u32 + i as u32).expect("Failed to format emoji"))
.collect::<Vec<_>>();
let poll_msg = msg.channel_id.say(&ctx.http, &message_text).await?;
for &emoji in &emojis {
poll_msg
.react(&ctx.http, ReactionType::Unicode(emoji.to_string()))
.await?;
}
let mut poll_data = ctx.data.write().await;
let poll_map = poll_data
.get_mut::<PollsKey>()
.expect("Failed to retrieve polls map!");
poll_map
.lock()
.await
.insert((msg.channel_id, poll_msg.id), poll);
Ok(())
}
#[tokio::main]
async fn main() {
let token = std::env::var("DISCORD_TOKEN").expect("Expected DISCORD_TOKEN to be set!");
let framework = StandardFramework::new()
.configure(|c| c.case_insensitivity(true))
.group(&GENERAL_GROUP);
let mut client = Client::builder(&token)
.event_handler(Handler)
.framework(framework)
.type_map_insert::<PollsKey>(Arc::new(Mutex::new(PollsMap::new())))
.await
.expect("Failed to build client");
if let Err(why) = client.start().await {
println!("Client error: {:?}", why);
}
}
Alright, take a quick break and relax a little before we jump into coding the rest of our bot!
Reacting to reactions
Now that we have our bot sending properly-formatted messages, the next step is to make it actually react when someone adds a reaction.
Discord sends us 3 reaction related events which we will need to handle:
1. reaction_add, which gives us a Reaction
struct containing the channel and message ID of the reaction, as well as the emoji,
2. reaction_remove, with similar parameters to reaction_add and
3. reaction_remove_all, which doesn't give us an emoji but gives us the channel and message ID. (Emoji doesn't make sense here since every reaction is removed.)
Note that there will probably be a lot of common code shared between these events: for each one we will have to (1) validate the message it is referring to, (2) retrieve the corresponding poll object and (3) perform the action indicated by the event. Therefore, we are going to implement this using macros to reduce code duplication!
Given either a reaction or a (channel id, message id) pair we will need to extract the channel and message ID from it. Let's create an enum representing these two states, and put it right below your imports:
enum ReactionEvent<'a> {
Reaction(&'a Reaction),
RemoveAll(ChannelId, MessageId),
}
The first variant represents some kind of single-reaction event (such as add/remove), and the second variant represents the remove all event which only has channel and message ID.
Now, let's start writing our macro to handle most of the shared code:
macro_rules! perform_reaction {
This starts a macro declaration. We're using one type of macro known as declarative macros, which are created using macro_rules!
. The other type (procedural macros) is out of scope for this tutorial.
(($ctx:expr, $reaction_event:expr) $body:expr) => {
This is a match rule. One invocation of the macro that would match this rule looks something like this:
// you can open a invocation with either { (curly bracket), [ (bracket), or ( (parenthesis)
perform_reaction! {
/*opening parenthesis*/(
/*$ctx:expr*/ &ctx,
/*$reaction_event:expr*/ ReactionEvent::Reaction(&reaction),
/*closing parenthesis*/)
/*$body:expr*/ |poll, i| {
// stuff
}
}
In fact, that's what most of our invocations will look like. Now, let's move on to what we're going to do with the parameters, now that we have them:
use ReactionEvent::*;
We use the variants of this enum a lot so bring it into top-level scope temporarily.
// Discard if it's our own reaction.
if let Reaction(r) = $reaction_event {
if r.user_id == Some($ctx.cache.current_user_id().await) {
println!("Reaction added by self, ignoring");
return;
}
}
Due to the if let
, we are only evaluating this if the ReactionEvent
provided is a reaction (and not a remove all). If we added the reaction, then we ignore it and return.
let key = match $reaction_event {
Reaction(r) => (r.channel_id, r.message_id),
RemoveAll(c, m) => (c, m),
};
We are turning our ReactionEvent
into the key we can lookup in our polls map. In Rust, match
is an expression so it works fine here.
// Try to get poll for the given message otherwise return
{
Here, we're starting a new scope. This is very important because otherwise we would get a deadlock.
Deadlock? Why?
Consider that in order to check if a key is present in a map, we only need read access to the map. However, to modify the map, we need write access. But we can only tell if we need to modify the map if we read the map first. Therefore, this scope will be the scope that holds read access to the map. If we don't need to write to it, we will early return from here. Otherwise, we will drop our read access so that later, we can take write access without deadlocking.
Sound confusing? Let me try to illustrate with an example:
// How RwLocks work:
// * multiple readers allowed concurrently
// * writers require exclusive access (no readers or other writers)
let read_access = get_read_access(); // \- read access begins here
if need_write_access(read_access) { // |
let write_access = get_write_access(); // | \- write access begins here
// | ? DEADLOCK! Already have read access,
// | ? so this will never complete!
modify(write_access); // | ?
} else { // |
return; // /- read access ends due to return
}
To fix this, we add a scope:
{
let read_access = get_read_access(); // \- read access begins here
// By inverting the condition we |
// prevent the two accesses from |
// overlapping. |
if !need_write_access(read_access) { // |
return; // |- read access possibly ends due to return
} // |
} // /- read access ends due to scope
let write_access = get_write_access(); // \- write access begins here
// | OK! No other references!
// | No deadlock!
modify(write_access); // |
This makes Rust happy and we don't get any deadlock!
Now, we check if the key is present in the polls map (if it is not, the message is not a poll):
let poll_data = $ctx.data.read().await;
let poll_map = poll_data
.get::<PollsKey>()
.expect("Failed to retrieve polls map!")
.lock()
.await;
if !poll_map.contains_key(&key) {
println!("Message not in polls map, ignoring");
return;
}
First we acquire read access to the data map (from the provided $ctx
). Next, we try to get the polls map by looking up the PollsKey
. Then we check if the polls map contains our message's key. If not, we return since it is not a poll. (This is similar to what we did in the poll
command.)
}
Now that we're done reading the map we relenquish read access by closing the scope.
Now we re-retrieve the polls map again but this time with write access:
// reretrieve the map as writable
let mut poll_data = $ctx.data.write().await;
let mut poll_map = poll_data
.get_mut::<PollsKey>()
.expect("Failed to retrieve polls map!")
.lock()
.await;
let poll = match poll_map.get_mut(&key) {
None => {
println!("Failed to get poll for {:?}", key);
return;
}
Some(poll) => poll,
};
If we were not able to find the poll in the map (even though we checked above), we just return.
Now we need to do something a little wacky. The $body:expr
we declared in our rule above? That's going to actually be a function that takes an &mut Poll
(to modify the poll) and an Option<usize>
indicating which answer was reacted to (if it is not a remove all event.) Since Rust can't infer the type of the function for some reason, we need to nudge it:
// nudges Rust towards the right type :)
fn get_f<F: FnOnce(&mut Poll, Option<usize>)>(f: F) -> F {
f
}
let f = get_f($body);
Next, if the event was a reaction, we need to validate it to ensure that we should actually process the reaction:
match $reaction_event {
Reaction(r) => match r.emoji {
ReactionType::Unicode(ref s) => {
let c = s.chars().nth(0).unwrap();
let end_char = std::char::from_u32('🇦' as u32 + poll.answers.len() as u32 - 1)
.expect("Failed to format emoji");
if c < '🇦' || c > end_char {
println!("Emoji is not regional indicator or is not in range, ignoring");
return;
}
let number = (c as u32 - '🇦' as u32) as usize;
First we ensure that the reaction's emoji is a Unicode emoji, since all regional indicators (the emojis we are using) are Unicode. Then we check that the emoji is actually a regional indicator, and is in range. Next, we figure out which answer it would be (where 0 is the first answer.)
Now, we can call the body and it can modify the poll:
f(poll, Some(number));
}
We also have to handle non-Unicode emojis, which we simply ignore:
_ => {
println!("Unknown emoji in reaction, ignoring");
return;
}
},
And if the event was a remove all, we just call the function, this time without an answer number:
RemoveAll(..) => f(poll, None),
}
Now that we've let the body update the poll appropriately, we need to update the message with the new value of the poll:
let content = render_message(&poll);
key.0
.edit_message(&$ctx.http, key.1, |edit| edit.content(&content))
.await
.expect("Failed to edit message");
println!("Rerendered message");
};
}
We're using that render_message
function we defined earlier (this is why we made it a function 😉) to get the new contents of our message. Then we edit the message's contents.
And that's it! We're done writing our macro! The rest is gonna be pretty easy from here since it's just a few more lines!
Handling reaction events
Now we just need to write the event handlers for the reaction events which will be super easy.
Inside the impl
block where you defined the ready
method on Handler
, let's add the reaction_add
handler as well:
async fn reaction_add(&self, ctx: Context, add_reaction: Reaction) {
println!("Reaction add");
perform_reaction! { (ctx, ReactionEvent::Reaction(&add_reaction)) |poll, number| {
poll.answerers[number.unwrap()] += 1;
}}
}
We're calling our perform_reaction
macro defined above with the context, an event with reaction (the reaction we got) and a body which just increments the vote count for the given answer by one.
We do pretty much the same thing for reaction_remove
, but we are decrementing this time:
async fn reaction_remove(&self, ctx: Context, removed_reaction: Reaction) {
println!("Single reaction remove");
perform_reaction! { (ctx, ReactionEvent::Reaction(&removed_reaction)) |poll, number| {
poll.answerers[number.unwrap()] -= 1;
}}
}
And for reaction_remove_all
, we just iterate through each vote count and set it to zero:
async fn reaction_remove_all(&self, ctx: Context, channel_id: ChannelId, removed_from_message_id: MessageId) {
println!("All reactions removed");
perform_reaction! { (ctx, ReactionEvent::RemoveAll(channel_id, removed_from_message_id)) |poll, _| {
for answers in poll.answerers.iter_mut() {
*answers = 0;
}
}}
}
Wrap-up
YAY!! We did it!! You're finally done with your bot!
Let's go ahead and try out all of this new stuff by hitting the run button. You should be able to add and remove reactions, having the message update as you do so. Additionally, if you try removing all reactions, it should update properly as well.
The full code is in the Demo section.
Going further
There are many ways to improve and hack on this project further. Here are 3 examples of possible hacks you could do:
- Showing who voted for what (code)
- Reaction roulette - your vote only has a 20% chance of counting! (This is kind of annoying, and on purpose :-) (code)
- Modifying polls after they are created (code)
Here are some links if you'd like to learn more about writing Discord bots with Serenity: