【问题标题】:how made cross-validation with python?如何使用 python 进行交叉验证?
【发布时间】:2022-01-25 03:36:22
【问题描述】:

您好,我制作了一个神经网络,我需要进行交叉验证。 我不知道它是怎么做到的,特别是如何训练或做到的。

如果有人知道,请写信或给我一些指示。

这是我的代码:

###Division Train / Test
X = df.drop('Peso secado',axis=1)  #Variables de entrada, menos la variable de salida
y = df['Peso secado']              #Variable de salida

from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X,y,test_size=0.3,random_state=101)

###

from sklearn.preprocessing import MinMaxScaler
scaler = MinMaxScaler()
X_train= scaler.fit_transform(X_train)
X_train
X_test = scaler.transform(X_test)
X_test



###Creacion del modelo###
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Activation
from tensorflow.keras.optimizers import Adam
import tensorflow as tf

model = Sequential()
num_neuronas = 50
model.add(tf.keras.layers.Dense(units=6, activation='sigmoid', input_shape=(6, )))
model.add(Dense(num_neuronas,activation='relu'))
model.add(tf.keras.layers.Dense(units=1, activation='linear')) 

#Buscar mejor funcion de activacion para capa de salida sigmoid? o linear?
model.summary()
model.compile(optimizer='adam',loss='mse')

###Entrenamiento###
model.fit(x = X_train, y = y_train.values,
          validation_data=(X_test,y_test.values), batch_size=10, epochs=1000) 

losses = pd.DataFrame(model.history.history)  
losses
losses.plot()
   
###Evaluacion###
from sklearn.metrics import mean_squared_error,mean_absolute_error,explained_variance_score,mean_absolute_percentage_error
X_test
predictions = model.predict(X_test)
mean_absolute_error(y_test,predictions)
mean_absolute_percentage_error(y_test,predictions)

mean_squared_error(y_test,predictions)
explained_variance_score(y_test,predictions)  

mean_absolute_error(y_test,predictions)/df['Peso secado'].mean() 
mean_absolute_error(y_test,predictions)/df['Peso secado'].median()

一些培训或验证建议会有所帮助

【问题讨论】:

  • 你知道什么是交叉验证吗?你说的“我不知道是怎么做到的”到底是什么意思?你在交叉验证的哪一部分苦苦挣扎?您编写了哪些代码来实现交叉验证?你到底卡在哪里了?
  • 你可以看看一些教程:scikit-learn.org/stable/modules/cross_validation.htmlmachinelearningmastery.com/…。您进行交叉验证以分析模型在(人为)“不同”数据集上的性能。所以你首先建立你的模型,看看它是否有意义(也许尝试训练一次,看看它需要多少时间,也许就像你在这段代码中所做的那样在测试数据集上评估它),然后运行交叉验证以如果您要在不同的数据集上拟合模型,可以看看模型如何“平均”执行。
  • 感谢 ForceBru
  • Stack Overflow 不是一个论坛,在这里您应该做一些基础研究来尝试解决您的问题,特别是如果大多数 ML 库中的教程都涵盖了这一点。
  • 对于编程知识最少的人来说,它很容易研究、理解和应用。对于其余的,这是一个巨大的头痛。这就是为什么我们寻找这样或其他的论坛,寻找“专​​家”来帮助我们。不要批评我们不知道如何提问。 PD:就像下面回答的那个人一样。

标签: python tensorflow keras neural-network cross-validation


【解决方案1】:

我的第一个观察是代码非常丑陋且非结构化。您应该在代码的顶部导入模块

为了执行交叉验证,首先从 sklearn 导入模块(以及您需要的所有其他模块)

from sklearn.model_selection import StratifiedKFold

我会将模型定义放在一个单独的函数中:

def get_model():
  model = Sequential()
  model.add(Dense(4, input_dim=8, activation='relu'))
  model.add(Dense(1, activation='sigmoid'))
  model.compile(loss='binary_crossentropy', optimizer='adam')
  return model

定义你的变量,如果你正在使用 tensorflow / Keras,请执行以下操作:

BATCH_SIZE = 64  #  128
EPOCHS = 100

k = 10
# Use stratified k-fold if the data is imbalanced
kf = StratifiedKFold(n_splits=k, shuffle=False, random_state=None)

# here comes the Cross validation
fold_index = 1
for train_index, test_index in kf.split(X, y):
            X_train = X[train_index]
            y_train = y[train_index]

            X_test = X[test_index]
            y_test = y[test_index]

            # fit the model on the training set
            model = get_model()

            model.fit(
                X_train,
                y_train,
                batch_size=BATCH_SIZE,
                epochs=EPOCHS,
                verbose=0,
                validation_data=(X_test, y_test),
            )

            # predict values
            # pred_values = model.predict(X_test)
            pred_values_prob = np.array(model(X_test))

注意:使用 tensorflow 时,您需要每次在循环中定义一个新模型。这不是 sklearn 的情况,因为 sklearn 在调用时会以新的初始化权重开始。在这里您需要单独执行此操作。

【讨论】:

    猜你喜欢
    • 1970-01-01
    • 2021-09-30
    • 2020-11-30
    • 2017-06-29
    • 2015-06-11
    • 2020-07-13
    • 2019-09-06
    • 2017-04-21
    • 1970-01-01
    相关资源
    最近更新 更多