【问题标题】:Why CIFAR-10 images are not displayed properly using matplotlib?为什么使用 matplotlib 无法正确显示 CIFAR-10 图像?
【发布时间】:2016-06-29 23:46:43
【问题描述】:

我从训练集中拍摄了一张大小为 (3,32,32) 的图像 ('img')。 我用过 plt.imshow(img.T)。图像不清晰。现在我必须对 image('img') 进行更改以使其更清晰可见。 谢谢。

【问题讨论】:

    标签: python image matplotlib machine-learning computer-vision


    【解决方案1】:

    以下打印 5X5 网格的随机 Cifar10 图像。它并不模糊,尽管也不完美。欢迎任何建议。

    %matplotlib inline
    import numpy as np
    import matplotlib.pyplot as plt
    from six.moves import cPickle 
    
    f = open('data/cifar10/cifar-10-batches-py/data_batch_1', 'rb')
    datadict = cPickle.load(f,encoding='latin1')
    f.close()
    X = datadict["data"] 
    Y = datadict['labels']
    X = X.reshape(10000, 3, 32, 32).transpose(0,2,3,1).astype("uint8")
    Y = np.array(Y)
    
    #Visualizing CIFAR 10
    fig, axes1 = plt.subplots(5,5,figsize=(3,3))
    for j in range(5):
        for k in range(5):
            i = np.random.choice(range(len(X)))
            axes1[j][k].set_axis_off()
            axes1[j][k].imshow(X[i:i+1][0])
    

    【讨论】:

    • 这是很好的信息,但 astype("float") 会将图像显示为负数,将其设置为 astype("uint8") 是正常的。
    • 最后一行的索引可能只是X[i, :] 对吧?
    【解决方案2】:

    确保在要显示图像时不对数据集进行规范化。

    示例:

    加载器...

    import torch
    from torchvision import datasets, transforms
    import matplotlib.pyplot as plt
    
    
    train_loader = torch.utils.data.DataLoader(
        datasets.CIFAR10('../data', train=True, download=True,
                         transform=transforms.Compose([
                             transforms.RandomHorizontalFlip(),
                             transforms.ToTensor(),
                            #  transforms.Normalize(
                            #      (0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
                         ])),
        batch_size=64, shuffle=True)
    

    显示图像的代码...

    img = next(iter(train_loader))[0][0]
    plt.imshow(transforms.ToPILImage()(img))
    

    标准化

    没有标准化

    【讨论】:

    • 这正是我的问题,我在标准化图像后查看图像。谢谢!
    【解决方案3】:

    此文件读取cifar10 dataset 并使用matplotlib 绘制单个图像。

    import _pickle as pickle
    import argparse
    import numpy as np
    import os
    import matplotlib.pyplot as plt
    
    cifar10 = "./cifar-10-batches-py/"
    
    parser = argparse.ArgumentParser("Plot training images in cifar10 dataset")
    parser.add_argument("-i", "--image", type=int, default=0, 
                        help="Index of the image in cifar10. In range [0, 49999]")
    args = parser.parse_args()
    
    
    def unpickle(file):
        with open(file, 'rb') as fo:
            dict = pickle.load(fo, encoding='bytes')
        return dict
    
    def cifar10_plot(data, meta, im_idx=0):
        im = data[b'data'][im_idx, :]
    
        im_r = im[0:1024].reshape(32, 32)
        im_g = im[1024:2048].reshape(32, 32)
        im_b = im[2048:].reshape(32, 32)
    
        img = np.dstack((im_r, im_g, im_b))
    
        print("shape: ", img.shape)
        print("label: ", data[b'labels'][im_idx])
        print("category:", meta[b'label_names'][data[b'labels'][im_idx]])         
    
        plt.imshow(img) 
        plt.show()
    
    
    def main():
        batch = (args.image // 10000) + 1
        idx = args.image - (batch-1)*10000
    
        data = unpickle(os.path.join(cifar10, "data_batch_" + str(batch)))
        meta = unpickle(os.path.join(cifar10, "batches.meta"))
    
        cifar10_plot(data, meta, im_idx=idx)
    
    
    if __name__ == "__main__":
        main()
    

    【讨论】:

    • 我喜欢这种方法 :)
    【解决方案4】:

    由于插值,图像模糊。为了防止在 matplotlib 中模糊,请使用关键字 interpolation='nearest' 调用 imshow

    plt.imshow(img.T, interpolation='nearest')
    

    此外,当您使用转置时,您的 x 轴和 y 轴似乎正在交换,因此您可能希望改为这样显示:

    plt.imshow(np.transpose(img, (1, 2, 0)), interpolation='nearest')
    

    【讨论】:

    • 它仍然无法正常工作。只是为了验证最后一行中的 img 是否具有形状 (3, 32, 32)。
    • 如果长度为 3 的维度是 RGB 轴,那么你想要它在最后。如果您的数组的形状为(3, R, C),那么np.transpose(img, (1, 2, 0)) 的形状为(R, C, 3)
    • 很好,还是看不到好图。有什么建议吗?
    • 如果imshow 图像仍然像链接的图像一样模糊,那么它没有使用“最近”插值。您应该能够使用imshow 返回的对象来验证它使用的是什么插值。根据您的 matplotlib 版本,您还可以尝试“无”作为插值。
    【解决方案5】:

    我使用以下代码将所有 CIFAR 数据显示为一张大图。代码显示图像,但如果你想保存它而不是模糊我建议使用plt.savefig(fname, format='png', dpi=1000)

    import numpy as np
    import matplotlib.pyplot as plt
    
    def reshape_and_print(self, cifar_data):
        # number of images in rows and columns
        rows = cols = np.sqrt(cifar_data.shape[0]).astype(np.int32)
        # Image hight and width. Divide by 3 because of 3 color channels
        imh = imw = np.sqrt(cifar_data.shape[1] // 3).astype(np.int32)
        # reshape to number of images X color channels X image size
        # transpose to color channels X number of images X image size
        timg = cifar_data.reshape(rows * cols, 3, imh * imh).transpose(1, 0, 2)
        # reshape to color channels X rows X cols X image hight X image with
        # swap axis to color channels X rows X image hight X cols X image with
        timg = timg.reshape(3, rows, cols, imh, imw).swapaxes(2, 3)
        # reshape to color channels X combined image hight X combined image with
        # transpose to combined image hight X combined image with X color channels
        timg = timg.reshape(3, rows * imh, cols * imw).transpose(1, 2, 0)
    
        plt.imshow(timg)
        plt.show()
    

    我做了一个快速的数据帮助类,用于一个小型测试项目,希望对你有用:

    import gzip
    import pickle
    import numpy as np
    import matplotlib.pyplot as plt
    
    
    class DataSet(object):
    
        def __init__(self, seed=42, setsize=10000):
            self.seed = seed
            # set the seed for reproducability
            np.random.seed(seed)
            # load the data
            train_set, test_set = self.load_data()
            # self.split_data(train_set, valid_set, test_set)
            self.split_data(train_set, test_set, setsize)
    
        def split_data(self, data_set, test_set, split_size):
            permutation = np.random.permutation(data_set.shape[0])
            self.train = data_set[permutation[:split_size]]
            self.valid = data_set[permutation[split_size:split_size * 2]]
            self.test = test_set[:split_size]
    
        def reshape_for_print(self, data):
            raise NotImplemented
    
        def load_data(self):
            raise NotImplemented
    
        def show_all_imgs(self, data):
            raise NotImplemented
    
    
    class CIFAR(DataSet):
    
        def load_data(self):
            # try to load data
            with open('./data/cifar-100-python/train', 'rb') as f:
                data = pickle.load(f, encoding='latin1')
            train_set = data['data'].astype(np.float32) / 255.0
    
            with open('./data/cifar-100-python/test', 'rb') as f:
                data = pickle.load(f, encoding='latin1')
            test_set = data['data'].astype(np.float32) / 255.0
    
            return train_set, test_set
    
        def reshape_for_print(self, data):
            gh = gw = np.sqrt(data.shape[0]).astype(np.int32)
            imh = imw = np.sqrt(data.shape[1] // 3).astype(np.int32)
            timg = data.reshape(gh * gw, 3, imh * imh).transpose(1, 0, 2)
            timg = timg.reshape(3, gh, gw, imh, imw).swapaxes(2, 3)
            timg = timg.reshape(3, gh * imh, gw * imw).transpose(1, 2, 0)
            return timg
    
        def show_all_imgs(self, data):
            timg = self.reshape_for_print(data)
            plt.imshow(timg)
            plt.show()
    
    
    class MNIST(DataSet):
    
        def load_data(self):
            # try to load data
            with gzip.open('./data/mnist.pkl.gz', 'rb') as f:
                train_set, valid_set, test_set = pickle.load(f, encoding='latin1')
            return train_set[0], test_set[0]
    
        def reshape_for_print(self, data):
            gh = gw = np.sqrt(data.shape[0]).astype(np.int32)
            imh = imw = np.sqrt(data.shape[1]).astype(np.int32)
            timg = data.reshape(gh, gw, imh, imw).swapaxes(1, 2)
            timg = timg.reshape(gh * imh, gw * imw)
            return timg
    
        def show_all_imgs(self, data):
            timg = self.reshape_for_print(data)
            plt.imshow(timg, cmap=plt.cm.gray)
            plt.show()
    

    【讨论】:

      【解决方案6】:

      我创建了一个函数来绘制 CIFAR10 数据集中的一行中的 RGB 图像。由于图像的原始尺寸非常小(32px X 32px),因此图像最多会模糊。

      def unpickle(file):
          with open(file, 'rb') as fo:
              dict1 = pickle.load(fo, encoding='bytes')
          return dict1
      
      pd_tr = pd.DataFrame()
      tr_y = pd.DataFrame()
      
      for i in range(1,6):
          data = unpickle('data/data_batch_' + str(i))
          pd_tr = pd_tr.append(pd.DataFrame(data[b'data']))
          tr_y = tr_y.append(pd.DataFrame(data[b'labels']))
          pd_tr['labels'] = tr_y
      
      tr_x = np.asarray(pd_tr.iloc[:, :3072])
      tr_y = np.asarray(pd_tr['labels'])
      ts_x = np.asarray(unpickle('data/test_batch')[b'data'])
      ts_y = np.asarray(unpickle('data/test_batch')[b'labels'])    
      labels = unpickle('data/batches.meta')[b'label_names']
      
      def plot_CIFAR(ind):
          arr = tr_x[ind]
          sc_dpi = 157.35
          R = arr[0:1024].reshape(32,32)/255.0
          G = arr[1024:2048].reshape(32,32)/255.0
          B = arr[2048:].reshape(32,32)/255.0
      
          img = np.dstack((R,G,B))
          title = re.sub('[!@#$b]', '', str(labels[tr_y[ind]]))
          fig = plt.figure(figsize=(3,3))
          ax = fig.add_subplot(111)
          ax.imshow(img,interpolation='bicubic')
          ax.set_title('Category = '+ title,fontsize =15)
      
      plot_CIFAR(4)
      

      【讨论】:

        【解决方案7】:

        尝试使用

        import matplotlib.pyplot as plt
        from scipy.misc import toimage
        plt.imshow(toimage(img))
        

        我不是 100% 确定代码是如何工作的,但我认为因为图像存储在浮点 numpy 数组中,所以 imshow() 函数很难将它们映射到正确的颜色。通过使用 toimage() 将它们类型转换为图像,您可以将它们转换为 imshow() 期望的正确图像格式,即不是数组,而是编码为 .png 或 .jpg 的图像。

        每次我想在 python 中显示图像时,这段代码都适用。

        【讨论】:

        • 请详细说明这是做什么的。
        • 我加了一点解释,希望对你有帮助
        【解决方案8】:

        代码结果是:试试下面的代码。

        我发现了一个关于 mnist 和 cifar 图像可视化的非常有用的链接。您可以找到各种图像的代码: https://machinelearningmastery.com/how-to-load-and-visualize-standard-computer-vision-datasets-with-keras/ cifar10 图像代码如下: 它运作良好。图片如上。

        # example of loading the cifar10 dataset
        from matplotlib import pyplot
        from keras.datasets import cifar10
        # load dataset
        (trainX, trainy), (testX, testy) = cifar10.load_data()
        # summarize loaded dataset
        print('Train: X=%s, y=%s' % (trainX.shape, trainy.shape))
        print('Test: X=%s, y=%s' % (testX.shape, testy.shape))
        # plot first few images
        for i in range(9):
            # define subplot
            pyplot.subplot(330 + 1 + i)
            # plot raw pixel data
            pyplot.imshow(trainX[i])
        # show the figure
        pyplot.show()
        

        【讨论】:

          【解决方案9】:

          加0.5:

          plt.imshow(np.transpose(img, (1, 2, 0)) + 0.5)
          

          【讨论】:

          • 请详细说明您的建议。
          猜你喜欢
          • 1970-01-01
          • 1970-01-01
          • 2017-10-13
          • 2019-12-28
          • 2020-08-29
          • 1970-01-01
          • 1970-01-01
          • 1970-01-01
          • 1970-01-01
          相关资源
          最近更新 更多