【问题标题】:Efficient way of reshaping a ndarray using sliding window without to use too much memory使用滑动窗口重塑 ndarray 而不使用太多内存的有效方法
【发布时间】:2021-03-25 23:27:03
【问题描述】:

我必须通过对其应用两个滑动窗口来将 [17205, 21] 的 ndarray 重塑为 [17011, 96, 100, 21]。

In: arr
Out: [[ 8.  0.  0. -0.  0.  0.  8.  8.  0.  0.  0.  0.  8.  7.  6.  9.  9.  1.
   1.  1.  2.]
 [ 8.  0.  0. -0.  0.  0.  8.  8.  0.  0.  0.  0.  8.  7.  5.  9.  8.  2.
   1.  1.  2.]
.
.
.
 [ 8.  0.  0. -0.  0.  0.  8.  8.  0.  0.  0.  0.  8.  7.  5.  9.  8.  3.
   1.  1.  2.]]

我的解决方案是对其应用两次滑动窗口。 然后我两次应用以下方法:

def separate_multi(sequences, n_steps):
    X = list()
    for i in range(len(sequences)):
       # find the end of this pattern
       end_ix = i + n_steps
       # check if we are beyond the dataset
       if end_ix > len(sequences):
           break
            # gather input and output parts of the pattern
       seq_x = sequences[i:end_ix, :]           
       X.append(seq_x)
       return np.array(X)

给出[17106, 100, 21] 的形状,然后再次使用n_step=96,给出[17011, 96, 100, 21] 的形状。

DRAWBACK:它将整个数据存储在内存中,这会出错:

MemoryError: Unable to allocate 24.3 GiB for an array with shape (17011, 96, 100, 20) and data type float64 

一个可能的解决方案:

import tensorflow as tf
df = tf.data.Dataset.from_tensor_slices(df)
df = df.window(100, shift=1, stride=1, drop_remainder=True)
df = df.window(96, shift=1, stride=1, drop_remainder=True)

但是,它并没有给我想要的输出,因为“它会生成嵌套窗口的数据集”,正如 here 所说的那样。

有什么想法吗?谢谢

【问题讨论】:

    标签: multidimensional-array memory-leaks reshape sliding-window tf.data.dataset


    【解决方案1】:

    我找到了我的问题的解决方案:

    主要问题不是分两步重塑数据,而是我通过重塑数据形成的对象的大小。因此,解决方案是将输入数组分解为多个部分。为此,我设计了以下功能:

    def split_chunks(sequence, chunk=3000):
        list_seq = []
        for i in range(len(sequence)):
            if (i+1)*chunk > len(sequence):
                seq = sequence[i*chunk:-1, :]
                list_seq.append(seq)
                break
            else:
                seq = sequence[i*chunk:(i+1)*chunk, :]
                list_seq.append(seq)
        return list_seq
    

    然后重塑list_seq 内的每个数组。另一种选择是 NumPy 方法np.split(),但是,我的函数比这个快 9 倍。

    【讨论】:

      猜你喜欢
      • 2017-09-04
      • 2021-10-18
      • 2021-10-29
      • 1970-01-01
      • 2014-09-09
      • 2021-06-08
      • 2021-06-20
      • 2016-10-07
      相关资源
      最近更新 更多