【发布时间】:2022-01-01 20:00:40
【问题描述】:
我有一个自定义的并行训练函数,它采用成对的训练和测试数据,并为不同的数据构建不同的模型。问题是数组似乎无法存储以下类型的数据。如何创建可以保存以下类型数据的列表。
for i in range(0,5):
def create_dataset():
...
...
train_data = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_data = train_data.cache().shuffle(buffer_size).batch(batch_size).repeat()
test_data = tf.data.Dataset.from_tensor_slices((x_test, y_test))
test_data = test_data.batch(batch_size).repeat()
return train_data,test_data
td[i],vd[i] = create_dataset()
model = create_model() # create the model
datasets = [(td[0],vd[0]),(td[1],vd[1]),(td[2],vd[3]),(td[3],vd[3]),(td[4],vd[4])]
parallel_trainer(model, datasets)
parallel trainer的参数是这样定义的,
def parallel_trainer(model, XY_train_datasets : list[tuple])
这样定义我的“数据集”会返回错误,
TypeError: 'type' object is not subscriptable
如何创建我的训练数据和测试数据的列表以便解决此错误。 解决方案可能很明显,但我对此相当陌生。
提前致谢。
【问题讨论】:
标签: python dataframe tensorflow