【发布时间】:2021-01-15 00:09:30
【问题描述】:
我正在尝试解决分类问题。我不知道为什么会出现此错误:
AttributeError: 'str' object has no attribute 'keys'
这是主要代码:
def generate_arrays_for_training(indexPat, paths, start=0, end=100):
while True:
from_=int(len(paths)/100*start)
to_=int(len(paths)/100*end)
for i in range(from_, int(to_)):
f=paths[i]
x = np.load(PathSpectogramFolder+f)
if('P' in f):
y = np.repeat([[0,1]],x.shape[0], axis=0)
else:
y =np.repeat([[1,0]],x.shape[0], axis=0)
yield(x,y)
history=model.fit_generator(generate_arrays_for_training(indexPat, filesPath, end=75) ## problem here
steps_per_epoch=int((len(filesPath)-int(len(filesPath)/100*25))),
validation_steps=int((len(filesPath)-int(len(filesPath)/100*75))),
verbose=2,class_weight="balanced",
epochs=15, max_queue_size=2, shuffle=True, callbacks=[callback])
其中generate_arrays_for_training 函数返回x 和y。 x 是一个二维浮点数数组,y 是 [0,1]。
错误:
Traceback (most recent call last):
File "/home/user1/thesis2/CNN_dwt2.py", line 437, in <module>
main()
File "/home/user1/thesis2/CNN_dwt2.py", line 316, in main
history=model.fit_generator(generate_arrays_for_training(indexPat, filesPath, end=75),
File "/home/user1/.local/lib/python3.8/site-packages/tensorflow/python/util/deprecation.py", line 324, in new_func
return func(*args, **kwargs)
File "/home/user1/.local/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py", line 1815, in fit_generator
return self.fit(
File "/home/user1/.local/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py", line 108, in _method_wrapper
return method(self, *args, **kwargs)
File "/home/user1/.local/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py", line 1049, in fit
data_handler = data_adapter.DataHandler(
File "/home/user1/.local/lib/python3.8/site-packages/tensorflow/python/keras/engine/data_adapter.py", line 1122, in __init__
dataset = dataset.map(_make_class_weight_map_fn(class_weight))
File "/home/user1/.local/lib/python3.8/site-packages/tensorflow/python/keras/engine/data_adapter.py", line 1295, in _make_class_weight_map_fn
class_ids = list(sorted(class_weight.keys()))
AttributeError: 'str' object has no attribute 'keys'
【问题讨论】:
-
直接原因是
class_weight是一个字符串,而它应该是一个dict。但是通过keras代码跟踪该变量到您的函数调用可能需要一些挖掘。确保model.fit_generator的输入与类型中的文档相匹配。 -
model.fit_generator的参数之一是字符串。你需要纠正它。很遗憾,您的 minimal reproducible example 使用了我们无法访问的数据,我们很难为您找到问题。 -
@hpaulj 我跟踪了代码,发现从
generate_arrays_for_training返回的值不是字符串,x是一个浮点二维矩阵,y是 [[0,1]]。我不知道字符串值是什么 -
@wwii 不幸的是,我正在处理非常大的数据集,我无法提供示例。但是我已经跟踪到
model.fit_generator的输入,发现它转到了“generate_arrays_for_training”函数,它返回了x和y,所以返回值中的错误 -
@wwii 我不知道为什么会出现这个错误
标签: python numpy tensorflow keras deep-learning