【问题标题】:AVX version didn't as fast as expectedAVX 版本没有预期的那么快
【发布时间】:2017-07-03 10:57:12
【问题描述】:

我正在尝试将函数转换为 AVX 版本。该函数本身基本上只是比较浮点数并返回真/假取决于计算。

这是原来的功能:

bool testSingle(float* thisFloat, float* otherFloat)
{   
    for (unsigned int k = 0; k < COL_COUNT / 2; k++)
    {

        if (thisFloat[k] < -otherFloat[COL_COUNT / 2 + k] || -thisFloat[COL_COUNT / 2 + k] > otherFloat[k]) 
        {
            return true;
        }
    }

    return false;
}

而且,这是 AVX 版本

__m256 testAVX(float* thisFloat, __m256* otherFloatInAVX)
{
    __m256 vTemp1;
    __m256 vTemp2;
    __m256 vTempResult;
    __m256 vEndResult = _mm256_set1_ps(0.0f);

    for (unsigned int k = 0; k < COL_COUNT / 2; k++)
    {

        vTemp1 = _mm256_cmp_ps(_mm256_set1_ps(thisFloat[k]), otherFloatInAVX[COL_COUNT / 2 + k], _CMP_LT_OQ);

        vTemp2 = _mm256_cmp_ps(_mm256_set1_ps(-thisFloat[COL_COUNT / 2 + k]), otherFloatInAVX[k], _CMP_GT_OQ);

        vTempResult = _mm256_or_ps(vTemp1, vTemp2);
        vEndResult = _mm256_or_ps(vTempResult, vEndResult);
        if (_mm256_movemask_ps(vEndResult) == 255)
        {
            break;
        }

    }

    return vEndResult;

}

这是完整的代码。我在开始时生成了一些随机浮点数并将其保存到 AVX 以便在 AVX 版本中进行计算。变量 thisFloat 中的值将与 otherFloat1, otherFloat2,...,otherFloat8 进行比较。

#define ROW_COUNT 1000000
#define COL_COUNT 46

float randomNumberFloat(float Min, float Max)
{
    return ((float(rand()) / float(RAND_MAX)) * (Max - Min)) + Min;
}

int main(int argc, char** argv)
{

    float** thisFloat = new float*[ROW_COUNT];
    for (int i = 0; i < ROW_COUNT; ++i)
        thisFloat[i] = new float[COL_COUNT];

    float** otherFloat1 = new float*[ROW_COUNT];
    for (int i = 0; i < ROW_COUNT; ++i)
        otherFloat1[i] = new float[COL_COUNT];

    float** otherFloat2 = new float*[ROW_COUNT];
    for (int i = 0; i < ROW_COUNT; ++i)
        otherFloat2[i] = new float[COL_COUNT];

    float** otherFloat3 = new float*[ROW_COUNT];
    for (int i = 0; i < ROW_COUNT; ++i)
        otherFloat3[i] = new float[COL_COUNT];

    float** otherFloat4 = new float*[ROW_COUNT];
    for (int i = 0; i < ROW_COUNT; ++i)
        otherFloat4[i] = new float[COL_COUNT];

    float** otherFloat5 = new float*[ROW_COUNT];
    for (int i = 0; i < ROW_COUNT; ++i)
        otherFloat5[i] = new float[COL_COUNT];

    float** otherFloat6 = new float*[ROW_COUNT];
    for (int i = 0; i < ROW_COUNT; ++i)
        otherFloat6[i] = new float[COL_COUNT];

    float** otherFloat7 = new float*[ROW_COUNT];
    for (int i = 0; i < ROW_COUNT; ++i)
        otherFloat7[i] = new float[COL_COUNT];

    float** otherFloat8 = new float*[ROW_COUNT];
    for (int i = 0; i < ROW_COUNT; ++i)
        otherFloat8[i] = new float[COL_COUNT];

    // save to AVX
    __m256** otherFloatInAVX = new __m256*[ROW_COUNT];
    for (int i = 0; i < ROW_COUNT; ++i)
        otherFloatInAVX[i] = new __m256[COL_COUNT];

    // variable for results
    unsigned int* resultsSingle = new unsigned int[ROW_COUNT];
    __m256* resultsAVX = new __m256[ROW_COUNT];


    // Generate Random Values
    for (unsigned int i = 0; i < ROW_COUNT; i++)
    {
        for (unsigned int j = 0; j < COL_COUNT; j++)
        {
            thisFloat[i][j] = randomNumberFloat(-1000.0f, 1000.0f);
            otherFloat1[i][j] = randomNumberFloat(-1000.0f, 1000.0f);
            otherFloat2[i][j] = randomNumberFloat(-1000.0f, 1000.0f);
            otherFloat3[i][j] = randomNumberFloat(-1000.0f, 1000.0f);
            otherFloat4[i][j] = randomNumberFloat(-1000.0f, 1000.0f);
            otherFloat5[i][j] = randomNumberFloat(-1000.0f, 1000.0f);
            otherFloat6[i][j] = randomNumberFloat(-1000.0f, 1000.0f);
            otherFloat7[i][j] = randomNumberFloat(-1000.0f, 1000.0f);
            otherFloat8[i][j] = randomNumberFloat(-1000.0f, 1000.0f);

        }

        for (unsigned int j = 0; j < COL_COUNT / 2; j++)
        {
            otherFloatInAVX[i][j] = _mm256_setr_ps(otherFloat1[i][j], otherFloat2[i][j], otherFloat3[i][j], otherFloat4[i][j], otherFloat5[i][j], otherFloat6[i][j], otherFloat7[i][j], otherFloat8[i][j]);
            otherFloatInAVX[i][COL_COUNT / 2 + j] = _mm256_setr_ps(-otherFloat1[i][j], -otherFloat2[i][j], -otherFloat3[i][j], -otherFloat4[i][j], -otherFloat5[i][j], -otherFloat6[i][j], -otherFloat7[i][j], -otherFloat8[i][j]);
        }
    }

    // do normal test
    auto start_normal = std::chrono::high_resolution_clock::now();
    for (unsigned int i = 0; i < ROW_COUNT; i++)
    {
        resultsSingle[i] = testSingle(thisFloat[i], otherFloat1[i]);
        resultsSingle[i] = testSingle(thisFloat[i], otherFloat2[i]);
        resultsSingle[i] = testSingle(thisFloat[i], otherFloat3[i]);
        resultsSingle[i] = testSingle(thisFloat[i], otherFloat4[i]);
        resultsSingle[i] = testSingle(thisFloat[i], otherFloat5[i]);
        resultsSingle[i] = testSingle(thisFloat[i], otherFloat6[i]);
        resultsSingle[i] = testSingle(thisFloat[i], otherFloat7[i]);
        resultsSingle[i] = testSingle(thisFloat[i], otherFloat8[i]);
    }
    auto end_normal = std::chrono::high_resolution_clock::now();

    auto duration_normal = std::chrono::duration_cast<std::chrono::milliseconds>(end_normal - start_normal);
    std::cout << "Duration of normal test: " << duration_normal.count() << " ms \n";

    // do AVX test

    auto start_avx = std::chrono::high_resolution_clock::now();
    for (unsigned int i = 0; i < ROW_COUNT; i++)
    {
        resultsAVX[i] = testAVX(thisFloat[i], otherFloatInAVX[i]);
    }
    auto end_avx = std::chrono::high_resolution_clock::now();


    auto duration_avx = std::chrono::duration_cast<std::chrono::milliseconds>(end_avx - start_avx);
    std::cout << "Duration of AVX test: " << duration_avx.count() << " ms";
return 0;
}

