【发布时间】:2021-03-19 03:24:13
【问题描述】:
读取 tfrecord、添加字段并将它们写回的快速方法是什么?下面是一个执行此操作的程序,给定目录*.tfrecords。说这样做
example = tf.train.Example()
example.ParseFromString(rec_str.numpy())
output_string = example.SerializeToString()
writer.write(output_string)
对于输入中的每条记录都是一个时间单位,然后相加
features = dict(example.features.feature.items())
features["pred"] = tf.train.Feature(float_list=tf.train.FloatList(value=prediction_output))
example = tf.train.Example(features=tf.train.Features(feature=features))
在序列化和写回之前就像 3 个时间单位。
它似乎增加了很多 - 我的 tfrecord 中有大约 50 个标量字段,我只是再添加一个,有人知道更有效的方法吗?下面是一个程序。您可以尝试使用 gzip 压缩,这似乎对我的文件没有太大影响。
import os
import time
import glob
import logging
import numpy as np
import tensorflow as tf
def write_tfrecords_with_ap(input_dir, output_dir, write_zip, add_val, num_files):
t0_all = time.time()
tm_write = 0.0
tm_feat_dict = 0.0
os.makedirs(output_dir, exist_ok=True)
logging.info(f"Writing new tfrecords to {output_dir}")
input_files = glob.glob(os.path.join(input_dir, "*.tfrecords"))[0:num_files]
assert len(input_files)
prediction_output = np.empty(shape=(1,), dtype=np.float32)
num_records = 0
for input_tfrec_file in input_files:
print(f"input {input_tfrec_file}")
t0_write = time.time()
input_dset = tf.data.TFRecordDataset(input_tfrec_file, compression_type="GZIP",)
basename = os.path.basename(input_tfrec_file)
output_tfrec_file = os.path.join(output_dir, basename)
options = None
if write_zip:
options = tf.io.TFRecordOptions(compression_type="GZIP")
with tf.io.TFRecordWriter(output_tfrec_file, options=options) as writer:
for rec_str in input_dset:
num_records += 1
example = tf.train.Example()
example.ParseFromString(rec_str.numpy())
if add_val:
t0_feat_dict = time.time()
features = dict(example.features.feature.items())
tm_feat_dict += time.time() - t0_feat_dict
prediction_output[0] = 3.0
features["pred"] = tf.train.Feature(float_list=tf.train.FloatList(value=prediction_output))
example = tf.train.Example(features=tf.train.Features(feature=features))
output_string = example.SerializeToString()
writer.write(output_string)
print(f"out {output_tfrec_file}")
tm_write += time.time() - t0_write
tm_all = time.time() - t0_all
print(
f"Wrote {num_files} tfrecords files with {num_records} records. minutes: total={round(tm_all/60.0, 2)} "
f"writing={round(tm_write/60.0, 2)} "
f"feat_dict={round(tm_feat_dict/60.0, 2)}"
)
【问题讨论】:
标签: tensorflow tensorflow-datasets tfrecord