【问题标题】:Error when converting string and array data from CSV file to tfrecords将字符串和数组数据从 CSV 文件转换为 tfrecords 时出错
【发布时间】:2019-09-09 07:38:29
【问题描述】:

我正在按照这些示例将我的 csv 文件转换为 tfrecords。

这是我尝试的代码

csv = pd.read_csv("ehealth.csv").values
with tf.python_io.TFRecordWriter("ehealth.tfrecords") as writer:
    for row in csv:
        question, answer, question_bert, answer_bert = row[0], row[1] , row[1], row[2]
        example = tf.train.Example()
        example.features.feature["question"].bytes_list.value.extend(question.encode("utf8"))
        example.features.feature["answer"].bytes_list.value.extend(answer.encode("utf8"))
        example.features.feature["question_bert"].float_list.value.extend(question_bert)
        example.features.feature["answer_bert"].float_list.value.append(answer_bert)
        writer.write(example.SerializeToString())

这是我的错误

TypeError                                 Traceback (most recent call last) <ipython-input-36-0a8c5e073d84> in <module>()
      4         question, answer, question_bert, answer_bert = row[0], row[1] , row[1], row[2]
      5         example = tf.train.Example()
----> 6         example.features.feature["question"].bytes_list.value.extend(question.encode("utf8"))
      7         example.features.feature["answer"].bytes_list.value.extend(answer.encode("utf8"))
      8         example.features.feature["question_bert"].float_list.value.extend(question_bert)

TypeError: 104 has type int, but expected one of: bytes

似乎在对字符串进行编码时出现问题。我评论了这两行以确保其他一切正常,

csv = pd.read_csv("ehealth.csv").values
with tf.python_io.TFRecordWriter("ehealth.tfrecords") as writer:
    for row in csv:
        question, answer, question_bert, answer_bert = row[0], row[1] , row[1], row[2]
        example = tf.train.Example()
#         example.features.feature["question"].bytes_list.value.extend(question)
#         example.features.feature["answer"].bytes_list.value.extend(answer)
        example.features.feature["question_bert"].float_list.value.extend(question_bert)
        example.features.feature["answer_bert"].float_list.value.append(answer_bert)
        writer.write(example.SerializeToString())

然后我得到这些错误

TypeError                                 Traceback (most recent call last) <ipython-input-13-565b43316ef5> in <module>()
      6 #         example.features.feature["question"].bytes_list.value.extend(question)
      7 #         example.features.feature["answer"].bytes_list.value.extend(answer)
----> 8         example.features.feature["question_bert"].float_list.value.extend(question_bert)
      9         example.features.feature["answer_bert"].float_list.value.append(answer_bert)
     10         writer.write(example.SerializeToString())

TypeError: 's' has type str, but expected one of: int, long, float

原来问题在于 pandas 将我的数组解释为字符串而不是数组

type( csv[0][2])

->str

此外,看起来我必须使用example.SerializeToString(),因为我有一个数组,但不知道如何去做。

以下是重现错误的完整代码,包括从谷歌驱动器下载 csv 文件的代码。

import pandas as pd
import numpy as np
import requests
import tensorflow as tf

def download_file_from_google_drive(id, destination):
    URL = "https://docs.google.com/uc?export=download"

    session = requests.Session()

    response = session.get(URL, params = { 'id' : id }, stream = True)
    token = get_confirm_token(response)

    if token:
        params = { 'id' : id, 'confirm' : token }
        response = session.get(URL, params = params, stream = True)

    save_response_content(response, destination)    

def get_confirm_token(response):
    for key, value in response.cookies.items():
        if key.startswith('download_warning'):
            return value

    return None

def save_response_content(response, destination):
    CHUNK_SIZE = 32768

    with open(destination, "wb") as f:
        for chunk in response.iter_content(CHUNK_SIZE):
            if chunk: # filter out keep-alive new chunks
                f.write(chunk)

# download_file_from_google_drive('1rMjqKkMnt6_vROrGmlTGStNGmwPO4YFX', 'model.zip') #

file_id = '1anbEwfViu9Rzu7tWKgPb_We1EwbA4x1-'
destination = 'ehealth.csv'
download_file_from_google_drive(file_id, destination)

healthdata=pd.read_csv('ehealth.csv')
healthdata.head()

csv = pd.read_csv("ehealth.csv").values
with tf.python_io.TFRecordWriter("ehealth.tfrecords") as writer:
    for row in csv:
        question, answer, question_bert, answer_bert = row[0], row[1] , row[1], row[2]
        example = tf.train.Example()
        example.features.feature["question"].bytes_list.value.extend(question)
        example.features.feature["answer"].bytes_list.value.extend(answer)
        example.features.feature["question_bert"].float_list.value.extend(question_bert)
        example.features.feature["answer_bert"].float_list.value.append(answer_bert)
        writer.write(example.SerializeToString())


csv = pd.read_csv("ehealth.csv").values
with tf.python_io.TFRecordWriter("ehealth.tfrecords") as writer:
    for row in csv:
        question, answer, question_bert, answer_bert = row[0], row[1] , row[1], row[2]
        example = tf.train.Example()
#         example.features.feature["question"].bytes_list.value.extend(question)
#         example.features.feature["answer"].bytes_list.value.extend(answer)
        example.features.feature["question_bert"].float_list.value.extend(question_bert)
        example.features.feature["answer_bert"].float_list.value.append(answer_bert)
        writer.write(example.SerializeToString())

【问题讨论】:

    标签: python pandas tensorflow tfrecord


    【解决方案1】:

    试试

    example.features.feature["question"].bytes_list.value.extend([bytes(question, 'utf-8')])
    

    它将帮助您解决第 6 行错误,同样的更改适用于第 7 行。

    并检查您的编号

    question, answer, question_bert, answer_bert = row[0], row[1] , row[1], row[2]
    

    我认为应该是 0、1、2 和 3。

    在更正正确的顺序时,您仍然会收到错误消息。 所以,添加

    print(type(question_bert))
    

    它说它是一个字符串。如果真的是字符串,那么需要换成

    float_list.value.append
    

    bytes_list.value.extend
    

    如果你有一个数组,那么你需要使用

    tf.serialize_tensor
    

    这里是一个简单的 tf.serialize_tensor 例子

    a = np.array([[1.0, 2, 46], [0, 0, 1]])
    b=tf.serialize_tensor(a)
    b
    

    输出是

    <tf.Tensor: id=25, shape=(), dtype=string, numpy=b'\x08\x02\x12\x08\x12\x02\x08\x02\x12\x02\x08\x03"0\x00\x00\x00\x00\x00\x00\xf0?\x00\x00\x00\x00\x00\x00\x00@\x00\x00\x00\x00\x00\x00G@\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xf0?'>
    

    您需要将其保存为字节。

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 1970-01-01
      • 2013-05-05
      • 1970-01-01
      • 2018-03-29
      • 1970-01-01
      • 1970-01-01
      • 2013-04-27
      • 2023-03-18
      相关资源
      最近更新 更多