【发布时间】:2020-06-03 08:59:28
【问题描述】:
我正在试验来自https://github.com/astirn/IIC 的聚类模型 (已经尝试联系他)
它像大多数研究论文一样使用 Mnist 数据集。 在这里,他们首先将数据集名称定义为“mnist”,这足以让 tensorflow 从他们的标准在线数据集中导入 mnist。 然后他使用 tensorflow_dataset.load() 函数加载数据集
我已经为我的数据集创建了一个 tfrecord 文件,现在我只需要替换前面提到的脚本指向“mnist”的部分(下面代码中的第 1 行),而是指向我的本地数据集。
我只是用第一行的文件路径替换'mnist'吗???
来自实际训练模型文件的代码:
if __name__ == '__main__':
# pick a data set
DATA_SET = 'mnist'
# define splits
DS_CONFIG = {
# mnist data set parameters
'mnist': {
'batch_size': 700,
'num_repeats': 5,
'mdl_input_dims': [24, 24, 1]}
}
# load the data set
TRAIN_SET, TEST_SET, SET_INFO = load(data_set_name=DATA_SET, **DS_CONFIG[DATA_SET])
# configure the common model elements
MDL_CONFIG = {
# mist hyper-parameters
'mnist': {
'num_classes': SET_INFO.features['label'].num_classes,
'learning_rate': 1e-4,
'num_repeats': DS_CONFIG[DATA_SET]['num_repeats'],
'save_dir': None},
}
来自“数据准备文件”的代码,他将带有 tensorflor_dataset.load 的数据集称为 tfds.load:
def load(data_set_name, **kwargs):
"""
:param data_set_name: data set name--call tfds.list_builders() for options
:return:
train_ds: TensorFlow Dataset object for the training data
test_ds: TensorFlow Dataset object for the testing data
info: data set info object
"""
# get data and its info
ds, info = tfds.load(name=data_set_name, split=tfds.Split.ALL, with_info=True)
感谢帮助
【问题讨论】:
标签: python tensorflow dataset tensorflow-datasets mnist