【问题标题】:How to train LSTM with single label per "batch"如何用每个“批次”的单个标签训练 LSTM
【发布时间】:2019-05-11 23:13:33
【问题描述】:

我想训练有状态的 LSTM 模型进行时间序列预测。 最初我认为我应该写:

    for batch in range(len(features) - window_size):

        # get arrays for the batch
        fb = features[batch:batch+window_size,:]
        lb = labels[batch:batch+window_size,:]

        #reshape
        fb = fb.reshape(1, fb.shape[0], fb.shape[1])
        lb = lb.reshape(1, lb.shape[0], lb.shape[1])

        # train        
        model.train_on_batch(fb, lb) # .fit(fb, lb, epochs=1, batch_size=window_size)

以上内容应该从 10000 个样本中提取 100 个固定大小的窗口,并在每次迭代(0..99、1..100、2..101 等)时对每个移动进行一次训练。在这种情况下,x 和 y 的长度相同,均为 100。

这里的窗口和批次是一样的。但实际上,一个窗口/批次有一个标签。

将表示移动猫视频和窗口的数据视为快照/图片,它只能用一个标签分类,而不是 100 个标签。想象它是一只猫的照片。并且拥有 100 个标签意味着必须以某种方式标记每一行,但实际上这没有任何意义。例如,一张快照可以标记为猫移动的距离。所以窗口 1 标签是 0,窗口 2 - 1mm 等等。

我在描述我的模型时误解了批次定义吗?

在这种情况下正确的输入/输出形状/批量大小是多少?

编辑:引入视频是为了解释自己。实际上,数据集是天数,其中每天有 10000 个度量(正常世界中的样本),具有 7 个输入/特征和 8 个单热标签。当前 train_on_batch 的输入形状为 (1, 100, 32) [32 是 LSTM 神经元数]。

【问题讨论】:

    标签: python tensorflow keras


    【解决方案1】:

    视频分类通常使用带有 3D 卷积核的卷积网络来完成。例如,看看谷歌和斯坦福研究员的this paper

    在您的情况下,您使用的是 LSTM,标记窗口的每一帧是没有意义的。您可以做的是只有一个输出并将其与您的标签进行比较(多对一架构)。换句话说,即使你的 LSTM 单元在每一步都产生输出,你也只考虑最后一步的输出来计算你的损失。

    for epoch in n_epoch: # number of batches to show to your LSTM
        # batch_features = numFrames x batchSize x numChannel x Width x Height
        # batch_labels   = batchSize x 1
        batch_features, batch_labels = getBatch() 
    
        # initialize cell state
        h = zeros()
        for frame in numFrames:
            # here the main loop of the LSTM. out will be constantly overwritten
            h, out = LSTM(h, batch_features[frame])
    
        # use only the final output to compute the loss
        loss = crossEntropywithLogits(out, batch_labels)
    

    【讨论】:

    • 谢谢。视频旨在帮助解释(见编辑)。我将尝试您的建议,但是该模型目前在 K80 GPU 上每天运行 1 小时的数据,而我有数百天。我希望通过不运行每个标签来缩短这些时间
    猜你喜欢
    • 1970-01-01
    • 2020-12-15
    • 1970-01-01
    • 2017-08-15
    • 2020-01-29
    • 1970-01-01
    • 2019-03-22
    • 2018-08-05
    • 2019-11-25
    相关资源
    最近更新 更多