【问题标题】:How can i convert mnist data to RGB format?如何将 mnist 数据转换为 RGB 格式?
【发布时间】:2020-03-05 19:30:30
【问题描述】:

我正在尝试将 MNIST 数据集转换为 RGB 格式,每个图像的实际形状是 (28, 28),但我需要 (28, 28, 3)。

import numpy as np
import tensorflow as tf

mnist = tf.keras.datasets.mnist
(x_train, _), (x_test, _) = mnist.load_data()

X = np.concatenate([x_train, x_test])
X = X / 127.5 - 1

X.reshape((70000, 28, 28, 1))

tf.image.grayscale_to_rgb(
    X,
    name=None
)

但我收到以下错误:

ValueError: Dimension 1 in both shapes must be equal, but are 84 and 3. Shapes are [28,84] and [28,3].

【问题讨论】:

  • 总是将完整的错误消息(从单词“Traceback”开始)作为文本(不是屏幕截图)放在有问题的(不是评论)中。还有其他有用的信息。

标签: python numpy tensorflow keras tensorflow-datasets


【解决方案1】:

除了@DMolony 和@Aqwis01 答案之外,另一个简单的解决方案是使用numpy.repeat 方法多次复制张量的最后一个维度:

X = X.reshape((70000, 28, 28, 1))
X = X.repeat(3, -1)  # repeat the last (-1) dimension three times
X_t = tf.convert_to_tensor(X)
assert X_t.shape == (70000, 28, 28, 3)

【讨论】:

    【解决方案2】:

    您应该将重塑后的 3D [28x28x1] 图像存储在一个数组中:

    X = X.reshape((70000, 28, 28, 1))
    

    转换时,设置另一个数组为tf.image.grayscale_to_rgb()函数的返回值:

    X3 = tf.image.grayscale_to_rgb(
    X,
    name=None
    )
    

    最后,用matplotlibtf.session() 从生成的张量图像中绘制一个示例:

    import matplotlib.pyplot as plt
    
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
    
        image_to_plot = sess.run(image)
        plt.figure()
        plt.imshow(image_to_plot)
        plt.grid(False)
    

    完整代码:

    
    import numpy as np
    import tensorflow as tf
    
    mnist = tf.keras.datasets.mnist
    (x_train, _), (x_test, _) = mnist.load_data()
    
    X = np.concatenate([x_train, x_test])
    X = X / 127.5 - 1
    
    # Set reshaped array to X 
    X = X.reshape((70000, 28, 28, 1))
    
    # Convert images and store them in X3
    X3 = tf.image.grayscale_to_rgb(
        X,
        name=None
    )
    
    # Get one image from the 3D image array to var. image
    image = X3[0,:,:,:]
    
    # Plot it out with matplotlib.pyplot
    import matplotlib.pyplot as plt
    
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
    
        image_to_plot = sess.run(image)
        plt.figure()
        plt.imshow(image_to_plot)
        plt.grid(False)
    

    【讨论】:

      【解决方案3】:

      如果您在 tf.image.grayscale_to_rgb 之前打印 X 的形状,您将看到输出尺寸为 (70000, 28, 28)。 tf.image.grayscale 的输入大小必须为 1,因为它是最终维度。

      扩展X的最终维度,使其与函数兼容

      tf.image.grayscale_to_rgb(tf.expand_dims(X, axis=3))
      

      【讨论】:

      • 这并不能解决问题,因为您仍然没有返回转换函数的输出。
      • 当然,我应该返回函数输出,但是说它不能修复错误是不正确的。
      猜你喜欢
      • 1970-01-01
      • 1970-01-01
      • 2021-06-17
      • 2019-12-09
      • 2020-07-24
      • 1970-01-01
      • 2019-09-29
      • 1970-01-01
      • 2011-11-19
      相关资源
      最近更新 更多