【问题标题】:How to implement cross-validation efficiently in Matlab using parfor如何使用 parfor 在 Matlab 中有效地实现交叉验证
【发布时间】:2015-08-26 16:30:50
【问题描述】:

交叉验证是embarrassingly parallel 问题之一。

假设您想要交叉验证线性回归模型。假设设计矩阵X 具有维度n-by-p 并且连续结果yn-by-1 向量。进一步假设foldMatrixn-by-k 矩阵logicals。每列代表一个分区:1 表示观察用于训练,0 表示用于验证。这种训练验证技巧重复k 次,以减少generalization error (GE) 估计的方差。

我(天真的?)在 Matlab 中进行并行交叉验证的方法如下所示:

matlabpool

GE = nan(k,1);  

parfor i = 1:k

   trainIndices = foldMatrix(:, i);
   b = X(trainIndices, :)\y(trainIndices, :);

   GE(i) = mean( (y(~trainIndices, :)  - X(~trainIndices, :)*b).^2 );

end

mspe = mean(GE);

当您这样做时,Matlab 将抱怨“X 已被索引但未在 PARFOR 循环中切片。这可能会导致不必要的通信开销”(y 的同义词)。

我的问题是:

  • 编辑:有什么方法可以在 Matlab 中使用并行实现来加速交叉验证?
  • 是否有一种高效/优雅的方法来解决Xy 变量未被切片的问题?

我觉得不太优雅的两个“解决方案”是:

  1. 忽略唠叨。对于小问题,比如p < 100n < 3000k < 40,顺序实现比并行实现要快。

  2. 在元胞数组或 3 维矩阵中“显式”地预分配训练验证分区。导致 k 数据的完整副本(Xy)。

【问题讨论】:

    标签: matlab cross-validation parfor


    【解决方案1】:

    忽略唠叨。

    代码分析器警告只是为了确保您知道自己在做什么。许多并行问题只对单独的数据块做一件事,因此 TMW 希望您知道您正在重用一些数据。

    这样想:数据必须以某种方式到达正确的处理器。您可以在内存中复制它,或者让处理器每次都请求同一块内存。但是,这种重复需要时间,这是我们不想要的。

    这是一个我们可以用来检查的小脚本:

    n = 20000;
    p = 6;
    k = 1000;
    
    N = [30 * (1:9), 300 * (1:9), 3000 * (1:10)];
    
    timing = nan(length(N), 2);
    
    gcp;
    
    for iN = 1:length(N)
        n = N(iN);
    
        X = rand(n, p);
        beta = rand(p, 1);
        y= X * beta;
    
        foldMatrix = logical(rand(n,k) > 0.5);
    
        GE = nan(k,1);
    
        tic;
        parfor i = 1:k
            trainIndices = foldMatrix(:, i);
            b = X(trainIndices, :)\y(trainIndices, :); %#ok<*PFBNS>
    
            GE(i) = mean( (y(~trainIndices, :)  - X(~trainIndices, :)*b).^2 );
        end
        timing(iN, 1) = toc;
    
        tic;
        X_rep = repmat(X, [1 1 k]);
        y_rep = repmat(y, [1 1 k]);
        parfor i = 1:k
            trainIndices = foldMatrix(:, i);
            b = X(trainIndices, :)\y(trainIndices, :);
    
            GE(i) = mean( (y(~trainIndices, :)  - X(~trainIndices, :)*b).^2 );
        end
        timing(iN, 2) = toc;
    end
    

    如果您针对 N 绘制时间,plot(N, timing):

    蓝线是代码分析器警告,橙线是repmat。在 k 维度上也是如此:

    所以省去计算并忽略警告。您可以在出现警告的行的末尾添加%#ok&lt;PFBNS&gt;,或在文件中的任何位置添加%#ok&lt;*PFBNS&gt;,以使它们停止显示。

    【讨论】:

      猜你喜欢
      • 2015-09-06
      • 2023-03-04
      • 2011-01-24
      • 1970-01-01
      • 2017-05-04
      • 1970-01-01
      • 2020-05-30
      • 2012-02-04
      • 2018-06-30
      相关资源
      最近更新 更多