【问题标题】:TensorFlow io.decode_csv and select data along one dimensionTensorFlow io.decode_csv 并沿一维选择数据
【发布时间】:2020-10-12 19:11:38
【问题描述】:

我在一组文件中有一组逗号分隔的记录,我在这些文件上调用 TensorFlow API 函数 tf.io.decode_csv()。记录如下所示:

tf.Tensor(b'249,EMR,2019-09-13,65.55,65.58,66.2099,65.16', shape=(), dtype=string)

我为该类型的记录使用默认对象:

defaults = [tf.constant([0])] + [tf.constant([], dtype=tf.string)] + [tf.constant([], dtype=tf.string)] + [tf.constant([0.0])]*4

运行 decode_csv() 函数:

ds = SP500fileNamesShuffle.map(lambda fn : tf.io.decode_csv(fn, defaults))

我得到了一个类型的数据集

<DatasetV1Adapter shapes: ((), (), (), (), (), (), ()), types: (tf.int32, tf.string, tf.string, tf.float32, tf.float32, tf.float32, tf.float32)>

每条记录有 7 种类型,因此是 7 个元素的元组。我不知道如何迭代特定元素,例如第二个元组上的元素。我会很感激你的帮助。 我试过了:

for e in ds.take(10):
    print(e[1])

我收到以下错误消息:

{{function_node __inference_Dataset_map_<lambda>_6530}} Expect 7 fields but have 1 in record 0
     [[{{node DecodeCSV}}]] [Op:IteratorGetNextSync]

【问题讨论】:

    标签: python tensorflow-datasets


    【解决方案1】:

    只是关闭这个话题,因为解决方案很简单:我没有正确指定“默认”记录。在这种特殊情况下,它应该是:

    defaults = [tf.constant([0])] + [tf.constant([''], dtype=tf.string)] + [tf.constant([''], dtype=tf.string)] + [tf.constant([0.0])]*4
    

    然后解码工作。

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2013-01-08
      • 1970-01-01
      • 2019-03-01
      • 1970-01-01
      • 1970-01-01
      • 2016-06-30
      • 1970-01-01
      • 2019-08-24
      相关资源
      最近更新 更多