【问题标题】:ValueError: Error when checking input: expected dense_1_input to have shape (8,) but got array with shape (1,)ValueError:检查输入时出错:预期dense_1_input的形状为(8,)但得到的数组形状为(1,)
【发布时间】:2020-10-25 06:20:45
【问题描述】:

我正在尝试训练一个神经网络来反弹一个球,但我在预测球的运动时遇到了问题,得到了错误 ValueError: Error when checking input: expected dense_1_input to have shape (8,) but got array with shape (1,)
我的代码:

from keras.models import Sequential
from keras.layers import Dense
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import csv
import pygame

running=True

def main():
    # load training data
    data_path = 'x.dat'
    with open(data_path, 'r') as f:
        reader = csv.reader(f, delimiter=',')
        headers = next(reader)
        x_train = np.array(list(reader)).astype(float)
    data_path = 'y.dat'
    with open(data_path, 'r') as f:
        reader = csv.reader(f, delimiter=',')
        headers = next(reader)
        y_train = np.array(list(reader)).astype(float)
    
    # debug print statement
    print(x_train)
    
    # define the keras model
    model = Sequential()
    model.add(Dense(8, input_shape=(8,), activation='relu'))
    model.add(Dense(10, activation='relu'))
    model.add(Dense(4, activation='relu'))

    # compile the keras model
    model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
    
    # fit the keras model on the dataset
    model.fit(x_train, y_train, epochs=1, batch_size=1000)

    # pygame initialization to visualize the ball
    global running, screen

    pygame.init()
    screen = pygame.display.set_mode((200, 200))
    pygame.display.set_caption("BallPhysics")
    screen.fill((255,255,255))
    pygame.display.update()

    # ball info
    x = 25
    y = 10
    xVel = -2
    yVel = 0
    gravity = 0.1
    elasticity = 0.9999
    radius = 10
    friction = 0.999

    while running:
        screen.fill((255,255,255))
        ev = pygame.event.get()

        # draw ball
        pygame.draw.circle(screen, (255, 0, 0), (x,y), radius)
        pygame.display.update()

        # make input array
        inp = [x/200,y/200,((xVel/10)+1)/2,((yVel/10)+1)/2,gravity,elasticity,radius/50.0,friction]
        print(inp)
        out = model.predict(inp)

        # set ball position and velocity to NN output
        x = out[0][0]
        y = out[0][1]
        xVel = out[0][2]
        yVel = out[0][3]
        
        # event handling
        for event in ev:
            if event.type == pygame.QUIT:
                running = False
                pygame.display.quit()
                pygame.quit()
        
            pygame.display.flip()

main()

调试打印语句打印出来

[[0.025568 0.131659 0.755605 ... 0.414219 0.094692 0.678865]
...
[0.08742  0.08742  0.5      ... 0.250432 0.699359 0.179118]]

这有点令人困惑,因为我注意到一件事是打印出来的数组没有逗号,而我制作的数组确实有逗号。这可能与它有关,但我不知道是什么。 任何帮助表示赞赏。

堆栈跟踪:

Traceback (most recent call last):
  File "/Users/grimtin10/Documents/Python Projects/BallPhysics/BallPhysics.py", line 78, in <module>
    main()
  File "/Users/grimtin10/Documents/Python Projects/BallPhysics/BallPhysics.py", line 63, in main
    out = model.predict(inp)
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/keras/engine/training.py", line 1149, in predict
    x, _, _ = self._standardize_user_data(x)
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/keras/engine/training.py", line 751, in _standardize_user_data
    exception_prefix='input')
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/keras/engine/training_utils.py", line 138, in standardize_input_data
    str(data_shape))
ValueError: Error when checking input: expected dense_1_input to have shape (8,) but got array with shape (1,)

【问题讨论】:

  • 我不知道问题出在哪里,但不是逗号。 Numpy 打印出不带逗号的数组,可能是为了节省空间。我觉得这很烦人。
  • 啊好的,很高兴知道这不是问题
  • 其实是要打印报表的。第一个打印语句是打印x_train,第二个是打印inp。后者只是 8 个元素的列表。似乎 model.predict 正在抱怨该输入。我不太了解这个包。这是传递给 model.predict 的正确参数吗?

标签: python tensorflow keras


【解决方案1】:

好吧,所以,我想通了。 事实证明,NN 期望 批量输入, 导致单个数组无法工作。我对 tf 和 keras 不是很熟悉,所以这就是问题所在。 使用np.reshape(inp,(1,8)) 修复它。

【讨论】:

    猜你喜欢
    • 2020-04-11
    • 2020-11-22
    • 2018-10-24
    • 2018-09-30
    • 2020-05-30
    • 2019-10-22
    • 1970-01-01
    • 1970-01-01
    • 2019-11-21
    相关资源
    最近更新 更多