【发布时间】:2020-11-26 11:09:06
【问题描述】:
我有以下代码可用于我在花园鸟类上训练过的模型。我有两个要解决的问题:
- 视频上的 FPS 很慢,大约 10fps。我该如何改进?
- 我想将它检测到的对象(鸟类)记录到带有时间戳的 CSV 文件中。
提前致谢。
import cv2
from object_detection.builders import model_builder
from object_detection.utils import visualization_utils as viz_utils
from object_detection.utils import config_util
from object_detection.utils import label_map_util
import tensorflow as tf
from PIL import Image
from six import BytesIO
import numpy as np
import os
import matplotlib
matplotlib.use('TkAgg')
def load_image_into_numpy_array(path):
img_data = tf.io.gfile.GFile(path, 'rb').read()
image = Image.open(BytesIO(img_data))
(im_width, im_height) = image.size
return np.array(image.getdata()).reshape(
(im_height, im_width, 3)).astype(np.uint8)
def get_keypoint_tuples(eval_config):
tuple_list = []
kp_list = eval_config.keypoint_edge
for edge in kp_list:
tuple_list.append((edge.start, edge.end))
return tuple_list
pipeline_config = 'inference_graph/pipeline.config'
model_dir = 'inference_graph/checkpoint/'
configs = config_util.get_configs_from_pipeline_file(pipeline_config)
model_config = configs['model']
detection_model = model_builder.build(
model_config=model_config, is_training=False)
ckpt = tf.compat.v2.train.Checkpoint(
model=detection_model)
ckpt.restore(os.path.join(model_dir, 'ckpt-0')).expect_partial()
def get_model_detection_function(model):
@tf.function
def detect_fn(image):
"""Detect objects in image."""
image, shapes = model.preprocess(image)
prediction_dict = model.predict(image, shapes)
detections = model.postprocess(prediction_dict, shapes)
return detections, prediction_dict, tf.reshape(shapes, [-1])
return detect_fn
detect_fn = get_model_detection_function(detection_model)
label_map_path = configs['eval_input_config'].label_map_path
label_map = label_map_util.load_labelmap(label_map_path)
categories = label_map_util.convert_label_map_to_categories(
label_map,
max_num_classes=label_map_util.get_max_label_map_index(label_map),
use_display_name=True)
category_index = label_map_util.create_category_index(categories)
label_map_dict = label_map_util.get_label_map_dict(
label_map, use_display_name=True)
image_dir = 'images/bluetit/'
image_path = os.path.join(image_dir, 'bluetit_5.jpg')
image_np = load_image_into_numpy_array(image_path)
cap = cv2.VideoCapture(0)
while(cap.isOpened()):
ret, frame = cap.read()
input_tensor = tf.convert_to_tensor(
np.expand_dims(frame, 0), dtype=tf.float32)
detections, predictions_dict, shapes = detect_fn(input_tensor)
label_id_offset = 1
image_np_with_detections = frame.copy()
keypoints, keypoint_scores = None, None
if 'detection_keypoints' in detections:
keypoints = detections['detection_keypoints'][0].numpy()
keypoint_scores = detections['detection_keypoint_scores'][0].numpy(
)
viz_utils.visualize_boxes_and_labels_on_image_array(
image_np_with_detections,
detections['detection_boxes'][0].numpy(),
(detections['detection_classes']
[0].numpy() + label_id_offset).astype(int),
detections['detection_scores'][0].numpy(),
category_index,
use_normalized_coordinates=True,
max_boxes_to_draw=200,
min_score_thresh=.80,
agnostic_mode=False,
keypoints=keypoints,
keypoint_scores=keypoint_scores,
keypoint_edges=get_keypoint_tuples(configs['eval_config']))
cv2.imshow('img', image_np_with_detections)
# wait for escape key to stop
if cv2.waitKey(1) == 27:
break
cap.release()
cv2.destroyAllWindows()
【问题讨论】:
标签: tensorflow machine-learning tensorflow2.0 object-detection