【发布时间】:2020-03-06 23:10:15
【问题描述】:
我正在研究 tensorflow,但遇到以下问题
import tensorflow as tf
import numpy as np
from tensorflow.keras import losses
from tensorflow import nn
#2*3
label = np.array([[2, 0, 1], [0, 2, 1]])
#2*3*3
logit = np.array([[[.9, .5, .05], [.35, .01, .3], [.45, .91, .94]],
[[.05, .2, .4], [.05, .29, .6], [.35, .01, .02]]])
#find the value corresponding to label index by row
output = nn.log_softmax(logit)
我有
output = tf.Tensor(
[[[-0.74085818 -1.14085818 -1.59085818]
[-0.97945321 -1.31945321 -1.02945321]
[-1.43897936 -0.97897936 -0.94897936]]
[[-1.27561467 -1.12561467 -0.92561467]
[-1.38741927 -1.14741927 -0.83741927]
[-0.88817684 -1.22817684 -1.21817684]]], shape=(2, 3, 3), dtype=float64)
我想通过来自label 的索引从output 中选择元素。也就是我的最终结果应该是
[[1.59085822 0.97945321 0.97897935] #2, 0, 1
[1.27561462 0.83741927 1.22817683]], #0, 2, 1
shape=(2, 3), dtype=float64)
【问题讨论】:
-
您好,请看一下我的回答,看看是否能解决您的问题
标签: python tensorflow indexing tensorflow2.0