【问题标题】:Multiprocessing on chunks of an image对图像块进行多处理
【发布时间】:2019-04-20 02:10:01
【问题描述】:

我有一个函数必须遍历图像的各个像素并计算一些几何图形。此功能需要很长时间才能运行(在 24 兆像素图像上约 5 小时),但似乎应该很容易在多个内核上并行运行。但是,我终其一生都找不到一个有据可查、解释清楚的使用 Multiprocessing 包执行此类操作的示例。这是我现在作为玩具示例运行的代码:

import numpy as np
import matplotlib.pyplot as plt
from scipy import misc
from skimage import color
import multiprocessing 
from multiprocessing import Process

#Some dumb stand in function for this exercise
def dumb_func(image):
    ny, nx = image.shape
    temp = np.empty_like(image)

    for y in range(ny):
        for x in range(nx):
            temp[y, x] = np.square(image[y, x])

    return temp

#Convert image to greyscale
img = color.rgb2gray(misc.ascent())

#Resize the image
ns = 2048 #Pixel size
img = misc.imresize(img, size = (ns, ns))


#Split the image into equal chunks...not sure how this works for arrays that
#are weird shapes and aren't the same size in each dimension

divs = 4
init_split = np.array_split(img, divs, axis = 0)
side = init_split[0].shape[0]
chunked = np.empty((divs, divs, side, side))
cur = 0
for i in range(divs):
    split = np.array_split(init_split[i], divs, axis = 1)
    for j in range(divs):
        chunked[i, j, :, :] = split[j]
        cur +=1

#Pull core count and divide by two to be safe
cores = int(multiprocessing.cpu_count() / 2)

result = np.empty_like(chunked)
idxs = np.array(np.meshgrid(np.arange(0, divs, 1), 
                            np.arange(0, divs, 1))).T.reshape(-1, 2)

基本上,这段代码加载到图像中,将其转换为灰度,使其变大,然后将其分块。分块数组的形状为 (i, j, ny, nx),其中 i 和 j 是标识我正在使用的图像块的索引,而 ny,nx 描述每个块的像素大小。

此外,我正在创建一个名为 idxs 的数组,它将所有可能的索引存储到分块数组中,以将分块图像拉出。

我想要做的是在块上并行运行一个函数(在本例中以dumb_func为例)并将结果存储在相同形状的结果数组中。我想象的方法是遍历 idxs 数组并分配处理属于这些索引的块,直到核心数,等待这些核心完成,然后为核心提供更多进程直到完成。我被卡住了,因为我无法 A) 弄清楚如何访问函数中的返回值,以及 B) 如何处理我可能有 16 个块和 5 个内核导致最后一次迭代只需要一个进程的情况。

我该怎么做呢?在过去的 6 到 7 个小时里,我一直在阅读有关多处理池、进程、地图、星图等方面的信息……但我一生都无法理解如何实现这一点。

为 Reedinationer 编辑:

这是我更新的代码,运行时没有错误。但是 new_data 数组永远不会更新。我用值 100 填充它,并且在例程 new_data 的末尾正是它的初始化方式。

import numpy as np
import matplotlib.pyplot as plt
from scipy import misc
from multiprocessing import Process, JoinableQueue
from time import time

#SOme dumb stand in function for this exercise
def dumb_func(q, new_data):
    while True:
        index, image = q.get()
        temp = image **2

        new_data[index[0], index[1], :, :] = temp
        q.task_done()

if __name__ == "__main__":
    start = time()
    q = JoinableQueue()
    img = misc.ascent()
    #Resize the image
    ns = 2048 #Pixel size
    img = misc.imresize(img, size = (ns, ns))
    #Split the image into equal chunks...not sure how this works for arrays that
    #are weird shapes and aren't the same size in each dimension

    divs = 4
    init_split = np.array_split(img, divs, axis = 0)
    side = init_split[0].shape[0]
    chunked = np.empty((divs, divs, side, side))
    cur = 0
    for i in range(divs):
        split = np.array_split(init_split[i], divs, axis = 1)
        for j in range(divs):
            chunked[i, j, :, :] = split[j]
            cur +=1

    new_data = np.full(chunked.shape, 100)
    idxs = np.array(np.meshgrid(np.arange(0, divs, 1), 
                                np.arange(0, divs, 1))).T.reshape(-1, 2)

    for i in range(len(idxs)):
        q.put((idxs[i], chunked[idxs[i][0], idxs[i][1], :, :]))

    print ('starting workers')

    worker_count = len(idxs)
    processes = []
    for i in range(worker_count):
        p = Process(target=dumb_func, args=[q, new_data])
        p.daemon = True
        p.start()
    print('main thread waiting')
    q.join()

    end = time()
    print('{:.3f} seconds elapsed'.format(end - start))

