【发布时间】:2019-11-24 02:20:48
【问题描述】:
我有一个形状为[batch_size, seq_len, num_features] 的ndarray。但是,顺序维度末尾的某些元素不是必需的,因此我想删除它们并将顺序维度合并到批处理维度中。比如我要操作的ndarraya就是
batch_size = 2
seq_len = 3
num_features = 1
a = np.random.randn(batch_size, seq_len, num_features)
mask = np.ones((batch_size, seq_len), dtype=np.bool)
mask[0][1:] = 0
mask[1][2:] = 0
"""
>>> a = [[[-0.3908401 ]
[ 0.89686512]
[ 0.07594243]]
[[-0.12256737]
[-1.00838131]
[ 0.56543754]]]
mask=[[ True False False]
[ True True False]]
"""
其中mask用于表示a中的元素是否有用。我可以使用以下代码得到我想要的东西
res = []
for seq, m in zip(a, mask):
res.append(seq[:sum(m)])
np.concatenate(res, axis=0)
"""
>>>array([[0.08676509],
[0.47162315],
[0.98070665]])
"""
我想知道在 numpy 中是否有更优雅的方法来做到这一点?
【问题讨论】: