【发布时间】:2020-10-19 18:39:44
【问题描述】:
在 tensorflow 中,我打算在预训练的 CNN 中调整超参数以用于图像分类任务。为此,我使用了像 vgg16 这样的预训练模型来提取特征,并使用提取的嵌入特征作为卷积神经网络 (CNN) 的输入。基本上,我将 CNN 放在预训练模型的顶部进行训练。我正在尝试使用GridSeatchCV 优化batch_size, epochs, drop-rate 之类的超参数,但出现以下类型错误:
TypeError: Only integers, slices (`:`), ellipsis (`...`), tf.newaxis (`None`) and scalar tf.int32/tf.int64 tensors are valid indices, got array([200, 201, 202, 203,...
我也试过这样:
grid_search = grid_search.fit(np.array(df_train_tf), np.array(labels_tr_tf[1:1001]))
但现在我遇到以下错误:
ValueError:分类指标无法处理混合 多标签指标和多类目标
我在SO上查看了这个错误,但它无法摆脱上面的错误。如何解决这个问题?
在我的 CNN 中,我将 flatten dim 张量作为输入传递给 CNN,从预训练模型中提取的嵌入特征是 1 个 dim 特征向量,我将其转换为张量。当我尝试运行网格搜索以进行超参数优化时,出现上述类型错误。我试图理解为什么我有这样的错误。谁能指出我发生了什么事?谢谢
我的尝试:
from keras.wrappers.scikit_learn import KerasClassifier
from sklearn.model_selection import GridSearchCV
model = KerasClassifier(build_fn=myCNN)
parameters = {'dim': [256,512, 784,1024, 2048],
'epochs': [25,50,75,100,125,150,200],
'batch_size':[32,64,128,192, 256],
'drop_rate': [0.1,0.2,0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
'opt': ['rmsprop', 'adam', 'sgd'],
'actv_func': ['relu', 'tanh']}
grid_search = GridSearchCV(estimator=model,
param_grid=parameters,
scoring='accuracy',
cv=5)
grid_search = grid_search.fit(df_train_tf, labels_tr_tf[1:1001])
其中df_train_tf 是预训练嵌入特征的张量,labels_tr_tf 是 one-hot 编码标签的张量。这是df_train_tf, labels_tr_tf 的样子。
df_train_tf.shape: TensorShape([1000, 2048]) labels_tr_tf[1:1001].shape: TensorShape([1000, 100]) type(labels_tr_tf[1:1001]): tensorflow.python.framework.ops.EagerTensor type(df_train_tf): tensorflow.python.framework.ops.EagerTensor df_train_tf: <tf.Tensor: shape=(1000, 2048), dtype=float32, numpy= array([[ 2.3664525 , 6.4614077 , 22.128284 , ..., 2.8993628 , 7.6006427 , 4.022856 ], [ 2.8110769 , 0. , 21.861437 , ..., 2.8580594 , 3.8210764 , 3.4176886 ],...] labels_tr_tf[1:1001]: <tf.Tensor: shape=(1000, 100), dtype=float32, numpy= array([[0., 0., 0., ..., 0., 0., 0.], [1., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.],..]
我没有找到任何线索为什么我会收到此错误。谁能指出我如何做到这一点?任何解决上述类型错误的解决方案?任何想法?谢谢
【问题讨论】:
标签: python tensorflow error-handling conv-neural-network