【问题标题】:How to improve neural network through dropout layers?如何通过 dropout 层改进神经网络?
【发布时间】:2020-08-25 11:51:37
【问题描述】:

我正在研究预测心脏病的神经网络。数据来自 kaggle 并经过预处理。我使用了各种模型,例如逻辑回归、随机森林和 SVM,它们都产生了可靠的结果。我正在尝试将相同的数据用于神经网络,以查看 NN 是否可以胜过其他 ML 模型(数据集相当小,这可能解释了结果不佳的原因)。下面是我的网络代码。下面的模型产生 50% 的准确率,显然,这太低而无用。据您所知,是否有任何看起来会破坏模型准确性的事情?

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tensorflow.keras.layers import Dense, Dropout
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.callbacks import EarlyStopping

df = pd.read_csv(r"C:\Users\***\Desktop\heart.csv")

X = df[['age','sex','cp','trestbps','chol','fbs','restecg','thalach']].values
y = df['target'].values

from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.30)

from sklearn.preprocessing import StandardScaler

scaler = StandardScaler()

scaler.fit_transform(X_train)
scaler.transform(X_test)


nn = tf.keras.Sequential()

nn.add(Dense(30, activation='relu'))

nn.add(Dropout(0.2))

nn.add(Dense(15, activation='relu'))

nn.add(Dropout(0.2))


nn.add(Dense(1, activation='sigmoid'))


nn.compile(optimizer = 'adam', loss = 'binary_crossentropy', metrics= 
 ['accuracy'])


early_stop = EarlyStopping(monitor='val_loss',mode='min', verbose=1, 
patience=25)

nn.fit(X_train, y_train, epochs = 1000, validation_data=(X_test, y_test),
     callbacks=[early_stop])

model_loss = pd.DataFrame(nn.history.history)
model_loss.plot()

predictions = nn.predict_classes(X_test)

from sklearn.metrics import classification_report,confusion_matrix

print(classification_report(y_test,predictions))
print(confusion_matrix(y_test,predictions))

【问题讨论】:

    标签: python tensorflow keras neural-network dropout


    【解决方案1】:

    缩放器未到位;您需要保存缩放后的结果。

    X_train = scaler.fit_transform(X_train)
    X_test = scaler.transform(X_test)
    

    然后您将获得更符合您预期的结果。

                  precision    recall  f1-score   support
    
               0       0.93      0.98      0.95       144
               1       0.98      0.93      0.96       164
    
        accuracy                           0.95       308
       macro avg       0.95      0.96      0.95       308
    weighted avg       0.96      0.95      0.95       308
    

    【讨论】:

      【解决方案2】:

      使用 EarlyStopping 运行模型后,

      Epoch 324/1000
      23/23 [==============================] - 0s 3ms/step - loss: 0.5051 - accuracy: 0.7364 - val_loss: 0.4402 - val_accuracy: 0.8182
      Epoch 325/1000
      23/23 [==============================] - 0s 3ms/step - loss: 0.4716 - accuracy: 0.7643 - val_loss: 0.4366 - val_accuracy: 0.7922
      Epoch 00325: early stopping
      WARNING:tensorflow:From <ipython-input-54-2ee8517852a8>:54: Sequential.predict_classes (from tensorflow.python.keras.engine.sequential) is deprecated and will be removed after 2021-01-01.
      Instructions for updating:
      Please use instead:* `np.argmax(model.predict(x), axis=-1)`,   if your model does multi-class classification   (e.g. if it uses a `softmax` last-layer activation).* `(model.predict(x) > 0.5).astype("int32")`,   if your model does binary classification   (e.g. if it uses a `sigmoid` last-layer activation).
                    precision    recall  f1-score   support
      
                 0       0.90      0.66      0.76       154
                 1       0.73      0.93      0.82       154
      
          accuracy                           0.79       308
         macro avg       0.82      0.79      0.79       308
      weighted avg       0.82      0.79      0.79       308
      
      

      它暗示了这样一个简单的 MLP 的合理准确度和 f1 分数。

      我使用了这个数据集:https://www.kaggle.com/abdulhakimrony/heartcsv/data

      1. 对所有 epoch 进行训练,初始精度可能较低,但模型会在几个 epoch 后很快收敛。

      2. 在 random、tensorflow 和 numpy 中使用seed,每次都能获得可重现的结果。

      3. 如果简单模型显示出良好的准确性,则 NN 很有可能会表现出色,但您必须确保 NN 没有过度拟合。

      4. 检查您的数据是否不平衡,如果是,请尝试使用class_weights

      5. 您可以尝试 tuner 进行交叉验证以获得最佳性能模型。

      【讨论】:

        猜你喜欢
        • 2016-11-02
        • 2023-02-07
        • 2011-03-28
        • 2022-01-14
        • 2016-09-02
        • 2019-09-01
        • 2017-02-25
        • 2012-09-22
        • 2018-11-05
        相关资源
        最近更新 更多