然后,我测量了两者的运行时间并得到了

Duration of normal test: 290 ms
Duration of AVX test: 159 ms

AVX 版本比原始版本快 1.82 倍。

是否还有可能改进 AVX 版本?还是我做错了 AVX?由于我同时进行八次计算,因此我预计它可能会快 5-6 倍。

【问题讨论】:

  • 您是否检查过您是否刚刚陷入内存/缓存带宽?
  • SIMD 性能的关键是在数据进入寄存器后对其进行大量计算,以帮助掩盖从内存中加载大量数据的开销。否则,你只是在写一个非常复杂的memcpy
  • 我认为您的 AVX 例程有一个错误 - 它不会在与标量代码相同的条件下“提前退出” - 您需要将测试从 == 255 更改为 != 0。 (注意:现在是一大早,我还没有喝咖啡,但随便看看,这似乎是个错误。)

标签: simd avx


【解决方案1】:

我认为 AVX 版本必须具有与标量相同的 API(所以我几乎没有更改它):

bool testAVX(float * thisFloat, float * otherFloat)
{
    size_t k = 0, size = COL_COUNT / 2, sizeAligned = size / 8 * 8;

    __m256 zero = _mm256_set1_ps(0);
    for (; k < sizeAligned; k += 8)
    {
        __m256 _thisFloat1 = _mm256_loadu_ps(thisFloat + k);
        __m256 _thisFloat2 = _mm256_loadu_ps(thisFloat + k + size);
        __m256 _otherFloat1 = _mm256_loadu_ps(otherFloat + k);
        __m256 _otherFloat2 = _mm256_loadu_ps(otherFloat + k + size);

        __m256 compareMask1 = _mm256_cmp_ps(_thisFloat1, _mm256_sub_ps(zero, _otherFloat2), _CMP_LT_OQ);
        __m256 compareMask2 = _mm256_cmp_ps(_mm256_sub_ps(zero, _thisFloat2), _otherFloat1, _CMP_GT_OQ);

        __m256 compareMask = _mm256_or_ps(compareMask1, compareMask2);

        if (!_mm256_testz_ps(compareMask, compareMask))
            return true;
    }

    for (; k < size; k++)
    {
        if (thisFloat[k] < -otherFloat[size + k] || -thisFloat[size + k] > otherFloat[k])
            return true;
    }
    return false;
}

所以比较这些版本会更容易。

【讨论】:

    猜你喜欢
    • 2016-08-19
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2019-05-12
    • 1970-01-01
    • 2021-09-07
    • 2015-09-11
    • 2021-09-02
    相关资源
    最近更新 更多