【问题标题】:How can I make this PyTorch heatmap function faster and more efficient?我怎样才能让这个 PyTorch 热图函数更快更高效?
【发布时间】:2021-04-17 18:30:24
【问题描述】:

我有这个函数可以为 2d 张量创建一个排序 if 热图,但是当使用更大的张量输入时它会非常缓慢。如何加快速度并提高效率?

import torch
import numpy as np
import matplotlib.pyplot as plt


def heatmap(
    tensor: torch.Tensor,
) -> torch.Tensor:
    assert tensor.dim() == 2

    def color_tensor(x: torch.Tensor) -> torch.Tensor:
        if x < 0:
            x = -x
            if x < 0.5:
                x = x * 2
                return (1 - x) * torch.tensor(
                    [0.9686, 0.9686, 0.9686]
                ) + x * torch.tensor([0.5725, 0.7725, 0.8706])
            else:
                x = (x - 0.5) * 2
                return (1 - x) * torch.tensor(
                    [0.5725, 0.7725, 0.8706]
                ) + x * torch.tensor([0.0196, 0.4431, 0.6902])
        else:
            if x < 0.5:
                x = x * 2
                return (1 - x) * torch.tensor(
                    [0.9686, 0.9686, 0.9686]
                ) + x * torch.tensor([0.9569, 0.6471, 0.5098])
            else:
                x = (x - 0.5) * 2
                return (1 - x) * torch.tensor(
                    [0.9569, 0.6471, 0.5098]
                ) + x * torch.tensor([0.7922, 0.0000, 0.1255])

    return torch.stack(
        [torch.stack([color_tensor(x) for x in t]) for t in tensor]
    ).permute(2, 0, 1)

x = torch.randn(3,3)
x = x / x.max()
x_out = heatmap(x)

x_out = (x_out.permute(1, 2, 0) * 255).numpy()
plt.imshow(x_out.astype(np.uint8))
plt.axis("off")
plt.show()

输出示例:

【问题讨论】:

    标签: python arrays numpy pytorch tensor


    【解决方案1】:

    您需要摆脱 ifs 和 for 循环并创建一个矢量化函数。为此,您可以使用掩码并计算所有内容。这里是:

    
    def heatmap(tensor: torch.Tensor) -> torch.Tensor:
        assert tensor.dim() == 2
    
        # We're expanding to create one more dimension, for mult. to work.
        xt = x.expand((3, x.shape[0], x.shape[1])).permute(1, 2, 0)
    
        # this part is the mask: (xt >= 0) * (xt < 0.5) ...
        # ... the rest is the original function translated
        color_tensor = (
            (xt >= 0) * (xt < 0.5) * ((1 - xt * 2) * torch.tensor([0.9686, 0.9686, 0.9686]) + xt * 2 * torch.tensor([0.9569, 0.6471, 0.5098]))
            +
            (xt >= 0) * (xt >= 0.5) * ((1 - (xt - 0.5) * 2) * torch.tensor([0.9569, 0.6471, 0.5098]) + (xt - 0.5) * 2 * torch.tensor([0.7922, 0.0000, 0.1255]))
            +
            (xt < 0) * (xt > -0.5) * ((1 - (-xt * 2)) * torch.tensor([0.9686, 0.9686, 0.9686]) + (-xt * 2) * torch.tensor([0.5725, 0.7725, 0.8706]))
            +
            (xt < 0) * (xt <= -0.5) * ((1 - (-xt - 0.5) * 2) * torch.tensor([0.5725, 0.7725, 0.8706]) + (-xt - 0.5) * 2 * torch.tensor([0.0196, 0.4431, 0.6902]))
        ).permute(2, 0, 1)
        
        return color_tensor
    
    

    【讨论】:

    • 你的矢量化版本比我的代码好很多。谢谢你帮助我!
    猜你喜欢
    • 1970-01-01
    • 2011-06-14
    • 2021-07-09
    • 2021-11-06
    • 2016-08-10
    • 2015-03-16
    • 2021-02-17
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多