【发布时间】:2019-05-23 18:40:49
【问题描述】:
在TFRecords 上训练 TensorFlow 模型时,我需要有效地加入少量数据。如何使用已解析的 TFRecord 中的信息进行此查找?
更多细节:
我正在使用 TFRecords 在大型数据集上训练卷积网络。每个TFRecord 都包含原始图像以及目标标签,以及有关图像的一些元数据。训练的一部分是我需要使用特定于一组图像的mean 和std 来标准化图像。为此,我将mean 和std 硬编码到TFRecord 中。然后在我的parse_example 中使用它,用于映射我的input_fn 中的Dataset,如下所示:
def parse_example(..):
# ...
parsed = tf.parse_single_example(value, keys_to_features)
image_raw = tf.decode_raw(parsed['image/raw'], tf.uint16)
image = tf.reshape(image_raw, image_shape)
image.set_shape(image_shape)
# pull hardcoded pixels mean and std from the parsed TFExample
mean = parsed['mean']
std = parsed['std']
image = (tf.cast(image, tf.float32) - mean) / std
# ...
return image, label
虽然上述方法有效并且可以缩短训练时间,但它的局限性在于我经常想更改我使用的 mean 和 std。与其将mean 和std 写入TFRecords,我更愿意在训练时查找适当的汇总统计信息。这意味着当我训练时,我有一个小的 python 字典,我可以使用从TFRecord 解析的图像信息来查找适当的摘要统计信息。我遇到的问题是我似乎无法在我的张量流图中使用这个 python 字典。如果我尝试直接进行查找,它不起作用,因为我有张量对象而不是实际的原语。这是有道理的,因为 input_fn 正在为 TensorFlow 进行符号操作构建计算图(对吗?)。我该如何解决这个问题?
我尝试过的一件事是从字典中创建一个查找表,如下所示:
def create_channel_hashtable(keys, values, default_val=-1):
initializer = tf.contrib.lookup.KeyValueTensorInitializer(keys, values)
return tf.contrib.lookup.HashTable(initializer, default_val)
可以在parse_example 函数中创建和使用哈希表来进行查找。这一切都“有效”,但它极大地减慢了训练速度。值得注意的是,这种培训是在 TPU 上进行的。使用来自TFRecords 的值的原始方法,训练速度非常快,并且不受 IO 的限制,但是当使用哈希查找时,这种情况会发生变化。处理这些情况的建议方法是什么?虽然重新打包TFRecords 是可行的,但如果要查找的数据很小并且可以提高效率,这似乎很愚蠢。
【问题讨论】:
-
input_fn实际上是在 CPU 上运行的。您的训练 (model_fn) 是在 TPU 上完成的。你key的词汇量很大吗?我用 5000 的词汇量解决了与你类似的问题。而且我没有 IO 瓶颈问题。您能否分享您的input_fn的整个代码,以便重现问题? -
@greeness 我的理解是它在 CPU 上运行。这是否意味着我应该能够使用 python 字典?还是我需要使用查找表?要回答您的问题,词汇查找的大小很小,只有 204。当您过去完成它时,您是否使用过
HashTable或者它是否适用于 python 字典? -
我使用的是同一个哈希表。
index_table = tf.contrib.lookup.HashTable( tf.contrib.lookup.KeyValueTensorInitializer( tf.convert_to_tensor(families, dtype=tf.string), tf.convert_to_tensor(family_indices, dtype=tf.int64)), 0) . family_indices = tf.map_fn( index_table.lookup, features['family'], dtype=tf.int64) -
感谢@greeness,很高兴知道这是一种可行的方法。我一直在调试这个,我认为实际的问题是我的 input_fn 和/或 tfrecords 的大小中 IO 性能的不相关回归。如果事实证明是这样,我将关闭这个问题。再次感谢您的回答。
标签: python tensorflow tfrecord tpu