【发布时间】:2019-07-30 15:33:13
【问题描述】:
我正在编写一个简单的 CNN 来对 mnist 数字进行分类,这很简单,但是该模型过拟合非常快,而且幅度很大
我实现了 counter_overfitting 技术,例如 dropout、batch norm、数据增强,但简单的模型从未改进
import tensorflow as tf
import tensorflow
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import random
from PIL import Image
class ConvBlock(tf.keras.layers.Layer):
"""Convolutional Block featuring Conv2D + Pooling"""
def __init__(self, conv_deep=1, kernels=32, kernel_size=3, pool_size=2, dropout_rate=0.4):
super(ConvBlock, self).__init__(self)
self.conv_layers = []
self.pooling_layers = []
self.bnorm_layers = []
self.dropout_layers = []
for index in range(0, conv_deep):
self.conv_layers.append(tf.keras.layers.Conv2D(filters=kernels, kernel_size=kernel_size, padding="same", activation="relu"))
self.pooling_layers.append(tf.keras.layers.MaxPool2D(pool_size=pool_size))
self.bnorm_layers.append(tf.keras.layers.BatchNormalization())
self.dropout_layers.append(tf.keras.layers.Dropout(dropout_rate))
def call(self, inputs, training=False):
output = inputs
for (conv, pooling, bnorm, dropout) in zip(self.conv_layers, self.pooling_layers, self.bnorm_layers, self.dropout_layers):
output = conv(output)
output = pooling(output)
output = bnorm(output)
if training:
output = dropout(output)
return output
class DigitsClassifier(tf.keras.Model):
"""MNIST Digit Classifier"""
def __init__(self):
super(DigitsClassifier, self).__init__(self)
self.conv_input = ConvBlock(conv_deep=2, kernels=32)
self.conv_hiden = ConvBlock(conv_deep=1, kernels=16)
self.flatten = tf.keras.layers.Flatten()
self.hiden = tf.keras.layers.Dense(50, "relu")
self.bnorm = tf.keras.layers.BatchNormalization()
self.softmax = tf.keras.layers.Dense(10, "softmax")
def call(self, inputs):
output = self.conv_input(inputs)
output = self.conv_hiden(output)
output = self.flatten(output)
output = self.hiden(output)
output = self.bnorm(output)
output = self.softmax(output)
return output
#Load Train Data
(train_digits, train_labels), (eval_digits, eval_labels) = tf.keras.datasets.mnist.load_data("./Resources")
kaggle_digits = pd.read_csv("./Resources/test.csv").values
#Preprocess
train_digits = np.reshape(train_digits, [np.shape(train_digits)[0], 28, 28, 1])/255.0
eval_digits = np.reshape(eval_digits, [np.shape(eval_digits)[0], 28, 28, 1])/255.0
kaggle_digits = np.reshape(kaggle_digits, [np.shape(kaggle_digits)[0], 28, 28, 1])/255.0
#Generator
def get_sample(digits, return_labels=False, labels=None):
if(return_labels):
if(np.shape(digits)[0] == np.shape(labels)[0]):
for index in range(0, np.shape(digits)[0]):
yield (digits[index], labels[index])
else:
raise ValueError("Digits and Labels dont have the same numberof samples")
else:
for index in range(0, np.shape(digits)[0]):
yield (digits[index])
def transform_sample(digit, label):
rot = random.randint(-1, 2)
t_digit = digit
t_digit = tf.compat.v2.image.rot90(t_digit, rot)
return t_digit, label
#Define datasets
train_ds = tf.data.Dataset.from_generator(get_sample, (tf.float32, tf.int32), args=[train_digits, True, train_labels]).map(transform_sample, 100).batch(1000).prefetch(2)
eval_ds = tf.data.Dataset.from_generator(get_sample, (tf.float32, tf.int32), args=[eval_digits, True, eval_labels]).batch(1000).prefetch(2)
kaggle_ds = tf.data.Dataset.from_generator(get_sample, (tf.float32), args=[kaggle_digits]).batch(1000).prefetch(2)
for digits, label in train_ds.take(1):
print(label)
sns.regplot(data=digits)
plt.show()
#Define model and load weights (Pretrained on google colab notebook)
model = DigitsClassifier()
model.compile(tf.keras.optimizers.Adadelta(7.0), tf.keras.losses.SparseCategoricalCrossentropy())
model.fit(train_ds, epochs=50, verbose=2, validation_data=eval_ds)
此时我真的不知道该尝试什么,我会降低模型复杂度,但我认为这不会有帮助
PD:违反直觉,停止使用数据增强技术让模型改进,我的数据增强简单在于映射函数 transform_sample 对每个图像执行随机 90 度旋转,或者根本不旋转
【问题讨论】:
-
嗨@EiPapi,尝试复制您的代码,但出现错误,“ValueError:应定义输入的通道维度。找到
None。”在“model.fit”行中。另外,您能否确认“资源”文件夹中存在哪些数据。我已经用下面的代码替换了那行代码: (train_digits, train_labels), (eval_digits, eval_labels) = tf.keras.datasets.mnist.load_data()
标签: tensorflow keras conv-neural-network mnist