【问题标题】:full-quatization does not except int8 data to change model input layer to int8全量化不排除 int8 数据将模型输入层更改为 int8
【发布时间】:2020-09-10 13:17:55
【问题描述】:

我正在将 keras h5 模型量化为 uint8。为了得到完整的uint8量化,用户dtlam26this post告诉我代表数据集应该已经在uint8中,否则输入层仍然在float32中。

问题是,如果我提供 uint8 数据,我会在调用 converter.convert() 期间收到以下错误

ValueError: 无法设置张量: 得到了 INT8 类型的张量但预期 输入 FLOAT32 输入 178,名称:input_1

看来,模型仍然需要 float32。所以我检查了基础 keras_vggface 预训练模型 (from here) 与

from keras_vggface.vggface import VGGFace
import keras

pretrained_model = VGGFace(model='resnet50', include_top=False, input_shape=(224, 224, 3), pooling='avg')  # pooling: None, avg or max

pretrained_model.save()

得到的 h5 模型具有 float32 的输入层。 接下来,我使用 uint8 作为输入 dtype 更改了模型定义:

def RESNET50(include_top=True, weights='vggface',
             ...)

    if input_tensor is None:
        img_input = Input(shape=input_shape, dtype='uint8')

但对于 int,只允许使用 int32。但是,使用 int32 会导致以下层需要 float32 的问题。

这似乎不是为所有层手动执行此操作的正确方法。

为什么我的模型在量化过程中除了 uint8 数据之外,并自动将输入改为 uint8?

我错过了什么?你知道解决办法吗?非常感谢。

【问题讨论】:

    标签: python tensorflow keras quantization tensorflow-lite


    【解决方案1】:

    来自用户 dtlam26 的解决方案

    尽管模型仍然无法使用 google NNAPI 运行,但使用 TF 1.15.3 或 TF2.2.0 使用 int8 量化模型并在 int8 中输出的解决方案是,感谢 delan:

    ...
    converter = tf.lite.TFLiteConverter.from_keras_model_file(saved_model_dir + modelname) 
            
    def representative_dataset_gen():
      for _ in range(10):
        pfad='pathtoimage/000001.jpg'
        img=cv2.imread(pfad)
        img = np.expand_dims(img,0).astype(np.float32) 
        # Get sample input data as a numpy array in a method of your choosing.
        yield [img]
        
    converter.representative_dataset = representative_dataset_gen
    
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    converter.representative_dataset = representative_dataset_gen
    converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
    converter.experimental_new_converter = True
    
    converter.target_spec.supported_types = [tf.int8]
    converter.inference_input_type = tf.int8 
    converter.inference_output_type = tf.int8 
    quantized_tflite_model = converter.convert()
    if tf.__version__.startswith('1.'):
        open("test153.tflite", "wb").write(quantized_tflite_model)
    if tf.__version__.startswith('2.'):
        with open("test220.tflite", 'wb') as f:
            f.write(quantized_tflite_model)
    

    【讨论】:

      猜你喜欢
      • 2020-12-27
      • 1970-01-01
      • 1970-01-01
      • 2022-11-11
      • 1970-01-01
      • 2022-12-07
      • 1970-01-01
      • 2018-10-21
      • 1970-01-01
      相关资源
      最近更新 更多