【问题标题】:Efficiently sample vectors from numpy ndarray从 numpy ndarray 中有效地采样向量
【发布时间】:2017-12-07 13:13:34
【问题描述】:

我有一个多维 numpy 数组 Xshape: (B, dim, H, W) 我想从 X 中随机抽样 k dim 维向量。
我可以从形状为(B, 1, H, W)msk 获取样本索引:

sIdx = random.sample((msk.flat>=0).nonzero()[0], k) 

使用 的等效采样代码为:

sIdx = np.random.choice((msk.flat>=0).nonzero()[0], replace=False, size=(k,))

但是我怎样才能有效地根据“平坦”采样索引sIdxX 进行切片?
也就是msk的随机抽样和X的切片有没有一种有效的方法?

【问题讨论】:

  • random 是否来自模块 random?最终输出的形状是什么?
  • @Divakar 是的,sample 来自 random 模块。采样可以替换为sIdx = np.random.choice((msk.flat>=0).nonzero()[0], replace=False, size=(k,))
  • @Divakar 我期望的最终输出是shapek-by-dim
  • @NilsWerner 谢谢你的回答。我尽量避免reshapetranspose

标签: numpy python numpy multidimensional-array random slice


【解决方案1】:

从展平的索引中使用np.unravel_index 获取其余三个轴的相应索引,然后沿着这些轴简单地索引到X 以获得最终输出,就像这样 -

I,J,K = np.unravel_index(sIdx, (B, H, W))
out = X[I,:,J,K]

【讨论】:

  • 谢谢!有没有办法“融合”“平面”采样和unravel,使这个过程更加高效?也就是说,有没有办法直接对“解开”的索引进行采样?
  • @Shai 我没有看到任何其他方式。我们需要flat 来生成采样的线性索引。现在,flat 只是一个视图,因此不需要额外的内存,所以应该非常高效。
  • 只是试图写一个答案而不使索引变平,是的,我什至不想提交,因为它太混乱和令人困惑,根本不会提高性能。
  • @DanielF 好吧,无论如何感谢您的努力!你会考虑发布一个“否定”的结果吗?
  • 一开始是sIdx = random.sample(list(np.transpose((msk >= 0).nonzero())), k) ,然后我就迷路了。问题是你有 4 个索引进入,只需要 3 个索引出来,扁平化是一种更好的方法。
猜你喜欢
  • 1970-01-01
  • 1970-01-01
  • 2020-09-30
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
相关资源
最近更新 更多