如果您有多个 GPU,则可以使用 PyTorch 的 DataParallel 在所有 GPU 上分配计算。它将在 GPU 之间拆分(并行化)矩阵 C_gpu 的列的乘法。
方法如下:
首先,导入模块并准备矩阵:
import torch
import torch.nn as nn
A_gpu = torch.from_numpy(A).float()
B_gpu = torch.from_numpy(B).float()
C_gpu = torch.from_numpy(C).float()
接下来,创建一个没有偏差的Linear“层”。这一层所做的正是矩阵乘法。输入大小将是C_gpu 每一列的大小,输出大小将是结果每一列的大小。
mat_mult = nn.Linear(in_features=C_gpu.shape[0],out_features=A_gpu.shape[0],bias=False)
将层的矩阵(=权重)设置为A_gpu @ B_gpu,这是一个无需并行化即可快速计算的小矩阵(尽管您也可以将其并行化)。
mat_mult.weight.data = A_gpu @ B_gpu
将图层转换为 DataParallel 实例。这意味着它将沿“批处理”维度自动并行计算。参数 device_ids 是您的 GPU 的索引列表(在您的情况下为 4 个)。
mat_mult_gpu = nn.DataParallel(mat_mult,device_ids=[0,1,2,3]).to('cuda:0')
现在您可以将矩阵C_gpu 输入层,计算将沿其大维度并行:
D_gpu = mat_mult_gpu(C_gpu.t())
重要提示:在撰写此答案时,我无法访问多个 GPU 来实际测试此提议的解决方案。如果有读者确认它确实有效(甚至更好 - 计时解决方案并与单个 GPU 进行比较),我将不胜感激
EDIT1:我现在在多个 GPU(四个 Nvidia Tesla P100)上尝试了这段代码,结果发现它给出了内存不足的错误。不过,我将在此处保留此解决方案作为参考,因为它确实适用于最大约 400K(而不是 3.6M)的大小。
此外,如果您将C 分成更小的块,将每个块送入mat_mult_gpu,然后在 CPU 上连接结果,则此解决方案仍适用于大小为 3.6M 的情况。请注意,您需要大量 CPU 内存才能工作,因为结果的大小为 3K-by-3.6M,在 fp32 中大约需要 40GB。 (或者,您可以将每个块保存到磁盘而不连接块)。