【问题标题】:Emulating nested for loops with scan is slow用 scan 模拟嵌套的 for 循环很慢
【发布时间】:2016-06-18 00:41:50
【问题描述】:

我正在尝试使用 scan 函数来模拟嵌套 for 循环,但这很慢。有没有更好的方法来模拟使用 Tensorflow 的循环嵌套?我不是只用 numpy 来做这个计算,所以我可以进行自动微分。

具体来说,我在使用 Tensorflow 控制操作的同时使用双边滤波器对图像进行卷积。为此,我嵌套了 scan() 函数,但这让我的性能非常差——过滤小图像需要 5 分钟以上。

有没有比嵌套扫描函数更好的方法,我使用 Tensorflow 控制流操作有多糟糕? 我感兴趣的一般答案不止一个针对我的代码的具体答案。

如果你想看的话,这里是原始的、更快的代码:

def bilateralFilter(image, sigma_space=1, sigma_range=None, win_size=None):

    if sigma_range is None:
        sigma_range = sigma_space
    if win_size is None: win_size = max(5, 2 * int(np.ceil(3*sigma_space)) + 1)

    win_ext = (win_size - 1) / 2
    height = image.shape[0]
    width = image.shape[1]

    # pre-calculate spatial_gaussian
    spatial_gaussian = []
    for i in range(-win_ext, win_ext+1):
        for j in range(-win_ext, win_ext+1):
            spatial_gaussian.append(np.exp(-0.5*(i**2+j**2)/sigma_space**2))

    padded = np.pad(image, win_ext, mode="edge")

    out_image = np.zeros(image.shape)
    weight = np.zeros(image.shape)

    idx = 0
    for row in xrange(-win_ext, 1+win_ext):
        for col in xrange(-win_ext, 1+win_ext):
            slice = padded[win_ext+row:height+win_ext+row,
                                          win_ext+col:width+win_ext+col]
            value = np.exp(-0.5*((image - slice)/sigma_range)**2) \
                    * spatial_gaussian[idx]
            out_image += value*slice
            weight += value
            idx += 1

    out_image /= weight

    return out_image

这是 TensorFlow 版本:

sess = tf.InteractiveSession()
with sess.as_default():
    def bilateralFilter(image, sigma_space, sigma_range):
        win_size = max(5., 2 * np.ceil(3 * sigma_space) + 1)

        win_ext = int((win_size - 1) / 2)
        height = tf.shape(image)[0].eval()
        width = tf.shape(image)[1].eval()

        spatial_gaussian = []
        for i in range(-win_ext, win_ext + 1):
            for j in range(-win_ext, win_ext + 1):
                spatial_gaussian.append(np.exp(-0.5 * (i ** 2 +\
                 j ** 2) / sigma_space ** 2))

        # we use "symmetric" as it best approximates "edge" padding
        padded = tf.pad(image, [[win_ext, win_ext], [win_ext, win_ext]],
                 mode='SYMMETRIC')
        out_image = tf.zeros(tf.shape(image))
        weight = tf.zeros(tf.shape(image))

        spatial_index = tf.constant(0)
        row = tf.constant(-win_ext)
        col = tf.constant(-win_ext)

        def cond(padded, row, col, weight, out_image, spatial_index):
            return tf.less(row, win_ext + 1)

        def body(padded, row, col, weight, out_image, spatial_index):
            sub_image = tf.slice(padded, [win_ext + row, win_ext + col],
                        [height, width])
            value = tf.exp(-0.5 *
                    (((image - sub_image) / sigma_range) ** 2)) * 
                     spatial_gaussian[spatial_index.eval()]
            out_image += value * sub_image
            weight += value
            spatial_index += 1
            row, col = tf.cond(tf.not_equal(tf.mod(col,
                               tf.constant(2*win_ext + 1)), 0),
                               lambda: (row + 1, tf.constant(-win_ext)),
                               lambda: (row, col))
            return padded, row, col, weight, out_image, spatial_index

        padded, row, col, weight, out_image, spatial_index =
        tf.while_loop(cond, body,
        [padded, row, col, weight, out_image, spatial_index])
        out_image /= weight

        return out_image

    cat = plt.imread("cat.png")  # grayscale
    cat = tf.reshape(tf.constant(cat), [276, 276])
    cat_blurred = bilateralFilter(cat, 2., 0.25)
    cat_blurred = cat_blurred.eval()
    plt.figure()
    plt.gray()
    plt.imshow(cat_blurred)
    plt.show()

【问题讨论】:

    标签: numpy tensorflow


    【解决方案1】:

    这是您的代码的一个问题。 cols() 有一堆 python 全局变量,您似乎希望它们在每次循环迭代时更新。请查看有关图形构建和执行的 TensorFlow 教程。简而言之,那些 python 全局变量及其相关代码只会在图构建时执行,它们甚至不在 TensorFlow 的执行图中。一个操作只能包含在执行图中,如果它是一个 tf 操作符。

    此外,tf.while_loop 似乎比扫描更适合您的代码。

    【讨论】:

    • 我已经更新了代码,但我认为我已经尽我所能,因为我使用张量进行索引,而 Tensorflow 不允许这样做。 “ValueError:Fetch 参数...已被标记为不可提取。”
    猜你喜欢
    • 1970-01-01
    • 2017-08-19
    • 1970-01-01
    • 1970-01-01
    • 2014-05-17
    • 2022-06-24
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多