【问题标题】:Using string labels in Tensorflow在 TensorFlow 中使用字符串标签
【发布时间】:2016-03-09 03:11:03
【问题描述】:

我仍在尝试使用自己的图像数据运行 Tensorflow。 我能够从这个示例link 使用 conevert_to() 函数创建一个 .tfrecords 文件@

现在我想使用来自该示例 link 的代码来训练网络。

但它在 read_and_decode() 函数中失败。我对该功能的更改是:

label = tf.decode_raw(features['label'], tf.string) 

错误是:

TypeError: DataType string for attr 'out_type' not in list of allowed values: float32, float64, int32, uint8, int16, int8, int64

那么如何 1) 阅读和 2) 在 tensorflow 中使用字符串标签进行训练。

【问题讨论】:

    标签: python labels tensorflow


    【解决方案1】:

    convert_to_records.py 脚本创建一个.tfrecords 文件,其中每条记录都是一个Example 协议缓冲区。该协议缓冲区支持使用 bytes_list kind 的字符串功能。

    tf.decode_raw op 用于将二进制字符串解析为图像数据;它不是为解析字符串(文本)标签而设计的。假设 features['label'] 是一个 tf.string 张量,您可以使用 tf.string_to_number 操作将其转换为数字。您的 TensorFlow 程序中对字符串处理的其他支持有限,因此如果您需要执行一些更复杂的函数来将字符串标签转换为整数,您应该在 Python 的 convert_to_tensor.py 修改版中执行此转换。

    【讨论】:

    【解决方案2】:

    要添加到@mrry 的答案,假设您的字符串是ascii,您可以:

    def _bytes_feature(value):
        return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
    
    def write_proto(cls, filepath, ..., item_id): # itemid is an ascii encodable string
        # ...
        with tf.python_io.TFRecordWriter(filepath) as writer:
            example = tf.train.Example(features=tf.train.Features(feature={
                 # write it as a bytes array, supposing your string is `ascii`
                'item_id': _bytes_feature(bytes(item_id, encoding='ascii')), # python 3
                # ...
            }))
            writer.write(example.SerializeToString())
    

    然后:

    def parse_single_example(cls, example_proto, graph=None):
        features_dict = tf.parse_single_example(example_proto,
            features={'item_id': tf.FixedLenFeature([], tf.string),
            # ...
            })
        # decode as uint8 aka bytes
        instance.item_id = tf.decode_raw(features_dict['item_id'], tf.uint8)
    

    然后当你在会话中取回它时,转换回字符串:

    item_id, ... = session.run(your_tfrecords_iterator.get_next())
    print(str(item_id.flatten(), 'ascii')) # python 3
    

    我从这个related so answer 中得到了uint8 的把戏。对我有用,但欢迎 cmets/改进。

    【讨论】:

    • 我有一个由图像组成的 TFRecord,其中一个特征是磁盘上该图像的路径。该路径的格式为path\to\images\image432.jpg。这条路径的长度从8891 不等。当我将此特定功能解码为tf.decode_raw(features['train/path'], tf.uint8) 时,我得到ValueError: All shapes must be fully defined: [TensorShape([Dimension(None)]), TensorShape([Dimension(256), Dimension(256), Dimension(1)]), TensorShape([])],第一个维度对应于路径
    • 我遇到了所有需要完全定义的形状的相同问题。另外,如果 Assert 语句失败,我想返回文件名。这似乎是不可能的。
    猜你喜欢
    • 2020-01-08
    • 1970-01-01
    • 1970-01-01
    • 2012-10-18
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2017-09-22
    • 1970-01-01
    相关资源
    最近更新 更多