【问题标题】:Cannot get predictions of tensorflow DNNClassifier无法获得 tensorflow DNNClassifier 的预测
【发布时间】:2017-04-03 23:52:09
【问题描述】:

我正在使用 MNIST 教程中的代码:

feature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)]
classifier = tf.contrib.learn.DNNClassifier(feature_columns=feature_columns,
                                            hidden_units=[10, 20, 10],
                                            n_classes=2,
                                            model_dir="/tmp/iris_model")

classifier.fit(x=np.array(train, dtype = 'float32'),
               y=np.array(y_tr, dtype = 'int64'),
               steps=2000)

accuracy_score = classifier.evaluate(x=np.array(test, dtype = 'float32'),
                                     y=y_test)["auc"]
print('AUC: {0:f}'.format(accuracy_score))

from tensorflow.contrib.learn import SKCompat
ds_test_ar = np.array(ds_test, dtype = 'float32')

ds_predict_tf = classifier.predict(input_fn = _my_predict_data)
print('Predictions: {}'.format(str(ds_predict_tf)))

但最后我得到了以下结果而不是预测:

Predictions: <generator object DNNClassifier.predict.<locals>.<genexpr> at 0x000002CE41101CA8>

我做错了什么?

【问题讨论】:

    标签: python tensorflow


    【解决方案1】:

    您收到并保存到ds_predict_tf 的是一个生成器表达式。 要打印它,你可以这样做:

    for i in ds_predict_tf:
        print i
    

    print(list(ds_predict_tf))
    

    您可以阅读有关genexpr here的更多信息。

    【讨论】:

    • 谢谢!奇怪直到现在才得到你的回复,发现自己,很抱歉这么简单的问题!
    【解决方案2】:

    DNNClassifier 预测函数默认具有 as_iterable=True。因此,它返回一个生成器。要获取预测值而不是生成器,请在 classifier.predict 方法中传递 as_iterable=False

    例如,

    classifier.predict(input_fn = _my_predict_data,as_iterable=False)



    用于了解有关分类器方法和参数的更多信息。这是 predict 方法的部分文档。

    来自DNNClassifier 文档:

    预测

    参数:

    • x:特征。
    • input_fn:输入函数。如果设置,x 必须为 None。
    • batch_size:覆盖默认批量大小。
    • 输出:str 列表,要预测的输出名称。如果没有,则返回类。
    • as_iterable:如果为 True,则返回一个迭代器,该迭代器不断为每个示例产生预测,直到输入用尽。注意:如果您希望迭代终止,输入必须终止(例如,如果您使用类似 read_batch_features 的东西,请确保传递 num_epochs=1)。

    返回:

    • 形状为 [batch_size] 的预测类的 Numpy 数组(如果 as_iterable 为 True,则为预测类的迭代)。每个预测的类都由其类索引(即从 0 到 n_classes-1 的整数)表示。如果设置了输出,则返回一个预测字典。

    【讨论】:

    • as_iterable 现已弃用。
    【解决方案3】:

    解决方案:-

    pred = classifier.fit(x=training_set.data, y=training_set.target, steps=2000).predict(test_set.data)
    
    print ("Predictions:")
    
    print(list(pred))
    

    就是这样……

    【讨论】:

      【解决方案4】:

      为了尽可能接近教程使用:

      print('Predictions: {}' .format(list(ds_predict_tf)))
      

      【讨论】:

        【解决方案5】:

        对不起,答案很简单,你需要使用predictor作为generator对象:

        g1 = ds_predict_tf
        
        [g1.__next__() for i in range(100)]
        

        【讨论】:

          猜你喜欢
          • 2017-04-21
          • 1970-01-01
          • 1970-01-01
          • 1970-01-01
          • 1970-01-01
          • 2017-08-07
          • 2019-07-26
          • 2018-01-23
          相关资源
          最近更新 更多