【问题标题】:How to iterate over PyTorch tensor如何迭代 PyTorch 张量
【发布时间】:2021-05-19 01:59:19
【问题描述】:

我有一个大小为 (1000,110) 的张量数据,我想遍历张量的第一个索引并计算以下内容。

    data = torch.randn(size=(1000,110)).to(device)
    
    male_poor = torch.tensor(0).float().to(device)
    male_rich = torch.tensor(0).float().to(device)
    
    female_poor = torch.tensor(0).float().to(device)
    female_rich = torch.tensor(0).float().to(device)
    
    for i in data:
    
        if torch.argmax(i[64:66]) == 0 and torch.argmax(i[108:110]) == 0:
          female_poor += 1
        if torch.argmax(i[64:66]) == 0 and torch.argmax(i[108:110]) == 1:
          female_rich += 1
        if torch.argmax(i[64:66]) == 1 and torch.argmax(i[108:110]) == 0:
          male_poor += 1
        if torch.argmax(i[64:66]) == 1 and torch.argmax(i[108:110]) == 1:
          male_rich += 1


    disparity = ((female_rich/(female_rich + female_poor))) / ((male_rich/(male_rich + male_poor)))

有比 for 循环更快的方法吗?

【问题讨论】:

    标签: pytorch iteration tensor


    【解决方案1】:

    pytorch(以及 numpy)中的关键是矢量化,也就是说,如果您可以通过对矩阵进行操作来删除循环,它会快得多。与底层编译的 C 代码中的循环相比,python 中的循环非常慢。在我的机器上,您的代码的执行时间约为 0.091 秒,以下矢量化代码约为 0.002 秒,因此大约快 x50:

    import torch
    torch.manual_seed(0)
    device = torch.device('cpu')
    
    data = torch.randn(size=(1000, 110)).to(device)
    
    import time
    t = time.time()
    #vectorize over first dimension
    argmax64_0 = torch.argmax(data[:, 64:66], dim=1) == 0
    argmax64_1 = torch.argmax(data[:, 64:66], dim=1) == 1
    argmax108_0 = torch.argmax(data[:, 108:110], dim=1) == 0
    argmax108_1 = torch.argmax(data[:, 108:110], dim=1) == 1
    female_poor = (argmax64_0 & argmax108_0).sum()
    female_rich = (argmax64_0 & argmax108_1).sum()
    male_poor = (argmax64_1 & argmax108_0).sum()
    male_rich = (argmax64_1 & argmax108_1).sum()
    
    disparity = ((female_rich / (female_rich + female_poor))) / ((male_rich / (male_rich + male_poor)))
    
    print(time.time()-t)
    print(disparity)
    

    【讨论】:

    • 我在 argmax 文档中没有看到 dim 参数,感谢您的帮助!
    猜你喜欢
    • 2020-10-28
    • 1970-01-01
    • 2019-08-24
    • 2021-03-05
    • 2018-10-09
    • 2021-12-25
    • 2021-11-04
    • 2021-02-01
    • 2021-02-07
    相关资源
    最近更新 更多