【问题标题】:Torch - Query matrix with another matrixTorch - 用另一个矩阵查询矩阵
【发布时间】:2016-02-09 12:31:23
【问题描述】:

我有一个 m x n 张量(张量 1)和另一个 k x 2 张量(张量 2),我希望使用基于张量 2 的索引提取张量 1 的所有值。例如;

Tensor1
  1   2   3   4   5
  6   7   8   9  10
 11  12  13  14  15
 16  17  18  19  20
[torch.DoubleTensor of size 4x5]

Tensor2
 2  1
 3  5
 1  1
 4  3
[torch.DoubleTensor of size 4x2]

函数会产生;

6
15
1
18

【问题讨论】:

    标签: torch


    【解决方案1】:

    想到的第一个解决方案是简单地遍历索引并选择相应的值:

    function get_elems_simple(tensor, indices)
        local res = torch.Tensor(indices:size(1)):typeAs(tensor)
        local i = 0
        res:apply(
            function () 
                i = i + 1
                return tensor[indices[i]:clone():storage()] 
            end)
        return res
    end
    

    这里的tensor[indices[i]:clone():storage()] 只是一种从多维张量中选择元素的通用方法。在 k 维的情况下,这与tensor[{indices[i][1], ... , indices[i][k]}] 完全相同。

    如果您不必提取大量值,则此方法可以正常工作(瓶颈是 :apply 方法,它不能使用许多优化技术和 SIMD 指令,因为它执行的函数是一个黑盒)。这项工作可以更有效地完成::index 方法完全可以满足您的需求……使用一维张量。多维目标/索引张量需要展平:

    function flatten_indices(sp_indices, shape)
        sp_indices = sp_indices - 1
        local n_elem, n_dim = sp_indices:size(1), sp_indices:size(2)
        local flat_ind = torch.LongTensor(n_elem):fill(1)
    
        local mult = 1
        for d = n_dim, 1, -1 do
            flat_ind:add(sp_indices[{{}, d}] * mult)
            mult = mult * shape[d]
        end
        return flat_ind
    end
    
    function get_elems_efficient(tensor, sp_indices)
        local flat_indices = flatten_indices(sp_indices, tensor:size()) 
        local flat_tensor = tensor:view(-1)
        return flat_tensor:index(1, flat_indices)
    end
    

    差别很大:

    n = 500000
    k = 100
    a = torch.rand(n, k)
    ind = torch.LongTensor(n, 2)
    ind[{{}, 1}]:random(1, n)
    ind[{{}, 2}]:random(1, k)
    
    elems1 = get_elems_simple(a, ind)      # 4.53 sec
    elems2 = get_elems_efficient(a, ind)   # 0.05 sec
    
    print(torch.all(elems1:eq(elems2)))    # true
    

    【讨论】:

    • 非常感谢!我这样做是为了对图像进行 2D/3D 插值,最初使用了类似的 for 循环结构,但正如你所说,它有很大的不同!
    • 双张量可以正常工作,但您知道为什么这不适用于 cuda 张量吗?
    • @mattdns,“无法正常工作”是什么意思?我无法访问 CUDA,因此无法重现您的问题。我只能猜测。 Cpu 张量以某种方式存储在内存中 - 逐行。这定义了如何展平索引。 Cuda 张量可能有自己的结构。如果是真的,你需要不同的flatten_indices 函数。要确定我是否正确,请比较扁平的 cuda 和 cpu 张量:torch.all(float_tens:view(-1):eq(cuda_tens:view(-1)))
    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2015-06-16
    相关资源
    最近更新 更多