【问题讨论】:

  • 请澄清您的示例,以显示您想如何以及在何处致电dumb_func。通常,您需要一个 multiprocessing.Pool(),然后是 map 您对该池的调用。

标签: python python-3.x multiprocessing


【解决方案1】:

我会做这样的事情,从依赖项开始:

from multiprocessing import Pool
import numpy as np
from PIL import Image

# and some for testing
from random import random
from time import sleep

首先我定义了一个函数来将图像分成“块”,就像你所说的那样:

def chunkit(ys, xs, blocksize=64):
    for y in range(0, ys, blocksize):
        yt = (y, min(ys, y + blocksize))
        for x in range(0, xs, blocksize):
            xt = (x, min(xs, x + blocksize))
            yield yt, xt

这是一个惰性迭代器,所以这可以持续一段时间。

然后我定义我的工作函数:

def dumb_func(cc):
    (y0,y1), (x0,x1) = cc
    # convert to floats for ease of processing
    chunk = image[y0:y1,x0:x1] / 255.
    # random slow down for testing
    # sleep(random() ** 6)
    res = chunk ** 2
    # convert back to bytes for efficiency
    return cc, (res * 255).astype(np.uint8)

为了提高效率,我确保源数组尽可能接近原始格式,并以相同的格式发回(如果您显然要处理其他像素格式,这可能需要一些麻烦)。

然后我把它放在一起:

if __name__ == '__main__':
    source = Image.open('tmp.jpeg')
    image = np.asarray(source)
    print("loaded", image.shape, image.dtype)

    with Pool() as pool:
        resit = pool.imap_unordered(
            dumb_func, chunkit(*image.shape[:2]))

        output = np.empty_like(image)
        for cc, res in resit:
            (y0,y1), (x0,x1) = cc
            output[y0:y1,x0:x1] = res

    im = Image.fromarray(output, 'RGB')
    im.save('out.jpeg')

这会在几秒钟内搅动一张 15Mpixel 的图像,其中大部分用于加载/保存图像。数组步长和缓存友好性可能会更聪明,但希望对您有所帮助!

注意:我认为这段代码依赖于 CPython Unix 风格的进程分叉语义,以确保图像在进程之间有效共享。不知道如果你在别的东西上运行它会发生什么

【讨论】:

  • 谢谢。抱歉,当我说我的问题需要遍历图像时,我不是很清楚。我使用图像作为代理,我的实际问题是循环遍历 numpy 数组的像素(恰好包含大量处理和组合的图像)。所以传入数据的形状只是二维的。我在运行您的代码时遇到错误,并且在 cc 行,res in resit 说名称“图像”未定义。我在 Windows 机器上工作,所以这可能是您所说的 CPython Unix 依赖项
  • 代码应该只适用于任何二维数组/矩阵,我只使用图像,因为您的问题提到了它们并且它们很容易获得。对于任何软件开发,Windows 通常都很尴尬,我建议尽可能远离它!我已经四处寻找将数组放入 Windows 下的“共享内存”的方法,但看不到太多。 Python 确实在 Docker 中运行良好,所以这段代码可以在那里工作...... CPython 只是 Python 的“标准”版本;有 Python 和 JITted 版本的 Java 实现,它们的工作方式略有不同。
【解决方案2】:

我一直在为基本相同的事情编写代码。现在的目标只是用透明像素替换白色像素,但它似乎替换了整个图像,所以某处存在一个错误......虽然multiprocessing模块中不再出现错误,所以也许它可以作为如何加载 Queue 然后让您的工作进程处理它的示例!

from PIL import Image
from multiprocessing import Process, JoinableQueue
from threading import Thread
from time import time

def worker_function(q, new_data):
    while True:
        # print("Items in queue: {}".format(q.qsize()))
        index, pixel = q.get()
        if pixel[0] > 240 and pixel[1] > 240 and pixel[2] > 240:
            out_pixel = (0, 0, 0, 0)
        else:
            out_pixel = pixel
        new_data[index] = out_pixel
        q.task_done()

if __name__ == "__main__":
    start = time()
    q = JoinableQueue()

    my_image = Image.open('InputImage.jpg')
    my_image = my_image.convert('RGBA')
    datas = list(my_image.getdata())
    new_data = [0] * len(datas) # make a blank array the size of our image to fill later

    print('putting image into queue')
    for count, item in enumerate(datas):
        q.put((count, item))

    print('starting workers')
    worker_count = 50
    processes = []
    for i in range(worker_count):
        p = Process(target=worker_function, args=[q, new_data])
        p.daemon = True
        p.start()
    print('main thread waiting')
    q.join()
    my_image.putdata(new_data)
    my_image.save('output.png', "PNG")

    end = time()
    print('{:.3f} seconds elapsed'.format(end - start))

我认为“保护”if __name__ == "__main__" 块内的代码很重要,否则生成的进程似乎会运行它。

