【发布时间】:2020-10-25 06:20:45
【问题描述】:
我正在尝试训练一个神经网络来反弹一个球,但我在预测球的运动时遇到了问题,得到了错误 ValueError: Error when checking input: expected dense_1_input to have shape (8,) but got array with shape (1,)
我的代码:
from keras.models import Sequential
from keras.layers import Dense
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import csv
import pygame
running=True
def main():
# load training data
data_path = 'x.dat'
with open(data_path, 'r') as f:
reader = csv.reader(f, delimiter=',')
headers = next(reader)
x_train = np.array(list(reader)).astype(float)
data_path = 'y.dat'
with open(data_path, 'r') as f:
reader = csv.reader(f, delimiter=',')
headers = next(reader)
y_train = np.array(list(reader)).astype(float)
# debug print statement
print(x_train)
# define the keras model
model = Sequential()
model.add(Dense(8, input_shape=(8,), activation='relu'))
model.add(Dense(10, activation='relu'))
model.add(Dense(4, activation='relu'))
# compile the keras model
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
# fit the keras model on the dataset
model.fit(x_train, y_train, epochs=1, batch_size=1000)
# pygame initialization to visualize the ball
global running, screen
pygame.init()
screen = pygame.display.set_mode((200, 200))
pygame.display.set_caption("BallPhysics")
screen.fill((255,255,255))
pygame.display.update()
# ball info
x = 25
y = 10
xVel = -2
yVel = 0
gravity = 0.1
elasticity = 0.9999
radius = 10
friction = 0.999
while running:
screen.fill((255,255,255))
ev = pygame.event.get()
# draw ball
pygame.draw.circle(screen, (255, 0, 0), (x,y), radius)
pygame.display.update()
# make input array
inp = [x/200,y/200,((xVel/10)+1)/2,((yVel/10)+1)/2,gravity,elasticity,radius/50.0,friction]
print(inp)
out = model.predict(inp)
# set ball position and velocity to NN output
x = out[0][0]
y = out[0][1]
xVel = out[0][2]
yVel = out[0][3]
# event handling
for event in ev:
if event.type == pygame.QUIT:
running = False
pygame.display.quit()
pygame.quit()
pygame.display.flip()
main()
调试打印语句打印出来
[[0.025568 0.131659 0.755605 ... 0.414219 0.094692 0.678865]
...
[0.08742 0.08742 0.5 ... 0.250432 0.699359 0.179118]]
这有点令人困惑,因为我注意到一件事是打印出来的数组没有逗号,而我制作的数组确实有逗号。这可能与它有关,但我不知道是什么。 任何帮助表示赞赏。
堆栈跟踪:
Traceback (most recent call last):
File "/Users/grimtin10/Documents/Python Projects/BallPhysics/BallPhysics.py", line 78, in <module>
main()
File "/Users/grimtin10/Documents/Python Projects/BallPhysics/BallPhysics.py", line 63, in main
out = model.predict(inp)
File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/keras/engine/training.py", line 1149, in predict
x, _, _ = self._standardize_user_data(x)
File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/keras/engine/training.py", line 751, in _standardize_user_data
exception_prefix='input')
File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/keras/engine/training_utils.py", line 138, in standardize_input_data
str(data_shape))
ValueError: Error when checking input: expected dense_1_input to have shape (8,) but got array with shape (1,)
【问题讨论】:
-
我不知道问题出在哪里,但不是逗号。 Numpy 打印出不带逗号的数组,可能是为了节省空间。我觉得这很烦人。
-
啊好的,很高兴知道这不是问题
-
其实是要打印报表的。第一个打印语句是打印
x_train,第二个是打印inp。后者只是 8 个元素的列表。似乎 model.predict 正在抱怨该输入。我不太了解这个包。这是传递给 model.predict 的正确参数吗?
标签: python tensorflow keras