【发布时间】:2021-05-19 14:46:14
【问题描述】:
我正在实施一个 ConvNet 来预测游戏小地图中的回合胜利。但是我在训练网络时遇到了问题。
当我运行以下代码时,我得到了错误:
ValueError: Input 0 is incompatible with layer model: expected shape=(None, 312, 312, 3), found shape=(312, 312, 3)
代码:
import tensorflow as tf
from tensorflow import keras
from tensorflow.python.keras.backend import relu
import numpy as np
#input layers
minimapInput = keras.Input(shape = (312, 312, 3), name="minimap_image")
#layers for mini input
mini = keras.layers.Conv2D(filters=32, kernel_size=3, padding='same', activation='relu')(minimapInput)
mini = keras.layers.Conv2D(filters=32, kernel_size=3, padding='same', activation='relu')(mini)
mini = keras.layers.AveragePooling2D(pool_size=(2,2), strides=2)(mini)
mini = keras.layers.Flatten()(mini)
mini = keras.layers.Dense(250, activation='relu')(mini)
mini = keras.layers.Dense(200, activation='relu')(mini)
mini = keras.layers.Dense(100, activation='relu')(mini)
#output
ctRoundWin = keras.layers.Dense(1, activation=keras.activations.softmax)(mini)
#model
model = keras.Model(inputs=minimapInput, outputs=ctRoundWin)
#creating/reading train data
def readFrames(file, frames, width, height, depth=3):
output = np.zeros((frames, height, width, depth))
for frame in range(frames):
for i in range(depth):
for j in range(height):
for k in range(width):
try:
output[frame, j, k, i] = ord(file.read(1)[0])
except IndexError:
output[frame, j, k, i] = 0
return output
y_pred = np.zeros(167)
frames = 167
minifile = open("C:\\Users\\s-wel\\OneDrive\\Desktop\\CSAI\\test\\mini.txt", "r")
mini = readFrames(minifile, frames, 312, 312)
minifile.close
#creating tf Dataset
data = tf.data.Dataset.from_tensor_slices((mini, y_pred))
#running it
model.compile(optimizer = keras.optimizers.Adam(learning_rate=0.001), loss = keras.losses.categorical_crossentropy, metrics = keras.metrics.binary_accuracy)
model.fit(data, epochs=1)
我发现的所有其他解决方案都不能真正解决这个问题。
【问题讨论】:
-
你能在 keras.Input 中试试 ''shape = (,312, 312, 3)'' 吗?
-
不,这给了我一个语法错误,如果我输入 (None, 312, 312, 3) 它告诉我它需要一个 ndim = 4 而不是 ndim = 5 我给了他 (None , 无, 312, 312, 3)。
标签: python tensorflow tf.keras