【发布时间】:2021-06-24 20:35:36
【问题描述】:
我想将 Pytorch 训练的模型转换为 tensorflow 模型并在移动设备上使用该模型。为此,我遵循以下步骤;首先,我将 pytorch 训练的模型转换为 onnx 格式。然后我把onnx格式转成tensorflow模型。
首先用pytorch训练模型到onnx;
import torch
import torch.onnx
from detectron2.modeling import build_model
from detectron2.modeling import build_backbone
from torch.autograd import Variable
model= build_backbone(cfg)
model.eval()
dummy_input = torch.randn(1,3,224, 224,requires_grad=True)
torch.onnx.export(model,dummy_input,"drive/Detectron2/model_final.onnx")
然后onnx转tflite模型;
import onnx
import warnings
from onnx_tf.backend import prepare
model = onnx.load("drive/Detectron2/model_final.onnx")
tf_rep = prepare(model)
tf_rep.export_graph("drive/Detectron2/tf_rep.pb")
import tensorflow as tf
## TFLite Conversion
# Before conversion, fix the model input size
model = tf.saved_model.load("drive/Detectron2/tf_rep.pb")
model.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY].inputs[0].set_shape([1, 3,224, 224])
tf.saved_model.save(model, "saved_model_updated", signatures=model.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY])
# Convert
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir='saved_model_updated', signature_keys=['serving_default'])
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]
tflite_model = converter.convert()
# Save the model.
with open('drive/Detectron2/model.tflite', 'wb') as f:
f.write(tflite_model)
## TFLite Interpreter to check input shape
interpreter = tf.lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors()
# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
# Test the model on random input data.
input_shape = input_details[0]['shape']
print(input_shape)
但是当我在移动设备上使用模型时,出现以下错误;
java.lang.AssertionError: Error occurred when initializing ImageSegmenter: The input tensor should have dimensions 1 x height x width x 3. Got 1 x 3 x 224 x 224.
我哪里做错了?
【问题讨论】:
-
用 numpy
reshape重塑你的输入 -
Pytorch 卷积层通常需要输入形状
(batch, num_channels, height, width),在您的情况下为:(1, 3, 224, 224)。然而 tensorflow 需要输入形状(batch, height, width, num_channels)。您需要将输入转置为 (1, height, width, num_channels) -
明确一点:不要不重塑输入。正如@Alka 所说,您需要转置它。
-
@Alka 我在进行您所说的更改时收到此错误;
dummy_input = torch.randn(1, 224, 224,3, requires_grad=True)RuntimeError: 给定组=1,大小为 [64, 3, 7, 7] 的权重,预期输入 [1, 224, 224, 3] 有 3 个通道,但有 224 个通道而不是 @Berriel -
我认为您不应该将 Pytorch 修改为 ONNX 部分,看起来已经可以了。但是,我看到了两个潜在的变革候选地点?您可以先尝试将
model.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY].inputs[0].set_shape([1, 3,224, 224])中的形状更改为model.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY].inputs[0].set_shape([1,224, 224, 3])之类的东西
标签: python tensorflow machine-learning pytorch onnx