【问题标题】:Keras: Using use_multiprocessing=True in predict_generator gives more predictions than required?Keras:在 predict_generator 中使用 use_multiprocessing=True 会提供比所需更多的预测?
【发布时间】:2018-03-12 01:41:44
【问题描述】:

由于内存问题,我正在研究语言建模问题并使用 predict_generator 函数。我面临的问题是 predict_generator 提供的预测比输入的大小更多。

我在 predict_generator 函数中提供的参数:

predictions = model.predict_generator(testDataGenerator(statements),
                                                  use_multiprocessing=True,workers=4,
                                                  steps=25,
                                                  verbose=1)

生成器函数:

def testDataGenerator(testDataFrame):
        testDataFrame.reset_index(drop=True, inplace=True)
        startPoint = 0
        endPoint = 64
        while True:
            statementSet = testDataFrame[startPoint:endPoint]
            test = buildTrainAndTestSets(statementSet)
            startPoint = endPoint
            endPoint += 64
            yield test

我总共有 1568 个输入,我将它们分批发送 64 个,但我得到了 1600 个预测。错误输出为:

25/25 [==============================] - 47s 2s/step
IndexError: Length of values does not match length of index

我认为我在生成器函数中发送语句的方式在这里有问题。

【问题讨论】:

  • 你有完整的代码吗?这有点像在黑暗中打,没有看到问题可能发生在哪里。

标签: python nlp deep-learning keras


【解决方案1】:

如果您使用自定义生成器,您必须谨慎使用预测器的最后一步。

由于您以 64 批大小执行 25 步,因此生成器希望您的数据正好是 1600,我认为在您的生成器中更改端点的简单 if 应该可以解决您的问题。

【讨论】:

    猜你喜欢
    • 1970-01-01
    • 2022-07-04
    • 2021-08-25
    • 2019-12-22
    • 2019-10-09
    • 1970-01-01
    • 2018-09-11
    • 2019-02-15
    • 1970-01-01
    相关资源
    最近更新 更多