【发布时间】:2020-08-04 10:28:18
【问题描述】:
我在 STFT 数据和离散小波变换数据上创建 CNN 模型。我想在 python 的 2 个输入数据上获得我的深度学习模型的权重和偏差的数量。该怎么做??
任何帮助将不胜感激。
代码:
def createModel():
with tf.device("cpu"):
input_shape=(1, 22, 5, 3844)
model = Sequential()
model.add(Conv3D(16, (22, 5, 5), strides=(1, 2, 2), padding='same',activation='relu',data_format= "channels_first", input_shape=input_shape))
model.add(keras.layers.MaxPooling3D(pool_size=(1, 2, 2),data_format= "channels_first", padding='same'))
model.add(BatchNormalization())
model.add(Conv3D(32, (1, 3, 3), strides=(1, 1,1), padding='same',data_format= "channels_first", activation='relu'))#incertezza se togliere padding
model.add(keras.layers.MaxPooling3D(pool_size=(1,2, 2),data_format= "channels_first", ))
model.add(BatchNormalization())
model.add(Conv3D(64, (1,3, 3), strides=(1, 1,1), padding='same',data_format= "channels_first", activation='relu'))#incertezza se togliere padding
model.add(keras.layers.MaxPooling3D(pool_size=(1,2, 2),data_format= "channels_first",padding='same' ))
model.add(BatchNormalization())
model.add(Dense(64, input_dim=64,kernel_regularizer=regularizers.l2(0.0001), activity_regularizer=regularizers.l1(0.0001)))
model.add(Flatten())
model.add(Dropout(0.5))
model.add(Dense(256, activation='sigmoid'))
model.add(Dropout(0.5))
model.add(Dense(2, activation='softmax'))
opt_adam = keras.optimizers.Adam(lr=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-08, decay=0.0)
model.compile(loss='categorical_crossentropy', optimizer=opt_adam, metrics=['accuracy'])
return model
【问题讨论】:
-
您能否添加更多解释。我无法正确理解这个问题。你想获取哪些参数?
-
@AbdullahDeliogullari 我的意思是在训练阶段学习的连接权重。
-
那么您想在训练期间获得 CNN 网络每一层的权重吗?我理解对了吗?
-
@AbdullahDeliogullari 我使用这 2 行,我得到每个时代的权重,但我不知道如何打开文件?
filepath_1="weights.{epoch:02d}-{val_loss:.2f}.hdf5" call=keras.callbacks.callbacks.ModelCheckpoint(filepath_1, monitor='val_loss', verbose=1, save_best_only=False, save_weights_only=True, mode='auto', period=1)和如何知道我的模型的权重和偏差的数量?? -
好的,所以你有 hd5 文件,无法打开。此外,您正在寻找一种方法来了解此文件的格式。我说的对吗?
标签: python keras deep-learning conv-neural-network hyperparameters