【问题标题】:How to read a datapoint with multiple lables from a tfrecord file如何从 tfrecord 文件中读取具有多个标签的数据点
【发布时间】:2019-01-04 16:28:41
【问题描述】:

我正在为每个图像编写带有多个标签的数据,在本例中是边界框和分类标签,并且正在使用以下函数将数据写入 tfrecord:

   def tfr_write_sr(data_split_name,save_dir, label_array, data_array):

       filename = os.path.join(save_dir, data_split_name + '.tfrecords')
       writer = tf.python_io.TFRecordWriter(filename)
       for index in range(data_array.shape[0]):

       image = data_array[index].tostring()
       example = tf.train.Example(features=tf.train.Features(
        feature={
            'height': tf.train.Feature(
                int64_list=tf.train.Int64List(
                    value=[data_array.shape[1]])),
            'width': tf.train.Feature(
                int64_list=tf.train.Int64List(
                    value=[data_array.shape[2]])),
            'depth': tf.train.Feature(
                int64_list=tf.train.Int64List(
                    value=[data_array.shape[3]])),
            'shape_type': tf.train.Feature(
                    int64_list=tf.train.Int64List(
                        value=[int(label_array[index][3])])),
            'bbtl_x': tf.train.Feature(
                    int64_list=tf.train.Int64List(
                        value=[int(label_array[index][1][0])])),
            'bbtl_y': tf.train.Feature(
                    int64_list=tf.train.Int64List(
                        value=[int(label_array[index][1][1])])),
            'bbbr_x': tf.train.Feature(
                    int64_list=tf.train.Int64List(
                        value=[int(label_array[index][0][0])])),
            'bbbr_y': tf.train.Feature(
                    int64_list=tf.train.Int64List(
                        value=[int(label_array[index][0][1])])),                
            'image_raw': tf.train.Feature(
                bytes_list=tf.train.BytesList(
                    value=[image]))}))
         writer.write(example.SerializeToString())
       writer.close() 

我已验证记录写入正确,但我之前看到的所有示例都只读取每个图像的一个标签,我如何读取多个标签?

【问题讨论】:

    标签: tensorflow tfrecord


    【解决方案1】:

    首先我们读入我们的 tfrecord 并获得它的特征:

      reader = tf.TFRecordReader()
      _ , serialized_example = reader.read(filename_queue)
    
       features = tf.parse_single_example(serialized_example, 
            features={
                'image_raw': tf.FixedLenFeature([],tf.string),
                'shape_type' : tf.FixedLenFeature([], tf.int64),
                'bbtl_x' : tf.FixedLenFeature([], tf.int64),
                'bbtl_y' : tf.FixedLenFeature([], tf.int64),
                'bbbr_x' : tf.FixedLenFeature([], tf.int64),
                'bbbr_y' : tf.FixedLenFeature([], tf.int64)
        })
    

    现在我们有了可以使用 tf.stack() 为我们的多变量构建张量并将其添加到我们的图表的功能:

         label  = tf.stack([features['shape_type'],
                            features['bbtl_x'],
                            features['bbtl_y'],
                            features['bbbr_x'],
                            features['bbbr_y'] ], axis=0 )
    
    
          image = tf.decode_raw(features['image_raw'], tf.uint8)
    
          images_batch, labels_batch = tf.train.shuffle_batch([image,label],
                                                     batch_size=128,
                                                     capacity=2000,
                                                     min_after_dequeue=1000) 
    

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 1970-01-01
      • 2015-04-30
      • 2015-06-24
      • 1970-01-01
      • 2020-09-16
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多