【发布时间】:2018-01-01 16:09:59
【问题描述】:
我实际上是 TensorFlow 和 ML 的新手,我正在尝试从 pickle 文件加载数据集。我的数据集是 2 个列表的列表。第一个列表是 10 000 个图像,每个图像由一个 3072 字节的数组表示。每种颜色 1024 (rgb)。另一个列表是 10 000 个布尔值。我像这样加载我的数据集:
X, Y = pickle.load(open('training_dataset.pkl', 'rb'))
然后我使用以下代码创建我的网络:
network = input_data(shape=[None, 32, 32, 3])
并获得ValueError: Cannot feed value of shape (96, 3072) for Tensor 'InputData/X:0', which has shape '(?, 32, 32, 3)'
如何将我的数据集重塑为 [?, 32, 32, 3]? 我的泡菜文件格式不正确吗?
这是用于创建 pickle 文件的代码:
def unpickle(file_name):
with open(file_name, 'rb') as opened_file:
data = pickle.load(opened_file, encoding='bytes')
return data
def create_training_pkl_file():
img_arrays_list = []
is_bird_boolean_list = []
training_dataset = []
for i in range(1,6):
batch = unpickle('./cifar-10-batches-py/data_batch_' + str(i))
for img in batch[b'data']:
img_arrays_list.append(img)
for label in batch[b'labels']:
is_bird_boolean_list.append(label==2)
training_dataset.append(img_arrays_list)
training_dataset.append(is_bird_boolean_list)
save_pickle(training_dataset, './training_dataset.pkl')
我正在使用CIFAR-10 dataset
【问题讨论】:
-
当你从腌制文件中读取数据时,它的形状为
(96,3072)。您腌制的数据格式错误。你需要reshape它来匹配输入张量 -
当我列出 X 的内容时,它会给我一个包含 10 000 个 3072 字节数组的列表,其他 96 个值是多少?
-
32x3=96,所以如果我是对的,那么您在代码中的某个地方搞砸了。另外,你能分享你制作这个泡菜文件的代码吗? -
@Nain 是对的!这是您腌制数据集中的问题。我必须查看您的腌制/未腌制数据集。它是公开的吗?如果是,我可以帮你解决这个问题
-
@Saranns 我已经编辑了帖子并添加了创建我的泡菜文件的代码。数据集上的所有信息都应该在链接中。
标签: python numpy tensorflow