【问题标题】:How to modify global numpy array safely with multithreading in Python?如何在 Python 中使用多线程安全地修改全局 numpy 数组?
【发布时间】:2021-02-23 17:44:21
【问题描述】:

我正在尝试在线程池中运行模拟并将每次重复的结果存储在全局 numpy 数组中。但是,我在这样做时遇到了问题,我观察到以下(简化)代码(python 3.7)的一个非常有趣的行为:

import numpy as np
from multiprocessing import Pool, Lock

log_mutex = Lock()
repetition_count = 5
data_array = np.zeros(shape=(repetition_count, 3, 200), dtype=float)

def record_results(repetition_index, data_array, log_mutex):
    log_mutex.acquire()
    print("Start record {}".format(repetition_index))
    # Do some stuff and modify data_array, e.g.:
    data_array[repetition_index, 0, 53] = 12.34
    
    print("Finish record {}".format(repetition_index))
    log_mutex.release()

def run(repetition_index):
    global log_mutex
    global data_array

    # do some simulation

    record_results(repetition_index, data_array, log_mutex)

if __name__ == "__main__":
    random.seed()
    with Pool(thread_count) as p:
        print(p.map(run, range(repetition_count)))



问题是:我得到了正确的“开始记录和完成记录”输出,例如开始记录 1... 结束记录 1。但是,每个线程修改的 numpy 数组的不同切片不会保存在全局变量中。换句话说,线程 1 修改过的元素仍然为零,线程 4 覆盖了数组的不同部分。

补充一点,全局数组的地址,我通过它检索 print(hex(id(data_array))) 对于所有线程都是相同的,在它们的 log_mutex.acquire() ... log_mutex.release() 行内。

我错过了一点吗?就像,每个线程都存储了多个全局 data_array 副本?我正在观察一些这样的行为,但是当我使用 global 关键字时不应该是这种情况,我错了吗?

【问题讨论】:

  • 您不是在使用多个线程,而是在使用多个进程。 id 只在一个进程中是唯一的

标签: python python-3.x multithreading numpy mutex


【解决方案1】:

看起来您正在使用多个进程而不是多个线程运行 run 函数。试试这样的:

import numpy as np
from threading import Thread, Lock

log_mutex = Lock()
repetition_count = 5
data_array = np.zeros(shape=(repetition_count, 3, 200), dtype=float)

def record_results(repetition_index, data_array, log_mutex):
    log_mutex.acquire()
    print("Start record {}".format(repetition_index))
    # Do some stuff and modify data_array, e.g.:
    data_array[repetition_index, 0, 53] = 12.34
    print("Finish record {}".format(repetition_index))
    log_mutex.release()

def run(repetition_index):
    global log_mutex
    global data_array
    record_results(repetition_index, data_array, log_mutex)

if __name__ == "__main__":
    threads = []
    for i in range(repetition_count):
        t = Thread(target=run, args=[i])
        t.start()
        threads.append(t)
    for t in threads:
        t.join()

更新:

要对多个进程执行此操作,您需要使用multiprocessing.RawArray 来实例化您的数组;数组的大小是乘积repetition_count * 3 * 200。在每个进程中,使用np.frombuffer 在阵列上创建一个视图,并相应地对其进行整形。虽然这会非常快,但我不鼓励这种编程风格,因为它依赖于全局共享内存对象,这在大型程序中很容易出错。

如果可能,我建议删除全局data_array,而是在每次调用record_results 时实例化一个数组,您将在run 中返回该数组。 p.map 调用将返回一个数组列表,您可以将其转换为 numpy 数组并恢复原始实现中全局 data_array 的形状和内容。这会产生通信成本,但它是一种更简洁的并发管理方法,并且无需锁。

尽量减少进程间通信通常是个好主意,但除非性能至关重要,否则我认为共享内存不是正确的解决方案。使用p.map,您需要避免返回大对象,但您的 sn-p 中的对象大小非常小(600*8 字节)。

【讨论】:

  • 确实应该有python的ThreadPool,但是还没有文档。
  • 但是使用(进程)池的正确方法是什么?
  • @OnurA 这里是一个线程池:docs.python.org/3/library/…
  • 我更新了我的回复,用p.map 解释了几个解决方案。
猜你喜欢
  • 1970-01-01
  • 2019-11-23
  • 1970-01-01
  • 2021-03-17
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
相关资源
最近更新 更多