【问题标题】:slice based on a masked tensor in tensorflow基于张量流中的掩码张量切片
【发布时间】:2019-11-05 05:37:36
【问题描述】:

这是我的可重现代码:

tf_ent = tf.Variable([   [9.96,    8.65,    0.99,    0.1 ],
                         [0.7,     8.33,    0.1  ,   0.1   ],
                         [0.9,     0.1,     6,       7.33],
                         [6.60,    0.1,     3,       5.5 ],
                         [9.49,    0.2,     0.2,     0.2   ],
                         [0.4,     8.45,    0.2,     0.2 ],
                         [0.3,     0.2,     5.82,    8.28]])

tf_ent_var = tf.constant([True, False, False, False, False, True, False])

我想保留tf_enttf_ent_var 中对应索引为True 的行,并使其余行在整个矩阵中最小化。

所以预期的输出是这样的:

                    [[9.96,    8.65,    0.99,   0.1 ],
                     [0.1,     0.1,     0.1  ,  0.1 ],
                     [0.1,     0.1,     0.1,    0.1 ],
                     [0.1,     0.1,     0.1,    0.1 ],
                     [0.1,     0.1,     0.1,    0.1 ],
                     [0.4,     8.45,    0.2,      0.2 ],
                     [0.1,     0.1,     0.1,    0.1 ]]

知道我该怎么做吗?

我试图从掩码张量中获取索引,然后使用 tf.gather 来完成这个,但是我得到的索引是这样的[[0], [6]],这是有道理的,因为它给出了一个向量的索引。

【问题讨论】:

  • 当你想要的行和你不想要的行中都存在0时,“整个矩阵中的最小值”0.1如何?
  • @ImperishableNight,我不得不让矩阵变小,忘记替换零:|,我会更新我的问题

标签: python tensorflow slice


【解决方案1】:

编辑:对于 tensorflow 1.x,使用:

val = tf.math.reduce_min(tf_ent)
tf.where(tf_ent_var, tf_ent, tf.zeros_like(tf_ent) + val)

不幸的是,广播规则不是 2.0 规则的子集(与 numpy 相同),而是“只是不同”。就版本兼容性而言,Tensorflow 并不是最好的库。


基本思想是使用tf.where,但是你需要先把tf_ent_var变成一个形状为(7, 1)的张量,这样tensorflow就知道在第二个轴而不是第一个轴上广播它。所以:

val = tf.math.reduce_min(tf_ent)
tf.where(tf_ent_var[:, tf.newaxis], tf_ent, val)

当然你也可以改成(-1, 1),不过我觉得用tf.newaxis切片更短更清晰。


这是我与 1.13.1 的 Python 交互会话,用于故障排除。

Python 3.7.3 (v3.7.3:ef4ec6ed12, Mar 25 2019, 16:52:21) 
[Clang 6.0 (clang-600.0.57)] on darwin
Type "help", "copyright", "credits" or "license" for more information.
>>> import tensorflow as tf
>>> sess = tf.InteractiveSession()
2019-06-22 15:51:09.210852: I tensorflow/core/platform/cpu_feature_guard.cc:141] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
>>> tf_ent = tf.Variable([   [9.96,    8.65,    0.99,    0.1 ],
...                          [0.7,     8.33,    0.1  ,   0.1   ],
...                          [0.9,     0.1,     6,       7.33],
...                          [6.60,    0.1,     3,       5.5 ],
...                          [9.49,    0.2,     0.2,     0.2   ],
...                          [0.4,     8.45,    0.2,     0.2 ],
...                          [0.3,     0.2,     5.82,    8.28]])
WARNING:tensorflow:From /Users/REDACTED/Documents/test/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
Instructions for updating:
Colocations handled automatically by placer.
>>> 
>>> tf_ent_var = tf.constant([True, False, False, False, False, True, False])
>>> init = tf.global_variables_initializer()
>>> sess.run(init)
>>> val = tf.math.reduce_min(tf_ent)
>>> tf.where(tf_ent_var, tf_ent, tf.zeros_like(tf_ent) + val)
<tf.Tensor 'Select:0' shape=(7, 4) dtype=float32>
>>> _.eval()
array([[9.96, 8.65, 0.99, 0.1 ],
       [0.1 , 0.1 , 0.1 , 0.1 ],
       [0.1 , 0.1 , 0.1 , 0.1 ],
       [0.1 , 0.1 , 0.1 , 0.1 ],
       [0.1 , 0.1 , 0.1 , 0.1 ],
       [0.4 , 8.45, 0.2 , 0.2 ],
       [0.1 , 0.1 , 0.1 , 0.1 ]], dtype=float32)
