【问题标题】:How to get the maximum 2D Tensor from a 3D tensor using TensorFlow 1.14?如何使用 TensorFlow 1.14 从 3D 张量中获得最大 2D 张量?
【发布时间】:2020-08-21 15:46:14
【问题描述】:

我正在寻找最佳和优化的方法(无循环),以使用 TensorFlow 1.14 根据最大​​值从 3D 张量获取 2D 最大张量。假设我们有这个张量和这个函数(为了理解-它不起作用-):

def get_Max(inputs):
    max_indices = [0,0,0]
    for i in range(16):
        for j in range(2048):
            for k in range(10):
                if(inputs[max_indices[0],max_indices[1],max_indices[2]]<inputs[i,j,k]):
                   max_indices = [i,j,k]
    return inputs[:][j]
inputs = tf.random.uniform(shape=[16,2048,10],dtype=tf.dtypes.float32)
output = get_Max(inputs)

因此,输出张量的形状必须为 [16,10],即 2048 的 16 个最大值。 那么,如何实现一个没有循环的函数呢?

我使用了tf.math.reduce_max,但这不是我想要的,因为它在下图中很清楚:

【问题讨论】:

  • 可能是tf.math.reduce_max(inputs)?或者 argmax 如果您还需要索引?
  • 但是如何返回一个 [16,10] 张量?
  • 你也可以指定一个轴来循环,所以不是在所有维度上找到最大值,而是在那个维度上循环然后丢弃它。
  • 我的意思是,tf.math.reduce_max(inputs, axis=1)
  • 我知道将这些索引传递给输入并获得形状为 [16,10] 的输入

标签: python tensorflow tensorflow2.0 tensorflow-serving


【解决方案1】:
inp = tf.random.uniform(shape=[4, 6, 2], maxval=20, dtype=tf.int32)
print(inp)

array([[[14,  8],
    [18, 10],
    [ 6, 14],
    [ 8,  9],
    [11, 11],
    [14, 13]],

   [[ 7, 18],
    [ 4, 10],
    [15,  6],
    [ 6,  2],
    [19, 11],
    [10,  4]],

   [[ 8,  1],
    [ 1,  3],
    [ 4, 17],
    [15,  7],
    [ 0,  0],
    [ 1,  4]],

   [[ 5,  0],
    [15, 12],
    [ 1, 16],
    [ 3, 17],
    [14, 17],
    [ 2, 18]]], dtype=int32)>

所以如果我理解正确,对于每个inp[i, :, :] 喜欢:

    [[14,  8],
    [18, 10],
    [ 6, 14],
    [ 8,  9],
    [11, 11],
    [14, 13]]

您想保留包含最大数量的项目,在本例中为第二行:[18, 10]。我要做的是首先找到最后一个轴的最大数量:

am = tf.math.reduce_max(inp, axis=2)
am[0, :, :]
[14,
 18,
 14,
 9,
 11,
 14]

然后找到包含最大数的行的索引:

am = tf.math.argmax(am, axis=1)

这些将是您想要的 js,然后您可以使用 tf.gather_nd 并枚举来获取这些值:

# [*enumerate(am)] = [(0, am[0]), (1, am[1]), ...]
tf.gather_nd(inp, [*enumerate(am)])

<tf.Tensor: shape=(4, 2), dtype=int32, numpy=
array([[18, 10],
       [19, 11],
       [ 4, 17],
       [ 2, 18]], dtype=int32)>

【讨论】:

  • 谢谢,这正是我想要的。
  • 附言。我得到一个 SyntaxError: invalid syntax in Python2.7,在行 tf.gather_nd(inp, [*enumerate(am)]) 它指向 *
  • @RaniaRano 在 Python2 中,您可以将 [*enumerate(blah)] 更改为 enumerate(blah)
  • @mohammed Jafar 它不起作用,因为我们无法在 TensorFlow 1.14 中迭代张量
  • @RaniaRano 你能在 1.14 中做am.numpy() 吗?如果没有,也许你可以使用类似于tf.stack(tf.range(tf.shape(am)[0]), am) 的东西来做enumerate 所做的事情?
猜你喜欢
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 2018-05-15
  • 2016-07-02
  • 1970-01-01
  • 1970-01-01
相关资源
最近更新 更多