【发布时间】:2023-01-13 04:45:30
【问题描述】:
如何加载自定义 yolo v-7 模型。
这就是我知道加载 yolo v-5 模型的方式:
model = torch.hub.load('ultralytics/yolov5', 'custom', path='yolov5/runs/train/exp15/weights/last.pt', force_reload=True)
我在网上看到视频,他们建议使用这个:
!python detect.py --weights runs/train/yolov7x-custom/weights/best.pt --conf 0.5 --img-size 640 --source final_test_v1.mp4
但我希望它像普通模型一样加载,并给我找到对象的边界框坐标。
这就是我在 yolo v-5 中的做法:
from models.experimental import attempt_load
yolov5_weight_file = r'weights/rider_helmet_number_medium.pt' # ... may need full path
model = attempt_load(yolov5_weight_file, map_location=device)
def object_detection(frame):
img = torch.from_numpy(frame)
img = img.permute(2, 0, 1).float().to(device) #convert to required shape based on index
img /= 255.0
if img.ndimension() == 3:
img = img.unsqueeze(0)
pred = model(img, augment=False)[0]
pred = non_max_suppression(pred, conf_set, 0.20) # prediction, conf, iou
# print(pred)
detection_result = []
for i, det in enumerate(pred):
if len(det):
for d in det: # d = (x1, y1, x2, y2, conf, cls)
x1 = int(d[0].item())
y1 = int(d[1].item())
x2 = int(d[2].item())
y2 = int(d[3].item())
conf = round(d[4].item(), 2)
c = int(d[5].item())
detected_name = names[c]
# print(f'Detected: {detected_name} conf: {conf} bbox: x1:{x1} y1:{y1} x2:{x2} y2:{y2}')
detection_result.append([x1, y1, x2, y2, conf, c])
frame = cv2.rectangle(frame, (x1, y1), (x2, y2), (255,0,0), 1) # box
if c!=1: # if it is not head bbox, then write use putText
frame = cv2.putText(frame, f'{names[c]} {str(conf)}', (x1, y1), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0,0,255), 1, cv2.LINE_AA)
return (frame, detection_result)
【问题讨论】:
-
你试试
model.load_state_dict(torch.load(PATH)) -
model = TheModelClass(*args, **kwargs) model.load_state_dict(torch.load(PATH)) model.eval() 我应该在模型类中写什么。帮助将不胜感激。
-
您的模型类基本上是 yolov7 的检测器类,例如 yolov6 或 v5 的 DetectBackend
-
唯一自定义的是您的自定义模型训练权重
标签: python object-detection yolo yolov5 yolov4