【问题标题】:Numpy 3D array Indexing : Works for 2D, how to do for 3D?Numpy 3D 数组索引:适用于 2D,如何处理 3D?
【发布时间】:2019-05-15 15:24:56
【问题描述】:

我有 3 个numpy 数组,如下所示。

import numpy as np
key_idx = np.array([1, 2, 1])  # both have same shape
out_idx = np.array([0, 3, 0])
max_out = out_idx.max()
output = np.zeros(shape=(len(key_idx), max_out + 1))

# output = 
# array([[0., 0., 0., 0.],
#        [0., 0., 0., 0.],
#        [0., 0., 0., 0.]])

我想增加索引给出的值,如下所示:

key_idx = key_idx[np.newaxis, :] # convert to 2D
out_idx = out_idx[np.newaxis, :]
idx = (key_idx, out_idx)
np.add.at(output, idx, 1)

# output =
# array([[0., 0., 0., 0.],
#        [2., 0., 0., 0.],
#        [0., 0., 0., 1.]])

然后应用如下变换:

np.sum(np.amax(output, axis=1))
#3.0

但现在我想为 3D 输出数组执行此操作,其中 key_idx2D 是一个 2D 数组,第一个维度表示 table_id。请参考下图:

我尝试了什么

key_idx2D = np.array([[1, 2, 1], [2, 2, 2]])
output3D = np.zeros(shape=(key_idx2D.shape[0], len(key_idx), max_out + 1))
key_idx2D = key_idx[np.newaxis, :] # convert to 3D
out_idx = out_idx[np.newaxis, :]
idx3D = (key_idx2D, out_idx)
np.add.at(output3D, idx3D, 1)

#IndexError: index 2 is out of bounds for axis 0 with size 2

如何为 3D 案例执行此操作?任何帮助表示赞赏。它应该为每个table_id 返回一个值数组,如图所示。

注意:我可以用循环来做,但是会很慢。我需要更快的东西。

编辑key_idx2Daxis 0 = table_idaxis 1 = key_idout_idxaxis 0 = out_idkey_idx2Dout_idx 都只包含 output ndarray 的索引,这些索引需要在它们上应用 np.add.at()。 我已更新该图以澄清这一点。

【问题讨论】:

    标签: python arrays numpy


    【解决方案1】:

    如果有人觉得它有用,我会发布答案。

    key_idx2D = np.array([[1, 2, 1], [2, 2, 2]])
    output3D = np.zeros(shape=(key_idx2D.shape[0], key_idx2D.shape[1], max_out + 1))
    output3D.shape
    #(2, 3, 4)
    

    我只需要为第一个轴(即轴 0)创建一个索引数组。

    table_idx = np.array([0, 1]).reshape(-1, 1)
    out_idx = np.array([0, 3, 0])
    table_idx.shape, key_idx2D.shape, out_idx.shape
    #((2, 1), (2, 3), (3,))
    

    然后将所有索引数组以元组的形式发送给np.add.at

    np.add.at(output3D, (table_idx, key_idx2D, out_idx), 1)
    output3D
    
    # array([[[0., 0., 0., 0.],
    #         [2., 0., 0., 0.],
    #         [0., 0., 0., 1.]],
    #        [[0., 0., 0., 0.],
    #         [0., 0., 0., 0.],
    #         [2., 0., 0., 1.]]])
    
    np.sum(np.amax(output3D, axis=2), axis=1)
    #array([3., 2.])
    

    【讨论】:

      猜你喜欢
      • 2019-12-23
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2018-03-08
      • 2012-12-09
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多