【问题标题】:Select elemetn from another list in tensorflow从张量流中的另一个列表中选择元素
【发布时间】:2020-03-06 23:10:15
【问题描述】:

我正在研究 tensorflow,但遇到以下问题

import tensorflow as tf
import numpy as np
from tensorflow.keras import losses
from tensorflow import nn

#2*3
label = np.array([[2, 0, 1], [0, 2, 1]])
#2*3*3
logit = np.array([[[.9, .5, .05], [.35, .01, .3], [.45, .91, .94]], 
         [[.05, .2, .4], [.05, .29, .6], [.35, .01, .02]]])


#find the value corresponding to label index by row
output = nn.log_softmax(logit)

我有

output = tf.Tensor(
[[[-0.74085818 -1.14085818 -1.59085818]
  [-0.97945321 -1.31945321 -1.02945321]
  [-1.43897936 -0.97897936 -0.94897936]]

 [[-1.27561467 -1.12561467 -0.92561467]
  [-1.38741927 -1.14741927 -0.83741927]
  [-0.88817684 -1.22817684 -1.21817684]]], shape=(2, 3, 3), dtype=float64)

我想通过来自label 的索引从output 中选择元素。也就是我的最终结果应该是

[[1.59085822 0.97945321 0.97897935]  #2, 0, 1
[1.27561462 0.83741927 1.22817683]], #0, 2, 1
shape=(2, 3), dtype=float64)

【问题讨论】:

  • 您好,请看一下我的回答,看看是否能解决您的问题

标签: python tensorflow indexing tensorflow2.0


【解决方案1】:

您不能直接执行此操作。实现这一目标的正确方法是首先在您的标签上应用one hot encoding。然后使用tf.boolean_mask 从您的输出日志中进行选择。

这是一个例子:

import tensorflow as tf
import numpy as np
from tensorflow.keras import losses
from tensorflow import nn

#2*3
label = np.array([[2, 0, 1], [0, 2, 1]])
#2*3*3
logit = np.array([[[.9, .5, .05], [.35, .01, .3], [.45, .91, .94]], 
         [[.05, .2, .4], [.05, .29, .6], [.35, .01, .02]]])


#find the value corresponding to label index by row
output = nn.log_softmax(logit)

one_hot = tf.one_hot(label, 3, dtype=tf.int32)
# <tf.Tensor: shape=(2, 3, 3), dtype=int32, numpy=
# array([[[0, 0, 1],
#         [1, 0, 0],
#         [0, 1, 0]],
# 
#        [[1, 0, 0],
#         [0, 0, 1],
#         [0, 1, 0]]], dtype=int32)>

result_vec = tf.boolean_mask(output, one_hot) # The result is a vector
# <tf.Tensor: shape=(6,), dtype=float64, numpy=
# array([-1.59085818, -0.97945321, -0.97897936, -1.27561467, -0.83741927,
#        -1.22817684])>

result = tf.reshape(result_vec, label.shape)

结果将是:(您是否遗漏了问题中的负号?)

<tf.Tensor: shape=(2, 3), dtype=float64, numpy=
array([[-1.59085818, -0.97945321, -0.97897936],
       [-1.27561467, -0.83741927, -1.22817684]])>

【讨论】:

    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 2017-12-23
    • 2016-04-01
    • 1970-01-01
    • 1970-01-01
    • 2016-09-23
    • 1970-01-01
    • 2019-09-13
    相关资源
    最近更新 更多