【问题标题】:Tensorflow iterator fails to iterateTensorflow 迭代器无法迭代
【发布时间】:2020-04-30 09:49:07
【问题描述】:

我正在做一个与实例分割相关的项目。我正在尝试使用我自己的图像数据集训练 SegNet,该数据集包含一组图像及其相应的掩码,并且我已成功使用 tf.Dataset 加载我的数据。但是每次我使用可馈送迭代器将数据集馈送到 SegNet 时,我的程序总是会在没有任何错误或警告的情况下终止。我的代码如下所示。

load_satellite_image() 用于读取图像的文件名,dataset() 用于使用 tf.Dataset 加载图像。迭代器似乎无法更新输入管道。

     train_path = "data_example/train.txt"
     val_path = "data_example/test.txt"

     config_file = 'config.json'
     with open(config_file) as f:
         config = json.load(f)

     train_img, train_mask = load_satellite_image(train_path)   
     val_img, val_mask = load_satellite_image(val_path)   

     train_dataset = dataset(train_img, train_mask, config, True, 0, 1)  
     val_dataset = dataset(val_img, val_mask, config, True, 0, 1)  

     train_iter = train_dataset.make_initializable_iterator()
     validation_iter = val_dataset.make_initializable_iterator()

     handle = tf.placeholder(tf.string, shape=[])    

     iterator =  tf.data.Iterator.from_string_handle(handle,
                 train_dataset.output_types,train_dataset.output_shapes)

     next_element = iterator.get_next()

     with tf.Session() as Sess:
         sess.run(train_iter.initializer)
         sess.run(validation_iter.initializer)
         train_iter_handle = sess.run(train_iter.string_handle())
         val_iter_handle = sess.run(validation_iter.string_handle())

         for i in range(2):
             print("1")
             try:
                 while True:

                     for i in range(5):
                         print(sess.run(next_element,feed_dict={handle:train_iter_handle}))
                         print('----------------------------','\n')

                     for i in range(2):
                         print(sess.run(next_element,feed_dict={handle:val_iter_handle}))
             except tf.errors.OutOfRangeError:
                     pass

运行上面的代码后,我得到了:

         In [2]: runfile('D:/python_code/tensorflow_study/SegNet/load_data.py', 
         wdir='D:/python_code/tensorflow_study/SegNet')

        (tf.float32, tf.int32)
        (TensorShape([Dimension(360), Dimension(480), Dimension(3)]), TensorShape([Dimension(360), 
        Dimension(480), Dimension(1)]))
        (tf.float32, tf.int32)
        (TensorShape([Dimension(360), Dimension(480), Dimension(3)]), TensorShape([Dimension(360), 
        Dimension(480), Dimension(1)]))
        WARNING:tensorflow:From D:\Anaconda\envs\tensorflow-gpu\lib\site- 
        packages\tensorflow\python\data\ops\dataset_ops.py:1419: colocate_with (from 
        tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
        Instructions for updating:
        Colocations handled automatically by placer.

        In [1]:

我很困惑我的代码被无故终止了。如您所见,我可以获得训练/验证图像和掩码的形状和数据类型,这意味着问题与我的数据集无关。但是,tf.Session() 中的for 循环没有执行,我无法得到print("1") 的结果。 sess.run() 也不执行迭代器。有人遇到过这个问题吗?

谢谢!!!

【问题讨论】:

    标签: python tensorflow tensorflow-datasets


    【解决方案1】:

    问题解决了。这是一个愚蠢的错误,浪费了我很多时间。 我的程序终止而没有错误消息的原因是我正在使用愚蠢的 Spyder 编写我的代码,我不知道为什么它没有显示错误消息。实际上,存在 TensorFlow 产生的错误消息。巧合的是,我通过 Anaconda 的命令窗口运行了我的代码,并收到了以下错误消息:

    2020-04-30 17:31:03.591207: W tensorflow/core/framework/op_kernel.cc:1401] OP_REQUIRES failed at whole_file_read_ops.cc:114 : Invalid argument: NewRandomAccessFile failed to Create/Open: D:\Study\PhD\python_code\tensorflow_study\SegNet\data_example\trainannot\ges_517405_679839_21.jpg
    

    迭代器不起作用,因为 TensorFlow 找不到掩码位置。图像和蒙版位置存储在一个文本文件中,如下所示:

    data_example\train\ges_517404_679750_21.jpg,data_example\trainannot\ges_517404_679750_21.jpg
    data_example\train\ges_517411_679762_21.jpg,data_example\trainannot\ges_517411_679762_21.jpg
    

    左侧是原始图像的位置,右侧是它们的掩码位置。一开始我用split(",")分别获取图片和蒙版的位置,但是好像蒙版的位置有问题。所以我检查了用于生成文本文件的代码:

    file.writelines([Train_path[i],',',TrainAnnot_path[i],'\n'])
    

    文本文件中的每一行都以\n 结尾,这就是Tensorflow 无法获取掩码位置的原因。所以我用file.writelines([Train_path[i],' ',TrainAnnot_path[i],'\n'])替换了file.writelines([Train_path[i],',',TrainAnnot_path[i],'\n']),并使用strip().split(" ")而不是split(" ")。这样就解决了问题。

    【讨论】:

    • 最好使用strip().split(" ")而不是split(),因为它无法处理\n
    猜你喜欢
    • 2016-08-23
    • 2022-10-15
    • 2020-07-31
    • 2020-01-03
    • 1970-01-01
    • 1970-01-01
    • 2020-10-23
    • 2021-08-24
    • 2016-10-22
    相关资源
    最近更新 更多