【问题标题】:Tensorflow compute multiplication by binary matrixTensorflow 通过二进制矩阵计算乘法
【发布时间】:2018-09-17 20:56:30
【问题描述】:

我的数据张量的形状为 [batch_size,512],并且我有一个常数矩阵,其值仅为 0 和 1,其形状为 [256,512]

我想为每个批次有效地计算我的向量乘积的总和(数据张量的第二维),仅针对 1 而不是 0 的条目。

一个解释性的例子: 假设我有 1 个大小的批次:数据张量具有值 [5,4,3,7,8,2],而我的常量矩阵具有值:

[0,1,1,0,0,0]
[1,0,0,0,0,0]
[1,1,1,0,0,1]

这意味着我想计算第一行 4*3、第二行 5 和第三行 5*4*3*2。 在这批中,我得到4*3+5+5*4*3*2 等于 137。 目前,我通过迭代所有行,逐元素计算我的数据和常量矩阵行的乘积,然后求和,这运行速度很慢。

【问题讨论】:

    标签: python tensorflow binary-matrix


    【解决方案1】:

    这样的事情怎么样:

    import tensorflow as tf
    
    # Two-element batch
    data = [[5, 4, 3, 7, 8, 2],
            [4, 2, 6, 1, 6, 8]]
    mask = [[0, 1, 1, 0, 0, 0],
            [1, 0, 0, 0, 0, 0],
            [1, 1, 1, 0, 0, 1]]
    with tf.Graph().as_default(), tf.Session() as sess:
        # Data as tensors
        d = tf.constant(data, dtype=tf.int32)
        m = tf.constant(mask, dtype=tf.int32)
        # Tile data as needed
        dd = tf.tile(d[:, tf.newaxis], (1, tf.shape(m)[0], 1))
        mm = tf.tile(m[tf.newaxis, :], (tf.shape(d)[0], 1, 1))
        # Replace values with 1 wherever the mask is 0
        w = tf.where(tf.cast(mm, tf.bool), dd, tf.ones_like(dd))
        # Multiply row-wise and sum
        result = tf.reduce_sum(tf.reduce_prod(w, axis=-1), axis=-1)
        print(sess.run(result))
    

    输出:

    [137 400]
    

    编辑:

    如果您输入的数据是单个向量,那么您只需:

    import tensorflow as tf
    
    # Two-element batch
    data = [5, 4, 3, 7, 8, 2]
    mask = [[0, 1, 1, 0, 0, 0],
            [1, 0, 0, 0, 0, 0],
            [1, 1, 1, 0, 0, 1]]
    with tf.Graph().as_default(), tf.Session() as sess:
        # Data as tensors
        d = tf.constant(data, dtype=tf.int32)
        m = tf.constant(mask, dtype=tf.int32)
        # Tile data as needed
        dd = tf.tile(d[tf.newaxis], (tf.shape(m)[0], 1))
        # Replace values with 1 wherever the mask is 0
        w = tf.where(tf.cast(m, tf.bool), dd, tf.ones_like(dd))
        # Multiply row-wise and sum
        result = tf.reduce_sum(tf.reduce_prod(w, axis=-1), axis=-1)
        print(sess.run(result))
    

    输出:

    137
    

    【讨论】:

    • 如果批量大小大于 1,我应该在您的解决方案中替换什么? @jdehesa
    • @Codevan 在这种情况下,您需要每个 bach 元素的结果还是总和?
    • @Codevan 我已经为每个元素的结果进行了编辑,如果需要,您可以稍后聚合它。
    • 是的,现在很好。另一个问题 - 我在 w 矩阵中收到 -1 而不是 1,可能是因为我的掩码(和数据)不是浮点数的整数。我该如何解决这个问题?
    • @Codevan 问题在于输入的维数在每种情况下都不同。您可以在所有情况下使用第一个代码,在平铺之前添加类似d = tf.reshape(d, (-1, tf.shape(d)[-1])) 的内容。
    猜你喜欢
    • 2011-09-26
    • 2019-04-04
    • 1970-01-01
    • 1970-01-01
    • 2018-01-21
    • 1970-01-01
    • 1970-01-01
    • 2019-09-08
    • 1970-01-01
    相关资源
    最近更新 更多