下面是我从cs231n上整理的神经网络的入门实现,麻雀虽小,五脏俱全,基本上神经网络涉及到的知识点都有在代码中体现。
理论看上千万遍,不如看一遍源码跑一跑。
源码上我已经加了很多注释,结合代码看一遍很容易理解。
最后可视化权重的图:
主文件,用来训练调参
two_layer_net.py
1 # coding: utf-8 2 3 # 实现一个简单的神经网络并在CIFAR10上测试性能 4 5 import numpy as np 6 import matplotlib.pyplot as plt 7 from neural_net import TwoLayerNet 8 from data_utils import load_CIFAR10 9 from vis_utils import visualize_grid 10 11 def get_CIFAR10_data(num_training=49000, num_validation=1000, num_test=1000): 12 cifar10_dir = 'cs231n/datasets/cifar-10-batches-py' 13 X_train, y_train, X_test, y_test = load_CIFAR10(cifar10_dir) 14 15 # 采样 16 mask = list(range(num_training, num_training + num_validation)) 17 X_val = X_train[mask] 18 y_val = y_train[mask] 19 mask = list(range(num_training)) 20 X_train = X_train[mask] 21 y_train = y_train[mask] 22 mask = list(range(num_test)) 23 X_test = X_test[mask] 24 y_test = y_test[mask] 25 26 # 归一化操作:减去均值,使得数据以0为中心 27 mean_image = np.mean(X_train, axis=0) 28 X_train -= mean_image 29 X_val -= mean_image 30 X_test -= mean_image 31 32 X_train = X_train.reshape(num_training, -1) 33 X_val = X_val.reshape(num_validation, -1) 34 X_test = X_test.reshape(num_test, -1) 35 36 return X_train, y_train, X_val, y_val, X_test, y_test 37 38 39 X_train, y_train, X_val, y_val, X_test, y_test = get_CIFAR10_data() 40 print('Train data shape: ', X_train.shape) 41 print('Train labels shape: ', y_train.shape) 42 print('Validation data shape: ', X_val.shape) 43 print('Validation labels shape: ', y_val.shape) 44 print('Test data shape: ', X_test.shape) 45 print('Test labels shape: ', y_test.shape) 46 47 48 #第一次训练 49 input_size = 32 * 32 * 3 50 hidden_size = 50 51 num_classes = 10 52 net = TwoLayerNet(input_size, hidden_size, num_classes) 53 stats = net.train(X_train, y_train, X_val, y_val, 54 num_iters=1000, batch_size=200, 55 learning_rate=1e-4, learning_rate_decay=0.95, 56 reg=0.25, verbose=True) 57 val_acc = (net.predict(X_val) == y_val).mean() 58 print('Validation accuracy: ', val_acc) 59 60 #效果不太理想,debug 61 62 # 先画一下loss和正确率的曲线看一看 63 plt.subplot(2, 1, 1) 64 plt.plot(stats['loss_history']) 65 plt.title('Loss history') 66 plt.xlabel('Iteration') 67 plt.ylabel('Loss') 68 69 plt.subplot(2, 1, 2) 70 plt.plot(stats['train_acc_history'], label='train') 71 plt.plot(stats['val_acc_history'], label='val') 72 plt.title('Classification accuracy history') 73 plt.xlabel('Epoch') 74 plt.ylabel('Clasification accuracy') 75 plt.show() 76 77 78 79 #可视化一下权重 80 def show_net_weights(net): 81 W1 = net.params['W1'] 82 W1 = W1.reshape(32, 32, 3, -1).transpose(3, 0, 1, 2) 83 plt.imshow(visualize_grid(W1, padding=3).astype('uint8')) 84 plt.gca().axis('off') 85 plt.show() 86 87 show_net_weights(net) 88 89 90 #通过上面的曲线我们可以看到基本上loss还在线性下降,表示我们的loss下降的还不够。 91 #一方面,我们可以加大学习率使loss更加快速的下降,另一方面,也可以增加迭代的次数,让loss继续下降。 92 #还有,在训练集和验证集上的正确率没有明显差距,表明网络的容量可能不够,可以尝试增加网络的复杂度使之拥有更强的表达能力。 93 94 95 96 #下面是我调出来的参数,实际上选了很久 ,在测试集上的正确率在55%左右 97 hidden_size = 150#[50,70,100,130] 98 learning_rates = 1e-3#np.array([0.5,1,1.5])*1e-3 99 regularization_strengths = 0.2#[0.1,0.2,0.3] 100 best_net = None 101 results = {} 102 best_val_acc = 0 103 104 105 for hs in hidden_size: 106 for lr in learning_rates: 107 for reg in regularization_strengths: 108 109 net = TwoLayerNet(input_size, hs, num_classes) 110 # Train the network 111 stats = net.train(X_train, y_train, X_val, y_val, 112 num_iters=3000, batch_size=200, 113 learning_rate=lr, learning_rate_decay=0.95, 114 reg= reg, verbose=False) 115 val_acc = (net.predict(X_val) == y_val).mean() 116 if val_acc > best_val_acc: 117 best_val_acc = val_acc 118 best_net = net 119 results[(hs,lr,reg)] = val_acc 120 121 plt.subplot(2, 1, 1) 122 plt.plot(stats['loss_history']) 123 plt.title('Loss history') 124 plt.xlabel('Iteration') 125 plt.ylabel('Loss') 126 127 plt.subplot(2, 1, 2) 128 plt.plot(stats['train_acc_history'], label='train') 129 plt.plot(stats['val_acc_history'], label='val') 130 plt.title('Classification accuracy history') 131 plt.xlabel('Epoch') 132 plt.ylabel('Clasification accuracy') 133 plt.show() 134 135 136 for hs,lr, reg in sorted(results): 137 val_acc = results[(hs, lr, reg)] 138 print ('hs %d lr %e reg %e val accuracy: %f' % (hs, lr, reg, val_acc)) 139 140 print ('best validation accuracy achieved during cross-validation: %f' % best_val_acc) 141 142 143 show_net_weights(best_net) 144 test_acc = (best_net.predict(X_test) == y_test).mean() 145 print('Test accuracy: ', test_acc)