【问题标题】:Tensorflow : can not create dataset with corresponding label using Dataset APITensorflow:无法使用数据集 API 创建具有相应标签的数据集
【发布时间】:2019-01-20 07:46:13
【问题描述】:

我的数据集包含featureslabels,例如。

features, labels = (np.random.sample((5,2)), np.random.sample((5,1))) 

这意味着这个数据集中有 5 个数据元素(有 5 行,每行是一个 2-dim 特征和 1-dim 标签)。

我使用tf.data.Dataset 使用此代码创建数据集:

import tensorflow as tf
import numpy as np
features, labels = (np.random.sample((5,2)), np.random.sample((5,1))) 
print("feature : \n", features)
print("labels : \n", labels)

dataset = tf.data.Dataset.from_tensor_slices((features, labels))
iter = dataset.make_one_shot_iterator()            
x, y = iter.get_next()                                       
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())   
    print("element:\n", sess.run(x), sess.run(y))

我使用的是 TF1.5 ,Windows 10。然后我得到结果:

feature :
 [[0.10261779 0.28041519]  # feature0
 [0.91091857 0.95644642]   # feature1
 [0.77542043 0.49631646]   # ...
 [0.33241678 0.28630983]
 [0.39095336 0.76686785]]
labels :
 [[0.54097027]             # label0
 [0.99022349]              # label1
 [0.87510303]              # ...
 [0.07331254]
 [0.10868335]]
element:
 [0.10261779 0.28041519] [0.99022349]

当我创建数据集时,我希望 feature0 [0.10261779 0.28041519] 与 label0 [0.54097027] 对应。但是使用代码,feature0 [0.10261779 0.28041519] 对应于 label1 [0.99022349]。顺序是错误的。我不知道get_next 实际是如何工作的。

不知道有没有方法可以使用tensorflow Dataset API按顺序输出特征和标签。

谢谢

【问题讨论】:

    标签: python tensorflow dataset


    【解决方案1】:

    问题在于,通过分别运行 xalso 运行 y,您将迭代器推进两次。即:当调用sess.run(x) 时,返回features 的第一个元素并且迭代器前进。然后调用sess.run(y) 将返回labelssecond 元素,因为xy 都基于相同的迭代器。如果您再次调用sess.run(x),它应该返回features第三个​​ 元素,依此类推。

    我建议你像这样重写你的代码,例如:

    ...
    next_batch_op = iter.get_next()
    
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        feature_batch, label_batch = sess.run(next_batch_op)
        print("element:\n", feature_batch, label_batch)
    

    这只会运行迭代器一次,并让您访问相应的功能/标签。

    作为替代方案,我只是尝试了以下方法,它似乎有效:

    ...
    x, y = iter.get_next()
    
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        print("element:\n", sess.run([x, y]))
    

    与您的代码不同的是,我们在一个 run 调用中同时运行 xy。但是我发现第一个解决方案更清楚。

    【讨论】:

      猜你喜欢
      • 2021-12-11
      • 2021-06-08
      • 1970-01-01
      • 2019-02-08
      • 1970-01-01
      • 1970-01-01
      • 2023-03-09
      • 2019-08-22
      • 2018-08-22
      相关资源
      最近更新 更多