【发布时间】:2019-12-23 21:17:59
【问题描述】:
在尝试训练我的模型时,我在 tensorflow 中遇到输入形状错误。我已经检查了输入形状是否匹配,但我仍然收到错误,任何帮助将不胜感激。 X 的形状是 (1, 7),y 的形状是 (1, 2)。 我有这个代码:
import tensorflow as tf
import keras
import numpy as np
import json
with open("situations.json") as f:
data = json.load(f)
X = np.array([i[0] for i in data])
y = np.array([i[1] for i in data])
print(X)
print(y)
model = keras.Sequential([
keras.layers.InputLayer(input_shape=(7,)),
keras.layers.Dense(7),
keras.layers.Dense(2)
])
model.compile(optimizer="adam", loss="mean_squared_error")
model.fit(X, y, epochs=100)
文件“situations.json”中有这个,我尝试使用更多数据(但它被删除了):
[
[[60, 60, -1, -1, -1, -5, 0], [0, 0]]
]
我收到此错误:
Traceback (most recent call last):
File "CarSim.py", line 20, in <module>
model.fit(X, y, epochs=100)
File "\lib\site-packages\keras\engine\training.py", line 1239, in fit
validation_freq=validation_freq)
File "\lib\site-packages\keras\engine\training_arrays.py", line 196, in fit_loop
outs = fit_function(ins_batch)
File "\lib\site-packages\tensorflow_core\python\keras\backend.py", line 3740, in __call__
outputs = self._graph_fn(*converted_inputs)
File "\lib\site-packages\tensorflow_core\python\eager\function.py", line 1081, in __call__
return self._call_impl(args, kwargs)
File "\lib\site-packages\tensorflow_core\python\eager\function.py", line 1121, in _call_impl
return self._call_flat(args, self.captured_inputs, cancellation_manager)
File "\lib\site-packages\tensorflow_core\python\eager\function.py", line 1224, in _call_flat
ctx, args, cancellation_manager=cancellation_manager)
File "\lib\site-packages\tensorflow_core\python\eager\function.py", line 511, in call
ctx=ctx)
File "\lib\site-packages\tensorflow_core\python\eager\execute.py", line 67, in quick_execute
six.raise_from(core._status_to_exception(e.code, message), None)
File "<string>", line 3, in raise_from
tensorflow.python.framework.errors_impl.InternalError: Blas GEMM launch failed : a.shape=(1, 7), b.shape=(7, 7), m=1, n=7, k=7
[[node dense_1/MatMul (defined at \lib\site-packages\tensorflow_core\python\framework\ops.py:1751) ]] [Op:__inference_keras_scratch_graph_691]
Function call stack:
keras_scratch_graph
我尝试在situations.json 文件中使用更多数据,尝试不同的损失函数和神经网络架构,但总是遇到某种错误,这就是其中之一。我知道这与输入形状有关,但我无法修复它。任何帮助将不胜感激。
【问题讨论】:
标签: python numpy tensorflow machine-learning keras