【问题标题】:How to do multi-task learning in torch7?如何在torch7中进行多任务学习?
【发布时间】:2018-05-26 03:02:20
【问题描述】:

Simple multi-task network can be done here. 但我想要类似enter image description here 的东西。 现在我构建模型如下:

model = nn.Sequential()
model:add(nn.Linear(3,5))
prl1 = nn.ConcatTable()
prl1:add(nn.Linear(5,1))
prl2 = nn.ConcatTable()
prl2:add(nn.Linear(5,1))
prl2:add(nn.Linear(5,1))
prl1:add(prl2)
model:add(prl1)

我的输出是:

input = torch.rand(5,3)
output = model:forward(input)
output
{
  1 : DoubleTensor - size: 5x1
  2 : 
    {
      1 : DoubleTensor - size: 5x1
      2 : DoubleTensor - size: 5x1
    }
}

我应该如何构建我的标准?

【问题讨论】:

    标签: lua torch


    【解决方案1】:

    我好像是通过两步搞定的:

    1.在上述网络中使用 nn.Concat 代替 nn.ConcatTable,这使得输出成为一个简单的 NxM 张量,例如当使用 nn.Concat 而不是 nn.ConcatTable 时,一个 5x3 的张量将进入上述网络。

    2.得到一个NxM张量后,我使用nn.ConcatTable、nn.Concat和nn.Select的组合使输出成为包含每个结果张量的简单表格。

    这是第 2 步的简单示例:

    model = nn.Sequential()
    model:add(nn.Linear(3,5))
    
    prl = nn.ConcatTable()
    
    spl1 = nn.Concat(2)
    
    seq1 = nn.Sequential()
    seq1:add(nn.Select(2, 1))
    seq1:add(nn.Reshape(1))
    
    seq2 = nn.Sequential()
    seq2:add(nn.Select(2, 2))
    seq2:add(nn.Reshape(1))
    
    seq3 = nn.Sequential()
    seq3:add(nn.Select(2, 3))
    seq3:add(nn.Reshape(1))
    
    spl1:add(seq1)
    spl1:add(seq2)
    spl1:add(seq3)
    prl:add(spl1)
    
    spl2 = nn.Concat(2)
    
    seq4 = nn.Sequential()
    seq4:add(nn.Select(2, 4))
    seq4:add(nn.Reshape(1))
    
    seq5 = nn.Sequential()
    seq5:add(nn.Select(2, 5))
    seq5:add(nn.Reshape(1))
    
    spl2:add(seq4)
    spl2:add(seq5)
    prl:add(spl2)
    
    model:add(prl)
    
    input = torch.rand(5,3)
    output = model:forward(input)
    

    输出将如下所示:

    th> output
    {
      1 : DoubleTensor - size: 5x3
      2 : DoubleTensor - size: 5x2
    }
    

    【讨论】:

      猜你喜欢
      • 2021-12-17
      • 1970-01-01
      • 2012-01-31
      • 1970-01-01
      • 2020-03-15
      • 2019-04-20
      • 1970-01-01
      • 2019-10-11
      • 2020-09-19
      相关资源
      最近更新 更多