更新

看起来您需要实现Manager()(或者可能还有其他我不知道的方法!)。我通过将代码更改为来运行我的代码:

from PIL import Image
from multiprocessing import Process, JoinableQueue, Manager
from threading import Thread
from time import time


def worker_function(q, new_data):
    while True:
        # print("Items in queue: {}".format(q.qsize()))
        index, pixel = q.get()
        if pixel[0] > 240 and pixel[1] > 240 and pixel[2] > 240:
            out_pixel = (0, 0, 0, 0)
        else:
            out_pixel = pixel
        new_data[index] = out_pixel
        q.task_done()


if __name__ == "__main__":
    start = time()
    q = JoinableQueue()
    my_image = Image.open('InputImage.jpg')
    my_image = my_image.convert('RGBA')
    datas = list(my_image.getdata())
    # new_data = [(0, 0, 0, 0)]*len(datas)
    manager = Manager()
    new_data = manager.list([(0, 0, 0, 0)]*len(datas))
    print(new_data)
    print('putting image into queue')
    for count, item in enumerate(datas):
        q.put((count, item))

    print('starting workers')
    worker_count = 50
    processes = []
    for i in range(worker_count):
        p = Process(target=worker_function, args=[q, new_data])
        p.daemon = True
        p.start()
    print('main thread waiting')
    q.join()
    print("Saving Image")
    my_image.putdata(new_data)
    my_image.save('output.png', "PNG")

    end = time()
    print('{:.3f} seconds elapsed'.format(end - start))

虽然这似乎不是最快的选择!我敢肯定还有其他方法可以提高速度。我用Threads 做同样事情的代码看起来非常相似:

from PIL import Image
from threading import Thread
from queue import Queue
import time

start = time.time()
q = Queue()

planeIm = Image.open('InputImage.jpg')
planeIm = planeIm.convert('RGBA')
datas = planeIm.getdata()
new_data = [0] * len(datas)

print('putting image into queue')
for count, item in enumerate(datas):
    q.put((count, item))

def worker_function():
    while True:
        # print("Items in queue: {}".format(q.qsize()))
        index, pixel = q.get()
        if pixel[0] > 240 and pixel[1] > 240 and pixel[2] > 240:
            out_pixel = (0, 0, 0, 0)
        else:
            out_pixel = pixel
        new_data[index] = out_pixel
        q.task_done()

print('starting workers')
worker_count = 100
for i in range(worker_count):
    t = Thread(target=worker_function)
    t.daemon = True
    t.start()
print('main thread waiting')
q.join()
print('Queue has been joined')
planeIm.putdata(new_data)
planeIm.save('output.png', "PNG")

end = time.time()

elapsed = end - start
print('{:3.3} seconds elapsed'.format(elapsed))

然而,使用线程处理我的图像大约需要 23 秒,使用多处理大约需要 170 秒!我怀疑这将来自启动Process 对象所需的更大开销,而且我处理每个像素的算法现在很简单(只是if pixel[0] > 240 and pixel[1] > 240 and pixel[2] > 240: 位),所以我可能不会提高速度一个复杂的像素处理算法会得到我。还要注意multiprocessing documentation

一个管理器可以由网络上不同计算机上的进程共享。但是,它们比使用共享内存要慢。

这让我相信有更快的替代方案。

【讨论】:

  • 谢谢!能稍微解释一下worker函数中index和new_data的关系吗?
  • @Will.Evo 因此,当您调用 get_data() 时,它将返回一个包含一行和多列的数组(因为它会将其展平)。因此,我在将内容放入队列时使用 enumerate() 并将其放入从中获取信息的索引中。这样我的工作函数就可以知道需要更新 new_data 的哪个索引。工作函数需要知道它正在处理的像素的位置,以便知道更新 new_data 的相同像素索引。这样,如果一个像素的计算速度较慢,它就不会像使用 new_data.append() 那样将其丢弃(它们可能会出现故障)
  • 谢谢你...我能够让代码正常运行(稍作修改以适应我的问题),但由于某种原因,返回的 new_data 数组全为零。不知道发生了什么。我将更新我的问题以包含我正在运行的新代码。当我最初说我正在迭代图像时,实际上我正在迭代由组合和处理的图像组成的数组。我只是将 new_data 更改为在各处初始化为 100 并再次运行。代码中永远不会访问 new_data
  • @Will.Evo 是的,这也是我想要弄清楚的部分 D:我不确定为什么即使它作为参数传递,它似乎也没有改变。更有趣的是,如果在new_data[index] = out_pixel I print(new_data[index]) 行下,它会显示更改的值...我想可能是Pipe 或进程间通信需要的东西,但我不确定
  • @Will.Evo 我修复了我的代码并上传了一个工作版本!
猜你喜欢
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 2015-10-19
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 2016-10-21
相关资源
最近更新 更多