【问题标题】:Generating TFRecord format data from C+从 C+ 生成 TFRecord 格式数据
【发布时间】:2021-03-25 00:50:32
【问题描述】:

我正在尝试使用TFRecord format 记录来自 C++ 的数据,然后在 python 中使用它来提供 TensorFlow 模型。

TLDR;简单地将 proto 消息序列化为流不满足 Python TFRecordDataset 类的 .tfrecord 格式要求。在 C++(TensorFlow 或 Google Protobuf 库中)是否有与 Python TfRecordWriter 等效的工具来生成正确的 .tfrecord 数据?

详情:

简化的 C++ 代码如下所示:

tensorflow::Example sample;
sample.mutable_features()->mutable_feature()->operator[]("a").mutable_float_list()->add_value(1.0);

std::ofstream out;
out.open("cpp_example.tfrecord", std::ios::out | std::ios::binary);
sample.SerializeToOstream(&out);

在 Python 中,为了创建 TensorFlow 数据,我尝试使用 TFRecordDataset,但显然它需要 .tfrecord 文件中的额外页眉/页脚信息(而不是简单的序列化原始消息列表):

import tensorflow as tf
tfrecord_dataset = tf.data.TFRecordDataset(filenames="cpp_example.tfrecord")
next(tfrecord_dataset.as_numpy_iterator())

输出:

tensorflow.python.framework.errors_impl.DataLossError: corrupted record at 0 [Op:IteratorGetNext]

请注意,记录的二进制文件没有任何问题,因为下面的代码打印了一个有效的输出:

import tensorflow as tf
p = open("cpp_example.tfrecord", "rb")
example = tf.train.Example.FromString(p.read())

输出:

features {
  feature {
    key: "a"
    value {
      float_list {
        value: 1.0
      }
    }
  }
}

通过分析我的 C++ 示例生成的二进制输出和使用 Python TfRecordWriter 生成的输出,我观察到内容中有额外的页眉和页脚字节。不幸的是,这些额外的字节代表的是一个实现细节(可能是压缩类型和一些额外的信息),我无法比 python 库中的某些类更深入地跟踪它,这些类刚刚暴露了_pywrap_tfe.so 的接口。

this advice.tfrecord 只是一个普通的google protobuf 数据。可能是我不知道在哪里可以找到 protobuf 数据编写器(期望将 proto 消息序列化到输出流中)?

【问题讨论】:

    标签: python c++ tensorflow protocol-buffers tfrecord


    【解决方案1】:

    事实证明,TensorFlow C++ 库的 tensorflow::io::RecordWriter 类可以完成这项工作。

    #include <tensorflow/core/lib/io/record_writer.h>
    
    #include <tensorflow/core/platform/default/posix_file_system.h>
    #include <tensorflow/core/example/example.pb.h>
    
    // ...
    
    // Create WritableFile and instantiate RecordWriter.
    tensorflow::PosixFileSystem posixFileSystem;
    std::unique_ptr<tensorflow::WritableFile> writableFile;
    
    posixFileSystem.NewWritableFile("cpp_example.tfrecord", &writableFile);
    
    tensorflow::io::RecordWriter recordWriter(mWritableFile.get(), tensorflow::io::RecordWriterOptions::CreateRecordWriterOptions(""));
    
    // ...
    tensorflow::Example sample;
    
    // ...
    
    // Serialize proto message into a buffer and record in tfrecord format.
    std::string buffer;
    sample.SerializeToString(&buffer);
    recordWriter.WriteRecord(buffer);
    
    

    如果这个类是从TFRecord documentation 的某个地方引用的,那将会很有帮助。

    【讨论】:

    • khkarens:我在序列化 example.proto 和 TFRecord 之间的差异方面遇到了同样的问题。我已经为 C++ 构建了 tensorflow(通过 floopcz 的 tensorflow_cc),但现在我在编译自己的程序时遇到了问题,因为您的答案中包含的标头需要大量 .cc 文件)。我的程序是在 tensorflow 之外构建的,只需要通过 protobufs 序列化数据,我真的很想避免在我的小程序中编译 tensorflow 的一半 .cc 文件。你是怎么编译上面的代码的?
    猜你喜欢
    • 1970-01-01
    • 2023-03-22
    • 2011-02-06
    • 2023-02-10
    • 2021-08-19
    • 1970-01-01
    • 1970-01-01
    • 2017-02-04
    • 1970-01-01
    相关资源
    最近更新 更多