【问题标题】:Implementing a parallel / multithreaded merge sort on Vec在 Vec 上实现并行/多线程合并排序
【发布时间】:2021-04-01 13:06:42
【问题描述】:

我正在尝试通过实现并行合并排序来学习 Rust 的多线程。一个简单的递归版本就可以了,但是这个版本:

use rand;

use std::sync::{Arc, Mutex};
use std::thread;

fn main() {
    //let mut input_line = String::new();
    // println!("Input amount of numbers to sort:");
    // let amount = match std::io::stdin().read_line(&mut input_line){
    //     Ok(_) => i64::from_str_radix(&input_line.trim(), 10).unwrap(),
    //     Err(_) => panic!("Error while reading amount of values")
    // };
    let amount = 1_000_000;

    // let mut rnd = rand::thread_rng();
    let mut arr: Vec<i64> = Vec::new();
    for _ in 0..amount {
        arr.push(rand::random::<i64>())
    }

    // println!("Vector before sort:");
    // for elem in &arr {
    //     println!("{}", elem);
    // }

    merge_sort(&mut arr);

    // println!("Vector after sort:");
    // for elem in &arr {
    //     println!("{}", elem);
    // }
}

fn merge_sort(arr: &mut Vec<i64>) {
    let arr_len = arr.len();
    let arr_slice = arr.as_mut_slice();

    // simple_merge_sort(arr, 0 as usize, arr_len - 1 as usize);

    let arc = Arc::new(Mutex::new(arr));
    par_merge_sort(&mut arc, 0 as usize, arr_len - 1 as usize, 4);
}

fn simple_merge_sort(arr: &mut Vec<i64>, lo: usize, hi: usize) {
    if lo == hi {
        return;
    }

    let mi = (hi + lo) / 2;
    simple_merge_sort(arr, lo, mi);
    simple_merge_sort(arr, mi + 1, hi);

    merge(arr, lo, mi, hi);
}

fn par_merge_sort(arc: &mut Arc<Mutex<&mut Vec<i64>>>, lo: usize, hi: usize, threads: i32) {
    if lo == hi {
        return;
    }

    let mi = (hi + lo) / 2_usize;
    if threads == 1 {
        let mut simple_arr = arc.lock().unwrap();
        simple_merge_sort(&mut simple_arr, lo, hi);
    } else {
        let thread_arc = Arc::from(*arc);
        let thread_rest = threads / 2;
        let thread_rest_2 = threads - thread_rest;
        let thread1 = thread::spawn(move || {
            par_merge_sort(&mut thread_arc, lo, mi, thread_rest);
        });
        let thread_arc = Arc::from(*arc);
        let thread2 = thread::spawn(move || {
            par_merge_sort(&mut thread_arc, mi + 1, hi, thread_rest_2);
        });

        thread1.join().unwrap();
        thread2.join().unwrap();
    }

    let mutex = arc.lock().unwrap();
    merge(&mut *mutex, lo, mi, hi);
}

fn merge(arr: &mut Vec<i64>, lo: usize, mi: usize, hi: usize) {
    let mut lo_arr: Vec<i64> = Vec::new();
    for i in lo..(mi + 1) {
        let elem = *arr.get(i).unwrap();
        lo_arr.push(elem);
    }

    let mut hi_arr: Vec<i64> = Vec::new();
    for i in (mi + 1)..(hi + 1) {
        let elem = *arr.get(i).unwrap();
        hi_arr.push(elem);
    }

    let mut i = 0;
    let mut j = 0;
    let mut counter = lo;

    while i < lo_arr.len() && j < hi_arr.len() {
        let elem_i = *lo_arr.get(i).unwrap();
        let elem_j = *hi_arr.get(j).unwrap();

        if elem_i <= elem_j {
            arr[counter] = elem_i;
            i += 1;
        } else {
            // elem_j <= elem_i
            arr[counter] = elem_j;
            j += 1;
        }
        counter += 1;
    }

    if j == hi_arr.len() {
        while i < lo_arr.len() {
            let elem_i = *lo_arr.get(i).unwrap();
            arr[counter] = elem_i;
            i += 1;
            counter += 1;
        }
    } else {
        // i == lo_arr.len()
        while j < hi_arr.len() {
            let elem_j = *hi_arr.get(j).unwrap();
            arr[counter] = elem_j;
            j += 1;
            counter += 1;
        }
    }
}

产生错误:

error[E0621]: explicit lifetime required in the type of `arc`
  --> src/main.rs:69:23
   |
56 | fn par_merge_sort(arc: &mut Arc<Mutex<&mut Vec<i64>>>, lo: usize, hi: usize, threads: i32) {
   |                        ------------------------------ help: add explicit lifetime `'static` to the type of `arc`: `&mut Arc<Mutex<&'static mut Vec<i64>>>`
...
69 |         let thread1 = thread::spawn(move || {
   |                       ^^^^^^^^^^^^^ lifetime `'static` required

error[E0621]: explicit lifetime required in the type of `arc`
  --> src/main.rs:73:23
   |
56 | fn par_merge_sort(arc: &mut Arc<Mutex<&mut Vec<i64>>>, lo: usize, hi: usize, threads: i32) {
   |                        ------------------------------ help: add explicit lifetime `'static` to the type of `arc`: `&mut Arc<Mutex<&'static mut Vec<i64>>>`
...
73 |         let thread2 = thread::spawn(move || {
   |                       ^^^^^^^^^^^^^ lifetime `'static` required

【问题讨论】:

标签: multithreading rust


【解决方案1】:

由于您的问题是关于并行化而不是排序,因此我在下面的示例中省略了 serial_sortmerge 函数的实现,但您可以使用已有的代码轻松填写它们:

#![feature(is_sorted)]

use crossbeam; // 0.8.0
use rand; // 0.7.3
use rand::Rng;

fn random_vec(capacity: usize) -> Vec<i64> {
    let mut vec = vec![0; capacity];
    rand::thread_rng().fill(&mut vec[..]);
    vec
}

fn parallel_sort(data: &mut [i64], threads: usize) {
    let chunks = std::cmp::min(data.len(), threads);
    let _ = crossbeam::scope(|scope| {
        for slice in data.chunks_mut(data.len() / chunks) {
            scope.spawn(move |_| serial_sort(slice));
        }
    });
    merge(data, chunks);
}

fn serial_sort(data: &mut [i64]) {
    // actual implementation omitted for conciseness
    data.sort()
}

fn merge(data: &mut [i64], _sorted_chunks: usize) {
    // actual implementation omitted for conciseness
    data.sort()
}

fn main() {
    let mut vec = random_vec(10_000);
    parallel_sort(&mut vec, 4);
    assert!(vec.is_sorted());
}

playground

【讨论】:

    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2016-04-07
    • 1970-01-01
    • 2013-08-11
    • 1970-01-01
    • 1970-01-01
    • 2011-09-01
    相关资源
    最近更新 更多