【问题标题】:Shuffle multidimensional array by columns and update list of indexes accordingly按列打乱多维数组并相应地更新索引列表
【发布时间】:2025-12-14 04:05:01
【问题描述】:

给定一个 N rows by M columns 数组,我需要按 columns 对其进行洗牌,同时更新一个单独的(唯一)列索引列表以指向新的洗牌元素的位置。

例如取下面的(3, 5)数组

a = [[ 0.15337424  0.21176979  0.19846229  0.5245618   0.24452392]
     [ 0.17460481  0.45727362  0.26914808  0.81620202  0.8898504 ]
     [ 0.50104826  0.22457154  0.24044079  0.09524352  0.95904348]]

以及列索引列表:

idxs = [0 3 4]

如果我按列对数组进行洗牌,它看起来像这样:

a = [[ 0.24452392  0.19846229  0.5245618   0.21176979  0.15337424]
     [ 0.8898504   0.26914808  0.81620202  0.45727362  0.17460481]
     [ 0.95904348  0.24044079  0.09524352  0.22457154  0.50104826]]

索引数组应该修改成如下所示:

idxs = [4 2 0]

我可以通过在洗牌之前和之后转置它来按列对数组进行洗牌(参见下面的代码),但我不确定如何更新索引列表。整个过程需要尽可能快,因为它将使用新数组执行数百万次。

import numpy as np

def getData():
    # Array of (N, M) dimensions
    N, M = 10, 500
    a = np.random.random((N, M))

    # List of unique column indexes in a.
    # This list could be empty, or it could have a length of 'M'
    # (ie: contain all the indexes in the range of 'a').
    P = int(M * np.random.uniform())
    idxs = np.arange(0, M)
    np.random.shuffle(idxs)
    idxs = idxs[:P]

    return a, idxs


a, idxs = getData()

# Shuffle a by columns
b = a.T
np.random.shuffle(b)
a = b.T

# Update the 'idxs' list?

【问题讨论】:

    标签: python arrays performance numpy random


    【解决方案1】:

    使用np.random.permutation 获取列索引的随机排列 -

    col_idx = np.random.permutation(a.shape[1])
    

    获取打乱的输入数组 -

    shuffled_a = a[:,col_idx]
    

    然后,只需将col_idx 的排序索引与idxs 索引到追溯版本 -

    shuffled_idxs = col_idx.argsort()[idxs]
    

    示例运行 -

    In [236]: a # input array
    Out[236]: 
    array([[ 0.1534,  0.2118,  0.1985,  0.5246,  0.2445],
           [ 0.1746,  0.4573,  0.2691,  0.8162,  0.8899],
           [ 0.501 ,  0.2246,  0.2404,  0.0952,  0.959 ]])
    
    In [237]: col_idx = np.random.permutation(a.shape[1])
    
    # Let's use the sample permuted column indices to verify desired o/p
    In [238]: col_idx = np.array([4,2,3,1,0])
    
    In [239]: shuffled_a = a[:,col_idx]
    
    In [240]: shuffled_a
    Out[240]: 
    array([[ 0.2445,  0.1985,  0.5246,  0.2118,  0.1534],
           [ 0.8899,  0.2691,  0.8162,  0.4573,  0.1746],
           [ 0.959 ,  0.2404,  0.0952,  0.2246,  0.501 ]])
    
    In [241]: col_idx.argsort()[idxs]
    Out[241]: array([4, 2, 0])
    

    【讨论】:

    • 谢谢迪瓦卡!我正在尝试提高函数的性能(正如您可能从我之前的问题中猜到的那样),而您在此处给出的答案 *.com/a/46079837/1391441 仍然会产生最快的结果。
    【解决方案2】:
    original_index = range(a.shape[1])
    permutation_series = pd.Series(original_index)
    permutation_series.index = np.random.permutation(original_index)
    new_idx = permutation_series[old_idx]
    a = a[:,permutation_series.index]
    

    【讨论】:

    • 请解释您的代码与 OP 的不同之处以及如何解决问题/回答他们的问题。我推荐本指南以创建有用的答案 *.com/help/how-to-answer
    【解决方案3】:

    数据数组必须使用索引数组打乱,所以首先打乱索引数组并使用它来打乱数据数组。

    import numpy as np
    
    def getData():
        # Array of (N, M) dimensions
        a = np.arange(15).reshape(3, 5)
        # [[ 0  1  2  3  4]
        # [ 5  6  7  8  9]
        # [10 11 12 13 14]]
        idxs = np.arange(a.shape[0]) #  [0 1 2]
        return a, idxs
    
    a, idxs = getData()
    
    # Shuffle a by columns
    b = a.T
    # [[ 0  5 10]
    # [ 1  6 11]
    # [ 2  7 12]
    # [ 3  8 13]
    # [ 4  9 14]]
    
    np.random.shuffle(idxs)  #  [2 0 1]
    a = b[:, idxs]
    
    # [[10  0  5]
    # [11  1  6]
    # [12  2  7]
    # [13  3  8]
    # [14  4  9]]
    

    所以如果你想洗牌任何其他数组比如 x 以匹配数组 a 的洗牌,idxs 将很有用

    【讨论】: