【问题标题】:How to process different row in tensor based on the first column value in tensorflow如何根据张量流中的第一列值处理张量中的不同行
【发布时间】:2018-09-14 04:03:50
【问题描述】:

假设我有一个 4 x 3 张量:

sample = [[10, 15, 25], [1, 2, 3], [4, 4, 10], [5, 9, 8]]

我想返回另一个形状为 4 的张量:[r1,r2,r3,r4] 如果 row[0] 小于 5,则 r 等于 tf.reduce_sum(row),或者 r 等于tf.reduce_mean(row) 如果 row[0] 大于或等于 5。 输出:

output = [16.67, 6, 18, 7.33]

我不是 tensorflow 方面的专家,请帮助我了解如何在没有 for 循环的情况下在 python 3 中实现上述目标。 谢谢

更新:

因此,我尝试调整@Onyambu 给出的答案以在函数中包含两个示例,但在所有情况下它都给了我一个错误。 这是第一种情况的答案:

def f(x):
    c = tf.constant(5,tf.float32)
    def fun1():
        return tf.reduce_sum(x)
    def fun2():
        return tf.reduce_mean(x)
    return tf.cond(tf.less(x[0],c),fun1,fun2)
a = tf.map_fn(f,tf.constant(sample,tf.float32))

上面的效果很好。

两个样本:

sample1 = [[10, 15, 25], [1, 2, 3], [4, 4, 10], [5, 9, 8]]
sample2 = [[0, 15, 25], [1, 2, 3], [0, 4, 10], [1, 9, 8]]

def f2(x1,x2):
    c = tf.constant(1,tf.float32)
    def fun1():
        return tf.reduce_sum(x1[:,0] - x2[:,0])
    def fun2():
        return tf.reduce_mean(x1 - x2)
    return tf.cond(tf.less(x2[0],c),fun1,fun2)
a = tf.map_fn(f2,tf.constant(sample1,tf.float32), tf.constant(sample2,tf.float32))

改编确实会报错,但原理很简单:

  • 如果 row[0] 小于 1,则计算 sample1[:,0] - sample2[:,0] 的 tf.reduce_sum

  • 如果 row[0] 大于或等于 1,则计算 sample1 - sample2 的 tf.reduce_sum

提前感谢您的帮助!

【问题讨论】:

    标签: python-3.x tensorflow


    【解决方案1】:
    import tensorflow as tf
    def f(x):
        y = tf.constant(5,tf.float32)
        def fun1():
            return tf.reduce_sum(x)
        def fun2():
            return tf.reduce_mean(x)
        return tf.cond(tf.less(x[0],y),fun1,fun2)
    
    a = tf.map_fn(f,tf.constant(sample,tf.float32))
    
    with tf.Session() as sess: print(sess.run(a))
    
     [16.666666   6.        18.         7.3333335]
    

    如果你想缩短它:

    y = tf.constant(5,tf.float32)
    f=lambda x: tf.cond(tf.less(x[0], y), lambda: tf.reduce_sum(x),lambda: tf.reduce_mean(x))
    
    a = tf.map_fn(f,tf.constant(sample,tf.float32))
    with tf.Session() as sess: print(sess.run(a))
    

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2018-08-30
      • 2019-12-28
      • 1970-01-01
      • 1970-01-01
      • 2019-12-26
      • 2020-05-29
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多