【问题标题】:How to select specific columns from tensorflow dataset?如何从 tensorflow 数据集中选择特定列?
【发布时间】:2026-02-09 02:30:01
【问题描述】:

我正在使用 tf.data.Dataset 预处理的 CSV 文件中的数据训练我的 Tensorflow 模型。但是,我希望模型分成三个分支,对应于一组不同的 csv 列,并且 model.fit 需要为每个输出提供一个单独的数据集。 CSV 文件的所有列都需要经过相同的预处理,因此最有效的准备方法是加载整个文件,对其进行处理,然后将数据集拆分为三个部分。但是,我正在努力寻找一种方法。

我希望 dataset.map 允许我使用以下操作选择一些列:

dset = dset.map(lambda x: x[[1, 2, 3, 7]])

但似乎 tensorflow 将其解释为 x[1][2][3][7]

我发现创建单独数据集的唯一可行方法是从头开始:

y = []
for cls, keys in output_classes.items():
    tmp = tf.data.experimental.CsvDataset(data_path, [tf.int32 for i in keys], select_cols=keys)
    [...]
    y.append(tmp)
y = tf.data.Dataset.zip(tuple(y))

不幸的是,它会产生大量不必要的开销并极大地减慢训练速度。

有没有办法通过特征子集拆分 tf.data.Dataset 对象?

【问题讨论】:

  • 我也有同样的问题。你找到解决办法了吗?

标签: python tensorflow tensorflow-datasets


【解决方案1】:

试试tf.gather:

tf.gather(tf.constant([1,2,3,4]), [1,2,3])
# ouputs : array([2, 3, 4])

如果您有高维数据,请使用tf.gather_nd

【讨论】:

    【解决方案2】:

    通过使用 .map() 修改 tornikeo 的答案,此解决方案对我有用。

    dataset = tf.data.Dataset.from_tensor_slices([[1,2,3,4], 
                                                  [5,6,7,8]])
    dataset_filter = dataset.map(lambda x: tf.gather(x, [0, 2], axis=0))
    result = list(dataset_filter.as_numpy_iterator())
    print(result)
    
    # Outputs array([1, 3], dtype=int32), array([5, 7])
    

    【讨论】: