【发布时间】:2018-03-12 01:41:44
【问题描述】:
由于内存问题,我正在研究语言建模问题并使用 predict_generator 函数。我面临的问题是 predict_generator 提供的预测比输入的大小更多。
我在 predict_generator 函数中提供的参数:
predictions = model.predict_generator(testDataGenerator(statements),
use_multiprocessing=True,workers=4,
steps=25,
verbose=1)
生成器函数:
def testDataGenerator(testDataFrame):
testDataFrame.reset_index(drop=True, inplace=True)
startPoint = 0
endPoint = 64
while True:
statementSet = testDataFrame[startPoint:endPoint]
test = buildTrainAndTestSets(statementSet)
startPoint = endPoint
endPoint += 64
yield test
我总共有 1568 个输入,我将它们分批发送 64 个,但我得到了 1600 个预测。错误输出为:
25/25 [==============================] - 47s 2s/step
IndexError: Length of values does not match length of index
我认为我在生成器函数中发送语句的方式在这里有问题。
【问题讨论】:
-
你有完整的代码吗?这有点像在黑暗中打,没有看到问题可能发生在哪里。
标签: python nlp deep-learning keras