【问题标题】:How to compare tensorflow tensors with Python objects?如何将张量流张量与 Python 对象进行比较?
【发布时间】:2017-11-14 10:20:27
【问题描述】:

我正在尝试将从 tfrec 文件加载的字符串标签张量转换为用于单热编码的数字。这个想法是使用一个 numpy 数组作为查找表,一旦命中,就会返回索引并将其存储在张量中。

但是,问题是张量不能直接与 python 对象进行比较。我尝试使用 tf.map_fn 来枚举我的一批标签,并使用 tf.cond 来进行比较 - 这没有用:

def elem_op(t):
    global all_labels
    for idx, lbl in enumerate(all_labels):
        lbl_tensor = tf.constant(lbl.encode())  # tensorflow stores string as bytes, so convert the Python string object to bytes tensor
        ret = tf.cond(tf.equal(lbl_tensor, t), lambda : idx, lambda : -1)
        if ret != -1:  # now this doesn't work because tf.cond returns a tensor
            return ret
    return -1

# labels is a tensor storing a batch of label strings
train_labels = tf.map_fn(fn=elem_op, elems=labels, dtype=tf.int32)

问题是 tf.cond 也返回一个张量并且不能在“if”子句中使用。我想知道解决这个问题的方法是什么?

谢谢!

【问题讨论】:

  • 忘了提到这个例程将在“model_fn”中调用,因此没有明确的“会话”对象,尤其是。在 elem_op 子例程中,我不能使用 sess.run(ret) 或 ret.eval(session=sess)。

标签: python tensorflow


【解决方案1】:

您必须在会话中评估张量才能获得它的实际值。

如果条件更改为:

if sess.run(ret) != -1:

其中sess 是您的tf.Session 实例。例如:

sess = tf.Session()

同样,你可以运行:

sess.run(train_labels)

【讨论】:

  • 忘了提到 model_fn 中没有明确的会话对象,所以很遗憾这不起作用。
  • 恐怕“如何比较 tensorflow 张量与 Python 对象 [没有会话]?”的答案。是“你没有”。
  • 嗯,这是一种方法 - 在将所有字符串标签传递给 tensorflow 之前将它们预处理为整数,从而避免张量与python-obj 比较。但我想知道是否有另一种方法可以做到这一点..?
猜你喜欢
  • 2020-08-28
  • 2019-12-19
  • 1970-01-01
  • 2021-07-08
  • 2018-12-28
  • 1970-01-01
  • 2020-09-19
  • 1970-01-01
  • 1970-01-01
相关资源
最近更新 更多