【问题标题】:Logging detected objects in TensorFlow Object Detection model (and improving low FPS)在 TensorFlow 对象检测模型中记录检测到的对象(并改善低 FPS)
【发布时间】:2020-11-26 11:09:06
【问题描述】:

我有以下代码可用于我在花园鸟类上训练过的模型。我有两个要解决的问题:

  1. 视频上的 FPS 很慢,大约 10fps。我该如何改进?
  2. 我想将它检测到的对象(鸟类)记录到带有时间戳的 CSV 文件中。

提前致谢。

import cv2
from object_detection.builders import model_builder
from object_detection.utils import visualization_utils as viz_utils
from object_detection.utils import config_util
from object_detection.utils import label_map_util
import tensorflow as tf
from PIL import Image
from six import BytesIO
import numpy as np
import os
import matplotlib
matplotlib.use('TkAgg')  


def load_image_into_numpy_array(path):   
    img_data = tf.io.gfile.GFile(path, 'rb').read()
    image = Image.open(BytesIO(img_data))
    (im_width, im_height) = image.size
    return np.array(image.getdata()).reshape(
        (im_height, im_width, 3)).astype(np.uint8)


def get_keypoint_tuples(eval_config):  
    tuple_list = []
    kp_list = eval_config.keypoint_edge
    for edge in kp_list:
        tuple_list.append((edge.start, edge.end))
    return tuple_list

pipeline_config = 'inference_graph/pipeline.config'
model_dir = 'inference_graph/checkpoint/'

configs = config_util.get_configs_from_pipeline_file(pipeline_config)
model_config = configs['model']
detection_model = model_builder.build(
    model_config=model_config, is_training=False)

ckpt = tf.compat.v2.train.Checkpoint(
    model=detection_model)
ckpt.restore(os.path.join(model_dir, 'ckpt-0')).expect_partial()


def get_model_detection_function(model):
  
    @tf.function
    def detect_fn(image):
        """Detect objects in image."""
        image, shapes = model.preprocess(image)
        prediction_dict = model.predict(image, shapes)
        detections = model.postprocess(prediction_dict, shapes)
        return detections, prediction_dict, tf.reshape(shapes, [-1])
    return detect_fn

detect_fn = get_model_detection_function(detection_model)

label_map_path = configs['eval_input_config'].label_map_path
label_map = label_map_util.load_labelmap(label_map_path)
categories = label_map_util.convert_label_map_to_categories(
    label_map,
    max_num_classes=label_map_util.get_max_label_map_index(label_map),
    use_display_name=True)
category_index = label_map_util.create_category_index(categories)
label_map_dict = label_map_util.get_label_map_dict(
    label_map, use_display_name=True)

image_dir = 'images/bluetit/'
image_path = os.path.join(image_dir, 'bluetit_5.jpg')
image_np = load_image_into_numpy_array(image_path)

cap = cv2.VideoCapture(0)

while(cap.isOpened()):
  
    ret, frame = cap.read()

    input_tensor = tf.convert_to_tensor(
        np.expand_dims(frame, 0), dtype=tf.float32)
    detections, predictions_dict, shapes = detect_fn(input_tensor)

    label_id_offset = 1
    image_np_with_detections = frame.copy()

    keypoints, keypoint_scores = None, None
    if 'detection_keypoints' in detections:
        keypoints = detections['detection_keypoints'][0].numpy()
        keypoint_scores = detections['detection_keypoint_scores'][0].numpy(
        )

    viz_utils.visualize_boxes_and_labels_on_image_array(
        image_np_with_detections,
        detections['detection_boxes'][0].numpy(),
        (detections['detection_classes']
         [0].numpy() + label_id_offset).astype(int),
        detections['detection_scores'][0].numpy(),
        category_index,
        use_normalized_coordinates=True,
        max_boxes_to_draw=200,
        min_score_thresh=.80,
        agnostic_mode=False,
        keypoints=keypoints,
        keypoint_scores=keypoint_scores,
        keypoint_edges=get_keypoint_tuples(configs['eval_config']))

    cv2.imshow('img', image_np_with_detections)
 
    # wait for escape key to stop
    if cv2.waitKey(1) == 27:
        break
  
cap.release()
cv2.destroyAllWindows()

【问题讨论】:

    标签: tensorflow machine-learning tensorflow2.0 object-detection


    【解决方案1】:

    我通过从 vizualisation_utils.py 中窃取一些代码来实现这一点。我创建了这个函数,它能够将视频中出现的任何对象打印到控制台。

    def LogBirds(boxes,
                 classes,
                 scores,
                 category_index,
                 min_score_thresh=.8):
    
        for i in range(boxes.shape[0]):
            if scores is None or scores[i] > min_score_thresh:
                display_str = ''
                if classes[i] in six.viewkeys(category_index):
                    class_name = category_index[classes[i]]['name']
                else:
                    class_name = 'N/A'
                display_str = str(class_name)
                if not display_str:
                    display_str = '{}%'.format(round(100*scores[i]))
                else:
                    display_str = '{}: {}%'.format(
                        display_str, round(100*scores[i]))
    
                print("Bird Detected: ", display_str)
    

    【讨论】:

      猜你喜欢
      • 2019-11-09
      • 2019-12-13
      • 2020-06-03
      • 1970-01-01
      • 2018-01-17
      • 1970-01-01
      • 2020-11-02
      • 2018-03-30
      • 2019-04-05
      相关资源
      最近更新 更多