>>> tf.__version__
'1.13.1'

【讨论】:

  • 谢谢你的回答,虽然我得到了这个错误 'tensorflow.python.framework.errors_impl.InvalidArgumentError: Inputs to operation Select of type Select must have the same size and shape.输入 0: [7,1] != 输入 1: [7,4] [Op:Select] '
  • 哦,我使用的是 tensorflow 2.0.0-beta1(并且 tensorflow 确保我知道这一点,因为它抛出的每个错误都来自一个名为 something_v2 的函数,例如 where_v2 和 @ 987654331@)。在 tensorflow 1.x 中,where 可能没有这么灵活。我会尝试在 1.x 中寻找解决方案。
  • 我不确定我是否应该更新到 beta 版本,如果您发现任何适用于 tf.13 的方法,请告诉我
  • val 值不会被添加到 tf.zeros 张量中,因此输出在最终矩阵中没有 0.1
  • 我尝试更新它,但听起来它与 tensorflow 版本有关,你得到的输出与我在问题中分享的完全相同吗?
【解决方案2】:
min_mat = tf.broadcast_to(tf.reduce_min(tf_ent), tf_ent.shape)
output = tf.where(tf_ent_var, tf_ent, min_mat)
sess.run(output)

【讨论】:

    【解决方案3】:

    这是我使用tf.concat()if-else 语句的实现。它不像其他人的答案那么优雅,但正在工作:

    import tensorflow as tf
    tf.enable_eager_execution()
    
    def slice_tensor_based_on_mask(tf_ent, tf_ent_var):
        res = tf.fill([1, 4], 0.0)  
        min_value_tensor = tf.fill([1,int(tf_ent.shape[1])], tf.reduce_min(tf_ent))
    
        for i in range(int(tf_ent.shape[0])):
            if tf_ent_var[i:i+1].numpy()[0]: # true value in tf_ent_var
                res = tf.concat([res, tf_ent[i:i+1]], 0)
            else:
                res = tf.concat([res, min_value_tensor], 0)
        return res[1:]
    
    tf_ent = tf.Variable([[9.96,    8.65,    0.99,   0.1 ],
                         [0.7,     8.33,    0.1  ,   0.1 ],
                         [0.9,     0.1,     6,       7.33],
                         [6.60,    0.1,     3,       5.5 ],
                         [9.49,    0.2,     0.2,     0.2 ],
                         [0.4,     8.45,    0.2,     0.2 ],
                         [0.3,     0.2,     5.82,    8.28]])
    
    tf_ent_var = tf.constant([True, False, False, False, False, True, False])
    print(slice_tensor_based_on_mask(tf_ent, tf_ent_var))
    

    输出:

    tf.Tensor(
    [[9.96 8.65 0.99 0.1 ]
     [0.1  0.1  0.1  0.1 ]
     [0.1  0.1  0.1  0.1 ]
     [0.1  0.1  0.1  0.1 ]
     [0.1  0.1  0.1  0.1 ]
     [0.4  8.45 0.2  0.2 ]
     [0.1  0.1  0.1  0.1 ]], shape=(7, 4), dtype=float32)
    

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2017-12-01
      • 1970-01-01
      • 2018-03-15
      • 1970-01-01
      • 2019-05-30
      • 2018-01-31
      • 1970-01-01
      • 2019-10-31
      相关资源
      最近更新 更多