【问题标题】:Elements arrangement in a numpy arraynumpy数组中的元素排列
【发布时间】:2014-07-01 23:12:56
【问题描述】:
import numpy as np

data = np.array([[0, 0, 1, 1, 2, 2],
                 [1, 0, 0, 1, 2, 2],
                 [1, 0, 1, 0, 0, 0],
                 [1, 1, 0, 0, 2, 0]])

我该怎么做?

在 2 x 2 补丁内:

if any element is 2: put 2
if any element is 1: put 1
if all elements are 0: put 0

预期结果是:

np.array([[1, 1, 2],
          [1, 1, 2]])

【问题讨论】:

  • 到目前为止您尝试过什么?就目前而言,这看起来像是您在要求某人为您编写代码。
  • 如果12 都出现在一个块中怎么办?
  • @Lego Stormtroopr '就目前的样子' 是什么意思?抱歉,我的母语不是英语。
  • 看起来您要求我们为您完成工作,因为没有代码。
  • @user2357112 因为第一步给了2个优先级,所以放了2个。

标签: python arrays numpy scipy scikit-learn


【解决方案1】:

如果原始数组很大,并且性能是一个问题,则可以通过操纵原始数组的形状和步幅来将循环下推到 numpy C 代码以创建您所在的窗口作用于:

import numpy as np
from numpy.lib.stride_tricks import as_strided

data = np.array([[0, 0, 1, 1, 2, 2],
                 [1, 0, 0, 1, 2, 2],
                 [1, 0, 1, 0, 0, 0],
                 [1, 1, 0, 0, 2, 0]])

patch_shape = (2,2)
data_shape = np.array(data.shape)

# transform data to a 2x3 array of 2x2 patches/windows

# final shape of the computation on the windows can be calculated with:
# tuple(((data_shape-patch_shape) // patch_shape) + 1)
final_shape = (2,3)

# the shape of the windowed array can be calculated with:
# final_shape + patch_shape
newshape = (2, 3, 2, 2)

# the strides of the windowed array can be calculated with:
# tuple(np.array(data.strides) * patch_shape) + data.strides
newstrides = (48, 8, 24, 4)

# use as_strided to 'transform' the array
patch_array = as_strided(data, shape = newshape, strides = newstrides)

# flatten the windowed array for iteration - dim of 6x2x2
# the number of windows is the product of the 'first' dimensions of the array
# which can be calculated with:
# (np.product(newshape[:-len(patch_shape)])) + (newshape[-len(patch_array):])
dim = (6,2,2)

patch_array = patch_array.reshape(dim)

# perfom computations on the windows and reshape to final dimensions
result = [2 if np.any(patch == 2) else
          1 if np.any(patch == 1) else
          0 for patch in patch_array]
result = np.array(result).reshape(final_shape)

可以在Efficient rolling statistics with NumPy找到用于创建窗口数组的通用一维函数

一个广义的多维函数和一个很好的解释可以在Efficient Overlapping Windows with Numpy找到

【讨论】:

【解决方案2】:

使用来自 scikit-learn 的extract_patches,您可以编写如下(复制和粘贴代码):

import numpy as np
from sklearn.feature_extraction.image import extract_patches

data = np.array([[0, 0, 1, 1, 2, 2],
                 [1, 0, 0, 1, 2, 2],
                 [1, 0, 1, 0, 0, 0],
                 [1, 1, 0, 0, 2, 0]])

patches = extract_patches(data, patch_shape=(2, 2), extraction_step=(2, 2))
output = patches.max(axis=-1).max(axis=-1)

解释extract_patches 让您可以查看数组的补丁,大小为patch_shape,位于extraction_step 的网格上。结果是一个 4D 数组,其中前两个轴索引补丁,最后两个轴索引补丁内的像素。然后我们评估最后两个轴上的最大值以获得每个补丁的最大值。

EDIT这其实和this question有很大关系

【讨论】:

  • 酷 - extract_patches 使用 numpy.lib.stride_tricks.as_strided - github.com/scikit-learn/scikit-learn/blob/master/sklearn/…
  • 确实 - (感谢您的关注,我为当时的贡献感到自豪;))很酷的是,一旦您使用 .reshape((-1,) + patch_shape) .这样你就可以将补丁提取降低到 C 级别,而不是使用例如一个 for 循环。
【解决方案3】:

这是一个相当冗长的单行代码,仅依赖于重塑、转置和沿不同轴取最大值。它也相当快。

data.reshape((-1,2)).max(axis=1).reshape((data.shape[0],-1)).T.reshape((-1,2)).max(axis=1).reshape((data.shape[1]/2,data.shape[0]/2)).T

本质上,它的作用是重塑以在水平方向上以两对为一组取最大值,然后再次随机播放并在垂直方向以两对为一组取最大值,最终给出每个 4 块的最大值,与您想要的输出相匹配。

【讨论】:

  • 如果我没记错的话,这正是@Jaime 在他的评论中进一步提出的建议,并分别采取了步骤。
  • 是的。我错过了他的评论。他的版本要优雅得多。
【解决方案4】:

我不知道你来自哪里得到你的输出,或者你应该离开输出,但你能适应这一点。 P>

import numpy as np

data = np.array([[0, 0, 1, 1, 2, 2],
                 [1, 0, 0, 1, 2, 2],
                 [1, 0, 1, 0, 0, 0],
                 [1, 1, 0, 0, 2, 0]])

def patchValue(i,j):
    return max([data[i][j],
                data[i][j+1],
                data[i+1][j],
                data[i+1][j+1]])

result = np.array([[0, 0, 0],
                   [0, 0, 0]])

for (v,i) in enumerate(range(0,4,2)):
    for (w,j) in enumerate(range(0,6,2)):
        result[v][w] = patchValue(i,j)

print(result)

【讨论】:

  • 使用max是一个绝妙的主意,但您的实现不采取numpy的充分利用。这将做同样在一个更为numpythonic方式:rows, cols = data.shape; result = np.max(data.reshape(rows//2, 2, cols//2, 2), axis=(1, 3)) SPAN>
  • 你需要哪个版本numpy的,以能写axis=(1, 3)?它不会在我的1.6.1工作,但它是一个功能我一直需要的。这将使我的主张更加简洁了。 SPAN>
  • @海梅什么行之间差// 2和行/ 2 跨度>
  • @ eickenberg我认为这是一个1.7的事情。跨度>
  • @阿尔卑斯rows // 2总是返回整数除法,而rows / 2会做浮点除法在Python 3.x或如果您使用from __future__ import division在Python 2.x的。 SPAN>
猜你喜欢
  • 1970-01-01
  • 1970-01-01
  • 2013-01-27
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
相关资源
最近更新 更多