【问题标题】:How to load custom yolo v-7 trained model如何加载自定义 yolo v-7 训练模型
【发布时间】:2023-01-13 04:45:30
【问题描述】:

如何加载自定义 yolo v-7 模型。

这就是我知道加载 yolo v-5 模型的方式:

model = torch.hub.load('ultralytics/yolov5', 'custom', path='yolov5/runs/train/exp15/weights/last.pt', force_reload=True)

我在网上看到视频,他们建议使用这个:

!python detect.py --weights runs/train/yolov7x-custom/weights/best.pt --conf 0.5 --img-size 640 --source final_test_v1.mp4 

但我希望它像普通模型一样加载,并给我找到对象的边界框坐标。

这就是我在 yolo v-5 中的做法:

from models.experimental import attempt_load
yolov5_weight_file = r'weights/rider_helmet_number_medium.pt' # ... may need full path
model = attempt_load(yolov5_weight_file, map_location=device)

def object_detection(frame):
    img = torch.from_numpy(frame)
    img = img.permute(2, 0, 1).float().to(device)  #convert to required shape based on index
    img /= 255.0  
    if img.ndimension() == 3:
        img = img.unsqueeze(0)

    pred = model(img, augment=False)[0]
    pred = non_max_suppression(pred, conf_set, 0.20) # prediction, conf, iou
    # print(pred)
    detection_result = []
    for i, det in enumerate(pred):
        if len(det): 
            for d in det: # d = (x1, y1, x2, y2, conf, cls)
                x1 = int(d[0].item())
                y1 = int(d[1].item())
                x2 = int(d[2].item())
                y2 = int(d[3].item())
                conf = round(d[4].item(), 2)
                c = int(d[5].item())
                
                detected_name = names[c]

                # print(f'Detected: {detected_name} conf: {conf}  bbox: x1:{x1}    y1:{y1}    x2:{x2}    y2:{y2}')
                detection_result.append([x1, y1, x2, y2, conf, c])
                
                frame = cv2.rectangle(frame, (x1, y1), (x2, y2), (255,0,0), 1) # box
                if c!=1: # if it is not head bbox, then write use putText
                    frame = cv2.putText(frame, f'{names[c]} {str(conf)}', (x1, y1), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0,0,255), 1, cv2.LINE_AA)

    return (frame, detection_result)

【问题讨论】:

  • 你试试model.load_state_dict(torch.load(PATH))
  • model = TheModelClass(*args, **kwargs) model.load_state_dict(torch.load(PATH)) model.eval() 我应该在模型类中写什么。帮助将不胜感激。
  • 您的模型类基本上是 yolov7 的检测器类,例如 yolov6 或 v5 的 DetectBackend
  • 唯一自定义的是您的自定义模型训练权重

标签: python object-detection yolo yolov5 yolov4


【解决方案1】:
import torch as th

def loadModel(path:str):
    model = th.hub.load("WongKinYiu/yolov7","custom",f{path}",trust_repo=True)

这会起作用。 trust_repo = True 不会要求说 y 或 n。 在路径中,您可以添加自定义火车模型,例如 ./best.pt

【讨论】:

    【解决方案2】:

    使用 torch.hub 使用 yolov7 进行预测

    !# Download YOLOv7 code
    !git clone https://github.com/WongKinYiu/yolov7
    %cd yolov7
    from pathlib import Path
    
    import torch
    
    from models.yolo import Model
    from utils.general import check_requirements, set_logging
    from utils.google_utils import attempt_download
    from utils.torch_utils import select_device
    
    dependencies = ['torch', 'yaml']
    check_requirements(Path("/content/yolov7/").parent / 'requirements.txt', exclude=('pycocotools', 'thop'))
    set_logging()
    
    def custom(path_or_model='path/to/model.pt', autoshape=True):
        """custom mode
    
        Arguments (3 options):
            path_or_model (str): 'path/to/model.pt'
            path_or_model (dict): torch.load('path/to/model.pt')
            path_or_model (nn.Module): torch.load('path/to/model.pt')['model']
    
        Returns:
            pytorch model
        """
        model = torch.load(path_or_model, map_location=torch.device('cpu')) if isinstance(path_or_model, str) else path_or_model  # load checkpoint
        if isinstance(model, dict):
            model = model['ema' if model.get('ema') else 'model']  # load model
    
        hub_model = Model(model.yaml).to(next(model.parameters()).device)  # create
        hub_model.load_state_dict(model.float().state_dict())  # load state_dict
        hub_model.names = model.names  # class names
        if autoshape:
            hub_model = hub_model.autoshape()  # for file/URI/PIL/cv2/np inputs and NMS
        device = select_device('0' if torch.cuda.is_available() else 'cpu')  # default to GPU if available
        return hub_model.to(device)
    
    model = custom(path_or_model='yolov7.pt')  # custom example
    # model = create(name='yolov7', pretrained=True, channels=3, classes=80, autoshape=True)  # pretrained example
    
    # Verify inference
    import numpy as np
    from PIL import Image
    
    imgs = [np.zeros((640, 480, 3))]
    
    results = model(imgs)  # batched inference
    results.print()
    results.save()
    df_prediction = results.pandas().xyxy
    df_prediction
    

    链接到colab

    https://colab.research.google.com/drive/1nKoC-_areXmc_20Xn7z6kcqHEKU7SJsX#scrollTo=yyB_MQW1OWhZ

    【讨论】:

      【解决方案3】:

      您不能使用 Yolov5 存储库中的 attempt_load,因为此方法指向 ultralytics 发布文件。您需要使用 Yolov7 存储库中的 attempt_load,因为它指向正确的文件。

      # yolov7
      def attempt_download(file, repo='WongKinYiu/yolov7'):
          # Attempt file download if does not exist
          file = Path(str(file).strip().replace("'", '').lower())
      ...
      
      # yolov5
      def attempt_download(file, repo='ultralytics/yolov5', release='v6.2'):
          # Attempt file download from GitHub release assets if not found locally. release = 'latest', 'v6.2', etc.
          from utils.general import LOGGER
      
          def github_assets(repository, version='latest'):
      ...
      

      然后你可以这样下载它:

      # load yolov7 method
      from models.experimental import attempt_load
      
      model = attempt_load('yolov7.pt', map_location='cuda:0')  # load FP32 model
      

      【讨论】:

        【解决方案4】:

        你可以这样做:

        import torch
        
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        path = '/path/to/your/file.pt'
        model = torch.hub.load("WongKinYiu/yolov7","custom",f"{path}",trust_repo=True)
        

        要获得结果,您可以运行

        results = model("/path/to/your/photo")
        

        要获取 bbox,您可以使用:

        results.pandas().xyxy
        

        【讨论】:

          猜你喜欢
          • 2019-06-03
          • 2020-12-12
          • 2021-01-12
          • 1970-01-01
          • 2018-11-02
          • 2021-12-15
          • 1970-01-01
          • 1970-01-01
          • 1970-01-01
          相关资源
          最近更新 更多