【发布时间】:2021-10-30 22:11:16
【问题描述】:
我在 Tensorflow 中洗牌 numpy 数组时遇到了一个奇怪的行为(使用 Google Colab):
from matplotlib import pyplot as plt
import tensorflow as tf
import numpy as np
seed = int(np.random.randint(0, 2 ** 16))
(train_x, train_y), (test_x, test_y) = tf.keras.datasets.cifar10.load_data()
train_x = train_x / 255.0 # this line
train_x = tf.random.shuffle(train_x, seed=seed)
train_y = tf.random.shuffle(train_y, seed=seed)
train_dataset = tf.data.Dataset.from_tensor_slices((train_x, train_y))
for i in train_dataset.take(10):
print(f"Label: {i[1].numpy()[0]}", end=', ')
plt.figure()
plt.imshow(i[0])
以这种方式对 train_x 和 train_y(都是 numpy 数组)进行混洗后,我在视觉上确认索引之间的关系得到维护,即似乎每次调用 shuffle 都会重置 rng 并且两次都得到相同的排列。但是,当我注释掉规范化步骤(标记为“这条线”)时,改组会破坏索引之间的关系。
我不理解这种行为,并想了解为什么会发生这种情况。任何帮助表示赞赏。
【问题讨论】:
标签: numpy tensorflow