【问题标题】:Weighted sum of matrices in TensorflowTensorflow中矩阵的加权和
【发布时间】:2017-02-03 19:49:27
【问题描述】:

我有一个大小为 (M,N,N) 的 3 维张量 A 和一个大小为 M 的一维张量 p。我想计算矩阵的加权和:

在 NumPy 中,我正在实现以下代码:

import numpy as np
temp=np.array([p[m]*A[m] for m in range(M)])
B=sum(temp);

我想在 TensorFlow 中做同样的事情,但我似乎没有找到任何内置操作来执行相同的操作。我尝试了tf.matmultf.mul,但它们似乎没有给出预期的结果。有人可以建议我在 TensorFlow 中执行此操作的正确方法吗?

【问题讨论】:

    标签: numpy tensorflow


    【解决方案1】:

    好吧,使用显式广播似乎很容易。但我不确定这有多有效!

    import numpy as np
    
    A = np.array([[[1, 7],
                  [4, 4]],
                 [[2, 6],
                  [5, 3]],
                 [[3, 5],
                  [6, 2]]])
    
    p = np.array([4,3,2])
    
    M = 3
    N = 2
    
    #numpy version
    temp=np.array([p[m]*A[m] for m in range(M)])
    B=sum(temp);
    
    #tensorflow version
    import tensorflow as tf
    A_tf = tf.constant(A,dtype=tf.float64)
    p_tf = tf.constant(p,dtype=tf.float64)
    p_tf_broadcasted = tf.tile(tf.reshape(p_tf,[M,1,1]), tf.pack([1,N,N]))
    B_tf = tf.reduce_sum(tf.mul(A_tf,p_tf_broadcasted), axis=0)
    
    sess=tf.Session()
    B_tf_ = sess.run(B_tf)
    
    
    #make sure they're equal
    B_tf_ == B
    
    
    #array([[ True,  True],
    #       [ True,  True]], dtype=bool)
    

    【讨论】:

    • 是否可以做上述的矢量化版本?我有一个大小为(K,M)P 矩阵和一个大小为(M,N,N) 的张量A。我想计算一个张量B 大小为(K,N,N) 其中B[k]=\sum_{m=0}^{M-1} P[k][m] A_m
    【解决方案2】:

    当你有一个大小为(K,M)的P矩阵和一个大小为(M,N,N)的张量A时,如果你想计算一个大小为(K,N,N)的张量B,你可以按照它来。

    import tensorflow as tf
    import numpy as np
    
    K = 2
    M = 3
    N = 2
    
    np.random.seed(0)
    A = tf.constant(np.random.randint(1,5,(M,N,N)),dtype=tf.float64)
    
    # when K.shape=(K,M)
    P = tf.constant(np.random.randint(1,5,(K,M)),dtype=tf.float64)
    # when K.shape=(M,)
    # P = tf.constant(np.random.randint(1,5,(M,)),dtype=tf.float64)
    
    P_new = tf.expand_dims(tf.expand_dims(P,-1),-1)
    
    # if K.shape=(K,M) set axis=1,if K.shape=(M,) set axis=0,
    B = tf.reduce_sum(tf.multiply(P_new , A),axis=1)
    
    with tf.Session()as sess:
        print(sess.run(P))
        print(sess.run(A))
        print(sess.run(B))
    
    [[1. 4. 3.]
     [1. 1. 1.]]
    [[[1. 4.]
      [2. 1.]]
    
     [[4. 4.]
      [4. 4.]]
    
     [[2. 4.]
      [2. 3.]]]
    [[[23. 32.]
      [24. 26.]]
    
     [[ 7. 12.]
      [ 8.  8.]]]
    

    上面的代码可以修改为在你的问题中包含问题的解决方案。

    【讨论】:

      猜你喜欢
      • 2021-10-20
      • 1970-01-01
      • 2015-06-18
      • 2011-08-29
      • 2017-03-12
      • 1970-01-01
      • 1970-01-01
      • 2020-04-16
      • 1970-01-01
      相关资源
      最近更新 更多