題目網址(需登入 Exercism)
簡單來說,就是要用 concurrency 的方式,將「計算文章中每個字母(數字和標點不計)出現的次數」,分配給 worker_count
個 thread 去做,最後再整合在一起。
輸入的資料,第一個是 &str
的 array slice,第二個是 worker_count
。
問題 1:如何將工作/資料平均分配為 worker_count 份?
第一個直覺的想法,就是從 array 的第 0 個開始,依序分配給不同的 worker。
假設 i
是 worker 的 index,那麼第 i
個 worker 所需要處理的資料可以利用 iterator 挑出來:
input.iter()
.skip(i)
.step_by(work_count)
//...
另外一個想法,是讓資料盡可能連續:
根據我的實驗,後者稍微比前者快一點點。但比較麻煩的是需要計算要怎麼分割。
首先,每個 worker 至少都會分配到 input.len() / worker_count
行。這樣會剩下 r = input.len() % worker_count
行。接下來只要前 r
個 worker 一人多認領一行就可以了。假設 worker 的 index 為 i
,只要 m > i
,就再多加一行。因此每個 worker 分配到的行數為:
let length = input.len() / worker_count +
if input.len() % worker_count > i {
1
} else {
0
};
至於要從哪一行開始拿,就只要把前面 worker 的 length 累加起來即可。
所以最後的程式碼如下:
let quotient = input.len() / worker_count;
let reminder = input.len() % worker_count;
let mut start_index = 0;
for i in 0..worker_count {
let length = quotient + if reminder > i { 1 } else { 0 };
if length == 0 {
break;
}
let v: Vec<_> = (&input[start_index..(start_index + length)])
.into_iter()
//...
start_index += length;
}
問題 2:如何處理傳入 thread 中的資料?
由於 spawned thread 可能會活得比 main thread 還長,因此若是 main thread 要把資料傳給 spawned thread,不能傳一般的 reference 過去(要是在 spawned thread 結束之前,main thread 就結束了,那reference 不就找不到資料了嗎?)。因此最直覺的作法,是把要傳入的資料轉成 owned data,然後利用 move
直接把所有權交給 spawned thread:
for i in 0..worker_count {
//....
let v: Vec<_> = (&input[start_index..(start_index + length)])
.into_iter()
.map(|&s| String::from(s))
.collect();
handles.push(thread::spawn(move || {
v.into_iter() // <-- v 被 move 進來
.for_each(|s| { // s 的型別為 String
//....
建立 v
是省不掉的。但把所有的 &str
轉成 String
實在很花時間。如果我們為 input: &[&str]
中的 &str
加上 static 的 lifetime 限制,就可以保證 &str
的存活時間比 spawned thread 長,這樣就可以直接把 &str
傳進 spawned thread 中了:
pub fn frequency(input: &[&'static str], worker_count: usize) -> HashMap<char, usize> {
//...
for i in 0..worker_count {
//....
let v: Vec<_> = (&input[start_index..(start_index + length)])
.into_iter()
// 到這邊 item 的 type 是 &&str,因此需要 deref 一下...
.map(|&s| s)
.collect();
handles.push(thread::spawn(move || {
v.into_iter().for_each(|s| { // 此處 s 的型別就是 &str,而非 String
//...
問題 3:如何整合所有的資料?
每個 spawned thread 都會把自己負責的部份存進自己的 hash map 中,最後還是要整合成一個的。我的作法是讓每個 spawned thread 都回傳自己的那一份 hash map,然後在 main thread 中整合。
let mut handles = vec![];
for i in 0..worker_count {
//...
handles.push(thread::spawn(move || {
let mut result = HashMap::new();
//...
result
}));
//...
}
handles
.into_iter()
.map(|h| h.join().unwrap())
.fold(HashMap::new(), |mut acc, single| {
single.into_iter().for_each(|(ch, count)| {
let entry = acc.entry(ch).or_insert(0);
*entry += count;
});
acc
})
另一個想法,是利用 Arc
和 Mutex
,把 main thread 中的 hash map 傳進 spawned thread 中,由每個 spawned thread 自行把結果塞進最後的 hash map 中。這個方法可行,但實驗結果沒有快多少,而且資料建立起來又比較麻煩,所以最後就沒這樣做了。
let result = Arc::new(Mutex::new(HashMap::new()));
for i in 0..worker_count {
//...
let r = Arc::clone(&result);
handles.push(thread::spawn(move || {
//...
result.into_iter().for_each(|(ch, count)| {
let mut m = r.lock().unwrap();
let entry = m.entry(ch).or_insert(0);
*entry += count;
})
//...
}));
//..
}
handles
.into_iter()
.for_each(|h| h.join().unwrap());
// 把資料從 Arc<Mutex<>> 中拿出來的方法。
// 參考資料:https://stackoverflow.com/questions/29177449/how-to-take-ownership-of-t-from-arcmutext
Arc::try_unwrap(result).unwrap().into_inner().unwrap()
完整程式
use std::collections::HashMap;
use std::thread;
pub fn frequency(input: &[&'static str], worker_count: usize) -> HashMap<char, usize> {
let mut handles = vec![];
let quotient = input.len() / worker_count;
let reminder = input.len() % worker_count;
let mut start_index = 0;
for i in 0..worker_count {
let length = quotient + if reminder > i { 1 } else { 0 };
if length == 0 {
break;
}
let v: Vec<_> = (&input[start_index..(start_index + length)])
.into_iter()
.map(|&s| s)
.collect();
handles.push(thread::spawn(move || {
let mut result = HashMap::new();
v.into_iter().for_each(|s| {
s.chars()
.filter(|c| c.is_alphabetic())
.map(|c| c.to_ascii_lowercase())
.for_each(|c| {
*(result.entry(c).or_insert(0)) += 1;
})
});
result
}));
start_index += length;
}
handles
.into_iter()
.map(|h| h.join().unwrap())
// merges results from different workers into one
.fold(HashMap::new(), |mut acc, single| {
single.into_iter().for_each(|(ch, count)| {
let entry = acc.entry(ch).or_insert(0);
*entry += count;
});
acc
})
}