【问题标题】:How to delete sub-arrays from multidimensional numpy arrays using a condition?如何使用条件从多维 numpy 数组中删除子数组?
【发布时间】:2021-03-26 05:16:14
【问题描述】:

我正在尝试使用条件从多维 numpy 数组中删除子数组。在这个例子中,我想删除所有包含值 999 的子数组。这是我失败的尝试之一:

a = np.array([[[1,2,3], [1,2,3]],
              [[999,5,6], [4,5,6]],
              [[999,8,9], [7,999,9]]
              ])

for i in range(0,len(a)):
    if 999 in a[i]:
        np.delete(a, i, 0)

我想要的结果是:

array([[1,2,3], [1,2,3]])

这只是一个小例子,它应该可以帮助我理解一个更大的问题,它看起来像这样:

# win_list_hyper.shape -> (1449168, 233)
# win_list_multi.shape -> (1449168, 12, 5, 5)

win_list_hyper = np.where(win_list_hyper <= 0, -3.40282e+38, win_list_hyper)
win_list_multi = np.where(win_list_multi <= 0, -3.40282e+38, win_list_multi)


# fail!:
for i in range(0,len(win_list_multi)):
    
    if -3.40282e+38 in win_list_multi[i] or -3.40282e+38 in win_list_hyper[i]:
        
        np.delete(win_list_multi, i, 0)
        np.delete(win_list_hyper, i, 0)

(顺便说一句。如果您知道如何提高效率,请告诉我!)

【问题讨论】:

    标签: python arrays numpy for-loop conditional-statements


    【解决方案1】:

    您的第一次尝试失败,因为np.delete 没有就地操作(即它没有修改数组,而是返回一个新数组)。此外,在迭代数组时从数组中删除元素通常不是一个好主意(除非您知道自己在做什么)。

    你可以像下面这样使用np.where

    inds = np.where(a == 999)  # get indices where value equals 999
    np.delete(a, inds[0], axis=0)   # delete along first dimension
    

    结果:

    array([[[1, 2, 3],
            [1, 2, 3]]])
    

    【讨论】:

    • 感谢您的帮助。我不得不意识到我的例子并不好。我不能那么容易地将您的解决方案转换为我的数据。你对我帮助很大。您的解决方案给了我掩盖我的数组的想法。我希望我的解决方案不会弄乱我的数据(例如随机播放)。
    【解决方案2】:

    Jussi Nurminen 解决方案适用于我的示例,但我不得不意识到我的示例并不好。我不能那么容易地将给定的解决方案转换为我的数据。 Jussi Nurminen 解决方案对我帮助很大,因为它给了我屏蔽数组的想法。我希望我的解决方案不会弄乱我的数据(例如随机播放)。对于那些有兴趣的人......

    ...这是我的(坏)示例的解决方案:

    a = np.array([[[1,2,3], [1,2,3]],[[999,5,6], [4,5,6]],[[999,8,9], [7,999,9]]])
    
    a_mask = []
    
    for i in range(0,len(a)):
        if 999 in a[i]:
            x = 0
        else: x = 1
        
        a_mask.append(x)
     
    a_mask = np.asarray(a_mask)
        
    inds = np.where(a_mask == 0)
            
    b = np.delete(a, inds, axis=0) 
    
    

    ...这就是我的数据转换后的样子:

    
    # win_list_multi.shape -> (1449168, 12, 5, 5)
    # win_list_hyper.shape -> (1449168, 233
    
    
    win_list_multi = np.where(win_list_multi <= 0, -1, win_list_multi)
    
    win_list_hyper = np.where(win_list_hyper <= 0, -1, win_list_hyper)
    
    
    win_list_multi_mask = []
    
    for i in range(0,len(win_list_multi)):
        if -1 in win_list_multi[i]:
            x = 0
        else: x = 1
        
        win_list_multi_mask.append(x)
    
    win_list_multi_mask = np.asarray(win_list_multi_mask)
    
    
    
    win_list_hyper_mask = []
    
    for i in range(0,len(win_list_hyper)):
        if -1 in win_list_hyper[i]:
            x = 0
        else: x = 1
        
        win_list_hyper_mask.append(x)
    
    win_list_hyper_mask = np.asarray(win_list_hyper_mask)
    
    
    
    inds = np.where((win_list_multi_mask == 0) | (win_list_hyper_mask == 0))
    
    
    win_list_multi_nd = np.delete(win_list_multi, inds, axis=0) 
    win_list_hyper_nd = np.delete(win_list_hyper, inds, axis=0) 
    
    # win_list_multi_nd.shape -> (9679, 12, 5, 5)
    # win_list_hyper_nd.shape -> (9679, 233)
    
    

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2011-03-26
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2016-04-25
      • 2017-12-02
      相关资源
      最近更新 更多