【问题标题】:Feeding numpy uint8 into tensorflow float32 placeholder将 numpy uint8 输入 tensorflow float32 占位符
【发布时间】:2018-08-24 20:49:31
【问题描述】:

我最近发现了一个概念验证实现,它使用 numpy.zeros 以单热编码方式准备功能:

data = np.zeros((len(raw_data), n_input, vocab_size),dtype=np.uint8)

如上所示,单个输入为np.uint8。 检查模型后发现,tensorflow模型的输入占位符定义为tf.float32

x = tf.placeholder(tf.float32, [None, n_input, vocab_size], name="onehotin")

我的特殊问题: tensorflow 如何处理这种输入类型的“不匹配”。这些值(0/1) 是否被张量流正确解释或强制转换。如果是这样,这是文档中提到的某个地方吗?谷歌搜索后,我找不到答案。应该提到的是,模型运行和值似乎是合理的。但是,将输入的 numpy 特征键入为 np.float32 会导致需要大量内存。

相关性: 一个正在运行但经过错误训练的模型在采用输入管道/将模型部署到生产环境后表现会有所不同。

【问题讨论】:

    标签: python numpy tensorflow type-conversion one-hot-encoding


    【解决方案1】:

    Tensorflow 支持这样的 dtype 转换。

    在诸如x + 1 之类的操作中,值1 正在通过负责验证和转换的tf.convert_to_tensor 函数。该函数有时会在后台手动调用,当设置 dtype 参数时,值会自动转换为这种类型。

    当您将数组输入到这样的占位符中时:

    session.run(..., feed_dict={x: data})
    

    ... 数据通过np.asarray 调用显式转换为正确类型的numpy 数组。请参阅python/client/session.py 的源代码。请注意,当 dtype 不同时,此方法可能会重新分配缓冲区,而这正是您的情况。所以你的内存优化并没有像你期望的那样工作:临时的 32 位 data 是在内部分配的。

    【讨论】:

    • 好的,谢谢,完美的答案。隐式转换可能会伤害...
    猜你喜欢
    • 2017-01-02
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2019-10-05
    • 1970-01-01
    • 2019-04-07
    • 2017-03-08
    相关资源
    最近更新 更多