【问题标题】:TensorFlow wrong array's shapeTensorFlow错误数组的形状
【发布时间】:2018-01-01 16:09:59
【问题描述】:

我实际上是 TensorFlow 和 ML 的新手,我正在尝试从 pickle 文件加载数据集。我的数据集是 2 个列表的列表。第一个列表是 10 000 个图像,每个图像由一个 3072 字节的数组表示。每种颜色 1024 (rgb)。另一个列表是 10 000 个布尔值。我像这样加载我的数据集:

X, Y = pickle.load(open('training_dataset.pkl', 'rb'))

然后我使用以下代码创建我的网络:

network = input_data(shape=[None, 32, 32, 3])

并获得ValueError: Cannot feed value of shape (96, 3072) for Tensor 'InputData/X:0', which has shape '(?, 32, 32, 3)'

如何将我的数据集重塑为 [?, 32, 32, 3]? 我的泡菜文件格式不正确吗?

这是用于创建 pickle 文件的代码:

def unpickle(file_name):
    with open(file_name, 'rb') as opened_file:
        data = pickle.load(opened_file, encoding='bytes')
    return data


def create_training_pkl_file():
    img_arrays_list = []
    is_bird_boolean_list = []
    training_dataset = []

    for i in range(1,6):
        batch = unpickle('./cifar-10-batches-py/data_batch_' + str(i))
        for img in batch[b'data']:
            img_arrays_list.append(img)

        for label in batch[b'labels']:
            is_bird_boolean_list.append(label==2)

    training_dataset.append(img_arrays_list)
    training_dataset.append(is_bird_boolean_list)

    save_pickle(training_dataset, './training_dataset.pkl')

我正在使用CIFAR-10 dataset

【问题讨论】:

  • 当你从腌制文件中读取数据时,它的形状为(96,3072)。您腌制的数据格式错误。你需要reshape它来匹配输入张量
  • 当我列出 X 的内容时,它会给我一个包含 10 000 个 3072 字节数组的列表,其他 96 个值是多少?
  • 32x3=96,所以如果我是对的,那么您在代码中的某个地方搞砸了。另外,你能分享你制作这个泡菜文件的代码吗?
  • @Nain 是对的!这是您腌制数据集中的问题。我必须查看您的腌制/未腌制数据集。它是公开的吗?如果是,我可以帮你解决这个问题
  • @Saranns 我已经编辑了帖子并添加了创建我的泡菜文件的代码。数据集上的所有信息都应该在链接中。

标签: python numpy tensorflow


【解决方案1】:

这是一个简单的类,可以最好地解决您的问题。可能看起来很长,但在执行数据流图时很容易调用它们。

cwd = os.getcwd() # Should be same as the directory where you extracted the CIFAR-10 dataset

class DATA(cwd):
    def __init__(self, directory = "./"):
        self._directory = directory

        self._training_data = []
        self._training_labels = []       
        self._load_training_data()

        np.random.seed(0)
        samples_n = self._training_labels.shape[0]
        random_indices = np.random.choice(samples_n, samples_n // 10, 
                                          replace = False)
        np.random.seed()

        self._training_data = np.delete(self._training_data, random_indices, 
                                        axis = 0)
        self._training_labels = np.delete(self._training_labels, 
                                          random_indices)


    def _load_training_data(self):
        for i in range(1, 6):
            path = os.path.join(self._directory, "data_batch_" + str(i))
            with open(path, 'rb') as fd:
                cifar_data = pickle.load(fd, encoding = "bytes")
                imgs = cifar_data[b"data"].reshape([-1, 3, 32, 32]) #FLATTEN THE IMAGE
                # imgs are not 3d tensors anymore.
                imgs = imgs.transpose([0, 2, 3, 1]) # img tensors as row vectors # Resulting img.size() should equals number of neurons in the input layer.
                if i == 1:
                    self._training_data = imgs
                    self._training_labels = cifar_data[b"labels"]
                else:
                    self._training_data =np.concatenate([self._training_data, imgs], axis = 0)
                    self._training_labels = np.concatenate([self._training_labels, cifar_data[b"labels"]])

    def get_training_batch(self, batch_size):
        return self._get_batch(self._training_data, self._training_labels, batch_size)

    def _get_batch(self, data, labels, batch_size):
        samples_n = labels.shape[0]
        if batch_size <= 0:
            batch_size = samples_n

        random_indices = np.random.choice(samples_n, samples_n, replace = False)
        data = data[random_indices]
        labels = labels[random_indices]
        for i in range(samples_n // batch_size):
            on = i * batch_size
            off = on + batch_size
            yield data[on:off], labels[on:off]

创建 DATA 类的实例

dataset = DATA()

获取一个批次的训练数据及其对应的标签

training_data,training_labels = next(dataset.get_training_batch(batch_size))

我也和你一样处于学习曲线中,所以如果你需要更多关于代码的细节,你可以参考here

希望有帮助!

【讨论】:

    猜你喜欢
    • 2020-12-14
    • 1970-01-01
    • 2016-10-06
    • 1970-01-01
    • 1970-01-01
    • 2017-10-25
    • 1970-01-01
    • 1970-01-01
    • 2017-05-02
    相关资源
    最近更新 更多