【问题标题】:why my dataset doesn't stop even if I set dataset.repeat(1)为什么即使我设置 dataset.repeat(1) 我的数据集也不会停止
【发布时间】:2019-01-24 05:45:48
【问题描述】:

我有一个训练数据集和一个测试数据集,

#training dataset
dataset_train = tf.data.TFRecordDataset(files_train)
dataset_train = dataset_train.map(...)
dataset_train = dataset_train.shuffle(...)
dataset_train = dataset_train.batch(...)
dataset_train = dataset_train.repeat(1)
iterator_train = dataset_train.make_initializable_iterator()

#test dataset
dataset_test = tf.data.TFRecordDataset(files_test)
dataset_test = dataset_test.map(...)
dataset_test = dataset_test.shuffle(...)
dataset_test = dataset_test.batch(...)
dataset_test = dataset_test.repeat(...)
iterator_test = dataset_test.make_initializable_iterator()

#for switch between two datasets.
handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(handle, dataset_train.output_types, dataset_train.output_shapes)
image_batch, label_batch = iterator.get_next()

在会话中,我有:

# in tf.Session()
train_iterator_handle = sess.run(train_iterator.string_handle())
val_iterator_handle = sess.run(test_iterator.string_handle())
sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()])

#start training, switch to training dataset
sess.run(iterator_train.initializer) 
while True:
    try:
        sess.run([train_step, ...])

        if global_step % N == 0: # test
            #start test, switch to test dataset
            sess.run(iterator_test.initializer)
            while True:
                try:
                    sess.run([acc_update, ...])
                except tf.errors.OutOfRangeError:
                    print("test finished")
                    break
            #test finished, switch back to training dataset
            sess.run(iterator_train.initializer) 
    except tf.errors.OutOfRangeError:
        print("training finished")
        break

我从 TF 的 API 中读到 训练数据集迭代器可以从上次离开的地方继续,并且我认为训练数据集在迭代所有数据时应该停止,因为我使用:

dataset_train = dataset_train.repeat(1)

但实际上,我的程序会运行并且不会停止。 所以我想我一定在某个地方犯了一个严重的错误。有人能帮我吗?

【问题讨论】:

    标签: python tensorflow iterator dataset


    【解决方案1】:

    验证后的这一行 sess.run(iterator_train.initializer) 将重置您的火车生成器状态,因此它将继续从头开始读取。我想,N 比火车迭代器中的步数少,所以它不会停止

    如果您只想在验证后继续训练,请不要再次调用训练迭代器初始化器

    【讨论】:

      猜你喜欢
      • 2015-12-13
      • 2022-01-11
      • 2020-10-03
      • 1970-01-01
      • 1970-01-01
      • 2022-08-11
      • 2023-01-07
      • 2017-11-01
      • 2019-04-14
      相关资源
      最近更新 更多