【发布时间】:2017-11-14 10:20:27
【问题描述】:
我正在尝试将从 tfrec 文件加载的字符串标签张量转换为用于单热编码的数字。这个想法是使用一个 numpy 数组作为查找表,一旦命中,就会返回索引并将其存储在张量中。
但是,问题是张量不能直接与 python 对象进行比较。我尝试使用 tf.map_fn 来枚举我的一批标签,并使用 tf.cond 来进行比较 - 这没有用:
def elem_op(t):
global all_labels
for idx, lbl in enumerate(all_labels):
lbl_tensor = tf.constant(lbl.encode()) # tensorflow stores string as bytes, so convert the Python string object to bytes tensor
ret = tf.cond(tf.equal(lbl_tensor, t), lambda : idx, lambda : -1)
if ret != -1: # now this doesn't work because tf.cond returns a tensor
return ret
return -1
# labels is a tensor storing a batch of label strings
train_labels = tf.map_fn(fn=elem_op, elems=labels, dtype=tf.int32)
问题是 tf.cond 也返回一个张量并且不能在“if”子句中使用。我想知道解决这个问题的方法是什么?
谢谢!
【问题讨论】:
-
忘了提到这个例程将在“model_fn”中调用,因此没有明确的“会话”对象,尤其是。在 elem_op 子例程中,我不能使用 sess.run(ret) 或 ret.eval(session=sess)。
标签: python tensorflow