【发布时间】:2017-05-01 15:41:47
【问题描述】:
代码摘自Tensorflow tutorial。该函数在 MNIST 数据集上运行操作,这是一个 0-9 的手写图片数据集。为什么要给int64打标签,我以为int32就够了。
def loss(logits,labels):
labels = tf.to_int64(labels)
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
logits,labels,name='xentropy')
loss = tf.reduce_mean(cross_entropy,name='xentropy_mean')
return loss
【问题讨论】:
-
其中一些转换是为了让 TensorFlow 在以 numpy 数组形式提供数据时使用与 numpy 相同的类型(numpy 整数默认为 int64)
-
在这种情况下,转换应该会自动发生,对吧?我在没有指定
dtype的情况下定义数组时使用numpy 对其进行了测试,默认情况下它为int64。那么,为什么还要提前进行投射呢? -
如果您将
int64提供给 TensorFlowlabels节点,即int32,那么它必须在每次运行调用期间进行向下转换。 IE,它必须运行逻辑来查看输入值 int64 是否适合int32空间
标签: python tensorflow neural-network deep-learning mnist