【发布时间】:2020-10-12 19:11:38
【问题描述】:
我在一组文件中有一组逗号分隔的记录,我在这些文件上调用 TensorFlow API 函数 tf.io.decode_csv()。记录如下所示:
tf.Tensor(b'249,EMR,2019-09-13,65.55,65.58,66.2099,65.16', shape=(), dtype=string)
我为该类型的记录使用默认对象:
defaults = [tf.constant([0])] + [tf.constant([], dtype=tf.string)] + [tf.constant([], dtype=tf.string)] + [tf.constant([0.0])]*4
运行 decode_csv() 函数:
ds = SP500fileNamesShuffle.map(lambda fn : tf.io.decode_csv(fn, defaults))
我得到了一个类型的数据集
<DatasetV1Adapter shapes: ((), (), (), (), (), (), ()), types: (tf.int32, tf.string, tf.string, tf.float32, tf.float32, tf.float32, tf.float32)>
每条记录有 7 种类型,因此是 7 个元素的元组。我不知道如何迭代特定元素,例如第二个元组上的元素。我会很感激你的帮助。 我试过了:
for e in ds.take(10):
print(e[1])
我收到以下错误消息:
{{function_node __inference_Dataset_map_<lambda>_6530}} Expect 7 fields but have 1 in record 0
[[{{node DecodeCSV}}]] [Op:IteratorGetNextSync]
【问题讨论】:
标签: python tensorflow-datasets