【发布时间】:2021-01-16 12:56:31
【问题描述】:
我正在尝试使用 keras 来拟合 CNN 模型来分类 2 类数据。我有不平衡的数据集我想平衡数据。我不知道我可以在 model.fit_generator 中使用 class_weight 。我想知道我是否在model.fit_generator中使用了class_weight="balanced"
主要代码:
def generate_arrays_for_training(indexPat, paths, start=0, end=100):
while True:
from_=int(len(paths)/100*start)
to_=int(len(paths)/100*end)
for i in range(from_, int(to_)):
f=paths[i]
x = np.load(PathSpectogramFolder+f)
x = np.expand_dims(x, axis=0)
if('P' in f):
y = np.repeat([[0,1]],x.shape[0], axis=0)
else:
y =np.repeat([[1,0]],x.shape[0], axis=0)
yield(x,y)
history=model.fit_generator(generate_arrays_for_training(indexPat, filesPath, end=75),
validation_data=generate_arrays_for_training(indexPat, filesPath, start=75),
steps_per_epoch=int((len(filesPath)-int(len(filesPath)/100*25))),
validation_steps=int((len(filesPath)-int(len(filesPath)/100*75))),
verbose=2,
epochs=15, max_queue_size=2, shuffle=True, callbacks=[callback])
【问题讨论】:
-
你可以像implementation一样使用class_weight。
-
@HweiGeokNg 我希望数据同样平衡。我该怎么做??
-
查看此博客:androidkt.com/set-class-weight-for-imbalance-dataset-in-keras。有一个名为 compute_class_weight() 的函数可以用作 class_weight 的参数。
-
@HweiGeokNg 我知道这个函数,但我的数据集中没有 x_train 和 y_train 我使用
generate_arrays_for_training函数。请检查代码我将这个函数。 -
抱歉,我错过了这些信息。我帮不了你,希望其他人能来救援。
标签: python machine-learning keras deep-learning generator