【发布时间】:2018-08-01 08:22:58
【问题描述】:
我正在尝试将编码的字节字符串转换回张量流图中的原始数组(使用张量流操作),以便在张量流模型中进行预测。数组到字节的转换基于this answer,它是谷歌云机器学习引擎上张量流模型预测的建议输入。
def array_request_example(input_array):
input_array = input_array.astype(np.float32)
byte_string = input_array.tostring()
string_encoded_contents = base64.b64encode(byte_string)
return string_encoded_contents.decode('utf-8')}
张量流代码
byte_string = tf.placeholder(dtype=tf.string)
audio_samples = tf.decode_raw(byte_string, tf.float32)
audio_array = np.array([1, 2, 3, 4])
bstring = array_request_example(audio_array)
fdict = {byte_string: bstring}
with tf.Session() as sess:
[tf_samples] = sess.run([audio_samples], feed_dict=fdict)
我尝试过使用decode_raw 和decode_base64,但都没有返回原始值。
我已尝试将 decode raw 的 out_type 设置为不同的可能数据类型,并尝试更改将原始数组转换为的数据类型。
那么,我将如何在 tensorflow 中读取字节数组?谢谢:)
额外信息
这背后的目的是为自定义 Estimator 创建服务输入函数,以使用 gcloud ml-engine local predict(用于测试)和对存储在云上的模型使用 REST API 进行预测。
Estimator 的服务输入函数是
def serving_input_fn():
feature_placeholders = {'b64': tf.placeholder(dtype=tf.string,
shape=[None],
name='source')}
audio_samples = tf.decode_raw(feature_placeholders['b64'], tf.float32)
# Dummy function to save space
power_spectrogram = create_spectrogram_from_audio(audio_samples)
inputs = {'spectrogram': power_spectrogram}
return tf.estimator.export.ServingInputReceiver(inputs, feature_placeholders)
Json 请求
我使用 .decode('utf-8') 是因为在尝试 json 转储 base64 编码的字节字符串时,我收到此错误
raise TypeError(repr(o) + " is not JSON serializable")
TypeError: b'longbytestring'
预测错误
使用 gcloud local 传递 json 请求 {'audio_bytes': 'b64': bytestring} 时出现错误
PredictionError: Invalid inputs: Expected tensor name: b64, got tensor name: [u'audio_bytes']
那么也许 google cloud local predict 不会自动处理音频字节和 base64 转换?或者我的 Estimator 设置可能有问题。
而对 REST API 的请求 {'instances': [{'audio_bytes': 'b64': bytestring}]} 给出了
{'error': 'Prediction failed: Error during model execution: AbortionError(code=StatusCode.INVALID_ARGUMENT, details="Input to DecodeRaw has length 793713 that is not a multiple of 4, the size of float\n\t [[Node: DecodeRaw = DecodeRaw[_output_shapes=[[?,?]], little_endian=true, out_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_source_0_0)]]")'}
这让我很困惑,因为我明确地将请求定义为浮点数并在服务输入接收器中执行相同操作。
从请求中删除 audio_bytes 并对字节字符串进行 utf-8 编码可以让我得到预测,但在本地测试解码时,我认为音频从字节字符串转换不正确。
【问题讨论】:
标签: tensorflow google-cloud-platform byte google-cloud-ml