【发布时间】:2021-07-11 00:15:11
【问题描述】:
我正在 COLAB 工作环境中使用 Python 语言运行一段用于训练手稿识别模型的代码。 该代码加载了一个由 MNIST 人员撰写的文学图片数据库并对其进行练习。 代码:
import numpy
from keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import Dropout
from tensorflow.keras.layers import Flatten
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.layers import MaxPooling2D
from tensorflow.keras import utils as np_utils
from tensorflow.keras import backend as K
K.set_image_data_format('channels_first')
from matplotlib import pyplot as plt
#load data from mnist
(xTrain,yTrain),(xTest,yTest)=mnist.load_data()
#reshape the images to be 28*28 pixels
xTrain=xTrain.reshape(xTrain.shape[0],1,28,28).astype('float32')
xTest=xTest.reshape(xTest.shape[0],1,28,28).astype('float32')
#normalize inputs from 0-255 to 0-1
xTrain=xTrain/255
xTest=xTest/255
#one hot encode outputs
yTrain=np_utils.to_categorical(yTrain)
yTest=np_utils.to_categorical(yTest)
num_classes=yTest.shape[1]
def baseline_model():
#create model
model=Sequential()
model.add(Conv2D(32,(5,5), input_shape=(1,28,28),activation='relu'))
model.add(MaxPooling2D(pool_size=(2,2)))
model.add(Dropout(0.2))
model.add(Flatten())
model.add(Dense(128,activation='relu'))
model.add(Dense(num_classes, activation='softmax'))
#complie model
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
return model
model=baseline_model()
model.fit(xTrain, yTrain, validation_data=(xTest, yTest), epochs=1, batch_size=200, verbose=2)
问题是代码在最后一行返回错误。 错误:
InvalidArgumentError Traceback (most recent call last)
<ipython-input-111-ec22b5dcc4e3> in <module>()
63
64 model=baseline_model()
---> 65 model.fit(xTrain, yTrain, validation_data=(xTest, yTest), epochs=1, batch_size=200, verbose=2)
6 frames
/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
58 ctx.ensure_initialized()
59 tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
---> 60 inputs, attrs, num_outputs)
61 except core._NotOkStatusException as e:
62 if name is not None:
InvalidArgumentError: Default MaxPoolingOp only supports NHWC on device type CPU
[[node sequential_36/max_pooling2d_21/MaxPool (defined at <ipython-input-111-ec22b5dcc4e3>:65) ]] [Op:__inference_train_function_10696]
Function call stack:
train_function
不胜感激, 谢谢
【问题讨论】:
标签: tensorflow google-colaboratory