【问题标题】:How to split numpy array in batches?如何分批拆分numpy数组?
【发布时间】:2015-04-14 22:26:13
【问题描述】:

听起来很简单,我不知道该怎么做。

我有 numpy 的二维数组

X = (1783,30)

我想将它们分成 64 个批次。我这样编写代码。

batches = abs(len(X) / BATCH_SIZE ) + 1  // It gives 28

我正在尝试批量预测结果。所以我用零填充批次,然后用预测结果覆盖它们。

predicted = []

for b in xrange(batches): 

 data4D = np.zeros([BATCH_SIZE,1,96,96]) #create 4D array, first value is batch_size, last number of inputs
 data4DL = np.zeros([BATCH_SIZE,1,1,1]) # need to create 4D array as output, first value is  batch_size, last number of outputs
 data4D[0:BATCH_SIZE,:] = X[b*BATCH_SIZE:b*BATCH_SIZE+BATCH_SIZE,:] # fill value of input xtrain

 #predict
 #print [(k, v[0].data.shape) for k, v in net.params.items()]
 net.set_input_arrays(data4D.astype(np.float32),data4DL.astype(np.float32))
 pred = net.forward()
 print 'batch ', b
 predicted.append(pred['ip1'])

print 'Total in Batches ', data4D.shape, batches
print 'Final Output: ', predicted

但是在最后批号 28 中,只有 55 个元素而不是 64 个(总元素 1783 个),它给出了

ValueError: could not broadcast input array from shape (55,1,96,96) into shape (64,1,96,96)

有什么办法解决这个问题?

PS:网络预测需要精确的批大小为 64 才能预测。

【问题讨论】:

  • 你的问题我不清楚(考虑到浏览量和没有答案,我不是唯一一个)。 1)。 net 来自哪个模块? 2) 你有一个二维数组 X。你想处理行 0:64,然后是 64:2*64,然后是 2*64:3*64,依此类推。你知道 1783 不是 64 的倍数吗?无论如何,这就是错误的来源。尝试更明确地说明您想要什么,可能会将自己简化为一个更简单的示例,例如 5x4。

标签: python numpy


【解决方案1】:

我也不太明白你的问题,尤其是 X 的样子。 如果您想创建数组大小相等的子组,请尝试以下操作:

def group_list(l, group_size):
    """
    :param l:           list
    :param group_size:  size of each group
    :return:            Yields successive group-sized lists from l.
    """
    for i in xrange(0, len(l), group_size):
        yield l[i:i+group_size]

【讨论】:

  • 网络只能对64个批次的数据进行预测。因此,只要批量大小为 64,它就需要虚拟数据
【解决方案2】:

我找到了一种解决批次问题的简单方法,方法是生成虚拟对象,然后填充必要的数据。

data = np.zeros(batches*BATCH_SIZE,1,96,96)
// gives dummy  28*64,1,96,96

此代码将完全加载 64 批大小的数据。最后一批最后会有虚拟零,但没关系:)

pred = []
for b in batches:
 data4D[0:BATCH_SIZE,:] = data[b*BATCH_SIZE:b*BATCH_SIZE+BATCH_SIZE,:]
 pred = net.predict(data4D)
 pred.append(pred)

output =  pred[:1783] // first 1783 slice

最后,我从总共 28*64 中切出 1783 个元素。这对我有用,但我相信有很多方法。

【讨论】:

    【解决方案3】:

    这可以使用 numpy 的as_strided 来实现。

    from numpy.lib.stride_tricks import as_strided
    def batch_data(test, batch_size):
        m,n = test.shape
        S = test.itemsize
        if not batch_size:
            batch_size = m
        count_batches = m//batch_size
        # Batches which can be covered fully
        test_batches = as_strided(test, shape=(count_batches, batch_size, n), strides=(batch_size*n*S,n*S,S)).copy()
        covered = count_batches*batch_size
        if covered < m:
            rest = test[covered:,:]
            rm, rn = rest.shape
            mismatch = batch_size - rm
            last_batch = np.vstack((rest,np.zeros((mismatch,rn)))).reshape(1,-1,n)
            return np.vstack((test_batches,last_batch))
        return test_batches
    

    【讨论】:

      【解决方案4】:

      data4D[0:BATCH_SIZE,:] 应该是 data4D[b*BATCH_SIZE:b*BATCH_SIZE+BATCH_SIZE, :]

      【讨论】:

      • 你能解释一下你的答案吗?
      • 那不行,网络会取 4d 数据,批大小正好是 64。如果输入数组不相等,网络模型会抛出错误。
      猜你喜欢
      • 1970-01-01
      • 1970-01-01
      • 2014-03-02
      • 1970-01-01
      • 2017-06-14
      • 2015-11-19
      • 1970-01-01
      • 1970-01-01
      • 2013-06-27
      相关资源
      最近更新 更多