【问题标题】:any easy way to get imagenet dataset for training custom model in tensorflow?任何简单的方法来获取用于在 tensorflow 中训练自定义模型的 imagenet 数据集?
【发布时间】:2021-05-23 01:33:26
【问题描述】:

在我的实验中,我想在 imagenet 数据集上训练我的自定义模型。为简单起见,我对 10/100 类分类任务感兴趣。但是,直接从tfds 下载imagenet 数据集需要大量硬盘空间。是否有任何解决方法我们可以子集imagenet 数据集以便子集imagenet 数据集适合10/100 类分类任务?有谁知道使这种情况发生的任何方法?有什么想法吗?

一般来说,cifar10cifar100 与 TensorFlow 的函数式 api 配合使用非常方便。但是,在我的实验中,我想在imagenet 上训练我自己的模型。我想避免直接下载 imagenet 数据集,相反,我想要一些计算量更少的方法,这样我就可以在子集 imagenet(10 或 100 类分类)上训练我的自定义模型。有什么办法可以做到这一点吗?有什么想法吗?

我尝试下载imagenet

这是我尝试在本地下载 imagenet 数据集,然后在 imagenet 数据集上训练我的自定义模型。但是下载和加载训练数据非常耗时。但这就是我所做的:

import keras
import tensorflow as tf
import tensorflow_datasets as tfds

## fetch imagenet dataset directly
imagenet = tfds.image.Imagenet2012()

## describe the dataset with DatasetInfo
C = imagenet.info.features['label'].num_classes
n_train = imagenet.info.splits['train'].num_examples
n_validation = imagenet.info.splits['validation'].num_examples

assert C == 1000
assert n_train == 1281167
assert n_validation == 50000

imagenet.download_and_prepare()   ## need more space in harddrive

# load imagenet data from disk as tf.data.Datasets
datasets = imagenet.as_dataset()
train_data, validation_data= datasets['train'], datasets['validation']
assert isinstance(train_data, tf.data.Dataset)
assert isinstance(validation_data, tf.data.Dataset)

如果我这样做,下载会很耗时,并且需要更多硬盘空间。有没有更简单的方法来对 imagenet 数据集进行子集化并从 TensorFlow 中获取?有谁知道为 10/100 分类任务获取更小的 imagenet 数据集的更简单方法?有什么想法吗?

期望的输出

通常我们可以从tf.keras.datasets 得到cifar10, cifar100。我们可以将 imagenet 数据集子集到(200k ~ 500K)范围内吗?有没有更痛苦的方法来获取 imagenet 数据集以在 imagenet 数据上训练自定义模型?有什么想法吗?

【问题讨论】:

    标签: python tensorflow


    【解决方案1】:

    我自己想出来的。我需要使用tiny_imagenet_200:

    import os, sys, wget
    from zipfile import ZipFile
    
    url = 'http://cs231n.stanford.edu/tiny-imagenet-200.zip'
    tiny_imgdataset = wget.download('http://cs231n.stanford.edu/tiny-imagenet-200.zip', out = os.getcwd())
    for file in os.listdir(os.getcwd()):
        if file.endswith(".zip"):
            zip = ZipFile(file)
            zip.extractall()
        else:
            print("not found")
    

    【讨论】:

      猜你喜欢
      • 2020-09-26
      • 2017-12-11
      • 2018-08-27
      • 1970-01-01
      • 2020-12-29
      • 2021-09-19
      • 2021-06-18
      • 2023-01-31
      • 2016-09-05
      相关资源
      最近更新 更多