【发布时间】:2022-01-09 09:58:09
【问题描述】:
尝试使用简单的架构来识别手写数字。测试给出 0.9723 的准确度
import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow import keras
from tensorflow.keras.layers import Dense, Flatten
from sklearn.model_selection import train_test_split
# data split
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# normalizing
x_train = x_train / 255
x_test = x_test / 255
y_train_cat = keras.utils.to_categorical(y_train, 10)
y_test_cat = keras.utils.to_categorical(y_test, 10)
# creating model
model = keras.Sequential([
Flatten(input_shape=(28, 28, 1)),
Dense(128, activation='relu'),
Dense(10, activation='softmax')
])
model.compile(optimizer='adam',
loss='categorical_crossentropy',
metrics=['accuracy'])
x_train_split, x_val_split, y_train_split, y_val_split = train_test_split(x_train, y_train_cat, test_size=0.2)
model.fit(
x_train_split,
y_train_split,
batch_size=32,
epochs=6,
validation_data=(x_val_split, y_val_split))
# saving model
model.save('mnist_model.h5')
# test
model.evaluate(x_test, y_test_cat)
但是,当我尝试识别自己的数字(0 到 9)时,其中一些无法正确识别: numbers and prediction above
尝试使用此代码:
from keras.models import load_model
from tensorflow.keras.datasets import mnist
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
model = load_model('mnist_model.h5')
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_test = x_test / 255
y_test_cat = keras.utils.to_categorical(y_test, 10)
model.evaluate(x_test, y_test_cat)
filenames = [
'project_imgs/0.png', 'project_imgs/1.png', 'project_imgs/2.png', 'project_imgs/3.png',
'project_imgs/4.png', 'project_imgs/5.png', 'project_imgs/6.png', 'project_imgs/7.png',
'project_imgs/8.png', 'project_imgs/9.png'
]
data = []
data_eds = []
for file in filenames:
picture = Image.open(file).convert('L')
pic_r = picture.resize((28, 28))
pic = np.array(pic_r)
pic = 255 - pic
pic = pic / 255
pic_eds = np.expand_dims(pic, axis=0)
data.append(pic)
data_eds.append(pic_eds)
plt.figure(figsize=(10, 5))
for i in range(10):
ax = plt.subplot(2, 5, i+1)
ax.set_title(f'Looks like {np.argmax(model.predict(data_eds[i]))}')
plt.xticks([])
plt.yticks([])
plt.imshow(data[i], cmap=plt.cm.binary)
plt.show()
我不明白为什么会这样。会不会是图片的原因?我已经看到 MNIST 生成的图像更黑,不像我的那样灰。还是因为与这个 28x28 正方形相关的数字的大小?
【问题讨论】:
标签: python tensorflow keras mnist