Nime

Multithreading in Rust

The general idea with multithreading is to do more work in a smaller amount of time. This is done by seperating code into multiple parts called threads.

These threads are then executed concurrently (maybe even in parallel).

Those threads have to be (largely) independent from each other. Precautions need to be taken for parts of code that depend on something from a different thread.

That’s the reason writing multithreaded code is often seen as “very hard”. Many multithreading bugs are very subtle, and hard to track down.

Lucky for Rust programmers, one of the major goals of the language is making concurrent programming safe and efficient. The Rust language makes many of these bugs impossible to write. Incorrect code will refuse to compile and present an error explaining the problem.

This post goes over a few ways to write multithreaded code by solving an exercism programming exercise.

Creating threads

A thread is created by calling std::thread::spawn. It’s a function that takes a closure.

The closure contains the code that will be ran in the thread.

The moment that thread is created, it is “detached” from the thread that created it. That means it is totally independent and can outlive the thread that spawned it (unless that creator is the main thread, if that stops, everything stops).

That “possibly outliving the parent thread” means everything passed to the closure must remain valid for the entire program (meaning: it has a 'static lifetime). This ensures everything in the thread remains valid even when the thread that spawned it no longer exists.

In practice, this means you want the closure to take ownership of every variable it uses. This is done by using the move keyword in front of the parameter list of the closure.

It’s possible to make a parent thread wait for the completion of a thread it spawned.

A call to std::thread::spawn returns a JoinHandle. That handle has a join method that blocks the current thread, waiting until the spawned thread is closed to continue executing code.

use std::thread;
fn main() {
let handle = thread::spawn(move || {
// some work here
});
// some work here
handle.join();
}

The problem

Count the frequency of letters in texts using parallel computation.

Parallelism is about doing things in parallel that can also be done sequentially. A common example is counting the frequency of letters. Create a function that returns the total frequency of each letter in a list of texts and that employs parallelism.

We need to write a function called frequency. It takes a slice of strings and a worker count as parameters.

The return value is a hashmap. The keys are all the letters those strings contain, their value the amount of times that letter appears.

This needs to be done in worker_count number of threads.

fn frequency(input: &[&str], worker_count: usize) -> HashMap<char, usize>

Single threaded

This snippet shows a solution to the problem that does all the work on a single thread.

lib.rs
fn frequency(input: &[&str]) -> HashMap<char, usize> {
let mut map = HashMap::new();
for line in input {
for c in line.chars().filter(|c| c.is_alphabetic()) {
*map.entry(c.to_ascii_lowercase()).or_default() += 1;
}
}
map
}

Strategy

We’ll divide the large problem into several smaller ones.

Each one will be solved in a thread with code similar to the one in the singlethreaded example.

The results of those smaller problems have to be combined into one large result which will be the returned value of the frequency function.

The breaking up of the larger problem can be done by calling the chunks method on the input parameter.

input.chunks((input.len() / worker_count).max(1));

The result is an iterator of length worker_count. Each thread will solve the problem for a single chunk.

We make sure to own the data in a chunk before sending it into a thread.

chunk.join("")

Then we are ready to spawn a thread that will solve the problem for each chunk.

pub fn frequency(input: &[&str], worker_count: usize) -> HashMap<char, usize> {
// divide the large problem into smaller problems
let chunks = input.chunks((input.len() / worker_count).max(1));
for chunk in chunks {
// collect the data for the current chunk into an owned variable before sending it to the thread.
let string = chunk.join("");
thread::spawn(move || {
// solve the problem for the current chunk
});
}
// combine the solutions
}

Joinhandle

A JoinHandle can have an inner type. This means we can return something from a child thread which can then be accessed by the thread that spawned it when it calls .join.

lib.rs
use std::collections::HashMap;
use std::thread;
pub fn frequency(input: &[&str], worker_count: usize) -> HashMap<char, usize> {
let mut result: HashMap<char, usize> = HashMap::new();
let chunks = input.chunks((input.len() / worker_count).max(1));
let mut handles = Vec::new();
for chunk in chunks {
let string = chunk.join("");
// return a HashMap from each thread, the JoinHandle wraps this hashmap
let handle = thread::spawn(move || {
let mut map: HashMap<char, usize> = HashMap::new();
for c in string.chars().filter(|c| c.is_alphabetic()) {
*map.entry(c.to_ascii_lowercase()).or_default() += 1;
}
map
});
handles.push(handle);
}
// wait for each thread to finish and combine every HashMap into the final result
for handle in handles {
let map = handle.join().unwrap();
for (key, value) in map {
*result.entry(key).or_default() += value;
}
}
result
}

Channel

A popular method to ensure safe concurrency is message passing. Multiple threads communicate by sending each other messages containing data.

Rust has the tool of a channel for this.

You can think of a channel like a stream of water. Put something in one end, it comes out at the other end.

A programming channel has two halves, a sender and a receiver. Put things in the sender, take things out of the receiver.

The Rust standard library has an implementation of this called mpsc, which stands for “multiple producer, single consumer”. So the mpsc channel can have multiple senders, but only a single receiver.

You can think of this like a river delta with many smaller rivers ending at the same location.

The std::sync::mpsc::channel function returns a tuple with a sender, and a receiver.

The Sender can be cloned, creating multiple copies that are able to be sent across threads.

A Sender has a send method, which, you guessed it, sends a value down the channel.

The values sent have to be owned values, a sent value can no longer be used in the thread it is sent from. Ownership transfers to the receiver when it receives that value.

The Receiver has a recv method which blocks the current thread until a message is received.

use std::thread;
use std::sync::mpsc;
fn main() {
let (sender, receiver) = mpsc::channel();
for i in 0..10 {
let sender = sender.clone();
thread::spawn(move|| {
sender.send(i).unwrap();
});
}
for _ in 0..10 {
// receive each value and wait between each one
println!("Got: {}", receiver.recv().unwrap());
}
}

A channel closes if either all the senders, or the single receiver is dropped.

It is possible to iterate over the receiver. The receiver will block when the iterator asks for the next value. When the channel is closed, the iterator will return None and end.

This introduces a small problem if we want to send cloned Senders into a thread. The original Sender is never dropped, and the channel will remain open.

This isn’t an issue in the example above where we looped a set number of times, but we create an infinite wait if we use the iterator method.

use std::thread;
use std::sync::mpsc;
fn main() {
let (sender, receiver) = mpsc::channel();
for i in 0..10 {
let sender = sender.clone();
thread::spawn(move|| {
sender.send(i).unwrap();
});
}
// this will wait until all senders are dropped
// the original sender is never dropped, so this waits forever
for received in receiver {
println!("Got: {}", received);
}
}

The solution is to drop the original sender.

use std::mem;
use std::thread;
use std::sync::mpsc;
fn main() {
let (sender, receiver) = mpsc::channel();
for i in 0..10 {
let sender = sender.clone();
thread::spawn(move|| {
sender.send(i).unwrap();
});
}
// drop the original sender
mem::drop(sender);
for received in rx {
println!("Got: {}", received);
}
}

You can also wrap the entire top section in a scope by using curly braces, ensuring everything leaves scope by the time the iterator is called.

Applied

Back to our frequency problem.

We make a String out of the data in a chunk and send that into a thread.

This time, we also send a sender into each thread.

When that thread solved its chunk, it sends the HashMap through the channel.

The main thread then picks up all messages and combines them into our final result.

lib.rs
use std::collections::HashMap;
use std::mem;
use std::sync::mpsc;
use std::thread;
pub fn frequency(input: &[&str], worker_count: usize) -> HashMap<char, usize> {
let mut result: HashMap<char, usize> = HashMap::new();
let chunks = input.chunks((input.len() / worker_count).max(1));
let (sender, receiver) = mpsc::channel();
for chunk in chunks {
let sender = sender.clone();
let string = chunk.join("");
thread::spawn(move || {
// Solve each chunk and send the resulting HashMap down the channel
let mut map: HashMap<char, usize> = HashMap::new();
for c in string.chars().filter(|c| c.is_alphabetic()) {
*map.entry(c.to_ascii_lowercase()).or_default() += 1;
}
sender.send(map).unwrap();
});
}
// drop the original sender, else the channel will remain open, causing the receiver to infinitely wait
mem::drop(sender);
// combine every received HashMap
for received in receiver {
for (key, value) in received {
*result.entry(key).or_default() += value;
}
}
result
}

Mutex

A Mutex wraps around other data. Any thread that wants to access that inner data has to go through the mutex first, which ensures the inner data is only accessed by one thread at a time.

That’s where the name comes from too, mutual exclusion.

While it is possible to use a Mutex in a single-threaded context. Using one would add an unnecessary layer of complexity when you can safely access data directly.

Because the mutex will be sent into a thread, it is often wrapped in an Arc so it can be owned by multiple threads at once.

A mutex allows multiple threads to access (and change) the same piece of data, but ensures only one thread can access that data at a time.

A mutex has a lock method that returns a MutexGuard if it succeeds. This “locks” the mutex, preventing every other thread from accessing the mutex.

If an other thread tries to access that mutex while that guard still exists, it will block that thread until the lock can be acquired.

That MutexGuard is a smart pointer.

The inner data can be accessed through that smart pointer. When the MutexGuard goes out of scope, the lock is dropped and an other thread gets a chance to acquire it.

The following example spawns 10 threads, each one increments the value inside the mutex. There is no guarantee in which order the threads run, but the final count will always be 10

use std::sync::{Arc, Mutex};
use std::thread;
fn main() {
let mutex = Arc::new(Mutex::new(0));
let mut handles = Vec::new();
for _ in 0..10 {
let mutex = Arc::clone(&mutex);
let handle = thread::spawn(move || {
let mut guard = mutex.lock().unwrap();
*guard += 1;
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
assert_eq!(*mutex.lock().unwrap(), 10);
}

Applied

Back to our frequency problem.

We make a String out of the data in a chunk and send that into a thread.

This time, we also send a mutex into each thread.

When that thread solved its chunk, it adds the resulting HashMap to that mutex.

The main thread waits for all threads to finish and returns the data the mutex wraps.

lib.rs
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::thread;
pub fn frequency(input: &[&str], worker_count: usize) -> HashMap<char, usize> {
let result = Arc::new(Mutex::new(HashMap::new()));
let chunks = input.chunks((input.len() / worker_count).max(1));
let mut handles: Vec<_> = Vec::new();
for chunk in chunks {
let string = chunk.join("");
let result = Arc::clone(&result);
let handle = thread::spawn(move || {
let mut map: HashMap<char, usize> = HashMap::new();
// create a HashMap for this chunk
for c in string.chars().filter(|c| c.is_alphabetic()) {
*map.entry(c.to_ascii_lowercase()).or_default() += 1;
}
// add the HashMap of this chunk to the HashMap that is wrapped by the Mutex
let mut result = result.lock().unwrap();
for (key, value) in map {
*result.entry(key).or_default() += value;
}
});
handles.push(handle);
}
// wait for each thread to finish
for handle in handles {
handle.join().unwrap()
}
// get the HashMap from the Arc<Mutex<HashMap>>
Arc::try_unwrap(result).unwrap().into_inner().unwrap()
}

Bonus: gotchas

Care has to be taken to ensure the concurrent code you write can take maximum advantage of it being concurrent. The Rust compiler will happily let you compile code that is slower than the sequential counterpart, as long as it’s correct.

The location threads are joined, or mutexes are locked matters. Every time you block a thread, and make it wait before it continues, ask yourself if there’s more work you can do.

Sometimes that involves writing code in a different way that seems less efficient at first, but is faster because it waits less.

For example, in the mutex example. The code within each thread has 2 for loops and goes a little like this:

for loop
lock
for loop

The same thing could be done with a single for loop if we lock the mutex before that loop.

This structure would effectively make the entire calculation sequential again. Each thread locks the mutex and excludes every other thread from accessing it.

Since all the work is done after acquiring the lock, every other thread has to wait around for its turn.

By dividing the code inside a thread into 2 parts, the work specific to each thread gets done without making other threads wait.


I’m quite the fan of iterators.

An iterator chain is executed one item at a time.

This means that for the following iterator, the entire chain gets executed for one item before moving on the the next item:

// first executes the entire chain for 1, then 2, then 3
[1, 2, 3].iter().filter(|n| n % 2 != 0).map(|n| n * 2).sum()

In a previous version of my JoinHandle code, everything was one big iterator chain.

Inside that chain I called handle.join(). That meant every other thread couldn’t even be spawned before the previous one finished running.

Final code

lib.rs
use std::collections::HashMap;
use std::thread;
pub fn frequency(input: &[&str], worker_count: usize) -> HashMap<char, usize> {
let mut result: HashMap<char, usize> = HashMap::new();
let chunks = input.chunks((input.len() / worker_count).max(1));
let mut handles = Vec::new();
for chunk in chunks {
let string = chunk.join("");
// return a HashMap from each thread, the JoinHandle wraps this hashmap
let handle = thread::spawn(move || {
let mut map: HashMap<char, usize> = HashMap::new();
for c in string.chars().filter(|c| c.is_alphabetic()) {
*map.entry(c.to_ascii_lowercase()).or_default() += 1;
}
map
});
handles.push(handle);
}
// wait for each thread to finish and combine every HashMap into the final result
for handle in handles {
let map = handle.join().unwrap();
for (key, value) in map {
*result.entry(key).or_default() += value;
}
}
result
}