【问题标题】:tensorflow periodic padding张量流周期性填充
【发布时间】:2016-08-22 20:47:54
【问题描述】:

在 tensorflow 中,我找不到直接的可能性来使用周期性边界条件进行卷积 (tf.nn.conv2d)。

例如取张量

[[1,2,3],
 [4,5,6],
 [7,8,9]]

和任何 3x3 过滤器。原则上可以通过对 5x5 进行周期性填充来完成具有周期性边界条件的卷积

[[9,7,8,9,7],
 [3,1,2,3,1],
 [6,4,5,6,4],
 [9,7,8,9,7],
 [3,1,2,3,1]]

随后在“有效”模式下与过滤器进行卷积。但是,很遗憾,函数tf.pad 不支持周期性填充。

有简单的解决方法吗?

【问题讨论】:

    标签: python tensorflow


    【解决方案1】:

    以下内容应该适用于您的情况:

    import tensorflow as tf
    a = tf.constant([[1,2,3],[4,5,6],[7,8,9]])
    b = tf.tile(a, [3, 3])
    result = b[2:7, 2:7]
    sess = tf.InteractiveSession()
    print(result.eval())
    
    # prints the following 
    array([[9, 7, 8, 9, 7],
           [3, 1, 2, 3, 1],
           [6, 4, 5, 6, 4],
           [9, 7, 8, 9, 7],
           [3, 1, 2, 3, 1]], dtype=int32)
    

    正如 cmets 中所述,这在内存方面有点低效。如果内存对您来说是个问题,但愿意花费一些计算,那么以下方法也可以:

    pre = tf.constant([[0, 0, 1], [1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0]])
    post = tf.transpose(pre)
    result = tf.matmul(tf.matmul(pre, a), post)
    print(result.eval())
    

    【讨论】:

    • 非常感谢!工作正常(除了它应该是 result.eval())。但是,如果“a”很大而过滤器很小,这对我来说似乎有点低效。或者 tensorflow 是否能够弄清楚它实际上并不需要计算 b 的所有分量?
    • 抱歉错字,已修复。是的,这种方法效率低下,因为它创建了 9 个图像副本并丢弃了其中的大部分。 TensorFlow 目前没有优化计算 result 而不具体化 b
    【解决方案2】:

    更加通用和灵活:一个或多个指定轴的周期性填充,可选择为不同轴指定不同的填充长度

    import tensorflow as tf
    
    def periodic_padding_flexible(tensor, axis,padding=1):
        """
            add periodic padding to a tensor for specified axis
            tensor: input tensor
            axis: on or multiple axis to pad along, int or tuple
            padding: number of cells to pad, int or tuple
    
            return: padded tensor
        """
    
    
        if isinstance(axis,int):
            axis = (axis,)
        if isinstance(padding,int):
            padding = (padding,)
    
        ndim = len(tensor.shape)
        for ax,p in zip(axis,padding):
            # create a slice object that selects everything from all axes,
            # except only 0:p for the specified for right, and -p: for left
    
            ind_right = [slice(-p,None) if i == ax else slice(None) for i in range(ndim)]
            ind_left = [slice(0, p) if i == ax else slice(None) for i in range(ndim)]
            right = tensor[ind_right]
            left = tensor[ind_left]
            middle = tensor
            tensor = tf.concat([right,middle,left], axis=ax)
    
        return tensor
    
    
    
    a = tf.constant([
        [[1,2,3],[4,5,6],[7,8,9]],
        [[11,12,13],[14,15,16],[17,18,19]],
    ])
    
    sess = tf.InteractiveSession()
    
    result = periodic_padding_flexible(a, axis=1,padding=1)
    print('a:')
    print(a.eval())
    print('padded a:')
    print(result.eval())
    
    result = periodic_padding_flexible(a, axis=2,padding=1)
    print('a:')
    print(a.eval())
    print('padded a:')
    print(result.eval())
    
    result = periodic_padding_flexible(a, axis=(1,2),padding=(1,2))
    print('a:')
    print(a.eval())
    print('padded a:')
    print(result.eval())
    

    输出:

    a:
    [[[ 1  2  3]
      [ 4  5  6]
      [ 7  8  9]]
     [[11 12 13]
      [14 15 16]
      [17 18 19]]]
    padded a:
    [[[ 7  8  9]
      [ 1  2  3]
      [ 4  5  6]
      [ 7  8  9]
      [ 1  2  3]]
     [[17 18 19]
      [11 12 13]
      [14 15 16]
      [17 18 19]
      [11 12 13]]]
    a:
    [[[ 1  2  3]
      [ 4  5  6]
      [ 7  8  9]]
     [[11 12 13]
      [14 15 16]
      [17 18 19]]]
    padded a:
    [[[ 3  1  2  3  1]
      [ 6  4  5  6  4]
      [ 9  7  8  9  7]]
     [[13 11 12 13 11]
      [16 14 15 16 14]
      [19 17 18 19 17]]]
    a:
    [[[ 1  2  3]
      [ 4  5  6]
      [ 7  8  9]]
     [[11 12 13]
      [14 15 16]
      [17 18 19]]]
    padded a:
    [[[ 8  9  7  8  9  7  8]
      [ 2  3  1  2  3  1  2]
      [ 5  6  4  5  6  4  5]
      [ 8  9  7  8  9  7  8]
      [ 2  3  1  2  3  1  2]]
     [[18 19 17 18 19 17 18]
      [12 13 11 12 13 11 12]
      [15 16 14 15 16 14 15]
      [18 19 17 18 19 17 18]
      [12 13 11 12 13 11 12]]]
    

    【讨论】:

      【解决方案3】:

      这是 tensorflow 中周期性填充的实现,适用于一批二维图像。它使用切片和 tf.concat:

      def periodic_padding(x, padding=1):
          '''
          x: shape (batch_size, d1, d2)
          return x padded with periodic boundaries. i.e. torus or donut
          '''
          d1 = x.shape[1] # dimension 1: height
          d2 = x.shape[2] # dimension 2: width
          p = padding
          # assemble padded x from slices
          #            tl,tc,tr
          # padded_x = ml,mc,mr
          #            bl,bc,br
          top_left = x[:, -p:, -p:] # top left
          top_center = x[:, -p:, :] # top center
          top_right = x[:, -p:, :p] # top right
          middle_left = x[:, :, -p:] # middle left
          middle_center = x # middle center
          middle_right = x[:, :, :p] # middle right
          bottom_left = x[:, :p, -p:] # bottom left
          bottom_center = x[:, :p, :] # bottom center
          bottom_right = x[:, :p, :p] # bottom right
          top = tf.concat([top_left, top_center, top_right], axis=2)
          middle = tf.concat([middle_left, middle_center, middle_right], axis=2)
          bottom = tf.concat([bottom_left, bottom_center, bottom_right], axis=2)
          padded_x = tf.concat([top, middle, bottom], axis=1)
          return padded_x
      
      import tensorflow as tf
      a = tf.constant([
          [[1,2,3],[4,5,6],[7,8,9]],
          [[11,12,13],[14,15,16],[17,18,19]],
      ])
      result = periodic_padding(a, padding=1)
      sess = tf.InteractiveSession()
      print('a:')
      print(a.eval())
      print('padded a:')
      print(result.eval())
      sess.close()
      

      例子的输出是:

      a:
      [[[ 1  2  3]
        [ 4  5  6]
        [ 7  8  9]]
      
       [[11 12 13]
        [14 15 16]
        [17 18 19]]]
      padded a:
      [[[ 9  7  8  9  7]
        [ 3  1  2  3  1]
        [ 6  4  5  6  4]
        [ 9  7  8  9  7]
        [ 3  1  2  3  1]]
      
       [[19 17 18 19 17]
        [13 11 12 13 11]
        [16 14 15 16 14]
        [19 17 18 19 17]
        [13 11 12 13 11]]]
      

      【讨论】:

        猜你喜欢
        • 2023-03-29
        • 1970-01-01
        • 1970-01-01
        • 1970-01-01
        • 2021-04-01
        • 1970-01-01
        • 1970-01-01
        • 2021-08-16
        • 2018-11-14
        相关资源
        最近更新 更多